diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b5f4dfb8..6610f21f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -96,7 +96,9 @@ jobs: UV_TORCH_BACKEND: cpu run: uv sync --group dev - name: Build distributions - run: uv run python -m build --outdir dist . + run: | + uv run python scripts/validate_fabric_generated_catalogs.py + uv run python -m build --outdir dist . - name: Check distributions run: uv run twine check dist/* - name: Smoke test wheel diff --git a/.gitignore b/.gitignore index a0617b89..52e29bce 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ dist/ build/ .ruff_cache/ .pyright/ +tmp/ ai_docs/runs/ diff --git a/AGENTS.md b/AGENTS.md index 3d8d0cfc..696a0e69 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -6,6 +6,9 @@ This file is only the top-level router for AI agents. Detailed workflows live in - If session-start hooks report warnings or action items, relay them to the user before doing work. - Before starting non-trivial work, inspect `skills/` for a relevant skill by name or description. +- For Cortical Fabric work, start with `skills/cb.fabric-workflow-router/SKILL.md` unless the user names a narrower + Fabric skill directly; include `skills/cb.fabric-compiler-boundary-audit/SKILL.md` for compiler/backend/performance or + semantic changes. - When a relevant skill exists, read its `SKILL.md` and follow that workflow. - Prefer the narrowest applicable skill. Use multiple skills only when the task clearly spans multiple workflows. - Also check for a more specific `AGENTS.md` in the subtree you are editing. diff --git a/ai_docs/FABRIC_THROUGHPUT_CLOSURE.md b/ai_docs/FABRIC_THROUGHPUT_CLOSURE.md new file mode 100644 index 00000000..bbd726e2 --- /dev/null +++ b/ai_docs/FABRIC_THROUGHPUT_CLOSURE.md @@ -0,0 +1,13509 @@ +# Fabric Throughput Closure + +Created: 2026-05-03. + +This is the active progress document for Fabric throughput work after compiler +closure. It starts from the compiler-closure checkpoint in +`ai_docs/REDO2_FIXMASS.md` and deliberately does not begin optimization by +itself. The first throughput task is measurement and owner attribution. + +## Starting Baseline + +- Compiler closure baseline: Fabric-focused registered temporal compiler sweep + green with `504 passed`, `9 warnings`, `342.77s`. +- No known compiler-closure blocker remains in the tracked registered temporal + compiler sweep. +- At document creation, dirty/untracked working-tree paths are documentation + only: `ai_docs/REDO2_FIXMASS.md`, `ai_docs/REDO_FIXMAASS.md`, + `ai_docs/AWS_RECOVERY_TRAIL.md`, `ai_docs/additonal_goals.md`, and + `ai_docs/prompt.tx`. +- No throughput benchmark, profiler, CUDA tuning, or owner-specific + optimization has been run for this throughput phase yet. +- April 21 comparison data source: + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json`. + +April 21 is the target matrix and historical scorecard. It is not code to copy. +Any recovered implementation idea must be re-expressed as a registered Fabric +compiler strategy over primitive rows, tensor bindings, memory/liveness rows, +executor rows, and reducer rows. + +## Non-Negotiable Rules + +- Throughput evidence must come from the declared Fabric program path: + `output = model(x, ...)`, external loss, `loss.backward()`, and optional + optimizer step. +- Benchmark code may report metadata, but it must not choose tiling, temporal + chunking, checkpointing, detach policy, private runtime helpers, direct CUDA + wrappers, or streaming-loss loops to make a row pass. +- A metadata label is not evidence. The active runtime owner must physically + move or shrink in the warmed profile. +- Every optimization must be a registered strategy over existing compiler + products. If the primitive rows, tensor roles, legality, memory/liveness + contract, or reducer rows do not exist, stop and add compiler semantics first. +- Do not add formulas to temporal scheduler files, fixed slot enums, + cell-family route selectors, benchmark branches, or compatibility wrappers. +- Treat CUDA-native and Triton as peer implementation strategies, not as Fabric + semantics and not as hidden special cases. The registered temporal compiler + may currently live under `backend/cuda`, but generic rows, liveness policy, + artifact policy, and scheduling policy must not be written as if CUDA-native + is the only possible implementation. Triton support must enter through the + same primitive rows, tensor bindings, legality checks, memory/liveness rows, + and explicit strategy records as CUDA-native support. +- Any new strategy-facing metadata must keep backend choice explicit: + implementation backend, runtime entrypoint, supported device, dtype/layout + contract, forward/backward support, artifact contract, workspace policy, and + fail-closed unsupported reason. Do not route by family name, benchmark row, + hidden size, graph shape, or ad hoc `if triton`/`if cuda` branches in the + temporal scheduler. +- Parity gates follow the touched owner. Speed does not count until outputs, + exposed state, input/carry gradients, and all nonzero parameter gradients are + green for the affected rows. +- Memory is a closure gate. Passing throughput with a material April 21 memory + regression remains open unless the regression has a named memory/liveness + owner. +- `audits/fabric/` is reserved for final full audit results and historical + baselines. Partial probes, owner-table experiments, rejected liveness routes, + and intermediate throughput measurements must write under the gitignored + `tmp/fabric_audits/partials/` tree. Promote a result into `audits/fabric/` + only when it is a final full result set for the documented closure gate. + +### 2026-05-03 - Active Semantics Change: Default Context-Nudge + Axon Norm + +Status: semantic/default change, not a throughput optimization. + +- The default Fabric `DotProduct` message rule is now fixed-slot + context-nudge, not dynamic key/value. Throughput work must profile the active + default unless a run is explicitly labeled as a legacy dynamic control. +- Axon now follows the same public norm epilogue shape as sLSTM: output + projection emits `public_y_raw`, then + `norm_or_identity(outnorm_weight, outnorm_eps)` emits `public_y`. +- For default context-nudge, recurrent messages are already projected-message + tensors. Axon input projection must consume the fused + `message_to_cell_weight -> input_proj_weight` path and its reducer must carry + `selected_static_source=message_to_cell_weight`; using old + `value_to_cell_weight -> input_proj_weight` semantics is a legacy dynamic + control, not the active default. +- Treat `outnorm_weight`, `outnorm_eps`, and `grad_outnorm_weight` as required + Axon parity and throughput sinks. Missing these bindings means the Axon row is + not testing the current compiler path. +- CUDA-native and Triton remain peer backend strategies selected by + rows/bindings. Do not encode context-nudge, dynamic dot-product, CUDA-native, + Triton, sLSTM, or Axon as hidden temporal-engine route selectors. + +## Fresh Agent Semantic Locator + +This section is the orientation map for a new agent entering the throughput +thread. If it disagrees with a newer dated section below, trust the newest dated +section and update this locator before proceeding. + +Read in this order: + +1. `## Non-Negotiable Rules`. +2. `### 2026-05-03 - Active Semantics Change: Default Context-Nudge + Axon Norm`. +3. This locator. +4. The newest dated section at the end of this file. +5. The April 21 baseline file: + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json`. + +Current semantic state: + +- Active default message math is fixed-slot context-nudge. Dynamic key/value is + only a labeled legacy/control run. +- Axon includes the same public norm epilogue path as sLSTM. Axon throughput and + parity must include `outnorm_weight`, `outnorm_eps`, and + `grad_outnorm_weight`. +- CUDA-native and Triton are peer strategy backends. They must be selected + through rows, bindings, legality, liveness, and explicit strategy metadata, + not hidden scheduler branches. + +Current performance location: + +- April 21 target floor for `h32_t1_bxparams` remains `58732.71 tok/s`, + `2.07 GiB`. +- Accepted steering baseline before the latest row-group liveness fix: + `tmp/fabric_audits/partials/2026-05-03/t1_direct_keyless_readout_forward_h32_100m_b1024`. +- Latest accepted forward memory move: + `tmp/fabric_audits/partials/2026-05-03/t1_slstm_transition_rowgroup_liveness_fix_forward_h32_100m_b1024`. + It moved sLSTM forward peak from `9.389 GiB` to `7.139 GiB` with speed about + flat, and kept Axon about flat at `13.750 GiB`. +- Rejected shortcut examples to avoid reopening blindly: + `t1_direct_message_forward_h32_100m_b1024`, + `t1_message_projection_out_forward_h32_100m_b1024`, + `t1_recurrent_matmul_bmm_forward_guard_h32_100m_b1024`, and + `t1_readout_bmm_forward_guard_h32_100m_b1024`. + +Current owner map: + +- sLSTM forward: transition row-group liveness moved, but the remaining peak is + native temporary/allocator high-water inside the row-group execution. Next + useful step is finer native-stage telemetry or a row-group workspace move + with physical owner movement. +- Axon forward: next owner remains recurrent K/V-before plus recurrent-message + producer-consumer materialization. Direct per-edge projection is rejected + because it loses grouped dense work and regresses throughput. +- Training: keep as parity/liveness guardrail until forward owners move again. + Do not let training optimization distract from the current forward owner. + +Code pointers: + +- Audit runner and April 21 reference loading: + `benchmarks/fabric/audit.py`, `benchmarks/fabric/run_audit.py`. +- Compiler rows, access, liveness, and registered executor bindings: + `src/cortical/fabric/backend/cuda/sequence_surface/compiler/`. +- Registered CUDA-native forward/backward strategy bodies: + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/`. +- Native callable strategy registration: + `src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py` + and + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/`. +- Boundary skills to read before edits: + `skills/cb.fabric-performance-loop/SKILL.md`, + `skills/cb.fabric-throughput-strategy/SKILL.md`, + `skills/cb.fabric-compiler-boundary-audit/SKILL.md`, + `skills/cb.fabric-native-strategy-onboarding/SKILL.md`, and + `skills/cb.fabric-reducer-liveness/SKILL.md` when reducer/artifact liveness + is touched. + +Before editing: + +- Run `git status --short`; do not overwrite another agent's dirty work. +- Write the hypothesis, smallest representative probe, artifact path, keep rule, + and follow-up representative row in this file before a long run. +- A semantic-transfer idea from April 21 counts only when it is expressed as a + current registered compiler strategy and the physical owner moves in timing, + launch shape, allocator/high-water telemetry, storage identity, or named + memory stage. + +## First Closure Target: T=1 + +T=1 is the base physical Fabric execution unit. K, T>1, horizon-H, and +per-timestep loss must not be treated as closed while matched T=1 training is +below April 21 on the same graph, batch, params, hidden, population, loss +boundary, reset, and state/materialization contract. + +### First Task: Current-Code Owner Table + +Before changing performance code, build a current-code T=1 owner table. Each row +must record: + +- April 21 reference key and exact mapping status; +- family and population shape: sLSTM, Axon, single-pop, mixed-pop; +- parameter target and actual parameter count; +- hidden size, batch size, graph shape, reset policy, output boundary, and final + state/materialization contract; +- mode: forward or training/backward; +- current tok/s and peak GiB; +- April 21 tok/s and peak GiB when an exact match exists; +- compiler/runtime metadata: forward owner, reverse owner, memory/artifact + owner, reducer owner, and dominant host/kernel owner; +- parity status and any missing parity axis; +- whether the run was cold, warmed, interrupted, stale, or concurrent. + +No optimization starts until this table identifies the dominant live owner for +the headline T=1 miss. + +### T=1 Rows To Map First + +The April 21 baseline file is the required T=1 closure target, not just a +source of boundary examples. The file records summary coverage rather than raw +per-row results for every member row, so this phase must reconstruct and run the +full covered matrices from the April 21 coverage strings before claiming +closure. Boundary rows such as Axon 100M train B=1024 are steering rows only; +they do not stand in for the complete matrix. + +- `h32_t1_bxparams`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384. Boundary reference: Axon 100M train B=1024, + `58732.71 tok/s`, `2.07 GiB`. +- `h32_small_params_high_batch`: sLSTM + Axon, 1M/10M, forward + training, + B=16384/65536/131072. Boundary reference: Axon 1M train B=131072, + `1986978.16 tok/s`, `12.97 GiB`. +- `h4_many_cell_stress`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384. Boundary reference: Axon 500M train B=1024, + `12513.0 tok/s`, `18.09 GiB`. +- `h8_many_cell_stress_focused_warmed_rerun`: Axon 1B train B=1024, + `16723.8 tok/s`, `46.71 GiB`. +- `h16_many_cell_stress`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384. Boundary reference: Axon 500M train B=1024, + `35558.75 tok/s`, `10.63 GiB`. +- `flat_graph_factorization_invariance`: equivalent flat graphs with only + user-side factorization labels varied. Run after the main T=1 owner has moved. + April 21 also has B=16 factorization evidence and B=128 rollout evidence, but + those are not substitutes for the T=1 throughput matrices above. Any added + small-B T=1 latency rows are supplemental guardrails unless they are mapped to + an April 21 T=1 throughput coverage row. +- `h32_small_batch_latency_guardrail`: supplemental current-code T=1 rows for + sLSTM + Axon, 100M, forward + training, B=1/64/512, h=32, single-pop, + reset-absent, terminal output boundary. These rows catch launch overhead, + front-end/runtime setup cost, fixed allocation overhead, and small-batch path + divergence. They have no April 21 throughput target and cannot replace the + required April 21 matrices. + +### T=1 Closure Gate + +T=1 closes only when: + +- every April 21 T=1 throughput coverage row has been rerun on the current + compiler path and matched or exceeded on throughput and memory: `h32_t1_bxparams` + (`24/24` rows), `h32_small_params_high_batch` (`24/24` rows), + `h4_many_cell_stress` (`24/24` rows), the accepted h8 stress confirmation + (`h8_many_cell_stress_focused_warmed_rerun`, `2/2` rows, plus any broad h8 + rows used as non-regression evidence), and `h16_many_cell_stress` (`24/24` + rows); +- no subset, average, boundary row, focused Axon row, four-row guardrail, or + small-B smoke row may close T=1 while any April 21 T=1 member row is missing, + slower, materially higher-memory, stale, interrupted, or only explained by an + unaccepted open owner; +- supplemental small-B T=1 guardrails B=1/64/512 pass for sLSTM and Axon + forward + training without Python replay/fallback, benchmark-owned policy, + row-specific routing, or unstable owner metadata; +- representative h=32 rows pass for single-pop sLSTM and Axon; +- mixed-pop T=1 passes the same conceptual surfaces, because mixed-pop is bucket + cardinality under the same flat-bucket engine; +- small-hidden h=4/8/16 stress rows are green or have accepted open owners; +- small-param high-batch rows show launch/batch scaling is not broken; +- reset-present rows pass once state, artifacts, tape, or backward owners are + touched; +- current compiler path meets or exceeds comparable April 21 throughput and + memory, or the miss is explicitly left open with a named owner; +- metadata reports registered compiler-owned forward, backward, memory/artifact, + and reducer owners, with no Python replay/fallback or benchmark-owned route. + +## After T=1 + +Only after T=1 is healthy: + +1. close K>1 against matched current-code T=1 divided by K; +2. close T>1 streaming against matched T=1 per-token throughput; +3. close horizon-H/TBPTT with bounded internal memory; +4. close ordinary per-timestep loss through `output = model(x)`, external loss, + and `loss.backward()` with no benchmark-owned streaming-loss helper. + +## Evidence Log + +### 2026-05-03 - Partial Audit Artifact Location Policy + +Status: cleanup only. No throughput optimization was run. + +- Moved the scratch `2026-05-03` partial throughput audit tree from + `audits/fabric/2026-05-03` to + `tmp/fabric_audits/partials/2026-05-03`. +- Updated this document's partial artifact references and command `--out-dir` + examples to use the gitignored temp tree. +- Added `tmp/` to `.gitignore`. +- Policy going forward: `audits/fabric/` stores final full results and + historical baselines only. Partial probes stay in `tmp/fabric_audits/partials/` + until promoted deliberately as a full closure result. + +### 2026-05-03 - Document Created + +- Created this standalone throughput progress doc. +- No throughput benchmark, profiler, or optimization was run. +- Next action: analyze current benchmark/audit entrypoints and build the first + current-code T=1 owner table against April 21. + +### 2026-05-03 - Initial T=1 Owner Table + +Status: analysis only. No optimization or code change was made. + +Commands: + +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_throughput_t1_owner_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_throughput_t1_owner_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_owner_table_h32_100m_b1024_initial --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_throughput_t1_owner_axon_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_throughput_t1_owner_axon_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_owner_table_h32_100m_b1024_axon_train_isolated --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2` + +Artifacts: + +- `tmp/fabric_audits/partials/2026-05-03/t1_owner_table_h32_100m_b1024_initial/summary.json` +- `tmp/fabric_audits/partials/2026-05-03/t1_owner_table_h32_100m_b1024_initial/cases.jsonl` +- `tmp/fabric_audits/partials/2026-05-03/t1_owner_table_h32_100m_b1024_axon_train_isolated/summary.json` +- `tmp/fabric_audits/partials/2026-05-03/t1_owner_table_h32_100m_b1024_axon_train_isolated/cases.jsonl` + +Reference: `h32_t1_bxparams`, April 21 summary floor +`58732.71 tok/s`, `2.07 GiB`. This reference is a summary floor for the +April 21 h32 T=1 matrix; exact per-row April 21 raw rows are not present in +the JSON, so the table below is a current-code owner table against the summary +floor. + +| row | current status | current tok/s | vs Apr21 tok/s | current peak GiB | vs Apr21 memory | owner metadata | +| --- | --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | ok | 5690.59 | 9.69% | 20.52 | 9.9x | forward `registered_fused_forward_program_cuda`; runtime `registered_temporal_fused_forward_program_cuda`; shape `[8,768]`, active receivers `4608`, actual params `100619488` | +| sLSTM 100M h32 B1024 T1 training | ok | 973.50 | 1.66% | 56.79 | 27.4x | forward `registered_fused_forward_program_cuda`; backward `registered_reverse_executor_bindings`; runtime `registered_temporal_fused_forward_program_cuda`; backward executors `physical_tiny_message_backward_executor`, `physical_receiver_affine_backward_executor`, `physical_state_epilogue_backward_executor`, `registered_sender_kv_projection_backward_executor`, `projection_reduction_boundary_backward`, `physical_temporal_bucket_sequence_backward`, `cuda_temporal_backward_glue`; full transition tape; `store_step_artifacts` | +| Axon 100M h32 B1024 T1 forward | ok | 3951.09 | 6.73% | 93.23 | 45.0x | forward `registered_fused_forward_program_cuda`; runtime `registered_temporal_fused_forward_program_cuda`; shape `[16,1024]`, active receivers `14336`, actual params `99924192` | +| Axon 100M h32 B1024 T1 training | OOM | - | - | >139 GiB process use | >67x before failure | isolated rerun also OOMed; forward `registered_fused_forward_program_cuda`; backward `registered_reverse_executor_bindings`; backward executors `physical_tiny_message_backward_executor`, `physical_diagonal_recurrence_backward_executor`, `registered_sender_kv_projection_backward_executor`, `projection_reduction_boundary_backward`, `physical_temporal_bucket_sequence_backward`, `cuda_temporal_backward_glue`; full transition tape; `store_step_artifacts` | + +Owner assessment: + +- The current path is compiler-owned at the registered program level: primitive + executor blockers are absent, forward owner is + `registered_fused_forward_program_cuda`, and runtime entrypoint is + `registered_temporal_fused_forward_program_cuda`. +- At the time of this first owner table, the audit option + `--require-cuda-temporal-owner` was stale for the new closure naming and + still expected `cuda_temporal_superop`. That audit gate was fixed in the + later registered memory/liveness slice below; this first table should be read + as historical pre-fix evidence. +- The first live throughput blocker is not missing compiler ownership. It is + the registered program's current execution/memory behavior: + T=1 forward is already far below April 21 and far above the April 21 memory + floor; T=1 training adds the registered reverse executor chain, full + transition tape, and `store_step_artifacts`; Axon training does not fit on an + H200 at this row. +- `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1` did not produce populated + `backward_owner_timing_ms` for the registered path in this audit. The owner + table therefore uses planner/runtime owners and launch-count metadata. Before + changing throughput code, the next analysis step should get a warmed kernel + or registered-stage profile for the sLSTM training row and the Axon forward + memory row. + +Next owner hypothesis: + +1. **Memory/liveness/artifact owner is first-class.** Axon forward uses + `93.23 GiB` before backward and Axon training OOMs; this points at + registered forward program materialization/layout/workspace and T=1 artifact + liveness, not only reverse math. +2. **Training compute owner is the registered reverse chain.** sLSTM training + is `973.50 tok/s`, with backward routed through message, transition, + recurrent K/V, readout, temporal glue, full tape, and stored artifacts. +3. **Instrumentation owner is open.** The registered path needs owner timing or + profiler attribution before code changes, otherwise the next optimization + would be guesswork. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner + +Status: plan only. Do not start implementation from this section without first +running the third prompt. + +Highest-impact owner: + +- Primary: registered-program memory/liveness/artifact execution. +- Evidence: Axon 100M T=1 forward already uses `93.23 GiB`, and Axon T=1 + training OOMs even in an isolated process. sLSTM T=1 training also uses + `56.79 GiB`, `27.4x` April 21 memory, and only `1.66%` of April 21 tok/s. +- Secondary: registered reverse execution chain compute, but only after the + memory/liveness owner is measured and narrowed. Starting with reverse math + before memory attribution risks optimizing the wrong surface. + +Compiler-boundary invariant: + +- Keep all work inside compiler-owned products: + `memory_liveness_rows`, `memory_runtime_schedule_rows`, + `forward_artifact_route_rows`, `forward_artifact_merge_rows`, + `reverse_artifact_consumer_route_rows`, executor rows, binding rows, and + registered forward/backward program entrypoints. +- Do not add family, benchmark, hidden-size, single-pop, or Axon/sLSTM branches. +- Do not move message, readout, gated, or diagonal formulas into temporal + scheduler code. Primitive math remains in registered primitive executors. +- Do not use benchmark-side tiling, checkpointing, detach policy, or direct + runtime helper calls as a fix. + +Implementation plan: + +1. **Fix attribution before changing strategy.** + - Update the audit owner gate naming so registered temporal program owners + are accepted as compiler-owned. Do not keep a legacy runtime alias in the + active gate; historical artifacts are historical evidence, not active + acceptance criteria. + - Add registered-program stage timing for the current path: + forward program, message span, transition span, readout span, layout/state + materialization, reverse readout/message, reverse transition, recurrent + K/V projection, parameter reducer, boundary projection, artifact store, and + runtime buffer allocation. + - Add a memory ledger to the audit result that separates model parameters, + static tensor tables, registered runtime buffers, forward artifacts, + transition tape, user-visible input/output/target tensors, and PyTorch + allocator reserve. + - Rerun only the measured owner rows: + sLSTM 100M T=1 training and Axon 100M T=1 forward/training. No optimization + is accepted until the dominant memory class is visible. + +2. **Make memory/liveness rows executable, not descriptive.** + - Audit `memory_liveness_rows` against actual allocations in + `registered_temporal_fused_forward_program_cuda` and + `registered_temporal_fused_backward_program_cuda`. + - Any runtime tensor that is allocated outside the compiler memory plan must + either be assigned to a row/lifetime/alias set or removed. + - Make runtime buffer allocation consume the compiler plan for workspace + reuse and aliasing. Alias only where the compiler says lifetimes do not + overlap. + - Add fail-closed validation: if a registered program needs a workspace or + artifact not declared in the memory plan, reject before launch. + +3. **Replace T=1 full artifact storage with route-minimal artifacts.** + - For `physical_time_steps == 1`, derive the required reverse artifact set + from `reverse_artifact_consumer_route_rows` and + `forward_artifact_route_rows`, not from blanket `store_step_artifacts`. + - Store only tensors consumed by the selected reverse spans. Avoid full-cell + and full-bank materialization when only recurrent/public slices are needed. + - Preserve correctness by keeping every artifact role explicit: + boundary step, recurrent K/V before/after, recurrent hidden before/after, + recurrent/output messages, output cells, transition state-before, and tape + inputs are included only when a reverse consumer row requires them. + - Add a negative test where a reverse consumer requests an undeclared artifact + and launch fails closed. + +4. **Make transition tape policy row-owned for T=1.** + - Keep the current full-tape path as the legal baseline while measuring. + - Add a compiler-selected T=1 tape policy only when the primitive executor + declares that logits/preprojection can be recomputed or elided from + existing row-owned inputs. + - Select compact or full tape through memory/liveness and primitive executor + legality, not through family names or benchmark row labels. + - Gate on parity for sLSTM and Axon training before using compact tape in + performance rows. + +5. **Then attack registered reverse compute.** + - Once Axon forward memory is below the current OOM cliff and sLSTM training + memory is materially lower, rerun warmed stage timing. + - If reverse remains dominant, fuse the registered reverse window stages over + the existing reverse executor rows and binding rows: + readout/message, transition, recurrent K/V, boundary, and reducer. + - The output must still be `registered_temporal_fused_backward_program_cuda` + consuming compiler tables; do not reintroduce direct reverse wrappers or + temporal-side primitive formulas. + +6. **Validation gates for each slice.** + - Source/static: registered program still requires compiler tables and fails + closed on missing memory/artifact rows. + - Parity: targeted T=1 sLSTM and Axon forward/training, reset absent first; + reset-present before touching reset/tape/state ownership further. + - Performance: rerun the same owner-table rows with private cache dirs and + warmed iterations. Record artifacts in this doc. + - Memory: current peak GiB must move down on Axon forward before reverse-only + compute work counts as progress. + +Expected first implementation outcome: + +- The next code slice should not try to hit April 21 throughput immediately. + It should produce a trustworthy registered-program timing/memory ledger and + reduce the T=1 memory cliff enough that Axon 100M training no longer OOMs. + Only after that should reverse-kernel throughput be the top owner. + +### 2026-05-03 - Registered T=1 Memory/Liveness Slice 1 + +Status: partial progress. This slice stayed inside registered compiler-owned +strategies and did not add benchmark-side tiling, family routes, direct CUDA +wrappers, or temporal-scheduler formulas. + +Code changes: + +- The audit CUDA-owner gate now accepts the current registered compiler owners: + `registered_fused_forward_program_cuda`, + `registered_temporal_fused_forward_program_cuda`, + `registered_reverse_executor_bindings`, and + `physical_temporal_bucket_sequence_backward`. The stale expectation that the + active owner must be named `cuda_temporal_superop` was removed. +- Runtime memory buffer plans now emit byte ledgers: + `planned_buffer_bytes`, `estimated_allocated_buffer_bytes`, + `bytes_by_workspace`, and `bytes_by_runtime_role`. +- Runtime planner metadata now records compiler memory buffer/artifact/schedule + summaries in `workspace_aliases`. +- Forward transition output materialization now follows registered primitive + output-binding requirements. For T=1 forward paths with no final-state carry, + optional transition outputs are removed from the active + `forward_executor_binding_rows` and runtime buffer requests. +- Transition state-before artifacts stay complete for backward. The forward + program skips only intentionally omitted next-state copies by checking whether + the carry source output binding is active in `forward_executor_binding_rows`. + This keeps reverse state-before inputs available without forcing every + optional next-state output to be carried in inference. +- Reverse currently keeps optional transition outputs materialized during its + adjoint recompute. That is the correctness baseline until a reverse strategy + proves which optional outputs can be elided or recomputed. + +Validation commands: + +- `python -m py_compile benchmarks/fabric/audit.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` +- `uv run pytest -q tests/test_fabric_audit_runner.py --tb=short` +- `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_audit_runner.py --tb=short` +- `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_throughput_registered_path_20260503c TRITON_CACHE_DIR=/tmp/cortical_triton_throughput_registered_path_20260503c uv run pytest -q tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_uses_registered_temporal_program --tb=short` +- `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_throughput_registered_axon_smoke_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_throughput_registered_axon_smoke_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference --tb=short -k axoncell` + +Validation results: + +- Source/audit/compiler tests: `21 passed`. +- Registered CUDA route: `2 passed`. +- Axon registered parity smoke: `1 passed, 1 deselected`. + +Performance/audit commands: + +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_throughput_registered_axon_forward_metadata_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_throughput_registered_axon_forward_metadata_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_registered_optional_forward_policy_axon_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_throughput_registered_axon_optional3_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_throughput_registered_axon_optional3_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_registered_optional_forward_only_axon_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_throughput_registered_slstm_train_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_throughput_registered_slstm_train_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_registered_optional_forward_only_slstm_train_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` + +Results: + +| row | status | tok/s | peak GiB | planned runtime buffer bytes | owner note | +| --- | --- | ---: | ---: | ---: | --- | +| Axon 100M h32 B1024 T1 forward | ok | 4083.11 | 79.20 | 15.30B | registered forward program; memory down from initial `93.23 GiB`; metadata reports `transition_optional_outputs=elided;omitted_bindings=8;carry_rows=10`; still far above April 21 `2.07 GiB` | +| Axon 100M h32 B1024 T1 training | OOM | - | >139 GiB process use | 17.45B before failure | reverse adjoint still materializes optional transition outputs and full artifact/tape path | +| sLSTM 100M h32 B1024 T1 training | ok | 976.09 | 56.79 | 10.67B | essentially unchanged throughput/memory; registered reverse chain remains dominant | + +Open owner after this slice: + +- The first memory move was real but not sufficient. Axon forward improved by + about `14 GiB`, but still uses `79.20 GiB`; Axon training still OOMs. +- The byte ledger shows large runtime-owned buffers remain: + Axon forward records `transition_forward_diag_output=7.52B`, + `transition_forward_linear_output=3.76B`, `forward_recurrent_hidden_after=1.88B`, + and `forward_recurrent_msg=1.88B`. +- Training OOM is now specifically the reverse/tape/artifact owner, not a + launch-ownership failure. Reverse needs a compiler-declared compact + transition-tape/artifact strategy before Axon T=1 training can fit. +- Next slice should make `reverse_artifact_consumer_route_rows` and transition + tape policy drive T=1 artifact/tape allocation. It should not start with + reverse compute throughput. + +### 2026-05-03 - Current T=1 Gap Owner Table + +Status: analysis only. No optimization was done for this section. + +Purpose: + +- Rebuild the current-code T=1 owner table after the registered memory/liveness + slice. +- Compare the representative h32 100M B1024 single-pop rows against the + April 21 `h32_t1_bxparams` summary floor: + `58732.71 tok/s`, `2.07 GiB`. +- Identify the first live owner before any next optimization. + +Commands/artifacts: + +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_t1_owner_slstm_forward_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_owner_slstm_forward_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_current_owner_table_slstm_forward_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm --sizes 100m --modes forward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- `tmp/fabric_audits/partials/2026-05-03/t1_registered_optional_forward_only_slstm_train_h32_100m_b1024/cases.jsonl` +- `tmp/fabric_audits/partials/2026-05-03/t1_registered_optional_forward_policy_axon_h32_100m_b1024/cases.jsonl` +- `tmp/fabric_audits/partials/2026-05-03/t1_registered_optional_forward_only_axon_h32_100m_b1024/cases.jsonl` + +Owner table: + +| row | current status | current tok/s | vs Apr21 tok/s | current peak GiB | vs Apr21 memory | planned runtime buffers | active owners | +| --- | --- | ---: | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | ok | 5690.56 | 9.69% | 20.52 | 9.9x | 9.87B | forward `registered_fused_forward_program_cuda`; runtime `registered_temporal_fused_forward_program_cuda`; no primitive blockers | +| sLSTM 100M h32 B1024 T1 training | ok | 976.09 | 1.66% | 56.79 | 27.4x | 10.67B | forward `registered_fused_forward_program_cuda`; backward `registered_reverse_executor_bindings`; reverse chain includes message, receiver affine, state epilogue, sender K/V, readout boundary, temporal backward glue | +| Axon 100M h32 B1024 T1 forward | ok | 4083.11 | 6.95% | 79.20 | 38.3x | 15.30B | forward `registered_fused_forward_program_cuda`; optional transition outputs elided for forward; no primitive blockers | +| Axon 100M h32 B1024 T1 training | OOM | - | - | >139 GiB process use | >67x before failure | 17.45B before failure | forward `registered_fused_forward_program_cuda`; backward `registered_reverse_executor_bindings`; reverse chain includes diagonal recurrence backward and temporal backward glue | + +Runtime-buffer breakdown highlights: + +- sLSTM forward planned buffers are dominated by transition outputs: + `transition_forward_linear_output=3.02B`, + `transition_forward_matmul_output=2.42B`, + `transition_forward_state_output=2.42B`, + `transition_forward_norm_output=0.60B`, + plus recurrent hidden/message buffers. +- sLSTM training adds `forward_cells_prev_artifact=0.81B` and + `store_step_artifacts`, but the peak jump from `20.52 GiB` forward to + `56.79 GiB` training is larger than the forward runtime-buffer delta. The + remaining owner is the registered reverse chain, parameter gradients, saved + artifacts/tape, and allocator residency around backward. +- Axon forward planned buffers are dominated by + `transition_forward_diag_output=7.52B`, + `transition_forward_linear_output=3.76B`, + `forward_recurrent_hidden_after=1.88B`, and + `forward_recurrent_msg=1.88B`. Optional transition outputs are elided for + forward (`omitted_bindings=8`), but the row still uses `79.20 GiB`. +- Axon training switches to `store_step_artifacts`, records + `forward_cells_prev_artifact=2.15B`, keeps the diagonal trace/tape path + available for reverse, and OOMs with the process at about `138.54 GiB`. + +Owner conclusion: + +- This is not currently a compiler-ownership miss. The measured rows report + registered compiler forward/backward owners and no primitive executor + blockers. +- The first live T=1 gap owner is **memory/liveness/artifact/tape execution**, + not raw kernel math. Forward rows are already 9.9x to 38.3x over the April 21 + memory floor, and Axon training cannot produce a throughput number. +- The second owner is **registered reverse execution cost**. sLSTM training + runs at only `1.66%` of April 21 while forward runs at `9.69%`, so backward is + a larger compute gap after memory fits. +- The current audit table is still single-pop only. Mixed-pop T=1 remains a + required closure axis because mixed-pop should be bucket cardinality under the + same flat-bucket engine, not a separate route. + +Next analysis-only owner to narrow before optimization: + +1. Attribute the training peak beyond planned runtime buffers: model/static + tensors, parameter gradients, reverse artifacts, transition tape, and PyTorch + allocator reserve. +2. Inspect which `reverse_artifact_consumer_route_rows` and transition seed/tape + rows require full diagonal trace state for Axon T=1. +3. Run a current mixed-pop T=1 owner probe after the single-pop owner is + narrowed enough to avoid conflating bucket-cardinality issues with the Axon + OOM. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner + +Status: plan only. No optimization was done for this section. + +Highest-impact owner: + +- Primary owner: **registered memory/liveness/artifact/tape execution**. +- Blocking symptom: Axon 100M h32 B1024 T1 training OOMs before throughput can + be measured, while forward-only rows are already 9.9x to 38.3x over the + April 21 memory floor. +- Secondary owner: **registered reverse compute**. sLSTM training is only + `1.66%` of April 21 while sLSTM forward is `9.69%`, but reverse compute should + not be optimized before Axon training fits. + +Compiler-boundary rules for the closure work: + +- Optimize only through registered compiler products: + `memory_liveness_rows`, `memory_runtime_schedule_rows`, + `forward_artifact_route_rows`, `forward_artifact_merge_rows`, + `reverse_artifact_consumer_route_rows`, `transition_reverse_seed_role_rows`, + executor rows, binding rows, native callable rows, and reducer rows. +- Do not add family-specific, hidden-size-specific, benchmark-row-specific, or + single-pop-only policy. +- Do not move diagonal, gated, attention, projection, normalization, or readout + formulas into the temporal scheduler. Primitive math stays in registered + primitive executors. +- If a reverse artifact, tape tensor, workspace, or gradient buffer is needed, + it must have a compiler row, lifetime, route, and legality reason. Otherwise + launch fails closed before execution. + +Plan: + +1. **Make the training memory ledger complete.** + - Extend audit/runtime metadata so every large training allocation is + attributed to one class: model/static tensors, parameter gradients, + runtime buffers, reverse artifacts, transition tape, reducer outputs, or + PyTorch allocator reserve. + - Add explicit byte summaries for reverse artifact tensors and transition + tape tensors, parallel to the existing runtime-buffer byte ledger. + - Rerun only the current owner rows: + sLSTM 100M T1 training, Axon 100M T1 forward, and Axon 100M T1 training. + - Acceptance for this step: Axon OOM must have a named byte owner, not just + a generic CUDA OOM message. + +2. **Replace blanket T1 artifact storage with consumer-routed artifacts.** + - Use `reverse_artifact_consumer_route_rows` to compute the exact artifact + demand set for each selected reverse span. + - Materialize only demanded `(producer route, artifact role, bucket, step)` + entries. Keep truly global roles global only when required, such as + boundary step and cells-prev. + - Make output message/cells and transition state-before route-owned, not + broad role-owned. + - Delete or fail-close any path that asks the runtime for + `store_step_artifacts` as a blanket training mode without a consumer route. + - Acceptance: sLSTM training peak moves down, Axon training gets past the + previous artifact allocation frontier or fails with a narrower named tape + owner. + +3. **Add compiler-owned compact transition tape policy for T1.** + - Extend transition primitive strategy metadata with a tape contract: + required saved outputs, recomputable outputs, state-before needs, parameter + gradient needs, and reset compatibility. + - For each transition primitive, legality chooses one of: + `full_tape`, `state_before_plus_recompute`, `public_state_only`, or + `unsupported_for_compact_tape`. + - Select the policy through compiler rows and memory/liveness planning, not + through Axon/sLSTM names. + - Reverse must consume the selected tape contract through binding/access + rows. If a primitive backward requests a tensor outside its contract, fail + closed before launch. + - Acceptance: Axon T1 training fits without changing semantics, or the + remaining OOM is no longer transition trace/tape storage. + +4. **Make runtime-buffer aliasing executable for non-overlapping lifetimes.** + - Use memory liveness rows to alias same-shape scratch and transition output + buffers when lifetimes do not overlap. + - Start with T1-safe aliases: forward transition intermediate buffers that + are not reverse artifacts and not part of selected tape. + - Add validation that aliased buffers have matching dtype, device, shape, + workspace class, and non-overlapping live intervals. + - Acceptance: Axon forward peak drops below the current `79.20 GiB` and the + planned/allocated buffer ledger shows the moved owner. + +5. **Only then optimize registered reverse compute.** + - After Axon training fits, collect warmed owner timing for the registered + reverse program. + - Fuse or specialize reverse stages only as registered strategies over + existing reverse executor rows and binding rows: readout, message, + transition, recurrent K/V, boundary, and reducer. + - Acceptance: sLSTM training and Axon training both produce valid throughput + numbers with improved tok/s and no memory regression. + +6. **Parity and performance gates after each slice.** + - Targeted parity: registered CUDA route, Axon single-pop parity smoke, + sLSTM/Axon T1 training where feasible, reset-absent first. + - Perf/audit: same h32 100M B1024 rows as the owner table, with private + cache dirs and warmed iterations. + - Metadata gate: registered forward/backward owners, no primitive blockers, + no Python replay/fallback, and byte owners recorded for any remaining + memory miss. + +Expected next implementation slice: + +- Start with steps 1 and 2 together: complete the training memory ledger and + replace blanket T1 artifact storage with route-owned artifact allocation. +- Do not start with reverse math or April 21 throughput matching. The immediate + closure target is: **Axon 100M h32 B1024 T1 training must fit and report a + named registered owner table.** + +### 2026-05-03 - Route-Minimal Transition Artifact Slice + +Status: implementation and measurement slice. This stayed inside registered +compiler products: executor binding rows, artifact route rows, runtime buffer +plans, and primitive output contracts. No benchmark-side tiling, direct CUDA +wrapper, family branch, or temporal-side primitive formula was added. + +What changed: + +- Fused forward reverse-artifact storage no longer clones every artifact + blindly. Immutable produced artifacts reuse their existing storage; only + `transition_state_before` takes a defensive copy because it snapshots mutable + transition state before carry update. +- `transition_state_before` artifacts are now route-minimal. The fused forward + program stores a transition state-before tensor only when compiler binding + rows prove one of two things: + - a selected reverse transition executor consumes the forward input binding; + - the corresponding forward carry output binding is materialized for a + longer recompute/tape window. +- The diagonal forward primitive wrapper now treats trace-state inputs as + optional when the compiler binding rows omit the optional `next_E_*` trace + outputs. This makes the primitive output contract executable instead of only + descriptive. +- Reverse transition recompute group binding rows now receive the same + optional-output decision as the top-level registered program. T=1 terminal + training can run the non-trace diagonal recompute contract; T>1/reset parity + still keeps trace inputs when the compiler materializes trace outputs. +- Runtime metadata now records the reverse artifact tensor store byte ledger: + tensor count, binding rows, logical bytes, unique storage bytes, + counts/bytes by artifact role, and storage policy. + +Rejected intermediate result: + +- `tmp/fabric_audits/partials/2026-05-03/t1_route_min_transition_artifacts_axon_h32_100m_b1024` + reduced Axon artifacts from 23 tensors to 15 tensors but failed backward with + `registered diag e_nu_c1 references an empty program tensor`. +- Root cause: per-transition reverse recompute binding groups still used the + full forward binding rows, so the diagonal primitive requested optional trace + inputs that the route-minimal artifact policy had correctly omitted. +- Fix: thread the optional-output policy into the transition recompute group + binding rows. + +Validation commands: + +- `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py` +- `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_plan.py::test_forward_artifact_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_forward_artifact_merge_rows_are_compiler_owned_rows --tb=short` +- `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_route_min_artifacts_route2_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_route_min_artifacts_route2_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_uses_registered_temporal_program --tb=short` +- `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_route_min_artifacts_axon_parity3_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_route_min_artifacts_axon_parity3_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference --tb=short -k axoncell` +- `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_route_min_artifacts_slstm_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_route_min_artifacts_slstm_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference --tb=short -k slstm` +- `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_audit_runner.py --tb=short` + +Validation results: + +- Source/compiler artifact tests: `3 passed`. +- Registered CUDA route: `2 passed`. +- Axon registered parity smoke: `1 passed, 1 deselected`. +- sLSTM registered parity smoke: `1 passed, 1 deselected`. +- Focused source/audit suite: `21 passed`, `14 warnings`. + +Performance/audit commands: + +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_route_min_artifacts_axon_audit2_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_route_min_artifacts_axon_audit2_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_route_min_transition_artifacts_axon_h32_100m_b1024_rerun --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_route_min_artifacts_slstm_audit_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_route_min_artifacts_slstm_audit_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_route_min_transition_artifacts_slstm_train_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` + +Results: + +| row | previous status | new status | new tok/s | new peak GiB | artifact/tape note | +| --- | --- | --- | ---: | ---: | --- | +| Axon 100M h32 B1024 T1 forward | ok, `4083.11 tok/s`, `79.20 GiB` | ok | 4079.15 | 79.20 | forward unchanged; optional trace outputs still elided | +| Axon 100M h32 B1024 T1 training | OOM before throughput | ok | 821.42 | 137.57 | transition-state artifacts reduced from 10 roles / `18.79B` to 2 roles / `3.76B`; total reverse artifact store `34.76B -> 19.73B` | +| sLSTM 100M h32 B1024 T1 training | ok, `963.85 tok/s`, `56.70 GiB` | ok | 961.89 | 56.70 | unchanged; gated recurrence needs its four state-before tensors | + +Historical owner note after this slice: + +- The route-minimal artifact rerun temporarily removed the Axon OOM cliff, so + T=1 training produced one measurable registered compiler row. +- This is still not throughput closure. Axon training uses `137.57 GiB`, + roughly `66x` the April 21 `2.07 GiB` memory floor, and only `1.40%` of the + April 21 `58732.71 tok/s` floor. +- The next highest owner remains memory/liveness execution: + - Axon forward still uses `79.20 GiB`. + - Axon training still carries `19.73B` logical reverse artifacts, including + full recurrent hidden/K/V/message artifacts and `2.15B` cells-prev. + - sLSTM training remains at `56.70 GiB`. +- Next implementation slice should make artifact roles shape/slice-aware and + make runtime-buffer aliasing executable for non-overlapping lifetimes before + spending time on reverse math throughput. + +### 2026-05-03 - Current T=1 Owner Table After Route-Minimal Artifacts + +Status: analysis-only. No throughput optimization or execution-code change was +made in this pass. The purpose was to rebuild the T=1 owner table against the +current tree and the April 21 `h32_t1_bxparams` summary floor. + +Reference: + +- April 21 source: + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json` +- Reference row: `h32_t1_bxparams`, summary floor `58732.71 tok/s`, + `2.07 GiB`. +- The April 21 file reports this as a 24-row matrix summary, not exact + per-family raw rows, so the current table uses the same summary floor for + all four current T=1 rows. + +Commands: + +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_t1_current_owner_after_route_min_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_current_owner_after_route_min_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_current_owner_table_after_route_min_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_t1_current_owner_after_route_min_axon_train_iso_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_current_owner_after_route_min_axon_train_iso_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_current_owner_table_after_route_min_axon_train_isolated_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- `CUDA_VISIBLE_DEVICES=1 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_t1_current_owner_after_route_min_axon_train_iso_gpu1_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_current_owner_after_route_min_axon_train_iso_gpu1_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_current_owner_table_after_route_min_axon_train_isolated_gpu1_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` + +Current owner table: + +| row | current status | current tok/s | vs Apr21 tok/s | current peak GiB | vs Apr21 memory | compiler owners | current byte owner | +| --- | --- | ---: | ---: | ---: | ---: | --- | --- | +| sLSTM 100M h32 B1024 T1 forward | ok | 5674.70 | 9.66% | 20.52 | 9.92x | forward `registered_fused_forward_program_cuda`; implementation `registered_temporal_fused_forward_program_cuda` | planned runtime buffers `9.87B`; largest roles: linear output `3.02B`, matmul output `2.42B`, state output `2.42B`, norm/recurrent message/recurrent hidden `0.60B` each | +| sLSTM 100M h32 B1024 T1 training | ok | 968.24 | 1.65% | 56.70 | 27.39x | forward owner above; backward `registered_reverse_executor_bindings`; reverse executors include tiny message, receiver affine, state epilogue, sender K/V, readout boundary, temporal sequence backward, CUDA glue | planned runtime buffers `10.67B`; reverse artifact store `7.95B`; artifact roles include full recurrent hidden/K/V/message banks and `transition_state_before=2.42B` | +| Axon 100M h32 B1024 T1 forward | ok | 4098.17 | 6.98% | 79.23 | 38.28x | forward `registered_fused_forward_program_cuda`; implementation `registered_temporal_fused_forward_program_cuda` | planned runtime buffers `15.30B`; largest roles: diagonal output `7.52B`, linear output `3.76B`, recurrent message `1.88B`, recurrent hidden `1.88B` | +| Axon 100M h32 B1024 T1 training | current OOM | - | - | OOM before peak recorded | >67x failure frontier inferred from prior `137.57 GiB` fit | forward owner above; backward `registered_reverse_executor_bindings`; reverse executors include tiny message, diagonal recurrence, sender K/V, readout boundary, temporal sequence backward, CUDA glue | planned runtime buffers `17.45B`; reverse artifact store `19.73B`; artifact roles include full recurrent hidden/K/V/message banks and `transition_state_before=3.76B`; current isolated reruns OOM with `Triton Error [CUDA]: out of memory` | + +Important correction: + +- `tmp/fabric_audits/partials/2026-05-03/t1_route_min_transition_artifacts_axon_h32_100m_b1024_rerun` + previously fit Axon training once at `821.42 tok/s`, `137.57 GiB`. +- Current reruns of the same logical row OOMed in the combined four-row audit, + then OOMed again in isolated processes on GPU 0 and GPU 1 with no other GPU + memory resident. +- Therefore the previous "OOM cliff removed" note is historical evidence only, + not closure. The current owner table treats Axon T=1 training as not fitting + reliably. + +Owner ranking: + +1. **T=1 training memory/liveness reliability.** Axon training is again the + boundary row because the current row OOMs even after route-minimal transition + artifacts. The historical fit at `137.57 GiB` left only a small margin on a + `143.77 GiB` H200, so allocator reserve, unplanned reverse workspaces, + parameter-gradient buffers, or transient Triton allocations can push the row + over the frontier. +2. **Forward program runtime materialization.** Forward-only rows are already + far above the April 21 memory floor: sLSTM `9.92x`, Axon `38.28x`. The + largest named buffers are transition output/materialization roles, not output + sequence storage. +3. **Reverse artifact/tape and hidden/K/V bank retention.** sLSTM training adds + roughly `36.17 GiB` over sLSTM forward while the compiler-reported reverse + artifact store is `7.95B`; Axon training reports `19.73B` artifacts before + OOM. The missing delta must be accounted for in an executable memory ledger + before reverse math throughput can be interpreted. +4. **Backward compute throughput.** Current successful sLSTM training is only + `1.65%` of April 21 tok/s, and the historical Axon fit was `1.40%`, but + compute tuning is secondary until the row fits with named memory owners. +5. **Mixed-pop T=1 coverage.** This table is still single-pop only. Mixed-pop is + still part of T=1 closure, but the single-pop Axon training memory owner must + fit first or mixed-pop results will be dominated by the same frontier. + +Next analysis/implementation target: + +- Build a complete current memory ledger for Axon T=1 training that accounts + for compiler-planned runtime buffers, reverse artifact tensor store, model + parameters, parameter gradients, static tensor packs, transition/reverse + workspaces, Triton temporaries, CUDA allocator reserve, and any unreported + autograd saves. +- The next code slice should reduce or alias the named liveness products inside + compiler-owned rows. It should not start with reverse math kernels or copied + April 21 strategies. + +### 2026-05-03 - Plan To Close Current Highest-Impact T=1 Owner + +Highest-impact owner: **Axon 100M h32 B1024 T1 training memory/liveness +reliability**. The row currently OOMs even though the active path is the +registered compiler-owned temporal program. This must be closed before reverse +math throughput tuning, because the row cannot produce stable timing evidence. + +Boundary invariant: + +- Every runtime tensor retained for backward must be demanded by a compiler + product: `memory_liveness_rows`, `memory_runtime_schedule_rows`, + `forward_artifact_route_rows`, `forward_artifact_merge_rows`, + `reverse_artifact_consumer_route_rows`, executor binding rows, or parameter + binding rows. +- External categories such as model parameters, parameter gradients, static + tensor packs, CUDA allocator reserve, and framework autograd saves must be + recorded separately. They cannot be hidden inside "workspace" or benchmark + behavior. +- No optimization in this owner may add a family branch, benchmark-shape branch, + temporal-side recurrence/message formula, private fallback, or copied April 21 + route. + +Plan: + +1. **Make the OOM row explain itself.** + - Extend the audit/runtime metadata for the T=1 row with a complete memory + ledger: compiler runtime buffers, reverse artifact tensor store, model + parameter bytes, parameter-gradient bytes, static/prepacked tensors, + transition/reverse runtime buffers, Triton/native-callable temporaries, + PyTorch allocated/reserved/free snapshots, and any unclassified remainder. + - Capture this ledger on success and on OOM. OOM evidence that only says + `Triton Error [CUDA]: out of memory` is not enough for the next slice. + - Acceptance: Axon T1 training OOM artifacts identify the largest named byte + classes and the unclassified remainder; no code path is optimized yet. + +2. **Turn unclassified retained tensors into compiler-owned products or delete + them.** + - Add a source/metadata guardrail: every reverse artifact tensor and runtime + buffer used by the registered path must have a row id, role, bucket, + physical step, producer route, consumer route, lifetime policy, and storage + policy. + - If a tensor is required but lacks a row, add the missing compiler row. If + it is not required by any consumer route, stop allocating it. + - Acceptance: the Axon T1 training ledger has no large unclassified Fabric + tensor class before the first memory-reduction patch lands. + +3. **Remove the full `cells_prev` artifact from T=1 terminal training unless it + is truly demanded.** + - Current Axon artifacts include a `2.15B` full-cell `cells_prev` tensor. For + terminal T=1 training without materialized final state, reverse generally + needs boundary slices, recurrent state-before, output-cell artifacts, and a + shape/template contract, not necessarily a full `[B, all_cells, H]` clone. + - Replace this with route-owned component artifacts or shape metadata: + `boundary_step`, recurrent-hidden-before if consumed, output cells if + consumed, and a compiler-owned state-template descriptor for final carry + materialization only when needed. + - Acceptance: `cells_prev` disappears from the Axon T1 reverse artifact byte + ledger for the terminal/no-final-state row, or remains with a specific + reverse consumer route that proves it is required. + +4. **Prune or recompute recurrent banks by consumer route.** + - Current Axon reverse artifacts retain full recurrent hidden/K/V/message + banks: before and after K/V, hidden before/after, and recurrent message. + - For T=1 absent-reset/terminal training, introduce compiler route policies + such as `materialize`, `reuse_existing_storage`, `implicit_initial_zero`, + and `recompute_from_program_tensor` where the reverse executor contract can + prove equivalence. + - Do not hardcode "Axon" or "T=1" inside the temporal engine. The legality + check must come from reset policy, state-provided policy, output + materialization, and reverse consumer routes. + - Acceptance: recurrent bank artifact bytes drop materially, and parity still + passes for sLSTM and Axon T1 single-pop CUDA rows. + +5. **Make runtime-buffer aliasing executable for non-overlapping lifetimes.** + - The current compiler can report alias sets, but high-cost named roles still + allocate separately when they are not generic workspace roles. The biggest + forward owners are Axon `transition_forward_diag_output=7.52B` and + `transition_forward_linear_output=3.76B`. + - Extend the runtime schedule so same-dtype/device buffers with proven + non-overlapping lifetimes can share storage through compiler rows. Use exact + shape aliasing first; only add larger backing-buffer views if the row schema + records offset/size/alignment and C++ validation enforces it. + - Acceptance: planned allocated bytes, not just planned logical bytes, move + down; source tests reject aliasing without a compiler schedule proof. + +6. **Fuse producer/consumer transition materialization only through registered + primitive executor contracts.** + - Where a transition primitive output has a single immediate consumer and no + reverse artifact/tape consumer, let the registered native callable write + directly to the consumer/output slot selected by binding rows. + - This is not a temporal-engine formula optimization. It is a primitive + executor storage contract selected by row legality, tensor bindings, and + memory schedule. + - Acceptance: Axon forward peak drops below the current `79.23 GiB` before + attempting deeper reverse throughput work. + +7. **Prove the owner is closed before moving to compute throughput.** + - Required evidence: + - Axon T1 training fits in isolated runs on at least two fresh H200 devices. + - The four-row single-pop table runs in one process without OOM: + sLSTM forward, sLSTM training, Axon forward, Axon training. + - Axon training has meaningful headroom below the H200 frontier, not a + one-off `137 GiB` near-fit. + - Metadata reports registered forward/backward owners, no primitive + blockers, no hidden fallback, and named byte owners for the remaining + peak. + - Targeted parity passes for affected T1 rows. + +First implementation slice after this plan: + +- Add the success/OOM memory ledger and the guardrail for unclassified + registered-path tensors, then use that evidence to choose between + `cells_prev` deletion and recurrent-bank recompute/aliasing as the first + reducing patch. + +### 2026-05-03 - Current T=1 Owner Table, No New Optimization + +Status: analysis only for this prompt. No throughput optimization was started. +Before running the audit, an interrupted half-applied autograd-neededness edit +was removed and Python compilation passed for the touched temporal modules. + +Command: + +- `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_t1_owner_current_noopt_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_owner_current_noopt_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_owner_table_current_noopt_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` + +Artifacts: + +- `tmp/fabric_audits/partials/2026-05-03/t1_owner_table_current_noopt_h32_100m_b1024/summary.json` +- `tmp/fabric_audits/partials/2026-05-03/t1_owner_table_current_noopt_h32_100m_b1024/cases.jsonl` + +Reference: April 21 `h32_t1_bxparams` summary floor, +`58732.71 tok/s`, `2.07 GiB`. + +| row | status | tok/s | vs Apr21 tok/s | peak GiB | vs Apr21 memory | named runtime/artifact owners | +| --- | --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | ok | 5688.73 | 9.69% | 20.52 | 9.9x | forward owner `registered_fused_forward_program_cuda`; runtime buffers `9865142576` bytes; artifact store n/a | +| sLSTM 100M h32 B1024 T1 training | ok | 978.51 | 1.67% | 56.70 | 27.4x | forward owner `registered_fused_forward_program_cuda`; backward owner `registered_reverse_executor_bindings`; runtime buffers `9865142576` bytes; reverse artifacts `7147094016` bytes | +| Axon 100M h32 B1024 T1 forward | ok | 4094.65 | 6.97% | 79.23 | 38.3x | forward owner `registered_fused_forward_program_cuda`; runtime buffers `15300960560` bytes; artifact store n/a | +| Axon 100M h32 B1024 T1 training | OOM | - | - | 137.57 max allocated before failure | 66.5x before failure | forward owner `registered_fused_forward_program_cuda`; backward owner `registered_reverse_executor_bindings`; runtime buffers `15300960560` bytes; reverse artifacts `17582522368` bytes | + +Current owner details: + +- The active path is still the registered compiler-owned path: forward owner is + `registered_fused_forward_program_cuda`; training backward owner is + `registered_reverse_executor_bindings`; primitive executor blockers are absent + in these rows. +- The table is dominated by memory/liveness, not missing owner selection. + Forward-only Axon already consumes `79.23 GiB`; Axon training OOMs after + recording `137.57 GiB` max allocated. +- Runtime buffer owners are explicit. Axon forward/training planned runtime + buffers total `15300960560` bytes, led by + `transition_forward_diag_output=7516192768`, + `transition_forward_linear_output=3758096384`, + `forward_recurrent_hidden_after=1879048192`, and + `forward_recurrent_msg=1879048192`. +- Reverse artifact owners are explicit. Axon training stores `17582522368` + logical artifact bytes, led by `transition_state_before=3758096384` and six + recurrent/message banks at `1879048192` bytes each: + recurrent hidden before/after, recurrent K/V before/after, and recurrent + message. +- The current next owner remains **T=1 training memory/liveness reliability**. + Compute throughput tuning is still premature because Axon T1 training cannot + produce a stable timing row. + +Next analysis target before any optimization: + +- Account for the remaining gap between named Fabric-owned bytes and the + `137.57 GiB` allocator peak on Axon training: runtime buffers, reverse + artifacts, model params/param grads, static/prepacked tensors, transition + native-callable temporaries, CUDA allocator reserve, and unclassified autograd + saves. +- The first code-changing prompt should reduce a named compiler product such as + route-owned recurrent bank artifacts, transition-state-before artifacts, or + executable runtime-buffer aliasing. It should not start with reverse math + formulas or benchmark-side policy. + +### 2026-05-03 - Plan To Close Current T=1 Highest-Impact Owner + +Status: plan only. No optimization was started in this prompt. + +Highest-impact owner: + +- **T=1 training memory/liveness reliability**, with Axon 100M h32 B1024 + terminal training as the boundary row. +- Current blocking symptom: Axon training OOMs after recording `137.57 GiB` + max allocated. This prevents stable timing and makes reverse compute tuning + premature. +- Current named byte owners: + - forward/runtime buffers: `15300960560` bytes; + - reverse artifact tensor store: `17582522368` bytes; + - largest runtime roles: `transition_forward_diag_output=7516192768`, + `transition_forward_linear_output=3758096384`, + `forward_recurrent_hidden_after=1879048192`, + `forward_recurrent_msg=1879048192`; + - largest reverse artifacts: `transition_state_before=3758096384`, plus + recurrent hidden/K/V/message banks at `1879048192` bytes each. + +Compiler-boundary rule for this owner: + +- Every reduction must be expressed as a compiler-owned row/binding/storage + policy change: memory liveness rows, runtime schedule rows, artifact route + rows, merge rows, consumer route rows, executor binding rows, or registered + primitive strategy contracts. +- Do not add Axon/sLSTM branches, `T=1` branches in temporal math, benchmark + tiling, direct CUDA wrappers, copied April 21 code, or recurrence/message + formulas inside the temporal scheduler. + +Plan: + +1. **Make the remaining allocator gap measurable before reducing it.** + - Extend result metadata enough to split Axon OOM peak into: + model params, param grads, compiler runtime buffers, reverse artifacts, + static/prepacked tensors, transition native-callable temporaries, reverse + runtime buffers, CUDA reserved/unallocated memory, and unclassified bytes. + - Acceptance: the Axon OOM row explains the difference between the named + `~32.9B` Fabric runtime/artifact bytes and the `137.57 GiB` max allocation. + +2. **Elide final carry/state-gradient materialization when autograd does not + require input-state gradients.** + - The public training row uses external loss on terminal output and does not + request/materialize final state. If the flattened input state tensors do + not require gradients, the reverse program should not allocate a full + `[B, input+recurrent+output, H]` grad-carry cells buffer only to return a + discarded state gradient. + - Represent this as a backward runtime support row / memory runtime policy + bit such as `materialize_grad_carry_cells`, derived from autograd + requiredness, window position, output contract, and materialized-final-state + policy. + - C++ validation must reject missing carry buffers when the policy says they + are required. + - Acceptance: targeted parity still passes for routes where state gradients + are required; terminal/no-final-state benchmark rows omit the carry buffer + by policy and Axon training peak moves down. + +3. **Route-prune recurrent-bank reverse artifacts.** + - Current Axon training stores recurrent hidden before/after, K/V before/after, + and recurrent message. Each is `1879048192` bytes. + - Add artifact access requiredness by reverse consumer route, not global + role-only requiredness. A producer artifact should be stored only if a + selected reverse executor consumes its route. + - For T=1 absent-reset/provided-initial-state cases, legal alternatives + include `reuse_existing_storage`, `implicit_initial_zero`, and + `recompute_from_program_tensor`, but only if the route verifier proves the + reverse executor contract is satisfied. + - Acceptance: reverse artifact tensor-store bytes drop materially without + adding cell-family or benchmark-row selectors. + +4. **Reduce `transition_state_before` artifacts by binding-level demand.** + - Current Axon training stores `transition_state_before=3758096384` bytes. + - Keep only state-before bindings consumed by reverse transition executors or + parameter reducers. If a state-before value is only a template for an + output that is not materialized, remove it from artifact rows. + - Acceptance: artifact binding rows explain every retained + `transition_state_before` tensor by producer route, binding index, consumer + route, and reducer need. + +5. **Make runtime-buffer aliasing executable for forward transition products.** + - The largest forward memory owners are transition outputs. Some are + producer/consumer temporaries with non-overlapping lifetimes. + - Extend the runtime schedule from "reports alias sets" to "allocates shared + backing storage" for proven same-dtype/device non-overlapping buffers. + Start with exact-shape aliases; only add offset/size views if rows carry + explicit offset, span, and alignment. + - Acceptance: allocated runtime bytes decrease, not just logical bytes, and + source/C++ guards reject aliasing without a schedule proof. + +6. **Rebuild the owner table and only then choose compute work.** + - Required post-patch evidence: + - registered route smoke; + - sLSTM and Axon single-pop parity for the touched routes; + - the four-row T=1 owner table in one process; + - isolated Axon training rerun; + - metadata showing registered forward/backward owners, no primitive + blockers, and named remaining byte owners. + - Closure for this owner means Axon T1 training fits with meaningful memory + headroom. Only after that should the next plan target reverse compute + throughput. + +First implementation prompt should start with steps 1 and 2: improve the memory +ledger and make grad-carry cells materialization policy-owned. If that does not +move the Axon row enough, continue immediately into recurrent-bank route pruning. + +### 2026-05-03 - Grad-Carry Materialization Policy Slice + +Status: implemented and measured. This is a reliability/memory-liveness slice, +not throughput closure. + +Compiler-owned changes: + +- Added reverse runtime support metadata for + `grad_carry_materialization_policy`. +- Threaded `materialize_grad_carry_cells` from autograd state-gradient + requiredness and backward-window position into the registered reverse program. +- Split local reverse workspace from returned carry materialization: + `reverse_grad_cells_work` is still allocated for the fused step, while + `reverse_grad_carry_cells` is omitted for terminal single-step windows that do + not need to return input-state gradients. +- Updated fused C++ reverse validation so a missing carry buffer is legal only + when there is no earlier local step to consume it. +- Added benchmark memory-ledger attribution for compiler runtime buffers, + reverse artifact storage, named runtime/artifact bytes, CUDA reserved gap, and + unclassified peak bytes. + +Checks: + +- `python -m py_compile benchmarks/fabric/suite_common.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` +- `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_fused_program_runtime_support_rejections_are_compiler_owned_rows tests/test_fabric_audit_runner.py --tb=short` +- `CUDA_VISIBLE_DEVICES=0 ... uv run pytest -q tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_uses_registered_temporal_program tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference --tb=short -k 'registered_temporal_program or axoncell'` + +Perf artifacts: + +- Isolated Axon training: + `tmp/fabric_audits/partials/2026-05-03/t1_grad_carry_policy_axon_train_fixed_h32_100m_b1024` +- Four-row owner table: + `tmp/fabric_audits/partials/2026-05-03/t1_grad_carry_policy_owner_table_h32_100m_b1024` + +Current T=1 table after this slice: + +| row | status | tok/s | peak GiB | named runtime bytes | reverse artifact bytes | unclassified peak bytes | +| --- | --- | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | ok | 5666.00 | 20.52 | 9865142576 | 0 | 11770109776 | +| sLSTM training | ok | 977.43 | 55.95 | 9865142576 | 7147094016 | 42254132432 | +| Axon forward | ok | 4092.61 | 79.23 | 15300960560 | 0 | 69375647568 | +| Axon training | ok | 824.08 | 135.57 | 15300960560 | 17582522368 | 111881994448 | + +Effect: + +- Axon T=1 training now produces a valid row in the four-row owner table instead + of the previous OOM boundary row. +- Peak only moved by about 2 GiB, so the next highest-impact owner is still + memory/liveness, not compute. The remaining problem is now explicitly named: + large transition forward buffers, large reverse artifact banks, and a very + large unclassified allocator peak. + +Next owner: + +1. Route-prune recurrent-bank reverse artifacts by consumer route. +2. Reduce `transition_state_before` artifacts by binding-level demand. +3. Make forward transition runtime-buffer aliasing executable for proven + non-overlapping lifetimes. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner After Grad-Carry Slice + +Status: plan only. Do not start compute-kernel tuning yet. + +Highest-impact owner: + +- **T=1 training memory/liveness**, still led by Axon 100M h32 B1024 terminal + training. +- Current row now fits, but only at `824.08 tok/s` and `135.57 GiB`. +- Named byte owners: + - forward/runtime buffers: `15300960560` bytes; + - reverse artifact tensor store: `17582522368` bytes; + - unclassified CUDA peak: `111881994448` bytes; + - CUDA reserved gap: `144813624832` bytes. +- This means the next owner is still storage/liveness and allocator pressure. + Compute fusion is premature until the representative row has meaningful memory + headroom. + +Compiler-boundary rule: + +- The next patch must only change compiler-owned products: artifact route rows, + reverse consumer route rows, memory/liveness rows, runtime schedule rows, + executor binding rows, registered strategy records, and C++/Python validation + over those rows. +- It must not add Axon/sLSTM branches, benchmark-row branches, hidden-size + rules, direct wrappers, copied April21 kernels, or primitive formulas in + temporal scheduler/scan/reverse code. + +Plan: + +1. **Make reverse artifact requiredness route-owned, then prune recurrent-bank + artifacts.** + - Build a required-artifact-role/route set from + `reverse_artifact_consumer_route_rows`, reverse executor rows, and reducer + route rows. + - Store a producer artifact only when a selected reverse consumer route + requires that exact producer route. + - Keep truly global artifacts global: `boundary_step` and shape/state + template facts. Everything else should be route-owned. + - First target roles: + `recurrent_hidden_backend_order`, + `recurrent_hidden_before_backend_order`, + `recurrent_k`, `recurrent_v`, + `recurrent_k_before`, `recurrent_v_before`, + `recurrent_msg_backend_order`. + - Legal alternatives must be explicit route policies: + `materialize`, `reuse_existing_storage`, `implicit_initial_zero`, + `recompute_from_program_tensor`, or `not_required`. + - Acceptance: Axon reverse artifact bytes drop by at least one recurrent bank + chunk (`1879048192` bytes) without a cell-family or benchmark selector, and + routed C++ reverse rejects any missing artifact that is still required. + +2. **Reduce `transition_state_before` by binding-level demand.** + - Current Axon row stores `transition_state_before=3758096384` bytes. + - Split state-before artifacts by primitive row, bucket ordinal, logical + binding, and consumer route. + - Retain only state-before tensors consumed by the selected transition + reverse executor, tape contract, or parameter reducer. + - Prune state-before entries that only exist as templates for omitted + optional forward outputs. + - Acceptance: every retained transition-state-before artifact has an + auditable producer route, binding index, consumer route, and reason. + +3. **Turn runtime-buffer alias metadata into executable allocation for exact + same-shape non-overlapping transition buffers.** + - Current forward runtime buffers include + `transition_forward_diag_output=7516192768` and + `transition_forward_linear_output=3758096384`. + - Start with exact-shape aliasing only. No subviews/offset packing until rows + encode offset, size, and alignment. + - Add allocator validation that same alias-set storage is used only when the + runtime schedule proves non-overlap, same dtype, same device, compatible + shape, and no simultaneous read/write lifetime. + - Acceptance: `estimated_allocated_buffer_bytes` and actual peak both move + down. Logical bytes alone are not enough. + +4. **Use the unclassified CUDA peak as the steering signal, not as a closure + excuse.** + - Keep the new memory ledger fields and add more attribution only if the + above named owners fail to move the peak. + - If named bytes move but peak does not, run a focused current-code memory + profile before writing more kernels. + - Candidate subowners: autograd saved tensors, CUDA extension temporaries, + transition native-callable temporaries, allocator fragmentation/reserve, or + parameter-gradient reduction buffers. + +5. **Parity and perf gates for each reducing patch.** + - Required before accepting a patch: + - source guardrail: no primitive formula/family/benchmark selector leaked + into temporal scheduler or registered program glue; + - CUDA registered-route smoke; + - sLSTM and Axon single-pop parity for the touched T=1 route; + - isolated Axon T=1 training audit; + - four-row T=1 owner table in one process; + - metadata showing registered forward/backward owners, no primitive + blockers, and named remaining byte owners. + - If route pruning changes state/reset semantics, also run reset-present and + materialized-final-state parity before citing performance. + +First implementation slice: + +- Implement step 1: route-owned recurrent-bank artifact requiredness and + pruning. +- Do not start with compute kernels. The current owner is storage and allocator + pressure. +- Stop criteria for this slice: either reverse artifact bytes drop materially + and parity/perf gates pass, or a C++/Python verifier proves a recurrent-bank + artifact is genuinely required and the plan moves to `transition_state_before`. + +Implementation update: + +- Added compiler route requiredness for recomputable after-transition recurrent + K/V artifacts: + - `message.recurrent_k` and `message.recurrent_v` route rows are now + `required=0`. + - `message.recurrent_k_before` and `message.recurrent_v_before` remain + `required=1` because recurrent-message backward consumes the before-bank + values directly. +- The fused forward program now honors the route-row `required` flag before + appending reverse artifact tensors. +- The registered reverse output-message step now treats missing routed + after-transition recurrent K/V artifacts as a compiler-selected recompute + case. Recompute is owned by the registered reverse message strategy through a + new `recurrent_kv_forward_recompute` native phase, not by temporal scheduler + formulas. +- This keeps the liveness policy in compiler rows and the math in native + message strategies. +Validation: + +- `uv run python scripts/validate_fabric_generated_catalogs.py` + - passed; generated native-callable catalog is up to date. +- `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_artifact_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + - `3 passed`; + - covers forward artifact route requiredness, native callable phase catalog + coverage, and fused launch contract source guards. +- `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window --tb=short` + - `2 passed`; + - covers the registered T=1 forward artifact tensor store into registered + reverse program, both reset-absent and reset-present. +- `git diff --check` over the touched compiler/runtime/doc/test files passed. + +Targeted Axon audit: + +- Command: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_recurrent_kv_prune_axon_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_recurrent_kv_prune_axon_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_recurrent_kv_artifact_prune_axon_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- Result: `status=ok`, `cases=2`. +- Axon forward row: + - `4080.85 tok/s`, `79.201 GiB`; + - runtime buffers `15300960560` bytes; + - reverse artifacts `0` bytes. +- Axon forward/backward row: + - `796.55 tok/s`, `132.068 GiB`; + - runtime buffers `15300960560` bytes; + - reverse artifacts `13824425984` bytes; + - named runtime/artifact bytes `29125386544` bytes. +- Compared with the pre-slice four-row owner table: + - reverse artifacts moved from `17582522368` to `13824425984` + bytes; + - peak memory moved from `135.568 GiB` to `132.068 GiB`; + - `recurrent_k` and `recurrent_v` after-transition role entries are gone; + - `recurrent_k_before` and `recurrent_v_before` remain because + recurrent-message backward consumes them directly. +- Throughput moved from `824.08` to `796.55 tok/s`; this is expected for this + memory/liveness slice because it trades stored recurrent K/V banks for + strategy-owned recompute. The throughput goal should tune that recompute only + after the remaining memory/liveness owners are closed. + +Four-row owner table: + +- Command: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_recurrent_kv_prune_owner_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_recurrent_kv_prune_owner_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_recurrent_kv_artifact_prune_owner_table_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- Result: `status=ok`, `cases=4`. + +| Family | Mode | tok/s | peak GiB | runtime bytes | reverse artifact bytes | named runtime/artifact bytes | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM | forward | 5661.58 | 20.524 | 9865142576 | 0 | 9865142576 | +| sLSTM | forward+backward | 965.07 | 54.821 | 9865142576 | 5939134464 | 15804277040 | +| Axon | forward | 4094.43 | 79.233 | 15300960560 | 0 | 15300960560 | +| Axon | forward+backward | 795.64 | 132.068 | 15300960560 | 13824425984 | 29125386544 | + +Role-level movement: + +- sLSTM training no longer stores `recurrent_k`/`recurrent_v`; artifact bytes + moved from `7147094016` to `5939134464`. +- Axon training no longer stores `recurrent_k`/`recurrent_v`; artifact bytes + moved from `17582522368` to `13824425984`. +- The largest remaining reverse artifact owners are now: + - Axon `transition_state_before=3758096384` bytes; + - Axon `recurrent_hidden_backend_order=1879048192` bytes; + - Axon `recurrent_hidden_before_backend_order=1879048192` bytes; + - Axon `recurrent_k_before=1879048192` bytes; + - Axon `recurrent_v_before=1879048192` bytes; + - Axon `recurrent_msg_backend_order=1879048192` bytes. + +Next owner after this slice: + +- Continue with `transition_state_before` binding-level demand pruning before + compute-kernel tuning. +- The current patch proves route-requiredness can move real peak memory, but it + deliberately trades two stored banks for recompute and is not the throughput + closure itself. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner After K/V Prune + +Status: plan only. Do not start compute-kernel tuning from this section; the +current owner is still memory/liveness reliability. + +Highest-impact owner: + +- **Axon 100M h32 B1024 T=1 terminal training memory/liveness.** +- Current evidence after recurrent K/V artifact pruning: + - Axon training now fits, but only at `795.64 tok/s`, `132.068 GiB`; + - runtime buffers are `15300960560` bytes; + - reverse artifacts are `13824425984` bytes; + - largest named reverse artifact is + `transition_state_before=3758096384` bytes; + - remaining recurrent/message banks are each `1879048192` bytes: + `recurrent_hidden_backend_order`, + `recurrent_hidden_before_backend_order`, `recurrent_k_before`, + `recurrent_v_before`, and `recurrent_msg_backend_order`. +- Compute fusion is still premature. The active row is memory-bound enough that + a faster reverse kernel could still fail the April 21 memory gate or preserve + allocator pressure. + +Compiler-boundary rule: + +- All changes must be compiler products: `forward_artifact_route_rows`, + `reverse_artifact_consumer_route_rows`, transition dynamic binding rows, + transition seed/tape rows, memory liveness rows, memory runtime schedule rows, + registered strategy records, or registered forward/reverse program kernels. +- No Axon/sLSTM branch, benchmark-shape branch, copied April 21 path, + scheduler-owned recurrence/message formula, direct pybind wrapper, or hidden + replay/fallback is allowed. + +Implementation plan: + +1. **Make transition-state-before demand binding-level and executable.** + - Today the forward program stores transition state-before tensors when the + reverse transition input binding consumes the forward binding or when the + forward output binding is materialized. + - Tighten this to a compiler-declared transition reverse dynamic binding + demand set: store only the exact state-before input bindings consumed by + selected reverse executor rows. + - Keep C++ fail-closed validation: if a reverse dynamic binding is marked + required and the corresponding state-before artifact is missing, reject at + launch/bind time with the binding id and bucket ordinal. + - Acceptance: Axon `transition_state_before` artifact bytes drop below the + current `3758096384` bytes, or the verifier proves every retained binding + is required. + +2. **Route-prune the remaining recurrent/message banks by consumer, not role.** + - Use `reverse_artifact_consumer_route_rows` plus reverse executor access + rows to decide whether each of + `recurrent_hidden_backend_order`, + `recurrent_hidden_before_backend_order`, `recurrent_k_before`, + `recurrent_v_before`, and `recurrent_msg_backend_order` is stored, + recomputed, sliced, or aliased. + - Prefer recompute or alias only through registered native phases and memory + liveness rows. Do not infer from role names in temporal glue. + - Acceptance: at least one additional `1879048192`-byte bank disappears from + Axon artifacts without parity regression. + +3. **Make runtime buffer aliasing/lifetimes physically reduce peak memory.** + - The named runtime buffers are still `15.30B` for Axon forward/training and + contain large transition forward outputs. Convert non-overlapping + transition output lifetimes into executable alias groups or in-place + producer/consumer handoff rows where legal. + - Keep aliasing legality in the memory/liveness plan: same dtype/device, + non-overlapping live ranges, no observable output/materialization, and no + backward consumer requiring the pre-alias value. + - Acceptance: named runtime bytes and actual Axon forward/training peak both + move down; metadata reports the alias group owner. + +4. **Use the unclassified peak as the steering signal.** + - After each named-byte reduction, rerun the Axon isolated training audit and + the four-row owner table. + - If named bytes drop but peak stays near `132 GiB`, run a focused memory + snapshot/profile to classify the remaining allocator peak into autograd + saved tensors, CUDA extension temporaries, static/prepacked tensors, + PyTorch reserve, or missing compiler-owned runtime buffers. + - Convert any persistent Fabric-owned class into rows, or delete it if it is + stale. + +5. **Validation and stop criteria for this owner.** + - Source/static: guard that transition-state artifacts and recurrent banks + are requested through compiler rows, not role-only temporal lookups. + - Parity: targeted registered-route T=1 sLSTM and Axon parity, including + reset-absent first; reset-present if reset/artifact ownership is touched. + - Perf/audit: isolated Axon training audit plus the four-row h32 100M B1024 + owner table with private cache dirs and warmed iterations. + - Close this owner only when Axon T=1 training has a materially lower peak, + all retained artifact/runtime byte owners are named, registered + forward/backward owners remain active, and no primitive blockers or hidden + fallbacks appear. + +Next implementation slice: + +- Start with step 1: transition-state-before binding-level demand pruning. +- If it cannot move bytes because every state binding is required, record that + proof and move immediately to step 2, recurrent/message bank route pruning. +- Do not start reverse compute fusion until this memory/liveness owner is below + the current cliff and the remaining peak is fully attributed. + +### 2026-05-03 - Recurrent K/V-Before Artifact Prune + +Status: implemented and validated as a compiler-owned memory/liveness slice. + +What changed: + +- `message.recurrent_k_before` and `message.recurrent_v_before` are now + recomputable forward artifact routes. The forward fused program does not store + them when `forward_artifact_route_rows.required == 0`. +- The registered reverse recurrent-message step now resolves those artifacts + optionally through routed access rows. If either bank is absent, it recomputes + both through the selected registered message strategy's + `recurrent_kv_forward_recompute` phase using + `recurrent_hidden_before_backend_order`. +- Transition state-before storage was tightened to reverse-demand only. The + forward program no longer stores a state-before input merely because the + matching forward output binding is materialized. +- The change stays inside compiler products and registered strategy products: + forward artifact route rows, routed reverse artifact access, reverse executor + binding rows, and registered native message strategy phases. + +Validation: + +- `uv run python scripts/validate_fabric_generated_catalogs.py` + - Result: passed. +- `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_artifact_routes_are_compiler_owned_rows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + - Result: `2 passed`. +- `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned --tb=short` + - Result: `2 passed`. +- `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_kv_before_prune_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_kv_before_prune_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window --tb=short` + - Result: `2 passed`. + +Isolated Axon audit: + +- Command: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_recurrent_kv_before_prune_axon_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_recurrent_kv_before_prune_axon_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_recurrent_kv_before_artifact_prune_axon_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- Result: `status=ok`, `cases=1`. +- Axon forward/backward row: + - `765.03 tok/s`, `128.568 GiB`; + - runtime buffers `15300960560` bytes; + - reverse artifacts `10066329600` bytes; + - named runtime/artifact bytes `25367290160` bytes. +- Compared with the previous recurrent K/V prune row: + - reverse artifacts moved from `13824425984` to `10066329600` bytes; + - peak memory moved from `132.068 GiB` to `128.568 GiB`; + - `recurrent_k_before` and `recurrent_v_before` role entries are gone; + - `transition_state_before` remains `3758096384` bytes, now stored only for + reverse-consumed transition input bindings. + +Four-row owner table: + +- Command: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_recurrent_kv_before_prune_owner_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_recurrent_kv_before_prune_owner_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_recurrent_kv_before_artifact_prune_owner_table_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- Result: `status=ok`, `cases=4`. + +| Family | Mode | tok/s | peak GiB | runtime bytes | reverse artifact bytes | named runtime/artifact bytes | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM | forward | 5688.08 | 20.524 | 9865142576 | 0 | 9865142576 | +| sLSTM | forward+backward | 951.24 | 53.696 | 9865142576 | 4731174912 | 14596317488 | +| Axon | forward | 4097.54 | 79.233 | 15300960560 | 0 | 15300960560 | +| Axon | forward+backward | 765.13 | 128.567 | 15300960560 | 10066329600 | 25367290160 | + +Role-level movement: + +- sLSTM training no longer stores `recurrent_k_before`/`recurrent_v_before`; + reverse artifact bytes moved from `5939134464` to `4731174912`. +- Axon training no longer stores `recurrent_k_before`/`recurrent_v_before`; + reverse artifact bytes moved from `13824425984` to `10066329600`. +- The largest remaining Axon reverse artifact owners are now: + - `transition_state_before=3758096384` bytes; + - `recurrent_hidden_backend_order=1879048192` bytes; + - `recurrent_hidden_before_backend_order=1879048192` bytes; + - `recurrent_msg_backend_order=1879048192` bytes. + +Next owner after this slice: + +- Continue memory/liveness closure before compute fusion. +- The next highest named reverse-artifact candidates are the hidden/message + banks. They require either route-specific recompute, aliasing, or an + executable memory plan proof that they are genuinely live for the selected + registered reverse strategies. +- Runtime buffers remain `15300960560` bytes for Axon; the large named owners + are still transition forward outputs. The next pass should decide whether + their liveness can be shortened or aliased through compiler memory rows. + +### 2026-05-03 - Optional Transition State-Carry Output Prune + +Status: implemented and validated as a registered primitive contract and +memory/liveness slice. + +What changed: + +- Transition primitive output contracts now distinguish observable/tape-required + outputs from carry-only outputs: + - gated logspace recurrence keeps `next_y` required and makes `next_c`, + `next_n`, and `next_m` optional; + - diagonal RTU keeps `preproj` required and makes `next_hc1` and `next_hc2` + optional, matching the existing optional trace-state outputs. +- The registered forward transition program now tolerates omitted optional + output bindings and passes null output pointers to the primitive CUDA kernels + only when the compiler binding rows have omitted those outputs. +- The gated and diagonal device kernels now guard writes to omitted optional + state-carry outputs. +- The reverse dynamic binder now materializes compiler-owned zero tensors for + missing optional state-before artifacts. This keeps initial/reset private + trace states legal without forcing every T=1 path to store trace artifacts. + +Why this is a compiler-boundary-safe optimization: + +- The decision is expressed in registered primitive output-binding contracts and + the existing compiler-owned `forward_executor_binding_rows` filtering. +- There is no family, batch, hidden-size, benchmark-row, or temporal-scheduler + branch. +- T>1, final-state materialization, and incoming final-state gradients still + request optional carry outputs through the same compiler schedule policy. + +Validation: + +- `uv run python scripts/validate_fabric_generated_catalogs.py` + - Result: passed. +- `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + - Result: `2 passed`. +- `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + - Result: `3 passed`. +- `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_optional_state_carry_parity_retry_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_optional_state_carry_parity_retry_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference --tb=short` + - Result: `2 passed`. +- `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window --tb=short` + - Result: `2 passed`. + +Four-row owner table: + +- Command: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_throughput_t1_optional_state_carry_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_throughput_t1_optional_state_carry_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_optional_transition_state_carry_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- Result: `status=ok`, `cases=4`. + +| Family | Mode | status | tok/s | peak GiB | runtime bytes | reverse artifact bytes | +| --- | --- | --- | ---: | ---: | ---: | ---: | +| sLSTM | forward | ok | 5581.61 | 18.837 | 8053203248 | 0 | +| sLSTM | forward+backward | ok | 956.57 | 50.321 | 8053203248 | 4731174912 | +| Axon | forward | ok | 4104.46 | 75.733 | 11542864176 | 0 | +| Axon | forward+backward | ok | 768.36 | 121.567 | 11542864176 | 10066329600 | + +Role-level movement versus the recurrent K/V-before artifact prune owner table: + +- sLSTM runtime buffers moved from `9865142576` to `8053203248` bytes. + `transition_forward_state_output` moved from `2415919104` to `603979776` + bytes because only `next_y` remains materialized for the T=1 terminal path. +- Axon runtime buffers moved from `15300960560` to `11542864176` bytes. + `transition_forward_diag_output` moved from `7516192768` to `3758096384` + bytes because only `preproj` remains materialized for the T=1 terminal path. +- Axon 100M h32 B1024 T1 training now completes instead of OOMing, at + `768.36 tok/s` and `121.567 GiB`. + +Next owner after this slice: + +- Runtime buffers are materially lower, and the Axon training row is measurable. + The remaining T=1 closure owner is still memory/liveness first: + - Axon reverse artifacts remain `10066329600` bytes, dominated by + `transition_state_before`, recurrent hidden before/after, and recurrent + message banks. + - Axon peak memory is still `58.7x` over the April 21 `2.07 GiB` summary + floor, so compute fusion alone is not a closure path. +- The next registered-strategy slice should attack routed hidden/message + artifact storage or executable aliasing/lifetime shortening. Reverse compute + throughput becomes the primary owner only after the remaining named memory + banks are either reduced or proven live by compiler rows. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner After Optional State-Carry Prune + +Status: plan only. Do not start compute-kernel tuning from this section. + +Highest-impact owner: + +- **Registered T=1 memory/liveness execution**, with Axon 100M h32 B1024 + terminal training as the boundary row. +- Current evidence after optional transition state-carry pruning: + - Axon training now fits, but only at `768.36 tok/s`, `121.567 GiB`; + - Axon forward alone is still `4104.46 tok/s`, `75.733 GiB`; + - Axon runtime buffers are `11542864176` bytes; + - Axon reverse artifacts are `10066329600` bytes; + - the largest runtime roles are + `transition_forward_diag_output=3758096384`, + `transition_forward_linear_output=3758096384`, + `forward_recurrent_hidden_after=1879048192`, and + `forward_recurrent_msg=1879048192`; + - the largest reverse artifacts are + `transition_state_before=3758096384`, + `recurrent_hidden_backend_order=1879048192`, + `recurrent_hidden_before_backend_order=1879048192`, and + `recurrent_msg_backend_order=1879048192`; + - `fabric_cuda_reserved_gap_bytes` is still very large, so named logical + byte reductions must be checked against actual peak, not accepted by + metadata alone. + +Why this owner is first: + +- It affects both forward and training rows. Axon forward is already + `38.3x` over the April 21 memory floor before reverse compute starts. +- Training cannot close while peak memory is `58.7x` over April 21, even if a + faster reverse kernel improves tok/s. +- The remaining byte owners are already compiler products: runtime buffer + rows, memory liveness rows, artifact route rows, reverse consumer route rows, + transition dynamic binding rows, and registered strategy access rows. That + means the next work can be real compiler-owned throughput work rather than a + compatibility rewrite. + +Compiler-boundary rule: + +- Changes must be expressed through `memory_liveness_rows`, + `memory_runtime_schedule_rows`, runtime buffer rows, artifact route rows, + reverse consumer route rows, registered native strategy records, or + registered program kernels. +- Do not add Axon/sLSTM branches, benchmark-row branches, hidden-size rules, + copied April21 kernels, direct pybind wrappers, or primitive formulas in + temporal scheduler code. +- Do not optimize benchmark-side chunking, detach policy, checkpoint policy, or + private runtime helper calls. + +Plan: + +1. **Make runtime-buffer aliasing executable for exact-shape transition + products.** + - Start with compiler-proven exact-shape aliases only: + same dtype, device, shape, bucket ordinal, non-overlapping live range, and + no backward consumer requiring both values simultaneously. + - First candidate group: transition forward intermediate/output buffers that + are producer-consumer temporaries rather than user-visible outputs. + - Express alias sets in `memory_liveness_rows` and + `memory_runtime_schedule_rows`; have Python allocation and C++ validation + require the same alias proof. + - Acceptance: allocated runtime bytes and actual Axon forward/training peak + both move down. Logical alias metadata alone is not enough. + +2. **Route-prune or recompute the remaining hidden/message reverse artifacts.** + - Use `reverse_artifact_consumer_route_rows` plus registered reverse strategy + access rows to decide whether each producer route is `materialize`, + `reuse_existing_storage`, `recompute_from_program_tensor`, or + `not_required`. + - First candidates: + `recurrent_hidden_backend_order`, + `recurrent_hidden_before_backend_order`, and + `recurrent_msg_backend_order`. + - Any recompute must be a registered native strategy phase, not temporal + scheduler math. + - Acceptance: at least one additional `1879048192`-byte artifact bank + disappears, or the verifier records a route-level proof that the bank is + genuinely live for the selected reverse strategies. + +3. **Split and validate `transition_state_before` by binding demand.** + - Current Axon still stores `3758096384` bytes under + `transition_state_before`. + - Split this artifact by primitive row, bucket ordinal, binding id, and + reverse consumer route. Retain only bindings consumed by selected reverse + transition executors, tape contracts, or parameter reducers. + - Acceptance: bytes drop, or every retained binding has an auditable + consumer reason in route metadata. + +4. **Use the allocator gap as a steering gate.** + - After each named-byte reduction, rerun the isolated Axon training audit and + four-row owner table. + - If named runtime/artifact bytes move but peak stays near `121 GiB`, run a + focused memory profile before writing compute kernels. + - Candidate unclassified owners: CUDA extension temporaries, transition + native-callable temporaries, autograd saved tensors, static/prepacked + tensors, parameter-gradient buffers, and allocator reserve. + +5. **Only then plan reverse compute fusion.** + - Reverse compute becomes the next owner only after Axon T=1 training has + meaningful memory headroom and the remaining peak is named. + - The compute plan must fuse registered reverse strategy phases over the + existing rows; it must not move message/transition/readout formulas into + temporal scheduler code. + +Validation gates for each implementation slice: + +- Source/static guardrails: no fixed slot, family, benchmark, hidden-size, or + primitive formula leaks into temporal scheduler. +- Registry/catalog validation: generated native-callable catalog stays current. +- Parity: targeted registered-route sLSTM and Axon T=1 parity; include + reset-present and materialized-final-state parity if artifact/state ownership + changes. +- Perf/audit: isolated Axon T=1 training audit plus four-row h32 100M B1024 + owner table with private cache dirs and warmed iterations. +- Metadata: registered forward/backward owners remain active, primitive + executor blockers absent, and remaining byte owners named. + +First implementation slice: + +- Start with step 1: executable runtime-buffer aliasing for exact-shape + transition products. +- If alias validation proves those buffers overlap or are simultaneously live, + record the proof and move immediately to step 2, hidden/message artifact + route pruning. +- Do not start reverse compute kernels until this memory/liveness owner is + materially lower or fully proven live by compiler rows. + +### 2026-05-03 - Runtime Public-State Alias Slice + +Status: implemented and measured; not enough to close the T=1 training owner. + +Compiler-owned change: + +- Added an executable runtime alias class for the transition public-state output + and `forward_recurrent_hidden_after` buffer. +- The alias is selected from `forward_program_access_rows` via the registered + `transition_public_state_output` opcode, not from cell names or benchmark + shapes. +- The alias is legal only when: + - the memory runtime policy enables scheduler aliasing; + - the physical runtime buffer plan has exactly one recurrent-hidden-after + step; + - the transition public-state output and recurrent-hidden-after buffer have + identical shape, dtype, device, and empty-init policy. +- Added C++ self-copy elision in the registered fused forward program so an + aliased public-state output does not call `copy_` into the exact same storage. + +Files changed: + +- `src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/forward_program.cuh` +- `tests/test_fabric_backend_plan.py` + +Validation: + +- `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` +- `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + - `1 passed` +- `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + - `1 passed` +- `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + - `3 passed` +- `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_t1_alias_nocopy_torch_extensions TRITON_CACHE_DIR=/tmp/cortical_t1_alias_nocopy_triton_cache uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_readout_closed_region_matches_pytorch_reference --tb=short` + - `4 passed` + +Perf/audit commands: + +- Isolated Axon training boundary: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_t1_public_state_alias_nocopy_axon_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_public_state_alias_nocopy_axon_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_public_state_runtime_alias_nocopy_axon_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` +- Final four-row owner table: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_t1_public_state_alias_nocopy_owner_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_public_state_alias_nocopy_owner_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_public_state_runtime_alias_nocopy_owner_table_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` + +Final owner table: + +| Row | Status | tok/s | peak GiB | runtime bytes | reverse artifact bytes | Notes | +| --- | --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | ok | `5608.56` | `18.274` | `7449223472` | `0` | Runtime allocation down from `8053203248`; forward peak down from `18.837`. | +| sLSTM 100M h32 B1024 T1 training | ok | `952.73` | `50.321` | `7449223472` | `4731174912` | Runtime allocation down, but peak unchanged; training owner is elsewhere. | +| Axon 100M h32 B1024 T1 forward | ok | `4112.55` | `73.983` | `9663815984` | `0` | Runtime allocation down from `11542864176`; forward peak down from `75.733`. | +| Axon 100M h32 B1024 T1 training | OOM | - | - | `9663815984` | `10066329600` | Still OOMs while trying to allocate `1.75 GiB`; named runtime/artifact bytes are `19730145584`, but `fabric_unclassified_cuda_peak_bytes` is `128021193552`. | + +Conclusion: + +- This slice is a valid compiler-owned throughput strategy and it physically + reduces runtime allocation for exact-shape T=1 public-state output. +- It does **not** close the highest-impact training row. The final no-copy + rerun still OOMs and should be treated as a failed training probe, not a + successful memory/liveness closure. +- Do not continue expanding this alias route until the hidden allocator owner is + explained. + +Comparison against the last successful Axon training row: + +| Probe | Status | tok/s | peak/max allocated | runtime bytes | reverse artifact bytes | named runtime+artifact bytes | unclassified peak bytes | Notes | +| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | --- | +| `t1_optional_transition_state_carry_h32_100m_b1024` | ok | `768.36` | `121.567 GiB` / `130531166208` | `11542864176` | `10066329600` | `21609193776` | `108122587344` | Last successful Axon training row. | +| `t1_public_state_runtime_alias_nocopy_owner_table_h32_100m_b1024` | OOM | - | OOM at `148153534464` max allocated | `9663815984` | `10066329600` | `19730145584` | `128019358544` | Named bytes dropped by `1879048192`, but max allocated rose by about `17622368256` and the row OOMed trying to allocate another `1.75 GiB`. | + +Interpretation: + +- The alias changed the compiler runtime-buffer allocation as intended: + `planned_buffer_bytes=11542864176` while + `estimated_allocated_buffer_bytes=9663815984`, and the runtime summary shows + `forward_recurrent_hidden_after_step_0` and + `transition_forward_linear_output_row_20_public_y` sharing + `runtime_alias.transition_public_state.1024,14336,32`. +- The alias did **not** explain or reduce the actual training peak. The + successful row ended with full parameter gradients allocated + (`model_parameter_grad_bytes=399688320`), while the failed no-copy row OOMed + before that point (`model_parameter_grad_bytes=4333568`) with roughly + `65 GiB` still allocated. That means the `128 GiB` unclassified peak is not a + simple final allocator-reserve gap; it is actual live allocation at failure, + or a CUDA temporary/autograd-saved lifetime that the current ledger does not + classify. +- Current evidence cannot distinguish the hidden owner between a lifetime bug, + autograd saved tensor, CUDA temporary, or allocator ordering/fragmentation. + The next step must instrument and name that hidden owner before adding more + alias metadata. + +Alias decision: + +- Keep the byte-accounting diagnostics and the failed-probe documentation. +- Do **not** keep the public-state alias active for training unless the hidden + `fabric_unclassified_cuda_peak_bytes` owner is explained. +- If the alias is retained, narrow it to forward-only inference/no-artifact + execution (`collect_artifacts=False`) because forward peak moved down and + parity passed. The training path should use the previous non-aliased storage + until the hidden allocator owner is identified. +- If a quick narrow is not acceptable, revert the active alias pieces: + `alias_runtime_role`, the `runtime_alias.transition_public_state.*` alias + selection/validation helpers, the `forward_program_access_rows` propagation + into transition runtime-buffer requests, and the public-state alias test. The + C++ self-copy skip is harmless but only useful if a forward-only alias remains. + +Next owner: + +- Name the hidden allocator/unclassified owner first. Instrument the Axon T=1 + training row around forward artifact creation, registered reverse readout, + message reverse, transition reverse, parameter reducers, and optimizer/grad + materialization. Attribute the `~128 GiB` unclassified peak before doing more + runtime-buffer aliasing or reverse compute fusion. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner After Failed Alias Probe + +Status: plan only. Do not optimize or expand aliasing from this section until +the hidden allocator owner is named. + +Highest-impact owner: + +- **Hidden unclassified training allocation in the registered Axon T=1 path**. +- Current boundary row: + `Axon 100M h32 B1024 T=1 terminal training`. +- The failed no-copy alias probe reduced named runtime/artifact bytes from + `21609193776` to `19730145584`, but max allocated rose from + `130531166208` to `148153534464` and the row OOMed. +- The active compiler path still reports registered forward/backward owners. + This is not an ownership-selection failure; it is a memory/lifetime + attribution and liveness failure inside the registered path. + +Compiler-boundary rules: + +- Do not add more public-state/runtime aliasing until the `~128 GiB` + unclassified peak is attributed. +- Do not add Axon/sLSTM branches, benchmark-row branches, hidden-size policy, + direct CUDA wrappers, copied April 21 code, or primitive formulas in temporal + scheduler code. +- Any kept fix must be expressed through compiler products: + memory/liveness rows, runtime schedule rows, artifact route rows, reverse + consumer rows, executor binding rows, native strategy contracts, reducer rows, + or audit metadata. + +Plan: + +1. **Freeze or narrow the failed alias path before measuring.** + - Preferred: disable the public-state alias for training/artifact-producing + rows and keep it only for forward-only `collect_artifacts=False` if that + path still passes parity and lowers peak. + - Alternative: revert the active alias pieces entirely and keep only the + byte-ledger diagnostics plus C++ self-copy skip if still useful. + - Acceptance: the next Axon training measurement starts from the last + successful non-aliased row, not from the failed OOM alias probe. + +2. **Add stage-scoped memory checkpoints for the registered backward program.** + - Record allocated/reserved/max-allocated bytes before and after: + forward runtime buffer allocation, artifact tensor-store construction, + readout reverse, message reverse, transition reverse, recurrent K/V + recompute, boundary projection backward, parameter reducer execution, + grad materialization, and autograd return. + - Include CUDA synchronize only around measurement gates, not as a permanent + execution policy. + - Acceptance: the ledger names the stage where the unclassified peak appears + or grows by tens of GiB. + +3. **Classify live tensors at the peak by owner.** + - For Fabric tensors created by registered code, attach or report: + row id, artifact role, runtime role, bucket, shape, dtype, storage data + pointer, live interval, and producer/consumer route. + - For non-Fabric tensors, report model parameter bytes, parameter-gradient + bytes, static/prepacked tensor bytes, PyTorch saved-tensor bytes where + discoverable, CUDA extension temporary class, and allocator reserve gap. + - Acceptance: the difference between named Fabric bytes and max allocated is + split into concrete classes; large "unclassified" Fabric allocations are a + blocker. + +4. **Use the named owner to select the reducing patch.** + - If the peak is reverse artifacts: continue route-pruning hidden/message + artifacts or split `transition_state_before` by binding demand. + - If the peak is transition reverse temporaries: add a registered native + strategy workspace contract and planner-owned scratch/lifetime rows. + - If the peak is autograd saves: change the custom autograd save contract so + only compiler-declared tensors are saved; missing rows fail closed. + - If the peak is parameter reducers/grad materialization: make reducer output + lifetimes and accumulation buffers compiler-owned rows. + - If the peak is allocator fragmentation/reserve with no live tensor owner: + adjust allocation order/reuse through memory runtime schedule rows, not + benchmark-side workarounds. + +5. **Run the minimal gates after the first reducing patch.** + - Source/static guardrail: no family, benchmark, hidden-size, fixed-slot, or + primitive-formula leaks. + - Parity: registered CUDA route plus sLSTM/Axon T=1 single-pop parity for + touched routes; reset/materialized-final-state parity if state/tape policy + changes. + - Performance: isolated Axon T=1 training row plus the four-row h32 100M + B1024 owner table. + - Acceptance: Axon training fits with clear headroom, peak moves down, and + the remaining peak has named compiler/runtime owners. + +Next implementation slice: + +- Start with steps 1 and 2: narrow/revert the training alias path and add + stage-scoped memory checkpoints around the registered backward program. +- Do not write reverse compute kernels, do not add more alias groups, and do + not tune benchmark policy until the hidden allocator owner is named. + +Implementation result: + +- The public-state runtime alias is now explicit opt-in in + `build_temporal_runtime_buffer_plan`. +- The registered fused forward path opts into the alias only when + `collect_artifacts=False`. +- The registered training/backward path explicitly passes + `enable_public_state_runtime_alias=False`, so artifact-producing training + rows do not use the failed no-copy alias route. +- The registered backward program now emits stage memory rows into backend + metadata while it runs. OOM rows append the last error stage to the existing + execution record, so the audit memory ledger can name the failing owner. +- `benchmarks/fabric/suite_common.py` now parses + `flat_bucket_temporal_registered_backward_memory_stage:*` rows into the + memory ledger as per-stage allocated/reserved/max-allocated fields. + +Targeted evidence: + +- Static/compiler checks: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: `173 passed`. +- Focused memory-plan check: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + passed. +- Focused audit-runner check: + `uv run pytest -q tests/test_fabric_audit_runner.py --tb=short` + passed: `19 passed`. +- Exact Axon owner row after alias narrowing: + `tmp/fabric_audits/partials/2026-05-03/t1_public_state_alias_narrowed_errorstage_axon_h32_100m_b1024`. + Result: still OOM, but now attributed. + +Exact Axon row memory ledger: + +| Row | Status | Runtime bytes | Reverse artifact bytes | Named runtime/artifact bytes | Max allocated | Unclassified peak | Peak stage | +|---|---:|---:|---:|---:|---:|---:|---| +| `t1_public_state_alias_narrowed_errorstage_axon_h32_100m_b1024` | OOM | `11542864176` | `10066329600` | `21609193776` | `148155369472` | `126142145360` | `fused_backward_program_error` | + +Stage ledger: + +| Stage | Allocated | Reserved | Max allocated | +|---|---:|---:|---:| +| `reverse_tensor_table_built` | `53534444032` | `85435875328` | `85046158336` | +| `runtime_buffer_plan_built` | `53534444032` | `85435875328` | `85046158336` | +| `runtime_buffers_allocated` | `65077216768` | `91341455360` | `85046158336` | +| `before_fused_backward_program` | `65077216768` | `91341455360` | `85046158336` | +| `fused_backward_program_error` | `65077216768` | `148232994816` | `148155369472` | + +Interpretation: + +- The alias did **not** remain active for training: named runtime/artifact bytes + returned to the previous non-aliased value, `21609193776`. +- The `~148.15GB` max-allocated peak is not explained by runtime buffers or + reverse artifact tensors. It appears inside + `registered_temporal_fused_backward_program_cuda` before span outputs return. +- Allocated bytes at the error stage fall back to `65077216768`, while reserved + and max allocated hit `~148GB`; this points at a large CUDA/PyTorch temporary + or internal span-output/lifetime burst inside the fused backward program, not + allocator reserve gap alone and not the public-state alias route. +- Keep: alias narrowing, stage ledger metadata, and ledger parsing. +- Revert/narrow: do not keep public-state runtime alias for any + artifact-producing training row until the fused backward internal temporary + owner is reduced or proven harmless. +- Next owner: split the fused backward program internal peak by C++ stage + (`readout`, `transition`, `message`, boundary/reducer staging) and move the + hot temporary into compiler-owned workspace/lifetime rows or eliminate the + materialization. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner: Registered Backward Span Materialization + +Owner: + +- Current failing representative row: + `t1_public_state_alias_narrowed_errorstage_axon_h32_100m_b1024`. +- Peak: `148155369472` max allocated, with `126142145360` unclassified after + named runtime/artifact bytes. +- Current stage owner: `registered_temporal_fused_backward_program_cuda`, with + the peak appearing between `before_fused_backward_program` and + `fused_backward_program_error`. +- Likely high-impact sub-owner from source inspection: reverse span output + materialization. The C++ program accumulates nested step/span output groups + and stabilizes every defined tensor through `contiguous().clone()` before + returning them to Python. + +Boundary audit: + +- This is throughput strategy work over existing compiler products, not a + semantic extension. +- Unchanged semantic rows: primitive rows, forward executor rows, reverse + executor rows, tensor binding rows, artifact route rows, output route rows, + message/cell/readout primitive semantics, and parameter-gradient meanings. +- Changed strategy/runtime rows: reverse program stage telemetry, reverse span + output materialization/lifetime policy, runtime buffer ownership, and reducer + consumption policy. +- Tensor/route/liveness rows consumed directly: `reverse_span_output_rows`, + `reverse_output_route_rows`, `reverse_parameter_reducer_route_rows`, + `memory_liveness_rows`, `memory_runtime_schedule_rows`, `runtime_buffer_rows`, + `reverse_program_access_rows`, and transition dynamic binding rows. +- Old route to delete or fail-close: returning all full reverse span output + groups to Python for supported registered rows. Unsupported rows may fail + before launch, but supported registered rows should use compiler-owned output + lifetimes. + +Plan: + +1. Add C++ internal stage memory telemetry before optimization. + - Record allocated/reserved/max-allocated around readout reverse, + output-message reverse, recurrent-KV reverse, transition reverse, + recurrent-message reverse, boundary-KV reverse, initial recurrent-KV + reverse, span-output stabilization, and return. + - Emit rows using the registered backward program stage metadata path, not + ad hoc benchmark labels. + - Acceptance: the exact Axon T=1 row names the internal stage that creates + the 128GB unclassified peak. + +2. Confirm whether the clone/retained-output hypothesis is true. + - If the peak is in `append_stable_reverse_program_output_groups` or + `stable_reverse_program_output_group`, move to step 3. + - If the peak is inside a native callable stage, keep the owner on that + stage and add a native strategy workspace/lifetime contract instead. + - If the peak is allocator reserve without live tensor growth, solve it + through memory runtime schedule ordering and reuse rows, not benchmark + knobs. + +3. Add compiler-owned reverse output materialization policy. + - Extend the reverse output/liveness plan so each output role declares one + of: return to Python, consume by reducer, alias runtime buffer, accumulate + into carry/state buffer, or skip after local use. + - Validate this policy in Python and C++ before launch. + - Keep global required outputs only: boundary gradients, carry/state seeds, + and compact reducer outputs or reducer-owned buffers. + - Do not infer policy from role names, cell families, hidden size, benchmark + row, or old fixed-slot layout. + +4. Stream or consume reducer outputs instead of retaining every span group. + - Use `reverse_parameter_reducer_route_rows` to route readout/message/ + transition parameter-gradient inputs directly to reducer-owned buffers or + compact route outputs. + - Avoid storing full intermediate front/boundary groups across the whole + backward program when the next consumer is known from compiler rows. + - Acceptance: Python no longer receives all intermediate span groups for + supported registered rows. + +5. Replace stable clone behavior with row-owned lifetime rules. + - Remove unconditional `contiguous().clone()` for outputs backed by stable + runtime buffers or consumed before return. + - Clone only outputs whose row policy explicitly requires independent + returned storage. + - Add validation that aliasable outputs point at compiler-owned runtime + buffers with compatible shape, dtype, and lifetime interval. + +6. Run gates in this order. + - Static/compiler: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + - Focused parity on the registered temporal physical-autograd route for + sLSTM and Axon T=1, with output/state/input/carry/parameter gradients. + - Reset/materialized-final-state parity if carry/state lifetime rows change. + - Exact perf row: + `t1_public_state_alias_narrowed_errorstage_axon_h32_100m_b1024`, rerun + under a new artifact directory after the first reducing patch. + - Non-regression perf: the four-row T=1 owner table after the Axon row fits. + +Acceptance: + +- The representative Axon T=1 training row fits with named memory owners. +- Max allocated moves down because the registered backward program no longer + retains/clones unnecessary reverse span outputs. +- Runtime metadata still reports compiler-owned registered forward/backward + owners and no hidden fallback/replay/compat route. +- No primitive formulas, family selectors, benchmark selectors, or fixed-slot + aliases are added to scheduler or benchmark code. + +Implementation result: + +- Changed the registered fused backward program return lifetime policy: + T=1 reverse windows now return compiler-owned span-output tensors without + unconditional `contiguous().clone()`; multi-step windows keep the stabilizing + clone path because later reverse steps may overwrite runtime-backed tensors. +- Skipped transition next-seed materialization at `local_step == 0`. There is + no earlier reverse step to consume those seeds, so the final-step seed clone + was pure lifetime waste under the compiler schedule. +- Added a compact single-executor return route: when there is exactly one + readout or message executor row, Python routes that executor's reducer inputs + through the aggregate front/boundary compiler rows instead of requiring + duplicate per-executor groups from C++. +- This stayed inside registered compiler-owned products: primitive rows, + executor rows, reverse span output rows, reverse output/reducer route rows, + memory liveness rows, runtime buffer rows, and transition seed rows. No + primitive math, family selector, benchmark selector, or fixed-slot alias was + added. +- Internal C++ sub-stage telemetry is still open. This pass used the existing + Python-visible registered backward stage ledger and a targeted lifetime patch; + the next owner needs either finer C++ stage rows or direct in-program reducer + consumption. + +Targeted checks: + +- Static/compiler: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: `173 passed`. +- T=1 CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_span_policy_v2 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_span_policy_v2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients --tb=short` + passed: `4 passed`. +- T>1 registered-window parity guard: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_span_policy_v2 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_span_policy_v2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: `2 passed`. +- Python compile sanity: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` + passed. + +Perf evidence: + +| Row | Status | Tokens/s | Step ms | Peak GiB | Max allocated | Unclassified peak | Notes | +|---|---:|---:|---:|---:|---:|---:|---| +| `t1_public_state_alias_narrowed_errorstage_axon_h32_100m_b1024` | OOM | n/a | n/a | n/a | `148155369472` | `126142145360` | pre-patch fused backward error | +| `t1_backward_span_policy_axon_h32_100m_b1024` | ok | `774.1936` | `1322.6665` | `100.7468` | `108176049664` | `85767470800` | T=1 no-clone + final-step seed elision | +| `t1_backward_compact_span_outputs_axon_h32_100m_b1024` | ok | `774.0694` | `1322.8789` | `100.7468` | `108176049664` | `85767470800` | plus single-executor duplicate group elision | + +Interpretation: + +- The exact Axon T=1 owner row now completes through the registered temporal + owner. This is a real owner move, not a metadata-only change. +- Max allocated moved down by `39979319808` bytes from the previous OOM row. +- The remaining peak is still too high: after the fused backward call, allocated + memory jumps from `65077216768` to `100005545472`. That means the next owner + is still inside the fused backward program's returned outputs/native + temporaries, not the public-state alias path and not duplicate single-executor + groups. +- The compact single-executor route is active in metadata via + `temporal_backward_glue:registered_fused_backward_program_single_executor_span_output_elision`, + but it did not materially change the large peak. Keep it as a small + route/lifetime cleanup, but do not count it as throughput closure. +- The summary file marks shared temporal coverage as failed only because this + command intentionally ran single-population only. The case status itself is + `ok`; mixed-pop T=1 remains a required closure row. + +Next owner: + +- Add internal C++ stage memory telemetry or direct route-owned reducer/carry + consumption inside `registered_temporal_fused_backward_program_cuda`. +- Specifically split the remaining `~35GB` after-fused-backward allocation + across transition outputs, message/recurrent-KV outputs, boundary outputs, + and parameter reducer inputs. +- The next reducing patch should stop returning full transition/message + intermediate tensors to Python when compiler rows prove they can be consumed + by reducers or accumulated into carry/state buffers inside the registered + program. + +### 2026-05-03 - Registered Transition-State Grad Liveness + +Status: implemented and measured; improvement, not throughput closure. + +Change: + +- The registered reverse launch now passes a compiler-owned + `return_window_start_transition_state_grads` liveness flag into + `registered_temporal_fused_backward_program_cuda`. +- The flag is true when the initial/window-start state can observe those + gradients (`state_requires_grad`) or when the window must hand private-state + gradients to an earlier window (`window_start > 0`). +- The fused reverse program still returns transition-state gradients for + earlier local steps (`local_step > 0`), so T>1 reverse seeds remain intact. +- For the T=1 terminal/no-initial-state-grad row, transition handlers no longer + allocate previous private-state banks solely to return empty/unconsumed state + gradients. +- This is strategy/liveness work over registered compiler products only: + primitive rows, executor rows, reverse seed rows, memory liveness rows, and + registered transition handlers. No cell-family or benchmark branch was added. + +Targeted checks: + +- Python compile sanity: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py` + passed. +- Focused static/compiler: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_reverse_transition_native_handlers_use_logical_binding_schema tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_backward_requires_registered_reverse_binding_plan --tb=short` + passed: `4 passed`. +- Full static/compiler: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: `173 passed`. +- T=1 CUDA parity plus state-gradient guard: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_transition_state_liveness_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_transition_state_liveness_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + passed: `4 passed`. + +Perf evidence: + +| Row | Status | Tokens/s | Step ms | Peak GiB | Max allocated | Unclassified peak | Notes | +|---|---:|---:|---:|---:|---:|---:|---| +| `t1_backward_compact_span_outputs_axon_h32_100m_b1024` | ok | `774.0694` | `1322.8789` | `100.7468` | `108176049664` | `85767470800` | pre-patch compact span route | +| `t1_transition_state_grad_liveness_axon_h32_100m_b1024` | ok | `777.7296` | `1316.6529` | `97.2468` | `104417953280` | `82009374416` | transition private-state grads returned only when window-start liveness needs them | + +Interpretation: + +- Max allocated dropped by `3758096384` bytes, matching one large Axon + transition-state bank at this shape. +- Allocated bytes immediately after the fused backward call dropped from + `100005545472` to `96247449088`. +- Runtime/artifact named bytes stayed fixed at `21609193776`, so this patch + reduced a native registered reverse output allocation rather than hiding it + in metadata. +- The row is still far from April 21 (`777.73 tok/s`, `97.25 GiB` vs + `58732.71 tok/s`, `2.07 GiB`), so this is not throughput closure. + +Next owner: + +- Continue on the same registered fused backward program. +- The remaining after-fused-backward allocation is still `~31.17 GiB` above + the pre-call stage. Split or eliminate it by routing transition/message + reducer inputs and carry/state outputs directly through compiler-owned + reducer/runtime buffers instead of returning full intermediate groups to + Python. +- If the next patch cannot reduce this further, add C++ sub-stage memory rows + inside readout, output-message, recurrent-KV, transition, recurrent-message, + boundary-KV, and initial-KV stages before changing strategy behavior again. + +### 2026-05-03 - Registered Local-Only Reverse Span Output Elision + +Status: implemented and measured; live allocation improvement, not throughput +closure. + +Change: + +- Marked reverse front-span outputs that are consumed inside the C++ reverse + step as optional returned compiler rows: + `grad_recurrent_hidden_backend_direct`, `grad_input_k_from_output`, + `grad_input_v_from_output`, and + `grad_recurrent_hidden_from_kv_graph_order`. +- The registered fused backward program now consumes those tensors locally for + transition/message-boundary work and returns empty route slots instead of + keeping large Python-visible references alive after the C++ call. +- Message strategy extra key-bank outputs are now materialized only when the + active reverse message native callable returns the fixed-slot/context extra + output arity. Default neighborhood attention no longer returns + `grad_input_key_bank` and `grad_recurrent_key_bank` as unused global extra + reducer roles. +- Python no longer resolves local-only `output_grad`/direct hidden routes after + the fused reverse program returns. +- This is a registered throughput/liveness strategy over existing compiler + products: reverse span output rows, reverse output/reducer route rows, + native callable output arity, memory liveness rows, and runtime buffer rows. + No primitive math, family selector, benchmark selector, hidden-size branch, or + fixed-slot temporal ABI was added. + +Targeted checks: + +- Focused static/compiler: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_reverse_span_outputs_are_compiler_owned_rows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: `2 passed`. +- Full static/compiler: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: `173 passed`. +- T=1/T>1 CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_local_only_span_elision2_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_local_only_span_elision2_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + passed: `4 passed`. + +Perf evidence: + +- Exact Axon owner row: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_local_only_span_elision2_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_local_only_span_elision2_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_local_only_span_output_elision_axon_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` + passed with one `ok` case. +- Four-row owner table: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_local_only_span_elision2_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_local_only_span_elision2_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_local_only_span_output_elision_owner_table_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner` + passed with four `ok` cases. + +| Row | Status | Tokens/s | Step ms | Peak GiB | Max allocated | Unclassified peak | After fused backward allocated | +|---|---:|---:|---:|---:|---:|---:|---:| +| `t1_transition_state_grad_liveness_axon_h32_100m_b1024` | ok | `777.7296` | `1316.6529` | `97.2468` | `104417953280` | `82009374416` | `96247449088` | +| `t1_local_only_span_output_elision_axon_h32_100m_b1024` | ok | `776.8233` | `1318.1891` | `97.2468` | `104417953280` | `82009374416` | `90207651328` | + +Four-row owner table after the patch: + +| Row | Status | Tokens/s | Step ms | Peak GiB | Max allocated | Runtime bytes | Reverse artifact bytes | After fused backward allocated | +|---|---:|---:|---:|---:|---:|---:|---:|---:| +| sLSTM 100M h32 B1024 T1 forward | ok | `5699.1491` | `179.6759` | `18.2742` | `19621811200` | `7449223472` | `0` | n/a | +| sLSTM 100M h32 B1024 T1 training | ok | `990.5307` | `1033.7893` | `29.0275` | `31168031232` | `8053203248` | `4731174912` | `24968391168` | +| Axon 100M h32 B1024 T1 forward | ok | `4116.3304` | `248.7653` | `73.9835` | `79439160320` | `9663815984` | `0` | n/a | +| Axon 100M h32 B1024 T1 training | ok | `777.7564` | `1316.6076` | `97.2451` | `104416118272` | `11542864176` | `10066329600` | `90205816320` | + +Interpretation: + +- The patch removed about `6039797760` bytes of live allocation after the fused + backward call on the exact Axon row: + `96247449088 -> 90207651328`. +- The named runtime/artifact bytes stayed fixed at `21609193776`, so the + movement came from returned/native reverse outputs, not metadata relabeling. +- Max allocated did not move (`104417953280` before and after). The remaining + peak therefore occurs inside the C++ fused backward/native callable lifetime, + before the Python-visible post-call stage can observe the reduced live set. +- Runtime metadata includes + `temporal_backward_glue:registered_fused_backward_program_local_only_span_output_elision`, + so the active path proves this strategy ran. +- The row is still far from April 21 (`~777.76 tok/s`, `97.25 GiB` vs + `58732.71 tok/s`, `2.07 GiB`), so this is not throughput closure. + +Next owner: + +- Add internal C++ sub-stage memory telemetry for + `registered_temporal_fused_backward_program_cuda` around readout, + output-message, recurrent-KV, transition, recurrent-message, boundary-KV, + initial-recurrent-KV, span-output packing, and return. +- The next reducing patch should target the stage that creates the unchanged + `104.42GB` max allocation. Candidate owners are native callable temporaries + and reducer-input tensors before they are dropped from the returned span + groups. +- Do not expand public-state aliasing and do not start compute-only reverse + tuning until the remaining max-allocated owner is named. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner After Local-Only Span Elision + +Status: implementation in progress. The proceed prompt selected this owner. + +Highest-impact owner: + +- **Registered fused backward program internal memory lifetime**, on + `Axon 100M h32 B1024 T=1 terminal training`. +- Current evidence: + - exact Axon row completes at about `777.76 tok/s`; + - peak remains `97.25 GiB`, `104416118272` max allocated; + - named runtime/artifact bytes are about `21.61B`; + - live allocation after fused backward dropped by about `6.04B`, but max + allocated did not move; + - therefore the remaining peak is created inside the C++ fused + backward/native-callable lifetime, before Python can observe returned + output pruning. + +Boundary manifest: + +- Unchanged semantic rows: current message, readout, transition, boundary, + reducer, reset, artifact, and output-route semantics. +- Changed strategy/runtime rows allowed: registered reverse program stage + telemetry rows, memory/liveness rows, native callable workspace/lifetime + rows, reverse span output rows, reverse artifact consumer route rows, and + reducer route rows. +- Bindings consumed directly: primitive rows, executor rows, native callable + binding/output rows, reverse program access rows, forward artifact route rows, + reverse artifact consumer rows, memory liveness rows, runtime buffer rows, and + reducer rows. +- Old route to avoid: no fixed-slot wrapper, no direct pybind helper, no + benchmark-side tiling/detach/checkpoint, no Axon/sLSTM or hidden-size branch, + no primitive formula inside temporal scheduler code. + +Plan: + +1. **Add C++ internal memory telemetry, not an optimization.** + - Add a registered fused-backward stage ledger around: + readout reverse, output-message reverse, recurrent-K/V projection reverse, + transition reverse groups, recurrent-message reverse, boundary K/V + projection reverse, initial recurrent K/V projection reverse, span-output + packing, and return packing. + - Record allocated, reserved, and max-allocated bytes at each stage through + runtime metadata/audit output. + - Keep the telemetry generic: stage rows and runtime metadata only, no family + or benchmark row keys. + - Acceptance: the exact Axon row names the stage where the `104.42GB` max + allocation is produced. + +2. **Pick one reducing patch from the measured stage.** + - If the peak is span-output packing, replace materialized output groups with + compiler-routed in-place consumers or single-owner runtime buffers. + - If the peak is a native callable temporary, add a registered native + strategy workspace/lifetime contract and make that callable write into + compiler-owned runtime buffers where legal. + - If the peak is reducer input materialization, make reducer routes consume + local stage outputs before global materialization. + - If the peak is transition reverse group output lifetime, split + transition-output consumers so recurrent-message gradients are accumulated + into the registered runtime buffer before keeping full group outputs alive. + - Acceptance: max allocated, not just after-call live allocation or named + logical bytes, moves down on the exact Axon row. + +3. **Keep compute tuning blocked until memory has headroom.** + - Do not optimize reverse math or launch count while the row remains around + `97 GiB`. + - After the stage peak moves, rerun the exact Axon row and the four-row T=1 + owner table to confirm sLSTM and Axon share the same registered program + path. + +4. **Parity and guardrails.** + - Focused source/compiler tests: + native-callable strategy locality, fused CUDA launch contract, memory + liveness rows, reverse span output rows, reducer route rows. + - CUDA parity: + T=1 terminal loss registered route, provided-state gradient route, + final-state-only loss route, and at least one T>1 guard proving the change + is not a T=1-only scheduler branch. + - Perf: + exact Axon owner row plus the four-row h32 100M B1024 owner table. + +Closure for this owner: + +- Exact Axon T=1 training peak moves materially below the current + `104416118272` max allocated. +- The moved bytes are attributed to a compiler-owned stage, route, liveness + policy, workspace contract, or reducer contract. +- Runtime metadata proves the registered fused backward program path ran. +- No public-state aliasing expansion, hidden fallback, fixed slot wrapper, + benchmark-owned route, or primitive formula leak is introduced. + +Implementation update: + +- Added native allocator stage rows inside the registered fused backward + program. The C++ program now appends a stripped telemetry group with + `(local_step, stage_id, allocated, reserved, max_allocated)` rows, and the + Python registered executor records those rows through the existing + `flat_bucket_temporal_registered_backward_memory_stage` audit ledger. +- Added stage names for the hidden C++ interval that Python could not observe: + native entry, grad-cells seed, readout reverse, output-message reverse, + recurrent-K/V reverse, front-output declaration, transition reverse, + recurrent-message runtime-buffer accumulation, recurrent-message reverse, + boundary-K/V reverse, initial recurrent-K/V reverse, boundary-output + declaration, step return, seed update, carry update, stable append, and + final return. +- Tightened registered fused backward liveness for the single-executor span + case. After the compiler-declared front outputs are built, the strategy now + drops unused readout/output-message/recurrent-KV intermediate vectors and + elided single-span groups before entering transition and recurrent-message + reverse stages. This is a compiler-owned registered strategy change: it uses + span/output rows and does not add family, benchmark, hidden-size, or fixed + slot branches. + +Checks started for this update: + +- Source guardrail: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + -> `1 passed`. +- Full static/compiler guardrail bundle: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + -> `173 passed in 9.47s`. +- CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_native_backward_stage_liveness_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_native_backward_stage_liveness_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short`. + -> `4 passed in 62.52s`. +- Exact Axon T=1 perf/owner row: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_native_backward_stage_liveness_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_native_backward_stage_liveness_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_native_backward_stage_liveness_axon_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner`. + -> `ok`, `775.6514 tok/s`, `93.6218 GiB`, + `100525639168` max allocated, + `78117060304` unclassified peak, + `90207651328` allocated after fused backward. +- This moved max allocated down by `3892314112` bytes from the previous + `104417953280` row while keeping the registered program path active. +- Native current-allocation telemetry now names the hottest in-call owner: + `native_after_initial_recurrent_kv_local0` at `98529150464` allocated, + followed by `native_after_boundary_kv_local0` at `96532661760` and + `native_after_recurrent_message_local0` at `96125814272`. The next reducing + patch should target initial recurrent-K/V reverse and boundary-K/V reverse + output/workspace lifetimes, not public-state aliasing. +- Four-row T=1 owner table: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_native_backward_stage_liveness_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_native_backward_stage_liveness_20260503 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir tmp/fabric_audits/partials/2026-05-03/t1_native_backward_stage_liveness_owner_table_h32_100m_b1024 --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm,axoncell --sizes 100m --modes forward,forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner`. + -> `ok`, 4 cases. + +| row | status | tokens/s | peak GiB | max allocated | unclassified peak | hottest native current stage | +| --- | --- | ---: | ---: | ---: | ---: | --- | +| sLSTM forward | ok | `5712.6934` | `18.2742` | `19621811200` | `11770109776` | n/a | +| sLSTM training | ok | `995.0450` | `27.8087` | `29859408384` | `16270082768` | `native_after_initial_recurrent_kv_local0=27988290048` | +| Axon forward | ok | `4116.2302` | `73.9835` | `79439160320` | `69375647568` | n/a | +| Axon training | ok | `777.7249` | `93.6201` | `100523804160` | `78115225296` | `native_after_initial_recurrent_kv_local0=98527315456` | + +Owner conclusion: + +- This pass produced a real memory movement on the registered compiler-owned + strategy path: Axon T=1 training max allocated moved from `104417953280` to + `100523804160` bytes in the four-row run. +- The remaining high-impact owner is no longer unknown. It is the native + recurrent/boundary sender-KV reverse tail: + `native_after_initial_recurrent_kv_local0`, + `native_after_boundary_kv_local0`, and + `native_after_recurrent_message_local0`. +- The next throughput pass should optimize those registered message reverse + strategy outputs through compiler-owned reducer routes or runtime buffers. + Do not reopen public-state aliasing. + +### 2026-05-03 - Skill Boundary Maintenance Before Next Throughput Pass + +Status: completed for the active skill set. + +Purpose: + +- Make future throughput work, native strategy work, and new cell/message/ + readout/primitive/graph additions follow the same compiler-boundary rules + without relying on chat memory. + +Skill updates: + +- `cb.fabric-workflow-router`: added explicit work lanes + (`semantic`, `strategy`, `native`, `evidence`, `cleanup`) and a required + handoff packet for analyze -> plan -> proceed prompts. +- `cb.fabric-compiler-boundary-audit`: added row-fingerprint stability for + throughput/native work, row-delta proof for semantic work, and active-route + ownership evidence before closure. +- `cb.fabric-throughput-strategy`: added semantic-stability guards, + row/binding fingerprint checks, and native output/reducer/metadata handling + rules, plus an explicit handoff to reducer/liveness review when a strategy + changes returned tensors, reducer inputs, runtime buffers, workspace, or + artifact lifetime. +- `cb.fabric-native-strategy-onboarding`: added ABI categories, allocation-site + classification, metadata-only telemetry rules, and singleton-assumption + rejection. +- `cb.fabric-compiler-extension`, `cb.fabric-primitive-op-onboarding`, + `cb.fabric-message-rule-onboarding`, `cb.fabric-cell-onboarding`, and + `cb.fabric-readout-rule-onboarding`: added row-delta/locality packets so new + formulas or declarations must change compiler rows or fail closed. +- `cb.fabric-parity-gate`: added native ABI locality and route-aware output + parity requirements. +- `cb.fabric-performance-loop`: added native-stage telemetry guidance and a + hard stop when a performance owner requires missing semantics. +- `cb.fabric-public-api-cleanup`, `cb.fabric-backend-boundaries`, and + `cb.fabric-skill-maintenance`: tightened routing, public-field ownership, and + cross-skill lane maintenance rules. + +New skill decision: + +- Added `cb.fabric-reducer-liveness` because reducer inputs, reverse span + outputs, native return groups, runtime buffers, workspace reuse, and + artifact/tape lifetime are now a recurring throughput risk. This is not a + duplicate of native strategy or performance work: it is the workflow that + decides whether a tensor is a semantic return, reducer input, carry/state + input, artifact/tape, workspace, metadata-only, or illegal. +- Existing skills were also strengthened so future throughput, native strategy, + compiler extension, parity, performance, and backend-boundary work must route + reducer/lifetime changes through compiler-owned rows rather than fixed return + groups or local ABI assumptions. +- Added `cb.fabric-boundary-guardrails` because source/static guardrails are a + recurring workflow and should not be written as broad string-presence checks. + The new skill requires every guardrail to pair a positive compiler-product + check with a negative stale-route ban, and states that guardrails do not + replace parity, active-route metadata, or warmed performance evidence. +- Tightened router and narrow skills so throughput, native strategy, reducer + liveness, cell, message-rule, readout, graph, primitive-op, public API + cleanup, parity, performance, and scaling work all route guardrail changes + through `cb.fabric-boundary-guardrails`. + +Validation: + +- Frontmatter check across `21` local skills: `bad=0`. +- Important repository hygiene: several of the newer skill directories are + still untracked and must be included when committing the skill set. + +### 2026-05-03 - Sender-KV Reverse Tail Liveness Slice + +Status: implemented, parity-checked, and measured. This is accepted as a small +registered-strategy memory movement, not T=1 throughput closure. + +Owner: + +- Selected owner: native recurrent/boundary sender-KV reverse tail in + `registered_temporal_fused_backward_program_cuda`. +- Prior four-row Axon training evidence: + `100523804160` max allocated, `93.6201 GiB`, hottest current stage + `native_after_initial_recurrent_kv_local0=98527315456`. + +Boundary audit: + +- Lane: throughput strategy / native implementation over existing compiler + products. +- Unchanged semantic rows: message, readout, transition, boundary, reset, + artifact, output-route, parameter-reducer, and primitive rows. +- Changed strategy/runtime behavior: single-message-span native reverse no + longer collects unused per-span output groups, and boundary K/V input grads + reuse the recurrent-message grad tensors in place before the initial + recurrent-K/V reverse call. +- Bindings/routes/liveness consumed directly: + `reverse_span_output_rows`, reverse executor rows, message native strategy + rows, artifact route rows, runtime buffer rows, and parameter reducer routes. +- Old route: no new route was added. The previous local lifetime behavior is + invalidated by the registered fused backward strategy itself. + +Implementation: + +- In `registered_program/backward_program.cuh`, single-span recurrent-message, + boundary-KV, and initial-recurrent-KV native calls now pass no per-span output + sink when compiler routing does not need per-span groups. +- `grad_input_k` / `grad_input_v` are accumulated in place into the + recurrent-message output tensors instead of allocating extra summed tensors. +- Large transient references are released before the initial recurrent-K/V + reverse strategy runs when the compiler-declared outputs do not need them. +- No primitive math, fixed-slot wrapper, benchmark policy, family selector, or + hidden-size selector was added. + +Validation: + +- Focused compiler guardrail: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + -> `1 passed`. +- Static compiler/source bundle: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + -> `173 passed in 66.36s`. +- Whitespace/source diff check: + `git diff --check` -> passed. +- CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_sender_kv_tail_liveness_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_sender_kv_tail_liveness_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + -> `4 passed in 61.00s`. + +Perf artifacts: + +- Exact Axon T=1 row: + `tmp/fabric_audits/partials/2026-05-03/t1_sender_kv_tail_liveness_axon_h32_100m_b1024`. +- Four-row owner table: + `tmp/fabric_audits/partials/2026-05-03/t1_sender_kv_tail_liveness_owner_table_h32_100m_b1024`. + +Exact Axon result: + +- Status: `ok`. +- Tokens/s: `776.4592`. +- Peak: `93.1218 GiB`. +- Max allocated: `99988768256`. +- Unclassified CUDA peak: `77580189392`. +- Allocated after fused backward: `90207651328`. +- Stage movement: + `native_after_initial_recurrent_kv_local0=96113231360`, + `native_after_boundary_kv_local0=96264226304`, + `native_after_recurrent_message_local0=96125814272`. + +Four-row result: + +| row | status | tokens/s | peak GiB | max allocated | unclassified peak | hottest native current stage | +| --- | --- | ---: | ---: | ---: | ---: | --- | +| sLSTM forward | ok | `5699.3972` | `18.2742` | `19621811200` | `11770109776` | n/a | +| sLSTM training | ok | `993.5858` | `27.8087` | `29859408384` | `16270082768` | `native_after_boundary_kv_local0=27145234944` | +| Axon forward | ok | `4114.2480` | `73.9835` | `79439160320` | `69375647568` | n/a | +| Axon training | ok | `777.0165` | `93.1201` | `99986933248` | `77578354384` | `native_after_boundary_kv_local0=96262391296` | + +Accepted/rejected: + +- Accepted: the max allocation moved down by about `0.54 GiB` in both exact + Axon and four-row evidence, and the hottest current stage moved from initial + recurrent-KV to boundary-KV. This proves the patch changed actual native + lifetime/storage pressure, not only metadata. +- Rejected as closure: the movement is too small to close T=1 memory or + throughput. Axon training remains around `93 GiB`, far above the April 21 + `2.07 GiB` memory floor, and throughput remains around `777 tok/s` versus + `58732.71 tok/s`. +- Public-state aliasing remains rejected for training. + +Next owner: + +- The remaining local owner is now `native_after_boundary_kv_local0`, with + `native_after_recurrent_message_local0` and the hidden peak inside + initial-recurrent-KV still close behind. +- The next reducing patch should target the registered boundary-KV/initial-KV + native strategy workspace and graph-order materialization lifetime through + compiler-owned reducer/runtime-buffer contracts. Do not move to compute + throughput or reopen aliasing until that memory owner has more headroom. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner: Value-Only Sender-KV Reverse + +Status: implemented as a native-strategy cleanup probe; rejected as the active +Axon T=1 throughput owner after measurement. + +Selected owner: + +- Highest-impact row remains Axon 100M h32 B1024 T=1 terminal training. +- Latest accepted four-row owner evidence: + `777.0165 tok/s`, `93.1201 GiB`, `99986933248` max allocated, + `77578354384` unclassified peak. +- Hottest native stage: + `native_after_boundary_kv_local0=96262391296`, with + `native_after_recurrent_message_local0=96125814272` and + `native_after_initial_recurrent_kv_local0=96113231360` close behind. + +Boundary classification: + +- Lane: throughput strategy / native implementation over existing compiler + products. +- Semantic rows must stay unchanged: graph, message, readout, transition, + primitive, artifact, output-route, reset, parameter-reducer, and memory rows. +- The strategy may only change registered native implementation and liveness: + no new message math, no scheduler formula, no benchmark policy, no fixed slot + temporal ABI, no family or hidden-size selector. +- Required compiler products already on the active path: + message native strategy rows, native callable binding schema rows, + reverse executor rows/bindings, reverse span output rows, parameter reducer + routes, runtime buffer rows, memory liveness rows, and artifact routes. + +Hypothesis to test before code: + +- The current fixed-slot-context sender-KV reverse path constructs synthetic + K/V projection weights by prefixing zero key channels onto value-only weights + before calling the generic sender-KV projection backward. This appears in both + boundary-KV and initial recurrent-KV reverse. That zero-prefix + concat path + is a native-strategy implementation detail, not Fabric semantics. +- If allocator telemetry confirms the peak is inside this synthetic weight + materialization or generic K/V projection temporary, the next legal strategy + is a registered value-only sender-value reverse executor that consumes the + same message primitive rows and parameter bindings directly. + +Plan: + +1. **Add stage telemetry, not optimization, first.** + - Split the current native stage into sub-stages: + boundary value-weight lookup, boundary zero-prefix materialization, + boundary sender-value projection backward, initial value-weight lookup, + initial zero-prefix materialization, initial sender-value projection + backward, graph-order materialization, and reducer-route handoff. + - Record current allocated/reserved at each stage as metadata only. + - Acceptance: the hidden peak is named as a concrete native sub-stage, not + broad `unclassified` memory. + +2. **Register a value-only reverse strategy for fixed-slot-context sender + value projections.** + - Add or reuse native callable rows so fixed-slot-context boundary and + recurrent initial K/V reverse can select a value-only implementation from + message primitive rows. + - Required inputs: boundary/recurrent hidden state, value weight, optional + grouped value weight, `grad_v`, graph/backend order rows, grouping rows, + and reducer/output route rows. + - Explicitly reject non-value-only layouts, missing bindings, grouped shapes + the strategy cannot implement, and any row where `grad_k` is semantically + required. + +3. **Implement value-only native reverse without synthetic concatenation.** + - Do not build `[zero_key, value_weight]` tensors. + - Compute `grad_hidden` and `grad_value_weight` directly from the value + projection roles. + - Produce the same compiler-declared reverse span outputs: + boundary projection raw grad, input/recurrent value-weight grad, grouped + flag, hidden graph-order grad, and reducer-route inputs. + - Keep any graph-order materialization behind route/liveness checks; if the + reducer can consume backend-order gradients, route through compiler rows + instead of allocating a graph-order copy. + +4. **Route outputs through reducer/runtime ownership.** + - `grad_input_kv_weight` and + `grad_initial_recurrent_kv_weight_graph_order` remain reducer inputs, not + semantic Python returns. + - Local-only K-bank placeholders stay illegal/empty for value-only rows. + - Boundary hidden/input gradients either write the declared runtime/carry + buffer or return through the existing reverse span rows only when required. + +5. **Fail closed for unsupported cases.** + - If the message strategy has true key gradients, grouped key/value weights + requiring the generic K/V path, multiple spans without route rows, or reset + behavior not represented by rows, reject before launch with a typed native + strategy reason. + - Do not silently fall back to the old synthetic K/V path for supported + rows. Keep the generic path only for rows that are still legally outside + the value-only strategy. + +Parity gate: + +- Targeted CUDA parity before perf: + `test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store`, + `test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients`, + `test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window`, + and `test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific`. +- Parameter-gradient parity is mandatory for the message value-weight, + message context/key/nudge parameters, sender K/V projection parameters, + recurrent query, readout, and transition reducers touched by the route. + +Perf gate: + +- First rerun the exact Axon row: + `t1-single-pop`, `axoncell`, `100m`, `forward_backward`, `B=1024`, + `T=1`, `h=32`, terminal boundary, no reset. +- Then rerun the four-row owner table: + sLSTM/Axon forward and training for `100m`, `B=1024`, `T=1`, `h=32`. +- Accepted movement requires lower max allocated and lower named native stage + allocation, not just lower after-call live allocation or metadata relabeling. + +Keep/revert rule: + +- Keep if the hidden peak moves from `native_after_boundary_kv_local0` / + `native_after_initial_recurrent_kv_local0` and parity is green. +- Narrow if forward or non-value-only rows are unaffected but one reverse route + grows allocator pressure. +- Revert if max allocated or unclassified peak grows, or if the strategy relies + on tensor-name/fixed-slot assumptions outside registered message native + strategy bindings. + +### 2026-05-03 - Value-Only Sender-KV Reverse Probe Result + +Status: parity-safe cleanup, not an accepted throughput win. + +Implemented: + +- Added a registered sender value-projection backward native kernel for + fixed-slot-context reverse paths. +- Removed the fixed-slot-context synthetic `[zero_key, value_weight]` K/V + concat helper. +- Updated the sender K/V parameter reducer to accept value-only weight-gradient + rows and accumulate only the value projection side. + +Validation: + +- `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_parameter_reducer_outputs_are_compiler_provided_tensor_table --tb=short` + - Result: `1 passed`. +- Targeted CUDA parity: + `test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store`, + `test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients`, + `test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window`, + `test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific`. + - Result: `4 passed`. + +Exact Axon T=1 measurement: + +- Command artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_value_only_sender_reverse_axon_h32_100m_b1024`. +- Result: `774.374 tok/s`, `93.1218 GiB`, + `cuda_max_allocated_bytes=99988768256`, + `fabric_unclassified_cuda_peak_bytes=77580189392`. +- Native stages: + `native_after_transition_local0=92097447424`, + `native_after_recurrent_message_local0=96125814272`, + `native_after_boundary_kv_local0=96264226304`, + `native_after_initial_recurrent_kv_local0=96113231360`. + +Conclusion: + +- The exact active Axon row still selects + `neighborhood_attention_project`, not the fixed-slot-context strategy. +- Keep the value-only patch as a source cleanup because it removes a synthetic + native concat and keeps reducer semantics compiler-row-owned. +- Do not count it as throughput closure. The active owner remains registered + reverse transition/message liveness in the `neighborhood_attention_project` + row. + +### 2026-05-03 - Transition Output Keep-Slot Liveness + +Status: accepted memory/liveness improvement; throughput still open. + +Boundary audit: + +- Lane: registered compiler-owned throughput/liveness strategy. +- Semantics unchanged: primitive rows, executor rows, tensor bindings, + artifacts, output routes, reset rows, and reducer routes are unchanged. +- Changed product: transition output keep-slot rows, produced from existing + reverse output bindings and transition parameter-gradient bindings. +- Old behavior: each transition group returned the full per-bucket program + tensor table to the full reverse step, keeping recomputed forward tensors and + internal reverse intermediates live through recurrent-message and sender-KV + reverse. +- New behavior: after transition state reset handling, the fused backward + program keeps only declared output slots needed by recurrent-message input, + transition state seeds, and transition parameter reducers. Slot indices remain + stable; non-kept slots are explicit empty tensors. + +Implementation: + +- Added `transition_output_keep_slot_row_groups` to the registered fused + backward launch ABI. +- Built keep-slot rows in + `src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` + from: + `grad_aggregated_message`, transition state-gradient outputs, and + transition parameter-gradient binding outputs. +- Added C++ validation and filtering in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_program.cuh`. +- Added source guardrails so the keep-slot rows remain in the Python launch + ABI, C++ fused program, and tests. + +Validation: + +- Focused source/compiler checks: + - `test_parameter_reducer_outputs_are_compiler_provided_tensor_table`: `1 passed`. + - `test_transition_reverse_seed_roles_are_compiler_owned_rows`: `1 passed`. +- Static/compiler guardrail bundle: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + - Result: `173 passed`. +- Targeted CUDA parity: + `test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store`, + `test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients`, + `test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window`, + `test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific`. + - Result: `4 passed`. + +Exact Axon T=1 measurement: + +- Command artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_keep_slots_axon_h32_100m_b1024`. +- Result: `776.270 tok/s`, `91.0224 GiB`, + `cuda_max_allocated_bytes=97734592000`, + `fabric_unclassified_cuda_peak_bytes=75326013136`. +- Stage movement versus the prior accepted Axon training row: + - Max allocated: `99986933248 -> 97734592000`. + - Peak memory: `93.1201 GiB -> 91.0224 GiB`. + - `native_after_boundary_kv_local0`: `96262391296 -> 77473744384`. + - `native_after_initial_recurrent_kv_local0`: `96113231360 -> 77322749440`. + - `native_after_transition_local0` remains `92095612416`. + +Four-row owner table: + +- Artifact: `tmp/fabric_audits/partials/2026-05-03/t1_transition_keep_slots_four_row`. +- `slstm forward`: `5723.979 tok/s`, `18.274 GiB`, + `max=19621811200`. +- `slstm forward_backward`: `994.377 tok/s`, `27.809 GiB`, + `max=29859408384`, + `native_after_transition_local0=25631549952`. +- `axoncell forward`: `4114.706 tok/s`, `73.983 GiB`, + `max=79439160320`. +- `axoncell forward_backward`: `778.464 tok/s`, `91.021 GiB`, + `max=97732756992`, + `native_after_transition_local0=92095612416`, + `native_after_boundary_kv_local0=77471909376`. + +Conclusion: + +- Accepted: this moves the named liveness owner after transition and avoids + carrying full transition program tables through later reverse stages. +- Not closed: the peak is now the transition reverse stage itself. The next + owner is the registered transition reverse implementation that recomputes + forward transition tensors and materializes large intermediate gradients + inside `native_after_transition_local0`. + +### 2026-05-03 - Transition Reverse Internal Liveness + +Status: accepted memory/liveness improvement; throughput still open. + +Boundary audit: + +- Lane: registered compiler-owned throughput/native liveness strategy. +- Semantics unchanged: primitive rows, executor rows, tensor bindings, + artifact/output routes, reset rows, and reducer rows are unchanged. +- Changed implementation: the registered transition reverse native handlers now + release local-only reverse intermediates after their compiler-declared + consumers have run. +- Added metadata-only native substage telemetry inside the registered transition + reverse group. The telemetry reports entry, parameter binding, dynamic binding, + forward recompute, and reverse primitive allocation points. It is stripped into + audit metadata and does not become a semantic return. + +Implementation: + +- Added transition-reverse group memory stage rows in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_program.cuh`. +- Passed stage rows through the fused registered backward program in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_program.cuh`. +- Released local-only gated/diagonal reverse intermediates in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_handlers.cuh` + after their downstream native consumers completed. +- Exposed the new substage names in + `src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py`. + +Validation: + +- Focused source/compiler checks: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_parameter_reducer_outputs_are_compiler_provided_tensor_table tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + - Result: `4 passed`. +- Targeted CUDA parity: + `test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store`, + `test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients`, + `test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window`, + `test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific`. + - Result: `4 passed`. + +Exact Axon T=1 measurement: + +- Command artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_reverse_liveness_axon_h32_100m_b1024`. +- Result: `775.593 tok/s`, `89.216 GiB`, + `cuda_max_allocated_bytes=95794988544`, + `fabric_unclassified_cuda_peak_bytes=73386409680`. +- Stage movement versus the prior accepted keep-slot Axon training row: + - Max allocated: `97734592000 -> 95794988544`. + - Peak memory: `91.0224 GiB -> 89.216 GiB`. + - `native_after_transition_local0`: `92095612416 -> 92097447424` + effectively unchanged. + - New substage owner: + `native_transition_group_dynamic_bound_local0=90030179840`, + `native_transition_group_after_forward_recompute_local0=90030179840`, + `native_transition_group_after_reverse_primitive_local0=92097447424`. + - Keep-slot release point remains effective: + `native_after_transition_keep_slots_local0=73306965504`. + +Four-row owner table: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_reverse_liveness_four_row`. +- `slstm forward`: `5717.999 tok/s`, `18.274 GiB`, + `max=19621811200`. +- `slstm forward_backward`: `985.176 tok/s`, `26.121 GiB`, + `max=28046875136`, + `native_transition_group_after_reverse_primitive_local0=25631549952`, + `native_after_transition_keep_slots_local0=23215630848`. +- `axoncell forward`: `4118.635 tok/s`, `73.983 GiB`, + `max=79439160320`. +- `axoncell forward_backward`: `775.116 tok/s`, `89.214 GiB`, + `max=95793153536`, + `native_transition_group_after_reverse_primitive_local0=92095612416`, + `native_after_transition_keep_slots_local0=73305130496`. + +Conclusion: + +- Accepted: this shortens local reverse-intermediate lifetimes after their + compiler-declared consumers run and lowers Axon training peak memory by about + `1.81 GiB` beyond the keep-slot row. +- Not closed: the high-water mark is still the registered transition reverse + primitive itself. The next owner is not later span output retention; it is the + transition reverse strategy's dynamic binding and forward-recompute workspace + inside the transition group. +- Next work should move those tensors into explicit compiler-owned workspace or + reducer/liveness rows, then prove the stage movement with the new substage + telemetry before attempting compute-throughput tuning. + +### 2026-05-03 - Transition Reverse Zero-Seed Dynamic Binding Cache + +Status: accepted registered-strategy liveness improvement; throughput still open. + +Boundary audit: + +- Lane: registered compiler-owned throughput/native liveness strategy. +- Semantics unchanged: primitive rows, executor rows, tensor bindings, + artifact/output routes, reset rows, reducer rows, and transition seed role + rows are unchanged. +- Changed implementation: missing `SeedOrZeros` dynamic transition reverse + bindings now reuse a read-only zero tensor per shape/device/dtype inside the + registered transition reverse group. Present seed tensors are still consumed + from compiler-owned transition seed rows. +- No public-state alias route was expanded. This patch targets only missing + reverse seed inputs selected by compiler dynamic-binding rows. + +Implementation: + +- Added `transition_seed_tensor_or_cached_zeros` in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/reverse_artifacts_and_resets.cuh`. +- Threaded a per-group `cached_zero_seed_tensors` cache through + `bind_transition_dynamic_tensors_for_handlers` in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_program.cuh`. +- Deleted the unused uncached `transition_seed_tensor_or_zeros` helper so the + registered path has one dynamic seed-zero implementation. +- Added a source guardrail in `tests/test_fabric_backend_boundaries.py` so the + dynamic seed cache remains part of the registered compiler-owned program. + +Validation: + +- Focused compiler/source checks: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + - Result: `3 passed`. +- Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_seed_cache_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_seed_cache_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + - Result: `4 passed`. + +Exact Axon T=1 measurement: + +- Command artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_seed_zero_cache_axon_h32_100m_b1024`. +- Result: `775.431 tok/s`, `87.466 GiB`, + `cuda_max_allocated_bytes=93915940352`, + `fabric_unclassified_cuda_peak_bytes=71507361488`. +- Stage movement versus prior accepted transition reverse internal-liveness + Axon row: + - Max allocated: `95794988544 -> 93915940352`. + - Peak memory: `89.216 GiB -> 87.466 GiB`. + - `native_transition_group_dynamic_bound_local0`: + `90030179840 -> 88151131648`. + - `native_transition_group_after_forward_recompute_local0`: + `90030179840 -> 88151131648`. + - `native_transition_group_after_reverse_primitive_local0`: + `92097447424 -> 90218399232`. + - Keep-slot release point remains effective: + `native_after_transition_keep_slots_local0=73306965504`. + +Four-row owner table: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_seed_zero_cache_four_row`. +- `slstm forward`: `5686.225 tok/s`, `18.274 GiB`, + `max=19621811200`. +- `slstm forward_backward`: `973.252 tok/s`, `24.433 GiB`, + `max=26234935808`, + `native_transition_group_dynamic_bound_local0=23061092864`, + `native_transition_group_after_reverse_primitive_local0=23819610624`, + `native_after_transition_keep_slots_local0=23215630848`. +- `axoncell forward`: `4112.290 tok/s`, `73.983 GiB`, + `max=79439160320`. +- `axoncell forward_backward`: `775.914 tok/s`, `87.464 GiB`, + `max=93914105344`, + `native_transition_group_dynamic_bound_local0=88149558784`, + `native_transition_group_after_reverse_primitive_local0=90216564224`, + `native_after_transition_keep_slots_local0=73305130496`. + +Conclusion: + +- Accepted: the intended owner moved. The dynamic-bound and forward-recompute + stages dropped by about `1.88 GiB` on the exact Axon training row, and no + unexplained allocator spike appeared. +- Not closed: the high-water mark is still inside the registered transition + reverse primitive and its forward-recompute inputs. The next owner is to move + state-before fallback zeros and recurrent-message slices into explicit + compiler-owned workspace/liveness rows, then remeasure + `native_transition_group_dynamic_bound_local0` and + `native_transition_group_after_reverse_primitive_local0`. + +### 2026-05-03 - Transition Reverse Dynamic Workspace Narrowing + +Status: accepted registered-strategy liveness improvement; throughput still open. + +Boundary audit: + +- Lane: registered compiler-owned throughput/native liveness strategy. +- Semantics unchanged: primitive rows, executor rows, tensor bindings, + artifact/output routes, reset rows, reducer rows, and transition seed role + rows are unchanged. +- Changed implementation: transition reverse dynamic binding no longer + materializes fallback recurrent-hidden template slices just to obtain zero + seed shapes. Full-span recurrent-message inputs use the compiler-owned + artifact directly; non-full recurrent-message spans and missing state-before + fallbacks require explicit runtime-buffer/liveness rows before launch. +- Runtime rows added: + `transition_reverse_recurrent_msg_span` and + `transition_reverse_state_before_zero`. + +Rejected broad probe: + +- First version allocated both reverse-dynamic buffers for every transition + bucket. It improved Axon but regressed the sLSTM training guardrail by about + `1.2 GiB`, because the single-bucket path did not need a copied recurrent + message span and did not need a missing state-before zero buffer. +- The accepted version narrows allocation to only real compiler-proven needs: + non-full recurrent-message spans or dynamic state-before rows missing an + artifact binding. + +Implementation: + +- Added compiler memory-liveness entries and runtime role opcodes for + transition reverse dynamic buffers in + `src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py` + and `registered_program/constants_and_checks.cuh`. +- Added `_transition_reverse_dynamic_runtime_buffer_requests(...)` in + `src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py`. +- Updated + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_program.cuh` + so dynamic binding consumes planned buffers when required and otherwise uses + compiler-owned artifact views without ad hoc contiguous copies. + +Validation: + +- Focused compiler/source checks: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + - Result: `3 passed`. +- Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_dyn_workspace_parity_narrow_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_dyn_workspace_parity_narrow_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + - Result: `4 passed`. + +Exact Axon T=1 measurement: + +- Command artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_dynamic_workspace_narrow_axon_h32_100m_b1024`. +- Result: `776.033 tok/s`, `79.237 GiB`, + `cuda_max_allocated_bytes=85079713792`, + `fabric_unclassified_cuda_peak_bytes=62671134920`. +- Stage movement versus prior accepted zero-seed cache Axon row: + - Max allocated: `93915940352 -> 85079713792`. + - Peak memory: `87.466 GiB -> 79.237 GiB`. + - `native_transition_group_dynamic_bound_local0`: + `88151131648 -> 74997794816`. + - `native_transition_group_after_forward_recompute_local0`: + `88151131648 -> 74997794816`. + - `native_transition_group_after_reverse_primitive_local0`: + `90218399232 -> 77065062400`. + - Runtime buffers allocated: + `68835313152 -> 66956265472` versus the broad-buffer probe. + +Four-row owner table: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_dynamic_workspace_narrow_four_row`. +- `slstm forward`: `5664.277 tok/s`, `18.274 GiB`, + `max=19621812224`. +- `slstm forward_backward`: `970.713 tok/s`, `24.433 GiB`, + `max=26234936832`, + `native_transition_group_dynamic_bound_local0=23061093888`, + `native_transition_group_after_reverse_primitive_local0=23819611648`, + `native_after_transition_keep_slots_local0=23215631872`. +- `axoncell forward`: `4114.261 tok/s`, `73.983 GiB`, + `max=79439161344`. +- `axoncell forward_backward`: `775.599 tok/s`, `79.236 GiB`, + `max=85078927360`, + `native_transition_group_dynamic_bound_local0=74996221952`, + `native_transition_group_after_reverse_primitive_local0=77063227392`, + `native_after_transition_keep_slots_local0=75184179200`. + +Conclusion: + +- Accepted: the high-impact Axon dynamic/reverse-primitive owner moved by about + `13.15 GiB` at dynamic bind and `13.15 GiB` after reverse primitive versus + the previous accepted row, while the sLSTM guardrail returned to the previous + memory level. +- Not closed: the largest remaining T=1 training owner is now later in the + registered reverse program, especially recurrent-message/boundary KV and + boundary-output stages around `81 GiB` on the exact Axon row. Next work should + build the owner table from the accepted narrow artifact before choosing the + next registered strategy. + +### 2026-05-03 - Initial Recurrent Hidden Grad Optional Liveness + +Status: accepted local registered-strategy liveness/throughput improvement; +T=1 memory high-water still open. + +Boundary audit: + +- Lane: registered compiler-owned throughput/native liveness strategy. +- Semantics unchanged: primitive rows, executor rows, tensor bindings, + artifact/output routes, reset rows, reducer routes, and parameter-gradient + contracts are unchanged. +- Changed implementation: reverse message recurrent-K/V strategies now accept a + `return_input_grad` implementation flag. The recurrent K/V reverse path still + requests hidden input gradients because transition reverse consumes them. The + initial recurrent K/V tail requests hidden input gradients only when the + compiler runtime-buffer plan materializes `reverse_grad_carry_cells`. +- The recurrent K/V weight gradient is still returned and reduced through the + same compiler-owned sender-KV reducer rows. Only the local hidden-gradient + materialization and graph-order `index_select` are skipped for terminal T=1 + rows with no state/carry consumer. + +Rejected probe before the accepted slice: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_sender_keybank_drop_axon_h32_100m_b1024`. +- Result: `775.903 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`, + `fabric_unclassified_cuda_peak_bytes=62671134920`. +- Rejected and reverted: dropping the recurrent key-bank span output did not + move the high-water allocation or produce a useful boundary-output owner + change. The next patch stayed on the compiler liveness signal instead of + deleting another returned span slot by hand. + +Implementation: + +- Added `return_input_grad` to the registered reverse message recurrent-K/V + strategy ABI in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_surface_steps.cuh`. +- Updated neighborhood-attention and fixed-slot-context reverse message + strategies in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_reverse_strategies.cuh` + so sender projection backward kernels skip sender hidden gradients when the + compiler caller declares no consumer. +- Updated the fused reverse program in + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_program.cuh` + to pass the compiler runtime-buffer ownership bit + `reverse_grad_carry_cells` to the initial recurrent K/V tail. An earlier + attempt keyed this on transition state-gradient ownership; targeted parity + caught that state-only and T>1 losses still require carry-cell + materialization, so the accepted condition is the runtime-buffer liveness row. + +Validation: + +- Focused compiler/source checks: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + - Result: `3 passed`. +- Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_initial_hidden_optional_parity2_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_initial_hidden_optional_parity2_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + - Result: `4 passed`. + +Exact Axon T=1 measurement: + +- Command artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_initial_recurrent_hidden_optional_axon_h32_100m_b1024`. +- Result: `856.533 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`, + `fabric_unclassified_cuda_peak_bytes=62671134920`. +- Stage movement versus prior accepted dynamic-workspace-narrow Axon row: + - Max allocated unchanged: `85079713792 -> 85079713792`. + - `native_after_recurrent_message_local0` unchanged: + `79214381056 -> 79214381056`. + - `native_after_boundary_kv_local0` unchanged: + `79352793088 -> 79352793088`. + - `native_after_initial_recurrent_kv_local0`: + `79201798144 -> 77322749952`. + - `native_after_boundary_outputs_local0`: + `77322749952 -> 75443701760`. + +Four-row owner table: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_initial_recurrent_hidden_optional_four_row`. +- `slstm forward`: `5712.180 tok/s`, `18.274 GiB`, + `max=19621812224`. +- `slstm forward_backward`: `1011.725 tok/s`, `24.433 GiB`, + `max=26234936832`, + `native_after_initial_recurrent_kv_local0=23961759232`, + `native_after_boundary_outputs_local0=23357779456`. +- `axoncell forward`: `4114.929 tok/s`, `73.983 GiB`, + `max=79439161344`. +- `axoncell forward_backward`: `857.242 tok/s`, `79.236 GiB`, + `max=85078927360`, + `native_after_recurrent_message_local0=79212546048`, + `native_after_boundary_kv_local0=79350958080`, + `native_after_initial_recurrent_kv_local0=77320914944`, + `native_after_boundary_outputs_local0=75441866752`. + +Conclusion: + +- Accepted: the intended local lifetime edge moved. Terminal no-state T=1 rows + no longer materialize an unused initial recurrent hidden gradient, while + state-only, provided-state, and T>1 parity remain green because they still + own `reverse_grad_carry_cells`. +- Not closed: peak memory and the high-water allocator owner did not move. The + remaining owner is earlier in the registered reverse program: + recurrent-message and boundary-KV reverse still reach about `85.08GB` + max allocation, with current allocations near `79.35GB`. +- Next work should target the boundary-KV/recurrent-message reverse owner by + moving local-only sender K/V intermediates into compiler-owned reducer or + runtime-buffer lifetimes, not by adding more span-output deletion shortcuts. + +### 2026-05-03 - Singleton Recurrent Message Grad Buffer Elision + +Status: accepted narrow registered-strategy liveness improvement; T=1 +max-allocated high-water still open. + +Boundary audit: + +- Lane: registered compiler-owned runtime-buffer/liveness strategy. +- Semantics unchanged: primitive rows, executor rows, tensor bindings, + artifact/output routes, reset rows, readout/message math, transition math, + and parameter reducer routes are unchanged. +- Compiler legality condition: the reverse runtime-buffer plan elides + `reverse_grad_recurrent_msg` only when `recurrent_msg_output_rows` has one + compiler row covering `[0, recurrent_count)` and one transition group. Any + multi-transition or partial-span program still requests the runtime buffer. +- Implementation shape: the fused reverse program consumes the transition + `grad_aggregated_message` tensor directly as the recurrent-message input + gradient, then clears the consumed transition output slot before later + stages. No fixed-slot or family-specific route owns the decision. + +Rejected probe before the accepted slice: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_reduced_key_bank_axon_h32_100m_b1024`. +- Result: `857.402 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`. +- Rejected and reverted: reducing the fixed-slot key-bank gradient shape passed + parity, but the active Axon row uses the registered + `neighborhood_attention_project` message strategy, so this did not move the + real active owner. + +Implementation: + +- Added `_requires_reverse_grad_recurrent_msg_runtime_buffer()` in + `src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` + and threaded the resulting optional shape into + `build_temporal_runtime_buffer_plan()`. +- Updated + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_program.cuh` + so singleton full-transition programs use the transition + `grad_aggregated_message` tensor directly when the compiler runtime-buffer + rows do not contain `reverse_grad_recurrent_msg`. + +Validation: + +- Focused compiler/source checks: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + - Result: `3 passed`. +- Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_recurrent_msg_direct_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_recurrent_msg_direct_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + - Result: `4 passed`. + +Exact Axon T=1 measurement: + +- Command artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_recurrent_msg_direct_axon_h32_100m_b1024`. +- Result: `858.574 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`, + `fabric_unclassified_cuda_peak_bytes=62671134920`. +- Stage movement versus the previous accepted Axon row: + - Max allocated unchanged: `85079713792 -> 85079713792`. + - `runtime_buffers_allocated`: `66956265472 -> 65077217280` + (`-1879048192` bytes). + - `native_after_recurrent_msg_buffer_local0`: + `75186014208 -> 73306966016` (`-1879048192` bytes). + - `native_after_recurrent_message_local0`: + `79214381056 -> 77335332864` (`-1879048192` bytes). + - `native_after_boundary_kv_local0`: + `79352793088 -> 75594696704` (`-3758096384` bytes). + - `native_after_boundary_outputs_local0`: + `75443701760 -> 71685605376` (`-3758096384` bytes). + +Four-row owner table: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_recurrent_msg_direct_four_row`. +- `slstm forward`: `5700.100 tok/s`, `18.274 GiB`, + `max=19621812224`. +- `slstm forward_backward`: `1023.468 tok/s`, `23.871 GiB`, + `max=25630957056`, + `runtime_buffers_allocated=19698114048`, + `native_after_recurrent_message_local0=24021528064`, + `native_after_boundary_kv_local0=23521357312`. +- `axoncell forward`: `4111.745 tok/s`, `73.983 GiB`, + `max=79439161344`. +- `axoncell forward_backward`: `859.051 tok/s`, `79.236 GiB`, + `max=85078927360`, + `runtime_buffers_allocated=65076430848`, + `native_after_recurrent_message_local0=77333497856`, + `native_after_boundary_kv_local0=75592861696`, + `native_after_boundary_outputs_local0=71683770368`. + +Conclusion: + +- Accepted: this patch moves the real registered Axon and sLSTM training live + allocations through compiler-owned runtime-buffer liveness, and keeps targeted + parity green. +- Not closed: `cuda_max_allocated_bytes` and + `fabric_unclassified_cuda_peak_bytes` did not move. The exact owner is now the + hidden high-water allocation that is already present by every registered + backward stage max sample, not the current recurrent-message buffer lifetime. +- Next work should name and split that high-water owner before adding another + liveness alias. The immediate target is allocator/stage instrumentation around + reverse tensor-table construction, runtime-buffer allocation, and early native + entry so the `85.08GB` max allocation has a concrete compiler product owner. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner: Phase-Scoped High-Water Attribution + +Status: planned; no optimization in this step. + +Current owner: + +- Exact Axon artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_recurrent_msg_direct_axon_h32_100m_b1024`. +- Result: `858.574 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`, + `fabric_unclassified_cuda_peak_bytes=62671134920`. +- The last accepted liveness slice lowered live allocations: + `runtime_buffers_allocated=65077217280`, + `native_after_recurrent_message_local0=77335332864`, + `native_after_boundary_kv_local0=75594696704`. +- The max/high-water owner is still unnamed: every registered backward native + max sample reports the same `85079713792`, while the first Python-side + registered backward current sample is only + `reverse_tensor_table_built=53534444032`. That means the next owner is earlier + than the current native-stage liveness rows or is hidden by global allocator + peak accounting. + +Boundary manifest: + +```text +Lane: throughput evidence -> registered liveness/native strategy +Public declaration/spec owner: unchanged Fabric model declaration +Rows expected to change: none for the attribution pass +Rows expected to stay stable: primitive rows, executor rows, tensor bindings, + artifact/output routes, reset rows, reducer routes, memory liveness semantics +Bindings/routes/liveness consumed: existing registered reverse artifact rows, + runtime-buffer rows, native memory-stage rows, reducer rows +Reference executor: unchanged PyTorch Fabric reference +Native strategy, if any: existing registered fused reverse program only +Backward/reducer owner: unchanged registered reverse program and parameter reducers +Unsupported typed blocker: unchanged legality; no new supported declarations +Old route deleted or fail-closed: no new route; no fixed-slot/fallback wrapper +Evidence gates: phase-scoped allocator telemetry, targeted parity, exact Axon row, + four-row guardrail +``` + +Plan: + +1. **Name the hidden high-water owner before optimizing.** + Add audit-only phase-scoped allocator telemetry around the registered T=1 + backward handoff: + - autograd backward entry, before reverse artifact/window construction; + - after reverse artifact tensor-store/window access materialization; + - after output-grad window materialization; + - after reverse executable tensor table construction; + - after runtime-buffer plan construction; + - after runtime-buffer allocation; + - before native fused reverse entry; + - first native stage entry. + The normal performance ledger must remain comparable. If local peak resets + are required to identify the high-water, keep them behind an explicit + owner-timing/debug flag and record those rows as attribution evidence, not + closure metrics. + +2. **Classify the high-water as a compiler product.** + The next patch is chosen only after the attribution row names one of: + - reverse artifact/tape tensor store; + - output-grad window materialization or route merge; + - runtime-buffer allocation overlap; + - native temporary/workspace inside registered fused reverse; + - autograd saved tensor or carry/state clone; + - allocator reserve/cache gap. + If the owner cannot be mapped to a compiler product, do not add aliases or + delete return slots; add narrower telemetry first. + +3. **Close the named owner with a compiler-owned strategy.** + Use the owner classification to pick exactly one legal strategy: + - artifact/tape owner: shorten artifact lifetime through artifact access rows + or recompute policy; do not drop required roles by name. + - output-grad owner: use output route rows to avoid full-window clone/zero + materialization where a routed view or singleton route is legal. + - runtime-buffer overlap: delay allocation or alias/reuse only through + memory-liveness rows; no ad hoc native temporaries. + - native workspace owner: move the temporary to a declared runtime workspace + or consume/reduce it locally through reducer/liveness rows. + - autograd saved-tensor owner: make the saved tensor/tape policy explicit in + compiler artifact rows before changing storage behavior. + +4. **Reject probes that do not move the real owner.** + A patch is rejected or narrowed if: + - primitive rows, tensor roles, or gradient contracts change; + - parity fails on outputs, state/carry/input gradients, or parameter grads; + - the active route demotes to fallback/replay/compat; + - current allocations move but the named high-water owner does not move and + the patch cannot be described as a narrow real lifetime improvement; + - max memory increases or unclassified memory grows without classification. + +5. **Validation gates.** + - Source/compiler: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + - Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_highwater_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_highwater_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + - Exact Axon perf: + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 ... run_audit --plan t1-single-pop --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent` + - Four-row guardrail: + same T=1 matrix for `slstm,axoncell` and `forward,forward_backward`. + +Acceptance: + +- The attribution pass is accepted only if it names the high-water owner more + precisely than "unclassified CUDA peak." +- The optimization pass is accepted only if that named owner moves in + allocator/stage telemetry and targeted parity remains green. +- T=1 closure still requires the normal non-debug exact Axon row and four-row + guardrail to improve or stay non-regressed; debug peak-reset telemetry alone + is not performance evidence. + +### 2026-05-03 - Phase-Scoped High-Water Attribution Result + +Status: accepted as evidence/measurement hygiene; not accepted as throughput +closure. + +Boundary classifier: + +- Lane: evidence + measurement hygiene for registered compiler-owned throughput + work. +- Semantic rows changed: none. +- Strategy/runtime rows changed: none. +- Accepted code changes: + - clear warmup training gradients before CUDA peak reset; + - run Python GC before the measurement reset so dead warmup references are not + charged to measured iterations; + - reset Fabric temporal memory-stage telemetry at the same boundary as CUDA + peak reset; + - keep forward/backward memory stages across all measured iterations instead + of resetting at each forward entry; + - add `measurement_start_*` CUDA memory ledger fields. + +Validation: + +- Focused compiler/source/measurement guard: + `uv run pytest -q tests/test_fabric_benchmark_suite_common.py::test_training_measurement_clears_warmup_grads_before_cuda_peak_reset tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + - Result: `4 passed`. +- Targeted CUDA parity after telemetry changes: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_stage_window_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_stage_window_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + - Result: `4 passed`. +- Targeted CUDA parity after the rejected reducer probe: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_reducer_ondemand_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_reducer_ondemand_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + - Result: `4 passed`. + +Exact Axon attribution artifacts: + +- `tmp/fabric_audits/partials/2026-05-03/t1_stage_window_axon_h32_100m_b1024` + - `833.344 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`. + - Stage rows now span both measured iterations: + `fabric_registered_backward_memory_stage_count=130`. + - Highest current registered stage: + `forward_final_state_materialized=77412393984`. + - `after_fused_forward_program=77412393984`. + - `native_after_recurrent_message_local0=77335332864`. +- `tmp/fabric_audits/partials/2026-05-03/t1_measure_start_axon_h32_100m_b1024` + - `834.513 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`. + - `measurement_start_cuda_allocated_bytes=1239984128`. + - `measurement_start_cuda_max_allocated_bytes=1239984128`. + - `forward_entry=41165039616` while the global max is already + `85079713792` by that first Fabric temporal stage. + +Finding: + +- The previous "after parameter reducer" high-water label was misleading: it was + selected because every later stage saw the same global max. With stage-window + alignment, the highest current registered stage is the forward + artifact/final-state materialization group at `77412393984`, and the + recurrent-message reverse stage is effectively tied at `77335332864`. +- The `85.08GB` max is not stale warmup memory at reset time: + `measurement_start_cuda_max_allocated_bytes=1239984128`. +- The `85.08GB` max happens after benchmark reset and before the first + registered temporal `forward_entry` memory sample. That makes the next owner + the pre-temporal Fabric front-end handoff into the registered program: + boundary input projection/state preparation/runtime call setup, or a CUDA + temporary created there. It is not yet owned by a temporal stage row. + +Rejected probe: + +- Probe: transition parameter reducer source rows were reduced on demand instead + of precomputed into a vector. +- Artifact: `tmp/fabric_audits/partials/2026-05-03/t1_reducer_ondemand_axon_h32_100m_b1024`. +- Result: `832.622 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`. +- Movement: none. `after_fused_forward_program`, `native_after_recurrent_message`, + `before_parameter_reducer`, and `after_parameter_reducer` were unchanged. +- Decision: reverted. It was not accepted as a throughput strategy because the + named owner did not move. + +Next owner: + +- Add compiler-boundary-safe allocator telemetry around the high-level Fabric + front-end handoff, before the registered temporal program entry: + input projection, boundary tensor reshape/scatter, initial state creation, + sender K/V setup, and the call into `run_shared_temporal_bucket_forward_scan`. +- The next accepted optimization must either: + - move this pre-temporal high-water into a compiler-owned runtime buffer or + workspace row; + - shorten/alias a named front-end tensor lifetime through compiler liveness + rows; or + - prove the high-water is a CUDA library temporary and replace it with a + registered compiler-owned strategy. +- Do not add more reverse aliases or reducer changes until the pre-temporal + owner is named. + +### 2026-05-03 - Plan To Close Highest-Impact T=1 Owner: Front-End Handoff High-Water + +Status: planned; no optimization in this step. + +Selected owner: + +- Highest-impact row remains Axon 100M h32 B1024 T=1 terminal training. +- Latest accepted evidence: + `tmp/fabric_audits/partials/2026-05-03/t1_measure_start_axon_h32_100m_b1024`. +- Result: `834.513 tok/s`, `79.2367 GiB`, + `cuda_max_allocated_bytes=85079713792`, + `measurement_start_cuda_allocated_bytes=1239984128`, + `measurement_start_cuda_max_allocated_bytes=1239984128`. +- The first registered temporal memory sample is already after the hidden peak: + `forward_entry=41165039616` while global max is already + `85079713792`. +- Highest current registered stage is later: + `forward_final_state_materialized=77412393984`, and reverse recurrent-message + is effectively tied at `77335332864`. That means the remaining high-water + owner is before registered temporal entry, not another reverse span-output + alias. + +Boundary classifier: + +```text +Lane: throughput evidence -> runtime front-end handoff -> compiler-owned + liveness/strategy patch after attribution +Public declaration/spec owner: unchanged Blueprint/Interface/Input/Output and + current Fabric model declaration +Rows expected to change during attribution: none +Rows expected to stay stable: graph/message/readout/transition primitive rows, + executor rows, tensor bindings, artifact/output routes, reset rows, reducer + routes, memory-liveness semantics +Potential rows if optimization is needed: input-adapter/boundary-projection + runtime-buffer rows, fresh-state virtual rows, static-materialization cache + rows, or registered adapter/native strategy rows +Reference executor: unchanged PyTorch Fabric reference +Backward/reducer owner: unchanged registered reverse program and parameter + reducers unless the named owner is adapter/static materialization +Old route deleted or fail-closed: no benchmark tiling, private runtime helper, + fallback, fixed-slot wrapper, or public-state alias expansion +``` + +Plan: + +1. **Instrument the public-call handoff before optimizing.** + Add metadata-only allocator stages around the path from model forward to the + registered temporal entry: + - model `_forward_sequence_with_readout` entry; + - readout batch-tile decision; + - runtime `forward_output_cells_for_readout` entry; + - before/after projected boundary source validation; + - before/after `_project_boundary_source_sequence`; + - before/after static tensor resolution; + - before/after `_ensure_state`; + - before/after reset expansion; + - before `execute_temporal_bucket_sequence`; + - inside `execute_temporal_bucket_sequence` entry and after cached static + tensor normalization, before the physical autograd call. + Record allocated/reserved/max-allocated plus tensor-byte summaries for + `source_hidden_seq`, `boundary_seq`, static tensor payload, initialized + cells/population state, sender K/V state, reset tensors, and adapter output. + +2. **Classify the hidden peak into one owner.** + The accepted attribution must name one of: + - public input adapter projection output or CUDA linear temporary; + - static tensor materialization/prepack; + - fresh `cells`/population state initialization; + - sender K/V setup; + - reset normalization; + - runtime call setup or autograd-saved tensor; + - allocator reserve/cache gap. + If it still appears only as broad unclassified CUDA peak, do not optimize; + add narrower telemetry. + +3. **Choose exactly one compiler-respecting closure strategy.** + - If the owner is boundary projection output: lower the input adapter into an + explicit boundary-projection/runtime-buffer product, so the public adapter + writes directly into the planned `boundary_seq` storage consumed by the + registered temporal program. Adapter backward must remain ordinary + PyTorch-module parity or become a registered adapter backward strategy with + reducer rows; no benchmark-side chunking. + - If the owner is a CUDA linear temporary: add a registered input-adapter + strategy over the existing input declaration and boundary tensor role, or + keep it as public adapter cost if the temporary is outside Fabric backend + ownership. Do not hide it in temporal scheduler code. + - If the owner is fresh state: represent fresh zero state as compiler-owned + virtual/zero state rows until a consumer requires materialization; avoid + allocating full `cells` or population private state before the registered + program when the program can consume zero-state semantics directly. + - If the owner is static materialization: make the static/prepack lifetime + explicit and persistent through compiler static-value/cache rows; remove + repeated large temporary materialization from the measured call while + preserving parameter-gradient reducer ownership. + - If the owner is sender K/V setup: move K/V setup behind registered + message/input-adapter rows or liveness rows; do not infer from old tensor + names. + - If the owner is allocator reserve only: document it separately and target + current allocated/lifetime owners; do not claim reserve movement as + throughput closure. + +4. **Reject shortcuts.** + Reject or revert any patch that: + - changes primitive math, tensor roles, or gradient semantics; + - depends on Axon/sLSTM, hidden size, benchmark row, shape label, or + single-pop route identity; + - calls private runtime helpers from benchmarks; + - adds temporal-scheduler formulas or fixed slots; + - reopens public-state aliasing for training; + - moves current live allocations but leaves the named high-water owner and + `cuda_max_allocated_bytes` unchanged without a documented narrow lifetime + reason. + +5. **Validation gates.** + - Source/compiler guardrails: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + - Handoff/source guardrail to add with the implementation: + front-end code may prepare inputs but must not contain primitive formulas, + strategy selection, fixed slots, family/shape selectors, or benchmark + policy. + - Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_frontend_handoff_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_frontend_handoff_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + - Exact Axon row: + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 ... run_audit --plan t1-single-pop --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent`. + - Four-row guardrail: + same T=1 matrix for `slstm,axoncell` and `forward,forward_backward`. + +Acceptance: + +- Attribution is accepted only if the front-end handoff ledger names the hidden + high-water owner before `forward_entry`. +- Optimization is accepted only if that named owner moves in + `cuda_max_allocated_bytes`, stage telemetry, or storage lifetime, with parity + green and registered compiler-owned forward/backward owners still active. +- T=1 throughput closure is still not claimed until the normal non-debug exact + Axon row and four-row guardrail improve materially and stay non-regressed. + +### 2026-05-03 - Front-End Handoff Attribution And Fresh-State Cache Slice + +Status: implemented, measured, and accepted as a narrow compiler-owned +memory/liveness improvement. This is not T=1 throughput closure. + +Boundary packet: + +```text +Lane: throughput strategy / runtime front-end handoff +Unchanged semantic rows: graph, message, readout, transition, reset, output, + artifact, reducer, and primitive rows +Changed strategy/runtime rows: memory-stage attribution plus fresh-state + population-cache handoff into the registered forward program +Tensor/liveness rows consumed: temporal plan carry.fresh_state_population_cache + and registered forward initial_population_state_cache +Old route deleted or fail-closed: none reopened; no benchmark/family/shape + selector added +``` + +Implementation: + +- Added `sequence_surface/runtime/memory_stages.py` so front-end handoff allocator + stages and tensor-byte summaries enter the same compiler-owned audit ledger + as registered temporal stages when owner timing is enabled. +- Instrumented the high-level `model(...)` path through boundary projection, + state preparation, static tensor resolution, and `execute_temporal_bucket_sequence` + handoff. +- Extended benchmark ledger parsing with + `fabric_frontend_tensor_*` keys so pre-registered-entry tensors are visible in + audit JSONL instead of appearing only as unclassified CUDA peak. +- Threaded the compiler plan's `fresh_state_population_cache` fact through + `forward_output_cells_for_readout`, `execute_temporal_bucket_sequence`, + `run_shared_temporal_bucket_forward_scan`, and the registered forward executor. + The registered forward path now reuses the compiler-owned fresh population + cache instead of rematerializing a second backend population cache before the + fused forward program. + +Attribution result before the optimization: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_frontend_handoff_attr_axon_h32_100m_b1024`. +- Exact Axon T=1 row: `831.8335 tok/s`, `1231.0155 ms`, + `79.2367 GiB`, `cuda_max_allocated_bytes=85079713792`. +- Front-end ledger proved the original "pre-temporal hidden peak" was too + broad: + - `frontend_after_ensure_state=22312167424` + - `forward_entry=41165039616` + - `before_fused_forward_program=73385862144` + - `after_fused_forward_program=77412393984` +- The first large active allocation movement was therefore inside the registered + forward program setup, not the public input adapter or benchmark handoff. + +Accepted optimization result: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_fresh_state_cache_axon_h32_100m_b1024`. +- Exact Axon T=1 row: `838.7215 tok/s`, `1220.9059 ms`, + `75.5241 GiB`, `cuda_max_allocated_bytes=81093429248`. +- Movement: + - `before_fused_forward_program`: `73385862144 -> 54595380224` + - `after_fused_forward_program`: `77412393984 -> 58621912064` + - global max allocated: `85079713792 -> 81093429248` + - peak allocated moved down by `3986284544` bytes. +- Accepted because the intended compiler-owned liveness edge moved, parity stayed + green, and no semantic rows or benchmark policies changed. + +Remaining owner after this slice: + +- The current high-water is no longer the front-end handoff or fused forward + setup. In raw stage chronology it first rises inside registered fused backward, + reaching `native_after_recurrent_message_local0` with + `allocated=77335332864` and `max_allocated=81093429248`. +- The next owner is the registered backward native memory/liveness path around + transition reverse, recurrent-message reverse, boundary K/V, and parameter + reducer inputs. The next patch should name and shorten one of those native + lifetimes through compiler rows, reducer/liveness rows, or registered native + strategy workspace reuse. Do not return to front-end aliases or reducer + relabeling unless this owner moves. + +Validation: + +- Source/compiler guardrails: + `uv run pytest -q tests/test_fabric_benchmark_suite_common.py::test_training_measurement_clears_warmup_grads_before_cuda_peak_reset tests/test_fabric_backend_boundaries.py::test_frontend_handoff_memory_attribution_enters_compiler_ledger tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows --tb=short` + -> `5 passed`. +- Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_fresh_state_cache_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_fresh_state_cache_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` + -> `4 passed`. +- Four-row T=1 guardrail: + `tmp/fabric_audits/partials/2026-05-03/t1_fresh_state_cache_guard_h32_100m_b1024` + -> `4 ok` cases. + +| family | mode | tok/s | ms | peak GiB | max allocated bytes | +|---|---:|---:|---:|---:|---:| +| sLSTM | forward | `5085.3137` | `201.3642` | `16.0242` | `17205893120` | +| sLSTM | forward_backward | `992.9951` | `1031.2236` | `23.8707` | `25630957056` | +| Axon | forward | `3712.2391` | `275.8443` | `56.4835` | `60648679424` | +| Axon | forward_backward | `839.6328` | `1219.5808` | `75.5224` | `81091594240` | + +The guardrail confirms the fresh-state cache slice did not break the registered +single-pop T=1 forward/training matrix. It does not close T=1 throughput: the +large Axon training row remains far below the April 21 baseline and still peaks +inside the registered backward native memory/liveness path. + +### 2026-05-03 - Rejected Native Liveness Probes After Fresh-State Cache + +Status: rejected and reverted. No code from these probes is kept because neither +probe moved the active Axon T=1 high-water owner. + +Boundary packet: + +```text +Lane: throughput strategy / registered backward reducer-liveness probe +Unchanged semantic rows: graph, message, readout, transition, primitive, + tensor-binding, reset, artifact/output-route, and reducer rows +Changed strategy/runtime rows attempted: none accepted +Old route deleted or fail-closed: no old route reopened; rejected diffs reverted +``` + +Rejected probe 1: transition recurrent-message source release + +- Attempt: clear the transition recurrent-message source slot immediately after + copying it into the compiler-owned `ReverseGradRecurrentMsg` runtime buffer. +- Source/compiler guardrails: + `tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows`, + `tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned`, + `tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows` + -> `3 passed`. +- Targeted CUDA parity: + T=1 terminal artifact-store route, provided-state gradient route, + final-state-only zero output grad route, and T>1 artifact-store route + -> `4 passed`. +- Exact Axon T=1 artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_recurrent_msg_source_release_axon_h32_100m_b1024`. +- Result: `839.5983 tok/s`, `75.5241 GiB`, + `cuda_max_allocated_bytes=81093429248`. +- Owner movement: none. + - `native_after_transition_keep_slots_local0=73306966016` + - `native_after_recurrent_msg_buffer_local0=73306966016` + - `native_after_recurrent_message_local0=77335332864` +- Decision: rejected as label/lifetime-no-op for the active high-water. The + exact diff was reverted. + +Rejected probe 2: fixed-slot-context reduced key-bank return + +- Attempt: have the fixed-slot context message reverse strategy return + reducer-owned reduced key-bank gradients instead of full `[B, sender, 2H]` + key-bank tensors. +- Focused source/reducer checks: + `tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned` + and + `tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows` + -> `2 passed`. +- Targeted CUDA parity: + T=1 terminal artifact-store route, provided-state gradient route, + final-state-only zero output grad route, and T>1 artifact-store route + -> `4 passed`. +- Exact Axon T=1 artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_fixed_slot_reduced_key_axon_h32_100m_b1024`. +- Result: `840.6210 tok/s`, `75.5241 GiB`, + `cuda_max_allocated_bytes=81093429248`. +- Owner movement: none. + - `native_after_recurrent_message_local0=77335332864` unchanged. + - `after_fused_backward_program` worsened from `65780025344` to + `67793291264`. +- Decision: rejected and reverted. This repeated an earlier wrong-owner lesson: + the representative Axon row is using the registered + `neighborhood_attention_project` message strategy, so fixed-slot key-bank + liveness does not close the active owner. + +Updated owner: + +- The remaining active owner is the registered + `neighborhood_attention_project` recurrent-message reverse output contract. + The strategy materializes full recurrent K/V gradient banks before the + initial recurrent K/V projection stage can consume them. The next accepted + throughput slice must change that compiler-owned strategy contract, for + example by streaming recurrent K/V adjoints directly into the registered + initial recurrent K/V projection/reducer outputs or into a declared + workspace/reducer route. More slot clearing or fixed-slot-only reducer edits + are rejected unless `native_after_recurrent_message_local0` and + `cuda_max_allocated_bytes` move. + +### 2026-05-03 - Priority Update: Attack Forward Before Backward + +Status: plan only. No optimization was done for this section. + +Decision: + +- Before changing the registered message reverse K/V-adjoint contract, attack + the T=1 forward runtime materialization owner. +- Rationale: forward memory is paid by both forward-only and training rows. + Axon forward is already `3712.2391 tok/s`, `56.4835 GiB`, about `15.82x` + slower and `27.29x` over the April 21 memory floor. Reducing forward runtime + buffers also lowers the starting point for backward and makes later reverse + owner measurements cleaner. +- The reverse plan below remains valid, but it is deferred until the forward + owner has either moved or been proven not to be the next high-impact slice. + +Highest-impact forward owner: + +- `registered_temporal_fused_forward_program_cuda`, specifically forward + runtime/materialized transition-message buffers in the registered program. +- Current accepted Axon T=1 forward guardrail: + `3712.2391 tok/s`, `56.4835 GiB`, + `cuda_max_allocated_bytes=60648679424`. +- Named compiler runtime buffers are about `9.66B`, but peak allocation is much + larger. The forward stage ledger shows high allocation around + `forward_runtime_buffers_allocated`, `before_fused_forward_program`, + `after_fused_forward_program`, and `forward_final_state_materialized`. +- Dominant named roles from the current owner table include + `transition_forward_diag_output`, `transition_forward_linear_output`, + `forward_recurrent_hidden_after`, and `forward_recurrent_msg`. The next pass + must separate true runtime buffers from hidden/native temporaries and + persistent materializations. + +Boundary manifest: + +```text +Lane: throughput strategy / registered forward memory-liveness +Expected semantic delta: none +Unchanged rows: graph, message primitive rows, readout rows, transition rows, + tensor bindings, output routes, reset rows, artifact routes, reducer rows, + and parameter meanings +Changed rows/contracts: forward memory_liveness rows, forward runtime-buffer + roles, local-workspace/output materialization policy, native metadata +Forward owner: registered_fused_forward_program_cuda / + registered_temporal_fused_forward_program_cuda +Backward owner: unchanged; no reverse strategy change in this slice +Memory/liveness owner: compiler memory_liveness_rows, memory_runtime_schedule + rows, forward artifact/output routes, registered native runtime buffers +Old route to delete/fail-close: materializing every forward transition/message + intermediate as a persistent runtime buffer when compiler rows prove it has + only local next-stage consumers +Guardrail invariant: no primitive math, fixed slots, family selectors, + hidden-size selectors, benchmark policy, or single/mixed-pop route branches +``` + +Implementation plan: + +1. **Build the forward consumer map first.** + - For Axon T=1 forward, list each large forward role and its compiler + consumers: user output, final state, reverse artifact, transition next + primitive, readout input, recurrent carry, or no downstream consumer. + - Start with the roles visible in the memory ledger: + `transition_forward_diag_output`, `transition_forward_linear_output`, + `forward_recurrent_hidden_after`, `forward_recurrent_msg`, + `forward_output_msg`, `forward_output_cells`, and output sequence. + - Add fail-closed validation if a compact/local policy is requested without + a complete consumer route. + +2. **Add forward liveness classes.** + - Classify forward outputs as: + `semantic_return`, `carry_state`, `reverse_artifact`, + `local_next_primitive`, `workspace`, `metadata`, or `drop_after_use`. + - This classification must come from compiler rows and output/artifact + routes, not from Axon/sLSTM names, hidden size, batch, or benchmark row. + +3. **Localize transition/message intermediates inside the registered forward + program.** + - Keep full materialization only for roles required by user output, final + state, or backward artifact policy. + - Convert one large local-only transition/message role at a time into + workspace or direct producer-consumer handoff inside the registered forward + program. + - Preserve training artifact correctness: if a role is required by backward, + it remains an artifact or gets a declared recompute route; do not silently + drop it. + +4. **Measure forward-only before training.** + - First acceptance target is Axon 100M h32 B1024 T1 forward: + `cuda_max_allocated_bytes` and the forward high-water stage must drop. + - Then run the four-row guardrail to ensure sLSTM/Axon forward and training + do not regress. + - If training peak is unchanged after forward peak drops, the next owner + remains the backward plan below. + +5. **Do not expand rejected alias routes.** + - Do not reopen public-state alias/no-copy for training. + - Do not keep a metadata-only forward liveness label unless storage identity, + allocated bytes, or stage high-water moves. + +Targeted gates for the proceed pass: + +- Static/compiler guardrails: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_forward_fused_program_runtime_facts_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` +- Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_forward_liveness_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_forward_liveness_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` +- Exact forward perf row: + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 ... run_audit --plan t1-single-pop --families axoncell --sizes 100m --modes forward --batches 1024 --seq-lens 1 --inner-steps 1 --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent` +- Four-row non-regression: + same T=1 matrix for `slstm,axoncell` and `forward,forward_backward`. + +Rollback rule: + +- Revert the patch if parity fails, if semantic rows change, if the active route + demotes to fallback/replay/compat, if forward peak does not move, or if + training regresses without a named follow-up owner. + +### 2026-05-03 - Plan: Registered Neighborhood-Attention Reverse Liveness + +Status: plan only; deferred behind the forward-first liveness plan above. No +optimization was done for this section. + +Highest-impact owner: + +- `registered_temporal_fused_backward_program_cuda`, specifically the registered + `neighborhood_attention_project` recurrent-message reverse strategy. +- Current accepted Axon T=1 training guardrail: + `839.6328 tok/s`, `75.5224 GiB`, + `cuda_max_allocated_bytes=81091594240`. +- Current named peak: + `native_after_recurrent_message_local0` at about `77.33GB` allocated. +- Rejected lessons: + - clearing the transition recurrent-message source slot did not move the + peak; + - reducing fixed-slot-context key-bank returns did not move the peak because + the representative row uses the registered `neighborhood_attention_project` + strategy, not that fixed-slot extra-output path; + - public-state alias/no-copy routes are rejected for training until the hidden + allocator owner is explained. + +Boundary manifest: + +```text +Lane: throughput strategy / native reducer-liveness +Expected semantic delta: none +Unchanged rows: graph, message primitive rows, readout rows, transition rows, + tensor bindings, output routes, reset rows, artifact routes, and parameter + meanings +Changed rows/contracts: reverse message strategy output contract, + reverse_parameter_reducer_route_rows consumption, runtime buffer/liveness + ownership for recurrent K/V adjoints, native metadata +Forward owner: registered_fused_forward_program_cuda +Backward owner: registered_reverse_executor_bindings / + registered_temporal_fused_backward_program_cuda +Memory/liveness owner: compiler memory_liveness_rows and registered native + runtime buffers +Old route to delete/fail-close: returning full recurrent K/V adjoint banks from + the neighborhood-attention reverse strategy when all consumers are declared + reducer/runtime routes +Guardrail invariant: no fixed-slot, family, hidden-size, benchmark, or + temporal-scheduler message-formula branch may be added +``` + +Implementation plan: + +1. **Map consumers before editing native outputs.** + - Locate the exact `neighborhood_attention_project_backward` output roles + that create the full recurrent K/V adjoint banks. + - For each role, list compiler consumers by route row: initial recurrent K/V + projection reducer, recurrent sender K/V parameter reducer, carry/state + seed, boundary input gradient, or no consumer. + - Add a fail-closed validation if a planned compact path is selected but a + consumer row is missing. + +2. **Add a route-owned compact K/V adjoint contract.** + - Introduce strategy metadata/runtime rows that distinguish: + `return_full_bank`, `reduce_to_parameter_route`, + `accumulate_to_runtime_buffer`, and `drop_after_local_consume`. + - Keep this contract selected from executor rows, binding rows, reducer rows, + and liveness rows. Do not infer it from tensor names, Axon/sLSTM, hidden + size, or population mode. + +3. **Change the registered message reverse strategy, not the temporal + scheduler.** + - Update the native message reverse implementation so recurrent K/V adjoints + are streamed or accumulated directly into declared reducer/runtime outputs + when no downstream row requires the full bank. + - Preserve the full-bank path only behind legality for rows that actually + require it. + - Do not touch message math or add Q/K/V semantics to scheduler files. + +4. **Bind reduced outputs through compiler-owned reducers.** + - Route recurrent sender K/V parameter-gradient inputs and initial recurrent + K/V projection-gradient inputs through `reverse_parameter_reducer_route_rows` + or the equivalent registered runtime buffer rows. + - Ensure Python receives only semantic outputs, reducer outputs, carry/state + outputs, artifacts, or metadata. Full temporary banks must not be returned + just to satisfy an old ABI. + +5. **Measure the exact owner movement.** + - Required accepting movement: + `native_after_recurrent_message_local0` decreases and + `cuda_max_allocated_bytes` decreases on the Axon 100M h32 B1024 T1 + training row. + - If only post-return live allocation moves while max allocated does not, + keep the patch only as a narrow liveness improvement and immediately add + finer native-stage telemetry before further changes. + +Targeted gates for the proceed pass: + +- Static/compiler guardrails: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows tests/test_fabric_backend_plan.py::test_reverse_span_outputs_are_compiler_owned_rows --tb=short` +- Targeted CUDA parity: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_message_reverse_liveness_parity_20260503 TRITON_CACHE_DIR=/tmp/cortical_triton_t1_message_reverse_liveness_parity_20260503 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific --tb=short` +- Exact perf row: + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 ... run_audit --plan t1-single-pop --families axoncell --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --hidden-sizes 32 --training-output-boundaries terminal --population-modes single --reset-modes absent` +- Four-row non-regression: + same T=1 matrix for `slstm,axoncell` and `forward,forward_backward`. + +Rollback rule: + +- Revert the patch if parity fails, if semantic rows change, if the active route + demotes to fallback/replay/compat, or if neither + `native_after_recurrent_message_local0` nor `cuda_max_allocated_bytes` moves. +- Do not keep metadata-only labels or fixed-slot-only changes as progress. + +### 2026-05-03 - T=1 Remaining Throughput Deep Dive + +Status: analysis only. No optimization was run in this pass. + +Reference target: + +- April 21 `h32_t1_bxparams` boundary: `58732.71 tok/s`, `2.07 GiB`. +- This is the main T=1 target row family: sLSTM + Axon, 100M/500M/1B, + forward + training, B=1024/16384, h=32. + +Current accepted guardrail baseline: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_fresh_state_cache_guard_h32_100m_b1024`. +- This is the last clean four-row guardrail after the accepted fresh-state cache + liveness slice. Later reverse-liveness probes that moved only local + allocations are useful owner evidence, but they are not T=1 closure evidence + because `cuda_max_allocated_bytes` did not move. + +| row | tok/s | vs Apr21 | peak GiB | vs Apr21 memory | named runtime bytes | reverse artifact bytes | unclassified peak bytes | +|---|---:|---:|---:|---:|---:|---:|---:| +| sLSTM forward | `5085.31` | `8.66%` | `16.02` | `7.7x` | `7.45B` | `0` | `9.35B` | +| sLSTM train | `993.00` | `1.69%` | `23.87` | `11.5x` | `8.05B` | `4.73B` | `12.04B` | +| Axon forward | `3712.24` | `6.32%` | `56.48` | `27.3x` | `9.66B` | `0` | `50.59B` | +| Axon train | `839.63` | `1.43%` | `75.52` | `36.5x` | `11.54B` | `10.07B` | `58.68B` | + +Largest current named byte owners: + +- Axon forward runtime buffers: + - `transition_forward_diag_output=3758096384` + - `transition_forward_linear_output=3758096384` + - `forward_recurrent_hidden_after=1879048192` + - `forward_recurrent_msg=1879048192` +- Axon training reverse artifacts: + - `transition_state_before=3758096384` + - `recurrent_hidden_backend_order=1879048192` + - `recurrent_hidden_before_backend_order=1879048192` + - `recurrent_msg_backend_order=1879048192` +- The named compiler products are not enough to explain the peak. The largest + remaining allocator owner is still the unclassified/high-water gap: + `50.59B` on Axon forward and `58.68B` on Axon training. + +Most important remaining T=1 work, in priority order: + +1. **Forward materialization/liveness first.** + - Axon forward alone is already `15.8x` slower and `27.3x` over the April 21 + memory floor. Any training row pays this cost before backward starts. + - The next forward patch should consume compiler liveness/consumer rows to + avoid returning or retaining local-only transition/message intermediates: + transition linear/diag outputs, recurrent-message buffers, and the final + program tensor table. + - Acceptance is not metadata. The Axon forward + `cuda_max_allocated_bytes`, `forward_runtime_buffers_allocated`, + `before_fused_forward_program`, `after_fused_forward_program`, and + `forward_final_state_materialized` stages must move. + +2. **Registered backward native high-water second.** + - After the fresh-state cache slice, the training peak is again inside the + registered backward path. The strongest named stages are + `native_after_recurrent_message_local0`, `native_after_boundary_kv_local0`, + and transition reverse group stages. + - The next backward strategy should target the registered + `neighborhood_attention_project` recurrent-message reverse contract, not + fixed-slot-only paths. Rejected probes already showed fixed-slot key-bank + edits do not move the active Axon owner. + - Likely closure direction: route recurrent K/V adjoints directly into + reducer/runtime buffers or declared workspaces instead of materializing full + banks, but only through reducer/liveness rows. + +3. **Unclassified allocator owner must be named.** + - The biggest memory gap is still not a fully named compiler product. + - Before another alias/no-copy route, add or use phase-scoped telemetry to + decide whether the peak is a forward native temporary, returned program + tensor group, autograd saved tensor, CUDA library temporary, or allocator + reserve/cache gap. + - Public-state alias/no-copy remains rejected for training until the hidden + allocator owner is explained. + +4. **Representative surface is still too narrow.** + - The current evidence is mostly h32 100M B1024 single-pop. + - T=1 closure still needs mixed-pop T=1, B=16384 rows, 500M/1B rows, + h4/h8/h16 stress rows, high-batch small-param rows, reset-present rows, + and final-state/materialized-state axes. + - These should not get separate performance paths. They should validate that + the same registered compiler strategy generalizes. + +5. **Throughput compute fusion comes after liveness moves.** + - Even if compute kernels are slow, current memory peaks are too far from the + April 21 floor to treat compute as the first closure lever. + - Once Axon forward/training memory is within a sane range, the likely compute + owners are registered forward transition/message/readout launches and + registered reverse message/transition/reducer stages. + +Open conclusion: + +- T=1 is not close to April 21 yet. The headline Axon training row is about + `1.43%` of April 21 throughput and `36.5x` the April 21 memory floor. +- The biggest next change should be forward liveness/materialization in the + registered fused forward program, because it affects forward-only and + training rows and is already a major standalone miss. +- The next analysis before implementation should rebuild the current-code + four-row owner table if the working tree has changed since the last accepted + guardrail, then implement exactly one forward owner-moving strategy. + +### 2026-05-03 - Accepted T=1 Forward Fresh-State Cache Slice + +Status: accepted narrow forward/liveness improvement. This is not throughput +closure. + +Boundary classifier: + +```text +Lane: throughput strategy / runtime front-end handoff / liveness +Expected semantic delta: none +Unchanged semantic rows: graph rows, message primitive rows, transition primitive + rows, readout rows, tensor bindings, parameter bindings, output routes, + reset rows, artifact routes, and backward reducer meaning +Changed strategy/runtime rows: planner-owned fresh-state cache policy for + registered CUDA inference without final-state materialization; fresh-zero + compiler state sentinel bindings; fused forward final program tensor return + compaction; transition matmul fresh-zero input handling +Rows consumed directly: registered executor rows, transition program tensor + binding rows, state carry rows, memory/liveness rows, runtime schedule rows +Rejected legacy shape: materializing full current-state population cache for + single-pop T=1 inference when the compiler plan proves no final state is + requested +``` + +What changed: + +- `planner.py` now lets single-pop registered CUDA inference use the same + planner-owned fresh-state cache policy as mixed-pop when there is no requested + final state and no training artifact collection. +- `program_tensors.py` binds missing initial transition state through a + zero-numel sentinel shaped `[B,0,H]`, so the binding carries batch/hidden + metadata without allocating the full state bank. +- The fused forward program clears local transition output binding slots after + declared consumers run and compacts returned program tensors when final state + is not materialized. +- Registered transition matmul treats a fresh-zero input sentinel as a + compiler-owned zero input and writes the declared output buffer as zeros + instead of requiring a materialized state tensor. + +Probe path: + +- Return compaction alone: + `tmp/fabric_audits/partials/2026-05-03/t1_forward_program_tensor_return_compact_axon_h32_100m_b1024` + moved Axon forward peak only from `56.483 GiB` to `56.451 GiB`; kept as a + narrow liveness cleanup, not headline progress. +- Fresh-zero sentinel without planner activation: + `tmp/fabric_audits/partials/2026-05-03/t1_fresh_zero_sentinel_axon_forward_h32_100m_b1024` + also stayed at `56.451 GiB`; useful mechanism, but not accepted by itself. +- Single-pop planner fresh cache: + `tmp/fabric_audits/partials/2026-05-03/t1_single_fresh_cache_axon_forward_h32_100m_b1024` + moved Axon forward to `4063.38 tok/s`, `21.450 GiB`. + +Accepted four-row guardrail artifact: + +- `tmp/fabric_audits/partials/2026-05-03/t1_single_fresh_cache_guard_matmul_h32_100m_b1024` + +| row | previous tok/s | new tok/s | previous peak GiB | new peak GiB | result | +|---|---:|---:|---:|---:|---| +| sLSTM forward | `5085.31` | `8610.38` | `16.024` | `11.526` | moved | +| sLSTM train | `993.00` | `989.48` | `23.871` | `23.874` | unchanged | +| Axon forward | `3712.24` | `4061.44` | `56.483` | `21.485` | moved | +| Axon train | `839.63` | `841.93` | `75.522` | `75.525` | unchanged | + +Relative to the April 21 `58732.71 tok/s`, `2.07 GiB` target: + +- sLSTM forward is now `14.66%` of target throughput and `5.6x` the memory + floor. +- Axon forward is now `6.91%` of target throughput and `10.4x` the memory + floor. +- Training rows are still effectively unchanged: sLSTM train is `1.68%` of + target throughput and Axon train is `1.43%`. + +Targeted gates run: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings \ + tests/test_fabric_backend_plan.py::test_fabric_temporal_execution_plan_records_fresh_multi_population_cache_policy \ + tests/test_fabric_backend_plan.py::test_fabric_temporal_execution_plan_records_fresh_single_population_cache_policy \ + --tb=short +# 3 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_fresh_zero_matmul_parity_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_fresh_zero_matmul_parity_20260503 \ +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific \ + --tb=short +# 4 passed +``` + +Next owner: + +- Forward inference still has large named runtime buffers: + `transition_forward_linear_output`, `transition_forward_diag_output`, + `transition_forward_matmul_output`, `transition_forward_norm_output`, + `forward_recurrent_msg`, and `forward_recurrent_hidden_after`. +- Axon training still peaks inside the registered backward path around + `native_after_recurrent_message_local0`; this remains the first training + owner after forward liveness stops moving. +- Do not expand public-state alias/no-copy routes until the hidden allocator + owner is named. The failed 128GB alias probe remains rejected. + +### 2026-05-03 - Remaining T=1 Throughput Owners After Fresh-State Cache + +Status: analysis only. No optimization was done in this pass. + +Current accepted current-code artifact: + +- `tmp/fabric_audits/partials/2026-05-03/t1_single_fresh_cache_guard_matmul_h32_100m_b1024` + +Reference target: + +- April 21 `h32_t1_bxparams`: `58732.71 tok/s`, `2.07 GiB`. + +Current owner table: + +| row | tok/s | vs Apr21 | peak GiB | memory vs Apr21 | named runtime bytes | reverse artifact bytes | estimated unclassified/current-gap bytes | +|---|---:|---:|---:|---:|---:|---:|---:| +| sLSTM forward | `8610.38` | `14.66%` | `11.526` | `5.6x` | `7.45B` | `0` | `4.93B` | +| sLSTM train | `989.48` | `1.68%` | `23.874` | `11.5x` | `8.05B` | `4.73B` | `12.85B` | +| Axon forward | `4061.44` | `6.92%` | `21.485` | `10.4x` | `9.66B` | `0` | `13.41B` | +| Axon train | `841.93` | `1.43%` | `75.525` | `36.5x` | `11.54B` | `10.07B` | `59.49B` | + +Top forward runtime roles: + +- Axon forward: + - `transition_forward_linear_output=3.758B` + - `transition_forward_diag_output=3.758B` + - `forward_recurrent_msg=1.879B` + - `forward_recurrent_hidden_after=1.879B` +- sLSTM forward: + - `transition_forward_linear_output=3.020B` + - `transition_forward_matmul_output=2.416B` + - `transition_forward_state_output=0.604B` + - `transition_forward_norm_output=0.604B` + - `forward_recurrent_msg=0.604B` + - `forward_recurrent_hidden_after=0.604B` + +Top training/native stages: + +- Axon train: + - `native_after_recurrent_message_local0=77.336B allocated` + - `native_after_boundary_kv_local0=75.595B allocated` + - `native_transition_group_after_reverse_primitive_local0=75.187B allocated` + - `native_after_transition_local0=75.187B allocated` +- sLSTM train: + - `native_after_recurrent_message_local0=24.025B allocated` + - `native_after_boundary_kv_local0=23.525B allocated` + - `native_transition_group_after_reverse_primitive_local0=23.219B allocated` + +Important interpretation: + +- The accepted fresh-state cache fixed forward-only state materialization for + no-final-state inference. It did not move training because the current planner + still reports `training_requires_materialized_state`. +- Forward-only is still far from closed. Axon forward is still about `14.0 GiB` + live at the registered runtime-buffer stage and `21.5 GiB` high-water. +- The high-water gap inside forward is still not fully named. Axon forward has + about `15.1B` allocated after runtime buffers but `23.1B` max allocated. The + missing owner is likely a native fused-forward temporary, returned tensor + group, or allocator high-water inside the C++ program launch. It needs + native-stage telemetry before another alias/no-copy route. +- Training has two distinct blockers: + - forward/tape setup still materializes full fresh state and full artifacts; + - reverse native recurrent-message/KV stages still materialize large adjoint + banks and span outputs. + +Remaining work, in priority order: + +1. **Forward runtime-buffer liveness and workspace reuse.** + - Convert local-only transition primitive outputs from persistent runtime + buffers into compiler-liveness-owned workspaces/aliases. + - The first target is Axon forward: + `transition_forward_linear_output`, `transition_forward_diag_output`, + `forward_recurrent_msg`, and `forward_recurrent_hidden_after`. + - sLSTM must follow the same mechanism for + `transition_forward_linear_output`, `transition_forward_matmul_output`, + `transition_forward_norm_output`, and `transition_forward_state_output`. + - Acceptance: Axon forward `forward_runtime_buffers_allocated`, + `after_fused_forward_program`, and `cuda_max_allocated_bytes` all move. + +2. **Training fresh-state/tape artifact liveness.** + - Extend the fresh-zero state contract from inference into training artifact + and tape rows where the compiler proves the initial state is zero and no + user-provided state tensor is required. + - Backward should consume compiler-owned zero/tape sentinels for + state-before artifacts instead of forcing frontend `current_state` + materialization. + - This is forward/tape compiler work, not reverse math work. It is likely the + highest-impact training memory prerequisite because Axon training still + carries about `20.94B` of fresh current-state tensor before registered + execution. + +3. **Name and then reduce the forward high-water gap.** + - Add or use native fused-forward memory stage rows around message recurrent + K/V, message aggregation, transition groups, readout, output assembly, and + program-tensor return. + - Do not accept metadata-only stage labels; the next optimization must point + to a real allocation/lifetime edge. + - Keep public-state alias/no-copy rejected until the hidden `~7-8B` Axon + forward high-water gap is classified. + +4. **Registered reverse recurrent-message/KV liveness.** + - Once forward/tape liveness stops moving training, attack + `native_after_recurrent_message_local0`. + - Route recurrent K/V adjoints directly into reducer/runtime buffers when + compiler rows prove no full bank is needed. + - The implementation must consume reverse span-output rows, reducer rows, + tensor bindings, and liveness rows. It must not revive fixed-slot or + role-only output assumptions. + +5. **Compute/launch fusion after memory is sane.** + - Even after memory moves, current forward throughput is only `6.9%` to + `14.7%` of April 21 and training is about `1.4%` to `1.7%`. + - The likely compute closure work is registered program-level fusion across + transition/message/readout primitive rows, not benchmark tiling or legacy + kernels. + - Start compute tuning only after the representative rows no longer sit + `5x-36x` over the memory floor. + +6. **Expand the representative T=1 surface.** + - Current accepted evidence is h32 100M B1024 single-pop. + - Before claiming T=1 closure, run the same compiler-owned strategy through: + mixed-pop T=1, B=16384, 500M/1B, h4/h8/h16, reset-present rows, + materialized-final-state rows, and the dot-product semantic stress test + queued before throughput closure. + +Next best target: + +- Start with item 1 if the immediate goal is forward-only progress. +- Start with item 2 if the immediate goal is the April21 boundary row, because + training cannot move while fresh training state/tape still materializes large + state banks before backward starts. + +### 2026-05-03 - Plan: T=1 Forward Runtime-Buffer Liveness + +Status: plan only. No optimization was done in this pass. + +Selected owner: + +- `registered_temporal_fused_forward_program_cuda` runtime-buffer liveness. +- First representative row: + `t1-single-pop_axoncell_100m_forward_b1024_t1_k1_h32`. +- Current row: + `4061.44 tok/s`, `21.485 GiB`, + `forward_runtime_buffers_allocated=15.133B allocated`. + +Boundary manifest: + +```text +Lane: throughput strategy / memory-liveness +Expected semantic delta: none +Unchanged rows: graph rows, message primitive rows, transition primitive rows, + readout rows, tensor bindings, parameter bindings, output routes, reset rows, + artifact rows, and backward gradient meaning +Changed rows/contracts: runtime-buffer liveness, workspace/alias assignment, + optional native memory telemetry, and possibly output-init policy for + local-only primitive outputs +Rows consumed directly: memory_liveness_rows, memory_runtime_schedule_rows, + forward_executor_rows, forward_executor_binding_rows, native callable output + rows, program tensor binding rows, transition state-carry rows +Illegal shortcut: no cell-family, benchmark, hidden-size, single-pop, + Q/K/V/gated/diagonal semantic branch; no benchmark tiling; no old kernel copy +``` + +Hypothesis: + +- Several forward transition/message tensors are allocated for the whole + runtime-buffer table even when their live interval is local to a primitive + group or a single consumer. +- If the memory plan routes local-only outputs to reusable workspace/alias sets, + Axon forward should reduce: + - `forward_runtime_buffers_allocated` + - `after_fused_forward_program` + - `cuda_max_allocated_bytes` + +Target tensors: + +- Axon: + - `transition_forward_linear_output=3.758B` + - `transition_forward_diag_output=3.758B` + - `forward_recurrent_msg=1.879B` + - `forward_recurrent_hidden_after=1.879B` +- sLSTM: + - `transition_forward_linear_output=3.020B` + - `transition_forward_matmul_output=2.416B` + - `transition_forward_state_output=0.604B` + - `transition_forward_norm_output=0.604B` + - `forward_recurrent_msg=0.604B` + - `forward_recurrent_hidden_after=0.604B` + +Implementation plan: + +1. **Classify actual consumers by compiler row.** + - For each forward runtime role, map producer primitive row, output binding, + downstream input binding, artifact requirement, final-state requirement, + and backward/tape requirement. + - Add a fail-closed validation if a role is marked local-only but has a + downstream artifact/tape/final-state consumer. + +2. **Add native fused-forward memory stage telemetry.** + - Record metadata-only stage rows around: + recurrent K/V projection, recurrent message aggregation, transition group + entry/exit, readout, cell assembly, and program tensor return. + - This must not add semantic returns. It is only to name the current + `~7-8B` Axon forward high-water gap. + +3. **Introduce compiler-owned workspace/alias rows for local-only outputs.** + - Extend the runtime buffer plan so eligible primitive outputs can share a + workspace alias set when liveness proves non-overlap. + - Start with one role family, preferably transition linear/diag outputs, + because those are the largest named Axon forward owners. + - Keep the liveness decision in `memory_liveness_rows` / + `memory_runtime_schedule_rows`, not inside the temporal scheduler or + strategy body. + +4. **Make fused forward consume the workspace contract.** + - Native output lookup should still use native callable output rows and + program tensor bindings. + - If a tensor is local-only, it may be reused or dropped after its declared + consumers run. + - If a tensor is artifact/final-state/tape required, preserve it. + +5. **Run narrow probe, then four-row guardrail.** + - Probe first on Axon forward only. + - If it moves the named stages, run sLSTM/Axon forward + training guardrail. + - If forward moves but training does not, record that and proceed next to + training fresh-state/tape artifact liveness. + +Targeted gates: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings \ + tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes \ + tests/test_fabric_backend_plan.py::test_forward_fused_program_runtime_facts_are_compiler_owned_rows \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_forward_runtime_liveness_parity_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_forward_runtime_liveness_parity_20260503 \ +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific \ + --tb=short +``` + +Perf probe: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_forward_runtime_liveness_probe_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_forward_runtime_liveness_probe_20260503 \ +uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_forward_runtime_liveness_axon_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families axoncell \ + --sizes 100m \ + --modes forward \ + --batches 1024 \ + --seq-lens 1 \ + --inner-steps 1 \ + --gradient-horizon-steps none \ + --checkpoint-steps none \ + --hidden-sizes 32 \ + --training-output-boundaries terminal \ + --population-modes single \ + --reset-modes absent \ + --warmup 1 \ + --iterations 2 \ + --require-cuda-temporal-owner +``` + +Four-row non-regression: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_forward_runtime_liveness_guard_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_forward_runtime_liveness_guard_20260503 \ +uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_forward_runtime_liveness_guard_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell \ + --sizes 100m \ + --modes forward,forward_backward \ + --batches 1024 \ + --seq-lens 1 \ + --inner-steps 1 \ + --gradient-horizon-steps none \ + --checkpoint-steps none \ + --hidden-sizes 32 \ + --training-output-boundaries terminal \ + --population-modes single \ + --reset-modes absent \ + --warmup 1 \ + --iterations 2 \ + --require-cuda-temporal-owner +``` + +Keep / narrow / revert rule: + +- Keep only if Axon forward moves in at least one physical owner: + `forward_runtime_buffers_allocated`, `after_fused_forward_program`, native + fused-forward stage allocation, or `cuda_max_allocated_bytes`. +- Narrow if only one role moves and parity is green. +- Revert if movement is metadata-only, if peak memory increases, if training + regresses without a named owner, or if the active route demotes to fallback, + replay, compatibility, or benchmark-owned scheduling. + +Follow-up if accepted: + +- If forward-only moves but training remains unchanged, next plan is + `training fresh-state/tape artifact liveness`, not compute fusion. +- If forward high-water is still mostly unnamed, first add finer native + fused-forward telemetry before additional alias/no-copy work. + +## T=1 Forward Runtime-Buffer Liveness Result + +Status: accepted as a narrow compiler-owned liveness/reducer fix; not a +throughput-closure result. + +Implemented: + +- Added compiler-owned `allocation=deferred_local` runtime-buffer specs for + local transition-forward outputs when the scheduler plan is `none`, artifacts + are not collected, final state is not materialized, and the role is not a + public-state alias. Deferred specs allocate a zero-length placeholder and are + materialized only when a native callable output row actually consumes them. +- Kept the policy in memory/liveness rows and registered program buffer + validation. The temporal scheduler does not infer local-only transition tensor + lifetimes by name. +- Fixed fixed-slot context message static materialization so + `recurrent_sender_value_weight` is compiler-bound even when full cell K/V + weights are not materialized. +- Fixed transition input-projection reducer routing for the fixed-slot message + path: when compiler bindings select `message_to_cell_weight`, 2D direct + gradients reduce to that static source and do not enter factorized + `[R,V,P]` projection-backward code. + +Verification: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings \ + tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes \ + tests/test_fabric_backend_plan.py::test_forward_fused_program_runtime_facts_are_compiler_owned_rows \ + --tb=short +# 4 passed + +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_transition_unfuse_routes_direct_message_to_cell_grad_from_compiler_source \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_selects_registered_strategy \ + --tb=short +# 2 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_final_parity_after_liveness_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_final_parity_after_liveness_20260503 \ +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific \ + --tb=short +# 4 passed in 61.26s + +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py \ + tests/test_fabric_backend_plan.py \ + tests/test_fabric_execution_imports.py \ + --tb=short +# 175 passed +``` + +Perf/audit artifacts: + +- Forward liveness probe: + `tmp/fabric_audits/partials/2026-05-03/t1_forward_runtime_liveness_axon_h32_100m_b1024` +- Axon training route check after reducer fix: + `tmp/fabric_audits/partials/2026-05-03/t1_forward_runtime_liveness_axon_train_fixed_h32_100m_b1024` +- Reduced four-row route guard: + `tmp/fabric_audits/partials/2026-05-03/t1_forward_runtime_liveness_guard_accept_h32_100m_b1024` + +Observed movement: + +| Row | Status | tok/s | Peak GiB | Runtime buffers | Forward runtime-buffer stage | Unclassified peak | +| --- | --- | ---: | ---: | ---: | ---: | ---: | +| Axon forward probe, warmed | ok | 4052.27 | 17.81 | 4.03B | 9.46B | 14.70B | +| sLSTM forward, reduced guard | ok | 18.02 | 12.76 | 1.41B | 3.74B | 11.89B | +| sLSTM training, reduced guard | ok | 13.18 | 24.91 | 8.05B | 15.23B | 13.05B | +| Axon forward, reduced guard | ok | 479.89 | 21.53 | 4.03B | 9.57B | 18.69B | +| Axon training, reduced guard | ok | 3.94 | 81.16 | 13.42B | 56.55B | 62.72B | + +Notes: + +- The reduced four-row guard used `warmup=0, iterations=1` to keep the route + check bounded while Axon training remains extremely slow. Its tok/s values + are current-route evidence, not warmed throughput acceptance numbers. +- The Axon training row previously failed in the reducer with + `factorized recurrent input projection backward expects [R,H,P], [R,V,P], + and [H,V]`; it now completes through the registered compiler-owned route. +- Deferred-local transition outputs materially reduce the named forward runtime + buffer allocation for forward inference. They do not close training memory: + Axon training is still dominated by reverse artifacts, backward runtime + buffers, and a large unclassified native/allocator peak. + +Next owner: + +- Training memory/liveness, not forward local transition outputs. The next pass + should classify and reduce `reverse_artifacts`, backward runtime buffers, and + the `native_after_transition_keep_slots_local0` / unclassified peak owner for + Axon training. Public-state alias/no-copy remains rejected until that owner is + named and bounded. + +## T=1 Remaining Throughput Analysis After Fixed-Slot Stress + +Status: analysis only. No optimization was run in this pass. + +Important evidence correction: + +- The latest branch now lowers the default dot-product message rule to the + fixed-slot/context-nudge executor: + `fixed_slot_context_nudge_message` / + `fixed_slot_context_nudge_message_backward`. +- Older warmed rows that report `neighborhood_attention_project` are useful for + memory/liveness owner comparison, but they are stale for final T=1 throughput + closure on the current semantics. +- The latest fixed-slot four-row guard was intentionally reduced + (`warmup=0, iterations=1`) to prove the route after the reducer fix. Its + tok/s values are not acceptance throughput numbers. + +Current target: + +- April 21 `h32_t1_bxparams`: `58732.71 tok/s`, `2.07 GiB`, covering sLSTM + + Axon, 100M/500M/1B, forward + training, B=1024/16384. + +Current usable evidence: + +| Evidence | Executor family | Scope | What it proves | +| --- | --- | --- | --- | +| `t1_single_fresh_cache_guard_matmul_h32_100m_b1024` | old neighborhood message | warmed four-row matrix | last comparable memory/liveness baseline before fixed-slot semantics | +| `t1_forward_runtime_liveness_axon_h32_100m_b1024` | old neighborhood message | warmed Axon forward probe | deferred-local liveness moves forward runtime-buffer allocation | +| `t1_forward_runtime_liveness_slstm_fix_h32_100m_b1024` | fixed-slot context nudge | single sLSTM forward row | fixed-slot route is active and much slower than the old warmed forward path | +| `t1_forward_runtime_liveness_guard_accept_h32_100m_b1024` | fixed-slot context nudge | reduced four-row route guard | all four current semantic rows complete; Axon training reducer failure is fixed | + +Representative gaps: + +| Row/evidence | tok/s | vs Apr21 | Peak GiB | Memory vs Apr21 | Notes | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM forward, old warmed neighborhood | `8610.38` | `14.66%` | `11.53` | `5.6x` | stale for current fixed-slot semantics | +| Axon forward, old warmed neighborhood after deferred-local liveness | `4052.27` | `6.90%` | `17.81` | `8.6x` | stale for current fixed-slot semantics, useful memory proof | +| sLSTM forward, fixed-slot single row | `1457.60` | `2.48%` | `12.81` | `6.2x` | current semantics; not full four-row acceptance | +| Axon forward, fixed-slot reduced guard | `479.89` | `0.82%` | `21.53` | `10.4x` | current semantics; reduced/cold guard | +| Axon training, fixed-slot reduced guard | `3.94` | `<0.01%` | `81.16` | `39.2x` | current semantics; reduced/cold guard, route now completes | + +Biggest remaining owners, in priority order: + +1. **Rebuild a warmed current fixed-slot T=1 owner table.** + - The current semantic route changed. Before accepting any throughput claim, + rerun the four-row h32 100M B1024 single-pop matrix with warmup and private + extension/cache dirs. + - The reduced guard can only prove route health. It cannot rank final compute + or closure progress. + +2. **Optimize fixed-slot/context-nudge message executors through registered + strategy rows.** + - The new message program has more primitive rows: query slot projection, + recurrent value projection, query nudge, sender slot/context key work, + input/recurrent value projection, attention, weighted sum, output + projection, and normalize. + - This is now a top forward compute owner. The old neighborhood-attention + performance path is no longer the semantic target. + - Any optimization must stay inside the registered message strategy over + primitive rows/tensor bindings. Do not revive Q/K/V fixed slots or temporal + scheduler formulas. + +3. **Training memory/liveness remains the top training prerequisite.** + - Latest fixed-slot Axon training route guard peaks at `81.16 GiB`. + - The dominant named stages are still registered backward native stages: + `native_after_recurrent_message_local0`, `native_after_boundary_kv_local0`, + `native_transition_group_after_reverse_primitive_local0`, and + `native_after_transition_keep_slots_local0`. + - The next training memory patch should classify and reduce reverse + artifacts, backward runtime buffers, and recurrent-message/KV adjoint + materialization through reducer/liveness rows. + +4. **Forward high-water still needs native-stage attribution.** + - Deferred-local transition outputs reduce named runtime-buffer allocation, + but current fixed-slot forward still peaks around `12.8-21.5 GiB`. + - The forward gap is now partly compute and partly native high-water. Add or + use fused-forward native-stage telemetry before adding more alias/no-copy + routes. + +5. **Coverage expansion is still required after the representative row moves.** + - T=1 closure still requires mixed-pop T=1, B=16384, 500M/1B, h4/h8/h16, + reset-present rows, final-state/materialized-state axes, and the queued + dot-product stress semantics as the same compiler-owned strategy surface. + +Boundary constraints for the next implementation: + +- Semantics are fixed for throughput. Do not change message/cell/readout math in + a throughput pass. +- New speed must come from registered compiler-owned strategies over primitive + rows, tensor bindings, output/artifact routes, reducer rows, and + memory/liveness rows. +- April 21 is the baseline target, not a source to copy. +- Public-state alias/no-copy remains rejected until the native/unclassified + high-water owner is named and bounded. + +## April 21 Code Reference Translation + +Status: research note only. No code change was made for this section. + +Source inspected: + +- Local git commit: `d30d1b64786a54777fc44e1f4d099e6da65f63d1`, the latest + local source commit before the April 21 baseline window available in this + checkout. +- Old implementation surface: + `src/cortex/fabric/runtime.py`, + `src/cortex/fabric/families.py`, + `src/cortex/fabric/anatomy.py`, and + `src/cortex/kernels/cuda/fabric/sparse_message_kernels.cu`. +- This is not the current compiler path. The old code is useful only as a + semantic/performance oracle for mechanisms that must be re-expressed through + current declaration rows, primitive rows, executor rows, tensor bindings, + artifact/output routes, reducer rows, and memory/liveness rows. + +Semantically useful references: + +1. **Consumer-projected message values.** + The old runtime folded message-output projection into consumer projections: + `value_to_cell_weight = msg_to_cell.weight @ msg_out.weight` and + `value_to_output_weight = einsum(msg_out.weight, output_cell_weight)`. + Translation for the current compiler path: + - add registered strategy variants that consume the existing message weighted + value role and write transition-input or output-cell roles directly when + artifact/tape policy proves `output_msg` or `recurrent_msg` is not needed; + - keep the fold as a tensor-binding/strategy choice, not a temporal scheduler + formula; + - preserve artifact routes for training, debug, and backward rows that + require the logical message tensor. + +2. **T=1/K=1 step shape.** + The old `_forward_stream_step_k1` path did not run a generic sequence window + for the common T=1, K=1 case. It projected sender K/V, computed one + recurrent message, ran one recurrent transition update, then computed output + cells from the after-transition state. + Translation: + - implement this as a registered T=1 terminal forward strategy over the same + primitive rows, not as a family or benchmark selector; + - use output/artifact route rows to decide whether recurrent K-after, + recurrent V-after, `recurrent_msg`, `output_msg`, and output cells must be + materialized; + - if final state is not requested and artifacts are not collected, the + strategy may write only the declared public output and required carry + products. + +3. **Packed state carry for step execution.** + The old family modules had `pack_step_state`, `forward_step_packed`, and + `unpack_step_state` paths for both sLSTM and Axon. That avoided repeated + TensorDict reshaping and per-step state materialization. + Translation: + - carry transition state through compiler state-carry rows in backend-native + packed layouts; + - require explicit state layout metadata and reset ownership; + - do not expose packed state as a cell-family side route or public API + shortcut. + +4. **Partitioned graph layout as a plan fact.** + The old runtime used a fast partitioned layout when input, recurrent, and + output cell banks were contiguous. In that case it used slices and + concatenation instead of index-select/scatter. + Translation: + - represent this as graph/layout rows and access rows; + - let registered strategies specialize on legal contiguous layout facts; + - never branch on `single_pop`, lattice constructor name, or benchmark id. + +5. **Constant step/reset facts.** + The old code cached constant step tensors and used `all_active` for K=1 rows + to skip per-row blend logic. + Translation: + - promote constant-step, all-active, no-reset, and reset-present facts into + reset/activity/liveness rows; + - let strategies choose branchless kernels when those rows prove legality. + +6. **Sparse attention kernel shape.** + The old CUDA sparse-message kernel emitted only an attention-weighted value + tensor. Projection, output cell projection, and readout lived outside that + kernel. + Translation: + - keep attention as a registered message primitive strategy; + - for the current fixed-slot context-nudge semantics, do not copy the old + dynamic Q/K/V kernel; + - use the old shape as evidence that weighted-value first, then + consumer-owned projection, is a good compiler strategy boundary. + +Non-actions: + +- Do not restore `src/cortex/fabric` or any old `cortex.kernels.cuda.fabric` + route as a sibling execution path. +- Do not copy April 21 kernels into the registered program. Any mechanism from + the old code must be rebuilt as a strategy over current compiler rows and must + pass parity plus owner movement. +- Do not use old Config fields, family names, graph constructors, or benchmark + switches as strategy selectors. + +Next forward research task from this translation: + +- Build a row-owned "consumer-projected message" strategy for T=1 forward: + recurrent weighted-value to transition-input directly and output weighted-value + to output-cell directly when artifact/output routes prove the logical message + tensor is dead. Expected owner movement: `forward_recurrent_msg`, `output_msg`, + message GEMM/normalization temporaries, and some transition/readout high-water + allocation. Reject if primitive rows or tensor bindings change, or if peak + memory/timing does not move on the warmed h32 100M B1024 forward rows. + +## T=1 Forward Fixed-Slot Message Strategy Slice + +Status: accepted forward-only registered strategy improvement. This is not T=1 +closure. + +Boundary manifest: + +- Surface: message primitive executor strategy. +- Lane: throughput strategy / native implementation. +- Unchanged semantics: default fixed-slot context-nudge dot-product message + rule, Axon public norm epilogue, transition rows, readout rows, output routes, + artifact routes, and memory/liveness rows. +- Changed implementation: fixed-slot context message forward native callable. + The strategy now computes attention-weighted values with a warp-per-row path + for degree `<= 32`, applies the value-to-message projection through a + registered GEMM step, then applies a native row-normalization epilogue. The + old scalar fixed-slot message kernel was deleted. +- Rows consumed: existing primitive rows, forward executor rows, native callable + binding rows, program access rows, and runtime buffer rows for + `forward_recurrent_msg`. +- Rejected shortcut: no semantic math toggle, no scheduler formula, no fixed + Q/K/V route revival, no April21 code copy. + +Verification: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_fixedslot_gemm_message_tests_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_fixedslot_gemm_message_tests_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal \ + --tb=short +# 3 passed in 60.15s + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_fixedslot_gemm_message_cleanup_tests_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_fixedslot_gemm_message_cleanup_tests_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal \ + --tb=short +# 3 passed in 60.22s after deleting the stale scalar fixed-slot message kernel + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_fixedslot_warp_message_tests_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_fixedslot_warp_message_tests_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal \ + --tb=short +# 3 passed in 60.25s + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_fixedslot_warp_message_parity_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_fixedslot_warp_message_parity_20260503 \ +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific \ + --tb=short +# 4 passed in 59.61s + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_transition_gate_bmm_tests_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_transition_gate_bmm_tests_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal \ + --tb=short +# 3 passed in 58.92s + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_transition_gate_bmm_parity_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_transition_gate_bmm_parity_20260503 \ +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific \ + --tb=short +# 4 passed in 61.27s + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_forward_bmm_final_tests_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_forward_bmm_final_tests_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short +# 5 passed in 62.58s after reverting the recurrent-matmul BMM probe + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_forward_dense_final_tests_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_forward_dense_final_tests_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short +# 5 passed in 60.49s after reverting the readout-projection BMM probe +``` + +Perf artifacts: + +- Current fixed-slot owner table: + `tmp/fabric_audits/partials/2026-05-03/t1_fixedslot_current_owner_table_h32_100m_b1024`. +- Row-oriented message attention/projection baseline: + `tmp/fabric_audits/partials/2026-05-03/t1_fixedslot_weighted_row_message_forward_guard_h32_100m_b1024`. +- Accepted GEMM-backed message projection: + `tmp/fabric_audits/partials/2026-05-03/t1_fixedslot_gemm_message_forward_guard_h32_100m_b1024`. +- Accepted warp+GEMM message strategy: + `tmp/fabric_audits/partials/2026-05-03/t1_fixedslot_warp_message_forward_guard_h32_100m_b1024`. +- Accepted transition gate-affine batched-GEMM strategy: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_gate_bmm_forward_guard_h32_100m_b1024`. +- Declared dynamic-key/value semantic sanity control: + `tmp/fabric_audits/partials/2026-05-03/t1_message_math_sanity_dynamic_key_value_h32_100m_b1024`. + +Forward movement: + +| Row | Initial fixed-slot | Row attention/projection | GEMM projection | Warp+GEMM | + transition gate BMM | Dynamic-key/value control | April21 summary floor | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward tok/s | `1458.53` | `7433.82` | `8364.25` | `9715.57` | `10888.76` | `8519.98` | `58732.71` | +| sLSTM peak GiB | `12.76` | `12.76` | `12.76` | `12.76` | `12.76` | `11.53` | `2.07` | +| Axon forward tok/s | `480.36` | `3474.13` | `4139.69` | `5379.03` | `5354.83` | `3926.88` | `58732.71` | +| Axon peak GiB | `21.53` | `21.50` | `21.50` | `21.50` | `21.50` | `19.56` | `2.07` | + +Sanity-check conclusion: + +- Switching the declared message rule back to `dynamic_key_value` does not + recover April21. It lands near the improved fixed-slot route, not near + `58732.71 tok/s`. +- Therefore the new attention math explains only part of the regression. The + remaining forward gap is not a semantic-math-only problem. +- The next forward owner is still the registered fused forward + program/native-stage path: remaining message/readout launch geometry, + transition program materialization, and the forward high-water allocator + stage. +- Dense primitive work should prefer registered GEMM/batched-GEMM strategies + over scalar per-element CUDA loops when the compiler rows expose dense affine + or projection structure. The message output projection and sLSTM gate affine + probes both moved throughput through that route. +- Keep the warp+GEMM message strategy because it moved the actual high-level + route, passed parity, deleted the stale scalar message kernel, and did not + increase the measured forward peak in the representative guard. +- Keep the transition gate-affine batched-GEMM path as a narrow sLSTM forward + improvement. It passed the same T=1 registered parity subset and did not + increase the measured representative peak; Axon was flat because this owner + mainly affects the gated transition row. +- Rejected recurrent-matmul generic BMM probe: + `tmp/fabric_audits/partials/2026-05-03/t1_recurrent_matmul_bmm_forward_guard_h32_100m_b1024`. + It changed sLSTM forward from `10888.76` to `10545.54 tok/s` and Axon from + `5354.83` to `5365.04 tok/s` with no memory movement. The patch was reverted; + this primitive needs either the existing scalar path or a future fused/tiled + registered strategy that avoids layout-copy overhead. +- Rejected readout-projection generic BMM probe: + `tmp/fabric_audits/partials/2026-05-03/t1_readout_bmm_forward_guard_h32_100m_b1024`. + It changed sLSTM forward from `10888.76` to `10784.30 tok/s` and peak from + `12.76` to `12.95 GiB`; Axon changed only from `5354.83` to `5378.30 tok/s`. + The patch was reverted. Generic BMM with layout copies is not automatically a + win; accepted dense strategies need row layout that feeds GEMM without adding + temporary-heavy transposes. + +## Rejected Shortcut List + +- Copying April 21 code directly. +- Optimizing benchmark-side chunking, detach policy, or private helper calls. +- Treating small 1M/high-batch rows as T=1 closure by themselves. +- Accepting compiler-owned metadata without active runtime owner movement. +- Adding formulas or strategy rules to temporal scheduler files instead of + registered primitive executors. +- Closing by averages while boundary/worst rows still fail. + +## T=1 Remaining Throughput Deep Dive After Fixed-Slot Forward Strategy + +Status: analysis only. No optimization was run in this pass. + +Current strict target remains the April 21 `h32_t1_bxparams` summary floor: +`58732.71 tok/s`, `2.07 GiB`, covering sLSTM + Axon, 100M/500M/1B, +forward + training, B=1024/16384. + +Current accepted forward-only evidence after the fixed-slot message and +transition gate-affine strategy work: + +| Row | Current best tok/s | vs Apr21 | Current peak GiB | Memory vs Apr21 | Accepted artifact | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | `10888.76` | `18.54%` | `12.76` | `6.2x` | `t1_transition_gate_bmm_forward_guard_h32_100m_b1024` | +| Axon 100M h32 B1024 T1 forward | `5354.83` | `9.12%` | `21.50` | `10.4x` | `t1_transition_gate_bmm_forward_guard_h32_100m_b1024` | + +The accepted forward compute changes are real: + +- fixed-slot message attention/output projection moved from the initial scalar + route to warp weighted-value plus registered GEMM projection; +- sLSTM gate affine moved to batched GEMM; +- recurrent-matmul generic BMM and readout generic BMM were rejected because + layout-copy overhead outweighed the math win. + +The remaining forward gap is therefore not "use BMM everywhere." It is: + +1. **Program-level message fusion without extra global temporaries.** + The fixed-slot message strategy still stages weighted values, GEMM output, + and row normalization as separate native operations. The next useful message + strategy should either fuse weighted-value/projection/normalization for h32 + rows or use compiler-owned workspace so GEMM output and normalization do not + inflate the high-water stage. It must consume the existing message primitive + rows and bindings; no Q/K/V route revival. + +2. **Transition program fusion over primitive rows.** + Current forward runtime roles remain large: + sLSTM has `transition_forward_linear_output=2.81 GiB`, + `transition_forward_matmul_output=2.25 GiB`, and smaller state/norm outputs; + Axon has `transition_forward_linear_output=3.5 GiB`, + `transition_forward_diag_output=3.5 GiB`, and + `transition_forward_norm_output=1.75 GiB`. The next transition strategy + should reduce intermediate materialization by fusing adjacent registered + transition primitives or allocating them as row-scoped workspace. Generic BMM + that forces transposes/copies is already rejected. + +3. **Forward high-water attribution.** + The named forward runtime allocation is much smaller than the CUDA high + water. For the current forward artifact, sLSTM reports about `1.31 GiB` + runtime buffers but `12.76 GiB` max allocated; Axon reports about `3.75 GiB` + runtime buffers but `21.50 GiB` max allocated. Native-stage telemetry points + the peak at/after `after_fused_forward_program`, but the exact owner is still + unclassified. Before adding more alias/no-copy routes, name whether the peak + is message GEMM temporaries, transition primitive temporaries, output/readout + materialization, allocator reserve, or frontend/static tensor overlap. + +4. **Training artifact/tape/reverse memory.** + T=1 training still uses `store_step_artifacts` and `transition_tape=full`. + Current fixed-slot training rows route through the registered reverse + program, but memory is still dominated by reverse artifacts, backward runtime + buffers, and native reverse stages such as + `native_after_recurrent_message_local0`, `native_after_boundary_kv_local0`, + `native_transition_group_after_reverse_primitive_local0`, and + `native_after_transition_keep_slots_local0`. This remains the largest + training prerequisite after forward stops moving. + +5. **Warmed current-semantics training table.** + The latest accepted forward-only artifacts are current semantics. Some + training artifacts in the doc are stale old-neighborhood evidence, while the + strict current fixed-slot route guard is useful for route health but too + reduced/cold to close throughput. Before choosing the next backward compute + patch, rerun a warmed current fixed-slot four-row table for sLSTM/Axon + forward + training at h32 100M B1024 with private extension/cache dirs. + +6. **T=1 coverage expansion after the owner moves.** + Closure still needs the April21-shaped T=1 matrix, not only the h32 100M + B1024 steering rows: B=16384, 500M/1B, h4/h8/h16 stress rows, mixed-pop T=1, + reset-present rows, and final-state/materialization axes. + +Priority for the next plan: + +1. Forward first: add native-stage timing/allocator attribution around + registered message, transition, readout, and output materialization. +2. Then implement the largest named forward owner as a registered strategy: + message fusion if message temporaries dominate, transition fusion/workspace + if transition intermediates dominate. +3. Only after forward high-water stops moving, attack training artifact/tape and + reverse native-stage liveness. + +## 2026-05-03 - Forward Transition Liveness Strategy + +Boundary classifier: throughput strategy plus native liveness implementation. + +Unchanged semantic rows: + +- No message, readout, cell, transition, graph, reset, or output-route semantics + change. +- Primitive rows and tensor bindings remain the compiler authority. +- The change targets only registered transition forward primitive-output + lifetimes inside the fused forward program. + +Owner table from +`tmp/fabric_audits/partials/2026-05-03/t1_forward_native_stage_telemetry_h32_100m_b1024`: + +| Row | tok/s | peak GiB | runtime GiB | unclassified GiB | dominant native allocated stage | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | `11012.86` | `12.76` | `1.31` | `11.08` | transition rises to `11.08 GiB`, then recurrent-KV-after/readout/output route reaches `12.76 GiB` | +| Axon 100M h32 B1024 T1 forward | `5390.31` | `21.50` | `3.75` | `17.38` | transition peaks at `21.50 GiB` | + +Runtime-role attribution: + +- sLSTM: `transition_forward_linear_output=2.81 GiB`, + `transition_forward_matmul_output=2.25 GiB`, + `transition_forward_norm_output=0.56 GiB`, + `transition_forward_state_output=0.56 GiB`. +- Axon: `transition_forward_linear_output=3.50 GiB`, + `transition_forward_diag_output=3.50 GiB`, + `transition_forward_norm_output=1.75 GiB`. + +Selected strategy: + +- Add compiler-row last-use clearing for transition primitive input bindings + after each primitive executes. +- Only clear bindings that are proven by `forward_executor_binding_rows` to be + outputs of an earlier primitive in the same transition span and have no later + input consumer in that span. +- Keep public-state outputs and state-carry outputs alive until the existing + compiler-owned public output copy/state-carry code consumes them. + +Rows consumed directly: + +- `primitive_rows` +- `forward_executor_rows` +- `forward_executor_binding_rows` +- `program_tensor_binding_rows` +- `memory_liveness_rows` + +Old route avoided: + +- No fixed tensor slots, no family/benchmark/hidden-size branch, no scheduler + formula, no public-state alias expansion, and no April21 code copy. + +Keep/narrow/revert rule: + +- Keep only if the targeted parity subset passes and warmed forward allocator + telemetry shows the transition/native high-water owner moves down without + route fallback. +- Revert if max allocation is flat or higher, or if any training artifact path + loses required transition tape/state artifacts. + +Result: + +- Rejected and reverted. Both span-local and bucket-local transition input + clearing passed the targeted subset, but warmed forward telemetry was flat: + sLSTM stayed at `12.764 GiB` and Axon stayed at `21.502 GiB`. +- The observed owner is not stale C++ references to consumed transition inputs. + It is still live native/runtime transition output materialization plus the + message K/V banks that remain live across transition/readout. + +Next liveness target: + +- Clear `recurrent_k_before` and `recurrent_v_before` from + `RegisteredForwardMessageStepState` immediately after the recurrent message + is emitted when `return_reverse_artifacts == false`. +- This is legal because the compiler artifact policy says no reverse artifacts + are returned, and the forward program's later readout path consumes + `recurrent_k_after`/`recurrent_v_after`, not the before banks. +- Keep only if the same warmed forward telemetry reduces the current/max + native stages without affecting the training artifact parity subset. + +Accepted result: + +- Patch: forward-only/no-artifact registered fused forward clears + `recurrent_k_before` and `recurrent_v_before` after emitting + `recurrent_msg`. +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_message_kv_before_liveness_forward_guard_h32_100m_b1024`. + +| Row | Before tok/s | After tok/s | Before peak GiB | After peak GiB | Before unclassified GiB | After unclassified GiB | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM 100M h32 B1024 T1 forward | `11012.86` | `10997.97` | `12.764` | `11.077` | `11.076` | `9.388` | +| Axon 100M h32 B1024 T1 forward | `5390.31` | `5391.71` | `21.502` | `16.252` | `17.376` | `12.126` | + +Native-stage movement: + +- sLSTM transition stage dropped from `11.076 GiB` to `9.389 GiB`; the later + recurrent-KV-after/readout/output stages dropped from `12.764 GiB` to + `11.077 GiB`. +- Axon transition stage dropped from `21.502 GiB` to `16.252 GiB`; the later + recurrent-KV-after/readout/output stages dropped from `19.752 GiB` to + `14.502 GiB`. + +Validation: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_message_kv_before_liveness_tests_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_message_kv_before_liveness_tests_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short +# 4 passed in 59.67s +``` + +Remaining forward owner: + +- The high-water is now transition-output materialization itself, not K/V-before + lifetime across transition. +- Next forward strategy should reduce or fuse transition primitive output + materialization under compiler output contracts, or route readout to consume + transition public state without retaining extra output banks. It must keep the + same row/binding ownership and must not revive public-state aliasing for + training. + +Rejected follow-up probe: + +- Probe: + `tmp/fabric_audits/partials/2026-05-03/t1_gate_affine_direct_forward_probe_h32_100m_b1024`. +- Change tested: replace only the registered sLSTM gate-affine BMM staging with + the existing direct CUDA kernel while keeping primitive rows, tensor bindings, + and output contracts unchanged. +- Result: reverted. The targeted parity smoke passed, but warmed forward + evidence did not move the owner. sLSTM fell from `10997.97` to + `9833.80 tok/s` while peak stayed flat at `11.077 GiB`; Axon stayed near + `5381.96 tok/s` and `16.252 GiB`, as expected because the gate-affine owner is + not active there. +- Conclusion: the current sLSTM gate-affine BMM remains the better registered + strategy. The next transition patch should not revert to scalar/direct + affine; it should fuse adjacent transition primitives or remove downstream + output retention through compiler-owned state/output contracts. + +Rejected recurrent-message lifetime probe: + +- Probe: + `tmp/fabric_audits/partials/2026-05-03/t1_recurrent_msg_liveness_forward_probe_h32_100m_b1024`. +- Change tested: in forward-only/no-artifact runs, clear the transition + aggregate-message program input and `RegisteredForwardMessageStepState` + `recurrent_msg` immediately after the transition program consumes it. +- Result: reverted. The targeted parity subset passed, but warmed allocator + telemetry was flat. sLSTM stayed at `11.077 GiB` peak with + `native_forward_after_recurrent_kv_after_local0=11.077 GiB`; Axon stayed at + `16.252 GiB` peak with `native_forward_after_transition_local0=16.252 GiB`. +- Conclusion: the remaining high-water is not a stale recurrent-message + reference after transition. The next forward patch needs to change the actual + transition/readout materialization strategy or the recurrent-K/V-after + production/consumption strategy. + +Rejected transition norm in-place probe: + +- Probe: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_norm_inplace_forward_guard_h32_100m_b1024`. +- Change tested: when the compiler runtime buffer plan made a transition norm + output deferred-local, reuse the dead producer input binding and run a + row-wise in-place norm kernel instead of materializing the separate norm + output. +- Result: reverted. The targeted compiler/parity subset passed, but warmed + forward telemetry did not move peak memory. sLSTM stayed at `11.077 GiB` + peak (`11011.34 tok/s`), and Axon stayed at `16.252 GiB` peak while dipping + to `5367.46 tok/s`. +- Conclusion: the remaining forward high-water is not the deferred-local norm + output allocation alone. The next forward patch should name or reduce the + transition group's larger live outputs/temporaries (`linear`, `diag`, + `matmul`, K/V-after/readout), preferably through a registered row-group + fusion or workspace plan, not more local aliasing. + +Accepted transition output binding clear-column probe: + +- Probe: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_output_binding_clear_forward_guard_h32_100m_b1024`. +- Hypothesis: `clear_forward_transition_output_binding_slots` was checking + `forward_executor_binding_rows[:, 2]` as if it were a surface column, but + that column is the executor id. Transition spans are identified by the + compiler bucket ordinal column, so most transition output bindings could stay + live after the transition stage. +- Boundary classifier: throughput/liveness fix over existing compiler-owned + `forward_executor_binding_rows` and `program_tensor_binding_rows`. Primitive + rows, formulas, tensor roles, output routes, and artifact semantics are + unchanged. +- Expected owner movement: lower forward transition/native high-water and later + recurrent-KV-after/readout/output stages, especially for gated and primitive + transition spans whose executor ids are not equal to the transition surface + opcode. +- Result: kept. The targeted source/parity gate passed, and warmed forward + telemetry moved the sLSTM post-transition native high-water down without + fallback; the peak owner shifted back to the transition stage itself. Axon + stayed flat, so this closes the stale binding-column liveness bug but does + not close the Axon transition materialization owner. + +| Row | Previous accepted tok/s | New tok/s | Previous peak GiB | New peak GiB | Previous transition stage GiB | New transition stage GiB | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM 100M h32 B1024 T1 forward | `10997.97` | `10709.16` | `11.077` | `9.389` | `9.389` | `9.389` | +| Axon 100M h32 B1024 T1 forward | `5391.71` | `5376.55` | `16.252` | `16.252` | `16.252` | `16.252` | + +Validation: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_transition_clear_outputs_tests_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_transition_clear_outputs_tests_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows \ + tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings \ + tests/test_fabric_backend_plan.py::test_forward_fused_program_runtime_facts_are_compiler_owned_rows \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific \ + --tb=short +# 5 passed in 60.18s +``` + +Perf probe: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_transition_clear_outputs_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_transition_clear_outputs_20260503 \ +uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_transition_output_binding_clear_forward_guard_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Remaining forward owner: + +- sLSTM now peaks at the transition stage itself (`9.389 GiB`) and drops to + `6.014 GiB` after recurrent-KV-after/readout/output route stages. +- Axon still peaks at the transition stage (`16.252 GiB`) and remains + `14.502 GiB` through recurrent-KV-after/readout/output route stages. +- The next forward patch should target Axon transition materialization through + compiler-owned primitive row-group fusion or workspace/liveness rows for the + large `linear`/`diag`/`norm` outputs. Do not return to public-state aliasing + or local helper aliases without naming the native allocator owner first. + +### 2026-05-03 - Current T=1 Owner Table After Transition Clear + +Status: analysis only. No optimization was run. + +Artifact: +`tmp/fabric_audits/partials/2026-05-03/t1_current_owner_table_after_transition_clear_h32_100m_b1024`. + +Command intent: + +- Rerun the current compiler-owned single-pop h32 100M B1024 T=1 owner table + after the accepted transition-output binding clear fix. +- Use warmup `1`, iterations `2`, terminal output boundary, reset absent, + `--require-cuda-temporal-owner`. + +Result: + +- Partial/stalled. Three rows completed in the JSONL. The final Axon training + row stayed active for more than 13 minutes while using about `102 GiB` on GPU + and was stopped as non-closure steering evidence rather than promoted as a + valid completed audit. +- The summary reports `completed_count=3` of `case_count=4`; this artifact must + remain under `tmp/fabric_audits/partials/` and must not be cited as final + closure evidence. + +Current completed rows versus the April 21 `h32_t1_bxparams` summary floor +(`58732.71 tok/s`, `2.07 GiB`): + +| Row | tok/s | vs Apr21 | Peak GiB | Memory vs Apr21 | Runtime buffers GiB | Reverse artifacts GiB | Dominant native stage | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | `10965.12` | `18.67%` | `9.389` | `4.54x` | `1.313` | `0` | `native_forward_after_transition_local0=9.389 GiB` | +| sLSTM 100M h32 B1024 T1 training | `13.29` | `0.023%` | `24.908` | `12.03x` | `7.500` | `4.500` | `native_after_recurrent_message_local0=23.221 GiB`; `native_after_boundary_kv_local0=22.754 GiB` | +| Axon 100M h32 B1024 T1 forward | `5372.89` | `9.15%` | `16.285` | `7.87x` | `3.750` | `0` | `native_forward_after_transition_local0=16.285 GiB`; post-transition stays `14.535 GiB` | +| Axon 100M h32 B1024 T1 training | no completed row | open | `~102 GiB active before stop` | `>49x active` | open | open | active long-running training/reverse path | + +Updated owner ranking: + +1. **Training/reverse path is the largest closure blocker.** + sLSTM training is now only `0.023%` of the April 21 summary floor and Axon + training did not finish in the warmed table. This is no longer just a forward + gap. +2. **Forward transition materialization remains the largest forward owner.** + sLSTM and Axon both peak at `native_forward_after_transition_local0`. + Reducing transition primitive output materialization is still required before + forward can approach the April 21 memory floor. +3. **Reverse message/KV/tape liveness is the largest successful-training owner.** + The successful sLSTM training row peaks around recurrent message and boundary + K/V reverse stages, with full runtime buffers and reverse artifacts still + materialized. +4. **Axon training needs a bounded diagnostic row before further large audits.** + The full warmed Axon training row is too slow/high-memory for a routine + four-row loop. The next Axon training probe should be isolated and bounded, + with allocator/native-stage telemetry, so the stop condition itself is + measurable. + +Open conclusion: + +- Forward has improved materially, but T=1 cannot move to closure while current + training is this far below April 21 and Axon training does not complete in the + warmed matrix. +- The next implementation plan should choose between two high-impact owners: + either transition row-group/workspace fusion for forward memory, or a bounded + reverse-message/KV/tape liveness strategy that makes Axon training complete + quickly enough to profile. + +### 2026-05-03 - T=1 Fixed-Slot Reverse Key-Reducer Liveness Slice + +Status: accepted as a narrow compiler-owned reducer/liveness improvement; not +T=1 closure. + +Boundary manifest: + +- Unchanged semantics: fixed-slot context dot-product rows, output/readout + routes, transition rows, artifact rows, reducer routes, and parameter + meanings are unchanged. +- Changed implementation: the registered fixed-slot reverse message callable now + collapses its input/recurrent key-bank gradients from full + `[B, sender, 2H]` banks into reducer-shaped `[sender, 2H]` tensors before + returning them to the fused reverse program. +- Reducer contract: the fixed-slot message parameter reducer accepts either + already-reduced `[N,2H]` tensors or legacy `[B,N,2H]` tensors through the same + `message_strategy_parameter_grad` reducer rows. +- Runtime liveness: the fused reverse program keeps the reduced key-bank tensors + for reducer routes and uses zero-size rank-3 sentinels for the value-only + fixed-slot boundary/initial K/V adjoint calls. +- Rejected claim: this does not remove the native operator's temporary full + key-bank allocation yet, so it does not close peak memory. + +Verification: + +```bash +uv run pytest -q tests/test_fabric_backend_boundaries.py --tb=short +# 41 passed in 5.11s + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_reduced_key_liveness_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_reduced_key_liveness_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal \ + --tb=short +# 4 passed in 61.50s +``` + +Bounded Axon training probe: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_reduced_key_liveness_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_reduced_key_liveness_20260503 \ +timeout 420s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_reduced_key_liveness_axon_training_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families axoncell --sizes 100m --modes forward_backward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 0 --iterations 1 --require-cuda-temporal-owner +# status ok; cases=1 +``` + +Comparison against the last successful Axon training row +`t1_forward_runtime_liveness_guard_accept_h32_100m_b1024`: + +| Metric | Last successful Axon training row | Reduced key-reducer liveness | Movement | +| --- | ---: | ---: | ---: | +| tokens/s | `3.936` | `3.959` | `+0.6%` | +| peak GiB | `81.162` | `81.318` | `+0.156 GiB` | +| current allocated at `native_after_recurrent_message_local0` | `75.912 GiB` | `72.165 GiB` | `-3.747 GiB` | +| current allocated at `native_after_boundary_kv_local0` | `74.289 GiB` | `70.542 GiB` | `-3.747 GiB` | +| current allocated after fused backward program | `66.844 GiB` | `63.097 GiB` | `-3.747 GiB` | +| unclassified CUDA peak | `58.410 GiB` | `58.566 GiB` | `+0.156 GiB` | +| runtime buffers | `12.500 GiB` | `12.500 GiB` | unchanged | +| reverse artifacts | `9.500 GiB` | `9.500 GiB` | unchanged | + +Keep/revert decision: + +- Keep the patch as a narrow liveness win because live tensor retention after + the fixed-slot reverse message stage moved by the expected key-bank amount and + parity is green. +- Do not claim peak-memory or throughput closure. The max allocation and + unclassified peak are essentially unchanged because + `flat_bucket_registered_backward_fixed_slot_context_message_cuda` still + allocates full key-bank gradient tensors internally before the registered + strategy reduces them. +- Next owner: replace the fixed-slot reverse message operator's full key-bank + gradient outputs with compiler-declared reducer/runtime outputs directly. The + kernel should accumulate input/recurrent key gradients into reducer-shaped + buffers or planned runtime buffers and never allocate the full + `[B, sender, 2H]` key-bank adjoint tensors on the successful fixed-slot route. + +### 2026-05-03 - T=1 Deep Dive After Keyless Readout Probe + +Status: analysis only. No optimization was run in this pass. + +Current target remains the April 21 `h32_t1_bxparams` summary floor: +`58732.71 tok/s`, `2.07 GiB`, covering sLSTM + Axon, +100M/500M/1B, forward + training, B=1024/16384. + +Latest forward-only evidence: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_forward_keyless_readout_h32_100m_b1024`. +- Scope: single-pop, h32, 100M, B1024, T=1, forward only, reset absent, + terminal output boundary. +- This is not a closure artifact. It is a two-row forward steering table. + +| Row | tok/s | vs Apr21 | Peak GiB | Memory vs Apr21 | Runtime buffers GiB | Dominant allocated stage | +| --- | ---: | ---: | ---: | ---: | ---: | --- | +| sLSTM forward | `11312.37` | `19.26%` | `9.389` | `4.54x` | `1.313` | `native_forward_after_transition_local0=9.389 GiB` | +| Axon forward | `5569.78` | `9.48%` | `16.252` | `7.85x` | `3.750` | `native_forward_after_transition_local0=16.252 GiB` | + +Current forward stage shape: + +| Row | input/KV stage | transition stage | post-KV/readout/output-route stage | +| --- | ---: | ---: | ---: | +| sLSTM forward | `3.764 GiB` | `9.389 GiB` | `4.889 GiB` | +| Axon forward | `9.252 GiB` | `16.252 GiB` | `11.002 GiB` | + +Logical runtime role sizes still visible in the compiler ledger: + +- sLSTM: + - `transition_forward_linear_output=2.812 GiB` + - `transition_forward_matmul_output=2.250 GiB` + - `transition_forward_norm_output=0.562 GiB` + - `transition_forward_state_output=0.562 GiB` + - `forward_recurrent_msg=0.562 GiB` + - `forward_recurrent_hidden_after=0.562 GiB` +- Axon: + - `transition_forward_linear_output=3.500 GiB` + - `transition_forward_diag_output=3.500 GiB` + - `transition_forward_norm_output=1.750 GiB` + - `forward_recurrent_msg=1.750 GiB` + - `forward_recurrent_hidden_after=1.750 GiB` + +Interpretation: + +- The legal keyless-readout path improves forward throughput modestly and keeps + output readout from requiring recurrent K-after on the fixed-slot route, but + the peak remains the transition stage. +- For sLSTM, transition peak is now the main forward memory wall. Later + recurrent-KV-after/readout/output-route stages fall to `4.889 GiB`. +- For Axon, transition peak is also the main forward memory wall. Later stages + remain high at `11.002 GiB`, so Axon still has both transition + materialization and post-transition state/readout liveness pressure. +- The latest rejected transition-dead-input probe improved Axon forward to + about `14.501 GiB`, but it broke sLSTM with + `registered fused forward transition state carry source references an empty + program tensor`. That path is not legal for carry-bearing transition programs + and must not be expanded as-is. + +Training status: + +- The latest completed sLSTM warmed training steering row, from + `t1_current_owner_table_after_transition_clear_h32_100m_b1024`, is + `13.29 tok/s`, `24.908 GiB`, about `0.023%` of the April 21 throughput floor + and `12.0x` the memory floor. +- The latest bounded Axon training row, from + `t1_reduced_key_liveness_axon_training_h32_100m_b1024`, is `3.96 tok/s`, + `81.318 GiB`, about `0.0067%` of the April 21 throughput floor and `39.3x` + the memory floor. +- Training remains a severe gap, but forward stays the current optimization + priority. Training should be used as bounded parity/liveness guardrail unless + a forward change breaks artifact/tape/reducer legality. + +Remaining T=1 work, highest impact first: + +1. **Forward transition row-group materialization.** + - Biggest forward owner today. + - Need a registered transition strategy that reduces the large + `linear`/`diag`/`matmul`/`norm` live outputs, either through row-group + fusion, compiler-owned workspace reuse, or direct write of the final public + state role. + - This must consume transition primitive rows and binding rows directly; no + family selector, no benchmark shape branch, and no scheduler-owned + transition math. + +2. **Consumer-projected message/readout strategy.** + - April 21 reference inspection suggests a useful compiler-safe direction: + produce transition-input and output-cell roles directly from weighted + values when artifact/output routes prove logical `recurrent_msg` or + `output_msg` is dead. + - Expected movement: `forward_recurrent_msg`, `output_msg`, + message projection/normalization temporaries, and some post-transition + high-water allocation. + - This is a strategy over existing message/readout rows, not a semantic math + change. + +3. **Forward native high-water attribution.** + - Runtime buffer bytes are much smaller than peak. For Axon forward the + named runtime buffer total is `3.750 GiB`, but the peak is `16.252 GiB`. + - Before more alias/no-copy experiments, add or use native-stage telemetry to + split transition temporaries, message GEMM/projection temporaries, + output-route materialization, static/front-end overlap, and allocator + reserve. + +4. **Training artifact/tape/reverse liveness after forward moves.** + - sLSTM and Axon training are far below April 21 and still dominated by + reverse artifacts, backward runtime buffers, and native reverse + recurrent-message/KV stages. + - The next training-specific owner is still full key-bank adjoint allocation + inside the fixed-slot reverse message operator, followed by artifact/tape + policy and reverse transition buffers. + +5. **Coverage expansion after the steering rows move.** + - T=1 closure still requires the full April 21 matrix, not only h32 100M + B1024: B=16384, 500M/1B, h4/h8/h16 stress, high-batch 1M/10M rows, + mixed-pop T=1, reset-present, final-state/materialized-state axes, and the + dot-product semantic stress case. + +Open conclusion: + +- Forward is better than the initial fixed-slot route but still far from April + 21: sLSTM forward is roughly `5.2x` slower and Axon forward roughly `10.5x` + slower than the summary floor, with memory `4.5x` and `7.9x` above the floor. +- The biggest next forward change should attack transition materialization, + not broad training optimization and not more keyless-readout-only work. +- The best candidate implementation shape is a registered transition row-group + fusion/workspace strategy, followed by a consumer-projected message/readout + strategy if transition telemetry shows message/readout temporaries are the + next wall. + +### 2026-05-03 - Narrow Transition Dead-Input Liveness Slice + +Status: kept as a narrow registered-strategy liveness improvement. + +Boundary manifest: + +- Semantic rows unchanged: graph rows, message rows, readout rows, transition + primitive rows, tensor binding meanings, output routes, and artifact routes + are unchanged. +- Implementation changed only inside the registered fused forward transition + program. After each transition primitive, the program clears row-local input + tensor-table slots once compiler binding rows prove there is no future input + use. +- The release rule is guarded by compiler rows: + `forward_executor_binding_rows` and `forward_transition_state_carry_rows`. + If a binding is an active state-carry source, it is preserved. +- This replaces the rejected broad transition-dead-input probe. The broad probe + cleared carry-producing tensors and broke sLSTM with + `registered fused forward transition state carry source references an empty + program tensor`; this narrow slice excludes that case. +- No family selector, benchmark-shape selector, semantic formula change, or + compatibility/fallback path was added. + +Commands: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_transition_liveness_narrow_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_transition_liveness_narrow_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short + +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_transition_dead_input_narrow_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_transition_dead_input_narrow_20260503 \ +timeout 420s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_transition_dead_input_narrow_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +``` + +Results: + +| Row | Previous tok/s | New tok/s | Previous peak GiB | New peak GiB | Current allocated after transition | +| --- | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | `11312.37` | `11287.62` | `9.389` | `9.389` | `9.389 GiB` unchanged | +| Axon forward | `5569.78` | `5566.93` | `16.252` | `14.502` | `16.252 -> 9.252 GiB` | + +Interpretation: + +- The sLSTM carry-producing route is no longer broken and remains unchanged, + which is expected because active state-carry rows prevent clearing its carry + source bindings. +- Axon releases stale transition primitive outputs after their last compiler + use. The current allocation at `native_forward_after_transition_local0` + drops by about `7.000 GiB`. +- Axon max allocated still peaks at `14.502 GiB` during the transition + sequence. That means the remaining owner is not stale tensor-table retention + after the transition marker; it is temporary lifetime inside the registered + transition strategy while `diag`, `linear`, and `norm` outputs overlap. +- Throughput does not materially move in this slice; this is a memory/liveness + cleanup that narrows the next owner. + +Keep/revert decision: + +- Keep the patch. It is compiler-owned, parity passed, the sLSTM failure from + the broad probe is gone, and Axon memory moves in the intended direction. +- Do not claim T=1 closure. The next forward owner is still transition + materialization, specifically reducing the max allocation inside the + registered transition row group rather than only clearing outputs after the + group. +- Revert only if a later mixed-pop/reset/materialized-state guardrail shows an + active carry source was incorrectly classified as dead. + +### 2026-05-03 - Registered Transition Dense-Affine GEMM Slice + +Status: kept as a compiler-owned forward throughput improvement, but not a +memory-closure move. + +Boundary manifest: + +- Semantic rows unchanged: transition primitive rows still lower to `linear`, + `matmul`, `gated_logspace_recurrence`, `diag_rtu`, and `norm_or_identity`. +- The registered transition `linear` primitive now dispatches through the + shared receiver-major dense-affine strategy, which chooses large GEMM, + batched GEMM, or grouped GEMM from tensor layout facts. +- The same change was applied to both the fused full-program transition + entrypoint and the standalone registered transition-linear primitive + entrypoint. The scalar `program_transition_linear_forward_kernel` launch is no + longer used for registered transition-linear forward. +- No family selector, April21-code copy, benchmark-shape selector, scheduler + formula, semantic math change, or compatibility path was added. +- This is a native strategy implementation swap under existing compiler rows, + not a declaration or lowering change. + +Commands: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_transition_dense_affine_2_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_transition_dense_affine_2_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short + +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_dense_affine_guard_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_dense_affine_guard_20260503 \ +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_dense_affine_receiver_major_matches_reference \ + --tb=short + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured \ + tests/test_fabric_backend_plan.py::test_forward_fused_program_runtime_facts_are_compiler_owned_rows \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_transition_dense_affine_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_transition_dense_affine_20260503 \ +timeout 480s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_transition_dense_affine_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +``` + +Results: + +| Row | Previous tok/s | New tok/s | Speedup | vs Apr21 | Previous peak GiB | New peak GiB | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | `11287.62` | `11393.42` | `1.009x` | `19.40%` | `9.389` | `9.389` | +| Axon forward | `5566.93` | `6536.31` | `1.174x` | `11.13%` | `14.502` | `14.502` | + +Current allocated stage comparison: + +| Row | entry | recurrent message | transition | post-KV/readout | compaction | final state | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM previous | `3.764` | `4.327` | `9.389` | `4.889` | `3.764` | `3.483` | +| sLSTM new | `3.764` | `4.327` | `9.389` | `4.889` | `3.764` | `3.483` | +| Axon previous | `9.252` | `11.002` | `9.252` | `11.002` | `9.252` | `8.877` | +| Axon new | `9.252` | `11.002` | `9.252` | `11.002` | `9.252` | `8.877` | + +Interpretation: + +- Keep the change. It is compiler-owned, parity passed, and it gives Axon a + real forward speedup without increasing peak memory. +- The unchanged memory rows prove the remaining forward wall is not the scalar + linear kernel alone. The next owner is transition row-group materialization + and tensor traffic: `diag`/`linear`/`norm` outputs overlap inside the + transition sequence, and sLSTM still carries large gate/matmul/norm/state + outputs. +- The next high-leverage design should be a generic compiler row-group planner + that groups legal dense subgraphs into fewer GEMM/BGEMM stages and reduces + data movement between them. It should operate over primitive rows, tensor + bindings, effects, liveness, and artifact routes, not over cell names. +- Do not move into broad training optimization from this result. Training + remains a guardrail lane until forward row-group traffic is reduced. + +Next owner: + +1. Add compiler-owned dense row-group candidate generation for transition + primitive spans. +2. Prove legality with existing rows/effects/bindings/liveness. +3. Implement the first row-group strategy for the Axon + `linear -> diag_rtu -> linear -> norm_or_identity` span or the sLSTM + `linear -> matmul -> gated_logspace_recurrence -> norm_or_identity` span, + whichever has the cleaner no-extra-artifact legality proof. +4. Measure whether max allocation inside `native_forward_after_transition_local0` + moves; if not, add finer native-stage telemetry before changing more code. + +### 2026-05-03 - Axon Transition Row-Group Liveness Slice + +Status: kept as a compiler-owned forward liveness improvement. This is not T=1 +closure. + +Boundary manifest: + +- Semantic rows unchanged: the active Axon transition still lowers to + `linear -> diag_rtu -> linear -> norm_or_identity`. +- Strategy/runtime rows changed: the fused transition program now recognizes + that compiler-selected row group only when optional transition outputs, + reverse artifacts, final-state materialization, and active private-state carry + are absent. +- Tensor rows consumed directly: primitive rows, forward executor rows, + forward executor binding rows, native callable binding schema rows, native + callable output rows, program tensor binding rows, and transition state-carry + rows. +- Implementation: input projection still uses the registered dense-affine + strategy; `diag_rtu -> output linear` is run as bounded batch chunks so the + full diagonal preprojection tensor is not materialized for the whole batch; + the public `norm_or_identity` primitive remains the compiler-selected final + output producer. +- Unsupported cases fall back to the ordinary primitive-row walk: artifact + collection, final-state materialization, active state carry, or any non-matching + row/binding pattern. +- No family selector, benchmark row selector, hidden-size selector, scheduler + formula, April21 code copy, or compatibility wrapper was added. + +Rejected/narrowed probes: + +- `tmp/fabric_audits/partials/2026-05-03/t1_transition_rowgroup_forward_h32_100m_b1024` + used a scalar custom `diag_rtu + output projection` kernel. It reduced Axon + peak from `14.502 GiB` to `12.752 GiB`, but regressed Axon forward from + `6536.31` to `3934.74 tok/s`. Rejected. +- `tmp/fabric_audits/partials/2026-05-03/t1_transition_rowgroup_chunked_forward_h32_100m_b1024` + used `512 MiB` diagonal-preprojection chunks. It reduced Axon peak to + `13.252 GiB`, but still cost about `2%` throughput. Narrowed. +- Accepted chunk budget is `1 GiB`: it preserves GEMM throughput while still + removing `0.75 GiB` of Axon forward peak. + +Commands: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_transition_rowgroup_chunk1g_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_transition_rowgroup_chunk1g_20260503 \ +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short + +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + --tb=short + +git diff --check + +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_transition_rowgroup_chunk1g_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_transition_rowgroup_chunk1g_20260503 \ +timeout 480s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_transition_rowgroup_chunk1g_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +``` + +Verification: + +- Focused CUDA parity/route tests: `4 passed`. +- Source/compiler guardrails: `2 passed`. +- `git diff --check`: clean. +- Forward audit artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_transition_rowgroup_chunk1g_forward_h32_100m_b1024`. + +Results versus the dense-affine slice: + +| Row | Previous tok/s | New tok/s | Speed ratio | Previous peak GiB | New peak GiB | Peak movement | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | `11393.42` | `11425.59` | `1.003x` | `9.389` | `9.389` | unchanged | +| Axon forward | `6536.31` | `6516.74` | `0.997x` | `14.502` | `13.750` | `-0.752 GiB` | + +Native-stage movement: + +| Row | Transition stage | Recurrent-KV/readout/output-route stage | Peak stage | +| --- | ---: | ---: | --- | +| sLSTM forward | `9.389 GiB` | `4.889 GiB` | transition | +| Axon forward | `9.252 GiB` | `11.002 GiB` | output route / post-transition | + +Interpretation: + +- Keep the change. It is compiler-owned, parity passed, sLSTM is unchanged, and + Axon retains the dense-affine speedup while reducing peak memory. +- The accepted row-group moved the Axon peak out of the transition row-group + itself. Axon now peaks later around recurrent-KV-after/readout/output-route + ownership, with a large unclassified allocator gap still present. +- This does not close T=1. The April21 floor is still `58732.71 tok/s`, + `2.07 GiB`; current accepted h32 100M B1024 forward rows are still about + `19.45%` of target for sLSTM and `11.10%` for Axon. + +Next owner: + +1. Forward post-transition owner: recurrent K/V-after, readout message/output + route materialization, and the remaining unclassified allocator peak. +2. sLSTM transition owner: transition still peaks at `9.389 GiB`; a separate + row-group/liveness strategy is needed for the gated row group. +3. Keep training as bounded parity/liveness evidence only until forward owners + move further. + +### 2026-05-03 - Direct Keyless Readout Forward Strategy + +Status: accepted forward-only registered strategy improvement. This is not T=1 +closure. + +Boundary manifest: + +- Semantic rows unchanged: message and readout declarations still lower through + the compiler-owned fixed-slot-context message strategy and readout strategy. +- Strategy metadata changed: fixed-slot-context forward message strategies now + declare an optional `direct_keyless_readout_message` phase in addition to the + existing bind, recurrent-K/V, message, and keyless-readout phases. +- Runtime behavior changed only when the registered strategy advertises that + phase, reverse artifacts are not requested, and the existing keyless readout + legality conditions hold. +- Implementation avoids materializing full recurrent V-after for forward + readout. The readout message kernel projects recurrent values from + `recurrent_hidden` and the strategy-bound recurrent value weight on demand. +- Unsupported cases still use the existing registered keyless readout or the + ordinary compiler-owned message/readout route. No benchmark branch, family + branch, April21 code copy, fixed tensor-slot enum, or compatibility wrapper + was added. + +Commands: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_message_readout_handlers_use_native_strategy_access_schema \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned \ + --tb=short + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured \ + tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_direct_keyless_readout_compile \ +TRITON_CACHE_DIR=/tmp/cortical_triton_direct_keyless_readout_compile \ +timeout 360s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_direct_keyless_readout_parity \ +TRITON_CACHE_DIR=/tmp/cortical_triton_direct_keyless_readout_parity \ +timeout 360s uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short + +git diff --check + +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_direct_keyless_readout_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_direct_keyless_readout_20260503 \ +timeout 480s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_direct_keyless_readout_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +``` + +Verification: + +- Source/compiler guardrails: `5 passed`. +- CUDA compile/parity guardrails: `4 passed`. +- `git diff --check`: clean. +- Forward audit artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_direct_keyless_readout_forward_h32_100m_b1024`. + +Results versus the accepted transition row-group slice: + +| Row | Previous tok/s | New tok/s | Speed ratio | Previous peak GiB | New peak GiB | Readout/output-route current GiB | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | `11425.59` | `11628.62` | `1.018x` | `9.389` | `9.389` | `4.889 -> 4.327` | +| Axon forward | `6516.74` | `6845.31` | `1.050x` | `13.750` | `13.750` | `11.002 -> 9.252` | + +Interpretation: + +- Keep the change. It is compiler-owned, preserves the registered fixed-slot + semantics, and moves the post-readout/output-route allocation down without + increasing peak memory. +- The unchanged peak is expected from the stage ledger. sLSTM still peaks inside + transition at `9.389 GiB`; Axon still has an earlier recurrent-K/V-before and + recurrent-message current allocation of `11.002 GiB`, and the overall + allocator high-water remains `13.750 GiB`. +- The next forward owner is not readout V-after anymore. It is the recurrent + K/V-before plus recurrent-message materialization path for Axon, and the gated + transition row group for sLSTM. Training remains a bounded parity/liveness + guardrail, not the main optimization lane. + +### 2026-05-03 - Rejected Direct Recurrent-Message Probe + +Status: rejected and removed from the active registry/native path. This was a +valid compiler-owned throughput probe, but it is not a strategy to keep. + +Boundary manifest: + +- Semantic rows unchanged: fixed-slot-context message primitive rows and tensor + bindings were unchanged. +- Probe strategy: add an optional registered forward message phase that produced + recurrent message directly from recurrent hidden and the strategy-bound + recurrent value weight when reverse artifacts were not requested. +- Old route interaction: the probe did not add a scheduler formula or benchmark + branch, but it replaced a fast recurrent-value-bank GEMM-style projection with + per-edge hidden projection inside the attention loop. +- Revert decision: removed the `direct_message` phase, native callable, kernel, + and registry entry. The active route is back to the accepted + direct-keyless-readout state above. + +Commands: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured \ + tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_direct_message_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_direct_message_20260503 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_direct_message_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_direct_message_20260503 \ +timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_direct_message_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +``` + +Verification: + +- Source/compiler guardrails before revert: `3 passed`. +- CUDA route/parity before revert: `4 passed`. +- Source/compiler guardrails after revert: `3 passed`. +- CUDA route/parity after revert: + `tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store`, + `tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program`, + and + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program` + all passed (`4 passed`). +- `git diff --check`: clean. +- Probe artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_direct_message_forward_h32_100m_b1024`. + +Result versus the accepted direct-keyless-readout row: + +| Row | Previous tok/s | Probe tok/s | Speed ratio | Previous peak GiB | Probe peak GiB | Live recurrent K/V-before current | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | `11628.62` | `6314.06` | `0.543x` | `9.389` | `9.389` | `4.327 -> 3.764 GiB` | +| Axon forward | `6845.31` | `2513.91` | `0.367x` | `13.750` | `13.750` | `11.002 -> 9.252 GiB` | + +Interpretation: + +- The probe shortened the intended live recurrent-value-bank edge, but it did + not reduce high-water peak memory and it badly regressed throughput. +- The reason is structural: recurrent value projection was recomputed per + receiver/edge/head lane instead of being represented as a grouped/batched GEMM + or streamed producer-consumer tile. +- Do not revive this direct per-edge projection route. The next forward strategy + should keep the compiler rows stable but change the implementation shape to a + row-owned grouped GEMM/BMM or chunked producer-consumer plan that reduces full + bank lifetime without multiplying hidden-projection work inside attention. + +### 2026-05-03 - Rejected Message Projection-Out Probe + +Status: rejected and removed. This probe replaced the fixed-slot message +projection `matmul -> projected temporary -> normalize into recurrent_msg` with +`dense_affine_out` directly into the compiler-planned recurrent-message buffer, +then in-place normalization. It preserved rows and parity, but did not move the +measured owner. + +Commands: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured \ + tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_message_projection_out_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_message_projection_out_20260503 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_message_projection_out_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_message_projection_out_20260503 \ +timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_message_projection_out_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +``` + +Verification: + +- Source/compiler guardrails: `3 passed`. +- CUDA route/parity: `4 passed`. +- Source/compiler guardrails after revert: `3 passed`. +- CUDA route/parity after revert: + `tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store`, + `tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program`, + and + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program` + all passed (`4 passed`). +- `git diff --check`: clean. +- Probe artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_message_projection_out_forward_h32_100m_b1024`. + +Result versus the accepted direct-keyless-readout row: + +| Row | Previous tok/s | Probe tok/s | Speed ratio | Previous peak GiB | Probe peak GiB | +| --- | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | `11628.62` | `11327.36` | `0.974x` | `9.389` | `9.389` | +| Axon forward | `6845.31` | `6890.43` | `1.007x` | `13.750` | `13.750` | + +Interpretation: + +- The change preserved correctness but did not move current-stage or high-water + memory; sLSTM got slower and Axon moved only within noise. +- Rejected as a closure step. The next forward owner is not merely the + projection-output tensor. It is the broader recurrent-message/transition + producer-consumer shape: fewer materialized banks and fewer GEMM boundaries, + planned as grouped/batched GEMM or a chunked row-owned producer-consumer + strategy. + +### 2026-05-03 - Remaining T=1 Throughput Work After Rejected Forward Probes + +Status: analysis only. No optimization was accepted in this pass. + +Active accepted forward steering baseline: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_direct_keyless_readout_forward_h32_100m_b1024`. +- Target floor remains April 21 `h32_t1_bxparams`: + `58732.71 tok/s`, `2.07 GiB`. + +Current accepted h32 100M B1024 forward rows: + +| Row | tok/s | vs Apr21 | Peak GiB | Memory vs Apr21 | Dominant live owner | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM forward | `11628.62` | `19.8%` | `9.389` | `4.5x` | transition row group | +| Axon forward | `6845.31` | `11.7%` | `13.750` | `6.6x` | recurrent K/V-before + recurrent-message current; high-water still unclassified around transition/post-transition | + +Accepted/rejected lesson from the last two probes: + +- Direct per-edge recurrent-message projection is wrong for throughput. It + removed the recurrent-value bank lifetime but replaced a GEMM-shaped producer + with per-edge hidden projection and regressed Axon to `2513.91 tok/s`. +- Direct projection into the recurrent-message buffer did not move the owner. + The problem is not just one projected-message temporary. +- The next valid strategy must preserve grouped/batched GEMM/BMM structure while + reducing producer-consumer materialization. Do not expand per-edge projection + or local no-copy swaps without owner movement. + +Highest-impact remaining T=1 work, in order: + +1. **sLSTM transition row-group strategy.** + - sLSTM still peaks at `native_forward_after_transition_local0=9.389 GiB`. + - Runtime role ledger shows large transition intermediates: + `transition_forward_linear_output=2.812 GiB`, + `transition_forward_matmul_output=2.250 GiB`, + `transition_forward_norm_output=0.562 GiB`, + `transition_forward_state_output=0.562 GiB`. + - Required direction: a compiler-owned gated row-group strategy over + existing primitive rows that reduces gate/matmul/norm/state overlap, + preferably by grouped/BMM packing and row-owned workspace reuse. + +2. **Axon recurrent-message/transition producer-consumer strategy.** + - Accepted Axon forward has + `native_forward_after_recurrent_kv_before_local0=11.002 GiB` and + `native_forward_after_recurrent_message_local0=11.002 GiB`. + - Direct per-edge projection was rejected; the correct shape is a chunked or + grouped/BMM producer-consumer strategy that avoids full recurrent V/message + lifetime without recomputing hidden projections inside each edge. + - This should consume message executor rows, transition aggregate-input + access rows, output/artifact routes, and memory-liveness rows directly. + +3. **Forward native high-water attribution.** + - Named runtime buffers do not explain peak. Axon has about `3.750 GiB` of + planned runtime buffers but `13.750 GiB` peak and a large allocator/reserve + gap. + - Before more alias/no-copy work, add or use native-stage telemetry inside + the registered forward program to split high-water into transition + temporaries, recurrent-message projection/normalization, readout/output + routing, static/front-end overlap, and allocator reserve. + +4. **Training reverse/artifact/reducer liveness after forward moves.** + - Training is still far below target, but forward remains the main lane. + - Latest bounded steering rows still show large reverse owners: + sLSTM training around `13.29 tok/s`, `24.91 GiB`; + Axon training around `3.96 tok/s`, `81.32 GiB`. + - The biggest training-specific owner remains fixed-slot reverse recurrent + message/KV adjoint materialization and transition reverse span outputs. + Keep training as parity/liveness guardrail until forward producer-consumer + owners move. + +5. **Closure coverage expansion.** + - The current steering rows are not closure. T=1 still needs the April21 + matrix: `100M/500M/1B`, `B=1024/16384`, sLSTM + Axon, forward + training. + - Required extra axes before claiming T=1: mixed-pop T=1, h4/h8/h16 stress, + high-batch 1M/10M rows, reset-present, final-state/materialized-state + axes, and the dot-product semantic stress case. + +Current conclusion: + +- The next plan should not be broad training optimization and should not revive + rejected message shortcuts. +- The biggest forward work is a real row-group/producer-consumer strategy: + first sLSTM transition row-group or Axon recurrent-message-to-transition + producer-consumer, with native high-water attribution if the owner is still + ambiguous. + +### 2026-05-03 - Plan: sLSTM Transition Row-Group Forward Strategy + +Status: implemented as a bounded native-strategy probe. Safe to keep, but not +counted as forward owner closure. + +Selected next owner: + +- Start with the sLSTM transition row group. It is the clearest accepted T=1 + forward owner after the rejected message probes: + `native_forward_after_transition_local0=9.389 GiB`. +- Keep Axon in the perf row as a non-regression guard, but do not make Axon + recurrent-message producer-consumer the first implementation target in this + pass. The two latest message probes proved that naive direct projection + either regresses throughput or does not move the owner. + +Boundary manifest: + +```text +Lane: throughput strategy +Expected semantic delta: none +Unchanged semantic rows: existing sLSTM transition primitive rows and tensor roles +Changed rows: registered forward strategy/native callable and memory/liveness rows only +Tensor rows consumed: aggregate input, transition parameter roles, transition state-before/state-after, public hidden output +Artifact rows consumed: recurrent_hidden_after, transition state-before/state-after, forward artifact routes required for backward +Output routes consumed: none, except existing downstream carry/output routes +Old route being replaced: per-primitive transition temporaries kept live across the row group +Old route not allowed: cell-family branch, fixed-slot transition ABI, benchmark shape selector, hidden-size selector +``` + +Implementation shape for the next "proceed" pass: + +1. **Preflight the row group.** + - Inspect the compiled primitive/executor/access rows for the active sLSTM + T=1 forward path. + - Record the row fingerprint and binding fingerprint before editing. + - Confirm the candidate strategy can match structurally from rows and roles, + not from the cell family name. + - Typed reject if required artifact, reset, materialized-state, dtype, or + layout contracts are not represented by rows. + +2. **Add native-stage transition attribution if the owner is still too broad.** + - Add metadata-only telemetry inside the registered forward program around + the transition row group: input projection, gate/matmul workspace, recurrence + update, normalization/public-state output, and state/artifact writeback. + - Artifact path: + `tmp/fabric_audits/partials/2026-05-03/t1_slstm_transition_rowgroup_attribution_h32_100m_b1024`. + - Keep rule: use this only to choose the exact liveness/workspace edge. Do + not accept it as a throughput optimization. + +3. **Implement the first row-group strategy.** + - Add a registered `gated_transition_rowgroup` style strategy over existing + primitive rows. + - Use row-owned workspace/liveness to avoid simultaneously retaining full + `transition_forward_linear_output`, `transition_forward_matmul_output`, + `transition_forward_norm_output`, and `transition_forward_state_output` + beyond their declared consumers. + - Preserve GEMM/BMM-shaped work. Do not move to per-edge or per-cell scalar + recompute. + - Return only compiler-declared semantic outputs, artifacts, reducer inputs, + and metadata. Local scratch must be workspace, not a new semantic return. + +4. **Verification gates.** + - Source/compiler guardrails: + + ```bash + uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured \ + tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + --tb=short + ``` + + - CUDA route/parity guardrails: + + ```bash + CUDA_VISIBLE_DEVICES=0 \ + TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_slstm_transition_rowgroup_20260503 \ + TRITON_CACHE_DIR=/tmp/cortical_triton_slstm_transition_rowgroup_20260503 \ + timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short + ``` + + - If reset, state materialization, or artifact lifetimes are touched, add one + targeted reset-present/materialized-state parity row before running perf. + +5. **Representative perf row.** + + ```bash + CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ + TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_slstm_transition_rowgroup_20260503 \ + TRITON_CACHE_DIR=/tmp/cortical_triton_t1_slstm_transition_rowgroup_20260503 \ + timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_slstm_transition_rowgroup_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner + ``` + +Acceptance rule: + +- Accept only if the named sLSTM transition owner moves materially below the + accepted baseline peak/current stage, or sLSTM forward throughput improves + materially without increasing peak memory. +- Axon forward must not regress by more than noise and must not exceed the + accepted `13.750 GiB` peak. +- Runtime metadata must name the registered strategy and prove the row/binding + fingerprints stayed semantic-stable. + +Revert or narrow rule: + +- Revert if parity fails, if the owner does not move, if speed regresses without + a measured memory win, or if implementation requires fixed slots, family + selectors, hidden-size selectors, benchmark policy, or temporal-scheduler + formulas. +- If current live memory drops but high-water does not, keep only if telemetry + identifies the remaining high-water owner and document it as a narrow + liveness win, not T=1 closure. + +Next owner after this plan: + +- If the sLSTM transition row group moves, rerun the owner table and then attack + Axon recurrent-message-to-transition producer-consumer with the same + compiler-owned row/liveness discipline. +- If it does not move, stop transition work and use native-stage attribution to + name the hidden allocator/native owner before adding more aliases or + workspace metadata. + +### 2026-05-03 - Accepted sLSTM Gated Transition Row-Group Liveness Strategy + +Status: kept as a compiler-owned T=1 forward memory/liveness improvement. This +is not throughput closure. + +Boundary manifest: + +- Semantic rows unchanged: the sLSTM transition still lowers to + `linear -> linear -> matmul -> gated_logspace_recurrence -> norm_or_identity`. +- Strategy/runtime rows changed: the registered fused forward transition + program now has a gated row-group path for that structural span when reverse + artifacts and final program tensors are not requested. +- Tensor rows consumed directly: primitive rows, forward executor rows, forward + executor binding rows, program tensor binding rows, native callable binding + schema rows, native callable output rows, memory/liveness rows, runtime buffer + rows, and transition state-carry rows. +- Implementation preserves GEMM/BMM-shaped work: aggregate input projection + still uses dense affine, gate affine uses chunked batched GEMM, recurrent + matmul uses the registered recurrent-matmul primitive kernel by chunk, and + norm remains the compiler-selected public output primitive. +- The previous active-state guard was too broad. Row-group strategies are now + allowed when `!return_reverse_artifacts && !return_final_program_tensors`; + dead-input clearing preserves active state-carry source slots by compiler + tensor-table storage identity, not only by binding id. +- No family selector, benchmark row selector, hidden-size selector, fixed slot, + scheduler formula, April21 code copy, or compatibility wrapper was added. + +Rejected/fixed first attempt: + +- Artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_slstm_transition_rowgroup_forward_h32_100m_b1024`. +- sLSTM failed with + `registered fused forward transition state carry source references an empty program tensor`. +- Cause: clearing a dead input binding could clear the same tensor-table slot as + an active carry source binding. The fix preserves by tensor-table slot + identity through + `forward_transition_binding_slot_aliases_active_state_carry_source`. + +Commands: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured \ + tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + --tb=short + +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_slstm_transition_rowgroup_liveness_fix_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_slstm_transition_rowgroup_liveness_fix_20260503 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_slstm_transition_rowgroup_liveness_fix_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_slstm_transition_rowgroup_liveness_fix_20260503 \ +timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_slstm_transition_rowgroup_liveness_fix_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none \ + --hidden-sizes 32 --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +``` + +Verification: + +- Source/compiler guardrails: `3 passed`, `2 passed`, and the focused + slot-alias guardrail passed. +- CUDA route/parity after the slot-alias fix: `4 passed`. +- `git diff --check`: clean. +- Clean forward audit artifact: + `tmp/fabric_audits/partials/2026-05-03/t1_slstm_transition_rowgroup_liveness_fix_forward_h32_100m_b1024`. + +Results versus the accepted direct-keyless-readout row: + +| Row | Previous tok/s | New tok/s | Speed ratio | Previous peak GiB | New peak GiB | Current transition stage | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | `11628.62` | `11521.96` | `0.991x` | `9.389` | `7.139` | `9.389 -> 4.327 GiB` | +| Axon forward | `6845.31` | `6879.16` | `1.005x` | `13.750` | `13.750` | `9.252 -> 9.252 GiB` | + +Interpretation: + +- Keep the patch. It moves the named sLSTM transition owner materially: + transition current allocation drops by about `5.062 GiB`, and max allocated + drops by about `2.250 GiB`. +- The small sLSTM speed regression is within the keep rule because the named + memory owner moved substantially and Axon did not regress. +- Planned runtime-role bytes still list the same transition roles because the + compiler rows are unchanged; the improvement is active local materialization + and slot-lifetime behavior inside the registered strategy. +- The sLSTM high-water is now a native temporary/allocator high-water inside the + row-group execution, not retained full gate/recurrent-logit outputs after the + transition stage. + +Next owner: + +1. Rerun the owner table before another optimization pass. +2. For sLSTM, add finer native-stage telemetry inside the gated row group if the + next target remains the `7.139 GiB` max allocated high-water. +3. For Axon, the next forward owner remains recurrent K/V-before plus + recurrent-message producer-consumer materialization; direct per-edge + projection remains rejected. +4. Keep training as bounded parity/liveness guardrail until forward owners move + again. + +April 21 semantic-transfer directive: + +- Use the April 21 evidence to recover mechanisms, not code. The target + mechanisms are low-live-memory registered temporal execution, grouped dense + producer-consumer work, and compiler-owned artifact/reducer contracts that + preserve the declared semantics. +- Before another long probe, identify the April 21 mechanism being transferred: + grouped GEMM/BMM route, transition row-group lifetime, recurrent K/V/message + producer-consumer lifetime, output/readout materialization, or reverse/reducer + artifact policy. +- The accepted current-code version must be expressed through primitive rows, + tensor bindings, memory/liveness rows, executor strategy records, and explicit + backend strategy metadata. Do not copy April 21 kernels, add benchmark-side + tiling, route by family/shape/hidden size, or hide CUDA/Triton choices in the + temporal scheduler. +- A semantic-transfer probe is useful only if the physical owner moves in timing, + launch shape, allocator/high-water telemetry, storage identity, or named + memory stage. Otherwise record it as a rejected mechanism and stop expanding + that direction. + +### 2026-05-03 - Current T=1 Owner Table After sLSTM Row-Group Fix + +Status: analysis only. No optimization was performed in this pass. + +Fresh current-code owner-table command: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_owner_after_slstm_rowgroup_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_owner_after_slstm_rowgroup_20260503 \ +timeout 900s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_owner_table_after_slstm_rowgroup_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward,forward_backward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +``` + +Artifact: +`tmp/fabric_audits/partials/2026-05-03/t1_owner_table_after_slstm_rowgroup_h32_100m_b1024`. + +The run completed three rows and timed out while measuring Axon training. Treat +the Axon-training row as still open for a bounded rerun, not as a completed +current-code closure result. + +| Row | Status | tok/s | vs Apr21 `58732.71` | Peak GiB | vs Apr21 `2.07 GiB` | Fresh owner signal | +| --- | --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | ok | `11244.68` | `19.15%` | `7.139` | `3.45x` | current live stages top out at `4.327 GiB`; peak is now `5.451 GiB` unclassified/native high-water | +| sLSTM 100M h32 B1024 T1 training | ok | `13.23` | `0.02%` | `25.049` | `12.10x` | runtime buffers `7.500 GiB`, reverse artifacts `4.500 GiB`, unclassified peak `12.298 GiB` | +| Axon 100M h32 B1024 T1 forward | ok | `6856.22` | `11.67%` | `13.783` | `6.66x` | recurrent K/V-before and recurrent-message current stage `11.035 GiB`; unclassified peak `9.657 GiB` | +| Axon 100M h32 B1024 T1 training | timed out | - | - | - | - | latest completed steering row remains stale: about `3.96 tok/s`, `81.32 GiB`; rerun separately after forward owners move | + +Fresh forward owner ranking: + +1. **Axon recurrent K/V-before plus recurrent-message producer-consumer.** + The current Axon forward row still rises from `9.285 GiB` at native entry to + `11.035 GiB` after recurrent K/V-before/message, while peak remains + `13.783 GiB`. Direct per-edge recurrent-message projection was already + rejected, so the next legal shape is a grouped/BMM producer-consumer strategy + that reduces full K/V/message lifetime without replacing GEMM-shaped work + with per-edge recompute. +2. **sLSTM native high-water attribution inside the gated row group.** + The accepted row-group fix moved current live memory down to `4.327 GiB`, + but max allocated is still `7.139 GiB`. Before another sLSTM liveness patch, + split the row group into gate-affine, recurrent-matmul, recurrence, norm, and + output/writeback high-water stages. +3. **Training reverse/artifact/reducer liveness.** + Training is still much farther from April 21, but it should stay a bounded + guardrail until forward owners move. The current sLSTM training row already + shows `7.500 GiB` runtime buffers, `4.500 GiB` reverse artifacts, and + `12.298 GiB` unclassified peak; Axon training still needs a separate bounded + rerun because the four-row table timed out. +4. **Coverage expansion.** + This is still only the h32 100M B1024 single-pop steering row. T=1 closure + still requires the April21 matrix: 100M/500M/1B, B=1024/16384, sLSTM + Axon, + forward + training, plus mixed-pop, h4/h8/h16, high-batch small-param, + reset-present, final-state/materialized-state, and the dot-product semantic + stress case. + +Conclusion: + +- Forward remains the correct priority. The largest fresh forward miss is Axon + producer-consumer lifetime, not broad reverse optimization. +- Do not reopen the rejected direct-message/direct-projection probes. The next + implementation should preserve grouped/BMM structure and consume compiler + message/readout/transition rows, program access rows, output/artifact routes, + and memory-liveness rows directly. +- Do not claim throughput closure from the current h32 100M B1024 row. It is a + steering owner table only. + +### 2026-05-03 - Plan: Axon Forward Producer-Consumer Strategy + +Status: implemented as a bounded native-strategy probe. Safe to keep, but not +counted as forward owner closure. + +Selected owner: + +- Attack Axon forward recurrent K/V-before plus recurrent-message + producer-consumer lifetime. +- Fresh current-code evidence: + `tmp/fabric_audits/partials/2026-05-03/t1_owner_table_after_slstm_rowgroup_h32_100m_b1024`. +- Axon forward rises from `9.285 GiB` at native entry to `11.035 GiB` after + recurrent K/V-before and recurrent-message, with peak `13.783 GiB`. +- Direct per-edge recurrent-message projection and direct projection-out probes + are rejected. The next implementation must preserve grouped/BMM/GEMM-shaped + work while reducing full K/V/message lifetime. + +Boundary manifest: + +```text +Lane: throughput strategy +Expected semantic delta: none +Unchanged semantic rows: fixed-slot context message rows, Axon transition rows, readout rows, output routes +Changed strategy/runtime rows: registered forward message/readout/transition strategy and memory/liveness rows only +Tensor rows consumed: message executor rows, message program access rows, transition aggregate-input access rows, readout access rows, program tensor binding rows +Artifact/output routes consumed: recurrent_msg, output_msg, recurrent_hidden_after, output_cells, output route rows, forward artifact route rows +Memory/liveness rows consumed: forward_recurrent_msg, forward_recurrent_hidden_after, transition_forward_* outputs, output/readout buffers, dead-input and keep-slot rows +Old route being replaced: full materialization of recurrent K/V-before and recurrent_msg across transition/readout when compiler consumers can stream or consume projected chunks +Old route not allowed: per-edge hidden projection, fixed-slot ABI expansion, family/hidden-size/benchmark selectors, temporal-scheduler formulas, April21 code copy +``` + +Implementation shape for the next `proceed` pass: + +1. **Preflight row and binding proof.** + - Dump or inspect the active Axon message, transition, readout, access, and + liveness rows for the h32 100M B1024 T1 forward path. + - Record primitive-row and tensor-binding fingerprints before editing. + - Confirm the strategy can match by row pattern and access roles, not by + `axoncell`, hidden size, shape, or benchmark id. + - Typed reject if artifact/output routes require full `recurrent_msg`, + `output_msg`, or recurrent K/V storage for the active mode. + +2. **Add a small producer-consumer probe, not a broad rewrite.** + - Start with Axon forward-only, no reverse artifacts, no final-state materialization. + - Keep existing key/value projection GEMM/BMM structure. + - Introduce a row-owned strategy phase that either: + - streams recurrent value/message chunks directly into the transition + aggregate input when `recurrent_msg` is dead after transition, or + - aliases/chunks recurrent-message storage with transition aggregate-input + storage when compiler liveness proves non-overlap. + - Do not use direct per-edge projection. That path already moved current live + bytes but badly regressed throughput. + +3. **Preserve readout contract.** + - Existing direct-keyless readout remains the accepted readout path for + forward-only rows. + - If readout needs the same producer-consumer data, route it through output + route/access rows rather than role-only or singleton assumptions. + - If multiple readout/message producers are not supported by this strategy, + reject through compiler legality before launch. + +4. **Add targeted native-stage telemetry only where needed.** + - If the first probe moves current allocation but not peak, split native + high-water into: + recurrent K/V projection, recurrent-message aggregation, transition + aggregate input, transition row-group, readout message/projection, output + route, and tensor compaction. + - Telemetry is metadata only; semantic returns stay row-owned tensors. + +5. **Source and parity gates.** + + ```bash + uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_message_readout_handlers_use_native_strategy_access_schema \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows \ + --tb=short + + CUDA_VISIBLE_DEVICES=0 \ + TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_axon_producer_consumer_parity_20260503 \ + TRITON_CACHE_DIR=/tmp/cortical_triton_t1_axon_producer_consumer_parity_20260503 \ + timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short + ``` + +6. **Representative perf gate.** + + ```bash + CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ + TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_axon_producer_consumer_20260503 \ + TRITON_CACHE_DIR=/tmp/cortical_triton_t1_axon_producer_consumer_20260503 \ + timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_axon_producer_consumer_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none \ + --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner + ``` + +Acceptance rule: + +- Accept if Axon forward current allocation after recurrent K/V-before/message + drops materially below `11.035 GiB` or Axon peak drops materially below + `13.783 GiB`, while Axon throughput does not regress beyond noise. +- sLSTM forward must stay within noise and must not exceed `7.139 GiB` peak. +- Runtime metadata must prove the strategy consumed compiler rows/bindings and + did not add hidden fixed slots, family selectors, benchmark selectors, or + scheduler formulas. + +Revert/narrow rule: + +- Revert if it resembles the rejected direct-message path: lower current memory + but large throughput regression from per-edge or scalar recompute. +- Narrow to telemetry-only if current allocation moves but peak does not and + the remaining high-water is still unclassified. +- Stop and reclassify as compiler-extension work if the implementation needs a + new tensor role, new output route, new message/readout semantics, or new + backward/reducer contract. + +Next owner after this plan: + +- If Axon forward producer-consumer moves, rerun the four-row owner table and + then decide between sLSTM native high-water attribution and training + reverse/artifact/reducer liveness. +- If it does not move, stop message/readout work and add finer native-stage + telemetry before another liveness or alias patch. + +Implementation result: + +- Changed only the registered fixed-slot context message native strategy: + `message_forward_strategies.cuh` now chunks the large + weighted-value-to-message projection GEMM before normalizing into + `forward_recurrent_msg`. It does not add semantic rows, fixed slots, family + selectors, benchmark selectors, or scheduler-owned formulas. +- This keeps the work GEMM-shaped and avoids the rejected direct per-edge + recurrent-message route. +- It is safe as a local native-strategy liveness improvement, but the audit shows + the full projected-message temporary is not the dominant Axon T=1 forward + owner. + +Validation: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_plan.py::test_registered_message_strategies_consume_compiler_sender_tables_not_geometry \ + --tb=short +# 5 passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + --tb=short +# 1 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_axon_producer_consumer_parity_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_axon_producer_consumer_parity_20260503 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short +# 3 passed +``` + +Perf gate: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_axon_producer_consumer_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_axon_producer_consumer_20260503 \ +timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_axon_producer_consumer_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none \ + --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Result table: + +```text +Row Tok/s Peak GiB Peak vs previous Native current owner +sLSTM forward 11627.87 7.139 flat after output route, 4.327 GiB current +Axon forward 6886.94 13.750 -0.033 GiB after recurrent message, 11.002 GiB current +``` + +Comparison to the previous successful owner table: + +- sLSTM forward was `11244.68 tok/s`, `7.139 GiB`; this patch is within noise + on memory and slightly faster. +- Axon forward was `6856.22 tok/s`, `13.783 GiB`; this patch is slightly faster + and only `0.033 GiB` lower peak. +- Axon native current after recurrent K/V-before/message was `11.035 GiB`; it is + now `11.002 GiB`. This is not material closure. + +Keep/revert decision: + +- Keep the chunked projection path because it is compiler-owned, row-local, and + did not regress throughput or parity. +- Do not treat it as the solved owner. +- Stop expanding the message-projection route for now. The current owner is + already present at native forward entry and in preallocated compiler runtime + buffers, especially transition forward intermediates: + `transition_forward_diag_output`, `transition_forward_linear_output`, + `transition_forward_norm_output`, plus `forward_recurrent_hidden_after` and + `forward_recurrent_msg`. + +Next exact forward owner: + +1. Move transition forward intermediates from unconditional full runtime-buffer + allocation into compiler-liveness-driven local/chunked transition strategy + storage. +2. The next patch should target `transition_forward_*` runtime roles and the + registered Axon transition row group, not message/readout projection. +3. Acceptance should require a material drop in Axon native entry/current memory, + not just a small high-water change inside the message native strategy. + +### 2026-05-03 - Axon Transition Deferred-Local Row-Group Liveness + +Boundary manifest: + +```text +Lane: throughput strategy/native strategy implementation +Expected semantic delta: none +Unchanged semantic rows: fixed-slot context message rows, Axon diag_rtu transition rows, readout rows, output routes +Changed strategy/runtime rows: registered transition native row-group liveness only +Rows consumed directly: primitive rows, forward executor rows, native callable binding rows, native callable output rows, program tensor binding rows, memory runtime buffer rows +Memory/liveness rows consumed: transition_forward_linear_output, transition_forward_diag_output, transition_forward_norm_output, forward_recurrent_hidden_after, forward_recurrent_msg +Old route replaced: immediate materialization of deferred-local transition row-group outputs into full [B,R,H] buffers +Old route not allowed: fixed slots, family/hidden-size/benchmark selectors, scheduler formulas, new semantic tensor roles +``` + +Implementation: + +- Added an unmaterialized native-callable output lookup for compiler-declared + deferred-local runtime buffers. The normal helper still materializes when a + full runtime tensor is required. +- Updated the registered `diag_rtu`/Axon transition row-group so, when the + compiler liveness plan marks both input-projection output and raw public-y + output as deferred local, it runs input projection, diagonal recurrence, + output projection, and normalization in batch chunks. +- Artifact/training cases still use the existing full-materialization route + because those tensors are required by reverse artifact rows. + +Validation: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +# 3 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_axon_transition_liveness_parity_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_axon_transition_liveness_parity_20260503 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short +# 4 passed +``` + +Perf gate: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_axon_transition_liveness_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_axon_transition_liveness_20260503 \ +timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_axon_transition_liveness_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none \ + --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Result table: + +```text +Row Tok/s Peak GiB Peak vs previous Native current owner +sLSTM forward 11628.33 7.139 flat after output route, 4.327 GiB current +Axon forward 6843.29 12.002 -1.748 GiB after recurrent message, 11.002 GiB current +``` + +Comparison to the previous successful producer-consumer row: + +- sLSTM forward stayed flat on memory: `7.139 GiB` peak. +- Axon forward peak dropped from `13.750 GiB` to `12.002 GiB`. +- Axon native current is still `9.252 GiB` at native entry, rises to + `11.002 GiB` after recurrent K/V-before/message, and returns to `9.252 GiB` + after transition. The patch removed the transition high-water, but it did not + solve the pre-transition recurrent K/V/message current owner. +- Planned runtime-role bytes still show large logical transition roles: + `transition_forward_linear_output=3.500 GiB`, + `transition_forward_diag_output=3.500 GiB`, + `transition_forward_norm_output=1.750 GiB`. These are now partly logical + deferred-local roles for the forward-only path, so closure decisions should + use active allocator stages, not role-byte totals alone. + +Keep/revert decision: + +- Keep the patch. It is compiler-owned, parity-clean, and materially reduces + Axon forward peak without hurting sLSTM memory. +- Do not call T=1 forward closed. The remaining active owner is pre-transition + recurrent K/V-before/message current allocation plus native-entry/runtime + buffer pressure. + +Next exact forward owner: + +1. Attack recurrent K/V-before/message current allocation while preserving the + grouped/BMM/GEMM shape. Direct per-edge projection remains rejected. +2. Use compiler liveness/access rows to shorten producer-consumer lifetime + around recurrent K/V-before and recurrent message, or prove the remaining + owner is actually frontend/static/state allocation before registered entry. +3. Keep training only as a bounded parity/liveness guardrail until the forward + owner table moves again. + +### 2026-05-03 - Deferred Recurrent Value Producer-Consumer Slice + +Status: kept as a compiler-owned T=1 forward strategy improvement. This is not +T=1 closure. + +Boundary manifest: + +```text +Lane: throughput strategy/native strategy implementation +Expected semantic delta: none +Unchanged semantic rows: fixed-slot context message rows, Axon transition rows, readout rows, output routes +Changed strategy/runtime rows: registered fixed-slot context message native implementation only +Rows consumed directly: forward executor rows, native callable binding rows, program tensor binding rows, runtime buffer rows +Memory/liveness rows consumed: forward_recurrent_msg, forward_recurrent_hidden_after, transition/output route liveness rows +Old route replaced: full forward-only recurrent value bank materialization before recurrent-message attention +Old route not allowed: per-edge hidden projection, fixed-slot ABI expansion, family/hidden-size/benchmark selectors, temporal-scheduler formulas, April21 code copy +``` + +Implementation: + +- Replaced the recurrent sender-value step scalar projection kernel with the + existing compiler-owned receiver-major dense affine strategy. This keeps the + operation batched/grouped-GEMM shaped. +- Extended the registered fixed-slot context message strategy so forward-only + paths can defer recurrent value-bank materialization. The message strategy now + projects recurrent value in batch chunks and immediately feeds those chunks to + the weighted-message producer before dropping the chunk. +- Reverse-artifact/training cases still materialize the full recurrent value + bank when artifact rows require it. +- No semantic row, message rule, output route, scheduler formula, compatibility + wrapper, benchmark selector, or April21 code path was added. + +Intermediate steering row: + +```text +Pure dense-affine sender-value step: + artifact: tmp/fabric_audits/partials/2026-05-03/t1_sender_value_dense_affine_forward_h32_100m_b1024 + Axon forward: 7226.47 tok/s, 12.002 GiB + Decision: useful compute strategy, but not sufficient memory movement because the full recurrent value bank still lived across the message stage. +``` + +Validation: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + tests/test_fabric_backend_plan.py::test_registered_message_strategies_consume_compiler_sender_tables_not_geometry \ + --tb=short +# 4 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_deferred_recurrent_value_parity_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_deferred_recurrent_value_parity_20260503 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short +# 4 passed +``` + +Perf gate: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_deferred_recurrent_value_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_deferred_recurrent_value_20260503 \ +timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_deferred_recurrent_value_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none \ + --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Result table: + +```text +Row Tok/s Peak GiB Peak vs previous Native current owner +sLSTM forward 11808.85 7.139 flat output route, 4.327 GiB current +Axon forward 7217.04 11.248 -0.754 GiB native return, 9.252 GiB current +``` + +Owner movement: + +- Previous Axon transition-liveness row: `6843.29 tok/s`, `12.002 GiB`; + native current rose from `9.252 GiB` at entry to `11.002 GiB` after recurrent + K/V-before/message. +- This row: `7217.04 tok/s`, `11.248 GiB`; native current stays flat at + `9.252 GiB` through recurrent K/V-before, recurrent message, transition, + readout, and return. +- The former `+1.750 GiB` recurrent-value-bank current jump is gone. +- Peak is still `11.248 GiB`, with max allocation occurring outside the named + current-stage jump; the remaining gap is now native-entry/runtime-buffer and + high-water attribution, not the full recurrent value bank lifetime. + +Keep/revert decision: + +- Keep the patch. It is compiler-owned, parity-clean, improves Axon throughput, + and materially reduces Axon forward peak while keeping sLSTM memory flat. +- Do not call T=1 forward closed. The next forward owner is the remaining + native-entry/runtime-buffer high-water plus output/readout/current-state + buffers, not recurrent K/V-before/message materialization. + +Next exact forward owner: + +1. Build the next owner table from this accepted row and separate + `native_entry/runtime_buffer` from output/readout high-water. +2. Target the largest remaining forward current/high-water owner through + registered strategy or compiler memory-liveness rows only. +3. Keep training as a bounded parity/liveness guardrail until the forward owner + table moves again. + +### 2026-05-03 - Forward Fixed-Slot Message High-Water Attribution + +Status: telemetry kept; direct-projection and chunk-size probes reverted. + +Boundary manifest: + +```text +Lane: throughput strategy/native strategy implementation +Expected semantic delta: none +Unchanged semantic rows: fixed-slot context message rows, transition rows, readout rows, output routes, reset rows +Changed strategy/runtime rows: registered fixed-slot context message native implementation only +Rows consumed directly: forward executor rows, native callable binding rows, program tensor binding rows, runtime buffer rows +Memory/liveness rows consumed: forward_recurrent_msg, forward_recurrent_hidden_after, transition/output-route liveness rows +Old route replaced: none; telemetry only +Old route not allowed: per-edge hidden projection, fixed-slot ABI expansion, family/hidden-size/benchmark selectors, temporal-scheduler formulas, April21 code copy +``` + +Hypothesis: + +- The previous accepted Axon forward row stayed at `9.252 GiB` current allocation + through native stages but hit `11.248 GiB` max allocation. That points to a + temporary inside a native strategy, not a persistent runtime buffer. +- In the forward-only fixed-slot context message path, the compiler already + defers recurrent value-bank materialization. The strategy still builds a + weighted-value chunk and projected-message chunk before normalizing into the + compiler-owned `forward_recurrent_msg` runtime buffer. +- Added native sub-stage rows around output-weight transpose, weighted-value, + projected-message, and normalize points so the allocator owner is visible + while the tensors are live. + +Validation: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + tests/test_fabric_backend_plan.py::test_registered_message_strategies_consume_compiler_sender_tables_not_geometry \ + --tb=short +# 4 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_message_stage_telemetry_parity_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_message_stage_telemetry_parity_20260503 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short +# 4 passed + +Final exact-state parity also passed with +`TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_message_telemetry_final_parity_20260503`. +``` + +Perf/evidence run: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_message_stage_telemetry_20260503 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_message_stage_telemetry_20260503 \ +timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-03/t1_message_stage_telemetry_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none \ + --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Result: + +```text +Row Tok/s Peak GiB New current owner +sLSTM forward 11735.02 7.139 message_after_projected/normalize, 4.764 GiB current +Axon forward 7226.86 11.248 message_after_projected/normalize, 10.250 GiB current +``` + +Rejected probes from this pass: + +- Direct projection into `forward_recurrent_msg` plus in-place normalize: + parity passed, but Axon stayed at `11.248 GiB` and throughput was slightly + lower (`7199.23 tok/s`). Reverted. +- Smaller projected-message chunk (`256 MiB` for the forward-only deferred path): + moved Axon projected current from `10.250 GiB` to `10.001 GiB`, but + `cuda_max_allocated_bytes` stayed `11.248 GiB` and throughput was slightly + lower (`7204.76 tok/s`). Reverted. +- Smaller recurrent-value chunk (`256 MiB` for the forward-only deferred path): + moved Axon projected/normalize current to `9.751 GiB`, but + `cuda_max_allocated_bytes` stayed `11.248 GiB` and throughput dropped to + `7171.82 tok/s`. Reverted. + +Keep/revert decision: + +- Keep the telemetry rows. They are compiler-owned metadata, parity-clean, and + identify a real native current owner without changing semantics. +- Do not keep the direct projection or chunk-size probes because they did not + move the global forward peak. + +Next exact forward owner: + +1. The persistent Axon forward current owner is now visible: + `native_forward_message_after_projected_local0=10.250 GiB` over a + `native_forward_entry=9.252 GiB` base. +2. The remaining `11.248 GiB` max-allocated gap is not explained by persistent + projected/recurrent-value current after the attempted chunk reductions. The + next pass should name the transient allocator owner inside the projection + call itself, likely GEMM/cuBLAS workspace or an allocator high-water before + the post-op stage can observe current allocation. +3. Do not add more alias/no-copy or chunk-size probes until that transient owner + is named. + +### 2026-05-04 - Plan: Forward Message Projection High-Water Attribution + +Status: completed, narrowed. This telemetry pass named the persistent message +current owner, but later high-water attribution showed transition scratch was +the larger global peak owner. + +Boundary manifest: + +```text +Lane: throughput strategy/native strategy implementation +Expected semantic delta: none +Unchanged semantic rows: fixed-slot context message rows, transition rows, readout rows, output routes, reset rows +Changed strategy/runtime rows: registered fixed-slot context message native implementation telemetry only +Rows consumed directly: forward executor rows, native callable binding rows, program tensor binding rows, runtime buffer rows +Memory/liveness rows consumed: forward_recurrent_msg, forward_recurrent_hidden_after, transition/output-route liveness rows +Old route replaced: none in this pass; attribution before further strategy replacement +Old route not allowed: alias/no-copy expansion, per-edge hidden projection, fixed-slot ABI expansion, family/hidden-size/benchmark selectors, temporal-scheduler formulas, April21 code copy +``` + +Hypothesis: + +- The current Axon h32 100M B1024 T=1 forward row peaks at `11.248 GiB`. +- Persistent native current is visible at + `native_forward_message_after_projected_local0=10.250 GiB`, above + `native_forward_entry=9.252 GiB`, but previous chunk-size probes did not move + global peak. +- The remaining high-water may be inside the projected-message GEMM allocation + or cuBLAS/workspace path, before post-op coarse stage rows see the transient + owner. + +Instrumentation plan: + +- Add registered native forward stage rows around weighted-value allocation, + output-weight transpose, projected-message GEMM, projected-message + contiguous materialization, and normalize/writeback. +- Keep these rows metadata-only; they are stripped into audit metadata and do + not become semantic returns. + +Keep/narrow/revert rule: + +- Keep the telemetry if parity is green and it names the owner. +- Accept a strategy change only if Axon forward peak drops below `11.248 GiB` + or the hidden allocator owner becomes clearly actionable. +- Revert any non-telemetry probe that only relabels stages, regresses + throughput, or leaves `cuda_max_allocated_bytes` unchanged. + +Validation: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + tests/test_fabric_backend_plan.py::test_registered_message_strategies_consume_compiler_sender_tables_not_geometry \ + --tb=short +# 4 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_message_projection_telemetry_parity_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_message_projection_telemetry_parity_20260504 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short +# 4 passed +``` + +Perf/evidence run: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_message_projection_telemetry_perf_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_message_projection_telemetry_perf_20260504 \ +timeout 600s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_message_projection_highwater_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none \ + --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Result: + +```text +Row Tok/s Peak GiB Peak vs previous Current owner +sLSTM forward 11418.22 7.139 flat message projected/normalize, 4.764 GiB current +Axon forward 7192.68 11.248 flat message projected/normalize, 10.250 GiB current +``` + +Native current-stage attribution: + +```text +Axon native_forward_entry_local0 9.252 GiB current +Axon message_before_weighted_value_alloc_local0 9.751 GiB current +Axon message_after_weighted_value_alloc_local0 9.751 GiB current +Axon message_before_output_weight_local0 9.751 GiB current +Axon message_before_projected_gemm_local0 9.751 GiB current +Axon message_after_projected_gemm_local0 10.250 GiB current +Axon message_after_projected_contiguous_local0 10.250 GiB current +Axon message_before_normalize_local0 10.250 GiB current +Axon message_after_normalize_local0 10.250 GiB current +Axon native_forward_return_local0 9.252 GiB current +``` + +Interpretation: + +- The projected-message GEMM output accounts for the visible `+0.499 GiB` + current increase. +- No additional live GEMM/cuBLAS workspace is visible in current allocator + telemetry after the GEMM returns. +- Native C++ stage rows use allocator stats that report a global/stale + `max_allocated` value. They are reliable for current/reserved attribution but + do not identify the first scoped `cuda_max_allocated_bytes` jump. The Python + audit peak is still `11.248 GiB`. +- The remaining peak is therefore not solved by projected-message chunking or + direct projection; those were already rejected because `cuda_max_allocated` + did not move. + +Keep/revert decision: + +- Keep this telemetry. It is compiler-owned metadata, parity-clean, and proves + the projected-message output is only a `0.499 GiB` visible current owner, not + the whole `11.248 GiB` peak. +- Do not add another chunk-size or alias/no-copy probe from this point. + +Next exact forward owner: + +1. Move from projection-output tweaking to native-entry/runtime-buffer base + pressure. Axon enters native forward at `9.252 GiB` and exits at the same + current allocation, while runtime buffers account for `3.750 GiB`. +2. The next strategy should shorten or remove forward-only runtime buffers that + survive across the whole program, especially compiler-owned state/output + buffers that are not needed after their consumer runs. +3. If more peak attribution is needed, add a scoped Python-side peak reset or + per-stage peak probe in the benchmark/audit harness only as evidence, not as + an optimization. Do not alter backend execution policy to make the metric + easier to read. + +### 2026-05-04 - Plan: Forward Step-Local Runtime Buffer Liveness + +Status: completed, narrowed. + +Boundary manifest: + +```text +Lane: throughput strategy/native strategy implementation +Expected semantic delta: none +Primitive row fingerprint before/after: unchanged by this pass +Tensor-role/binding fingerprint before/after: unchanged by this pass +Unchanged semantic rows: message rows, transition rows, readout rows, output routes, reset rows, artifact routes +Changed strategy/runtime rows: runtime-buffer allocation mode for compiler-declared forward step-local buffers +Rows consumed directly: primitive rows, executor rows, native callable rows, program access rows, runtime buffer rows, memory liveness rows +Memory/liveness owner: compiler memory liveness plan +Old route replaced: none; this is a liveness strategy over the registered fused forward program +Old route not allowed: benchmark tiling, family/hidden-size selectors, April21 code copy, fixed-slot ABI expansion, alias/no-copy expansion +``` + +Hypothesis: + +- The latest h32 100M B1024 T=1 Axon forward row enters native forward at + `9.252 GiB`, peaks at `11.248 GiB`, and reports `3.750 GiB` of compiler + runtime buffers. +- The projected-message output only explains a visible `+0.499 GiB` current + increase, so the higher-impact owner is whole-program runtime-buffer + lifetime. +- For forward-only, no-final-state runs, these roles are consumer-scope values + and should not be eagerly allocated for the whole fused program: + `forward_recurrent_hidden_after`, `forward_recurrent_msg`, + `forward_output_msg`, `forward_output_cells`. + +Implementation plan: + +- Reuse the compiler-owned `deferred_local` allocation mechanism for + forward-only step-local buffers. +- Keep `output_seq` eager because it is the semantic return. +- Keep reverse artifacts and final-state buffers eager when + `collect_artifacts=True` or `materialize_final_state=True`. +- Materialize deferred forward buffers inside registered fused program role + lookups, so native strategies still consume runtime-buffer rows rather than + ad hoc allocations. +- Return an empty final recurrent-hidden tensor from the forward-only ABI when + no final state or reverse artifact consumer exists. + +Keep/narrow/revert rule: + +- Keep only if parity stays green and the forward owner table shows runtime + buffer bytes, native-entry current, or `cuda_max_allocated_bytes` moving down. +- Narrow if storage identity/lifetime moves but global peak remains unchanged. +- Revert if parity fails, if training artifact paths are affected, or if peak + memory rises with an unclassified owner. + +Implementation: + +- Added `defer_forward_step_buffers` to the compiler runtime-buffer plan. +- Forward-only, no-final-state runs now emit `deferred_local` allocation for + `forward_recurrent_hidden_after`, `forward_recurrent_msg`, + `forward_output_msg`, and `forward_output_cells`. +- Registered fused C++ role lookup now materializes deferred-local step buffers + at the consumer site, using the existing compiler-owned runtime-buffer rows. +- Forward-only ABI now returns an empty final recurrent-hidden tensor when no + final-state or reverse-artifact consumer exists. + +Validation: + +```bash +git diff --check -- \ + src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py \ + src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/memory_runtime_buffers.cuh \ + src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/forward_program.cuh \ + tests/test_fabric_backend_plan.py \ + tests/test_fabric_backend_boundaries.py \ + ai_docs/FABRIC_THROUGHPUT_CLOSURE.md +# clean + +python -m py_compile \ + src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py +# clean + +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_backend_plan.py::test_registered_message_strategies_consume_compiler_sender_tables_not_geometry \ + --tb=short +# 4 passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +# 2 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_forward_step_liveness_parity_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_forward_step_liveness_parity_20260504 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short +# 2 passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short +# 2 passed +``` + +Perf/evidence run: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_forward_step_liveness_perf_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_forward_step_liveness_perf_20260504 \ +timeout 900s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_forward_step_liveness_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none \ + --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Result: + +```text +Row Previous Tok/s New Tok/s Previous Peak New Peak Runtime Buffers Native Entry Current +sLSTM forward 11418.22 11723.23 7.139 GiB 6.951 GiB 1.313 -> 0.563 3.295 -> 3.014 GiB +Axon forward 7192.68 7199.35 11.248 GiB 10.998 GiB 3.750 -> 1.750 9.252 -> 7.252 GiB +``` + +Keep/narrow decision: + +- Keep the diff. The intended compiler-owned lifetime moved: + runtime-buffer allocation dropped by `0.750 GiB` for sLSTM and `2.000 GiB` + for Axon, native-entry current dropped, and global peak moved down. +- Treat as a narrow liveness win, not T=1 closure. Axon is still + `10.998 GiB` and `7199 tok/s`, far from the April 21 row. +- Next owner remains forward native live set after deferred liveness: + Axon still reaches `10.000 GiB` around message projected/normalize stages and + returns to `7.252 GiB`; the next strategy should reduce the projected/transition + producer-consumer live overlap or replace the current message/readout staging + with a row-owned fused/streaming strategy. + +### 2026-05-04 - Plan: Forward Message Projection Output Liveness + +Status: completed, narrowed. + +Boundary manifest: + +```text +Lane: throughput strategy/native strategy implementation +Expected semantic delta: none +Primitive row fingerprint before/after: unchanged by this pass +Tensor-role/binding fingerprint before/after: unchanged by this pass +Unchanged semantic rows: fixed-slot context message rows, transition rows, readout rows, output routes, reset rows, artifact routes +Changed strategy/runtime rows: fixed-slot context forward message native implementation only +Rows consumed directly: message executor rows, native callable rows, tensor binding/access rows, runtime-buffer rows, memory liveness rows +Memory/liveness owner: compiler memory liveness plan plus registered native message strategy local outputs +Old route replaced: full projected-message temporary before normalization when projection output can be written into the compiler-owned recurrent-message buffer +Old route not allowed: per-edge hidden projection, benchmark chunking, family/hidden-size selectors, April21 code copy, alias/no-copy expansion +``` + +Hypothesis: + +- Latest accepted h32 100M B1024 T=1 Axon forward reaches + `native_forward_message_after_projected_gemm_local0=10.000 GiB` and + `native_forward_message_after_normalize_local0=10.000 GiB`. +- The fixed-slot context message path still computes + `weighted_value -> projected temporary -> normalize into recurrent_msg`. +- For the common `value_dim != message_dim` row, the projected GEMM output can + be written directly into the compiler-owned `forward_recurrent_msg` runtime + buffer, then normalized in place. This preserves GEMM shape and avoids the + separate projected-message allocation. + +Keep/narrow/revert rule: + +- Keep only if parity stays green and Axon forward native projected/normalize + current allocation or global peak moves down without a throughput regression. +- Narrow if the storage/lifetime change is correct but only moves stage current + allocation, not global `cuda_max_allocated_bytes`. +- Revert if parity fails, if `value_dim == message_dim` aliasing is not guarded, + or if training artifact paths are affected. + +Implementation result: + +- Added a registered native-message projection path that writes non-aliasing + `weighted_value -> recurrent_msg` projections through the existing + compiler-owned dense-affine strategy and then normalizes `recurrent_msg` + in-place. +- Did not keep the row-local square in-place projection/normalization kernel as + an active default. It removed the square-case projected tensor, but it was too + slow and did not move the case-level peak. +- Current h32 T=1 single-pop workloads have `value_dim == message_dim`, so the + kept non-alias path does not close the active owner yet. + +Measured result: + +```text +Audit: tmp/fabric_audits/partials/2026-05-04/t1_message_projection_out_forward_h32_100m_b1024 +Change tested: ATen mm_out into recurrent_msg for non-alias projection +Result: rejected as owner closure; projected/normalize current allocation unchanged. + +Audit: tmp/fabric_audits/partials/2026-05-04/t1_message_projection_dense_affine_forward_h32_100m_b1024 +Change tested: registered dense-affine projection into recurrent_msg for non-alias projection +Result: kept narrowly; parity-clean, but current square workloads do not use it and owner allocation is unchanged. + +family tok/s peak GiB before_projected after_projected after_normalize +sLSTM 11444.84 6.951 4.077 4.577 4.577 +Axon 7207.21 10.998 9.501 10.000 10.000 + +Audit: tmp/fabric_audits/partials/2026-05-04/t1_message_project_norm_fused_forward_h32_100m_b1024 +Change tested: square value/message row-local fused project+normalize, no projected tensor +Result: rejected as default throughput strategy; stage current moved, global peak did not, throughput regressed. + +family tok/s peak GiB before_projected after_projected after_normalize +sLSTM 9859.22 6.951 4.077 4.077 4.077 +Axon 5210.68 10.998 9.501 9.501 9.501 +``` + +Verification: + +```text +git diff --check -- src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_forward_strategies.cuh src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_device_kernels.cuh ai_docs/FABRIC_THROUGHPUT_CLOSURE.md +PASS + +uv run pytest -q tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified --tb=short +2 passed in 5.14s + +CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_msg_projection_final_parity_20260504 TRITON_CACHE_DIR=/tmp/cortical_triton_msg_projection_final_parity_20260504 timeout 420s uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window --tb=short +2 passed in 61.22s + +uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program --tb=short +2 passed in 63.24s +``` + +Next owner: + +- The active square `value_dim == message_dim` fixed-slot context message path + still needs a fast registered strategy that preserves GEMM/BMM-class throughput + while eliminating the projected tensor. A naive row-local kernel is not good + enough. +- Because global peak stayed at `6.951 GiB` for sLSTM and `10.998 GiB` for Axon + even when the projected stage moved, the next forward owner must be chosen from + the full native live set, not only the message projection stage. + +### 2026-05-04 - Plan: Forward High-Water Attribution Before Next Kernel + +Status: completed. + +Boundary manifest: + +```text +Lane: throughput evidence/performance loop +Expected semantic delta: none +Primitive row fingerprint before/after: unchanged +Tensor-role/binding fingerprint before/after: unchanged +Changed runtime behavior: none +Changed audit behavior: memory ledger records first max-allocated owner and per-stage max-allocated deltas +Old route not allowed: benchmark-side tiling, family/hidden-size selectors, extra alias/no-copy metadata, April21 code copy +``` + +Hypothesis: + +- The prior message-projection attempts proved that current allocated bytes move + around message projected/normalize stages, but global peak can remain flat. +- The audit ledger was selecting the last stage tied at global + `max_allocated`, which can mislabel the owner as a late return/finalization + stage. +- Before adding another registered CUDA strategy, the owner table needs two + sharper facts: + `fabric_registered_backward_first_peak_stage.*` and + `fabric_registered_backward_peak_stage_by_max_delta.*`. + +Implementation plan: + +- Keep native execution unchanged. +- Add benchmark-ledger attribution for first high-water stage and max-allocated + delta by stage. +- Run the smallest representative forward audit: h32, 100M, B1024, T=1, + single-pop, forward-only, sLSTM and Axon. +- Choose the next strategy from the measured owner: + message/readout/transition/liveness/allocator-reserve. Do not add another + CUDA kernel if the high-water delta points somewhere else. + +Implementation/result: + +- Added benchmark ledger fields: + `fabric_registered_backward_first_peak_stage.*`, + `fabric_registered_backward_first_peak_stage_max_allocated_bytes`, + `fabric_registered_backward_stage_max_delta_bytes.*`, and + `fabric_registered_backward_peak_stage_by_max_delta.*`. +- Imported native forward memory-stage rows before Python post-call/final-table + markers so the first high-water owner is not hidden by broad Python markers. +- Attribution audit: + `tmp/fabric_audits/partials/2026-05-04/t1_forward_highwater_attribution_native_first_h32_100m_b1024`. + +Measured owner table: + +```text +Row Tok/s Peak GiB First high-water owner Largest max-delta owner +sLSTM forward 11737.63 6.951 native_forward_after_transition native_forward_after_transition, +2.375 GiB +Axon forward 7220.52 10.998 native_forward_after_transition native_forward_message_before_weighted_value_alloc, +2.249 GiB +``` + +Decision: + +- Transition scratch is a real high-water owner for both families and is the + first owner for the final peak. +- Message weighted/projected remains the largest persistent current owner after + transition is narrowed. + +### 2026-05-04 - Plan: Forward Transition Scratch Liveness Strategy + +Status: completed, kept. + +Boundary manifest: + +```text +Lane: throughput strategy/native strategy implementation +Expected semantic delta: none +Primitive row fingerprint before/after: unchanged +Tensor-role/binding fingerprint before/after: unchanged +Changed strategy/runtime rows: registered transition forward row-group scratch chunk target +Rows consumed directly: transition primitive rows, native callable output rows, program tensor bindings, runtime-buffer rows +Memory/liveness owner: registered native transition strategy local scratch buffers +Old route replaced: none; strategy-local temporary lifetime narrowed +Old route not allowed: benchmark tiling, family/hidden-size selectors, April21 code copy, semantic row changes +``` + +Implementation: + +- Changed registered transition row-group scratch target from `1 GiB` to + `256 MiB` for both gated-logspace and diagonal-RTU forward row groups. +- This keeps the same compiler-emitted primitive rows and executor bindings; it + only narrows native scratch lifetime inside the registered strategy. + +Perf/evidence run: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_forward_transition_scratch_256m_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_forward_transition_scratch_256m_20260504 \ +timeout 900s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_forward_transition_scratch_256m_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward --batches 1024 \ + --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none \ + --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Result: + +```text +Row Previous Tok/s New Tok/s Previous Peak New Peak New first high-water owner +sLSTM forward 11444.84 11783.81 6.951 GiB 5.260 GiB native_forward_after_transition +Axon forward 7207.21 7146.38 10.998 GiB 10.000 GiB native_forward_message_after_projected_gemm +``` + +Verification: + +```text +python -m py_compile benchmarks/fabric/suite_common.py +# clean + +python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py +# clean + +git diff --check -- \ + benchmarks/fabric/suite_common.py \ + tests/test_fabric_benchmark_suite_common.py \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py \ + src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_forward_program.cuh \ + ai_docs/FABRIC_THROUGHPUT_CLOSURE.md +# clean + +uv run pytest -q \ + tests/test_fabric_benchmark_suite_common.py::test_compiler_memory_ledger_records_first_peak_and_max_delta_stage \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + --tb=short +# 3 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_transition_scratch_parity_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_transition_scratch_parity_20260504 \ +timeout 420s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + --tb=short +# 2 passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short +# 2 passed +``` + +Keep/narrow/revert decision: + +- Keep. The strategy moved the actual T=1 forward high-water owner: + `-1.691 GiB` for sLSTM and `-0.998 GiB` for Axon, with parity green. +- Treat the Axon throughput dip as acceptable for this memory-owner pass, but + not as throughput closure. +- Next forward owner is the message weighted/projected live set: + Axon now peaks at `10.000 GiB`, with current owner + `native_forward_message_after_normalize_local0=10.000 GiB` and largest + max-delta owner + `native_forward_message_before_weighted_value_alloc_local0=2.249 GiB`. + +### 2026-05-04 - Current T=1 Throughput Deep Dive After Transition Scratch + +Status: analysis only. No optimization was run in this pass. + +Reference target: + +- April 21 `h32_t1_bxparams`: `58732.71 tok/s`, `2.07 GiB`. +- This remains the summary floor for T=1 closure, not only a steering row. + +Current-code audit: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_current_deepdive_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_current_deepdive_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_current_deepdive_after_transition_scratch_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward,forward_backward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +# status=ok, cases=4 +``` + +Current owner table: + +| Row | tok/s | vs Apr21 | slowdown | Peak GiB | Memory vs Apr21 | Runtime GiB | Artifact GiB | Unclassified GiB | Current owner | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | --- | +| sLSTM forward | `11614.43` | `19.78%` | `5.06x` | `5.260` | `2.54x` | `0.563` | `0.000` | `4.322` | `native_forward_message_after_normalize_local0` | +| sLSTM train | `13.31` | `0.023%` | `4411.65x` | `25.049` | `12.10x` | `7.500` | `4.500` | `12.298` | `native_after_recurrent_message_local0` | +| Axon forward | `7114.53` | `12.11%` | `8.26x` | `10.036` | `4.85x` | `1.750` | `0.000` | `7.910` | `native_forward_message_after_normalize_local0` | +| Axon train | `3.99` | `0.0068%` | `14734.73x` | `81.319` | `39.28x` | `12.500` | `9.500` | `58.567` | `native_after_recurrent_message_local0` | + +Important current facts: + +- Forward is still the first optimization lane. Training is orders of magnitude + slower, but every training row pays the forward message live-set cost first. +- Transition scratch liveness moved the forward high-water owner, but the active + forward owner is now message weighted/projected/normalized live set. +- Axon forward current reaches `10.036 GiB` at + `native_forward_message_after_projected_gemm_local0` and remains there through + normalize. Largest forward max-delta is + `native_forward_message_before_weighted_value_alloc_local0=2.250 GiB`. +- sLSTM forward current reaches `4.577 GiB` at the same message + projected/normalize stages. Largest forward max-delta is + `native_forward_message_before_weighted_value_alloc_local0=1.062 GiB`. +- Training current owner is not forward anymore; it is registered backward + recurrent-message/recurrent-KV/transition span materialization: + `native_after_recurrent_message_local0`, `native_after_recurrent_kv_local0`, + `native_after_boundary_kv_local0`, and transition reverse group stages. +- Training unclassified/high-water memory remains very large: + `12.298 GiB` for sLSTM and `58.567 GiB` for Axon. + +Remaining T=1 work, highest impact first: + +1. **Forward message producer-consumer strategy.** + - The next registered strategy should reduce the + `weighted_value -> projected/normalized recurrent_msg` live set without + changing semantic rows. + - It must preserve GEMM/BMM-class structure. Prior row-local fused projection + removed the temporary but regressed throughput and was rejected. + - Likely direction: grouped/BMM or streaming producer-consumer strategy that + writes/normalizes route-owned recurrent message chunks while shortening + weighted-value and projected-message lifetimes. + +2. **Forward message/readout/transition overlap.** + - After the message projection owner moves, the next forward live set is the + overlap between recurrent message, transition public state, output message, + and output cells. + - Candidate direction: route-owned direct readout/output strategy that avoids + keeping recurrent message/output cells live past their consumers when the + compiler output route proves terminal-only semantics. + +3. **Training reverse recurrent-message/KV liveness.** + - Once forward moves again, training should target the reverse native stages + around recurrent message, recurrent K/V, boundary K/V, and transition + reverse group outputs. + - The change must be reducer/liveness-row owned: route adjoints directly into + reducer/runtime/workspace destinations instead of materializing full banks + and span outputs. + +4. **Training artifact/tape reduction.** + - Current training still carries large reverse artifacts: + `transition_state_before`, `recurrent_msg_backend_order`, + `recurrent_hidden_backend_order`, and + `recurrent_hidden_before_backend_order`. + - These should be recomputed, compacted, or consumed through artifact-route + rows only when backward legality requires them. + +5. **Coverage expansion after the forward owner moves.** + - Current table is still h32, 100M, B1024, single-pop only. + - Closure still requires the April21-shaped matrix: + `100M/500M/1B`, `B=1024/16384`, sLSTM + Axon, forward + training. + - Additional required axes before closure: mixed-pop T=1, h4/h8/h16 stress, + reset-present, final-state/materialized-state, high-batch small-param rows, + and the dot-product semantic stress case. + +Conclusion: + +- The biggest next T=1 change is not broad training optimization yet. It is a + forward registered message producer-consumer strategy that keeps compiler rows + stable while reducing the message weighted/projected live set. +- Training remains a severe blocker for full T=1 closure, but it should be + treated as parity/liveness guardrail until the active forward owner moves + again. + +### 2026-05-04 - Plan: Forward Message Producer-Consumer Strategy + +Status: implementation pass starting. This is throughput strategy work, not a +semantic compiler-extension pass. + +Boundary packet: + +- Unchanged semantic rows: message primitive rows, message executor rows, + tensor-role bindings, output routes, artifact routes, reset rows, and + transition/readout primitive rows. +- Changed strategy/runtime rows: the registered native forward fixed-slot + context message implementation may change its workspace/liveness strategy for + the existing message row group. +- Rows consumed directly: native message strategy row, forward executor rows, + tensor binding rows, program access rows, runtime buffer rows, memory stage + rows, and compiler-owned recurrent-message runtime buffer. +- Old route being replaced/narrowed: the square `value == message` path that + aliases weighted attention output onto the final recurrent-message buffer and + then materializes projected-message chunks before normalization. +- No semantic delta: the attention, projection, and normalization math must stay + the same; this pass only shortens producer-consumer lifetime. + +Hypothesis: + +- Current forward peaks are dominated by the live overlap of recurrent message, + weighted value, projected message, and transition/readout consumers. +- For fixed-slot context rows where `value == message`, writing weighted value + into the final recurrent-message buffer forces a full-message live set before + projection can overwrite it. +- A registered strategy-owned bounded weighted-value workspace can stream + `weighted_value_chunk -> dense_affine_out -> normalize_inplace` into the + route-owned recurrent-message chunk. This keeps GEMM/BMM-class projection and + avoids the previously rejected row-local fused projection kernel. + +Expected owner movement: + +- `native_forward_message_after_normalize_local0` should drop from the current + `5.260 GiB` sLSTM / `10.036 GiB` Axon row. +- The largest max-delta stages should move away from + `native_forward_message_before_weighted_value_alloc_local0` and + `native_forward_message_after_projected_gemm_local0`. +- Throughput should not materially regress versus the current rows: + sLSTM `11614.43 tok/s`, Axon `7114.53 tok/s`. + +Keep/narrow/revert rule: + +- Keep if the named message owner or peak allocated memory physically moves and + targeted parity remains green. +- Narrow if only one family benefits but the change is still row-owned and + legal through strategy selection. +- Revert if the owner does not move, if throughput regresses like the rejected + row-local fused projection probe, or if the implementation needs family, + benchmark, hidden-size, or scheduler-owned formula selectors. + +Acceptance rows: + +- High-level forward current-code rows: + sLSTM and Axon, single-pop, `100m`, `h32`, `B=1024`, `T=1`, terminal output. +- Parity/liveness guardrails: + fused forward artifact-store path, final-state-only zero-output-grad path, and + the fixed-slot context nudge/gate registered training gates. + +### 2026-05-04 - Result: Forward Message Producer-Consumer Strategy + +Status: accepted narrow forward liveness improvement. + +Implementation: + +- Changed only the registered fixed-slot context forward message native strategy. +- The strategy now uses a bounded weighted-value producer workspace for + `value == message` rows, then streams + `weighted_value_chunk -> dense_affine_out -> normalize_inplace` into the + compiler-owned recurrent-message runtime buffer. +- Strategy-local chunk target: `768 MiB`. +- Semantic rows, tensor bindings, output routes, artifact routes, reset rows, + and transition/readout rows are unchanged. +- No scheduler formula, benchmark policy, family selector, hidden-size selector, + or April21 code copy was added. + +Accepted artifact: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_msg_pc_768_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_msg_pc_768_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_forward_message_producer_consumer_768mb_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 2 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +Result versus the current post-transition-scratch owner table: + +| Row | Previous tok/s | New tok/s | Previous peak GiB | New peak GiB | Previous message after-normalize GiB | New message after-normalize GiB | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| sLSTM forward | `11614.43` | `11786.89` | `5.260` | `5.260` | `4.577` | `4.077` | +| Axon forward | `7114.53` | `7152.40` | `10.036` | `9.501` | `10.036` | `9.501` | + +Rejected steering variant: + +- `512 MiB` chunk target: + `tmp/fabric_audits/partials/2026-05-04/t1_forward_message_producer_consumer_h32_100m_b1024`. +- It moved Axon memory to the same `9.501 GiB` peak but slowed the row: + sLSTM `11417.57 tok/s`, Axon `6858.88 tok/s`. +- Rejected because `768 MiB` preserves the memory move with less launch overhead. + +Parity/source gates: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_msg_pc_768_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_msg_pc_768_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program \ + tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program \ + --tb=short +# 6 passed +``` + +Keep/narrow/revert decision: + +- Keep. The strategy moved the physical message-stage live set and the Axon + forward peak without throughput regression. +- This is not T=1 closure. The row remains far below April 21 + `58732.71 tok/s`, `2.07 GiB`. +- Next forward owner: route-owned message/readout/transition live overlap after + recurrent message. The accepted row now peaks at readout projection for sLSTM + and still peaks in message normalize for Axon, so the next pass should target + direct producer-consumer handoff from recurrent message into transition/readout + consumers or route-owned output materialization, not broad training backward. + +### 2026-05-04 - Strategic Correction: Reconstruct April21 Physical Execution Shape + +Status: strategic correction before the next implementation pass. Do not start +another narrow performance patch until this section and the hypothesis packet +below are the active plan. + +Fresh current-code audit: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_t1_deepdive_current_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_t1_deepdive_current_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_current_deepdive_after_message_pc_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward,forward_backward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +# status=ok, cases=4 +``` + +Current local wins: + +- The recent registered-strategy work made valid local liveness improvements. +- sLSTM forward memory is now around `5.26 GiB`. +- Axon forward memory is now around `9.5 GiB`. +- Forward throughput is still only around `11.6k tok/s` for sLSTM and + `7.1k tok/s` for Axon, versus the April 21 `58.7k tok/s` floor. +- Training remains catastrophically far behind: + about `13 tok/s` for sLSTM and `4 tok/s` for Axon. + +Current owner table: + +| Row | tok/s | vs Apr21 | slowdown | Peak GiB | Memory vs Apr21 | Runtime GiB | Artifact GiB | Unclassified GiB | Current owner | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | --- | +| sLSTM forward | `11640.60` | `19.82%` | `5.05x` | `5.260` | `2.54x` | `0.563` | `0.000` | `4.322` | `native_forward_after_readout_projection_local0` | +| sLSTM train | `13.37` | `0.0228%` | `4393.77x` | `25.049` | `12.10x` | `7.500` | `4.500` | `12.298` | `native_after_recurrent_message_local0` | +| Axon forward | `7128.07` | `12.14%` | `8.24x` | `9.534` | `4.61x` | `1.750` | `0.000` | `7.408` | `native_forward_message_after_normalize_local0` | +| Axon train | `4.01` | `0.0068%` | `14657.62x` | `81.319` | `39.28x` | `12.500` | `9.500` | `58.567` | `native_after_recurrent_message_local0` | + +Diagnosis: + +- The current implementation is still materializing too many compiler stage + boundaries. +- The real missing thing is likely a streaming-step physical execution strategy + whose first acceptance row is T=1/K=1, not another isolated buffer shortcut. +- April 21 should be treated as evidence for a low-live-memory physical + execution model: + `message -> transition/readout -> output` with minimal materialization, then + matching backward/reducer liveness. +- The current compiler semantics must remain compiler-owned: + no April21 code copy, no benchmark tiling, no family/shape/hidden-size + selectors, and no hidden CUDA-only scheduler policy. +- The local one-buffer wins are useful, but they are not converging fast enough + toward April 21 because they keep the same broad materialized stage shape. + +New direction: + +- Stop expanding narrow message projection, alias, no-copy, or one-buffer + liveness probes unless they are explicitly part of the broader physical + execution strategy below. +- The next plan must define a registered streaming-step physical + producer-consumer strategy across message, transition, readout, and output + routes. T=1/K=1 is only the first row through that step body. +- That strategy must name the compiler rows, tensor bindings, liveness rows, + artifact routes, output routes, and reducer routes it consumes. +- It must state which semantic row fingerprints remain stable. +- It must identify which tensors cease to be fully materialized and where their + consumers move. + +Required hypothesis packet before implementation: + +- April21 mechanism being semantically transferred. +- Current compiler products that express the mechanism. +- Rows/fingerprints expected to stay stable. +- Runtime/liveness/artifact rows consumed. +- Small synthetic probe, if useful. +- First high-level representative row. +- Keep/narrow/revert rule. +- Parity and boundary gates. + +Hypothesis packet for the next pass: + +- April21 mechanism being semantically transferred: + a low-live-memory streaming physical step body with producer-consumer handoff + across message, transition, readout, and output, rather than independent + full-stage materialization. April21 is evidence for this execution shape only; + no April21 code, benchmark tiling, family/shape selectors, or fixed-slot route + identities may be copied. +- Current compiler products that express it: + primitive rows, forward executor rows, native strategy rows, tensor binding + rows, program access rows, runtime buffer rows, memory liveness rows, + artifact route rows, output route rows, transition carry rows, and parameter + reducer rows. +- Rows/fingerprints expected to stay stable: + message primitive rows, transition primitive rows, readout primitive rows, + tensor-role bindings, output routes, reset rows, and reducer semantics. + Expected semantic delta is `none`. +- Runtime/liveness/artifact rows consumed: + forward recurrent-message runtime buffer, transition state/public-output + runtime buffers, output-message/output-cell runtime buffers, final-state + materialization policy, reverse artifact routes, transition-state-before + artifact policy, recurrent-message artifact policy, and parameter reducer + liveness rows. +- Tensors that should stop being fully materialized in the streaming physical + step strategy: + recurrent message after its transition/readout consumers finish, + transition linear/diag/matmul/norm intermediate outputs when their next + primitive consumes them immediately, output message/output cells when the + output route is terminal-only, and reverse span outputs that can be routed + directly into reducers or compact artifacts. +- Consumer movement: + transition and readout should consume route-owned message slices directly; + output projection should consume route-owned output cells/message without + preserving unrelated full banks; backward should route adjoints directly to + reducer/runtime/artifact destinations instead of rebuilding full forward + stage boundaries. +- T/K extension: + the same physical step body runs once for `T=1,K=1`; loops over outer time + for `T>1,K=1`; and loops over both outer time and inner microsteps for + `T>1,K>1`. Carried state, reset handling, output-route emission, + artifact/tape materialization, checkpoint/recompute windows, and reducer + liveness must be selected by compiler rows, not by benchmark policy or a + separate temporal route identity. +- Small synthetic probe, if useful: + a tiny high-level T=1 program with the same compiler route that emits memory + stages for `message -> transition -> readout -> output` and verifies that a + selected producer-consumer strategy removes at least one full-stage allocation + without changing primitive-row fingerprints. +- First high-level representative row: + sLSTM and Axon, single-pop, `100m`, `h32`, `B=1024`, `T=1`, forward only, + then the same rows with `forward_backward` after forward owner movement is + proven. +- T/K streaming guardrail: + after the first accepted T=1/K=1 forward row, run a small high-level + `T>1,K>1` row through the same compiler-owned route and prove no new route + identity, benchmark-owned time chunking, detach policy, Python replay, or full + `[T, cells, state]` materialization appears. For T, compare primarily to the + same-row Fabric T=1 per-token line. For K, compare to the matched Fabric + `T=1,K=1` training line divided by K. +- Keep/narrow/revert rule: + keep only if the named current owner moves physically in current allocated + bytes, peak allocated bytes, or launch/lifetime telemetry while parity stays + green; narrow if only one surface benefits but the strategy remains + row-owned; revert if the change is metadata-only, throughput regresses + materially, or the implementation needs benchmark, family, hidden-size, + fixed-slot, or scheduler-owned formula selectors. +- Parity and boundary gates: + focused source guardrails for compiler-owned rows/bindings/liveness, no + hidden fallback/compat path, no scheduler formulas, CUDA/reference parity for + outputs/final state/input gradients/parameter gradients on touched rows, and + the high-level audit path with `--require-cuda-temporal-owner`. + +Strategic consequence: + +- The next implementation pass should not be another isolated buffer tweak. + It should start by designing the registered streaming-step physical + producer-consumer strategy and the exact row/liveness contract that lets + message, transition, readout, output, backward artifacts, and reducers share a + compact physical schedule while preserving compiler-owned semantics across + T=1/K=1, T>1/K=1, and T>1/K>1. + +### 2026-05-04 - T x K Streaming Constraint For T=1 Work + +Status: architectural constraint for the next physical strategy. + +The T=1 optimization must not become a one-step special case. The intended +physical model is that `T x K` is efficient streaming of the same T=1 step +kernel over time and inner steps. A good T=1 strategy should therefore be the +single-step body of the streaming executor, not a terminal-only shortcut that +has to be thrown away for sequence rows. + +Required design constraint: + +- Treat T=1 as the streaming physical step unit for T x K. +- The registered producer-consumer handoff across + `message -> transition/readout -> output` must be expressible inside the + temporal scan loop with carried state, reset policy, output-route policy, + artifact/tape policy, and reducer liveness owned by compiler rows. +- The same strategy record should explain: + `T=1,K=1`, `T>1,K=1`, and `T>1,K>1`; differences should be loop count, + checkpoint/recompute/tape policy, and output emission policy, not a separate + semantic or scheduler route. +- Do not encode a T=1-only terminal-output special path unless the compiler + legality rows explicitly prove it is a specialization of the streaming step + and the generic streaming path remains the owner for T/K rows. +- The T=1 physical strategy must not depend on benchmark-owned time chunking, + detach policy, private replay loops, or full `[T, cells, state]` + materialization. + +Implication for the next plan: + +- The "compact T=1 physical execution" plan should be written as a + "streaming-step physical execution" plan. +- The first acceptance row can remain the h32 100M B1024 T=1 forward steering + row, but the design packet must include the follow-up T/K guardrail: + run a small `T>1,K>1` high-level row through the same compiler-owned route and + prove no new route identity, benchmark policy, full time-surface + materialization, or Python replay owner appears. +- For large T, judge throughput primarily against the same-row Fabric T=1 + per-token line. It should stay flat or improve as T streams, subject to + expected output/loss emission and checkpoint policy. +- For K, compare to the matched current-code Fabric `T=1,K=1` training line + divided by K. K is repeated temporal work, not a separate model-capacity axis. + +### 2026-05-04 - Next Plan: Streaming-Step Physical Execution Strategy + +Status: active plan constraint before the next implementation pass. Any +existing work-in-progress that only describes a compact T=1 strategy must be +reframed as this streaming-step plan before more code changes continue. + +Goal: + +- Implement a compiler-owned streaming physical step strategy whose first + acceptance row is the h32 100M B1024 T=1 forward row, but whose legality and + ABI are explicitly the same route used by T>1/K=1 and T>1/K>1. + +April21 mechanism being semantically transferred: + +- Low-live-memory physical execution shape: + `message producer -> transition/readout consumers -> output route`, followed + by matching backward artifact/tape/reducer liveness. +- The transfer is semantic and structural only. The current compiler rows own + meaning; April21 is evidence for producer-consumer scheduling and low live + memory, not source to copy. + +Compiler products consumed: + +- Primitive rows and primitive-row fingerprints. +- Forward/reverse executor rows and native strategy rows. +- Tensor binding rows and program access rows. +- Memory liveness rows and runtime schedule rows. +- Physical strategy rows for the streaming-step route. +- Forward artifact route rows and artifact merge rows. +- Forward output route rows. +- Reverse artifact consumer route rows. +- Reverse parameter reducer route rows and transition parameter-gradient + binding rows. +- Reset rows, output-emission rows, checkpoint/recompute rows, and tape policy + rows from the scheduler/memory plan. + +Semantic fingerprints expected to stay stable: + +- Message primitive rows and tensor-role bindings. +- Transition primitive rows and tensor-role bindings. +- Readout primitive rows and tensor-role bindings. +- Output route semantics. +- Reset semantics. +- Artifact/tape role semantics. +- Reducer and parameter-gradient semantics. +- Expected semantic delta: `none`. + +Tensors that stop being fully materialized: + +- Recurrent message banks after transition/readout route consumers complete. +- Transition primitive intermediates that have a single next primitive consumer. +- Output message/output cells when output routes can consume them immediately. +- Reverse span outputs that can be routed directly into reducer/runtime/artifact + destinations. +- Full time-surface `[T, cells, state]` tensors unless a compiler artifact/tape, + output, checkpoint, or user-state row requires them. + +Consumer movement in the streaming step: + +- Transition consumers move to route-owned recurrent-message slices or chunks. +- Readout consumers move to route-owned message/output slices or chunks. +- Output routing moves to immediate route materialization or reduction. +- Backward consumers move to declared artifact/tape/reducer/runtime buffers + instead of rebuilding full forward stage boundaries. +- State carry updates remain the streaming loop carry, not a hidden global bank. + +T/K extension: + +- `T=1,K=1`: run one streaming physical step body. +- `T>1,K=1`: repeat the same body over outer time with carried recurrent/public + state, output emission controlled by output-route rows, and checkpoint/tape + policy controlled by memory/runtime rows. +- `T>1,K>1`: repeat the same body over outer time and inner microsteps; K only + changes step count and message-step index, not route identity. +- T and K must not introduce benchmark-owned time chunking, detach policy, + Python replay, or a second terminal-only path. + +Required guardrails: + +- First acceptance row may be h32 100M B1024 T=1 forward. +- Follow-up guardrail must be a small high-level T>1/K>1 row through the same + compiler-owned route. +- Prove no new route identity, benchmark-owned time chunking, detach policy, + Python replay, or full `[T, cells, state]` materialization appears. +- For T, compare primarily to the same-row Fabric T=1 per-token line. +- For K, compare to the matched Fabric T=1/K=1 training line divided by K. +- Keep/narrow/revert rule remains: keep only if the named physical owner moves; + narrow if the route is still generic but only one surface benefits; revert if + the patch is metadata-only or creates a terminal-only T=1 shortcut. + +### 2026-05-04 - Result: Streaming-Step Physical Strategy Contract + +Status: accepted as compiler/ABI contract groundwork only. This is not a +throughput win and does not claim T=1 closure. + +Implemented compiler products: + +- Added `physical_strategy_rows` as a concrete compiler launch table. +- The active executable row is `stage_materialized`, matching current runtime + behavior. +- The next strategy row is + `streaming_step_producer_consumer`, currently blocked with + `pending_registered_streaming_step_program_body`. +- The row records physical step count, inner-step count, output-boundary + policy, reset policy, required producer/consumer surfaces, consumed compiler + tables, executable status, and blocker code. +- The row exists for T=1/K=1 and T/K rows; it is not a terminal-only T=1 + shortcut. + +ABI/route wiring: + +- Forward registered program launch now requires `physical_strategy_rows`. +- Backward registered program launch now requires `physical_strategy_rows`. +- Forward and backward native validation check that: + - exactly one active executable physical strategy exists; + - the streaming-step producer-consumer strategy row exists; + - physical step count matches `memory_runtime_schedule_rows`; + - rows declare message, transition, readout, artifact, and reducer surfaces; + - rows consume primitive, executor, binding, memory/liveness, artifact-route, + output-route, and runtime-schedule tables. +- The forward artifact store records physical strategy fingerprints and rows so + backward rejects stale or mismatched strategy rows before launch. +- Runtime metadata now emits `flat_bucket_temporal_physical_strategy:*` and + `flat_bucket_temporal_physical_strategy_rows:*`. + +Semantic stability: + +- Primitive rows, tensor bindings, output routes, artifact routes, reset + meaning, tape meaning, and reducer semantics are unchanged. +- Expected semantic delta: `none`. + +Validation: + +```bash +python -m py_compile \ + src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py \ + tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py +``` + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names \ + tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes \ + tests/test_fabric_backend_plan.py::test_temporal_backward_validates_memory_artifact_plan_fingerprint \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +# 4 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_step_contract_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_stream_step_contract_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + --tb=short +# 1 passed +``` + +Next owner: + +- Implement the blocked `streaming_step_producer_consumer` program body as a + registered physical strategy over the existing rows. The first acceptance row + remains h32 100M B1024 T=1 forward, followed by the required small high-level + T>1/K>1 same-route guardrail. + +### 2026-05-04 - Implementation Pass: Activate Streaming-Step Forward Contract + +Status: in progress. This pass is throughput/native strategy work only; expected +semantic delta is `none`. + +Boundary manifest: + +- Lane: registered throughput/native strategy over existing compiler products. +- Declaration/spec owner: unchanged Fabric graph/cell/message/readout + declarations. +- Primitive rows: unchanged message, transition, readout, output-route, artifact + route, and reducer rows. +- Tensor/binding rows consumed: primitive rows, forward executor rows, + native-strategy rows, program tensor binding rows, forward program access rows, + runtime buffer rows, memory liveness rows, artifact route rows, output route + rows, runtime schedule rows, and physical strategy rows. +- Changed strategy/runtime rows: + `physical_strategy_rows` may select + `streaming_step_producer_consumer` when the registered forward body is legal: + no reverse artifacts, no final program tensor retention, reset absent, and + compiler-owned deferred-local step buffers are present. +- Old route being narrowed: + `stage_materialized` remains the explicit candidate/fallback for rows that + need artifact/tape/final-state retention. It is not the active strategy for + the first forward-only streaming row. +- Native body movement: + the active streaming forward body must release recurrent-message storage after + transition consumes it and before readout/output routing, instead of retaining + that full message bank through the rest of the physical step. +- First perf row: + h32 100M B1024 T=1/K=1 forward, sLSTM + Axon. +- Follow-up guardrail: + small high-level T>1/K>1 row through the same compiler-owned route, with no + benchmark-owned chunking, detach policy, Python replay, route identity change, + or full `[T, cells, state]` materialization. +- Keep/narrow/revert: + keep only if runtime metadata shows the streaming strategy active and the + measured forward owner physically moves; narrow or revert if this is + metadata-only, if parity fails, or if the implementation requires family, + hidden-size, benchmark, fixed-slot temporal scheduler, or April21 code-copy + assumptions. + +### 2026-05-04 - Result: Streaming-Step Forward Body, Narrow Accepted + +Status: accepted only as a narrow T=1/K=1 streaming producer-consumer body. +This is not throughput closure. + +What changed: + +- `physical_strategy_rows` now selects + `streaming_step_producer_consumer` for forward-only rows when all legality + conditions hold: no reverse artifacts, no final program tensor retention, no + reset tensors, and compiler-owned deferred-local step buffers. +- `stage_materialized` remains the active candidate for artifact/tape, + final-state, reset-present, and backward/training rows. +- The active streaming forward program consumes the same compiler rows and + releases the recurrent-message bank after transition consumes it. +- The keyless/direct readout shortcut was rejected for this pass. It is not + semantically equivalent on the current row set and must not run until a + separate compiler legality row proves that readout can skip recurrent K/V. + +Semantic stability: + +- Primitive rows, tensor bindings, output routes, artifact routes, reset + semantics, tape semantics, and reducer semantics are unchanged. +- Expected semantic delta: `none`. +- April21 was used only as physical-shape evidence; no April21 code or fixed + slot route was copied. + +Validation: + +```bash +python -m py_compile \ + src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py \ + tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py \ + tests/test_fabric_runtime.py +``` + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names \ + tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes \ + tests/test_fabric_backend_plan.py::test_temporal_backward_validates_memory_artifact_plan_fingerprint \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +# 4 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_step_body_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_step_body_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_step_body_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_step_body_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + --tb=short +# 1 passed +``` + +Performance/owner row: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_step_body_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_step_body_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_streaming_step_body_forward_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +| Row | tok/s | peak GiB | Accepted movement | +| --- | ---: | ---: | --- | +| sLSTM forward | `11493.23` | `5.452` | `native_forward_after_streaming_message_release_local0` drops current allocation by `0.5625 GiB` | +| Axon forward | `6872.64` | `12.768` | `native_forward_after_streaming_message_release_local0` drops current allocation by `1.75 GiB` | + +Interpretation: + +- The streaming row is real and active on the T=1 forward acceptance row. +- The recurrent-message lifetime move is real, but the high-water owner remains + readout/recurrent K/V-after and readout projection. +- Axon peak is worse than the previous invalid keyless-readout probe because + that shortcut was removed for correctness. +- The next forward owner is therefore not another metadata/liveness shortcut: + it is a compiler-legal readout/message producer-consumer strategy that proves + which readout rows can consume projected value/KV streams without fully + materializing recurrent K/V-after. + +T/K guardrail status: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_step_body_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_step_body_20260504 \ +timeout 900s uv run pytest -q \ + 'tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_output_only_pooled_readout_uses_temporal_superop[False-2-slstm-spec0]' \ + --tb=short +# failed before the new strategy assertions: +# max abs output diff versus PyTorch reference: 0.072310075 +``` + +- The same custom probe with zero reset tensors forced + `stage_materialized` and produced identical output to the streaming run, but + both differed from the PyTorch reference. +- That means the failed T>1/K>1 guardrail is not specifically the new streaming + row; it is an existing K>1 high-level parity owner that must be closed before + claiming generic T/K streaming acceptance. +- Keep/narrow decision: keep the T=1 streaming-message lifetime move; keep T/K + streaming marked open; do not expand readout/KV elision until the K>1 parity + owner and readout legality row are explicit. + +Next owner: + +- Define and implement a compiler-owned readout/message producer-consumer + legality row. It must name the readout primitive rows, message/value + projection rows, tensor bindings, and output-route rows that permit recurrent + K/V-after elision. Without that legality row, the full recurrent K/V readout + path remains the only accepted path. + +### 2026-05-04 - Plan: Readout/Message Producer-Consumer Legality Row + +Status: next implementation pass. This is still throughput/native strategy +work only; expected semantic delta is `none`. + +Hypothesis: + +- The next forward high-water owner is the full recurrent K/V-after bank kept + alive for readout. The compiler needs an explicit readout/message + producer-consumer row before any strategy may elide or stream that bank. +- April21 is evidence for a low-live-memory physical step shape: + message -> transition/readout -> output with minimal materialization. The + transfer must be semantic, not code-copy: current primitive rows, tensor + bindings, output routes, artifact routes, reset policy, tape policy, and + reducer policy remain compiler-owned and stable. + +Compiler products consumed: + +- Primitive rows: existing message, transition, and readout rows. +- Executor rows: forward message/readout executor ids and bucket ordinals. +- Tensor bindings: program tensor bindings, forward executor binding rows, and + forward program access rows. +- Liveness/artifacts/output: memory liveness rows, runtime schedule rows, + runtime buffer rows, forward artifact route/merge rows, and forward output + route rows. +- Strategy rows: existing `physical_strategy_rows` plus a new + readout/message producer-consumer legality table. + +Stable fingerprints: + +- Semantic rows: unchanged. +- Tensor-role/binding rows: unchanged. +- Output/artifact/reducer/reset/tape semantics: unchanged. +- New row delta: strategy/legality metadata only, plus native launch ABI + validation that consumes it. + +Required row behavior: + +- Start with an active `materialized_recurrent_kv_after` row and a blocked + `stream_readout_from_message_projection` row. +- A future streaming readout row may become executable only when it proves the + exact producer route, consumer route, required tensor roles, output route, + reset/tape/artifact legality, and semantic equivalence. +- The keyless/direct readout shortcut remains invalid until this row proves + equivalence. No strategy may skip recurrent K/V just because the output route + is singleton or T=1. + +Tensors that must eventually stop being fully materialized: + +- `recurrent_k_after` and `recurrent_v_after` for readout-only consumers. +- Any readout-local message/output intermediates that have a single declared + local consumer. + +Where consumers move: + +- Readout message consumption moves from role-only access to the + producer-consumer row route. +- Output materialization remains owned by forward output route rows. +- Backward/reducer/tape consumers remain blocked unless artifact/tape rows prove + the needed storage or recompute route. + +T/K streaming constraint: + +- The row is part of the same streaming physical step body used for T=1/K=1, + T>1/K=1, and T>1/K>1. T=1 may specialize only by legality rows proving it is + a specialization of that streaming route. +- The follow-up T/K guardrail remains open until the existing K>1 parity owner + is fixed and the same compiler-owned route runs without Python replay, + benchmark chunking, detach policy, route identity changes, or full + `[T, cells, state]` materialization. + +Keep/narrow/revert: + +- Keep row/ABI work if malformed rows fail before launch and current forward + parity stays green. +- Keep an executable streaming-readout body only if output parity stays green + and allocator/native stage telemetry shows the recurrent K/V-after owner + actually moved. +- Narrow or revert if it is metadata-only, if it reintroduces keyless/direct + readout without proof, or if it keys on family, hidden size, benchmark row, + April21 code shape, or fixed-slot names outside compiler rows. + +### 2026-05-04 - Result: Readout/Message Producer-Consumer Row ABI + +Status: accepted as compiler legality/ABI groundwork. This is not yet the +executable streaming-readout body and does not claim a throughput win. + +What changed: + +- Added compiler-owned `readout_message_producer_consumer_rows`. +- The active row is `materialized_recurrent_kv_after`. +- The future row `stream_readout_from_message_projection` is present but + blocked with `pending_registered_readout_streaming_body`. +- The row records producer message executor, consumer readout executor, + output-route row, required tensor-role mask, executable status, typed blocker, + and schema version. +- The registered forward program now receives and validates the row table before + launch. An active streaming-readout row fails closed until a real registered + body is implemented. +- Runtime metadata records both the active materialized rows and the template + streaming rows. + +Boundary result: + +- Primitive rows: unchanged. +- Tensor bindings: unchanged. +- Output/artifact/reducer/reset/tape semantics: unchanged. +- Native ABI delta: forward launch now consumes the producer-consumer row table. +- Deleted/rejected route: the invalid keyless/direct readout shortcut remains + inactive; this pass does not reintroduce it. + +Validation: + +```bash +python -m py_compile \ + src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py \ + src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py \ + src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py \ + tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py +# passed +``` + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_readout_message_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +# 3 passed +``` + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py \ + tests/test_fabric_backend_plan.py \ + tests/test_fabric_execution_imports.py \ + --tb=short +# 176 passed +``` + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_forward_output_routes_are_compiler_owned_rows \ + tests/test_fabric_backend_plan.py::test_forward_multi_output_concat_routes_are_compiler_owned_rows \ + tests/test_fabric_backend_plan.py::test_forward_output_routes_reject_non_concat_offsets_before_launch \ + tests/test_fabric_backend_plan.py::test_forward_multi_output_routes_require_explicit_merge_semantics \ + --tb=short +# 4 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_readout_pc_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_readout_pc_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_readout_pc_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_readout_pc_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + --tb=short +# 1 passed +``` + +Next owner: + +- Implement the executable `stream_readout_from_message_projection` body only + after the row can prove exact equivalence for the current message/readout + bindings. The body must consume recurrent hidden and recurrent K/V projection + bindings through compiler rows, move the readout consumer into the streaming + step, and stop fully materializing recurrent K/V-after for readout-only + consumers. + +### 2026-05-04 - Plan: Executable Streaming Readout/Message Body + +Status: implementation pass. This is throughput/native-strategy work only. +Expected semantic delta is `none`. + +Boundary manifest: + +- Surface: forward message/readout producer-consumer execution. +- Lane: throughput strategy plus native strategy implementation. +- Declaration/spec owner: unchanged message/readout declarations. +- Primitive rows: unchanged. +- Tensor/parameter bindings: unchanged. +- Changed strategy/runtime rows: + `readout_message_producer_consumer_rows` may select + `stream_readout_from_message_projection` for legal forward-only rows. +- Bindings/routes/liveness consumed directly: + primitive rows, forward executor rows, native strategy rows, forward program + access rows, forward executor binding rows, runtime buffer rows, + memory-liveness rows, forward output route rows, and + readout/message producer-consumer rows. +- Old route being bypassed for the accepted row: + full `recurrent_k_after`/`recurrent_v_after` materialization solely for + readout consumption. +- Old route retained: + materialized recurrent K/V-after remains active for artifact/tape, reset, + materialized-final-state, backward/training, and unsupported binding rows. + +Hypothesis: + +- April21 mechanism being transferred: low-live-memory physical step where + readout consumes message-producer bindings inside the same streaming step + instead of forcing a full recurrent K/V-after bank. +- Rows/fingerprints expected to stay stable: + primitive rows, tensor-role rows, output route rows, artifact route rows, and + semantic executor ids. +- Native body: + compute readout message from `input_k`, `input_v`, recurrent hidden, + recurrent-value projection binding, fixed sender key/context bindings, and + readout output query. The recurrent value projection may use bounded + compiler-owned batch chunks, but must not materialize full recurrent K/V-after + for readout-only consumers. +- Tensors expected to stop being fully materialized on the accepted row: + `recurrent_k_after` and `recurrent_v_after`. +- Consumers move: + readout message consumes the producer route directly; output projection and + output route materialization stay owned by readout/output route rows. +- T/K extension: + this is the per-physical-step body. T=1/K=1 is only the first acceptance row; + T>1/K=1 and T>1/K>1 must use the same compiler-owned route once the existing + K>1 parity owner is fixed. +- Keep/narrow/revert: + keep only if parity stays green and native/allocator telemetry shows the + recurrent K/V-after owner physically moves. Narrow/revert if the change is + metadata-only, reintroduces keyless/direct readout without proof, uses family + or benchmark selectors, changes semantic rows, or regresses output parity. + +### 2026-05-04 - Result: Streaming Readout/Message Body Probe + +Status: correctness accepted, throughput not accepted. The compiler-owned +streaming row now has a registered native body and can run through the active +forward program without changing semantic rows, but the representative T=1 +forward row did not improve throughput. + +What changed: + +- `stream_readout_from_message_projection` can be selected from + `readout_message_producer_consumer_rows` for legal no-artifact/no-reset/ + no-final-state forward rows. +- The selected body consumes producer/consumer executor rows, access rows, + binding rows, runtime-buffer rows, output-route rows, and memory-liveness + rows. +- The body computes readout output through chunk-local recurrent K/V material + and the same generic partitioned-attention primitive used by the materialized + route. This avoids relying on the invalid keyless/direct readout shortcut. +- Artifact/tape/reset/final-state/training rows still use the materialized + route. +- A custom fixed-slot stream-readout warp body was rejected and deleted after + parity investigation; the retained native body is the generic chunk-local + implementation. + +Validation: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py \ + tests/test_fabric_backend_plan.py \ + tests/test_fabric_execution_imports.py \ + --tb=short +# 176 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_readout_20260504f \ +TRITON_CACHE_DIR=/tmp/cortical_triton_stream_readout_20260504f \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + --tb=short +# 2 passed +``` + +Representative audit: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_stream_readout_body_h32_100m_b1024 +``` + +Results: + +- sLSTM forward: `11026.44 tok/s`, `5.260 GiB`. +- Axon forward: `6555.22 tok/s`, `9.501 GiB`. +- Previous current-code line before this probe was about `11640.60 tok/s`, + `5.260 GiB` for sLSTM and `7128.07 tok/s`, `9.534 GiB` for Axon. + +Decision: + +- Keep the compiler row, native ABI, route validation, and parity coverage. +- Do not count this as throughput closure. It preserves the low-live-memory + shape but is slower than the previous current-code forward line. +- Do not pursue more one-buffer readout tweaks from this body until the next + owner is framed as a broader producer-consumer physical strategy. + +Next owner: + +- The peak owner remains the message/transition physical shape, especially + full `native_forward_message_*` weighted-value/projected-message boundaries + and transition output boundaries. +- The next plan should reduce stage boundaries across message -> transition -> + readout using compiler rows, likely by grouping/fusing GEMM/BMM producer- + consumer chains rather than adding another readout-local shortcut. + +### 2026-05-04 - Plan: Message -> Transition Producer-Consumer Row + +Status: implementation pass. This is throughput/native-strategy work only. +Expected semantic delta is `none`. + +Boundary manifest: + +- Surface: forward message producer consumed by transition aggregate input. +- Lane: throughput strategy plus native strategy implementation. +- Declaration/spec owner: unchanged graph, message, transition, readout, and + output declarations. +- Primitive rows expected stable: message primitive rows, transition primitive + rows, readout primitive rows, and output route rows. +- Tensor/parameter bindings expected stable: message/readout/transition + parameter bindings and transition aggregate-input program-access binding. +- Changed strategy/runtime rows: + `message_transition_producer_consumer_rows` may select + `stream_message_to_transition_input` for legal forward-only rows. +- Bindings/routes/liveness consumed directly: + primitive rows, forward executor rows, native strategy rows, forward program + access rows, forward executor binding rows, runtime-buffer rows, + memory-liveness rows, physical strategy rows, forward output route rows, and + message/transition producer-consumer rows. +- Old route being bypassed for the accepted row: + transition aggregate input must not be produced by role-only full + `recurrent_msg.slice(...).contiguous()` when the compiler row proves a + singleton producer/consumer route. +- Old route retained: + materialized recurrent-message routing remains active for reverse artifacts, + reset-present rows, materialized final state, backward/training, multi- + transition rows without explicit merge semantics, and unsupported binding + rows. + +Hypothesis: + +- April21 mechanism being transferred: low-live-memory physical step in which + message output is consumed by transition inside the same physical step rather + than through a separate transition input copy. +- Current compiler products expressing it: + `forward_program_access_rows` identify the transition aggregate binding; + `runtime_buffer_rows` and `memory_liveness_rows` own the step-local recurrent + message buffer; `physical_strategy_rows` gates streaming-step legality; + `message_transition_producer_consumer_rows` owns the producer/consumer edge. +- Rows/fingerprints expected to stay stable: + primitive rows, tensor-role rows, forward output route rows, forward artifact + route rows, and semantic executor ids. +- First executable row: + legal singleton message -> singleton transition row, no artifacts, no resets, + no materialized final state, forward-only T=1. +- Tensors expected to stop being fully duplicated: + transition aggregate input copy from `recurrent_msg.slice(...).contiguous()`. + Full recurrent-message materialization remains open until the next native + row can produce transition chunks directly without a full message bank. +- Consumers move: + transition aggregate input binding consumes the message producer buffer + directly; readout remains route-owned and may use the existing streaming + readout body where legal. +- T/K extension: + this is a per-physical-step route. T=1/K=1 is the first acceptance row; + T>1/K=1 and T>1/K>1 must use the same compiler-owned route once the existing + K guardrail is rerun. +- Keep/narrow/revert: + keep if parity stays green and the transition aggregate copy path is absent + from the active route with stable semantic rows. Narrow/revert if the change + is metadata-only, changes primitive rows, uses family/shape/benchmark + selectors, or fails to move named memory/time owner evidence. + +### 2026-05-04 - Result: Message -> Transition Producer-Consumer Row + +Status: kept narrowly as compiler routing infrastructure, not accepted as a +T=1 throughput owner closure. + +Implemented: + +- Added compiler-owned `message_transition_producer_consumer_rows` with schema + version, strategy opcode, status opcode, producer message executor row, + consumer transition executor row, transition aggregate access opcode, required + role mask, and typed blocker. +- Threaded the rows through `RegisteredTemporalExecutorProgram`, fused CUDA + launch planning, the Python extension wrapper, pybind binding, and the + registered forward program. +- The registered forward program validates the rows before launch and can bind + the transition aggregate input directly to the message producer buffer when + rows prove a legal singleton `stream_message_to_transition_input` route. +- Mixed-pop/multi-transition rows remain materialized and explicitly block the + direct route with `multiple_transition_consumers_need_merge_rows` until + merge/chunk semantics are compiler-owned. + +Focused checks: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_plan.py::test_readout_message_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + --tb=short +# 4 passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py \ + tests/test_fabric_backend_boundaries.py \ + tests/test_fabric_execution_imports.py \ + --tb=short +# 177 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_msg_transition_pc_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_msg_transition_pc_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +Representative forward audit: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_msg_transition_pc_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_msg_transition_pc_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_message_transition_pc_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +| Row | tok/s | peak GiB | current peak owner | accepted movement | +| --- | ---: | ---: | --- | --- | +| sLSTM forward | `10808.82` | `5.260` | `native_forward_after_transition_local0` | none; transition aggregate input binding is compiler-routed but peak unchanged | +| Axon forward | `6636.77` | `9.501` | `native_forward_message_after_normalize_local0` | none; message-stage high-water remains dominant | + +Interpretation: + +- The row is real: the active launch contract requires + `message_transition_producer_consumer_rows`, the high-level CUDA path passes + parity, and the audit reports registered fused forward ownership. +- The route does not close the high-impact owner. For the singleton T=1 row, + the previous full-span `recurrent_msg.slice(...).contiguous()` path was not + the memory/time owner; the remaining peak is transition output materialization + for sLSTM and message weighted/projected/normalized intermediates for Axon. +- Keep the compiler row because it is the legality surface needed for the next + real producer-consumer strategy. Do not count this as a throughput win. + +Next owner: + +- Implement a true producer-consumer native strategy that moves consumers into + the message/transition step body instead of only rebinding a full message + tensor: + - for sLSTM, reduce or stream transition forward linear/matmul/norm/state + outputs that are currently fully materialized; + - for Axon, reduce the weighted-value/projected/normalized message + intermediates before transition/readout. +- The next plan must name which compiler rows consume the chunk/stream route + and which full tensors stop being materialized. It should not add another + binding-only route unless it moves allocator telemetry. + +### 2026-05-04 - Deep Dive: Remaining T=1 Throughput Gap + +Status: analysis checkpoint only. No optimization code was changed in this +pass. + +Current target: + +- April 21 `h32_t1_bxparams` summary floor remains `58732.71 tok/s`, + `2.07 GiB`. +- That target covers sLSTM + Axon, `100M/500M/1B`, forward + training, + `B=1024/16384`. The current rows below are steering rows, not closure. + +Latest current-code evidence: + +- Single-pop forward/training owner table: + `tmp/fabric_audits/partials/2026-05-04/t1_current_deepdive_after_message_pc_h32_100m_b1024` +- Latest message -> transition row check: + `tmp/fabric_audits/partials/2026-05-04/t1_message_transition_pc_h32_100m_b1024` +- Mixed-pop forward steering row: + `tmp/fabric_audits/partials/2026-05-04/t1_mixed_forward_analysis_h32_100m_b1024` + +Current owner table: + +| Row | tok/s | vs Apr21 | slowdown | Peak GiB | Memory vs Apr21 | Current owner | +| --- | ---: | ---: | ---: | ---: | ---: | --- | +| sLSTM single forward | `10808.82` | `18.40%` | `5.43x` | `5.260` | `2.54x` | `native_forward_after_transition_local0` | +| Axon single forward | `6636.77` | `11.30%` | `8.85x` | `9.501` | `4.59x` | `native_forward_message_after_normalize_local0` | +| sLSTM mixed forward | `6448.85` | `10.98%` | `9.11x` | `9.847` | `4.76x` | `native_forward_after_transition_local0` | +| Axon mixed forward | `6471.34` | `11.02%` | `9.08x` | `10.879` | `5.26x` | `native_forward_after_transition_local0` | +| sLSTM single train | `13.37` | `0.0228%` | `4393.77x` | `25.049` | `12.10x` | `native_after_recurrent_message_local0` | +| Axon single train | `4.01` | `0.0068%` | `14657.62x` | `81.319` | `39.28x` | `native_after_recurrent_message_local0` | + +Diagnosis: + +- Forward remains the right priority. Training is catastrophically behind, but + the forward physical step is still materially different from April 21 and + training will inherit that over-materialization. +- Binding-only producer-consumer rows are no longer enough. The latest + `message_transition_producer_consumer_rows` route is real and compiler-owned, + but peak did not move because the full producer tensors are still + materialized. +- For sLSTM, the single-pop forward peak is transition output materialization. + The runtime roles behind this are the transition forward linear/matmul/norm/ + state outputs, especially `transition_forward_linear_output`, + `transition_forward_matmul_output`, `transition_forward_norm_output`, and + `transition_forward_state_output`. +- For Axon single-pop forward, the largest live stage is still the message + weighted/projected/normalized path. The message body still allocates + `weighted_value`, projects it, normalizes it, and then hands a full message + bank to later consumers. +- Mixed-pop forward is now useful evidence: it does not expose a separate + semantic route, but it raises peak to roughly `9.8-10.9 GiB` and blocks the + direct message -> transition route with four producer-consumer rows until + merge/chunk semantics are explicit. Mixed-pop closure therefore depends on + the same generic chunk/stream producer-consumer strategy, not a separate + single-pop shortcut. + +Highest-impact remaining work: + +1. Native forward producer-consumer body across message -> transition. + Move transition consumers into the same step/chunk that produces message + data, so transition forward outputs do not require full-span runtime buffers + unless artifact/tape/reset/final-state rows demand them. +2. Message weighted/projected/normalized streaming. + Convert the fixed-slot context message body from + weighted-value -> projected-message -> normalize -> full message bank into + chunked/grouped GEMM/BMM producer-consumer execution selected by compiler + rows. This is the Axon forward owner and also helps mixed-pop. +3. Transition output liveness/reducer contract. + Decide by compiler rows which transition outputs are public state, tape, + reducer input, or local-only workspace. Local-only outputs should be consumed + or released inside the native body, not held as full runtime buffers. +4. Mixed-pop generic route. + Extend the same route through merge/chunk rows for multiple transition + consumers. Do not add a single-pop-only fast path. +5. Training/backward after forward moves. + Once the forward live set moves, revisit reverse artifacts, tape, + recurrent-message gradients, and parameter reducers. Current training + numbers are too far behind for closure but should stay a parity/liveness + guardrail until the forward physical step is compact. +6. Full closure matrix. + After the steering row moves, rerun April 21 coverage: `100M/500M/1B`, + `B=1024/16384`, sLSTM + Axon, forward + training, mixed-pop T=1, + small-hidden h4/h8/h16, high-batch small-param, reset-present, + materialized/no-final-state, and T/K streaming guardrails. + +### 2026-05-04 - Plan: Streaming Forward Producer-Consumer Step Body + +Status: plan only. Do not start another narrow liveness patch before this +packet is implemented or explicitly rejected. + +Hypothesis: + +- The April 21 mechanism to semantically transfer is a compact physical step + body: message production, transition consumption, readout/output production, + and local temporary release happen inside one streaming step unit instead of + across fully materialized compiler stage boundaries. +- The current compiler products are sufficient to express the first version: + primitive rows, forward executor rows, tensor binding rows, program access + rows, native callable output rows, memory liveness rows, output route rows, + artifact route rows, and `message_transition_producer_consumer_rows`. +- Expected semantic delta: none. Primitive row fingerprints and tensor-role + fingerprints must stay stable. Only strategy/runtime/liveness behavior may + change. + +Lane: + +- Throughput strategy plus native implementation. +- Not a semantic extension. +- Not a benchmark-side T=1 shortcut. +- Not an April 21 code copy. + +Rows and bindings consumed directly: + +- `TemporalPrimitiveTablePlan` rows for message, transition, readout, and + output surfaces. +- Registered forward executor rows for message, transition, and readout. +- Program access rows for transition aggregate input and readout/message input. +- Tensor binding rows for message value/projection/normalization inputs, + transition projection/state inputs, readout projection inputs, and output + routes. +- Native callable output rows for local-only transition and message outputs. +- Memory/liveness rows for workspace, runtime buffers, artifact/tape retention, + and local release. +- `forward_output_route_rows` for output materialization. +- `message_transition_producer_consumer_rows` for producer-consumer legality. + +First implementation target: + +1. Add a registered streaming forward strategy row that owns one physical step + body across message -> transition -> readout. +2. Keep the existing compiler rows stable, but add strategy metadata that says + the message producer may stream directly into transition aggregate input + chunks when liveness rows do not require a full recurrent-message artifact. +3. In the native forward body, process a batch chunk through: + - weighted message/value production, + - projection/normalization into transition input or a short-lived workspace, + - transition row-group computation, + - readout/output route production when required, + - immediate release/reuse of local-only workspaces. +4. Do not add family, benchmark, hidden-size, single-pop, or graph-constructor + selectors. Legality must come from rows/bindings/liveness only. + +Tensors expected to stop being fully materialized for forward-only T=1 rows: + +- Full-span recurrent message where no artifact/tape/output route requires it. +- Full-span transition aggregate input where it is only consumed by the + transition row group. +- Local-only transition forward linear/matmul/norm/state intermediates. +- Message weighted-value/projected/normalized intermediates that can be + consumed inside the same chunk. + +Where consumers move: + +- Transition consumes streamed message projection/normalization chunks instead + of reading a full recurrent-message bank. +- Readout consumes routed output chunks from the same step body where legal. +- Artifact/tape routes remain the only reason to materialize a full logical + tensor. + +Smallest representative probe: + +- High-level Fabric row, not direct helper call: + `h32 100M B1024 T=1 K=1 forward`, single-pop sLSTM and Axon. +- Artifact path: + `tmp/fabric_audits/partials/2026-05-04/t1_streaming_forward_body_h32_100m_b1024`. +- Expected owner movement: + - sLSTM peak must move below `native_forward_after_transition_local0` and + reduce below the current `5.260 GiB`. + - Axon peak must move below `native_forward_message_after_normalize_local0` + and reduce below the current `9.501 GiB`. + - Throughput should improve materially; if memory moves without throughput, + collect launch/native timing before accepting. + +Keep/narrow/revert rule: + +- Keep only if the named native owner moves in allocator telemetry, launch + shape, storage lifetime, or tok/s on the representative high-level row. +- Narrow if it is correct for forward-only but extends artifact/tape lifetime or + increases training reserved/unclassified memory. +- Revert if it is only metadata, only changes labels, requires hidden route + identity, or fails to move the owner. + +Follow-up representative rows: + +- Mixed-pop `h32 100M B1024 T=1 K=1 forward` through the same strategy id. +- sLSTM + Axon training guardrail after forward owner moves, using the same + compiler route and no Python replay/compat fallback. +- April 21 T=1 matrix only after the steering rows move. + +T/K streaming guardrail: + +- The strategy must be the physical step unit for `T x K`, not a terminal-only + T=1 shortcut. +- Follow-up small high-level row: `T>1,K>1` through the same route identity. +- Prove no benchmark-owned time chunking, detach policy, Python replay, or full + `[T, cells, state]` materialization appears. +- For `T`, compare primarily against the same-row Fabric T=1 per-token line. +- For `K`, compare against matched Fabric `T=1,K=1` training divided by `K`. + +Parity and boundary gates: + +- Targeted parity: output, final state, input/carry gradients, and parameter + gradients for affected rows before any closure claim. +- Source/boundary grep must show no new family selectors, benchmark selectors, + fixed-slot scheduler formulas, or hidden compatibility path. +- Active metadata must report registered compiler-owned strategy rows, + consumed binding/liveness/output/artifact rows, and no hidden fallback. + +### 2026-05-04 - Implementation Pass: Gated Final-Step Local Transition Outputs + +Status: implementation in progress. + +Boundary manifest: + +- Surface: registered temporal fused forward program. +- Lane: throughput strategy/native implementation. +- Semantic delta: none. +- Declaration/spec owner: unchanged message, transition, readout, output + declarations. +- Primitive rows: unchanged. The gated transition row group is still + `linear -> linear -> matmul -> gated_logspace_recurrence -> norm_or_identity`. +- Tensor bindings/routes consumed: existing program tensor bindings, transition + state-carry rows, native callable output rows, memory runtime-buffer rows, + physical strategy rows, and `message_transition_producer_consumer_rows`. +- Memory/liveness owner: compiler runtime-buffer plan with `deferred_local` + transition outputs. +- Old route deleted/fail-closed: none in this pass; the materialized transition + program remains for artifacts, final program tensors, recompute, and non-final + streaming steps. + +Change: + +- Reused the existing compiler-owned deferred-local runtime-buffer ABI for the + gated transition row group. +- `transition_forward_linear_output` for gated aggregate input projection can + now remain deferred local and be produced per batch chunk. +- On the terminal local forward step, required gated `next_y` can remain + deferred local: the row group computes `next_y` into a chunk workspace, + immediately runs the compiler-selected norm output for that chunk, and does + not materialize private state carry that no downstream row consumes. +- The final-step state-carry copy is skipped only under the registered + streaming-step physical strategy, no reverse artifacts, no final program + tensors, and the last physical step. This keeps the same step body valid for + `T>1/K>1`: non-final steps still materialize state carry through the existing + compiler carry rows. + +Expected owner movement: + +- sLSTM forward should reduce `native_forward_after_transition_local0` by + removing the full transition input and terminal private `next_y` live buffers. +- Axon may not move much in this pass because its current owner is still the + message weighted/projected/normalized path. + +Keep/narrow/revert rule: + +- Keep if the high-level sLSTM T=1 forward row moves the transition owner in + allocated memory or tok/s without parity regressions. +- Narrow if forward moves but a training/T>1 guardrail exposes a private-state + lifetime issue. +- Revert if the patch only changes labels, fails CUDA parity, introduces hidden + route identity, or increases unclassified/reserved memory. + +Focused checks: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +# 3 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_gated_local_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_gated_local_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py \ + tests/test_fabric_backend_boundaries.py \ + tests/test_fabric_execution_imports.py \ + --tb=short +# 177 passed +``` + +Representative T=1 forward audit: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_gated_local_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_gated_local_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_streaming_forward_body_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +| Row | tok/s | peak GiB | current owner | result | +| --- | ---: | ---: | --- | --- | +| sLSTM single forward | `10271.49` | `4.394` | `native_forward_message_after_normalize_local0` | accepted memory movement; transition owner moved | +| Axon single forward | `6628.35` | `9.501` | `native_forward_message_after_normalize_local0` | unchanged; message owner remains | + +Movement vs previous current-code steering row: + +- sLSTM peak moved from `5.260 GiB` to `4.394 GiB`. +- sLSTM `native_forward_after_transition_local0` moved from about `4.139 GiB` + to `3.577 GiB`. +- The sLSTM current owner moved from transition to message normalization. +- Axon stayed at `9.501 GiB`; this confirms Axon needs message + weighted/projected/normalized streaming next. + +Mixed-pop forward audit: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_gated_local_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_gated_local_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_streaming_forward_body_mixed_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes mixed \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +| Row | tok/s | peak GiB | current owner | result | +| --- | ---: | ---: | --- | --- | +| sLSTM mixed forward | `6456.63` | `9.121` | `native_forward_after_transition_local0` | memory moved down from `9.847 GiB`; transition still owner | +| Axon mixed forward | `6402.11` | `10.152` | `native_forward_after_transition_local0` | memory moved down from `10.879 GiB`; transition still owner | + +T/K streaming guardrail: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_gated_local_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_gated_local_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan tk-scaling \ + --out-dir tmp/fabric_audits/partials/2026-05-04/tk_streaming_forward_body_guardrail_h32_100m_b128_t2_k2 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 128 --seq-lens 2 --inner-steps 2 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +# status=ok, cases=2 +``` + +| Row | tok/s | peak GiB | current owner | +| --- | ---: | ---: | --- | +| sLSTM `T=2,K=2` forward | `2235.06` | `2.914` | `native_forward_message_after_normalize_local3` | +| Axon `T=2,K=2` forward | `2324.76` | `11.702` | `native_forward_message_after_normalize_local3` | + +Training guardrail observation: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_gated_local_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_gated_local_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads \ + --tb=short +# 4 failed on forward output parity vs PyTorch reference +``` + +- This is recorded as an existing training/parity blocker, not as accepted + throughput evidence. +- The new terminal-local transition route is not active in this test: + autograd/training collects reverse artifacts, so `return_reverse_artifacts` + keeps `terminal_local_transition_state` false and the materialized transition + path is used. +- Do not expand training optimization from this pass. Keep training as a + parity/liveness guardrail until the forward physical owner moves further. + +Interpretation: + +- The patch is accepted as a real compiler-owned liveness movement for gated + transition forward. It is not T=1 throughput closure. +- The next forward owner is no longer the gated terminal transition buffer for + sLSTM. It is the shared message weighted/projected/normalized live set, which + also remains Axon's primary owner. +- Mixed-pop still needs generic merge/chunk producer-consumer handling because + multiple transition consumer rows keep transition as the mixed owner. + +### 2026-05-04 - Deep Dive: Remaining T=1 Gap After Gated Local Transition + +Status: analysis checkpoint only. No optimization code was changed in this +pass. + +Baseline: + +- April 21 `h32_t1_bxparams` summary floor remains `58732.71 tok/s`, + `2.07 GiB`. +- The April 21 row covers sLSTM + Axon, `100M/500M/1B`, forward + training, + `B=1024/16384`. The current rows below are steering rows, not closure. +- `origin/main` is treated as mechanism evidence only. Do not copy old + execution code; transfer the physical idea as registered compiler-owned + strategies over current rows/bindings/liveness routes. + +Latest accepted current-code evidence: + +- Single-pop forward: + `tmp/fabric_audits/partials/2026-05-04/t1_streaming_forward_body_h32_100m_b1024` +- Mixed-pop forward: + `tmp/fabric_audits/partials/2026-05-04/t1_streaming_forward_body_mixed_h32_100m_b1024` +- T/K streaming guardrail: + `tmp/fabric_audits/partials/2026-05-04/tk_streaming_forward_body_guardrail_h32_100m_b128_t2_k2` + +Current owner table: + +| Row | tok/s | vs Apr21 | slowdown | Peak GiB | Memory vs Apr21 | Current allocated owner | First high-water owner | +| --- | ---: | ---: | ---: | ---: | ---: | --- | --- | +| sLSTM single forward | `10271.49` | `17.49%` | `5.72x` | `4.394` | `2.12x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| Axon single forward | `6628.35` | `11.29%` | `8.86x` | `9.501` | `4.59x` | `native_forward_message_after_normalize_local0` | `native_forward_message_after_weighted_value_alloc_local0` | +| sLSTM mixed forward | `6456.63` | `10.99%` | `9.10x` | `9.121` | `4.41x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| Axon mixed forward | `6402.11` | `10.90%` | `9.17x` | `10.152` | `4.90x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| sLSTM `T=2,K=2` forward, `B=128` | `2235.06` | steering only | steering only | `2.914` | steering only | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | +| Axon `T=2,K=2` forward, `B=128` | `2324.76` | steering only | steering only | `11.702` | steering only | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | + +Training status: + +- Latest single-pop training steering evidence before the gated-local pass was + still around `13.37 tok/s`, `25.049 GiB` for sLSTM and `4.01 tok/s`, + `81.319 GiB` for Axon, with `native_after_recurrent_message_local0` as the + owner. +- The current small `T>1` training parity guardrail fails forward-output parity + before it can be used as throughput evidence. The new terminal-local forward + route is not active in that guardrail because reverse artifacts are collected. +- Training remains a blocker for full T=1 closure, but forward remains the + optimization lane. Training should be used as bounded parity/liveness + evidence until the forward physical step is compact. + +Diagnosis: + +- The last pass correctly moved the sLSTM transition owner down: + single-pop sLSTM memory improved from about `5.260 GiB` to `4.394 GiB`. + It did not close throughput: single-pop sLSTM is still `5.72x` below the + April 21 floor and Axon is `8.86x` below. +- The dominant single-pop forward owner is now the message + weighted-value -> output projection -> row-normalize live set. Current code + still allocates or keeps a full `weighted_value` chunk and full + `recurrent_msg`/projected-message stage around the same point in the physical + step. +- Earlier dense-affine/message-projection/readout variants were steering + probes, not accepted closure: + - `t1_message_project_norm_fused_forward_h32_100m_b1024` regressed memory. + - `t1_message_projection_dense_affine_forward_h32_100m_b1024` and + `t1_message_projection_out_forward_h32_100m_b1024` improved tok/s + modestly but left peak around `6.951/10.998 GiB`. + - `t1_stream_readout_body_h32_100m_b1024` did not move the message owner. + - `t1_stream_readout_cost_rejected_h32_100m_b1024` regressed memory badly. +- Mixed-pop is not a separate fast-path request. It exposes the next compiler + generality issue: multiple transition consumers need merge/chunk rows for + message -> transition producer-consumer execution. Current singleton + transition streaming cannot close mixed-pop. +- The frontend/static/state handoff is not the current peak owner, but it is + already large enough to matter for closure. In the latest single-pop forward + row, frontend peak is about `2.09 GiB` for sLSTM and `4.61 GiB` for Axon; + April 21's whole memory floor is `2.07 GiB`. After message-stage memory moves, + static/state lifetime will become a real closure owner. +- April 21 mechanism evidence points to a low-live-memory physical model: + compact message projection, receiver/state-affine chunking, phase reuse for + projected-message/reset sources when legal, and direct/persistent state + affine execution. In this branch those must be expressed through current + primitive rows, native callable rows, access rows, producer-consumer rows, + output/artifact routes, and liveness rows. + +Highest-impact remaining changes before T=1 can match/exceed April 21: + +1. **Message projection/normalization streaming body.** + Replace the current `weighted_value -> projected/recurrent_msg -> normalize` + materialization pattern with a registered producer-consumer body that + consumes message rows and writes only required route-owned outputs. The + implementation should reduce or eliminate full `weighted_value` and + projected-message live overlap, and should use grouped/batched GEMM/BMM or a + fused native strategy where the compiler rows make it legal. +2. **Message -> transition/readout consumer movement.** + Move transition and readout consumers into the streaming step where legal. + The next strategy should state which consumers read chunk-local projected or + normalized message data and which artifacts force full materialization. +3. **Mixed-pop merge/chunk producer-consumer rows.** + Generalize the route beyond singleton transition consumers. Multiple + transition/readout consumers need explicit concat/sum/select/route ownership + through compiler rows; do not add a single-pop-only shortcut. +4. **Frontend/static/state lifetime cleanup after message owner moves.** + Once the message high-water drops, classify and shorten `state`, + `static_tensors`, `route_static_tensors`, and initial recurrent hidden/KV + lifetimes through compiler-owned handoff/liveness rows. This must not become + benchmark-side tiling or Config policy. +5. **Training parity/liveness unblock.** + Fix the current small `T>1` training parity blocker, then revisit reverse + artifacts, recurrent-message gradients, transition reverse spans, and + reducer liveness. Do not tune training broadly until forward owner movement + has a compact step body. +6. **Full closure matrix.** + After the steering owner moves, rerun April 21 coverage: `100M/500M/1B`, + `B=1024/16384`, sLSTM + Axon, forward + training, mixed-pop T=1, + small-hidden h4/h8/h16, high-batch small-param, reset-present, final-state + materialized/unmaterialized, and T/K streaming guardrails. + +Next highest-impact plan should target item 1 and item 2 together. Another +readout-only, alias-only, or one-buffer message tweak is not enough unless it +is explicitly part of the registered streaming physical step body and moves the +named message owner in allocator telemetry and tok/s. + +### 2026-05-04 - Plan: Streaming Message Projection/Consumer Body + +Status: plan only. Do not implement a narrow message/readout shortcut outside +this plan. + +Hypothesis: + +- The next April 21 mechanism to semantically transfer is compact + message-stage execution: produce attention-weighted values, project them, + normalize them, and feed transition/readout consumers inside the streaming + physical step rather than holding full producer stage boundaries live. +- This should be expressed as a registered physical strategy over existing + compiler rows. User-visible message, cell, readout, graph, reset, output, and + loss semantics do not change. +- Expected owner movement: + - single-pop sLSTM: `native_forward_message_after_normalize_local0` and first + high-water `native_forward_after_readout_message_local0` must drop below the + current `4.394 GiB` peak. + - single-pop Axon: `native_forward_message_after_normalize_local0` and first + high-water `native_forward_message_after_weighted_value_alloc_local0` must + drop below the current `9.501 GiB` peak. + - mixed-pop: transition peak should not regress; follow-up work should move + the `multiple_transition_consumers_need_merge_rows` blocker into explicit + merge/chunk rows. + +Boundary manifest: + +- Surface: registered temporal fused forward program and message native + callable strategy. +- Lane: throughput strategy plus native implementation. +- Semantic delta: none. +- Declaration/spec owner: unchanged graph, message-rule, transition/cell, + readout, output-route, reset, and execution declarations. +- Primitive rows expected stable: message primitive rows, transition primitive + rows, readout primitive rows, output-route rows, artifact rows, reducer rows. +- Tensor/binding rows consumed: message native callable access rows, message + tensor binding rows, transition aggregate input access rows, readout access + rows, output route rows, artifact route rows, memory liveness rows, + runtime-buffer rows, physical strategy rows, message-transition + producer-consumer rows, and readout-message producer-consumer rows. +- Memory/liveness owner: compiler runtime-buffer plan and native strategy + local workspaces. Full materialization remains legal only when an artifact, + tape, output route, or downstream consumer route requires it. +- Old route to narrow/fail-close: do not delete the materialized message route + yet; it remains required for artifact/tape/training and unsupported consumer + shapes. New streaming route must be selected by legality rows, not by family, + hidden size, benchmark id, or population mode. + +Implementation slices: + +1. **Add explicit streaming message-stage strategy metadata.** + - Extend the message native callable/output contract to declare a + `stream_project_normalize_message` phase. + - Add output/liveness contract rows for: + `weighted_value_workspace`, `projected_message_workspace`, + `normalized_message_route`, and optional `full_recurrent_msg_materialized`. + - The strategy must declare whether it writes route-owned normalized message + directly, returns a full recurrent message, or provides chunk-local + consumer input only. + - Add source guardrails that the strategy is selected from native callable + rows and memory liveness rows, not from old fixed-slot/global tensor slots. + +2. **Implement a bounded chunk-local message project+normalize body.** + - Start in the registered message native callable body, currently around + `weighted_value = at::empty(...)` in + `registered_program/native_callables/message_forward_strategies.cuh`. + - Keep the existing attention-weighted value computation semantic. + - Replace the current full live overlap with a bounded producer-consumer + pipeline: + `weighted_value_chunk -> batched/grouped GEMM or dense_affine_out -> + normalize -> route-owned normalized chunk`. + - Prefer out-variant GEMM/BMM/dense-affine where the output storage is + compiler-owned. Avoid `at::matmul` paths that allocate unclassified + temporaries unless the probe proves they are faster and lower-memory. + - Keep formulas inside the message native strategy. Do not move message math + into the temporal scheduler. + +3. **Thread normalized-message chunks to transition/readout consumers.** + - For singleton legal routes, make transition aggregate input consume the + normalized chunk directly instead of re-reading a full recurrent-message + buffer. + - For readout, only stream when `readout_message_producer_consumer_rows` + declares an executable route and prior rejected readout-cost conditions are + not hit. + - If a full logical `forward_recurrent_msg` is required by artifact/tape or + unsupported consumers, materialize it explicitly and record the reason in + metadata. + +4. **Add mixed-pop merge/chunk planning as the follow-up inside the same design.** + - Do not unblock mixed-pop with a separate fast path. + - Replace singleton-only legality in the message-transition plan with + explicit merge/chunk route rows for multiple transition consumers. + - First version may still block non-singleton at launch, but the plan must + expose which merge rows are missing and must not silently fall back to + role-only materialization for supported rows. + +5. **Add measurement and rollback gates.** + - Record native stage telemetry before/after: + `message_before_weighted_value_alloc`, `message_after_weighted_value`, + `message_after_projected_gemm`, `message_after_normalize`, + transition/readout consumer stages, and reserved gap. + - Keep only if the named message owner physically moves in allocated memory, + max allocated memory, launch shape, or tok/s on the high-level row. + - Narrow to forward-only if artifact/tape/training lifetimes grow. + - Revert if the change only relabels metadata, uses hidden route identity, + regresses Axon like the rejected readout variants, or increases + unclassified/reserved memory without a named owner. + +First implementation target: + +- Implement slice 1 and slice 2 for legal forward-only single-pop T=1 rows. +- Do not require mixed-pop to close in the first patch, but leave a typed + compiler blocker and no hidden fallback for missing merge/chunk rows. +- Do not broaden training optimization in this patch. Run the existing training + parity guardrail only as a liveness/parity check. + +Representative perf row: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_message_project_norm_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_stream_message_project_norm_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_stream_message_project_norm_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +Follow-up guardrails: + +```bash +# Mixed-pop same route, expected to expose merge/chunk row status. +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_message_project_norm_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_stream_message_project_norm_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_stream_message_project_norm_mixed_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes mixed \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner + +# T/K streaming route identity. +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_message_project_norm_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_stream_message_project_norm_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan tk-scaling \ + --out-dir tmp/fabric_audits/partials/2026-05-04/tk_stream_message_project_norm_guardrail_h32_100m_b128_t2_k2 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 128 --seq-lens 2 --inner-steps 2 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +Static/parity gates: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py \ + tests/test_fabric_backend_boundaries.py \ + tests/test_fabric_execution_imports.py \ + --tb=short + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_message_project_norm_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_stream_message_project_norm_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +``` + +Acceptance: + +- Accept only if the single-pop forward row moves the message owner below the + current `4.394 GiB` sLSTM / `9.501 GiB` Axon steering peaks or materially + improves tok/s without memory regression. +- Treat mixed-pop and T/K as guardrails for route identity and no hidden + terminal-only shortcut. +- No T=1 closure claim until the full April 21 coverage matrix passes. + +### 2026-05-04 - Rejected Probe: Bounded Message Project/Normalize Chunks + +Status: rejected and reverted. This pass tested the first bounded slice of the +streaming message projection/consumer-body plan, but it did not move the named +owner. + +Boundary classifier: + +- Lane: throughput strategy plus native implementation. +- Semantic delta: none. +- Rows/fingerprints expected stable: message primitive rows, transition + primitive rows, readout primitive rows, output-route rows, artifact rows, + runtime-buffer rows, physical-strategy rows, and native callable binding rows. +- Native body changed: registered fixed-slot-context message native strategy + only. +- Runtime/liveness rows consumed: existing forward recurrent-message runtime + buffer role, message native callable rows, executor rows, and stage telemetry. +- Tested route narrowing: remove the unbounded full-batch message + project/normalize call from the registered strategy body and chunk + `weighted_value -> projected_message -> normalized_message` by a bounded + per-batch byte budget. +- Still not claimed: mixed-pop merge/chunk routing, full transition/readout + chunk-local consumer movement, and artifact/training collection closure. + +Keep/narrow/revert rule used: + +- Keep only if the representative single-pop T=1 forward row moves the named + message owner in allocator telemetry, max allocated memory, or tok/s without + increasing hidden/unclassified memory. +- Narrow if it helps forward-only but regresses artifact/training liveness. +- Revert if it only adds launch/GEMM overhead, fails compile/parity, or leaves + owner telemetry unchanged. + +Guardrail status: + +- A temporary source guardrail was added during the probe, then removed with the + reverted implementation. The invariant is not durable because the probe did + not become an accepted path. + +Commands/results: + +- Focused static guardrails: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified --tb=short` + passed: `3 passed in 5.78s`. +- CUDA smoke: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_message_project_norm_20260504 TRITON_CACHE_DIR=/tmp/cortical_triton_stream_message_project_norm_20260504 timeout 900s uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: `1 passed in 62.02s`. +- Representative audit: + `tmp/fabric_audits/partials/2026-05-04/t1_stream_message_project_norm_h32_100m_b1024`. + +Measured result: + +- sLSTM single-pop T=1 forward: `10695.57 tok/s`, `4.3936 GiB`. + Previous accepted steering row was about `10271.49 tok/s`, `4.394 GiB`. + The owner remained `native_forward_message_after_normalize_local0` / + `native_forward_after_readout_message_local0`; memory did not move. +- Axon single-pop T=1 forward: `6585.96 tok/s`, `9.5012 GiB`. + Previous accepted steering row was about `6628.35 tok/s`, `9.501 GiB`. + The owner remained in the message/transition high-water; memory did not move + and tok/s regressed slightly. + +Decision: + +- Reverted the bounded message chunk implementation and removed the temporary + guardrail. This was a valid mechanism probe, but not an accepted throughput + strategy. +- Next work should not continue one-buffer project/normalize chunk tweaks. + The owner table points back to the broader streaming physical step body: + avoid materializing the full recurrent-message and transition/readout + boundary as independent stage products, or move their consumers into the same + registered producer-consumer strategy with explicit compiler route/liveness + rows. + +Next active owner: + +- The current forward-only route already avoids full recurrent K/V-before + materialization where legal; `run_registered_forward_message_recurrent_kv_handler` + passes `materialize_key_bank=return_reverse_artifacts`, and the fixed-slot + message strategy computes recurrent values in bounded chunks when the bank is + deferred. +- The remaining forward T=1 peak is therefore not the isolated recurrent-value + bank. It is the full logical `forward_recurrent_msg` runtime buffer and the + transition/readout consumer boundary that still requires `aggregate_input = + producer_state.recurrent_msg`. +- The next implementation should add a registered streaming + message-to-transition producer-consumer body. That body must consume the same + message-transition producer-consumer rows, forward program access rows, + native callable binding rows, runtime-buffer/liveness rows, and artifact + rows, but should generate chunk-local normalized message and immediately feed + the transition input projection/row-group for the same route. +- This is a real fused physical strategy, not a helper-level allocation tweak. + It must keep semantic row fingerprints stable and fail closed if artifacts, + reverse tape, multiple transition consumers, reset/materialization policy, or + missing merge/chunk rows require full recurrent-message materialization. + +Required next hypothesis packet: + +- April21 mechanism: low-live-memory physical step where message projection, + transition input projection, transition recurrence/public output, and readout + route are producer-consumer stages of one streaming body. +- Compiler products consumed: `message_transition_producer_consumer_rows`, + `readout_message_producer_consumer_rows`, `forward_program_access_rows`, + `native_callable_binding_schema_rows`, `native_callable_output_rows`, + `runtime_buffer_rows`, `memory_liveness_rows`, `forward_output_route_rows`, + artifact rows, and physical strategy rows. +- Tensors expected to cease full materialization on the first legal row: + full `forward_recurrent_msg` for transition consumption; later, selected + transition intermediate outputs and readout message where route rows prove + local consumption. +- First representative row: h32 100M B1024 T=1 single-pop forward for sLSTM and + Axon. +- Guardrail row: small `T=2,K=2` through the same route identity. +- Revert rule: reject if the full recurrent-message stage still allocates the + same logical buffer, if transition/readout consumers still fetch role-only + recurrent message, or if the patch selects by family/hidden size/benchmark + instead of compiler rows. + +### 2026-05-04 - Deep Dive: Fresh Current T=1 Forward Gap + +Status: analysis checkpoint only. No optimization code was changed in this +pass. + +Fresh evidence: + +- Single-pop forward: + `tmp/fabric_audits/partials/2026-05-04/t1_current_forward_deepdive_fresh_h32_100m_b1024` +- Mixed-pop forward: + `tmp/fabric_audits/partials/2026-05-04/t1_current_mixed_forward_deepdive_fresh_h32_100m_b1024` +- Baseline: April 21 `h32_t1_bxparams`, `58732.71 tok/s`, `2.07 GiB`. + This remains the summary floor for the full T=1 closure matrix. + +Current owner table: + +| Row | tok/s | vs Apr21 | Slowdown | Peak GiB | Memory vs Apr21 | Current allocated owner | First high-water owner | +| --- | ---: | ---: | ---: | ---: | ---: | --- | --- | +| sLSTM single forward | `10669.88` | `18.17%` | `5.50x` | `4.394` | `2.12x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| Axon single forward | `6623.80` | `11.28%` | `8.87x` | `9.501` | `4.59x` | `native_forward_message_after_normalize_local0` | `native_forward_message_after_weighted_value_alloc_local0` | +| sLSTM mixed forward | `6434.41` | `10.96%` | `9.13x` | `9.121` | `4.41x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| Axon mixed forward | `6427.31` | `10.94%` | `9.14x` | `10.152` | `4.90x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | + +Important live runtime roles: + +- sLSTM single forward: `forward_recurrent_msg=0.562 GiB`, + `transition_forward_linear_output=2.812 GiB`, + `transition_forward_matmul_output=2.250 GiB`, + `transition_forward_norm_output=0.562 GiB`, + `transition_forward_state_output=0.562 GiB`. +- Axon single forward: `forward_recurrent_msg=1.750 GiB`, + `transition_forward_diag_output=3.500 GiB`, + `transition_forward_linear_output=3.500 GiB`, + `transition_forward_norm_output=1.750 GiB`. +- Mixed forward: `forward_recurrent_msg=1.453 GiB`, + `transition_forward_linear_output=5.086 GiB`, + `transition_forward_matmul_output=2.906 GiB`, + `transition_forward_norm_output=1.453 GiB`, + `transition_forward_state_output=0.727 GiB`. + +Code-path diagnosis: + +- The current compiler route is real: `physical_strategy_rows` selects + `streaming_step_producer_consumer`, and both producer-consumer row tables are + present in the launch contract. +- The active message->transition route is still not a true chunk-local fused + body. `forward_program.cuh` still resolves the transition aggregate input as + `aggregate_input = producer_state.recurrent_msg` for the streaming row. That + means the full logical `forward_recurrent_msg` buffer is still produced before + transition consumes it. +- Mixed-pop exposes the same issue more clearly: the row count rises to four + message-transition producer-consumer rows, and direct singleton consumption + cannot be generalized until explicit merge/chunk semantics exist. +- The last bounded message project/normalize probe was correctly rejected: + helper-level chunking did not move the owner because the consumer boundary + still required the full recurrent-message stage product. +- Frontend/static/state lifetime is already near or above the April21 memory + floor: single sLSTM frontend peak is about `2.09 GiB`, single Axon about + `4.61 GiB`, and mixed about `5.45 GiB`. It is not the first forward owner, + but it will block final memory closure after the native live set moves. + +Largest remaining changes before T=1 can match or exceed April21: + +1. **Executable message -> transition streaming body.** + This is the highest-impact next owner. It must consume + `message_transition_producer_consumer_rows`, `forward_program_access_rows`, + native callable binding/output rows, runtime-buffer rows, liveness rows, and + artifact/output route rows, then feed transition input projection from + chunk-local normalized message. It must not fetch role-only + `producer_state.recurrent_msg` for supported rows. +2. **Transition output workspace/liveness contraction inside the same body.** + Transition intermediate roles are several GiB. Local-only linear/matmul/diag/ + norm/state outputs must become compiler-owned workspace or chunk-local + producer-consumer values, with only public state, artifact/tape, or output + route products retained. +3. **Message projection/normalization as a consumer-coupled strategy, not a + helper tweak.** + The message weighted/projected/normalized live set remains the dominant + single-pop owner. The next implementation should stream this as part of the + physical step body and use grouped/batched GEMM/BMM where legal. +4. **Mixed-pop merge/chunk route semantics.** + Mixed-pop should go through the same strategy. It needs explicit compiler + rows for multiple transition consumers, concat/sum/select or chunk routes, + and typed blockers where unsupported. Do not add a single-pop-only fast path. +5. **Readout route only after equivalence is proven.** + `materialize_recurrent_kv_after(true)` still appears on the materialized + readout path. The prior readout streaming probes were not accepted. A new + readout strategy must prove semantic equivalence through + `readout_message_producer_consumer_rows` before it can elide recurrent K/V. +6. **Frontend/static/state lifetime after native peak moves.** + Once the registered native body stops dominating peak memory, static tensors, + state tensors, boundary projection, and registered handoff lifetimes need to + move behind compiler-owned buffers/liveness. This is a pre-throughput-closure + memory owner, not a benchmark-side shortcut. +7. **Training/backward/reducer liveness after forward step compacts.** + Last known training steering rows remain around `13 tok/s` for sLSTM and + `4 tok/s` for Axon, with very high memory. Training should remain a bounded + parity/liveness guardrail until the forward physical step is compact, then + reverse artifacts, recurrent-message gradients, transition reverse spans, + and parameter reducers become the main owner. +8. **Full April21 closure matrix.** + The steering row is not closure. T=1 still requires `100M/500M/1B`, + `B=1024/16384`, sLSTM + Axon, forward + training, single + mixed, + small-hidden stress, high-batch small-param rows, reset-present, + materialized/no-final-state, and T/K streaming guardrails. + +Next plan should target item 1 and item 2 together. Another isolated +message-buffer, readout-buffer, alias, or one-buffer liveness probe should be +rejected unless it is explicitly part of the executable producer-consumer body +and moves the named owner in allocator telemetry. + +### 2026-05-04 - Plan: Executable Message -> Transition Streaming Body + +Status: plan only. Do not implement another isolated message/readout/alias +patch before this producer-consumer body is either implemented or explicitly +rejected. + +Hypothesis: + +- The largest remaining T=1 forward gap is the materialized boundary between + message and transition. The current route has compiler-owned + `message_transition_producer_consumer_rows`, but the active C++ body still + feeds transition through `aggregate_input = producer_state.recurrent_msg`. +- The April21 mechanism to transfer is the low-live-memory physical step: + message weighted value + projection + normalization feeds transition input + projection/recurrent row-group without keeping a full message-stage product + live unless artifact/tape/reset/final-state routes require it. +- Expected semantic delta: none. Primitive rows, tensor bindings, output + routes, artifact routes, reset semantics, and transition/message formulas + must remain stable. + +Boundary manifest: + +- Surface: registered temporal fused forward program. +- Lane: throughput strategy plus native implementation. +- Declaration/spec owner: unchanged graph, message rule, cell/transition, + readout, output, reset, and loss declarations. +- Compiler rows consumed directly: + `message_transition_producer_consumer_rows`, + `readout_message_producer_consumer_rows`, `physical_strategy_rows`, + `forward_program_access_rows`, `native_callable_binding_schema_rows`, + `native_callable_output_rows`, `forward_executor_rows`, + `forward_executor_binding_rows`, `program_tensor_binding_rows`, + `runtime_buffer_rows`, `memory_liveness_rows`, `forward_output_route_rows`, + and forward artifact/merge rows. +- Old route to delete/fail-close for supported rows: + transition may not fetch `producer_state.recurrent_msg` or role-only + `forward_recurrent_msg` when an executable streaming message->transition row + is selected. +- Old route to keep for unsupported rows: + materialized recurrent message remains legal for reverse artifacts, training + tape, reset-present rows, final program tensor retention, multiple transition + consumers without merge/chunk rows, and failed legality checks. + +Implementation slices: + +1. **Make the producer-consumer row select a real body.** + - Use the existing `stream_message_to_transition_input` strategy row as the + legal selector. + - Keep it blocked for mixed-pop/multiple transition consumers until explicit + merge/chunk rows exist. + - Make executable status depend on forward-only no-artifact/no-reset/no-final + tensor retention plus singleton receiver span equality. + - Add a launch/source guardrail: if the row is executable, transition input + must not be assigned from `producer_state.recurrent_msg`. + +2. **Split message native strategy into chunk producer API.** + - Add a registered message native callable entry that produces a normalized + message chunk into caller-owned workspace. + - It should reuse the existing fixed-slot context math and bindings, but + write to a chunk-local tensor instead of requiring the full + `forward_recurrent_msg` runtime buffer. + - The chunk producer must take compiler-owned rows/bindings and stage memory + telemetry; no family, hidden-size, benchmark, or April21-code selector. + +3. **Add transition row-group chunk consumer API.** + - Add row-group entrypoints that accept a chunk-local aggregate message input + directly instead of reading the aggregate binding from `program_tensors`. + - First target the existing supported row groups: + gated logspace and diagonal RTU. + - Keep transition formulas inside transition native row-group functions; the + forward scheduler only orchestrates producer/consumer order. + - Local-only transition outputs should be chunk-local workspace where + liveness rows prove no artifact/tape/final consumer exists. + +4. **Wire the fused streaming loop in `forward_program.cuh`.** + - For each executable message->transition route: + produce normalized message chunk -> feed transition input projection and + row-group chunk -> write public recurrent hidden/output routes. + - Do not allocate full `forward_recurrent_msg` for that route. + - Preserve materialized route for unsupported or required-materialization + cases. + - Keep readout on the existing materialized path initially unless readout + producer-consumer rows prove an executable equivalent. + +5. **Stage and runtime-buffer telemetry.** + - Add stages that distinguish: + `message_transition_stream_before_chunk`, + `message_transition_stream_after_message_chunk`, + `message_transition_stream_after_transition_chunk`, + `message_transition_stream_after_public_state_write`. + - Keep existing message stages so the owner table can prove the previous + `native_forward_message_after_normalize_local0` / full + `forward_recurrent_msg` owner moved. + +6. **Mixed-pop route design in the same pass, implementation optional.** + - Do not add a separate single-pop-only fast path. + - Add or extend typed blockers for multiple transition consumers to name the + missing merge/chunk route rows. + - If practical, add the row schema for chunk routes now; if not, keep mixed + blocked before launch with explicit reason. + +7. **Guardrails and tests.** + - Source guardrail: executable streaming row cannot assign transition + aggregate input from `producer_state.recurrent_msg`. + - Negative legality: malformed/executable streaming rows fail before launch + when artifact/reset/multiple-consumer conditions require materialization. + - Positive compiler product: row table, access rows, runtime buffer rows, and + physical strategy rows are consumed in the launch ABI. + - CUDA smoke: fused forward no-artifact output-cells test. + +Representative commands: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_message_transition_stream_body_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_message_transition_stream_body_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +``` + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_message_transition_stream_body_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_message_transition_stream_body_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_message_transition_stream_body_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +Follow-up guardrails: + +```bash +# Mixed-pop must either use the same route with explicit merge/chunk rows or +# fail closed with the typed blocker; no hidden singleton shortcut. +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_message_transition_stream_body_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_message_transition_stream_body_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_message_transition_stream_body_mixed_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes mixed \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +```bash +# T/K route identity guardrail after the T=1 row moves. +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_message_transition_stream_body_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_message_transition_stream_body_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan tk-scaling \ + --out-dir tmp/fabric_audits/partials/2026-05-04/tk_message_transition_stream_body_guardrail_h32_100m_b128_t2_k2 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 128 --seq-lens 2 --inner-steps 2 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +Acceptance: + +- The single-pop forward row must physically move the full + `forward_recurrent_msg`/message-stage owner. Evidence must show allocator + stage movement, not just metadata. +- sLSTM should drop below the current `4.394 GiB` peak and Axon below the + current `9.501 GiB` peak without a major tok/s regression. +- If memory moves but tok/s regresses, narrow only if launch count/workspace + evidence says the next fusion step will recover it; otherwise reject. +- Reject immediately if semantic rows change, if the scheduler owns primitive + math, if route selection keys on family/hidden size/benchmark, if mixed-pop + silently uses a singleton route, or if training/artifact paths are broken + instead of cleanly falling back to materialized rows. + +### 2026-05-04 - Implementation Note: Streaming Message/Hidden-After Alias Edge + +Boundary packet: + +- Semantic delta: none. Message, transition, readout, output, reset, and + artifact primitive rows are unchanged. +- April21 mechanism being transferred: shorten the live producer-consumer edge + between message aggregate and transition/readout state update instead of + retaining a full recurrent-message stage boundary. +- Compiler products consumed: active + `message_transition_producer_consumer_rows`, active + `physical_strategy_rows`, `runtime_buffer_rows`, + `memory_liveness_rows`, message executor rows, and transition aggregate + access rows. +- Strategy/runtime change: when the compiler selects + `stream_message_to_transition_input`, reverse artifacts are off, final + program tensors are off, and the message output shape matches the recurrent + hidden-after carry shape, the native message callable writes the recurrent + message aggregate into the compiler-owned + `forward_recurrent_hidden_after` runtime buffer. Transition then consumes that + aggregate and overwrites the same storage with the public hidden-after state. +- Materialized fallback: artifact/training/final-state paths, non-singleton + producer routes, shape-mismatched routes, and mixed-pop routes still use the + materialized recurrent-message path until explicit merge/chunk rows exist. +- This is not the final direct chunk body. It is the first executable liveness + edge in the streaming physical step. Acceptance still requires the measured + owner to move on the high-level forward row. + +Verification started: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + --tb=short +``` + +Result: passed. + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + --tb=short +``` + +Result: passed. + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_msg_alias_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_msg_alias_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +``` + +Result: passed. + +Representative single-pop forward audit: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_msg_alias_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_msg_alias_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_streaming_message_hidden_alias_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +Result: + +- sLSTM: `10495.02 tok/s`, `4.394 GiB` peak. Current message-stage + allocation moved down from `4.377 GiB` to `3.773 GiB`, but the total peak did + not move because the later readout/output high-water still reaches + `4.717 GiB`. +- Axon: `6597.59 tok/s`, `8.777 GiB` peak, improved from the fresh current + `9.501 GiB` peak. The message-stage current allocation moved from + `10.202 GiB` to `8.323 GiB`. +- Decision: keep the alias edge as a valid streaming-step liveness improvement, + but do not count it as T=1 closure. The remaining single-pop forward owner is + now the streaming readout/output high-water for sLSTM and remaining native + message/readout temp ownership for Axon. + +Mixed-pop guardrail: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_msg_alias_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_msg_alias_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_streaming_message_hidden_alias_mixed_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes mixed \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +Result: mixed-pop stayed on the materialized/merge-required route, as expected: +sLSTM `6449.77 tok/s`, `9.121 GiB`; Axon `6426.57 tok/s`, `10.152 GiB`. + +T/K streaming guardrail: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_msg_alias_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_msg_alias_20260504 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan tk-scaling \ + --out-dir tmp/fabric_audits/partials/2026-05-04/tk_streaming_message_hidden_alias_guardrail_h32_100m_b128_t2_k2 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 128 --seq-lens 2 --inner-steps 2 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +Result: route remains runnable through the high-level compiler-owned temporal +path. sLSTM `2221.32 tok/s`, `2.914 GiB`; Axon `2310.39 tok/s`, +`11.702 GiB`. + +Rejected probe: + +- I tried replacing the fixed-slot streaming readout's recurrent K/V chunks with + a direct readout kernel over input keys, static recurrent key parts, recurrent + hidden, and recurrent value weights. +- It failed + `tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells` + with output mismatch (`128/128` elements mismatched, max abs diff about + `1.16`), so that direct readout kernel was reverted. +- Do not revive this as a performance patch until the readout/message semantic + fingerprint is proven equivalent. The next valid direction is either a + parity-proven direct readout primitive strategy or the larger + message->transition->readout streaming chunk body that preserves the current + partitioned attention primitive exactly. + +### 2026-05-04 - Boundary Packet: Streaming Forward Step Body + +Status: implementation starting. This is the active next T=1 forward owner. + +Semantic delta: none. + +Stable semantic fingerprints expected: + +- Graph/topology rows unchanged. +- Message primitive rows unchanged. +- Transition primitive rows unchanged. +- Readout primitive rows unchanged. +- Output-route rows unchanged. +- Reset, artifact, and tape semantics unchanged. + +Changed layer: + +- Registered forward physical strategy and memory/liveness execution only. +- The strategy may alter workspace lifetime, producer-consumer scheduling, and + native return groups for already verified compiler rows. + +Compiler products consumed directly: + +- `message_transition_producer_consumer_rows` +- `readout_message_producer_consumer_rows` +- `forward_program_access_rows` +- `memory_liveness_rows` +- `runtime_buffer_rows` +- `forward_output_route_rows` +- `forward_artifact_route_rows` +- registered message, transition, and readout executor rows + +April21 mechanism being semantically transferred: + +- Low-live-memory physical execution shape: + `message producer -> transition consumer -> readout/output consumer`. +- April21 is evidence for the physical shape only. No April21 code copy, no + benchmark-side chunking, no family/hidden-size/shape selectors, and no + scheduler-owned primitive formulas. + +Implementation hypothesis: + +- Current forward still materializes too many compiler stage boundaries. +- The next owner to move is transition/message live overlap: recurrent message + and transition primitive intermediates should become local workspace wherever + artifact/tape/final-state rows do not require them. +- Readout remains correctness-first. The previously rejected direct/keyless + readout shortcut stays deleted until a row proves semantic equivalence. + +Tensors that should stop being fully materialized where legal: + +- `forward_recurrent_msg` after its transition/readout consumers finish. +- `transition_forward_linear_output` +- `transition_forward_matmul_output` +- `transition_forward_diag_output` +- `transition_forward_norm_output` +- `transition_forward_state_output` + +Where consumers move: + +- Transition consumes message aggregate inside the streaming step. +- Transition intermediate outputs are consumed by the transition public-state + epilogue and then released or reused through compiler workspace/liveness rows. +- Readout/output continues to use the existing semantic route unless + `readout_message_producer_consumer_rows` proves executable streaming readout + equivalence. + +Keep/narrow/revert rule: + +- Keep only if output parity is green and measured allocator telemetry shows the + named owner physically moved. +- Narrow if only a no-artifact/no-reset/no-final-state forward route is legal. +- Revert if semantic rows change, if selection keys on family/hidden size or + benchmark row, if the scheduler owns primitive math, or if memory only moves + in metadata. + +Representative row: + +```bash +CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 \ +timeout 1200s uv run python -m benchmarks.fabric.run_audit \ + --plan t1-single-pop \ + --out-dir tmp/fabric_audits/partials/2026-05-04/t1_streaming_forward_step_body_h32_100m_b1024 \ + --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json \ + --families slstm,axoncell --sizes 100m --modes forward \ + --batches 1024 --seq-lens 1 --inner-steps 1 \ + --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 \ + --training-output-boundaries terminal --population-modes single \ + --reset-modes absent --warmup 1 --iterations 1 --require-cuda-temporal-owner +``` + +Follow-up guardrails after the owner moves: + +- Mixed-pop T=1 forward through the same compiler-owned route or explicit typed + merge/chunk blocker. +- Small T>1,K>1 high-level row through the same streaming route identity. + +### 2026-05-04 - Rejected Probe: Eager Compiler Chunk Workspace + +Status: rejected and reverted. + +Probe: + +- Added a compiler runtime-buffer role for transition row-group chunk outputs. +- Routed row-group temporary chunks through eager compiler runtime buffers + instead of native `at::empty` chunk allocations. +- Semantic rows were unchanged and parity passed, but the live set grew because + the chunk buffers were allocated before the forward program body. + +Validation before rejection: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_plan.py::test_readout_message_producer_consumer_rows_are_compiler_owned_legality_rows \ + --tb=short +# 3 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_forward_chunk_workspace_20260504b \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_forward_chunk_workspace_20260504b \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_forward_chunk_workspace_20260504b \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_forward_chunk_workspace_20260504b \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + --tb=short +# 1 passed +``` + +Representative audit: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_streaming_forward_chunk_workspace_h32_100m_b1024 +``` + +Result versus the last accepted streaming-message/hidden-after alias row: + +| Row | Before tok/s | Before peak | Probe tok/s | Probe peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `10495.02` | `4.394 GiB` | `10742.38` | `5.390 GiB` | reject | +| Axon single T=1 forward | `6597.59` | `8.777 GiB` | `6612.23` | `9.527 GiB` | reject | + +Owner movement: + +- The runtime buffer table gained `transition_forward_chunk_output` storage: + about `1.07 GiB` for sLSTM and `0.80 GiB` for Axon. +- That increased `fabric_compiler_runtime_buffer_bytes` and moved allocation + earlier to `forward_runtime_buffers_allocated`. +- It did not lower the readout/message high-water enough to satisfy the keep + rule. + +Keep/narrow/revert decision: + +- Revert the eager chunk-workspace code path. +- Keep this only as evidence: chunk temporaries must be compiler-owned without + becoming eager full forward live-set buffers. A future attempt needs lazy + workspace pooling or a true fused producer-consumer body that removes the + temporary, not an eager runtime-buffer copy of it. + +### 2026-05-04 - Rejected Probe: Direct Message Projection Warp Body + +Status: rejected and reverted. + +Probe: + +- Added a guarded registered native message path for the existing + fixed-slot-context message row. +- The probe tried to compute + `attention(input/recurrent value) -> message_output_weight -> rownorm` + directly into `forward_recurrent_msg` for small message widths, avoiding the + full `weighted_value`, output-weight transpose, and projected-message scratch. +- Semantic rows were intended to remain unchanged. + +Result: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_msg_direct_project_guard2_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_msg_direct_project_guard2_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# failed: output mismatch with NaNs +``` + +Revert validation: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_msg_direct_project_revert_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_msg_direct_project_revert_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +Decision: + +- Do not revive this as a throughput patch without first proving the fused + projected-message formula as a compiler-owned primitive strategy. +- The useful mechanism remains valid: avoid materializing weighted/projected + message boundaries. The implementation must be expressed as a verified + registered message strategy over current rows, with an isolated parity probe + before it is wired into the high-level forward path. +- The next forward pass should either target a lower-risk liveness owner with + exact existing operators, or add a proper reference-vs-native projected + message strategy test before high-level activation. + +### 2026-05-04 - Rejected Probe: Scalar Projected-Message Body + +Status: rejected and reverted. + +Probe: + +- Reconstructed the April21-style tiny/direct projected-message physical shape + as a current registered native message strategy over the existing + fixed-slot-context row. +- The scalar-row body kept the weighted value in registers, projected with + `message_output_weight`, applied the current row normalization, and wrote + directly into `forward_recurrent_msg`. +- Semantic rows stayed unchanged and the focused fused-forward parity smoke + passed. + +Validation: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_msg_scalar_project_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_msg_scalar_project_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +Representative audit: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_streaming_message_scalar_project_h32_100m_b1024 +``` + +Result versus last accepted streaming-message/hidden-after alias row: + +| Row | Before tok/s | Before peak | Probe tok/s | Probe peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `10495.02` | `4.394 GiB` | `9719.60` | `4.394 GiB` | reject | +| Axon single T=1 forward | `6597.59` | `8.777 GiB` | `5351.26` | `8.777 GiB` | reject | + +Owner movement: + +- The native message-stage allocation moved down internally, but the true + forward high-water stayed at readout/output. +- Throughput regressed substantially, so this is not a valid forward closure + step. + +Decision: + +- Revert the scalar projected-message body. +- Keep the mechanism as steering evidence only: a projected-message direct path + needs a parallel/tiled registered strategy that preserves current semantics + and moves the readout/output high-water, not a scalar per-row fallback. + +### 2026-05-04 - Streaming Step Slice: Message To Transition Input + +Status: kept as a narrow gated-transition producer-consumer slice. + +Mechanism: + +- Added a registered message native phase named `stream_transition_input`. +- The active executable `stream_message_to_transition_input` producer-consumer + row no longer hands the transition span a fully materialized + `producer_state.recurrent_msg` when the message strategy supports the direct + phase. +- The fixed-slot-context message strategy now streams + `message -> transition input projection` in batch chunks and writes the + compiler-owned transition input binding directly. +- The gated and diagonal transition row-groups now accept a prefilled compiler + transition-input/cell-input binding and skip their first linear projection + when that binding was produced by the message/transition row. +- The direct phase is legal only when the consumer transition span fingerprint + contains `gated_logspace_recurrence`. A diagonal-row attempt improved speed + but raised Axon peak memory to about `11.0 GiB`, so diagonal keeps the prior + lower-memory materialized route until it has its own fused streaming body. + +Compiler products consumed: + +- `message_transition_producer_consumer_rows` selects the producer-consumer + route. +- `forward_executor_binding_rows` and `native_callable_binding_schema_rows` + identify the transition input projection binding, weight, bias, and output. +- `native_callable_output_rows` and `runtime_buffer_rows` own the destination + transition-input runtime buffer. +- `memory_liveness_rows` still decides whether downstream transition outputs + are materialized, deferred local, or returned. + +Stable semantic fingerprints: + +- Message primitive rows and fixed-slot-context math are unchanged. +- Transition primitive rows are unchanged; the first linear primitive is still + present and verified. Only its producer moves from a materialized message + table to a compiler-owned producer-consumer stream. +- No benchmark-owned time chunking, hidden family selector, route identity, or + April21 code copy was introduced. + +Expected live-set movement: + +- Full `forward_recurrent_msg` should stop being a required forward live tensor + for the fixed-slot-context message -> transition edge when no reverse + artifacts/final tensors are requested. +- The remaining high-water is expected to move to transition/readout/output + consumers until those consumers are folded into the same streaming step body. + +Validation so far: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned \ + --tb=short +# 2 passed +``` + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured \ + --tb=short +# 1 passed +``` + +```bash +uv run python scripts/validate_fabric_generated_catalogs.py +# Generated Fabric catalog headers are up to date. +``` + +```bash +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +Representative audit: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_stream_transition_input_gated_h32_100m_b1024 +``` + +Result versus the last accepted forward steering rows: + +| Row | Before tok/s | Before peak | New tok/s | New peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `10495.02` | `4.394 GiB` | `13192.93` | `4.394 GiB` | keep | +| Axon single T=1 forward | `6597.59` | `8.777 GiB` | `7946.75` | `8.776 GiB` | keep | + +Rejected narrower packet: + +- `tmp/fabric_audits/partials/2026-05-04/t1_stream_transition_input_h32_100m_b1024` + enabled the direct phase for the diagonal consumer too. It produced + sLSTM `13264.64 tok/s`, `4.394 GiB`, but Axon rose to `7525.21 tok/s`, + `11.000 GiB`. +- `tmp/fabric_audits/partials/2026-05-04/t1_stream_transition_input_narrow_h32_100m_b1024` + used a message-width predicate and still selected Axon, with the same + `11.000 GiB` peak. That predicate was rejected. + +Keep/narrow/revert rule applied: + +- Keep the gated-row producer-consumer direct phase because parity passed and + representative forward speed moved without increasing peak memory. +- Keep the route disabled for diagonal rows because the direct phase raised peak + memory. Diagonal needs a fused streaming body that removes its own + transition/readout live-set before direct message handoff is re-enabled. +- This does not close the forward memory owner. `forward_recurrent_msg` remains + in the planned runtime ledger, and the live high-water is still inside + message projection/readout/transition boundaries. The next pass must fold a + larger streaming step body, not widen this route. + +### 2026-05-04 - Rejected Probe: Direct Output-Route Projection + +Status: rejected and reverted. + +Mechanism tested: + +- Reused the registered readout projection output-route path to write pooled + output cells directly into the output route when the route was singleton and + forward-only. +- Semantic rows stayed unchanged. The probe was intended to remove a small + readout/output-route materialization edge after the message and transition + stages. + +Validation: + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_output_route_pooled_direct_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_output_route_pooled_direct_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +Representative audit: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_output_route_pooled_direct_projection_h32_100m_b1024 +``` + +Measured result: + +| Row | tok/s | peak GiB | peak/current owner | decision | +| --- | ---: | ---: | --- | --- | +| sLSTM single forward | `10658.95` | `4.394` | `native_forward_message_after_normalize_local0` | reject | +| Axon single forward | `6625.15` | `8.776` | `native_forward_message_after_normalize_local0` | reject | + +Decision: + +- The output-route current allocation moved down after the route, but the + high-water stayed inside message normalization and throughput did not improve + enough to justify keeping the extra path. +- The patch was reverted. Do not continue output-route-only work until the + producer-consumer body has moved the message/transition owner. + +Current confirmation after revert: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_stream_transition_confirm_h32_100m_b1024 +``` + +| Row | tok/s | peak GiB | current owner | +| --- | ---: | ---: | --- | +| sLSTM single forward | `10156.76` | `4.394` | `native_forward_message_after_normalize_local0` | +| Axon single forward | `6621.97` | `8.776` | `native_forward_message_after_normalize_local0` | + +These confirmation numbers are lower than the older +`t1_stream_transition_input_gated_h32_100m_b1024` row. Treat the confirmation +row as the active truth for the next owner unless a warmed rerun supersedes it. + +### 2026-05-04 - Kept Slice: Transition Input Uses Planned Public-State Buffer + +Status: kept as a narrow forward-only streaming-step liveness slice, not +throughput closure. + +Mechanism: + +- For legal forward-only `stream_message_to_transition_input` gated rows, the + streamed transition input projection now writes into the planned + `forward_recurrent_hidden_after` public-state runtime buffer instead of + allocating a separate full transition-input output. +- The transition primitive rows are unchanged. The first linear primitive still + exists; the implementation changes only the compiler-owned storage used by + its streamed output when artifacts, final program tensors, resets, and + multi-consumer routes do not require materialization. +- The route remains gated-row only. Diagonal/Axon direct transition streaming + stays blocked because prior diagonal direct probes raised peak memory. + +Compiler products consumed: + +- `message_transition_producer_consumer_rows` selects the direct + message-to-transition route. +- `forward_executor_binding_rows`, `native_callable_binding_schema_rows`, and + `native_callable_output_rows` identify the same transition input projection + primitive and output binding. +- `runtime_buffer_rows` owns the target public-state buffer. +- `memory_liveness_rows` and the forward-only no-artifact/no-final route decide + that transition input does not need an independent full materialization. + +Validation: + +```bash +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +# 3 passed +``` + +```bash +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_transition_alias_smoke_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_transition_alias_smoke_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +Representative audit: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_transition_input_public_state_alias_h32_100m_b1024 +``` + +Measured result versus the current confirmation row: + +| Row | Confirm tok/s | Alias tok/s | Confirm peak | Alias peak | owner movement | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single forward | `10156.76` | `10689.81` | `4.394` | `4.394` | message current allocation `4.076 -> 3.513 GiB` | +| Axon single forward | `6621.97` | `6614.23` | `8.776` | `8.776` | unchanged; diagonal route remains materialized | + +Decision: + +- Keep the forward-only gated alias slice because it consumes compiler-owned + route/liveness products, passes focused CUDA parity, and moves the sLSTM + message-stage current allocation without increasing peak memory. +- Do not call this T=1 closure. Max allocated still stays at `4.394/8.776 GiB` + and the dominant owner is still the message projection/normalization stage. +- The next real owner remains the executable chunk-local + message -> transition/readout streaming body. It must avoid both the full + recurrent-message boundary and the remaining transition/readout stage + products where route rows prove local consumption. + +Rejected follow-up cleanup: + +- A wrapper around the existing registered CUDA final-state cell-layout kernel + was tested to replace Python `index_select + torch.cat` final-state + materialization. +- Artifact: + `tmp/fabric_audits/partials/2026-05-04/t1_transition_alias_final_layout_h32_100m_b1024`. +- Result: sLSTM `10693.70 tok/s`, `4.394 GiB`; Axon `6633.68 tok/s`, + `8.776 GiB`. `forward_final_state_materialized` remained the max-allocated + stage and current allocations were unchanged. +- Decision: reverted. It is not wrong cleanup, but it does not move the active + throughput owner and should not distract from the streaming physical body. + +### 2026-05-04 - Deep Dive: Remaining T=1 Throughput Work After Transition Alias + +Status: analysis only. No new optimization patch was made in this pass. + +Baseline and evidence: + +- April21 source of truth: `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json`. +- April21 `h32_t1_bxparams` summary floor: `58732.71 tok/s`, + `2.07 GiB`, `24/24` rows closed across sLSTM + Axon, + `100M/500M/1B`, forward + training, `B=1024/16384`. +- Current accepted single-pop forward artifact: + `tmp/fabric_audits/partials/2026-05-04/t1_transition_input_public_state_alias_h32_100m_b1024`. +- Fresh mixed-pop forward artifact from this analysis pass: + `tmp/fabric_audits/partials/2026-05-04/t1_current_mixed_after_transition_alias_h32_100m_b1024`. +- Current T/K streaming guardrail artifact: + `tmp/fabric_audits/partials/2026-05-04/tk_streaming_forward_body_guardrail_h32_100m_b128_t2_k2`. +- Latest known completed training steering rows remain the earlier + `13.37 tok/s`, `25.049 GiB` sLSTM and `4.01 tok/s`, `81.319 GiB` + Axon rows with `native_after_recurrent_message_local0` as owner. Training + was not refreshed in this analysis pass because forward remains the active + optimization lane. + +Owner table versus April21 `h32_t1_bxparams`: + +| Row | Current tok/s | % of April21 | Slowdown | Peak GiB | Mem x April21 | Current owner | First high-water | +| --- | ---: | ---: | ---: | ---: | ---: | --- | --- | +| sLSTM single T=1 forward | `10689.81` | `18.20%` | `5.49x` | `4.394` | `2.12x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| Axon single T=1 forward | `6614.23` | `11.26%` | `8.88x` | `8.776` | `4.24x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| sLSTM mixed T=1 forward | `6416.09` | `10.92%` | `9.15x` | `9.121` | `4.41x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| Axon mixed T=1 forward | `6443.18` | `10.97%` | `9.12x` | `10.152` | `4.90x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| sLSTM T=2,K=2 forward guardrail | `2235.06` | `3.81%` | `26.28x` | `2.914` | `1.41x` | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | +| Axon T=2,K=2 forward guardrail | `2324.76` | `3.96%` | `25.26x` | `11.702` | `5.65x` | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | +| sLSTM single training, latest known | `13.37` | `0.0228%` | `4392.87x` | `25.049` | `12.10x` | `native_after_recurrent_message_local0` | not refreshed | +| Axon single training, latest known | `4.01` | `0.0068%` | `14646.56x` | `81.319` | `39.28x` | `native_after_recurrent_message_local0` | not refreshed | + +Important diagnosis: + +- The accepted transition-input/public-state alias is a valid compiler-owned + liveness slice, but it is not the missing physical execution model. It moves + sLSTM message-stage current allocation from about `4.076 GiB` to + `3.513 GiB`, while max allocated remains `4.394 GiB` and throughput remains + `5.49x` below April21. +- The current branch is still materializing too many stage boundaries inside + the registered forward body. In single-pop rows, the active current owner is + `native_forward_message_after_normalize_local0`; in mixed-pop rows, the owner + shifts to `native_forward_after_transition_local0`. +- The fresh mixed-pop row matters: directly targeting single-pop-only message + aliases cannot close the real surface. Mixed-pop needs explicit compiler + producer-consumer and merge/chunk routes across multiple transition/readout + consumers. +- The current forward program has the right compiler products + (`primitive_rows`, executor rows, program access rows, + `runtime_buffer_rows`, `memory_liveness_rows`, + message/transition and readout/message producer-consumer rows), but the + executable shape still falls back to materialized `recurrent_msg`, + transition outputs, readout messages, output cells, and output-route + concatenation for important rows. +- April21 should be treated as evidence for a low-live-memory physical shape: + one streaming step body where message projection, transition input, + transition state update, readout message, readout projection, and output + route consume each other with minimal full-bank materialization. Do not copy + April21 code and do not introduce benchmark-owned tiling, family selectors, + hidden-size selectors, or CUDA-only scheduler policy. + +Biggest remaining changes before T=1 can match or exceed April21: + +1. **Executable message -> transition -> readout streaming step body.** + This is the largest forward blocker. Implement a registered strategy that + consumes compiler rows and bindings directly, keeps semantic row + fingerprints stable, and moves consumers into the same physical step. The + tensors that should stop being fully materialized where legality proves + local consumption are `forward_recurrent_msg`, transition primitive + intermediates, readout message, and route-local output cells. Consumers move + into the streaming step via `message_transition_producer_consumer_rows`, + `readout_message_producer_consumer_rows`, access rows, output-route rows, + runtime-buffer rows, and memory-liveness rows. + +2. **Message projection/normalization as a grouped/BMM-oriented producer.** + The current owner is still the message projection/normalization live set. + The next strategy should group compatible message rows into batched or + grouped GEMM/BMM work and reduce traffic between weighted-value, + projected-message, normalization, transition input, and readout consumers. + This must be selected through native strategy legality, not by benchmark, + family, shape, or hidden-size checks. + +3. **Diagonal/Axon streaming body, not widened gated shortcuts.** + Prior diagonal direct-transition probes raised peak memory. Axon will remain + about `9x` behind until diagonal transition and readout consumers have their + own compiler-owned streaming body. Do not re-enable diagonal direct handoff + until its local transition/readout live set is removed or narrowed. + +4. **Mixed-pop merge/chunk ownership.** + Mixed-pop now peaks at transition. The compiler must own multi-consumer + merge/chunk rows for message-to-transition and readout/output routes. Any + strategy that only proves singleton single-pop behavior is steering work, + not closure. + +5. **Final-state/frontend/static lifetime after native body moves.** + `forward_final_state_materialized` is still the max allocated stage in the + ledger, and `frontend_execute_after_static_cache` is already near the + April21 memory floor. These become the next owners only after the native + message/transition/readout body stops dominating. + +6. **Training/backward/reducer liveness after forward compacts.** + Training remains catastrophically far behind, but broad training tuning is + premature while forward still materializes the same live sets. Once the + forward streaming body is compact, close reverse artifact/tape/recurrent + message gradients and parameter reducer liveness through compiler-owned + rows. + +7. **T/K streaming proof.** + T=1 must remain the step unit for `T x K`, not a terminal-only shortcut. + The same route identity must pass a small `T>1,K>1` high-level row without + benchmark-owned time chunking, detach policy, Python replay, or full + `[T, cells, state]` materialization. + +8. **Full closure matrix.** + A steering row is not closure. T=1 throughput closure still requires + `100M/500M/1B`, `B=1024/16384`, sLSTM + Axon, forward + training, + single + mixed populations, small hidden widths, reset-present rows, + materialized/no-final-state rows, and the T/K streaming guardrail. + +Next-plan constraint: + +- Do not spend another pass on isolated output-route, final-state wrapper, + one-buffer alias, or narrow no-copy probes unless the change is explicitly + part of the executable streaming physical step body above. +- The next implementation plan should start from the owner table here and name + exactly which producer-consumer rows, runtime buffer rows, memory-liveness + rows, output-route rows, artifact rows, and semantic fingerprints it consumes. + +### 2026-05-04 - Hypothesis Packet: Streaming Step Body v1 + +Hypothesis: + +- The remaining T=1 forward gap is not a missing semantic row. It is the + physical execution shape: the registered forward body still materializes + producer/consumer boundaries that April21 avoided. +- A small `streaming_step_body_v1` slice should reduce the current owner only + if it consumes compiler route/liveness rows and shortens a real live edge. + +April21 mechanism being transferred: + +- Low-live-memory physical step shape: + `message projection -> transition input/public state -> readout message -> + readout projection -> output route`. +- April21 remains evidence only. No code copy, no benchmark-owned time + chunking, no family selector, no hidden-size selector, no fixed-slot policy + outside registered native strategy implementation. + +Lane: + +- Throughput/native implementation over existing compiler rows. +- Semantic delta: none. + +Rows/fingerprints expected to stay stable: + +- Primitive rows. +- Forward executor rows. +- Forward executor binding rows. +- Program tensor binding rows. +- Native callable binding schema rows. +- Output route rows. +- Artifact route rows. +- Memory/liveness rows except for strategy-status metadata if required. + +Compiler products consumed directly: + +- `message_transition_producer_consumer_rows`. +- `readout_message_producer_consumer_rows`. +- `forward_output_route_rows`. +- `forward_program_access_rows`. +- `runtime_buffer_rows`. +- `memory_liveness_rows`. +- `physical_strategy_rows`. +- `native_strategy_rows`. +- `native_callable_output_rows`. + +First implementation slice: + +- Keep the existing legal direct message-to-transition route. +- Add a route-owned streaming step body helper around the active readout route + so the readout consumer runs under the same producer-consumer legality + contract and can later share chunk-local message/transition workspace. +- Keep unsupported cases fail-closed before launch. If artifacts, final program + tensors, reset policy, multi-producer merge, or diagonal transition require + materialized state, do not silently select the streaming body. + +Tensors intended to stop being fully materialized where legality proves local +consumption: + +- `forward_recurrent_msg`. +- Transition input/output intermediates. +- `forward_output_msg`. +- Route-local `forward_output_cells`. + +Smallest representative row: + +- h32 100M B1024 T=1 K=1 forward, sLSTM + Axon, single population first. + +Follow-up guardrails: + +- Mixed-pop T=1 forward through the same compiler-owned route or explicit typed + blocker. +- Small high-level T>1,K>1 row through the same streaming route identity. +- Focused CUDA parity for fused forward output/state. + +Keep/narrow/revert rule: + +- Keep only if semantic row fingerprints stay stable and the named owner moves + in memory-stage telemetry, allocator peak/current, or tok/s without parity + regression. +- Narrow to singleton/sLSTM only if legality rows explicitly block unsupported + consumers before launch. +- Revert if the patch only changes labels, adds another wrapper, increases + peak memory, introduces benchmark/time/family/shape selectors, or leaves the + owner unchanged. + +### 2026-05-04 - Implementation: Output Route Projection-Into Slice + +Implemented slice: + +- Added a forward readout native strategy phase, `projection_into`, so a + registered readout executor can write projection output into a caller-owned + compiler route target. +- The fused forward program now writes legal reset-free output routes directly + into the compiler-owned `output_seq` step/slice instead of materializing + route-local `forward_output_cells` and then copying/cat-ing. +- Added a strided output projection kernel for `output_seq.select(t)` and + concat-route slices. This is required for the same step body to work for + `T>1/K>1`; the change is not a terminal-only T=1 shortcut. +- Legality is intentionally bounded to reset-free forward rows for this slice. + Reset-present rows continue through the materialized route until reset-aware + direct output routing gets separate parity. + +Boundary proof: + +- Semantic delta: none. +- Stable compiler products consumed: readout executor rows, native callable + binding schema rows, readout program access rows, output route rows, + runtime-buffer rows, physical strategy rows, and reset rows. +- No benchmark/family/hidden-size route selector was added. +- The scheduler still does not own readout math; projection lives behind the + registered native readout strategy phase. + +Validation: + +```text +uv run pytest -q \ + tests/test_fabric_backend_boundaries.py::test_forward_message_readout_handlers_use_native_strategy_access_schema \ + tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local \ + tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured \ + tests/test_fabric_backend_plan.py::test_readout_executor_patterns_follow_registered_readout_specs \ + --tb=short +# 4 passed + +uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed +``` + +Steering audit artifacts: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_h32_100m_b1024 +tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_mixed_h32_100m_b1024 +tmp/fabric_audits/partials/2026-05-04/tk_output_route_projection_into_guardrail_h32_100m_b128_t2_k2 +``` + +Measured result versus the prior accepted rows: + +| Row | Before tok/s | After tok/s | Before peak | After peak | Status | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `10689.81` | `14058.92` | `4.394 GiB` | `4.394 GiB` | keep | +| Axon single T=1 forward | `6614.23` | `8579.00` | `8.776 GiB` | `8.776 GiB` | keep | +| sLSTM mixed T=1 forward | `6416.09` | `7548.00` | `9.121 GiB` | `9.121 GiB` | keep | +| Axon mixed T=1 forward | `6443.18` | `7540.89` | `10.152 GiB` | `10.152 GiB` | keep | +| sLSTM T=2,K=2 forward guardrail | `2235.06` | `2663.16` | `2.914 GiB` | `2.852 GiB` | keep | +| Axon T=2,K=2 forward guardrail | `2324.76` | `2621.03` | `11.702 GiB` | `11.702 GiB` | keep | + +Decision: + +- Keep the slice. It is compiler-owned, moves throughput materially on + single-pop, mixed-pop, and T/K forward guardrails, and does not increase peak + memory. +- Do not call this throughput closure. The dominant owner is still the + message/transition live set: + - single-pop current owner remains + `native_forward_message_after_normalize_local0`; + - mixed-pop current owner remains + `native_forward_after_transition_local0`; + - T/K first high-water remains transition/readout-body related. + +Residual parity note: + +- `tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_readout_closed_region_matches_pytorch_reference` + still fails for reset-present mixed-pop output parity. +- I disabled the new direct output-route projection path locally and the test + failed with the same numeric diff, so this is not introduced by the + `projection_into` slice. It remains a pre-existing reset/mixed route parity + issue to close before full throughput acceptance. + +Next owner: + +- Continue with the broader streaming physical step body: + message projection/normalization and transition/readout consumers need to + share chunk-local producer-consumer work so `forward_recurrent_msg`, + transition primitive intermediates, and `forward_output_msg` stop dominating + the live set. + +### 2026-05-04 - Deep Dive: Remaining T=1 Work After Projection-Into + +Scope: + +- Analysis only. No new optimization patch in this section. +- Forward numbers are fresh from the accepted `projection_into` slice. +- Training numbers are the latest known current-code single-pop rows from + `t1_current_deepdive_after_message_pc_h32_100m_b1024`; they were not rerun + after `projection_into` and should be treated as stale guardrail evidence. +- April21 target remains + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json`, + row `h32_t1_bxparams`: `58732.71 tok/s`, `2.07 GiB`. + +Current owner table: + +| Row | Current tok/s | % of Apr21 | Slowdown | Peak GiB | Mem x Apr21 | Current allocated owner | First high-water | +| --- | ---: | ---: | ---: | ---: | ---: | --- | --- | +| sLSTM single T=1 forward | `14058.92` | `23.94%` | `4.18x` | `4.394` | `2.12x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| Axon single T=1 forward | `8579.00` | `14.61%` | `6.85x` | `8.776` | `4.24x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| sLSTM mixed T=1 forward | `7548.00` | `12.85%` | `7.78x` | `9.121` | `4.41x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| Axon mixed T=1 forward | `7540.89` | `12.84%` | `7.79x` | `10.152` | `4.90x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| sLSTM T=2,K=2 forward guardrail, B=128 | `2663.16` | steering only | steering only | `2.852` | steering only | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | +| Axon T=2,K=2 forward guardrail, B=128 | `2621.03` | steering only | steering only | `11.702` | steering only | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | +| sLSTM single T=1 training, stale | `13.37` | `0.0228%` | `4393.8x` | `25.049` | `12.10x` | `native_after_recurrent_message_local0` | `native_after_recurrent_message_local0` | +| Axon single T=1 training, stale | `4.01` | `0.0068%` | `14657.6x` | `81.319` | `39.29x` | `native_after_recurrent_message_local0` | `native_after_recurrent_message_local0` | + +Largest materialized runtime roles still visible in the fresh forward ledgers: + +| Row | Largest compiler runtime role tensors | +| --- | --- | +| sLSTM single T=1 forward | `transition_forward_linear_output=2.812 GiB`, `transition_forward_matmul_output=2.250 GiB`, `transition_forward_state_output=0.562 GiB`, `transition_forward_norm_output=0.562 GiB`, `forward_recurrent_msg=0.562 GiB`, `forward_recurrent_hidden_after=0.562 GiB` | +| Axon single T=1 forward | `transition_forward_linear_output=3.500 GiB`, `transition_forward_diag_output=3.500 GiB`, `transition_forward_norm_output=1.750 GiB`, `forward_recurrent_msg=1.750 GiB`, `forward_recurrent_hidden_after=1.750 GiB` | +| Mixed T=1 forward | `transition_forward_linear_output=5.086 GiB`, `transition_forward_matmul_output=2.906 GiB`, `transition_forward_norm_output=1.453 GiB`, `transition_forward_diag_output=1.453 GiB`, `forward_recurrent_msg=1.453 GiB`, `forward_recurrent_hidden_after=1.453 GiB` | +| Axon T=2,K=2 guardrail | `transition_forward_diag_output=2.625 GiB`, `forward_recurrent_msg=0.875 GiB`, `forward_recurrent_hidden_after=0.875 GiB` | + +Diagnosis: + +- `projection_into` was a valid local liveness/throughput win, but it did not + reconstruct the low-live-memory April21 physical execution shape. The active + T=1 path still materializes too many compiler stage boundaries. +- The single-pop forward owner has moved to the message projection/normalization + live set. That is now the direct path to more forward improvement. +- The mixed-pop forward owner remains transition materialization. Mixed-pop + cannot close by proving only singleton single-pop shortcuts. +- The T/K guardrail uses the same kind of owner across local steps, so the next + strategy must stay a streaming step body for `T>1,K>1`, not a terminal-only + T=1 optimization. +- Training remains far worse than forward, but broad training optimization + should remain a parity/liveness guardrail until the forward streaming body + stops materializing the same producer/consumer boundaries. + +Biggest remaining changes before T=1 can match or exceed April21: + +1. **Program-level streaming physical step body v2.** + Build a registered physical strategy where message projection, + normalization, transition input/state update, readout message, readout + projection, and output route are one compiler-owned producer-consumer step. + It must consume primitive rows, executor rows, program access rows, + producer-consumer rows, runtime buffer rows, memory-liveness rows, artifact + rows, and output-route rows directly. It must keep primitive/tensor-role + fingerprints stable. + +2. **Grouped/BMM/GEMM lowering for message and transition producers.** + The useful April21 mechanism to transfer is not code; it is the physical + shape of reducing producer boundaries into larger grouped or batched GEMM + work with less intermediate traffic. The current message path still exposes + weighted-value, output-weight, projected, contiguous, and normalized stages. + Compatible message/transition/readout row groups need a registered strategy + that chooses fewer larger matrix operations and streams the result to + consumers instead of saving each stage. + +3. **Transition output liveness removal, especially Axon diagonal output.** + The ledger still shows multi-GiB `transition_forward_linear_output`, + `transition_forward_matmul_output`, and `transition_forward_diag_output` + tensors. These should become chunk-local workspace or direct consumer inputs + when legality proves they are not required as artifacts/tape. Prior diagonal + direct probes raised peak memory, so the next diagonal move must be inside + the broader streaming body, not a widened shortcut. + +4. **Mixed-pop merge/chunk ownership.** + Mixed-pop forward is already transition-owned. Closure requires compiler + merge/chunk rows for multiple message/transition/readout producers and + consumers. If the strategy is legal only for singleton rows, it is useful + steering work, not T=1 closure. + +5. **Reset-present mixed route parity.** + The current reset-present mixed-pop readout parity failure predates + `projection_into`, but it still blocks full acceptance. Keep it out of the + main forward optimization lane unless the next strategy touches reset or + routed readout semantics. + +6. **Forward final-state/frontend/static lifetime after native body moves.** + Some ledgers still name final-state/static/front-end high-water points, but + they are not the biggest actionable owner while native message/transition + stage materialization dominates. Revisit after the streaming body moves the + native owners. + +7. **Training/reverse/reducer liveness after forward compacts.** + Once forward uses the compact physical body, the training lane must close + reverse artifacts/tape, recurrent-message gradients, reducer liveness, and + checkpoint/recompute policy through the same compiler rows. Do not implement + broad reverse shortcuts before the forward physical step is compact. + +8. **Full April21 closure matrix.** + Current steering rows are not closure. T=1 still requires the April21-shaped + surface: `100M/500M/1B`, `B=1024/16384`, sLSTM + Axon, forward + training, + single + mixed populations, reset absent/present, small hidden stress rows, + materialized/no-final-state modes, plus the T/K streaming guardrail. + +Recommended next implementation target: + +- Start `streaming_physical_step_v2` as a registered compiler-owned strategy. +- Initial legal row: reset-free h32 100M B1024 T=1 forward, single-pop sLSTM + and Axon. +- Required follow-up in the same pass family: mixed-pop T=1 forward and small + high-level `T>1,K>1` forward guardrail through the same route identity or an + explicit typed legality blocker. +- Tensors that should stop being fully materialized where legality proves + local consumption: `forward_recurrent_msg`, + `transition_forward_linear_output`, `transition_forward_matmul_output`, + `transition_forward_diag_output`, `transition_forward_norm_output`, + `forward_output_msg`, and route-local `forward_output_cells`. +- Consumers move into the step body through + `message_transition_producer_consumer_rows`, + `readout_message_producer_consumer_rows`, `forward_program_access_rows`, + `forward_output_route_rows`, `runtime_buffer_rows`, + `memory_liveness_rows`, and artifact/reducer rows. +- Keep/narrow/revert rule: keep only if the named owner physically moves in + allocator-stage telemetry, launch shape, or tok/s without parity regression. + Narrow if unsupported rows fail closed through legality. Revert if the patch + only changes labels, adds benchmark/family/shape selectors, copies April21 + code, or leaves the owner unchanged. + +### 2026-05-04 - Rejected Slice: Gated Recurrent Gate Add-In-Place + +Hypothesis: + +- April21 mechanism being semantically transferred: reduce chunk-local + transition intermediates inside the streaming physical step body. +- Current compiler products consumed: gated transition primitive rows, native + callable rows, tensor bindings, runtime-buffer rows, and existing transition + row-group liveness. +- Semantic row fingerprints expected to stay stable: no declaration, primitive + row, tensor role, output-route, reset, or tape change. +- Proposed tensor lifetime move: avoid allocating a separate + `recurrent_gate_logits_chunk` by adding recurrent gate matmul output into the + feedforward `gate_logits_chunk` before recurrence. +- Keep/narrow/revert rule: keep only if representative T=1 forward throughput + or named owner improves without peak-memory or parity regression. + +Validation before audit: + +```text +uv run python scripts/validate_fabric_generated_catalogs.py +# passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_boundaries.py::test_temporal_engine_table_sources_do_not_use_cell_or_benchmark_route_selectors \ + --tb=short +# 2 passed + +CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_gated_gate_add_20260504 \ + TRITON_CACHE_DIR=/tmp/cortical_triton_gated_gate_add_20260504 \ + timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed + +CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_gated_gate_add_20260504 \ + TRITON_CACHE_DIR=/tmp/cortical_triton_gated_gate_add_20260504 \ + timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store \ + --tb=short +# 1 passed +``` + +Representative audit artifact: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_gated_gate_add_forward_h32_100m_b1024 +``` + +Measured result versus the accepted `projection_into` rows: + +| Row | Accepted tok/s | Probe tok/s | Accepted peak | Probe peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `14058.92` | `10797.59` | `4.394 GiB` | `4.394 GiB` | revert | +| Axon single T=1 forward | `8579.00` | `6616.60` | `8.776 GiB` | `8.776 GiB` | revert | + +Decision: + +- Reverted this slice. It preserved semantics and passed narrow parity, but it + materially regressed the representative T=1 forward throughput and produced + no peak-memory win. +- The result is useful evidence: removing one small chunk-local gated tensor is + not enough, and the next step should not continue isolated one-buffer tweaks. + The remaining owner still points at the broader streaming physical step body: + reduce full producer boundaries and data traffic across message, transition, + readout, and output routes through compiler-owned rows. + +### 2026-05-04 - Hypothesis: Fixed-Slot Stream-Transition Producer Consumer + +Hypothesis: + +- Replace the current direct message-to-transition streaming implementation + with a registered native producer-consumer body that computes weighted value, + message output projection, message normalization, and transition input + projection inside the same legal route. +- This targets the single-pop owner + `native_forward_message_after_normalize_local0` and the broader producer + boundary that still materializes message-stage chunks. + +April21 mechanism being semantically transferred: + +- Low-live-memory physical execution shape: message producer feeds transition + consumer without a fully materialized projected-message boundary. +- This is a semantic transfer only. No April21 code is copied. + +Lane: + +- Throughput/native implementation over existing compiler rows. +- Semantic delta: none. + +Rows/fingerprints expected to stay stable: + +- Primitive rows. +- Forward executor rows. +- Forward executor binding rows. +- Program tensor binding rows. +- Native callable binding schema rows. +- Message-transition producer-consumer rows. +- Output route rows. +- Artifact/tape rows. + +Compiler products consumed directly: + +- `message_transition_producer_consumer_rows`. +- `forward_program_access_rows`. +- `program_tensor_binding_rows`. +- `runtime_buffer_rows`. +- `memory_liveness_rows`. +- fixed-slot message native strategy rows. +- transition linear native callable output rows. + +Implementation slice: + +- Add a dynamic-shape warp-level kernel under + `run_fixed_slot_context_stream_transition_input`. +- The fused path is selected only when the existing compiler-owned direct + message-to-transition route has already been selected and shape/shared-memory + constraints are legal. +- Unsupported rows fall back to the existing materialized chunk implementation; + no benchmark, family, hidden-size, or single/mixed-pop selector is added. + +Tensors intended to stop being fully materialized on legal rows: + +- route-local `message_chunk`; +- route-local `weighted_value`; +- projected-message normalization output before transition input projection. + +Smallest representative row: + +- h32 100M B1024 T=1 K=1 forward, sLSTM + Axon, single population. + +Follow-up guardrails: + +- Focused CUDA fused-forward parity smoke. +- Terminal-loss artifact-store smoke. +- Mixed-pop T=1 forward through the same route or existing typed blocker. +- Small high-level `T>1,K>1` forward guardrail through the same streaming route. + +Keep/narrow/revert rule: + +- Keep only if the single-pop owner physically moves in native memory-stage + telemetry or tok/s improves without peak-memory regression. +- Narrow if the fused kernel is useful only for a typed legal subset and + unsupported rows remain compiler-owned. +- Revert if throughput regresses, peak memory grows, parity fails, or the named + owner remains unchanged. + +Result: + +- Rejected and reverted. +- Artifact: + `tmp/fabric_audits/partials/2026-05-04/t1_stream_transition_fused_body_h32_100m_b1024`. +- Validation before representative audit passed: + - `uv run python scripts/validate_fabric_generated_catalogs.py` + - focused source/plan guardrails: `3 passed` + - fused-forward CUDA parity smoke: `1 passed` + - terminal-loss artifact-store CUDA smoke: `1 passed` + +Measured result versus the accepted `projection_into` rows: + +| Row | Accepted tok/s | Probe tok/s | Accepted peak | Probe peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `14058.92` | `6128.77` | `4.394 GiB` | `4.394 GiB` | revert | +| Axon single T=1 forward | `8579.00` | `6632.75` | `8.776 GiB` | `8.778 GiB` | revert | + +Interpretation: + +- The fused warp-level per-row body reduced one local materialization shape but + destroyed arithmetic efficiency. This confirms that the next useful direction + should preserve large GEMM/BMM structure while reducing producer boundaries, + not replace GEMM-shaped work with per-row scalar loops. +- The next plan should focus on compiler-owned grouped/BMM/GEMM lowering and + liveness around those calls: precompose legal linear maps where semantics + allow it, keep normalization boundaries explicit, and stream only the + boundaries that do not sacrifice matrix throughput. + +### 2026-05-04 - Deep Dive: Remaining T=1 Throughput Work After Rejected Scalar Fusion + +Accepted current baseline: + +- Use `t1_output_route_projection_into_h32_100m_b1024` and + `t1_output_route_projection_into_mixed_h32_100m_b1024` as the active + accepted rows. +- The later gated-add and stream-transition fused-body probes are rejected + evidence, not the current baseline. + +Current accepted owner table: + +| Row | Current tok/s | Current peak | April21 tok/s | April21 peak | First/peak owner | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `14058.92` | `4.394 GiB` | `58732.71` | `2.07 GiB` | `native_forward_message_after_normalize_local0` | +| Axon single T=1 forward | `8579.00` | `8.776 GiB` | `58732.71` | `2.07 GiB` | `native_forward_message_after_normalize_local0` | +| sLSTM mixed T=1 forward | `7548.00` | `9.121 GiB` | `58732.71` | `2.07 GiB` | `native_forward_after_transition_local0` | +| Axon mixed T=1 forward | `7540.89` | `10.152 GiB` | `58732.71` | `2.07 GiB` | `native_forward_after_transition_local0` | + +Largest accepted live roles: + +- sLSTM single: `transition_forward_linear_output` (`2.812 GiB`), + `transition_forward_matmul_output` (`2.250 GiB`), then + `transition_forward_state_output`, `transition_forward_norm_output`, + `forward_recurrent_msg`, and `forward_recurrent_hidden_after` + (`0.562 GiB` each). +- Axon single: `transition_forward_linear_output` (`3.500 GiB`), + `transition_forward_diag_output` (`3.500 GiB`), then + `transition_forward_norm_output`, `forward_recurrent_msg`, and + `forward_recurrent_hidden_after` (`1.750 GiB` each). +- Mixed rows combine both families and move the first peak to transition, + with `transition_forward_linear_output`, `transition_forward_matmul_output`, + `transition_forward_norm_output`, and `transition_forward_diag_output` + all live around the same physical step. + +Diagnosis: + +- Current T=1 forward is still a compiler-correct registered path, but it is + not yet reconstructing the low-live-memory physical execution shape that the + April21 row demonstrated. +- The gap is not one remaining temporary. The active path still exposes too + many compiler stage boundaries as full tensors: projected message, + transition input/linear output, transition matmul/diag output, normalized + transition output, recurrent hidden after, and output-route intermediates. +- Narrow aliases and per-row scalar fusion have not worked. They either left + the peak owner unchanged or regressed throughput by replacing dense matrix + work with inefficient scalar loops. +- The useful transfer from April21 is therefore physical shape, not code: + message -> transition/readout -> output with minimal live stage boundaries, + while preserving compiler-owned rows, bindings, output routes, artifact/tape + policy, and T/K streaming legality. + +Biggest remaining changes before April21-class T=1 forward: + +1. Build a GEMM-preserving streaming physical step strategy. + - The strategy must keep dense affine/projection work as large GEMM/BMM or + grouped-GEMM-shaped operations. + - It should reduce materialized producer boundaries around those dense calls, + not replace the dense calls with per-receiver scalar kernels. + - Legal linear-only producer/consumer maps may be precomposed by compiler + rows. Nonlinear boundaries such as normalization/gating must remain + explicit unless the strategy proves an equivalent fused implementation. + +2. Make transition liveness executable, not only descriptive. + - The largest current buffers are transition-side outputs. + - `transition_forward_linear_output`, `transition_forward_matmul_output`, + `transition_forward_diag_output`, and `transition_forward_norm_output` + need workspace/lifetime ownership from compiler rows so consumers can read + them before they become long-lived logical artifacts. + - Artifacts required for training must be explicitly routed into tape rows; + forward-only rows should not retain full transition products. + +3. Generalize direct producer-consumer rows for mixed populations. + - Single-pop rows already select a direct message-to-transition route, but + mixed rows still peak at transition because merge/chunk ownership is not + compact enough. + - The next mixed-pop work is not a separate benchmark path. It should reuse + the same streaming step body with explicit chunk/merge/output-route rows. + +4. Preserve the T/K streaming contract. + - T=1 remains the first acceptance row, but only as one invocation of the + same streaming step body used by `T>1,K=1` and `T>1,K>1`. + - The follow-up guardrail must prove no benchmark-owned time chunking, + Python replay, detach policy, hidden route identity, or full + `[T, cells, state]` materialization appears. + +5. Keep training as a bounded guardrail until forward shape is compact. + - Training is still far behind, but optimizing reverse before the forward + step body is compact will chase the wrong owners. + - After forward liveness is fixed, the next high-impact training work is + reverse/tape/reducer liveness using the same artifact routes. + +Next implementation direction: + +- Define a registered streaming physical step strategy over existing semantic + rows that names the compiler access rows, producer-consumer rows, + output-route rows, memory-liveness rows, runtime-buffer rows, artifact rows, + and reducer/tape rows it consumes. +- Start with the single-pop h32 100M B1024 T=1 forward row and require the + same strategy to pass a small high-level `T>1,K>1` guardrail before treating + it as an accepted direction. +- Keep/narrow/revert rule: keep only if the owner table moves away from full + transition/message materialization while throughput improves or peak memory + drops without parity regressions. + +### 2026-05-04 - Plan: GEMM-Preserving Streaming Transition Workspace + +Classifier: + +- Lane: throughput strategy / native implementation. +- Semantic delta: none. +- April21 usage: evidence for physical execution shape only; no code copy. + +Hypothesis: + +- The next recoverable gap is not another one-buffer alias. It is the lack of + a compiler-owned streaming workspace around the message-to-transition + producer-consumer route. +- The accepted path already selects `message_transition_producer_consumer_rows`, + but `run_fixed_slot_context_stream_transition_input` still creates a + route-local projected message chunk and then transition forward allocates + large transition-side products. We need to preserve GEMM/BMM shape while + shortening those lifetimes. + +Rows/fingerprints expected to stay stable: + +- Primitive rows. +- Forward executor rows. +- Forward executor binding rows. +- Program tensor binding rows. +- Forward program access rows. +- Message-transition producer-consumer rows. +- Output route rows. +- Artifact/tape rows. + +Compiler products consumed directly: + +- `forward_program_access_rows`. +- `message_transition_producer_consumer_rows`. +- `program_tensor_binding_rows`. +- `runtime_buffer_rows`. +- `memory_liveness_rows`. +- `native_callable_output_rows`. +- `forward_artifact_route_rows`. +- `forward_output_route_rows`. + +Implementation plan: + +1. Add a transition streaming-workspace contract. + - Use existing runtime-buffer/liveness rows first; add only row fields or + typed validation needed to express workspace ownership. + - The contract must say which native callable outputs are logical artifacts + and which are route-local workspace. + - Unsupported rows remain typed blockers before launch. + +2. Replace message-to-transition chunk materialization with workspace reuse, + while preserving GEMM. + - Keep the current weighted-value kernel and dense affine GEMM. + - Keep message normalization explicit. + - Route projected/normalized message through compiler-owned workspace, then + immediately into transition input projection. + - Do not add per-receiver scalar projection loops or family/shape selectors. + +3. Shorten transition-side output lifetimes. + - For forward-only rows, make `transition_forward_linear_output`, + `transition_forward_matmul_output` / `transition_forward_diag_output`, and + `transition_forward_norm_output` route-local workspace unless an artifact + or downstream binding requires them. + - For training/artifact rows, route only required tape tensors through + artifact rows; do not retain full products by default. + +4. Keep mixed-pop on the same route. + - If mixed rows need explicit merge/chunk ownership, express it through + compiler rows rather than singleton checks. + - The same streaming body must handle single and mixed populations through + row spans/chunks, not separate benchmark paths. + +5. Validate in the smallest useful sequence. + - First: source/catalog/static guardrails that row fingerprints are stable + and the active route consumes the intended rows. + - Then: focused CUDA fused-forward parity smoke. + - Then representative current-code row: + `tmp/fabric_audits/partials/2026-05-04/t1_gemm_streaming_transition_workspace_h32_100m_b1024`. + - Then mixed row: + `tmp/fabric_audits/partials/2026-05-04/t1_gemm_streaming_transition_workspace_mixed_h32_100m_b1024`. + - Then T/K guardrail: + `tmp/fabric_audits/partials/2026-05-04/tk_gemm_streaming_transition_workspace_guardrail_h32_100m_b128_t2_k2`. + +Keep/narrow/revert rule: + +- Keep if the named owner moves away from full message/transition + materialization and either tok/s improves or peak memory drops without parity + regressions. +- Narrow if only one legal route shape benefits and the blocker is represented + as compiler legality metadata. +- Revert if throughput regresses like the scalar fusion probe, peak memory + grows, row fingerprints change, or the same owner remains. + +Result: + +- Rejected and reverted. +- Artifact: + `tmp/fabric_audits/partials/2026-05-04/t1_gemm_streaming_transition_workspace_h32_100m_b1024`. +- Validation before audit: + - `uv run python scripts/validate_fabric_generated_catalogs.py` + - `uv run pytest -q tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows --tb=short` + - `uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + - `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows --tb=short` + - `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + - `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store --tb=short` + +Measured result versus accepted `projection_into`: + +| Row | Accepted tok/s | Probe tok/s | Accepted peak | Probe peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `14058.92` | `10972.27` | `4.394 GiB` | `4.394 GiB` | revert | +| Axon single T=1 forward | `8579.00` | `6602.65` | `8.776 GiB` | `8.776 GiB` | revert | + +Interpretation: + +- Keeping the gate-affine BMM result in a packed temporary and consuming it + from a packed recurrence kernel preserved the semantic row path, but it did + not move the named runtime role owners. `transition_forward_linear_output`, + `transition_forward_matmul_output` / `transition_forward_diag_output`, + `forward_recurrent_msg`, and `forward_recurrent_hidden_after` stayed at the + same sizes. +- Throughput regressed materially, likely because the recurrence now reads the + BMM output in a less favorable layout while the high-water allocations remain. +- Do not retry packed-gate layout as a standalone optimization. The next useful + direction must move an actual full producer boundary or use a larger + compiler-planned GEMM/grouped-GEMM producer-consumer transform whose owner + movement is visible before representative audit. + +### 2026-05-04 - Deep Dive: Remaining T=1 Work After GEMM Workspace Rejection + +Accepted current baseline: + +- Keep `projection_into` as the accepted active row: + `tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_h32_100m_b1024`. +- Treat `t1_stream_transition_fused_body_h32_100m_b1024` and + `t1_gemm_streaming_transition_workspace_h32_100m_b1024` as rejected evidence. + +Current accepted owner table: + +| Row | Current tok/s | Current peak | April21 tok/s | April21 peak | Observed peak stage | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `14058.92` | `4.394 GiB` | `58732.71` | `2.07 GiB` | `native_forward_message_after_normalize_local0` | +| Axon single T=1 forward | `8579.00` | `8.776 GiB` | `58732.71` | `2.07 GiB` | `native_forward_message_after_normalize_local0` | +| sLSTM mixed T=1 forward | `7548.00` | `9.121 GiB` | `58732.71` | `2.07 GiB` | `native_forward_after_transition_local0` | +| Axon mixed T=1 forward | `7540.89` | `10.152 GiB` | `58732.71` | `2.07 GiB` | `native_forward_after_transition_local0` | + +Important distinction: + +- The single-pop high-water is observed during message normalization/readout + stages, but the logical runtime-role bytes live at that point are dominated + by transition products and recurrent public/message banks. +- Therefore the next owner is not simply "the message kernel." The useful owner + is the lifetime overlap between message-stage work and full transition/runtime + buffers. + +Largest accepted runtime roles: + +- sLSTM single: + - `transition_forward_linear_output`: `2.812 GiB` + - `transition_forward_matmul_output`: `2.250 GiB` + - `forward_recurrent_hidden_after`: `0.562 GiB` + - `forward_recurrent_msg`: `0.562 GiB` + - `transition_forward_norm_output`: `0.562 GiB` + - `transition_forward_state_output`: `0.562 GiB` +- Axon single: + - `transition_forward_diag_output`: `3.500 GiB` + - `transition_forward_linear_output`: `3.500 GiB` + - `forward_recurrent_hidden_after`: `1.750 GiB` + - `forward_recurrent_msg`: `1.750 GiB` + - `transition_forward_norm_output`: `1.750 GiB` +- Mixed: + - `transition_forward_linear_output`: `5.086 GiB` + - `transition_forward_matmul_output`: `2.906 GiB` + - `forward_recurrent_hidden_after`: `1.453 GiB` + - `forward_recurrent_msg`: `1.453 GiB` + - `transition_forward_diag_output`: `1.453 GiB` + - `transition_forward_norm_output`: `1.453 GiB` + +Rejected directions now ruled out: + +- Per-row scalar stream-transition fusion: peak unchanged and throughput + regressed to `6128.77` sLSTM / `6632.75` Axon tok/s. +- Packed gate-logit BMM output consumed by a packed recurrence kernel: peak + unchanged and throughput regressed to `10972.27` sLSTM / `6602.65` Axon + tok/s. +- Standalone alias/no-copy tweaks that do not classify allocator high-water or + move storage lifetime. + +What remains before T=1 can match or exceed April21: + +1. Real transition runtime-buffer liveness. + - Existing `memory_liveness_rows` and deferred-local runtime-buffer support + are present, but the active row still reports full logical transition + products. + - The next change must make `transition_forward_linear_output`, + `transition_forward_matmul_output` / `transition_forward_diag_output`, and + `transition_forward_norm_output` stop overlapping with message/readout + high-water for forward-only rows. + - This likely means a row-group transition executor that produces only the + public output/carry required by downstream rows and treats intermediate + transition products as chunk-local workspace unless artifact/tape rows + require them. + +2. Message-to-transition/readout fanout without full recurrent-message bank. + - `run_fixed_slot_context_stream_transition_input` still creates a + route-local `message_chunk`, and the regular message carrier still writes + a full `forward_recurrent_msg` unless the direct row bypasses it. + - The larger strategy should stream normalized projected-message chunks into + both transition input and readout/output consumers through compiler route + rows, instead of materializing a full recurrent-message bank for + forward-only rows. + - Dense projection should remain GEMM/BMM-shaped. Normalization remains an + explicit boundary unless a fused strategy proves equivalence. + +3. Mixed-pop chunk/merge ownership. + - Mixed rows are currently worse than single-pop and peak after transition. + - This is the same problem at a larger row group: transition products for + multiple populations overlap because chunk/merge ownership is not compact + enough. + - The fix must reuse the same streaming step body with explicit + chunk/merge/output-route rows, not introduce a mixed-pop special path. + +4. T/K streaming proof. + - Any accepted T=1 physical strategy must run as the body for `T>1,K=1` and + `T>1,K>1`. + - The guardrail remains a small high-level `T>1,K>1` row proving no new + route identity, benchmark time chunking, Python replay, detach policy, or + full `[T, cells, state]` materialization. + +5. Training after forward compaction. + - Training remains far behind, but broad reverse optimization before the + forward body is compact will chase inflated tape/reducer owners. + - Once forward live products are compact, reopen reverse/tape/reducer + liveness through the same artifact routes. + +Highest-impact next plan: + +- Do not retry a layout-only or scalar-kernel probe. +- Define a compiler-row-owned transition row-group executor contract with + explicit workspace-vs-artifact outputs. +- First prove with a small mechanism probe that at least one of the full + transition runtime roles disappears from forward-only `projection_into` rows. +- Only then run the representative h32 100M B1024 T=1 forward audit. + +### 2026-05-04 - Plan: Transition Row-Group Workspace Contract + +Classifier: + +- Lane: throughput strategy / liveness mechanism probe. +- Semantic delta: none. +- April21 usage: physical-shape evidence only. +- Do not start by changing math kernels. Start by proving the compiler can make + transition products workspace-only for forward-only streaming rows. + +Boundary manifest: + +- Surface: registered fused forward program, transition row-group execution. +- Declaration/spec owner: unchanged graph/cell/message/readout declarations. +- Primitive rows: unchanged transition row groups + (`linear -> gated/diag -> norm/output`). +- Tensor/binding owner: `program_tensor_binding_rows`, + `native_callable_output_rows`, `forward_program_access_rows`. +- Route owner: `message_transition_producer_consumer_rows`, + `forward_output_route_rows`, `forward_artifact_route_rows`. +- Memory/liveness owner: `memory_liveness_rows`, `runtime_buffer_rows`, and + physical strategy rows. +- Forward owner: `registered_temporal_fused_forward_program_cuda`. +- Backward/reducer owner: unchanged in this slice; artifact/training rows keep + existing materialized tape behavior. + +Hypothesis: + +- The current single-pop peak is observed during message stages because full + transition/runtime products are already live or reserved by then. +- If forward-only streaming rows make transition intermediate outputs + workspace-only instead of logical runtime products, the owner table should + show at least one of these roles shrink or disappear: + `transition_forward_linear_output`, + `transition_forward_matmul_output`, + `transition_forward_diag_output`, + `transition_forward_norm_output`. +- If those roles do not move, do not run a representative audit; reject the + slice as metadata-only. + +Rows/fingerprints expected to stay stable: + +- Primitive row fingerprint. +- Forward executor row fingerprint. +- Forward executor binding fingerprint. +- Program tensor binding fingerprint. +- Message-transition producer-consumer row fingerprint. +- Output route fingerprint. +- Artifact route fingerprint. + +Implementation phases: + +1. Add an explicit workspace-vs-artifact contract. + - Use existing row products first; add only typed validation or metadata + needed to prove which native callable outputs are workspace-only. + - Legal only when: + - `return_reverse_artifacts == false`; + - `return_final_program_tensors == false`; + - no artifact/tape route requires the intermediate; + - downstream primitive consumer is inside the same transition row group. + - Training/artifact rows stay on the existing materialized path. + +2. Make transition row-group outputs actually local. + - For gated rows, the input projection, gate affine output, recurrent + matmul output, recurrence private output, and norm output should be + local/chunk workspace unless a later binding or artifact route requires + them. + - For diagonal rows, the input projection, diagonal preprojection/output + projection, and norm/public output should follow the same rule. + - Continue using GEMM/BMM where current code already does; do not introduce + per-receiver scalar projection loops. + +3. Clear or avoid program bindings for dead transition intermediates. + - Use the existing binding/liveness machinery rather than hidden C++ + side channels. + - After each primitive consumer runs, dead input/output bindings should be + cleared or never materialized as logical runtime buffers. + - Public state output and output-route products remain visible through + compiler access/output rows. + +4. Add a mechanism probe before representative audit. + - Artifact path: + `tmp/fabric_audits/partials/2026-05-04/t1_transition_workspace_mechanism_probe_h32_100m_b1024`. + - Run a small or representative forward-only row only far enough to inspect + runtime-role bytes and memory-stage owner movement. + - Success means at least one full transition role shrinks/disappears and + no new unclassified allocator owner appears. + +5. Representative gates only after mechanism success. + - Single-pop h32 100M B1024 T=1 forward: + `tmp/fabric_audits/partials/2026-05-04/t1_transition_workspace_rowgroup_h32_100m_b1024`. + - Mixed h32 100M B1024 T=1 forward: + `tmp/fabric_audits/partials/2026-05-04/t1_transition_workspace_rowgroup_mixed_h32_100m_b1024`. + - T/K guardrail: + `tmp/fabric_audits/partials/2026-05-04/tk_transition_workspace_rowgroup_guardrail_h32_100m_b128_t2_k2`. + +Tests and guardrails: + +- Catalog validation. +- Source guardrail that transition workspace ownership is row/binding/liveness + owned and does not use family, hidden-size, benchmark, or single/mixed + selectors. +- CUDA fused-forward smoke. +- Terminal-loss artifact-store smoke to prove training/artifact rows still use + safe materialized tape. + +Keep/narrow/revert rule: + +- Keep only if the mechanism probe physically moves role bytes or named native + stage high-water, then representative rows improve tok/s or peak memory. +- Narrow if only a typed legal subset can use workspace-only transition + outputs and unsupported cases fail closed before launch. +- Revert if role bytes remain unchanged, peak memory grows, throughput + regresses, row fingerprints change, or the change only relabels metadata. + +### 2026-05-04 - Result: Transition Row-Group Workspace Probe Rejected + +Artifact paths: + +- Accepted current comparison row: + `tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_h32_100m_b1024`. +- Mechanism probe: + `tmp/fabric_audits/partials/2026-05-04/t1_transition_workspace_mechanism_probe_h32_100m_b1024`. +- Selector-disabled control: + `tmp/fabric_audits/partials/2026-05-04/t1_transition_workspace_selector_disabled_h32_100m_b1024`. + +Probe result: + +| Row | Accepted tok/s | Probe tok/s | Accepted peak | Probe peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 forward | `14058.92` | `10683.08` | `4.394 GiB` | `4.394 GiB` | reject | +| Axon single T=1 forward | `8579.00` | `6576.92` | `8.776 GiB` | `8.776 GiB` | reject | + +The probe did not satisfy the keep rule: + +- Planned runtime role bytes stayed unchanged for the named transition roles: + - sLSTM: + - `transition_forward_linear_output`: `3019898880` + - `transition_forward_matmul_output`: `2415919104` + - `transition_forward_norm_output`: `603979776` + - `transition_forward_state_output`: `603979776` + - `forward_recurrent_hidden_after`: `603979776` + - `forward_recurrent_msg`: `603979776` + - Axon: + - `transition_forward_linear_output`: `3758096384` + - `transition_forward_diag_output`: `3758096384` + - `transition_forward_norm_output`: `1879048192` + - `forward_recurrent_hidden_after`: `1879048192` + - `forward_recurrent_msg`: `1879048192` +- The current allocated peak stage remained + `native_forward_message_after_normalize_local0` for both rows. +- Peak memory did not improve. +- Throughput regressed materially. + +Important accounting correction: + +- `fabric_compiler_runtime_role_bytes.*` reports planned logical role bytes, + including roles that may be `deferred_local`. +- Future liveness probes must compare `estimated_allocated_buffer_bytes`, + allocated/max stage telemetry, storage identity/lifetime evidence, and + unclassified allocator growth. Planned role bytes alone are not proof that a + tensor is physically live. + +Selector-disabled control: + +- Disabling the transition row-group selector is invalid as a revert path. +- sLSTM fails before launch with + `registered transition linear input references an empty program tensor`. +- Axon runs but regresses to about `6686.97` tok/s and `12.501 GiB`. +- Reason: the compiler-owned message-to-transition streaming route intentionally + uses an empty aggregate-input sentinel. The row-group consumer is the legal + owner that consumes the streamed transition input. The primitive-loop fallback + cannot consume that route. + +Decision: + +- Reject the transition row-group workspace probe as a throughput optimization. +- Do not run the planned representative single-pop, mixed-pop, or T/K audits for + this mechanism. +- Keep the compiler-required transition row-group route active; it is required + for direct stream-transition legality. +- Do not remove or disable row-group execution as a performance revert. +- Do not continue one-buffer transition role-byte tweaks unless the next + hypothesis names the physical allocator/native-stage owner and predicts a real + stage movement. + +Next owner: + +- The next forward owner is not another metadata-only transition role-byte + change. +- Work should either: + - add compiler-row-owned cost/workspace policy for row-group launch shape so + the current 256 MiB chunking does not destroy GEMM/BMM efficiency when the + peak owner is elsewhere; or + - attack the message-stage native temporary owner directly with a registered + producer-consumer/GEMM-grouping strategy that names the tensors that stop + being fully materialized and the downstream consumers that move inside the + streaming step. +- Both directions must keep primitive rows, route rows, artifact rows, and + output routes stable and must remain usable as the T/K streaming step body. + +Validation after restoring the required row-group route: + +- `uv run python scripts/validate_fabric_generated_catalogs.py`: passed. +- `uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short`: + passed, `2 passed`. + +### 2026-05-04 - Deep Dive: Remaining T=1 Throughput After Row-Group Workspace Rejection + +Scope: + +- Analysis only. No implementation or optimization patch in this section. +- Accepted current baseline remains: + - `tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_h32_100m_b1024` + - `tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_mixed_h32_100m_b1024` + - `tmp/fabric_audits/partials/2026-05-04/tk_output_route_projection_into_guardrail_h32_100m_b128_t2_k2` +- Rejected evidence remains rejected: + - `t1_gated_gate_add_forward_h32_100m_b1024` + - `t1_stream_transition_fused_body_h32_100m_b1024` + - `t1_gemm_streaming_transition_workspace_h32_100m_b1024` + - `t1_transition_workspace_mechanism_probe_h32_100m_b1024` +- April21 comparison row remains `h32_t1_bxparams`: + `58732.71 tok/s`, `2.07 GiB`. + +Current accepted owner table: + +| Row | Current tok/s | % April21 | Slowdown | Peak GiB | Mem x April21 | Current allocated owner | First high-water | +| --- | ---: | ---: | ---: | ---: | ---: | --- | --- | +| sLSTM single T=1 forward | `14058.92` | `23.94%` | `4.18x` | `4.394` | `2.12x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| Axon single T=1 forward | `8579.00` | `14.61%` | `6.85x` | `8.776` | `4.24x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| sLSTM mixed T=1 forward | `7548.00` | `12.85%` | `7.78x` | `9.121` | `4.41x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| Axon mixed T=1 forward | `7540.89` | `12.84%` | `7.79x` | `10.152` | `4.90x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| sLSTM T=2,K=2 forward guardrail, B=128 | `2663.16` | steering | steering | `2.852` | steering | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | +| Axon T=2,K=2 forward guardrail, B=128 | `2621.03` | steering | steering | `11.702` | steering | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | + +Stale training guardrail: + +| Row | Latest known tok/s | % April21 | Slowdown | Peak GiB | Mem x April21 | Owner | +| --- | ---: | ---: | ---: | ---: | ---: | --- | +| sLSTM single T=1 training | `13.37` | `0.02%` | `4392.87x` | `25.049` | `12.10x` | `native_after_recurrent_message_local0` | +| Axon single T=1 training | `4.01` | `0.01%` | `14646.56x` | `81.319` | `39.28x` | `native_after_recurrent_message_local0` | + +Training rows are stale relative to `projection_into`; use them as liveness +warnings only. Forward remains the main optimization lane until the physical +step body is compact. + +Largest accepted forward runtime roles: + +- sLSTM single: + - `transition_forward_linear_output`: `2.812 GiB` + - `transition_forward_matmul_output`: `2.250 GiB` + - `transition_forward_state_output`: `0.562 GiB` + - `transition_forward_norm_output`: `0.562 GiB` + - `forward_recurrent_msg`: `0.562 GiB` + - `forward_recurrent_hidden_after`: `0.562 GiB` +- Axon single: + - `transition_forward_linear_output`: `3.500 GiB` + - `transition_forward_diag_output`: `3.500 GiB` + - `transition_forward_norm_output`: `1.750 GiB` + - `forward_recurrent_msg`: `1.750 GiB` + - `forward_recurrent_hidden_after`: `1.750 GiB` +- Mixed: + - `transition_forward_linear_output`: `5.086 GiB` + - `transition_forward_matmul_output`: `2.906 GiB` + - `transition_forward_norm_output`: `1.453 GiB` + - `transition_forward_diag_output`: `1.453 GiB` + - `forward_recurrent_msg`: `1.453 GiB` + - `forward_recurrent_hidden_after`: `1.453 GiB` + +Diagnosis: + +- The accepted forward path is compiler-owned and materially better than the + earlier local rows, but it is still not April21-class. It is reconstructing + pieces of the low-live-memory physical shape, not the whole step body. +- Single-pop peaks during message/readout stages, but the logical live set is + dominated by transition outputs plus recurrent public/message banks. The + useful owner is the overlap between message-stage work and transition/runtime + products, not simply one message kernel. +- Mixed-pop peaks after transition. That means singleton-only streaming is not + enough; chunk/merge ownership across multiple producers remains open. +- Rejected probes show what not to repeat: + - per-row scalar fusion destroys matrix efficiency; + - packed gate temporary layout did not move the owner; + - row-group workspace metadata did not move actual allocator stages; + - disabling row groups breaks the legal stream-transition route. + +Biggest remaining changes, in priority order: + +1. **Compiler-owned GEMM/BMM grouping and cost policy.** + The next high-leverage move is a registered physical strategy that groups + compatible affine/projection work into larger GEMM/BMM or grouped-GEMM + units while preserving explicit nonlinear boundaries. This should be a + compiler strategy/cost-model product over rows and bindings, not a family, + shape, or benchmark selector. + +2. **Program-level streaming physical step body.** + Message projection, message normalization, transition input/state update, + readout message, readout projection, and output route need to become one + producer-consumer step body. The strategy must consume + `forward_program_access_rows`, + `message_transition_producer_consumer_rows`, + `readout_message_producer_consumer_rows`, + `forward_output_route_rows`, `runtime_buffer_rows`, + `memory_liveness_rows`, and artifact/tape rows directly. + +3. **Executable transition liveness, not planned-role relabeling.** + Full transition outputs must stop overlapping with message/readout + high-water on forward-only rows. Acceptance requires movement in allocator + telemetry, stage max/current bytes, launch shape, or storage lifetime; role + byte labels alone are not enough. + +4. **Mixed-pop chunk/merge ownership.** + The same streaming body must handle mixed rows through explicit chunk, + merge, and output-route rows. A singleton-only improvement can be steering + evidence, but not T=1 closure. + +5. **T/K streaming proof.** + T=1 is only the first invocation of the T/K streaming step. Any accepted + strategy needs a small high-level `T>1,K>1` guardrail through the same route + identity, with no benchmark time chunking, detach policy, Python replay, or + full `[T, cells, state]` materialization. + +6. **Reset-present mixed route parity.** + The reset-present mixed-pop readout parity issue predates the latest + forward slices, but it still blocks full acceptance. Keep it separate unless + the next strategy touches reset or readout routing. + +7. **Training/reverse/reducer liveness after forward compaction.** + Training remains catastrophically behind, but it should not become the main + lane until the forward producer-consumer body stops inflating the same + artifact/tape/recurrent-message owners. Once forward compacts, reopen + reverse artifacts, recurrent-message gradients, parameter reducers, and + checkpoint/recompute policy through the same compiler rows. + +Practical next target: + +- Plan a `streaming_physical_step_v2` registered strategy focused on preserving + GEMM/BMM efficiency while reducing producer boundaries. +- First acceptance row: h32 100M B1024 T=1 forward, sLSTM + Axon, single-pop. +- Required follow-up before acceptance: mixed-pop T=1 forward plus small + high-level `T>1,K>1` forward guardrail through the same compiler route. +- Keep only if a named owner physically moves and tok/s or peak improves + without parity regression. Narrow through typed legality blockers when only a + subset is legal. Revert if the row fingerprints change or the change merely + relabels metadata. + +### 2026-05-04 - Plan: Streaming Physical Step v2 With GEMM/BMM Grouping + +Classifier: + +- Lane: throughput strategy / native implementation plan. +- Semantic delta: none. +- April21 usage: physical execution-shape evidence only; no code copy. +- Do not optimize training in this slice except as a bounded artifact/liveness + smoke if forward changes affect training legality. + +Goal: + +- Implement the next accepted T=1 forward move as a registered + `streaming_physical_step_v2` strategy that preserves large GEMM/BMM work + while reducing full producer boundaries across message, transition, readout, + and output routes. +- Treat T=1 as one invocation of the same step body used for `T>1,K=1` and + `T>1,K>1`; do not create a terminal-only T=1 shortcut. + +Hypothesis: + +- The remaining forward gap is caused by the active path materializing too many + row boundaries while already holding transition/runtime products live. +- Prior scalar and one-buffer probes failed because they did not preserve the + dense matrix shape or did not move a physical allocator owner. +- A legal strategy that keeps affine work GEMM/BMM-shaped and moves consumers + inside the step body should reduce either: + - current allocated stage bytes at + `native_forward_message_after_normalize_local*` or + `native_forward_after_transition_local*`; or + - full runtime product overlap for + `forward_recurrent_msg`, + `transition_forward_linear_output`, + `transition_forward_matmul_output`, + `transition_forward_diag_output`, + `transition_forward_norm_output`, + `forward_output_msg`, and `forward_output_cells`. + +Rows/fingerprints expected to stay stable: + +- Primitive rows. +- Forward executor rows. +- Forward executor binding rows. +- Program tensor binding rows. +- Native callable binding schema rows. +- Message-transition producer-consumer rows, except for explicit strategy/cost + metadata if needed. +- Readout-message producer-consumer rows, except for explicit strategy/cost + metadata if needed. +- Forward output route rows. +- Forward artifact route rows. +- Reset rows and runtime schedule rows. + +Compiler products consumed directly: + +- `physical_strategy_rows`. +- `memory_liveness_rows`. +- `runtime_buffer_rows`. +- `memory_runtime_schedule_rows`. +- `forward_program_access_rows`. +- `message_transition_producer_consumer_rows`. +- `readout_message_producer_consumer_rows`. +- `forward_output_route_rows`. +- `forward_artifact_route_rows`. +- `program_tensor_binding_rows`. +- `native_callable_output_rows`. + +Boundary manifest: + +- Declaration/spec owner: unchanged graph, cell, message, readout, interface, + and execution declarations. +- Forward owner: `registered_temporal_fused_forward_program_cuda`. +- Strategy owner: `physical_strategy_rows` with + `streaming_step_producer_consumer` active. +- Native implementation surfaces: + - `forward_program.cuh` for route selection and step-body orchestration; + - `message_forward_strategies.cuh` for message/readout producer work; + - `transition_forward_program.cuh` for transition row-group consumption; + - compiler row builders in `program_execution.py` / `memory_plan.py` only if + an explicit cost/legality row is required. +- Backward/reducer owner: unchanged for this forward slice. Artifact/training + paths must keep materialized tape unless explicitly planned in a later + reverse/reducer pass. + +Implementation phases: + +1. **Add a plan-level strategy contract, not a hidden kernel branch.** + - If existing `physical_strategy_rows` are sufficient, consume them and + validate the active `streaming_step_producer_consumer` row before launch. + - If more distinction is needed, add compiler-owned strategy/cost metadata + for `streaming_physical_step_v2`; do not encode the choice as a C++ + hidden-size, family, single/mixed, or benchmark selector. + - Record typed blockers for unsupported reset, artifact, mixed-merge, + dtype/layout, or multiple-consumer cases. + +2. **Make the first mechanism probe a launch-shape/liveness probe.** + - Do not begin with another one-buffer alias. + - Add or reuse native-stage telemetry to prove whether the owner is: + - message weighted-value/output-weight/projected/normalize temporary; + - transition row-group chunk workspace; + - readout projection/output-route temporary; + - allocator reserve/unclassified high-water. + - The first useful evidence is owner movement in stage telemetry, launch + count/shape, or storage lifetime, not role-byte metadata. + +3. **Preserve GEMM/BMM shape while moving consumers.** + - Keep current dense affine/projection work in GEMM/BMM/grouped-GEMM form. + - Move consumers inside the step body only across legal producer-consumer + rows: + - message -> transition input projection; + - message -> readout message; + - readout message -> output projection/output route; + - transition row-group intermediate -> transition public/carry output. + - Keep normalization, gating, reset, and output-route semantics explicit + unless the strategy proves an equivalent fused implementation. + +4. **Make workspace policy compiler-owned.** + - Replace hardcoded local chunk policy only where the compiler strategy row + can prove legality. + - Candidate policy knobs must be expressed as strategy/cost/liveness rows or + derived from compiler memory rows, not from benchmark shape checks. + - The immediate suspect is row-group/message chunking that preserves memory + but damages GEMM/BMM efficiency; the fix should be a legal strategy choice, + not a larger global constant. + +5. **Single-pop mechanism gate before representative acceptance.** + - First artifact path: + `tmp/fabric_audits/partials/2026-05-04/t1_streaming_physical_step_v2_mechanism_h32_100m_b1024`. + - Row: h32 100M B1024 T=1 forward, sLSTM + Axon, single-pop. + - Success criteria: + - named stage owner moves or shrinks; + - no peak-memory increase; + - no row fingerprint semantic change; + - throughput improves versus accepted `projection_into`, or peak memory + drops materially with flat throughput. + +6. **Representative follow-ups only after mechanism success.** + - Single-pop representative: + `tmp/fabric_audits/partials/2026-05-04/t1_streaming_physical_step_v2_h32_100m_b1024`. + - Mixed-pop representative: + `tmp/fabric_audits/partials/2026-05-04/t1_streaming_physical_step_v2_mixed_h32_100m_b1024`. + - T/K guardrail: + `tmp/fabric_audits/partials/2026-05-04/tk_streaming_physical_step_v2_guardrail_h32_100m_b128_t2_k2`. + - Follow-up rows must use the same compiler-owned route identity or fail + closed through typed legality blockers. + +Tests and guardrails: + +- Catalog validation. +- Source/static guardrails: + - active path consumes `physical_strategy_rows`; + - producer-consumer rows remain compiler-owned; + - no family, hidden-size, benchmark, fixed-slot enum, single/mixed-pop, or + April21-code selector is added. +- CUDA fused-forward smoke: + `tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells`. +- Terminal-loss artifact-store smoke: + `tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store`. +- If reset/readout routing is touched, include the reset-present mixed readout + parity test before accepting the slice. + +Keep/narrow/revert rule: + +- Keep only if a named physical owner moves and the representative row improves + tok/s or peak memory without parity regression. +- Narrow if only a typed legal subset benefits and unsupported rows fail closed + through compiler legality. +- Revert if: + - primitive/tensor-role fingerprints change; + - the change is metadata-only; + - throughput regresses like the scalar/packed-gate probes; + - peak memory grows; + - selection depends on family, hidden size, benchmark row, single/mixed + labels, or old fixed-slot ABI shortcuts. + +Deferred work explicitly out of scope for this plan: + +- Broad reverse/training reducer optimization. +- Dot-product semantic stress test. +- Public Config/Blueprint cleanup. +- April21 code copy or direct kernel transplant. + +### 2026-05-04 - Result: Rejected Stream Readout Project-Into Hook + +Attempted mechanism: + +- Add a message native callable hook that kept the existing fixed-slot + partitioned attention primitive and recurrent-value GEMM, but made the + readout output-message chunk step-local and immediately projected it into the + compiler-owned output route. +- Compiler-owned surfaces touched during the probe: + `readout_message_producer_consumer_rows`, forward output route rows, message + native callable phases, generated native callable catalog, and fused forward + program route selection. +- Semantic rows, tensor bindings, output routes, artifact routes, reset rows, + message math, transition math, and readout math were intended to stay stable. + +Validation before the performance decision: + +```text +CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_readout_project_into_20260504c \ +TRITON_CACHE_DIR=/tmp/cortical_triton_stream_readout_project_into_20260504c \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed in 61.82s + +uv run python scripts/validate_fabric_generated_catalogs.py +# Generated Fabric catalog headers are up to date. + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_readout_message_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + --tb=short +# 3 passed in 6.29s after reverting the rejected hook +``` + +Mechanism audit: + +```text +tmp/fabric_audits/partials/2026-05-04/t1_stream_readout_project_into_h32_100m_b1024 +``` + +Comparison against accepted `t1_output_route_projection_into_h32_100m_b1024`: + +| Row | Accepted tok/s | Probe tok/s | Accepted peak | Probe peak | Owner movement | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | `14058.92` | `10572.47` | `4.394 GiB` | `4.394 GiB` | none | +| Axon 100M h32 B1024 T1 forward | `8579.00` | `6562.63` | `8.776 GiB` | `8.776 GiB` | none | + +Owner-table detail: + +- `native_forward_message_after_normalize_local0` stayed the peak-current + owner: + - sLSTM: `3.513 GiB`. + - Axon: `7.750 GiB`. +- `native_forward_after_readout_message_local0` stayed the first/max-delta + owner: + - sLSTM first peak: `4.394 GiB`; readout-message stage current + `3.108 GiB`. + - Axon first peak: `8.776 GiB`; readout-message stage current + `7.376 GiB`. +- `forward_output_msg` role bytes stayed unchanged: + - sLSTM: `0.094 GiB`. + - Axon: `0.125 GiB`. + +Decision: + +- Rejected and reverted. +- Reason: the hook passed parity but did not change actual storage lifetime or + allocator ownership, and it materially reduced throughput. It only moved a + local chunk boundary while the compiler runtime buffer and allocator + high-water owners stayed unchanged. + +Implication for the next pass: + +- Do not re-add a readout-message-to-output projection hook unless the compiler + liveness plan also removes or defers the `forward_output_msg` runtime buffer + allocation itself for the legal route. +- The next viable owner must attack the real allocated owner: + `native_forward_message_after_normalize_local*` and the transition/runtime + overlap, not just the final readout projection consumer. + +### 2026-05-04 - Hypothesis Packet: Stream Message Projection Into Transition Input + +Classifier: + +- Lane: throughput strategy / native implementation. +- Semantic delta: none. +- April21 usage: physical execution-shape evidence only; no code copy. + +Hypothesis: + +- The accepted single-pop route already has compiler-owned + `message_transition_producer_consumer_rows` selecting direct + message-to-transition input. +- That route still materializes a full normalized message chunk before applying + the transition input projection. +- A registered native implementation can preserve the same attention and GEMM + projection semantics while consuming the normalized message chunk immediately + into the transition input target, reducing the live projected-message + boundary in the `native_forward_message_after_normalize_local*` owner. + +Compiler products consumed: + +- `physical_strategy_rows` with the active streaming-step strategy. +- `message_transition_producer_consumer_rows`. +- `forward_program_access_rows`. +- `program_tensor_binding_rows`. +- `native_callable_binding_schema_rows`. +- `native_callable_output_rows`. +- `runtime_buffer_rows`. +- `memory_liveness_rows`. + +Rows/fingerprints expected to stay stable: + +- Primitive rows. +- Forward executor rows. +- Program tensor binding rows. +- Message/readout/transition declaration rows. +- Artifact/output/reset rows. + +Implementation boundary: + +- Change only the registered fixed-slot-context message native strategy phase + selected by the compiler-owned `stream_transition_input` row. +- Keep weighted-value attention and output projection GEMM-shaped. +- Do not add scheduler-owned formulas, family selectors, hidden-size selectors, + benchmark selectors, fixed tensor slot shortcuts, or April21 code. + +First artifact path: + +- `tmp/fabric_audits/partials/2026-05-04/t1_stream_message_transition_project_into_h32_100m_b1024`. + +Keep/narrow/revert rule: + +- Keep only if the targeted CUDA parity smoke passes and the performance row + moves a named owner or improves tok/s/peak without increasing memory. +- Narrow if the direct path is legal only for defined recurrent-value banks. +- Revert if the patch is metadata-only, changes row fingerprints, regresses + throughput without owner movement, or increases peak memory. + +Result: + +- Rejected and reverted. +- Artifact: + `tmp/fabric_audits/partials/2026-05-04/t1_stream_message_transition_project_into_h32_100m_b1024`. +- Parity/compile smoke passed: + +```text +CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_stream_msg_transition_project_into_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_stream_msg_transition_project_into_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed in 63.66s +``` + +Measured against accepted +`tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_h32_100m_b1024`: + +| Row | Accepted tok/s | Probe tok/s | Accepted peak | Probe peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | `14058.92` | `10797.06` | `4.394 GiB` | `4.394 GiB` | reject | +| Axon 100M h32 B1024 T1 forward | `8579.00` | `6620.76` | `8.776 GiB` | `8.776 GiB` | reject | + +Owner movement: + +- None. `native_forward_message_after_normalize_local0` stayed the + peak-current owner for both rows. +- `native_forward_after_readout_message_local0` stayed the first high-water + owner for both rows. +- Runtime role bytes were unchanged. + +Interpretation: + +- Eliminating only the full message chunk inside this phase did not shorten the + measured allocator lifetime. The weighted-value/projection/transition live + overlap and runtime transition products still dominate the high-water. +- The extra projected subchunk loop reduced GEMM efficiency, so this is another + rejected narrow probe. +- Do not retry this shape unless it is part of a broader row-group GEMM/BMM + producer-consumer strategy that also moves transition product liveness. + +### 2026-05-04 - Hypothesis Packet: Gated Transition BMM Output Consumed Directly + +Classifier: + +- Lane: throughput strategy / native implementation. +- Semantic delta: none. +- April21 usage: physical execution-shape evidence only; no code copy. + +Hypothesis: + +- The registered gated transition row group already uses a GEMM/BMM-shaped gate + affine, but the active implementation materializes the BMM output and then a + second gate-major copy before recurrence consumes it. +- A registered native implementation can keep the same compiler primitive rows, + tensor bindings, and gate affine math while letting the recurrence kernel + consume the BMM output layout directly. +- This should reduce one full gate-logit-sized transition temporary for legal + forward-only streaming rows without changing message, transition, readout, + output, reset, artifact, or reducer semantics. + +Compiler products consumed: + +- `physical_strategy_rows` with the active streaming-step strategy. +- Transition primitive rows and forward executor binding rows. +- `program_tensor_binding_rows`. +- `native_callable_binding_schema_rows`. +- `native_callable_output_rows`. +- `runtime_buffer_rows`. +- `memory_liveness_rows`. + +Rows/fingerprints expected to stay stable: + +- Primitive rows. +- Forward executor rows. +- Forward executor binding rows. +- Program tensor binding rows. +- Message-transition and readout-message producer-consumer rows. +- Forward output/artifact/reset rows. + +Implementation boundary: + +- Change only the registered gated transition forward row-group implementation. +- Preserve BMM/GEMM shape for the gate affine. +- Do not add scheduler-owned formulas, family selectors, hidden-size selectors, + benchmark selectors, fixed tensor slot shortcuts, or April21 code. + +First artifact path: + +- `tmp/fabric_audits/partials/2026-05-04/t1_gated_transition_bmm_direct_consume_h32_100m_b1024`. + +Follow-up artifact paths if the first row moves: + +- `tmp/fabric_audits/partials/2026-05-04/t1_gated_transition_bmm_direct_consume_mixed_h32_100m_b1024`. +- `tmp/fabric_audits/partials/2026-05-04/tk_gated_transition_bmm_direct_consume_guardrail_h32_100m_b128_t2_k2`. + +Keep/narrow/revert rule: + +- Keep if CUDA parity smoke passes and a named forward owner or stage shrinks, + or tok/s improves without increasing peak memory. +- Narrow to gated-transition rows if Axon/diagonal rows are unaffected. +- Revert if row fingerprints change, peak memory grows, throughput regresses + without owner movement, or the implementation adds a hidden route selector. + +Result: + +- Rejected and reverted. +- Artifact: + `tmp/fabric_audits/partials/2026-05-04/t1_gated_transition_bmm_direct_consume_h32_100m_b1024`. +- Compiler/static gates passed before the perf decision: + +```text +uv run python scripts/validate_fabric_generated_catalogs.py +# Generated Fabric catalog headers are up to date. + +CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_gate_affine_head_major_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_gate_affine_head_major_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed in 61.58s + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows \ + --tb=short +# 3 passed in 5.09s +``` + +Measured against accepted +`tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_h32_100m_b1024`: + +| Row | Accepted tok/s | Probe tok/s | Accepted peak | Probe peak | Decision | +| --- | ---: | ---: | ---: | ---: | --- | +| sLSTM 100M h32 B1024 T1 forward | `14058.92` | `10208.17` | `4.394 GiB` | `4.394 GiB` | reject | +| Axon 100M h32 B1024 T1 forward | `8579.00` | `6357.74` | `8.776 GiB` | `8.776 GiB` | reject | + +Owner movement: + +- None useful. The peak memory stayed unchanged for both rows. +- sLSTM still peaked at `native_forward_after_readout_message_local0` by + first/high-water and `native_forward_message_after_normalize_local0` by + peak-current ownership. +- Axon still showed the same message/transition runtime ownership shape; the + gated-only change did not improve the diagonal path. + +Interpretation: + +- Consuming the gate affine BMM layout directly preserved parity, but it moved + the hot path into a less favorable memory layout for the recurrence and cut + throughput sharply. +- This confirms that isolated removal of the gate-major copy is not the + missing April21 physical shape. The next viable forward owner needs a broader + producer-consumer grouping that keeps GEMM/BMM efficiency while reducing + full transition/readout stage materialization, rather than a layout-only + recurrence consumer tweak. + +### 2026-05-04 - Deep Dive: Remaining T=1 Throughput After BMM Direct Rejection + +Scope: + +- Analysis only. No throughput optimization patch in this section. +- Current accepted baseline remains: + - `tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_h32_100m_b1024` + - `tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_mixed_h32_100m_b1024` + - `tmp/fabric_audits/partials/2026-05-04/tk_output_route_projection_into_guardrail_h32_100m_b128_t2_k2` +- April21 comparison row remains `h32_t1_bxparams`: + `58732.71 tok/s`, `2.07 GiB`. + +Current accepted owner table: + +| Row | Current tok/s | % April21 | Slowdown | Peak GiB | Mem x April21 | Current allocated owner | First high-water | +| --- | ---: | ---: | ---: | ---: | ---: | --- | --- | +| sLSTM single T=1 forward | `14058.92` | `23.94%` | `4.18x` | `4.394` | `2.12x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| Axon single T=1 forward | `8579.00` | `14.61%` | `6.85x` | `8.776` | `4.24x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| sLSTM mixed T=1 forward | `7548.00` | `12.85%` | `7.78x` | `9.121` | `4.41x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| Axon mixed T=1 forward | `7540.89` | `12.84%` | `7.79x` | `10.152` | `4.90x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| sLSTM T=2,K=2 forward guardrail, B=128 | `2663.16` | steering | steering | `2.852` | steering | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | +| Axon T=2,K=2 forward guardrail, B=128 | `2621.03` | steering | steering | `11.702` | steering | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | + +Accepted single-pop role bytes: + +| Role | sLSTM GiB | Axon GiB | +| --- | ---: | ---: | +| `transition_forward_linear_output` | `2.812` | `3.500` | +| `transition_forward_matmul_output` | `2.250` | `0.000` | +| `transition_forward_diag_output` | `0.000` | `3.500` | +| `transition_forward_norm_output` | `0.562` | `1.750` | +| `transition_forward_state_output` | `0.562` | `0.000` | +| `forward_recurrent_msg` | `0.562` | `1.750` | +| `forward_recurrent_hidden_after` | `0.562` | `1.750` | + +Accepted mixed-pop role bytes: + +| Role | GiB | +| --- | ---: | +| `transition_forward_linear_output` | `5.086` | +| `transition_forward_matmul_output` | `2.906` | +| `transition_forward_diag_output` | `1.453` | +| `transition_forward_norm_output` | `1.453` | +| `forward_recurrent_msg` | `1.453` | +| `forward_recurrent_hidden_after` | `1.453` | + +Raw findings: + +- The accepted `projection_into` row is still the best current single-pop row + among the measured artifacts. Most later probes kept peak flat but cut + throughput, which means the branch is no longer blocked by one missing + output-route shortcut. +- Single-pop first high-water occurs at readout-message time, but the large + logical live set is transition products plus recurrent public/message banks. + The remaining gap is overlap and launch shape across the whole step body. +- Mixed-pop peaks after transition. The current compiler plan still blocks + message-to-transition streaming for multiple transition consumers with + `multiple_transition_consumers_need_merge_rows`, so mixed cannot close from + singleton-only improvements. +- T/K already uses the registered route, but the guardrail still peaks inside + repeated message normalization. A T=1-only terminal shortcut would not solve + the streaming-body problem. + +Rejected directions that should not be repeated as narrow patches: + +- Readout message project-into: parity passed, but output-message storage + lifetime did not move and throughput regressed. +- Message-to-transition project-into: parity passed, but allocator ownership + did not move and smaller projection chunks reduced GEMM efficiency. +- Gate BMM direct consume: parity passed, but the recurrence layout became + worse for throughput and peak did not move. +- Scalar/local fusion attempts: reduced matrix efficiency and did not recover + April21-class execution shape. +- Alias/no-copy probes: must not continue unless they name and move the + actual allocator/lifetime owner; previous broad alias routes exposed + unclassified high-water risk. + +Remaining T=1 work, in priority order: + +1. **Registered streaming physical step body.** + The current forward path is compiler-owned, but it is still a sequence of + materialized stage boundaries. The main remaining work is a registered + producer-consumer strategy over `physical_strategy_rows`, + `memory_liveness_rows`, `runtime_buffer_rows`, access rows, artifact rows, + output route rows, and producer-consumer rows. It must be the same step body + for `T=1,K=1`, `T>1,K=1`, and `T>1,K>1`. + +2. **GEMM/BMM grouping and cost policy.** + The next strategy must preserve large matrix work. The failed probes show + that smaller chunks, scalar projection, and layout-only recurrence tweaks + lose more throughput than they save. Candidate strategies should group + compatible projection/affine work by compiler rows and choose chunk sizes as + explicit strategy/cost metadata, not constants hidden in kernels or + benchmark shape selectors. + +3. **Executable transition liveness.** + Transition intermediates are still the largest nominal roles. Existing + deferred-local outputs helped, but full transition row-group products still + overlap with message/readout high-water. The next accepted change must + physically move allocator telemetry for `transition_forward_linear_output`, + `transition_forward_matmul_output`, `transition_forward_diag_output`, and + `transition_forward_norm_output`, not just relabel role bytes. + +4. **Message native strategy that feeds consumers without breaking GEMM shape.** + The fixed-slot context message path still builds weighted-value/projected + message boundaries before transition/readout consumers. The desired change + is not a per-row scalar fusion; it is a compiler-selected native strategy + that keeps the weighted-value/projection work GEMM/BMM-shaped while routing + outputs directly to legal transition/readout consumers. + +5. **Readout output route liveness owned by rows.** + `project_into` is accepted, but output-message/output-cells route + allocation still appears in the runtime plan. A future readout strategy must + make output-message and output-cells lifetime disappear or shrink through + compiler liveness/output-route rows, not by adding another local + post-readout hook. + +6. **Mixed-pop merge/chunk rows.** + Mixed T=1 is as important as single T=1 because it exercises the compiler's + multi-executor design. The remaining blocker is explicit multi-producer / + multi-consumer chunk and merge ownership so streaming remains legal across + more than one transition/readout span. + +7. **T/K streaming guardrail.** + Any accepted T=1 strategy must immediately run a small high-level + `T>1,K>1` guardrail through the same route identity. No benchmark-owned + time chunking, detach policy, Python replay, or full `[T, cells, state]` + materialization may appear. + +8. **Training/reverse/reducer liveness after forward moves.** + Training remains far behind, but it should stay a bounded guardrail until + the forward step body becomes compact. Once forward moves, reopen reverse + artifacts, recurrent-message gradients, parameter reducers, checkpoint / + recompute, and tape liveness through the same compiler rows. + +April21 transfer boundary: + +- Origin/main / April21 is useful evidence for a low-live-memory physical + shape: receiver/message work, transition public-state emission, readout, and + output are scheduled with much less live materialization. +- Do not copy old execution code. The transferable idea must be re-expressed + as registered compiler-owned strategies over current primitive rows, + bindings, liveness rows, artifact routes, output routes, and reducer rows. + +Practical next target: + +- Plan a `streaming_step_producer_consumer` strategy extension rather than another one-buffer + probe. +- First acceptance row: h32 100M B1024 T=1 forward, sLSTM + Axon, single-pop. +- Required follow-ups before accepting the strategy: mixed-pop T=1 forward and + small high-level T>1,K>1 forward guardrail through the same compiler route. +- Keep only if a named physical owner moves and throughput or peak improves + without parity or boundary regression. + +### 2026-05-04 - Hypothesis Packet: Streaming Step Producer-Consumer Strategy + +Hypothesis: + +- The registered streaming step producer-consumer strategy should become the + compiler-owned owner for the T=1 body and the T/K scan body. The strategy + should consume the same rows that already describe message, transition, + readout, output routing, artifacts, runtime buffers, and memory liveness. + +April21 mechanism being semantically transferred: + +- Low-live-memory physical execution shape: message projection, transition + public-state emission, readout, and output routing are scheduled as one + producer-consumer step body with minimal full-stage materialization. +- This is a mechanism transfer only. No April21 code copy, no benchmark-owned + tiling, no hidden family/shape selectors, and no semantic changes. + +Lane: + +- Throughput strategy / native strategy over existing compiler rows. + +Rows/fingerprints expected to stay stable: + +- Primitive rows, tensor binding rows, message/cell/readout declarations, + output route semantics, reset semantics, and parameter bindings. +- Expected semantic delta: none. + +Bindings/liveness/artifact rows consumed: + +- `physical_strategy_rows` +- `memory_liveness_rows` +- `runtime_buffer_rows` +- `memory_runtime_schedule_rows` +- `forward_program_access_rows` +- `forward_artifact_route_rows` +- `forward_artifact_merge_rows` +- `forward_output_route_rows` +- `readout_message_producer_consumer_rows` +- `message_transition_producer_consumer_rows` +- forward/reverse executor binding rows + +Smallest representative row: + +- h32 100M B1024 T=1 forward, sLSTM + Axon, single-pop. + +Artifact path: + +- Planned: `tmp/fabric_audits/partials/2026-05-04/t1_streaming_step_producer_consumer_h32_100m_b1024`. + +Expected owner movement: + +- First useful proof is route ownership: metadata reports + `selected_strategy=streaming_step_producer_consumer`. +- Performance acceptance requires a named physical owner to move or improve: + `native_forward_message_after_normalize_local*`, + `native_forward_after_readout_message_local*`, or + `native_forward_after_transition_local*`. + +Keep/narrow/revert rule: + +- Keep row plumbing if parity/static gates pass and route metadata proves the + `streaming_step_producer_consumer` compiler strategy owns the active path. +- Keep performance behavior only if a named owner physically moves or tok/s / + peak improves without memory regression. +- Revert or narrow behavior if row fingerprints change, a hidden selector is + introduced, peak grows, or throughput regresses without owner movement. + +Follow-up representative rows: + +- Mixed-pop T=1 forward through the same strategy identity. +- Small high-level T>1,K>1 forward guardrail through the same strategy + identity. + +T/K streaming guardrail: + +- The strategy is not a terminal-only T=1 shortcut. It must remain the + streaming step body inside the temporal scan loop for `T>1,K=1` and + `T>1,K>1`, with carried state, reset policy, output-route policy, + artifact/tape policy, checkpoint/recompute policy, and reducer liveness owned + by compiler rows. + +Result: + +- Accepted as route-gating/ABI cleanup, not as a throughput win. +- Production code continues to use the semantic strategy identity + `streaming_step_producer_consumer`; no hypothesis/version label is used in + the runtime ABI. +- The active physical strategy is now gated by compiler-owned executable + producer-consumer rows instead of only the no-artifact/no-reset/final-state + policy. At least one executable streaming producer-consumer row must be + selected before the physical strategy can claim the streaming step body. +- Existing blocked rows remain visible. For example, transition edges that + still report `multiple_transition_consumers_need_merge_rows` do not become + hidden special cases; they are still explicit compiler blockers inside the + selected step strategy. +- No semantic rows, tensor roles, output routes, artifact routes, reset policy, + tape policy, or reducer contracts changed. + +Validation: + +```bash +python -m py_compile \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py \ + src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py +# passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes \ + tests/test_fabric_backend_plan.py::test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows \ + tests/test_fabric_backend_plan.py::test_readout_message_producer_consumer_rows_are_compiler_owned_legality_rows \ + --tb=short +# 3 passed + +uv run pytest -q \ + tests/test_fabric_backend_plan.py::test_temporal_backward_validates_memory_artifact_plan_fingerprint \ + tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified \ + --tb=short +# 2 passed + +CUDA_VISIBLE_DEVICES=0 \ +TORCH_EXTENSIONS_DIR=/tmp/cortical_ext_streaming_strategy_gate_or_20260504 \ +TRITON_CACHE_DIR=/tmp/cortical_triton_streaming_strategy_gate_or_20260504 \ +timeout 900s uv run pytest -q \ + tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells \ + --tb=short +# 1 passed + +git diff --check -- \ + ai_docs/FABRIC_THROUGHPUT_CLOSURE.md \ + src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py \ + src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py \ + src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/constants_and_checks.cuh \ + src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/memory_runtime_buffers.cuh +# passed +``` + +Performance status: + +- No new throughput claim from this slice. The accepted owner table remains + `t1_output_route_projection_into_h32_100m_b1024` and + `t1_output_route_projection_into_mixed_h32_100m_b1024`. +- The next implementation still needs to move a named allocator/native-stage + owner. The highest-impact owner remains the streaming step body that preserves + GEMM/BMM grouping while reducing message/transition/readout materialization. + +### 2026-05-04 - Deep Dive: Remaining T=1 Throughput Work After Strategy Gate + +Status: analysis only. No optimization was started in this pass. + +Current accepted evidence: + +- April21 reference floor remains `h32_t1_bxparams`: `58732.71 tok/s`, + `2.07 GiB`. +- Accepted current single-pop artifact: + `tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_h32_100m_b1024`. +- Accepted current mixed-pop artifact: + `tmp/fabric_audits/partials/2026-05-04/t1_output_route_projection_into_mixed_h32_100m_b1024`. +- Accepted current T/K guardrail artifact: + `tmp/fabric_audits/partials/2026-05-04/tk_output_route_projection_into_guardrail_h32_100m_b128_t2_k2`. + +Owner table: + +| Row | Current tok/s | April21 share | Slowdown | Peak GiB | Peak vs April21 | Peak-current owner | First/high-water owner | +| --- | ---: | ---: | ---: | ---: | ---: | --- | --- | +| sLSTM single T=1 forward | `14058.92` | `23.94%` | `4.18x` | `4.394` | `2.12x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| Axon single T=1 forward | `8579.00` | `14.61%` | `6.85x` | `8.776` | `4.24x` | `native_forward_message_after_normalize_local0` | `native_forward_after_readout_message_local0` | +| sLSTM mixed T=1 forward | `7548.00` | `12.85%` | `7.78x` | `9.121` | `4.41x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| Axon mixed T=1 forward | `7540.89` | `12.84%` | `7.79x` | `10.152` | `4.90x` | `native_forward_after_transition_local0` | `native_forward_after_transition_local0` | +| sLSTM T=2,K=2 forward guardrail, B=128 | `2663.16` | steering | steering | `2.852` | steering | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | +| Axon T=2,K=2 forward guardrail, B=128 | `2621.03` | steering | steering | `11.702` | steering | `native_forward_message_after_normalize_local3` | `native_forward_after_transition_local1` | + +Largest runtime role bytes still visible: + +- sLSTM single: `transition_forward_linear_output=2.812 GiB`, + `transition_forward_matmul_output=2.250 GiB`, + `forward_recurrent_msg=0.562 GiB`, + `forward_recurrent_hidden_after=0.562 GiB`, + `transition_forward_norm_output=0.562 GiB`, + `transition_forward_state_output=0.562 GiB`. +- Axon single: `transition_forward_linear_output=3.500 GiB`, + `transition_forward_diag_output=3.500 GiB`, + `forward_recurrent_msg=1.750 GiB`, + `forward_recurrent_hidden_after=1.750 GiB`, + `transition_forward_norm_output=1.750 GiB`. +- Mixed: `transition_forward_linear_output=5.086 GiB`, + `transition_forward_matmul_output=2.906 GiB`, + `transition_forward_diag_output=1.453 GiB`, + `transition_forward_norm_output=1.453 GiB`, + `forward_recurrent_msg=1.453 GiB`, + `forward_recurrent_hidden_after=1.453 GiB`. + +Diagnosis: + +- The active path is compiler-owned and uses registered strategy rows, but the + physical execution still materializes too many producer/consumer stage + boundaries. +- Local liveness wins are not enough. The rejected probes already showed that + scalar projection, smaller chunks, direct BMM consume, no-copy aliases, and + one-buffer shortcuts either did not move the named owner or damaged GEMM/BMM + efficiency. +- The real remaining gap is the compact streaming physical step body: message + projection, transition public-state emission, readout, and output routing + must run as one compiler-selected producer-consumer unit while preserving + large GEMM/BMM/grouped-GEMM shapes. +- Single-pop and mixed-pop now expose different blockers. Single-pop is still + message/readout high-water dominated; mixed-pop is blocked by + `multiple_transition_consumers_need_merge_rows` and still peaks after + transition. +- T/K must stay in scope. A terminal-only T=1 shortcut would be invalid even if + it improved the steering row, because the same step body must extend to + `T>1,K=1` and `T>1,K>1`. + +Biggest remaining changes before T=1 can match or exceed April21: + +1. **Implement the real streaming producer-consumer native body.** + The current `streaming_step_producer_consumer` strategy is a valid ownership + gate, but it still needs a broader native body that shortens the + message/transition/readout live ranges together. This is the highest-impact + remaining forward owner. + +2. **Preserve GEMM/BMM grouping while reducing boundaries.** + The next body should combine or stream producer/consumer work without + breaking large matrix operations into slow scalar or tiny chunks. Chunk size + and grouping must be cost/strategy metadata over compiler rows, not + benchmark or hidden shape policy. + +3. **Make transition outputs true local workspace.** + `transition_forward_linear_output`, `transition_forward_matmul_output`, + `transition_forward_diag_output`, and `transition_forward_norm_output` must + stop overlapping with later message/readout high-water stages. The accepted + result must show allocator/native-stage movement, not only role metadata. + +4. **Close message-to-transition merge rows for mixed-pop.** + Mixed-pop cannot close through singleton-only streaming. The compiler needs + explicit multi-consumer merge/chunk rows so streaming remains legal for more + than one transition span. + +5. **Close readout/message producer-consumer execution without keyless shortcuts.** + Readout must consume compiler-proven producer routes. It may not skip + recurrent K/V or use direct/keyless readout unless the row proves semantic + equivalence and parity. + +6. **Keep T/K as the streaming guardrail.** + Every accepted T=1 step-body change must immediately run a small high-level + `T>1,K>1` guardrail through the same route identity. No Python replay, + detach policy, benchmark time chunking, or full `[T, cells, state]` + materialization may appear. + +7. **Reopen training only after forward moves.** + Training remains catastrophically behind, but broad reverse/reducer tuning + should stay a guardrail lane until the forward physical step body moves. + Once forward moves, reopen artifact/tape/reducer liveness against the same + compiler rows. + +Boundary rule for the next plan: + +- Do not add another narrow alias, direct-consume, project-into, or scalar + fusion patch unless it is explicitly part of the streaming physical step + body and predicts which named owner will move. +- Do not copy origin/main / April21 code. Use it only as evidence for the + physical execution shape, then express the mechanism through current + primitive rows, executor rows, tensor bindings, liveness rows, artifact + routes, output routes, and reducer rows. + +### 2026-05-04 - Plan: Streaming Step Producer-Consumer Native Body + +Status: planned next implementation. Do not start from a narrow buffer patch. +The next change must implement a broader compiler-owned streaming physical +step body. + +Boundary manifest: + +- Lane: throughput strategy / native implementation over existing compiler + products. +- Expected semantic delta: none. +- Stable rows: primitive rows, tensor binding rows, executor rows, output + route rows, artifact route rows, reset rows, tape rows, reducer rows. +- Changed rows/runtime: strategy/liveness/runtime validation and native + producer-consumer execution. +- Old path to avoid: materializing full message, transition, and readout stage + products when compiler rows prove a streaming producer-consumer route. +- Hard rejects: April21 code copy, benchmark tiling, hidden family/shape + selectors, keyless readout without legality proof, scalar/tiny-chunk GEMM + breakage, Python replay, or terminal-only T=1 shortcuts. + +April21 mechanism being semantically transferred: + +- Low-live-memory producer-consumer physical shape: + message -> transition/readout -> output with minimal full-stage + materialization. +- The transfer is only the mechanism. The implementation must consume current + compiler rows directly. + +Compiler products consumed: + +- `physical_strategy_rows` +- `memory_liveness_rows` +- `runtime_buffer_rows` +- `memory_runtime_schedule_rows` +- `forward_program_access_rows` +- `message_transition_producer_consumer_rows` +- `readout_message_producer_consumer_rows` +- `forward_output_route_rows` +- `forward_artifact_route_rows` +- `forward_artifact_merge_rows` +- forward executor rows and forward executor binding rows + +Implementation plan: + +1. **Build a step workgroup map from rows.** + Create a native-side workgroup for each physical step from message + producers, transition consumers, readout consumers, output routes, runtime + buffers, and liveness rows. This must be route-owned, not role-only. + +2. **Preserve large GEMM/BMM groups.** + Group compatible message projection, transition affine/projection, and + readout projection work by compiler row shape/layout facts. Do not split the + row into scalar or tiny chunks unless cost rows explicitly select that + strategy. + +3. **Make transition intermediates local scratch.** + For legal forward-only rows, keep `transition_forward_linear_output`, + `transition_forward_matmul_output`, `transition_forward_diag_output`, and + `transition_forward_norm_output` inside the step body as scratch/workspace. + Only public state, output route products, and required artifacts may escape. + +4. **Stream message to transition without retaining full recurrent message.** + Use the message-transition producer-consumer row to feed transition input + directly when legal. If multiple transition consumers exist, do not guess: + require explicit merge/chunk rows and keep the current typed blocker until + those rows exist. + +5. **Stream readout without keyless shortcuts.** + Readout may avoid fully retained recurrent K/V-after only when + readout-message producer-consumer rows prove the producer route and required + tensor roles. Computing K/V/readout inputs as local route scratch is allowed; + skipping semantic K/V is not. + +6. **Keep output route ownership unchanged.** + Output sequence/cells writes continue through `forward_output_route_rows`. + The streaming body may write into route targets, but it must not invent + output ownership. + +7. **Add C++ validation before execution.** + Validate the selected strategy, producer/consumer row coverage, liveness row + coverage, reset/artifact/tape constraints, and route offsets before launch. + Unsupported combinations fail closed with typed blockers. + +8. **Run the smallest representative proof.** + First row: h32 100M B1024 T=1 forward, sLSTM + Axon, single-pop. + Required follow-up before acceptance: mixed-pop T=1 forward and small + high-level `T>1,K>1` forward through the same route identity. + +Expected tensors to stop being fully materialized for accepted forward rows: + +- Full recurrent message after transition has consumed it. +- Transition linear/matmul/diag/norm products beyond local step scratch. +- Recurrent K/V-after for readout-only consumers when producer-consumer rows + prove a local projected route. +- Readout output-message intermediates when output route projection can consume + them locally. + +Expected owner movement: + +- Single-pop should move below `native_forward_message_after_normalize_local0` + / `native_forward_after_readout_message_local0`. +- Mixed-pop should move below `native_forward_after_transition_local0` only + after merge/chunk rows make multi-consumer streaming legal. +- Runtime role bytes for transition intermediates and recurrent message/hidden + must shrink or become scratch-owned in allocator telemetry. + +Keep/narrow/revert rule: + +- Keep if parity passes and a named native-stage or allocator owner moves with + no semantic row delta. +- Keep only as groundwork if validation/row plumbing improves but no owner + moves; do not claim throughput. +- Narrow or revert if throughput regresses without owner movement, peak grows, + row fingerprints change, keyless readout appears without proof, or strategy + selection depends on family, hidden size, benchmark row, or April21 code + shape. + +Validation gates: + +- Static/ABI: producer-consumer row tests, liveness row tests, output-route + validation, allocation classification guardrail, `git diff --check`. +- CUDA parity: fused forward no-artifact output-cells test, T=1 terminal-loss + artifact-store test, representative sLSTM/Axon forward parity. +- Performance steering: `t1-single-pop` h32 100M B1024 forward for sLSTM and + Axon, followed by mixed-pop T=1 and small `T>1,K>1` forward guardrail. +- Training remains guardrail-only until forward owner movement is real. diff --git a/ai_docs/REDO2_FIXMASS.md b/ai_docs/REDO2_FIXMASS.md new file mode 100644 index 00000000..ba0b4095 --- /dev/null +++ b/ai_docs/REDO2_FIXMASS.md @@ -0,0 +1,10655 @@ +# REDO2 Fixmass Plan + +Status: ACTIVE PLAN. + +Purpose: replace the drifted REDO_FIXMAASS execution order with a +compiler-first, throughput-first, scalability-first plan. REDO2 is the plan to +make Fabric a genuine compiler: user-declared graph, message, cell, readout, +reset, output, and temporal semantics lower into IR, then primitive op/tensor +rows, parameter bindings, primitive executors, and finally one shared temporal +runtime. The same compiled program must support single-pop and mixed-pop, +`T*K`, rolling horizon `H`, reset/state axes, and the required April 21 +throughput/memory gates. + +## Deep-Dive Refinement - 2026-04-29 + +Sources re-read for this refinement: + +- `skills/cb.fabric-backend-boundaries/SKILL.md` +- `skills/cb.fabric-performance-loop/SKILL.md` +- `skills/cb.fabric-scaling-horizon/SKILL.md` +- `skills/cb.fabric-parity-gate/SKILL.md` +- `ai_docs/prompt.tx` +- `ai_docs/additonal_goals.md` +- `ai_docs/AWS_RECOVERY_TRAIL.md` +- `ai_docs/recovered_core.py` +- `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json` +- `benchmarks/fabric/audit.py` +- `benchmarks/fabric/suite_common.py` + +Corrections added by this read: + +- Forward-only long-T closure is separate from training T/H closure. Forward + rows do not use horizon-H, but they still must prove T streaming scaling. +- Training T/H and T*K closure must inherit the full T=1 matrix. Representative + rows steer backend work; they do not close scalable Fabric. +- Reset-present rows are closure rows, not optional smoke tests. Reset parity + and reset throughput must be covered wherever state, message, projection, + checkpoint, or temporal backward behavior changes. +- T=1 state cases are closure rows, not optional semantics. Cover backend-created + fresh state, detached provided state, and differentiable provided state. + Detached provided state still contributes values to recurrent/message + parameter gradients. State-provided versus backend-created state is an audit + axis over the same shared temporal path, not a route identity, shortcut key, + or permission to skip generic recurrent/message adjoint work. +- The recovered April 24 state had semantic rolling-H scaffolding but not the + efficient CUDA rolling-H backward. REDO2 must close the real + `BackwardWindowPlan`/streaming adjoint executor, not just planner metadata. +- The audit runner can help run and resume experiments, but dry-run case lists + are never progress evidence. Progress is measured tok/s, memory, parity, and + truthful owner movement from high-level API calls. +- R15 cleanup is the full recovered cleanup board: message rules, graph + ownership, Config/anatomy leakage, old public API paths, population identity, + hidden-size policy keys, stale routes, and tests. It is not only the recent + `config.py` and `anatomy.py` complaint. + +## Compiler Closure Contract + +Fixmass is not complete unless Fabric is a compiler in the active high-level +API path. This is the central backend item, not an R15 side task. + +Required chain for every supported Fabric path: + +`declaration -> Fabric IR -> primitive op rows -> tensor-table roles -> parameter bindings -> primitive executors -> shared temporal runtime` + +Adding a new primitive op must require only registry/lowering metadata, a +reference executor, an optional fused CUDA strategy, and tests. It must not +require editing temporal scheduler ownership, numeric tensor-slot enums, +cell-family route selectors, or monolithic scan/reverse ABIs. + +The compiler closure contract applies to: + +- Graph facts and flat bucket identity. +- Message-passing math, including source selection, aggregation, distance/delay + terms, normalization choices, reset behavior, and adjoints. +- Cell transition math, including projections, recurrent/local operators, + normalization, activation, public/private state emission, and adjoints. +- Readout/output/boundary math and parameter-gradient reductions. +- Temporal execution policy: `T*K`, horizon `H`, reset scheduling, + checkpoint/recompute, materialization, dependency ordering, and workspace. + +The temporal runtime is the scheduler for the compiled primitive program. It +may own physical time, `T*K`, H windows, reset masks, checkpoints, +materialization, dependency order, workspace, and flat-bucket traversal. It must +not own primitive formulas such as dot-product Q/K/V attention, gated sLSTM, +diagonal recurrence, layernorm, projection/readout equations, or their +adjoints. Those formulas are valid only inside primitive executors selected by +lowered op rows. + +The compiler is allowed and expected to use composite primitives and fused +primitive blocks for throughput. A recurrence such as +`gated_logspace_recurrence` or `diag_rtu` may be a first-class composite +primitive instead of a long scalar-op expansion. That composite is valid only +when it is declared in `fabric.cuda.nn`, lowered into primitive rows/tensor +roles, selected by the primitive executor, and paired with explicit forward, +backward, recompute/tape, and parameter-gradient behavior. A fused or composite +block must not become a hidden canonical algorithm that ignores user-declared +math. + +Compiler extensibility is a closure requirement, not a cleanup nicety. Adding a +new primitive op must require only registry/lowering metadata, a reference +executor, optional fused CUDA strategy, and tests. It must not require editing +the temporal scheduler ownership, fixed tensor slot enums, cell-family route +selectors, or monolithic scan/reverse ABIs. If a new op cannot either enter this +chain locally or fail closed with a typed unsupported reason, Fabric is not yet a +true compiler. + +Target cell declaration style after R2.3/R15 cleanup is concise and +declarative. Cell files should describe the primitive program; primitive +executors should own math. sLSTM should look like this shape, with all declared +private states emitted: + +```cpp +struct SLSTM { + static fabric::cuda::nn::CellTransitionIR cell_transition_ir_host( + int64_t hidden_dim, + int64_t message_dim) { + fabric::cuda::nn::CellTransitionBuilder b("slstm"); + + auto y = b.private_state("y", {hidden_dim}); + auto c = b.private_state("c", {hidden_dim}); + auto n = b.private_state("n", {hidden_dim}); + auto m = b.private_state("m", {hidden_dim}); + auto msg = b.message_input("aggregated_message", {message_dim}); + + auto gate_logits = b.linear(msg, b.parameter("gate_weight"), b.parameter("bias")); + auto recurrent_gate_logits = b.matmul(y, b.parameter("recurrent_kernel")); + auto next = b.gated_logspace_recurrence(gate_logits, recurrent_gate_logits, c, n, m); + auto public_y = b.norm_or_identity(next.y, b.parameter("outnorm_weight")); + + b.emit_state("y", next.y); + b.emit_state("c", next.c); + b.emit_state("n", next.n); + b.emit_state("m", next.m); + b.emit_public(public_y); + return b.build_cell_transition(); + } +}; +``` + +Axon should use the same declaration style. The diagonal transition/eligibility +trace update is the supported composite primitive. Input projection, output +projection, state emission, public emission, and composite options remain +declared compiler rows/attributes; they must not be hidden in an Axon family +bundle: + +```cpp +struct Axon { + static fabric::cuda::nn::CellTransitionIR cell_transition_ir_host( + int64_t hidden_dim, + int64_t message_dim) { + fabric::cuda::nn::CellTransitionBuilder b("axoncell"); + + auto hc1 = b.private_state("hc1", {hidden_dim}); + auto hc2 = b.private_state("hc2", {hidden_dim}); + auto e_nu_c1 = b.private_state("E_nu_c1", {hidden_dim}); + auto e_nu_c2 = b.private_state("E_nu_c2", {hidden_dim}); + auto e_th_c1 = b.private_state("E_th_c1", {hidden_dim}); + auto e_th_c2 = b.private_state("E_th_c2", {hidden_dim}); + auto e_w1_c1 = b.private_state("E_w1_c1", {hidden_dim}); + auto e_w1_c2 = b.private_state("E_w1_c2", {hidden_dim}); + auto e_w2_c1 = b.private_state("E_w2_c1", {hidden_dim}); + auto e_w2_c2 = b.private_state("E_w2_c2", {hidden_dim}); + auto msg = b.message_input("aggregated_message", {message_dim}); + + auto input_weight = b.parameter("input_proj_weight"); + auto input_bias = b.parameter("recurrent_cell_bias"); + auto nu_log = b.parameter("nu_log"); + auto theta_log = b.parameter("theta_log"); + auto w1 = b.parameter("w1"); + auto w2 = b.parameter("w2"); + auto activation = b.parameter("activation_id"); + auto out_weight = b.parameter("out_proj_weight"); + auto out_bias = b.parameter("out_proj_bias"); + + auto cell_input = b.linear(msg, input_weight, input_bias); + auto next = b.diag_rtu( + cell_input, + hc1, + hc2, + e_nu_c1, + e_nu_c2, + e_th_c1, + e_th_c2, + e_w1_c1, + e_w1_c2, + e_w2_c1, + e_w2_c2, + nu_log, + theta_log, + w1, + w2, + activation); + auto public_y = b.linear(next.preproj, out_weight, out_bias); + + b.emit_state("hc1", next.hc1); + b.emit_state("hc2", next.hc2); + b.emit_state("E_nu_c1", next.E_nu_c1); + b.emit_state("E_nu_c2", next.E_nu_c2); + b.emit_state("E_th_c1", next.E_th_c1); + b.emit_state("E_th_c2", next.E_th_c2); + b.emit_state("E_w1_c1", next.E_w1_c1); + b.emit_state("E_w1_c2", next.E_w1_c2); + b.emit_state("E_w2_c1", next.E_w2_c1); + b.emit_state("E_w2_c2", next.E_w2_c2); + b.emit_public(public_y); + return b.build_cell_transition(); + } +}; +``` + +These examples are not facade APIs. Closure requires each builder call to lower +into Fabric IR, primitive rows, tensor-table roles, parameter bindings, and +primitive executor dispatch. If changing one of these declarations would not +change the compiled program or fail closed as unsupported, the compiler owner is +still open. + +The current code is not allowed to call this complete just because the target +sketch exists. The old indexed C++ builder shape, the PyTorch Axon +`activation_name` side channel, and any CUDA helper that selects diagonal or +gated math by cell-family shape are R2.3/R15 blockers until they are replaced by +row-owned composite primitive executors selected from the compiled transition +program. + +Unsupported declarations must fail closed at lowering or executor selection. +No broad public API, metadata label, `fabric.cuda.nn` wrapper, or benchmark +result counts as compiler closure if the active backend silently runs the same +hidden dot-product/gated/diagonal/readout route after the declaration changes. + +## Non-Negotiable Rules + +- Compiler closure is a hard pass/fail gate. REDO2 cannot close on parity or + throughput from a facade path. The active high-level row must prove the full + declaration-to-primitive-executor chain for both cell math and message math. +- Throughput evidence means measured `tokens_per_s`, `peak_mem_gib`, owner + metadata, and artifact paths. Dry-run case lists are not progress. +- T=1 is the base streaming case. If T=1 training is broken, T/H/K results are + diagnostic only and cannot close REDO2. +- H is not a T=1 closure axis. T=1 provides the baseline full-step reference; + horizon-H is the rolling TBPTT window for `T>1` sequence training. +- TBPTT horizon is clipped at actual loss-emission points and actual available + physical steps. If `T*K <= H`, there is no extra horizon scan beyond the + sequence; the backward window is just the full available dependency range for + the emitted losses. H must never cause replay, checkpointing, or materialized + scan work outside the real `T*K` stream. +- T/H closure is a full scalability matrix. It must cover every shape, size, + family, batch, hidden-size, population mode, reset mode, and output-boundary + contract used for T=1 closure, not only a representative row. +- K closure also inherits the scalability matrix and must run as T*K, not only + as `T=1,K>1`. Representative K probes steer implementation, but final K=128 + support cannot be claimed from one small row. +- Fabric parallelism has spatial axes and temporal streaming axes. `B`, params, + graph/node count, topology, population buckets, and `h` are the spatial/work + axes. `T` and `K` are temporal streaming axes over the same T=1 substrate: + `T*K` is repeated T=1 execution in one backend-owned temporal engine, with + different output materialization and checkpoint/recompute policy, not a + separate algorithm or route. +- Benchmarks are high-level API consumers only: model forward, external loss, + `loss.backward()`, and optimizer step where applicable. No benchmark-owned + temporal tiling, detach policy, checkpoint policy, private planner/runtime + calls, or streaming-loss helper can be used for closure. +- The backend must stay generic. No cell-kind selectors, population-name + selectors, benchmark-row branches, lattice/config fields, hidden-size policy + keys, or separate single/mixed execution identities in temporal kernels. +- Planner metadata must be truthful. `cuda_temporal_superop` can appear as a + backward owner only after the active measured path physically runs there. +- April 21 JSON is the historical score source. Current-code matched T=1 rows + are required for T/H/K scaling floors. +- Checkpoint steps are planner-determined unless a case explicitly requests an + override. The selected policy must be recorded in metadata. +- Checkpointing and replay are planner/workspace decisions derived from shape, + size, tape pressure, output/loss materialization, reset policy, `T*K`, H, and + device memory. They must not be fixed routes. Small or `T*K <= H` cases can + store/directly feed the needed generic reverse-table artifacts when that is + cheaper; larger cases use planner-selected checkpoint/recompute windows. +- Every temporal CUDA kernel/superop must pass a manual boundary review before + it is accepted: generic tensor/op table ABI, no cell-family bundle, no + population-name route, no benchmark-row route, no hidden-size policy key, and + no separate single/mixed path. +- No compatibility shims are accepted as the target design. During cleanup, + callsites move to the correct owner and the old route/export is deleted or + fails closed. +- Experiments for this REDO2 run use GPUs 0-4 only, with private Torch/Triton + cache directories under `/tmp` or another run-specific path. Do not use the + global extension cache for CUDA audit work. +- Policy changes must be generic. A fix can target the current blocking owner, + but the code path must derive from plan/runtime facts such as `T*K`, horizon, + reset policy, output materialization, tensor/op tables, graph facts, bucket + identity, dtype/device, and workspace pressure. It must not derive from an + audit row id, family name, target parameter label, hidden-size constant, + single-vs-mixed label, or a fabricated unit-test reference. +- Unit tests for audit gates must use the April 21 reference loader or explicit + planner/runtime metadata fixtures. Do not add fake throughput rows to encode + REDO2 policy; if a synthetic result is needed, it must be testing a generic + comparator contract and not pretending to be an audit closure row. + +## Current Measured State + +- Blocking row: + `sLSTM 500M, B=512, h=32, T=1, K=1, sequence loss`. +- Artifact: + `/tmp/redo_fixmaass_t1_slstm500m_b512_h32_sequence_owner_timing/cases.jsonl`. +- Current result: + `52.781 tok/s`, `9700.488 ms`, peak `125.569 GiB`. +- April 21 floor: + `streaming_per_timestep_sequence_loss`, `95,638.46 tok/s`, `61.21 GiB`. +- Current owner: + forward reports `cuda_temporal_superop`; backward still reports + `python_autograd_scan`, `python_host_reverse_loop`, and + `temporal_artifacts:recompute_step_artifacts`. +- Largest measured owners: + `temporal_artifact_recompute=4096.730 ms`, + `artifact.recompute.cuda_temporal_replay_scan=4027.582 ms`, + `state_epilogue.core=1029.082 ms`. + +Conclusion: R4/shared temporal backward is the immediate blocker. REDO2 starts +by fixing T=1 training throughput and memory before treating T/H/K as closure. + +Additional diagnostic already attempted: + +- `sLSTM 500M, B=512, h=32, T=4096, K=1, H=64, sequence loss` failed OOM after + reaching about 129 GiB in-use and attempting another 32 GiB allocation. +- That row is not enough to diagnose a long-T-specific owner because the matched + T=1 training row is already badly regressed. Keep it as evidence that the + current path is not streaming-bounded, then fix T=1/shared backward first. +- A newer current-code T=1 guardrail after the reverse-window sequence-binding + probe also failed: + `/tmp/redo_fixmaass_t1_slstm500m_b512_h32_sequence_reverse_seq/cases.jsonl`. + Result: `44.607 tok/s`, `11477.999 ms`, peak `125.569 GiB`. Metadata still + showed `temporal_plan_backward_owners=["python_autograd_scan"]`, + `flat_bucket_temporal_reverse_scan_owner:python_host_reverse_loop`, and + `temporal_artifacts:recompute_step_artifacts`. Dominant owners remained + `temporal_artifact_recompute=4102.506 ms`, + `artifact.recompute.cuda_temporal_replay_scan=4030.251 ms`, + `public_projection=1537.729 ms`, and `state_epilogue.core=1028.936 ms`. + Conclusion: the sequence-binding probe is not R2.1 closure; the next actual + backend owner is the T=1 CUDA temporal checkpoint-consumption and + artifact/replay path. +- A direct T=1 forward-stored reverse-table artifact probe also failed to move + the row: + `/tmp/redo_fixmaass_t1_slstm500m_b512_h32_sequence_direct_reverse/cases.jsonl`. + Result: `44.607 tok/s`, peak `125.569 GiB`, with the same + `temporal_artifact_recompute`/`cuda_temporal_replay_scan` owner. The probe did + not trigger because the current reverse-table executor requires both a gated + bucket and a diagonal bucket; the April21-shaped single-pop sLSTM row has one + transition bucket. This is rejected as closure evidence. Next backend owner: + make the reverse-table/sequence owner generic over one-or-many primitive + buckets instead of assuming a two-bucket gated+diagonal mix. + +Generic policy update in progress: + +- Planner horizon policy now clips requested TBPTT H to the actual available + physical stream (`T*K`) before selecting the backward window and default + checkpoint stride. If requested H covers the full stream, the plan records a + full-horizon window instead of creating a rolling-H scan. This applies to all + Fabric rows and is not tied to the current 500M/B512 blocker. +- Active reverse-only fallback failures now report both the generic reverse + engine rejection reason and the CUDA reverse-table error. The next backend + probe must use that reason to fix the table-owned reverse executor rather + than adding row-specific admission logic. + +Current-code guardrail after reverse primitive-row binding: + +- Row: + `t1-single-pop_slstm_1m_forward_backward_b16384_t1_k1_h32_ghnone_ckplanner_losssequence_popsingle_resetabsent`. +- Artifact: + `/tmp/redo2_t1_guard_1m_b16384_h32_sequence/cases.jsonl`. +- Result: + `231,971.14 tok/s`, `70.629 ms`, peak `1.904 GiB`. +- April 21 summary floor: + `h32_small_params_high_batch`, `1,986,978.16 tok/s`, peak `12.97 GiB`. +- Gate: + failed throughput (`tokens_per_s_below_april21_reference`). Memory is below + floor, but throughput is about 8.6x below the April21 protected score. +- Owners: + forward and backward both report `cuda_temporal_superop`; reverse artifacts + report `forward_reverse_tables`; backward launch counts include + `cuda_transition_message_reverse_table_window` and + `cuda_transition_message_reverse_table_device_loop`. +- Remaining blocker: + the owner gate still fails for + `temporal_primitive_executor_blockers_present` across message, readout, + parameter reduction, affine, recurrence, and norm rows. This confirms the + next owner is real primitive executor dispatch and deletion of the bundled + transition/message reverse program, not more planner relabeling. + +Warm confirmation and owner split: + +- Forward-only warm row: + `/tmp/redo2_t1_guard_1m_b16384_h32_forward_warm/cases.jsonl`, + `1,849,370.53 tok/s`, `8.859 ms`, peak `0.820 GiB`. This is close to, but + still below, the April21 summary floor. +- Training warm row: + `/tmp/redo2_t1_guard_1m_b16384_h32_sequence_warm/cases.jsonl`, + `236,306.83 tok/s`, `69.334 ms`, peak `1.904 GiB`. +- Timed training row: + `/tmp/redo2_t1_guard_1m_b16384_h32_sequence_timed/cases.jsonl`, + `230,088.37 tok/s`, `71.207 ms`. +- Dominant CUDA-event owner: + `transition_message_reverse_table_device_loop: 37.762 ms`. The next + throughput owner is the bundled reverse table program and its primitive + executor/parameter-reduction split, not forward scan admission. + +## Audit Matrix Contract + +### April 21 Score Source + +The April 21 JSON is the historical score source and defines the required +surface classes: + +- `h32_t1_bxparams`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384, 24/24 rows. +- `h32_small_params_high_batch`: sLSTM + Axon, 1M/10M, forward + training, + B=16384/65536/131072, 24/24 rows. +- `h4_many_cell_stress`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384, 24/24 rows. +- `h8_many_cell_stress_broad` plus + `h8_many_cell_stress_focused_warmed_rerun`: broad h=8 coverage plus the + focused Axon 1B train B=1024 rerun that superseded the noisy row. +- `h16_many_cell_stress`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384, 24/24 rows. +- `streaming_per_timestep_sequence_loss`: exact streaming rows for sLSTM/Axon, + 500M/1B, B=512, T=512/4096, h=32. +- `flat_graph_factorization_invariance`: equivalent flat graphs with only + user-side factorization labels varied. +- `rollout_curves_vs_hf`: SmolLM2 135M, SmolLM2 360M, TinyLlama 1.1B, + sLSTM + Axon, B=128, h=32. + +Closure rows must cite either the exact streaming reference key or the April 21 +summary audit that controls the row. If the current audit runner only has a +summary floor for a broad group, the runner must still emit enough row metadata +to show which April 21 surface class is being protected. + +### T=1 Matrix + +The T=1 closure matrix is the base matrix. Every later scalability matrix +inherits these axes unless a row is explicitly inapplicable and documented: + +- Families: `slstm`, `axoncell`. +- Population modes: `single`, `mixed` where mixed means the same shared engine + with multiple flat buckets, not a wrapper. +- Parameter targets: every April 21 T=1 size group, including `1M`, `10M`, + `100M`, `500M`, and `1B` where present. +- Batches: every April 21 T=1 batch group, including `B=1024`, `B=16384`, + high-batch `B=65536/131072`, plus streaming-sequence reference rows such as + `B=512` where they are the controlling April 21 references. +- Hidden sizes: headline `h=32` plus stress `h=4/8/16/32` for hidden-size + closure. +- Modes: forward and forward_backward. Horizon-H applies to training rows; + forward rows still need T-scaling coverage without backward H. +- Output boundaries: terminal and per-timestep sequence loss/output where + applicable. +- Reset modes: absent and present for semantic/reset closure. +- State modes: backend-created fresh state, detached provided state, and + differentiable provided state. Provided state affects the current recurrent + message/transition computation even when detached; differentiable provided + state additionally exposes user-state gradients. All three modes must run + through the same generic flat-bucket temporal engine and must not use + state-presence-specific backend routes or sender-gradient shortcuts. +- Graph/factorization rows: every factorization/shape group used for T=1 + closure or hidden-size/factorization closure. + +T=1 pass criteria: + +- Throughput must be greater than or equal to the matched April 21 reference. +- Peak memory must be less than or equal to the matched April 21 reference. +- Metadata must show shared temporal engine ownership. Training rows cannot + close with `python_autograd_scan`, `python_host_reverse_loop`, old sequence + helpers, or mixed-pop wrappers. + +### Forward T Matrix + +Forward T scaling is required even though horizon-H is a training concept. + +- For every T=1 forward matrix row, run `T=512`, `T=4096`, and frontier `T=16K` + where memory allows with `K=1`. +- Use the same graph, batch, params, hidden size, population mode, reset mode, + and output-boundary contract as the T=1 row. +- Compare against the matched current-code T=1 forward row and any applicable + April 21/HF reference. + +Forward T pass criteria: + +- Increasing T must not materially reduce per-token throughput. +- Peak memory must show streaming behavior and must not retain full recurrent + state surfaces unless the public API explicitly requested them. +- Metadata must show the same shared temporal engine route, not a terminal-only + shortcut or a separate long-T path. + +### Training T/H Matrix + +T/H is not a small representative audit. For every T=1 matrix row that has a +training contract, run the corresponding long-T rows: + +- T sweep values for staged closure and diagnosis: `T=1,2,4,8,...,4096`, + plus frontier `T=16K` where memory allows. `T=512` and `T=4096` remain + required named closure rows; smaller powers of two are guardrails used to + find the first scaling break. +- K value for T/H base closure: `K=1`. +- Horizon: `H=64` is the first closure target for `T>1` rolling TBPTT. +- Output boundaries: terminal and per-timestep sequence loss. +- Checkpoint policy: planner-selected unless explicitly provided by the audit + row; every choice must be recorded in metadata. +- Batches/model sizes/hidden sizes/population modes are inherited from the T=1 + matrix. `B=1` or tiny parameter rows are smoke only unless they are part of + the closure matrix. + +T/H pass criteria: + +- For each row, compare against the matched current-code T=1,K=1 training row + with the same graph, batch, params, hidden size, population mode, reset mode, + and output-boundary contract. +- `T=512,K=1,H=64` and `T=4096,K=1,H=64` must stay flat or above the matched + T=1 training throughput floor. +- If an exact April 21 streaming-sequence reference exists for that row, the + row must also meet or exceed that April 21 throughput and stay within its + memory ceiling. +- Peak memory must be bounded by backend-owned streaming/checkpoint policy. + Shrinking batch, hidden size, graph shape, or model size to make the row fit + is not closure. +- T=16K is required as a frontier sweep where memory allows. If it cannot fit, + record the exact memory owner and reopen the backend stage; do not silently + drop the row. + +### T*K Matrix + +K is the same temporal stream as T*K. The only semantic difference is output +materialization/emission. + +- K sweep: `K=1,2,4,8,...,128`. +- K=128 is the current accepted ceiling. +- K rows must run over the inherited scalability matrix and over long-T rows. + During development, representative probes can use a smaller subset, but final + closure cannot. +- Final closure requires at least `T=512` and `T=4096` with `H=64` for the full + inherited training matrix. `T=16K,K=1` is the required frontier; `T=16K,K>1` + is explored where memory allows and any failure must name the owner. +- `T=1,K>1` rows are implementation probes for the K executor and the T=1/K + floor; they do not by themselves exercise or close H because H is a rolling + sequence horizon. +- For K>1, the throughput floor is matched current-code T=1,K=1 training + throughput divided by K. +- K rows must cover terminal and per-timestep loss and must retain reset parity. +- Report raw tok/s, `tokens_per_s * K`, active-cell microstep rate where + available, peak memory, horizon, checkpoint policy, and emission policy. +- Scalar K is the current closure surface, but the plan representation must be + future-shaped for per-timestep K through cumulative scan offsets. Per-timestep + K must not require a separate executor identity. + +### H Matrix + +- H=64 is the first hard closure target. +- H sweep: `H=1,2,4,8,...,64`. H=64 is the pass gate. +- H>64 exploration is allowed only after H=64 closes. It cannot weaken H=64 + criteria. +- Increasing H may reduce tok/s, but H=64 must remain at or above the matched + K-adjusted T=1 floor where expected. + +## REDO2 Stages + +### R2.0 - Evidence And Doc Discipline + +Status: OPEN. + +Goals: + +- Treat this document as the REDO2 scratchpad. +- Every measured row must include command, artifact path, tok/s, memory, + reference key, owner metadata, and pass/fail reason. +- Every code slice must update this doc before or immediately after execution. + +Exit: + +- Current blockers, measured owners, and next row are always visible in this + file and committed at useful checkpoints. + +### R2.1 - Compiler Front Door And Semantic IR + +Status: ACTIVE. + +Owner: Fabric compiler lowering. + +Goals: + +- Make the high-level Fabric API compile user declarations into semantic IR, + not into old `Config`/anatomy/runtime shortcuts. +- Graph constructors own topology and boundary facts. Generic Fabric IR owns + flat graph facts, edge/group tables, boundary sets, flat bucket identity, and + reset/state/output requests. +- Message declarations own message math. Cell declarations own transition + math. Readout declarations own output/readout math. Execution request types + own `T`, K, H, output materialization, state materialization, and checkpoint + requests. +- `compile_fabric_ir` must preserve declared message, cell, graph, readout, + reset, state, and temporal semantics. It must not synthesize or substitute a + default dot-product/gated/diagonal/readout program when the user declared a + different program. +- Unsupported declarations fail closed during lowering with a precise + unsupported-op/source/shape reason. +- Remove or fail-close front-door surfaces that make broad promises while still + lowering into one hidden canonical algorithm. + +Exit: + +- Changing a supported declaration changes the compiled IR and primitive + program that executes. +- Changing an unsupported declaration fails closed before execution. +- Backend IR carries the declared graph/message/cell/readout/output/state + semantics needed by the primitive table builder. +- Tests prove declared message rules and cell/readout declarations are not + silently replaced by defaults. +- No T=1/T/H/K throughput row is accepted as closure unless it comes through + this compiler front door. + +### R2.2 - Primitive Program And Tensor-Table Lowering + +Status: ACTIVE. + +Owner: compiler primitive-program lowering. + +Goals: + +- Lower semantic IR into a single executable primitive program: + message rows, projection/affine rows, recurrent/local transition rows, + normalization rows, readout/boundary rows, parameter-reduction rows, and + reset/materialization/tape roles. +- `TemporalPrimitiveTablePlan` or its successor is the compiler product, not a + metadata side channel. It must describe all executable primitive rows, tensor + roles, parameter bindings, flat bucket identity, receiver ranges, dependency + ordering, reset policy, tape/artifact needs, and forward/backward behavior. +- Message rules lower from `MessageRuleIR` into message primitive rows. Q/K/V + may exist only as roles inside the message primitive program and executor, + not as temporal-engine ABI. +- Cell transitions lower from cell declarations into transition primitive rows. + Gated or diagonal recurrence can be supported primitive opcodes, but not + hidden temporal-kernel formulas. +- Readout, boundary backward, and parameter reductions lower into primitive + rows instead of runtime glue. +- Delete or fail-close any path that fabricates private gated/diagonal/message + primitive rows outside the compiler product. + +Exit: + +- The active runtime table contains every primitive surface needed for forward + and backward: message, transition, norm, readout, boundary, and parameter + reduction. +- Missing primitive rows are compiler failures, not audit warnings. +- Unit and runtime metadata prove the table came from Fabric IR, not from a + hardcoded temporal binding fallback. +- This stage does not close on labels alone; it closes only when the primitive + program is the only source of truth accepted by the temporal runtime. + +### R2.3 - Primitive Executors And Shared Temporal Runtime + +Status: ACTIVE. + +Owner: primitive executor registry plus shared temporal CUDA runtime. + +Goals: + +- Add the real primitive executor boundary: primitive rows select forward and + backward executors through opcode, tensor-table roles, parameter bindings, + shape metadata, and reset/tape policy. +- Primitive executors own math: message operators, projections, recurrent/local + transition operators, normalization, activations, readout, boundary adjoints, + carry/input gradients, and parameter reductions. +- Add composite/fused executor dispatch as a compiler optimization, not a + semantic shortcut. Supported groups are selected from primitive rows and + tensor roles after lowering. For example, `gated_logspace_recurrence`, + `diag_rtu`, message projection + segment ops, and readout/boundary + reductions may execute as CUDA blocks only if the block is selected by the + compiled rows and reports row/tensor-role coverage in metadata. +- The shared temporal runtime owns only physical time, `T*K`, H windows, + checkpoint/recompute, output/loss materialization, reset scheduling, + dependency order, workspace, and flat-bucket traversal. +- Replace the hardcoded forward temporal formulas with primitive executor + dispatch. `flat_bucket_temporal_scan_kernels.cu` or its replacement may + schedule and fuse primitive rows, but it may not directly encode cell/message + formulas as the semantic source of truth. +- Replace the hardcoded reverse temporal formulas with reverse primitive + executor dispatch. `transition_message_reverse_table_device_loop` and + similar bundled reverse programs cannot be closure owners. +- After a primitive executor replaces any legacy kernel, delete the old + kernel/export/binding or make it fail closed immediately. This applies across + the whole compiler surface: message, transition, projection, normalization, + activation, readout, boundary adjoints, parameter reductions, temporal + scheduling, and any old route-specific kernels. The compiler target is less + code with one executable source of truth, not a new compiler layer sitting + beside legacy kernels. +- Single-pop and mixed-pop use the same compiled primitive program and temporal + runtime. Population cardinality changes only flat bucket rows and + parameter/state bindings. +- T=1, T>1, K>1, and T*K use the same temporal substrate. T=1 is the + degenerate stream, not a special executor identity. + +Exit: + +- Runtime/audit metadata reports no primitive executor blockers. +- Code review greps show no primitive formulas in temporal scheduler files + outside selected primitive executors. +- Training metadata truthfully reports + `temporal_plan_backward_owners=["cuda_temporal_superop"]` only when backward + is physically in the compiled primitive temporal runtime. +- No `python_host_reverse_loop`, `python_autograd_scan`, + `transition_message_reverse_table_device_loop` as semantic owner, + hardcoded gated/diagonal temporal formula owner, or dot-product-only message + runtime owner appears in closure rows. +- Single-pop and mixed-pop rows report the same temporal engine identity and + differ only by flat bucket/parameter/state bindings. +- Representative T=1 parity and reset parity pass through this compiled path + before throughput is accepted. + +### R2.4 - Planner Truth And Policy Ownership + +Status: OPEN. + +Owner: planner and temporal compiler policy. + +Goals: + +- Planner owns K schedule, H horizon, checkpoint steps, emission/materialization + policy, reset policy, static/tape policy, and backend selection. +- Planner records output request, autograd seed surface, finite-H backward + window, checkpoint stride, recompute window, reverse artifact kind, and + required primitive backward surfaces for the whole `T*K` stream. +- Runtime executes the compiled primitive program under the plan and reports + metadata; it does not reconstruct policy from batch, hidden size, family + name, population count, or audit row. +- Checkpoint steps are planner-determined unless explicitly provided. +- H is clipped to actual loss-emission points and available physical steps. + `T*K <= H` is a full available-window backward plan, not a separate scan. + +Exit: + +- Audit metadata exposes the selected policies and owners. +- No runtime or benchmark branch chooses temporal policy from batch, model row, + hidden size, cell family, population count, or lattice shape. +- T=1, T>1, K>1, terminal output, per-timestep output, provided state, and + reset-present rows enter the same planned compiler/runtime path. + +### R2.5 - Continuous Representative Throughput Probes + +Status: OPEN. + +Owner: performance/audit and canonical benchmark runner. + +Goals: + +- Run measured guardrail rows during backend work, not only at the end. +- Guardrail failures immediately reopen the responsible compiler, executor, + backend, or planner stage. A throughput pass through a facade path is a + failure, not progress. +- Keep the canonical audit entrypoint under `benchmarks/fabric/` as a high-level + API consumer. It may build case definitions, shard cases, and record metadata; + it must not choose backend policy or implement temporal execution. +- Extend the runner with explicit high-level scopes for T=1, forward-T, + training-T/H, T*K/H, reset, mixed-pop, hidden-size, and factorization closure. + Comparator unit tests may validate mechanics, but closure criteria come from + this document plus the April 21 JSON and measured current-code baselines, not + synthetic fake-reference rows. +- Use GPUs 0-4 only for this REDO2 work, and set private + `TORCH_EXTENSIONS_DIR`/`TRITON_CACHE_DIR` for CUDA runs. +- Periodically run broader backend sweeps before closing an owner. Do not wait + for final closure to discover that a backend slice broke T=1, reset, + mixed-pop, small-h, or long-T rows. +- Every guardrail must record compiler metadata: declaration source, IR message + rule/cell/readout identities, primitive row count/families, primitive + executor blockers, temporal forward/backward owners, materialization policy, + checkpoint policy, and forbidden-route markers. + +Required guardrails after backend slices: + +- T=1 owner-timed representatives: + `sLSTM 500M, B=512, h=32, sequence loss` plus April21-shaped + `100M/500M/1B` rows such as `B=1024, h=32` terminal/per-timestep training + where applicable. The `100M/500M/1B` rows are not optional because + current-code large-model failures expose memory and parameter-gradient + regressions that the 1M fast probe misses. +- At least one April 21 T=1 `B=1024` row and one high-batch row when the path + changes. +- Fast 1M/high-batch rows are only guardrail probes. They never close T=1 and + cannot replace 100M/500M/1B April21-shaped rows in the representative or full + closure audit. +- T=1 K probes are allowed only as K-executor diagnostics; do not attach H to + them or claim H closure from them. +- T/H representative: + `T=512,K=1,H=64` and `T=4096,K=1,H=64` on the matched representative row. +- K representative: + `K=1,2,8,32,128` on the matched representative row. +- Reset parity: + T=1 and T>1 reset/no-reset. +- Mixed-pop T=1 training. +- Small-h/shape guard after bucket/layout changes. +- A current-code T=4096,K=1,H=64 sequence-loss row at the same B/params/h as a + healthy matched T=1 row before claiming any long-T owner movement. + +Exit: + +- Every accepted backend slice has current measured evidence and artifact + paths. No slice closes on a dry-run case list or metadata-only change. +- The benchmark folder has one canonical Fabric audit path with selectable + scopes for quick, closure, and frontier runs. + +### R2.6 - Full T=1 Closure + +Status: OPEN. + +Owner: performance/audit. + +Goals: + +- Run the full T=1 matrix for single-pop and mixed-pop. +- Compare single-pop to April 21 references. +- Compare mixed-pop to matched same-parameter stack/MoE and April 21 + controlling references. + +Exit: + +- Every T=1 row passes throughput, memory, parity, and compiler/shared-engine + metadata. +- Every T=1 row proves the declaration-to-primitive-executor chain for message, + cell, readout, boundary, and parameter-gradient surfaces. +- Any T=1 row with primitive executor blockers, hardcoded temporal formula + owners, or hidden dot-product/gated/diagonal route owners remains open even + if values match. + +### R2.7 - Full Forward-T And Training T/H Closure + +Status: OPEN. + +Owner: performance/audit and temporal backend. + +Goals: + +- For every T=1 forward matrix row, run forward T scaling at `T=512`, `T=4096`, + and frontier `T=16K` where memory allows. +- For every T=1 training matrix row, run `T=512,K=1,H=64` and + `T=4096,K=1,H=64`. +- Run T=16K frontier for every row where memory allows; record exact failures + for rows that do not fit. +- Cover terminal and per-timestep loss. +- Include single-pop and mixed-pop. +- Include h stress and factorization rows as inherited from the T=1 matrix. +- Include reset-present and reset-absent rows inherited from the T=1 matrix. + +Exit: + +- T scaling is flat or better against matched T=1 for every row. +- Exact April 21 streaming rows pass their historical floor. +- H=64 works without backend memory blowups. +- No row passes via benchmark-owned tiling or route-specific shortcuts. +- Forward-T and training T/H rows use the same compiled primitive program and + temporal runtime identity as their matched T=1 rows. + +### R2.8 - Full T*K/H Closure Across The T=1 Matrix + +Status: OPEN. + +Owner: performance/audit and temporal backend. + +Goals: + +- Run K sweep `1,2,4,8,...,128` over the inherited matrix. +- Run T sweep `1,2,4,8,...,4096` plus frontier `16K` where memory allows. +- Run H sweep `1,2,4,8,...,64`, with H=64 as the closure gate. +- Cover long-T rows, not only `T=1,K>1`: `T=512` and `T=4096` are closure rows + with H=64; `T=16K` is frontier where memory allows. +- Judge every K>1 row against matched T=1/K. +- Preserve H=64, reset parity, terminal loss, and per-timestep loss. +- Prove scalar K uses the same executor family as K=1 and that planner metadata + is future-shaped for per-timestep K. + +Exit: + +- K=128 is accepted across the scalability matrix, not only a small smoke row. +- K extra-work scaling is measured and documented row-by-row. +- `tokens_per_s * K` remains stable enough to show that K is extra work over + the same streaming engine, not benchmark-side looping or route drift. +- K rows prove the same primitive program and temporal runtime as K=1, with + only planner-recorded K schedule, materialization, checkpoint, and H policy + changing. + +### R2.9 - Hidden-Size And Factorization Closure + +Status: OPEN. + +Owner: bucket lowering/performance. + +Goals: + +- Rerun h=4/8/16/32 across representative and closure rows. +- Rerun factorization/shape groups at 1B where applicable. +- Prove flat graph identity, not lattice shape, determines backend behavior. + +Exit: + +- Reducing hidden size does not reduce tok/s. +- Shape/factorization spread remains within the closure threshold. +- No backend policy consumes lattice width/height/depth/wrap or hidden-size + constants as route identity. + +### R2.10 - Continuous R15 Cleanup And Legacy Deletion + +Status: OPEN. + +Owner: cleanup and API ownership. + +Goals: + +- Delete or fail-close legacy execution routes as soon as the compiler/runtime + stage that replaces them lands. Final R15 closure still waits for R2.6-R2.9, + but stale sibling routes must not accumulate during compiler work. +- Delete legacy CUDA kernels, pybind exports, Python wrappers, and tests across + all compiler-owned surfaces once their primitive executors have parity, reset + coverage, owner metadata, and representative throughput evidence. This + includes message, transition, projection, normalization, activation, readout, + boundary adjoints, parameter reductions, temporal scheduling, and any + route-specific kernels. The final compiler should reduce code volume by + removing obsolete kernels instead of permanently wrapping them. +- Remove old sequence helpers, direct-grad/checkpoint paths, stale route names, + population wrappers, benchmark wrappers, hidden-size/cell-family planner + policy, and Config-era truth paths. +- Redesign ownership so graph constructors own lattice facts, message rules own + message semantics, cells own local transition declarations, readout owns + readout declarations, and execution specs own planner requests. +- Move message and cell math through `fabric.cuda.nn`/IR primitive declarations + and tensor/op rows. +- Delete or move the recovered additional-goals inventory: + 1. `Blueprint.message_passing` cannot be concretely `DotProduct`. + 2. `DotProduct.to_ir()` cannot discard declaration fields. + 3. Backend IR must receive lowered message-rule declarations. + 4. Message-rule classification cannot recognize only dot-product attention. + 5. CUDA message lowering cannot depend on one exact dot-product graph shape. + 6. CUDA message-rule builders need real generic lowerer ownership. + 7. PyTorch reference message passing must execute the same rule IR semantics. + 8. Topology bucketing and message semantics must be separated. + 9. Blueprint cannot remain a facade over old `Config`. + 10. Old config concepts cannot remain central internal truth. + 11. Runtime old `d_hidden`/construction paths cannot be peer public APIs. + 12. Public tests cannot use old internals as the correctness oracle. + 13. Graph API must become graph-family generic, not lattice-only. + 14. Interface dimensions cannot be implicitly forced through hidden size + outside one explicit normalization restriction. + 15. Named inputs/outputs must either compile generally or the public surface + must be narrowed honestly. + 16. Population cardinality cannot be constrained by old config. + 17. Population placement must be graph/declaration-owned, not anatomy-owned + random/x-band logic. + 18. Bucket identity cannot include user population names. + 19. Execution policy cannot be selected by hidden-size thresholds or magic + hardcoded tiles. + 20. Cell families cannot drive backend surface selection or runtime policy. +- Clean up `src/cortical/fabric/config.py` by splitting ownership into graph + constructors, message declarations, cell declarations, readout declarations, + initialization utilities, and planner request types. Do not add a compatibility + shim as the migration target. +- Clean up `src/cortical/fabric/anatomy.py` so lattice-only facts such as + width/height/depth/wrap/bands/projection regions live in lattice graph + constructors. Generic Fabric anatomy/backend surfaces expose flat graph facts, + boundary sets, edge/group tables, flat bucket identity, tensor tables, and op + rows. +- Shrink `runtime/core.py` by moving temporal/planner-owned logic into narrower + backend/runtime modules once the shared temporal engine is measured healthy. +- Split `src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py` + only along semantic ownership boundaries, not arbitrary file size: + - `temporal_artifacts.py`: checkpoint policy, artifact-store data classes, + recompute-window payloads, replay payload construction, and checkpoint + lookup/window helpers. + - `temporal_reverse_tables.py`: reverse-window table construction, tensor-role + packing, active reverse-only payload validation, and generic table ABI + guards. + - Boundary/readout adjoints should be registered executor implementations + over compiler bindings, not a standalone runtime-phase module. Output + gradient materialization can live with output backward, while boundary + projection and query/parameter reductions belong to the selected reverse + executor strategy. + - `temporal_reverse_executor.py`: the CUDA reverse sequence owner, H-window + controller, reverse seed carry, reset-aware reverse propagation, and + table-owned temporal backward execution. + - `temporal_param_binding.py`: deferred transition/message/readout parameter + gradient accumulation and trainable-parameter binding. + - `temporal_autograd.py`: public autograd function and the small user-facing + entrypoint that calls forward scan plus physical temporal backward. +- Delete the host-loop reverse scan, Python replay bridge, and stale helper + functions only after the corresponding CUDA-owned semantic file has parity, + reset coverage, owner metadata, and representative throughput evidence. A + split that preserves the same host owner is not cleanup closure. + +Exit: + +- Repository greps and audit metadata show no legacy execution/config truth + path remains. +- Public API and backend/anatomy surfaces are graph-generic. +- All additional-goals issues are closed by code and tests, or remain explicit + open blockers that prevent final REDO2 closure. + +### R2.11 - Final Closure Report + +Status: OPEN. + +Owner: audit/docs. + +Goals: + +- Link all closure artifacts. +- State pass/fail thresholds and measured results. +- List deleted legacy paths. +- State exactly how every Fabric flow enters the shared temporal engine. +- State exactly how every supported declaration compiles into IR, primitive + rows, tensor-table roles, parameter bindings, primitive executors, and the + shared temporal runtime. + +Exit: + +- REDO2 can be called complete only when compiler closure, parity, throughput, + memory, shared ownership, T/H/K, hidden-size/factorization, mixed-pop, and + legacy-deletion gates all pass. + +## Next Action + +The active owner is the Fabric compiler path, not a cleanup side quest and not +an optimization of the current hardcoded temporal route. Work R2.1-R2.3 first: + +1. Finish the compiler front door so declarations reach Fabric IR without + silent substitution. +2. Lower Fabric IR into one primitive program/tensor table for message, cell, + readout, boundary, and parameter-gradient surfaces. +3. Add real primitive executor dispatch and fail closed for rows without + executors. +4. Migrate forward and backward temporal execution to schedule those primitive + rows instead of hidden gated/diagonal/QKV/readout formulas. +5. Run representative high-level T=1 guardrails throughout this work, including + April21-shaped 100M/500M/1B rows, reset/state axes, and mixed-pop rows. +6. Only after T=1 is compiler-correct and healthy, run the required T/H and + T*K guardrails, then full closure matrices. + +R15 cleanup continues inside every compiler/runtime slice: when a compiler path +replaces a legacy route, delete or fail-close the old route immediately. Do not +restore replay, host loops, compatibility shims, or benchmark-owned execution +to make an audit row pass. + +## Active Implementation Log - 2026-04-29 R2.1-R2.3 + +Current owner: + +- `src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py` + no longer retains the old host reverse-loop body as an executable fallback. + If the generic CUDA reverse table path cannot accept a window, backward must + fail closed and the missing table-owned path becomes the owner. +- The unvalidated CUDA `transition_message_reverse_table_window_sequence` + binding and host-side reverse batch slicing were rejected and removed. They + did not move the owner into the table-owned temporal superop and added another + route to reason about. +- `src/cortical/fabric/runtime/core.py` no longer keeps the model-level + direct-grad/checkpoint sequence wrappers as active high-level API routes. + Remaining sequence failures should be repaired inside the shared temporal + engine/planner, not by reintroducing model-level wrappers. + +Boundary review for the next slice: + +- Acceptable ABI: flat tensor-table roles, op rows, scalar metadata, and generic + transition/message primitive tables already used by the single-window reverse + table call. +- Not acceptable: cell-name selectors, population-name route switches, + benchmark-row selectors, hidden-size policy keys, or separate single/mixed + execution identities. +- Next cleanup slice: make unsupported sequence/backward rows fail closed at + the shared temporal engine boundary, then delete unreachable host wrapper + helpers and update tests that still assert `python_autograd_scan` as an + expected owner. If a row needs a reference path, it must be an explicit + reference test path, not the production high-level Fabric route. + +## Active Cleanup Log - 2026-04-29 R15 + +Cleanup target for this pass: + +1. Runtime high-level Fabric sequence APIs must enter the shared temporal + engine for `T=1`, `T>1`, `K>1`, terminal loss, and per-timestep loss. +2. Unsupported CUDA rows fail closed with planner/runtime rejection detail. No + production fallback to `_forward_sequence_checkpointed`, + `_reduce_sequence_outputs_checkpointed`, or `forward_cells` sequence + wrappers. +3. Temporal backward may call the generic reverse table path. If it cannot, it + fails closed instead of recording `python_host_reverse_loop`. +4. Planner/audit/test metadata must stop treating `python_autograd_scan` as a + valid closure owner. Tests may assert rejection/demotion, but not legacy + owner success. +5. After active-route deletion, split the remaining temporal backward code by + semantic owner and continue deleting helpers that are no longer reachable. + +Actions completed in this cleanup slice: + +- Deleted `src/cortical/fabric/backend/cuda/sequence_surface/replay.py`. + Full/reference replay backward and phase probe routing are no longer + production paths. +- Removed the `CORTICAL_FABRIC_BACKWARD_MODE` switch from the sequence + autograd surface. The backward entry always calls the physical/shared + backward executor. +- Deleted the old compiled single-pop flat-bucket temporal scan route in + `temporal_executor.py`. Inference now reaches the shared temporal bucket + scan instead of the separate sequence-surface dispatcher path. +- Removed Python step replay recompute from `temporal_backward.py`. Checkpointed + windows must now be reconstructed through registered temporal executor + artifacts or fail closed. +- Deleted `benchmarks/run_fabric_scaling_profile.py`. Canonical audit entry + points are under `benchmarks/fabric/` and must stay high-level API consumers. +- Renamed audit guardrails from transitional-owner language to forbidden + legacy-owner language. `python_autograd_scan`, `python_step_replay`, and + `open_cuda_temporal_superop` markers are rejection signals only. +- Renamed checkpoint-consumption metadata away from + `physical_recompute_bridge_from_cuda_checkpoints` to + `cuda_temporal_table_from_cuda_checkpoints`. +- Removed the production `backend_host_loop` owner marker from the CUDA + dispatcher and Python sequence-surface metadata. If the old direct + sequence-surface host scan entry is reached, it now raises and directs the + caller to the shared flat-bucket temporal engine. +- Removed the stale `temporal_plan_legacy_recurrence_populations` execution + record field and the runtime test expectations that kept it alive after the + planner stopped exposing a legacy recurrence selector. +- Renamed stale production/debug metadata that still carried legacy or + transitional language without representing live compatibility routes: + `disabled_sequence_surface`, `pack_cublas`, + `message_bucket_missing_explicitly_demoted`, + `identity_projection_copy`, and checkpoint-source metadata now describe the + current owner/rejection directly. +- Deleted the old CUDA dispatcher request builder and the old direct + sequence-surface host-scan entrypoint. Remaining sequence-surface callers now + enter through the compiler-owned temporal executor path; there is no live + `FabricExecutionRequest` construction path or `_run_backend_sequence_surface_once` + tombstone. +- Removed the unused runtime `_run_backend_projected_sequence_surface` wrapper + that only forwarded to the older projected source sequence-surface helper. +- Deleted the projected-source sequence-surface host chunking helper and its + time-chunk policy. Projected-source inputs must now enter through the shared + sequence surface or fail closed; there is no separate projected-boundary + streaming executor, output-consumer route, or benchmark-side chunk policy to + tune. +- Removed stale public runtime tests that exercised `stream_sequence_outputs` + and `reduce_sequence_outputs` as old direct CPU/PyTorch sequence routes. Those + APIs remain cleanup targets until rebuilt on the shared temporal engine; they + are not correctness or performance closure paths for REDO2. +- Renamed graph-capture cache placeholder metadata from fallback language to + `_GraphCaptureCacheEntry`. +- Renamed supported-surface `disallowed_fallbacks` to `forbidden_routes` and + replaced runtime-reference fallback strings with explicit forbidden route + names. +- Audit runner dry-run paths no longer load April 21 score JSON. April 21 + references remain the score source for real audit execution, but unit tests + now exercise temporal-owner and mixed-stack gates with synthetic dictionaries + instead of reading recovered result files. + +Validation/fallout: + +- Runtime tests that monkeypatched deleted replay methods or expected + `python_autograd_scan` owners were rewritten as new-engine-only assertions. +- `python -m py_compile` passed for the touched runtime/backend/audit/test + files. +- `uv run pytest -q tests/test_fabric_audit_runner.py + tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py -n0 + --tb=short` passed: 73 tests. +- `uv run pytest --collect-only -q tests/test_fabric_runtime.py -n0` passed: + 336 tests collected. +- After removing `backend_host_loop`, `uv run pytest -q + tests/test_fabric_execution_imports.py -n0 --tb=short` passed: 4 tests, and + runtime collection still passed. +- Stale marker scan passed with no matches in `src/cortical/fabric`, tests, or + `benchmarks/fabric` for: `disabled_legacy_sequence_surface`, + `pack_cublas_transitional`, `legacy_message_path_explicitly_demoted`, + `identity_projection_legacy_copy`, `legacy Runtime`, `legacy Python`, + `fallback=replayed_transition_checkpoints`, or `legacy_recurrence`. +- `uv run pytest -q tests/test_fabric_execution_imports.py + tests/test_fabric_backend_plan.py -n0 --tb=short` passed: 55 tests. +- `uv run pytest --collect-only -q tests/test_fabric_runtime.py -n0` passed: + 336 tests collected. +- `python -m py_compile` passed for the touched audit, runtime, sequence + surface, support, surfaces, and test files after the unreachable host-scan + deletion. +- `uv run pytest -q tests/test_fabric_audit_runner.py + tests/test_fabric_execution_imports.py tests/test_fabric_backend_plan.py -n0 + --tb=short` passed: 72 tests. +- `uv run pytest --collect-only -q tests/test_fabric_runtime.py -n0` passed: + 336 tests collected after this continuation. +- After deleting the projected-source helper, `CUDA_VISIBLE_DEVICES= + python -m py_compile` passed for + `src/cortical/fabric/backend/cuda/sequence_surface/surface.py`, + `src/cortical/fabric/backend/cuda/sequence_surface/policy.py`, and + `tests/test_fabric_runtime.py`. +- After deleting the stale direct stream/reduce tests, `CUDA_VISIBLE_DEVICES= + uv run pytest --collect-only -q tests/test_fabric_runtime.py -n0` passed: + 322 tests collected. +- `CUDA_VISIBLE_DEVICES= uv run pytest -q tests/test_fabric_audit_runner.py + tests/test_fabric_execution_imports.py tests/test_fabric_backend_plan.py -n0 + --tb=short` passed: 72 tests after this projected-source deletion. +- A broader runtime `-x` pytest attempt was killed and is not closure evidence: + it entered CUDA extension compilation through the default global cache instead + of the REDO2 private-cache/GPU-0-4 experiment discipline. +- Scan note: remaining broad matches for legacy/fallback/transitional terms are + guardrail assertions/test names, explicit forbidden owner markers + (`python_autograd_scan`, `python_step_replay`), HuggingFace stack cache API + names in `benchmarks/fabric/suite_common.py`, projected-message primitive + names, and fail-closed direct sequence-route error strings. They are not R15 + closure; config/anatomy, direct sequence API cleanup, and temporal-backward + semantic split remain open scan targets. +- Any audit/runtime row that fails after this deletion is a real shared-engine + implementation gap, not permission to restore replay, shims, or host loops. + +## Active Backend Log - 2026-04-29 R4/R13 + +Decision: + +- R15 remains open, but active execution cleanup has removed enough stale + replay/direct sequence pollution to move back to the throughput-critical + backend owners. +- The next owner is R4/R13: table-owned CUDA temporal reverse scan for + training, followed by current-code T=1 and T/H guardrail audits. +- Do not spend the next iteration on `config.py`/`anatomy.py` ownership cleanup + unless a backend path directly requires it. + +Boundary review for the next backend slice: + +- Acceptable ABI: flat bucket identity, tensor-table/op-table roles, + primitive-table bucket indices, reset windows, output materialization policy, + checkpoint/recompute windows, and scalar schedule metadata. +- Not acceptable: cell-family route switches, population-name route switches, + benchmark-row selectors, hidden-size policy keys, lattice/config fields, + or separate single-pop/mixed-pop temporal routes. +- If a CUDA reverse path only works for a specific pair of primitive families, + the fix is to make missing primitive buckets zero-count table rows through the + generic table ABI or add a generic primitive-table row, not add a cell-specific + branch. + +Immediate guardrails: + +- Run a matched high-level T=1 training guardrail before attributing any T/H + failure. Use private `TORCH_EXTENSIONS_DIR` and `TRITON_CACHE_DIR`; restrict + CUDA visibility to GPUs 0-4. +- Run a representative T/H guardrail after the T=1 owner is known: + `T=512,K=1,H=64` and then `T=4096,K=1,H=64` on the same matched row, terminal + and per-timestep where feasible. +- Guardrail rows are steering evidence only. Closure still requires the full + T=1 matrix and inherited T/H/K matrix. + +Open owner expectation: + +- Current stale evidence showed `temporal_artifact_recompute`, + `artifact.recompute.cuda_temporal_replay_scan`, and + `python_autograd_scan`/`python_host_reverse_loop`-style ownership on the + T=1 training row. After R15 deletion, rerun current-code evidence before + changing kernels. + +Current-code guardrail evidence: + +- `sLSTM 500M, B=1024, h=32, T=1, K=1, terminal training`: + `/tmp/redo2_r4_t1_slstm500m_b1024_h32_terminal/cases.jsonl`. + Status: OOM before owner metadata. The row attempted another 23 GiB with + about 118 GiB already in use. This is a T=1 memory regression signal, not a + T/H diagnosis. +- `sLSTM 100M, B=1024, h=32, T=1, K=1, terminal training`: + `/tmp/redo2_r4_t1_slstm100m_b1024_h32_terminal_timing/cases.jsonl`. + Status: ran, but failed closure. Result: `351.694 tok/s`, peak + `27.992 GiB`, current-code elapsed `2911.624 ms`. + April 21 summary floor for `h32_t1_bxparams`: `58,732.71 tok/s`, + `2.07 GiB`. +- Active metadata on the 100M/B1024 row: + forward owner `cuda_temporal_superop`; backward plan owner + `cuda_temporal_superop`; runtime scan owner `cuda_temporal_superop`; backward + executor `physical_temporal_bucket_sequence_backward`; reverse-scan workspace + alias `flat_bucket_temporal_reverse_scan_owner:cuda_temporal_superop`. +- Dominant measured backward owners: + `transition_message_reverse_table_device_loop=1695.380 ms`, + `temporal_artifact_recompute=587.759 ms`, + `artifact.recompute.cuda_temporal_replay_scan=580.494 ms`. + Next code must reduce the CUDA reverse table/recompute owner; metadata + relabeling is not progress. +- Audit reporter fix in progress: emit `launch_temporal_scan_owners` from + `benchmarks/fabric/suite_common.py` and allow the generic + `physical_temporal_bucket_sequence_backward` executor through the owner gate + so the row fails on forbidden timing owners instead of a reporter omission. +- Corrected gate rerun: + `/tmp/redo2_r4_t1_slstm100m_b1024_h32_terminal_gatefix/cases.jsonl`. + Result: `352.883 tok/s`, peak `27.992 GiB`, `2901.808 ms`. Gate failure is + now `forbidden_backward_temporal_owner_timing_present`, with forbidden owners + `transition_message_reverse_table_device_loop`, `temporal_artifact_recompute`, + `artifact.recompute.cuda_temporal_replay_scan`, and + `artifact.recompute.cuda_replay_input_projection`. + +## Planner Output-Request Refactor - 2026-04-29 + +Correction recorded: + +- `terminal training` is not a backend mode. It is a compact output/loss + request: emit only the final outer timestep, then ordinary PyTorch autograd + may seed loss from that returned tensor. Per-timestep loss is the same path + with all requested outer timesteps emitted. Future explicit timestep loss + should be represented as explicit requested outer steps, not as another + backend training path. +- H is a TBPTT/reverse-window property for the loss/output request over the + available physical stream. For `T*K <= H`, the planner must clip to the + available stream and must not create an extra rolling scan just because H was + requested. +- Checkpoint/recompute policy must be planner-owned and based on the output + request, final-state materialization, gradient surfaces, H, `T*K`, and memory + policy. Local shortcuts inside temporal backward are rejected. + +Rejected probe: + +- A local T=1 direct-artifact shortcut was tried and backed out. Artifact: + `/tmp/redo2_r4_t1_slstm100m_b1024_h32_direct_surface_retry/cases.jsonl`. + It failed with reverse-engine rejection + `non_backend_order_public_carry` and fell back to stored step artifacts. This + is evidence that forward-produced artifact consumption must be planned through + the shared tensor-table/flat-bucket contract, not a row-local shortcut. + +Landed planner slice: + +- Removed the legacy planner `TemporalEmissionPlan` and the planner-side + `boundary.output_boundary`. +- Added `TemporalOutputRequestPlan`, which records: + output selector (`all_outer_steps`, `terminal_outer_step`, future + `explicit_outer_steps`), compact outer-step schedule, compact physical-step + schedule, output surface, readout surface, final-state materialization, + autograd seed kind, required generic backward surfaces, and checkpoint policy + basis. +- Runtime and benchmark metadata now report `temporal_plan_output_*`, + `temporal_plan_autograd_seed_kinds`, + `temporal_plan_required_backward_surfaces`, and + `temporal_plan_checkpoint_policy_basis` instead of planner + `temporal_plan_emission_*`. +- Verification for this slice: + `CUDA_VISIBLE_DEVICES= python -m py_compile ...` passed for touched planner, + backend metadata, benchmark, and test files. + `CUDA_VISIBLE_DEVICES= uv run pytest -q tests/test_fabric_backend_plan.py + tests/test_fabric_audit_runner.py -n0 --tb=short` passed: 69 tests. + `CUDA_VISIBLE_DEVICES= uv run pytest -q tests/test_fabric_execution_imports.py + -n0 --tb=short` passed: 4 tests. + Focused runtime planner metadata test passed. + +Open backend owner after this slice: + +- R4/R13 remain open. The next useful code must make temporal backward consume + the planner output request when constructing reverse tables and checkpoints: + no local terminal/sequence training branch, no cell-specific route, no + population-name branch, and no hidden-size policy key. +- Current measured blocker remains the table-owned CUDA reverse scan and + recompute owner on the high-level T=1 training guardrail. Planner metadata is + only useful if it moves that owner next. + +## State-Path Correction And Profiling Plan - 2026-04-29 + +Correction: + +- A fresh-zero/state-provided split is rejected. Whether the caller passes no + state, a detached state, or a differentiable state must not select a different + temporal backward route, CUDA scalar path, or recurrent-sender-gradient + shortcut. The only semantic differences are tensor values and whether + user-provided differentiable state receives exposed state gradients. +- The attempted recurrent-sender-gradient elision probe is rejected and removed. + It also did not move the measured owner before removal: + `/tmp/redo2_r21_sender_elision_t1_slstm100m_b1024_h32_sequence_timing2/cases.jsonl` + stayed around `428.959 tok/s`, `2387.175 ms`, peak `27.992 GiB`, with + `transition_message_reverse_table_device_loop=1692.787 ms`. + +Immediate profiling requirement: + +- Before the next structural CUDA edit, run deeper current-code profiles on the + same high-level API row. Required artifacts: + owner-timed audit JSONL, warmed repeated timing, PyTorch profiler or CUDA + event timeline where available, and a focused CUDA-kernel timing breakdown for + `transition_message_reverse_table_device_loop`. +- The profiler must compare state modes as an axis over the same route: + backend-created state, detached provided state, and differentiable provided + state. If metadata or timing shows a route split, reopen R2.1/R2.2 before + interpreting throughput. +- The first target remains the matched T=1 training guardrail on an April + 21-shaped row (`B=1024`, `h=32`, same graph/params/loss contract), then the + first T/H guardrail only after T=1 is healthy enough to attribute the long-T + owner scientifically. + +Current profiling evidence: + +- Focused CUDA guards passed: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/redo2_ext_generic_state_guard2 + TRITON_CACHE_DIR=/tmp/redo2_triton_generic_state_guard2 uv run pytest -q + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts + tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients + -n0 --tb=short`, `2 passed in 4.58s` after cache build. These guards assert + the no-state and provided-state rows use the same generic + `cuda_temporal_superop`/flat-bucket sequence path; provided state is not a + route key. +- High-level T=1 terminal audit without owner timing: + `/tmp/redo2_deep_profile_t1_slstm100m_b1024_h32_terminal/cases.jsonl`. + Result: `429.544 tok/s`, `2383.923 ms`, peak `27.993 GiB`, reference + `h32_t1_bxparams` April 21 floor `58,732.71 tok/s`, `2.07 GiB`. + Metadata shows `temporal_artifacts:store_step_artifacts` with + `source=cuda_temporal_superop_forward_reverse_tables` and + `recompute_elided=forward_reverse_table`. +- High-level T=1 terminal audit with owner timing: + `/tmp/redo2_deep_profile_t1_slstm100m_b1024_h32_terminal_timing/cases.jsonl`. + Result: `428.929 tok/s`, `2387.343 ms`, peak `27.993 GiB`. Gate fails only + because timing exposes the forbidden dominant owner: + `transition_message_reverse_table_device_loop=1703.527 ms`. + Other owners are small: `public_projection=6.715 ms`, + `message.fused_receiver_sender=1.599 ms`, `readout=0.570 ms`. +- High-level T=1 sequence/per-timestep audit with owner timing: + `/tmp/redo2_deep_profile_t1_slstm100m_b1024_h32_sequence_timing/cases.jsonl`. + Result: `429.177 tok/s`, `2385.960 ms`, peak `27.993 GiB`. + Dominant owner remains + `transition_message_reverse_table_device_loop=1692.380 ms`, so terminal + versus sequence output materialization is not the cause. +- PyTorch/CUDA profiler artifact: + `/tmp/redo2_deep_profile_t1_slstm100m_b1024_h32_sequence_torch_profiler/summary.json` + and chrome trace + `/tmp/redo2_deep_profile_t1_slstm100m_b1024_h32_sequence_torch_profiler/trace.json`. + Top CUDA events: `_TemporalBucketSequenceFunctionBackward/fabric.backward.physical_temporal_bucket_sequence` + around `1772 ms`, `transition_message_reverse_table_device_loop_kernel` + around `1690.6 ms`, forward `flat_bucket_gated_diagonal_temporal_scan_kernel` + around `599.6 ms`, recurrent K/V projection backward hidden/weight around + `39.4 ms` and `18.7 ms`. + +Current owner: + +- R2.1-R2.3 remain open. Recompute replay is no longer the T=1 blocker on this + row; the blocker moved to the table-owned CUDA reverse kernel. Under the + compiler-first stage order, that kernel cannot be optimized as a hidden + reverse formula owner. The next code must move the work behind primitive + executor dispatch and then retile/split the generic executor work as needed, + not add state-provided/no-state, cell-family, row-id, hidden-size, or + single-population shortcuts. + +April 26-style design correction: + +- A T=1-only reverse kernel split was rejected before completion. The issue is + not that the single-step case needs a separate owner; the issue is that the + shared temporal plan must own the materialization policy for the whole `T*K` + stream. +- The target is the recovered April 26 design shape: plan first, execute second. + The temporal planner records output request, autograd seed surface, finite-H + backward window, checkpoint stride, recompute window, and reverse artifact + kind. The CUDA temporal engine consumes that plan and flat tensor/op tables. + `T=1` is only the degenerate full-stream case of the same plan. +- Direct forward-owned reverse tables are allowed only when the planner says the + requested `T*K` dependency range fits the checkpoint/window contract and the + current generic table surface supports the requested reset/output/final-state + contract. Larger streams use planner-recorded checkpoint/recompute windows. + This is not a row shortcut and must not be keyed by family, hidden size, + benchmark id, or single/mixed population. +- Implemented first design slice: `TemporalMaterializationPlan` is now part of + `TemporalExecutionPlan`, with metadata for + `temporal_plan_reverse_artifact_kinds`, + `temporal_plan_recompute_window_steps`, and + `temporal_plan_materialization_reasons`. The old local + `_should_materialize_forward_reverse_tables` predicate in + `temporal_backward.py` was removed; forward reverse-table selection now reads + the planner materialization decision. +- This refactor does not close R2.1/R2.2 by itself. It removes a wrong owner for + the materialization decision so the next backend work can target a generic + temporal-window/table executor and then re-profile throughput. +- Verification after the planner-materialization refactor: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/redo2_ext_temporal_materialization_plan1 + TRITON_CACHE_DIR=/tmp/redo2_triton_temporal_materialization_plan1 uv run pytest -q + tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts + tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific + tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients + -n0 --tb=short`, `3 passed in 4.26s`. +- The new T>1 guard proves the planner direct-table materialization is not a + single-step route: `T=2,K=1`, no final state, sequence/terminal last-step + parity, `temporal_plan_reverse_artifact_kinds=["forward_reverse_tables"]`, + `temporal_plan_recompute_window_steps=["2"]`, and direct-table recompute + elision. +- Current owner-timed audit after the refactor: + `/tmp/redo2_materialization_plan_t1_slstm100m_b1024_h32_sequence_timing/cases.jsonl`. + Result: `429.350 tok/s`, `2385.003 ms`, peak `27.993 GiB`; gate still fails + because `transition_message_reverse_table_device_loop=1694.805 ms`. + Planner metadata now records `temporal_plan_reverse_artifact_kinds=["forward_reverse_tables"]` + and `temporal_plan_materialization_reasons=["materialization=forward_reverse_tables;physical_steps=1;window_steps=1;checkpoint_steps=1"]`. + This is accepted as design/metadata progress only, not throughput closure. +- Next throughput owner: the generic reverse table still performs recurrent + sender K/V projection adjoints inside the cooperative transition/message + loop. That work is mathematically generic projection backward: + `grad_hidden += grad_kv @ W^T` and `grad_W += hidden^T @ grad_kv`. It should + move to a table-owned projection-adjoint phase/reduction that is valid for any + Fabric primitive row, not to a T=1, cell-family, or benchmark-specific kernel. + +## Boundary Failure And Correction - 2026-04-29 + +Rejected probes: + +- Forward probe: split the cooperative forward scan into `regular_gated_*` + phase kernels and ATen-batched gated affine calls inside + `flat_bucket_temporal_scan_kernels.cu`. +- Reverse probe: split a gated reverse-state adjoint and sender reverse phase + out of `transition_message_reverse_table_device_loop_kernel`. + +Why rejected: + +- The probes improved representative throughput, but they put declared + primitive math directly in the temporal scan/reverse engine. That violates + the Fabric boundary. The temporal engine may schedule time, `T*K`, H, + checkpoint/recompute, reset, materialization, and dependency ordering, but it + must not own gated recurrence, diagonal recurrence, dot-product message, + normalization, projection, or readout formulas. +- Tensor tables and op rows are the required ABI. Primitive math must live + behind `fabric.cuda.nn`/lowered primitive executors selected by op-row + metadata. The temporal owner dispatches primitive rows; it does not inline + primitive formulas. +- REDO2 cannot close with a fake `fabric.cuda.nn` internal path. A declaration + only counts if it is the source of truth for Fabric IR, primitive op rows, + tensor-table roles, parameter bindings, shape metadata, and primitive + executor selection. If `fabric.cuda.nn` is only a label while the backend + still infers semantics from hardcoded gated/diagonal/message formulas, + cell-specific parameter bundles, or temporal-kernel branches, the owner + remains open. +- Faster tok/s from these probes is invalid evidence. It cannot close R2.1, + R2.2, R2.3, or R15. + +Actions taken: + +- Restored `flat_bucket_temporal_scan_kernels.cu` to remove the invalid + `regular_gated_*`/ATen gated-affine forward path. +- Restored `flat_bucket_temporal_backward_kernels.cu` to remove the invalid + reverse split path and its copied message/transition primitive math. +- Updated `skills/cb.fabric-backend-boundaries/SKILL.md` and + `skills/cb.fabric-performance-loop/SKILL.md` with an explicit hard stop: + no temporal CUDA probe is acceptable if it inlines declared primitive math + instead of dispatching tensor-table/op-row primitive executors. + +Next valid backend design: + +- Introduce a real tensor-table primitive-executor layer for temporal forward + and reverse. The temporal CUDA/C++ owner should iterate physical steps and + execute generic op rows such as message primitive, receiver affine, recurrent + primitive, norm/readout primitive, projection adjoint, and parameter-reduction + primitive by dispatching declared primitive executors. +- The owner API remains flat bucket identity plus tensor/op tables, output + request, H/window, reset policy, checkpoint/recompute policy, and + materialization plan. +- Primitive executors may be specialized by primitive opcode lowered from + `fabric.cuda.nn`; the temporal engine must not branch on cell name, + population name, benchmark row, hidden-size policy, or hardcoded cell/message + formulas. +- Closure evidence must include code-review and runtime metadata showing the + active high-level API row used the declaration -> IR -> primitive op row -> + tensor table -> primitive executor chain. Metadata that merely says + `fabric.cuda.nn` without that chain is not closure evidence. +- R2.1-R2.3 remain open. The next implementation must complete the compiler + front door, primitive program, and primitive dispatch boundary, then move the + hot forward/backward work through that boundary and rerun T=1 guardrails. + +## Active Plan - Fabric Compiler And Primitive Executor Closure - 2026-04-29 + +Status: ACTIVE. + +Scope: + +- This plan covers the whole active Fabric cell-math path and the whole active + Fabric message-passing path. It is not limited to temporal files, sLSTM, + Axon, dot-product attention, or the latest throughput row. +- Any path that does not prove + `declaration -> Fabric IR -> primitive op row -> tensor table -> primitive executor` + is a compiler closure blocker. Metadata or wrapper names that mention + `fabric.cuda.nn` are not enough. + +1. Audit the active compiler/runtime path for facade violations. + + - Grep and review temporal forward/backward, runtime lowering, + `transition_execution.py`, cell specs, message-rule declarations, message + CUDA lowering, readout/boundary backward, and parameter-gradient binding. + - Flag hardcoded primitive formulas, cell-family bundles, dot-product-only + message assumptions, metadata-only `fabric.cuda.nn`, population-name + selectors, native-cell-kind selectors, and bundled gated/diagonal temporal + arguments. + - Mark every violation as a compiler closure blocker, with file references + and the owner that must remove, replace with a primitive executor, or + fail-close it. + +2. Define the real primitive-executor boundary. + + - Add or clean the CUDA-side primitive executor layer so op rows select + primitive executors through tensor-table roles and parameter bindings. + - Temporal engine ownership is limited to physical time, `T*K`, H, + reset scope, checkpoint/recompute, output/loss materialization, + dependency order, and workspace policy. + - Primitive executors own message operators, projection operators, + recurrence operators, normalization, readout, adjoints, and parameter + reductions. + - Primitive executors may specialize by primitive opcode and tensor shape + metadata. They must not specialize by cell family, population name, + benchmark row, hidden-size policy key, or single/mixed route identity. + +3. Wire forward through tensor/op rows first. + + - Forward temporal execution should consume flat bucket identity plus + lowered primitive rows for message, receiver projection, recurrent/local + transition, normalization/public emission, readout, and output + materialization. + - Remove or fail-close forward paths that infer semantics from `slstm`, + `axoncell`, `gated`, `diagonal`, dot-product bundles, or old Config-era + message metadata instead of primitive rows. + - Keep T=1 as the base streaming contract, but do not add T=1-specific + logic. T=1, T>1, and T*K must be the same temporal substrate with + different materialization/checkpoint policy. + +4. Wire backward through the same primitive row system. + + - Build reverse primitive executors for message adjoints, projection + adjoints, transition adjoints, normalization adjoints, readout/boundary + adjoints, carry/input gradients, and parameter reductions. + - The reverse temporal owner consumes the same op/tensor tables plus the + planner's `BackwardWindowPlan`/materialization plan. It must not infer + semantics from cell names, bundled gated/diagonal assumptions, or one + dot-product message implementation. + - Remove or fail-close any reverse path that uses Python host scan loops, + replay bridges, cell-family payload lists, or hidden formulas as the + semantic source of truth. + +5. Run guardrails continuously. + + - First guardrail: T=1 high-level API training rows from the April21-shaped + B/params/h matrix, including reset/no-reset and provided-state/no-state. + - Then representative T/H rows: `T=512,K=1,H=64` and + `T=4096,K=1,H=64`, with terminal and per-timestep output/loss requests as + applicable. + - Only after T=1 is healthy, expand to the full inherited T/K/H matrix: + T sweep `1,2,4,8,...,4096,16K`, K sweep `1,2,4,8,...,128`, and H sweep + `1,2,4,8,...,64`. + - Benchmarks remain high-level API consumers only: model forward, external + loss, `loss.backward()`, optimizer step where applicable. They do not own + backend policy, checkpoint policy, detach policy, or private temporal + calls. + +Immediate next implementation: + +- Scan the current Fabric CUDA lowering/runtime files and identify the smallest + real primitive-dispatch boundary that already exists or is missing. +- Implement that boundary before touching performance kernels again. +- No throughput result closes R2.1-R2.3/R15 unless the active high-level row + reports and code-review proves the full declaration-to-primitive-executor + chain for both cell math and message passing math. + +## Fabric Compiler Blocker Findings - 2026-04-29 + +Status: ACTIVE COMPILER BLOCKERS FOR R2.1-R2.3. + +Audited scope: + +- Public declaration/config path: + `src/cortical/fabric/blueprint.py`, + `src/cortical/fabric/config.py`, + `src/cortical/fabric/message_rules/declarations.py`. +- IR/planner/runtime path: + `src/cortical/fabric/backend/ir.py`, + `src/cortical/fabric/backend/cell_backend.py`, + `src/cortical/fabric/backend/cell_specs.py`, + `src/cortical/fabric/backend/planner.py`, + `src/cortical/fabric/runtime/core.py`. +- CUDA execution path audited at the time: + `src/cortical/fabric/backend/cuda/nn/ir.cuh`, + `src/cortical/fabric/backend/cuda/message_passing/*`, + `src/cortical/fabric/backend/cuda/ops/*`, + `src/cortical/fabric/backend/cuda/transition_execution.py`, + `src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py`, + `src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py`, + `src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py`, + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_temporal_scan_kernels.cu`, + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_temporal_backward_kernels.cu`. + +Findings: + +1. Public message declarations are not the active generic source of truth. + + - `message_rules/declarations.py` exposes `DotProduct` only and validates + receiver slot, sender public, one head, and projected-message output as a + fixed v1 surface. + - Follow-up cleanup now routes `DotProduct` through the registered + `MessageRuleBackendSpec` builder path. The remaining compiler issue is + narrow public rule coverage and narrow executor support, not a separate + CUDA-side dot-product declaration. + - `backend/message_rules.py` now requires registered rule metadata for a + supported lowering. Raw structural message patterns fail closed unless a + backend spec supplies lowering kind and primitive bindings. + - `backend/cuda/nn/ir.cuh:1051-1120` repeats the same exact message-rule + classifier and rejects all non-dot-product lowerings. + - Compiler blocker: message rules need real declaration objects that lower + into message primitive rows and tensor-table roles. Dot-product can be one + primitive composition, but it cannot be the hidden message backend policy. + +2. Blueprint/config normalization still routes through old generic-looking + Config fields. + + - `blueprint.py:95-162` normalizes Blueprint into `Config` with + lattice/dimension/message fields such as width/height, local radius, + wrap, head_dim, `num_heads=1`, and projection-region shape. + - `config.py:8-62` still owns cell-family literals, head fields, lattice + graph facts, message/KV grouping, readout policy, backend request, K, H, + and initialization in one object. + - Compiler blocker: ownership must be split so graph constructors own graph + facts, message declarations own message math, cells own local transition + declarations, readout owns readout declarations, and execution specs own + planner requests. No compatibility shim is a closure target. + +3. Cell declarations and backend specs are metadata bundles, not full + `fabric.cuda.nn` lowering. + + - `cells/declarations.py:12-31` maps user cells to + `CellPopulationConfig(cell_type=...)`. + - `backend/cell_backend.py:17-71` has useful Python `TransitionOp` and + `CellTransitionIR` structs, but they are not yet lowered into executable + CUDA primitive rows as the active path. + - `backend/cell_specs.py:16-101` and `backend/cell_specs.py:104-281` + register sLSTM/Axon specs as named transition bundles, parameter names, + surface keys, and implemented variants. + - `cells/slstm.py:24-69` and `cells/axon.py:38-99` register cell specs by + `cell_type` and materialize hand-shaped parameter bundles. + - Compiler blocker: cells may declare primitive math, state schema, + parameter schema, and local recurrence meaning. They must not be the + selector for temporal execution, message semantics, projection semantics, + or CUDA route identity. + +4. Runtime and planner still select behavior from cell names, population names, + and fixed message lowering labels. + + - `runtime/core.py:96-118` stores `_population_cell_types`; `core.py:468-483` + exposes surface selection by cell type or population name. + - `runtime/core.py:1663-1667` records `"mixed"` as a separate cell-type-like + identity in PyTorch execution metadata. + - `runtime/core.py:1699-1795` materializes message/projection/fused + recurrent inputs directly from module fields and static tensor names. + - `runtime/core.py:1817-1845` branches on `cell_spec.public_schema.kind` + (`hidden` versus `preproj`) to choose backend cell tensor bundles. + - `runtime/core.py:3298-3419` routes recurrent message and transition + execution through backend-order transition buckets or population update + helpers rather than a generic primitive-row executor. + - `backend/ir.py:116-119` always installs + `default_dot_product_message_rule_summary(...)`. + - `backend/ir.py:237-260` rebuilds backend transition signatures from + `cell_type`. + - `backend/planner.py:656-714` constructs physical backward plans by bucket + population and transition signatures; `planner.py:1600-1975` has a useful + primitive backward registry, but active executors still do not dispatch + primitive rows from it. + - Compiler blocker: planner/runtime should build a single primitive + execution plan from declared graph/message/cell/readout IR. Runtime should + execute the plan, not reconstruct semantic routing from names or + schema-kind shortcuts. + +5. CUDA message execution is a hardcoded dot-product message implementation, + not a generic message primitive row executor. + + - Deleted `backend/cuda/message_rules/dot_product.cuh` because it was a + second CUDA-side declaration source for the dot-product rule. Message + semantics must come from `fabric/message_rules/declarations.py` and + registered `backend/message_rule_specs.py` / `MessageRuleBackendSpec` + lowering, parallel to cells coming from `backend/cell_specs.py` / + `CellBackendSpec` / `CellTransitionIR`. + - `backend/cuda/message_passing/local_message_kernels.cu:52-137` and + `local_message_kernels.cu:139-244` implement dot-product logits, + distance penalty, softmax, and weighted sum directly. + - `backend/cuda/message_passing/sparse_message_kernels.cu:53-131` and + `sparse_message_kernels.cu:133-251` implement the sparse dot-product + forward/backward formula directly. + - `backend/cuda/ops/dense_message_kernels.cu:136-218` and onward provide a + more primitive-looking dense message layer, but it is still selected by + lowered dot-product/message-bucket assumptions, not by a general + message-rule op-row program. + - Compiler blocker: message forward/backward must lower declared message ops + into message primitive rows. The temporal engine may schedule a message row + but must not treat dot-product attention as Fabric's message semantics. + +6. The active temporal forward path has tensor-table names but still dispatches + hardcoded gated/diagonal formulas. + + - `temporal_tables.py:23-55` defines `TemporalTensorTableSlot`, + `TemporalPrimitiveRow`, and `TemporalPrimitiveTablePlan`; this is the + closest existing boundary to keep. + - `temporal_tables.py:90-140` builds rows from backend buckets, but + `temporal_tables.py:143-165` immediately narrows scan support to at most + one `gated_logspace_recurrence` bucket and one `diag_rtu` bucket. + - `flat_bucket_temporal_scan_kernels.cu:25-151` and + `flat_bucket_temporal_scan_kernels.cu:426-552` expose a + `flat_bucket_gated_diagonal_temporal_scan_kernel` ABI with explicit gated + and diagonal argument lists. + - `flat_bucket_temporal_scan_kernels.cu:656-729` computes gated projection, + recurrent affine, logspace recurrence, and outnorm directly. + - `flat_bucket_temporal_scan_kernels.cu:776-848` computes diagonal/Axon + projection, recurrence, traces, activation, and output projection + directly. + - `flat_bucket_temporal_scan_kernels.cu:1842-1978` maps tensor roles to + primitive-specific names, then still calls the hardcoded gated/diagonal + scan kernel. + - Compiler blocker: keep the table concept, but replace the gated/diagonal + scan ABI with a temporal primitive scheduler over generic op rows and + tensor slots. + +7. The active temporal backward path has table wrappers but still owns + primitive formulas and cell-family parameter bundles. + + - `transition_execution.py:406-422`, `transition_execution.py:656-672`, + and `transition_execution.py:2868-2898` select sLSTM/Axon-style execution + by checking exact op-name sequences such as `gated_logspace_recurrence` + and `diag_rtu`. + - `transition_execution.py:2715-2865` computes gated input projection, + recurrent affine, recurrence/outnorm, public K/V projection, and tape + construction directly. + - `transition_execution.py:1320-1363` and + `transition_execution.py:2327-2661` implement recurrent input projection + tapes and adjoints using named static tensors such as + `value_to_cell_weight`. + - `temporal_backward.py:562-618` rebuilds transition tapes by explicitly + naming gated and diagonal buckets. + - `temporal_backward.py:5987-6588` builds a reverse engine payload around + `gated_*` and `diagonal_*` tensors, message dot-product tables, head/value + dimensions, and parameter names, then calls + `try_transition_message_reverse_table_window_cuda`. + - `temporal_backward.py:6714-6775` maps reverse results back to hardcoded + parameter names (`gate_weight`, `recurrent_kernel`, `nu_log`, `theta_log`, + `value_to_cell_weight`, etc.). + - `flat_bucket_temporal_backward_kernels.cu:397-555` and + `flat_bucket_temporal_backward_kernels.cu:555-760` implement gated + logspace recurrence, outnorm, recurrent-affine, and input-projection + adjoints directly. + - `flat_bucket_temporal_backward_kernels.cu:258-397` implements diagonal + recurrence backward directly. + - `flat_bucket_temporal_backward_kernels.cu:3059-3795` implements the + transition/message reverse table loop as a fixed gated+diagonal+regular + local message program, then performs hardcoded parameter reductions. + - Compiler blocker: reverse must dispatch reverse primitive rows for + message, projection, recurrence, normalization, readout/boundary, + carry/input gradients, and parameter reductions. The temporal reverse + owner should only own H/window, scan direction, reset/mask policy, + materialization, dependency order, and workspace. + +8. `fabric.cuda.nn` exists but is not yet the live end-to-end source of truth. + + - `backend/cuda/nn/ir.cuh:1239-1269` defines `CellTransitionIR`, + `FabricStepIR`, and `LoweredPhaseIR`. + - `backend/cuda/nn/ir.cuh:1432-1628` has a C++ Builder for primitive-ish + cell declarations, state affines, and diagonal recurrence. + - `backend/cuda/nn/ir.cuh:1674-1726` lowers affine/message/diagonal pieces + into phase IR. + - The active Python runtime and temporal CUDA kernels do not consume this as + the executable source of truth; they consume Python `CellBackendSpec` + bundles, static tensor dictionaries, and primitive-specific temporal + arguments. + - Compiler blocker: either bridge Python declarations into this IR or move + the source of truth to a Python-side equivalent, but the active path must + be auditable from declaration to primitive executor. + +Existing useful boundary to keep: + +- `TemporalPrimitiveTablePlan` in `temporal_tables.py` is the smallest + meaningful existing boundary. It already records bucket ordinal, receiver + range, primitive name/family, inputs, outputs, attributes, and flat bucket + identity. +- The missing layer is an executable primitive-dispatch plan: + `TemporalPrimitiveTablePlan -> executable op rows -> tensor roles/parameter + bindings -> primitive executor registry -> forward/reverse executor calls`. +- Today, op rows mostly validate or label hardcoded routes. The next code patch + should introduce the primitive executor registry/plan object and make the + temporal path fail closed unless the active primitive rows have registered + forward and backward executors. After that, migrate one primitive family at a + time behind the registry, starting with projection/message primitives that + are already relatively standalone in `backend/cuda/ops`. + +Immediate implementation slice: + +1. Add a CUDA temporal primitive execution contract in Python first, close to + `temporal_tables.py`, with: + - primitive name/opcode, + - forward executor key, + - backward executor key, + - required tensor roles, + - parameter binding names, + - reset/tape policy, + - whether the current active executor is real or fail-closed. +2. Populate it from `TemporalPrimitiveTablePlan` for message, linear/matmul, + recurrence, norm/readout, and parameter-reduction rows. +3. Add guardrails so a temporal scan/reverse cannot report + `cuda_temporal_superop` closure when any active row is missing a real + primitive executor or is routed through `gated_diagonal_temporal_scan` or + `transition_message_reverse_table_device_loop` as semantic owner. +4. Then migrate forward projection/message rows through existing + `backend/cuda/ops` executors. Do not touch performance kernels again until + this primitive-dispatch boundary exists. + +Code log - 2026-04-29 primitive-dispatch guardrail slice: + +- Added `temporal_primitive_dispatch.py` as a fail-closed contract layer over + `TemporalPrimitiveTablePlan`. It does not claim closure. It classifies the + current transition rows as missing/legacy primitive executor ownership and + emits explicit blockers for missing message rows, readout/boundary rows, and + parameter-reduction rows. +- `BackendExecutionRecord` now has + `temporal_primitive_executor_contracts` and + `temporal_primitive_executor_blockers`, and + `record_temporal_bucket_step_loop_execution(...)` records them in both + structured fields and workspace aliases. This makes current-code audits show + fake genericity instead of hiding it behind `cuda_temporal_superop` labels. +- Added + `test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch` + to ensure the guardrail stays generic and does not use sLSTM/Axon text as + executor identity. +- The fabric audit runner now exports + `temporal_primitive_executor_contracts`/`blockers` and fails the CUDA + temporal-owner gate when blockers are present. A `cuda_temporal_superop` + label alone no longer passes ownership closure if primitive math is still + hidden in legacy/hardcoded owners. +- Updated `skills/cb.fabric-backend-boundaries/SKILL.md` with the cleanup + lesson: shared-engine edits must stay small and must delete, fail-close, or + expose stale legacy routes immediately instead of adding parallel + abstractions. +- Still open: the guardrail is not a primitive executor implementation. R4/R13 + stay open until message, projection, recurrence, normalization, readout, + boundary backward, and parameter reductions execute through real primitive + rows and the old hardcoded temporal/message routes are deleted. + +Cleanup log - 2026-04-29 benchmark legacy deletion: + +- Deleted stale top-level Fabric audit scripts: + `benchmarks/run_fabric_bxt_scaling_audit.py`, + `benchmarks/run_fabric_factorization_invariance_audit.py`, and + `benchmarks/run_fabric_mixed_population_audit.py`. The canonical audit + entrypoint is `python -m benchmarks.fabric.run_audit`. +- Deleted the top-level `benchmarks/fabric_suite_common.py` re-export shim and + updated benchmark imports to use `benchmarks.fabric.suite_common` directly. +- Moved the shadowed top-level Fabric benchmark module into + `benchmarks/fabric/benchmark.py` and taught `benchmarks/run.py` to discover + package-local `benchmark.py` modules. This removes the file/package name + collision around `benchmarks.fabric`. +- This is R15 cleanup only. It reduces benchmark/audit route confusion but does + not close the backend primitive executor blockers above. + +Code log - 2026-04-29 message primitive rows: + +- `FabricIR.message_rule` now keeps the message `MessageRuleIR`, not only the + summary. The summary fields remain available through properties so existing + planner/runtime metadata does not need a side channel. +- `TemporalPrimitiveTablePlan` now appends executable message primitive rows + from the message IR nodes. Dot-product lowering becomes generic message + op rows such as `linear`, `attention_logits`, `add`, `segment_softmax`, and + `weighted_sum`, with `surface=message` flat-bucket identity. +- The primitive executor guardrail no longer reports + `message_primitive_rows_missing_from_temporal_table` when those rows are + present. It now reports + `message_primitive_not_dispatched_by_temporal_primitive_row`, which is the + correct remaining blocker: rows exist, real primitive executors are not yet + wired into the temporal engine. +- Still open: message math is still executed by the old message runtime path. + The next backend implementation must bind message rows to primitive + executors and delete the hidden dot-product runtime route as closure + evidence. +- Current-code CUDA smoke: + `/tmp/redo2_message_rows_smoke/cases.jsonl` failed the owner gate as + intended. The active table now reports `12` primitive rows and lists message + primitives (`linear`, `attention_logits`, `add`, `segment_softmax`, + `weighted_sum`) instead of the old synthetic missing-message-row blocker. + Throughput was not a closure row (`B=1`, one iteration, compile/cache probe) + and is recorded only as guardrail evidence. + +Code log - 2026-04-29 readout/boundary primitive rows: + +- `TemporalPrimitiveTablePlan` now appends `readout_project` and + `reduction_boundary` rows for the output/readout surface. This removes the + synthetic `readout_boundary_rows_missing_from_temporal_table` blocker when + readout nodes exist. +- The guardrail now reports + `readout_boundary_not_dispatched_by_temporal_primitive_row` for those rows. + That is the real remaining owner: readout/boundary math still runs through + the old runtime path instead of a temporal primitive executor. +- Current-code CUDA smoke: + `/tmp/redo2_readout_rows_smoke/cases.jsonl` failed the owner gate as + intended with `14` primitive rows. The old missing-readout-row blocker is + gone; the remaining readout blockers are primitive executor dispatch + blockers. + +Code log - 2026-04-29 parameter reduction primitive rows: + +- `TemporalPrimitiveTablePlan` now appends parameter-gradient + `reduction_boundary` rows for each transition parameter in each flat bucket. + Parameter gradient ownership is now visible in the same table as message, + readout, and transition primitives. +- The guardrail now reports + `parameter_reduction_not_dispatched_by_temporal_primitive_row` instead of + `parameter_reduction_rows_missing_from_temporal_table` when parameter rows + exist. This removes the last synthetic missing-surface blocker from the + ordinary active table; remaining blockers are executor-dispatch blockers. +- Current-code CUDA smoke: + `/tmp/redo2_param_rows_smoke/cases.jsonl` failed the owner gate as intended + with `20` primitive rows. The remaining blockers are message/readout/parameter + dispatch blockers plus transition affine/recurrent/norm dispatch blockers. + +Code log - 2026-04-29 active scan consumes real primitive rows: + +- Added `temporal_primitive_rows_tensor(...)` to encode the + `TemporalPrimitiveTablePlan` rows for the CUDA scan ABI. The active forward + scan and replay scan now pass this encoded table into + `try_gated_diagonal_temporal_scan_cuda(...)` instead of letting the Python + binding fabricate a private two-row gated/diagonal table. +- The CUDA table-scan kernel still only executes the gated/diagonal recurrence + opcodes. This does not close R4/R13. It removes one stale binding shortcut + and makes the next blocker precise: implement real row-dispatch primitive + executors for the rows already in the table, then delete the hardcoded + gated/diagonal/message/readout formula owners. +- Current-code CUDA smoke: + `/tmp/redo2_real_primitive_rows_smoke_v2/cases.jsonl` failed the owner gate + as intended with `flat_bucket_temporal_scan_primitive_rows:fabric_ir_temporal_table` + and `flat_bucket_temporal_table_primitive_rows:20`. The row source is now the + Fabric IR temporal table; the remaining failure is executor dispatch blockers. +- Follow-up cleanup removed the Python binding fallback that fabricated a + private two-row gated/diagonal primitive table when no rows were provided. + `try_gated_diagonal_temporal_scan_cuda(...)` now requires caller-provided + primitive rows, so a temporal scan cannot silently bypass the + declaration/IR/table source. This is still not R4/R13 closure; it is a + fail-closed boundary cleanup before real primitive-executor dispatch. +- Current-code CUDA smoke after the fallback deletion: + `/tmp/redo2_required_rows_smoke/cases.jsonl` executed with + `flat_bucket_temporal_scan_primitive_rows:fabric_ir_temporal_table` and + `flat_bucket_temporal_table_primitive_rows:20`, then failed the owner gate for + `temporal_primitive_executor_blockers_present`. The row had no runtime error; + the remaining blockers are the expected missing message, readout, parameter, + affine, recurrence, and norm primitive executors. +- Follow-up cleanup deleted the direct pybind export + `gated_diagonal_temporal_scan`; the extension now exposes only + `temporal_table_scan`. Python callsites were renamed to + `try_flat_bucket_temporal_scan_cuda(...)`, and a boundary test now prevents + reintroducing the direct gated/diagonal bypass. The implementation still + delegates supported recurrence rows to the old internal recurrence executor, + so R4/R13 remain open until that internal owner is replaced by primitive-row + executors. +- Current-code CUDA smoke after the direct-export deletion: + `/tmp/redo2_table_only_smoke/cases.jsonl` compiled in a private cache and ran + without runtime error through `flat_bucket_temporal_table_extension`. The gate + failed for `temporal_primitive_executor_blockers_present`, as expected. The + low tok/s from this artifact is compile/cold-smoke noise (`warmup=0`, + `iterations=1`) and is not throughput evidence. + +Code log - 2026-04-29 reverse primitive rows: + +- Added `temporal_reverse_executor_rows_tensor(...)` so the active reverse + table path derives its op rows from `TemporalPrimitiveTablePlan` instead of + fabricating a private gated/diagonal/message row list inside the CUDA wrapper. +- `try_transition_message_reverse_table_window_cuda(...)` now requires + caller-provided reverse executor rows. This mirrors the forward scan cleanup: + reverse execution cannot silently bypass the Fabric IR temporal table when binding the + CUDA reverse table. +- The CUDA reverse table no longer requires both gated and diagonal primitive + rows unconditionally. It now requires a row for each active transition bucket + and still requires the message row. This reopens the intended one-bucket + single-population reverse path without adding single-population logic. +- Verification: `tests/test_fabric_backend_plan.py` now covers both mixed + reverse rows and one-transition-bucket reverse rows; the direct CUDA reverse + extension test + `test_fabric_cuda_transition_message_reverse_table_window_matches_mixed_step_loop` + passes in a private cache. High-level smoke + `/tmp/redo2_reverse_rows_smoke/cases.jsonl` ran without runtime error and + failed the known primitive-executor owner gate; it is a cold one-iteration + guardrail, not throughput evidence. +- Still open: the reverse CUDA executor still delegates to the old internal + transition/message reverse table program. R4/R13 remain open until message, + transition, norm, readout, and parameter adjoints dispatch real primitive + rows rather than a bundled reverse program. + +Rejected boundary review - 2026-04-29 recurrent K/V reduction split: + +- Owner being edited: `transition_message_reverse_table_device_loop`. +- Rejected probe: adding a deferred recurrent K/V adjoint-window mode to the + existing reverse table. Even though K/V reduction is a projection primitive, + the edit still treated Q/K/V as temporal-engine concepts and kept the active + message rule hardcoded in the bundled reverse table. +- Correct invariant: message passing and cell math follow the same + `fabric.cuda.nn` contract. User-declared message rules and cell transitions + lower into IR, primitive op rows, tensor roles, parameter bindings, and + primitive executors. The shared temporal engine sees flat bucket identity, + dependencies, reset/checkpoint/materialization policy, and opaque primitive + rows; it must not name or assume Q/K/V, gated recurrence, diagonal recurrence, + layernorm, projection formulas, or readout formulas. +- Compiler/library contract: Fabric is a compiler for declared Fabric programs + and a PyTorch-style user library. Declarations lower into IR, op/tensor rows, + parameter bindings, primitive executors, and then temporal scheduling. If a + user declares an unsupported cell/message/readout/projection/norm program, + lowering or executor selection must fail closed. REDO2 cannot close with a + broad API that silently maps arbitrary-looking `fabric.cuda.nn` programs into + one hardcoded dot-product/gated/diagonal route. +- Next backend slice must therefore implement or fail-close real message-rule + primitive dispatch from `MessageRuleIR`, not optimize the existing hardcoded + Q/K/V temporal route. + +Code log - 2026-04-29 message-rule compiler handoff: + +- Fixed the first compiler breach in Python IR lowering: `Spec` can now carry a + normalized `MessageRuleIR`, Blueprint normalization installs the user-declared + `message_passing` rule after graph normalization, and `compile_fabric_ir` + consumes `spec.message_rule` instead of always synthesizing + `default_dot_product_message_rule_ir`. +- Added guard tests proving Blueprint normalization preserves message-rule + sharing metadata and backend IR/planner use a declared `Spec.message_rule` + rather than silently substituting the default. +- This is compiler-boundary progress only. It does not close R4/R13 because the + active CUDA temporal forward/backward path still has hardcoded message tensor + roles and must next lower `MessageRuleIR` into executable primitive dispatch + instead of assuming Q/K/V dot-product semantics. + +Code log - 2026-04-29 compiled message program: + +- Closed the next R2.1/R2.2 substage for message-rule lowering: `MessageRuleIR` + now compiles into a `CompiledMessageRule` with explicit + `CompiledMessagePrimitiveOp` rows. `FabricIR` carries both the declared + `message_rule` and the compiled `message_program`. +- `compile_fabric_ir` now fails closed when a declared message rule is + unsupported or has no executable primitive rows. Unsupported message rules no + longer reach temporal table construction as empty metadata. +- `TemporalPrimitiveTablePlan` now consumes `runtime.backend_ir.message_program` + when appending message rows. It no longer remaps raw `MessageRuleIR.nodes` + inside `temporal_tables.py`, and it raises if backend IR has no compiled + message program. +- Added tests proving the default and declared message rules expose a compiled + primitive program, temporal message rows come from the compiled program, and + unsupported message rules fail before runtime execution. +- Follow-up cleanup aligned message declarations with cell declarations: + `DotProduct` now lowers through a registered `MessageRuleBackendSpec` + builder, not a hardcoded CUDA-side declaration or structural temporal + classifier. Adding a message rule should now be local to the public semantic + declaration, backend spec builder, primitive bindings, executor coverage, and + parity tests. +- Verification: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_default_message_rule_contract_is_planner_visible tests/test_fabric_backend_plan.py::test_backend_ir_uses_declared_spec_message_rule_not_default_substitution tests/test_fabric_backend_plan.py::test_temporal_message_rows_come_from_compiled_message_program tests/test_fabric_backend_plan.py::test_message_rule_compiler_rejects_unsupported_rule_before_runtime_execution tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch` + passed: 5 tests. +- Still open: R2.3 primitive executor dispatch. The compiled message program is + now the source of temporal message rows, but those rows still report + `message_primitive_not_dispatched_by_temporal_primitive_row` until real + message primitive forward/backward executors replace the old dot-product + runtime path. + +Code log - 2026-04-29 compiled transition program: + +- Closed the matching R2.1/R2.2 substage for cell-transition lowering: + `CellTransitionIR` now compiles into `CompiledTransitionProgram` with + explicit `CompiledTransitionPrimitiveOp` rows, parameter inputs, schema roles, + and binding-slot ownership. `FabricIR` carries one compiled transition + program per population binding slot. +- `TemporalPrimitiveTablePlan` now consumes + `runtime.backend_ir.transition_program_for_binding_slot(...)` for transition + tensor slots, transition primitive rows, and parameter-reduction rows. It no + longer reads raw transition ops or parameter schema from + `runtime._backend_population_specs` while building the temporal primitive + table. +- Unsupported transition ops fail closed during transition-program compilation + with an explicit "add the op to fabric.cuda.nn lowering" error. They do not + reach temporal execution as metadata-only rows. +- Added tests proving backend IR carries compiled transition programs, + temporal transition rows come from those programs, unsupported transition ops + fail closed, and the primitive-executor plan still reports the real remaining + blocker rather than missing rows. +- Verification: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fabric_backend_ir_compiles_receiver_sets_and_buckets tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_temporal_transition_rows_come_from_compiled_transition_programs tests/test_fabric_backend_plan.py::test_transition_program_compiler_rejects_unsupported_transition_op tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch` + passed: 5 tests. +- Still open: R2.3 primitive executor dispatch. The transition program is now + the source of transition rows, but recurrence, affine/projection, norm, + message, readout, boundary, and parameter-gradient math still need real + primitive executors before R4/R13/R2.3 can close and before legacy kernels can + be deleted. + +Code log - 2026-04-29 readout program and primitive registry: + +- Added `backend/primitives.py` as the shared `fabric.cuda.nn` primitive + registry used by both planner backward-behavior metadata and transition + compilation. The transition compiler no longer has a small private admission + list; it accepts any registered callable primitive such as `tanh`, `sigmoid`, + `add`, `exp`, reductions, message segment ops, or the current composite + recurrence primitives, then executor dispatch decides whether the row is + implemented or blocked. +- Important semantic note: accepting fine-grained primitives in the compiler is + not enough. The existing sLSTM/Axon declarations still contain composite + primitive rows (`gated_logspace_recurrence`, `diag_rtu`) whose internal + elementwise math is hidden inside legacy executors. R2.3 remains open until + either those are explicit supported fused primitive executors selected from + rows, or the cell declarations decompose into fine-grained primitive rows that + a fusion pass groups back into fast CUDA blocks. +- Added a compiled readout program: `ReadoutRuleIR` lowers into + `CompiledReadoutRule` with `readout_project` and `reduction_boundary` rows. + `FabricIR` carries the readout rule/program, and + `TemporalPrimitiveTablePlan` now consumes the compiled readout program rather + than fabricating rows from `runtime.config.readout_pool`. +- Documented the required fusion model in R2.3: fine-grained IR preserves + user-definable math, while executor selection may fuse adjacent primitive rows + into message/transition/readout/reverse CUDA blocks for throughput as long as + the fused block is selected by compiled rows and reports row coverage. +- Verification: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fabric_backend_ir_compiles_receiver_sets_and_buckets tests/test_fabric_backend_plan.py::test_temporal_readout_rows_come_from_compiled_readout_program tests/test_fabric_backend_plan.py::test_readout_rule_compiler_rejects_unsupported_rule tests/test_fabric_backend_plan.py::test_transition_program_compiler_uses_cuda_nn_primitive_registry tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch` + passed: 5 tests. +- Still open: R2.3 executor dispatch and fusion. The compiler products now + cover message, transition, readout, boundary, and parameter rows, but the + active CUDA route still has executor blockers and legacy bundled kernels. + +Code log - 2026-04-29 fused executor groups: + +- Added explicit fused executor groups to the temporal primitive executor plan. + The plan now reports row coverage for compiler-selected message, transition, + readout/boundary, and parameter-reduction fusion groups instead of only + per-row blockers. +- This is not a performance implementation yet. It is the compiler artifact the + next CUDA executor work must consume: a fused CUDA block is valid only if it + covers lowered row indices from the primitive table and reports the covered + primitives in metadata. +- The current transition fusion groups are deliberately blocked with + `declared_composite_transition_primitive_not_dispatched_by_primitive_executor` + when they contain `gated_logspace_recurrence`, `diag_rtu`, or + `diagonal_recurrence`. This records the true current state: composite + recurrence is the accepted primitive granularity, but those composites still + need row-owned primitive executors before closure. +- Deleted the old CUDA execution-cell formula bundles + `src/cortical/fabric/backend/cuda/cells/slstm.cuh` and + `src/cortical/fabric/backend/cuda/cells/axon.cuh` with their registration + `.cu` files. The cell semantic definition remains in + `src/cortical/fabric/backend/cell_specs.py` as `CellBackendSpec` / + `CellTransitionIR`; CUDA implementation must come from primitive executor + bindings selected by lowered rows. The acceptable end state is either: + 1. cell files declare `fabric.cuda.nn` programs in primitive builder terms and + the compiler/fusion pass creates fast CUDA fused blocks from the lowered + rows, or + 2. a composite such as `gated_logspace_recurrence`/`diag_rtu` is an explicit + supported fused primitive with its own tensor-role ABI, forward executor, + backward executor, parameter reductions, and fail-closed declaration + boundary. + The unacceptable state is a generic-looking compiler table that still executes + arbitrary user-declared cell math through hidden sLSTM/Axon formula bundles. +- Still open: replace the remaining bundled reverse/scan kernels such as + `transition_message_reverse_table_device_loop` with primitive/fused executor + dispatch, then delete the superseded kernels/exports once parity and + representative throughput pass. + +Code log - 2026-04-30 composite recurrence pivot: + +- Reverted the sLSTM fine-grained semantic-row split. The active design now + keeps sLSTM on the declared composite primitive + `gated_logspace_recurrence`, alongside Axon's `diag_rtu` composite. This is a + deliberate compiler granularity choice, not a permission to route by cell + name or hidden formula bundles. +- Removed the attempted eager semantic transition interpreter from the CUDA + path before committing it. PyTorch-style op interpretation may exist only as + reference/debug machinery; it is not REDO2 CUDA backend closure. +- Updated executor-plan guardrails so composite recurrence rows fail closed as + `declared_composite_transition_primitive_not_dispatched_by_*`, not as a + semantic-row blocker and not as an accepted legacy owner. +- R2.3/R4/R13 remain open. The next backend owner is still the row-owned + composite primitive executor and table-owned CUDA temporal reverse path for + `gated_logspace_recurrence`/`diag_rtu`, plus current-code T=1 guardrails + before broader T/H/K audits. + +Code log - 2026-04-30 compiled transition-program executor selection: + +- Moved CUDA transition forward/backward dispatch one step closer to the real + compiler boundary: active composite transition execution now resolves + `runtime.backend_ir.transition_program_for_binding_slot(...)` through the + population binding slot and selects `gated_logspace_recurrence`/`diag_rtu` + from compiled primitive rows. The old direct `population_spec.transition_ir` + pattern checks are no longer the selection source. +- Added runtime metadata for this slice: + `_last_transition_executor_program_source="compiled_transition_program"`, + `_last_transition_executor_binding_slot`, and + `_last_transition_executor_primitives`. +- This does not close R2.3/R4/R13. The selected composite executor still calls + the existing primitive/kernel bodies, and the temporal reverse owner is still + open until the active reverse scan/recompute path is table-owned inside the + CUDA temporal superop and legacy helper routes can be deleted. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py` + passed: 65 tests. +- Current-code T=1 guardrail on GPU 0 with isolated cache dirs: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_composite TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_composite ...` + passed forward parity for sLSTM `B=2,T=1,h=32`, reset absent and reset + present, both with max diff `8.493661880493164e-07`. Metadata reported + `compiled_transition_program` and primitives + `('linear', 'matmul', 'gated_logspace_recurrence', 'norm_or_identity')`. + A tiny high-level training backward smoke on the same row completed with loss + `0.05161779746413231`. This is only a guardrail, not full T=1 audit closure. + +Code log - 2026-04-30 composite cell declaration target: + +- Accepted the concise composite primitive authoring style for sLSTM and Axon: + cell declarations should read like a small Fabric program, with `linear`, + `matmul`, `norm_or_identity`, state emission, and public emission as ordinary + declared ops, while `gated_logspace_recurrence` and `diag_rtu` are explicit + composite recurrence primitives. +- Tightened the Axon target sketch so `diag_rtu` owns only the diagonal + transition/eligibility update. Input projection and output projection remain + separate declared primitives, and activation selection is an explicit + op-row/tensor/attribute concern rather than a hidden runtime side channel. +- Recorded the remaining blocker: the current indexed C++ builder, + PyTorch-side `activation_name` lowering, and any helper that dispatches by + Axon/sLSTM-shaped bundles are still facade risks. R2.3/R15 remain open until + the active path is declaration -> IR -> primitive rows -> tensor-table roles + -> parameter bindings -> primitive executors -> shared temporal runtime. +- Refreshed CUDA parity guardrails so PyTorch comparison models use the explicit + `_fabric_forward_reference` helper instead of the disabled direct sequence + route. This keeps the no-legacy-route invariant intact while preserving + parity checks for mixed-pop T=1 and forced flat-bucket sequence rows. +- Verification: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_composite TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_composite uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific` + passed: 4 tests. +- Closed the immediate low-memory recompute blocker exposed by that broader + run. Recompute previously disabled transition tape under tight memory, but + reset-present/materialized-final-state rows were excluded from the compact + CUDA reverse-window table path, so backward rejected with + `reverse_engine_reject=missing_transition_tape`. The recompute planner now + permits those rows to use table-owned reverse-window payloads; reset masks are + passed as transition/message reset windows to the CUDA reverse table, and + unsupported state-gradient cases still fail closed instead of reopening the + Python host loop. +- Verification after the backend fix: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_recompute_reset TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_recompute_reset uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_recomputed_artifacts_match_pytorch_reference` + passed: 1 test. + The broader affected runtime set using the same isolated cache passed: + 16 tests. + +Code log - 2026-04-30 primitive-row parameter ABI: + +- Closed a small but necessary R2.2/R2.3 compiler-boundary substage: + parameter ownership is now part of the primitive row and executor contract, + not an inferred side channel. Compiled message rows record message-rule + parameter bindings (`q_weight`, `k_weight`, `v_weight`, `out_weight`), + compiled transition rows carry their per-op parameter inputs, readout rows + carry `readout_weight`, and parameter-reduction rows carry the parameter they + reduce. +- `TemporalPrimitiveExecutorContract` now records row index, surface, inputs, + outputs, and parameter bindings; fused executor groups also report their + combined parameter bindings. This is still fail-closed: statuses remain + `missing_executor`, so audits still reject the active path until real + primitive executors replace the legacy message/readout/transition/parameter + routes. +- R2.3/R4/R13 remain open. The next backend implementation must consume these + row-owned bindings in the CUDA primitive executor registry instead of + resolving message, projection, recurrence, norm, readout, or parameter + gradients from hardcoded runtime bundles. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/readout_rules.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py` + passed. + `uv run ruff check src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/readout_rules.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_audit_runner.py::test_fabric_audit_cuda_temporal_owner_gate_fails_primitive_executor_blockers` + passed: 62 tests, 14 warnings. + +Code log - 2026-04-30 compiled transition parameter bindings: + +- Moved active CUDA transition parameter resolution off the old + `CellBackendSpec.transition_parameter_bindings` side channel. sLSTM and Axon + transition forward/backward now resolve projection, recurrence, norm, and + output-projection parameters through the compiled transition program's + parameter bindings. Eligibility trace state discovery now also reads the + compiled transition program schema. +- This is still not R2.3 closure: the active executors remain legacy-backed + primitive bodies, but parameter binding now follows the compiler chain needed + by the real primitive executor registry: + compiled transition program -> primitive row inputs -> compiled parameter + bindings -> CUDA executor resolution. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py` + passed: 65 tests. + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_program_params TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_program_params uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_recomputed_artifacts_match_pytorch_reference tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor` + passed: 16 tests in 237.75s. + +Current-code audit guardrail - 2026-04-30 T=1 100M regression: + +- Ran the high-level audit entrypoint on GPU 0 with private cache dirs: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_t1_guardrail TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_t1_guardrail uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --families slstm --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --hidden-sizes 32 --population-modes single --reset-modes absent --warmup 1 --iterations 2 --out-dir /tmp/redo2_t1_program_params_guardrail` + returned `ok` as an informational run, not a closure run. +- Warmed confirmation on the same cache: + `/tmp/redo2_t1_program_params_guardrail_repeat` with `iterations=3`. + Result stayed bad: `430.67 tok/s`, `2377.69 ms`, `27.99 GiB`. + April 21 reference key `h32_t1_bxparams` is `58732.71 tok/s`, + `2.07 GiB`. This is a severe R11/R2.3 regression; T=1 is not healthy. +- Owner timing with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1` at + `/tmp/redo2_t1_program_params_owner_timing` confirms the dominant owner: + `transition_message_reverse_table_device_loop:ms=1686.269;count=1`. + Other CUDA owners are small by comparison + (`public_projection:6.715 ms`, `message.fused_receiver_sender:1.596 ms`, + `readout:1.213 ms`). +- The same artifact reports primitive executor blockers for message, readout, + parameter reduction, transition affine, `gated_logspace_recurrence`, and + `norm_or_identity`. Therefore the next implementation must replace the + bundled `transition_message_reverse_table_device_loop` with row-owned reverse + primitive executors. More planner relabeling or manifest work is invalid + until this owner physically moves. + +Rejected probe - 2026-04-30 sender-reverse policy: + +- Tested a generic policy probe that disabled sender-reverse accumulation in + the reverse table executor. Targeted CUDA parity passed: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_sender_probe TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_sender_probe uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads` + passed: 6 tests. +- Performance was worse on the same high-level T=1 100M B=1024 row: + `/tmp/redo2_t1_sender_reverse_off_probe` reported `177.79 tok/s`, + `5759.51 ms`, and + `transition_message_reverse_table_device_loop:ms=5066.727;count=1`. +- Reverted the probe. The issue is not just sender-reverse policy; the bundled + reverse executor itself is the owner. Continue with row-owned reverse + primitive executors and delete the bundled transition/message reverse loop + once parity and throughput recover. + +Code log - 2026-04-30 gated/message reverse primitive group: + +- Wired the existing table-based gated/message reverse executor into the active + temporal backward path for the pure `gated_logspace_recurrence + message` + primitive program. The route is selected from the temporal reverse primitive + rows and tensor roles, not from population name or an audit row. Mixed + `gated + diag_rtu` windows still fall back to the bundled + `transition_message_reverse_table_device_loop`, so R2.3/R4/R13 stay open. +- The row-owned group still is not full closure: it is a C++ window loop over + primitive executors and does not yet support the mixed transition buckets, + reset-split semantics, sparse direct-public gradients, or the final CUDA + temporal reverse superop. It is accepted only as a throughput-moving slice + away from the worst single-pop bundled reverse owner. +- CUDA parity guardrail: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_gated_primitive_probe TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_gated_primitive_probe uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads` + passed: 7 tests in 230.03s. +- Current-code T=1 100M/B=1024/h=32 no-reset single-pop guardrail: + `/tmp/redo2_t1_gated_primitive_owner_timing_repeat` reported + `860.83 tok/s`, `1189.55 ms`, `29.55 GiB`. The previous warmed row was + `430.67 tok/s`, `2377.69 ms`, `27.99 GiB`; April21 reference remains + `58732.71 tok/s`, `2.07 GiB`. +- Owner timing moved for that row from + `transition_message_reverse_table_device_loop:ms=1686.269` to + `gated_message_reverse_primitive_group:ms=488.148`. This is real progress + but not enough; memory regressed and throughput is still far below the T=1 + training floor. +- Next R2.3/R4/R13 owner: make the gated/message primitive group a real CUDA + temporal reverse executor rather than a C++ per-window loop, remove the + remaining hidden message/transition parameter reductions from the temporal + glue, and then extend the same primitive-row executor to mixed + `gated + diag_rtu` windows before deleting the bundled reverse loop. + +Rejected probe - 2026-04-30 dense sender-reverse inside gated/message group: + +- Tried the dense sender-reverse message table inside the row-owned + gated/message reverse group instead of the receiver-partitioned fallback. + Parity was clean: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_gated_sender_dense TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_gated_sender_dense uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. +- Representative T=1 100M/B=1024 timing was worse: + `/tmp/redo2_t1_gated_sender_dense_owner_timing` reported `849.36 tok/s`, + `1205.62 ms`, `29.57 GiB`, and + `gated_message_reverse_primitive_group:ms=503.056`. The accepted + receiver-partitioned row remains `860.83 tok/s`, `1189.55 ms`, + `29.55 GiB`, `488.148 ms` for that owner. +- Reverted the probe. The next useful optimization is not changing sender-table + direction; it is reducing the forward/reverse state materialization footprint + and moving the primitive group out of the C++ per-window loop. + +Measurement log - 2026-04-30 affine parameter reduction timing: + +- Added owner timing around the remaining Python-side gated affine parameter + reductions in the accepted gated/message primitive route. This is measurement + only; it does not close R2.3/R4/R13 because the reduction still is not a + primitive-row executor. +- Current-code T=1 100M/B=1024/h=32 no-reset single-pop guardrail with the + timing label: + `/tmp/redo2_t1_affine_reduce_owner_timing` reported `861.60 tok/s`, + `1188.49 ms`, `29.55 GiB`. +- Timing breakdown: + `gated_message_reverse_primitive_group:ms=489.001;count=1`, + `gated_message_reverse_affine_param_reduce:ms=8.716;count=1`, + `public_projection:ms=6.731;count=1`, + `message.fused_receiver_sender:ms=1.596;count=1`. +- Conclusion: the Python affine reduction is real compiler debt but not the + current throughput owner. Prioritize the row-owned reverse primitive executor + and forward/reverse materialization footprint before moving this 8-9 ms + reduction. + +Rejected probe - 2026-04-30 fresh-source private-state recurrence policy: + +- Tested two generic source-policy variants for the gated recurrence primitive: + first a runtime flag that skipped recurrent-kernel accumulation when the + reverse table source was planner-known zero, then a templated executor variant + that skipped private-state adjoints for fresh one-step no-state windows. + Both kept provided-state and T>1 windows on the normal path. +- Targeted CUDA parity passed for both variants: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_zero_source_recurrent_kernel TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_zero_source_recurrent_kernel uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 5 tests; the templated variant repeated the same 5-test pass with + `/tmp/cortical_torch_ext_redo2_private_state_template`. +- Performance did not improve. Runtime-flag runs at + `/tmp/redo2_t1_zero_source_recurrent_kernel_skip_repeat` reported + `786.03 tok/s`, `1302.74 ms`, `29.55 GiB`, with + `gated_message_reverse_primitive_group:ms=604.221`. The templated variant at + `/tmp/redo2_t1_private_state_template` returned to baseline at + `861.04 tok/s`, `1189.25 ms`, `29.55 GiB`, with + `gated_message_reverse_primitive_group:ms=489.059`. +- Reverted both source-policy variants. Do not retry this as a local flag in + the hot recurrence kernel. The next backend move must be structural: either + a real table-owned reverse primitive executor/superop that reduces the + materialized windows and reductions together, or a fused primitive-executor + reduction selected from the compiled parameter-reduction rows. + +Kernel profile - 2026-04-30 T=1 100M/B=1024: + +- Ran nsys on the same high-level audit row with private caches: + `CUDA_VISIBLE_DEVICES=0 CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_affine_reduce_probe TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_affine_reduce_probe nsys profile --force-overwrite=true --stats=true --output=/tmp/redo2_t1_kernel_profile uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --families slstm --sizes 100m --modes forward_backward --batches 1024 --seq-lens 1 --inner-steps 1 --hidden-sizes 32 --population-modes single --reset-modes absent --warmup 1 --iterations 1 --out-dir /tmp/redo2_t1_kernel_profile_audit`. +- Artifacts: + `/tmp/redo2_t1_kernel_profile.nsys-rep`, + `/tmp/redo2_t1_kernel_profile.sqlite`, + `/tmp/redo2_t1_kernel_profile_audit`. +- GPU kernel summary shows the real owners: + `flat_bucket_gated_diagonal_temporal_scan_kernel` took about `594 ms` per + run under nsys and + `gated_logspace_core_recurrent_affine_backward_window_kernel` took about + `419 ms` per run. The remaining recurrent K/V backward kernels were each + about `18-40 ms`; message/readout kernels were much smaller. +- Conclusion: T=1 health cannot be restored by planner metadata, benchmark + changes, or the 8-9 ms Python affine reduction. The next throughput work must + change primitive-executor tiling/fusion and then the forward scan tiling. R4/R13 + stay open. + +Code log - 2026-04-30 gated recurrence backward warp-row executor: + +- Added a warp-row executor variant inside the + `gated_logspace_recurrence` composite backward primitive for hidden widths + `H <= 32`. This is a primitive-executor tiling choice based on the lowered + primitive tensor width; it does not add a population-name, benchmark-row, or + temporal-scheduler branch. The existing block-row executor remains for wider + hidden widths. +- Targeted parity: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_gated_warp32_backward TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_gated_warp32_backward uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 5 tests. +- Mixed-pop guardrail: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_gated_warp32_backward TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_gated_warp32_backward uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- Current-code T=1 100M/B=1024/h=32 no-reset single-pop audit: + `/tmp/redo2_t1_gated_warp32_backward` reported `1045.88 tok/s`, + `979.08 ms`, `29.55 GiB`. The accepted previous row was `861.60 tok/s`, + `1188.49 ms`, `29.55 GiB`. +- Owner timing moved + `gated_message_reverse_primitive_group` from about `489 ms` to + `277.895 ms`; the Python affine reduction remains about `8.704 ms`. +- This is accepted R2.3/R4/R13 progress but not closure. T=1 is still far below + the April21 reference `58732.71 tok/s` and memory is still far above + `2.07 GiB`. The next high-priority owner is the forward + `flat_bucket_gated_diagonal_temporal_scan_kernel`, which is still the largest + kernel in the T=1 profile and still contains temporal-scheduler-owned + primitive math that must move toward row-owned primitive executors. + +Code log - 2026-04-30 gated recurrence forward small-hidden scratch: + +- Added a small-hidden executor variant for the declared + `gated_logspace_recurrence` composite forward row inside the current flat + bucket temporal scan kernel. For `hidden_dim <= 32`, the row executor now uses + a 32-wide scratch buffer instead of the legacy 256-wide scratch buffer. The + selection is based on primitive tensor width, not on population name, audit + row, `T=1`, single/mixed labels, or family-specific policy. +- Boundary review: this is accepted only as a throughput-moving primitive + executor storage fix. It does not close R2.3/R4/R13 because the forward scan + body is still the legacy monolithic temporal kernel and still contains + message/readout/gated/diagonal primitive formulas instead of a true + row-owned primitive executor registry. The remaining closure item is still to + split the active forward scan into temporal scheduling over tensor/op rows and + primitive executors, then delete the hardcoded monolithic kernel once parity + and throughput recover. +- Targeted parity after the final patch: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_small_scratch2 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_small_scratch2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_small_scratch2 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_small_scratch2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. +- Current-code T=1 100M/B=1024/h=32 no-reset single-pop audit after the final + patch: + `/tmp/redo2_t1_forward_small_scratch2` reported `1113.73 tok/s`, + `919.43 ms`, `29.55 GiB`; backward owner timing stayed around + `gated_message_reverse_primitive_group:ms=277.445`. + The previous accepted row was `/tmp/redo2_t1_gated_warp32_backward` at + `1045.88 tok/s`, `979.08 ms`, `29.55 GiB`. +- Nsys confirmation artifact: + `/tmp/redo2_t1_forward_small_scratch_profile.nsys-rep` and + `/tmp/redo2_t1_forward_small_scratch_profile.sqlite`. + The forward `flat_bucket_gated_diagonal_temporal_scan_kernel` averaged about + `539 ms` under nsys versus the prior profile's about `594 ms`; the backward + `gated_logspace_core_recurrent_affine_backward_window_warp32_kernel` averaged + about `207 ms`. +- R2.3/R4/R13 remain open. T=1 is still far below the April21 score reference + `h32_t1_bxparams = 58732.71 tok/s` and memory remains far above `2.07 GiB`. + The next high-priority backend owner is not another local scratch tweak; it is + moving the forward message/transition/readout primitive execution out of the + monolithic temporal scan and into compiled tensor-table primitive executors, + with a corresponding materialization-footprint reduction. + +Rejected probe - 2026-04-30 message small-degree scratch: + +- Tried the same scratch-size idea on the forward dot-product message primitive: + for degree `<= 8`, use an 8-wide local logits/weight scratch instead of the + legacy 32-wide scratch. This passed targeted parity but did not move the + representative row. +- Verification before revert: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_message_small_scratch TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_message_small_scratch uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 5 tests. + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_message_small_scratch TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_message_small_scratch uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- Performance before revert: + `/tmp/redo2_t1_message_small_scratch` reported `1113.17 tok/s`, + `919.90 ms`, `29.55 GiB`, effectively no improvement over + `/tmp/redo2_t1_forward_small_scratch2` at `1113.73 tok/s`, `919.43 ms`, + `29.55 GiB`. +- Reverted. Do not retry local message scratch-size edits as R2.3/R4/R13 + closure. The owner remains structural: primitive executor dispatch and + materialization footprint, not another local array-size branch. + +Code log - 2026-04-30 transition tape mode ABI and gated compact recompute: + +- Added an explicit CUDA temporal scan transition-tape mode: + `disabled`, `input_projection`, and `full`. The previous scan ABI only had a + boolean and therefore always allocated/wrote full gated logits and diagonal + preprojection tape whenever any transition tape was requested. The scan now + allocates/writes input-projection tape separately from full transition tape. +- Added a compact recompute path inside the declared + `gated_logspace_recurrence` composite backward primitive executor. When a + reverse table provides input-projection tape but omits full gate/recurrent + logits, the primitive executor can recompute those logits from its own + lowered primitive tensors (`transition.input_window`, `gate_weight`, + `gate_bias`, `recurrent_kernel`, `y_prev_window`) instead of requiring the + temporal scheduler to materialize full logits. This keeps the recompute inside + the composite primitive executor, not as cell-family logic in the temporal + scheduler. +- Boundary review: accepted as infrastructure only. The ABI is still not R2.3 + closure because the active forward scan remains the monolithic + `flat_bucket_gated_diagonal_temporal_scan_kernel`, message/readout/reduction + primitive rows still report missing primitive executors, and the default T=1 + path stays on full tape for throughput until compact recompute is faster or + selected by a memory-pressure policy. +- Targeted parity: + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_compact_gated_tape TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_compact_gated_tape uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_compact_gated_tape TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_compact_gated_tape uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_compact_gated_tape_g1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_compact_gated_tape_g1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. +- Rejected as default T=1 policy: forcing compact input-projection tape for the + 100M/B=1024/h=32 T=1 forward-backward row reduced peak memory but regressed + throughput. `/tmp/redo2_t1_compact_gated_tape` reported `1062.90 tok/s`, + `963.40 ms`, `25.05 GiB`, with + `gated_message_reverse_primitive_group:ms=384.516`. The previous accepted + default was `/tmp/redo2_t1_forward_small_scratch2` at `1113.73 tok/s`, + `919.43 ms`, `29.55 GiB`. +- Accepted default guardrail after restoring full tape for T=1: + `/tmp/redo2_t1_compact_gated_tape_full_default` reported `1106.69 tok/s`, + `925.28 ms`, `29.55 GiB`, with + `gated_message_reverse_primitive_group:ms=279.332`. This preserves the + throughput-first T=1 default while retaining compact tape/recompute support + for longer-T or memory-pressure planning. +- Next R4/R13 owner remains throughput: move the active forward + message/transition/readout primitive bodies out of the monolithic temporal + scan into row-owned primitive executors, and only then reconsider compact tape + as a default if the recompute executor no longer costs more than full-logit + materialization. + +April26 recovered design correction - 2026-04-30 training is always streamed: + +- Re-read `ai_docs/recovered_core.py` and `ai_docs/AWS_RECOVERY_TRAIL.md` + around the April24-April26 temporal design notes. The recovered design treats + ordinary training as streamed execution: the public call is still + `model(...)`, external loss, and `loss.backward()`, but the backend streams + the planned `T*K` substrate, saves only planner-selected compact checkpoints, + and walks reverse windows right-to-left during backward. +- Dense sequence outputs are user-visible materialization, not permission for + full internal liveness. Internal state, message, K/V, transition tape, + readout, boundary-adjoint, and parameter-reduction artifacts must be bounded + by `EmissionPlan`, `CheckpointPlan`, and `BackwardWindowPlan`, then discarded + per segment. +- A full `[T, B, cells, ...]` internal training artifact is now explicitly a + streaming-liveness regression unless it is the requested user output or a + recorded compact checkpoint/recompute artifact. This applies to T=1, T>1, + K>1, terminal loss, and per-timestep sequence loss under the same engine. +- Updated durable skills: + `skills/cb.fabric-backend-boundaries/SKILL.md` and + `skills/cb.fabric-performance-loop/SKILL.md` now state this invariant + directly so future backend/performance work does not reinterpret training as + a separate retained-sequence route. + +Boundary failure - 2026-04-30 flat-bucket scan ABI is still not a generic +compiler superop: + +- User review called out the live CUDA signature + `flat_bucket_gated_diagonal_temporal_scan_cuda(...)`. The concern is correct: + the ABI still mixes message/readout tensors (`input_k_seq`, `input_v_seq`, + `recurrent_q`, sender indices, distances, output query/bias/projection), + gated/sLSTM state (`initial_gated_y/c/n/m`), and Axon/diagonal state and + trace tensors (`initial_diagonal_*`) in one monolithic temporal scan entry + point. +- The tensor-table wrapper in + `flat_bucket_temporal_scan_kernels.cu` does not fix that by itself. It maps + roles such as `message.recurrent_q`, + `primitive.gated_logspace_recurrence.param.*`, and + `primitive.diag_rtu.param.*` into the same hardcoded fused function. That is + a transitional table facade, not REDO2 compiler closure. +- Accepted final direction: temporal superop ABI must be tensor/op-row owned. + It may schedule flat bucket identity, `T*K`, resets, horizon, + checkpoint/recompute, materialization, and dependency order. Message, + projection, readout, normalization, `gated_logspace_recurrence`, and + `diag_rtu` math must be reached only through primitive executors selected by + lowered op rows. Composite recurrence primitives are acceptable, but they + still need explicit row-owned forward/backward/recompute/tape/param-gradient + executor bindings and fail-closed unsupported-op behavior. +- This keeps R2.3/R4/R13/R15 open even when a fused scan benchmark runs. A + benchmark through `flat_bucket_gated_diagonal_temporal_scan_cuda` is + throughput evidence for the transitional route only; it cannot close generic + Fabric compiler ownership until the hardcoded ABI and kernels are replaced or + moved behind true primitive row dispatch and the old route is deleted. + +Code log - 2026-04-30 bounded reverse workspace attempt for T/H: + +- Added a compact checkpoint-artifact mode for active gated reverse windows, + compact gated transition recompute, direct output-boundary reverse-payload + use, compact one-step recurrent-message gradient workspace, and split K/V + projection backward to reduce retained `[T,B,...]` tensors in + `T=512,K=1,H=64` sequence-loss training. +- Targeted parity stayed green: + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_split_kv_proj TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_msg_work_fix uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_split_kv_proj_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_msg_work_fix uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. +- Representative current-code T/H guardrail still fails: + `/tmp/redo2_th_t512_k1_h64_b128_split_kv_proj` for + `T=512,K=1,H=64,B=128,100M,h=32,single-pop,sequence-loss` OOMed on another + `4.50 GiB` allocation. Peak allocator state was about `128.99 GiB` + allocated and `6.49 GiB` reserved. This is not acceptable against the + April21 streaming per-timestep reference (`95638.46 tok/s`, `61.21 GiB`). +- The same run's metadata explicitly reports primitive executor blockers for + message primitives, readout boundary, parameter reduction, transition affine, + `gated_logspace_recurrence`, and `norm_or_identity`. That confirms the next + high-priority owner is not more facade optimization: replace the active + monolithic scan/reverse ABI with real primitive row dispatch, while preserving + the streaming liveness invariant and deleting the hardcoded route as each + primitive executor takes ownership. + +Code log - 2026-04-30 forward primitive executor extraction: + +- Moved forward message/readout, `gated_logspace_recurrence`, and `diag_rtu` + formulas out of `flat_bucket_temporal_scan_kernels.cu` into + `sequence_surface/flat_bucket/flat_bucket_temporal_forward_primitives.cuh`. The + scan scheduler now calls primitive executor functions for these row bodies + instead of carrying the Axon/diagonal and gated recurrence math inline in the + temporal loop. +- Removed the misleading `flat_bucket/primitive_executors/` folder. The helper + header now sits beside the flat-bucket scan kernel it serves; real primitive + executor ownership remains the compiler Pass 2/3 target, not a staging + subfolder around fixed scan math. +- This closes only a narrow R2.3/R15 substage: primitive formulas are no longer + embedded directly in the forward scan scheduler body. It does not close + R2.3/R4/R13 because the kernel launch ABI still passes a gated+diagonal+ + fixed-message/readout bundle, and runtime metadata still reports primitive + executor blockers for message, readout, parameter reduction, transition + affine, `gated_logspace_recurrence`, and `norm_or_identity`. +- Targeted parity: + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_primitive_header TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_primitive_header uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_primitive_header TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_primitive_header uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_primitive_header TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_primitive_header uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py --tb=short` + passed: 3 tests. +- Performance guardrail: + `/tmp/redo2_t1_forward_primitive_header` with `warmup=0` reported + `333.50 tok/s` and a cold `public_projection` owner; rejected as cold-only + evidence. + `/tmp/redo2_t1_forward_primitive_header_warm` with `warmup=1` reported + `1106.91 tok/s`, `925.10 ms`, `29.55 GiB`, with + `gated_message_reverse_primitive_group:ms=279.392`, + `public_projection:ms=6.912`. This matches the accepted current-code T=1 + line and preserves the post-scratch throughput level, but remains far below + the April21 reference (`58732.71 tok/s`, `2.07 GiB`). +- Next R2.3 owner: replace the remaining monolithic forward launch ABI with a + tensor-table/op-row primitive executor dispatch boundary. The primitive + header is only a staging point; it must not become a permanent wrapper around + the old gated+diagonal bundle. + +Code log - 2026-04-30 forward table-bound host implementation: + +- Removed the private host entry + `flat_bucket_gated_diagonal_temporal_scan_cuda(...)` with the giant + gated+diagonal+message+readout argument list from + `flat_bucket_temporal_scan_kernels.cu`. The exported pybind path already + called `flat_bucket_temporal_table_scan_cuda(...)`; that wrapper now invokes + `flat_bucket_temporal_table_scan_impl_cuda(...)` with tensor-table roles, + primitive-row ranges, and scalar tables instead of re-expanding the old ABI at + the wrapper boundary. +- This answers the scan-body location precisely: the forward row formulas were + moved in the previous slice from + `flat_bucket_temporal_scan_kernels.cu` into + `sequence_surface/flat_bucket/flat_bucket_temporal_forward_primitives.cuh`. This + slice removed the old host-call surface around that scan body, not the CUDA + kernel body itself. +- Manual boundary review from that slice: this was still not REDO2 compiler + closure. The CUDA + kernel symbol is still `flat_bucket_gated_diagonal_temporal_scan_kernel`, and + the table implementation still binds role names for the current supported + message/readout/`gated_logspace_recurrence`/`diag_rtu` tensors before calling + the monolithic kernel. The remaining temporal math needed to be removed or + moved behind primitive executors selected by lowered op rows. In particular, + `temporal_backward.py` and + `flat_bucket_temporal_backward_kernels.cu` remain open R4/R13/R15 targets for + semantic splitting and CUDA reverse primitive executor ownership. +- Targeted verification: + `git diff --check` passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py --tb=short` + passed: 3 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_table_impl TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_table_impl uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_table_impl TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_table_impl uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_table_impl TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_table_impl uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- Warmed T=1 guardrail: + `/tmp/redo2_t1_forward_table_impl` reported `1103.16 tok/s`, + `928.24 ms`, and `29.55 GiB` for the high-level API + `100M/B=1024/T=1/K=1/h=32/single-pop/terminal training` row. Runtime + metadata reports + `flat_bucket_temporal_scan_binding_abi:flat_bucket_temporal_table_extension`, + `20` primitive rows, and `16` + primitive-executor blockers. This preserves the current post-scratch T=1 + line, but remains far below the April21 score reference + (`58732.71 tok/s`, `2.07 GiB`) and does not close throughput. +- Next high-priority R2.3/R4/R13 work: replace the kernel-level fixed + gated+diagonal/message/readout pointer ABI with tensor-table pointer rows and + primitive executor dispatch, then do the same for reverse scan/recompute and + parameter reductions. The cleanup rule remains deletion-first: as each + primitive executor becomes real, remove the old bundled kernel/wrapper route + rather than retaining parallel paths. + +Code log - 2026-04-30 forward kernel descriptor ABI: + +- Collapsed the cooperative forward scan kernel signature from a long fixed + tensor/scalar argument list into two launch descriptors: + `TemporalScanKernelTensors` and `TemporalScanKernelScalars`. The host launch + now passes those two descriptors through `kernel_args` instead of expanding + every message/readout/transition/checkpoint tensor as a separate CUDA + parameter. +- This is a structural staging step toward table-owned primitive dispatch, not + the final compiler boundary. The descriptor still contains the current + supported tensor fields, and the CUDA body still calls the existing forward + primitive executor functions. The next real closure step is to replace this + fixed descriptor with tensor/op-row indexed primitive executor inputs, then + remove the old `flat_bucket_gated_diagonal_temporal_scan_kernel` identity + entirely. +- Verification: + `git diff --check` passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py --tb=short` + passed: 3 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_kernel_descriptor TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_kernel_descriptor uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_kernel_descriptor TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_kernel_descriptor uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_kernel_descriptor TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_kernel_descriptor uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- Warmed T=1 guardrail: + `/tmp/redo2_t1_forward_kernel_descriptor` reported `1105.11 tok/s`, + `926.60 ms`, and `29.55 GiB` on the same + `100M/B=1024/T=1/K=1/h=32/single-pop/terminal training` row. The result is + stable relative to `/tmp/redo2_t1_forward_table_impl` and still far below the + April21 reference (`58732.71 tok/s`, `2.07 GiB`). Metadata still reports + `16` primitive-executor blockers, so R2.3/R4/R13/R15 remain open. + +Code log - 2026-04-30 forward primitive program gate: + +- Rejected the descriptor-only direction as another facade. The active forward + temporal scan now validates the lowered primitive-row programs before launch: + message bucket `-1` must be the compiled dot-product + `linear,linear,linear,attention_logits,add,segment_softmax,weighted_sum,linear` + program; readout bucket `-2` must be `readout_project,reduction_boundary`; + transition buckets must be either + `linear,matmul,gated_logspace_recurrence,norm_or_identity` or + `linear,diag_rtu,linear`. +- Added explicit `message_program_id` and `readout_program_id` scalar fields + into the forward scan descriptor and routed the CUDA body through + row-selected primitive executor wrapper calls for the currently supported + fused message/readout programs. Unsupported changed declarations now fail + closed before launch with a `fabric.cuda.nn primitive executor/lowering` + error instead of silently running the fixed Q/K/V/gated/diag route. +- Closed substage: forward scan no longer accepts arbitrary primitive rows + while executing the current fixed math. This is a compiler-boundary guard, + not full compiler closure. R2.3/R4/R13/R15 remain open because the scan + descriptor still contains fixed message/readout/transition tensor roles, the + kernel symbol is still the bundled + `flat_bucket_gated_diagonal_temporal_scan_kernel`, and audit metadata still + reports primitive executor blockers for message, readout, transition affine, + recurrence/norm, parameter reduction, and reverse. +- Verification: + `git diff --check` passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py::test_temporal_forward_primitive_row_tensor_encodes_supported_program_groups --tb=short` + passed: 5 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_program_gate TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_program_gate uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_program_gate TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_program_gate uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_program_gate_mixed TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_program_gate_mixed uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 66 tests. +- Warmed T=1 guardrail: + `/tmp/redo2_t1_forward_program_gate` reported `1143.56 tok/s`, + `895.45 ms`, and `29.55 GiB` for the high-level API + `100M/B=1024/T=1/K=1/h=32/single-pop/terminal training` row. This preserves + the current regressed line and remains far below the April21 reference + (`58732.71 tok/s`, `2.07 GiB`), so no throughput owner closes. +- Next high-priority owner: replace the fixed forward descriptor with true + tensor-table/op-row pointer dispatch for message, transition, readout, and + checkpoint artifacts, then mirror the same boundary in reverse scan and + parameter reduction. Do not mark metadata blockers closed until those + executors are real and the bundled kernel identity is deleted. + +Code log - 2026-04-30 forward tensor pointer-table launch ABI: + +- Removed the `TemporalScanKernelTensors` CUDA launch struct and replaced it + with a device tensor-pointer table passed as + `const int64_t* __restrict__ tensor_table_ptrs`. The host side now builds a + pointer table, checks it against `kTsTensorSlotCount`, copies it to CUDA, and + launches the cooperative forward scan with `{&tensor_table_ptrs_arg, + &kernel_scalars}`. +- Closed substage: the forward launch ABI no longer passes a C++ struct whose + type name pretends to be a generic tensor table while still being a fixed + field bundle. This is only launch-shape progress. It does not close + R2.3/R4/R13/R15 because the slot enum is still fixed around the currently + supported message/readout/gated/diagonal tensors, the kernel symbol is still + `flat_bucket_gated_diagonal_temporal_scan_kernel`, and runtime metadata still + reports primitive executor blockers. +- Manual boundary review: no new cell-family selector, benchmark-row selector, + hidden-size policy key, or separate single/mixed route was added. The slice is + still not a real compiler executor because tensor slots are positional and + primitive-specific; the next owner must replace those fixed slots with + lowered tensor-role rows selected from primitive op rows, then do the same in + reverse and parameter reduction. +- Verification: + `git diff --check` passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_ptr_table TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_ptr_table uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_ptr_table TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_ptr_table uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_ptr_table TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_ptr_table uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- Warmed T=1 guardrail: + `/tmp/redo2_t1_forward_ptr_table` reported `1056.07 tok/s`, + `969.63 ms`, and `29.55 GiB`. A warmed repeat at + `/tmp/redo2_t1_forward_ptr_table_repeat` reported `1062.27 tok/s`, + `963.98 ms`, and `29.55 GiB` for the high-level API + `100M/B=1024/T=1/K=1/h=32/single-pop/terminal training` row. This preserves + the current regressed line but remains far below the April21 reference + (`58732.71 tok/s`, `2.07 GiB`), so no throughput stage closes. +- Next high-priority owner: make the pointer table semantic instead of + positional by lowering primitive tensor-role bindings into compact tables + consumed by message, transition, readout, checkpoint, reverse, and parameter + primitive executors. Delete the bundled kernel identity once those executors + own the active path. + +Rejected probe - 2026-04-30 temporal-side tensor-role id map: + +- Rejected an uncommitted attempt to replace C++ string role lookup with a + fixed Python map such as `"primitive.gated_logspace_recurrence.param.gate_weight" + -> 13`. This is still a facade: it aliases the same hardcoded + message/readout/gated/diagonal signature and would break or silently + misrepresent execution when a user declaration changes. +- New invariant: tensor bindings must be compiler products. The declaration and + compiled primitive program own input names, output names, parameter bindings, + and required artifacts. The temporal engine may consume lowered binding rows, + but it must not own a fixed role-name or role-id map for user signatures. +- Immediate consequence for R2.3/R15: the real breach is not the string lookup + itself. The active scan still receives hidden tensors such as + `primitive.gated_logspace_recurrence.param.value_to_state_weight` and + `primitive.gated_logspace_recurrence.param.recurrent_bias` while the sLSTM + transition declaration does not currently expose that projection in its + primitive program. The next real code must make the declared program match the + executed program, or fail closed. Do not retry temporal-side numeric role + maps. + +Code log - 2026-04-30 compiled transition parameter binding guard: + +- Added `src/cortical/fabric/backend/cuda/temporal_param_binding.py` as the + shared resolver for compiled transition-program parameter bindings. CUDA + transition execution and the active temporal scan/recompute/reverse paths now + resolve transition parameters such as `value_to_state_weight`, + `recurrent_bias`, `input_proj_weight`, `nu_log`, `out_proj_weight`, and + `out_proj_bias` through the compiled transition program instead of reading + old runtime tensor names directly. +- Updated the sLSTM declaration to expose the input projection that the active + scan already executes: + `linear(aggregated_message,value_to_state_weight,recurrent_bias) -> + transition_input`, followed by the gate affine, recurrent matmul, + `gated_logspace_recurrence`, and `norm_or_identity`. This fixes the immediate + mismatch where the executed program had hidden `value_to_cell_weight` and + `recurrent_cell_bias` tensors that were not represented in the sLSTM + transition IR. +- Boundary status: this is fail-closed compiler-boundary progress, not generic + executor closure. If a user signature changes and the compiled transition + program does not bind the requested logical parameter, the CUDA temporal path + now raises at the compiled binding boundary instead of guessing an old tensor + by convention. The active forward/reverse scan still has fixed role strings, + a fixed bundled scan kernel, and primitive executor blockers for message, + transition, readout, and parameter reduction. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/temporal_param_binding.py src/cortical/fabric/backend/cuda/transition_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py --tb=short` + passed: 68 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_compiled_param_binding TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_compiled_param_binding uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_compiled_param_binding TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_compiled_param_binding uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_compiled_param_binding_mixed TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_compiled_param_binding_mixed uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- T=1 guardrail: + `/tmp/redo2_t1_compiled_param_binding_final` reported `1055.58 tok/s`, + `970.08 ms`, and `29.55 GiB` for the high-level API + `100M/B=1024/T=1/K=1/h=32/single-pop/terminal training` row. This remains far + below the April21 reference (`58732.71 tok/s`, `2.07 GiB`) and does not close + any throughput owner. +- Next high-priority owner: replace the remaining fixed role-string temporal + scan/reverse ABI with compiler-owned binding rows and primitive executors. + The temporal engine may consume lowered binding rows, but it must not know a + user signature such as Q/K/V, gated params, diagonal params, or readout params. + +Code log - 2026-04-30 compiler-owned temporal tensor binding rows: + +- Closed the next R2.2/R2.3 boundary substage: `TemporalPrimitiveTablePlan` + now carries `TemporalTensorBindingRow` entries produced from compiled + message, readout, transition, and parameter-reduction primitive rows. The + binding rows record surface, primitive row index, logical input/output or + parameter name, and compiled source binding. This is a compiler product; it + is not a temporal-side role-id map. +- Added fail-closed behavior for signature drift. If a compiled primitive row + declares a parameter input that the corresponding compiled message/readout/ + transition binding source cannot explain, temporal table construction raises + before the active scan can run a hidden fixed ABI. +- Tightened parameter-reduction rows: they now lower only from compiled + transition parameter bindings, not every transition tensor schema entry. + This removed fake reduction rows for schema-only placeholders such as + receiver/query or cache-like tensors. Durable skill update added the same + rule. +- Runtime metadata now records `flat_bucket_temporal_table_tensor_binding_rows` + and per-row `flat_bucket_temporal_tensor_binding:*` summaries when the active + path builds the temporal table. This is guardrail evidence only. It does not + close R2.3/R4/R13 because the forward/reverse scan still projects these rows + into the existing fixed bundled tensor ABI and primitive executor blockers + remain open. +- Manual boundary review: no cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, separate single/mixed route, + or temporal-owned user signature map was added. Q/K/V-like and recurrence + names remain only as message/transition logical primitive inputs in binding + summaries, not as a claimed generic temporal executor ABI. The remaining open + owner is to make the CUDA scan/reverse executors consume these binding rows + directly and delete the bundled kernel identity. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py --tb=short` + passed: 70 tests. + `git diff --check` passed. + Earlier interrupted attempts at the CUDA row exited 143 while unrelated long + evaluation jobs occupied GPUs 0-4. Those evaluation jobs were stopped on user + request, then the CUDA checks below were rerun on GPU 0 with the same private + cache. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_tensor_bindings TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_tensor_bindings uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_tensor_bindings TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_tensor_bindings uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_tensor_bindings TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_tensor_bindings uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- T=1 guardrail: + `/tmp/redo2_t1_tensor_binding_rows` reported `1053.19 tok/s`, + `972.28 ms`, and `29.55 GiB` for the high-level API + `100M/B=1024/T=1/K=1/h=32/single-pop/terminal training` row. This is stable + with the prior regressed line but remains far below the April21 reference + (`58732.71 tok/s`, `2.07 GiB`), so no throughput owner closes. +- Next high-priority owner: wire the active CUDA scan/reverse table projection + to consume compiler-owned binding rows directly, starting with message and + readout parameter/source bindings, then transition artifacts and parameter + reductions. Keep deleting or fail-closing stale fixed ABI routes as each + primitive executor becomes real. + +Code log - 2026-04-30 supported scan binding projection guard: + +- Added `validate_temporal_supported_scan_binding_projection(...)` as the + explicit boundary between compiler-owned tensor binding rows and the current + fixed forward/recompute scan ABI. The active scan now accepts only the + currently implemented compiled projections: + `message=dot_product_segment_softmax_weighted_sum`, + `readout=readout_project_reduction_boundary`, + `transition_bucket=N:gated_logspace_recurrence`, and + `transition_bucket=N:diag_rtu`. +- This closes a narrow facade risk: if message/readout/transition declarations + change consistently and have compiled parameter sources, temporal table + construction can still succeed, but the fixed scan projection now fails + before launch instead of silently running the old Q/K/V/readout/gated/diagonal + ABI. Tests cover missing parameter bindings plus consistent message and + readout signature drift. +- Runtime metadata now records + `flat_bucket_temporal_scan_binding_projection:*` entries so audits can see + which compiler binding projection the active fixed scan consumed. This is not + primitive-executor closure; the metadata still reports missing primitive + executor blockers and the CUDA scan/reverse code still projects into the + bundled supported ABI. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py --tb=short` + passed: 72 tests. + `git diff --check` passed. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_tensor_bindings TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_tensor_bindings uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_tensor_bindings TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_tensor_bindings uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_tensor_bindings TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_tensor_bindings uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- T=1 guardrail: + `/tmp/redo2_t1_scan_binding_projection` reported `1060.00 tok/s`, + `966.04 ms`, and `29.55 GiB` for the high-level API + `100M/B=1024/T=1/K=1/h=32/single-pop/terminal training` row. Metadata + included + `flat_bucket_temporal_scan_binding_projection:message=dot_product_segment_softmax_weighted_sum`, + `flat_bucket_temporal_scan_binding_projection:readout=readout_project_reduction_boundary`, + and + `flat_bucket_temporal_scan_binding_projection:transition_bucket=0:gated_logspace_recurrence`. + This is stable with the current regressed line but remains far below the + April21 reference (`58732.71 tok/s`, `2.07 GiB`). +- Next high-priority owner: replace the fixed projection validator with real + row-owned primitive executor dispatch, starting with the forward message and + readout blocks, then carry the same binding rows into reverse and parameter + reduction. R2.3/R4/R13 remain open until the bundled scan/reverse ABI is + deleted and throughput recovers. + +Code log - 2026-04-30 forward executor rows and universal-op guard: + +- Added compiler-produced `TemporalForwardExecutorRow` records and passed their + tensor form into the CUDA temporal scan. This is the first active boundary + where the scan receives a selected executor table separately from primitive + op rows. It is still a fixed supported executor projection, not final + compiler closure, because the scan still calls the bundled implementation + after validating the executor rows. +- Renamed the new executor row identities to universal primitive-sequence + names: `neighborhood_attention_project` and `projection_reduction_boundary`. + The surfaces remain `message` and `readout`, but those are scheduling/use + surfaces, not primitive types. `fabric.cuda.nn` primitives are universal ops; + the same primitive opcode can appear in message, transition, readout, and + parameter-reduction rows. +- Added a static test proving universal primitive identity across surfaces: + `linear` keeps the same primitive opcode on message and transition rows, and + `reduction_boundary` keeps the same opcode on readout-boundary and + parameter-reduction rows. The test also verifies forward executor rows are a + fused selection layer over primitive rows rather than a new primitive kind. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_temporal_scan_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py --tb=short` + passed: 73 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_executor_rows TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_executor_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test after cold extension build. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_executor_rows TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_executor_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_redo2_forward_executor_rows TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_redo2_forward_executor_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop --tb=short` + passed: 2 tests. +- Guardrail: + `/tmp/redo2_t1_forward_executor_rows` reported `1055.61 tok/s`, + `970.06 ms`, and `29.55 GiB` for the high-level API + `100M/B=1024/T=1/K=1/h=32/single-pop/terminal training` row. This is still + far below the April21 reference (`58732.71 tok/s`, `2.07 GiB`), so no + throughput owner closes. + +Current-tree assessment - 2026-04-30: + +- Bottom line: this is useful compiler-boundary scaffolding, not true REDO2 + compiler closure. Treat the current tree as roughly `6/10` architecture + hygiene and `2/10` completed compiler closure. +- What is real progress: the old sequence-surface sprawl is split into + `compiler/`, `runtime/`, `temporal/`, `flat_bucket/`, + `ops/temporal_backward/`, and `transition_execution/`; runtime dispatch now + imports split transition lowering/types and flat-bucket modules; variable-K + and direct hidden-input temporal routes fail closed instead of silently + falling back to old Python loops; compiler metadata now includes primitive + rows, tensor binding rows, forward/reverse executor rows, verifier status, + strategy blockers, compatibility debt, and memory-plan summaries. +- Current validation reported by the tree assessment: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed with `105` tests, and the targeted runtime projection-gate test passed + with `6` cases. +- What is not closed: the compiler now has a separate forward executable plan, + but the active forward route still uses the quarantined compatibility launch + (`cuda_temporal_scan_table_extension_compatibility`) while backward still uses + `cuda_temporal_reverse_table_extension_compatibility`. Primitive dispatch + still reports mostly `fixed_composite_abi` or `missing_executor`, active + forward scan still calls the fixed ABI wrapper with message/readout/gated/ + diagonal tensors plus compiler rows, active reverse still calls the fixed + transition/message reverse wrapper with explicit gated/diagonal/message + arguments, and the CUDA wrappers/kernels still accept a fixed semantic bundle + rather than dispatching registered primitive executors from tensor bindings. +- Hard active limits still remaining: fixed scan supports at most one gated + bucket and one diagonal bucket, the wrapper caps hidden/head/value/degree + dimensions, the scan kernel accepts only the current neighborhood-attention, + projection-boundary, gated-transition, and diagonal-transition executor IDs, + and `transition_execution` still recognizes exactly + `gated_logspace_recurrence` and `diagonal_rtu` structural patterns. +- Future issue risk is high if this is described as "true compiler closure": + new cell ops, extra transition buckets, different message/readout + declarations, larger dimensions, or non-gated/diag recurrences will hit fixed + ABI limits or `missing_executor` blockers. If described honestly as partial + compiler scaffolding with fail-closed guardrails, the risk is much lower. +- Immediate process risk before any commit: the replacement source directories + are currently untracked (`ops/temporal_backward/`, `sequence_surface/compiler/`, + `sequence_surface/flat_bucket/`, `sequence_surface/runtime/`, + `sequence_surface/temporal/`, and `transition_execution/`). A commit that + records only tracked deletions/modifications and misses these directories + will break the tree. + +Code log - 2026-04-30 forward executable plan split: + +- `sequence_surface/compiler/forward_plan.py` now owns only the true compiler + forward executable plan: `TemporalForwardExecutablePlan` over compiler + executor binding rows, strategy candidates, legality reasons, and the future + `registered_fused_forward_program_cuda` runtime entrypoint. +- The fixed scan tensor-slot projection moved out of the compiler executable + plan into `sequence_surface/compiler/forward_compatibility.py` as + `TemporalForwardCompatibilityLaunchPlan`. That file is explicit compatibility + debt: it still names the fixed slot ABI for the current CUDA scan wrapper, + and it is not presented as compiler closure. +- Runtime metadata now records the two facts separately: + `_last_flat_bucket_temporal_forward_executable_plan` for the compiler-owned + plan and `_last_flat_bucket_temporal_forward_compatibility_launch_plan` for + the active fixed scan bridge. The active route remains open work until the + scan consumes registered executor bindings directly and the compatibility + launch is deleted. +- The remaining forward pass-2 goal is no longer naming cleanup. It is to make + `registered_fused_forward_program_cuda` real: replace the fixed scan ABI with + row-owned primitive executor dispatch that consumes `PrimitiveRowIR`, + tensor-binding rows, executor-binding rows, and effect/lifetime plans without + top-level gated/diagonal/QKV/readout slots. +- `sequence_surface/compiler/backward_plan.py` was split the same way: + `TemporalBackwardExecutablePlan` now represents the compiler-owned reverse + executable plan over reverse executor-binding rows and future + `registered_reverse_executor_bindings`, while + `sequence_surface/compiler/backward_compatibility.py` contains the fixed + reverse tensor-slot launch used by the current CUDA reverse table bridge. +- Runtime metadata now records + `_last_flat_bucket_temporal_backward_executable_plan` separately from + `_last_flat_bucket_temporal_backward_compatibility_launch_plan`. The active + reverse route is still compatibility debt until the reverse table consumes + registered executor bindings directly and the top-level gated/diagonal/ + message reverse slot ABI is deleted. +- `TemporalArtifactStore` now stores + `forward_compatibility_launch_plan` and + `backward_compatibility_launch_plan` explicitly. The old generic + `forward_launch_plan`/`backward_launch_plan` names were removed from + source/tests/docs so the active artifact path cannot be mistaken for the + future registered executor-binding path. +- The shared temporal forward/recompute path now fails closed before launching + the fixed scan kernel whenever the compiler executable forward or backward + plan is not legal. The recorded reject is + `registered_executor_bindings_required` with + `fixed_scan_reverse_compatibility_kernels_disabled=1`. +- The temporal reverse executor now rejects stored compatibility reverse launch + plans unless `registered_reverse_executor_bindings` is legal. The fixed + reverse table kernel is no longer an accepted active reverse path for blocked + compiler plans. +- This checkpoint was superseded by the registered executor activation below. + At that earlier point, the cutoff was not conditional on future legality: if + forward/reverse executable plans became legal before registered kernels + existed, the active paths rejected with + `registered_forward_executor_kernel_not_implemented`, + `registered_forward_executor_recompute_kernel_not_implemented`, or + `registered_reverse_executor_kernel_not_implemented` instead of falling back + to the old fixed kernels. +- This intentionally regresses the old CUDA compatibility runtime smoke until + the registered forward/reverse executor kernels land. That is the correct + fail-closed compiler state: unsupported rows no longer run through the fixed + gated/diagonal/QKV reverse-table ABI. +- Verification after the executable/compatibility plan split: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_compatibility_launch_plan_owns_scan_tensor_slots tests/test_fabric_backend_boundaries.py::test_temporal_backward_compatibility_launch_plan_owns_reverse_tensor_slots tests/test_fabric_backend_boundaries.py::test_reverse_executor_rows_are_selected_from_pattern_registry tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers tests/test_fabric_backend_plan.py::test_temporal_backward_rejects_compatibility_launch_until_registered_reverse_exists --tb=short` + passed across the focused targets. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 111 tests. + The old CUDA compatibility runtime smoke is no longer a passing target after + this cutoff; it must be replaced by a registered executor-binding runtime + smoke when the new kernels land. + `git diff --check && git diff --cached --check` passed. + +Aggressive five-pass compiler conversion plan: + +Compiler strategy/fusion contract: + +- The required architecture is: + user declarations -> semantic IR -> canonical primitive rows -> tensor + roles/parameter bindings -> effect annotations -> verifier -> legal strategy + candidates -> cost model/planner -> temporal schedule -> executor bindings + -> audited execution. +- The named compiler pipeline is: + frontend declarations -> `SemanticIR` -> graph/message/cell normalization -> + shape/dtype/device analysis -> `PrimitiveRowIR` -> canonicalization/row + grouping -> tensor-role/layout assignment -> active-region plan -> + boundary/carry/tape plan -> memory/liveness/workspace plan -> forward + physical plan -> backward physical plan -> strategy matching -> cost + selection -> executable launch plan -> audit metadata. +- Verification is required at every boundary: + `verify_semantic_ir(program)`, `verify_primitive_rows(rows)`, + `verify_tensor_roles(rows, bindings)`, `verify_temporal_plan(plan)`, + `verify_strategy_legality(strategy, executable_plan)`, + `verify_backward_coverage(forward_plan, backward_plan)`, and + `verify_no_hidden_fallback(execution_record)`. +- Semantics are compiler-owned. Executor strategies, including fused CUDA + kernels, are replaceable implementations of verified compiler products. A + strategy is never the mathematical authority for a cell, message rule, + readout, or gradient. +- Fusion is a legality-preserving rewrite plus an executor strategy over a + verified semantic subgraph/row group. Every fused strategy must declare and + prove the row pattern it implements, materialization points, + reset/replay/checkpoint behavior, observable emissions, forward/backward + equivalence, memory effects, aliasing behavior, workspace needs, ABI + versions, scheduler dependency rewrites, and typed failure modes. +- Legality and cost are separate phases. Legality filtering proves a strategy + may run; cost/ranking only chooses among legal strategies. A bad cost model + may make execution slower, never incorrect. +- Row-pattern matching must be structural: canonical ops, tensor roles, + dependency edges, state reads/writes, emissions, reset behavior, layout + constraints, and gradient requirements. Strategy matching must not depend on + cell names, benchmark labels, hidden-size fixtures, or route identities. +- The scheduler remains math-ignorant but effect-aware. It may schedule + abstract effects such as state read/write, output emission, tape consumption, + materialization, replay requirement, temporal dependency, checkpoint boundary, + reversibility, workspace mutation, aliasing, and layout transform. It must + not implement primitive formulas. +- Backward executors implement compiler-level gradient specs. They must declare + gradient inputs, saved/recomputable tensors, tape schema, accumulation + behavior, parameter-gradient outputs, state-gradient outputs, tolerance + expectations, and deterministic/nondeterministic behavior. +- Executor metadata is versioned: + `strategy_id`, `strategy_version`, `row_schema_version`, + `tensor_binding_schema_version`, `metadata_schema_version`, and + `cuda_kernel_abi_version`. +- Strategy records must split `can_implement(plan)` legality, + `estimate_cost(plan, facts)` planning, and `execute(plan, tensors)` runtime. + Required fields are stable ID/version, primitive pattern signature, legality + predicate, required tensor roles, required layouts/strides/contiguity, + supported dtypes/devices/compute capabilities, stream and CUDA graph capture + compatibility, workspace requirements, aliasing permissions, + saved-tensor/recompute contract, forward executor, backward executor, + gradient accumulation contract, determinism/tolerance class, cost-model hook, + fallback/demotion reason enum, and audit metadata schema. +- Strategy rejection is typed and auditable: + `UNSUPPORTED_PATTERN`, `UNSUPPORTED_DTYPE`, `UNSUPPORTED_LAYOUT`, + `INSUFFICIENT_WORKSPACE`, `RESET_POLICY_MISMATCH`, + `TAPE_POLICY_MISMATCH`, `DEVICE_CAPABILITY_MISMATCH`, + `SHAPE_OUT_OF_RANGE`, and `MISSING_REQUIRED_BINDING` are minimum categories. +- Memory/layout planning is a compiler product: activation/tape/workspace + lifetimes, aliasing, layout transforms, contiguous/strided requirements, + alignment, persistent buffers, temporary buffers, and recompute-vs-store + choices are planned facts, not hidden strategy side effects. +- Every supported row group must have a boring reference implementation + (`semantic IR -> reference primitive executor`) that defines truth for + optimized primitive, fused, and specialized throughput executors. +- Strategy registry governance is mandatory: stable ID, owner, matching rules, + required tests, benchmark coverage, deprecation/removal criteria, emitted + audit metadata, verifier dependency, and planner explain output for candidate + strategies, typed rejections, selected strategy, schedule, materialization + policy, kernels run, and any fail-closed event. +- Compiler-boundary tests are layered: semantic IR goldens, primitive-row + goldens, canonicalization tests, strategy legality positive/negative tests, + fail-closed unsupported-row tests, no-hidden-fallback source/profile tests, + forward/backward differential tests, finite-difference gradient spot checks, + shape/dtype/reset/materialization fuzz tests, memory-budget/recompute tests, + CUDA graph capture tests, and legacy-path deletion tests. + +Five-pass compiler checklist: + +Audit note, 2026-05-02: this block is the historical checklist from the +fixed-ABI review. It is intentionally kept as the closure target that guided +the work, but it is not the current signoff state. The current pre-throughput +compiler closure checklist and deep-dive audit live at the end of this log; use +that later checklist for "what is still open now." + +- [x] Pass 1 - compiler ABI freeze, verifier, and legacy fail-close. + Status: SUPERSEDED/CLOSED by the later active closure checklist below. This + historical review entry originally marked the fixed-ABI baseline as partial; + the current route now rejects unsupported rows before launch, records typed + blocker metadata, and executes supported CUDA rows through compiler-owned + primitive/executor/binding products. + Historical target: keep the fail-closed runtime behavior, but do + not count metadata as closure until every active shared-engine entrypoint + rejects unsupported declarations before launch and no active route infers + semantics from fixed role names such as Q/K/V, gated, diagonal, or readout. + Deliverable: every active shared-engine entrypoint accepts only compiler + products (`IR -> primitive rows -> tensor bindings -> effects -> verified + executor rows`) and fails before launch on unsupported rows with typed, + audited rejection; no Python replay fallback, legacy planner route, + compatibility shim, or fixed tensor-name ABI remains active for the surfaces + this pass owns. +- [x] Pass 2 - universal strategy registry and forward executor table. + Status: SUPERSEDED/CLOSED by the later active closure checklist below. This + historical review entry originally marked forward strategy ownership as + partial; the active supported forward path now consumes registered executor + programs, access/carry/output-route rows, native callable records, and typed + fail-closed strategy blocker metadata. + Historical target: the fixed + `try_flat_bucket_temporal_scan_cuda(...)` argument bundle and scan extension + source have been deleted; replace the resulting fail-closed gap with + registered primitive executor dispatch. Forward scan must consume tensor + binding rows as the ABI, not alongside hardcoded `message.recurrent_q`, + `primitive.gated_logspace_recurrence.*`, `primitive.diag_rtu.*`, and readout + parameter roles. + Deliverable: forward/recompute scan consumes universal primitive executor + rows and tensor binding rows for message, readout, transition, and parameter + surfaces; strategies match canonical row groups, emit typed rejections, and + tests prove primitive identity is surface-independent. +- [x] Pass 3 - reverse compiler executor, autodiff contract, and parameter binding. + Status: SUPERSEDED/CLOSED by the later active closure checklist below. This + historical review entry originally marked reverse ownership as partial; the + active supported reverse path now runs through the registered tensor-store + reverse program, route-aware artifacts, compiler-declared span outputs, and + registered parameter reducers. + Historical target: the fixed + `try_transition_message_reverse_table_window_cuda(...)` explicit + gated/diagonal/message argument bundle, Python wrapper, pybind export, and + device-loop kernel have been deleted; replace the resulting fail-closed gap + with registered backward executor dispatch. At the time, primitive dispatch + could report blockers; later registered reverse/tensor-store/reducer cuts + closed this for the supported active CUDA route. + Deliverable: reverse scan, boundary adjoints, primitive adjoints, and + parameter reductions are selected from the same primitive/tensor binding + system; backward strategies are legality-checked implementations of compiler + gradient specs; `temporal_backward.py` is split into semantic files and the + host-loop replay path is deleted. The op-specific temporal backward + `try_*` wrappers must be removed from the active route here, replaced by + compiler-selected backward strategy bindings with typed rejection. +- [x] Pass 4 - effect-aware temporal scheduler and memory/layout planner. + Status: SUPERSEDED/CLOSED by the later active closure checklist below. This + historical review entry originally marked scheduler/memory ownership as in + progress; the active path now validates executable runtime buffers against + compiler memory/liveness rows and uses scheduler/output/artifact rows for the + supported temporal CUDA route. + Historical target: remove fixed scheduler/kernel limits that still + make many-population and larger-shape support a compatibility projection: + at most one gated bucket plus one diagonal bucket, fixed executor-ID set, + hidden/head/value/degree caps, and T/K/H behavior coupled to the current + superop argument layout. + Deliverable: planner records output emissions, autograd seed surfaces, + horizon windows, checkpoint/recompute policy, materialization requests, + effects, activation/tape/workspace lifetimes, aliasing, layout transforms, + and recompute/store choices as explicit data; temporal kernels schedule + executor rows over `T*K` without owning primitive math or T=1/K-only cases. +- [x] Pass 5 - governance, explainability, legacy deletion, and audits. + Status: SUPERSEDED/CLOSED by the later active closure checklist below. This + historical review entry originally marked governance/deletion/audits as + pending; the current active checklist records the no-compatibility source + sweep, typed `missing_executor` classification, full Fabric-focused pytest + sweep, generated-catalog validation, and tree-hygiene classification. + Historical target: do not merge or describe this as true compiler + closure while `fixed_composite_abi`, compatibility entrypoints, or untyped + fallback paths remain in the active path. The current source sweep found no + such live active-path hits; remaining `missing_executor` strings are typed + unsupported-row metadata only. + Deliverable: old kernels/wrappers/pybind exports/planner flags/tests for + replaced routes are removed; strategy registry governance, planner explain + output, versioned audit metadata, reference parity, benchmark coverage, and + strict static compiler-boundary guardrails are in place; no strategy can be + selected without verifier checks. The old + `ops/temporal_backward/temporal_backward_cuda.py` file has been deleted; the + active temporal backward route now reaches supported CUDA work through + registered executor bindings and typed unsupported-row rejection. + +Pass 1 - compiler ABI freeze, verifier, and legacy fail-close: + +- Freeze the only accepted active-route contract: + declaration -> IR -> canonical primitive op rows -> tensor binding rows -> + effect annotations -> verifier -> executor rows -> temporal scheduler. The + temporal scheduler can see surface, dependency, reset, checkpoint, + materialization, effect, and time metadata; it cannot infer primitive math + from tensor names, parameter names, or cell/message family. +- Add a compiler-owned verifier before execution. It validates row-pattern + well-formedness, tensor-role completeness, shape/layout/dtype consistency, + reset/materialization/tape compatibility, scheduler dependencies, + forward/backward contract presence, device/workspace constraints, ABI/schema + versions, and absence of hidden route identities. +- Convert current projection guards into hard fail-closed gates for every + forward/recompute/reverse entrypoint. Unsupported op rows or executor rows + must fail before CUDA launch with typed rejection reasons. No Python replay + fallback, legacy planner route, compatibility shim, or fixed tensor-name ABI + is allowed to remain active for the shared engine. +- Version the compiler/executor ABI at the boundary: + strategy/schema/kernel versions must be emitted into audit metadata and + checked before cached plans or benchmark traces are trusted. +- Delete or quarantine old sequence-surface paths as soon as the compiler gate + owns the same surface. Use commit history for April21 behavior, not live + fallback code. +- Current remaining work from the 2026-04-30 tree assessment: + - Convert active entrypoint validation from metadata-adjacent checks to a + hard compiler verifier gate that owns the launch decision. + - Make unsupported message, readout, transition, recurrence, reset, and + shape cases fail with typed verifier rejection before CUDA launch, not by + falling through to fixed wrapper shape guards. + - Keep the `temporal_table_fixed_scan_compatibility` blocker visible in + runtime metadata until Pass 2/4 remove the fixed bucket-cardinality + projection. + +Code log - 2026-04-30 Pass 1.1/1.2 fail-close cleanup: + +- Deleted the dormant active-output Python route from + `sequence_surface/temporal_executor.py`: + `execute_temporal_bucket_active_output_window`, + `_active_output_receiver_window`, and + `record_temporal_bucket_sequence_surface_execution`. This removed a second + message/transition/readout implementation that bypassed temporal compiler + rows and was not referenced by production call sites. +- Boundary-space `T=1` calls now enter the same shared temporal scan branch as + longer boundary sequences when CUDA flat buckets and constant K are selected. + Direct hidden-input adapter cleanup remains open because it must be routed + through the public input projection/boundary adapter path, not a private + temporal shortcut. +- Boundary-space variable-K CUDA calls now fail closed until per-timestep K is + lowered by the planner/compiler. The old Python step route is not allowed to + execute an unlowered variable-K program under a `cuda_temporal_superop` plan. +- CUDA replay artifact recompute no longer sparse-replays output messages with + a host message call. If any output-message artifact is required, the CUDA + replay scan is asked for output-message artifacts; missing artifacts now fail + closed instead of calling `_compute_messages_step_subset_partitioned_raw`. +- Guardrails added: + `test_sequence_executor_does_not_keep_active_output_python_route` and + `test_temporal_recompute_does_not_sparse_replay_output_messages_on_host`. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 77 tests. +- CUDA guardrails passed with private caches on GPUs 0 and 1: + `test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts` and + `test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific`. + Both targeted tests passed after extension rebuild (`360.24s` and + `357.18s` respectively). +- Remaining Pass 1 blockers: + the forward scan wrapper and CUDA binding still expose the old positional + scan tensor ABI next to compiler rows. This is fail-closed by row/executor + validation but not yet the final tensor-binding executor ABI; Pass 2 owns the + actual universal forward primitive executor table replacement. + +Code log - 2026-04-30 Pass 2/3 temporal backward migration: + +- Deleted `sequence_surface/temporal_backward.py` entirely after initially + recreating it as a too-small barrel module. The active code now imports the + real compiler-owned modules directly; there is no top-level temporal backward + facade and no moved-monolith import. +- Deleted the old monolithic implementation/reference file after migration. + Static guardrails now reject any active import of + `temporal_backward_engine` or `temporal_backward_legacy_reference`. +- Split active code into the `sequence_surface/temporal/` package: + `types.py`, `windows.py`, `common.py`, `forward_scan.py`, + `output_backward.py`, `param_binding.py`, `reverse_executor.py`, and + `physical_autograd.py`. The later registered-executor pass deleted + `boundary_backward.py` once boundary projection backward was owned by the + selected reverse executor implementation. +- Deleted the dormant one-step host-style temporal API from active source: + `compute_temporal_bucket_step_artifacts`, + `_run_temporal_bucket_step_backward_result`, and + `run_temporal_bucket_step_backward`. The old CUDA parity test that existed + only for that route was removed; active training coverage now remains on the + sequence/shared temporal path. +- Removed the remaining direct-hidden `_forward_stream_step` loop from + `temporal_executor.py`. CUDA temporal execution now requires planner-lowered + boundary sequences; unlowered direct hidden input fails closed until the + public input projection is lowered before the temporal scheduler. +- Verification after migration: + `uv run ruff check` on the temporal split modules and touched tests passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 80 tests. + +Code log - 2026-04-30 sequence-surface semantic package split: + +- Organized the active CUDA sequence surface into semantic packages: + `compiler/` for temporal bucket/table/scan-schedule/compiler-dispatch rows, + `flat_bucket/` for CUDA flat-bucket bindings, kernels, and fused device + helpers, `runtime/` for runtime mixins, policy, support, and the public + sequence executor, and `temporal/` for forward scan, reverse execution, + autograd, boundary/output adjoints, parameter binding, shared types, and + reverse-window helpers. +- The package root now only exports `CudaSequenceSurfaceMixin` from + `runtime.surface`. Top-level modules such as `temporal_backward.py`, + `temporal_executor.py`, `temporal_tables.py`, `temporal_scan.py`, + `flat_bucket_temporal_scan_cuda.py`, and `flat_bucket_layout_cuda.py` are no + longer active import locations. +- Updated active imports and static guardrails so old sequence-surface module + names resolve to no module while the new compiler/runtime/flat-bucket/ + temporal package paths import directly. The legacy temporal-backward + reference file remains deleted. +- Removed the remaining temporal split wildcard imports from + `temporal.common`. Shared temporal helpers are now imported explicitly, and a + boundary test rejects reintroducing `common import *` or `F403/F405` + suppressions in temporal modules. +- Deleted adjacent old execution dispatcher code, including + `src/cortical/fabric/backend/cuda/recurrence_executor.py`, after routing + remaining sequence-surface calls through the compiler-owned temporal executor. + The transition math/tape/adjoint owner remains in the split + `transition_execution` package and still needs generic-lowering closure. +- Removed sequence-surface imports of the whole `transition_execution` module. + Sequence-surface modules now import the exact transition helpers they still + depend on, and a boundary test rejects reintroducing module-wide + `transition_execution` coupling from sequence-surface code. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 82 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_transition_explicit_imports_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_explicit_imports_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `359.36s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_transition_explicit_imports_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_explicit_imports_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `360.75s`. + +Code log - 2026-04-30 transition-execution compiler-shape cleanup: + +- Deleted the `src/cortical/fabric/backend/cuda/transition_execution.py` + monolith and replaced it with semantic transition-execution modules: + `types.py` for ABI dataclasses, `projection.py` for affine/projection tape + and deferred parameter-gradient helpers, `program.py` for compiled transition + program executor selection and reset/record helpers, and `lowering.py` for + the public CUDA transition forward/backward lowerers. The later + `temporal_fusion.py` experiment was rejected and deleted; fusion planning + belongs to compiler strategy records, not a transition-side no-op facade. +- Runtime dispatch, sequence-surface modules, and tests now import the semantic + transition submodules directly. Static guardrails reject reintroducing the + old transition package-root import or the deleted monolith. +- Replaced direct transition lowerer branching on gated/diagonal predicate + helpers with `select_transition_program_executor(...)`, which returns a + compiled transition executor plan. The currently registered executors remain + narrow (`gated_logspace_recurrence` and `diagonal_rtu`); unsupported compiled + primitive programs still fail closed. +- The attempted transition-side fusion request/result layer was not kept as an + active abstraction. It renamed old gated/diagonal temporal-backward hooks + without providing compiler-selected executor implementations, so it is + deleted. Real fusion records now live in + `sequence_surface/compiler/executor_patterns.py`, + `primitive_dispatch.py`, and `strategy_selection.py`, where they are blocked + until a verified rewrite plus executor strategy exists. +- Moved the temporal backward extension loader out of + `sequence_surface.flat_bucket` and into + `cortical.fabric.backend.cuda.ops.temporal_backward.extension`. + Transition execution and sequence surface depend on narrow backend ops + modules instead of importing sequence-surface flat-bucket modules through a + fusion wrapper. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/ops/temporal_backward src/cortical/fabric/backend/cuda/sequence_surface src/cortical/fabric/backend/runtime_dispatch.py tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/ops/temporal_backward src/cortical/fabric/backend/cuda/sequence_surface src/cortical/fabric/backend/runtime_dispatch.py tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py --tb=short` + passed: 20 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 85 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_fusion_engine_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fusion_engine_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `357.87s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_fusion_engine_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fusion_engine_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `360.78s`. + +Code log - 2026-04-30 Pass 4 scheduler-plan wiring: + +- Added `sequence_surface/temporal/scheduler.py` with explicit runtime records + for output emission physical steps, autograd seed kind, required backward + surfaces, checkpoint steps, backward-window steps, reverse artifact kind, and + replay artifact requests. +- `run_shared_temporal_bucket_forward_scan` now builds a + `TemporalRuntimeSchedulerPlan` from the planner `TemporalExecutionPlan` and + uses it instead of local `getattr` helpers for checkpoint/window and + forward-reverse-table decisions. The old unused temporal artifact store policy + helper was deleted. +- `TemporalPhysicalBackwardScanExecutor` now builds the same scheduler plan and + asks it for replay requests. Backward recompute no longer assembles + `output_message_physical_steps` inline from local tuple math. +- Output-gradient window materialization can consume the scheduler's explicit + physical-to-output index map, while still using the CUDA materializer when + the map matches the current scalar all-steps or terminal emission schedule. +- Follow-up wiring now passes `runtime._last_temporal_execution_plan` into the + inference shared scan, carries `ctx.output_boundary` into physical backward, + and records the runtime scheduler summary in execution workspace aliases. + The scheduler no longer infers terminal/sequence ownership from gradient + tensor shape in the backward executor. +- Guardrails added: + `test_cuda_temporal_runtime_scheduler_consumes_planner_records`, + `test_cuda_temporal_runtime_scheduler_uses_terminal_output_plan`, + `test_temporal_backward_replay_requests_come_from_scheduler_plan`, and + `test_temporal_scheduler_plan_is_plumbed_through_forward_and_backward`. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/ops/temporal_backward src/cortical/fabric/backend/runtime_dispatch.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed after the follow-up wiring. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed after the follow-up wiring. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 89 tests. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_scheduler_plan_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_scheduler_plan_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `353.96s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_scheduler_plan_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_scheduler_plan_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `355.53s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_scheduler_plan_plumb_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_scheduler_plan_plumb_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `355.14s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_scheduler_plan_plumb_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_scheduler_plan_plumb_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `354.43s`. + +Code log - 2026-04-30 fusion pattern registry slice: + +- Added `sequence_surface/compiler/executor_patterns.py` as the shared + representation of current fused executor groups. Message projection, + readout/boundary, gated transition, and diagonal transition now have explicit + row-pattern records with executor names and fixed-ABI blocker reasons. +- `temporal_forward_executor_rows(...)` now selects its current fused forward + executor rows through those row patterns instead of repeating one-off tuple + checks in each surface helper. +- `build_temporal_primitive_executor_plan(...)` now reports row-selected + current fusion groups as `fixed_composite_abi` blockers with precise reasons + such as + `transition_rows_selected_but_scan_kernel_uses_fixed_gated_composite_abi`. + This is intentionally not Pass 4 closure; it makes the remaining CUDA scan + ABI debt visible and machine-checkable. +- Fusion pattern records now live under `sequence_surface/compiler/` instead + of `transition_execution/temporal_fusion.py`. The current groups declare + primitive groups, required effects, versions, legality predicates, cost + hooks, runtime entrypoints, and `verified_rewrite_required` blockers. +- Guardrails added/updated: + `test_temporal_executor_fusion_patterns_are_structured`, + `test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch`, + and the later + `test_rejected_transition_temporal_fusion_facades_were_deleted`. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_transition_temporal_fusion_uses_generic_executor_requests tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 90 tests. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_fusion_patterns_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fusion_patterns_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `353.62s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_fusion_patterns_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fusion_patterns_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `353.37s`. + +Code log - 2026-04-30 temporal table metadata owner: + +- Added `sequence_surface/compiler/runtime_metadata.py` as the single helper + that stamps temporal table review fields, tensor-binding summaries, + scan-binding projection summaries, and primitive-executor blocker summaries + onto runtime metadata. +- `temporal/forward_scan.py` now calls that helper when the CUDA scan/replay + path accepts a temporal table instead of duplicating table metadata fields in + the forward and artifact-recompute paths. +- This moves fixed-composite ABI blockers to the CUDA table-consumption point: + current scan/replay rows now stamp + `_last_flat_bucket_temporal_primitive_executor_blockers` before late backend + record reconstruction. The blockers remain real blockers, not closure. +- Guardrails added: + `test_temporal_table_runtime_metadata_records_executor_blockers` and + `test_forward_scan_records_temporal_table_metadata_through_compiler_helper`. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_forward_scan_records_temporal_table_metadata_through_compiler_helper tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 92 tests. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_runtime_metadata_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_runtime_metadata_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `358.88s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_runtime_metadata_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_runtime_metadata_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `356.40s`. + +Code log - 2026-04-30 reverse executor row records: + +- Added `TemporalReverseExecutorRow` and + `temporal_reverse_executor_rows(...)` to the temporal compiler table layer. + Reverse execution now has structured compiler rows for the message reverse + executor and supported composite transition reverse executors. +- Added `temporal_reverse_executor_rows_tensor(...)` as the compatibility + projection from those reverse executor rows into the current 4-column CUDA + reverse-table ABI. The raw tensor is explicitly an executor-row projection, + not a primitive-row source representation. +- Runtime table metadata now records + `_last_flat_bucket_temporal_reverse_executor_summaries`, and backend + execution aliases include `flat_bucket_temporal_reverse_executor:*` entries. + This makes the reverse-side row selection visible to audits while preserving + the explicit blocker that the CUDA reverse kernel is still fixed ABI. +- Guardrails extended: + reverse executor row names/params are checked for mixed and one-transition + tables, and metadata recording now verifies reverse executor summaries are + stamped with primitive-executor blockers. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_allow_one_transition_bucket tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers tests/test_fabric_backend_boundaries.py::test_forward_scan_records_temporal_table_metadata_through_compiler_helper --tb=short` + passed: 4 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 92 tests. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_reverse_executor_rows_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_executor_rows_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `355.69s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_reverse_executor_rows_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_executor_rows_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `352.62s`. + +Code log - 2026-04-30 reverse executor pattern registry: + +- Moved reverse executor IDs, names, and row signatures out of + `compiler/tables.py` into `compiler/executor_patterns.py`. +- Added `TemporalReverseExecutorPattern`, + `temporal_reverse_executor_patterns(...)`, and + `match_temporal_reverse_executor_pattern(...)`, mirroring the forward + executor pattern registry. +- `temporal_reverse_executor_rows(...)` now selects message and composite + transition reverse rows through those patterns, then projects the selected + rows into the current 4-column CUDA reverse ABI. Unsupported message reverse + row shapes fail closed before launch. +- `test_temporal_executor_fusion_patterns_are_structured` now covers forward + and reverse pattern records, including reverse fixed-ABI reasons. +- Static boundary coverage now rejects reintroducing reverse opcode/name maps + in `compiler/tables.py`, rejects the old + `temporal_reverse_primitive_rows_tensor(...)` name, and requires reverse CUDA + calls to pass `reverse_executor_rows`. +- Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_allow_one_transition_bucket --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 92 tests. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed: 4 tests. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_reverse_patterns_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_patterns_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `354.14s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_reverse_patterns_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_patterns_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `352.96s`. + +Code log - 2026-04-30 verifier/effects/strategy governance integration: + +- Integrated the missing compiler architecture pieces into the five-pass plan: + verifier, legality-before-cost, structural row-group schemas, explicit + effects, compiler-owned autodiff contracts, strategy/ABI versioning, typed + rejection reasons, memory/layout planning, reference executors, registry + governance, and planner explainability. +- Tightened the plan with the additional refinements: named pass pipeline, + verifier gates at every compiler boundary, fusion as a legality-preserving + rewrite plus an executor strategy, explicit `can_implement`/`estimate_cost`/ + `execute` strategy phases, physical forward/backward plan mirroring, + memory/liveness/workspace ownership, and layered compiler-boundary tests. +- Added `sequence_surface/compiler/verification.py`, the first code-level + verifier scaffold. It records the named compiler pass pipeline, schema/ABI + versions, typed strategy rejection categories, inferred effect annotations, + verifier issues, and planner-explain summaries for the temporal primitive + table. +- `runtime_metadata.py` now stamps verifier status, pass pipeline, schema + versions, typed rejection codes, effect summaries, verifier issues, and + planner explain output onto runtime metadata alongside executor blockers. + Current supported rows still report `blocked`, not closed, because the CUDA + path still uses fixed compatibility ABIs. +- Structured forward and reverse strategy pattern records now carry stable + strategy IDs, strategy/schema/kernel versions, legality predicate, cost model, + runtime entrypoint, required effects, layout/dtype/device constraints, + workspace/aliasing/saved-tensor policies, gradient accumulation contract, + determinism/tolerance class, fallback policy, audit metadata schema, and a + `verified_rewrite_required` flag. The current entries still point at fixed + compatibility ABIs and must not be treated as throughput-closed strategies. +- Added `sequence_surface/compiler/row_groups.py` and moved strategy matching + onto canonical row-group schemas. Candidate rows are canonicalized with + surface, bucket, primitive sequence, parameter bindings, inputs/outputs, + attribute keys, and inferred effects; forward strategy matching now requires + the row-group effects it declares. Reverse strategies still match the + canonical forward row group while separately declaring backward effects. +- Superseded checkpoint: added `sequence_surface/compiler/strategy_selection.py` + as the first explicit + legality-vs-cost selection report. It records candidate match status, + legality status, typed rejection code, rejection reason, cost model, rank, and + runtime entrypoint. At that checkpoint fixed-ABI strategies could match + compiled row groups, but they were blocked with `UNVERIFIED_REWRITE` and + received no cost rank until verified executor replacements landed. +- Added `sequence_surface/compiler/memory_plan.py` as the first compiler-owned + memory/liveness/workspace plan. It derives tensor-role entries and primitive + effect entries from the temporal table, recording layout, lifetime, + workspace class, alias set, recompute policy, and owning compiler boundary. + This makes memory policy a planner product instead of executor-side hidden + tensor retention. +- `runtime_metadata.py` now stamps memory-plan review rows, memory-plan + summaries, workspace policy, layout policy, alias policy, and the current + peak-workspace estimate placeholder alongside verifier and strategy + selection metadata. +- Superseded checkpoint: added a compiler-owned + `TemporalForwardExecutorLaunchPlan` that records + selected forward executor rows, forward strategy IDs, the active scan + compatibility entrypoint, strategy legality status, and typed legality + blockers. Forward became audited the same way as backward: row groups could + match strategy records, but remained blocked with `UNVERIFIED_REWRITE` until + the fixed scan extension was replaced with registered executor bindings. +- Superseded checkpoint: the active forward scan and recompute paths failed + closed before the fixed scan compatibility wrapper could launch. They built + the compiler forward and backward executable plans, required legal registered + executor bindings, and rejected with + `registered_forward_executor_kernel_not_implemented` or + `registered_forward_executor_recompute_kernel_not_implemented` until a real + registered executor kernel existed. `temporal/forward_scan.py` no longer + imports or calls `try_flat_bucket_temporal_scan_cuda(...)`. +- Added explicit compiler compatibility-debt metadata for the remaining fixed + extension surfaces, including the disabled forward/reverse compatibility + wrappers and the temporal-backward materialization/reduction helpers. Each + entry records the owning pass and replacement contract so the wrappers cannot + be mistaken for finished primitive executors. +- Removed direct runtime tests and `__all__` exports that treated the remaining + temporal-backward `try_*` helpers as public APIs. They remain callable only + from the remaining materialization/reduction internals until the compiler + executor bindings replace them and Pass 5 deletes the wrappers/kernels. The + fixed forward scan and fixed reverse table calls were later disabled from the + temporal active path. +- Deleted the monolithic + `ops/temporal_backward/temporal_backward_cuda.py` compatibility file. The + remaining compatibility extension code is split by semantic surface into + `reverse_table.py`, `materialization.py`, and `reductions.py`, with shared + extension loading in `extension.py`. This is still compatibility debt, but it + is no longer one broad temporal-backward API file. +- Tightened strategy selection so matching is per canonical row group rather + than once per strategy pattern. Multi-bucket fabrics now produce one matched + strategy candidate per matching bucket, and candidates carry typed legality + reason strings instead of a single hidden blocker. +- Added the first reverse launch-plan metadata pass; this was later split into + `TemporalBackwardExecutablePlan` for compiler-owned reverse executor-binding + rows and `TemporalBackwardCompatibilityLaunchPlan` for the active fixed slot + bridge. +- Superseded checkpoint: tightened the backward compatibility launch so + compatibility was no longer + described as a gateway. The compatibility plan reports + `cuda_temporal_reverse_table_extension_compatibility`, records + `strategy_legality_status`, and emits typed legality blockers such as + `UNVERIFIED_REWRITE` until the fixed reverse table extension is replaced + with registered compiler-selected backward executors. Later registered + reverse program work deleted that active compatibility bridge. +- Removed the attempted `ops/temporal_backward/fusion_gateway.py` wrapper. That + gateway only renamed old `try_*` calls and was rejected as hidden legacy. + The follow-on `transition_execution/temporal_fusion.py` no-op registry was + also deleted. A later hard-closure cut deleted + `transition_execution/lowering.py` entirely; transition math now reaches CUDA + through registered temporal program dispatch, not a direct eager lowering + sibling. +- Removed the direct Python wrappers and direct pybind exports for + `try_gated_logspace_recurrence_core_backward_window_cuda`, + `try_gated_logspace_recurrence_core_recurrent_affine_backward_window_cuda`, + and `try_diagonal_recurrence_core_backward_window_cuda`. The long CUDA tests + that preserved those direct APIs as step-loop oracles were deleted. The + dormant standalone C++/CUDA core wrapper functions and kernels were also + deleted from `flat_bucket_temporal_backward_kernels.cu`; the active reverse + table remains the only compatibility reverse launch for this surface. +- Removed stale transition audit helpers that only recorded the deleted + gated/diagonal temporal core-window fusion routes. The remaining transition + lowerer records compiled executor selection, not a fake temporal fusion + launch. +- Removed the optional direct + `try_gated_message_reverse_table_window_cuda(...)` reverse fast path, its + pybind export, and the matching C++ wrapper. The later hard cutoff also + removed the fixed + `try_transition_message_reverse_table_window_cuda(...)` call from + `temporal/reverse_executor.py`; reverse windows now reject before any fixed + reverse compatibility kernel can launch. True Pass 3 closure requires adding + the registered compiler-selected backward executors that replace that disabled + route, then deleting the remaining pybind/C++ compatibility wrapper. +- Removed direct recurrent-message backward Python wrappers and direct pybind + exports for `recurrent_message_initial_kv_backward`, + `recurrent_message_initial_kv_backward_window`, + `recurrent_message_initial_kv_backward_window_with_state`, and + `recurrent_message_table_backward_window_with_state`. Their low-level wrapper + tests were deleted. The dormant standalone recurrent-message initial-K/V C++ + functions and their private message kernels were also deleted from + `flat_bucket_temporal_recurrent_backward_kernels.cu`; that file now keeps + only the exported recurrent K/V weight-gradient reducer used by the active + compatibility path. +- Removed two more unused direct temporal-backward glue APIs: + `try_merge_temporal_carry_output_grad_step_cuda(...)` and + `try_reduce_temporal_recurrent_query_grad_cuda(...)`. The merge pybind export + and kernel were deleted because no active compiler path called them. The + remaining generic tensor reduction helper now uses a + `reduce_temporal_tensor_grad` pybind/CUDA entrypoint instead of the stale + recurrent-query naming until parameter reductions are moved behind registered + executor bindings. +- Updated the pass statuses accordingly. The active-route cleanup from Passes + 1-3 remains useful and tested, but those passes are no longer described as + fully closed under the stricter compiler bar until verifier/effect/strategy + contracts are implemented as first-class compiler products. +- Renamed the reverse CUDA Python boundary from primitive-row wording to + executor-row wording: `temporal_reverse_executor_rows_tensor(...)`, + `reverse_executor_rows`, and the CUDA wrapper keyword now match the actual + compiler product being passed into the current fixed compatibility ABI. +- Static guardrails now reject reintroducing the old + `temporal_reverse_primitive_rows_tensor(...)` function name and require the + temporal reverse executor to call CUDA wrappers with `reverse_executor_rows`. +- Moved the fixed reverse compatibility launch from string role lookup to a + compiler-owned tensor-slot table, then disabled that launch from the active + reverse executor. `TemporalBackwardCompatibilityLaunchPlan` still records the + fixed slot table as compatibility debt, but `temporal/reverse_executor.py` + no longer builds the plan or passes `tensor_slot_rows` into a fixed reverse + kernel. +- Applied the same slot-table cleanup to the forward CUDA scan compatibility + launch, then disabled that launch from the active forward/recompute paths. + `TemporalForwardCompatibilityLaunchPlan` still records global and + executor-local tensor slot rows as compatibility debt, but the fixed forward + scan Python wrapper, pybind binding, and CUDA kernel source have now been + deleted. `temporal/forward_scan.py` no longer calls + `try_flat_bucket_temporal_scan_cuda(...)` or passes + `forward_compatibility_launch_plan.tensor_slot_rows`. +- Deleted the now-dead `kTf*` fixed forward scan slot constants, then deleted + the whole fixed forward scan compatibility source set: + `flat_bucket_temporal_scan_cuda.py`, + `flat_bucket_temporal_scan_binding.cpp`, and + `flat_bucket_temporal_scan_kernels.cu`. There is no remaining forward scan + compatibility kernel for the temporal path to fall back to. +- Added a non-compatibility compiler executor-binding product: + `compiler/executor_bindings.py` builds forward and reverse + `TemporalExecutorBindingPlan` rows directly from primitive rows, executor + rows, and `TemporalTensorBindingRow` records. Runtime metadata now records + forward/reverse executor binding row tensors, summaries, and blockers, and + the verifier treats missing executor tensor bindings as + `MISSING_REQUIRED_BINDING` errors. This is the compiler launch contract that + should replace the fixed scan/reverse wrapper signatures; it does not use + `kTf*` slots, gated/diagonal start arguments, or compatibility role maps. +- Wired the strategy-selection report to those executor-binding plans. + Strategy candidates now report `binding_rows=` and `binding_blockers=` for + their matched row group before cost selection. This keeps legality/cost + explainability tied to actual compiler tensor bindings instead of pattern + matches alone. +- Tightened the naming around that state: forward/backward launch summaries now + say `compiler_projected_fixed_compatibility`, not `compiler_owned`. The + compatibility launch plans are audit/debt records, not active launch records: + temporal forward/recompute/reverse now fail closed until registered primitive + executor bindings replace the fixed scan and reverse superops. +- Added compiler-binding proof checks for the forward compatibility slot table. + The slot rows are still numeric compatibility slots, but any slot sourced from + message/readout/transition parameters must now prove the corresponding + compiler tensor binding exists before launch. A missing binding for the fixed + projection slot, for example `projection.value_to_output_weight`, fails closed + instead of silently relying on slot order. +- Added the same compiler-binding proof checks to the reverse compatibility + slot table. Reverse slots for message query/K/V, recurrent K/V projection, + recurrent message replay, and gated/diagonal transition parameters now prove + their compiler tensor bindings before the reverse launch plan is built. A + missing binding for a fixed reverse slot such as `message.recurrent_q` fails + closed with the missing compiler binding instead of relying on slot order. +- Focused verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/ops/temporal_backward tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/ops/temporal_backward tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_reverse_executor_rows_are_selected_from_pattern_registry tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_allow_one_transition_bucket tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names --tb=short` + passed: 4 tests. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_compiler_has_named_verifier_effects_and_typed_rejections tests/test_fabric_backend_boundaries.py::test_forward_scan_records_temporal_table_metadata_through_compiler_helper tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers tests/test_fabric_backend_plan.py::test_temporal_compiler_verifier_reports_effects_and_typed_legality_blockers --tb=short` + passed: 4 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_strategy_patterns_declare_legality_cost_runtime_contracts tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_strategy_patterns_declare_legality_cost_runtime_contracts tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_temporal_strategy_matching_uses_canonical_row_group_schema --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_strategy_selection_has_separate_legality_and_cost_phase tests/test_fabric_backend_boundaries.py::test_forward_scan_records_temporal_table_metadata_through_compiler_helper tests/test_fabric_backend_plan.py::test_temporal_strategy_selection_separates_match_legality_and_cost tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers --tb=short` + passed: 4 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 99 tests. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed: 4 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_boundaries.py::test_temporal_forward_compatibility_launch_plan_owns_scan_tensor_slots --tb=short` + passed: 2 tests after the forward executor-local tensor-slot handoff. + `git diff --check` passed. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_verifier_plan_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_verifier_plan_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `361.93s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_verifier_effects_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_verifier_effects_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `363.36s`. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=1 TRITON_CACHE_DIR=/tmp/cortical_triton_verifier_effects_tgt1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_verifier_effects_tgt1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 1 test in `363.19s`. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_plan.py` + passed after the memory-plan slice. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_plan.py` + passed after the memory-plan slice. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 100 tests. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed: 4 tests. + `git diff --check` passed. + +Code log - 2026-04-30 registered executor activation: + +- Implemented the active registered forward executor path in + `temporal/registered_executors.py`. `temporal/forward_scan.py` now builds the + compiler forward/backward executable plans and runs + `run_registered_temporal_forward_executor_scan(...)` over scheduler steps and + compiler binding rows instead of failing closed with + `registered_forward_executor_kernel_not_implemented`. +- Implemented the stored-artifact registered reverse executor path in the same + module. `temporal/reverse_executor.py` now calls + `run_registered_temporal_reverse_executor_window(...)` for artifact windows + and no longer depends on a fixed reverse compatibility launch plan. +- Deleted the forward/reverse compatibility launch-plan modules and removed + fixed tensor-slot helpers from the compiler path. A later cleanup deleted the + remaining `compiler/compatibility.py` debt ledger entirely; deleted entrypoint + history is now enforced by guardrails and git history, not a live module. +- Removed the fixed one-gated-plus-one-diagonal scan compatibility projection + from active bucket selection. Primitive tables now expose transition + recurrent bucket kinds, allowing many-bucket plans to proceed through + registered strategy selection rather than hitting a cardinality ABI limit. +- Changed executor pattern records from fixed-ABI blockers to registered + implementation contracts. Strategy selection now reports legal candidates + for implemented message/readout/gated/diagonal row groups, and primitive + dispatch reports implemented registered executor groups instead of + `fixed_composite_abi`. +- Remaining after this slice: CUDA throughput kernels that execute the + registered plan without the current per-step Python orchestration. +- This supersedes earlier 2026-04-30 compatibility-plan notes below that + described `registered_*_not_implemented` fail-closed placeholders and + `TemporalForwardCompatibilityLaunchPlan` / + `TemporalBackwardCompatibilityLaunchPlan`. Those were accurate at that + checkpoint, but the launch-plan modules and active placeholders are now + deleted. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: `111` tests. + `uv run pytest -q tests/test_fabric_runtime.py --collect-only` + collected `314` tests. + `git diff --check` and `git diff --cached --check` passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/ops/temporal_backward tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed after deleting the fixed forward scan source set and fixed reverse + table wrapper/kernel export. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/ops/temporal_backward tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed after the same deletion. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 111 tests after the same deletion. + `MAX_JOBS=1 CUDA_VISIBLE_DEVICES=0 TRITON_CACHE_DIR=/tmp/cortical_triton_strategy_selection_t1 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_strategy_selection_t1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts --tb=short` + passed: 1 test in `358.98s`; this run started before the memory-plan + metadata slice, so it is evidence for the strategy-selection slice only. + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed after deleting the rejected fusion gateway. + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed after deleting the rejected fusion gateway. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_transition_temporal_fusion_uses_generic_executor_requests tests/test_fabric_backend_boundaries.py::test_reverse_executor_rows_are_selected_from_pattern_registry tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers --tb=short` + passed: 4 tests after deleting the rejected fusion gateway. This command is + historical; the guardrail was later replaced by + `test_rejected_transition_temporal_fusion_facades_were_deleted` when the + transition-side fusion facade itself was deleted. + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py` + passed after deleting the rejected transition-side temporal fusion facade. + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution/lowering.py tests/test_fabric_backend_boundaries.py` + passed after deleting the rejected transition-side temporal fusion facade. + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/ops/temporal_backward tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed after deleting the direct temporal-backward fusion/core Python APIs + and the transition-side fusion facade. + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/ops/temporal_backward tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed after the same deletion slice. + `uv run pytest -q tests/test_fabric_runtime.py --collect-only` collected + `312` tests successfully after deleting the long direct-wrapper CUDA tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted tests/test_fabric_backend_boundaries.py::test_transition_execution_monolith_was_deleted --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 100 tests. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed: 4 tests. + `git diff --check` passed. + +Code log - 2026-04-30 registered artifact recompute and payload deletion: + +- Wired checkpointed artifact recompute through + `recompute_registered_temporal_artifact_window(...)`. Recompute now rebuilds + normal `TemporalBucketStepArtifacts` from the nearest compiler-owned + checkpoint using the same registered per-step executor used by forward and + backward; it no longer fails closed with + `registered_forward_executor_recompute_requires_stored_artifacts_or_registered_replay`. +- Removed the old reverse-only payload route from active temporal backward. + `TemporalReverseWindowPayload`, `TemporalReverseWindowTables`, + `active_reverse_only`, reverse-window table helpers, and payload-specific + output backward are gone from `sequence_surface/temporal`. Checkpointed + windows now load as stored artifacts or registered recompute artifacts only. +- `TemporalPhysicalBackwardScanExecutor` now verifies the registered reverse + executable plan for artifact windows and rejects stale primitive-table + fingerprints directly; there is no payload admission branch or CUDA reverse + table escape hatch. +- Updated tests at the time to assert registered recompute wiring, payload + deletion, and the absence of the old recompute/replay placeholder. Later + compiler tensor-store closure superseded the + `registered_temporal_executor_recompute` owner with + `registered_fused_forward_program_tensor_store_direct` plus + `registered_fused_reverse_program_tensor_table`. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/output_backward.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/windows.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + +Code log - 2026-04-30 temporal-backward helper extension deletion: + +- Removed the remaining `ops/temporal_backward` Python/CUDA extension surface: + `materialization.py`, `reductions.py`, `extension.py`, + `flat_bucket_temporal_backward_binding.cpp`, + `flat_bucket_temporal_backward_kernels.cu`, and + `flat_bucket_temporal_recurrent_backward_kernels.cu`. +- Replaced active output-gradient window materialization, boundary-gradient + accumulation, recurrent-state gradient materialization, generic temporal + tensor reductions, and recurrent K/V projection weight-gradient reduction + with temporal scheduler/backward-local tensor operations and existing + registered parameter binding paths. There is no remaining active + `ops.temporal_backward` import. +- Materialization/reduction helper wrappers are no longer open compatibility + debt. A later cleanup removed the remaining `compiler/compatibility.py` debt + ledger entirely after the fixed scan/reverse compatibility entrypoints were + deleted from active code. +- Updated runtime/source guardrails so deleted helper wrappers and pybind/CUDA + files cannot be reintroduced as hidden launch paths. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 111 tests. + `uv run pytest -q tests/test_fabric_runtime.py --collect-only` collected + 314 tests. + `git diff --check && git diff --cached --check` passed. + A source scan for `ops.temporal_backward`, deleted `try_materialize_*`, + `try_reduce_*`, and temporal-backward pybind/kernel filenames now finds only + negative tests and audit metadata strings, not active imports or launch + callsites. + +Code log - 2026-04-30 registered owner metadata cleanup: + +- Replaced the remaining sequence-surface planner owner labels that still said + `cuda_temporal_superop` with registered executor owners: + `registered_fused_forward_program_cuda`, + `registered_reverse_executor_bindings`, and + `registered_temporal_executor_bindings`. +- Registered forward execution now records + `flat_bucket_temporal_scan_binding_abi:registered_executor_binding_rows` + and `flat_bucket_temporal_scan_primitive_rows:compiler_primitive_rows`. + Runtime defaults no longer invent a CUDA superop owner if the registered + executor did not set one. +- Reverse-scan validation now treats `cuda_temporal_superop` and the old + reverse-table ABI as legacy claims that raise, rather than as acceptable + owners requiring a compatibility binding. +- Runtime expectations from that slice asserted the temporary registered loop + implementation. The current tree has since renamed that implementation to + `registered_temporal_fused_forward_program_cuda` so the registered program cannot + be mistaken for fused CUDA closure. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/temporal_plan.py src/cortical/fabric/backend/planner.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/common.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/temporal_plan.py src/cortical/fabric/backend/planner.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/common.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py --tb=short` passed: + 81 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 111 tests. + `uv run pytest -q tests/test_fabric_runtime.py --collect-only` collected + 314 tests. + +Code log - 2026-05-01 registered executor program hardening: + +- Added `RegisteredTemporalExecutorProgram` as the runtime object that joins + the primitive-table fingerprint, bucket count, forward/backward executable + plans, and compiler-owned tensor binding rows. Active forward scan, + checkpoint recompute, and reverse artifact windows now build and pass this + program before executing any registered temporal step. +- The registered executor program validates executable-plan legality, stale + binding-row mismatches, message/readout executor ownership, and transition + forward/reverse coverage against `TemporalPrimitiveTablePlan` bucket + cardinality. This closes the previous loophole where registered execution + could run against independently selected helpers without proving that the + compiled executor/binding program covered every transition bucket. +- Forward/recompute execution and reverse execution now iterate transition + buckets from the compiled program rather than `runtime._population_names`. + Population names remain state keys only; executor coverage is a compiler + product derived from flat-bucket primitive rows. +- Source guardrails were tightened so the active forward/reverse paths must + call `build_registered_temporal_executor_program(...)` and pass + `executor_program=executor_program` into registered execution. The guardrails + also assert forward and reverse transition coverage helpers exist. +- Added compiler source bindings for runtime tensors consumed by registered + message/readout execution: recurrent query, input/recurrent K/V projection + weights, output query, value-to-output weight, and output bias ownership. + `RegisteredTemporalExecutorHandle` now resolves those tensors through + `static_tensor:` and `runtime_attr:` source bindings. The active + `registered_executors.py` path no longer reads `static_tensors["output_q"]`, + `static_tensors["recurrent_q_backend_order"]`, or + `static_tensors["value_to_output_weight"]` directly. +- Checkpointed artifact recompute now resolves input K/V projection weights + through the registered message executor binding rows as well. This keeps the + recompute path aligned with the same compiler program that forward and + reverse validate. +- Remaining after this slice: implement fused CUDA throughput executors that + consume the same `RegisteredTemporalExecutorProgram`; make memory/liveness + planning allocate launch workspace instead of reporting audit summaries only; + and move `transition_execution` from exact gated/diagonal composite + recognition to declared-op lowering with typed unsupported-op rejection. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal src/cortical/fabric/backend/cuda/sequence_surface/compiler tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_tensor_binding_rows_are_compiler_products --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_executor_bindings_fail_when_compiler_binding_is_missing tests/test_fabric_backend_plan.py::test_reverse_executor_bindings_fail_when_compiler_binding_is_missing tests/test_fabric_backend_plan.py::test_temporal_scan_binding_projection_fails_for_consistent_message_signature_drift tests/test_fabric_backend_plan.py::test_temporal_scan_binding_projection_fails_for_readout_signature_drift --tb=short` + passed: 4 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 111 tests. + +Code log - 2026-05-01 executable memory plan and transition registry slice: + +- Promoted the temporal memory/liveness plan from audit metadata into an + executable artifact policy. `TemporalMemoryRuntimeArtifactPlan` now carries + the planned artifact mode, checkpoint stride, recompute window, checkpoint + owner, reason, and workspace aliases derived from compiler memory entries + and the runtime scheduler plan. +- Active registered forward execution consumes that memory artifact plan. The + `TemporalArtifactStore` mode, stored-step-vs-recompute choice, checkpoint + stride, and recompute window now come from compiler memory planning instead + of the old ad hoc `registered_checkpoint_stride` / + `registered_recompute_window_len` values. When the plan selects checkpoint + recompute, registered forward stores checkpoints and avoids retaining the + full `artifacts_by_step` list. +- `RegisteredTemporalExecutorProgram` now owns the compiler memory/liveness + plan next to the forward/backward executable plans and binding rows, so the + active registered path has one program object for executor, binding, and + artifact-memory policy. +- Converted `transition_execution` executor selection from direct gated/diag + branch recognition into registered structural records. The selector now + walks `TransitionProgramExecutorRecord` entries with declared primitive + sequences, state slices, tensor edges, message inputs, and arity constraints. + Unsupported compiled transition programs fail closed with typed + `NO_REGISTERED_TRANSITION_EXECUTOR` instead of falling through an implicit + family route. +- Split message/readout primitive math out of the temporal registered scan + orchestrator. `registered_executors.py` now calls row-owned + `surface_executor_runtime.py` functions for recurrent message forward, + readout forward, output-message backward, and recurrent-message backward. + The temporal scan file no longer calls `_compute_messages_*`, + `_run_backend_message_backward_phase`, or the fused readout kernel directly; + those names live under the registered surface executor implementation + boundary selected by compiler handles. +- Added `TemporalFusedCudaProgramPlan` as the compiler-owned legality boundary + for the program-direct CUDA launch. Runtime metadata and backend execution + records report fused program status and blocker code. The plan now becomes + `legal` when registered forward and reverse span dispatch bodies are present; + typed transition primitive blockers still fail closed before launch. +- Added the first program-level fused CUDA ABI boundary: + `registered_temporal_fused_forward_program_validate_cuda` and + `registered_temporal_fused_backward_program_validate_cuda` now consume + compiler-owned primitive rows, forward/reverse executor rows, + forward/reverse binding rows, and memory-liveness rows. The C++ side + validates the row schema and decodes forward/reverse executor span tables + keyed by direction, executor row, primitive row range, binding row range, and + memory-plan coverage. This is intentionally named as validation rather than + execution until the fused program body exists; it is the required C++ program + decoder that the fused execution body must consume next. +- Added the first active fused forward execution entrypoint: + `registered_temporal_fused_forward_program_cuda` accepts a real CUDA + `boundary_seq`, compiler rows, executor binding rows, memory-liveness rows, + and a compiler tensor table. The entrypoint now owns the no-artifact + output-cells forward sequence for supported rows: it launches input K/V + projection once for the sequence, loops the temporal program in C++ over + message, transition, recurrent K/V, and readout spans, and emits output cells + without returning to the Python per-step orchestrator. It is deliberately + gated to inference-style rows with no resets, no final-state materialization, + and no artifact collection until the reverse/artifact program body lands. +- Tightened the blocker so it now comes from the transition primitive executor + registry, not from a temporal-side hardcoded label. Each + `TransitionPrimitiveExecutorRecord` declares its current per-step CUDA + executor, the required program-layer forward/backward C++ symbols, and a + program-layer blocker. `TransitionProgramExecutorPlan` and + `TemporalFusedCudaProgramPlan` now report the missing program symbols such + as `program_transition_linear_forward`, + `program_transition_gated_logspace_recurrence_forward`, and + `program_transition_diag_rtu_backward`. This makes the next closure step + concrete: implement and register those program-layer symbols, then flip the + corresponding primitive records to `program_layer_status=callable`. +- Implemented the first real program-layer transition primitive: + `program_transition_linear_forward` and + `program_transition_linear_backward` are now C++/CUDA symbols in the + registered fused-program extension and are exposed through + `registered_program_transition_linear_forward_cuda` and + `registered_program_transition_linear_backward_cuda`. They consume compiler + executor rows and binding rows, validate the selected executor bucket, and + execute receiver-major linear projection/adjoints without going through the + old transition lowering path. The `linear` primitive record now marks + `program_layer_status=callable`; full fused transition remains blocked + because `program_transition_gated_logspace_recurrence_*`, + `program_transition_norm_or_identity_*`, and `program_transition_diag_rtu_*` + are still missing. +- Implemented the second program-layer transition primitive: + `program_transition_recurrent_matmul_forward` and + `program_transition_recurrent_matmul_backward` now execute the gated + recurrent affine matmul shape directly from the registered fused-program + extension, with compiler executor row/binding validation. The `matmul` + primitive record now also marks `program_layer_status=callable`, so the + fused-program transition blocker is narrowed to the true recurrence/norm + primitives rather than affine scaffolding. +- Implemented the third program-layer transition primitive: + `program_transition_gated_logspace_recurrence_forward` and + `program_transition_gated_logspace_recurrence_backward` now execute the + sLSTM core recurrence row directly from compiler executor rows and binding + rows. This is intentionally the core primitive only: + `linear -> matmul -> gated_logspace_recurrence -> norm_or_identity` remains + four compiler rows, so outnorm is still owned by the separate + `norm_or_identity` primitive rather than folded back into a hidden gated + bundle. The `gated_logspace_recurrence` primitive record now marks + `program_layer_status=callable`, and the fused-program transition blocker is + narrowed to `program_transition_norm_or_identity_*` and + `program_transition_diag_rtu_*`. +- Implemented the fourth program-layer transition primitive: + `program_transition_norm_or_identity_forward` and + `program_transition_norm_or_identity_backward` now execute the public-state + normalization row directly from compiler executor rows and binding rows. + This completes the sLSTM transition primitive stack at the program layer: + `linear`, `matmul`, `gated_logspace_recurrence`, and `norm_or_identity` all + report `program_layer_status=callable`. The remaining transition primitive + blocker is now the Axon `diag_rtu` program-layer forward/backward symbols. +- Implemented the fifth program-layer transition primitive: + `program_transition_diag_rtu_forward` and + `program_transition_diag_rtu_backward` now execute the core Axon diagonal + RTU recurrence over compiler executor rows and binding rows. The forward + path materializes `preproj`, `next_hc1`, `next_hc2`, and optional eligibility + trace state; the backward path covers the current compiler contract used by + temporal BPTT, returning cell/state gradients plus reduced `nu_log`, + `theta_log`, `w1`, and `w2` gradients. All currently registered transition + primitives now report `program_layer_status=callable`; the fused CUDA plan's + remaining blocker is no longer transition primitive availability, but the + missing program-level span dispatch body over the decoded forward/reverse + span table. +- Implemented the first program-level fused span dispatcher: + `registered_temporal_fused_forward_transition_program_cuda` consumes a + compiler tensor table, explicit `program_tensor_binding_rows`, primitive + rows, forward executor rows, forward executor binding rows, and + memory-liveness rows. It decodes transition executor spans and dispatches + `linear`, `matmul`, `gated_logspace_recurrence`, `norm_or_identity`, and + `diag_rtu` by primitive opcode and binding rows, writing outputs back into + compiler-owned tensor table slots. This is not a fixed `kTf*` ABI: tensor + access is by compiler binding index and tensor-table mapping. Full fused + temporal forward remains fail-closed until message/readout/layout spans and + reverse spans consume the same program binding model. +- Extended the program-layer `linear` primitive to cover the real sLSTM gate + affine form: rank-4 `[R,Heads,D,4D]` weights plus `[R,4,Heads,D]` bias now + produce `[B,R,4,H]` gate logits directly. The fused transition span parity + test now executes the full sLSTM-style transition chain + `linear -> linear -> matmul -> gated_logspace_recurrence -> norm_or_identity` + from compiler binding rows and a tensor table, rather than testing only an + already-materialized recurrence core. +- Remaining after this slice: fused CUDA throughput executors still need to + execute complete decoded message/readout/layout/reverse span groups directly; + memory planning still needs real launch workspace allocation/alias buffers + beyond artifact policy; adding a new transition executor still requires an + executor implementation, but selector legality is now registry-local rather + than hardwired. +- Verification for this slice: + `uv run python - <<'PY' ... registered_temporal_fused_forward_program_validate_cuda(...) ... PY` + passed and returned schema summary plus `[executor_count, 12]` decoded forward + and reverse span tables. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_span_rank4 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 2 tests, covering rank-4 gate-affine program linear + forward/backward and the full compiler-bound sLSTM transition span. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/transition_execution/program.py src/cortical/fabric/backend/cuda/transition_execution/lowering.py tests/test_fabric_backend_plan.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/transition_execution/program.py src/cortical/fabric/backend/cuda/transition_execution/lowering.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_plan.py::test_transition_program_compiler_uses_cuda_nn_primitive_registry tests/test_fabric_backend_plan.py::test_transition_program_compiler_rejects_unsupported_transition_op tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 6 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted --tb=short` + passed: 4 tests after the program-layer transition blocker was moved into + the primitive registry. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted --tb=short` + passed: 4 tests, including CUDA parity for the new + `program_transition_linear_forward` and + `program_transition_linear_backward` primitives against torch `einsum` + forward/autograd backward. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_recurrent_matmul_cuda_uses_compiler_rows --tb=short` + passed after validating the new recurrent matmul program-layer forward and + backward kernels against torch autograd. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_gated_logspace_recurrence_cuda_uses_compiler_rows tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests, including CUDA parity for the core + `program_transition_gated_logspace_recurrence_forward` and + `program_transition_gated_logspace_recurrence_backward` primitives against + torch autograd. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_norm_or_identity_cuda_uses_compiler_rows tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests, including CUDA parity for the + `program_transition_norm_or_identity_forward` and + `program_transition_norm_or_identity_backward` primitives against torch + autograd. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_diag_rtu_cuda_uses_compiler_rows tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 4 tests, including CUDA parity for the + `program_transition_diag_rtu_forward` and + `program_transition_diag_rtu_backward` core recurrence primitives against a + torch reference/autograd. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 121 tests after adding the core gated-logspace, + norm-or-identity, and diag-RTU program-layer primitives. + +Code log - 2026-05-01 registered surface executor ownership slice: + +- Removed the remaining direct primitive/runtime calls from the active + `registered_executors.py` temporal loop. It now contains no direct `try_*` + CUDA calls, `_compute_*` message calls, `_project_*` public-projection calls, + or `_run_backend_*` primitive/backward calls. +- Added registered surface executor functions for recurrent K/V projection + forward/backward, input K/V replay projection, transition forward/backward, + readout projection backward, fallback grouped recurrent K/V backward, + recurrent-query parameter backward, boundary-public backward, initial + recurrent backward, and cell-layout assembly. These functions consume the + same compiler-selected executor handles that the temporal program validated, + so primitive math remains behind the row-owned surface executor boundary + instead of leaking back into the scheduler/orchestrator. +- The active temporal loop still schedules per-step execution and artifact + windows, but transition/message/readout/projection/layout implementation + calls now live in `temporal/surface_executor_runtime.py`. This is not fused + CUDA throughput closure; it is the active-path ownership cleanup needed + before a program-direct fused CUDA launch can consume the same registered + executor/binding program without inheriting temporal-side primitive math. +- Source guardrails now assert that `registered_executors.py` delegates to the + registered surface executors and does not call deleted fixed scan/reverse + ABIs, direct message kernels, direct recurrent public-projection kernels, + direct transition bucket helpers, direct output/query/boundary/initial + backward helpers, or direct layout CUDA helpers. +- Remaining after this slice: implement fused CUDA throughput executors that + consume `RegisteredTemporalExecutorProgram` and the compiler memory plan + directly; make the memory/liveness plan allocate and alias launch workspace; + keep expanding the transition executor registry so new declared ops are + local registry/lowering additions with typed fail-closed rejection. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 112 tests. + +Code log - 2026-05-01 registered executor fail-closed/program-executor split: + +- Removed silent registered-surface fallbacks that could hide unsupported + executor shapes behind old runtime helper paths. Recurrent K/V projection now + raises when the registered CUDA projection executor cannot implement the + compiled tensor roles/shapes; readout now raises when the fused CUDA readout + executor is not legal for the compiled roles/shapes; cell layout now raises + instead of falling back to `torch.cat` assembly. +- Deleted the grouped/raw recurrent K/V backward and initial recurrent raw + backward fallback functions from the registered surface runtime. If recurrent + K/V projection backward has actual K/V gradients and the registered CUDA + backward executor cannot handle the compiled role/shape contract, execution + fails closed instead of invoking the old grouped/raw backward helper. +- Deleted the unreferenced old `run_temporal_output_backward_sequence(...)` + host implementation. It still carried direct output-message/query/recurrent + projection backward helper calls, and registered reverse windows now bind + those gradients through `run_registered_temporal_reverse_executor_window(...)`. +- Invalidated the stale deferred query/KV/transition parameter-reduction route + in `temporal/reverse_executor.py`. Registered reverse windows must now bind + parameter gradients inside each window; returning deferred reductions raises + instead of launching old query/KV/transition reducer helpers after the window + loop. +- Added `TemporalRegisteredProgramExecutorPlan` next to + `TemporalFusedCudaProgramPlan`. Runtime metadata now explicitly separates + the current registered executor program from the still-blocked fused CUDA + program launch and records the program demotion policy as + `fail_closed_no_legacy_abi_or_unregistered_executor_demotion`. +- The active scan implementation metadata was renamed from + `registered_temporal_executor_loop` to + `registered_temporal_fused_forward_program_cuda`. This is intentionally not + a closure claim: it makes the registered executor-program status visible + until the fused CUDA program consumes `RegisteredTemporalExecutorProgram` + directly. +- Source guardrails now assert the registered executor program does not retain the + deleted grouped/raw projection fallback hooks, old sender projection fallback + tag, output-message projection fallback tag, `torch_cat` layout fallback, + unreferenced output-backward host sequence, or deferred parameter binders. +- Remaining after this slice: the hard Pass 2/3/4 blocker is still implementing + program-direct fused CUDA forward/backward executors over + `RegisteredTemporalExecutorProgram`, executor/binding rows, and the compiler + memory/liveness plan. The registered executor program is explicit, + fail-closed, and not allowed to masquerade as fused-throughput closure. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_allow_one_transition_bucket --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_allow_one_transition_bucket --tb=short` + passed: 4 tests after the stale output-backward/deferred-reducer deletion. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 112 tests. + `git diff --check` passed. + +Code log - 2026-05-01 registered program executor and memory-window fail-closed cleanup: + +- Removed the active-path legacy interpreter contract naming. The compiler + now emits `TemporalRegisteredProgramExecutorPlan`, runtime metadata records + `_last_flat_bucket_temporal_registered_program_executor_*`, and scan + implementation records use `registered_temporal_fused_forward_program_cuda`. + This keeps the active path honest: it is a registered executor program over + compiler rows/bindings, not a hidden interpreter or legacy fallback. +- Replaced program/launch fallback wording with explicit demotion policy: + `fail_closed_no_unregistered_program_demotion` for the fused CUDA launch + contract and + `fail_closed_no_legacy_abi_or_unregistered_executor_demotion` for the active + registered executor program. +- Made temporal backward require the executable memory plan's + `artifact_store.backward_windows`. Missing or non-covering windows now raise; + the old local `temporal_artifact_windows(...)` stride/window helper was + deleted from `temporal/windows.py`, so reverse execution cannot silently + reconstruct scheduler policy from legacy stride fields. +- Runtime execution records now include the registered program executor plan + alongside the fused CUDA launch contract and fused-program blocker. Reviewers + can see both the active registered executor program and the remaining fused + CUDA blocker in the same launch metadata. +- Remaining after this slice: implement program-direct fused CUDA + forward/backward kernels over the same executor rows, binding rows, and + memory plan; make workspace alias/lifetime entries allocate actual buffers + and enforce CUDA graph constraints; continue extending transition lowering by + registry-local reference/optimized executors for new declared ops. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/windows.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/windows.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_backward_requires_compiler_planned_memory_windows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 115 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `git diff --check` passed. + +Code log - 2026-05-01 temporal primitive ABI registry cleanup: + +- Moved temporal primitive opcodes, scheduling-surface opcodes, binding-kind + opcodes, transition tape labels, and full-tape extra-state factors out of + `compiler/tables.py` into `compiler/primitive_registry.py`. Temporal table + construction now asks the primitive registry for opcode/tape facts; adding a + new primitive no longer requires editing table-local fixed maps. +- Added `executor_id` to `TemporalForwardExecutorPattern` and removed the + table-local forward-executor opcode map. Forward executor row IDs now come + from registered strategy records, matching the reverse executor side. +- Added guardrails proving `tables.py` no longer owns + `_TEMPORAL_PRIMITIVE_OPCODE_BY_NAME` or + `_TEMPORAL_FORWARD_EXECUTOR_OPCODE_BY_NAME`, and that primitive row tensor + opcodes match `temporal_primitive_opcode(...)`. +- Remaining after this slice: the primitive registry still needs to grow into + the permanent add-op workflow with reference executors, optional fused CUDA + strategies, and compiler-boundary tests for every new primitive; the current + slice removes a drift point but does not itself implement additional math. +- Verification for this slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fabric_cuda_nn_primitives_are_surface_independent_temporal_rows tests/test_fabric_backend_boundaries.py::test_temporal_executor_bindings_are_compiler_products_not_compatibility_slots --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 115 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `git diff --check` passed. + +Pass 2 - universal strategy registry and forward executor table: + +- Replace the monolithic forward scan argument projection with executor rows + and tensor binding rows for all supported universal primitives: + `linear`, `matmul`, `add`, `attention_logits`, `segment_softmax`, + `weighted_sum`, `readout_project`, `reduction_boundary`, + `norm_or_identity`, `gated_logspace_recurrence`, and `diag_rtu`. +- Build a structural strategy registry over canonical row groups, not + string/name bundles. Matching uses ops, tensor roles, dependency edges, state + read/write effects, emissions, reset behavior, layout constraints, and + gradient requirements. +- Split strategy selection into legality filtering and cost/ranking. A + rejected strategy records a typed reason; the cost model only ranks legal + strategies. +- Every forward/fused strategy declares owner, stable ID/version, row schema + version, tensor-binding schema version, metadata schema version, + CUDA-kernel ABI version, workspace/layout requirements, and failure modes. +- Message, readout, transition, and parameter reduction are scheduling + surfaces over the same primitive vocabulary. Tests must prove the same + primitive opcode/executor identity can appear on different surfaces. +- Keep composite recurrence primitives explicit and narrow: + `gated_logspace_recurrence` for sLSTM recurrence math, `diag_rtu` for the + custom Axon recurrence/trace math. Projections, normalization, activation, + readout, and parameter reductions stay separate primitive rows. +- Keep a reference primitive executor for every supported row group. Optimized + and fused forward strategies prove against that reference and never redefine + semantics. +- Current remaining work from the 2026-04-30 tree assessment: + - The registered forward executor path over `TemporalExecutorBindingPlan` + rows is now active for stored-artifact scans. It executes message, readout, + transition, and recurrent K/V projection through compiler-selected + primitive executor bindings and existing primitive CUDA/runtime kernels. + As of 2026-05-01, active forward/recompute must first build a + `RegisteredTemporalExecutorProgram`; stale binding rows or missing + transition bucket coverage reject before execution. Message/readout + tensors used by the registered executor are now resolved from compiled + `static_tensor:` / `runtime_attr:` source bindings, not direct temporal + loop role-name lookups. + - The old fixed C++ executor-ID validation was deleted with + `flat_bucket_temporal_scan_kernels.cu`; new dispatch is driven by + registered strategy records and tensor binding rows. + - The one-gated-plus-one-diagonal compatibility projection has been removed + from active selection. Many transition buckets are represented as flat + bucket cardinality in the primitive table and strategy report. + - As of 2026-05-01, the compiler also emits a + `TemporalFusedCudaLaunchContract` from forward executor rows, reverse + executor rows, executor binding rows, and memory liveness entries. This is + the only allowed ABI for the future fused CUDA program launch. It names + required compiler tables and explicitly uses + `fail_closed_no_unregistered_program_demotion`; it does not introduce fixed tensor + slot enums or gated/diagonal temporal-side slot identities. + - Remaining: implement the fused CUDA forward program kernel that consumes + this launch contract directly. Until then the registered program + executor is an explicit fail-closed program path, not fused CUDA closure. + - As of 2026-05-01, transition execution has a lowering-dispatch registry: + the shared forward/backward entrypoints call a registered + `_TransitionLoweringRecord` rather than branching on + `executor_plan.executor`. New transition executors must add a structural + executor record plus forward/backward lowering implementations, and + unsupported programs still fail closed with + `NO_REGISTERED_TRANSITION_EXECUTOR`. + - Remaining: move the transition executor registry beyond the current + supported composite records by adding reference executors and lowering + metadata for new declared transition ops. The current supported rows are + still the explicit `gated_logspace_recurrence` and `diag_rtu` composites. + +Pass 3 - reverse compiler executor, autodiff contract, and parameter binding: + +- Build the reverse executor table from the same primitive rows, tensor + bindings, and forward artifact policy. Reverse must not infer semantics from + gated/diagonal tensor bundles, Q/K/V names, or bucket family. +- Implement primitive adjoints for message/boundary, projection, normalization, + composite recurrences, and parameter reductions behind row-selected + executors. Deferred parameter gradients must come only from compiled + parameter binding rows. +- Define compiler-owned gradient specs for every supported row group. Backward + strategies implement those specs and declare gradient inputs, saved tensors, + recomputable tensors, tape schema, accumulation behavior, parameter-gradient + outputs, state-gradient outputs, numerical tolerance expectations, and + deterministic/nondeterministic behavior. +- Replace active calls to op-specific temporal backward CUDA wrappers such as + `try_gated_*`, `try_diagonal_*`, and `try_transition_message_*` with + compiler-selected backward strategy bindings. Those wrappers may remain only + as temporary implementation references until Pass 5 deletion. +- Verify forward/backward equivalence through reference execution and parity + gates. A fused backward kernel may be selected only if it preserves the same + observable emissions, reset/replay/checkpoint behavior, materialization + points, and accumulation semantics as the compiler gradient spec. +- Split `temporal_backward.py` semantically while doing this: + reverse artifacts/tables, output-gradient materialization, + registered reverse executor implementations, parameter binding, and autograd + entrypoints. Boundary projection backward must stay in the registered reverse + executor strategy, not in a separate runtime-phase module. Delete the + host-loop scan/recompute path as each reverse executor path lands. +- Current remaining work from the 2026-04-30 tree assessment: + - The registered reverse executor path is active for stored artifact windows. + `temporal/reverse_executor.py` calls + `run_registered_temporal_reverse_executor_window(...)` and no longer calls + `try_transition_message_reverse_table_window_cuda(...)`. + As of 2026-05-01, reverse windows must first build the same registered + executor program and verify reverse transition-bucket coverage before any + backward step runs. + - Reverse executor rows have moved from compatibility validation to actual + dispatch for stored artifacts. Tensor binding rows drive parameter routing; + no fixed role-string reverse table remains in the active path. + - As of the 2026-05-01 reverse-window cut, no-carry/no-reset `output_cells` + training windows first run through + `registered_reverse_program_window`, a compiler-owned window driver over + the selected reverse executor rows and binding rows. This removes the + per-step registered bucket backward call from that supported route without + claiming the full C++ fused reverse sequence kernel is done. + - Registered checkpoint recompute now rebuilds normal artifact windows. + The fused CUDA launch contract now covers reverse executor rows and reverse + executor binding rows too. Remaining: implement the fused CUDA backward + program kernel that consumes that contract directly and preserves the + compiler autodiff/tape/materialization contract. + - Keep deleted direct gated/diagonal/gated-message/recurrent-message wrappers + deleted; do not reintroduce them as facades or "fusion gateways." + +Pass 4 - effect-aware temporal scheduler and memory/layout planner: + +- Refactor the planner data structures into a PyTorch-style execution plan: + emitted output steps, autograd seed surfaces, horizon windows, checkpoint + steps, recompute windows, artifact kind, and materialization requests are + explicit plan records, not ad hoc flags. +- Add an effect system the scheduler can consume without learning primitive + math: state reads/writes, output emissions, tape consumption, + materialization, replay requirements, temporal dependencies, reversibility, + checkpoint support, workspace mutation, aliasing, and layout transforms. +- Elevate memory/layout planning to first-class compiler output: activation + lifetime, tape lifetime, workspace lifetime, aliases, contiguous/strided + requirements, alignment, persistent buffers, temporary buffers, and + recompute-vs-store tradeoffs. +- Training is always streamed. `T=1` is the base physical step; `T*K` is the + same substrate over the temporal axis. Horizon `H` clips reverse windows only + where loss/output gradients are emitted. Checkpoint steps are planner-chosen + unless explicitly provided. +- The temporal kernel/superop executes executor rows over the planned stream. + It does not own primitive math, cell identities, message rule identities, or + special T=1/K-only cases. +- Current remaining work from the 2026-04-30 tree assessment: + - Lift hidden/head/value/degree caps and bucket-cardinality constraints into + planner legality and strategy selection. The fixed scan cardinality limit + has been removed from active selection; remaining numeric caps belong in + strategy legality and cost metadata. + - Ensure many-population fabrics are handled as flat bucket cardinality in the + same plan, not as a separate unsupported route. + - Make the memory/liveness plan executable: workspace allocation, aliases, + checkpoint/recompute windows, tape materialization, and CUDA graph capture + constraints must drive launches instead of being audit-only summaries. + As of 2026-05-01, the memory liveness plan is part of the fused CUDA launch + contract, which blocks any future fused launch from bypassing compiler + workspace/layout/alias policy. Remaining: turn those entries into concrete + allocator/workspace/tape actions in the fused executors. + - Throughput tuning cannot start until scheduler and memory policy feed the + real executor launch path; current guardrails remain roughly + `1055.61 tok/s` and `29.55 GiB`, far from the April21 + `58732.71 tok/s`, `2.07 GiB` reference. + +Pass 5 - governance, explainability, legacy deletion, and audits: + +- Remove old kernels, wrappers, pybind exports, stale planner flags, and tests + for any route replaced by the compiler executor path. The codebase should + shrink after the compiler path is real. +- The old `ops/temporal_backward/temporal_backward_cuda.py` monolith and the + split helper wrappers are deleted. Keep extension loading in registered + strategy entrypoints; primitive-math `try_*` probes and pybind helper launch + surfaces must not return. +- Static guardrails: no cell-family selectors in temporal engine sources, no + hidden surface-specific primitive meanings, no backend/private planner calls + from benchmarks, no old audit fixture loads in unit tests except the official + April21 score-reference loader in benchmark/audit code. +- Add planner explain output before throughput tuning is considered closed: + why each row group lowered the way it did, candidate strategies, typed + rejection reasons, selected strategy, produced schedule, materialization + policy, kernels run, workspace/layout plan, and any fail-closed event. +- Enforce strategy registry governance: stable ID, owner, matching rules, + required tests, benchmark coverage, deprecation/removal criteria, emitted + audit metadata, verifier dependency, and no selection without passing + verifier/legal-strategy checks. +- Parity/audit order after each pass, not only at the end: T=1 high-level + training rows first, including 100M/500M/1B and reset/provided-state axes; + then representative `T=512,K=1,H=64` and `T=4096,K=1,H=64`; then full + T/K/H sweeps (`T=1,2,4,...,4096,16K`, `K=1,2,4,...,128`, + `H=1,2,...,64`) with terminal and per-timestep losses. Use April21 JSON for + score references and current matched T=1 training divided by K for K-scaling + gates. +- Current remaining work from the 2026-04-30 tree assessment: + - Before committing, stage/track the new replacement source directories: + `src/cortical/fabric/backend/cuda/ops/temporal_backward/`, + `src/cortical/fabric/backend/cuda/sequence_surface/compiler/`, + `src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/`, + `src/cortical/fabric/backend/cuda/sequence_surface/runtime/`, + `src/cortical/fabric/backend/cuda/sequence_surface/temporal/`, and + `src/cortical/fabric/backend/cuda/transition_execution/`. + - Keep source guardrails that prevent deleted files from returning: + `sequence_surface/temporal_backward.py`, + `ops/temporal_backward/temporal_backward_cuda.py`, direct gated/diagonal + core wrappers, gated-message wrappers, recurrent-message direct wrappers, + and fake fusion gateways. + - Do not close Pass 5 while any compatibility wrapper source or pybind/C++ + export can launch active temporal work. Current guardrails assert those + wrappers are deleted, and planner/runtime metadata now reports registered + executor owners rather than the deleted CUDA superop label; the remaining + closeout is fused throughput kernels, executable memory/liveness planning, + and benchmark parity against the April21 reference. + - 2026-05-01 code log: added the compiler-owned fused CUDA launch contract + to `compiler/program_execution.py`, threaded it through metadata recording + and active registered executor programs, and added source/runtime guardrails + proving the contract is row/binding/memory-owned rather than a fixed-slot + compatibility launch. Verification: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 113 tests. `git diff --check` passed. + - 2026-05-01 code log: transition shared lowering no longer branches on + `executor_plan.executor`. The selected transition executor now dispatches + through a `_TransitionLoweringRecord` with forward/backward implementations. + The registered temporal surface executor also no longer hard-checks the + current gated/diag transition executor-name pair; it validates transition + surface coverage and compiler parameter bindings, leaving unsupported + transition programs to fail closed at the transition executor registry. + Registered reverse windows now record + `flat_bucket_temporal_reverse_scan_binding_abi:registered_executor_binding_rows` + into the backend execution record before validation. Verification: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 1 test. + - 2026-05-01 code log: removed dead fixed-scan tensor packing helpers from + `temporal/forward_scan.py` (`_gated_temporal_scan_parameter_tensors`, + `_diagonal_temporal_scan_parameter_tensors`, and their empty-state + builders). Those helpers resolved concrete gated/diagonal parameter names + for the deleted scan ABI and are no longer part of the compiler-owned + forward path. Guardrails now assert they do not return to the temporal + forward entrypoint. + - 2026-05-01 code log: active registered temporal execution no longer + imports or calls the concrete surface executor implementation functions + directly from `registered_executors.py`. `RegisteredTemporalExecutorProgram` + now carries a `RegisteredTemporalExecutorKernelRegistry`, and forward, + recompute, and reverse step execution dispatch through that registry using + the selected executor rows and binding rows. The concrete implementation + functions remain in `surface_executor_runtime.py` as strategy + implementations, but the active temporal caller is no longer wired to the + current message/readout/transition function names. Source guardrails assert + those direct calls do not return to `registered_executors.py`. + Verification: + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 113 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + - 2026-05-01 code log: executor implementation-name legality moved out of + the temporal caller and surface runtime. `registered_executors.py` no + longer exposes a general `require_executor(...)` helper and no longer names + the current message, readout, gated, or diagonal executor strings. + `surface_executor_runtime.py` also no longer hard-checks those executor + names. `RegisteredTemporalExecutorKernelRegistry` validates handles against + compiler-selected executor rows and binding rows, and its legal + executor-name sets are derived from `compiler/executor_patterns.py` rather + than duplicated in runtime code. This makes adding a new strategy a + compiler-pattern/lowering/executor action instead of a temporal-caller + edit, while unsupported rows still fail closed before launch. Remaining + compiler-closure blockers are unchanged and still highest priority: + implement program-direct fused CUDA forward/backward executors over the + fused launch contract, make memory/liveness planning executable, and expand + generic transition lowering beyond the current supported composite records. + Verification: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 113 tests. `git diff --check` passed. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + - 2026-05-01 code log: primitive dispatch no longer keeps a separate + hardcoded list of supported composite transition primitives. The declared + composite transition rows considered implemented by + `build_temporal_primitive_executor_plan(...)` are now derived from the + registered forward/reverse transition executor patterns. Adding a new + supported composite transition therefore requires registering its + structural executor pattern and lowering/executor implementation, not + editing an extra primitive-dispatch allowlist. Verification: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch --tb=short` + passed: 2 tests. + - 2026-05-01 code log: transition structural executor records moved out of + `transition_execution/program.py` into + `transition_execution/registry.py`. `select_transition_program_executor` + now reads registered structural records from that compiler-facing registry. + A later cut deleted the direct `transition_execution/lowering.py` runtime + dispatch module, so executor selection and implementation metadata now live + in `program.py`/`registry.py` and registered temporal program dispatch. + Adding a new transition program now requires a structural record plus + registered program implementation coverage, with + `NO_REGISTERED_TRANSITION_EXECUTOR` for unsupported programs. Verification: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records --tb=short` + passed: 2 tests. + - 2026-05-01 code log: memory/liveness planning now emits concrete + checkpoint steps and backward artifact windows in + `TemporalMemoryRuntimeArtifactPlan`. The registered forward executor uses + those planned checkpoint steps directly instead of recomputing checkpoint + placement with local modulo arithmetic, and the reverse executor consumes + `artifact_store.backward_windows` as the only planned backward window + schedule. The old `temporal_artifact_windows(...)` stride/window helper was + deleted, and missing or non-covering windows now fail closed. This moves + checkpoint/recompute window execution into the compiler memory plan; + remaining Pass 4 work is still to turn workspace aliases/lifetimes into + actual allocator and CUDA graph capture constraints. Verification: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + passed: 1 test. + - 2026-05-01 code log: transition lowering implementation names are now part + of the structural executor registry. `TransitionProgramExecutorRecord` + now carries `forward_strategy_id` and `backward_strategy_id`, selected + `TransitionProgramExecutorPlan`s preserve those symbols, and + `transition_execution/program.py` records runtime execution status without + naming a deleted lowering route. This removes another executor-name drift + path: adding a transition program now requires one structural record plus + registered forward/backward program implementation coverage, with + unsupported programs failing closed at selection. + - 2026-05-01 code log: removed the registered temporal path's last + gated-bucket/diagonal-bucket scan-selection helper. `temporal/common.py` + now exposes `_flat_bucket_temporal_table_plan(...)`, and forward, + recompute, and reverse callers consume the compiler primitive table + directly instead of asking for a fixed `(gated_bucket, diagonal_bucket)` + pair. The dead `_transition_tape_by_population_from_cuda_scan(...)` helper + was deleted as well; it still reconstructed backward tapes from + CUDA-scan-specific gated/diagonal tensors and had no active registered + caller. This keeps bucket discovery on compiler rows rather than a + two-transition-family compatibility shape. + - 2026-05-01 code log: transition bucket-kind discovery in + `compiler/tables.py` no longer carries an explicit + `gated_logspace_recurrence` / `diag_rtu` / `diagonal_recurrence` allowlist. + It asks the primitive registry via `temporal_transition_tape_kind(...)` + whether a primitive row is a transition-tape-producing row. Future + transition primitives therefore extend the primitive registry and executor + records instead of editing this table helper. + - 2026-05-01 code log: the active registered transition executor no longer + calls the legacy backend-order transition bucket runners. Forward execution + in `surface_executor_runtime.py` now walks the compiler-selected transition + executor handles, validates each handle against its compiler bucket range, + and, as of the later forward-span slice below, launches the row/binding + owned fused transition program dispatcher per selected transition span. + Backward still uses + `lower_backend_population_transition_backward_shared(...)`, then binds + row-owned transition parameter gradients through + `bind_temporal_transition_param_grads(...)`. Guardrails now reject + `run_backend_order_transition_buckets_step_cached_eager_result(...)`, + `run_backend_order_transition_buckets_backward_step_cached(...)`, and + `runtime._run_backend_order_transition_buckets_*` calls from the registered + surface runtime. This removes the old composite bucket runner from the + active registered transition path without changing the selected primitive + rows or bindings. + - Same slice deletion: after migrating the registered reverse path, removed + the unreferenced cached backend-order transition backward helper exports + from `flat_bucket/flat_buckets.py` + (`run_backend_order_transition_buckets_backward_step_cached` and + `run_backend_order_transition_buckets_backward_step_cached_unbound`). The + remaining registered reverse path owns backend-state tape execution through + selected transition handles plus compiler lowering and parameter binding. + - 2026-05-01 code log: the registered temporal surface executor no longer + delegates message, readout projection, recurrent query parameter, boundary + public projection, or input K/V replay work to runtime-private phase + methods. `surface_executor_runtime.py` now launches the selected primitive + implementations directly from compiler-bound tensor roles: + partitioned message forward/backward CUDA kernels, output projection + gradient reduction, query parameter reduction, and sender K/V projection + backward. Guardrails now reject + `runtime._compute_messages_step_subset_partitioned_raw(...)`, + `runtime._project_sender_kv_from_cells_sequence(...)`, + `runtime._run_backend_message_backward_phase(...)`, + `runtime._run_backend_output_projection_backward_phase(...)`, + `runtime._run_backend_query_param_backward_phase(...)`, and + `runtime._run_backend_boundary_public_backward_phase(...)` from the + registered surface runtime. This is still not the final fused CUDA + throughput executor, but it deletes another active legacy ownership + boundary: registered executor rows now own these primitive launches rather + than borrowing monolithic runtime phases. + - Same slice deletion: `temporal/boundary_backward.py` was removed after the + registered reverse executor stopped using its window helpers. Those helpers + were unreferenced and still delegated to the runtime-private boundary + projection backward phase. Primitive-dispatch missing-coverage records were + also renamed from `*_runtime_path` to explicit registered-executor + requirements, so fail-closed diagnostics no longer imply a hidden runtime + fallback path. + - Same slice cleanup: `temporal/param_binding.py` now only binds transition + parameter gradients. Removed the unused recurrent-query and initial + recurrent parameter-binding sequences because they still called + runtime-private query/KV binding helpers after the registered reverse + executor took ownership of those reductions. + - 2026-05-01 continuation: registered reverse execution no longer binds + recurrent K/V projection raw gradients through + `runtime._sender_kv_projection_param_grad_tuple_from_raw_grads(...)`. + `RegisteredTemporalExecutorKernelRegistry` now owns + `recurrent_kv_projection_param_binding(...)`, which validates the selected + message executor binding and reduces raw recurrent K/V projection gradients + inside `surface_executor_runtime.py`. This removes another runtime-private + binder from the active registered reverse path. + - Same continuation: transition parameter binding no longer calls + `runtime._state_public_explicit_param_grad_tuple(...)`. The temporal + compiler param-binding module now reduces materialized/static transition + parameter gradients into trainable parameter gradients directly from the + compiler accumulator. This keeps registered reverse transition gradients + inside the temporal executor/parameter-binding layer instead of borrowing + the old state/public backward binder. + - Same continuation: `temporal/forward_scan.py` no longer precomputes input + K/V through `runtime._project_sender_kv_from_cells_sequence(...)` before + entering registered execution. The registered forward executor program now + obtains input K/V through its selected message executor and compiler tensor + bindings at scan entry, matching recompute and preventing the forward caller + from owning public-projection primitive work. + - Deleted two unreferenced common helpers that still encoded the old + recurrent K/V projection route: + `_initial_recurrent_kv_backend_order(...)` and + `_temporal_backend_order_recurrent_kv_projection_backward(...)`. They were + dormant after the registered executor migration but still contained runtime + projection fallback and private group-id binding logic, so keeping them + would preserve a misleading legacy sibling. + - 2026-05-01 continuation: the old transition execution lowering path no + longer materializes recurrent K/V by calling + `runtime._project_sender_kv_from_cells_step(...)`; that module has since + been deleted. A later hard-closure continuation deleted the remaining + one-row transition K/V projection plan; transition public K/V + materialization must now use the compiled forward executor rows/binding + rows and the selected message executor identity. + - Primitive-dispatch summaries no longer use `*_runtime_path` names for + implemented message, readout, parameter-reduction, affine-transition, or + transition-normalization rows. The executor labels now point at registered + primitive/transition executors, so compiler diagnostics do not imply that + old runtime fallback paths remain valid execution owners. + - 2026-05-01 continuation: deleted + `sequence_surface/compiler/compatibility.py` and stopped recording + `_last_flat_bucket_temporal_compatibility_debt`. The deleted fixed scan and + reverse-table entrypoints are now protected by source guardrails and git + history, not by a live compiler compatibility ledger that could be confused + with an accepted execution surface. + - Same continuation: the registered temporal surface runtime no longer calls + active `try_*` CUDA primitive probes for recurrent K/V projection, readout, + or graph layout. `flat_bucket_public_projection_cuda.py`, + `flat_bucket_readout_cuda.py`, and `flat_bucket_layout_cuda.py` now expose + strict registered primitive entrypoints that either execute the compiler- + selected CUDA primitive or raise; `surface_executor_runtime.py` calls those + strict functions. The old `try_*` probes remain only for unrelated direct + parity tests/callers until their surfaces are migrated, but the registered + compiler temporal path is off probe/fallback semantics. + - 2026-05-01 hard-closure slice: added a real registered CUDA program + epilogue under `flat_bucket_registered_program_*` and wired it into the + active no-artifact forward path. The kernel consumes + `forward_executor_rows` and `forward_executor_binding_rows`, validates the + compiler-selected readout executor/bindings, and fuses readout projection + with backend-to-graph cell assembly. This is not a facade over the deleted + fixed scan ABI: it has no `kTf*` slot enum and no `InitialGated` / + `InitialDiagonal` temporal-side slots. The artifact/recompute path still + materializes output-message tape through registered primitive executors + because backward needs that tape. + - Same hard-closure slice, reverse side: readout backward is now a first- + class reverse executor surface instead of borrowing the forward readout + row. `executor_patterns.py`, strategy selection, executor rows, binding + rows, and the active backward step now include + `reverse.readout.projection_reduction_boundary.v1` / + `projection_reduction_boundary_backward`. The registered reverse step uses + `reverse_handle(surface="readout", bucket_ordinal=-2)`, and + `flat_bucket_registered_program_*` now includes a registered CUDA reverse + readout/layout projection kernel. That kernel consumes + `reverse_executor_rows` and `reverse_executor_binding_rows`, validates the + compiler-selected readout backward executor/bindings, splits + `grad_cells_out` into boundary/recurrent/output surfaces, and computes the + readout projection adjoints. This deletes the implicit forward-handle + backward authority for readout and makes readout reverse coverage visible + in the compiler plan. + - The registered program Python wrapper now rejects non-CPU/non-int64 + executor row tables instead of silently converting them before launch. The + compiler rows must already satisfy the registered executor ABI when the + CUDA program is called. + - Same hard-closure slice, memory side: `TemporalArtifactStore` now carries + the compiler memory/liveness fingerprint and the runtime artifact-plan + fingerprint. Registered backward recomputes the current primitive table, + memory liveness plan, checkpoint/recompute artifact plan, checkpoint + steps, and backward windows, then fails closed before reverse execution if + the stored artifact plan differs. This removes the old + `checkpoint_stride` / `recompute_window_len` window derivation as an + accepted runtime authority; reverse windows must now be compiler memory + products. + - Same memory continuation: added a compiler runtime-buffer plan and wired + the active registered path through it for the two live buffers that were + still allocated ad hoc. Forward output sequence storage is allocated from + `memory_runtime_buffer_plan` instead of accumulating a Python + `output_steps` list and stacking it, and backward boundary-gradient + storage is allocated from the same plan instead of `torch.zeros_like(...)`. + This is still not the final whole-program workspace allocator, but + output and boundary-gradient runtime buffers now have explicit shape, + dtype, device, workspace class, alias set, init policy, and compiler owner. + - 2026-05-01 follow-up hard-closure slice: the readout primitive ABI no + longer exposes one broad `readout_weight` logical binding that secretly + meant `output_q`, `value_to_output_weight`, and `output_cell_bias`. + `compile_readout_rule(...)` now lowers those as distinct parameter roles, + tensor binding rows record exact static/runtime sources for each role, and + the registered readout forward/reverse executors resolve those exact + compiler roles without preferred-key bundle selection. This makes adding + or changing readout parameters a local declaration/lowering/binding + change instead of editing a monolithic temporal-side weight bundle. + - Same hard-closure direction for message: removed the fake message + `out_weight` primitive row/binding. The active temporal path never applied + that projection inside message execution; `msg_out` is already represented + by transition/readout physical bindings (`value_to_state_weight` and + `value_to_output_weight`). Keeping `out_weight` in the message table was a + compiler facade because it was bound to K/V projection tensors. The + compiled message row group now ends at `weighted_sum` and declares exact + message physical roles: + `recurrent_q_weight`, `input_sender_kv_weight`, `input_group_kv_weight`, + and `recurrent_sender_kv_weight`. + - Same message ABI cleanup: registered executor handles no longer accept + `preferred_keys` / `preferred_names`. Static/runtime tensors are resolved + only from compiler binding rows, and repeated logical roles across the + key/value primitive rows are merged as one row-group binding. The forward + scan no longer checks + `static_tensors["recurrent_sender_input_to_kv_weight_backend_order"]` + directly; it verifies availability through the compiler tensor-binding + table for `recurrent_sender_kv_weight`. + - 2026-05-01 hard-closure continuation: local and sparse partitioned + attention are no longer launched through the standalone + `fabric_local_message_*` / `fabric_sparse_message_*` wrappers from the + active registered temporal path. `flat_bucket_registered_program_*` now + owns registered partitioned-attention and sparse-attention forward/backward + CUDA entrypoints. They consume selected executor rows plus executor binding + rows, validate the exact forward/reverse executor ID, bucket ordinal, + receiver count, and compiler-owned parameter binding rows, then run the + attention weighted-sum and adjoints. The same registered kernels are used + by recurrent message execution and readout output-message tape execution, + with the expected executor identity coming from the selected compiler + handle instead of a fixed temporal-side slot table. + - Same continuation: artifact/tape readout projection no longer calls the + standalone `fused_local_readout_cuda(...)` wrapper from the registered + surface runtime. The registered readout executor now first materializes the + output-message tape through registered message attention, then launches + `registered_forward_readout_projection_cuda(...)` with the selected readout + executor row and binding rows to apply `value_to_output_weight` and + `output_cell_bias`. The no-artifact path still uses the registered fused + readout/layout epilogue. Both readout forward modes are now compiler + binding-owned instead of borrowing a fixed readout wrapper. + - Remaining after this hard-closure slice: the full fused CUDA temporal + forward program and the fused CUDA reverse program are still open. The + new registered partitioned-attention, forward epilogue, and reverse + readout/layout kernels remove active program-level gaps and prove the + binding shape for compiler-owned fused launches, and the memory + fingerprints make checkpoint/recompute windows a compiler-owned runtime + contract. They do not close the sequence-wide scan/reverse throughput + executor, generic transition lowering beyond the current structural + composites, or a real allocator/aliasing owner for all workspaces. + - Verification for this continuation: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 116 tests after the explicit readout-role and runtime-buffer + plan changes. + CUDA direct epilogue parity against the prior registered primitive readout + plus layout kernels passed on random tensors. + CUDA direct reverse readout/layout parity against tensor-reference + reductions passed on random tensors, including the strict row-table wrapper. + CUDA no-artifact forward probe reported + `_last_flat_bucket_readout_backend=registered_temporal_program_fused_readout_layout_epilogue`. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row tests/test_fabric_backend_plan.py::test_temporal_tensor_binding_rows_are_compiler_products tests/test_fabric_backend_plan.py::test_temporal_scan_binding_projection_fails_for_readout_signature_drift tests/test_fabric_backend_plan.py::test_temporal_readout_rows_come_from_compiled_readout_program --tb=short` + passed: 5 tests after the explicit readout-role split. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + passed: 2 tests after runtime-buffer plan ownership was wired into the + active registered path. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 116 tests after the exact message-role split and fake + `out_weight` deletion. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests after routing local/sparse attention through the + registered program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 116 tests after the registered message-attention migration. + A direct CUDA sparse probe comparing + `registered_forward_sparse_attention_cuda(...)` and + `registered_backward_sparse_attention_cuda(...)` against the prior sparse + partitioned primitive passed on random tensors. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 5 tests after the registered message-attention migration. + The same CUDA training smoke passed again after replacing + `fused_local_readout_cuda(...)` in the registered artifact/readout path. + - 2026-05-01 hard-closure continuation: sender K/V projection is now a + registered program surface for the active temporal compiler path. + `surface_executor_runtime.py` no longer imports or calls + `project_sender_kv_from_cells_sequence(...)`, + `project_recurrent_kv_backend_order_cuda(...)`, or + `project_recurrent_kv_backend_order_backward_cuda(...)`. Forward input + sequence K/V, recurrent carry K/V, recurrent K/V backward, and boundary + input K/V backward now launch through + `flat_bucket_registered_program_*` entrypoints that consume selected + executor rows plus executor binding rows and validate compiler-owned + parameter bindings. The backward kernel covers direct recurrent weights, + direct input weights, and grouped input weights, so the registered + training path no longer borrows `receiver_major_affine_backward_cuda(...)` + or `fabric_grouped_projection_backward_cuda(...)` for this semantic + surface. + - This is real closure of one fixed projection ABI leak, not a new facade: + the registered sender K/V kernels have no temporal `kTf*` slot table, no + string role lookup, and no try/fallback probe. The remaining hard compiler + blockers are still the sequence-wide fused forward program, the + sequence-wide fused reverse program, executable whole-program + memory/aliasing beyond the current runtime-buffer plan, and generic + transition lowering/reference executor coverage beyond the current + structural gated/diagonal composites. + - Verification for the sender K/V closure: + direct CUDA parity passed for registered forward sender K/V sequence + projection against `project_sender_kv_from_cells_sequence(...)`, including + direct and grouped input weights, and for registered recurrent K/V step + projection against `project_recurrent_kv_backend_order_cuda(...)`. + Direct CUDA parity passed for registered sender K/V backward against + `project_recurrent_kv_backend_order_backward_cuda(...)`, including a + missing-V-gradient case, and grouped input K/V backward matched an autograd + reference. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 5 tests. + - 2026-05-01 hard-closure continuation: deleted + `TransitionPublicKVProjectionPlan` and the + `transition_public_kv_projection:sender_kv:v1` one-row mini ABI. Transition + public K/V projection now fails closed unless the caller supplies the + compiled forward executor rows, compiled forward binding rows, and selected + message executor id/bucket ordinal. Registered temporal forward threads + those compiler products from `RegisteredTemporalExecutorProgram` into + transition lowering, so this path can no longer synthesize an executor + table beside the compiler-owned plan. + - This closes the transition-side sender K/V projection leak specifically by + invalidating the old one-row projection path. The remaining transition + compiler work is still broader: generic transition lowering/reference + executor coverage beyond the current structural gated/diagonal records, + plus sequence-wide fused forward/reverse programs and full executable + memory/alias ownership. + - Same continuation: transition primitive support now has an explicit + primitive-executor registry separate from the structural program executor + records. `TransitionPrimitiveExecutorRecord` names CUDA-backed primitives + (`linear`, `matmul`, `gated_logspace_recurrence`, `norm_or_identity`, + `diag_rtu`, and now `tanh` forward/reverse) plus any future reference-only blocked + primitives with an explicit primitive blocker. Program executor selection + now fails closed with `UNREGISTERED_TRANSITION_PRIMITIVE` when a declared + op has no primitive executor record, with the primitive blocker when a + declared op is registered but lacks the needed CUDA/program layer, and with + `NO_REGISTERED_TRANSITION_EXECUTOR` when primitive records exist but no + legal forward/backward program executor covers the row pattern. This makes + the add-op path local and explicit instead of hiding everything behind two + gated/diagonal composite pattern checks. + - Verification for the transition public-K/V closure: + direct CUDA parity passed for registered sender K/V step projection against + `project_sender_kv_from_cells_step(...)` for both direct and grouped + weights. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_transition_parameters_resolve_from_compiled_bindings tests/test_fabric_backend_plan.py::test_temporal_transition_rows_come_from_compiled_transition_programs --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 5 tests. + - 2026-05-01 hard-closure continuation: deleted the remaining standalone + flat-bucket layout, recurrent K/V public-projection, and readout extension + source sets: + `flat_bucket_layout_{cuda.py,binding.cpp,kernels.cu}`, + `flat_bucket_public_projection_{cuda.py,binding.cpp,kernels.cu}`, and + `flat_bucket_readout_{cuda.py,binding.cpp,kernels.cu}`. These files are + no longer implementation references, wrappers, or fallback targets for the + registered temporal compiler path; the only flat-bucket CUDA extension left + for these surfaces is `flat_bucket_registered_program_*`. + - Same continuation: graph-order recurrent/cell assembly now runs through + `registered_forward_cells_layout_cuda(...)` in the active registered + temporal path. That kernel consumes `forward_executor_rows` and + `forward_executor_binding_rows`, validates the compiler-selected readout + executor plus compiler-owned parameter binding rows, and then assembles + graph-order recurrent cells and full cells. The shared backward + materialization helper no longer imports the old layout extension; it uses + the compiler/runtime graph-order permutation tensor directly. + - Same continuation: the two recurrent K/V backward runtime tests were + migrated from `try_project_recurrent_kv_backend_order_backward_cuda(...)` + to `registered_backward_sender_kv_projection_cuda(...)` with explicit + reverse executor rows and binding rows. The direct tests now exercise the + registered backward executor contract rather than keeping the deleted + public-projection wrapper alive as a test-only ABI. + - Source status after this slice: `rg` for `try_*_cuda(` under + `src/cortical/fabric/backend/cuda/sequence_surface` finds no active + temporal compatibility probes; `rg` for `flat_bucket_layout_cuda`, + `flat_bucket_public_projection_cuda`, and `flat_bucket_readout_cuda` finds + only deletion guardrails in tests. This supersedes the earlier 2026-05-01 + note saying those modules remained as strict wrappers. + - Verification for the legacy flat-bucket source deletion: + direct CUDA parity for `registered_forward_cells_layout_cuda(...)` matched + the prior graph-order layout result on random tensors before deleting the + old layout wrapper. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fixed_temporal_scan_extension_sources_were_deleted tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests after deleting layout/public-projection/readout sources. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 116 tests after the deletions. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_backend_order_recurrent_kv_projection_backward_matches_autograd tests/test_fabric_runtime.py::test_fabric_cuda_recurrent_kv_backward_uses_public_outnorm_sender_state tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 7 tests after the deletions. + - 2026-05-01 hard-closure continuation: runtime buffer allocation is now a + compiler memory-plan product. `TemporalRuntimeBufferPlan` exposes + `temporal_runtime_buffer_spec(...)` and + `allocate_temporal_runtime_buffer(...)`; active registered forward uses it + for `output_seq`, and active reverse uses it for `grad_boundary_seq`. + This removes another direct tensor-shape allocation from the temporal + runtime and makes the executable memory/liveness plan drive real runtime + buffers, not only metadata. + - Same continuation: deleted the stale fixed-formula + `sequence_surface/flat_bucket/flat_bucket_temporal_forward_primitives.cuh` + header. It was no longer a callable path, but it still encoded old + message/gated/diagonal formula bundles in the flat-bucket layer; the + deletion guard now keeps it from returning as an implementation reference. + - Same continuation: the registered CUDA readout/layout/readout-backward + entrypoints no longer validate hardcoded readout executor constants. + `flat_bucket_registered_program_*` now receives the compiler-selected + `executor_id` and `bucket_ordinal` from the readout executor row for + forward readout projection, forward readout+layout epilogue, cells layout, + and reverse readout/layout projection. The old `kReadoutExecutorId`, + `kReverseReadoutExecutorId`, and `kReadoutBucketOrdinal` constants were + removed and source-guarded. + - Same continuation: active registered temporal runtime no longer looks up + synthetic message/readout handles by literal `bucket_ordinal=-1/-2`. + `RegisteredTemporalExecutorProgram` now exposes + surface handle tuple selection, so runtime orchestration selects + message/readout handles from compiled executor rows by surface and passes + the selected row identity into CUDA. The synthetic ordinals remain only as + named compiler IR constants + (`TEMPORAL_MESSAGE_BUCKET_ORDINAL`, `TEMPORAL_READOUT_BUCKET_ORDINAL`, + `TEMPORAL_PARAMETER_REDUCTION_BUCKET_ORDINAL`) inside compiler passes. + - Same continuation: deleted the registered surface runtime's per-call + executor mini-table shims. `surface_executor_runtime.py` no longer builds + `_registered_executor_rows_tensor(...)` or + `_registered_executor_binding_rows_tensor(...)`; active forward/recompute + callers pass `executor_program.forward_plan.forward_executor_rows` and + `executor_program.forward_plan.executor_binding_rows` directly into every + registered sender-K/V, message, and readout CUDA surface launch. Active + reverse callers likewise pass + `executor_program.backward_plan.reverse_executor_rows` and + `executor_program.backward_plan.executor_binding_rows` into registered + output/recurrent message backward launches. This removes another facade + layer where runtime code could silently narrow the compiled program into a + one-row ABI before calling CUDA. + - Same continuation: registered forward output storage is now allocated + before the temporal loop from the compiler memory/liveness plan and the + temporal output contract. `_temporal_forward_output_step_shape_for_contract(...)` + derives the output step shape from the compiled runtime facts, then + `build_temporal_runtime_buffer_plan(...)` and + `allocate_temporal_runtime_buffer(...)` allocate `output_seq` before any + step executes. The prior lazy `if output_seq is None` allocation after the + first emitted output was deleted and source-guarded. This moves another + executable memory decision out of step-local runtime observation and into + the memory-plan/product boundary. + - Source status after this slice: `rg` for active `try_*_cuda(` probes under + `sequence_surface`, `ops/temporal_backward`, and `transition_execution` + returns no active source matches; remaining occurrences are deletion/source + guardrails in tests. `rg` for `bucket_ordinal=-1/-2/-3` in compiler/runtime + source now returns only test guardrails; compiler passes use named IR + constants. `rg` for `_registered_executor_rows_tensor` and + `_registered_executor_binding_rows_tensor` returns no active source + matches. + - Remaining hard compiler blockers after this slice are unchanged in + priority: sequence-wide fused CUDA forward/reverse program kernels that + consume executor rows, binding rows, and the memory plan directly; fuller + executable memory/liveness ownership for workspace aliasing and tape + materialization; and generic transition lowering/reference executor + coverage beyond the current structural gated/diagonal transition records. + - Verification for this continuation: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/row_groups.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/strategy_selection.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/row_groups.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/strategy_selection.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_strategy_selection_has_separate_legality_and_cost_phase tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 116 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_backend_order_recurrent_kv_projection_backward_matches_autograd tests/test_fabric_runtime.py::test_fabric_cuda_recurrent_kv_backward_uses_public_outnorm_sender_state tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 7 tests. + - Verification for the mini-table deletion: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 116 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_backend_order_recurrent_kv_projection_backward_matches_autograd tests/test_fabric_runtime.py::test_fabric_cuda_recurrent_kv_backward_uses_public_outnorm_sender_state tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 7 tests. + - Verification for the upfront output-buffer memory-plan slice: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 116 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_backend_order_recurrent_kv_projection_backward_matches_autograd tests/test_fabric_runtime.py::test_fabric_cuda_recurrent_kv_backward_uses_public_outnorm_sender_state tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 7 tests. + - 2026-05-01 hard-closure continuation: the active registered transition + forward path no longer calls + `lower_backend_population_transition_forward_result_shared(...)` or passes + the compiled message executor as a public-K/V compatibility side channel. + `run_registered_transition_forward_executor(...)` now materializes the + transition program tensor table from compiler-owned executor bindings, + passes `primitive_rows`, selected forward executor rows, selected binding + rows, and `memory_liveness_rows` into + `registered_temporal_fused_forward_transition_program_cuda(...)`, and + reads public outputs plus next private state from the returned program + tensor slots. The old forward transition lowering path remains deleted from + the active temporal surface; backward transition still uses the lowering + reference and is the next hard Pass 3 closure target. + - This moves the forward transition surface from per-bucket Python lowering + into a compiler-owned fused transition span. It is still not final + sequence-wide fused execution: message, readout, K/V, layout, and temporal + scheduling remain separate registered surfaces around the loop, and reverse + transition spans still need the equivalent row/binding-owned program + dispatcher before Pass 3/5 can close. + - Verification for the active forward transition-span adoption: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_boundaries.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 122 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_active_transition uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor --tb=short` + passed: 2 tests, covering the active sLSTM and Axon transition families. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_active_transition uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 1 test, covering the multi-transition-bucket active CUDA path. + - 2026-05-01 hard-closure continuation: the active registered transition + backward path no longer calls + `lower_backend_population_transition_backward_shared(...)`. Reverse + transition executor bindings now declare explicit backward tensor-role + contracts for the registered recurrence span: forward intermediates, + public/state gradient inputs, parameter bindings, state-gradient outputs, + message-gradient output, and parameter-gradient outputs. The runtime + recomputes transition intermediates through the registered forward + transition program, extends the compiler tensor table with reverse + gradient roles, and calls + `registered_temporal_fused_reverse_transition_program_cuda(...)` with + `primitive_rows`, the selected reverse executor row, selected reverse + binding rows, and `memory_liveness_rows`. + - The CUDA registered-program extension now exposes + `fused_reverse_transition_program_execute`. That entrypoint validates + compiler rows/binding rows/memory rows, dispatches the selected reverse + transition span by primitive opcode, and writes gradients back through + compiler tensor binding slots. It composes the existing registered + primitive kernels for normalization, recurrence, recurrent affine/input + affine, diagonal output projection, diagonal recurrence, and input + projection inside one reverse transition program entrypoint instead of + returning to the legacy transition lowering route. + - Parameter binding for transition input projections now converts fused + input-projection gradients back into compiler parameter-gradient maps + before `bind_temporal_transition_param_grads(...)`. Factorized projections + use the existing projection unfuse helper; unfactorized projections bind + the direct value-to-cell and recurrent-bias gradients. This keeps the + active reverse path on registered executor products while preserving the + existing trainable parameter contracts. + - Source status after this slice: `surface_executor_runtime.py` has no + active `lower_backend_population_transition_forward_result_shared(...)` or + `lower_backend_population_transition_backward_shared(...)` call, and its + guardrails reject literal gated/diagonal transition executor names in the + active runtime. The remaining transition compiler work is no longer the + active fixed lowering route; it is broader generic transition + lowering/reference coverage beyond the current registered gated/diagonal + structural records. + - Remaining hard compiler blockers after this slice: + sequence-wide fused CUDA forward/reverse program kernels still need to + consume the complete temporal executor program directly rather than + per-surface/per-step registered launches; memory/liveness planning still + needs to own workspace aliasing, tape materialization, and CUDA graph + constraints as executable policy; generic transition lowering still needs + reference executors and typed fail-closed coverage for new declared ops + beyond the current gated/diagonal recurrence records. April21-class + throughput remains a later performance-closure target after these compiler + ownership blockers are closed. + - Verification for the active reverse transition-span adoption: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 122 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_program uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference --tb=short` + passed: 2 tests, covering active sLSTM and Axon forward/backward parity. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_program uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route --tb=short` + passed: 1 test, covering multiple transition buckets through the registered + reverse program path. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_program uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_program uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed: 4 tests, including reset/no-reset parameter-gradient coverage. + - 2026-05-01 hard-closure continuation: started the sequence-wide fused + CUDA forward program cutover. `registered_temporal_fused_forward_program_cuda` + is no longer a validate-and-throw stub for supported inference rows. The + pybind ABI now takes a compiler tensor table, primitive rows, forward and + reverse executor rows, executor binding rows, memory-liveness rows, and + flat-bucket graph tables. The C++ body decodes the compiler span table and + runs the no-artifact `output_cells` forward sequence inside one program + entrypoint: input K/V sequence projection, recurrent K/V projection, + recurrent message attention, registered transition program spans, output + message attention, and readout projection. + - Active runtime cutover is intentionally narrow and honest. The forward + runner calls the fused sequence program only when the request needs no + artifacts, no materialized final state, no resets, CUDA float32 tensors, + local-message tables, and the `output_cells` contract. All training, + replay, reset, and final-state requests continue to use registered + executor rows, not legacy fixed scan/reverse ABIs. This was the interim + state before the reverse span body landed; the current fused program plan + requires both forward and reverse span dispatch bodies before reporting + `legal`. + - Transition constants are now compiler tensor bindings instead of temporal + launch side channels. `norm_or_identity` rows bind `outnorm_eps`, `diag_rtu` + rows bind `activation_id`, and the fused forward/reverse transition program + kernels read those scalars from `program_tensor_binding_rows`. This removes + the earlier uniform-constant compromise and lets each transition bucket + carry its own compiler-owned scalar values. + - Remaining hard compiler blockers after this slice: implement the fused + reverse sequence program over the same row/binding/memory contract; widen + the fused forward sequence body to artifact/final-state/reset policies + using executable memory/liveness rows; then delete the registered per-step + training orchestration once reverse parity and artifact policy are + compiler-owned through the program entrypoints. + - 2026-05-01 reverse-window cut: the active reverse window now attempts a + compiler-owned `registered_reverse_program_window` before the old + per-step `run_registered_temporal_bucket_step_backward(...)` loop. The + cut is intentionally narrow: CUDA float32, `output_cells`, no reset tensors, + no final-state carry gradient, no incoming backend-state gradient, and + stored artifacts with backend-order recurrent-hidden-before tensors. The + selected path still composes registered primitive CUDA executors from + Python at the window level, so the true fused C++ reverse sequence + entrypoint remains the top fused-kernel blocker. + - The new ownership tag is + `flat_bucket_temporal_reverse_scan_owner:registered_reverse_program_window` + plus `flat_bucket_temporal_backward_binding_abi:registered_executor_binding_rows`. + Reset/final-state/reverse-table artifact windows are no longer allowed to + demote into the old per-step registered fallback from this entrypoint; they + must either satisfy the fused reverse-program contract or fail closed with a + typed `registered_reverse_program_window_reject:`. + - Memory/liveness planning now names the reverse artifact roles needed by the + program-window path: boundary/public tensors, recurrent K/V before/after, + recurrent hidden before/after, recurrent/output messages, output cells, + backend state cache, and transition tape. `TemporalArtifactStore` carries + those roles from `TemporalMemoryRuntimeArtifactPlan`, and + `registered_reverse_program_window` refuses to own a window unless the + compiler artifact role plan and the stored artifact tensors both satisfy + that contract. This is the enabling contract for replacing the remaining + Python artifact object with a fused reverse-program tensor table. + - Verification started for the sequence-wide fused forward cutover: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence4 uv run python - <<'PY' ... _load_ext() ... PY` + passed and loaded the registered program CUDA extension. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence4 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: 1 test, proving the no-artifact `output_cells` route is owned by + `registered_temporal_fused_forward_program_cuda` and matches the registered + executor-row path for the same program. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence4 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor --tb=short` + passed: 2 tests, covering the existing registered sequence executor path. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 122 tests after the fused sequence cutover. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence5 uv run python - <<'PY' ... _load_ext() ... PY` + passed after removing the fused transition program scalar launch arguments. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence5 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test, proving `outnorm_eps` is consumed from compiler tensor + bindings by the fused transition program body. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence5 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: 1 test after binding-owned transition constants. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence5 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor --tb=short` + passed: 2 tests after binding-owned transition constants. + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_runtime.py` + passed after the reverse-window cut. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_runtime.py` + passed after the reverse-window cut. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed after adding compiler-owned reverse artifact roles. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_backward_validates_memory_artifact_plan_fingerprint --tb=short` + passed: 2 tests, including the new reverse artifact role contract in the + memory runtime artifact plan. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence5 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, proving the no-carry/no-reset `output_cells` backward + route is owned by `registered_reverse_program_window` and matches the + registered reset-present fallback for outputs, boundary gradients, and + parameter gradients. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence5 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests, covering the new reverse-window owner plus reset/state + routes that intentionally remain on `registered_reverse_executor_bindings`. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 122 tests after binding-owned transition constants. + - 2026-05-01 reverse-program artifact-table cut: added + `compiler/reverse_artifacts.py` as the compiler-owned reverse artifact role + table and threaded those rows into the active registered reverse window. + The new C++/pybind entrypoint + `fused_backward_program_output_grad_window` consumes + `primitive_rows`, forward/reverse executor rows, executor binding rows, + `memory_liveness_rows`, reverse artifact role rows, and reverse artifact + tensor binding rows. For the currently legal no-carry/no-reset + `output_cells` reverse-program window, output-gradient materialization is + now done by that registered backward program stage instead of by Python + `_full_cells_grad_for_output_contract(...)` shape glue. + - This is a real Pass 3/4 cut, but not full fused reverse closure. The fused + stage validates the compiler program and reverse artifact table, builds the + full-cell gradient window, and records + `temporal_backward_glue:registered_fused_backward_program_output_grad_window`. + The remaining hard blocker is to move the readout backward, message + backward, transition backward, boundary projection backward, and parameter + reductions into the same program-level reverse C++ entrypoint, then delete + the Python window orchestration/per-step registered fallback for supported + routes. + - Verification for this cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_reverse_artifact_roles_are_compiler_binding_rows tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence7 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, proving the active reverse-program window uses the fused + output-gradient stage and still matches the registered fallback for the + compared gradients. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence7 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 reverse transition dynamic-binding row cut: moved the + registered reverse transition dynamic tensor binder off handler-kind + branches for gated-logspace versus diagonal-RTU state/seed positions. The + compiler now emits per-transition-group dynamic binding rows with explicit + source kinds: reverse artifact access, transition state-before artifact, or + seed-or-zero with a declared template binding. The fused backward program + threads those rows through Python, pybind, and the CUDA program executor, + and the C++ binder only interprets the row table; it no longer owns + gated/diagonal input slot positions or seed slot positions. + - Remaining transition compiler work after this cut: the registered reverse + primitive implementations are still the current gated and diagonal + executor bodies, transition parameter reducer algebra still needs to move + from monolithic reducer transforms into strategy-owned reducer rows, and + memory/liveness rows still need to become executable workspace/aliasing + allocation rather than mostly validation metadata. + - Verification for the reverse transition dynamic-binding row cut: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `git diff --check` passed. CUDA extension rebuild/runtime parity was not + rerun for this source cut yet. + - 2026-05-01 program-access opcode cut: widened forward/reverse temporal + program access rows from local slot rows to compiler access rows carrying + both the handler-local slot and a stable access opcode produced from the + registered strategy pattern. The fused C++ program now binds message, + readout, and transition tensors by access opcode (`message_recurrent_query`, + `message_recurrent_kv_weight`, `readout_value_to_output_weight`, + `transition_aggregated_message_input`, etc.) instead of temporal-side + `kExecutorAccessSlot*` constants. This keeps the current neighborhood + attention/readout/transition strategies narrow, but moves their tensor-role + ABI under the registered handler/access contract rather than under the + temporal scheduler. + - Same cut: forward and reverse message/readout span selection now finds + handlers by declared surface plus capability flag. The active C++ no longer + asks for a fixed message/readout handler enum in forward binding, readout + backward, output-message backward, recurrent-message backward, recurrent + K/V projection backward, initial recurrent K/V projection, or boundary K/V + projection. The remaining fixed handler IDs are implementation strategy + registrations and transition primitive adjoint implementations, not the + shared temporal program's tensor-access ABI. + - Remaining work after the program-access opcode cut: replace the remaining + transition-adjoint implementation switch with a registry table that maps + handler kind to callable primitive executor, move parameter reducer + transforms into compiler reducer-transform rows, and let memory/liveness + rows allocate workspace/artifact/gradient buffers instead of only auditing + them. + - 2026-05-01 transition primitive executor registry cut: split the fused + forward transition program off its direct opcode switch. The shared fused + transition loop now validates primitive rows and binding rows, then calls a + `RegisteredTransitionForwardPrimitiveExecutor` selected by primitive opcode. + The current linear, recurrent-matmul, gated-logspace recurrence, + norm-or-identity, and diag-RTU implementations are registered executor + handlers; the shared loop no longer embeds the per-op formula branches or + the old unsupported-op branch. This is not yet arbitrary-op closure, but + it changes the extension point: a new CUDA transition primitive should add + a primitive executor record/handler and tests instead of editing the + temporal program loop. + - The reverse transition handler selection received the same registry shape: + `RegisteredTransitionReversePrimitiveExecutor` maps the registered reverse + handler and primitive opcode to the current gated/diag adjoint + implementation. The remaining reverse work is still deeper than this: + the current gated/diag adjoints are composite handlers, and parameter + reducer algebra still has specialized transforms that must become + binding-owned reducer primitive rows. + - 2026-05-01 transition parameter-gradient contract cut: moved the + reverse parameter-gradient output contract out of + `sequence_surface.compiler.executor_bindings` and into + `transition_execution.registry.TransitionPrimitiveExecutorRecord`. The + temporal transition parameter-gradient binding plan now derives reducer + rows from the primitive executor record instead of a local gated/diag map. + - 2026-05-01 transition trainable reducer handler-table cut: replaced the + C++ transition trainable reducer kind switch with a + `RegisteredTransitionTrainableReducerRunFn` handler table. The current + materialized, value-to-cell, and recurrent-bias transforms are still the + supported reducer set, but the reducer program now dispatches through + registered handler entries instead of embedding the transform switch in the + shared reducer runner. Remaining reducer closure is to make those handler + entries generated from binding-owned transform contracts and to add any new + transform by registering a reducer primitive handler plus legality tests. + - Remaining compiler-closure order after this cut: + 1. Replace whole-pattern transition program selection with legal primitive + DAG coverage plus optional fused strategy records. + 2. Generate reducer handler rows from binding-owned transform contracts + instead of today’s fixed transform ids. + 3. Make memory/liveness allocate and alias transition workspaces, tapes, + artifact buffers, grad workspaces, and parameter accumulators. + 4. Generalize message/readout carriers so new strategies register through + primitive rows and access rows without editing the shared program loop. + 5. Only then widen coverage for output contracts, reset variants, variable + K, and additional primitives. + - 2026-05-01 handler-strategy registry/effect-verifier cut: moved fused + forward/reverse handler-row construction off private + executor-name-to-handler dictionaries in `program_execution.py`. Handler + kind, capability, and required effect metadata now live on the structured + executor strategy records in `executor_patterns.py`, and handler rows are + emitted by matching the compiled primitive row group back to that strategy + record. This makes the current message/readout/gated/diagonal handlers + local strategy registrations rather than hidden temporal-side maps. + - The fused C++ program now carries the handler required-effect mask in the + decoded span and verifies it against compiler memory-liveness rows. This + exposed a real memory-plan gap: reverse handlers require `grad_read` and + `parameter_grad_emit` facts, while the memory plan only emitted forward + effects for message/readout/transition rows. The memory plan now emits the + reverse-window effects needed by registered reverse handlers, so effect + legality is proved by the compiler memory rows instead of inferred from a + broad surface default. + - Remaining work after this cut: the C++ registered transition primitive + forward body is still a monolithic opcode dispatcher, transition parameter + reducers still encode current parameter algebra in C++, and memory/liveness + rows still need to drive actual workspace allocation and aliasing rather + than only proving legality. Those remain the next hard compiler-closure + blockers. + - C++ handler validation was tightened in the same pass. The fused program no + longer accepts handler kinds by a broad numeric enum range or derives names + from switch-only helper functions. It resolves each decoded span through a + registered handler spec keyed by handler kind plus surface opcode, validates + the declared capability flags, then validates the required effect mask + against memory-liveness facts. This prevents a malformed compiler row from + claiming a message/readout/transition handler through the wrong surface or + missing capability bit. + - Verification for this cut: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_strategy_matching_uses_canonical_row_group_schema tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 4 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_handler_effects2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_handler_specs uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension with + C++ handler specs. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `git diff --check` passed. + - 2026-05-01 compiler-emitted executor handler-row cut: moved fused + forward/reverse program handler selection out of C++ static handler tables + and into compiler-owned `forward_handler_rows` / `reverse_handler_rows`. + `build_temporal_fused_cuda_program_plan(...)` now requires those row + tables before a fused launch can be selected, the runtime records them + beside primitive/executor/binding/memory rows, and every registered fused + forward/backward pybind entrypoint receives the handler rows as part of the + executable program ABI. The registered C++ program now decodes spans from + executor rows plus handler rows, so the active program no longer treats a + compiled executor id as sufficient evidence of a supported strategy. + - This is a real compiler-closure cut, not final strategy generality. The + supported handlers are still the current neighborhood-attention message, + projection-reduction readout, gated-logspace transition, and diagonal-RTU + transition strategies. The next hard closure step is to make new + strategies register through the Python compiler registry with their + primitive pattern, tensor roles, memory/liveness effects, and forward/ + reverse CUDA implementation, without editing temporal scheduler ownership + or reintroducing fixed C++ handler arrays. + - Verification for the handler-row cut so far: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_handler_rows2 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_handler_rows_runtime uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `git diff --check` passed. + - 2026-05-01 parameter-reducer handlerization cut: moved the registered + temporal parameter reducer program off a monolithic reducer-family body. + The active C++ reducer entrypoint now decodes compiler parameter reducer + rows into `RegisteredParameterReducerHandler` records, validates expected + tensor counts from those records, and dispatches sender K/V, recurrent + query, output query, readout output, and transition parameter reductions + through handler functions. Transition trainable algebra is likewise behind + `RegisteredTransitionTrainableReducerHandler` records, so adding another + trainable reducer is a handler registration plus source/target binding + contract instead of expanding the scheduler body. + - Remaining reducer compiler work after this cut: the registered handler set + still implements the current projection/query/readout/transition reducer + strategies, and the C++ formulas still live in those handlers. The next + closure step is making new parameter reducer strategies register from + primitive-row/binding metadata and then widening memory/liveness rows so + reducer workspace and aliasing are planned instead of just validated. + - 2026-05-01 parameter-reducer strategy-table cut: removed the remaining + common reducer semantic dispatch switch from the active registered C++ + parameter reducer path. The compiler now emits explicit + `parameter_reducer_strategy_rows` beside per-request `parameter_reducer_rows`; + C++ decodes those strategy rows to choose the run function and + expected-count contract for each reducer kind. The temporal program still + invokes the current readout, sender K/V, recurrent-query, output-query, and + transition reducer implementations, but it no longer branches on + `RegisteredParameterReducerHandlerKind` or `switch (handler.kind)` to pick + them, and it no longer relies only on a built-in reducer-kind table for + strategy selection. Static guardrails now reject reintroducing that common + reducer enum or switch and assert the compiler-owned strategy-row ABI. + - Remaining reducer compiler work after this strategy-table cut: the table + is still populated by the current Python-built strategy contract and C++ + callable implementations. Full closure requires reducer strategy + registration from primitive-row/parameter-binding metadata, typed + legality/cost records for new reducer strategies, and memory/liveness-owned + reducer workspace allocation. Transition trainable reducer formulas are + already function-pointer handlers, but their implementation registry still + needs to become external strategy data rather than C++ table literals + before new reducer algebra is truly local to a plugin. + - Verification for the parameter-reducer handlerization cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_param_handlers uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + - 2026-05-01 executable memory-contract cut: the registered fused C++ + program now converts compiler memory-liveness rows into per-span memory + facts before executing a registered handler. Forward and reverse spans + prove required effects and workspaces for their surface (`state_read`, + `parameter_read`, `message_emit`, `output_emit`, `message_read`, + `state_write`, `tape_policy`, and the matching workspace class) against + the compiler liveness table. This turns memory rows from row-count audit + metadata into active legality checks on handler execution. + - Remaining memory/liveness work after this cut: the C++ program still uses + these rows as effect/workspace contracts, not as a full allocator for all + temporary buffers. The next closure step is to make handler workspace + allocations and aliasing consume planned buffer specs directly, then remove + any remaining local `new_empty` / `zeros_like` allocation choices that are + not backed by a compiler memory buffer or artifact plan. + - 2026-05-01 runtime buffer table cut: extended + `TemporalRuntimeBufferPlan` so memory-liveness entries that require + runtime storage lower into dense `runtime_buffer_rows` plus allocated + runtime buffer tensors. Registered fused forward and reverse program + launches now pass those buffer rows/tensors into the compiled extension, + and C++ validates every allocatable memory-liveness row has a matching + runtime buffer with matching workspace, surface, bucket, effect, and alias + contract before executing the fused program. This is the first allocator + ownership cut: runtime buffers are now a compiler product passed through + the active registered CUDA ABI, not just review metadata. + - Remaining memory/liveness work after the runtime buffer table cut: many + handler internals still allocate local temporary tensors. The next cut is + to replace those local temporaries with actual slices/views of the + compiler runtime buffer table and then make alias groups enforce reuse + rather than only validating that allocation ownership exists. + - 2026-05-01 fused forward output-buffer ownership cut: the registered fused + forward program no longer allocates `output_seq` inside the C++ temporal + loop. The Python launch now asks `TemporalRuntimeBufferPlan` for the + concrete output-sequence shape, allocates it as a compiler runtime buffer, + and the C++ program resolves that tensor through `runtime_buffer_rows` by + output workspace/effect/shape before writing emitted readout cells. This + turns the most visible fused-forward result buffer from a local + `at::empty(...)` into a compiler-owned runtime allocation. + - Remaining memory/liveness work after the output-buffer cut: recurrent + carry workspaces, reverse carry-gradient buffers, transition/readout + temporary tensors, and parameter accumulators still need to consume + runtime-buffer rows directly instead of local `empty_like` / `zeros_like` + allocations. The next allocator pass should add shape/lifetime rows for + those workspaces and enforce alias-set reuse, not just ownership presence. + - Verification for the executable memory-contract cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_memory_contracts uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_memory_contracts uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `git diff --check` passed. + - 2026-05-01 transition primitive blocker cut: tightened + `select_transition_program_executor(...)` so a declared transition program + containing a registered-but-unimplemented primitive fails at primitive + capability legality with the primitive blocker code before the selector + reports a missing composite program executor. This makes adding a new + primitive op local and explicit: register/lower it, provide a CUDA/reference + status, and either add a legal program executor strategy or fail closed with + the primitive registry blocker instead of an ambiguous pattern miss. + - Verification for the transition primitive blocker cut: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records --tb=short` + passed. + - 2026-05-01 fused forward training-artifact tensor-table cut: the + registered fused forward program now accepts compiler-owned forward reset + rows and can return reverse artifact tensors by role when + `collect_artifacts=True`. The tensor table covers boundary step, + cells-prev template, input K/V, recurrent K/V before/after, recurrent + hidden before/after, recurrent message, output message/cells, and + transition state-before rows keyed by compiler transition bindings. The + temporal artifact store now has a `TemporalReverseArtifactTensorStore` + path, so the physical backward loader can consume the fused forward tensor + table directly instead of using the old recompute source as the artifact + owner for this supported store-step route. + - This cut intentionally keeps a lightweight `TemporalBucketStepArtifacts` + view materializer at the reverse executor boundary so the already-fused + reverse program can keep its current window API while receiving values from + compiler tensor rows. The next cleanup after this artifact closure is to + make the reverse window API consume the tensor store directly and delete + the view bridge. + - Verification for the fused forward training-artifact cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_artifacts uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_temporal_fused_forward_program_cuda`, + `registered_fused_forward_program_cuda`, and + `registered_fused_forward_program_tensor_store`. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `git diff --check` passed. + + - 2026-05-01 strategy-owned forward/reverse program access cut: moved the + full-program tensor access contract into registered executor strategy + records. `TemporalForwardExecutorPattern` and + `TemporalReverseExecutorPattern` now declare `program_accesses`, and forward + state carry is driven by `state_carry_rules` instead of transition-output + name conventions. Reverse executable program tensor tables now emit + `reverse_program_access_rows`, and the active fused reverse full-step ABI + passes those rows through Python, pybind, and CUDA. + - This removes the active reverse helper dependence on positional primitive + offsets such as `readout_primitive_start`, `message_primitive_start`, and + `message_primitive_start + 1`. Readout/message reverse helpers now resolve + `output_q`, `value_to_output_weight`, `recurrent_q`, direct/group input K/V + weights, and recurrent K/V weights through compiler-owned program access + rows plus tensor-binding rows. The remaining forward-strategy work is + expanding the registered strategy set, not adding new fixed row-position + assumptions. + - The same pass also tightened reverse reset ownership: transition + state-reset rows are reset-capability metadata. The fused reverse full-step + kernel now applies them only when the compiler-supplied transition reset + tensor is present, instead of treating row presence as an active reset. + - Verification for the strategy-owned access cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_reverse_program_access_rows_are_compiler_owned tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 4 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_access_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + - 2026-05-01 reverse-program transition-stage entrypoint cut: added + `fused_backward_program_transition_stage` as a registered CUDA entrypoint + over the selected transition executor row groups. The active + `registered_reverse_program_window` no longer imports or calls the old + `run_registered_transition_reverse_program_stage(...)` helper and no longer + goes through the generic `registered_temporal_reverse_program_stage_cuda` + dispatcher for recurrent-message/boundary/initial-KV. Transition reverse + now enters CUDA once per local step with compiler primitive rows, executor + rows, binding rows, memory-liveness rows, reverse artifact rows, seed rows, + and transition parameter rows; Python still owns the parameter reducers. + - Deleted the stale high-level transition reverse stage helper from + `surface_executor_runtime.py`, deleted the generic reverse stage dispatcher + from `flat_bucket_registered_program_cuda.py`, and tightened the guardrails + so the active path must call the direct registered entrypoints. This is a + real active-path cut, not a compatibility relabel: remaining reverse work is + now reducer orchestration, executable memory/liveness policy, and broader + program-level fusion across the still-separate reverse entrypoints. + - Verification for the transition-stage entrypoint cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_stage uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension; rerun + against the same extension also passed with the new + `registered_fused_backward_program_transition_stage` launch-count guard. + - 2026-05-01 transition-boundary reverse fusion continuation: added + `fused_reverse_program_transition_boundary_step`, which runs the selected + transition reverse executor groups, assembles the full backend-order + recurrent-message gradient from compiler-declared transition output slots + and bucket ranges, and immediately runs recurrent-message/boundary-KV/ + initial-recurrent-KV backward through the same registered C++ program + entrypoint. The active Python window no longer launches + `registered_temporal_fused_backward_program_transition_stage_cuda(...)` + followed by + `registered_temporal_fused_backward_program_recurrent_message_boundary_initial_kv_step_cuda(...)`; + it calls one combined registered transition-boundary step and then runs the + compiler parameter reducers over the returned tensors. + - This is still not final program-level reverse closure: parameter reducers + remain Python-owned, and executable memory/liveness policy is still mostly + verified metadata rather than allocator/tape authority. But the active + transition + recurrent-boundary math now crosses into CUDA as one + compiler-row entrypoint instead of a pair of host-orchestrated stage calls. + - Verification for the transition-boundary fusion continuation: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_boundary uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 parameter-reducer batch program cut: the active + `registered_reverse_program_window` no longer calls the readout-output, + sender-K/V, recurrent-query, or transition parameter reducer functions + surface-by-surface inside the reverse sweep. The window now builds typed + `TemporalParameterReducerRequest` records from compiler executor handles, + stage rows, transition parameter binding rows, and raw gradient outputs, + then calls one `run_temporal_parameter_reducer_program(...)` after the + reverse sweep. The old reducer functions are private implementations behind + that batch program, and the transition-boundary reverse helper now returns + transition reducer requests instead of trainable-parameter gradient tuples. + Static guardrails assert the active registered executor cannot call the old + per-surface reducer entrypoints. + - This is a real active-path ownership cut, not final reducer CUDA closure. + The batch reducer still performs host-side PyTorch reductions inside + `temporal/param_binding.py`; the next hard closure is a registered CUDA + parameter-reducer program that consumes reducer requests/stage rows, + executor binding rows, parameter binding rows, and memory-liveness rows + directly. + - Verification for the parameter-reducer batch program cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_boundary uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 parameter-reducer row-program cut: tightened the previous + batch reducer into an explicit `TemporalParameterReducerProgram`. The + active reverse window now calls `build_temporal_parameter_reducer_program` + with typed requests and compiler reverse stage rows, producing CPU int64 + reducer rows with reducer kind, executor row, executor id, bucket, request + index, tensor count, and flags. `run_temporal_parameter_reducer_program` + executes from those rows and records the row tensor/summaries on runtime + metadata. The old `_run_temporal_*_param_reducer_program(...)` functions + were deleted as callable units and replaced by row executors such as + `_execute_readout_output_parameter_reducer_row(...)` and + `_execute_transition_parameter_reducer_row(...)`. + - This removes the active host-side `isinstance` dispatch as the owner of + legality: stage-row legality is now checked while building the reducer + program, and execution follows the row program. It is still not the final + CUDA reducer closure because the row executors perform ATen/PyTorch tensor + reductions in `temporal/param_binding.py`. The next hard closure is moving + the row-program executor itself behind a registered CUDA parameter-reducer + entrypoint that consumes reducer rows, executor binding rows, transition + parameter binding rows, trainable parameter rows, and memory-liveness rows. + - Verification for the parameter-reducer row-program cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_boundary uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 parameter-reducer CUDA common row executor cut: moved the + common readout-output, sender K/V projection, recurrent-query, and + output-query parameter reductions behind + `registered_temporal_parameter_reducer_program_cuda(...)`. The registered + extension consumes the compiler-owned reducer row tensor plus typed tensor + lists and returns the fixed common trainable-gradient outputs. The old + Python row executors for those common reducers and the direct + `torch.einsum`/`index_add_` formulas were deleted from + `temporal/param_binding.py`; the active common reducer path now has to pass + through `parameter_reducer_program_execute`. + - This is not pretending transition parameter binding is solved. The + transition reducer remains the only host-side reducer row executor because + it still depends on dynamic population parameter-name mapping, + materialized-vs-static-source reducer bindings, and trainable parameter + shape alignment. The next hard reducer closure is to compile transition + reducer bindings into trainable-parameter rows and move that path into the + same registered CUDA parameter-reducer program. + - Verification for the parameter-reducer CUDA common row executor cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_param_reducer_cuda2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 transition trainable-parameter row cut: replaced the remaining + transition reducer execution-time trainable-name synthesis with compiler + trainable reducer rows. `build_temporal_parameter_reducer_program(...)` + now receives `trainable_param_names`, compiles transition source names to + source ids, resolves materialized/static transition gradient sources to + target trainable parameter indices, and emits `transition_trainable_rows` + plus summaries. Runtime execution no longer constructs names such as + `population_modules.._delta` while reducing gradients; + it follows the trainable row opcode, target index, aux index, and source + id. The old `_registered_transition_population_param_grad_tuple(...)` and + singular `_execute_transition_parameter_reducer_row(...)` path were + deleted. + - This still leaves the final reducer throughput closure open: transition + trainable rows are compiler-owned and fail closed when a live source has no + trainable row, but their tensor algebra still runs in + `temporal/param_binding.py`. The remaining hard cut is to pass the + transition trainable rows/source tensors into + `parameter_reducer_program_execute` so transition parameter algebra runs in + the same registered CUDA reducer program as readout, sender K/V, and query + reductions. + - Verification for the transition trainable-parameter row cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_trainable_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 transition parameter-reducer CUDA program cut: moved the + remaining transition trainable-parameter reducer tensor algebra out of + `temporal/param_binding.py` and into the registered + `parameter_reducer_program_execute` extension. The reducer program now + emits `transition_source_rows` and `transition_trainable_rows`; Python only + validates compiler binding coverage, gathers source tensors by source row, + and maps returned per-trainable outputs back to autograd parameters. The + registered CUDA reducer consumes the compiler source rows, trainable rows, + transition source tensor table, recurrent-index side table, and trainable + parameter tensor table, then executes materialized-base/delta, + value-to-cell/msg projection, and recurrent-bias reducer opcodes directly. + - The old `_execute_transition_parameter_reducer_rows(...)`, + `_reduce_transition_named_grad_sequence(...)`, and temporal-side transition + `matmul`/recurrent-bias formulas are now deleted from + `temporal/param_binding.py`. Static guardrails assert that transition + reducer algebra lives in the registered CUDA row program, not in the + temporal Python scheduler. This is still not final end-to-end throughput + closure because the reverse window still orchestrates several registered + entrypoints, but parameter-reduction algebra is no longer a host-side + compiler facade. + - Verification for the transition parameter-reducer CUDA program cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 1 test after rebuilding the registered-program extension with the + transition source/trainable reducer ABI, covering materialized, value-to- + cell, and recurrent-bias trainable reducer opcodes. + - 2026-05-01 reverse-program boundary K/V projection cut: added + `fused_backward_program_boundary_kv_projection_step` for the input/boundary + K/V projection adjoint. The stage resolves direct/grouped input K/V weights + from compiler tensor bindings, reads `boundary_step` from reverse artifact + rows, preserves compiler-owned grouped/direct weight selection, and returns + boundary/input parameter gradients for the existing parameter-binding + reducer. The active supported route no longer calls + `kernel_registry.boundary_public_backward(...)`. + - Remaining reverse Python orchestration after this cut: transition backward, + recurrent query binding, carry/state propagation, and final parameter + binding. The message-surface projection and attention adjoints in the + supported reverse-program window now run through fused compiler-owned + program stages. + - Verification for the boundary K/V projection cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence13 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_fused_backward_program_boundary_kv_projection_step`. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence13 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 reverse fallback deletion: removed the older per-step + registered reverse executor loop from + `run_registered_temporal_reverse_executor_window`. The registered reverse + path now has one active outcome for supported rows: + `registered_reverse_program_window`. If that compiler-owned window rejects + the request, the runtime fails closed with + `registered_reverse_program_window_reject:` instead of falling back + to `registered_reverse_executor_bindings`. + - The supported reverse-program output contracts now include `output_cells` + and `pooled_output_cells` for `mean`/`flatten` pooling. Pooled gradients are + converted into an output-cell gradient window before the compiler-owned + fused output-gradient program stage. Attention-style pooled readout remains + fail-closed until its readout-query parameter-gradient contract is carried + by the registered program rows. + - The fused reverse-program window now seeds its loop from incoming + final-state carry-cell gradients and next backend-state gradients instead + of rejecting them. This makes no-reset final-state losses and provided + state-gradient paths compiler-owned through the same reverse window. + - Current explicit fail-closed axes from the fused reverse-program gate: + unsupported output contracts/readout pools, missing output-gradient windows, + unsupported carry-cell gradient dtype/device, missing reverse artifact + roles, reset/transition-reset windows, non-backend public carry order, + unsupported boundary dtype/device, and missing output-message artifacts. + These are open compiler work items, not compatibility fallbacks. + - Verification for the reverse fallback deletion: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence14 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence14 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window --tb=short` + passed: 2 tests, covering mean-pooled public output and reset fail-closed + rejection. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence14 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence14 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients --tb=short` + passed: 5 tests, covering no-reset final-state/private-state gradients and + reset fail-closed rejection. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence14 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 reverse transition registry-wrapper deletion: removed the + reverse transition executor callable from + `RegisteredTemporalExecutorKernelRegistry`. The registered reverse program + window now calls `run_registered_transition_backward_executor(...)` + directly, and that executor calls + `registered_temporal_fused_reverse_transition_program_cuda(...)` over + compiler primitive rows, reverse executor rows, binding rows, and + memory-liveness rows. This removes one more live wrapper layer from the + supported reverse window without changing the fused transition program + contract. + - Verification for the reverse transition wrapper deletion: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 1 test. + - 2026-05-01 reverse-program front-half fusion cut: exposed + `fused_backward_program_readout_message_kv_step` as the only public C++/ + pybind/Python ABI for the readout -> output-message -> recurrent-K/V + front half of each reverse step. The active `registered_reverse_program_window` + now launches that fused compiler-owned program stage once per local step + instead of calling separate exported readout, output-message, and recurrent + K/V projection stages from Python. + - This cut deletes the smaller Python wrappers and pybind exports + `registered_temporal_fused_backward_program_readout_step_cuda`, + `registered_temporal_fused_backward_program_output_message_step_cuda`, + and + `registered_temporal_fused_backward_program_recurrent_kv_projection_step_cuda`. + Their C++ bodies are now file-local implementation helpers behind the + fused entrypoint rather than public reverse-program launch surfaces. Source + guardrails assert the old wrappers/pybind names stay absent from the active + Python ABI and reverse window. + - The fused stage consumes the same compiler products as the smaller stages: + primitive rows, forward/reverse executor rows, forward/reverse binding rows, + memory-liveness rows, program tensor bindings, reverse artifact role rows, + and reverse artifact tensor binding rows. It returns boundary/carry direct + gradients, readout parameter gradients, output-query/input-K/V gradients, + and graph-order recurrent K/V projection gradients for the existing + compiler parameter-binding reducers. + - Remaining reverse Python orchestration after this cut: transition backward, + recurrent-message backward, initial recurrent K/V projection, boundary K/V + projection, query/K/V/readout parameter binding, carry/state propagation, + and final parameter accumulation. The later transition tape/state-table + cuts below remove the stale non-tensor transition artifacts that blocked + folding the transition stage into the same program-level reverse entrypoint. + - Verification for the front-half fusion cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_front_half_* uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_fused_backward_program_readout_message_kv_step` and absence of + the old three launch tags. + `git diff --check` passed. + - 2026-05-01 reverse transition tape artifact deletion: removed + `transition_backward_tape_by_population` from the compiler reverse artifact + role table, memory runtime artifact plan, `TemporalBucketStepArtifacts`, + and the registered transition forward executor. The registered transition + backward path recomputes from compiler primitive/executor/binding rows and + does not consume that Python `TransitionBackwardTape`, so carrying it as a + non-tensor reverse artifact was stale compatibility debt. + - After this cut, source tests asserted + `transition_backward_tape_by_population` was an unknown artifact role. The + next cut below removes the remaining non-tensor transition state artifact + as well. + - Verification for the transition tape artifact deletion: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_reverse_artifact_roles_are_compiler_binding_rows --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_no_transition_tape_* uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the CUDA extension with the role-14 + reverse artifact ABI. + - 2026-05-01 reverse transition state artifact table cut: replaced + `backend_state_cache_before` with tensorized transition state-before + artifact tables on `TemporalBucketStepArtifacts`. Each step now carries + `transition_state_before_tensors` plus CPU binding rows keyed by transition + bucket ordinal and executor binding index. Registered transition backward + reconstructs the state-before tensor map from those binding rows instead + of indexing `artifacts.backend_state_cache_before` or falling back to + population TensorDict state. + - The compiler reverse artifact role table no longer contains any non-tensor + transition object roles. The C++ reverse artifact ABI now ends at role 13 + `output_cells`; source tests assert both `backend_state_cache_before` and + `transition_backward_tape_by_population` are unknown artifact roles. The + remaining reverse work is no longer object-artifact cleanup: it is fusing + the still-separate transition, recurrent-message, boundary/initial K/V, and + parameter-binding stages into fewer program-level executor entrypoints. + - Verification for the transition state artifact table cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_reverse_artifact_roles_are_compiler_binding_rows --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_state_table_fix_* uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the CUDA extension with the role-13 + reverse artifact ABI and reading transition state-before tensors through + the matching forward transition executor bindings. + - 2026-05-01 reverse-program recurrent-message/initial-KV fusion cut: + exposed `fused_backward_program_recurrent_message_initial_kv_step` as the + only public C++/pybind/Python ABI for the recurrent-message adjoint followed + by the initial recurrent K/V projection adjoint. The active + `registered_reverse_program_window` now launches one compiler-owned stage + that consumes `grad_recurrent_msg`, recurrent-message topology tensors, + compiler primitive/executor/binding/memory rows, program tensor bindings, + and reverse artifact bindings, then returns recurrent-query/input-K/V + gradients plus graph-order hidden and recurrent K/V weight gradients. + - This cut deletes the smaller public Python wrappers and pybind exports + `registered_temporal_fused_backward_program_recurrent_message_step_cuda` + and + `registered_temporal_fused_backward_program_initial_recurrent_kv_projection_step_cuda`. + Their C++ bodies are file-local helpers behind the fused entrypoint; source + guardrails assert the old wrappers/pybind names stay absent from the active + ABI and reverse loop. The remaining reverse stage split is now transition + backward, boundary K/V projection, and Python-side parameter binding/ + reduction. + - Verification for the recurrent-message/initial-KV fusion cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_recurrent_message_initial_kv_* uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_fused_backward_program_recurrent_message_initial_kv_step` and + absence of the old recurrent-message and initial-K/V launch tags. + - 2026-05-01 reverse-program recurrent-message/boundary/initial-KV fusion + cut: superseded the recurrent-message/initial-KV stage with + `fused_backward_program_recurrent_message_boundary_initial_kv_step`. The + active registered reverse program now consumes the recurrent-message + adjoint plus output-message K/V adjoints inside one C++/pybind/Python + entrypoint, runs recurrent-message backward, boundary K/V projection + backward, and initial recurrent K/V projection backward behind compiler + primitive/executor/binding/memory rows, and returns recurrent-query, + optional boundary-state, boundary K/V weight, recurrent hidden-before, and + initial recurrent K/V weight gradients. + - This cut deletes the public boundary K/V projection stage from the + registered program ABI: + `registered_temporal_fused_backward_program_boundary_kv_projection_step_cuda` + and `fused_backward_program_boundary_kv_projection_step` are no longer + exported or callable from the active route. It also removes the now-obsolete + public recurrent-message/initial-KV wrapper name from the active caller. + The old C++ boundary helper remains file-local only as implementation + detail inside the fused registered executor. Source guardrails now require + the combined ABI and assert the split recurrent-message, boundary-K/V, and + initial-K/V public stages stay absent. + - Remaining reverse Python orchestration after this cut: transition backward, + recurrent query/boundary/initial projection parameter binding, carry/state + propagation, and final reductions. The next hard compiler closure target is + fusing transition backward and parameter-gradient binding/reduction into + program-level executor products so the Python loop stops sequencing + primitive math stages. + - Verification for the recurrent-message/boundary/initial-KV fusion cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_recurrent_message_boundary_initial_kv_* uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after a fresh CUDA extension rebuild. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 reverse per-surface registry/export deletion cut: after the + readout/message/KV reverse work moved behind fused registered program + entrypoints, the old direct registered reverse Python/pybind APIs were + removed instead of kept as a parallel callable route. Deleted public + Python wrappers and pybind exports include + `registered_backward_sender_kv_projection_cuda`, + `registered_backward_partitioned_attention_cuda`, + `registered_backward_sparse_attention_cuda`, + `registered_backward_readout_layout_projection_cuda`, + `backward_sender_kv_projection`, `backward_partitioned_attention`, + `backward_sparse_attention`, and `backward_readout_layout_projection`. + The underlying C++ helper kernels are no longer exported through pybind or + Python wrappers; they are implementation details behind fused registered + program entrypoints. + - The registered temporal kernel registry no longer carries callable reverse + hooks for `output_message_backward`, `recurrent_message_backward`, + `recurrent_kv_projection_backward`, `readout_layout_projection_backward`, + or `boundary_public_backward`. The active reverse loop can still bind + parameter gradients and run transition reverse, but it can no longer + reconstruct the old readout/message/KV backward route from per-surface + Python executor methods. + - Verification for the reverse per-surface registry/export deletion cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_registry_delete_* uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after a fresh CUDA extension rebuild. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 transition parameter-gradient binding row cut: added + compiler-owned transition parameter-gradient binding rows that connect + reverse transition program output slots to `parameter_reduction` primitive + rows. The plan records the reverse executor row, bucket, transition + primitive row, parameter-reduction row, grad logical tensor, grad binding + index, parameter binding index, reducer kind, and source bindings. + - The active registered transition backward path now consumes those binding + rows through `_transition_param_grad_accumulator_from_binding_rows(...)`. + It no longer infers the parameter contract from Python checks like + `"grad_recurrent_kernel" in reverse_logical_to_slot` or + `"grad_nu_log" in reverse_logical_to_slot`. Input projection weight/bias + grads are still unfused through the existing projection helper, but the + choice is now driven by explicit compiler reducer kinds + `input_projection_weight`, `input_projection_bias`, and `materialized`. + - Runtime/compiler metadata now exposes + `_last_flat_bucket_temporal_transition_param_grad_binding_rows` and + `_last_flat_bucket_temporal_transition_param_grad_binding_summaries`. + Remaining parameter-gradient work is to move query/readout/KV + trainable-parameter reductions out of Python binding helpers and into an + executable reducer program over the same binding products. + - Verification for the transition parameter-gradient binding row cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 transition executable reducer first cut: removed the active + `bind_temporal_transition_param_grads(...)` route from the registered + reverse program. Transition trainable-parameter gradients now run through + `run_temporal_transition_param_reducer_program(...)`, with each population + reduced immediately from the reverse executor row's + `TemporalTransitionParamGradBinding` rows. The reducer validates that every + materialized/static-source gradient has a matching compiler reducer row, + records `_last_flat_bucket_transition_param_reducer_program`, and fails + closed if a reverse executor produces gradients outside the compiled + reducer contract. Source guardrails now reject the old binder symbol in the + active surface runtime and binding module. + - Verification for the transition executable reducer cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 sender K/V projection executable reducer cut: removed + `_recurrent_kv_projection_param_binding` from the temporal executor + registry and deleted `run_registered_recurrent_kv_projection_param_binding_executor(...)` + from the surface runtime. The registered reverse program now calls + `run_temporal_sender_kv_projection_param_reducer_program(...)` directly for + recurrent, boundary-input, and initial recurrent K/V projection weight + gradients. That reducer validates the role against compiler-bound message + parameter bindings (`recurrent_sender_kv_weight`, `input_sender_kv_weight`, + or `input_group_kv_weight`) before mapping trainable gradients, and records + `_last_flat_bucket_sender_kv_param_reducer_program`. + - Verification for the sender K/V projection reducer cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 recurrent query executable reducer cut: removed + `_recurrent_query_param_backward` from the temporal executor registry and + deleted `run_registered_recurrent_query_param_backward_executor(...)` from + the surface runtime. The registered reverse program now calls + `run_temporal_recurrent_query_param_reducer_program(...)` directly. The + reducer validates `recurrent_q_weight` and `output_q` through compiler + executor bindings before accumulating `slot_embed` and `q_proj.weight` + gradients, and records + `_last_flat_bucket_recurrent_query_param_reducer_program`. + - Verification for the recurrent query reducer cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 readout output executable reducer cut: moved the + `value_to_output_weight` / `output_cell_bias` trainable-gradient binding + out of the registered reverse loop and into + `run_temporal_readout_output_param_reducer_program(...)`. The reducer + validates `value_to_output_weight` through compiler static tensor bindings + and `output_cell_bias` through compiler runtime-attribute bindings before + mapping `msg_out.weight`, `output_cell_weight`, and `output_cell_bias` + gradients. Source guardrails now reject the old inline + `output_projection_param_grads = tuple(...)` reducer in + `registered_executors.py`. + - Verification for the readout output reducer cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 reverse program stage-row cut: added + `TemporalReverseProgramStagePlan` with explicit stage rows for + `output_grad_window`, `readout_message_kv_step`, per-bucket + `transition_step`, `recurrent_message_boundary_initial_kv_step`, and + `parameter_reducer_step`. Runtime metadata now records + `_last_flat_bucket_temporal_reverse_program_stage_rows` and summaries. + The active registered reverse path now dispatches CUDA-owned reverse chunks + through `registered_temporal_reverse_program_stage_cuda(...)` using those + rows, so `registered_executors.py` no longer directly calls the named + readout/message/KV or recurrent-message/boundary/initial-KV fused CUDA + wrappers. This cut left transition reverse as the next active bypass + because state-cache/TensorDict assembly had not yet been moved into the + executable stage ABI. + - Verification for the reverse program stage-row cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 transition reverse stage-dispatch cut: renamed the active + transition reverse boundary to + `run_registered_transition_reverse_program_stage(...)`, passed + `reverse_program_stage_rows` into it from the registered reverse window, and + routed the transition reverse CUDA launch through + `registered_temporal_reverse_program_stage_cuda(stage_name="transition_step", ...)`. + `registered_executors.py` no longer imports or calls the old + `run_registered_transition_backward_executor(...)`, and + `surface_executor_runtime.py` no longer calls + `registered_temporal_fused_reverse_transition_program_cuda(...)` directly. + The remaining transition closure gap is the stage ABI itself: forward + recompute tensors, state-cache/TensorDict packing, and transition parameter + reducer assembly are still Python boundary work around the stage dispatch + and must be flattened into compiler-owned tensor/binding rows before this + can count as full transition-stage closure. + - Verification for the transition reverse stage-dispatch cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 parameter-reducer stage-row gate: expanded + `TemporalReverseProgramStagePlan` so every reverse executor row receives a + `parameter_reducer_step` row on the `parameter_reduction` surface instead + of the previous single message-owned reducer row. The active sender K/V, + recurrent-query, readout-output, and transition parameter reducer programs + now require matching compiler stage rows before reducing gradients. This + removes another implicit direct reducer trigger from the registered reverse + loop. It is still not full reducer closure: the reductions themselves are + Python-host tensor reductions over compiler binding rows, not a fused + program-level CUDA reducer launch. + - Verification for the parameter-reducer stage-row gate: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 transition state-before reverse artifact ABI cut: added the + compiler reverse artifact role `transition_state_before` and encoded its + bucket/executor binding identity in the existing reverse artifact binding + rows. `TemporalBucketStepArtifacts` no longer carries + `transition_state_before_tensors` or `transition_state_before_binding_rows`; + the registered reverse artifact table now materializes transition + state-before tensors from `population_state_before` under the shared + artifact role, and `run_registered_transition_reverse_program_stage(...)` + reads those tensors from the common reverse artifact table. The old + `build_registered_transition_state_before_artifact_table(...)` side ABI is + removed from active code. Remaining gap: transition forward recompute and + reverse tensor-table extension still happen in Python around the stage + dispatch; the next closure step is a transition-stage ABI that consumes + artifact/program rows directly without host-side table assembly. + - Verification for the transition state-before reverse artifact ABI cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_reverse_artifact_roles_are_compiler_binding_rows tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 4 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 transition-stage artifact-table input cut: + `run_registered_transition_reverse_program_stage(...)` no longer accepts + the full `TemporalBucketStepArtifacts` object. The stage now receives the + shared `reverse_artifact_tensors`, `reverse_artifact_binding_rows`, and + `reverse_artifact_role_rows`, validates required roles, loads + `recurrent_msg_backend_order` and `transition_state_before` from those rows + by `local_step`, and then performs the current forward-recompute/reverse + transition subprogram. This removes the outer reverse loop's last + transition-specific artifact object dependency. Remaining gap: the + transition stage still assembles forward/reverse transition program tensor + tables in Python inside the stage boundary; the next cut is to move that + table assembly into a registered program-level stage executor. + - Verification for the transition-stage artifact-table input cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_reverse_artifact_roles_are_compiler_binding_rows tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 transition reverse stage-executor boundary cut: + introduced `RegisteredTransitionReverseStageResult` and + `registered_transition_reverse_stage_program(...)` as the single owner for + transition reverse stage recompute, reverse tensor-table extension, + reverse artifact reads, registered `transition_step` CUDA launch, state + gradient extraction, and parameter-gradient accumulation maps. The outer + `run_registered_transition_reverse_program_stage(...)` now only selects the + matching forward executor, invokes the typed stage program, places the + returned bucket gradient into the full recurrent-message gradient tensor, + and runs the compiler parameter reducer. Static guardrails now assert the + outer function does not call `_transition_program_tensor_table(...)`, + `_extend_transition_reverse_program_tensor_table(...)`, or + `registered_temporal_reverse_program_stage_cuda(...)` directly. Remaining + gap: the transition stage executor is still a Python stage boundary around + registered program CUDA calls; the next closure cut is to move this boundary + into a program-level CUDA entrypoint that consumes primitive rows, executor + rows, binding rows, reverse artifact rows, and memory-liveness rows + directly. + - Verification for the transition reverse stage-executor boundary cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_reverse_artifact_roles_are_compiler_binding_rows tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 registered transition-step CUDA program entrypoint cut: + added `registered_temporal_fused_backward_program_transition_step_cuda(...)` + and the bound extension symbol `fused_backward_program_transition_step`. + The transition reverse stage now extends the compiler tensor table once, + then enters a single registered program-layer CUDA call that performs + forward transition recompute followed by reverse transition execution over + the same primitive rows, forward/reverse executor rows, tensor binding rows, + memory-liveness rows, and `transition_step` stage row. The stage no longer + makes separate host-level calls to + `registered_temporal_fused_forward_transition_program_cuda(...)` and + `registered_temporal_reverse_program_stage_cuda(...)` for transition + backward. Remaining gap: Python still materializes the transition + program-tensor table and gradient-output slots before the program-layer + CUDA call; the next hard cut is to make the C++/CUDA stage consume reverse + artifact rows and executor bindings directly for transition inputs, + state-before values, and gradient seed slots. + - Verification for the registered transition-step CUDA program entrypoint + cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_step_program uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + - 2026-05-01 transition-step artifact/seed row binding cut: moved the + dynamic transition-step inputs out of Python tensor-table construction and + into the registered C++ program entrypoint. Python now provides only the + compiler parameter/slot table plus `transition_seed_rows`; the C++ stage + consumes `reverse_artifact_tensors`, `reverse_artifact_binding_rows`, + forward/reverse executor binding rows, and `local_step` to bind + `recurrent_msg_backend_order`, all available `transition_state_before` + inputs, and the incoming public/state gradient seeds before running the + forward recompute and reverse transition program. This removes the Python + decoder/table path for transition state-before artifacts from + `surface_executor_runtime.py`; diagonal trace state inputs are now picked + up through forward binding rows rather than a hardcoded Python state table. + Remaining gap: transition parameters are still resolved into a Python + parameter/slot table before this CUDA entrypoint. The next hard cut is to + make parameter binding rows themselves sufficient for the C++ program + tensor table, then delete the remaining `_transition_program_tensor_table` + forward-stage compatibility path. + - Verification for the transition-step artifact/seed row binding cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_artifact_stage uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 forward transition-step dynamic binding cut: removed the + active `_transition_program_tensor_table(...)` path from + `surface_executor_runtime.py`. The forward transition stage now builds the + same compiler parameter/slot table used by the reverse stage, emits compact + `transition_state_rows` for state-before inputs, and enters + `registered_temporal_fused_forward_program_transition_step_cuda(...)`. + The C++ stage binds the `aggregated_message` input from the selected + forward executor binding row, binds state-before tensors from + `transition_state_rows`, then runs the registered forward transition + program over primitive rows, executor rows, binding rows, and + memory-liveness rows. This deletes the remaining Python full transition + tensor-table materializer from the active forward transition path. + Remaining gap: parameter values are still populated into the program slot + table by Python. The next compiler-closure cut is parameter binding + execution: C++ should consume compiler parameter binding rows plus provided + tensor slots directly, so the temporal/runtime layer no longer resolves + transition parameter names or cell-param/static-tensor source strings. + - Verification for the forward transition-step dynamic binding cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_transition_step uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 transition parameter row binding cut: moved transition + parameter population out of the per-step transition program tensor-table + builder and into a compiler parameter tensor table consumed by the + registered C++ transition-step entrypoints. The active forward and reverse + transition stages now pass `transition_parameter_tensors` plus + `transition_parameter_rows`; C++ validates those rows and calls + `bind_transition_parameter_tensors(...)` before running the registered + forward recompute or reverse transition program. This deletes the active + `_transition_program_tensor_table(...)` path and removes transition + parameter source-name resolution from `surface_executor_runtime.py`. + Remaining gap: message/readout parameter materialization still has + surface-local Python table builders, and transition parameter tensor + materialization is still Python-side compiler binding materialization rather + than a single whole-program tensor materializer shared by all surfaces. + - Verification for the transition parameter row binding cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_param_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 shared program parameter tensor materialization cut: moved the + message/readout program parameter table builder out of + `registered_executors.py` into the same compiler/runtime parameter-table + module that builds transition parameter rows. Fused forward and reverse + program launches now ask `surface_parameter_tensor_table(...)` for + compiler-bound parameter tensors instead of each launch path interpreting + static/runtime parameter source bindings locally. This keeps + `registered_executors.py` focused on assembling launch products and leaves + parameter source interpretation in one compiler-owned materialization + boundary. Remaining gap: full whole-program tensor materialization still + stitches surface parameters, transition parameters, and initial state rows + in Python before launch; the next closure cut is a single executable + program tensor materializer shared by forward/reverse launches and later + backed by C++ row consumption. + - Verification for the shared program parameter tensor materialization cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_param_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 executable program tensor-table cut: added + `TemporalExecutableProgramTensorTable` in `temporal/program_tensors.py` + and migrated fused forward/reverse launches to consume that single + compiler-owned product. `registered_executors.py` no longer owns + `_fused_forward_program_tensor_table(...)` or + `_fused_reverse_program_tensor_table(...)`; it requests + `build_forward_executable_program_tensor_table(...)` or + `build_reverse_executable_program_tensor_table(...)` and passes the + resulting tensors/rows into registered CUDA entrypoints. The table product + now carries program tensors, program binding rows, transition parameter + tensors/rows, and review metadata. Remaining gap: transition step C++ still + receives transition parameters as a side table, and reverse execution is + still sequenced through multiple Python stage calls. The next hard closure + cut is collapsing those reverse stages behind a larger registered C++ + program entrypoint that consumes this executable tensor table directly. + - Verification for the executable program tensor-table cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_program_tensor_table uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 reverse full-step registered CUDA program cut: added + `registered_temporal_fused_reverse_program_full_step_cuda(...)` and the + pybind/C++ entrypoint `fused_reverse_program_full_step`. The active + registered reverse window now calls one compiler-owned full-step program per + local reverse step. That program consumes reverse artifact rows, primitive + rows, executor rows, executable tensor binding rows, transition program + tensor groups, transition seed rows, public-y seed rows, parameter rows, and + memory-liveness rows; then it runs the front readout/output-message/ + recurrent-KV half and the transition/recurrent-message/boundary/initial-KV + half in C++ without returning through Python for the intermediate reverse + stage boundary. + - This cut removed the public Python wrappers and pybind exports for the old + intermediate reverse stages: + `registered_temporal_fused_backward_program_output_grad_cuda(...)`, + `registered_temporal_fused_backward_program_readout_message_kv_step_cuda(...)`, + `registered_temporal_fused_reverse_program_window_step_cuda(...)`, + `registered_temporal_fused_reverse_program_transition_boundary_step_cuda(...)`, + and + `registered_temporal_fused_backward_program_recurrent_message_boundary_initial_kv_step_cuda(...)`. + The remaining same-file C++ helper functions are now internal + implementation details called only by the full-step entrypoint, not public + runtime ABI. + - Remaining reverse compiler work after the full-step cut: delete or fold the + now-internal C++ helper functions once the full-step body owns the + row-dispatched kernels directly, then continue replacing any fixed + transition/message slot names inside those helpers with executor/binding row + lookups. Parameter reducers are already executed through the registered + reducer row program. + - Verification for the reverse full-step registered CUDA program cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_full_step_cut uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension and + asserting `registered_fused_reverse_program_full_step` metadata. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py --tb=short` + passed: 32 tests. + `git diff --check` passed. + - 2026-05-01 reverse full-step internal-stage deletion cut: folded the + remaining staged reverse C++ wrapper bodies into + `flat_bucket_registered_temporal_fused_reverse_program_full_step_cuda(...)`. + The registered program kernel no longer defines the deleted staged symbols: + `flat_bucket_registered_temporal_fused_backward_program_output_grad_cuda`, + `flat_bucket_registered_temporal_fused_backward_program_readout_message_kv_step_cuda`, + `flat_bucket_registered_temporal_fused_reverse_program_window_step_cuda`, + `flat_bucket_registered_temporal_fused_backward_program_transition_stage_cuda`, + `flat_bucket_registered_temporal_fused_reverse_program_transition_boundary_step_cuda`, + or + `flat_bucket_registered_temporal_fused_backward_program_recurrent_message_boundary_initial_kv_step_cuda`. + Full-step now directly constructs the output-cell gradient seed, invokes the + selected readout/output-message/recurrent-KV primitive helpers, runs each + transition step from transition executor/binding/memory rows, materializes + the recurrent-message gradient from transition output rows, and executes the + recurrent-message/boundary/initial-KV tail before returning grouped outputs. + - This removes the previous "hidden old path" objection for the reverse + full-step launch: there is no public Python wrapper, no pybind export, and + no same-file staged C++ wrapper symbol for those intermediate reverse + phases. Remaining compiler work is now narrower: replace the fixed role + constants still used by the primitive helper implementations with fully + row-owned tensor binding lookups, then generalize/remove fixed supported + message/readout/transition assumptions. + - 2026-05-01 reverse artifact access-row cut: added + `TemporalReverseArtifactAccess` and + `temporal_reverse_artifact_access_rows_tensor(...)` as compiler-owned + access declarations for the registered full-step reverse program. The + active reverse launch now passes both `reverse_artifact_role_rows` and + `reverse_artifact_access_rows` into `fused_reverse_program_full_step`. + The C++ full-step body and its internal primitive helpers no longer fetch + artifacts through direct fixed role calls such as + `reverse_artifact_tensor_for_role_step(..., kReverseArtifactInputK, ...)`. + Instead, every full-step artifact read resolves an executor access id + through `reverse_artifact_access_rows`, then resolves the compiler role + through the artifact binding table. Transition state-before reads use the + same access plan plus the existing bucket/binding flags. + - This does not claim new primitive semantics yet. It removes the fixed + artifact-read ABI from the active registered reverse launch, so the next + closure blockers are the remaining fixed supported message/readout/ + transition executor assumptions. The backward transition-step path is no + longer a Python wrapper or pybind ABI; it is now only an internal same-file + subroutine of the full-step registered reverse program. + - Verification for the reverse artifact access-row cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_reverse_artifact_access_rows_are_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_artifact_access_no_transition_pybind uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_artifact_access_no_transition_pybind uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 125 tests. + `git diff --check` passed. + - 2026-05-01 forward transition-step ABI deletion cut: removed the public + Python wrapper, pybind export, and C++ entrypoint + `registered_temporal_fused_forward_program_transition_step_cuda` / + `fused_forward_program_transition_step`. The active registered transition + forward executor now binds aggregate-message, state, and parameter tensors + into the compiler program tensor table in + `surface_executor_runtime.py`, using compiler binding rows, then calls + `registered_temporal_fused_forward_transition_program_cuda(...)` directly. + This leaves the supported active route as executor/binding-row program + execution, not a separate transition-step compatibility ABI. + - Verification for the forward transition-step ABI deletion cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_no_transition_step_pybind uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_no_transition_step_pybind uv run python - <<'PY' ...` + passed a CUDA-only registered forward route smoke and reported + `cuda_forward_registered_transition_program_ok`. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 125 tests. + - 2026-05-01 reverse reset ownership cut: removed the + `reset_not_program_owned` rejection from the registered reverse window and + made reset behavior an explicit registered-program input. The active + reverse full-step call now builds compiler reset tensor rows for message + and transition reset policies, passes `reverse_reset_rows` and + `transition_state_reset_rows` into + `fused_reverse_program_full_step`, and the C++ full-step masks only the + compiler-declared state-carry gradient slots plus the recurrent hidden + carry crossing the reset boundary. Message gradients and parameter + reductions remain owned by their executor rows. + - The same cut fixed a correctness bug in saved reverse artifacts: when the + forward path uses the backend population-state cache, `population_state_before` + is now materialized from that reset-aware cache, not from the stale + TensorDict state. That makes transition-state reverse artifacts match the + actual state consumed by the forward transition executor. + - This closes the previous reset training blocker for registered reverse + windows. It does not claim full forward-program generality yet: the + remaining hard blocker is still generic program-level forward scan + dispatch over message/readout/transition executor rows, instead of fixed + assumptions about the currently supported strategy shapes. + - Verification for the reverse reset ownership cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reset_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reset_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_reset_rows uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_reverse_reset_rows_are_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_reset_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity[present] --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_reset_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity[present] --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_reset_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window[True] --tb=short` + passed: 1 test. + - 2026-05-01 forward program access/carry row cut: added compiler-owned + forward program access rows and transition state-carry rows. The fused + forward program now receives `forward_program_access_rows` and + `forward_transition_state_carry_rows`, resolves message/readout/transition + bindings through those compiler rows, and no longer reaches into the full + scan body with assumptions such as "message primitive start + 1" for + sender K/V weights or gated/diagonal-specific private-state carry branches. + Transition private state is advanced through compiler rows of + `[bucket_ordinal, input_binding, output_binding]`, so current `y/c/n/m` + and diagonal trace carries are state-carry bindings rather than temporal + scan semantics. + - This is a real forward-program ABI cut, not a throughput claim. The active + program still supports the current registered message/readout/transition + strategies, but the full forward scan now consumes compiler-produced + binding access/carry tables instead of positional fixed role assumptions + inside the program body. Remaining forward closure is to generalize the + executor strategy set itself so additional message/readout/transition + primitive row groups can register legal program executors or fail closed. + - Verification for the forward program access/carry row cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short` + passed: 3 tests. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_access_rows uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_access_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + - 2026-05-01 forward program executor-local access slot cut: changed + `forward_program_access_rows` / `reverse_program_access_rows` from global + semantic access IDs to executor-local access slots keyed by + `(slot, executor_row_index, bucket_ordinal)`. The registered forward + program C++ lookup now resolves tensors through the executor row span plus + access slot, and launch validation rejects spans with no matching compiler + access rows. This removes the fixed top-level `kForwardProgramAccess*` + ABI from the registered-program kernel; the remaining slot meanings are + strategy-local to the currently registered message/readout/transition + executors. + - Remaining after this cut: the supported forward program still has one + current message executor span and one current readout executor span because + the scheduler surface has one active message rule and one active readout + boundary. The next hard closure is registering additional forward executor + strategies against primitive row groups and dispatching them through the + same executor-local access slot contract, rather than adding another + global temporal slot enum. + - Verification for the executor-local access slot cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_reverse_program_access_rows_are_compiler_owned tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short` + passed: 4 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_access_slots uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_access_slots uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: 1 test using the rebuilt registered-program extension. + - 2026-05-01 fused forward final-state ownership cut: removed the + `materialize_final_state` demotion from the registered fused forward + program for the supported no-reset/output-cells route. The fused C++ launch + now returns the final executable program tensor table alongside + `output_seq` and final backend-order recurrent hidden state. Python + materializes final population state by reading compiler tensor binding rows + and `forward_transition_state_carry_rows`, then assembles final cells + through the registered cells-layout executor. This keeps final-state + ownership on compiler program rows instead of falling back to the + per-step surface loop. + - Remaining after this cut: fused forward still rejects artifact collection + and reset tensors. The next compiler closure target is to make reset rows + and memory/artifact liveness executable inside the same fused program + launch, so training artifact collection does not demote to the registered + per-step orchestration path. + - Verification for the fused forward final-state ownership cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_runtime.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_final_state uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + - 2026-05-01 fused forward reset-row ownership cut: added + compiler-owned `forward_reset_rows` and forward reset tensors to the + registered fused forward program ABI. The active fused launch no longer + demotes solely because population or transition reset tensors are present. + Message reset rows zero backend recurrent hidden before the boundary + message step; transition reset rows zero state-carry input bindings through + `forward_transition_state_carry_rows` before the transition program runs. + Population resets also feed transition reset behavior when no separate + transition reset tensor is declared, matching the scalar scan schedule. + - Remaining after this cut: artifact collection still demotes because the + fused forward program does not yet write the reverse artifact/liveness + buffers. The next compiler closure target is making memory liveness and + reverse artifact role rows executable inside the fused forward launch for + training. + - Verification for the fused forward reset-row ownership cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reset_plan.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reset_plan.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_runtime.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_reset_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + - Verification for the internal-stage deletion cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_full_step_inline uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_full_step_inline uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 124 tests. + `git diff --check` passed. + - 2026-05-01 reverse window-step entrypoint cut: added + `registered_temporal_fused_reverse_program_window_step_cuda(...)` and the + C++ pybind/kernel entrypoint `fused_reverse_program_window_step`. The + active reverse loop no longer calls + `registered_temporal_reverse_program_stage_cuda(stage_name="output_grad_window")` + followed by + `registered_temporal_reverse_program_stage_cuda(stage_name="readout_message_kv_step")`. + Instead, C++ materializes the output-gradient contribution for the selected + local reverse step, merges the current carry-cell seed, and immediately + runs the registered readout/output-message/recurrent-KV reverse program + over reverse artifact rows, primitive rows, executor rows, executable + program tensors, and memory-liveness rows. This is bounded to the first + reverse chunk because the readout/message/KV stage depends on + `current_grad_carry_cells`, which is produced later in each reverse + iteration by transition and recurrent-boundary stages. + - Remaining reverse Python orchestration after the window-step cut: + transition reverse, recurrent-message/boundary/initial-KV, query/KV/readout + reducers, and carry/state propagation are still sequenced in Python. The + next closure target is to fold transition plus recurrent-boundary work into + the same larger registered reverse program entrypoint, then reduce + parameter binding through compiler-owned reducer rows. + - Verification for the reverse window-step entrypoint cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_window_step uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after rebuilding the registered-program extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + - 2026-05-01 reverse-program initial recurrent K/V projection cut: added + `fused_backward_program_initial_recurrent_kv_projection_step` for the + K/V-gradient path emitted by recurrent-message backward. The stage resolves + the recurrent K/V projection weight from compiler tensor bindings, reads + `recurrent_hidden_before_backend_order` from reverse artifact rows, and + returns graph-order hidden and raw recurrent K/V weight gradients for the + compiler parameter-binding reducer. The active supported route no longer + calls `kernel_registry.recurrent_kv_projection_backward(...)` for the + recurrent-before K/V gradients. + - Remaining reverse Python orchestration after this cut: transition backward, + recurrent query binding, boundary projection backward, carry/state + propagation, and final parameter binding. The next compiler closure target + is boundary K/V projection backward because it is the last message-surface + K/V projection still calling the registry helper in this supported route. + - Verification for the initial recurrent K/V projection cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence12 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_fused_backward_program_initial_recurrent_kv_projection_step`. + The later boundary-cut verification reran the same small CUDA reverse + regression group and static guardrails after this stage stayed active. + `git diff --check` passed. + - 2026-05-01 reverse-program recurrent-message cut: added + `fused_backward_program_recurrent_message_step` as the sibling registered + reverse program stage for the recurrent attention/message adjoint. The + stage resolves `recurrent_q` from compiler tensor bindings, reads + `input_k`, `input_v`, `recurrent_k_before`, and `recurrent_v_before` from + reverse artifact binding rows, validates compiler primitive/executor/ + binding/memory rows, and runs recurrent-message attention backward inside + the fused reverse program. The active supported route no longer calls + `kernel_registry.recurrent_message_backward(...)`. + - Remaining reverse Python orchestration after this cut: transition backward, + initial recurrent K/V projection/parameter binding, recurrent query binding, + boundary projection backward, carry/state propagation, and final parameter + binding. The next compiler closure target is the initial recurrent K/V + projection backward stage, followed by boundary K/V projection backward. + - Verification for the recurrent-message cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence11 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_fused_backward_program_recurrent_message_step`. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence11 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 transition primitive-DAG legality cut: changed + `select_transition_program_executor(...)` so primitive-program legality is + no longer equivalent to matching one of the two whole-pattern transition + executor records. The selector now always builds a compiler-owned + `TransitionPrimitiveDagExecutorPlan` from registered primitive executor + records, validates external inputs, inter-op tensor edges, produced state/ + public outputs, callable program-layer symbols, tape policies, and + parameter-gradient contracts. Current gated-logspace and diagonal-RTU + records remain registered fused lowering strategies with eager + forward/backward registered program strategies; a different legal primitive DAG + now selects `transition_executor:primitive_dag:v1` with + `runtime_execution_status=registered_primitive_dag_program` instead of failing with + the old "no registered whole-pattern executor" error. Unsupported + primitives, reference-only primitives, forward references, duplicate + outputs, and unbound inputs still fail closed with typed rejection codes. + - This closes the semantic compiler-selection blocker, not all transition + execution. A new transition composed from registered primitive rows can now + be represented as a legal compiler product without adding a whole-program + `TransitionProgramExecutorRecord`; execution now goes through the + registered temporal program path or an explicit fused strategy. That split + is intentional: semantics are primitive-DAG owned, fused strategies are + optional implementations. + - Remaining transition/compiler closure after this cut: + 1. Generate transition reducer handler rows from binding-owned transform + contracts rather than fixed reducer-kind algebra. + 2. Make memory/liveness allocate and alias transition/message/readout + workspaces, tape buffers, artifact buffers, grad workspaces, and + parameter accumulators instead of mostly documenting those lifetimes. + 3. Generalize message/readout carrier handlers so adding a new forward or + reverse strategy is a local primitive-row/access-row registration, not a + fused temporal-loop edit. + - 2026-05-01 transition reverse binding registry cut: moved the transition + reverse input/parameter/output binding contracts for currently callable + recurrence primitives into `TransitionPrimitiveExecutorRecord` metadata. + `executor_bindings.py` now asks the primitive registry for + `reverse_input_bindings`, `parameter_bindings`, `reverse_output_bindings`, + and parameter-gradient outputs instead of branching on + `gated_logspace_recurrence` or `diag_rtu`. Transition DAG legality also + resolves lowered primitive aliases through the registry helper, so + `diagonal_recurrence` and `diag_rtu` normalization is registry-owned. + - Remaining transition genericity after the reverse-binding registry cut: + eager transition lowering functions still exist only for today's fused + gated-logspace and diagonal-RTU strategies. The compiler can represent a + legal primitive DAG, but executing a new transition op still requires + either a registered program-level primitive executor body or an explicit + fused lowering strategy with forward/backward/runtime-buffer contracts. + - Verification for the output-buffer and reverse-binding registry cuts: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/transition_execution/registry.py src/cortical/fabric/backend/cuda/transition_execution/program.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/transition_execution/registry.py src/cortical/fabric/backend/cuda/transition_execution/program.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_memory_buffers_2 uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `git diff --check` passed. + 4. Widen output-contract/reset/variable-K/dtype coverage only after the + compiler-owned strategy and memory contracts are executable. + - Verification for the transition primitive-DAG legality cut: + `python -m py_compile src/cortical/fabric/backend/cuda/transition_execution/program.py src/cortical/fabric/backend/cuda/transition_execution/lowering.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution/program.py src/cortical/fabric/backend/cuda/transition_execution/lowering.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 1 test. + - 2026-05-01 reverse-program recurrent K/V projection cut: wired the + registered fused reverse program stage + `fused_backward_program_recurrent_kv_projection_step` into the active + `registered_reverse_program_window`. The stage consumes output-message + recurrent K/V gradients, validates compiler primitive/executor/binding/ + memory rows, resolves the recurrent K/V projection weight from compiler + tensor bindings, reads `recurrent_hidden_backend_order` from the reverse + artifact binding table, and returns graph-order hidden/weight gradients for + the existing compiler parameter-binding reducer. The supported route no + longer calls `kernel_registry.recurrent_kv_projection_backward(...)` for + output-message K/V gradients. + - Remaining reverse Python orchestration after this cut: transition backward, + recurrent-message backward, initial recurrent K/V projection/parameter + binding, recurrent query binding, boundary projection backward, carry/state + propagation, and final parameter-gradient binding. The next compiler + closure target is recurrent-message backward because it is the other major + message-surface adjoint still outside the fused reverse program. + - Verification for the recurrent K/V projection cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence10 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_fused_backward_program_recurrent_kv_projection_step`. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence10 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 reverse-program output-message cut: added + `fused_backward_program_output_message_step` as the next registered + reverse program C++ stage. The active `registered_reverse_program_window` + now resolves `output_q` from reverse program tensor bindings, reads + `input_k`, `input_v`, `recurrent_k`, and `recurrent_v` from reverse artifact + tensor bindings, validates the compiler program/executor/memory rows, and + runs output-message attention backward inside the fused reverse program + stage. The Python window driver no longer calls + `kernel_registry.output_message_backward(...)` for the supported + no-reset/no-carry route. + - Remaining reverse Python orchestration after this cut: recurrent K/V + projection backward for the output-message K/V gradients, transition + backward, recurrent-message backward, recurrent query/initial K/V parameter + binding, boundary projection backward, carry/state propagation, and final + parameter-gradient binding. The next compiler closure target is the + recurrent K/V projection backward stage because it immediately consumes the + output-message K/V gradients now produced by the fused reverse program. + - Verification for the output-message cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence9 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_fused_backward_program_output_grad_window`, + `registered_fused_backward_program_readout_step`, and + `registered_fused_backward_program_output_message_step`. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence9 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 fused-artifact direct reverse cut: removed the active + tensor-store-to-`TemporalBucketStepArtifacts` bridge for fused forward + training artifacts. `reverse_executor.py` now sends + `TemporalReverseArtifactTensorStore` windows directly into + `run_registered_temporal_reverse_executor_tensor_store_window(...)`, which + rebases the compiler artifact binding rows for the requested backward + window and feeds those rows/tensors to the registered fused reverse program. + The old `materialize_registered_temporal_artifact_window_from_tensor_store` + function is no longer exported or called by the active reverse executor. + - This was a real active-path deletion but not yet the final reverse-span + closure. The next closure target was to move the span loop, + state-gradient seed carry, and parameter-source collection behind a single + `registered_temporal_fused_backward_program_cuda` entrypoint consuming + primitive rows, executor rows, binding rows, artifact rows, reset rows, and + memory-liveness rows. + - Verification for the fused-artifact direct reverse cut: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_boundaries.py tests/test_fabric_runtime.py` + passed. + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting + `registered_fused_forward_program_tensor_store_direct`. + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + - 2026-05-01 reverse-span program cut: added the public + `registered_temporal_fused_backward_program_cuda(...)` / pybind + `fused_backward_program_execute` entrypoint. The active + `registered_reverse_program_window` now launches one C++ reverse span per + backward window, not one public full-step launch per local step. The span + consumes compiler primitive rows, executor rows, tensor binding rows, + memory-liveness rows, reverse artifact rows, reset row groups, transition + parameter rows, and transition seed rows; it carries dynamic transition + state-gradient seeds inside the C++ span using compiler seed-output rows. + Parameter-gradient binding remains a separate registered parameter reducer + program over the source tensors returned by the span, not a Python math + fallback. + The old public `registered_temporal_fused_reverse_program_full_step_cuda` + / `fused_reverse_program_full_step` launch surface was removed; the + full-step body remains only as a private C++ implementation helper inside + the span. + - `build_temporal_fused_cuda_program_plan(...)` now reports `legal` when + registered forward/reverse span bodies and registered transition primitive + CUDA symbols are present. Unsupported transition primitive sets still fail + closed through typed transition blocker codes before runtime launch. + - Verification for the reverse-span program cut: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_span uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_span uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `git diff --check` passed. + - 2026-05-01 fused forward executor-handler dispatch cut: replaced the + sequence-wide fused forward program's top-level message/readout span lookup + with registered C++ forward executor handlers. The program-level validator + now resolves each forward span through a handler table keyed by + compiler-owned `executor_id` plus surface opcode, rejects unsupported + handlers with an explicit error, and checks access-row ownership per + handler. The active fused forward body now binds/runs the current + neighborhood-attention message carrier and projection-reduction readout + through handler state objects instead of directly resolving `message_span` + / `readout_span` in the scheduler body. Transition spans remain + multi-bucket and are validated as registered transition handlers before the + transition primitive program runs. + - This is not final forward strategy generality yet: the supported handler + table still contains the current message/readout/gated/diagonal strategies, + and the current artifact schema still expects one temporal message carrier + and one temporal readout handler. The next closure step is to make adding a + new forward strategy a local handler registration plus primitive-row tests, + then widen artifact/readout/message carrier rows when multiple carriers or + readout strategies are declared. + - 2026-05-01 reverse helper handler-dispatch cut: carried the same + compiler-owned handler lookup into the registered reverse helper stages. + Readout backward, output-message backward, recurrent-message backward, + recurrent K/V projection backward, initial recurrent K/V projection, and + boundary K/V projection now resolve message/readout executors through + `RegisteredReverseExecutorHandler` rather than open-coded scans for a + message or readout span. The helpers still implement the current supported + reverse strategies, but their entry checks now fail through handler + registration instead of temporal-side fixed span identities. + - Remaining reverse genericity after this cut: transition-adjoint internals + and parameter-source collection still need the same handler-owned shape for + adding new reverse strategies, and the artifact schema still encodes the + current single message/readout carrier assumptions until multi-carrier + compiler artifact rows are introduced. + - 2026-05-01 reverse transition-adjoint handlerization cut: moved the + registered fused reverse transition program off direct gated/diagonal + opcode branching. Reverse transition spans now resolve through + `RegisteredReverseExecutorHandler` records keyed by compiler executor id + plus surface opcode; each handler declares its primitive opcode/count + proof obligation, validates the row signature, binds tensor-role inputs + from compiler binding rows, and invokes the current gated or diagonal + primitive adjoint implementation behind the handler boundary. The dynamic + transition binding path now iterates reverse transition handler spans and + validates matching forward transition handler spans instead of enforcing a + single forward row plus single reverse row transition-step ABI. The active + C++ span still supports only the current gated-logspace and diagonal-RTU + handlers, but adding another transition adjoint is now a handler + registration plus tensor-binding contract rather than a temporal reverse + scheduler opcode branch. + - Remaining transition compiler work after this cut: transition + parameter-source/reducer algebra is still monolithic inside the registered + parameter reducer program, memory/liveness rows still need to drive real + workspace allocation and aliasing, and generic transition lowering still + needs registry/reference/executor support for new declared primitive ops. + - Verification for the reverse transition-adjoint handlerization cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_handlers uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `git diff --check` passed. + - 2026-05-01 reverse-program readout cut: added + `fused_backward_program_readout_step` as the next registered reverse + program C++ stage. The active `registered_reverse_program_window` now builds + a reverse program tensor table from compiler executor bindings, resolves + `value_to_output_weight` through those bindings, reads `output_msg` from the + reverse artifact tensor table, and runs readout layout/projection backward + through the fused reverse program stage. The Python window driver no longer + calls `kernel_registry.readout_layout_projection_backward(...)` on this + supported no-reset/no-carry route. + - This cut is deliberately a step-level readout program stage, not a full + readout-window precompute, because earlier physical steps must include the + dynamic recurrent carry gradient produced by later transition backward + steps. The stage still consumes the compiler program rows, executor rows, + binding rows, memory rows, artifact role rows, and artifact tensor binding + rows; the remaining Python orchestration is now around message backward, + transition backward, boundary projection, carry/state propagation, and + parameter reductions. + - Verification for the readout cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence8 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test, with active metadata asserting both + `registered_fused_backward_program_output_grad_window` and + `registered_fused_backward_program_readout_step`. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_sequence8 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_forward_reverse_table_artifacts tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_forward_reverse_table_plan_is_not_t1_specific --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 123 tests. + `git diff --check` passed. + - 2026-05-01 registered handler-row and binding-legality cut: removed the + fixed C++ forward/reverse handler kind/spec tables from the registered + fused program. The validator now decodes handler contracts from compiler + handler rows, checks executor id plus surface opcode, proves primitive + opcode/count against the actual primitive rows, validates capability/effect + bits, and records generic `registered.forward.compiler_handler` / + `registered.reverse.compiler_handler` ownership. This invalidates the old + fixed handler-table route instead of wrapping it. + - Tightened tensor-binding legality around the active executable program: + surface parameter resolution now raises with direction/surface/bucket/ + logical/source context instead of returning an empty placeholder tensor; + transition forward aggregate input binding is resolved from the + `aggregated_message` compiler input binding rather than primitive-row + position; missing external transition state inputs fail at tensor-table + binding time while internal producer slots remain compiler-owned output + destinations; and an unused reverse tensor-table extender that allocated + empty binding slots was deleted. + - Moved program-access opcodes onto the registered strategy access records + (`TemporalProgramAccessPattern.access_opcode`) so adding or changing a + strategy access is local to the strategy pattern metadata. The + `forward_program.py` builder no longer owns a central `_PROGRAM_ACCESS_OPCODE` + map; it emits access rows from strategy-owned access records and validates + opcode uniqueness through the registered patterns. + - Verification for this cut: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_reverse_program_access_rows_are_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_plan.py::test_surface_parameter_tensor_table_fails_closed_on_unresolved_compiler_source --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_handler_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_handler_rows uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + - Superseded closure note: this cut still had hard compiler blockers in the + forward/reverse fused program, artifact routing, memory/liveness ownership, + transition parameter reducers, and generic transition lowering. Those + blockers were attacked in the later 2026-05-02/2026-05-03 cuts: output + route rows, artifact merge rows, tensor-store reverse, registered native + callables, memory buffer validation, transition reducer ownership, direct + transition API deletion, DotProduct stress-gate execution, and the broad + no-`--lf` Fabric closure sweep below. This paragraph is retained as + historical context only, not current open work. + - Required pre-throughput compiler stress-test stage: after the compiler path + is considered closed, but before throughput optimization starts, integrate a + semantic message-math change equivalent to the fixed-slot context-nudge + DotProduct update from + `https://app.graphite.com/github/pr/Metta-AI/cortical/10/Integrate-Fabric-fixed-slot-context-nudge-DotProduct-math`. + Status: completed in the 2026-05-02 ten-iteration counter below as + Iteration 6/10. This stress stage is no longer a future pre-throughput + blocker for this branch. + Graphite can be used through the `gt` command. The acceptance criterion is + architectural: adding the new dot-product attention semantics, additional + normalization, and gating must be expressible as + declaration/IR/primitive-row/tensor-binding/strategy changes plus + reference/fused executor tests. It must not require editing + temporal scheduler ownership, fixed tensor slot enums, monolithic scan/ + reverse ABIs, cell-family route selectors, or hidden compatibility + fallbacks. If this stress test is large or invasive, the compiler path is + not real enough to move to throughput work. + - Required post-compiler cleanup stage before throughput tuning: once the + active compiler path is closed and before April21-class throughput tuning + starts, run a cleanup-only pass over non-compiler public API ownership that + has been intentionally deferred during the hard compiler work. Status: + runtime/anatomy/stale-CUDA cleanup completed in Iteration 7/10 and guarded + in Iteration 8/10. The post-counter Config/graph cleanup and final closure + checklist below supersede the earlier open bullets. Future API expansion is + product work and must not reintroduce compatibility shims, route selectors, + or a second execution source of truth. + - `src/cortical/fabric/config.py` cleanup is closed for the live public + constructor shape. `Config` no longer accepts lattice dimensions, + boundary-port fields, message/head fields, backend/K/checkpoint fields, + population placement, and initialization as one flat object. It now + requires a user graph declaration plus owned sections: + `interface`, `message`, `populations`, `readout`, `execution`, and + `initialization`. `lattice2d.Graph` is one graph declaration, not the + definition of Fabric graphs; `graphs.flat.Graph` is the explicit + user-defined flat-graph declaration for arbitrary node/edge sets. + - `src/cortical/fabric/anatomy.py` cleanup completed for the runtime + compiler path. Lattice-only coordinate construction, ports, graph edges, + KV groups, slot features, population layout, and local sender tables now + live under `src/cortical/fabric/graphs/lattice_anatomy.py`; generic + Fabric anatomy exposes flat graph facts and precomputed sender tables. + - Public constructor/blueprint cleanup is closed for the old broad-config + route. `Blueprint` now passes the graph declaration object through + `Config(graph=...)` and populates owned config sections instead of + reconstructing a flat lattice config. Remaining Blueprint limitations are + typed legality limitations, not a compatibility config path: current + runtime still requires one shared hidden size and one output aggregate. + - Message-rule public API cleanup is no longer a compiler-closure blocker + for the supported surface. `DotProduct`, `fixed_slot_context_nudge`, and + `fixed_slot_context_gate` lower through registered + `MessageRuleBackendSpec` records and train through registered fused + temporal programs. Expanding public rule coverage is future product work: + add registered specs, reference implementations, native strategies, and + typed fail-closed legality, without editing the temporal scheduler. + - Cell declaration cleanup is no longer a compiler-closure blocker for the + supported surface. Cell families lower through `CellBackendSpec` / + transition primitive records and the old broad public `Config` constructor + path has been removed. Expanding cell declaration ergonomics is future + product work as long as the lowering chain remains declaration -> IR -> + primitive rows -> tensor bindings -> registered executors. + - Backend surface cleanup is no longer an active closure blocker for the + supported route. Guardrails now reject hidden compatibility/fallback + surfaces, and supported/unsupported behavior is expressed as compiler + legality and typed blocker metadata. Additional capability-query API work + belongs after closure. + - Runtime/public-test cleanup is closed for the compiler-critical suites: + tests now assert declaration lowering, primitive rows, executor metadata, + fail-closed unsupported rows, and parity through real registered runs. + Further test pruning is normal maintenance, not a pre-throughput blocker. + - `runtime/core.py` and runtime construction cleanup is closed for the + compiler-critical route: graph facts come from graph declarations/anatomy + metadata, planner-owned behavior stays in planner/runtime products, and + backend execution ownership stays in registered compiler-selected + executors. Further file-size reduction is product maintainability work. + - Direct/chunked sequence API cleanup is no longer a compiler-closure + blocker for supported CUDA training. The active CUDA temporal route is + registered-program owned; remaining public helper APIs must remain narrow + semantic entrypoints or route through compiler execution records. + - Old CUDA execution-surface cleanup is closed for the runtime-critical + direct CUDA siblings. The pre-compiler + dispatcher package `backend/cuda/execution/**`, its C++ cell-dispatch + registry helpers, CUDA cell formula headers, and `recurrence_executor.py` + have been deleted from the active tree. The standalone direct + `backend/cuda/message_passing/**` pybind/kernels and the standalone + grouped-projection CUDA bridge were also deleted after their tests were + migrated or invalidated. Remaining CUDA operator files must either be + registered primitive executors selected by primitive rows, or explicit + low-level ops whose callers are compiler-selected strategy bodies. + - Test-suite cleanup is closed for old flat Fabric `Config` construction + fixtures in the core Fabric runtime/anatomy/public API/visualizer/backend + guard suites. Runtime tests now build graph-plus-section declarations + through owned test helpers, and the only remaining flat + `Config(width=..., height=..., hidden_size=...)` construction in the + targeted guard set is an explicit rejection assertion for the deleted + constructor shape. Remaining test cleanup is narrower: keep + source-contract tests only where they prevent a known forbidden path from + returning; move correctness coverage to declaration lowering, reference + parity, executor legality, artifact/reset behavior, and audit metadata + produced by real runs. + - Documentation hygiene is classified for compiler closure. The canonical + compiler/cleanup/throughput state is this REDO2 document. The tracked + `ai_docs/REDO_FIXMAASS.md` remains a historical alias, while + `ai_docs/AWS_RECOVERY_TRAIL.md`, `ai_docs/additonal_goals.md`, and + `ai_docs/prompt.tx` are untracked recovered planning/context inputs + referenced by the tracked REDO docs; they are not active runtime/compiler + artifacts and are not closure blockers. Decide whether to add or omit + them at commit time. + - 2026-05-02 public Config/test cleanup verification: migrated the + remaining old flat `Config(...)` constructions in + `tests/test_fabric_runtime.py` to graph-plus-section declarations and + replaced stale `runtime.config.d_public`/backend mutation references with + runtime/compiler-owned fields or section updates. Verification: + `uv run ruff check tests/test_fabric_runtime.py`, + `uv run python -m compileall -q tests/test_fabric_runtime.py`, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_stream_sequence_single_population_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_message_fast_path_matches_sparse_reference_2d tests/test_fabric_runtime.py::test_fabric_message_fast_path_matches_sparse_reference_unwrapped_lattice tests/test_fabric_runtime.py::test_fabric_stream_step_subset_messages_match_full_reference tests/test_fabric_runtime.py::test_fabric_stream_step_k1_fast_path_matches_previous_reference --tb=short`, + `uv run pytest -q tests/test_fabric_anatomy.py tests/test_fabric_public_api.py tests/test_fabric_visualizer.py tests/test_fabric_backend_boundaries.py::test_lattice_config_cleanup_stays_out_of_backend_runtime --tb=short`, + and targeted `git diff --check` passed. + - Superseded order note: the hard compiler blockers, public + API/config/anatomy cleanup, DotProduct semantic stress gate, broad + validation sweep, and final tree-hygiene classification are now completed + in the later entries below. If future cleanup exposes missing lowering or + unsupported declarations, add typed fail-closed compiler errors rather + than preserving any old config route. + - 2026-05-01 fused-forward mandatory cut: deleted the active per-step + registered forward executor fallback from + `run_registered_temporal_forward_executor_scan`. A registered CUDA forward + row now either returns through `registered_fused_forward_program_cuda` or + raises a typed fail-closed rejection tagged + `per_step_registered_forward_fallback_deleted=1`; the older per-step + orchestration remains only as explicit artifact recompute/reference + machinery, not as the supported scan owner. + - The fused forward C++ program now routes current message/readout role + assumptions through compiler-selected strategy handler records. The outer + temporal loop binds a compiler handler span and calls the strategy dispatch; + the current `neighborhood_attention_project` and + `projection_reduction_boundary` implementations own their Q/K/V/readout + access rows internally. This is still narrow strategy coverage, but it + moves another fixed assumption out of scheduler ownership. + - Verification for this cut: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_forward_scan_fails_closed_on_unsupported_primitive_programs tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short` + passed: 2 tests. + - 2026-05-01 fused-forward training artifact policy cut: made + compiler-owned stored reverse artifact tensor tables the default supported + training materialization for flat-bucket temporal execution. The planner, + runtime scheduler, and memory runtime artifact plan now default + `collect_artifacts=True` to `store_step_artifacts`; the old active + checkpoint-recompute training materialization is no longer the registered + forward path's fallback target. `forward_scan.py` no longer carries the + `store_forward_reverse_tables` compatibility flag, because the active + registered fused forward program always owns artifact tensor-table emission + for supported training rows. + - Verification for the fused-forward training artifact policy cut: + `python -m py_compile src/cortical/fabric/backend/planner.py src/cortical/fabric/backend/temporal_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/scheduler.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_fused_mandatory2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells --tb=short` + passed: 1 test. + - 2026-05-01 old CUDA execution package deletion: removed + `backend/cuda/execution/**`, `backend/cuda/recurrence_executor.py`, the + old C++ cell-dispatch registry helpers, and the old CUDA cell formula + headers/registration `.cu` files. The remaining sequence-surface internal + callers now go through `_execute_compiler_temporal_sequence_surface`, which + builds a TensorDict state and calls `execute_temporal_bucket_sequence` + instead of constructing an execution-dispatch request. Cell semantics are + still defined by `backend/cell_specs.py` and lowered through + `CellTransitionIR` / temporal primitive rows; the deleted files were only + the old dispatcher implementation surface. + - 2026-05-01 parameter-reducer trainable/runtime role ABI cut: removed the + fixed common trainable/runtime tensor arguments from the registered + temporal parameter reducer program ABI. `build_temporal_parameter_reducer_program` + now emits compiler-owned `parameter_reducer_trainable_role_rows` from the + active `trainable_param_names`, and the runtime supplies explicit + `parameter_reducer_runtime_metadata_rows` for topology tensors such as the + backend recurrent inverse order and recurrent/output cell indices. The + C++ reducer entrypoint returns one tensor per trainable parameter and + resolves sender K/V, recurrent-query, output-query, and readout reducer + parameters through role rows rather than through fixed arguments such as + `public_proj_weight`, `q_proj_weight`, `slot_embed`, or + `output_cell_bias`. + - Verification for the parameter-reducer role ABI cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `git diff --check` passed. + - Remaining parameter-reducer compiler work after this cut: reducer + strategy implementations are still the current built-in C++ handlers for + projection/query/readout/transition algebra. New reducer strategies still + need to register from primitive-row/parameter-binding metadata with typed + legality/cost records, and reducer workspace/aliasing still needs to be + allocated from the executable memory/liveness plan rather than local + handler allocation choices. + - Same-pass cleanup: deleted the unused + `RegisteredTransitionTrainableReducerHandlerKind` enum from the transition + trainable reducer handler table. Transition trainable rows now carry the + reducer opcode, and the active C++ path validates that opcode by resolving + the callable handler directly instead of retaining a second handler-kind + identity. + - 2026-05-01 temporal executor strategy registry cut: moved forward/reverse + temporal executor strategy ownership behind + `TemporalExecutorStrategyRegistry`. The compiler still exposes the older + pattern query helpers for existing callers/tests, but active table + construction, primitive dispatch, fused program handler rows, program + access rows, strategy selection, and temporal executor kernel validation + now ask the registry object for strategy matching and handler contracts. + The registry validates direction-prefixed strategy IDs, row signatures, + compiler-owned implementation contracts, and stable program-access opcodes + before any temporal plan can select an executor. + - Verification for the temporal executor strategy registry cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_temporal_strategy_matching_uses_canonical_row_group_schema --tb=short` + passed: 5 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `uv run ruff check` on the touched compiler/runtime/test files passed, and + `git diff --check` passed. + - Remaining strategy-registry compiler work after this cut: the C++ fused + program still has native callable tables for the currently supported + message/readout/transition/reducer handlers. Those functions are now + selected through compiler handler rows, but adding a new native strategy + still requires adding a C++ callable implementation and registering it with + the native handler table. The next hard closure is making that native + handler table mirror the Python compiler strategy registry contract + directly, including legality, access rows, memory/workspace requests, and + backward/reducer coverage. + - 2026-05-01 native temporal strategy row ABI cut: added a compiler-owned + `native_strategy_rows` table to the registered fused temporal program + contract. Python planning now emits one native ABI row per unique + forward/reverse strategy implementation from the + `TemporalExecutorStrategyRegistry`, records it in runtime metadata, stores + it on `RegisteredTemporalExecutorProgram`, and passes it into full + forward/backward fused program launches and transition subprogram launches. + The C++ fused program validator now checks that every selected handler row + has a matching native strategy row with direction, surface, executor id, + handler kind, primitive opcode, row signature count, capability mask, + effect mask, and schema versions. The active wrappers reject empty + native-strategy tables, so handler rows alone are no longer enough to + launch the compiler-owned fused program path. + - Verification for the native temporal strategy row ABI cut: + `uv run python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 2 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed, and `git diff --check` passed for the touched files. + - 2026-05-01 native strategy/handler contract mirror cut: extended both + compiler handler rows and native strategy rows with the same strategy + identity hash, program-access count, state-carry rule count, and verified + rewrite flag. The fused C++ program span now carries those fields, matches + native strategy rows by strategy identity instead of opcode alone, and + rejects launches when the handler contract and native strategy contract do + not match. This closes the immediate duplicate-alias risk where current + reverse diagonal strategies share the same primitive opcode and handler + kind but represent different registered compiler strategies. + The C++ side does not maintain a second string-name strategy registry for + this check; it validates the compiler-emitted native rows against handler + rows and then checks executable forward access/state-carry row counts + against the selected strategy contract. + - Boundary status for this cut: this is a real ABI hardening step, not full + generic strategy-table closure. The native C++ callable set is still the + current supported strategy set, but it can no longer be selected by + fixed primitive opcode/handler kind alone. The next closure remains moving + the native callable registration itself to generated/registered compiler + strategy records, including memory/liveness and backward/reducer coverage. + - Verification for the native strategy/handler contract mirror cut: + `uv run python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_strategy_contract_v3 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test. + - 2026-05-01 native callable contract cut: the native C++ callable + dispatch for forward message/readout and reverse transition now matches + the compiler-emitted strategy contract, not handler kind or primitive + opcode alone. Forward native callables require executor id, surface, + first primitive opcode, row count, strategy hash, access count, carry + count, and rewrite flag to match the fused program span. Reverse + transition callables do the same and split the `diag_rtu` and + `diagonal_recurrence` strategy hashes even though those lowered + primitives currently share the diagonal recurrence opcode and callable. + Reverse primitive input/parameter/output arity checks moved into the + native callable contract record before primitive math dispatch, so + primitive functions no longer own separate positional arity policy. + - Boundary status for this cut: this is another hardening step toward + generated/registered native callable tables. The supported callable set + is still narrow, but selection now requires the compiler strategy + identity and binding contract all the way to native dispatch. Remaining + closure is to generate or register the native callable metadata from the + same strategy registry product instead of maintaining static C++ records + by hand. + - Verification for the native callable contract cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_native_callable_contract_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + - 2026-05-01 native callable row-source cut: C++ callable dispatch still + owns native function pointers, but it no longer duplicates compiler + strategy hashes or access/carry/rewrite metadata in static strategy + records. Forward message/readout dispatch and reverse transition + dispatch now resolve the exact `RegisteredNativeStrategyRow` emitted by + the compiler and use that row as the strategy contract before selecting + the native callable. The remaining C++ static tables are now callable + maps keyed by executor/handler/primitive function shape, not a second + source of strategy identity. + - Boundary status for this cut: this removes the most misleading duplicate + native strategy truth. It does not yet make new strategies fully + registration-only because adding a new native callable still requires a + C++ function-table entry. The next native-callable closure is generating + or registering those callable map entries from the same strategy registry + product, or making the callable id an explicit compiler-emitted field + with a typed missing-callable rejection. + - 2026-05-01 native callable id cut: executor strategy records now declare + a direction-scoped `native_callable` id, and `native_strategy_rows` + carry its stable hash as compiler data. The fused native C++ program + still owns compiled function pointers, but message/readout/reverse + transition callable tables are now keyed by the compiler-emitted + native callable hash instead of duplicating handler kind, executor id, + surface opcode, or primitive opcode in the callable structs. This is a + stronger fail-closed boundary: a strategy can be legal and still reject + at native launch with a typed missing-callable hash until the native + function is registered. + - Boundary status for this cut: adding a new native throughput strategy + still requires implementing the C++ function pointer, but the callable + selection identity now comes from the compiler strategy registry rather + than a hand-maintained C++ semantic contract mirror. The next closure is + moving native callable registration itself to generated/registered + compiler metadata. + - Verification for the native callable id cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_native_callable_id_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `git diff --check` passed. + - 2026-05-01 transition primitive callable-row cut: the compiler now emits + `transition_primitive_callable_rows` from the transition primitive + registry and passes those rows into the registered fused forward, + transition-forward, and reverse programs. Forward transition primitive + dispatch no longer selects native C++ primitive functions by primitive + opcode directly. The C++ path maps primitive opcode to a compiler-emitted + forward callable hash, then selects the registered native function pointer + by that callable id. Reference-only transition primitives are excluded + from the callable table and remain typed fail-closed for native execution. + - Boundary status for this cut: the active transition primitive dispatch now + carries compiler-owned callable identity to the native primitive function + boundary. The remaining native-callable closure is to generate or register + the native C++ function pointer tables from compiler strategy/primitive + metadata, and to mirror this callable-row ownership for any remaining + backward primitive maps that still have hand-maintained native entries. + - Verification for the transition primitive callable-row cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_callable_rows_v3 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `git diff --check` passed. + - 2026-05-01 reverse transition backward-callable row cut: reverse transition + native dispatch now also consumes `transition_primitive_callable_rows`. + The reverse transition executor must match both the compiler strategy + native callable hash and the compiler-emitted primitive backward callable + hash for the recurrence primitive opcode before selecting the native C++ + reverse function pointer. This removes the remaining asymmetry where + forward primitive dispatch was callable-row owned while reverse transition + primitive ownership still came only from a hand-maintained native strategy + table. + - Boundary status for this cut: this does not generate native C++ function + pointers yet, but both forward and reverse transition primitive callable + identity now reaches native dispatch from compiler rows. The remaining + native-callable work is the generator/registration layer for those function + pointer tables, not another fixed primitive-opcode selection path. + - Verification for the reverse transition backward-callable row cut: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_backward_callable_rows_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `git diff --check` passed. + - 2026-05-02 parameter reducer native-callable row cut: the registered + parameter reducer program no longer selects common reducer handlers or + transition trainable reducer handlers from fixed strategy-kind ids alone. + `parameter_reducer_strategy_rows` now carry the compiler-emitted native + callable hash for readout-output, sender-K/V, recurrent-query, + transition, and output-query reducers. `transition_trainable_rows` now + carry a native callable hash for each trainable transition reducer + (`materialized_base`, `materialized_delta`, + `value_to_cell_msg_to_cell`, `value_to_cell_msg_out`, + `recurrent_bias_slot_embed`, and `recurrent_bias_cell_bias_proj`). The + C++ side still owns the compiled function pointers, but it binds them by + compiler callable hash rather than treating reducer-kind integers as the + implementation selector. The old C++ transition-trainable reducer selector + constants were deleted after the hash-owned path passed. + - Boundary status for this cut: parameter reduction is still a native + reducer-function catalog, not a generated C++ callable table. This cut + removes another hand-owned selector from the active reverse/trainable + path and makes missing native reducer implementations fail by callable + id. Remaining closure is the same generator/registration layer for all + native function pointers plus executable memory/liveness ownership. + - Verification for the parameter reducer native-callable row cut: + `uv run python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_parameter_reducer_callable_rows_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 1 test. + After deleting the dead C++ transition-trainable reducer selector constants, + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_parameter_reducer_callable_rows_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `git diff --check` passed. + - 2026-05-02 native callable compiler-catalog cut: native callable + identity now has a compiler-owned catalog in + `sequence_surface/compiler/native_callables.py`. Registered temporal + program construction builds `native_callable_catalog_rows`, validates that + every `native_strategy_rows` callable hash and every transition primitive + forward/backward callable hash is present in that catalog, and records the + catalog rows/summaries in runtime metadata. Parameter reducer and + transition-trainable reducer callable ids were moved out of the temporal + reducer implementation module into the same compiler catalog, so reducer + rows now consume catalog-owned ids rather than local private maps. + - C++ native dispatch status for this cut: the active CUDA program still + uses compiled function pointers, but the remaining tables have been moved + out of the monolithic CUDA program file into the dedicated native + implementation catalog header + `flat_bucket_registered_native_callables.cuh` + (`registered_native_transition_forward_primitive_catalog_*`, + `registered_native_forward_message_catalog_*`, + `registered_native_forward_readout_catalog_*`, + `registered_native_transition_reverse_primitive_catalog_*`, + `registered_native_parameter_reducer_catalog_*`, and + `registered_native_transition_trainable_reducer_catalog_*`) instead of + local `kExecutors`/`kStrategies`/`kHandlers` lists inside lookup + functions. This is not throughput closure and not the final generated + registry, but it removes the last scattered local selector-list shape and + gives the next pass a single generator/registration target. + - Verification for this cut: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. Full static guardrails, ruff, CUDA rebuild checks, and + `git diff --check` were then run: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_native_callable_catalog_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 2 tests. + After moving the implementation catalogs to + `flat_bucket_registered_native_callables.cuh`, + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_native_callable_header_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests after the catalog-header move. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed after the catalog-header move. + `git diff --check` passed. + - 2026-05-02 follow-up: the native callable C++ catalog is now a generated + source target owned by `compiler/native_callables.py`. The compiler + registry renders `flat_bucket_registered_native_callables.cuh`, records a + catalog fingerprint, and `test_temporal_executor_fusion_patterns_are_structured` + verifies the checked-in header exactly matches the compiler-generated + text. This closes the previous duplicated manual C++ catalog truth for + registered message/readout/transition/reducer native callables; adding a + callable now means updating the compiler registry/generator, not editing + a separate C++ selector table by hand. + - Same follow-up: executable memory/liveness ownership advanced one concrete + allocation class. The reverse full-step `grad_cells_out` workspace is now + a compiler runtime-buffer role (`reverse_grad_cells_work`) allocated by + `build_temporal_runtime_buffer_plan`, validated through runtime buffer + rows, and consumed by the fused reverse program. The old local + `at::zeros({B,total_cells,hidden})` allocation and allocating carry add + are replaced by a planned workspace buffer plus in-place accumulation. + This is not full memory closure; transition/message/readout primitive + temporaries and tape/recompute workspaces still need the same treatment. + - 2026-05-02 transition-forward runtime-buffer cut: transition `linear` + primitive outputs are now planned runtime buffers instead of local CUDA + output allocations. `build_temporal_runtime_buffer_plan` accepts + compiler-emitted `TemporalTransitionForwardRuntimeBufferRequest` entries + keyed by primitive row, emits the `transition_forward_linear_output` + runtime role, and the registered transition-forward primitive executor + retrieves that buffer by `(role, primitive_row_index)` before launching + the linear kernel. This covers both ordinary `[B,R,H]` transition affine + outputs and gate-affine `[B,R,4,H]` outputs; the gate-affine path now has + a direct-layout CUDA kernel that writes into the compiler-provided buffer + rather than allocating a temporary flat output and repacking it. The + recurrent matmul primitive output is also planned through the same request + table via `transition_forward_matmul_output`. Gated recurrence state + outputs (`next_y/c/n/m`) now use `transition_forward_state_output` + buffers keyed by output binding, and norm/public-output writes now use + `transition_forward_norm_output`. Diagonal RTU preprojection/state/trace + outputs now use `transition_forward_diag_output`, also keyed by output + binding. Full forward and reverse recompute paths both build these + requests from registered transition executor bindings; the direct + per-surface transition caller now also supplies runtime rows for the same + roles. The remaining transition-forward memory work is to enforce alias + reuse instead of merely allocating distinct planned buffers, then apply + the same runtime-buffer ownership to message/readout temporaries and + backward-side workspaces. + - 2026-05-02 forward message/readout runtime-buffer cut: the fused forward + program now allocates per-physical-step compiler runtime buffers for + recurrent message output, readout/output message output, and readout output + cells. The active fused message/readout strategy signatures receive + `runtime_buffer_tensors` plus `runtime_buffer_rows`, lookup + `forward_recurrent_msg`, `forward_output_msg`, and `forward_output_cells` + by `(role, physical_step)`, and launch the partitioned attention/readout + projection kernels directly into those buffers. This preserves reverse + artifact tensor identity without cloning while removing another active-path + local allocation class from the fused forward program. Remaining memory + closure is now concentrated in sender K/V buffers, sparse/direct helper + outputs, readout/recurrent-message backward gradients, parameter-reducer + scratch/full accumulators, and real alias reuse across compatible lifetimes. + - Same pass, reverse recurrent-message gradient workspace cut: the fused + reverse program now obtains `grad_recurrent_msg` from a compiler runtime + buffer role (`reverse_grad_recurrent_msg`) and zeros/fills that planned + buffer before running the recurrent-message backward helper. This removes + the previous local `at::zeros({B,recurrent_count,value_dim})` allocation + inside the reverse full-step loop and makes another backward workspace part + of the executable memory/liveness product. + - Verification for the follow-up so far: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + passed: 2 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py tests/test_fabric_backend_plan.py` + passed. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_native_memory_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 2 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_native_memory_v2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_forward_buffers_v6 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test after the transition-forward runtime-buffer cut. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_forward_buffers_v6 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after the reverse recompute path consumed the same + runtime-buffer role table. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + passed: 4 tests after the transition-forward runtime-buffer cut. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py tests/test_fabric_backend_plan.py` + passed after the transition-forward runtime-buffer cut. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests after the transition-forward runtime-buffer cut. + `git diff --check` passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + passed: 2 tests after the forward message/readout runtime-buffer cut. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_msg_runtime_buffers_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test after the forward message/readout runtime-buffer cut. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_msg_runtime_buffers_v1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after the forward message/readout runtime-buffer cut. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py` + passed after the reverse recurrent-message gradient workspace cut. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short` + passed: 2 tests after the reverse recurrent-message gradient workspace cut. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_msg_runtime_buffers_v2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after the reverse recurrent-message gradient workspace cut. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_msg_runtime_buffers_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test after the reverse recurrent-message gradient workspace cut. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 128 tests after the forward message/readout and reverse recurrent-message runtime-buffer cuts. + `git diff --check` passed after the forward message/readout and reverse + recurrent-message runtime-buffer cuts. + - 2026-05-02 reverse seed-role compiler-row cut: transition reverse seed + identity is no longer owned by a temporal-runtime `_TRANSITION_*` map. + `compiler/native_callables.py` now emits + `transition_reverse_seed_role_rows` for `grad_public_y` and every + `grad_next_*` state seed; `RegisteredTemporalExecutorProgram` stores that + row table; the fused backward launch contract lists it as a required + compiler table; and the C++ fused backward program validates seed rows, + dynamic binding rows, public-Y carry injection, and next-step seed outputs + against that row table instead of accepting a fixed seed-role range. + Public-Y seed rows now carry the compiler-emitted seed role id too, so the + reverse full-step loop no longer inserts a hardcoded public seed role. + This is a real Pass 3 ABI cut, not throughput work. Remaining reverse + genericity is now in the reverse helper internals and native transition + reverse handler coverage, not in Python seed-state loop ownership. + - Verification for the reverse seed-role compiler-row cut: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed: 3 tests. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_seed_roles_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed: 1 test. + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_seed_roles_v3 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed: 1 test after tightening the native callable output-row selector + so exact compiler output binding rows win over shape-compatible fallback + rows for multi-output primitives such as `diag_rtu`. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 129 tests. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `git diff --check` passed. + - 2026-05-02 registered program source split: the + `flat_bucket_registered_program_kernels.cu` translation unit is now an + ordered include list over semantic registered-program headers instead of a + single monolithic CUDA file. The split keeps one NVCC translation unit while + moving shared constants/checks, span decoding, memory/runtime buffers, + native callable binding helpers, program spans/handlers, reverse artifacts + and reset helpers, program tensor access, transition device kernels, + transition math helpers, layout kernels, forward program, reverse helpers, + transition forward/reverse programs, parameter reducers, and operator + exports into separate files under `flat_bucket/registered_program/`. + Boundary tests now read that semantic source set instead of only the small + include driver. This is cleanup for maintainability of the real registered + compiler path, not a new compatibility layer. + - Verification for the registered program source split: + `python -m py_compile tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py` + passed. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_registered_program_split2 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and rebuilt the registered program CUDA extension. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_registered_program_split2 uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 125 tests. + - 2026-05-02 reverse native transition binding-schema cut: the diagonal RTU + reverse native handler no longer treats `inputs[n]`, `params[n]`, or + `outputs[n]` as semantic role truth. It now mirrors the gated reverse + handler by resolving every input, parameter, and output through + `native_callable_program_binding_for(...)` and the compiler-emitted native + callable binding schema rows. This closes the remaining positional role + leak inside the registered transition reverse native handlers; remaining + reverse genericity work is now in message/readout/KV helper strategy + dispatch and in generated-catalog validation, not in transition reverse + handler positional ABIs. + - 2026-05-02 forward native strategy access-schema cut: forward + message/readout handlers no longer fetch their parameters by central + `kProgramAccessMessage*` or `kProgramAccessReadout*` constants inside the + handler bodies. `native_callables.py` now emits native callable binding + schema rows for executor-strategy program accesses, and the registered + forward message/readout handlers resolve their tensors through the selected + native strategy row plus those logical access schema rows. The current + dot-product/readout implementations remain the only registered strategies, + but their binding contract is now strategy-owned rather than fused-program + fixed-access-opcode owned. + - 2026-05-02 reverse message/readout native access-schema cut: reverse + readout, output-message, recurrent-K/V, recurrent-message, initial-K/V, and + boundary-K/V helpers now resolve message/readout tensors through + `program_tensor_for_native_strategy_access(...)`, the selected native + strategy row, and native callable binding schema rows. The helper bodies no + longer reference central `kProgramAccessMessage*` or `kProgramAccessReadout*` + constants. This does not add new reverse strategies yet, but it removes the + fixed access-opcode role truth from the active reverse helper internals. + - 2026-05-02 forward transition/public-state native access-schema cut: + the active fused forward program no longer uses `kProgramAccessTransition*` + constants, `forward_program_tensor_for_access_opcode(...)`, or + access-opcode binding helpers to write aggregated transition inputs or read + transition public outputs. Transition aggregate input and public-state output + now resolve through `program_binding_for_native_strategy_access(...)` / + `program_tensor_for_native_strategy_access(...)`, the compiler-emitted + native strategy row, and the native callable binding schema. + - 2026-05-02 reverse helper native implementation-dispatch cut: reverse + readout/message/KV helper implementations now route through generated native + reverse readout/message catalogs keyed by compiler-emitted native callable + hashes. The current attention, readout projection, and sender-KV kernels are + still the only registered implementations, but they are now strategy + implementations selected by native strategy rows rather than direct choices + inside the public reverse helper bodies. + - 2026-05-02 message step-index runtime-buffer cut: forward and reverse + registered programs no longer allocate per-step `step_flat` tensors with + local `at::full(...)` calls in the fused scheduler. The compiler memory + plan now emits `forward_message_step_flat` and `reverse_message_step_flat` + runtime-buffer roles with `torch.int64` dtype, the Python allocator supports + mixed float/int64 planned buffers, and C++ runtime-buffer validation accepts + int64 only for these message-step roles. Forward recurrent/output message + execution and reverse recurrent/output message adjoints now fill and reuse + those compiler-planned buffers. This is a memory/liveness execution cut: the + scheduler no longer creates hidden message-step workspace outside the + runtime buffer plan. + - 2026-05-02 transition primitive-DAG direct-eager deletion cut: the direct + population transition lowering path no longer has an eager primitive-DAG + executor at all. Primitive-DAG transition programs remain valid compiler + products, but CUDA execution is now `registered_program_only`: supported + temporal training must lower through the registered temporal executor + program, and unsupported direct population lowering fails closed before any + per-op eager interpreter can run. Source guards now reject + `_lower_transition_primitive_dag_forward`, + `_transition_primitive_dag_forward_executors`, direct + `primitive_forward_executors.get(...)` dispatch, and + `primitive_dag_eager_forward` runtime status in transition lowering. This + removes a real legacy execution sibling rather than relabeling it. + - Verification for the 2026-05-02 binding/access schema cuts: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_forward_message_readout_handlers_use_native_strategy_access_schema tests/test_fabric_backend_boundaries.py::test_reverse_message_readout_helpers_use_native_strategy_access_schema tests/test_fabric_backend_boundaries.py::test_reverse_transition_native_handlers_use_logical_binding_schema --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_native_strategy_access_schema tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_reverse_message_readout_helpers_use_native_strategy_access_schema tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 2 tests. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_native_strategy_access_current_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and rebuilt the registered-program CUDA extension against + the current tree. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_native_transition_access_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and rebuilt the registered-program CUDA extension after the + transition/public-state access cut. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_strategy_dispatch_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and rebuilt the registered-program CUDA extension after the + reverse native helper-dispatch cut. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py --tb=short` + passed: 128 tests. + `git diff --check` passed. + - Verification for the message step-index runtime-buffer and transition + primitive-DAG dispatch/direct-eager deletion cuts: + `python -m py_compile src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/transition_execution/lowering.py` + passed. + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/transition_execution/lowering.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py` + passed. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_boundaries.py::test_transition_execution_monolith_was_deleted --tb=short` + passed: 2 tests. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_planned_step_flat_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and rebuilt the registered-program CUDA extension. + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 133 tests after updating the import guard to read the semantic + registered-program source split. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted tests/test_fabric_backend_plan.py::test_primitive_dag_transition_direct_lowering_is_registered_program_only tests/test_fabric_backend_plan.py::test_temporal_forward_executor_rows_fall_back_to_registered_primitive_dag_strategies tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_derives_carry_rows_from_compiler_bindings tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_registers_dynamic_reverse_seed_roles tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_cover_parameterless_primitive_dag_adjoint tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records --tb=short` + passed: 7 tests. + - 2026-05-02 registered-program allocation ownership audit: added a compiler + audit for allocation sites in the semantic registered-program CUDA headers. + The audit classifies every `at::empty`, `at::zeros`, `at::full`, + `zeros_like`, `empty_like`, and tensor `new_empty`/`new_zeros` use as a + primitive output, planned runtime buffer, metadata row, or illegal scheduler + allocation. Current registered-program sources classify as primitive + outputs or metadata rows, while the previously hidden scheduler workspaces + have been moved to planned runtime buffers. The boundary test now fails if + a new allocation appears without an explicit owner, which keeps future + memory/liveness cuts honest instead of letting new unplanned workspace drift + back into the fused scheduler. + - 2026-05-02 active-path wording cleanup: active CUDA compiler-path errors + and metadata no longer describe removed Python/per-step/monolithic routes + as fallbacks. Unsupported execution now reports the required compiler-owned + product directly: planner-lowered K schedule rows, boundary sequences, + registered fused forward/reverse programs, registered executor bindings, + compiler memory-plan windows, and compiler-owned reducer programs. This is + a small governance cut but it matters for review hygiene: the code should + present one real compiler path rather than explaining which deleted path it + refused to run. + - 2026-05-02 DotProduct semantic stress-gate start: added an explicit + `fixed_slot_context_nudge` `DotProduct` math variant at the declaration and + message-rule registry layer. The new rule lowers to distinct compiler-owned + primitive rows for receiver public context, sender slot identity, query + nudge scaling, dynamic value projection, output projection, and message + normalization. This is intentionally not claimed as active CUDA support yet: + executor selection fails closed until a registered forward/reverse native + strategy implements that row group. This is the first piece of the required + pre-throughput stress stage; the remaining work is the actual native + executor, tensor-role bindings, public-update gate/norm transition rows, + and backward/parameter-gradient contracts. + - 2026-05-02 DotProduct semantic stress-gate strategy contract: the + `fixed_slot_context_nudge` message row group now registers explicit + forward and reverse compiler strategy patterns over its full primitive + sequence: + `linear+linear+mul+concat+linear+concat+linear+attention_logits+add+segment_softmax+weighted_sum+linear+normalize`. + The temporal primitive registry now includes the `normalize` opcode, the + native callable catalog and binding schema include + `native.forward.msg_fixed_slot_context_nudge.v1` and + `native.reverse.msg_fixed_slot_context_nudge.v1`. This checkpoint originally + blocked executable planning with `UNVERIFIED_REWRITE`; the later native + forward/reverse/reducer cuts removed that gate after the compiler-owned + native implementations landed. The important contract remains: new message + math is visible to the compiler through rows, bindings, native callable ids, + and typed legality, without editing scheduler bodies or pretending CUDA + support exists before the native implementation is proven. + - Verification for the DotProduct semantic stress-gate strategy contract: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_matches_strategy_but_requires_verified_rewrite --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_matches_strategy_but_requires_verified_rewrite --tb=short` + passed: 2 tests after updating the registry-structure guard to require + exactly the context-nudge strategies to be rewrite-gated. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_context_nudge_catalog_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and rebuilt the registered-program CUDA extension with the + generated context-nudge native callable catalog entries. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py --tb=short` + passed: 136 tests. `uv run ruff check` for the touched compiler/test Python + files, native callable header validation, and `git diff --check` also + passed. + - 2026-05-02 DotProduct semantic stress-gate parameter-binding cut: + `message_parameter:*` compiler sources now resolve into executable program + tensor-table slots instead of being metadata-only declarations. The + fixed-slot context-nudge runtime owns the new semantic parameters for + query-slot projection, sender-slot key projection, query nudge scale, and + sender context key; static materialization now exposes compiler-named + tensors for query slot, nudge scale, sender slot key, sender context key, + input/group/recurrent value projection, and message output projection. The + executable surface parameter table resolves these through the compiler + binding rows and still fails closed for unresolved message parameters; it + does not synthesize empty fallback tensors. At this checkpoint the native + forward/reverse context-nudge callables remained blocked by + `UNVERIFIED_REWRITE`; the later native forward/reverse/reducer cuts consume + these bindings directly and remove that strategy gate. + - Verification for the DotProduct semantic stress-gate parameter-binding cut: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_message_parameters_materialize_from_compiler_bindings tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_matches_strategy_but_requires_verified_rewrite tests/test_fabric_backend_plan.py::test_surface_parameter_tensor_table_fails_closed_on_unresolved_compiler_source --tb=short` + passed: 3 tests. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_message_parameters_materialize_from_compiler_bindings tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short` + passed: 2 tests, including the boundary guard that rejects hidden + tensor-table fallback allocation. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py --tb=short` + passed: 137 tests. + `uv run ruff check` for the touched runtime/compiler/test Python files + passed. + - 2026-05-02 DotProduct semantic stress-gate forward-native cut: + the generated forward native callable catalog now points + `native.forward.msg_fixed_slot_context_nudge.v1` at real registered-program + symbols instead of pending stubs. The forward program binds the compiler + access rows for query-slot weight, query nudge scale, sender slot key, + sender context key, input/group/recurrent value weights, and message output + weight. The CUDA registered-program split now includes context-nudge key + bank materialization, value-only sender projection, and a context-nudge + attention/message kernel that emits the message-rule-owned output width. + This was not full stress-gate closure at the time of the cut: the strategy + still failed closed at executable planning because reverse/backward support + was incomplete. The later reverse-native and parameter-reducer cuts removed + the rewrite gate; the remaining activation blocker is now end-to-end + training through fused forward artifact tensors, including pooled readout + output-contract handling. + - Verification for the DotProduct semantic stress-gate forward-native cut: + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_context_nudge_forward_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and rebuilt the registered-program CUDA extension with the + real forward context-nudge symbols in the generated callable catalog. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_message_parameters_materialize_from_compiler_bindings tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_matches_strategy_but_requires_verified_rewrite --tb=short` + passed: 2 tests. + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified --tb=short` + passed after classifying the new context-nudge key/value projection + allocations as primitive outputs. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py --tb=short` + passed: 137 tests. + - 2026-05-02 DotProduct semantic stress-gate reverse-native cut: + reverse message strategy callables now own their compiler binding lookup + instead of receiving dynamic-dot-product-specific tensors from the shared + reverse step. The shared recurrent K/V, recurrent-message, initial-K/V, + and boundary-K/V reverse steps now pass the compiler program tensor table, + reverse access rows, native strategy row, and native callable binding + schema to the selected reverse message strategy. The generated native + callable catalog points `native.reverse.msg_fixed_slot_context_nudge.v1` + at real context-nudge reverse symbols, and the pending reverse symbols were + removed. The context-nudge reverse native path now includes a + registered-program CUDA kernel for the PR-10-style output-projected and + normalized message adjoint, plus value-only recurrent and boundary + projection adjoints that consume compiler-owned logical bindings. The + strategy intentionally remained blocked at this checkpoint: output readout + key-width compatibility and context-nudge-specific parameter reducers still + needed closure before the stress gate could become an active trainable + path. Those native/reducer blockers have since been addressed; the active + blocker is the fused forward training artifact path and end-to-end parity. + - Verification for the DotProduct semantic stress-gate reverse-native cut: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_matches_strategy_but_requires_verified_rewrite --tb=short` + passed: 1 test. + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed: 104 tests. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_context_nudge_reverse_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and forced a fresh registered-program CUDA extension build + with the new reverse context-nudge native symbols and kernel compiled. + - 2026-05-02 DotProduct semantic stress-gate key-width cut: + registered attention/readout program kernels no longer assume the message + key bank width equals the query width. The partitioned attention, sparse + attention, and readout-layout epilogue entrypoints now carry `head_dim` + from the selected query row and `key_dim` from the compiler-bound key bank + separately, requiring only `key_dim >= head_dim` for strategies that read a + prefix of a widened key bank. This is the compiler-owned shape rule needed + by the fixed-slot context-nudge message strategy, whose key bank is + `slot_key || context_key`. The temporal scheduler still sees generic + message/readout executor rows; the widened key layout is contained inside + the registered primitive strategy and its tensor bindings. + - Verification for the DotProduct semantic stress-gate key-width cut: + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_key_dim_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test and rebuilt the registered-program CUDA extension from + source with separated query/key-bank widths. Focused strategy/access tests + passed with + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_matches_strategy_but_requires_verified_rewrite tests/test_fabric_backend_boundaries.py::test_reverse_message_readout_helpers_use_native_strategy_access_schema --tb=short`. + `uv run ruff check` for the touched compiler/test Python files, + native-callable generated-header validation, and `git diff --check` passed. + - 2026-05-02 DotProduct semantic stress-gate reverse parameter-output cut: + the context-nudge reverse message kernel now materializes gradients for + the rule-owned `message_query_nudge_scale` scalar and + `message_output_weight` projection in addition to the existing query/key/ + value adjoints. The fused backward program propagates those additional + context-nudge output slots through the compiler-owned boundary output group + instead of dropping them inside the C++ strategy. This was not activation + at the time: the reducer path still needed dedicated context-nudge + parameter reducer rows before the rewrite gate could be removed. + - Verification for the DotProduct semantic stress-gate reverse + parameter-output cut: + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_context_nudge_param_outputs_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test after compiling the context-nudge scalar/output-weight + gradient kernel outputs. + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_context_nudge_reducer_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed: 1 test after recompiling the registered-program extension with + context-nudge output-group propagation and the new reducer ABI. + - 2026-05-02 DotProduct semantic stress-gate parameter-reducer cut: + fixed-slot context-nudge message backward now has a registered + `context_nudge_message` parameter reducer strategy selected from compiler + reducer rows, not from the old recurrent-query or sender-K/V reducer + assumptions. The reducer owns query-slot projection gradients, + sender-slot-key projection gradients, direct sender-context-key gradients, + the rule scalar `message_query_nudge_scale`, and `message_output_weight`. + Runtime metadata rows now expose `input_cell_idx` alongside recurrent and + backend-order metadata so input/recurrent key-bank gradients reduce through + row-owned cell-index bindings. Value-only sender projection gradients are + still allowed through the shared sender projection reducer, but only after + validating the context-nudge value-weight logical bindings. This was still + not final activation at the time: active end-to-end parity still had to + prove the full semantic rewrite and gradient contract before the temporary + rewrite gate could be removed. + - Verification for the DotProduct semantic stress-gate parameter-reducer cut: + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_context_nudge_reducer_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed: 1 test, including the new context-nudge reducer case over + compiler trainable-role rows and runtime metadata rows. + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_matches_strategy_but_requires_verified_rewrite --tb=short` + passed: 2 tests and validated the generated native callable catalog now + includes `native.reverse.parameter_reduction.context_nudge_message.v1`. + `uv run ruff check ... && git diff --check` passed for the touched + compiler/runtime/test files. + - 2026-05-02 DotProduct semantic stress-gate activation cut: + the context-nudge forward and reverse strategy patterns no longer require + the temporary `UNVERIFIED_REWRITE` legality gate. Executable planning now + selects the registered context-nudge forward/reverse strategies and their + dedicated parameter reducer from compiler rows and binding tables. The + active failure moved out of strategy selection: training activation now + requires the fused forward program to serve the compiler-planned output + contract and reverse artifact tensor table end to end. The current hard + target is pooled-output fused forward artifact execution followed by + reverse tensor-store consumption and parity for the PR-10-style message + math. + - Verification for the DotProduct semantic stress-gate activation cut: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_selects_registered_strategy tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short` + passed: 2 tests and verifies context-nudge executable plans are legal + without rewrite blockers. + - 2026-05-02 DotProduct semantic stress-gate fused-training cut: + the fused forward program now supports compiler-planned pooled output + buffers while still emitting raw output-cell reverse artifacts. The fused + reverse full-step program now collects `grad_aggregated_message` into a + compiler-sized recurrent-message buffer instead of assuming `value_dim`, + and its boundary output group accepts strategy-declared extra parameter + gradient slots. Context-nudge transition input-projection gradients now + route directly to the compiler-selected `message_to_cell_weight` source + instead of the old composed `value_to_cell_weight` reducers. A small CUDA + training probe with the declared PR-10-style context-nudge DotProduct rule + now runs through the fused artifact tensor-store and registered reverse + program end to end, producing gradients for boundary input, + `msg_to_cell.weight`, `msg_out.weight`, and `message_query_nudge_scale`. + - Verification for the DotProduct semantic stress-gate fused-training cut: + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_context_nudge_boundary_group_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_program_transition_linear_forward_cuda_uses_compiler_rows --tb=short` + passed after recompiling the registered program extension. The manual CUDA + stress probe reported + `registered_temporal_fused_backward_program_cuda;window_start=0;window_len=8` + with reverse artifact roles + `boundary_step,cells_prev,input_k,input_v,recurrent_k_before,recurrent_v_before,recurrent_k,recurrent_v,recurrent_hidden_before_backend_order,recurrent_hidden_backend_order,recurrent_msg_backend_order,output_msg,output_cells,transition_state_before`. + - 2026-05-02 message reverse parameter-output metadata cut: removed the + context-nudge-specific branch from the temporal reverse scheduler. Message + reverse strategies now declare parameter-gradient outputs in the strategy + registry, including whether each output comes from the recurrent-query + gradient slot or from extra boundary output slots, and + `temporal/registered_executors.py` builds generic message-strategy reducer + requests from those compiler-owned rows. This keeps the context-nudge CUDA + primitive and reducer implementation, but the temporal scheduler no longer + branches on `fixed_slot_context_nudge_message_backward` or positional + strategy names. Remaining reducer work is to make the reducer C++ ABI itself + accept generic strategy-declared output roles rather than the current + context-nudge-specific reducer tensor groups. + - Verification for the message reverse parameter-output metadata cut: + `uv run ruff check ...` passed for the touched compiler/temporal/test files; + `python -m compileall -q ...` passed; targeted pytest passed for + `tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_selects_registered_strategy`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_message_parameters_materialize_from_compiler_bindings`, + and CUDA + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program`. + - 2026-05-02 message output-dimension role cut: message rule compilation now + carries an explicit `output_dim_role` from `MessageRuleBackendSpec` through + `MessageRuleIR` into `CompiledMessageRule`. The shared temporal helper + `temporal_message_output_dim()` no longer imports or branches on + `DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE`; it maps the compiler-owned output + dimension role (`value_dim` or `d_msg`) to the runtime dimension. This + removes another temporal-side message-rule identity check while preserving + the current baseline and PR-10-style context-nudge output widths. + - Verification for the message output-dimension role cut: + `python -m compileall -q ...` and `uv run ruff check ...` passed for the + touched message-rule, temporal-common, and test files. Targeted pytest + passed for + `tests/test_fabric_backend_plan.py::test_default_message_rule_contract_is_planner_visible`, + `tests/test_fabric_backend_plan.py::test_message_rule_backend_specs_are_registered_like_cell_specs`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_lowers_as_distinct_message_program`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_message_parameters_materialize_from_compiler_bindings`, + `tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings`, + and CUDA + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program`. + - 2026-05-02 message runtime materialization metadata cut: runtime-owned + message parameters and static tensors are now declared by the compiled + message program instead of a runtime branch on + `DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE`. `MessageRuleBackendSpec` now + carries runtime module specs, runtime parameter specs, and static tensor + source specs through `MessageRuleIR` into `CompiledMessageRule`; the runtime + installs those modules/parameters generically and materializes + message-rule static tensors by declared source kind. The transition + parameter table also stopped importing the context-nudge constant and now + chooses projected-message transition input weights from the compiled + message `output_dim_role`. This removes another active-route message-rule + identity check while keeping the context-nudge CUDA primitive itself inside + the registered message strategy boundary. + - Verification for the message runtime materialization metadata cut: + `uv run ruff check ...` and `python -m compileall -q ...` passed for the + touched message-rule, runtime, temporal-parameter, and test files. Targeted + pytest passed for + `tests/test_fabric_backend_boundaries.py::test_message_rule_runtime_materialization_is_declared_by_message_program`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_lowers_as_distinct_message_program`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_message_parameters_materialize_from_compiler_bindings`, + and CUDA + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program`. + - 2026-05-02 message strategy reducer tensor-table cut: the registered + parameter reducer launch no longer has context-nudge-specific gradient + tensor groups in its Python/C++ ABI. Message strategy parameter gradients + now enter the reducer as a generic tensor table plus + `message_strategy_grad_rows`, where each row records the selected reducer + kind, strategy-declared output role, and tensor slot. The context-nudge + reducer remains a strategy-specific native handler, but it resolves its + query-slot, input-key-bank, recurrent-key-bank, nudge-scale, and + output-weight gradients by role rows rather than by fixed launch + arguments. This keeps strategy math local to the strategy while making the + reducer program ABI extensible for additional message strategies. + - Verification for the message strategy reducer tensor-table cut: + `python -m compileall -q ...` and `uv run ruff check ...` passed for the + touched reducer Python/test files. CUDA + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_message_strategy_reducer_rows_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + passed after rebuilding the registered-program extension with the generic + message strategy grad-row ABI. CUDA + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_message_strategy_reducer_rows_train_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program --tb=short` + passed through the active fused training path with the same reducer ABI. + - 2026-05-02 transition binding message-role cut: removed the remaining + context-nudge lowering-name check from + `compiler/executor_bindings.py`. Transition parameter-gradient binding now + selects `message_to_cell_weight` from the compiler-owned message + `output_dim_role` carried on primitive rows, not from + `dot_product_fixed_slot_context_nudge` or `compiled_lowering_kind`. This + keeps the context-nudge native strategy specific where it belongs while + making transition input-projection binding a generic row-role decision. + - Verification for the transition binding message-role cut: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py tests/test_fabric_backend_boundaries.py` + passed, `python -m compileall -q + src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py` + passed, and targeted pytest passed for + `tests/test_fabric_backend_boundaries.py::test_message_rule_runtime_materialization_is_declared_by_message_program`, + `tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_selects_registered_strategy`, + and + `tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row`. + - 2026-05-02 strategy C++ entrypoint registry cut: moved message/readout and + reverse-transition native C++ entrypoint lists out of + `compiler/native_callables.py` executor-name maps and into the registered + executor strategy records in `compiler/executor_patterns.py`. The native + callable catalog now reads `pattern.cxx_entrypoints`, so adding a new + message/readout strategy localizes the C++ handler contract to the + strategy registry instead of expanding scheduler/catalog switch tables. + - Verification for the strategy C++ entrypoint registry cut: + `python -m compileall -q ...` and `uv run ruff check ...` passed for the + touched compiler/test files. Targeted pytest passed for + `tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned` + and + `tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured`; + the checked-in generated native-callable header still validates against + `temporal_native_callable_generated_header_text()`. + - 2026-05-02 transition primitive callable-contract registry cut: moved + transition primitive forward C++ entrypoint, forward input/parameter/output + binding schema, forward output runtime-buffer contracts, and reverse + native-callable ownership from `compiler/native_callables.py` symbol maps + onto `TransitionPrimitiveExecutorRecord` in + `transition_execution/registry.py`. The native callable catalog now reads + those compiler registry rows directly, so adding a transition primitive + callable updates the transition registry instead of expanding a second + symbol-keyed ABI table in the catalog layer. + - Verification for the transition primitive callable-contract registry cut: + `python -m compileall -q ...` and `uv run ruff check ...` passed for the + touched transition registry, native callable, and test files. Targeted + pytest passed for + `tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted`, + `tests/test_fabric_backend_plan.py::test_transition_program_compiler_uses_cuda_nn_primitive_registry`, + and + `tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured`. + - 2026-05-02 parameter-reducer output tensor-table cut: moved registered + parameter-reducer trainable-gradient outputs out of lazy C++ `zeros_like` + allocation and into a compiler-built output tensor table. The Python + parameter-reducer program now preallocates active trainable-gradient + outputs from reducer rows and passes `parameter_output_tensors` into the + registered CUDA reducer. C++ validates that each requested role resolves to + a compiler-provided output tensor and only mutates that table; inactive + trainables remain zero-size sentinels. This closes one concrete + memory/liveness leak where reducer strategies could allocate cross-stage + parameter outputs internally. + - 2026-05-02 parameter-reducer native-callable registry cut: moved + parameter-reducer and transition-trainable-reducer native callable ids plus + C++ implementation symbols out of `compiler/native_callables.py` maps and + into `compiler/reducer_patterns.py` strategy records. The native callable + catalog now reads reducer callables from reducer strategy records, matching + the executor-strategy and transition-primitive callable ownership model. + - Verification for these reducer/memory cuts: + `uv run ruff check ...` passed for the touched compiler, wrapper, temporal, + and test files; `python -m compileall -q ...` passed for the same Python + files. Targeted pytest passed for + `tests/test_fabric_backend_boundaries.py::test_parameter_reducer_outputs_are_compiler_provided_tensor_table`, + `tests/test_fabric_backend_boundaries.py::test_registered_program_allocations_are_compiler_classified`, + `tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned`, + `tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured`, + and the CUDA reducer test + `tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows`. + - 2026-05-02 generic transition primitive onboarding cut: promoted the + declared `tanh` transition op from a `reference_only` registry row into a + compiler-owned CUDA forward/reverse primitive. The primitive now has a + temporal opcode, transition primitive executor record, forward and backward + native callable hashes, logical binding schema, runtime-buffer output + contract, registered program CUDA kernels, direct CUDA wrappers, + primitive-DAG eager forward execution, forward strategy row, reverse + strategy row, and a fused reverse native handler selected by + `primitive_rows`, `reverse_executor_rows`, `reverse_executor_binding_rows`, + `native_strategy_rows`, and `transition_primitive_callable_rows`. The fused + program binding validator now accepts parameterless primitive rows instead + of requiring every primitive span to contain a parameter binding; + parameterized primitives still validate their callable binding schema. This + is the first non-composite, parameterless transition op that can execute + through both program-level transition primitive dispatch directions. + - Verification for the generic `tanh` transition cut: + `uv run ruff check ...` and `python -m compileall -q ...` passed for the + touched transition registry/lowering, primitive registry, memory-plan, + wrapper, and test files. Targeted pytest passed for + `tests/test_fabric_backend_plan.py::test_transition_program_compiler_uses_cuda_nn_primitive_registry`, + `tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records`, + `tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured`, + `tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned`, + CUDA + `tests/test_fabric_backend_plan.py::test_program_transition_tanh_cuda_uses_compiler_rows_without_parameter_bindings`, + and CUDA + `tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_registered_tanh_callable`. + - 2026-05-02 generic transition primitive reverse-handler cut: added + `native.reverse.transition_tanh.v1` and + `run_registered_tanh_reverse_transition_handler` so a parameterless unary + adjoint is selected through the same reverse native strategy table as the + gated and diagonal composite handlers. The handler resolves `output`, + `grad_output`, and `grad_input` through compiler binding schema and calls + the row-owned `program_transition_tanh_backward` kernel; it does not index + positional C++ argument bundles or special-case a cell family. + - Verification for the `tanh` reverse-handler cut: + `uv run ruff check ...` and `python -m compileall -q ...` passed for the + touched transition registry, executor-pattern registry, C++ headers, and + tests. The native callable generated-header validator passed. Targeted + pytest passed for + `tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records`, + `tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured`, + `tests/test_fabric_backend_boundaries.py::test_reverse_transition_native_handlers_use_logical_binding_schema`, + `tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted`, + CUDA + `tests/test_fabric_backend_plan.py::test_program_transition_tanh_cuda_uses_compiler_rows_without_parameter_bindings`, + CUDA + `tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_registered_tanh_callable`, + and CUDA + `tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_tanh_callable`. + - 2026-05-02 primitive-DAG forward strategy generality cut: transition + forward executor selection now first tries a registered bucket-level fused + strategy and then falls back to registered per-primitive transition + strategies. `linear`, `matmul`, `norm_or_identity`, and `tanh` can now be + selected as local forward primitive rows without adding another + bucket-sized fixed pattern. Row matching supports explicit wildcard + parameter lists for primitive-local strategies, while the existing + gated/diagonal fused strategies keep exact row signatures and therefore + still win when their structural bucket programs match. + - 2026-05-02 primitive-output contract cut: transition forward runtime + buffer planning now resolves primitive output allocation by local output + position when a compiled DAG tensor name differs from the primitive-local + role name. A declared `tanh(transition_input) -> public_y` can therefore + use the same `program_transition_tanh_forward` output contract as + `tanh(input) -> output`; the compiler maps program tensor names to + primitive binding positions instead of requiring op authors to use internal + primitive role names. + - 2026-05-02 reverse primitive-DAG fail-closed cut: transition reverse table + construction now rejects primitive rows that are not covered by a + compiler-owned reverse primitive or by an explicitly registered composite + reverse executor whose declared reverse inputs/parameters cover the + surrounding forward rows. This prevents a generic primitive-DAG program + from silently training only the subset of rows that happen to have a + reverse handler today. + - Verification for the primitive-DAG strategy/coverage cuts: + targeted pytest passed for + `tests/test_fabric_backend_plan.py::test_temporal_forward_executor_rows_fall_back_to_registered_primitive_dag_strategies`, + `tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_cover_parameterless_primitive_dag_adjoint`, + `tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names`, + and + `tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured`. + - 2026-05-02 parameterized primitive-DAG training cut: standalone + transition `linear` now has a registered reverse native callable, + `transition_linear_primitive_backward` reverse executor strategy, logical + reverse binding schema, and C++ native handler selected by + `reverse_executor_rows`, `reverse_executor_binding_rows`, and the native + callable table. Reverse tensor-gradient bindings are compiler-owned roles: + `tanh` consumes the public-gradient seed, produces the intermediate + `grad_transition_input`, and `linear` consumes that same binding before + producing `grad_aggregated_message`, `grad_value_to_state_weight`, and + `grad_recurrent_bias`. Parameter-gradient reducer choice is derived from + actual primitive parameter bindings (`value_to_state_weight` -> + input-projection-weight reducer, `recurrent_bias` -> + input-projection-bias reducer), not fixed `weight`/`bias` slot names. + Runtime transition program tables now accept multiple transition executor + rows per bucket for primitive DAGs, so a `linear -> tanh` bucket is a real + grouped compiler program instead of a single composite compatibility row. + - Verification for the parameterized primitive-DAG training cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal src/cortical/fabric/backend/cuda/transition_execution` + passed; `git diff --check` passed; `uv run ruff check` passed for the + touched compiler, transition runtime, transition registry, and backend-plan + test files. Targeted pytest passed for + `tests/test_fabric_backend_plan.py::test_temporal_forward_executor_rows_fall_back_to_registered_primitive_dag_strategies` + and + `tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_cover_parameterless_primitive_dag_adjoint`; + full backend-plan pytest passed with 106 tests; the boundary/plan/import + guardrail set passed with 147 tests. CUDA + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_linear_reverse_native_handler_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_linear_callable --tb=short` + passed after rebuilding the registered-program extension and dispatching + `native.reverse.transition_linear_primitive.v1` through the fused reverse + transition program; the paired linear/tanh fused reverse native-handler + CUDA test set also passed. + - 2026-05-02 norm primitive-DAG training cut: standalone + `norm_or_identity` now follows the same compiler-owned reverse path as + `linear` and `tanh`. The transition primitive registry declares + `native.reverse.transition_norm_or_identity_primitive.v1`, logical reverse + inputs/parameters/outputs, and the `grad_weight` parameter-gradient + contract; the reverse strategy registry adds + `transition_norm_or_identity_primitive_backward`; the checked-in native + callable catalog includes the C++ handler; and + `run_registered_norm_or_identity_reverse_transition_handler` resolves + `input`, `grad_output`, `weight`, optional `eps`, `grad_input`, and + `grad_weight` through binding rows before calling the row-owned + norm backward primitive. This removes the remaining obvious parameterized + primitive reverse-handler gap for the currently registered transition + primitive DAG set. + - Verification for the norm primitive-DAG training cut: + `uv run ruff check ...`, `uv run python -m compileall -q ...`, and + `git diff --check` passed. Targeted strategy/registry pytest passed for + `tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured` + and + `tests/test_fabric_backend_plan.py::test_transition_program_compiler_uses_cuda_nn_primitive_registry`. + CUDA + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_norm_reverse_native_handler_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_norm_callable --tb=short` + passed after rebuilding the registered-program extension; the paired + tanh/linear/norm fused reverse native-handler CUDA set passed; and the + boundary/plan/import guardrail set passed with 148 tests. + - 2026-05-02 end-to-end primitive-DAG training cut: a real compiled + transition program `linear(aggregated_message, value_to_state_weight, + recurrent_bias) -> norm_or_identity(outnorm_weight, outnorm_eps) -> tanh` + now trains through the registered fused temporal forward and backward + programs. The fused forward program no longer asks transition rows for the + old composite logical binding names through the native callable schema. + Instead, `transition_aggregated_message_input` and + `transition_public_state_output` are compiler access rows that can point at + any primitive input/output binding in the DAG. Intermediate primitive rows + can have no temporal access row at all. Reverse seed allocation now falls + back to compiler-owned recurrent-hidden reverse artifacts when a stateless + primitive DAG has no state tensor to use as a seed template, and the + transition parameter reducer accepts one bucket-level reducer request whose + binding coverage spans multiple reverse executor rows. The CUDA regression + test asserts boundary, input-projection, bias, and norm parameter gradients + all materialize through the registered fused program path. + - Verification for the end-to-end primitive-DAG training cut: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_generic_transition_primitive_dag_trains_through_registered_fused_temporal_program --tb=short` + passed. The focused compiler/strategy checks + `test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned`, + `test_temporal_forward_executor_rows_fall_back_to_registered_primitive_dag_strategies`, + and + `test_temporal_reverse_executor_rows_cover_parameterless_primitive_dag_adjoint` + passed. `uv run ruff check` passed for the touched Python files and + `uv run python -m compileall -q` passed for the compiler/temporal/test + Python targets. The boundary/plan/import guardrail set passed with + 149 tests after updating the static boundaries to require transition + program I/O through compiler access rows instead of native composite + binding names. + - 2026-05-02 message-strategy extensibility cut: added a new + `DotProduct(math="fixed_slot_context_gate")` semantic variant as a + compiler stress slice, not as a temporal-engine branch. It registers a + message backend spec, lowers to distinct primitive rows with + `message_query_context_gate`, selects separate forward/reverse message + strategy IDs, and binds the neutral native scalar access + `message_query_context_scalar` through compiler access rows so the native + callable can consume the same access contract without the scheduler knowing + the semantic parameter name. This removes one piece of the fixed-strategy + assumption: adding a new message math variant is now local to + declaration/lowering/strategy/binding metadata plus optional native + callable support. + - 2026-05-02 fixed-slot context message reducer cleanup: replaced the + nudge-specific message parameter-reducer contract with the generic + `fixed_slot_context_message` reducer and + `grad_query_context_scalar` output role. The nudge and gate message + declarations still keep separate semantic parameter names and strategy IDs, + but the shared native reducer, generated callable catalog entry, temporal + parameter reducer rows, and C++ reducer handler now describe the compiler + implementation unit rather than one semantic spelling. The registered + program helper/kernel names and diagnostics were cleaned up from + context-nudge wording to fixed-slot-context wording where the code is shared + by both message variants. + - 2026-05-02 fixed-slot context gate training cut: added CUDA training + coverage for the `fixed_slot_context_gate` DotProduct variant through the + registered fused temporal forward/backward programs. The test asserts + boundary, readout/message, and `message_query_context_gate` gradients + materialize, verifies the compiler-owned extra message parameter-gradient + slots are returned, and checks the backward owner is + `registered_temporal_fused_backward_program_cuda`. + - 2026-05-02 fixed-slot context gate equivalence cut: added a CUDA + nudge-vs-gate equivalence test with identical scalar bindings. It compares + forward outputs, boundary gradients, all common parameter gradients, and + the mapped scalar gradient + `message_query_context_gate == message_query_nudge_scale` while requiring + the gate path to report the registered fused backward owner. This is not + the final PR-10 math-change stress test, but it proves the new semantic + spelling is carried by compiler access rows and the shared reducer rather + than by a hidden nudge parameter name. + - 2026-05-02 DotProduct C++ IR stress-gate cut: tightened the + `fabric.cuda.nn` message-rule IR boundary so the CUDA-side semantic + classifier is driven by generated registry catalog rows for the nudge and + gate message rules rather than by hand-coded shape probes. Fixed-slot + context rules now require explicit parameter-value nodes for the scalar + context gate/nudge and sender context-key table instead of accepting any + loose receiver-public/sender-slot/concat/normalize shape as nudge. This + keeps the PR-10-style stress gate compiler-owned at declaration/IR + classification: changing the semantic scalar binding changes the lowered + message rule or fails closed before strategy selection. + - Verification for the fixed-slot context reducer/gate training cut: + `uv run ruff check ...`, `uv run python -m compileall -q ...`, + `git diff --check`, and the native-callable generated-header validator + passed. Targeted compiler checks passed for + `test_temporal_executor_fusion_patterns_are_structured`, + `test_fixed_slot_context_nudge_dot_product_selects_registered_strategy`, + `test_fixed_slot_context_nudge_message_parameters_materialize_from_compiler_bindings`, + `test_fixed_slot_context_gate_dot_product_selects_registered_strategy_with_access_remap`, + and + `test_fixed_slot_context_gate_message_parameters_materialize_from_compiler_bindings`. + CUDA checks passed for + `test_registered_parameter_reducer_cuda_executes_transition_trainable_rows`, + `test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program`, + `test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program`, + and + `test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal`. + - Verification for the DotProduct C++ IR stress-gate cut: + `uv run pytest -q tests/test_fabric_execution_imports.py::test_sequence_surface_uses_registered_temporal_compiler_path tests/test_fabric_backend_boundaries.py::test_cuda_message_rule_ir_distinguishes_context_gate_from_nudge --tb=short`. + A standalone `g++ -std=c++17 -I src -x c++ -fsyntax-only` snippet also + built fixed-slot context nudge and gate rules through `MessageRuleBuilder` + and asserted they classify to different lowering kinds. + - 2026-05-02 reverse zero-output-gradient window cut: closed a registered + reverse coverage gap for final-state-only losses. The backward scheduler can + request a reverse window because final-state carry gradients are present + while the user output tensor is unused and autograd supplies no output + gradient. The artifact-list and tensor-store output-gradient window builders + now materialize compiler-owned zero windows from reverse artifact roles + instead of passing `None` into the registered reverse program and tripping + the late `missing_grad_output_window` gate. This keeps the fused reverse ABI + tensor-table based and does not add an output backward fallback. + - Verification for the reverse zero-output-gradient window cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/output_backward.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_runtime.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/output_backward.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_runtime.py`, + and `git diff --check` passed. CUDA parity is covered by + `tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window`, + `tests/test_fabric_runtime.py::test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store`, + `tests/test_fabric_runtime.py::test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients`, + and + `tests/test_fabric_runtime.py::test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific`. + - 2026-05-02 active-path identity cleanup: removed the compatibility-era + `registered_forward_executor_bindings` and + `registered_temporal_forward_executor_program` names from active forward + planning, strategy metadata, runtime records, and tests. Forward ownership + now reports `registered_fused_forward_program_cuda`, and executable + forward entrypoints report + `registered_temporal_fused_forward_program_cuda`. The registered program + plan now points at the fused CUDA forward/backward launch contract instead + of describing a separate executor-program phase. + - Verification for the active-path identity cleanup: + `python -m compileall -q ...`, `uv run ruff check ...`, + `git diff --check`, `uv run pytest -q tests/test_fabric_backend_plan.py --tb=short`, + and `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed. + - 2026-05-02 forward program runtime-fact row cut: moved fused forward scan + topology/schedule/scalar ABI facts out of loose C++/Python arguments and + into a compiler-owned `forward_program_runtime_rows` tensor plus runtime + fact tensor table. The fused forward program now resolves recurrent sender + topology, output sender topology, distance/delay tables, inner-step count, + terminal-output policy, distance scale, head/value dimensions, and delay + enablement through required role opcodes before launching the program body. + This is still the same fused CUDA program, but the scan ABI is now driven by + rows/tensor tables rather than positional fixed-slot facts. + - Verification for the forward program runtime-fact row cut: + `python -m compileall -q ...`, `uv run ruff check ...`, + `git diff --check`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_fused_program_runtime_facts_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_temporal_table_runtime_metadata_records_executor_blockers --tb=short`, + `uv run pytest -q tests/test_fabric_backend_plan.py --tb=short`, and + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window --tb=short` + passed. + - 2026-05-02 reverse program runtime-fact row cut: mirrored the forward + runtime-fact row design in the fused reverse program. Graph order, + backend-to-graph order, output/recurrent sender topology, neighbor + metadata, distance/delay tables, message-step indices, input/recurrent + counts, sparse-message policy, delay policy, group/head/value dimensions, + distance scale, and boundary-gradient policy now flow into + `registered_temporal_fused_backward_program_cuda` through + `reverse_program_runtime_rows` and a runtime fact tensor table. The reverse + Python/C++ ABI no longer passes those facts as loose scan arguments. The + fused launch contract now names both `forward_program_runtime_rows` and + `reverse_program_runtime_rows` as required compiler-owned tables. + - Verification for the reverse program runtime-fact row cut: + `python -m compileall -q ...`, `uv run ruff check ...`, + `git diff --check`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_fused_program_runtime_facts_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_reverse_fused_program_runtime_facts_are_compiler_owned_rows --tb=short`, + `uv run pytest -q tests/test_fabric_backend_plan.py --tb=short`, and + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window --tb=short` + passed. + - 2026-05-02 residual reverse-glue ownership cut: removed the remaining + active-path `thin_reverse_path:explicit_executor` and + `explicit_*_thin_reverse` ownership labels from the planner and CUDA + sequence runtime metadata for public projection, readout projection, and + state epilogue backward. These surfaces now report compiler-owned + registered reverse handlers such as + `registered_sender_kv_projection_backward_executor` and + `projection_reduction_boundary_backward`; residual glue demotions are empty + instead of accepted as closure evidence. Boundary tests now reject the old + explicit thin-reverse strings in the active planner/runtime sources. + - Verification for the residual reverse-glue ownership cut: + `python -m compileall -q src/cortical/fabric/backend/planner.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/planner.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py`, + `git diff --check -- src/cortical/fabric/backend/planner.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py ai_docs/REDO2_FIXMASS.md`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short`, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_backward_physical_plan_is_typed_and_semantic --tb=short`, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_training_surface_records_backward_phase_plans --tb=short`, + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short`, + and `uv run pytest -q tests/test_fabric_backend_plan.py --tb=short` + passed. A broader mixed-pop CUDA parity probe, + `tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route`, + still failed on existing parameter-gradient tolerance and remains part of + the pre-throughput parity matrix rather than this metadata ownership cut. + - 2026-05-02 executable memory-policy row cut: the memory liveness table now + carries explicit compiler-owned runtime-policy rows for local seed rows, + metadata rows, primitive outputs, tape policy, alias policy, recompute + window policy, materialization policy, and CUDA graph constraints. These + rows use a registered `runtime_policy` surface and `compiler_memory_policy` + owner, travel through the same `memory_liveness_rows` ABI as primitive and + tensor-role entries, and are intentionally not allocated as ordinary + runtime buffers. This makes the remaining memory/tape/checkpoint work + concrete: each policy must either become a planned runtime buffer, + primitive-owned output/workspace, metadata/policy row, or an illegal + scheduler allocation. + - Verification for the executable memory-policy row cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `git diff --check -- src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short`, + and `uv run pytest -q tests/test_fabric_backend_plan.py --tb=short` + passed. + - 2026-05-02 transition matmul reverse primitive closure: standalone + transition `matmul` now has a registered reverse executor strategy, + reverse native callable ID, compiler-owned reverse input/parameter/output + binding schema, parameter-gradient output contract, generated native + transition reverse catalog entry, and C++ handler + `run_registered_matmul_reverse_transition_handler`. The fused reverse + transition program can now dispatch matmul backward from + `primitive_rows`, `reverse_executor_rows`, logical binding rows, native + callable schema rows, and `memory_liveness_rows` without falling back to a + composite gated reverse handler. + - Verification for the transition matmul reverse primitive closure: + `python -m compileall -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/transition_execution/registry.py`, + `uv run ruff check tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/transition_execution/registry.py`, + `git diff --check -- src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/transition_execution/registry.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_primitive_forward_ops.cuh src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_handlers.cuh src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_reverse_transition_native_handlers_use_logical_binding_schema --tb=short`, + and `uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_matmul_callable --tb=short` + passed. + - 2026-05-02 runtime policy row validation cut: the registered fused program + C++ validator now treats `runtime_policy` / `policy_table` rows as a typed + memory-liveness contract instead of generic positive integers. Runtime + policy rows must use the `runtime_policy` surface, `policy_table` + workspace, `compiler_memory_policy` owner, no primitive row, and one of + the allowed policy effects. When a policy set is present, validation + requires local seed, metadata, primitive-output, tape, alias, + recompute-window, materialization, and CUDA graph policy rows. Policy rows + are also explicitly non-allocating, so runtime buffer rows cannot claim + them. + - Verification for the runtime policy row validation cut: + `python -m compileall -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run ruff check tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `git diff --check -- src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/constants_and_checks.cuh src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/memory_runtime_buffers.cuh tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short`, + and `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_matmul_callable --tb=short` + passed. + - 2026-05-02 transition primitive contract verifier cut: transition + primitives can no longer be treated as program-layer callable merely + because `program_layer_status` says so. The transition registry now has a + contract blocker check requiring callable primitives to declare forward + symbol, C++ forward entrypoint, forward input/output binding schema, + forward output contracts, backward symbol, reverse native callable, + reverse input bindings, and reverse output bindings. Missing contracts now + produce typed blocker codes such as + `INCOMPLETE_FUSED_TRANSITION_PRIMITIVE_REVERSE_CONTRACT` instead of + silently passing through primitive-DAG selection. + - Verification for the transition primitive contract verifier cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution/registry.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution/registry.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_program_compiler_uses_cuda_nn_primitive_registry tests/test_fabric_backend_boundaries.py::test_forward_transition_access_uses_compiler_program_access_rows --tb=short` + passed. + - 2026-05-02 stateful primitive-DAG binding cut: forward transition state + carry rows are now inferred from compiler tensor bindings across the whole + transition bucket before strategy-declared composite carry rules are used + for validation/coverage. A primitive-DAG transition that declares an + arbitrary private state such as `mem -> next_mem` now produces carry rows + without adding a `tanh_transition` state-carry special case. Reverse + generic primitive bindings now preserve `grad_next_*` for state-output + seeds, and the registered temporal executor program emits dynamic + transition reverse seed-role rows from the reverse binding plan instead of + limiting the seed catalog to built-in `y/c/n/m/hc1/hc2` roles. + - 2026-05-02 stateful primitive-DAG training cut: runtime state + initialization, temporal autograd state flattening, final-state + materialization, and population-gradient conversion now use the compiled + transition private-state schema instead of the built-in cell module schema + when a program declares custom private state. Parameterless transition + programs now produce empty compiler parameter tables/reducer outputs + instead of failing as if every transition row had trainable parameters. + The fused forward CUDA program now accepts compiler-derived state-carry + rows for primitive-DAG strategies that declare no fixed carry contract, and + transition state-before reverse artifacts are emitted once per bucket per + step rather than once per transition span. This closes the active-path + proof that `mem -> next_mem` can train through registered fused forward and + registered reverse tensor-store execution with `K>1`, reset-present + windows, materialized final state, boundary gradients, and initial + state-carry gradients. + - Verification for the stateful primitive-DAG binding cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_derives_carry_rows_from_compiler_bindings tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_registers_dynamic_reverse_seed_roles --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_reverse_seed_roles_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_forward_executor_rows_fall_back_to_registered_primitive_dag_strategies tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_cover_parameterless_primitive_dag_adjoint --tb=short` + passed. + Follow-up verification for the stateful primitive-DAG training cut: + `uv run python -m compileall -q src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_buckets.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_buckets.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_generic_transition_primitive_dag_trains_through_registered_fused_temporal_program tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_derives_carry_rows_from_compiler_bindings tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_registers_dynamic_reverse_seed_roles --tb=short`, + `uv run pytest -q tests/test_fabric_runtime.py::test_backend_originated_population_state_roundtrips_without_copy tests/test_fabric_runtime.py::test_ensure_state_preserves_backend_originated_population_views --tb=short`, + `uv run pytest -q tests/test_fabric_backend_plan.py --tb=short` (`120` + passed), + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py --tb=short` + (`41` passed), + and `git diff --check -- src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_buckets.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/program_tensor_access.cuh src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/forward_program.cuh tests/test_fabric_backend_plan.py ai_docs/REDO2_FIXMASS.md` + passed. + - Remaining compiler work before throughput after the native callable, + transition primitive callable-row, reverse backward-callable row, and + parameter-reducer native-callable row cuts: + - 2026-05-02 hard-blocker cut after the direct transition API review: + the direct per-population CUDA transition bucket API is no longer a + runtime entrypoint. `flat_buckets.py` was trimmed to the shared + transition-gradient accumulator types and population/backend grad-state + conversions only; it no longer exposes transition bucket forward, + backward, active-window, or cached eager runners. `runtime_dispatch.py` + no longer has `_run_*transition_bucket*` helpers or direct + `_lower_backend_population_transition_*` dispatch methods, and + `runtime/core.py` no longer calls the direct transition bucket helpers + from CUDA stream-step or boundary-multistep paths. The training sequence + surface also stopped wrapping the registered temporal program in + `_CapturedTrainingSequenceSurface`; gradients now flow through the + compiler-owned temporal physical autograd path returned by + `execute_temporal_bucket_sequence(...)`. The old + `_PhysicalBackwardSequenceExecutor` and the runtime surface + step-object artifact helper were deleted, so the active surface no longer + recomputes transition/message/readout artifacts through direct runtime + calls. Boundary guards now reject reintroducing those methods/symbols. + Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/runtime_dispatch.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/surface.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/support.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/runtime_dispatch.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/surface.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/support.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_shared_temporal_forward_scan_has_no_python_step_loop_fallback --tb=short`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_primitive_dag_transition_direct_lowering_is_registered_program_only tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short`, + and targeted `git diff --check` passed. Follow-up cleanup on + 2026-05-02 deleted the historical `runtime/backward.py` mixin as a live + CUDA runtime sibling. The two still-used policy helpers were moved into + their owning surfaces: transition core-state counting now lives on + `Runtime`, and receiver-window static tensor slicing now lives on the + sequence surface mixin. `CudaSequenceSurfaceMixin` no longer inherits a + backward mixin, boundary guards assert the old file/import/class stay + absent, and active execution metadata now reports the registered + message/transition/readout program owners instead of backend-order + transition-bucket executor names. Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/runtime/surface.py src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/runtime/surface.py src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_runtime.py::test_fabric_mixed_population_sequence_route_is_planner_owned tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_has_no_legacy_cell_surface --tb=short`, + and targeted `git diff --check` passed; + - 2026-05-02 transition native locality cut: the broad + `transition_primitive_ops.cuh` file was deleted and split into + `transition_primitive_forward_ops.cuh` for direct primitive + forward/backward bodies and `transition_reverse_handlers.cuh` for + registered reverse span handlers. The fused temporal program include now + names those two semantic units directly, and allocation audit rules track + them separately. This is source-ownership cleanup only: dispatch still + flows through compiler rows, native callable catalogs, executor rows, and + registered program bindings; + - 2026-05-02 direct transition primitive export cut: the + `registered_program_transition_*_cuda` Python entrypoints and + `program_transition_*` pybind exports were deleted. Transition primitive + CUDA bodies remain only as internal registered-program strategy bodies + reached through compiler rows, native callable catalogs, executor rows, + tensor binding rows, and fused transition/full temporal program launches. + Boundary guards now assert those direct Python/pybind launch surfaces stay + absent, and primitive parity coverage moved to registered fused program + tests instead of importing primitive probes; + - 2026-05-02 registered temporal route-name cleanup: planner/runtime + metadata no longer describes the selected sequence implementation as + `flat_transition_buckets`. `SequenceSurfaceRoute` now reports + `implementation_executor="registered_temporal_program"` and + `surface_key="registered_temporal_sequence_surface"`, runtime capability + checks use `uses_registered_temporal_program`, and the planner rejection + reason is `requires_registered_temporal_program`. This is a naming + cleanup over the already-active compiler path, but it removes misleading + labels that implied the registered program was still a flat transition + bucket route. Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/temporal_plan.py src/cortical/fabric/backend/planner.py src/cortical/fabric/backend/runtime_dispatch.py src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py`, + `uv run ruff check src/cortical/fabric/backend/temporal_plan.py src/cortical/fabric/backend/planner.py src/cortical/fabric/backend/runtime_dispatch.py src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fabric_temporal_execution_plan_records_output_request_schedule tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_uses_registered_temporal_program tests/test_fabric_runtime.py::test_fabric_mixed_population_sequence_route_is_planner_owned tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_has_no_legacy_cell_surface --tb=short` + passed; + - 2026-05-02 direct CUDA runtime-op bridge deletion DONE: the active + runtime dispatcher no longer imports + `cortical.fabric.backend.cuda.runtime_ops` or reaches direct + message/projection CUDA wrappers when a request misses the registered + temporal program. The compiler-owned registered temporal program remains + the CUDA execution path; non-registered helper execution uses the + PyTorch/reference backend operations. The now-source-inactive + `src/cortical/fabric/backend/cuda/runtime_ops.py` bridge was deleted, + and import guardrails assert it stays absent. A later cut deleted the + standalone direct message/projection CUDA probe modules and tests. + Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/runtime_dispatch.py src/cortical/fabric/runtime/core.py tests/test_fabric_execution_imports.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/runtime_dispatch.py src/cortical/fabric/runtime/core.py tests/test_fabric_execution_imports.py tests/test_fabric_runtime.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_execution_imports.py::test_sequence_surface_uses_registered_temporal_compiler_path tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_uses_registered_temporal_program tests/test_fabric_runtime.py::test_fabric_mixed_population_sequence_route_is_planner_owned tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_has_no_legacy_cell_surface --tb=short`, + and `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + passed; + - 2026-05-02 standalone grouped-projection CUDA bridge deletion DONE: + after `runtime_dispatch.py` stopped importing `cuda/runtime_ops.py`, the + remaining `cuda/projections.py` and + `cuda/projection/grouped_projection_*` files were only direct probe + surfaces. They were deleted, the remaining projection reference math test + now imports the PyTorch/reference projection primitive, and import + guardrails assert the deleted CUDA bridge files stay absent. Registered + temporal program projection work remains owned by compiler rows/native + strategy bodies, not the old grouped-projection extension. Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/runtime_dispatch.py src/cortical/fabric/backend/cuda/projection tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py tests/test_fabric_execution_imports.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/runtime_dispatch.py src/cortical/fabric/backend/cuda/projection tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py tests/test_fabric_execution_imports.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_execution_imports.py tests/test_fabric_backend_plan.py::test_backend_projection_primitives_match_reference_math tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_uses_registered_temporal_program --tb=short` + passed; + - 2026-05-02 hard-blocker cut after the ten-iteration counter: native + message/readout C++ bodies are no longer embedded in the broad fused + temporal program files. `forward_program.cuh` and + `backward_surface_steps.cuh` now keep only shared strategy state, + dispatch, and generated-catalog lookup; concrete message/readout native + callable bodies live under + `registered_program/native_callables/message_forward_strategies.cuh`, + `readout_forward_strategies.cuh`, `message_reverse_strategies.cuh`, and + `readout_reverse_strategies.cuh`. The generated native-callable catalog + still points directly at those strategy symbols; this did not introduce a + facade or wrapper path. In the same cut, primitive row executor contracts + now derive implemented/missing status from registered fused strategy + group coverage and native forward/reverse strategy matching instead of + bucket ordinals. Message/readout/parameter rows that are present in a + primitive table but not covered by a registered executor group now fail + closed with row-level blockers. The inactive object-artifact recompute + branch was deleted from `forward_scan.py` and + `registered_executors.py`; active backward already requires + `TemporalReverseArtifactTensorStore`, and there is no remaining + `recompute_registered_temporal_artifact_window` symbol in runtime source. + Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py`, + `uv run python scripts/validate_fabric_generated_catalogs.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_message_readout_native_callable_bodies_are_strategy_local tests/test_fabric_backend_boundaries.py::test_forward_scan_fails_closed_on_unsupported_primitive_programs tests/test_fabric_backend_boundaries.py::test_shared_temporal_forward_scan_has_no_python_step_loop_fallback tests/test_fabric_backend_boundaries.py::test_temporal_object_artifact_recompute_branch_was_deleted tests/test_fabric_backend_boundaries.py::test_temporal_backward_replay_requests_come_from_scheduler_plan tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch --tb=short`, + and `git diff --check` passed. Remaining native-body work before + throughput is generator/registration closure for adding brand-new + strategy C++ bodies without touching broad temporal program files or the + generated catalog by hand; + - generated native implementation catalog closure is now complete for the + checked-in source model. `scripts/validate_fabric_generated_catalogs.py` + validates both generated C++ catalog headers against the live compiler + registries (`message_rule_lowering_catalog_header_text()` and + `temporal_native_callable_generated_header_text()`), supports `--write` + regeneration, and runs in the package-build CI job before distributions + are produced. Verification: + `uv run python scripts/validate_fabric_generated_catalogs.py`, + `uv run ruff check scripts/validate_fabric_generated_catalogs.py`, + and + `git diff --check -- scripts/validate_fabric_generated_catalogs.py .github/workflows/ci.yml` + passed; + - continue tightening executable memory/liveness/tape/checkpoint policy: + message-step workspace and parameter-reducer output tensors are now + planned by compiler-owned tables; runtime-policy rows now have both C++ + validation and a Python `TemporalMemoryRuntimePolicy` extracted from the + compiler memory plan; runtime buffer planning now fails closed when any + required policy row is missing; and compiler-declared scratch workspace + aliases now drive actual buffer reuse. Artifact stores now also carry a + `memory_runtime_policy_fingerprint`, and registered backward validates + that fingerprint before consuming checkpoint/recompute windows. The next + cut added `TemporalMemoryRuntimeSchedulePlan`: checkpoint stride, + recompute window length, planned checkpoint steps, backward windows, + output physical steps, primitive-output/tape/materialization policy, and + CUDA graph constraint are now one compiler runtime-schedule product + consumed by `TemporalMemoryRuntimeArtifactPlan`. `TemporalArtifactStore` + carries `memory_runtime_schedule_fingerprint`, and registered backward + rejects stale schedule products before using artifact windows. Remaining + memory work is to connect local seed-row and metadata-row policy to + concrete scheduler/runtime table placement, then carry the same + schedule-plan fingerprint through transition tape allocation, + recompute-window materialization, and CUDA graph launch constraints. The + next cut turned the local seed-row, metadata-row, policy-row, checkpoint, + backward-window, and output-step schedule into concrete compiler runtime + schedule rows: `TemporalArtifactStore` now carries + `memory_runtime_schedule_rows`, registered backward compares those rows + against the current compiler schedule before consuming artifacts, and + runtime buffer plans carry the same schedule fingerprint/rows when + allocating transition tape and replay buffers. The follow-up C++ ABI cut + passes `memory_runtime_schedule_rows` into the full fused forward and + backward program entrypoints and validates policy-row effects, + memory-row ownership, recompute-policy opcodes, checkpoint/recompute + scalars, backward windows, and CUDA graph constraint presence against + `memory_liveness_rows` before launch. The next guard cut makes the CUDA + graph constraint executable at launch: fused forward/backward validation + now requires the compiler-owned `cuda_graph_guard_policy` schedule row, + queries the current CUDA stream capture status, and rejects invalidated + capture streams before any registered program launch proceeds. Remaining + memory work is now limited to future non-scratch alias/lifetime decisions + that need device-side schedule facts; + - close generic transition lowering: primitive-DAG direct eager dispatch has + been deleted from transition lowering, so primitive-DAG CUDA training now + has to execute through the registered temporal executor program. + Per-primitive forward strategies are selectable after bucket-level fusion + misses, parameterless `tanh` has CUDA forward/reverse coverage, and + `linear -> tanh` now has a compiler-owned parameterized reverse DAG plus + parameter-gradient binding. Standalone + `norm_or_identity` and standalone recurrent `matmul` now have the same + reverse native-handler and parameter-gradient coverage, and + `linear -> norm_or_identity -> tanh` now has end-to-end CUDA training + coverage from a real compiled transition program. Stateful primitive-DAG + carry rows and dynamic `grad_next_*` reverse seed roles are now + compiler-binding owned, and `mem -> next_mem` now has CUDA end-to-end + stateful training coverage with reset-present `K>1` materialized-state + gradients. The primitive-DAG executor plan now also carries explicit + per-op tape/recompute contracts: saved input/output tensors, recompute + inputs/outputs, reverse input bindings, and primitive tape policy are + derived from transition primitive registry records and exposed in the + compiler plan review. Follow-up legality closure: transition primitive + registry records now fail closed when their compiler-owned logical + binding contracts are internally invalid: duplicate forward/reverse + binding names, invalid forward output contract shape/index metadata, + tape bindings that do not map to declared forward inputs/outputs, or + parameter-gradient outputs that target undeclared reverse parameters. + Full primitive-DAG closure still needs broader reset/state parity against + a reference path, executable memory placement for those tape contracts, + optional fused CUDA strategies, and typed fail-closed selection for + unsupported CUDA execution. The older direct per-population CUDA + transition lowering module has now been deleted as a live source path: + `transition_execution/lowering.py` no longer exists, no active source + imports `lower_backend_population_transition_*`, and boundary guards + assert the file stays absent. Transition program metadata remains in + `transition_execution/program.py`/`registry.py`, and executable CUDA + transition math is reached through registered temporal program dispatch; + - forward strategy genericity advanced: the shared fused-forward + message/readout executor state no longer stores fixed DotProduct/readout + fields such as `recurrent_q`, fixed-slot sender key tensors, + `output_q`, or `value_to_output_weight`. Forward message/readout + strategies now bind compiler-emitted program access rows into + strategy-owned tensor tables, and the strategy body asks for the access + it owns. The follow-up hard-blocker cut above moved concrete + message/readout C++ native bodies out of the broad temporal program files + and into strategy-local native-callable headers. The remaining work is to + make new message/readout strategy registration fully generated/local to + declaration/lowering, executor pattern, native callable, and tests, with + no additions to shared program structs, broad temporal-side fixed roles, + or checked-in generated catalog edits; + - Readout strategy locality advanced: readout rules now have the same + registry-owned backend spec shape as message rules for static tensor + access rows and native forward/reverse executor contracts. The temporal + executor registry derives readout primitive row signatures, program + access rows, strategy IDs, native callable IDs, and named C++ entrypoint + phases from `ReadoutRuleBackendSpec`/`compile_readout_rule(...)` instead + of duplicating the projection/reduction formula in shared temporal + strategy code. The generated native-callable C++ catalog now validates + readout entrypoints by named phases (`bind/message/projection` and + `readout_backward/output_message_backward`) instead of positional tuple + order. Concrete readout C++ bodies now live in strategy-local + native-callable headers. Remaining strategy-locality work is now the + native callable body generator/registration contract for adding a new + readout implementation without touching temporal program structs or the + checked-in generated catalog by hand; + - reverse readout helper genericity advanced in the same pass: shared + reverse readout/output-message step bodies now select the native readout + strategy and pass compiler program tensor/access tables through to it. + The strategy-owned readout native callable resolves + `readout_value_to_output_weight` and `readout_output_query` from + compiler access rows; the shared reverse step no longer hardcodes those + readout access names before dispatch; + - parameter reducer genericity advanced: reducer kind opcodes, count + contracts, active trainable roles, message strategy static binding + requirements, and message parameter-gradient output roles now live in + `compiler/reducer_patterns.py`. The active temporal parameter binding + path consumes that metadata instead of branching on + `fixed_slot_context_message` or choosing `message_query_context_gate` + versus `message_query_nudge_scale` itself. Fixed-slot context remains a + registered reducer strategy implementation, but the host program builder + now treats it as registry metadata plus tensor rows rather than a + one-off adapter. The C++ parameter-reducer executor also no longer stores + strategy expected counts in named fields such as + `fixed_slot_context_message`; expected counts are accumulated and queried + by reducer count-target opcode from the compiler-emitted strategy rows; + - C++ message-rule classification is now generated from the Python + `MessageRuleBackendSpec` registry. `fabric/backend/cuda/nn/ir.cuh` no + longer contains procedural `message_rule_matches_*DotProduct*` classifier + bodies or scalar-name branches in `classify_message_rule_lowering`; it + validates the generic `MessageRuleIR` and iterates + `registered_message_rule_lowering_patterns_*` from + `message_rule_lowering_catalog.cuh`. The checked-in catalog is validated + byte-for-byte against + `message_rule_lowering_catalog_header_text()`, and the default DotProduct + semantic spec now declares its final projected-message projection instead + of relying on a hidden C++ catalog node. This does not finish generated + native message strategy implementation registration, but it removes the + C++ classifier/catalog as a second DotProduct formula owner. Remaining + message-rule compiler work is to prove the PR-10-style DotProduct + math-change stress gate can be added without scheduler, fixed-slot ABI, or + shared classifier edits; + - Message-rule catalog fidelity and forward strategy locality advanced: + C++ `MessageRuleIR` nodes now carry explicit `parameter_indices`, and the + generated lowering catalog emits per-node parameter-index arrays instead + of collapsing parameter-binding ops to one `parameter_index`. Multi-param + message primitives such as the fixed-slot value projection now preserve + all declared parameter bindings in the catalog ABI. The fixed-slot + message forward executor row patterns are also derived from the registered + `MessageRuleBackendSpec`/`compile_message_rule(...)` output instead of + duplicating the primitive formula in the temporal executor registry, and + fixed-slot message program access names/opcodes now live on the + message-rule static tensor specs before the executor registry derives + `TemporalProgramAccessPattern` rows from them. Fixed-slot reverse message + parameter-gradient output metadata and reducer kind now also live on the + registered message-rule spec before reverse executor patterns derive + their `message_param_grad_outputs`. Follow-up strategy locality cut: + `MessageRuleBackendSpec` now also owns the native forward/reverse + executor contract and the message parameter-reducer implementation + contract. `executor_patterns.py` derives fixed-slot forward/reverse + message strategies from `spec.native_executors`, and + `reducer_patterns.py` derives the fixed-slot message reducer from + `spec.parameter_reducer`; the fixed-slot native handler symbols and + reducer implementation symbol no longer live in the shared temporal + executor/reducer registries. Concrete fixed-slot C++ body placement is + now strategy-local under `registered_program/native_callables/`. + Remaining strategy-locality work is the native callable body generator + contract itself, so new message math does not require broad temporal-side + edits; + - Verification for the memory runtime schedule and forward strategy-access + cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_backward_validates_memory_artifact_plan_fingerprint --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_reverse_message_readout_helpers_use_native_strategy_access_schema tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_strategy_access_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_readout_access_v1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed; + - Verification for the concrete runtime schedule-row cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_backward_validates_memory_artifact_plan_fingerprint tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + passed; + - Verification for the C++ fused-program runtime schedule-row ABI cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_backward_validates_memory_artifact_plan_fingerprint tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_schedule_rows_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program --tb=short` + passed; + - Verification for the CUDA graph launch-guard cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run ruff check tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_cuda_graph_guard_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program --tb=short` + passed; + - Verification for the reducer-pattern metadata and primitive-DAG tape + contract cuts: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/transition_execution/program.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/transition_execution/program.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned --tb=short`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_reverse_executor_rows_cover_parameterless_primitive_dag_adjoint --tb=short`, + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reducer_registry_contract_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_message_reducer_registry_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program --tb=short` + passed; + - Verification for the generated message-rule C++ lowering catalog cut: + `uv run python -m compileall -q src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_cuda_message_rule_ir_distinguishes_context_gate_from_nudge --tb=short`, + `uv run pytest -q tests/test_fabric_public_api.py::test_message_rule_declaration_lowers_to_projected_message_boundary tests/test_fabric_public_api.py::test_dot_product_context_nudge_math_lowers_to_distinct_ir tests/test_fabric_public_api.py::test_dot_product_context_gate_math_lowers_to_distinct_ir tests/test_fabric_backend_plan.py::test_default_message_rule_contract_is_planner_visible tests/test_fabric_backend_plan.py::test_message_rule_backend_specs_are_registered_like_cell_specs tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_default_message_rule_has_python_semantic_equivalent --tb=short`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_backend_ir_uses_declared_spec_message_rule_not_default_substitution tests/test_fabric_backend_plan.py::test_message_rule_compiler_rejects_unsupported_rule_before_runtime_execution tests/test_fabric_backend_plan.py::test_message_rule_builder_rejects_unregistered_rule_type --tb=short`, + a lightweight `g++ -std=c++17 -I src` include compile for + `cortical/fabric/backend/cuda/nn/ir.cuh`, and exact catalog validation via + `validate_message_rule_lowering_catalog_header(...)` passed; + - Verification for the explicit message catalog parameter-binding and + fixed-slot row-pattern locality cut: + `python -m compileall -q src/cortical/fabric/backend/message_rules.py tests/test_fabric_backend_boundaries.py`, + `python -m compileall -q src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + exact catalog validation via + `validate_message_rule_lowering_catalog_header(...)`, + a `g++ -std=c++17 -I src` C++ sanity check that builds a fixed-slot + context nudge message rule with a multi-parameter `linear` node and + classifies it through the generated catalog, and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_cuda_message_rule_ir_distinguishes_context_gate_from_nudge tests/test_fabric_backend_plan.py::test_message_rule_backend_specs_are_registered_like_cell_specs tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs --tb=short` + passed; + follow-up verification for moving fixed-slot program-access and + parameter-gradient metadata into message specs: + `python -m compileall -q src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_backend_plan.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs --tb=short` + passed; + follow-up verification for moving fixed-slot native executor and reducer + implementation contracts into message specs: + `python -m compileall -q src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_selects_registered_strategy tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_dot_product_selects_registered_strategy_with_access_remap --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed; + - Verification for the readout strategy-locality and generated native + callable catalog phase cut: + `uv run python -m compileall -q src/cortical/fabric/backend/readout_rules.py src/cortical/fabric/backend/readout_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/readout_rules.py src/cortical/fabric/backend/readout_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py tests/test_fabric_backend_plan.py`, + `uv run python scripts/validate_fabric_generated_catalogs.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_readout_executor_patterns_follow_registered_readout_specs tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names --tb=short`, + and targeted `git diff --check` passed; + - Verification for the transition primitive registry logical-contract + legality cut: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution/registry.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution/registry.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted --tb=short`, + and targeted `git diff --check` passed; + - Follow-up reverse native transition callable contract cut: + `transition_reverse_program.cuh` now validates reverse input, + parameter, and output binding vectors against the compiler-emitted + `native_callable_binding_schema_rows` for the selected generated native + callable before dispatching a transition primitive adjoint. The generated + catalog-selected `native_callable_hash` is threaded into each reverse + transition handler instead of hardcoding `native.reverse.transition_*` + identities inside handler bodies, and generated catalog counts are checked + against the same binding-schema rows. This also fixed a discovered schema + drift: `norm_or_identity` reverse `eps` is optional in the handler and is + now optional in the compiler binding schema/header as well. Follow-up: + the runtime `schema_version` is now threaded through the reverse handler + ABI and used by each logical binding lookup instead of hardcoding `1` in + primitive bodies. This closes the immediate drift hole where checked-in + C++ handler bodies and compiler-emitted reverse binding schemas could + disagree silently; + - Verification for the reverse native transition callable contract cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_reverse_transition_native_handlers_use_logical_binding_schema tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records --tb=short`, + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_schema_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_tanh_callable tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_linear_callable tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_matmul_callable tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_norm_callable --tb=short`, + and `git diff --check --` on the touched compiler/C++/test files passed; + - DotProduct semantic math-change stress gate: completed in Iteration 6/10 + below. The `fixed_slot_context_nudge` and `fixed_slot_context_gate` + message strategies now lower, bind, train, and reduce extra parameter + gradients through the registered compiler path; + - deferred config/anatomy/stale-CUDA cleanup: runtime-critical cleanup + completed in Iteration 7/10 and guarded in Iteration 8/10 below. The + remaining `Config`/Blueprint split is broad public API cleanup and must + preserve the compiler-owned execution source of truth. + - 2026-05-02 ten-iteration compiler closure counter, before throughput: + - Iteration 1/10 DONE: remove the remaining fixed reverse-output + tuple ABI by making fused reverse span outputs and reducer requests + compiler-declared rows/roles instead of Python/C++ positional slots. + Implemented `reverse_span_output_rows`, threaded them into the registered + fused backward CUDA program, assembled C++ output groups from declared + roles, consumed Python outputs by role, and replaced message extra + parameter-gradient slot metadata with declared role metadata. + Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_reverse_span_outputs_are_compiler_owned_rows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_span_outputs_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program --tb=short` + passed. + - Iteration 2/10 DONE: make message native strategy additions local to + the message-rule spec/native-callable catalog contract, with no shared + temporal struct or scheduler edits for new message math. Message rule + native executors now declare named C++ entrypoint phases, executor + patterns validate the required forward/reverse phase contracts, and the + generated native-callable message catalogs bind by phase instead of + positional tuple order. Verification: + `python -m compileall -q src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_fusion_patterns_are_structured tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned --tb=short`, + and `git diff --check --` on the touched message/compiler/test files + passed. + - Iteration 3/10 DONE: close the generic transition primitive extension + tail by making per-op tape/recompute placement a registry-owned contract + instead of hidden string inference. Transition primitive executor records + now declare saved/recomputed input and output bindings, executor + selection fails closed on incomplete tape contracts, and reset/state + training coverage runs through the registered temporal program. + Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution/registry.py src/cortical/fabric/backend/cuda/transition_execution/program.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution/registry.py src/cortical/fabric/backend/cuda/transition_execution/program.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_boundaries.py::test_reverse_transition_native_handlers_use_logical_binding_schema tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short`, + `git diff --check -- src/cortical/fabric/backend/cuda/transition_execution/registry.py src/cortical/fabric/backend/cuda/transition_execution/program.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_transition_tape_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed. + - Iteration 4/10 DONE: finish executable memory/liveness cleanup for + runtime buffer ownership. `build_temporal_runtime_buffer_plan` now + validates every executable runtime buffer against the compiler + memory-liveness rows before CUDA row emission: buffer specs must name a + real runtime memory row, cannot target policy/parameter-only rows, must + match workspace/surface/effect, must carry compiler memory-plan ownership, + and full fused launches require workspace coverage plus runtime schedule + rows. Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short`, + `git diff --check -- ai_docs/REDO2_FIXMASS.md src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_memory_buffer_validation_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program --tb=short` + passed. + - Iteration 5/10 DONE: delete the active object-artifact reverse window + branch. Registered temporal backward now requires a compiler-owned + `TemporalReverseArtifactTensorStore`; if forward did not provide tensor + artifact rows, backward fails closed instead of loading step objects, + checkpointing through `_recompute_temporal_bucket_artifact_window`, or + converting `TemporalBucketStepArtifacts` into reverse tensor tables. + The registered reverse executor exposes only the tensor-store window + entrypoint for the active CUDA training path. Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_forward_scan_fails_closed_on_unsupported_primitive_programs tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_replay_requests_come_from_scheduler_plan --tb=short`, + `git diff --check -- src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_boundaries.py ai_docs/REDO2_FIXMASS.md`, + and + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_tensor_store_reverse_only_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program --tb=short` + passed. + - Iteration 6/10 DONE: run the PR-10-style DotProduct semantic stress + test by implementing the attention normalization/gating math through + declaration, IR, tensor bindings, native forward/reverse executors, and + reducers only. The `fixed_slot_context_nudge` and + `fixed_slot_context_gate` message rules lower as distinct declared + message programs, bind their context scalar through compiler rows, train + through the registered fused temporal forward/reverse program, and reduce + the declared extra message parameter-gradient outputs through the + registered parameter reducer. The gate stress path exposed and fixed a + metadata bug: extra parameter-gradient slots are now reported from the + reverse strategy's compiler-declared `message_param_grad_outputs` rather + than variant-specific runtime state. Verification: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_dot_product_lowers_as_distinct_message_program tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_selects_registered_strategy tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_dot_product_selects_registered_strategy_with_access_remap --tb=short` + passed: 5 tests; + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_dotproduct_stress_counter_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal --tb=short` + passed: 3 tests. + - Iteration 7/10 DONE: run deferred Config/anatomy/public API cleanup + and remove stale CUDA execution surfaces/directories that are not + compiler-selected primitive executors. This slice deleted the old + package-level CUDA message/readout mini registries and the test-only CUDA + source reference directory, stopped `cortical.fabric.backend.cuda` from + re-exporting direct message/projection kernels as a public execution API, + and moved the remaining test reference helper into the test that owns it. + It also split lattice-specific anatomy construction out of + `anatomy.py` into `graphs/lattice_anatomy.py`: coordinate construction, + band/explicit ports, offset-neighborhood graph construction, KV grouping, + slot features, population layout, and local sender tables are now graph + constructor responsibilities. Runtime no longer rebuilds local sender + tables from `spec.config.coord_shape`/`spec.config.wrap`, and Fabric IR + no longer carries lattice `wrap` as backend identity. Same-slice bug fix: + the default message strategy matcher now follows the current compiler + primitive rows for split input-direct/input-group/recurrent KV projection + instead of the stale bundled KV row assumption. Verification: + `uv run python -m compileall -q src/cortical/fabric/anatomy.py src/cortical/fabric/graphs/lattice_anatomy.py src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/ir.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_anatomy.py tests/test_fabric_public_api.py tests/test_fabric_execution_imports.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/anatomy.py src/cortical/fabric/graphs/lattice_anatomy.py src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/ir.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_anatomy.py tests/test_fabric_public_api.py tests/test_fabric_execution_imports.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_anatomy.py tests/test_fabric_public_api.py tests/test_fabric_execution_imports.py tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names --tb=short`, + `uv run pytest -q tests/test_fabric_execution_imports.py tests/test_fabric_runtime.py::test_receiver_major_projection_backward_gate_is_work_shape_based --tb=short`, + and targeted `git diff --check` passed. + Remaining public API cleanup is narrower: the broad `Config` object still + exists as compatibility input and Blueprint still normalizes through it, + but backend/runtime lattice factorization is no longer consumed from + `Spec.config`. + - Iteration 8/10 DONE: tighten source and metadata guardrails so legacy + fixed-slot, compatibility, fallback, facade, or direct helper routes + cannot re-enter the active CUDA training path. Added a boundary guard that + rejects reintroducing runtime-side local sender table construction from + `spec.config.coord_shape`/`spec.config.wrap`, backend-IR lattice `wrap` + identity, old `backend/cuda/reference` and `backend/cuda/registry` + directories, top-level direct CUDA kernel re-exports, and the deleted + message/readout mini registries. Existing import guards now also assert + the `cuda` package stays empty as an execution API. Verification: + `uv run python -m compileall -q tests/test_fabric_backend_boundaries.py`, + `uv run ruff check tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_lattice_config_cleanup_stays_out_of_backend_runtime tests/test_fabric_backend_boundaries.py::test_temporal_engine_table_sources_do_not_use_cell_or_benchmark_route_selectors --tb=short`, + and `uv run pytest -q tests/test_fabric_backend_boundaries.py --tb=short` + passed. + - Iteration 9/10 DONE: run targeted compiler/parity validation only for + the changed surfaces; avoid broad water-is-wet suites unless a boundary + change justifies them. This caught stale guardrails from the default + message rule's split recurrent-KV projection row: structural row-group + signatures, primitive opcode rows, summary expectations, and + missing-binding negative tests now assert the current compiler-declared + message program instead of the older bundled KV row. Verification: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch tests/test_fabric_backend_plan.py::test_temporal_strategy_matching_uses_canonical_row_group_schema tests/test_fabric_backend_plan.py::test_forward_executor_bindings_fail_when_compiler_binding_is_missing tests/test_fabric_backend_plan.py::test_reverse_executor_bindings_fail_when_compiler_binding_is_missing tests/test_fabric_backend_plan.py::test_temporal_forward_primitive_row_tensor_encodes_supported_program_groups --tb=short` + passed; + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py tests/test_fabric_anatomy.py tests/test_fabric_public_api.py --tb=short` + passed: 187 tests; + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_validation_counter_v3 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed. + - Iteration 10/10 DONE: doc closure sweep and git hygiene. The stale + future-task wording for the DotProduct stress stage and anatomy cleanup + now points to the completed counter entries instead of presenting closed + work as pending. Git hygiene review found the new runtime-critical source + file `src/cortical/fabric/graphs/lattice_anatomy.py` as the only new + source file that must be included in the eventual commit; the other + untracked files are unrelated `ai_docs/` scratch docs already outside + this compiler closure slice. Final Config cleanup in this iteration also + narrowed lattice graph construction from public `Config` to + `LatticeAnatomyConfig`, so graph/topology constructors consume only the + fields they own. Verification: + `uv run python -m compileall -q src/cortical/fabric/anatomy.py src/cortical/fabric/graphs/lattice_anatomy.py tests/test_fabric_anatomy.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/anatomy.py src/cortical/fabric/graphs/lattice_anatomy.py tests/test_fabric_anatomy.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_anatomy.py tests/test_fabric_backend_boundaries.py::test_lattice_config_cleanup_stays_out_of_backend_runtime tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names --tb=short`, + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py tests/test_fabric_execution_imports.py tests/test_fabric_anatomy.py tests/test_fabric_public_api.py --tb=short`, + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_validation_counter_v3 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short`, + `rg -n "finish the DotProduct|This is a scheduled acceptance gate|anatomy.py\\` cleanup remains open|Iteration 10/10" ai_docs/REDO2_FIXMASS.md`, + `git status --short`, + `git diff --check -- ai_docs/REDO2_FIXMASS.md tests/test_fabric_backend_plan.py`, + and `uv run ruff check tests/test_fabric_backend_plan.py` passed. + - Post-counter public Config/graph cleanup DONE: removed the remaining + flat public `Config` constructor path from the live source. Fabric config + now requires a graph declaration and owned sections for interface, + message, population layout, readout, execution, and initialization. + Boundary ports and topology inputs moved into graph declarations: + `lattice2d.Graph` owns lattice bands/explicit ports/connectivity, and + the new `graphs.flat.Graph` supports user-defined node/edge graphs with + explicit inputs, outputs, recurrent nodes, populations, and KV groups. + `Blueprint` now preserves the graph declaration and fills owned config + sections instead of translating back into flat lattice fields. Runtime, + backend IR, CUDA sequence-surface helpers, and the visualizer now read + owned section fields or anatomy metadata instead of `spec.config.width`, + `coord_shape`, `wrap`, `d_msg`, `readout_pool`, `backend`, or K/checkpoint + flat fields. A boundary test now rejects `Config(width=..., height=..., + hidden_size=...)` as a validation error. + Verification: + `uv run python -m compileall -q src/cortical/fabric/config.py src/cortical/fabric/blueprint.py src/cortical/fabric/anatomy.py src/cortical/fabric/graphs src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/ir.py src/cortical/visualization/fabric.py tests/test_fabric_anatomy.py tests/test_fabric_public_api.py tests/test_fabric_visualizer.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py`, + `uv run ruff check src/cortical/fabric/config.py src/cortical/fabric/blueprint.py src/cortical/fabric/anatomy.py src/cortical/fabric/graphs src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/ir.py src/cortical/visualization/fabric.py tests/test_fabric_anatomy.py tests/test_fabric_public_api.py tests/test_fabric_visualizer.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py`, + `uv run pytest -q tests/test_fabric_anatomy.py tests/test_fabric_public_api.py tests/test_fabric_backend_boundaries.py::test_lattice_config_cleanup_stays_out_of_backend_runtime --tb=short`, + and `uv run pytest -q tests/test_fabric_visualizer.py --tb=short` + passed. + - Post-counter hard-blocker cut DONE: direct transition primitive launch + wrappers were removed from the active Python and pybind surfaces. The C++ + primitive bodies remain only under `registered_program/` as strategy + internals used by fused transition/full temporal program dispatch. Deleted + direct primitive probe tests so transition math is validated through + compiler-owned program launches, while boundary guards now reject + reintroducing `registered_program_transition_*_cuda` Python functions or + `program_transition_*` pybind exports. Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_registered_tanh_callable tests/test_fabric_backend_plan.py::test_fused_reverse_transition_program_cuda_dispatches_registered_tanh_callable tests/test_fabric_backend_plan.py::test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings --tb=short` + passed. + - Post-counter direct transition lowering deletion DONE: the dormant + `src/cortical/fabric/backend/cuda/transition_execution/lowering.py` + module was deleted. It was no longer imported by active source and still + contained direct gated/diagonal eager CUDA formula routing, so keeping it + as a sibling would preserve a non-compiler execution path in the tree. + Boundary tests now require the file to be absent and keep the + `transition_execution` package as metadata/projection/registry/program + modules only. Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py src/cortical/fabric/backend/cuda/transition_execution`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_transition_execution_monolith_was_deleted tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted --tb=short` + passed. + - Post-counter stale transition-lowering metadata cleanup DONE: after + deleting the eager lowering module, transition executor records/plans no + longer expose `forward_lowering`, `backward_lowering`, or + `runtime_lowering_status`. The remaining metadata now names + `forward_strategy_id`, `backward_strategy_id`, and + `runtime_execution_status`, with fused records reporting + `registered_fused_program` and primitive DAGs reporting + `registered_primitive_dag_program`. Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/transition_execution tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/transition_execution tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_boundaries.py::test_rejected_transition_temporal_fusion_facades_were_deleted --tb=short` + passed. + - Post-counter standalone message CUDA probe deletion DONE: deleted + `src/cortical/fabric/backend/cuda/message_passing/**`, including the + direct local/sparse message Python wrappers, pybind bindings, backend + headers, and standalone CUDA kernels. The active registered temporal + program continues to reach message math through compiler-selected native + strategy bodies and dense message ops; tests that exercised the deleted + wrappers as public/direct APIs were removed or shifted to registered + strategy/source-contract coverage. Import and boundary guards now require + the `cuda/message_passing` directory to stay absent. + - Post-counter parameter-reducer count-target genericity DONE: the shared + registered parameter-reducer ABI no longer has a fixed-slot-specific + expected-count target. Fixed-slot context remains a message-strategy + reducer implementation selected by message-rule metadata, but the common + reducer row/count validation now names and validates the generic + `message_strategy` tensor table target. This keeps future message + reducers local to message-rule specs/native callable records instead of + adding a new shared reducer count enum for every message strategy. + Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py src/cortical/fabric/backend/message_rule_specs.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py src/cortical/fabric/backend/message_rule_specs.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned --tb=short` + passed. The targeted CUDA reducer test + `uv run pytest -q tests/test_fabric_backend_plan.py::test_registered_parameter_reducer_cuda_executes_transition_trainable_rows --tb=short` + also passed. + - Post-counter message-rule lowering ID genericity DONE: C++ message-rule + lowering no longer owns a `MessageRuleLoweringKind` enum or a DotProduct + allowlist. The generated message-rule catalog now emits registry-derived + integer lowering IDs next to each structural pattern, and + `lower_message_rule_to_bucket` accepts any catalog-registered message + lowering while leaving executor support to primitive/strategy legality. + This means adding a new message rule no longer requires editing C++ + enum cases or the CUDA NN lowering allowlist; it requires registry + metadata, generated catalog validation, primitive/executor strategy + coverage, and tests. Verification: + `python -m compileall -q src/cortical/fabric/backend/message_rules.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/message_rules.py tests/test_fabric_backend_boundaries.py`, + `uv run python scripts/validate_fabric_generated_catalogs.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_cuda_message_rule_ir_distinguishes_context_gate_from_nudge --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs --tb=short` + passed. + - Post-counter default-message strategy registration DONE: the default + DotProduct message strategy (`neighborhood_attention_project`) now comes + from `MessageRuleBackendSpec.native_executors` and declared message + static-tensor access rows, the same path used by fixed-slot context + message variants. `executor_patterns.py` now derives all registered + message forward/reverse executor patterns from ordered registered message + specs instead of hand-inserting the default DotProduct strategy and then + appending fixed variants. This keeps new message strategies local to + message-rule metadata/native-callable registration rather than requiring + edits in temporal strategy tuples. Verification: + `python -m compileall -q src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/backend/message_rules.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run python scripts/validate_fabric_generated_catalogs.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_message_rule_backend_specs_are_registered_like_cell_specs tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned --tb=short` + passed. + `uv run pytest -q tests/test_fabric_execution_imports.py --tb=short` + also passed after updating the generic CUDA NN import guard to require + generated message lowering IDs instead of the deleted C++ enum. + - Post-counter transition primitive alias genericity DONE: the lowered + primitive alias `diagonal_recurrence -> diag_rtu` moved from lookup logic + into `TransitionPrimitiveExecutorRecord.aliases`. The transition + primitive lookup now checks registry record aliases generically, so a new + lowered primitive spelling is registered with the primitive executor + record instead of adding another hardcoded branch. This same pass marked + default DotProduct message program access specs as + `existing_static_tensor`, so the message rule owns program access + metadata without asking runtime to rematerialize base static tensors that + already exist in the compiler static tensor table. Verification: + `python -m compileall -q src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/transition_execution/registry.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/message_rule_specs.py src/cortical/fabric/runtime/core.py src/cortical/fabric/backend/cuda/transition_execution/registry.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_transition_execution_monolith_was_deleted tests/test_fabric_backend_plan.py::test_transition_executor_selection_uses_registered_structural_records tests/test_fabric_backend_plan.py::test_message_rule_backend_specs_are_registered_like_cell_specs --tb=short` + passed. + - Post-counter direct forward helper surface deletion and registry + genericity DONE: the remaining registered-forward helper kernels were + removed from the Python and pybind runtime surfaces. The only + runtime-callable registered temporal forward path is now the fused + compiler program entrypoint; sender-K/V, attention, readout projection, + and layout helper bodies remain only as internal native-callable/C++ + implementation details behind selected strategy rows. The old + `compute_registered_temporal_bucket_step_artifacts(...)` sibling route + was deleted, so training artifacts come from the fused forward program's + reverse artifact tensor table instead of reconstructing Python step + objects through per-surface executors. Final-state materialization now + assembles cells from the fused program tensor table/public carry directly + instead of calling a registered helper surface. Readout strategy patterns + now come from registered readout-rule backend specs for every supported + lowering kind, with primitive attribute constraints distinguishing + `mean`, `flatten`, `attn`, and `attention` rows. Transition strategy + patterns now come from `transition_execution.registry` + `TransitionExecutorStrategySpec` records, so new transition strategies + are registered with transition metadata instead of central + sequence-surface tuples. Shared native-callable IDs are now merged only + when their direction/surface/version and named C++ entrypoint phases + match; the generated catalog records the shared callable identity instead + of whichever strategy alias happened to be deduped first. Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py`, + `uv run python scripts/validate_fabric_generated_catalogs.py --write`, + `uv run ruff check tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/row_groups.py src/cortical/fabric/backend/cuda/transition_execution/registry.py src/cortical/fabric/backend/readout_rule_specs.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs tests/test_fabric_backend_plan.py::test_readout_executor_patterns_follow_registered_readout_specs --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_forward_scan_fails_closed_on_unsupported_primitive_programs tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed (`160 passed in 59.55s`). + - Post-counter reverse output-route row cut DONE: the active registered + reverse program no longer maps fused reverse output tensors to parameter + reducers by host-local `front_output("...")` / `boundary_output("...")` + logical-name calls. `program_execution.py` now emits + `TemporalReverseOutputRouteRow` records that map route kind and target + reducer role to the compiler-owned reverse span output group/role, and + the fused CUDA launch contract records `reverse_output_route_rows` as a + required compiler table. `RegisteredTemporalExecutorProgram` carries that + table, and `registered_executors.py` resolves readout gradients, + recurrent/output query gradients, sender-K/V gradients, transition + boundary gradients, carry gradients, and message-strategy extra + parameter-gradient outputs through the route rows. This is not throughput + tuning; it removes another remaining reverse helper genericity leak before + throughput work. Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_reverse_span_outputs_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_reverse_output_routes_are_compiler_owned_rows --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed. Follow-up static compiler suite + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed (`162` tests). + - Post-counter runtime support-row cut DONE: the active fused + forward/reverse Python entrypoints no longer own the first layer of + output-contract, readout-pool, device/dtype, local-message-step, artifact + storage, grad-window, carry-grad, and reverse-artifact role legality + checks as local conditionals. `program_runtime.py` now emits + `TemporalProgramRuntimeSupportPlan` requirement rows and summaries for + forward and reverse runtime facts; `registered_executors.py` records those + rows in runtime metadata and uses the compiler-owned rejection reason to + fail closed before launch. This is still support legality, not throughput + tuning, but it moves another hardcoded active-path support gate into the + compiler contract. Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_runtime.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_fused_program_runtime_facts_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_reverse_fused_program_runtime_facts_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_fused_program_runtime_support_rejections_are_compiler_owned_rows --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed. + - Post-counter message/readout surface-cardinality cut DONE: the Python + registered executor program no longer exposes singleton + `forward_surface_handle(...)` / `reverse_surface_handle(...)` APIs for + message/readout surfaces. Surface selection now returns ordered executor + handle tuples, forward/reverse program tensor tables consume those + handle tuples, and reverse program stage rows are emitted for every + compiler-selected message/readout executor row instead of using an + `_first_reverse_executor_row` singleton. The global transition binding + scan for a single message `output_dim_role` was also deleted; transition + parameter reducers now choose projected-message static sources from the + binding's own source rows. C++ fused-program span lookup now has + collector helpers + `registered_forward_handler_span_indices_by_capability` and + `registered_reverse_handler_span_indices_by_capability`; the current + message/readout reducer bodies still require a single route-reduced span + explicitly, so the remaining hard follow-up is to make those route rows + describe multi-span reduction/merge semantics before broad multi-carrier + message/readout programs can be accepted. Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_reverse_program_access_rows_are_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row tests/test_fabric_backend_boundaries.py::test_forward_message_readout_handlers_use_native_strategy_access_schema tests/test_fabric_backend_boundaries.py::test_reverse_message_readout_helpers_use_native_strategy_access_schema --tb=short` + passed. + - Post-counter reverse surface span-merge cut DONE: the registered fused + reverse C++ surface helpers no longer call a unique reverse + message/readout span helper before executing readout, output-message, + recurrent-K/V, recurrent-message, initial recurrent-K/V, or boundary-K/V + reverse work. The helpers now collect all compiler-matched reverse spans + for the required surface capability and combine their compiler-routed + output tensors with shape/dtype/device validation, using OR/max semantics + for scalar long flags such as grouped sender-K/V. This makes reverse + helper math consume compiler-selected span sets instead of assuming one + fixed span. Remaining before broad multi-carrier reverse programs can be + accepted: Python parameter reducer binding still needs per-span reducer + route rows so parameter-gradient outputs are not collapsed before reducer + ownership is chosen; the active path now fails closed there with + `reverse_surface_parameter_reducers_require_per_span_route_rows` rather + than hiding a unique reverse helper assumption. Verification for the + reverse span-merge and forward span-collector cuts: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_forward_message_readout_handlers_use_native_strategy_access_schema tests/test_fabric_backend_boundaries.py::test_reverse_message_readout_helpers_use_native_strategy_access_schema --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_reverse_output_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_temporal_reverse_program_access_rows_are_compiler_owned tests/test_fabric_backend_plan.py::test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned --tb=short` + passed. CUDA smoke + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_span_merge_smoke uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed. + - Post-counter forward surface span-collector cut DONE: the registered + fused forward program no longer owns `require_unique_*` helper functions + for message/readout handler lookup. Forward message/readout binding now + collects all compiler-selected handler spans through the same collector + row scan used by reverse. The current fused forward artifact table still + stores one logical input-K/V, recurrent-K/V, recurrent-message, + output-message, and output-cells artifact stream, so broad multi-span + forward programs fail closed at the artifact-merge boundary with + `requires compiler artifact merge rows for multiple ... spans` until + those route rows drive per-span artifact streams at execution time. This + deletes the old singleton helper API while keeping the remaining + artifact-route execution gap explicit. + - Post-counter compiler route-row product cut DONE: added explicit + compiler products for the remaining span-ownership gaps. Forward reverse + artifacts now have `TemporalForwardArtifactRouteRow` / + `forward_artifact_route_rows`, carrying producer surface, executor row, + executor id, bucket ordinal, artifact role, logical route id, required + flag, and schema version. Reverse parameter reducers now have + `TemporalReverseParameterReducerRouteRow` / + `reverse_parameter_reducer_route_rows`, carrying reducer route kind, + target role, source span-output group/role, source surface, executor row, + executor id, bucket ordinal, required flag, and schema version. The + registered executor program, runtime metadata, and fused CUDA launch + contract expose both row tables, and the forward C++ program boundary now + validates `forward_artifact_route_rows` before launching. The reverse C++ + program boundary now receives and validates `reverse_program_stage_rows` + instead of treating them as Python-only metadata; full stage-row-driven + execution remains the next cut. Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch tests/test_fabric_backend_plan.py::test_forward_artifact_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_reverse_parameter_reducer_routes_are_compiler_owned_rows tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_reverse_program_access_rows_are_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed. CUDA smoke + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_route_rows_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed. Remaining hard blocker: make C++ reverse surface helpers return + per-span output groups keyed by these reducer route rows instead of + combining same-role outputs before Python reducer request construction, + then delete the active + `reverse_surface_parameter_reducers_require_per_span_output_groups` guard + and the forward multi-span artifact-route guards. + - Post-counter artifact/reducer route-row execution cut DONE: forward + reverse-artifact binding rows now carry the compiler forward artifact + route row as their fifth column + `[role_id, tensor_index, local_step, flags, forward_artifact_route_row]`. + The fused forward C++ program resolves every stored runtime-policy, + message, transition, and readout artifact through + `forward_artifact_route_rows` before writing the reverse artifact tensor + table, and both the Python and C++ reverse artifact validators require + the route-row column. Reverse parameter-gradient reducer construction now + resolves readout/query/sender-KV/message-strategy parameter outputs + through `reverse_parameter_reducer_route_rows` filtered by surface, + executor row, executor id, and bucket ordinal; ordinary value/carry + gradients still use the base reverse output route table. This removes + another active helper assumption from the registered training path. The + remaining hard blocker is narrower: C++ reverse span output groups are + still merged before Python sees them, so multi-message/readout reducer + programs continue to fail closed at + `reverse_surface_parameter_reducers_require_per_span_output_groups` until + C++ returns per-span output groups or equivalent route-addressed slots. + Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_artifact_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_reverse_parameter_reducer_routes_are_compiler_owned_rows tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. CUDA smoke + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_route_artifact_rows_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed. + - Post-counter reverse per-span reducer group cut DONE: the registered + fused reverse C++ program now returns the original merged front/boundary + groups for carry/value propagation plus additional compiler-ordered + per-readout front groups, per-message front groups, and per-message + boundary groups. `registered_executors.py` consumes those groups by + executor row and uses `reverse_parameter_reducer_route_rows` to build + readout, output-query, recurrent-query, sender-K/V, and message-strategy + reducer requests per selected executor span. The active + `reverse_surface_parameter_reducers_require_per_span_output_groups` + rejection was deleted; multi-span reverse reducer ownership no longer + depends on a merged same-role group. Remaining hard blocker before this + owner is fully closed: forward multi-span artifact streams still fail + closed at the artifact-route boundary until the forward artifact table + stores per-producer streams instead of one logical message/readout + stream. Verification so far: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_reverse_parameter_reducer_routes_are_compiler_owned_rows tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_plan.py::test_reverse_parameter_reducer_routes_are_compiler_owned_rows --tb=short` + passed. CUDA smoke + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_reverse_per_span_groups_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed. + - Post-counter routed reverse artifact consumption cut DONE: the registered + fused backward program now receives `forward_artifact_route_rows` through + the active Python/C++ launch ABI and uses compiler producer routes to + fetch span-owned reverse artifacts. Readout output-message artifacts and + message input/recurrent K/V, recurrent-message, and recurrent-hidden + artifacts are now resolved from the reverse artifact tensor table by + producer surface, bucket, artifact role, and local step instead of by + role/local-step position alone. Producer route lookup deliberately fails + closed when multiple forward producers for the same surface/bucket/role + exist; that remaining gap is now explicitly `compiler artifact merge + rows`, not missing route-row plumbing. Runtime-policy artifacts such as + `boundary_step` and `cells_prev` remain global role-owned artifacts. + Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings --tb=short` + passed. CUDA smoke with a fresh extension directory + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_route_artifact_access_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short`; + passed. + - Post-counter compiler artifact merge-row execution cut DONE: added + `TemporalForwardArtifactMergeRow` / `forward_artifact_merge_rows` as a + compiler product with producer surface, executor row/id, bucket ordinal, + artifact role, merge kind, output route id, required flag, and schema + version. The active registered forward and backward fused CUDA launch + ABI now carries the merge rows through Python, C++ binding, forward + program, backward program, and runtime metadata. Forward message/readout + selection no longer has direct `message_executors.size() == 1` or + `readout_executors.size() == 1` guards; it resolves the selected producer + through compiler merge rows. Reverse span-owned artifact consumption now + resolves through the same merge rows rather than role-only lookup. + `identity_singleton` is executable now; `concat_or_error` and + `sum_or_error` remain explicit legality/error merge kinds until their + output semantics are intentionally accepted. Same cut closed a reducer + follow-up exposed by CUDA smoke: transition parameter reducers now scatter + bucket-local materialized and static-source gradients through + compiler-owned recurrent-row source maps, not old full-shape alignment + assumptions or graph-cell IDs. Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run python scripts/validate_fabric_generated_catalogs.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed (`165` tests), and + `CUDA_LAUNCH_BLOCKING=1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + plus the same targeted CUDA test without launch blocking both passed. + - Post-counter forward span cardinality legality cut DONE: the fused + forward span validator no longer requires exactly one message-carrier + handler and exactly one readout handler. It requires at least one + compiler-selected producer for each surface, then leaves multi-producer + legality to the compiler artifact merge rows. This cut made aggregate + artifact merges a planner-visible blocker before the follow-up executable + merge cut below replaced the blocker with shape-checked merge execution. + - Post-counter forward output ownership route cut DONE: added + `TemporalForwardOutputRouteRow` / `forward_output_route_rows` as an + executable compiler product for output sequence ownership. The active + registered fused forward/backward launch ABI now carries output route + rows through the executor program, runtime metadata, Python CUDA wrapper, + C++ binding, forward program, and backward program validation. Forward + output materialization now selects its readout executor through the + compiler output route row, not through a readout artifact route or hidden + singleton readout assumption. Multiple output owner routes now fail + closed in `build_temporal_fused_cuda_program_plan` with + `FORWARD_OUTPUT_ROUTE_UNSUPPORTED` before launch; the current executable + policy is intentionally one logical output owner until concat/reduce/select + route semantics are defined. Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_output_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_forward_multi_output_routes_block_before_cuda_launch tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed (`168` tests), and CUDA smoke + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed. + - Post-counter reverse artifact consumer route cut DONE: added + `TemporalReverseArtifactConsumerRouteRow` / + `reverse_artifact_consumer_route_rows` as the missing compiler mapping + from reverse executor row/id to the forward artifact producer route row. + This fixes the remaining reverse-helper genericity breach where routed + reverse artifact access ignored the reverse executor identity and + resolved through merged surface/bucket/role lookup. The fused backward + launch ABI now carries the consumer route rows through the executor + program, runtime metadata, Python CUDA wrapper, C++ binding, backward + program validation, and the routed reverse artifact helper. Reverse + readout/message spans now consume artifacts by compiler consumer route, + so forward and reverse executor ids are no longer assumed to be equal. + Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler src/cortical/fabric/backend/cuda/sequence_surface/temporal tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_reverse_artifact_consumer_routes_map_reverse_spans_to_forward_artifacts tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py --tb=short` + passed (`169` tests), `uv run python scripts/validate_fabric_generated_catalogs.py` + passed, and CUDA smoke + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells --tb=short` + passed. + - Post-counter forward producer-route execution cut DONE: the registered + fused forward program no longer selects one global message executor + before the scan. It now executes every compiler-selected forward message + producer for each physical step, stores input K/V, recurrent K/V + before/after, recurrent hidden before/after, and recurrent message + artifacts under the producer's `forward_artifact_route_rows`, then + resolves the logical tensors consumed by transition and readout through + `forward_artifact_merge_rows`. The program body no longer treats "the + message executor" as a temporal-engine singleton. + Same cut removed the obsolete surface+role artifact merge helper and + changed the reverse output-message helper to select its message span from + `reverse_artifact_consumer_route_rows`; reverse output-message no longer + discovers message artifacts through a broad surface/role merged lookup. + Verification: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_plan.py::test_forward_artifact_aggregate_merges_block_before_cuda_launch --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned tests/test_fabric_backend_boundaries.py::test_reverse_message_readout_helpers_use_native_strategy_access_schema tests/test_fabric_backend_plan.py::test_reverse_artifact_consumer_routes_map_reverse_spans_to_forward_artifacts --tb=short`, + `uv run ruff check tests/test_fabric_backend_boundaries.py`, + targeted `git diff --check`, and CUDA smoke + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_forward_route_multi_producer_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed. + - Post-counter executable artifact merge cut DONE: non-identity + `forward_artifact_merge_rows` are no longer rejected by + `build_temporal_fused_cuda_program_plan`. The fused forward program now + resolves merge rows directly and supports `identity_singleton`, + shape-checked `concat_or_error`, and shape-checked `sum_or_error` for + message artifacts consumed by transition/readout. The old singleton-only + `forward_artifact_merged_route_row_for_surface_bucket_role` helper was + deleted, and the remaining single-output readout helper was renamed to + `single_executable_forward_output_readout_route` so the output policy is + explicit. Remaining output work: define multi-output route semantics + (`select`/`concat`/`reduce`) or keep them as compiler legality failures + before launch. Verification: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_artifact_aggregate_merges_are_executable_compiler_rows tests/test_fabric_backend_plan.py::test_forward_multi_output_routes_block_before_cuda_launch tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + targeted `git diff --check`, and CUDA smoke with a fresh extension + directory + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_artifact_merge_exec_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short`. + passed. + - Post-counter forward output route execution cut DONE: forward output + route rows now carry executable route semantics for base/select, + concat, and sum output ownership. `build_temporal_fused_cuda_program_plan` + no longer rejects multiple output route rows when the compiler rows + explicitly select concat or sum semantics, and the fused forward C++ + program no longer calls the singleton readout-output helper. It resolves + every `forward_output_route_rows` producer to its compiler-selected + readout executor, executes the routed readout outputs, and materializes + `output_seq` through the route semantics. Ambiguous multi-output rows + still fail closed before launch with + `registered_fused_program_requires_explicit_multi_output_route_merge_kind`. + Same cut made reverse transition recurrent-message output ownership an + explicit keyed compiler contract: `transition_recurrent_msg_output_rows` + are now `[group_index, output_slot, bucket_start, bucket_stop]`, and the + fused reverse C++ program validates one row for each transition group by + group id instead of assuming row order equals group order. Remaining + output-route follow-up before broad multi-output training support: + route-aware output-cell artifact merging/splitting for pooled/multi + readout backward, so `grad_output_window` is split by the same output + route rows rather than by the current single logical output artifact. + Verification: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_multi_output_concat_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_forward_multi_output_routes_require_explicit_merge_semantics tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. CUDA smoke: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_output_routes_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed. + - Verification for the executable memory runtime-policy cut: + `python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes tests/test_fabric_backend_plan.py::test_temporal_backward_validates_memory_artifact_plan_fingerprint --tb=short`, + and `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_memory_plan_drives_checkpoint_and_backward_windows --tb=short` + passed. + - Current pre-throughput compiler closure checklist after the output-route + execution review: + - [x] Close the output-route ABI by validating + `forward_output_route_rows[:, 9]` as compiler-owned `output_offset`, + enforcing exact concat offsets in fused forward validation, and + requiring zero offsets for non-concat route semantics. + - [x] Make Python registered reverse tensor-store consumption + route-aware: `boundary_step` and `cells_prev` remain global + role-owned artifacts; `output_msg` and `output_cells` must resolve + through compiler producer routes. Pooled output backward must + reconstruct logical output cells through `forward_output_route_rows` + instead of role-only artifact lookup. + - [x] Make C++ registered readout backward split aggregate + `grad_output_window` by compiler output route rows: concat slices by + `output_offset`, sum shares the aggregate output grad, and singleton or + select routes stay one-owner. Input/recurrent carry gradients must not + be duplicated across multiple readout spans. + - [x] Delete the remaining role-only output assumptions from the active + reverse path, especially C++ full-step `output_cells` role-only access + and Python role-only `output_msg`/`output_cells` preflight. Add focused + source guardrails so these cannot return. + - [x] Run CUDA extension smoke for the route-aware output backward cut. + Verification: + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_backend_plan.py tests/test_fabric_backend_boundaries.py`, + and + `uv run pytest -q tests/test_fabric_backend_plan.py::test_forward_output_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_forward_multi_output_concat_routes_are_compiler_owned_rows tests/test_fabric_backend_plan.py::test_forward_output_routes_reject_non_concat_offsets_before_launch tests/test_fabric_backend_boundaries.py::test_fused_cuda_launch_contract_is_compiler_owned --tb=short` + passed. CUDA smoke: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_output_route_backward_v2 uv run pytest -q tests/test_fabric_backend_plan.py::test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program --tb=short` + passed after fixing the full-step reverse ABI to pass + `forward_output_route_rows`. + - [x] Finish the public Blueprint/Config lowering cleanup: remove + `_blueprint_to_config()` as the normalization choke point, build the + normalized declaration/spec directly from `Blueprint`, and keep + `Config` only as an internal section container if it remains useful. + `src/cortical/fabric/blueprint.py` now lowers into an explicit + `_BlueprintLowering` product before creating the internal runtime section + container; `normalize()` no longer calls a Blueprint-to-config + translator or invokes message-rule lowering from the raw source object. + - [x] Run the final bounded support-surface gate: supported CUDA + declarations must compile through registered primitive rows and + unsupported declarations must fail closed before launch. No hidden + compatibility, fallback, facade, or scheduler-owned primitive math is + allowed. + Verification: + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_blueprint_normalization_is_not_old_config_translation tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_temporal_executor_bindings_are_compiler_products_not_compatibility_slots tests/test_fabric_backend_boundaries.py::test_forward_scan_fails_closed_on_unsupported_primitive_programs tests/test_fabric_backend_boundaries.py::test_shared_temporal_forward_scan_has_no_python_step_loop_fallback tests/test_fabric_backend_plan.py::test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names tests/test_fabric_backend_plan.py::test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch tests/test_fabric_backend_plan.py::test_message_rule_backend_specs_are_registered_like_cell_specs tests/test_fabric_backend_plan.py::test_message_executor_patterns_follow_registered_message_specs tests/test_fabric_backend_plan.py::test_readout_executor_patterns_follow_registered_readout_specs tests/test_fabric_backend_plan.py::test_message_rule_compiler_rejects_unsupported_rule_before_runtime_execution tests/test_fabric_backend_plan.py::test_transition_program_compiler_rejects_unsupported_transition_op tests/test_fabric_backend_plan.py::test_backend_ir_uses_declared_spec_message_rule_not_default_substitution tests/test_fabric_backend_plan.py::test_fabric_supported_backend_surface_matrix_exposes_no_fallback_contract --tb=short` + passed: 15 tests. + - [x] Deep-dive active-path audit: old fixed scan/reverse ABI markers are + not present in live `sequence_surface` runtime source. A source audit for + `try_flat_bucket_temporal_scan_cuda`, + `try_transition_message_reverse_table_window_cuda`, + `build_temporal_forward_compatibility_launch_plan`, + `build_temporal_backward_compatibility_launch_plan`, + `fixed_composite_abi`, `compatibility_launch_plan`, and the old + singleton message/readout executor guards returns no live source matches + under `src/cortical/fabric/backend/cuda/sequence_surface`; remaining hits + are static guardrail tests. + - [x] Deep-dive `missing_executor` classification: the remaining + `missing_executor` strings in + `compiler/primitive_dispatch.py` are typed fail-closed contract metadata + for unsupported or uncovered primitive rows. They are allowed only when a + declaration is rejected before launch; they are not accepted as a + supported CUDA execution path. + - [x] Reconcile the stale historical five-pass status table: the earlier + PARTIAL/PENDING pass checklist is now explicitly labeled as the + historical fixed-ABI review target. The current checklist here is the + current signoff source of truth. + - [x] Run a broader pre-throughput validation sweep beyond the targeted + compiler/source/CUDA smoke gates above. Minimum expected sweep before + throughput work is a current Fabric-focused pytest slice that includes + backend boundaries, backend plan, execution imports, runtime, public API, + anatomy, and visualizer tests, plus `scripts/validate_fabric_generated_catalogs.py`. + Any failure must be classified as compiler-closure regression, + unrelated dirty-tree regression, or accepted unsupported-surface gap + before throughput optimization starts. + Historical status, 2026-05-02: not closed at that time. + `uv run python scripts/validate_fabric_generated_catalogs.py` + passed, but the broad pytest command + `uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py tests/test_fabric_runtime.py tests/test_fabric_public_api.py tests/test_fabric_anatomy.py tests/test_fabric_visualizer.py --tb=short` + failed with `172 failed, 334 passed`. One compiler allocation-audit + miss was fixed immediately by classifying the route-split readout + backward workspace, and the focused allocation audit now passes. The + remaining compiler-relevant failures are registered transition parameter + reducer closure: small-hidden T>1 sLSTM reaches backward but has + parameter-gradient mismatches above tolerance, and small-hidden T>1 Axon + is missing `input_proj_weight_base` / `input_proj_weight_delta` parameter + gradients. Many other failures come from broad runtime tests that still + exercise disabled direct/non-registered surfaces and must be updated, + deleted, or moved behind explicit unsupported-surface assertions before + final signoff. + Follow-up, 2026-05-03: fixed the compiler-relevant registered + parameter-gradient blockers from this sweep. Axon transition + `input_proj_weight_base` / `input_proj_weight_delta` were missing because + the transition binding plan selected the old `message_to_cell_weight` + static source when the compiler had already produced + `value_to_cell_weight`; the binding now keeps the compiler-owned + value-to-cell source. Axon `out_proj_bias_base` also needed the + materialized-base reducer to align directly to its target parameter + instead of scattering through recurrent-row indices. + sLSTM T>1 parameter gradients were a forward-artifact liveness/reset + bug, not a reducer tolerance issue: the fused forward program stored + reverse artifacts with `tensor.contiguous()`, which aliases contiguous + program/runtime buffers, and it captured transition state-before + artifacts before applying transition reset zeroing. Forward reverse + artifacts are now snapshots, and reset zeroing happens before + state-before artifact capture. Focused verification: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_closure_reset_snapshot_v1 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_parameter_grads --tb=short` + passed (`4` tests: sLSTM/Axon, reset/no-reset). Additional targeted + checks passed: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_closure_reset_snapshot_v1 uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_dot_product_selects_registered_strategy tests/test_fabric_backend_plan.py::test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program tests/test_fabric_backend_plan.py::test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program --tb=short`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py::test_temporal_forward_registered_executor_owns_scan_bindings tests/test_fabric_backend_boundaries.py::test_temporal_backward_registered_executor_owns_reverse_bindings tests/test_fabric_backend_boundaries.py::test_parameter_reducer_native_callables_are_registry_owned --tb=short`, + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cell_specs.py tests/test_fabric_backend_plan.py`, + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py src/cortical/fabric/backend/cell_specs.py tests/test_fabric_backend_plan.py`, + and targeted `git diff --check`. + Closure-risk note from the same audit: active message strategy files still + contain `fixed_slot_context_*` strategy names. Those are registered + declared message strategies now, not the deleted fixed scan/reverse ABI, + but the naming is easy to misread and should stay visible for the + post-compiler dot-product stress/cleanup stage. + Follow-up, 2026-05-03: fixed the remaining compiler-relevant terminal + K>1 mixed-pop reverse blocker. The registered tensor-store reverse path + was only materializing a grad-output window for backward windows with an + active emitted output. Terminal output plans emit only at the final + physical microstep, so earlier reverse windows received `None` and the + compiler-owned fused reverse program correctly failed closed with + `registered_reverse_program_window_reject:missing_or_mismatched_grad_output_window`. + The reverse tensor-store path now always builds the compiler-owned + local-step output-gradient table and fills non-emission steps from routed + zero output artifacts. Focused verification passed: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_terminal_grad_fix_v2 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_terminal_replay_uses_temporal_superop tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients --tb=short -x` + (`4` tests). The same pass updated stale assertions that still expected + the deleted `registered_temporal_executor_recompute` and old CUDA + backward-glue tags; the live owner is now + `registered_fused_forward_program_tensor_store_direct` with reverse + ownership recorded as + `reverse_owner=registered_fused_reverse_program_tensor_table` and + `flat_bucket_temporal_reverse_scan_owner:registered_reverse_program_window`. + Follow-up, 2026-05-03: broad Fabric-focused closure sweep is now green. + The current full command, without last-failed filtering, + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_closure_broad_full_v1 uv run pytest -q tests/test_fabric_backend_boundaries.py tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py tests/test_fabric_runtime.py tests/test_fabric_public_api.py tests/test_fabric_anatomy.py tests/test_fabric_visualizer.py --tb=short` + passed (`504 passed`, `9 warnings`, `342.77s`). During the sweep cleanup, sparse + patch-edge and nonzero-delay CUDA rows were made explicit fail-closed + surfaces because the registered fused temporal program does not yet prove + sparse/delay message parity through compiler-owned executor rows. This is + intentional closure behavior: unsupported message declarations must fail + before launch rather than run a mismatched CUDA path. The affected tests + now assert the typed unsupported route, while the PyTorch reference still + executes for comparison-shape sanity. Additional checks passed: + `uv run ruff check src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_runtime.py`, + `uv run python -m compileall -q src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py tests/test_fabric_runtime.py`, + `uv run python scripts/validate_fabric_generated_catalogs.py`, and + targeted `git diff --check`. + - [x] Do final tree hygiene before signoff/commit: inspect `git status` and + make sure all new compiler-owned source files and generated catalog files + are tracked, stale scratch docs are intentional, and no commit can include + old monolith deletions without the replacement registered-program source. + Status, 2026-05-02: inspected but not final-signoff clean. `git status` + shows no untracked compiler source directories, but it does show + untracked scratch/planning docs: + `ai_docs/AWS_RECOVERY_TRAIL.md`, `ai_docs/additonal_goals.md`, and + `ai_docs/prompt.tx`. Decide whether these docs are intentional before + commit. + Follow-up, 2026-05-03: final tree-hygiene classification is complete for + this closure cut. `git status --short` still shows no untracked + compiler/runtime source directories or generated catalog files. The only + untracked paths are the recovered planning/context docs listed above; + they are referenced by the tracked REDO docs as historical inputs and are + not active runtime/compiler artifacts. They should be either intentionally + added as planning provenance or omitted before the eventual commit, but + they are not a compiler-closure blocker. + - Result: current Fabric-focused compiler closure sweep is green as of + 2026-05-03. The explicitly tracked post-compiler stress/cleanup stages + have also been reconciled in this document. There is no known remaining + compiler-closure blocker in the registered temporal compiler test sweep + above; the next major work category is throughput/performance closure. + +## Compiler Closure Checkpoint Before Throughput + +Checkpoint date: 2026-05-03. + +This checkpoint freezes the compiler-closure baseline before any throughput +optimization starts. It is intentionally documentation-only for throughput: +no benchmark, profiler, CUDA tuning, or owner-specific optimization was run as +part of this checkpoint. + +Active throughput progress document: +`ai_docs/FABRIC_THROUGHPUT_CLOSURE.md`. + +Baseline state: + +- Fabric-focused compiler closure sweep is green: + `504 passed`, `9 warnings`, `342.77s`. +- No known compiler-closure blocker remains in the registered temporal compiler + closure sweep tracked above. +- `git status --short` at checkpoint shows no dirty `src/`, `tests/`, + generated-catalog, or skill source paths. The dirty/untracked paths are docs: + `ai_docs/REDO2_FIXMASS.md`, `ai_docs/REDO_FIXMAASS.md`, + `ai_docs/AWS_RECOVERY_TRAIL.md`, `ai_docs/additonal_goals.md`, and + `ai_docs/prompt.tx`. +- The untracked docs are planning/provenance context, not active runtime or + compiler artifacts. Decide whether to add or omit them before the eventual + commit. +- April 21 comparison data source for the next phase is + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json`. +- April 21 is the throughput and memory target matrix, not code to copy. + Throughput ideas must be expressed as registered compiler strategies over + primitive rows, tensor bindings, memory/liveness rows, executor rows, and + reducer rows. + +Next throughput prompt should start with analysis, not optimization: + +1. build the current-code T=1 owner table against the April 21 rows; +2. identify the dominant registered compiler owners for the gap; +3. only then propose the first throughput strategy change. + +## Throughput Closure Plan - T=1 First + +Decision, 2026-05-03: throughput work starts with T=1. T=1 is the base physical +Fabric execution unit; K, T>1, horizon-H, and per-timestep loss are not allowed +to close while the matched T=1 compiler path is below the April 21 reference. +Use April 21 as the score/memory target and shape matrix, not as code to copy. +Any April 21 implementation idea must be re-expressed as a registered compiler +strategy over primitive rows, tensor bindings, memory/liveness rows, forward and +backward executors, and reducer rows. + +### T=1 Scope + +Do expand T=1 beyond a single 100M smoke row. Do not expand it so broadly that +we delay owner attribution forever. Use a tiered gate: + +1. **T=1 baseline owner table, required first.** + - h=32, single-pop sLSTM and Axon. + - 100M/500M/1B where feasible. + - B=1024 and B=16384 where feasible. + - forward and training/backward. + - terminal loss/output boundary. + - reset absent first; reset-present as soon as the owner touches state, + artifacts, tape, or backward. + - record April 21 target row, current tok/s, current peak GiB, stack/context + baseline if available, route metadata, forward owner, reverse owner, + reducer owner, dominant kernel/host owner, and whether parity was green. + +2. **T=1 stress guardrails before declaring T=1 closed.** + - small-parameter/high-batch rows: 1M/10M with B=16384/65536/131072, because + these catch launch overhead and batch scaling regressions that large-param + rows can hide. + - small-hidden many-cell rows: h=4, h=8, h=16 over 100M/500M/1B where + feasible, because hidden-size shrinkage must not become a route selector or + destroy many-cell scaling. + - mixed-pop sLSTM+Axon T=1 row, because single-pop and mixed-pop are bucket + cardinalities of the same flat-bucket engine. + - reset-present T=1 training row for each touched family once backward, + artifact, memory, or state-carry owners are involved. + - materialized-final-state and no-final-state coverage when the optimized + owner touches final state, carry, artifact, or output materialization. + +3. **T=1 graph/factorization sanity before moving to K/T/H.** + - at least one equivalent flat-graph/factorization pair should show that the + backend consumes flat graph facts rather than user-side lattice/factor + labels. + - this is a sanity gate, not the first owner. Run it after the main T=1 owner + moves. + +4. **T=1 closure criteria.** + - parity green for outputs, exposed state, input/carry gradients, and all + nonzero parameter gradients on the affected rows. + - current compiler path meets or exceeds the comparable April 21 Fabric tok/s + target, or any miss is explicitly accepted as a still-open throughput + blocker. + - peak memory does not meaningfully regress versus the comparable April 21 + row, or the regression has a named open memory/liveness owner. + - runtime/audit metadata reports registered compiler-owned forward, + backward, memory/artifact, and reducer owners; no Python replay/fallback, + benchmark tiling, direct wrapper, or hidden fixed-slot route is accepted. + - any optimization is a registered strategy over existing primitive rows. If + the primitive rows do not exist, stop and add compiler semantics/legality + before profiling. + +### April 21 T=1 Rows To Map First + +Use `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json` +as the initial target map: + +- `h32_t1_bxparams`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384. Boundary reference: Axon 100M train B=1024, + `58732.71 tok/s`, `2.07 GiB`. +- `h32_small_params_high_batch`: sLSTM + Axon, 1M/10M, forward + training, + B=16384/65536/131072. Boundary reference: Axon 1M train B=131072, + `1986978.16 tok/s`, `12.97 GiB`. +- `h4_many_cell_stress`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384. Boundary reference: Axon 500M train B=1024, + `12513.0 tok/s`, `18.09 GiB`. +- `h8_many_cell_stress_focused_warmed_rerun`: Axon 1B train B=1024, + `16723.8 tok/s`, `46.71 GiB`. +- `h16_many_cell_stress`: sLSTM + Axon, 100M/500M/1B, forward + training, + B=1024/16384. Boundary reference: Axon 500M train B=1024, + `35558.75 tok/s`, `10.63 GiB`. +- `flat_graph_factorization_invariance`: equivalent flat graphs with only + user-side factorization labels varied. Run after the main T=1 owner has moved, + to verify graph-shape genericity. + +### Additional Throughput Concerns To Include + +- **Reproducibility before attribution.** Every throughput row must record GPU, + command, commit/dirty state, private `TORCH_EXTENSIONS_DIR` and + `TRITON_CACHE_DIR`, warmup count, measured repeat count, and whether the row + was cold, warmed, interrupted, or concurrent with another GPU job. A cold-only + result cannot select an owner. +- **April 21 mapping must be exact.** Before comparing tok/s or memory, map the + current row to the closest April 21 key by family, parameter target, actual + params, batch, hidden size, mode, loss boundary, reset policy, graph shape, + and output/materialization contract. If no exact April 21 row exists, label it + as context/guardrail rather than closure target. +- **Single-pop first for attribution, mixed-pop same gate for closure.** + Single-pop T=1 is the diagnostic first target because it removes cross-bucket + layout/scheduling noise. Mixed-pop T=1 must then run the same conceptual + surfaces before T=1 closes. If mixed-pop needs a separate route, T=1 remains + open. +- **Owner taxonomy must be explicit.** Each miss should be assigned to one live + owner before code changes: forward program, reverse program, message/readout, + recurrent K/V projection, transition primitive, parameter reducer, + artifact/memory liveness, layout/copy, launch count, allocator/workspace, or + benchmark/user-visible tensor cost. Avoid generic labels such as "CUDA slow". +- **Memory is a first-class gate.** Passing tok/s with April 21 memory + regression is not closure. Split memory into user-visible tensors, planner + checkpoints/artifacts, workspace, parameter/reducer temporaries, and accidental + full-bank/full-sequence materialization. +- **Profiler evidence must be warmed and layered.** Use high-level audit timing + first, then owner-timed runtime metadata, then PyTorch/CUDA profiler or + kernel-level breakdown for the dominant owner. Do not start CUDA edits from a + stale profile or a first-run compile artifact. +- **Parity gates follow the touched owner.** Throughput changes need targeted + parity for the affected surfaces before speed counts. Reducer work must check + all parameter-gradient keys and nonzero values; artifact/state work must check + input/carry/state gradients and reset-present rows. +- **Benchmark harness stays user-like.** The measured path must be ordinary + `output = model(x, ...)`, external loss, `loss.backward()`, and optional + optimizer step. No benchmark-side time chunking, detach policy, checkpoint + selection, direct runtime helper calls, or streaming-loss hooks are allowed as + throughput evidence. +- **Stop rules for strategy work.** If an optimization requires adding primitive + formulas to temporal scheduler files, fixed slot enums, cell-family selectors, + direct wrappers, or benchmark-row branches, reject it. Re-express the idea as + a registered compiler strategy or keep the row open. +- **No closure by averages only.** Report boundary/worst rows, not just average + speedup. The lowest passing row for each family/hidden/batch/param bucket is + the one that determines whether the gate closes. +- **Skill and doc hygiene continue.** If throughput work discovers a durable + rule, update the relevant Fabric skill in the same pass. Record accepted and + rejected probes in this document so future iterations do not repeat them. + +### After T=1 + +Only after the T=1 gate is healthy: + +1. close K>1 against matched current-code T=1 divided by K; +2. close T>1 streaming against matched T=1 per-token throughput; +3. close horizon-H/TBPTT with bounded internal memory; +4. close ordinary per-timestep loss through `output = model(x)`, external loss, + and `loss.backward()` with no benchmark-owned streaming-loss helper. diff --git a/ai_docs/REDO_FIXMAASS.md b/ai_docs/REDO_FIXMAASS.md new file mode 100644 index 00000000..679592e7 --- /dev/null +++ b/ai_docs/REDO_FIXMAASS.md @@ -0,0 +1,9937 @@ +# REDO_FIXMAASS + +Running scratchpad, staged redo plan, and audit log for rebuilding the lost +fixmaas/fixmass work from the April 21 codebase plus recovered artifacts. + +This file is the working memory for the redo. Keep appending concrete facts, +decisions, owner changes, audit results, and reopen reasons here before making +large changes elsewhere. Every closed stage must leave this doc current and be +committed. + +## Operating Rules + +- Current local code is treated as the April 21 baseline, except for recovered + artifacts under `ai_docs/`. +- `ai_docs/recovered_core.py` is evidence, not automatically truth. Compare it + against live code and port only designs that still match the intended + architecture. +- `ai_docs/AWS_RECOVERY_TRAIL.md` is stale after April 24. It is useful for + intent, stage names, and prior failure modes, but it cannot close redo stages. +- The redo is complete only when all Fabric execution flows use one shared + temporal engine, all parity/performance audits pass, and legacy paths are + deleted. +- Audit historical references are always resolved from the April 21 audit JSON: + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json`. + Exact April 21 T=1 rows remain strict throughput/memory regression gates. + T*K/H closure also requires matched current-code T=1 training evidence, + because the recovered April 21/April 26 artifacts do not prove that the live + horizon-H implementation is semantically correct. Recovered April 26 notes + can explain intent or expected ceilings, but they do not replace current-code + matched T1/K scaling evidence. +- Planner ownership is mandatory: runtime, model helpers, benchmarks, and tests + must not make hidden execution-policy decisions. +- Benchmarks and audits are high-level API consumers only. They must use normal + Fabric model calls and PyTorch-style training: + `output = model(x, ...)`, external loss construction, `loss.backward()`, and + optimizer step when applicable. This applies to T=1, T>1, K>1, and T*K/H + audits. +- Benchmarks must never design backend/planner behavior, call private planner or + runtime helpers for closure, split time/batch/workspace to make rows fit, or + provide special streaming-loss/per-chunk-backward/detach hooks as evidence. + Backend/runtime/planner code owns all temporal chunking, rolling tape, + checkpointing, horizon semantics, workspace policy, and physical execution. +- Eight GPUs are available as an operator resource for faster experiments. + Multi-GPU parallel launch helpers are useful, but deterministic GPU sharding + is not an audit closure requirement. Audit correctness must not depend on how + cases are scheduled across devices. +- There must be one Fabric audit entrypoint. Benchmark code should be refactored + into an organized Fabric folder instead of the current flat `benchmarks/` + sprawl. +- If an audit identifies a semantic or performance issue, reopen the stage that + owns the broken decision. Do not patch around owner boundaries in the audit. + +## Baseline Evidence + +- April 21 single-pop audit result: + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json` + This is the authoritative score reference for audit gates. +- Recovered original redo prompt: + `ai_docs/prompt.tx` +- Recovered backend evidence: + `ai_docs/recovered_core.py` +- Recovered stale trail: + `ai_docs/AWS_RECOVERY_TRAIL.md` +- Additional cleanup goals: + `ai_docs/additonal_goals.md` + +Important April 21 baseline rows already inspected: + +- `h32_t1_bxparams`: sLSTM and Axon, 100M/500M/1B, forward and training, + B=1024/16384, 24/24 fit. +- `h32_small_params_high_batch`: 1M/10M, B=16384/65536/131072, 24/24 fit. +- `h4_many_cell_stress`: 100M/500M/1B, B=1024/16384, 24/24 fit. +- `h8_many_cell_stress_focused_warmed_rerun`: Axon 1B train B=1024, + Fabric 16,723.8 tok/s, stack 5,631.7 tok/s. +- `h16_many_cell_stress`: boundary Axon 500M train B=1024, + Fabric 35,558.75 tok/s. +- `streaming_per_timestep_sequence_loss`: sLSTM and Axon 500M/1B, + B=512, T=512/4096, h=32, 8/8 fit. +- `factorization_invariance`: spread around 0.31%. +- `rollout_curves`: Axon 360M boundary Fabric 140,917.7 tok/s vs + HF 6,048.16 tok/s. + +April 21 streaming sequence loss baselines inspected: + +- sLSTM 500M B512 T512: 103,575.41 tok/s, 14.09 GiB. +- sLSTM 500M B512 T4096: 104,115.17 tok/s, 29.00 GiB. +- sLSTM 1B B512 T512: 100,449.76 tok/s, 28.00 GiB. +- sLSTM 1B B512 T4096: 95,638.46 tok/s, 61.21 GiB. +- Axon 500M B512 T512: 99,768.20 tok/s, 18.25 GiB. +- Axon 500M B512 T4096: 98,607.96 tok/s, 51.88 GiB. +- Axon 1B B512 T512: 101,065.73 tok/s, 27.15 GiB. +- Axon 1B B512 T4096: 99,896.35 tok/s, 60.94 GiB. + +## Recovered Intent + +The target architecture is a shared temporal engine: + +- Single-pop and mixed-pop Fabric must become the same backend problem. +- Backend normalizes all declared populations into flat buckets with flat + bucket identity. +- Flat buckets lower into superops. User population names may be debug metadata + or parameter binding labels, but not semantic bucket identity. +- Forward and backward consume the same temporal plan. +- Internal thinking steps `K` are internal temporal scan steps. Conceptually the + stream length is `T*K`; only emission/materialization differs. +- Future per-timestep `K` must be possible, so design around an emission map or + temporal schedule, not scalar-only repeat expansion. +- Fabric is always streamed. `T=1` is the base case; `T*K > 1` is the same + engine over more internal scan steps. +- Training must support a horizon `H` for TBPTT-style bounded backward. H must + be an explicit planner semantic, not an artifact of host helper chunking. +- Checkpointing, recomputation, materialization, active regions, and output + emission are planner-owned decisions. + +Recovered stale fixmaas stages imply these durable concepts: + +- `TemporalExecutionPlan` +- `SubstratePlan` or active graph substrate +- `ActiveRegionPlan` +- `BoundaryPlan` +- `CarryPlan` +- `EmissionPlan` +- `CheckpointPlan` +- `GradientBoundaryPlan` +- `BackwardWindowPlan` +- `ExecutorPlan` + +## Recovered Prompt Traceability + +`ai_docs/prompt.tx` was deep-read on 2026-04-27. It is a recovered copy of the +original redo request and is now treated as the compact requirement checklist +for this plan. The current plan must continue to cover every item below: + +- Code state: local code starts from the April 21 Fabric backend state, with + single-pop audit results in the April 21 JSON. +- Architecture target: refactor Fabric into one shared temporal engine and + eventually delete all legacy paths. +- Cardinality target: single-pop and multi-pop Fabric are the same backend + problem. Both must flatten through flat bucket identity and lower into + superops. +- Temporal target: forward and backward follow the same temporal design. + Internal thinking steps `K` are internal scan steps, so the stream is `T*K` + and output materialization/emission is the only user-visible difference. +- Future-proofing: per-timestep `K` must remain representable, so scalar-K + repeat expansion cannot become the semantic abstraction. +- Streaming target: Fabric is always streamed. `T=1` is the base case, and + `T*K > 1` must be the same engine over more internal steps. +- Horizon target: `T*K` training uses TBPTT-style horizon `H`, with + planner-owned checkpointing and materialization. Recovered April 26 evidence + says H=64, K=128, T=4096, and sometimes T=16K were practical targets. +- Planner target: all planning decisions go through the planner. The planner + may be split into multiple files/modules, and supporting abstractions may be + introduced elsewhere, but execution policy ownership must not move into + runtime, model helpers, benchmarks, or tests. +- Cleanup target: `ai_docs/additonal_goals.md` cleanup items must be included, + and final closure requires unused/legacy paths and logic to be deleted. +- Process target: update this redo doc continuously as scratchpad/log/memory, + update relevant skills when durable rules are discovered, and commit useful + checkpoints. +- Audit target: T=1 single-pop covers every April 21 JSON case with no + throughput or memory regression; T=1 mixed-pop covers the same shapes and + beats same-parameter MoE/stack; T*K/H proves streaming throughput, H behavior, + K extra-work scaling, and per-timestep loss; small `h` and shape + factorization do not reduce throughput. +- Reopen target: audit failures can change the owner and move work back to an + earlier stage. The plan must explicitly allow stage reopen/jump-back. + +Prompt requirement matrix: + +| ID | Source lines | Requirement | Owning stages | Closure evidence | +| --- | --- | --- | --- | --- | +| P0 | 1 | Treat local code as April 21 baseline and recovered files as evidence sources. | R0, R16 | Baseline JSON, `prompt.tx`, `recovered_core.py`, stale trail, and additional goals are listed in this doc; final report links exact audit artifacts. | +| P1 | 3 | Refactor Fabric into one shared temporal engine and delete legacy paths after audits. | R1-R5, R15, R16 | Runtime metadata and repository greps show every Fabric path enters the shared temporal engine; legacy route/helpers are removed. | +| P2 | 3 | Single-pop and multi-pop use the same backend through flat bucket identity lowered to superops. | R2, R3, R8, R12, R14 | Bucket signatures/cache keys no longer include user population names; single/mixed audits report the same temporal engine with population differences only as parameter/state bindings. | +| P3 | 3 | Forward and backward follow the same temporal design. | R1, R3, R4, R10 | Forward and backward consume the same `TemporalExecutionPlan`; parity covers outputs, states, input/carry grads, and parameter grads. | +| P4 | 3 | K internal thinking steps are the same stream as `T*K`; only output materialization/emission changes. | R1, R3, R4, R13 | Plan records internal scan steps and `EmissionPlan`; K>1 uses the same public model call and temporal executor family as K=1. | +| P5 | 3 | Future per-timestep K must remain representable. | R1, R3, R10 | Temporal schedule representation supports non-scalar K shape even if initial audits use scalar K; tests include future-shaped schedule stubs/fail-closed behavior. | +| P6 | 3, 7 | Fabric is always streamed; T=1 is the base working case. | R3, R10, R11 | T=1 and T>1 use the same engine; every April 21 T=1 row passes parity, throughput, and memory gates. | +| P7 | 3, 11 | T*K training uses TBPTT-style horizon H with checkpointing/materialization. | R1, R4, R13 | `GradientBoundaryPlan`, `CheckpointPlan`, and `BackwardWindowPlan` record H; H=64 closure passes through normal `loss.backward()`. | +| P8 | 3 | Planning decisions always go through the planner; planner can be split into focused modules. | R1, R8, R15 | Runtime/model/tests/benchmarks do not select route, tile, tape, checkpoint, horizon, population, or family policy; planner modules own those decisions. | +| P9 | 3 | Include all additional cleanup goals. | R6, R7, R8, R15 | All 20 imported cleanup goals have closed owner stages or explicit final-report deferrals. | +| P10 | 3 | Keep docs/skills updated and commit useful checkpoints. | R0, all stages, R16 | Working log records owner decisions and audit results; durable user corrections are encoded in skills; each checkpoint is committed. | +| P11 | 7 | T=1 single-pop covers all April 21 JSON cases with no throughput or memory regression. | R9, R10, R11 | Canonical audit runner emits exact April 21 reference row for every case and all rows pass strict gates. | +| P12 | 9 | T=1 mixed-pop Axon+sLSTM covers the same shapes and beats same-parameter MoE/stack. | R2, R3, R4, R9, R12 | Mixed-pop audit rows use high-level model calls, cite controlling April 21 references, compare to matched current MoE/stack, and pass tok/s gates. | +| P13 | 11 | T*K/H audit proves T scaling, H behavior, K extra-work scaling, and per-timestep loss. | R1, R4, R9, R10, R13 | Audit includes T=1/512/4096 and frontier T=16K where feasible, K up to 128, H up to 64, terminal and per-timestep loss, all through normal autograd. | +| P14 | 13 | Smaller hidden size and graph shape/factorization must not reduce throughput. | R2, R8, R9, R14 | h=4/8/16/32 and factorization rows reference April 21 h-stress/factorization groups and pass spread/throughput gates. | +| P15 | 15 | Redo completes only after strict audits, strict parity, shared engine flow, and legacy removal. | R10-R16 | Final closure report states all parity/performance gates, deleted paths, and shared-engine proof. | +| P16 | 17 | Audit failures can change owner and jump back to an earlier stage. | Stage reopen rules, all audit stages | Every failed audit row records owner stage and reopen reason; fixes happen in owning architecture/planner/runtime stage, not in the harness. | +| P17 | 1, 5 | April 26 had all fixmaas stages closed and additional audit stages introduced; redo must include both architecture rebuild stages and post-refactor audit stages. | R1-R16 | Stages R1-R10 close architecture/parity/tooling; R11-R14 close required extra audits; R16 links the final audit artifacts. | +| P18 | 15, 17 | Semantic correctness and performance goals are co-equal hard gates. | R10-R16 | No performance score counts before strict parity; no semantic refactor closes without April 21 score gates and memory gates. | +| P19 | 17 | Continue stage-to-stage and return to the user only for significant confusion or a true blocker. | All stages | Working log records the next owner/gate at each checkpoint; blockers are explicit, rare, and tied to missing facts that cannot be recovered locally. | + +Prompt source handling rules: + +- `prompt.tx` line 1 says the AWS trail is stale and `recovered_core.py` is a + recovered key backend file. Treat both as design evidence, not current-code + proof or closure evidence. +- The local code remains the April 21 baseline for implementation and the April + 21 audit JSON remains the score source, even when recovered April 26 notes + describe a higher ceiling. +- Before implementing R1-R5, compare any recovered-core idea against the live + planner/runtime/backend path being edited. Do not paste recovered code blindly + or preserve recovered fallback branches as final design. +- A prompt ID closes only from current code, current tests, current audit + artifacts, and the final deletion sweep. Citing a recovered artifact is not + enough. + +For every future code or audit stage, record the affected prompt IDs in the +Working Log entry and audit manifest. R16 cannot close until P0-P19 are marked +closed or explicitly deferred with an owner and reason. + +## Live Code Findings So Far + +### Public API and Blueprint + +Live files inspected: + +- `src/cortical/fabric/blueprint.py` +- `src/cortical/fabric/config.py` +- `src/cortical/fabric/anatomy.py` +- `src/cortical/fabric/graphs/lattice2d.py` +- `src/cortical/fabric/message_rules/declarations.py` +- `src/cortical/fabric/backend/message_rules.py` +- `src/cortical/fabric/backend/ir.py` +- `src/cortical/fabric/backend/buckets.py` + +Current gaps: + +- `Blueprint.message_passing` is typed as `DotProduct`, not a generic message + rule declaration. +- `DotProduct.to_ir()` validates a fixed surface and returns a default IR; its + semantic fields are effectively discarded. +- `compile_fabric_ir()` constructs a default dot-product message rule instead + of receiving the Blueprint-provided rule IR. +- `classify_message_rule()` recognizes only the current dot-product graph. +- `Blueprint.normalize()` still bridges through old `Config`. +- `Config` still owns old concepts: `default_k`, `k_max`, `readout_pool`, + `population_mix`, `cell_arrangement`. +- `Config` rejects more than two populations. +- `Blueprint` currently supports only lattice2d graph declarations and forces a + single shared hidden dimension through interfaces. +- Named inputs/outputs are effectively constrained to one external input and + one external output. +- Bucket signatures include user population names, so current bucket identity + is not flat bucket identity yet. + +### CUDA and Message Lowering + +Live files inspected: + +- `src/cortical/fabric/backend/cuda/nn/ir.cuh` +- `src/cortical/fabric/backend/cuda/message_rules/dot_product.cuh` +- `src/cortical/fabric/backend/cuda/execution/registry.py` +- `src/cortical/fabric/backend/pytorch/message_passing.py` + +Current gaps: + +- CUDA has generic-looking message IR structs, but the lowering registry only + supports `dot_product_segment_softmax_weighted_sum`. +- CUDA message lowering exact-matches the current dot-product graph. +- The only CUDA-side message builder is dot-product-specific. +- PyTorch reference message passing is dot-product-specific. +- Message topology and message semantics remain coupled. + +### Planner and Runtime + +Live files inspected: + +- `src/cortical/fabric/backend/planner.py` +- `src/cortical/fabric/runtime/core.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/temporal_buckets.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/surface.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/flat_buckets.py` + +Current gaps: + +- Planner still exposes `SequenceSurfaceRoute`, `kind`, `executor`, and + `implementation_executor` instead of a complete temporal execution plan. +- `plan_sequence_surface_route()` is shallow and mostly chooses the flat-bucket + sequence surface. It does not own active regions, emissions, checkpointing, + backward windows, or horizon semantics. +- Planner heuristics still use hidden-size thresholds and fixed tiles. +- Runtime `forward_cells()` and `forward_output_cells_for_readout()` make local + route/static/materialization decisions after calling the planner. +- Runtime still has active-output special paths and fresh-state paths that are + not a single shared temporal engine. +- Single-pop and mixed-pop paths are not fully identical. Some caches and + windows are gated on population count. +- `K` is currently implemented in temporal execution by repeat-expanding + `boundary_seq` with `repeat_interleave`. This is scalar-K behavior, not a + generic temporal schedule or emission map. +- Runtime rejects per-timestep variable `K` inside a sequence. +- `supported_backend_surfaces()` still exists. +- Model still has old sequence helpers and policy knobs: + `stream_sequence_outputs`, `reduce_sequence_outputs`, + `stream_sequence_mse_loss`, direct-grad sequence path, checkpointed sequence + path, `_sequence_direct_grad_target_bytes`, + `_sequence_checkpoint_target_bytes`, and cell-family sequence memory policy. +- Current CUDA temporal backward is real physical backward, but checkpoint and + recompute policy is still local to sequence-surface code. It is not driven by + an explicit planner horizon `H` or `EmissionPlan`. +- `CudaSequenceSurfaceMixin` still contains old single-pop cell recurrence + surface logic and local active receiver/materialization decisions. + +Additional planner/executor details found in the deeper pass: + +- `FabricExecutionPlanner.plan_sequence_surface_route()` only validates + CUDA/float32/partitioned layout and population support, then returns + `surface_key="flat_bucket_sequence_surface"` with + `implementation_executor="flat_transition_buckets"`. +- The route does not receive output boundary, readout boundary, final-state + materialization, time length, H, emission schedule, active region, checkpoint + steps, or backward-window semantics. +- `FabricExecutionPlanner.plan_execution()` and + `plan_backward_execution()` build bucket plans from current IR bucket + signatures. The signatures include population names through `FabricBucket`. +- `PlanCacheKey` uses those signatures, so population names currently affect + planning/cache identity. +- `_heuristic_plan_bucket()` still chooses microkernel vs grouped-GEMM from + dim signatures and hardcoded thresholds. +- `execute_temporal_bucket_sequence()` implements K>1 training by + `repeat_interleave(boundary_seq, inner_steps, dim=1)` and then slices emitted + outputs with `output_step_indices`. This is a scalar-K expansion, not a + generic schedule/emission plan. +- `execute_temporal_bucket_sequence()` contains separate paths for active-output + window, physical training, step mode, and fallback step loop. +- `supports_temporal_bucket_active_output_window()` currently rejects + `len(active_populations) <= 1`, requires local message support, rejects edge + delay/sparse message, and requires the active output region to be compact and + full. This confirms active-output is not a generic shared-engine plan. +- Active-output execution allocates full recurrent K/V banks and uses + `index_copy_` into active recurrent indices. This is a useful migration + control but should not be a permanent separate execution identity. +- Temporal execution record metadata still names `active_output_window`, + `flat_bucket_temporal_scan`, `stored_temporal_physical_scan`, and + `windowed_temporal_physical_scan`. Redo tests should assert semantic plan + fields rather than these route-era strings. + +### Flat Buckets and Transition Lowering + +Live files inspected: + +- `src/cortical/fabric/backend/runtime_dispatch.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/flat_buckets.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/temporal_buckets.py` +- `src/cortical/fabric/backend/cuda/transition_execution.py` + +Current gaps: + +- Runtime dispatch still exposes flat-bucket helpers such as + `_run_transition_bucket_step`, `_run_transition_buckets_step`, + `_run_backend_order_transition_buckets_step`, and + `_run_active_window_transition_buckets_step`. These are useful pieces, but the + shared temporal engine should own their orchestration. +- `TemporalPopulationBucket` is keyed by `name`, `backend_start`, + `backend_stop`, and recurrent indices. This is population-bucketed, not + flat bucket identity/superop lowering yet. +- `temporal_bucket_plan()` builds one bucket per active population in backend + order and caches that in `static_tensors`. +- `temporal_backward_owner_plan()` still reports owner families in terms of + message, receiver_affine, state_epilogue, diagonal_recurrence, readout, and + glue/layout. This metadata is useful, but ownership should be generated from + the shared temporal/backward plan. +- `transition_execution.lower_backend_population_transition_forward_result_shared()` + dispatches by recognized transition IR patterns: + gated-logspace sLSTM-like recurrence and diagonal Axon-like recurrence. +- Transition lowering still needs a `population_name` to resolve backend specs + and materialized parameters. In the redo, population labels can remain + parameter/state binding labels, but bucket/superop identity must be semantic. + +### Temporal Backward + +Live files inspected: + +- `src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/policy.py` + +Current gaps: + +- `_TemporalBucketSequenceFunction` is the current physical temporal autograd + path. It streams forward under `torch.no_grad()`, stores or recomputes + artifacts, and invokes `TemporalPhysicalBackwardScanExecutor` in backward. +- This is a real backend-owned physical backward, but it is not yet the target + shared temporal backward because it has no explicit `EmissionPlan`, + `GradientBoundaryPlan`, or `BackwardWindowPlan`. +- `temporal_transition_tape_policy()` picks transition tape mode from memory, + transition kinds, `time_steps`, and `tape_policy_bin`; it is not planner H + semantics. +- `_temporal_artifact_store_policy()` chooses store vs recompute artifacts, + checkpoint stride, and recompute window length from memory heuristics after + seeing the first artifact. This must move into planner-owned checkpoint/window + planning. +- `_temporal_artifact_windows()` slices windows from checkpoint stride and + window length, not from loss emissions or rolling horizon H. +- Backward receives normal `grad_output` from PyTorch autograd, which is the + right public training boundary. The redo should preserve this while replacing + the internal window/checkpoint policy with planner-owned temporal semantics. +- Current output emission is every physical internal step in the physical path, + with K outputs later selected by index. The redo needs emission-aware forward + and backward so non-emission internal steps do not look like user outputs. + +### Tests + +Live files inspected: + +- `tests/test_fabric_public_api.py` +- `tests/test_fabric_backend_plan.py` +- `tests/test_fabric_runtime.py` + +Current gaps: + +- Public API tests still use old internals as oracles in several places. +- Backend plan tests assert current route/surface names and dot-product-only + assumptions. +- Runtime tests verify current flat-bucket behavior for small T/K cases, but + they also encode stale route names such as `flat_bucket_sequence_surface`, + `temporal_bucket_sequence`, and `flat_bucket_temporal_scan`. +- Existing parity tests cover some single-pop and mixed-pop T>1/K>1 CUDA paths, + but only at tiny shapes. They do not close the required performance or + horizon semantics. +- Current tests still exercise `stream_sequence_outputs` and + `reduce_sequence_outputs`, which are slated for removal or re-expression + through the shared temporal engine. + +### Benchmark and Audit Tooling + +Live files inspected: + +- `benchmarks/common.py` +- `benchmarks/run.py` +- `benchmarks/README.md` +- `benchmarks/fabric.py` +- `benchmarks/fabric_scaling.py` +- `benchmarks/fabric_suite_common.py` +- `benchmarks/run_fabric_scaling_profile.py` +- `benchmarks/run_fabric_bxt_scaling_audit.py` +- `benchmarks/run_fabric_mixed_population_audit.py` +- `benchmarks/run_fabric_factorization_invariance_audit.py` + +Current gaps: + +- Fabric benchmark code is spread across flat files in `benchmarks/` instead of + an organized Fabric package. +- There is no single canonical Fabric audit script. Current audit-like scripts + include separate scaling, BxT, mixed-population, and factorization entrypoints. +- The generic `benchmarks/run.py` imports every flat benchmark module and is not + enough for closure auditing because it does not know baseline comparisons, + owner stages, audit manifests, high-level API proof, or strict gates. +- `fabric_suite_common.py` contains core model construction, parameter matching, + measurement, rollout, HF comparison, planner-signature extraction, and + training loss helpers in one large shared file. +- `_BackboneWithHead.grad_sequence_strategy()` calls private model helpers such + as `_should_use_direct_grad_sequence`. This makes benchmark reporting depend + on legacy model internals. +- `_run_sequence_forward()` passes `materialize_final_state=False` and + `output_boundary=...` directly to private Fabric behavior when available. +- Current training measurement uses ordinary external loss/backward in the main + sequence case, which is good. The redo must preserve that and remove any + closure dependence on model-owned streaming loss hooks. +- The April 26 intended benchmark style was high-level Fabric API only, including + T*K/H experiments. The audit refactor must keep benchmark code as a user of + model forward/backward, not a backend/planner co-designer. +- `run_fabric_bxt_scaling_audit.py` writes JSONL, JSON summaries, markdown, and + sqlite planner-policy data. It runs cases serially through subprocesses; + optional parallel launch support may be useful for experiments, but lack of a + parallel scheduler is not itself an audit blocker. +- `run_fabric_mixed_population_audit.py` is a separate script with its own model + builder, stack matching, output writer, and backend-summary extraction. +- `run_fabric_factorization_invariance_audit.py` is another separate script and + uses `fabric_suite_common.run_sequence_case`. +- `run_fabric_scaling_profile.py` is a large profiler/diagnostic script with + row definitions, kernel-pattern gates, glue/backward attribution tables, and + current date/output defaults. It is useful as a source for profiler gates but + should not remain the canonical audit entrypoint. +- Existing audit scripts default output paths under old `docs/user/subho/...` + style locations. The redo needs a stable audit artifact root under the repo + audit tree. +- Current scripts do not load the April 21 baseline JSON as a strict row-by-row + gate. +- Current scripts do not emit owner-stage attribution in a consistent schema. + +Required benchmark/audit refactor shape: + +- Create a dedicated Fabric benchmark/audit package, e.g. `benchmarks/fabric/`. +- Split reusable concerns into small modules: + `models.py`, `cases.py`, `measure.py`, `baselines.py`, `manifest.py`, + `report.py`, optional parallel-launch helpers, and profiler helpers. +- Provide one canonical entrypoint for Fabric closure, e.g. + `benchmarks/fabric/run_audit.py`. +- Keep temporary wrappers at old script paths only during migration, and delete + them in Stage R15. +- Add an explicit manifest with case IDs, owner stage, suite (`quick`, + `closure`, `frontier`), resource hints, baseline key, and pass/fail criteria. +- Optional parallel launch support may distribute independent cases across the + available GPUs for faster experiments and should record device metadata when + used. This is an execution convenience, not a semantic or performance closure + gate. +- Use one output schema for every case: input axes, semantic axes, engine plan + metadata, performance metrics, memory metrics, baseline comparison, pass/fail, + owner stage, and reopen reason. +- The runner may read public/high-level backend metadata emitted after + `model.forward()`/`loss.backward()`, but must not steer the planner or call + private execution helpers to produce a passing result. + +### Cell and Physical Backend Contracts + +Live files inspected: + +- `src/cortical/fabric/cells/slstm.py` +- `src/cortical/fabric/cells/axon.py` +- `src/cortical/fabric/registry/cells.py` +- `src/cortical/fabric/backend/cell_backend.py` +- `src/cortical/fabric/backend/cell_specs.py` +- `src/cortical/fabric/backend/reuse.py` +- `src/cortical/fabric/backend/surfaces.py` + +Current gaps: + +- Cell specs carry `sequence_memory_policy` and `supports_direct_grad_sequence` + metadata. Current model/runtime code consumes these as execution policy. +- Backend cell specs still advertise old per-cell sequence surfaces such as + `slstm_recurrence` and `axon_recurrence`. +- Surface support is keyed by cell type and old surface names. +- Axon and sLSTM backend specs expose useful transition IR, parameter bindings, + primitive support, and reuse scopes. The redo should preserve those facts but + move policy decisions into the planner. +- `ExecutionFamily` and `MathBackend` are still only + `receiver_major`/`edge_major`/`sequence_major` and + `microkernel`/`grouped_gemm`. The new planner can extend or reinterpret these, + but runtime should not branch on them directly. + +### Recovered Core and Trail Comparison + +Recovered files inspected: + +- `ai_docs/recovered_core.py` +- `ai_docs/AWS_RECOVERY_TRAIL.md` + +Useful recovered design: + +- `Runtime` in recovered core inherits a missing + `SharedTemporalBackwardHelperMixin`, suggesting backward helper extraction was + started after the April 21 codebase. +- Recovered runtime imports a missing `runtime/model_temporal.py` with + `ModelTemporalMixin`, `_expand_resets_for_time`, `_flatten_tensordict`, + `_slice_sequence_k`, and `_unflatten_tensordict`. +- Recovered runtime seeds `FabricExecutionPlanner` with active/output regions, + sender tables, sender valid masks, and input sender counts. Live planner does + not receive those facts. +- Recovered runtime has `_plan_temporal_execution(...)` returning + `TemporalExecutionPlan` and passing device, dtype, partitioned layout, edge + delay, constant K, output boundary, readout boundary, fresh state, training, + time steps, helper/direct-grad facts, `gradient_horizon_steps`, and + `checkpoint_steps`. +- Recovered runtime deleted the old public `supported_backend_surfaces()` + exposure and old surface-key path in some places. +- Recovered runtime uses planner fields such as + `executor.selected_implementation`, `executor.temporal_strategy`, + `carry.carry_policy`, and `backward.static_values_mode`. + +Recovered design that is not sufficient by itself: + +- Recovered `forward_cells()` still has a large fallback step loop and still + calls `_forward_stream_step`. +- Recovered `forward_output_cells_for_readout()` still has separate active + output window and shared temporal sequence branches, though they are driven by + a temporal plan. +- Recovered active-output strategy is still a named execution identity. The redo + must ensure active-output is a materialization/region plan inside the shared + engine, not a permanent sibling route. +- Recovered code still uses `constant_k_host` in active-output branches, so it + does not fully close future per-timestep K. +- Recovered code references `gradient_horizon_steps` and checkpoint steps, but + no recovered planner file exists here. Rebuild the planner intentionally; do + not try to blindly paste core fragments. + +Durable recovered algorithm notes from the stale trail: + +- `EmissionPlan` maps user-visible output timesteps to internal scan steps. +- `SubstratePlan` owns the flat active graph substrate. `H=full` uses the exact + finite dependency closure; finite H uses the declared boundary substrate. +- `GradientBoundaryPlan` records `full_horizon` or `rolling_horizon(H)`. H is a + user/declaration semantic and must not be silently shrunk by the planner. +- `CheckpointPlan` places compact carry checkpoints from emission schedule, H, + reset/K schedule, active substrate size, memory budget, and recompute/launch + cost. Fixed stride is a special case, not the abstraction. +- `BackwardWindowPlan` groups emissions into bounded reverse chunks over the + same substrate. Accepted finite-H implementation used non-overlapping chunks + of length at most H; overlapping per-emission windows would be a separate + higher-fidelity semantic. +- Forward streams the planned substrate once over `T*K`, emits normal dense user + outputs, and saves only compact checkpoints. +- Ordinary external `loss.backward()` supplies `grad_output`; backend walks + `BackwardWindowPlan` right-to-left, reloads/recomputes each segment, injects + gradients at emission steps, runs the physical adjoint, accumulates + input/carry/parameter grads, and discards segment-local artifacts. +- Rejected alternatives from the trail: do not mask inactive recurrent state + after transition, do not make CUDA-only fixed-window shortcuts, do not compare + fixed-substrate CUDA against a full-surface PyTorch reference, do not use + benchmark streaming-loss helpers for closure, do not add K=1/single-pop/family + selectors, and do not restore the April 21 sequence-major path as a permanent + sibling route. + +## Additional Cleanup Goals + +Imported from `ai_docs/additonal_goals.md` and mapped to redo stages: + +1. `Blueprint.message_passing` typed as `DotProduct`, not generic. +2. `DotProduct.to_ir` discards fields and uses default IR. +3. Backend IR does not receive Blueprint message rule. +4. `classify_message_rule` recognizes only dot-product. +5. CUDA message lowering exact-matches dot-product graph. +6. Only CUDA-side message builder is DotProduct. +7. PyTorch reference message passing is dot-product-specific. +8. Message topology and semantics are coupled. +9. Blueprint facade over old Config. +10. Old Config concepts are central. +11. Runtime exposes old `d_hidden` Model construction path. +12. Public tests use old internals as oracle. +13. Graph API is lattice-only. +14. Interface dims are forced equal to hidden size. +15. Named inputs/outputs support only one compile path. +16. Config rejects more than two populations. +17. Old population placement exists outside graph declarations. +18. Bucket identity includes user population names. +19. Planner uses hidden-size thresholds and hardcoded tiles. +20. Cell families drive backend surface selection/runtime policy. + +## Audit Requirements + +### Score reference policy + +- The April 21 audit JSON is the only authoritative historical score source: + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json`. +- Exact T=1 single-pop rows must match an April 21 row and use that row's + throughput and peak-memory values as strict gates. +- Streaming/per-timestep-loss rows must reference the April 21 + `streaming_sequence_loss` rows where exact matches exist. +- Factorization and hidden-size rows must reference the April 21 + factorization/h-stress rows where exact matches exist. +- Mixed-pop rows that did not exist in April 21 must name the controlling April + 21 shape/family/parameter reference row for Fabric floor checks and must also + compare against a same-parameter MoE/stack baseline measured by the current + audit. +- T*K/H rows without exact April 21 matches must name the controlling April 21 + T=1 row and, when applicable, the April 21 streaming sequence-loss row. April + 26 H/K/T observations are target intent and ceiling context, not score gates. +- T*K/H semantic closure is not allowed to cite April 21/April 26 H behavior as + proof. It must rerun the matched current-code T=1,K=1 training row on the same + graph/batch/parameter/hidden/population/loss-boundary contract, then judge + K=1 T-scaling against that current-code T1 line and K>1 rows against that + current-code T1 line divided by K. +- Every audit output row must include the April 21 reference key or reference + group used for the pass/fail decision. + +### Required audit organization + +- Refactor Fabric benchmark/audit code under a dedicated Fabric folder, e.g. + `benchmarks/fabric/`. +- Provide one canonical audit entrypoint, e.g. + `benchmarks/fabric/run_audit.py` or equivalent. +- The audit entrypoint must support: + - quick local smoke runs, + - full closure runs, + - frontier/performance exploration, + - JSONL case logs, + - summary JSON, + - machine-readable baseline comparisons, + - owner-stage attribution for failures, + - prompt requirement IDs (`P0`-`P19`) affected by the case, + - shared temporal-engine metadata proving the case did not use a legacy route, + - high-level API proof showing the measurement path used normal model + forward/loss/backward/optimizer calls. +- All score references must be looked up from the April 21 audit JSON. If a new + case has no exact April 21 row, the manifest must name the controlling April + 21 reference row or reference group, such as the matched T=1 family/parameter + row or the April 21 streaming sequence-loss row. Do not use recovered April + 26 numbers as closure gates. +- Do not create legacy wrappers as a migration target. Update callsites + to the right graph/message/cell/readout/planner declarations and delete stale + paths. + +### T=1 single-pop closure + +- Start with `T=1`; because Fabric is always streaming, T=1 must be the base + working case. +- Cover all cases in the April 21 baseline JSON. +- Throughput must be greater than or equal to April 21 results for each matched + row. +- Peak memory must not regress against April 21 for each matched row. +- Both forward and training modes must pass parity and performance gates. + +### T=1 mixed-pop closure + +- Evaluate mixed-pop Axon+sLSTM Fabric for the same shape/size grid used by the + T=1 single-pop closure. +- Compare against a same-parameter MoE/stack baseline. +- Mixed-pop Fabric must be faster in tokens/sec than the matched baseline. +- Mixed-pop must use the same shared temporal engine as single-pop; no mixed + wrapper or route identity may be the reason it passes. + +### T*K, horizon H, and per-timestep loss closure + +- Audit `T*K` as one internal temporal stream with an emission plan. +- For small H, increasing T must not materially change Fabric tok/s. For K=1 + and H=64, T=512/T=4096 rows must remain at or above the matched current-code + T=1,K=1 training throughput for the same graph/batch/parameter/hidden/ + population/loss-boundary contract. +- Increasing H may reduce tok/s, but H=64 is the initial closure target and + should still remain above the matched K-adjusted T=1 training floor where + possible. +- Increasing K should reduce tok/s by approximately factor K because K performs + K times the internal work. +- For K>1 rows, the throughput floor is not raw T=1 tok/s. It is the matched + current-code T=1,K=1 training throughput divided by K, because K internal + thinking steps do K times the work. For the current accepted ceiling, + K=128/H=64 should be judged against matched current-code T=1 training + throughput / 128, while still checking memory, parity, owner metadata, and H + behavior. +- Per-timestep loss must be evaluated, not only terminal loss. +- Required axes include at least T=1/512/4096 and frontier T=16K when memory + allows; K=1 and K up to 128; H up to 64. + +### Hidden-size and shape closure + +- Reducing hidden size `h` must not reduce tok/s. +- h=4 should remain comparable to h=32, matching the April 26 observation. +- Fabric shape/factorization should have very small tok/s spread because all + shapes flatten before lowering. +- Repeat the April 21 factorization-invariance audit and extend it to mixed-pop + where applicable. + +## Staged Redo Plan + +Status legend: + +- `PENDING`: not started. +- `ACTIVE`: current owner is working. +- `BLOCKED`: cannot advance without an explicit missing fact or failed + dependency. +- `REOPENED`: an audit found a regression owned by this stage. +- `CLOSED`: exit criteria met and committed. + +### Stage reopen rules + +- Public API, message-rule, graph, interface, or Config-era failures reopen R6 + or R7. +- Population-count-specific behavior, population-name bucket identity, or + shape/factorization spread reopens R2. +- Runtime-owned route/materialization/checkpoint/horizon decisions reopen R1, + R3, R4, or R8 according to the broken decision. +- Any benchmark case that calls private planner/runtime helpers, implements + temporal chunking itself, or uses non-public streaming-loss closure reopens + R9 before the score can count. +- T=1 throughput or memory regressions against the April 21 JSON reopen the + owning planner/backend stage, not the audit harness. +- T*K, emission, per-timestep loss, or H regressions reopen R1/R3/R4 first. +- Remaining legacy execution paths after closure reopen R15. + +### Stage R0 - Evidence freeze and doc discipline + +Status: CLOSED + +Owner: docs/audit lead. + +Goals: + +- Preserve the April 21 baseline result paths and recovered artifacts. +- Keep this file as the redo scratchpad/log. +- Record every durable finding before major code changes. +- Commit this doc at each useful checkpoint. + +Exit criteria: + +- Current baseline, recovered evidence, live-code gaps, and audit targets are + documented. +- No code architecture stage starts without a matching owner and audit gate. + +### Stage R1 - Planner-owned temporal model + +Status: ACTIVE + +Owner: planner. + +Goals: + +- Replace `SequenceSurfaceRoute` with a real `TemporalExecutionPlan`. +- Add planner-owned structures for substrate, active regions, boundary, carry, + emission, checkpointing, gradient boundary, backward windows, executor + choice, and audit attribution. +- Split planner implementation into focused modules if that keeps ownership + clearer, for example temporal plan declarations, emission schedules, + substrate/active-region planning, checkpoint/window planning, and executor + selection. +- Planner must receive normalized message-rule IR and graph/interface facts. +- Runtime may execute the plan but must not reconstruct policy. + +Exit criteria: + +- Forward and backward callers can obtain one temporal plan for T=1 and T>1. +- Plan records scalar K and has a representation ready for per-timestep K. +- Existing route/surface strings are either compatibility metadata or removed. + +R1 implementation checklist: + +- Prompt IDs: P1, P3, P4, P5, P7, P8, P10, P17, P18, P19. +- Add planner-owned temporal plan structures for substrate, boundary, carry, + emission, checkpoint, gradient boundary, backward windows, and executor + selection. +- Preserve current behavior in the first slice by treating legacy + `SequenceSurfaceRoute` as compatibility metadata inside the temporal plan. +- Runtime may consume the plan and pass the selected route to existing executor + code, but it must not independently reconstruct temporal schedule, emission, + or H ownership for new R1 metadata. +- Tests should prove the planner emits the temporal plan for T=1/T>1/K>1 and + that runtime obtains route compatibility through the plan. + +### Stage R2 - Pop-agnostic flat substrate and bucket identity + +Status: PENDING + +Owner: backend IR / bucket lowering. + +Goals: + +- Normalize single-pop and mixed-pop declarations into one flat temporal + substrate. +- Remove user population names from semantic bucket identity. +- Lower all active receivers into flat bucket identity/superops. +- Keep population names only where needed for parameter ownership, debug labels, + or state mapping. + +Exit criteria: + +- Single-pop and mixed-pop produce the same kind of flat temporal substrate. +- Bucket equality and audit attribution cannot depend on user population names. +- More than two populations are not structurally blocked by Config-era limits. + +### Stage R3 - Shared temporal forward engine + +Status: PENDING + +Owner: runtime/backend forward. + +Goals: + +- Implement one streaming temporal forward engine for T=1, T>1, K=1, and K>1. +- Replace scalar-K repeat expansion with an emission/schedule model. +- Route all current forward paths through this engine. +- Materialize only requested outputs/states according to the plan. +- Remove separate single-pop/mixed-pop active-output branches as execution + identities. + +Exit criteria: + +- T=1 and T>1 use the same forward backend. +- K is represented as temporal scan work with explicit emissions. +- Scalar-K behavior is implemented as one case of the generic emission/schedule + abstraction, not as repeat-expanded user-visible sequence semantics. +- Output materialization choices do not change the semantic engine. + +### Stage R4 - Shared temporal backward engine + +Status: PENDING + +Owner: runtime/backend backward. + +Goals: + +- Backward consumes the same `TemporalExecutionPlan` as forward. +- Add explicit `GradientBoundaryPlan` and `BackwardWindowPlan`. +- Implement full-horizon and rolling-horizon H semantics. +- Checkpoint/recompute policy is planner-owned and recorded. +- Checkpointing and materialization preserve the prompt requirement that H is a + TBPTT-style semantic horizon while throughput remains comparable to the + matched T=1 reference where expected. +- External PyTorch autograd loss remains normal; benchmark-owned streaming loss + closures must not be required for correctness or memory behavior. + +Exit criteria: + +- Physical temporal backward works for single-pop and mixed-pop through the same + path. +- H=1 and H>1 are explicit audited modes. +- Per-timestep and terminal losses both pass parity. + +### Stage R5 - Model temporal API cleanup + +Status: PENDING + +Owner: model/runtime API. + +Goals: + +- Split temporal model helpers out of `runtime/core.py` if needed. +- Remove or re-express `stream_sequence_outputs`, `reduce_sequence_outputs`, + and `stream_sequence_mse_loss` through the shared temporal engine. +- Delete direct-grad/checkpointed model sequence paths once the shared engine + covers them. +- Remove old user-facing `d_hidden` construction path when replacement API is + ready. + +Exit criteria: + +- Public model forward is a thin call into the shared temporal plan/engine. +- Tests no longer use old internals as the semantic oracle. + +### Stage R6 - Generic message rules and graph interfaces + +Status: PENDING + +Owner: public API / IR / CUDA and PyTorch lowering. + +Goals: + +- Introduce generic message-rule declarations. +- Pass Blueprint message-rule IR through backend IR. +- Decouple topology from message semantics. +- Keep dot-product as one implementation, not the public type of all message + passing. +- Fail closed for unsupported rules with clear planner/audit attribution. +- Start unblocking non-lattice graph declarations, decoupled interface dims, and + multiple named inputs/outputs. + +Exit criteria: + +- Dot-product no longer appears as the only public message passing abstraction. +- PyTorch and CUDA reference paths consume the same normalized message rule IR. +- Existing dot-product audits remain parity-clean. + +### Stage R7 - Config/anatomy/population cleanup + +Status: PENDING + +Owner: public API / anatomy. + +Goals: + +- Move Blueprint away from old `Config` as the real normalized form. +- Delete or quarantine Config-era concepts that should not drive the backend: + `population_mix`, `cell_arrangement`, `default_k`, `k_max`, and + `readout_pool`. +- Move population placement into graph declarations. +- Remove the two-population cap. +- Ensure cell family metadata does not choose runtime execution policy. + +Exit criteria: + +- Backend receives graph/message/cell facts directly from the new normalized + declaration path. +- Legacy Config fields no longer own graph, message, cell, readout, + initialization, or planner semantics. + +### Stage R8 - Planner physical-policy cleanup + +Status: PENDING + +Owner: planner / physical backend registry. + +Goals: + +- Remove hidden-size thresholds and hardcoded tiles from runtime decisions. +- Move microkernel/grouped-GEMM/superop selection into planner-owned policy. +- Replace op-name checks such as recurrence-family strings with capability + records. +- Ensure audit metadata names the planner decision that selected each executor. + +Exit criteria: + +- Runtime no longer branches on hidden size, population count, or cell family to + select physical policy. +- Physical policy can be audited and reopened to planner ownership. + +### Stage R9 - Benchmark and audit package refactor + +Status: PENDING + +Owner: benchmark/audit tooling. + +Goals: + +- Move Fabric benchmark/audit code under a dedicated folder such as + `benchmarks/fabric/`. +- Provide one canonical Fabric audit script. +- Keep closure/smoke audits runnable without multi-GPU scheduling. Optional + parallel launch helpers may use the available GPUs for independent + experiments, but this is not a closure gate. +- Produce JSONL case logs, summary JSON, and optional sqlite/manifest output. +- Record baseline row IDs, measured row IDs, pass/fail criteria, and owner stage + for every case. +- Record prompt requirement IDs (`P0`-`P19`) for every case so final closure can + prove coverage of `ai_docs/prompt.tx`. +- Resolve every score threshold from the April 21 audit JSON and record the + exact reference key/row in the case output. +- Make old benchmark scripts wrappers only during migration. +- Ensure every audit case calls the high-level Fabric API exactly like user code: + model forward, external loss, backward, optional optimizer step. T*K, H, + terminal/per-timestep loss, and mixed-pop cases must not use private runtime + or planner hooks to implement the temporal behavior. + +Exit criteria: + +- A single command can run quick, closure, and frontier Fabric audit suites. +- Audit output can directly compare against the April 21 baseline JSON. +- No audit score is accepted without an April 21 JSON reference row or explicit + April 21 reference group. +- Benchmark code no longer owns temporal chunking, loss streaming, or detach + policy. +- The audit package contains no closure path that calls private Fabric planner, + runtime, stream-sequence, reduce-sequence, or per-chunk backward helpers to + make a case pass. + +### Stage R10 - Strict semantic parity matrix + +Status: PENDING + +Owner: tests/parity. + +Goals: + +- Expand parity before relying on performance numbers. +- Cover T=1, T>1, K=1, K>1, scalar K, future-shaped K schedule stubs, resets, + terminal output, per-timestep output, materialized and unmaterialized final + state, direct and chunked execution, single-pop and mixed-pop. +- Cover gradients for input, state, and parameters. +- Replace stale tests that assert legacy route names with tests that assert the + shared temporal plan semantics. + +Exit criteria: + +- Parity matrix passes on CPU/PyTorch reference and CUDA shared engine. +- No test depends on a legacy path as the correctness oracle. + +### Stage R11 - T=1 single-pop performance closure + +Status: PENDING + +Owner: performance/audit. + +Goals: + +- Rerun all April 21 T=1 single-pop cases. +- Require throughput >= April 21 row throughput. +- Require peak memory <= April 21 row peak memory. +- Include both sLSTM and Axon, forward and training, small/high-batch and large + parameter rows. + +Exit criteria: + +- Every matched row passes strict throughput and memory gates. +- Any failure reopens the owning architecture or planner stage. + +### Stage R12 - T=1 mixed-pop performance closure + +Status: PENDING + +Owner: performance/audit. + +Goals: + +- Run Axon+sLSTM mixed-pop Fabric over the same shape/size grid as T=1 + single-pop closure. +- Compare to same-parameter MoE/stack baseline. +- Require Fabric tok/s > matched baseline tok/s. +- Verify audit metadata shows shared temporal engine, not mixed-pop wrappers. + +Exit criteria: + +- All mixed-pop rows pass parity and performance criteria. +- Any population-count-specific backend identity reopens R2/R3/R4. + +### Stage R13 - T*K and horizon H performance closure + +Status: PENDING + +Owner: performance/audit and temporal backward. + +Goals: + +- Audit T scaling with small H: increasing T should not materially change tok/s. +- Audit K scaling: tok/s should fall approximately by factor K because the + engine is doing K times the internal work. +- Audit K from 1 through 128. K=128 is the currently accepted ceiling for + internal thinking steps, and the full sweep must support the extra-work + scaling claim. +- Audit H scaling: H=64 is the first closure target; larger H can be explored. +- Include per-timestep loss and terminal loss. +- Exercise T=4096 and frontier T=16K where memory allows. +- Use April 21 T=1 and streaming sequence-loss rows as the score references; + April 26 H/K observations are target intent, not score source. +- After H=64 closure, frontier runs should attempt to raise the H ceiling while + staying at or above the matched K-adjusted T=1 training floor where feasible. + H>64 exploration is not allowed to weaken the H=64 closure gate. + +Exit criteria: + +- T scaling is flat enough to prove streaming semantics. +- H=64 works without breaking the matched K-adjusted T=1 training throughput + floor where expected. +- K=1 through K=128 run through the shared temporal engine, K=128 is accepted + as the current ceiling, and the sweep matches the extra-work model. + +### Stage R14 - Hidden-size and factorization closure + +Status: PENDING + +Owner: performance/audit and bucket lowering. + +Goals: + +- Reproduce h=4/h=8/h=16/h=32 behavior. +- Confirm reducing h does not reduce tok/s. +- Confirm graph shape/factorization has only small spread after flattening. +- Extend factorization invariance to mixed-pop where meaningful. + +Exit criteria: + +- h=4 remains comparable to h=32. +- Shape spread is within the closure threshold chosen in the audit manifest. + +### Stage R15 - Legacy deletion and Fabric surface cleanup sweep + +Status: PENDING + +Owner: cleanup, public Fabric surface, graph/message ownership, and backend +boundary review. + +Goals: + +- Treat R15 as the full recovered cleanup board, not as the latest local + `Config`/`anatomy.py` complaint. The `Config`/anatomy cleanup is one + important R15 slice; it does not replace the broader message, graph, + planner, runtime, benchmark, and legacy-route deletion work. +- Delete legacy execution paths after strict shared-temporal audits pass. +- Remove stale route types, old sequence surfaces, old stream/reduce/loss helper + paths, hidden-size/cell-family runtime policy, population-name bucket + identity, stale benchmark wrappers, and any sibling single-pop/mixed-pop + route identity. +- Redesign the Fabric public ownership surface so `Config` is no longer the + generic machinery center. Graph construction, message declarations, cell + declarations, readout declarations, initialization, and planner requests must + each own their own fields. +- Split lattice-specific facts out of generic Fabric/anatomy code. Lattice owns + rectangular dimensions, coordinates, wrap, band ports, projection regions, + offset neighborhoods, and lattice KV grouping. Backend-facing Fabric code + consumes only flat graph/tensor/op tables and recorded planner decisions. +- Remove direct message/cell math hidden in runtime/core when it should be + expressed through user-declared `fabric.cuda.nn` primitives and lowered + tensor/op rows. +- Add static and manual review gates that prevent cell-family names, + population names, rectangular factorization, config-level head counts, + benchmark ids, or hidden-size constants from becoming backend policy keys. +- Refactor benchmarks/audits so they stay in the Fabric audit folder, use the + high-level model forward/loss/backward path, and never design backend or + planner behavior from the harness. +- Use repository-wide greps and audit metadata to prove no legacy execution or + config truth path remains. +- Close every item imported from `ai_docs/additonal_goals.md`, including the + message-rule genericity board, Blueprint/Config facade removal, graph API + ownership cleanup, old public-test cleanup, population/cardinality cleanup, + planner policy cleanup, and cell-family backend-surface cleanup. + +Exit criteria: + +- All Fabric paths flow through the shared temporal engine. +- Fabric's generic backend/anatomy surface is graph-generic: no lattice, + population-name, cell-family, hidden-size, benchmark, or config-era field is + used as backend identity or route policy. +- No legacy execution/config truth path remains, and R15 lists the removed files, + callsites, and guardrails that prevent reintroduction. +- Every recovered additional cleanup issue has either a code-level closure with + tests/audit evidence or an explicit final-report deferral with the reason it + cannot be removed before the shared temporal engine is complete. + +### Stage R16 - Final closure report + +Status: PENDING + +Owner: audit/docs. + +Goals: + +- Summarize final architecture. +- Link every audit artifact. +- State exact pass/fail thresholds and measured results. +- Record deleted legacy paths. +- Mark P0-P19 from the prompt requirement matrix as closed or explicitly + deferred with owner and reason. +- Commit final docs and code. + +Exit criteria: + +- We can clearly state that all Fabric paths flow through the new shared + temporal engine and all strict audits passed. +- Final report includes a prompt-traceability checklist showing every + requirement from `ai_docs/prompt.tx` is either closed or explicitly deferred + with owner and reason. + +## Implementation Read Checklist + +This R0 pass inspected the public API, Config/anatomy path, message-rule IR, +planner, runtime core, temporal CUDA executor, flat bucket helpers, temporal +backward, cell/backend specs, tests, current benchmark scripts, recovered core, +stale trail, and April 21 audit JSON. Before each code stage starts, do a +line-local read of the files being edited and append the exact owner decision +here. + +High-risk code areas for the first implementation stages: + +- `src/cortical/fabric/backend/planner.py` +- `src/cortical/fabric/runtime/core.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py` +- `src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py` +- `src/cortical/fabric/backend/buckets.py` +- `src/cortical/fabric/backend/ir.py` +- `benchmarks/fabric_suite_common.py` +- `benchmarks/run_fabric_scaling_profile.py` +- `benchmarks/run_fabric_bxt_scaling_audit.py` +- `benchmarks/run_fabric_mixed_population_audit.py` +- `tests/test_fabric_public_api.py` +- `tests/test_fabric_backend_plan.py` +- `tests/test_fabric_runtime.py` + +## Working Log + +### 2026-04-27 UTC - Scratchpad seeded + +- User clarified that this redo doc is the scratchpad/log and should be updated + continuously. +- User clarified 8 GPUs are available for audit experiments; this is an + execution resource, not by itself an audit closure requirement. +- User requested one Fabric audit script and a benchmark refactor into a Fabric + folder. +- Current read found live code still has route-level planner, runtime-owned + temporal policy, dot-product-only message rules, Config-era public API + normalization, population-name bucket identity, scalar-K repeat expansion, and + legacy model streaming helpers. + +### 2026-04-27 UTC - Benchmark/recovered-code pass + +- Added benchmark audit findings: Fabric scripts are split across flat files, + `fabric_suite_common.py` mixes model building/measurement/reporting, and no + single canonical audit entrypoint currently handles baseline gates, + owner-stage attribution, high-level API proof, or strict score gates. +- Added required audit refactor: dedicated `benchmarks/fabric/` package, + manifest-driven canonical runner, JSONL case logs, summary JSON, optional + sqlite, row-level owner/reopen schema, and optional parallel launch helpers + for experiments. +- Added cell/backend contract finding: cell specs expose useful transition and + reuse facts, but current runtime still reads cell metadata for execution + memory policy and direct-grad decisions. +- Compared `recovered_core.py` and the stale trail. The recoverable value is the + temporal-plan contract and model/backward helper split; the recovered core is + not a finished shared engine and still contains fallback/active-output branch + identities that must be redesigned, not copied blindly. + +### 2026-04-27 UTC - Benchmark API correction + +- User clarified that benchmark/audit code must always use Fabric high-level API + and normal PyTorch-style forward/backward. Even T*K/H experiments should be + expressed as user-level model calls with external losses, matching the April + 26 intended audit style. +- Benchmarks may report planner/backend metadata after the fact, but must never + design planner behavior, call private backend helpers as the measurement path, + or use benchmark-side streaming-loss/per-chunk-backward/time-tiling/detach + tricks to claim closure. + +### 2026-04-27 UTC - April 21 score-source rule + +- User clarified that benchmarks and audits should always use the April 21 + audit JSON for score references. +- Added operating rule that every audit score threshold resolves to + `audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json`. + For new K/H/mixed/frontier cases without exact April 21 matches, the manifest + must name the controlling April 21 row or group. + +### 2026-04-27 UTC - Deep read pass closed for planning + +- Completed the initial R0 read across planner, runtime, temporal executor, + temporal backward, flat buckets, transition lowering, tests, benchmark + scripts, cell/backend specs, recovered core, stale trail, additional goals, + and April 21 score JSON. +- Added reopen routing so audit failures go back to the stage that owns the + broken decision instead of being patched inside the audit harness. +- R0 evidence/doc stage is closed for this checkpoint. Future implementation + stages must continue appending owner decisions, audit results, and reopen + reasons here before committing. + +### 2026-04-27 UTC - Recovered prompt deep dive + +- Deep-read `ai_docs/prompt.tx`. It is a compact recovered copy of the original + redo request, not a separate implementation design, but it is now listed as a + baseline evidence source. +- Added explicit prompt traceability so the plan covers every recovered + requirement: April 21 code state, shared temporal engine, flat bucket identity, + buckets, superop lowering, shared forward/backward, K as `T*K`, future + per-timestep K, always-streaming T=1 base case, TBPTT horizon H, + planner-owned decisions, additional cleanup goals, strict audits, stage + reopen/jump-back behavior, continuous doc/skill updates, and commits. +- Added a stricter score-reference policy for exact rows and new rows. New + mixed-pop and T*K/H cases must still name controlling April 21 JSON reference + rows or groups, while April 26 recovered observations remain target context. + +### 2026-04-27 UTC - Prompt coverage matrix pass + +- Re-read `ai_docs/prompt.tx` line-by-line and converted the recovered prompt + into stable requirement IDs P0-P16 with source-line references, owning stages, + and closure evidence. +- Added audit-manifest requirements to record prompt IDs, shared-engine + metadata, and high-level API proof for every audit case. +- Added final-report requirement to mark P0-P16 closed or explicitly deferred + with owner and reason. + +### 2026-04-27 UTC - Prompt hard-gate pass + +- Re-read `ai_docs/prompt.tx` again and strengthened clauses that were present + but too implicit in the plan. +- Extended prompt IDs to P0-P19. New IDs cover April 26's additional audit + stage expectation, semantic/performance co-equal hard gates, and autonomous + stage-to-stage continuation unless there is significant confusion or a true + local blocker. +- Added prompt source handling rules: recovered trail and recovered core are + design evidence only; April 21 local code and April 21 audit JSON remain the + implementation/score baseline; recovered code must be compared against live + code before use; prompt IDs close only from current code/tests/audits/deletion. +- Strengthened R13 so H>64 frontier exploration cannot weaken the H=64 closure + gate. + +### 2026-04-27 UTC - GPU parallelism clarification + +- User clarified that the 8-GPU note is for using scripts to run parallel + experiments, not a requirement that audit runs deterministically shard cases + across GPUs. +- Updated the plan so multi-GPU launch support is optional execution + convenience. Audit closure remains about canonical high-level API runs, + April 21 score references, strict parity/performance gates, owner attribution, + and shared-engine proof. + +### 2026-04-27 UTC - K=1..128 R13 gate + +- User noted that Stage R13 did not explicitly name the K=128 target. +- User clarified that K should be audited from 1 through 128, with K=128 as the + currently accepted ceiling. +- Updated R13 so closure requires the K=1..128 sweep through the shared temporal + engine, with the sweep matching the extra-work scaling model. + +### 2026-04-27 UTC - R1 implementation start + +- R1 set ACTIVE. +- First code slice will introduce planner-owned `TemporalExecutionPlan` + structures and route runtime through that plan while preserving current + executor behavior. This is intentionally a scaffolding/migration slice, not + the final deletion of legacy route names. + +### 2026-04-27 UTC - R1 temporal plan scaffold + +- Added planner-owned temporal plan structures in + `src/cortical/fabric/backend/temporal_plan.py`: route compatibility, + schedule, substrate, boundary, carry, emission, checkpoint, gradient boundary, + backward window, executor, and `TemporalExecutionPlan`. +- Added `FabricExecutionPlanner.plan_temporal_execution(...)`. It records scalar + K schedules, runtime-variable K representation for future per-timestep K, + output emission, state materialization, full/rolling horizon H, checkpoint + ownership, and compatibility executor selection. +- Runtime now exposes `plan_temporal_execution(...)`, records + `last_temporal_execution_plan`, and obtains current sequence route + compatibility from the temporal plan in `forward_cells()` and + `forward_output_cells_for_readout()`. +- Current executor behavior is intentionally preserved. `SequenceSurfaceRoute` + remains compatibility metadata inside the temporal plan until later R1/R3/R4 + slices remove route-era ownership. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_supported_cuda_surface_selection -n0` passed, + `uv run pytest -q tests/test_fabric_execution_imports.py -n0` passed, + `uv run ruff check ...` passed, and `uv run ruff format --check ...` passed + for touched Python files. +- R1 remains ACTIVE. Open work: move more temporal policy decisions out of + runtime/local CUDA policy, connect temporal plan metadata into physical + backward planning, and reduce compatibility route assertions after semantic + parity gates are in place. + +### 2026-04-27 UTC - R1 execution-record metadata + +- Added temporal-plan metadata fields to `BackendExecutionRecord` so audit code + can read planner-owned schedule, scan length, emission, horizon, checkpoint, + substrate, and executor facts after a normal model forward/backward path. +- Threaded temporal-plan record metadata into PyTorch fallback records, + sequence-surface records, and flat-bucket temporal sequence records without + changing executor behavior. +- Added runtime test coverage that `forward_cells()` records both + `last_temporal_execution_plan` and matching `BackendExecutionRecord` + temporal metadata. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_supported_cuda_surface_selection -n0` passed, + `uv run pytest -q tests/test_fabric_execution_imports.py -n0` passed, + `uv run ruff check ...` passed, and `uv run ruff format --check ...` passed + for touched Python files. + +### 2026-04-27 UTC - R1 temporal plan attached to backward planning + +- `PlannedFabricBackwardExecution` can now carry the matching + `TemporalExecutionPlan`. +- `Runtime.plan_backend_backward_execution(...)` accepts the temporal plan, and + the flat-bucket temporal physical training path passes the current planner + temporal plan into backward planning. +- Added planner test coverage that backward planning preserves the same + temporal plan object and its scan/horizon metadata. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_supported_cuda_surface_selection -n0` passed, + `uv run pytest -q tests/test_fabric_execution_imports.py -n0` passed, + `uv run ruff check ...` passed, and `uv run ruff format --check ...` passed + for touched Python files. + +### 2026-04-27 UTC - R1 full runtime test pass + +- Ran `uv run pytest -q tests/test_fabric_runtime.py -n0`. +- Result: 280 passed, 1 profiler warning, runtime 346.08s. +- R1 remains ACTIVE. The scaffold is stable, but runtime still owns some + temporal/executor policy decisions that must move under planner ownership + before closing the stage. + +### 2026-04-27 UTC - R1 planner-owned backend selection slice + +- Active invariant: backend identity is a temporal execution-plan decision, not + a runtime-side duplicate selector after route compatibility has already been + planned. +- Moved selected backend name into `TemporalExecutorPlan` and + `BackendExecutionRecord.temporal_plan_backend_names`. +- `forward_cells()` and `forward_output_cells_for_readout()` now use + `temporal_plan.executor.backend_name` for downstream CUDA/PyTorch branching. +- Route-only compatibility queries still force `configured_backend="auto"` so + support inspection does not depend on the user backend request. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_supported_cuda_surface_selection tests/test_fabric_runtime.py::test_fabric_backend_cuda_requires_supported_cuda_surface -n0` passed, + `uv run pytest -q tests/test_fabric_execution_imports.py -n0` passed, + `uv run ruff check ...` passed, + `uv run ruff format --check ...` passed, and `git diff --check` passed. + +### 2026-04-27 UTC - R1 planner-owned static value policy slice + +- Active invariant: static tensor ownership, detachment, and native + materialization are temporal-plan policy decisions. Runtime may materialize + tensors, but should not independently decide the static-value mode after the + temporal plan has selected the backend/executor. +- Added `TemporalStaticValuePlan` and exposed it through + `BackendExecutionRecord` metadata. +- `forward_cells()` and `forward_output_cells_for_readout()` now use + `temporal_plan.static_values` for static KV inclusion, training detachment, + native static materialization, and PyTorch-autograd static-value mode. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_supported_cuda_surface_selection tests/test_fabric_runtime.py::test_fabric_backend_cuda_requires_supported_cuda_surface -n0` passed, + `uv run pytest -q tests/test_fabric_execution_imports.py -n0` passed, + focused temporal static metadata tests passed, `uv run ruff check ...` + passed, `uv run ruff format --check ...` passed, and `git diff --check` + passed. + +### 2026-04-27 UTC - R8 canonical Fabric audit runner start + +- Active invariant: benchmark/audit code is a high-level Fabric API consumer. + It may construct cases, run `model(...)`, external loss/backward, collect + metadata, and compare against April 21 references; it must not own planner, + backend, tiling, checkpoint, or temporal-loop policy. +- This was an early infrastructure scaffold because the user requested one + canonical Fabric audit script and dedicated Fabric benchmark folder. R8 is not + closed; it must not be used to skip R1/R2/R3 refactor ownership work. +- Added `benchmarks/fabric/` as the dedicated Fabric audit package with + canonical entrypoint `python -m benchmarks.fabric.run_audit`. +- Initial runner loads the April 21 JSON, writes `manifest.json`, `cases.jsonl`, + and `summary.json`, supports smoke/T=1/TK case manifests, records owner stage + and prompt requirement IDs, records high-level API proof, and supports simple + deterministic case slicing for manual multi-GPU runs. +- Updated benchmark planner signatures to include temporal-plan metadata from + normal model execution records. +- Validation: + `uv run python -m benchmarks.fabric.run_audit --help` passed, + `uv run python -m benchmarks.fabric.run_audit --plan smoke --dry-run ...` + passed, + `uv run python -m benchmarks.fabric.run_audit --plan smoke --device cpu --warmup 0 --iterations 1 ...` + passed and produced temporal-plan metadata through the normal model forward + path, + `uv run pytest -q tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, `uv run ruff check ...` passed, `uv run ruff format --check ...` + passed, and `git diff --check` passed. + +### 2026-04-27 UTC - R1 planner-owned fresh carry cache policy slice + +- Active invariant: fresh carry/cache virtualization is part of temporal carry + planning. Runtime may allocate or skip state tensors, but the decision that a + fresh multi-population inference row can use the backend population cache + belongs in the temporal plan. +- Added `TemporalCarryPlan.fresh_state_population_cache` and reason metadata. +- `forward_cells()` now consumes this carry-plan field instead of recomputing + the condition from backend name, route shape, population count, training mode, + materialization, and K. +- Benchmark planner signatures now include the fresh-cache temporal metadata. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_supported_cuda_surface_selection tests/test_fabric_runtime.py::test_fabric_backend_cuda_requires_supported_cuda_surface tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, + `uv run pytest -q tests/test_fabric_execution_imports.py -n0` passed, + `uv run ruff check ...` passed, `uv run ruff format --check ...` passed, + and `git diff --check` passed. + +### 2026-04-27 UTC - R1 planner-owned legacy recurrence population slice + +- Active invariant: the old single-population recurrence population identity is + compatibility metadata inside the temporal executor plan. Runtime should not + recompute that legacy path selection after the temporal planner has already + classified population cardinality and K. +- Added `TemporalExecutorPlan.legacy_recurrence_population_name` and record + metadata. +- `forward_cells()` and `forward_output_cells_for_readout()` now read the + legacy recurrence population from the temporal plan instead of calling the + runtime selector directly. +- This does not close the legacy surface deletion; it only records the + transitional identity under planner ownership so later R3/R4 cleanup can + remove the compatibility path with audit proof. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_supported_cuda_surface_selection tests/test_fabric_runtime.py::test_fabric_backend_cuda_requires_supported_cuda_surface tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, + `uv run pytest -q tests/test_fabric_execution_imports.py -n0` passed, + `uv run ruff check ...` passed, `uv run ruff format --check ...` passed, + and `git diff --check` passed. + +### 2026-04-27 UTC - User correction: one temporal superop target + +- User clarified that the shared temporal engine should eventually be a fully + owned temporal superop design, matching the April 26 direction. +- Durable invariant: Python may construct declarations, invoke the high-level + model API, and pass tensors/metadata, but should not own temporal scanning, + K microsteps, horizon windows, checkpoint/recompute, materialization policy, + or backward loops beyond unavoidable API glue. +- R1/R2/R3 implementation must therefore move toward one planner-owned temporal + superop path and delete Python loop/sibling route logic after parity and audit + gates, not merely make the current loops better organized. + +### 2026-04-27 UTC - R3 internal K schedule implementation start + +- Active invariant: K internal thinking is part of the temporal scan schedule. + The backend must not make K>1 work by materializing a user-visible repeated + boundary sequence; outputs are emitted according to the emission plan. +- Current target slice: update the physical temporal autograd superop to accept + outer `[B,T,...]` boundary inputs plus `inner_steps=K`, execute `T*K` + internally, accumulate boundary gradients back to the outer T inputs, and emit + only outer timestep outputs. +- This is R3/R4 shared-engine work and should keep normal public model + forward/loss/backward behavior unchanged. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference tests/test_fabric_runtime.py::test_fabric_cuda_training_surface_uses_flat_bucket_route_without_single_population_selector tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor -n0` + passed, + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_execution_imports.py -n0` + passed, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state tests/test_fabric_runtime.py::test_fabric_supported_cuda_surface_selection tests/test_fabric_runtime.py::test_fabric_backend_cuda_requires_supported_cuda_surface tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, `uv run ruff check ...` passed, + `uv run ruff format --check ...` passed, and `git diff --check` passed. + +### 2026-04-27 UTC - R4 planner-owned H/checkpoint implementation start + +- Active invariant: H horizon and checkpoint/materialization policy are temporal + planner decisions. High-level Fabric calls and benchmark cases may request H, + and may provide checkpoint steps as an explicit override, but absent that + override checkpoint steps must be determined by the planner. +- Current target slice: add `gradient_horizon_steps` and `checkpoint_steps` to + Fabric config/blueprint execution declarations and canonical Fabric audit case + construction as optional requests/overrides, default runtime temporal planning + from config, and thread the resulting plan into the physical temporal scan + policy. +- For this slice, H is interpreted as a physical scan-step horizon over the + flattened T*K stream, matching the prompt invariant that internal K is lowered + as temporal scan work. Future per-timestep K must keep the same semantic owner + even if the schedule representation becomes ragged. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_audit_runner.py tests/test_fabric_public_api.py -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_supported_training_backward_uses_reset_aware_physical_policy -n0` + passed with reset-present K>1 single-pop, reset-present K>1 mixed-pop, and + reset-aware physical backward coverage, + `uv run python -m benchmarks.fabric.run_audit --plan smoke --dry-run --out-dir /tmp/redo_fixmaass_h_plan_smoke` + passed, + `uv run python -m benchmarks.fabric.run_audit --plan tk-scaling --dry-run --out-dir /tmp/redo_fixmaass_h64_k128_dry --families slstm --sizes 1m --modes forward_backward --batches 2 --seq-lens 4 --inner-steps 1,2,4,8,16,32,64,128 --hidden-sizes 8 --gradient-horizon-steps 64 --checkpoint-steps planner --limit 4` + passed and produced planner-owned checkpoint manifests (`checkpoint_steps: + null`, H=64), + `uv run ruff check ...` passed, + `uv run ruff format --check ...` passed for Python files, and + `git diff --check` passed. + +### 2026-04-27 UTC - User correction: fully owned CUDA temporal engine + +- User clarified that the April 26 final direction had a fully owned CUDA + temporal engine, not just better planner metadata around Python scan loops. +- Durable invariant: the acceptable end state is one CUDA-owned temporal superop + over flat buckets. Python `autograd.Function` temporal scan bodies and + Python-side T/K/H loops are transitional scaffolding only; they can be used to + keep parity while rebuilding, but R3/R4 are not complete until the temporal + scan/backward/checkpoint/materialization owner is CUDA/backend code and the + legacy Python loop/sibling route paths are deleted after audit closure. +- Current R4 patch still matters because it moves H/checkpoint intent into the + planner and high-level API; the follow-on implementation must consume that + plan in the CUDA temporal engine rather than stopping at Python execution. + +### 2026-04-27 UTC - User correction: reset parity hard gate + +- User clarified that reset parity failed intermittently during the April 24-26 + work and must be explicitly checked. +- Durable gate: every temporal engine slice that affects forward, backward, + K, H, checkpoint/recompute, materialization, active-output windows, or legacy + route deletion must include reset-present and reset-absent parity where the + path supports resets. Reset checks must cover outputs, final state when + materialized, input/boundary gradients, and nonzero parameter gradients; a + no-reset-only pass is not enough to close the owner. + +### 2026-04-27 UTC - R4 high-level H/reset parity follow-up + +- Live route check: normal high-level CUDA sequence training is using the + flat-bucket temporal surface (`surface_key=flat_bucket_sequence_surface`) for + the tested single-population H/reset row, not the old single-population + recurrence sequence surface. Do not polish the old recurrence route as if it + were shared temporal engine closure. +- Current target slice: add a high-level model parity test that requests + planner-owned H through config, runs reset-present sequence training through + the flat-bucket temporal path, and compares CUDA against PyTorch for outputs, + materialized state, input gradients, and all parameter gradients. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity -n0` + passed. + +### 2026-04-27 UTC - R8 reset-aware audit manifest slice + +- Active invariant: reset-present rows are first-class audit cases, not + one-off local probes. The canonical Fabric audit entrypoint should make reset + coverage visible in manifests/results while still using only the high-level + Fabric API. +- Current target slice: add reset modes to `benchmarks.fabric.run_audit` case + generation. `reset_mode=present` builds a deterministic `[B,T]` mask and + passes it to `model(..., resets=...)`; it does not alter backend/planner + policy or add benchmark-owned temporal tiling. +- This complements the R4 high-level H/reset parity test. Reset-present + performance rows still use the April 21 JSON as the score source, with reset + mode recorded as additional audit metadata. +- Validation: + `uv run pytest -q tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, + `uv run python -m benchmarks.fabric.run_audit --plan tk-scaling --dry-run ... --reset-modes absent,present` + produced reset-absent and reset-present manifests, `uv run ruff check ...` + passed, and `uv run ruff format --check ...` passed. + +### 2026-04-27 UTC - R5 temporal engine owner gate start + +- Active invariant: the redo must not confuse planner metadata with CUDA + temporal-superop closure. The planner/audit record must say who owns the + temporal scan and backward loop today, and strict audits must be able to fail + until that owner is the target `cuda_temporal_superop`. +- Current target slice: add `TemporalEnginePlan` to the planner output and + backend execution metadata. Current flat-bucket CUDA rows are explicitly + labeled `python_autograd_scan` with status `transitional_python_scan`; PyTorch + rows are labeled `pytorch_reference`; the target owner is recorded as + `cuda_temporal_superop`. +- The canonical Fabric audit runner now has a `--require-cuda-temporal-owner` + gate. This is expected to fail on the transitional implementation and should + be enabled for final closure audits after the CUDA temporal superop is real. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_runtime_records_planner_temporal_plan_for_forward_cells tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity -n0` + passed, + `uv run python -m benchmarks.fabric.run_audit --plan smoke --dry-run --require-cuda-temporal-owner ...` + passed parser/manifest dry-run, `uv run ruff check ...` passed, + `uv run ruff format --check ...` passed, and `git diff --check` passed. + +### 2026-04-27 UTC - R2/R5 flat bucket identity slice + +- Active invariant: the shared temporal engine cannot be considered shared + while plan/cache identity uses user population names. Population names may + remain as transitional state/parameter binding handles, but the bucket + identity consumed by the planner and future CUDA temporal superop must be + derived from semantic transition schema, parameter-binding slot, dimensions, + topology bucket, delay/reset layout, and sharing policy. +- Current target slice: add explicit flat bucket identity to `FabricBucket` + and move planner/cache bucket signatures to that identity. Keep + `population_name` available for existing TensorDict/module lookup only, then + record the temporal substrate as `flat_bucket_identity` when the identity is + no longer name-keyed. +- Reset parity remains a hard gate for every follow-up temporal engine owner + change. This slice should not change math; validation focuses on planner + cache behavior and existing high-level reset-present temporal parity. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity -n0` + passed, `uv run ruff check ...` passed, and + `uv run ruff format --check ...` passed. User corrected the label during the + slice: use `flat_bucket_identity`. + +### 2026-04-27 UTC - R2 runtime temporal flat bucket naming slice + +- Active invariant: backend IR now uses `flat_bucket_identity`, but the runtime + temporal plan still exposes `TemporalPopulationBucket` and active code imports + `backend_order_population_buckets`. That is a naming/ownership bridge, not + final CUDA temporal superop closure. +- Current target slice: introduce `TemporalFlatBucket` with `binding_name`, + `binding_slot`, and `flat_bucket_identity`; switch active temporal forward and + backward callers to `backend_order_flat_buckets`. Keep compatibility aliases + only where state and parameter TensorDict lookup still requires the binding + name. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity -n0` + passed, `uv run ruff check ...` passed, and + `uv run ruff format --check ...` passed. + +### 2026-04-27 UTC - User correction: temporal owner is shared multi-pop + +- User clarified that temporal owner closure must be judged as shared + multi-pop ownership. Single-pop is only the one-bucket case of the same + temporal engine, not a separate owner or sufficient closure proof. +- Durable gate: every temporal owner, reset, H/checkpoint, K, backward, or + CUDA-superop closure must include mixed-pop coverage alongside single-pop + coverage where the path supports it. Do not mark a temporal owner closed on a + single-pop-only row. +- Current target follow-up: the reset metadata/H parity slice now needs + high-level mixed-pop reset-absent and reset-present parity in addition to the + single-pop parameterized row. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity -n0` + passed with reset-absent and reset-present rows for both single-pop and + mixed-pop; `uv run pytest -q tests/test_fabric_benchmark_suite_common.py -n0` + passed; `uv run ruff check ...` and `uv run ruff format --check ...` passed. + +### 2026-04-27 UTC - R8 shared temporal coverage audit guard + +- Active invariant: final temporal owner audits must not close from a single-pop + manifest. Single-pop is one bucket; mixed-pop is the shared temporal owner + proof. +- Current target slice: add `--require-shared-temporal-coverage` to the + canonical Fabric audit runner. The gate requires both `single` and `mixed` + `population_mode` cases in the manifest. The current generator only emits + single-pop cases, so this flag intentionally fails until mixed-pop audit case + generation is implemented through the same high-level API path. +- Validation: + `uv run pytest -q tests/test_fabric_audit_runner.py -n0` passed, + `uv run python -m benchmarks.fabric.run_audit --plan smoke --dry-run --out-dir /tmp/redo_fixmaass_shared_coverage_gate --require-shared-temporal-coverage` + failed as expected with + `shared_temporal_owner_requires_single_and_mixed_population_coverage`; + `uv run ruff check ...` and `uv run ruff format --check ...` passed. + +### 2026-04-27 UTC - R8/R12 mixed-pop audit generation start + +- User clarified that the temporal owner is shared multi-pop: single-pop and + mixed-pop are the same temporal engine cardinality, and single-pop is only the + one-bucket case. Do not leave the shared coverage guard as a permanently + failing manifest-only check. +- Live code finding: `benchmarks/fabric/audit.py` still types + `population_mode` as `single` and the canonical case generator emits no + mixed-pop rows. Public `fabric.Blueprint` already supports multi-population + declarations through explicit `Population.nodes`, so the audit runner can add + mixed-pop rows without calling private planner/runtime helpers. +- Current target slice: add mixed-pop case generation and execution to the + canonical audit runner through normal high-level Fabric calls. Mixed rows + must record `population_mode=mixed`, include P2/P12 ownership, use the same + reset/K/H/checkpoint request fields as single rows, and make + `--require-shared-temporal-coverage` pass only when both population modes are + present. +- This is audit/tooling enablement, not R12 performance closure. R12 remains + open until mixed-pop rows cover the April 21 T=1 shape grid, compare against + matched stack/MoE baselines, pass parity first, and meet throughput/memory + gates with shared temporal owner metadata. +- Implemented the slice by adding `--population-modes`, mixed-pop manifest + generation, public-Blueprint mixed Fabric model construction with explicit + `Population.nodes`, and mixed execution through the same high-level + `run_sequence_case` wrapper. The runner now fails + `--require-shared-temporal-coverage` for single-only manifests and passes when + both `single` and `mixed` cases are present. +- Current-code smoke proof: + `CUDA_VISIBLE_DEVICES=0 uv run python -m benchmarks.fabric.run_audit --plan smoke --out-dir /tmp/redo_fixmaass_shared_smoke_actual --population-modes single,mixed --require-shared-temporal-coverage --warmup 0 --iterations 1` + passed with two cases. Both rows reported `flat_bucket_sequence_surface`, + `temporal_plan_bucket_identity=["flat_bucket_identity"]`, current forward + owner `python_autograd_scan`, and target owner `cuda_temporal_superop`. + This confirms shared audit coverage and also keeps R3/R4/R12 open because the + CUDA temporal superop is not yet the active owner. +- Validation: + `uv run pytest -q tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, `uv run ruff check ...` passed, + `uv run ruff format --check ...` passed, and `git diff --check` passed. + +### 2026-04-27 UTC - R12 canonical mixed stack-baseline start + +- Live code finding: the repo still has a separate + `benchmarks/run_fabric_mixed_population_audit.py` script that compares + mixed-pop Fabric to a mixed Axon+sLSTM stack, but the canonical + `benchmarks.fabric.run_audit` runner does not yet include that comparison. +- Active invariant: R12 audit evidence must live in the single canonical Fabric + audit entrypoint and must remain a high-level API consumer. The audit may + build matched Fabric and stack models, but it must not call planner/runtime + private helpers or choose backend temporal policy. +- Current target slice: move the reusable mixed-stack model construction and + param matching into `benchmarks/fabric_suite_common.py`, then attach a + same-parameter mixed stack baseline to canonical mixed-pop audit results. + This is still not R12 closure; it only makes the required Fabric-vs-stack + score evidence available in the correct artifact. +- Implementation note: the first smoke run exposed a bad comparison before + commit. Mixed Fabric was being compared against a much larger stack because + the matcher used approximate backbone counts while the measured audit models + include the wrapper head and exact Fabric shared-runtime parameters. The slice + now uses sequence-model parameter formulas for Fabric and mixed stack matches, + fixes the Fabric shared-runtime parameter formula, and adds a mixed-stack gate + that fails when the baseline is not parameter matched within 5 percent. +- Current-code audit finding: after correcting parameter matching, the + canonical mixed rows are valid but the existing backend does not yet pass the + mixed Fabric-vs-stack speed gate on small smoke rows. The corrected smoke row + matched Fabric `1,008,000` params against mixed stack `958,858` params + (`mixed_stack_param_error=0.04875`) and reported + `mixed_fabric_stack_ratio=0.194` at `B=1`. A more realistic `B=1024` T=1 + row still reported `mixed_fabric_stack_ratio=0.556`. This is useful failure + evidence for R12/R3/R4, not audit-script closure. +- Current backend owner on these rows remains `python_autograd_scan` with + target owner `cuda_temporal_superop`; shared mixed-pop evidence must not close + until the temporal owner is the fully shared CUDA superop path. +- Validation: + `uv run pytest -q tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, `uv run ruff check ...` passed, + `uv run ruff format --check ...` passed, + `CUDA_VISIBLE_DEVICES=0 uv run python -m benchmarks.fabric.run_audit --plan smoke --out-dir /tmp/redo_fixmaass_mixed_stack_smoke_v2 --population-modes mixed --warmup 0 --iterations 1` + completed and recorded the expected parameter-matched mixed-stack gate + failure, and + `CUDA_VISIBLE_DEVICES=0 uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir /tmp/redo_fixmaass_mixed_stack_t1_b1024 --population-modes mixed --families slstm --sizes 1m --modes forward --batches 1024 --seq-lens 1 --inner-steps 1 --hidden-sizes 32 --warmup 0 --iterations 1` + completed with the same corrected baseline fields. + +### 2026-04-27 UTC - R3/R4 explicit physical scan schedule start + +- Active invariant: the shared temporal engine is a T*K physical scan with an + emission plan. Python scan bodies are transitional, but they must already + consume the same explicit schedule contract that the CUDA temporal superop + will eventually own. +- Live code finding: `_TemporalBucketSequenceFunction.forward` still computes + `outer_step_index, inner_step_index = divmod(step_index, inner_steps)` inside + the scan loop and hard-codes outer-step emission with + `inner_step_index == inner_steps - 1`. This is semantically correct for scalar + K but not a mature schedule abstraction, and it is a bad place to keep + per-timestep K future logic. +- Current target slice: introduce a small physical scan schedule helper for + scalar K that emits physical step, outer step, inner step, boundary reset, and + output emission metadata. Use it in the temporal autograd scan path without + changing behavior, and add unit coverage for K=1, K>1, and K=128 schedule + shape. +- Implemented `temporal_scan.py` with `TemporalPhysicalScanSchedule` and + `TemporalPhysicalScanStep`, then changed the physical temporal autograd + forward path to consume `scan_schedule.steps` for reset and emission + decisions. The step contract separates boundary/state reset timing from + transition reset timing so the April 24-26 reset parity risk is explicit in + the schedule. This is not CUDA-superop closure; it is the transitional + schedule contract the CUDA owner must replace/consume. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_marks_outer_emissions_and_resets tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_covers_k128_ceiling -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity -n0` + passed with reset-present and reset-absent coverage, and + the audit manifest now has an explicit K sweep guard that includes K=128 for + both single and mixed population modes. `uv run ruff check ...` / + `uv run ruff format --check ...` passed. + +### 2026-04-27 UTC - R9 benchmark package cleanup start + +- Active invariant: canonical Fabric audits live under `benchmarks/fabric/` and + use the high-level model API. Legacy top-level benchmark scripts can remain + temporarily, but the redo audit path should not keep adding Fabric-specific + logic to the top-level benchmark namespace. +- Current target slice: move the shared Fabric benchmark helpers into + `benchmarks/fabric/suite_common.py`, update the canonical audit runner and + tests to import from the package path, and leave `benchmarks/fabric_suite_common.py` + as a compatibility wrapper for older scripts until R15 cleanup deletes or + rewrites the legacy entrypoints. +- Validation: + `uv run pytest -q tests/test_fabric_audit_runner.py tests/test_fabric_benchmark_suite_common.py -n0` + passed, `uv run ruff check ...` passed, + `uv run ruff format --check ...` passed, and + `uv run python -m benchmarks.fabric.run_audit --plan smoke --dry-run --out-dir /tmp/redo_fixmaass_import_path_smoke --population-modes single,mixed --require-shared-temporal-coverage` + passed with two canonical dry-run cases. + +### 2026-04-27 UTC - R12/R13 mixed stack gate scope fix + +- Live audit finding before K128 smoke: the canonical mixed-pop runner attached + the same-parameter mixed stack baseline to every mixed case, including + `tk-scaling` rows where Fabric intentionally performs K internal scan steps. + That would incorrectly compare K>1 Fabric work against a 1x stack baseline. +- Fix: scope the mixed Fabric-vs-stack baseline gate to mixed T=1/K=1 rows. + R13 K/H rows still use April 21 T=1 and streaming sequence-loss references + plus K/T/H scaling criteria; they do not use the mixed stack speed gate. +- Follow-up bug found by actual K128 training smoke: temporal output backward + sliced output gradients by physical window indices. For K>1, output gradients + are indexed by emitted outer steps, not every physical T*K step, so + `T=1,K=128,H=64` could hand a zero-length grad-output window to output + backward. +- Fix: map output gradients from physical step to emitted outer step, create + zero grad-output entries for non-emitting physical steps, and handle terminal + output gradients as a one-step emission at the final outer step. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run python -m benchmarks.fabric.run_audit --plan tk-scaling --out-dir /tmp/redo_fixmaass_k128_scope_train_smoke_v2 --population-modes mixed --families slstm --sizes 1m --modes forward_backward --batches 1 --seq-lens 1 --inner-steps 128 --hidden-sizes 8 --gradient-horizon-steps 64 --checkpoint-steps planner --warmup 0 --iterations 1` + passed and recorded no mixed-stack baseline, `inner_steps=128`, + `horizon=64`, and `checkpoint=64`. + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients -n0` + passed against the PyTorch reference, the existing K>1 and H/reset parity + tests passed, and `uv run pytest -q tests/test_fabric_audit_runner.py -n0` + passed. + +### 2026-04-27 UTC - Substage closure and R3/R4 owner handoff + +- Closed substage: the R9 benchmark package cleanup slice is complete and + committed in `971097a`. The canonical Fabric audit helpers now live under + `benchmarks/fabric/`, and the top-level helper is only a migration wrapper. + Full R9 remains pending until every closure path is proven high-level API + only and legacy wrappers are deleted in R15. +- Closed substage: the R13 K128 outer-emission gradient indexing bug is fixed + and committed in `1d0dd92`. The K128 smoke now runs through the high-level + model/loss/backward path and records planner-owned H/checkpoint metadata. + Full R13 remains pending until the full K=1..128, T, H, terminal, and + per-timestep audit matrix passes against April 21 references. +- Current open owner: R3/R4 temporal forward/backward. The next high-priority + slice is to remove remaining ad hoc physical-step mapping from the + transitional backward path, make schedule-owned output-gradient indexing cover + both per-timestep and terminal emissions, and add mixed-pop K>1 terminal-loss + parity. This is still transitional `python_autograd_scan`; it must not be + relabeled as `cuda_temporal_superop` until the fully owned CUDA temporal + superop exists and passes the audit gates. + +### 2026-04-27 UTC - R3/R4 schedule-owned backward emission mapping + +- Implemented the R3/R4 handoff slice: scalar physical-step mapping now lives in + `temporal_scan.py` through `scalar_temporal_scan_step`, and terminal vs + per-timestep output-gradient indexing goes through + `emitted_output_index_for_scan_step`. The transitional backward recompute, + output-gradient, and boundary-gradient accumulation paths now consume this + schedule contract instead of open-coded physical-step arithmetic. +- Added parity for the missing high-level terminal-loss surface: + mixed-population `T=3,K=2,H=2`, reset-present, `materialize_final_state=False`, + `output_boundary="terminal"`, normal model forward, external loss, and + `loss.backward()`. This guards the exact final outer-emission gradient mapping + used by terminal T*K audit rows. +- This improves the transitional Python scan semantics only. The active + metadata remains `python_autograd_scan` with target `cuda_temporal_superop`; + R3/R4 remain open until the fully owned CUDA temporal superop replaces the + Python scan/backward loops and passes the full parity/performance audit. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity -n0` + passed, and `uv run ruff check ...` / `uv run ruff format --check ...` + passed on the touched files. + +### 2026-04-27 UTC - R3/R4 lazy physical scan schedule + +- Follow-up R3/R4 performance hygiene: `TemporalPhysicalScanSchedule` no longer + stores one Python step object per physical `T*K` scan step. It keeps only + `outer_time_steps` and `inner_steps`, exposes lazy `iter_steps()` and + `step_at()` helpers, and preserves the small-test `steps` compatibility + property for explicit assertions. +- This matters for the audit frontier because `T=16K,K=128` is 2,097,152 + physical steps. The transitional Python scan still loops, so this is not final + CUDA temporal-superop closure, but it removes an avoidable schedule-table + materialization that would make the frontier less representative. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed with a lazy + `T=16K,K=128` schedule assertion, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity -n0` + passed, and `uv run ruff check ...` / `uv run ruff format --check ...` + passed on the touched files. + +### 2026-04-27 UTC - R13 canonical terminal and per-timestep audit coverage + +- Audit gap found after the terminal-loss parity slice: the canonical + `tk-scaling` manifest defaulted every training row to `sequence` loss. That + meant R13 could not prove both required output materialization modes unless a + caller remembered to run a separate terminal-only command. +- Fix: `tk-scaling` now emits both `sequence` and `terminal` training output + boundaries by default for `forward_backward` cases. Forward-only cases stay at + one sequence-output row to avoid duplicate measurements. Case IDs now include + `losssequence` or `lossterminal`, and the summary records the boundaries + covered by the run. `--training-output-boundaries` can still narrow or + explicitly request boundaries for smoke/debug runs. +- Validation: + `uv run pytest -q tests/test_fabric_audit_runner.py -n0` passed, + `uv run python -m benchmarks.fabric.run_audit --plan tk-scaling --dry-run --out-dir /tmp/redo_fixmaass_tk_boundaries_dry --families slstm --sizes 1m --modes forward,forward_backward --batches 2 --seq-lens 4 --inner-steps 2 --hidden-sizes 8 --gradient-horizon-steps 2 --checkpoint-steps planner --population-modes mixed --warmup 0 --iterations 1` + wrote three cases: forward sequence, training sequence, and training terminal, + and + `CUDA_VISIBLE_DEVICES=0 uv run python -m benchmarks.fabric.run_audit --plan tk-scaling --out-dir /tmp/redo_fixmaass_tk_boundaries_smoke --families slstm --sizes 1m --modes forward_backward --batches 1 --seq-lens 2 --inner-steps 2 --hidden-sizes 8 --gradient-horizon-steps 2 --checkpoint-steps planner --population-modes mixed --training-output-boundaries sequence,terminal --warmup 0 --iterations 1` + ran both high-level training boundaries successfully. Both rows still record + transitional temporal owner `python_autograd_scan`, so this is audit coverage, + not R13 closure. + +### 2026-04-27 UTC - R4 provided-state gradient parity for terminal K>1 + +- Closed substage: the R13 canonical terminal/per-timestep audit coverage slice + is committed in `4a1dc4e`; it remains audit coverage only, not R13 closure, + because temporal owners still report `python_autograd_scan`. +- New R4 parity gap found while continuing the open temporal owner: high-level + mixed-pop `T=3,K=2,H=2`, reset-present, `output_boundary="terminal"`, and + `materialize_final_state=True` returned the terminal output but recorded the + backend execution as `sequence_output_boundary:all_steps` with + `sequence_output_contract:full_cells`. The cause was a fallback through + `forward_cells`, which plans full-cell sequence output and slices afterward. +- Fix: broadened the high-level readout path so flat-bucket CUDA executes + provided-state and materialized-final-state readout directly through + `execute_temporal_bucket_sequence` with the requested readout output contract + and output boundary. This keeps the active path under the shared flat-bucket + temporal executor and lets terminal metadata reflect the requested semantics. +- Added parity for provided initial-state gradients on that row. The test + compares CUDA and PyTorch outputs, materialized final state, input gradients, + provided state/carry gradients, and full parameter gradients. +- This is still transitional R4 work. It closes a semantic fallback leak, but + the active temporal forward/backward owner remains `python_autograd_scan` + until the CUDA temporal superop owns the scan and backward loops. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_unmaterialized_final_state_preserves_sequence_output_math -n0` + passed, and `uv run ruff check ...` / `uv run ruff format --check ...` + passed on the touched files. + +### 2026-04-27 UTC - R3/R4 backend temporal owner seam started + +- Status answer for the current owner: the final CUDA temporal superop is still + not implemented. The real remaining R3/R4 backend work is to replace the + Python/autograd temporal scan and backward loops with a CUDA-owned temporal + superop over flat bucket identity. The current code is semantically stronger + than the April 21 baseline, but active owners still report + `python_autograd_scan`. +- Backend slice started: added a CUDA-dispatcher `TemporalScanDescriptor` in + `backend/cuda/execution/common.cuh` with explicit owner, physical scan length, + emission stride, terminal/sequence output mapping, and a reserved + `CudaTemporalSuperOp` owner value. The existing dispatcher now consumes this + descriptor for readout emission instead of local terminal-output lambdas. +- Runtime/audit metadata now records `launch_temporal_scan_owners`, + outer/inner/physical scan steps, emission count/stride, and output boundary. + Legacy recurrence-surface dispatcher rows report `backend_host_loop`, while + flat-bucket temporal rows report the current transitional + `python_autograd_scan`. This prevents accidental relabeling as + `cuda_temporal_superop` before the actual CUDA owner exists. +- This is actual backend work but not closure: it creates the backend seam the + CUDA superop must take over next. R3/R4 remain open until temporal forward and + backward scan execution move out of Python host loops and parity/performance + gates pass. +- High-priority remaining work after this checkpoint: + R3/R4 implement the CUDA-owned temporal forward/reverse scan body; R2 remove + all population-name semantic bucket identity and Config-era population caps; + R9 finish audit cleanup and wrapper deletion; R10 expand strict parity; R11 + through R14 run April-21-referenced performance closures; R15 delete legacy + routes; R16 write final closure. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run python - <<'PY' ... dispatcher_cuda._load_ext() ... PY` + loaded `fabric_dispatcher_cuda`, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence -n0` + passed, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_marks_outer_emissions_and_resets tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_covers_k128_ceiling tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_does_not_store_large_tk_step_table tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_maps_sequence_and_terminal_emissions -n0` + passed, + `uv run pytest -q tests/test_fabric_execution_imports.py::test_fabric_cuda_execution_sources_keep_scalable_backend_contract -n0` + passed, and ruff/check-format plus `git diff --check` passed on touched + files. +- Reset parity note: while checking broader reset coverage, the existing row + `tests/test_fabric_runtime.py::test_fabric_cuda_slstm_sequence_backend_matches_pytorch_reference_forward[True-16-0-microkernel-sequence_major]` + failed on both this worktree and a detached clean worktree at `8cd2468`, so + it is not introduced by this backend seam patch. This remains an open reset + parity issue for R3/R4 and should be handled by the temporal owner, not hidden + by audit filtering. + +### 2026-04-27 UTC - R3/R4 reset parity root cause + +- Active owner: R3/R4 temporal forward/backward reset semantics. +- Invariant: a provided reset tensor is runtime data and must be honored by the + shared temporal step even when the caller did not precompute per-step reset + flags. Reset flags are only an optional skip/cache hint, not permission to + ignore the reset mask. +- Reproduced the open reset row: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q 'tests/test_fabric_runtime.py::test_fabric_cuda_slstm_sequence_backend_matches_pytorch_reference_forward[True-16-0-microkernel-sequence_major]' -n0` + fails before the fix with output mismatch on reset-affected timesteps. +- Debugged the high-level readout route and found it calls + `execute_temporal_bucket_sequence(..., step_reset_flags=None, ...)`; the + shared `_forward_stream_step` reset block currently runs only for graph + capture or `has_resets is True`, so this route silently skips a supplied reset + mask. Diff localization matched the missed reset semantics: batch 0 diverged + from timestep 1 onward and batch 1 diverged only on the later timestep-3 + reset, while a timestep-0 reset from a fresh zero state was numerically + unchanged. +- Fix: `_forward_stream_step` now applies a supplied reset mask unless the + caller explicitly passes `has_resets=False`. This preserves the skip hint for + known empty reset steps but makes `has_resets=None` safe for high-level routes + that pass reset tensors without precomputed host flags. +- Closed substage: the open reset parity row is fixed for the transitional + shared temporal path. R3/R4 remain open for the CUDA-owned temporal superop, + but this reset issue is no longer blocking parity coverage. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q 'tests/test_fabric_runtime.py::test_fabric_cuda_slstm_sequence_backend_matches_pytorch_reference_forward[True-16-0-microkernel-sequence_major]' 'tests/test_fabric_runtime.py::test_fabric_cuda_slstm_sequence_backend_matches_pytorch_reference_forward[False-16-0-microkernel-sequence_major]' -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients -n0` + passed, `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py -k 'reset or resets' -n0` + passed, and ruff/check-format passed on the touched runtime/test files. + +### 2026-04-27 UTC - R2/R7 flat bucket cardinality and identity cleanup + +- Active owner moved to the next high-priority substrate cleanup after the reset + substage: R2 flat bucket identity and R7 Config-era population cardinality. +- Root issue: `Config` still rejected more than two populations, even though the + shared temporal bucket executor already iterates active population bindings. + This was a Config-era cap, not a backend substrate limit. +- Root identity invariant: `flat_bucket_identity` is semantic bucket identity, + not a parameter/state binding label. Binding slots may remain available for + state/parameter ownership and debugging, but they must not be embedded in the + flat bucket identity used to describe the lowered bucket kind. +- Implemented slice: + removed the `len(cell_populations) > 2` validation cap, and split + `TemporalFlatBucket.binding_identity` from `flat_bucket_identity`. The latter + now records cell type, cell kind, state schema, and public schema only. +- Added a three-population flat-bucket planning test with two sLSTM bindings + and one Axon binding. This proves more than two binding populations can build + and produce the same shared temporal bucket plan shape without user names or + binding slots leaking into semantic flat bucket identity. +- Added a CUDA high-level forward parity test for the same three-population + binding setup with reset-present `T=2`. It compares CUDA shared temporal + output/state against PyTorch and asserts the flat-bucket temporal sequence + route, not a single-bucket executor. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_temporal_bucket_plan_exposes_flat_bucket_identity tests/test_fabric_backend_plan.py::test_temporal_bucket_plan_accepts_more_than_two_binding_populations -n0` + passed, + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_sequence_supports_three_population_bindings -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity -n0` + passed, `uv run pytest -q tests/test_fabric_public_api.py -n0` passed, and + ruff/check-format passed on the touched files. + +### 2026-04-27 UTC - R2/R8 planner cache uses flat bucket identity + +- Continued the same substrate cleanup into planner policy: executable plan + cache keys must use flat bucket identity plus execution-relevant shape/topology + facts, not parameter binding slots or user population labels. +- Implemented `FabricBucket.planner_signature`, which excludes + `parameter_binding` but includes receiver count, degree range, dimensions, + delay depth, stencil template, sharing pattern, receiver kind, transition + signature, and sparse-overlay status. +- Updated forward and backward plan-cache keys to use `planner_signature`. + Cached plans are rebound with the current `bucket_id` when reused, so sharing + the executable policy does not corrupt downstream bucket attribution. +- Added a planner regression where two buckets have different population names, + population indices, and parameter bindings but identical planner signatures. + The forward planner now records one miss and one hit, the backward receiver + and sender phases record two misses and two hits, and all returned plan bucket + ids remain `(0, 1)`. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_planner_cache_keys_flat_bucket_identity_not_binding_slot -n0` + passed, + `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, and + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_sequence_supports_three_population_bindings tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence -n0` + passed. + +### 2026-04-27 UTC - R3/R4 terminal emission materialization owner + +- Closed substage before moving on: the R2/R8 planner-cache flat-bucket + identity slice is complete and committed in `3926d95`. R2 remains open until + all remaining semantic bucket identity cleanup and legacy attribution removal + are finished. +- Active owner returns to R3/R4 temporal forward/backward. Invariant: terminal + output materialization is a temporal engine emission decision, not a + high-level caller slice after every timestep has been carried through Python. +- Live issue found: the transitional flat temporal executor recorded + `sequence_output_materialization:terminal_step_only`, but the non-grad step + loop and physical autograd forward still materialized all outer emissions for + terminal calls. `Runtime.forward_output_cells_for_readout` then sliced + `[:, -1:]`, hiding the backend ownership violation. +- Current implementation slice: + pass `output_boundary` into the physical temporal autograd function, emit only + the final outer timestep for terminal output, make the non-grad flat temporal + step loop allocate a one-step output buffer for terminal calls, and replace + the high-level flat-route slice with a contract assertion that terminal flat + execution returns exactly one timestep. +- Added high-level CUDA/PyTorch parity coverage for mixed-pop `T=4,K=2`, + reset-present, `materialize_final_state=False`, and terminal output in + inference semantics. This guards the non-grad flat temporal loop while keeping + the user path ordinary `model(...)`. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_temporal_terminal_output_materializes_only_final_step_in_executor tests/test_fabric_runtime.py::test_fabric_cuda_terminal_output_boundary_matches_sequence_last_step tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_unmaterialized_final_state_preserves_sequence_output_math -n0` + passed, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_marks_outer_emissions_and_resets tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_covers_k128_ceiling tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_does_not_store_large_tk_step_table tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_maps_sequence_and_terminal_emissions -n0` + passed, and ruff/check-format plus `git diff --check` passed on touched + Python files. + +### 2026-04-27 UTC - R2/R3 stale legacy owner marker removal + +- Closed substage before moving on: terminal emission materialization ownership + is complete and committed in `a2df6d2`. +- Active invariant: single-pop and mixed-pop are cardinalities of the same flat + temporal engine. A single-pop `K=1` row must not carry a legacy recurrence + population owner just because older recurrence-surface metadata used one. +- Live issue found: the planner always selects + `selected_implementation="flat_transition_buckets"` for the supported CUDA + sequence route, but still filled + `TemporalExecutorPlan.legacy_recurrence_population_name` for single-pop + `K=1`. Runtime currently ignores that field because the route is not a + legacy cell recurrence surface, but audit metadata could still make the row + look partially legacy-owned. +- Current implementation slice: + set `legacy_recurrence_population_name=None` for the flat temporal route and + remove the stale helper that inferred a legacy owner from population + cardinality. The single-pop K=1 planner test now asserts flat transition + buckets, flat bucket identity, no legacy recurrence population, and target + owner `cuda_temporal_superop`. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fabric_temporal_execution_plan_keeps_single_pop_k1_on_flat_temporal_owner tests/test_fabric_backend_plan.py::test_fabric_temporal_execution_plan_records_schedule_and_emission tests/test_fabric_backend_plan.py::test_fabric_temporal_execution_plan_records_fresh_multi_population_cache_policy -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_sequence_supports_three_population_bindings tests/test_fabric_runtime.py::test_fabric_cuda_temporal_terminal_output_materializes_only_final_step_in_executor -n0` + passed, `uv run pytest -q tests/test_fabric_backend_plan.py -n0` passed, and + ruff/check-format plus `git diff --check` passed on touched files. + +### 2026-04-27 UTC - R3 active legacy sequence-surface selector disabled + +- Closed substage before moving on: stale legacy owner metadata removal is + complete and committed in `2d4dc2e`. +- Active invariant: supported CUDA Fabric sequence execution chooses the shared + flat temporal bucket route directly. Runtime should not ask for a + single-population cell recurrence surface before falling through to the flat + route. +- Live issue found: `forward_cells` and `forward_output_cells_for_readout` + still called `_select_backend_sequence_surface` even though planner-selected + CUDA routes are `flat_transition_buckets`. The selector returned `None`, so + this was dead active-route logic, but it preserved a misleading legacy branch + in the forward/readout path. +- Current implementation slice: + remove the active `_select_backend_sequence_surface` calls and delete the + selector method. `_select_output_cells_stream_backend_population` remains + inert in this historical checkpoint and returns `None`; final cleanup must + delete it rather than preserving it as a legacy bridge. Runtime tests now + assert supported CUDA routes use flat transition buckets and have no legacy + cell recurrence surface. +- Validation: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_uses_flat_temporal_bucket_executor tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_has_no_legacy_cell_surface -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_training_surface_uses_flat_bucket_route_without_single_population_selector tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_temporal_terminal_output_materializes_only_final_step_in_executor -n0` + passed, and ruff/check-format plus `git diff --check` passed on touched + files. + +### 2026-04-27 UTC - R2/R3 legacy cell recurrence route identity removed + +- Closed substage before moving on: active legacy sequence-surface selector + removal is complete and committed in `df19d75`. +- Active invariant: the temporal planner should expose only the shared flat + temporal route for supported CUDA Fabric sequence execution. A dead + `cell_recurrence_surface` route kind is still a legacy execution identity, + even if no active call reaches it. +- Current implementation slice: + remove `cell_recurrence_surface`, `uses_cell_recurrence_surface`, + `legacy_cell_recurrence_surface`, and + `legacy_population_recurrence_identity` from planner/type metadata. The + non-flat readout tiling fallback now treats backend sequence surface support + as false, while supported CUDA routes continue to use flat transition buckets. +- Validation: + `uv run pytest -q tests/test_fabric_backend_plan.py tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_uses_flat_temporal_bucket_executor tests/test_fabric_runtime.py::test_fabric_supported_cuda_route_has_no_legacy_cell_surface -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_training_surface_uses_flat_bucket_route_without_single_population_selector tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_sequence_supports_three_population_bindings -n0` + passed, `rg` finds the removed legacy route identifiers only in this log + entry, and ruff/check-format plus `git diff --check` passed on touched files. + +### 2026-04-27 UTC - R3 active-output identity cleanup start + +- Closed substage before moving on: legacy cell recurrence route identity + removal is complete and committed in `93d1917`. +- Active owner remains R3 shared temporal forward. Invariant: active-output + execution is an active-region/materialization policy inside the shared flat + temporal scan, not a separate temporal scan implementation or owner. +- Live issue found: the readout-only fresh-state path still records + `launch_scan_implementations=("active_output_window",)` and + `physical_op_executors` includes `active_window_static_buckets`. That makes + the row look like a sibling execution identity even though the planner should + expose one flat temporal scan over flat bucket identity. +- Current target slice: reclassify active-output metadata under + `flat_bucket_temporal_scan`, keep the readout dependency window only as + active-region/workspace policy, and add a regression test that the high-level + readout path no longer reports `active_output_window` as the scan + implementation. +- User correction received during this slice: backend throughput work has + priority over cleanup. Metadata/route cleanup is not closure evidence unless + the actual backend temporal engine exists and passes parity/performance. +- Priority reset: continue R3/R4 by moving hot temporal scan work into + backend-owned CUDA execution first, then use cleanup only after it unblocks or + validates that backend path. +- Actual backend slice started: the K>1 boundary temporal step path still used + Python population update logic inside the inner K loop after CUDA message + computation. The current code change routes CUDA flat-bucket rows through the + backend-order transition bucket executor for each internal K step, so + transition math and state updates stay in the Fabric CUDA flat-bucket backend + instead of the population update fallback. This is still not the final + one-shot CUDA temporal superop; it is a concrete reduction of Python-owned + inner-loop work while preserving the high-level model API. +- User reinforced that the shared temporal engine is the crux of REDO_FIXMASS. + The implementation priority is now explicitly: build/migrate the shared + temporal engine and throughput-critical CUDA owner first; use cleanup, + metadata relabeling, and route deletion only after they unblock or validate + the backend engine. +- Updated cortical skills generically so future agents do not misprioritize: + `cb.fabric-backend-boundaries`, `cb.fabric-performance-loop`, and + `cb.fabric-scaling-horizon` now state that shared temporal-engine/backend + throughput ownership comes before cleanup, and metadata-only changes do not + close backend stages. +- Implemented the backend-focused R3 slice: + `run_shared_temporal_bucket_forward_scan(...)` now owns scalar T*K scan + iteration, reset mapping, emission/materialization, recurrent KV carry reuse, + artifact collection, and final-state materialization for flat temporal bucket + forward. The physical autograd forward path calls this shared engine instead + of carrying its own duplicated scan loop. +- No-grad K>1 flat-bucket inference now also enters the shared temporal forward + scan, so the same scan/emission/reset path feeds both inference and the + physical backward path. The shared engine executes transition updates through + backend-order flat transition buckets and records + `forward_transition=backend_order_transition_buckets`. +- This is still transitional: `launch_temporal_scan_owners` remains + `python_autograd_scan` because the one-shot CUDA temporal superop has not yet + replaced the host loop. R3/R4 remain open until the scan loop itself moves + into the CUDA temporal superop and throughput audits pass. +- Validation: + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_temporal_terminal_output_materializes_only_final_step_in_executor tests/test_fabric_runtime.py::test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_unmaterialized_final_state_preserves_sequence_output_math tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_readout_closed_region_matches_pytorch_reference -n0` + passed, + `CUDA_VISIBLE_DEVICES=0 uv run pytest -q tests/test_fabric_runtime.py -k 'k_gt1 or terminal_output_boundary or temporal_terminal_output or flat_temporal_horizon' -n0` + passed, + `uv run pytest -q tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_marks_outer_emissions_and_resets tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_covers_k128_ceiling tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_does_not_store_large_tk_step_table tests/test_fabric_backend_plan.py::test_scalar_temporal_scan_schedule_maps_sequence_and_terminal_emissions -n0` + passed, `uv run pytest -q tests/test_fabric_execution_imports.py -n0` + passed, and ruff/check-format plus `git diff --check` passed on touched code. + +### 2026-04-27 UTC - R3/R4 CUDA temporal superop implementation start + +- Closed substage before moving on: shared flat-bucket forward scan + consolidation is complete and committed in `4809c4c`. +- Active owner remains R3/R4 temporal forward/backward. User explicitly + directed that the next work must be the CUDA temporal superop/kernel path, not + more cleanup. +- Current invariant: the shared temporal engine is not closed while the scan + loop is owned by Python. The next backend slice must move actual temporal + scan work into a Fabric CUDA-owned execution path and prove parity before + metadata or cleanup can count. +- Current target slice: implement the first CUDA temporal-superop forward owner + for a narrow flat-bucket row, route it from the shared temporal engine without + creating a public/benchmark knob, and keep unsupported rows fail-closed or on + the explicitly transitional shared Python scan until the full mixed-pop CUDA + superop is implemented. + +### 2026-04-27 UTC - Skill correction: kernel work is mandatory when it is the owner + +- User correction accepted as durable REDO_FIXMASS rule: if the high-priority + backend owner is a missing CUDA temporal kernel/superop, implementing that + kernel is the next owner. Planner-only edits, metadata relabeling, benchmark + organization, route polish, or cleanup cannot substitute for moving the hot + temporal scan/K/H/backward work into backend-owned CUDA execution. +- Updated `skills/cb.fabric-backend-boundaries/SKILL.md`, + `skills/cb.fabric-performance-loop/SKILL.md`, and + `skills/cb.fabric-scaling-horizon/SKILL.md` to make this explicit and generic. + +### 2026-04-27 UTC - Rejected narrow cell-specific CUDA temporal probe + +- During R3/R4 CUDA temporal-superop implementation, a tempting first slice was + an Axon-only diagonal recurrence kernel path. Rejected before implementation: + it would encode cell-family logic in the backend temporal owner and violate + the Fabric rule that cell differences are `fabric.cuda.nn` declarations and + low-level op buckets, not backend route identities. +- The active implementation target is therefore generic flat-bucket temporal + execution: the shared temporal engine may specialize on declared physical op + capabilities and bucket metadata, but not on population names, cell names, or + native cell ids. Single-population rows are one flat transition bucket; mixed + rows are multiple flat transition buckets under the same CUDA temporal owner. + +### 2026-04-27 UTC - GPU allocation constraint for audits + +- Durable audit constraint from user: my parallel experiments and benchmark + runs must use only GPUs 0-4. Use `CUDA_VISIBLE_DEVICES=0,1,2,3,4` for + multi-GPU audit workers and single devices within that range for local smoke + tests. Do not schedule my work on GPUs 5-7 unless the user explicitly changes + this constraint. +- Durable compile-cache constraint from user: do not use the global/default CUDA + extension cache because it can hang compiles. A stable project/run cache such + as `/tmp/cortical_torch_ext_$USER_redo_fixmass` is the default; create + separate cache roots only for parallel workers or when a stale compile is + suspected. It is not necessary to create a brand-new cache for every single + experiment. + +### 2026-04-27 UTC - R3/R4 compiled temporal dispatcher slice + +- Implemented the first backend-owned temporal forward slice for the shared + flat-bucket route. The CUDA dispatcher now accepts an explicit temporal scan + shape: `temporal_outer_time_steps`, `temporal_inner_steps`, and physical + `T=outer*K`. It materializes sequence outputs at outer-step cadence and + terminal outputs at the final physical step. +- Split reset semantics in the compiled dispatcher request: message/boundary + reset masks zero recurrent message sources only at inner step 0, while + transition reset masks apply at every physical inner step. This directly + targets the reset parity failures that occurred during April 24-26. +- Routed eligible high-level `model.forward(..., k>1)` inference rows through + the compiled dispatcher when the row is a single flat transition bucket, + CUDA float32, no edge delays, output-cell contract, and no autograd path. This + is recorded as `compiled_flat_bucket_temporal_scan:backend_host_loop` and + `single_flat_bucket_eligibility_slice`; it is not final R3/R4 closure because + mixed flat buckets and the device-side temporal superop are still open. +- Kept the owner label honest: launch metadata reports `backend_host_loop`, not + `cuda_temporal_superop`. R3/R4 remain open until the physical scan loop moves + into the CUDA temporal superop and throughput audits pass. +- Added high-level API parity coverage: + `test_fabric_cuda_single_flat_bucket_k_gt1_inference_uses_compiled_temporal_dispatcher` + exercises SLSTM and Axon through `model.forward`, K=2, reset rows, outer-step + output emission, and final-state materialization against the PyTorch backend. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass`: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_k_gt1_inference_uses_compiled_temporal_dispatcher -n0` + passed 2 cases in 104.18s; the broader K>1 group passed 7 cases in 81.16s. + A K=1 sequence-surface regression group passed 3 cases in 4.68s with the same + cache and GPU constraint. + +### 2026-04-27 UTC - R3 K=1 base-case dispatcher widening + +- Closed substage before moving on: the compiled dispatcher K>1 eligibility + slice is complete and committed in `1f665f5`. +- Active invariant: `K=1` is not a separate Fabric route. It is the base + streaming temporal case and must share the same flat-bucket temporal + dispatcher family as `K>1` when the row is eligible. +- Current implementation slice: + widen the compiled flat-bucket dispatcher attempt from `constant_k > 1` to + `constant_k >= 1` for eligible high-level no-grad single flat-bucket CUDA + rows. The fallback shared Python scan remains K>1-only after a compiled + dispatcher miss, so this does not create a new user-visible route. +- Updated high-level API parity coverage: + `test_fabric_cuda_single_flat_bucket_inference_uses_compiled_temporal_dispatcher` + now covers both `K=1` and `K=2` for SLSTM and Axon with reset rows, sequence + output emission, and final-state materialization. The existing + single-population flat-bucket K=1 forward test now expects the compiled + dispatcher metadata instead of the old temporal bucket sequence identity. +- Owner remains honest and open: metadata records + `compiled_flat_bucket_temporal_scan:backend_host_loop` and + `launch_temporal_scan_owners=("backend_host_loop",)`. This is still not + `cuda_temporal_superop`; R3/R4 stay open until the physical scan loop, mixed + flat buckets, and backward scan move into the fully owned CUDA temporal + superop and pass the April 21-referenced throughput/parity audits. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass`: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_inference_uses_compiled_temporal_dispatcher -n0` + passed 4 cases in 5.34s; + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_inference_uses_compiled_temporal_dispatcher -n0` + passed 6 cases in 5.77s; the focused K/reset/materialization group covering + single-pop K>1, mixed-pop K>1, terminal-loss gradients, fresh-state mixed-pop, + and this widened dispatcher route passed 12 cases in 5.43s after formatting. + Ruff check, Ruff format check, and `git diff --check` passed on the touched + Python files. + +### 2026-04-27 UTC - R3 first CUDA temporal superop owner slice + +- Closed substage before moving on: K=1 base-case dispatcher widening is + complete and committed in `1549359`. +- Active invariant: the remaining R3 owner is not metadata. The physical + temporal scan loop must move from `backend_host_loop` into a Fabric CUDA + temporal superop where the low-level ops are declared through + `fabric.cuda.nn`/physical bucket capabilities, not through cell names. +- Current implementation slice: + added `execution/temporal_superop.cu` and `.cuh`, compiled into the Fabric + dispatcher extension. The first device-owned temporal scan is a cooperative + CUDA kernel for eligible receiver-owned single flat-bucket rows whose + transition bucket declares the generic diagonal-recurrence IR, regular-local + tiny-message direct projection, full receiver surface, no forward-carry + checkpoint, no autograd path, and materialized carry. It fuses the physical + T*K loop for message projection, diagonal transition, public projection, and + readout emission inside one CUDA temporal owner. +- Attribution is now owner-derived instead of hardcoded: the compiled + flat-bucket dispatcher records `cuda_temporal_superop` for eligible diagonal + recurrence rows and keeps `backend_host_loop` for unsupported rows such as + state-affine SLSTM. The route still records + `compiled_flat_bucket_temporal_scan:` so audit metadata distinguishes + the first real CUDA superop slice from the remaining transitional dispatcher. +- Reset parity remains explicit: the superop consumes the split message/boundary + reset mask and transition reset mask, preserving the April 24-26 reset lesson. +- Open owner after this slice: + R3/R4 are still not closed. State-affine transition buckets, mixed flat + buckets, output-dependency windows, autograd/backward, H/checkpoint/recompute, + and throughput audits still need the fully shared CUDA temporal superop path. + No cleanup/deletion stage may claim closure until those owners move and pass + April 21-referenced parity/performance gates. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_temporal_superop`: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_inference_uses_compiled_temporal_dispatcher -n0` + passed 4 cases in 5.67s, covering SLSTM host-loop fallback plus Axon + `cuda_temporal_superop` for K=1 and K=2. The focused six-case set covering + the existing K=1 flat-bucket route and the widened dispatcher route passed in + 5.96s. The broader K/reset/materialization group covering single-pop K>1, + mixed-pop K>1, terminal-loss gradients, fresh-state mixed-pop, and the new + superop route passed 12 cases in 81.12s. + +### 2026-04-28 UTC - R3 state-affine CUDA temporal superop owner slice + +- Closed substage before moving on: the first diagonal-recurrence CUDA temporal + superop slice is complete and committed in `86c03db`. +- Active invariant: SLSTM/state-affine rows are not allowed to stay on + `backend_host_loop` merely because their local math differs. The shared + temporal owner may dispatch through cell-local `CellCore` math, but execution + selection must come from `fabric.cuda.nn`/physical bucket capabilities, not + from cell-family names. +- Current implementation slice: + added a registry-dispatched state-affine CUDA temporal superop for eligible + receiver-owned single flat-bucket no-grad rows. Eligibility is the generic IR + family: regular-local tiny-message direct projection, two receiver-major + state affines (`projected_message` overwrite plus `state_prev` accumulate), + full flat receiver surface, materialized carry, no output-dependency window, + no forward-carry checkpoint, no autograd path, and separate port-owned + readout. +- The new superop fuses the physical `T*K` loop for message projection, + state-affine contribution materialization, `CellCore` state update/emit, + public projection, and readout emission inside one cooperative CUDA temporal + owner. The cell-local recurrent equation is still supplied by the registered + `CellCore`; the backend route is selected from the state-affine IR family and + flat bucket identity. +- Attribution moved for the high-level no-grad single flat-bucket SLSTM rows: + both SLSTM/state-affine and Axon/diagonal rows now report + `compiled_flat_bucket_temporal_scan:cuda_temporal_superop` and + `launch_temporal_scan_owners=("cuda_temporal_superop",)` for K=1 and K=2. + Reset semantics remain split between message/boundary reset masks and + transition reset masks. +- Open owner after this slice: + R3/R4 remain open. Mixed flat buckets, output-dependency windows, autograd + and temporal backward, H/checkpoint/recompute ownership, optimized + state-affine parallelization, K=1/T/K/H throughput audits against the April + 21 score reference, and legacy path deletion are still required before + REDO_FIXMASS closure. This state-affine kernel is real owner movement, not a + performance closure claim. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_state_affine_superop`: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_inference_uses_compiled_temporal_dispatcher -n0` + passed 4 cases in 70.80s after the first compile; + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor -n0` + passed 2 cases in 4.20s; the broader K/reset/materialization group covering + single-pop K>1, mixed-pop K>1, terminal-loss gradients, fresh-state mixed-pop, + and the new SLSTM/Axon CUDA temporal superop route passed 12 cases in 82.63s. + Ruff check, Ruff format check, and `git diff --check` passed on touched files. + +### 2026-04-28 UTC - R3 state-affine temporal superop throughput slice + +- Closed substage before moving on: state-affine no-grad single flat-bucket + rows now have a real `cuda_temporal_superop` owner and are committed in + `edff617`. +- Mixed flat-bucket scan inspection: + the current mixed-pop path still represents state as per-population local + banks and uses `compute_temporal_bucket_step_artifacts` / + `run_backend_order_transition_buckets_step_cached_eager_result` from a + Python temporal loop. Moving mixed-pop into one CUDA temporal superop remains + the next structural R3 owner, but it needs a graph-wide flat state/public + representation across multiple bucket schemas rather than a per-pop launch + wrapper. Do not close mixed-pop R3 with metadata or a wrapper. +- Current implementation slice: + optimized the new state-affine CUDA temporal superop by changing the + state-update/emit phase from one thread serializing all hidden lanes for a + `(batch, receiver)` row to one warp per row. `CellCore::forward_state_chunk` + now runs with lane/stride ownership, reduction stats are combined across the + warp, and `CellCore::emit_public_chunk` emits raw public values in parallel. +- Scope: + this is a throughput-oriented kernel improvement inside the existing + state-affine superop owner. It does not weaken the remaining open owners: + mixed flat buckets, output-dependency windows, autograd/backward, + H/checkpoint/recompute ownership, and April 21-referenced throughput audits + remain open. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_state_affine_superop`: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_inference_uses_compiled_temporal_dispatcher -n0` + passed 4 cases in 69.16s after recompilation; the broader + K/reset/materialization group covering single-pop K>1, mixed-pop K>1, + terminal-loss gradients, fresh-state mixed-pop, and the CUDA temporal superop + route passed 12 cases in 5.56s. Ruff check, Ruff format check, and + `git diff --check` passed. + +### 2026-04-28 UTC - R3 output-only temporal superop readout boundary probe + +Status: ACCEPTED + +Owner: shared temporal forward / CUDA temporal superop readout boundary. + +- Active invariant: + output-only `T*K>1` is still the same shared temporal scan, but the CUDA + temporal superop must implement the requested readout output boundary itself. + It is not valid to write per-output-port values into a pooled output buffer or + to use a one-step compact readout dependency window across multiple recurrent + timesteps unless the window is closed over recurrent/message dependencies. +- Rejected probe: + compact active readout windows were temporarily relaxed for the temporal + superop. High-level SLSTM and Axon output-only rows exposed parity drift, and + a host-loop check showed the one-step readout dependency cone is not a valid + temporal substrate for SLSTM over `T*K>1`. The safe planner rule is therefore + to keep compact readout dependency windows single physical step only until a + graph-closure proof exists. +- Current code finding: + the host dispatcher has explicit mean-pooled readout handling, but the CUDA + temporal superop readout kernel currently iterates `plan.output_ports` and + writes each physical output port directly into `readout.output_state`. For a + pooled boundary the output tensor has one slot, so output-only `T>1` rows can + report `cuda_temporal_superop` while producing wrong pooled values. +- Next implementation slice: + completed. The diagonal and state-affine CUDA temporal superops now compute + mean-pooled output boundaries directly by averaging the per-port projected + readout values into the single pooled slot. Output-only `T>1` rows can now + keep `cuda_temporal_superop` ownership without writing per-output-port values + into a pooled output tensor. Metadata records sequence-output materialization, + final-state materialization, and pooled mean readout boundary. +- Safety boundary: + the compact active-window temporal-superop probe was rolled back from the + accepted patch. The follow-up compact-window slice below narrows this to a + one-physical-step rule: compact readout dependency windows are eligible only + when `time_steps == 1` and `temporal_inner_steps == 1`; `T*K>1` output-only + rows use the full recurrent surface until a graph-closure proof and parity + matrix exist for compact temporal windows. +- Regression coverage added: + `test_fabric_cuda_single_flat_bucket_output_only_pooled_readout_uses_temporal_superop` + covers the high-level `model(...)` API, SLSTM and Axon, K=1 and K=2, + reset/no-reset, sequence output, `materialize_final_state=False`, empty + returned state, full-surface ownership, and `cuda_temporal_superop` + attribution. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_pooled_readout_superop`: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_output_only_pooled_readout_uses_temporal_superop -n0` + passed 8 cases in 73.35s after recompilation; + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_inference_uses_compiled_temporal_dispatcher tests/test_fabric_runtime.py::test_fabric_cuda_single_flat_bucket_output_only_pooled_readout_uses_temporal_superop -n0` + passed 12 cases in 5.85s; the broader K/reset/materialization group covering + single-pop K>1, mixed-pop K>1, terminal-loss gradients, fresh-state mixed-pop, + and both materialized and output-only CUDA temporal superop routes passed + 20 cases in 7.29s. Ruff check, Ruff format check, and `git diff --check` + passed. +- Next owner: + mixed flat-bucket forward remains Python-loop/transitional because the live + request still carries per-population local state/public banks. The next R3/R4 + backend step is graph-wide flat state/public packing for multiple transition + buckets under one shared CUDA temporal owner, followed by temporal backward + ownership. Compact temporal active windows remain a later owner after graph + closure semantics are explicit. + +### 2026-04-28 UTC - R3 one-physical-step compact temporal superop probe + +Status: ACCEPTED + +Owner: shared temporal forward / CUDA temporal superop active readout window. + +- Active invariant: + compact readout dependency windows are only semantically safe when the + physical temporal scan has exactly one recurrent step, unless the planner has + proven graph closure over recurrent/message dependencies. Therefore the guard + is `outer_time_steps * inner_steps == 1`, not merely `outer_time_steps == 1`. +- Implementation target: + completed. Compact-window support is now enabled only for the + one-physical-step CUDA temporal superop case. The diagonal and state-affine + kernels keep compact state/public rows local while using global receiver ids + for receiver-major parameters and public-projection bindings. `T*K>1` + compact windows remain demoted to the full recurrent surface until + graph-closure semantics are explicit. +- Dispatcher/planner guard: + `state_output_required` and `public_output_required` now include + `preserve_internal_carry`, allowing a one-step output-only superop to allocate + the compact internal state/public banks it needs without returning user-visible + state. The single-pop compiled scan requests this internal carry only when + `physical_time_steps == 1`. The sequence-surface active-window predicate now + requires both `time_steps == 1` and `temporal_inner_steps == 1`. +- Regression coverage added: + `test_fabric_cuda_single_flat_bucket_output_only_one_physical_step_window_uses_temporal_superop` + covers high-level `model(...)` output-only parity against PyTorch for SLSTM + and Axon, `T=1,K=1` reset/no-reset with `readout_dependency_cone` and + `cuda_temporal_superop`, and `T=1,K=2` reset/no-reset full-surface demotion + with `cuda_temporal_superop`. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_compact_temporal_superop`: + the new compact-window test passed 8 cases in 110.77s after recompilation; + the dispatcher/materialized + output-only superop group passed 20 cases in + 4.77s; the broader K/reset/materialization group covering single-pop K>1, + mixed-pop K>1, terminal-loss gradients, fresh-state mixed-pop, pooled + output-only full-surface, and compact one-step CUDA temporal superop routes + passed 28 cases in 82.08s. Ruff check, Ruff format check, and + `git diff --check` passed. +- Next owner: + this closes the single-bucket output-only compact T=1 owner, but does not + close mixed flat-bucket temporal ownership. Mixed-pop forward still needs the + graph-wide flat state/public representation that can be consumed by one CUDA + temporal superop across multiple transition buckets. + +### 2026-04-28 UTC - R3 mixed flat carry cache materialization slice + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / graph-wide flat carry boundary. + +- Active invariant: + mixed-population Fabric must remain a flat-bucket engine problem. The + backend may branch on flat bucket identity, state schema, public schema, + physical op family, and binding metadata, but not on cell names such as SLSTM + or Axon. A population-local TensorDict view is a user/materialization view; + it should not be rematerialized inside each physical temporal step when the + backend already owns a packed flat-bucket state cache. +- Implementation: + completed. The shared temporal bucket scan now keeps intermediate mixed-pop + state in the backend-order flat bucket state cache when no artifact store and + no final user-visible state require population TensorDict materialization. + Population views are only rematerialized for artifact/checkpoint collection + or final-state materialization. Reset handling now resets the backend-order + cache and graph-level `cells` tensor without requiring empty population views + to pass through cell-local reset code. +- Metadata: + backend records now expose `flat_bucket_state_cache:*`, + `flat_bucket_state_cache_materialized_steps:*`, and + `flat_bucket_state_cache_elided_steps:*` workspace aliases. The accepted + no-final-state mixed-pop K>1 route reports + `backend_order_flat_cache_population_views_elided` and six elided physical + steps for `T=3,K=2`. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_mixed_flat_carry_cache_rerun`: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_output_only_uses_backend_order_flat_carry_cache -n0` + passed 2 cases in 4.89s after the reset fix. The broader guard covering + single-pop K>1, compiled single-bucket inference, pooled and one-step compact + output-only temporal superop routes, mixed-pop K>1, mixed terminal-loss + gradients, single-pop sequence executor, and mixed fresh-state/no-final-state + parity passed 30 cases in 75.33s. +- Remaining owner: + this reduces per-step mixed-pop TensorDict materialization, but does not + close R3/R4. The scan owner is still `python_autograd_scan`; mixed flat-bucket + forward/backward remain open until the temporal scan loop itself is owned by + the shared CUDA temporal superop and the throughput audit passes. + +### 2026-04-28 UTC - R3 constant-K no-grad route convergence + +Status: ACCEPTED + +Owner: shared temporal forward / mixed flat-bucket route convergence. + +- Active invariant: + K=1 is the base streaming case, not a separate active-output runtime path. + If the compiled single-bucket CUDA temporal superop cannot own a no-grad + constant-K sequence row, the row should fall through to the shared temporal + bucket scan for both single and mixed population cardinalities. Separate + active-output host loops are legacy route divergence unless they are backed + by the same shared temporal engine contract. +- Implementation: + completed. The projected-boundary and boundary-sequence no-grad CUDA routes + no longer take the separate active-output shortcut before the shared temporal + scan. After the compiled single-bucket CUDA temporal superop declines a row, + the shared temporal bucket scan now handles constant `K >= 1`, not only + `K > 1`. Fresh no-final-state rows that arrive with only graph-level `cells` + now initialize backend-order zero packed state for each active flat bucket, + preserving rolling carry without materializing population TensorDict views. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_constant_k_route`: + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k1_output_only_uses_shared_temporal_bucket_scan -n0` + passed 2 cases in 118.70s after first compile. The fresh-state regression + guard plus the new K=1 route guard passed 3 cases in 5.57s after initializing + fresh backend-order packed state. The broader adjacency guard covering single + temporal superop rows, mixed K=1/K>1 output-only rows, mixed K>1 training, + terminal-loss gradients, single-pop sequence executor, fresh no-final-state, + and mixed readout-closure parity passed 33 cases in 6.76s. +- Remaining owner: + this removes one legacy host-loop route divergence and makes K=1 use the same + shared temporal scan surface as K>1 when the compiled CUDA temporal superop + cannot own the row. R3/R4 still remain open because mixed forward/backward + scan ownership is still `python_autograd_scan`, not the final + `cuda_temporal_superop`. + +### 2026-04-28 UTC - R3 flat bucket identity schema cleanup + +Status: ACCEPTED + +Owner: shared temporal forward / mixed flat-bucket CUDA temporal superop prep. + +- Active invariant: + flat bucket identity is the semantic transition/schema identity that the + shared temporal engine consumes. It must not be keyed by population name, + public cell type, or cell kind. Population slot and binding names belong only + in binding/materialization metadata. +- Implementation target: + completed. `cell_type=` and `cell_kind=` were removed from live runtime + flat-bucket identity and backend IR transition signatures. Identities now use + generic transition IR/schema fields: state/public schema, + parameter/projection schema, transition state/message/parameter inputs, + state/public/recompute outputs, backward decomposition, transition ops, and + generic parameter binding sources. This is a guard before graph-wide + multi-bucket CUDA temporal superop descriptors are introduced. +- Verification: + `uv run pytest -q tests/test_fabric_backend_plan.py::test_fabric_bucket_identity_is_not_keyed_by_population_name tests/test_fabric_backend_plan.py::test_temporal_bucket_plan_exposes_flat_bucket_identity tests/test_fabric_backend_plan.py::test_temporal_bucket_plan_accepts_more_than_two_binding_populations tests/test_fabric_backend_plan.py::test_planner_cache_keys_flat_bucket_identity_not_binding_slot -n0` + passed 4 tests in 5.05s. `uv run pytest -q tests/test_fabric_backend_plan.py + -n0` passed 48 tests in 5.26s. On GPU 0 with + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_identity_schema`, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity tests/test_fabric_runtime.py::test_fabric_cuda_flat_temporal_sequence_supports_three_population_bindings -n0` + passed 3 tests in 116.22s. Ruff check, Ruff format check, and + `git diff --check` passed. +- Closure note: + this does not close R3/R4 by itself. R3/R4 remain open until the mixed + flat-bucket temporal scan loop and temporal backward owner move into the + CUDA temporal superop and audits pass. + +### 2026-04-28 UTC - R3 mixed flat-bucket graph-order layout kernel + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / CUDA temporal superop layout prep. + +- Active invariant: + graph-order assembly is a backend layout operation over flat bucket identity, + not population/cell logic. The mixed temporal scan should not pay a per-step + PyTorch `index_select` plus `cat` tax when the backend already owns + backend-order recurrent buckets and the inverse graph-order map. +- Implementation target: + completed. Added Fabric CUDA flat-bucket layout kernels under the sequence + surface: one kernel reorders recurrent hidden cells from backend bucket order + to graph order, and one kernel assembles boundary, recurrent, and output cells + into the full graph-order `cells` bank. The mixed temporal bucket scan now + routes CUDA float32 rows through these kernels and keeps the existing PyTorch + layout path as a closed fallback for unsupported devices/dtypes. Runtime + metadata records `flat_bucket_recurrent_graph_layout:*` and + `flat_bucket_graph_order_layout:*` aliases. +- Verification: + on GPU 0 with + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_flat_layout`, + `uv run pytest -q tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k1_output_only_uses_shared_temporal_bucket_scan tests/test_fabric_runtime.py::test_fabric_cuda_mixed_population_k_gt1_output_only_uses_backend_order_flat_carry_cache -n0` + passed 4 tests in 153.91s after first compile. The adjacent mixed temporal + parity group covering reset/no-reset horizon, three-population bindings, + K>1 sequence ownership, terminal-loss emission gradients, and provided-state + gradients passed 6 tests in 4.69s against the same cache. Ruff check, Ruff + format check, and `git diff --check` passed. +- Closure note: + this is a kernel-backed hot-loop cleanup, not final R3/R4 closure. The + remaining owner is still moving the scan loop, transition buckets, message, + public projection, readout, reset, checkpoint, and temporal backward into the + shared CUDA temporal superop. + +### 2026-04-28 UTC - R3 mixed backend-order public carry slice + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / CUDA temporal superop carry prep. + +- Active invariant: + recurrent public carry is a flat-bucket layout object. For no-tape mixed + inference rows, the temporal scan should keep recurrent hidden/K/V in backend + bucket order through readout and only materialize graph-order cells at the + graph boundary. Output sender tables may be remapped by flat bucket layout, + but backend selection must not branch on population names, cell kinds, or + cell-specific logic. +- Implementation: + completed. Runtime now registers flat-bucket carry-order sender tables for + recurrent and output message receivers without changing the existing + graph-order sender tables used by training/backward artifacts. Inference + static tensors now include backend-order recurrent public projection weights. + The shared temporal bucket step uses backend-order recurrent hidden/K/V for + the full no-final-state output-only scan when transition tape is disabled, + and only reorders recurrent hidden to graph order at the final graph boundary + for `cells` assembly. Training/artifact rows intentionally remain on + `flat_bucket_public_carry:graph_order` until the temporal backward owner + moves. +- Metadata: + backend execution records now expose `flat_bucket_public_carry:*`. Accepted + mixed K=1/K>1 output-only CUDA rows with and without resets report + `flat_bucket_public_carry:backend_order`, plus the existing CUDA recurrent + graph-order layout and graph-order cell assembly aliases. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_backend_public_carry`: + the mixed K=1/K>1 output-only rows passed 4 tests in 149.63s after first + compile. The adjacent reset/backward guard covering shared mixed reset + parity, three-population bindings, terminal-loss emission gradients, and + provided-state gradients passed 5 tests in 4.85s. Ruff check, Ruff format + check, and `git diff --check` passed. +- Closure note: + this closes the backend-order public-carry hot-loop slice only. R3/R4 remain + open because `launch_temporal_scan_owners` is still `python_autograd_scan`; + the next backend owner is moving the scan loop itself into the shared CUDA + temporal superop without introducing cell-specific backend logic. + +### 2026-04-28 UTC - R3 mixed backend-order public projection kernel + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / CUDA temporal superop public carry prep. + +- Active invariant: + backend-order recurrent K/V projection is a flat-bucket public projection + operation. It should be owned by Fabric CUDA low-level ops and keyed by the + flat bucket carry layout, not by cell names or population-specific backend + routing. +- Implementation: + completed. Added a Fabric CUDA flat-bucket public projection extension that + projects backend-order recurrent hidden directly into separate K/V banks. The + mixed no-final-state temporal scan uses it for backend-order public carry; + unsupported rows fall back to the existing runtime projection, and + graph-order training/artifact rows stay on the existing path. +- Metadata: + backend execution records now expose `flat_bucket_public_projection:*`. + Accepted mixed output-only rows report + `flat_bucket_public_projection:fabric_cuda_flat_bucket_public_projection`. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_public_projection_kernel`: + mixed K=1/K>1 output-only rows passed 4 tests in 188.44s after first + compile. The adjacent reset/backward guard covering shared mixed reset + parity, three-population bindings, terminal-loss emission gradients, and + provided-state gradients passed 5 tests in 4.76s. Ruff check, Ruff format + check, and `git diff --check` passed. +- Closure note: + this closes only the public-projection kernel slice. R3/R4 remain open until + the temporal scan loop, transition bucket dispatch, readout, reset handling, + and temporal backward are owned by the shared CUDA temporal superop. + +### 2026-04-28 UTC - R3 mixed flat-bucket fused readout kernel + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / CUDA temporal superop readout prep. + +- Active invariant: + output readout is graph-level flat-bucket work. The mixed no-final-state scan + should not require a Python-owned message call followed by a separate output + projection call when the backend already owns input/recurrent public banks, + output sender tables, and readout parameters. +- Implementation: + completed. Added a Fabric CUDA flat-bucket fused readout extension for the + backend-order public-carry path. The kernel consumes partitioned + input/recurrent K/V banks, backend-order flat-bucket output sender tables, + output queries, `value_to_output_weight`, and `output_cell_bias`, computes + the local attention readout, and materializes output cells directly. It is + enabled only for local-message, no-delay, backend-order public-carry rows; + unsupported rows stay on the existing message-plus-output-projection path. +- Metadata: + temporal sequence records now expose `flat_bucket_readout:*`. Accepted mixed + output-only K=1/K>1 rows report + `flat_bucket_readout:fabric_cuda_flat_bucket_readout_fused`. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_fused_readout_kernel`: + mixed K=1/K>1 output-only reset/no-reset rows passed 4 tests in 185.59s + after first compile. The adjacent reset/backward guard covering shared mixed + reset parity, three-population bindings, terminal-loss emission gradients, + and provided-state gradients passed 5 tests in 45.38s. Ruff check, Ruff + format check, and `git diff --check` passed. +- Closure note: + this closes only the fused-readout kernel slice. R3/R4 remain open until the + scan loop, transition bucket dispatch, reset handling, checkpointing, and + temporal backward are owned by the shared CUDA temporal superop and throughput + audits pass. + +### 2026-04-28 UTC - R3 mixed backend-order graph materialization elision + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / CUDA temporal superop carry prep. + +- Active invariant: + graph-order cells are a user/materialization boundary, not the internal + rolling carry for backend-order flat buckets. A no-final-state mixed scan + that already owns backend-order recurrent state, K/V carry, public + projection, and fused readout should not rebuild graph-order recurrent cells + and full graph cells at every physical step. +- Implementation: + completed. No-final-state backend-order public-carry scans now keep compact + recurrent K/V and backend population state cache as the rolling internal + carry, instead of rebuilding graph-order recurrent cells and full graph cells + every physical step. `compute_temporal_bucket_step_artifacts()` still + materializes graph-order tensors for backward artifacts, final-state + materialization, and graph-order carry paths. +- Metadata: + accepted output-only mixed rows now report + `flat_bucket_recurrent_graph_layout:elided_backend_order_no_final_state` and + `flat_bucket_graph_order_layout:elided_backend_order_no_final_state`. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_fused_readout_kernel`: + mixed K=1/K>1 output-only reset/no-reset rows passed 4 tests in 5.29s on + the warmed cache. The adjacent reset/backward guard covering shared mixed + reset parity, three-population bindings, terminal-loss emission gradients, + and provided-state gradients passed 5 tests in 5.08s. Ruff check, Ruff + format check, and `git diff --check` passed. +- Closure note: + this removes another per-step graph-order layout/materialization owner from + mixed no-final-state inference. R3/R4 remain open because the temporal scan + loop and multi-bucket transition dispatch are still host/Python-owned rather + than one shared CUDA temporal superop. + +### 2026-04-28 UTC - R3 mixed flat-bucket CUDA temporal scan owner + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / CUDA temporal superop. + +- Active invariant: + the next high-priority backend owner is the physical scan loop itself. The + backend may specialize on transition IR families, but not on population names + or cell identities. Single and mixed population cardinalities remain the same + flat-bucket execution model; this substage is a guarded mixed no-final-state + output path and must fall back cleanly for unsupported state/backward/audit + rows. +- Implementation: + completed for the guarded fresh no-final-state mixed output route. Added a + Fabric CUDA flat-bucket temporal scan extension that owns the physical `T*K` + loop for backend-order mixed flat buckets over the current generic transition + IR families: gated log-space recurrence and diagonal RTU. The cooperative + kernel consumes input K/V sequence, recurrent/output local sender tables, + transition parameters, backend-order public projection weights, and readout + parameters. It computes recurrent messages, transition updates, recurrent + K/V projection, reset handling, and output readout inside the device temporal + scan. Accepted rows now report `launch_temporal_scan_owners` and + `launch_scan_implementations` as `cuda_temporal_superop`. +- Reset parity note: + recurrent reset rows must keep valid recurrent edges in the attention + denominator with zero K/V value contribution. The CUDA scan now matches the + existing temporal reset semantics for K=1 and K>1 reset rows. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_mixed_scan_superop3`: + mixed K=1/K>1 output-only reset/no-reset rows passed 4 tests in 5.31s on the + warmed cache and reported `cuda_temporal_superop`. Manual parity against the + PyTorch reference for K=1/K=2 with and without resets had max absolute diff + `2.98e-08`. The adjacent guard covering shared mixed reset parity, + three-population fallback, terminal-loss gradient routing, provided-state + gradients, and fresh no-final-state runtime parity passed 6 tests in 6.73s. + The combined 10-test group passed in 5.20s. Ruff check, Ruff format check, + and `git diff --check` passed. +- Non-closure note: + this is not full R3/R4 closure by itself. Provided-state rows, final-state + materialization, training artifacts, temporal backward, checkpointing, and + throughput audits remain gates before the legacy scan owner can be deleted. + +### 2026-04-28 UTC - R3 mixed CUDA temporal scan provided-state carry + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / CUDA temporal superop provided-state carry. + +- Active invariant: + provided state is just the initial flat-bucket carry for the shared temporal + engine. It must not introduce population-name or cell-specific routing in the + backend, and it must preserve reset semantics: reset rows ignore provided + recurrent values for value contribution while still keeping valid recurrent + edges in the attention denominator. +- Implementation target: + completed for the guarded mixed no-final-state output route. The CUDA + temporal superop now accepts optional initial flat-bucket carry tensors: + backend-order recurrent K/V, gated log-space private state, and diagonal RTU + private state. Fresh rows still synthesize zero carry inside the superop; + provided-state rows clone the supplied flat carry once before the device + scan. Reset state is now controlled by the fresh/provided initial-state flag, + so non-reset first steps consume provided carry while reset rows still clear + transition state and recurrent value contribution correctly. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_mixed_scan_provided_state`: + the new provided-state output-only parity row passed K=1/K=2 with and + without resets (`4 passed in 226.79s` after first compile) and reported + `cuda_temporal_superop`. The adjacent mixed temporal guard covering fresh + K=1/K>1 output rows, provided-state output rows, shared mixed reset parity, + three-population fallback, terminal-loss gradient routing, provided-state + gradients, and fresh no-final-state runtime parity passed `14 passed in + 7.61s` on the warmed cache. A final warmed focused pass for fresh and + provided-state output-only rows passed `8 passed in 5.08s`. Ruff check, Ruff + format check, and `git diff --check` passed. +- Non-closure note: + this closes only provided-state carry for guarded no-final-state mixed + output inference. Final-state materialization, training artifacts, temporal + backward, checkpointing, single-pop convergence onto the same superop, and + throughput audits remain open before R3/R4 can close or legacy scan paths can + be deleted. + +### 2026-04-28 UTC - R3 mixed CUDA temporal scan final-state materialization + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal forward / CUDA temporal superop final-state materialization. + +- Active invariant: + final-state materialization is a graph boundary over the same flat-bucket + temporal scan. The scan may return final flat carry and state buffers, but it + must not reintroduce host-owned temporal loops or population-specific backend + routes. Diagonal trace state is part of the generic diagonal transition + family state, so final-state parity must include it. +- Implementation target: + completed for the guarded no-grad mixed route. The CUDA temporal superop can + now optionally return final backend-order recurrent hidden, recurrent K/V, + gated private state, diagonal private state, and diagonal trace state. The + runtime assembles user-visible graph-order `cells`, sender K/V, and + population TensorDict state at the graph boundary. For materialized-state + rows, the scan materializes output cells internally even when the public + output contract is pooled, so final `cells` and user output are both derived + from the same device-owned scan. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_mixed_scan_final_state`: + an ABI compile/smoke run for the existing no-final-state mixed K=1 row + passed in 76.55s after first compile. The new materialized final-state + parity row passed fresh/provided initial state, K=1/K=2, and reset coverage + (`4 passed in 41.69s`) with `cuda_temporal_superop` ownership. The broader + adjacent mixed temporal guard covering fresh/provided no-final-state, + fresh/provided final-state, shared mixed reset parity, three-population + fallback, terminal-loss gradient routing, provided-state gradients, and + fresh no-final-state runtime parity passed `18 passed in 84.99s`. Ruff + check, Ruff format check, and `git diff --check` passed. +- Non-closure note: + this closes no-grad final-state materialization for the guarded mixed + temporal superop only. Training artifact capture, temporal backward, + checkpoint/recompute ownership, single-pop convergence onto the same superop, + throughput audits, benchmark refactor, and legacy path deletion remain open. + +### 2026-04-28 UTC - R3/R4 mixed CUDA temporal scan training forward with recompute artifacts + +Status: ACCEPTED + +Owner: mixed flat-bucket temporal training forward / CUDA temporal superop artifact bridge. + +- Active invariant: + training uses the same flat-bucket temporal forward semantics as inference. + The forward scan should be device-owned when the row is otherwise supported; + backward may still recompute artifacts through the existing temporal backward + path until the CUDA backward superop owner moves. This must be recorded as a + transitional recompute-artifact bridge, not as full backward closure. +- Implementation: + completed for guarded mixed training rows. `collect_artifacts=True` no + longer forces the supported mixed flat-bucket route back to the Python + autograd scan. CUDA owns the forward temporal scan and final-state + materialization, then the runtime attaches a + `recompute_step_artifacts` store rooted at the initial checkpoint so the + existing physical temporal backward can rebuild step artifacts for + gradients. +- Metadata: + accepted training rows now report `launch_temporal_scan_owners` and + `launch_scan_implementations` as `cuda_temporal_superop`, + `temporal_artifacts:recompute_step_artifacts`, and + `flat_bucket_state_cache:cuda_temporal_superop_internal_flat_state`. + Backward still reports `physical_temporal_bucket_sequence_backward`. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_mixed_scan_final_state`: + terminal-loss and provided-state gradient parity rows passed `2 passed in + 5.07s`. The broader adjacent mixed temporal guard covering fresh/provided + output-only rows, final-state materialization, reset parity, three-population + bindings, terminal-loss gradients, provided-state gradients, and K=128 + backward gradient mapping passed `19 passed in 7.35s`. Ruff check, Ruff + format check, and `git diff --check` passed. +- Non-closure note: + this closes CUDA-owned forward execution for supported mixed training rows + only. The backward pass is still the existing physical temporal bucket + sequence backward with recompute artifacts, and intermediate checkpoint + production is still not owned by the CUDA temporal superop. R3/R4 remain open + for the CUDA backward superop, planner-owned checkpoint production, single-pop + convergence onto the same temporal superop, throughput audits, benchmark + refactor, and legacy path deletion. + +### 2026-04-28 UTC - R4 CUDA temporal superop planner-stride checkpoints + +Status: ACCEPTED AS ABI SLICE; CONSUMPTION OPEN + +Owner: mixed flat-bucket temporal checkpoint production / CUDA temporal superop. + +- Active invariant: + recompute checkpoints are part of the shared temporal engine, not a Python + replay policy. For supported mixed flat-bucket rows, the CUDA temporal + superop should materialize planner-stride flat carry checkpoints while it owns + the forward scan. Backward can still recompute artifacts from those + checkpoints through the existing physical temporal backward until the full + CUDA backward superop lands. +- Implementation: + completed the guarded ABI/kernel production slice. The gated+diagonal + flat-bucket temporal scan can now optionally emit planner-stride checkpoint + tensors: backend-order recurrent hidden, recurrent K/V, gated private state, + diagonal private state, diagonal trace state, and output cells for graph + boundary reconstruction. Training metadata records + `checkpoint_owner=cuda_temporal_superop` and the produced checkpoint count. +- Rejected consumption probe: + consuming those CUDA-produced checkpoints inside the current Python/physical + recompute path produced a hybrid numerical path and broke strict parameter + gradient parity. The first focused probe failed the two mixed terminal-loss + rows with max parameter-gradient absolute differences up to about `0.0159` + against the PyTorch reference. Because reset/parity gates are strict, the + runtime now produces the checkpoint tensors but keeps checkpoint consumption + disabled until the CUDA temporal backward/window-recompute owner can consume + them consistently. +- Metadata: + accepted rows report `produced_checkpoint_count=2`, + `consumed_checkpoint_count=0`, `checkpoint_owner=cuda_temporal_superop`, and + `checkpoint_consumption=disabled_until_cuda_temporal_backward` in + `backward_recompute_mode`. Backward still reports + `physical_temporal_bucket_sequence_backward`. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_cuda_checkpoints`: + terminal-loss and provided-state gradient parity rows passed `2 passed in + 5.45s` with checkpoint-production metadata asserted. The broader adjacent + mixed temporal guard covering fresh/provided output-only rows, final-state + materialization, reset parity, three-population bindings, terminal-loss + gradients, provided-state gradients, and K=128 backward gradient mapping + passed `19 passed in 8.37s`. Python compile, Ruff check, Ruff format check, + and `git diff --check` passed. +- Non-closure note: + this closes only CUDA checkpoint tensor production in the forward superop ABI. + It does not close R4 checkpoint consumption or temporal backward. The next + high-priority owner is CUDA temporal backward/window recompute so the engine + can consume CUDA-produced checkpoints without the current hybrid recompute + parity drift. + +### 2026-04-28 UTC - R4 CUDA checkpoint consumption / temporal backward boundary + +Status: ACCEPTED AS PHYSICAL RECOMPUTE BRIDGE; CUDA BACKWARD OPEN + +Owner: CUDA temporal backward/window recompute and checkpoint-boundary gradient accounting. + +- Active invariant: + CUDA-produced planner-stride checkpoints must be consumed by the shared + temporal engine without changing any public output, exposed state, input + gradient, provided-state gradient, or parameter gradient. If consuming a + checkpoint creates drift, the bug is in the generic temporal boundary + contract between flat carry, transition private state, backend state cache, + recurrent K/V, and window gradients; it must not be fixed with population or + cell-specific routing. +- Current evidence: + CUDA checkpoint tensors match Python recompute final-state values tightly + for reset/no-reset and K>1 probes, so the rejected consumption probe is not a + simple forward state-value mismatch. The next investigation is the backward + boundary: how a recompute window started from a device checkpoint attributes + gradients to transition parameters, backend state cache, recurrent K/V, and + the previous window. +- Implementation: + accepted the generic layout-boundary fix. Temporal checkpoints and recompute + step artifacts now record whether recurrent K/V carry is graph-order or + backend-order flat carry. CUDA-produced checkpoints are inserted into the + artifact store with `backend_order` carry, and recompute/message backward + chooses recurrent/output sender tables from that layout. Backend-order K/V + gradients are converted back to graph order before recurrent K/V projection + parameter binding, preserving the public parameter contract without + population or cell-specific routes. +- Metadata: + mixed training rows now report `produced_checkpoint_count=2`, + `consumed_checkpoint_count=2`, `checkpoint_owner=cuda_temporal_superop`, and + `checkpoint_consumption=physical_recompute_bridge_from_cuda_checkpoints`. + Backward still reports `physical_temporal_bucket_sequence_backward`, so this + is checkpoint consumption through the existing physical recompute bridge, not + full CUDA temporal backward closure. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_checkpoint_consume1`: + focused mixed terminal-loss and provided-state full gradient parity passed + `2 passed in 5.05s` after metadata update. The adjacent mixed temporal guard + covering fresh/provided output-only rows, final-state materialization, reset + parity, three-population bindings, terminal-loss gradients, provided-state + gradients, and K=128 backward gradient mapping passed `19 passed in 7.54s`. + Python compile, Ruff check, Ruff format check, and `git diff --check` passed. +- Non-closure note: + this closes R4 checkpoint consumption for the supported mixed CUDA forward + route only as a transitional bridge. The next high-priority owner remains the + CUDA temporal backward superop itself: move the reverse scan/window recompute + loop and transition/message/readout backward scheduling out of Python/physical + replay, then make training metadata report the backward owner as + `cuda_temporal_superop`. + +### 2026-04-28 UTC - R4 CUDA temporal backward superop start + +Status: ACCEPTED AS PARTIAL BACKWARD SCHEDULING SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward superop / reverse flat-bucket scan. + +- Active invariant: + the remaining R4 owner is not another checkpoint bridge or metadata rename. + Training backward must eventually be one shared CUDA temporal reverse + superop over flat bucket identity, with reset, emission, checkpoint/window, + message, transition, readout, recurrent K/V projection, boundary projection, + query, and parameter-gradient attribution represented as engine work rather + than a Python loop over recomputed step artifacts. No cell-specific backend + routing is allowed; supported transition families come from the flat bucket + IR and low-level `fabric.cuda.nn`/CUDA op declarations. +- Current starting point: + forward is CUDA-owned for the guarded mixed gated+diagonal route and now + produces/consumes planner-stride CUDA checkpoints. Backward still reports + `physical_temporal_bucket_sequence_backward` because + `TemporalPhysicalBackwardScanExecutor.run()` owns the reverse window loop, + calls Python recompute for artifacts, then dispatches existing CUDA physical + ops step by step. +- First implementation target: + identify the smallest real CUDA temporal-backward kernel slice that moves + work out of that Python reverse loop without changing semantics. Candidate + slices must be generic flat-bucket work, for example a backend-order + per-window reverse carry/readout/message seed over CUDA checkpoint windows, + not planner-only relabeling or population-specific local math. +- Current slice: + implement CUDA temporal backward glue kernels for generic scan-index work: + materialize per-physical-step output-gradient windows from emitted user + gradients, and accumulate per-physical-step boundary gradients back to outer + T positions. This is not full R4 closure because transition/message reverse + dependencies still run through `TemporalPhysicalBackwardScanExecutor`, but it + removes two Python temporal-index loops from backward and establishes the + CUDA-owned reverse-window boundary the larger superop can absorb next. +- Verification target: + focused mixed terminal/provided-state parity remains the gate after every + slice, followed by reset/final-state/K=128 adjacent guards. The owner remains + open until metadata can truthfully report the backward temporal owner as + `cuda_temporal_superop`. + +### 2026-04-28 UTC - R4 CUDA temporal backward scan-index glue kernels + +Status: ACCEPTED AS PARTIAL KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward glue / generic physical scan-index kernels. + +- Active invariant: + this slice is shared temporal engine work over flat bucket identity. It must + not inspect cell families or route around the planner. It only handles generic + physical-step indexing for output-gradient windows and boundary-gradient + accumulation, leaving transition/message math owned by the existing physical + reverse executor until the next CUDA reverse-superop slice. +- Implementation: + added `fabric_flat_bucket_temporal_backward_cuda` with two CUDA kernels: + one materializes a per-window physical output-gradient tensor from emitted + user gradients for sequence or terminal output boundaries, and one + accumulates per-physical-step boundary gradients back into outer `T` + positions with CUDA atomics. `TemporalPhysicalBackwardScanExecutor` now uses + these kernels before falling back to the scalar Python loops, and the runtime + records `cuda_temporal_backward_glue` plus + `temporal_backward_glue:cuda_output_grad_window` and + `temporal_backward_glue:cuda_boundary_grad_accumulate` only when the kernels + actually run. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_$USER_redo_fixmass_temporal_backward_glue`: + the standalone scalar-schedule kernel parity test passed `1 passed in + 4.32s`. Focused mixed terminal-loss and provided-state full gradient parity + passed `2 passed in 262.70s` including CUDA-glue metadata assertions. The + adjacent mixed temporal guard covering output-only rows, final-state + materialization, reset parity, three-population bindings, terminal/provided + state gradients, fresh-state output-only parity, and K=128 backward gradient + mapping passed `19 passed in 7.92s`. +- Non-closure note: + R4 is still open. The remaining high-priority owner is the actual CUDA + temporal reverse superop for the transition/message recurrent dependency + loop inside `_run_backward_window`; the backward record must continue to + report `physical_temporal_bucket_sequence_backward` until that work moves + out of the Python reverse loop. + +### 2026-04-28 UTC - R4 CUDA backend-order recurrent K/V reverse edge + +Status: ACCEPTED AS PARTIAL KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward recurrent public-carry projection. + +- Active invariant: + recurrent public carry is part of the shared temporal engine. When the + forward temporal superop carries recurrent K/V in backend-order flat bucket + layout, backward should reverse that projection in the same backend-order + layout instead of detouring through graph-order recurrent banks. This must be + a generic flat-bucket K/V projection reverse edge, not cell-family logic. +- Implementation: + added `flat_bucket_project_recurrent_kv_backward_cuda` under the flat-bucket + public projection extension. The kernel computes backend-order hidden + gradients and backend-order public K/V projection weight gradients from + backend-order recurrent hidden plus backend-order recurrent K/V gradients. + Temporal backward now attempts this CUDA edge when artifacts declare + `public_carry_order=backend_order` for terminal/sequence output backward and + for the per-step recurrent carry edge. Only the external hidden-state + boundary is converted back to graph order; public K/V parameter gradients are + rebound through the existing public projection parameter binding. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_${USER}_redo_fixmass_backend_order_kv_backward`: + standalone backend-order recurrent K/V backward parity against autograd + passed `1 passed in 41.59s`. Focused mixed K>1 terminal-loss and + provided-state full gradient parity passed `2 passed in 261.28s` with + `temporal_backward_glue:backend_order_recurrent_kv_projection_backward` + metadata required. The adjacent temporal guard covering output-only rows, + final-state materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 7.48s`. Python compile, Ruff + check, Ruff format check, and `git diff --check` passed after formatting. +- Non-closure note: + this closes only the backend-order recurrent K/V projection reverse edge. + R4 remains open until the transition/message reverse scan loop itself moves + into the CUDA temporal superop and throughput audits prove the new path is + not regressing the April 21 score references. + +### 2026-04-28 UTC - R4 CUDA reverse-loop gradient seed/materialization + +Status: ACCEPTED AS PARTIAL KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward reverse-loop generic gradient assembly. + +- Active invariant: + the reverse scan loop should not keep host-side tensor cloning/slice + assignment for generic flat-cell gradient surfaces. Carry/output gradient + merge and recurrent-state gradient materialization are flat bucket identity + operations over `[B, cells, h]`, not cell-specific math. +- Implementation: + extended `fabric_flat_bucket_temporal_backward_cuda` with two more generic + reverse-loop kernels. `merge_carry_output_grad` builds a full flat-cell + gradient for one physical step from an incoming carry gradient plus an + emitted output-cell gradient slice. `materialize_recurrent_state_grad` + writes recurrent hidden gradients into the recurrent slice of a full + `[B, cells, h]` state-gradient tensor and zeros non-recurrent cells. + `TemporalPhysicalBackwardScanExecutor` now routes these two reverse-loop + setup/materialization surfaces through CUDA when shapes allow and records + `temporal_backward_glue:cuda_carry_output_grad_merge` or + `temporal_backward_glue:cuda_recurrent_state_grad_materialize` only when the + CUDA kernels run. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_grad_seed`: + standalone temporal backward glue scalar parity passed `1 passed in 38.91s`. + The first focused two-row run passed numerical parity but caught an + over-broad metadata expectation: the no-final-state terminal row has no + carry/output merge surface, so only recurrent state materialization is + required there. After narrowing that assertion, focused mixed K>1 + terminal-loss/provided-state full gradient parity passed `2 passed in + 5.07s`. The adjacent temporal guard covering output-only rows, final-state + materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 7.76s`. Python compile, Ruff + check, Ruff format check, and `git diff --check` passed. +- Non-closure note: + this is still a generic reverse-loop setup/materialization slice, not full + R4 closure. The transition/message recurrent dependency scan remains + host-orchestrated by `TemporalPhysicalBackwardScanExecutor`; the remaining + high-priority R4 owner is to move that reverse recurrence itself into the + CUDA temporal superop and then run throughput audits against the April 21 + score references. + +### 2026-04-28 UTC - R4 CUDA recurrent query temporal reduction + +Status: ACCEPTED AS PARTIAL KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward recurrent query-adjoint reduction. + +- Active invariant: + recurrent query adjoints produced at each physical step are part of the + shared temporal backward window. Reducing them across physical time before + query-parameter binding is a generic flat receiver reduction over + `[physical_steps, recurrent_cells, head_dim]`, not population or cell-family + logic. +- Implementation: + extended `fabric_flat_bucket_temporal_backward_cuda` with + `reduce_recurrent_query_grad`, a CUDA reduction from + `[physical_steps, recurrent_cells, head_dim]` to + `[recurrent_cells, head_dim]`. Deferred temporal query-parameter binding now + uses this CUDA reduction when every step in the window has a CUDA tensor and + records `temporal_backward_glue:cuda_recurrent_query_grad_reduce` only when + the kernel runs. Unsupported or sparse optional-step windows keep the prior + scalar accumulation fallback. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_${USER}_redo_fixmass_query_reduce`: + standalone temporal backward glue scalar parity passed `1 passed in 41.37s`. + Focused mixed K>1 terminal-loss/provided-state full gradient parity passed + `2 passed in 253.97s` with query-reduction metadata required on the + no-final-state terminal row. The adjacent temporal guard covering output-only + rows, final-state materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 8.35s`. Python compile, Ruff + check, Ruff format check, and `git diff --check` passed. +- Non-closure note: + this closes only the recurrent query-adjoint temporal reduction. The + transition/message recurrent dependency scan is still host-orchestrated by + `TemporalPhysicalBackwardScanExecutor`; R4 remains open until that reverse + recurrence and its per-step CUDA physical dispatch move into the CUDA + temporal superop and throughput audits clear against the April 21 score + references. + +### 2026-04-28 UTC - R4 CUDA initial recurrent K/V param-grad reduction + +Status: ACCEPTED AS PARTIAL KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward recurrent public-projection parameter reduction. + +- Active invariant: + recurrent public K/V projection parameter adjoints emitted at each physical + step are temporal window outputs. Summing compatible raw projection weight + gradients across physical time before public parameter binding is a generic + tensor reduction over flat bucket identity, not a population or cell-family + decision. +- Implementation: + added a generic Python wrapper over the temporal CUDA reduction kernel for + arbitrary trailing tensor shapes, then used it to reduce compatible + `TemporalInitialRecurrentBackwardStep.raw_param_grad.grad_weight` tensors + before public K/V projection parameter binding. The existing raw-grad binding + remains the single public parameter attribution path, and unsupported role, + grouping, device, group-id, or shape combinations keep the old tuple + fallback. Active CUDA reductions record + `temporal_backward_glue:cuda_initial_recurrent_kv_param_grad_reduce`. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_${USER}_redo_fixmass_query_reduce`: + standalone temporal backward glue scalar parity, including the generic + tensor-reduction check, passed `1 passed in 4.92s`. Focused mixed K>1 + terminal-loss/provided-state full gradient parity passed `2 passed in 5.71s` + with initial recurrent K/V raw-param reduction metadata required. The + adjacent temporal guard covering output-only rows, final-state + materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 7.94s`. Python compile, Ruff + check, Ruff format check, and `git diff --check` passed. +- Non-closure note: + this closes only another temporal reduction around parameter binding. R4 + remains open for the actual transition/message recurrent dependency reverse + scan and per-step physical dispatch to move into the CUDA temporal superop. + +### 2026-04-28 UTC - R4 CUDA recurrent message + initial K/V reverse edge + +Status: ACCEPTED AS PARTIAL KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward recurrent dependency edge. + +- Active invariant: + the next R4 owner must touch the recurrent reverse dependency itself, not only + another reduction or metadata tag. The supported backend-order public-carry + path should reverse recurrent message passing and the initial recurrent K/V + projection through a flat-bucket CUDA edge. It must not branch on cell family + or population names; transition math remains bucket-declared separately. +- Current target slice: + added a temporal backward CUDA entrypoint that takes the per-step recurrent + message adjoint, backend-order recurrent query, input/recurrent K/V banks, + backend-order recurrent hidden-before, flat local sender table, reset rows, + and backend-order recurrent public-projection weights. The entrypoint returns + recurrent query/input K/V adjoints plus backend-order hidden and public + projection weight adjoints, so `_run_temporal_bucket_step_backward_result` + now uses one backend-owned recurrent edge before falling back to the older + message + initial K/V projection sequence. The route records + `temporal_backward_glue:cuda_recurrent_message_initial_kv_backward` only when + the fused recurrent edge runs. +- Debug note: + the first focused run kept numerical parity but missed the new metadata + because the fast path used the local-delay dtype for `step_flat`; the fused + message kernel requires int64 step indices. The selected path now uses + `torch.long`, matching the existing local-message CUDA contract. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_${USER}_redo_fixmass_recurrent_msg_kv_edge`: + standalone fused recurrent message + initial K/V edge parity against the + existing CUDA message and backend-order K/V projection edges passed `1 passed + in 113.00s` after first compile. Focused mixed K>1 terminal-loss and + provided-state full gradient parity passed `2 passed in 5.45s` with the new + metadata required. The adjacent temporal guard covering output-only rows, + final-state materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 8.46s`. Python compile, Ruff + check, Ruff format check, and `git diff --check` passed. +- Non-goal for this slice: + this is not full R4 closure. The Python reverse loop and transition + backward dispatch still remain open until the CUDA temporal reverse superop + owns the whole window scan and throughput audits clear the April 21 score + references. + +### 2026-04-28 UTC - R4 CUDA transition parameter temporal reduction + +Status: ACCEPTED AS PARTIAL KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward transition parameter attribution. + +- Active invariant: + transition parameter adjoints emitted by each physical reverse step are + temporal window outputs. The shared temporal backward path should not add + every transition parameter tensor on the host while walking the Python + reverse loop. Compatible flat-bucket transition gradients should be reduced + by the CUDA temporal reduction path before public parameter binding. +- Current target slice: + changed the transition-parameter accumulator in + `TemporalPhysicalBackwardScanExecutor` from eager host tensor addition to + per-population/per-name grad collection, then reduce same-shape CUDA float32 + sequences with `try_reduce_temporal_tensor_grad_cuda` at bind time. Record a + temporal backward glue tag only when this CUDA transition reduction actually + runs. This is transition-owner work, not cell-family routing: population + names are binding scopes, and math still comes from the transition IR bucket + executor. +- Verification on GPU 0 with non-global cache + `TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_${USER}_redo_fixmass_transition_param_reduce`: + the first focused mixed K>1 terminal/provided-state run passed `2 passed in + 297.71s` after fresh extension compile. A metadata probe confirmed + `temporal_backward_glue:cuda_transition_param_grad_reduce` on the high-level + route. After making that metadata required, focused mixed terminal-loss and + provided-state full gradient parity passed `2 passed in 6.23s`. The adjacent + temporal guard covering output-only rows, final-state materialization, reset + parity, three-population bindings, terminal/provided-state gradients, + fresh-state output-only parity, and K=128 backward gradient mapping passed + `19 passed in 8.82s`. Python compile, Ruff check, Ruff format check, and + `git diff --check` passed after formatting. +- Non-goal for this slice: + the transition backward kernels themselves and the reverse dependency scan + are still called per physical step. R4 remains open until those move into the + CUDA temporal reverse superop. + +### 2026-04-28 UTC - R4 temporal artifact recompute owner + +Status: ACCEPTED AS PARTIAL BACKWARD SCHEDULING SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward artifact recompute bridge. + +- Active invariant: + backward-window artifact recompute is part of the shared temporal engine. It + must not replay avoidable boundary/public projection work once per physical + step when the input boundary is indexed by outer timestep and K microsteps + reuse the same input K/V. This is generic temporal scheduling over flat bucket + identity, not cell-family logic. +- Current evidence: + GPU 0 warmed focused mixed K>1 terminal-loss owner timing with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1` and private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_owner_timing` shows + `temporal_artifact_recompute:ms=15.156;count=3` as the largest warm backward + owner. The cold first pass was dominated by JIT/compile and is rejected as + closure evidence. +- Current implementation target: + hoist recompute-window input-boundary K/V projection out of the physical-step + replay loop. The recompute bridge should project the needed outer-time + boundary slice once with the existing backend sequence projection and pass the + matching K/V row into each physical step, including K>1 repeated outer-step + rows and reset-present rows. +- Follow-up target in the same owner: + avoid materializing readout/output artifacts for non-emitting K microsteps + during backward replay. Internal K steps advance recurrent state; user output + materialization happens only on emission steps, so non-emitting replay steps + should not run output message/readout work unless a future API explicitly + requests intermediate output materialization. +- Adjacent backward-output target: + prebatched temporal output backward should batch only emission microsteps. + Non-emitting K microsteps have no user output gradient and should contribute + all-None output-adjoint steps instead of being sent through readout/message + backward with zero gradients. +- Current continuation: + apply the same emission-only rule to materialized-final-state backward. After + the final physical step, reverse carry gradients are recurrent-state + gradients; non-final, non-emitting K microsteps should bypass output/readout + backward when no emitted output gradient is present. +- Implementation: + `_recompute_temporal_bucket_artifact_window` now preprojects the needed + outer-time input K/V window once and reuses those rows for K microsteps. + Recompute also elides readout/output materialization on non-emitting K + microsteps by writing zero output artifacts for those internal steps. The + prebatched temporal output backward path now receives explicit active + emission-step indices and only batches readout/message/output projection + backward over those emitted physical steps; skipped steps return all-None + output adjoints into the existing reverse scan. +- Evidence: + focused mixed K>1 terminal-loss and provided-state/reset full-gradient parity + passed on GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_recompute_hoist`: `2 passed in + 6.24s` after the recompute hoist/elision and `2 passed in 6.42s`, then + `2 passed in 5.99s` after formatting, after the emission-only output-backward + batching. The adjacent temporal guard covering output-only rows, + final-state materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 8.04s`. + After extending the same skip to materialized-final-state backward, focused + terminal/provided-state parity passed on GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_materialized_output_skip`: + `2 passed in 299.79s` after fresh compile. The same adjacent temporal guard + passed again: `19 passed in 8.96s`. +- Timing note: + the recompute hoist/elision alone did not materially move the warm + `temporal_artifact_recompute` owner (`~15 ms` remained the warm line), so it + is recorded only as cleanup inside the accepted slice. The emission-only + output-backward batching did move active output/readout work on the same + focused terminal row: warm `readout` count dropped from 3 to 1 and + `message.fused_receiver_sender` count dropped from 5 to 3. The main open + warm owner remains `temporal_artifact_recompute:ms=14.580;count=3`, followed + by transition backward families, so this is not R4 closure. + On the materialized-final provided-state row, warm readout remained one call + and `message.fused_receiver_sender` remained three calls after the continuation + (`temporal_artifact_recompute:ms=14.497;count=3` remained the open owner). +- Hygiene: + Python compile, Ruff check, Ruff format check, and `git diff --check` passed. +- Non-closure note: + this does not close R4. The reverse dependency scan and transition/message + backward dispatch remain host-orchestrated until the CUDA temporal reverse + superop owns the window. + +### 2026-04-28 UTC - R4 CUDA temporal scan replay ABI + +Status: ACCEPTED AS ABI/KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward artifact replay surface. + +- Active invariant: + checkpoint-window replay is part of the shared temporal engine. The CUDA + temporal scan must be able to replay arbitrary physical windows, including + windows that start inside a K microstep group, without assuming that + `physical_step == 0` or that a checkpoint is outer-step aligned. This is + generic flat-bucket temporal indexing; it is not population or cell-family + routing. +- Implementation: + extended `fabric_flat_bucket_temporal_scan_cuda` with `physical_start`, + `physical_steps`, and optional recurrent-message artifact outputs. The + cooperative scan kernel now maps local replay steps back to absolute + physical/outer/inner positions for reset and emission semantics, while the + Python wrapper can request `checkpoint_recurrent_msg_backend_order` plus + `final_recurrent_msg_backend_order`. The default forward path still passes + `physical_start=0` and full `outer_time * K` physical steps, so existing + public high-level model calls keep the same route and metadata. +- Evidence: + a first owner probe using a new private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_owner_probe` was aborted after + it spent several minutes compiling without producing runtime timing; it is + rejected as performance evidence. The accepted verification used GPU 0 and + private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_scan_replay_abi`. The focused + mixed K>1 terminal-loss row passed after the fresh scan-extension rebuild: + `1 passed in 290.24s`. The paired terminal/provided-state reset parity rows + then passed warm: `2 passed in 4.92s`. The adjacent temporal guard covering + output-only rows, final-state materialization, reset parity, + three-population bindings, terminal/provided-state gradients, fresh-state + output-only parity, and K=128 backward gradient mapping passed + `19 passed in 7.52s`. +- Non-closure note: + this slice gives the recompute bridge the required CUDA replay ABI, but it + does not yet consume recurrent-message artifacts in + `_recompute_temporal_bucket_artifact_window`, and it does not move the + reverse transition/message dependency loop out of + `TemporalPhysicalBackwardScanExecutor._run_backward_window`. R4 remains open; + the next owner is to consume this replay surface for checkpoint-window + artifacts, then fold transition/message reverse recurrence into the CUDA + temporal backward superop. + +### 2026-04-28 UTC - R4 CUDA temporal scan transition-tape ABI + +Status: ACCEPTED AS ABI/KERNEL SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward artifact replay surface and transition-tape +handoff. + +- Active invariant: + checkpoint-window replay must be able to produce the same transition backward + tapes that the Python recompute bridge currently rebuilds step by step. + These tapes are transition-IR artifacts over flat bucket identity + (gated-logspace and diagonal-RTU buckets in the current guarded CUDA scan), + not cell-name routes or benchmark hooks. +- Implementation: + extended `fabric_flat_bucket_temporal_scan_cuda` with optional + `return_transition_tape_artifacts`. When requested, the cooperative temporal + scan now emits per-physical-step and final replay tensors for: + gated input projection, gated input gate logits, gated recurrent gate logits, + diagonal input projection, and diagonal preprojection. The Python wrapper + parses these tensors without changing the default high-level forward path. + `transition_execution` input-projection tape helpers now accept a + scan-owned `output_override`, so future recompute consumers can construct + truthful `TransitionBackwardTape` objects without recomputing those forward + outputs. `temporal_backward` has the internal tape-construction helper ready + for the next consumer slice. +- Replay correctness fix: + the prior replay ABI seeded checkpoint state into the local `*_a` buffers but + selected scan source/destination buffers using absolute physical-step parity. + That is correct only for windows beginning at physical step zero. The kernel + now uses local replay-window parity for rolling gated/diagonal state buffers, + while absolute physical/outer/inner indexing remains responsible for resets + and output emission semantics. +- Evidence: + GPU 0, private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_scan_tape_abi_compile`. + Fresh focused mixed K>1 terminal-loss compile/parity passed: + `1 passed in 298.65s`. A direct high-level runtime probe requested + transition-tape artifacts through `_try_cuda_mixed_flat_bucket_temporal_scan` + and returned the CUDA temporal-superop artifact store successfully. Warm + mixed K>1 terminal-loss and provided-state gradient parity passed + `2 passed in 4.84s`. The adjacent temporal/reset/K128 guard matrix passed + `19 passed in 6.92s`. Python compile, Ruff check, Ruff format check, and + `git diff --check` passed. +- Non-closure note: + this is still not R4 closure. The new tape tensors are available and parsed, + but `_recompute_temporal_bucket_artifact_window` still needs the full CUDA + artifact consumer. The reverse dependency loop in + `TemporalPhysicalBackwardScanExecutor._run_backward_window` remains open + until the CUDA temporal backward superop owns checkpoint replay, + transition/message reverse recurrence, and throughput audits pass. + +### 2026-04-28 UTC - R4 CUDA temporal scan output-message ABI + +Status: ACCEPTED AS ABI/KERNEL SLICE; FULL REPLAY CONSUMER REJECTED FOR NOW + +Owner: CUDA temporal backward artifact replay surface. + +- Active invariant: + a CUDA replay consumer cannot replace `_recompute_temporal_bucket_artifact_window` + unless it can provide every step artifact needed by the ordinary high-level + `loss.backward()` path. Transition tapes and recurrent messages are not + sufficient; output backward also needs the pre-projection `output_msg`. +- Implementation: + extended `fabric_flat_bucket_temporal_scan_cuda` with optional + `return_output_msg_artifacts`. The cooperative temporal scan now can emit + per-physical-step and final output-message tensors alongside recurrent-message + and transition-tape artifacts. The default high-level model path does not + request these artifacts, so no benchmark/user route owns backend policy. +- Rejected probe: + a guarded full replay consumer was prototyped but left inactive. It rebuilt + complete step artifacts from scan replay tensors, but the strict K>1 + terminal-loss parity gate failed on parameter gradients when using scan-owned + gated logits as the full transition tape. The largest reported mismatch was + `runtime.population_modules.slstm.bias_base` with max abs diff + `0.01118628` against the `1e-2` gate. That means the next consumer pass must + make the scan-owned gated tape numerically identical to the existing + transition tape, or route that tape production through an exact CUDA owner, + before replacing Python artifact recompute. This probe is rejected as closure + evidence and is not allowed to weaken the parity gate. +- Evidence: + GPU 0, private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_scan_output_msg_abi`. + Fresh focused mixed K>1 terminal-loss compile/parity passed + `1 passed in 295.81s`. Optional output-message plus transition-tape artifact + parsing was exercised through `_try_cuda_mixed_flat_bucket_temporal_scan` and + returned successfully for all requested artifact combinations. After the + rejected consumer was disabled, focused mixed K>1 terminal/provided-state + parity passed `2 passed in 4.86s`. The adjacent temporal/reset/K128 guard + matrix passed `19 passed in 6.94s`. Python compile, Ruff check, Ruff format + check, and `git diff --check` passed. +- Non-closure note: + R4 remains open. The scan ABI now has recurrent messages, output messages, + and transition-tape tensors, but active checkpoint-window artifact recompute + still falls back to the existing Python step replay until the exact gated-tape + parity issue is fixed and the reverse dependency scan moves into the CUDA + temporal backward superop. + +### 2026-04-28 UTC - R4 CUDA replay artifact consumer enabled + +Status: ACCEPTED AS BACKEND SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward artifact replay consumer. + +- Active invariant: + checkpoint-window artifact recompute is part of the shared temporal engine. + The consumer must rebuild scalar-step artifacts from flat-bucket temporal scan + replay tensors without changing high-level benchmark/model APIs and without + population-name routing. Reset parity is part of the gate: boundary resets + affect public carry/cells at the boundary step, while transition resets apply + on every physical microstep. +- Probe result: + the prior rejected consumer was not failing because scan-owned gated logits + were materially different. A focused K>1 terminal-loss probe compared the + inactive CUDA replay artifacts against the existing Python step replay and + found gated input/recurrent gate logits within about `1e-7`. The real replay + issue was artifact semantics around resets: the consumer had copied + `reset_step` into `transition_reset_step`, which loses K-microstep reset + semantics, and it exposed pre-reset `cells_prev`, recurrent K/V, and backend + state cache on boundary-reset steps. +- Implementation: + `_recompute_temporal_bucket_artifact_window` now first attempts the guarded + CUDA temporal replay consumer for supported mixed flat-bucket windows and + records `cuda_temporal_superop_replay` when accepted. The replay consumer now + computes boundary reset and transition reset independently, applies boundary + resets to `cells_prev`, recurrent K/V, and the backend state cache handed to + transition backward, and keeps the scan-owned `state_after`/message/output + tensors as the source of forward replay truth. The K>1 terminal/provided-state + tests now assert this recompute owner directly. +- Evidence: + GPU 0, private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_exact_tape_probe`. Focused + K>1 terminal-loss and provided-state/reset parity passed after enabling the + replay consumer: `2 passed in 4.88s`. The adjacent temporal guard covering + output-only rows, final-state materialization, reset parity, + three-population bindings, terminal/provided-state gradients, fresh-state + output-only parity, and K=128 backward gradient mapping passed + `19 passed in 6.84s`. Python compile, Ruff check, Ruff format check, and + `git diff --check` passed. +- Non-closure note: + R4 remains open. Artifact recompute now uses CUDA temporal replay on the + guarded mixed flat-bucket surface, but the reverse dependency scan and + transition/message backward recurrence still execute as host-orchestrated + physical steps. The next owner is to move that reverse loop itself into the + CUDA temporal backward superop, then run throughput audits before declaring + R3/R4 closure. + +### 2026-04-28 UTC - R4 temporal query-param reduction batched + +Status: ACCEPTED AS BACKEND SLICE; R4 REVERSE SUPEROP OPEN + +Owner: temporal backward query-gradient binding. + +- Active invariant: + window-level backward reductions should be owned by the shared temporal + backend, not repeated from benchmark code or per-step model API logic. Query + parameter gradients are generic message-superop gradients over recurrent and + output receivers; batching them is flat-bucket temporal glue, not + population-specific logic. +- Implementation: + `TemporalRecurrentQueryBackwardStep` now carries both recurrent-query and + output-query adjoints. `_run_temporal_bucket_step_backward_result` defers + query parameter binding for every temporal backward window, including + materialized-final-state paths where output query gradients are produced + inside the physical step. `run_temporal_recurrent_query_backward_sequence` + now reduces recurrent-query and output-query adjoint sequences with the + shared CUDA temporal tensor reducer before invoking the runtime query-param + binding once per window. The provided-state K>1 regression now asserts both + `cuda_recurrent_query_grad_reduce` and `cuda_output_query_grad_reduce`. +- Evidence: + GPU 0, private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_exact_tape_probe`. Focused + K>1 terminal-loss and provided-state/reset parity passed `2 passed in + 4.84s`. The adjacent temporal guard covering output-only rows, final-state + materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 6.71s`. Python compile, Ruff + check, Ruff format check, and `git diff --check` passed. +- Non-closure note: + this removes another repeated per-step parameter binding, but R4 remains + open. The reverse dependency scan still iterates physical steps on the host + and still calls transition and message backward per step. The next backend + owner remains a CUDA temporal backward superop for the reverse recurrence + itself. + +### 2026-04-28 UTC - R4 CUDA replay stays in flat cache + +Status: ACCEPTED AS BACKEND SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal replay artifact materialization. + +- Active invariant: + replay artifacts for cache-backed temporal backward should stay in flat + backend identity. Population TensorDict views are a public/reference + convenience and should not be rebuilt per physical replay step when the + backward path consumes `backend_state_cache_before`. +- Implementation: + factored mixed CUDA replay assembly into flat cell assembly and backend state + cache construction. `_try_cuda_mixed_flat_bucket_recompute_artifact_window` + now carries `cells_before/cells_after` plus backend-order state cache through + the CUDA replay loop, applies boundary resets directly to the flat cells and + backend cache, and stores empty population views in replay artifacts because + the cache-backed transition backward path does not consume those views. The + ordinary final-state/materialized forward assembly still constructs public + population views where user-visible state requires them. +- Evidence: + GPU 0, private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_replay_flat_cache`. Focused + K>1 terminal-loss and provided-state/reset parity passed after cold rebuild: + `2 passed in 221.61s`. The adjacent temporal guard covering output-only rows, + final-state materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 47.68s`. Hygiene passed: + Python compile, Ruff check, Ruff format check, and `git diff --check`. + Warm owner timing on the same provided-state K>1 row moved + `temporal_artifact_recompute` from the prior accepted `~5.918 ms;count=3` + line to `4.689 ms;count=3`; cold timing is rejected as compile noise. +- Non-closure note: + this makes replay artifact materialization more faithful to flat bucket + identity and modestly reduces replay cost, but R4 remains open. The reverse + transition/message dependency scan is still host-orchestrated and remains the + next high-priority CUDA temporal backward superop owner. + +### 2026-04-28 UTC - R4 transition input-projection param window + +Status: ACCEPTED AS BACKEND SLICE; R4 REVERSE SUPEROP OPEN + +Owner: temporal transition input-projection backward. + +- Active invariant: + transition input-projection parameter gradients are temporal-window + reductions over flat bucket identity. The reverse scan still needs + per-physical-step input adjoints immediately, but weight/bias reductions do + not need to bind or reduce through the public parameter surface once per + step. This is projection-tape capability, not population-name or cell-family + routing. +- Implementation: + added a Fabric-owned receiver-major input-only backward executor and extended + transition projection tapes with an optional deferred parameter-gradient + step. The temporal unbound transition backward path now computes the + per-step `grad_recurrent_msg` needed by recurrent/message dependencies, saves + projection parameter-gradient work as flat tape steps, and reduces those + steps once per population/window before ordinary public parameter binding + when the checkpoint window is large enough to amortize the extra reducer. + Small windows keep the immediate transition path because the visible reducer + timing showed H=2 was not large enough to win. The reducer supports fused + recurrent input projections, static recurrent input projections, + diagonal/factorized input-projection weights, and unfused diagonal recurrent + input projections through transition-IR tape structure. The K=128/H=64 guard + now asserts `cuda_transition_input_projection_param_grad_window`. +- Evidence: + GPU 0, private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_inputproj_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_inputproj_window`. Focused H=2 + K>1 terminal-loss, provided-state/reset, and K=128/H=64 parity passed: + `3 passed in 6.09s` after the thresholded policy change. The adjacent + temporal guard covering output-only rows, + final-state materialization, reset parity, three-population bindings, + terminal/provided-state gradients, fresh-state output-only parity, and K=128 + backward gradient mapping passed `19 passed in 6.61s`. Hygiene passed: + Python compile, Ruff check, Ruff format check, and `git diff --check`. + Warm H=2 owner timing stayed on the immediate path, with no deferred + projection launch tag and `receiver_affine.input_projection_backward` at + `2.975 ms;count=12`, effectively the prior accepted `2.956 ms;count=12` + line. Warm K=128/H=64 owner timing exercises the deferred path and reports + `receiver_affine.input_projection_backward:ms=46.680;count=256` plus + `receiver_affine.input_projection_param_window:ms=1.096;count=4`, with + warmed launch metadata including + `temporal_backward_glue:cuda_transition_input_projection_param_grad_window`. + Cold timing is rejected as compile noise. +- Non-closure note: + this removes repeated transition projection parameter-gradient work from the + host step loop only for large enough checkpoint windows, but R4 remains open. + The reverse dependency loop itself still iterates physical steps on the host + and transition/message backward recurrence must still move into the CUDA + temporal backward superop before R3/R4 closure or throughput audits. + +### 2026-04-28 UTC - R4 recurrent-message reverse owner visible + +Status: ACCEPTED AS BACKEND VISIBILITY SLICE; R4 REVERSE SUPEROP OPEN + +Owner: temporal recurrent-message/initial-KV reverse edge. + +- Active invariant: + owner timing must expose every dominant per-step reverse edge before choosing + the next CUDA temporal superop cut. Launch metadata already proved the fused + recurrent-message plus initial-KV backward CUDA path was active, but it was + not visible in the owner timing summary, which made the transition kernels + look like the only remaining large per-step owners. +- Implementation: + wrapped `try_recurrent_message_initial_kv_backward_cuda` in + `message.recurrent_initial_kv_backward` owner timing. This does not change + semantics or routing; it makes the already-active generic flat-bucket + recurrent reverse edge measurable in the same high-level + `model(...); loss.backward()` path as the transition owners. +- Evidence: + GPU 0, private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_msg_timing` and + `/tmp/cortical_triton_${USER}_redo_fixmass_msg_timing`. Focused H=2 K>1 + terminal/provided-state plus K=128/H=64 parity passed after cold compile: + `3 passed in 225.67s`. The adjacent temporal/reset/K128 guard passed + `19 passed in 50.92s`. Warm K=128/H=64 owner timing now reports + `message.recurrent_initial_kv_backward:ms=12.191;count=128` alongside + `receiver_affine.input_projection_backward:ms=42.309;count=256`, + `receiver_affine.gate_affine_backward:ms=22.430;count=128`, + `diagonal_recurrence.output_projection_backward:ms=19.836;count=128`, + `receiver_affine.recurrent_affine_backward:ms=17.212;count=128`, + `state_epilogue.core:ms=16.242;count=128`, and + `diagonal_recurrence.core:ms=15.541;count=128`. Cold timing is rejected as + compile noise. +- Non-closure note: + this is visibility, not R4 closure. The next implementation owner remains + the CUDA temporal backward superop that carries the reverse dependency loop + across transition backward and recurrent-message/initial-KV backward inside + the backend, eliminating the host physical-step loop rather than only timing + it. + +Rejected probe: + +- A follow-up attempted to defer only the recurrent initial-K/V weight-gradient + kernel across large temporal windows, while keeping recurrent-message and + hidden-state adjoints per step. The probe extended the fused recurrent + message backward ABI to optionally skip the per-step weight-gradient kernel + and return recurrent K/V adjoints for a window reducer. It passed focused + parity (`3 passed`) and the adjacent temporal/reset/K128 guard (`19 passed`), + but warm K=128/H=64 timing did not improve: + `message.recurrent_initial_kv_backward` stayed around `12.8 ms;count=128` + versus the prior visible `12.2 ms;count=128`, and the first implementation + accidentally re-entered the generic message fallback until fixed. The probe + was reverted before commit. Do not repeat this narrow weight-gradient + deferral as R4 closure work; the needed owner is still the full reverse + dependency scan inside the CUDA temporal backward superop. + +### 2026-04-28 UTC - R4 recurrent-message reverse window ABI + +Status: ACCEPTED AS PARTIAL KERNEL/ABI SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward recurrent dependency window. + +- Active invariant: + the recurrent-message plus initial recurrent K/V reverse edge is a temporal + window operation over flat bucket identity. A future CUDA temporal backward + superop must be able to consume a physical window with explicit per-step + reset and step-index tensors, reduce query/weight gradients across `T*B`, + and return per-step input/hidden adjoints without Python owning the edge + launch per physical step. +- Implementation: + added `recurrent_message_initial_kv_backward_window` to the Fabric temporal + backward CUDA extension. The C++ ABI accepts rank-4 `[T,B,...]` window + tensors plus rank-2 reset and step-index windows, flattens `T*B` inside the + extension, launches the existing fused recurrent-message/initial-KV CUDA + kernels once for the window, and reshapes the per-step adjoints back to + window layout. The Python wrapper exposes this as + `try_recurrent_message_initial_kv_backward_window_cuda`. A standalone CUDA + parity test compares the window ABI against the existing step edge across + multiple physical steps, nonzero delay masks, and reset rows. +- Evidence: + GPU 0, private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_recurrent_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_recurrent_window`. Focused + recurrent edge tests passed after cold extension build: + `2 passed in 114.54s`; warm rerun passed `2 passed in 3.77s`. The + adjacent temporal/reset/K128 guard covering output-only rows, provided and + final state, reset horizon parity, three-population bindings, K>1 + terminal/provided-state gradients, fresh-state no-final-state parity, and + K=128 backward gradient mapping passed `19 passed in 167.89s`. +- Non-closure note: + this creates the window ABI the CUDA reverse superop needs for the recurrent + dependency edge, but `_run_backward_window` still has a loop-carried + transition/message dependency that prevents dropping this in as a pure + post-loop batch call. R4 remains open until transition backward and recurrent + message backward are fused into the CUDA temporal reverse scan owner and the + H/K/T throughput audits pass. + +### 2026-04-28 UTC - R4 schedule-owned temporal message step indices + +Status: ACCEPTED AS SEMANTIC BACKEND SLICE; R4 REVERSE SUPEROP OPEN + +Owner: shared temporal backward message-delay semantics. + +- Active invariant: + recurrent and readout message delay masks in the shared temporal path must be + driven by the temporal scan schedule, not by a hardcoded step `1`. K>1 uses + the same public model call and temporal artifact stream as K=1, so forward + artifacts and backward message kernels need an explicit schedule-owned + message step before the reverse loop can move into a CUDA temporal superop. +- Implementation: + `TemporalBucketStepArtifacts` now carries both `physical_step_index` and + `message_step_index`. Stored and recomputed artifacts set + `message_step_index = inner_step + 1` from the scalar physical scan schedule. + Temporal recurrent/readout message forward artifact creation uses that index, + output-message backward sequences pass a flat per-row step tensor into the + shared backend message backward phase, and the recurrent + message/initial-KV reverse edge uses the artifact step index instead of + rebuilding a constant step-1 tensor. The generic CUDA message backward helper + now accepts either an integer step or an explicit flat `[B*T]` step tensor. +- Evidence: + GPU 0, private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_step_index` and + `/tmp/cortical_triton_${USER}_redo_fixmass_step_index`. The new delayed K>1 + high-level parity case passed: `1 passed in 191.97s`. Focused temporal and + local-message backward guards passed `7 passed in 81.27s`. The adjacent + temporal/reset/K128 guard passed `19 passed in 51.06s`. +- Non-closure note: + this removes a hardcoded temporal-delay assumption from the active shared + backend path and makes the recurrent reverse edge window-ready, but it does + not close R4. `_run_backward_window` still owns the reverse physical-step + carry loop in Python. The next high-priority owner remains the CUDA temporal + backward superop that fuses transition backward with recurrent-message + backward over flat bucket identity. + +### 2026-04-28 UTC - R4 active recurrent reverse window ABI consumption + +Status: ACCEPTED AS BACKEND ABI CONSUMPTION SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward recurrent dependency edge. + +- Active invariant: + the active recurrent reverse edge should consume the same window-shaped CUDA + ABI that the eventual temporal reverse superop will call. Keeping the + step-only ABI on the hot path would leave the final owner with an unexercised + interface and make the reverse scan migration riskier. +- Implementation: + `_run_temporal_bucket_step_backward_result` now routes the backend-order + recurrent message plus initial recurrent K/V reverse edge through + `try_recurrent_message_initial_kv_backward_window_cuda` with a one-step + `[T=1,B,...]` window. It passes the artifact-owned message step and reset + row as explicit rank-2 tensors, unwraps the per-step input/hidden adjoints, + and keeps the scalar ABI only as a fallback. High-level K>1 terminal-loss + tests now assert + `temporal_backward_glue:cuda_recurrent_message_initial_kv_backward_window` + in addition to the existing recurrent reverse tag. +- Evidence: + GPU 0, private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_window_active` and + `/tmp/cortical_triton_${USER}_redo_fixmass_window_active`. Focused + high-level mixed K>1 terminal/provided-state gradients plus the delayed K>1 + single-pop schedule-step parity test passed after cold extension/cache setup: + `3 passed in 304.06s`; the same focused set with stricter launch-count + assertions passed warm in `5.00s`. The adjacent temporal/reset/K128 guard + passed `19 passed in 10.29s`. +- Non-closure note: + this consumes the window ABI in the active path but still calls it once per + physical step from `_run_backward_window`. R4 remains open until the CUDA + temporal backward superop carries the reverse recurrence across transition + backward and recurrent message backward without the Python physical-step + loop. + +### 2026-04-28 UTC - R4 CUDA hidden-before reverse window materialization + +Status: ACCEPTED AS PARTIAL KERNEL/ABI SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward reverse-window flat-bucket state identity. + +- Active invariant: + the recurrent reverse dependency edge must consume hidden-before state in + backend flat-bucket order. Recomputing it inside each Python reverse step by + slicing graph-order `cells_prev` and reordering hides a piece of the temporal + reverse window from the CUDA owner, and reset rows must be applied before the + recurrent initial-K/V reverse edge sees that state. +- Implementation: + added `materialize_recurrent_hidden_before_window` to the Fabric temporal + backward CUDA extension. Given the checkpoint hidden state, the replayed + recurrent-hidden-after window, and a physical reset window, it materializes + `[physical_steps, B, recurrent, H]` hidden-before tensors in backend order + with reset rows zeroed. CUDA replay artifact construction now calls this + kernel once per recompute window and stores + `recurrent_hidden_before_backend_order` on `TemporalBucketStepArtifacts`. + The active recurrent-message/initial-KV backward edge consumes this flat + hidden-before artifact instead of rebuilding it from graph-order cells per + physical step. Non-CUDA/stored paths still populate the same artifact field + when they have materialized cells, preserving the shared flat-bucket ABI. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_hidden_before_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_hidden_before_window`. + Low-level temporal backward glue parity passed `1 passed in 40.28s`. + High-level mixed K>1 terminal and provided-state backward parity passed + `2 passed in 191.29s`. The K=128 mixed-pop backward smoke passed + `1 passed in 7.23s`. Reset parity plus recurrent-message window parity passed + `5 passed in 45.93s`. +- Non-closure note: + this is real CUDA reverse-window preparation, but it does not move the + loop-carried transition backward plus recurrent-message backward recurrence + out of `_run_backward_window`. R4 remains open. The next owner is still the + CUDA temporal reverse superop that carries those dependencies inside the + device-side flat-bucket reverse scan, followed by H/K/T throughput audits. + +### 2026-04-28 UTC - R4 CUDA recurrent reverse full-cell carry materialization + +Status: ACCEPTED AS PARTIAL KERNEL/ABI SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward recurrent dependency edge and flat-cell carry +materialization. + +- Active invariant: + the recurrent-message reverse edge is the producer of the carry gradient for + the previous physical step. That carry is a flat-cell identity tensor, so the + hot reverse path should not rebuild it through a Python graph-order reorder + plus a separate per-step recurrent-state materialization kernel. The backend + has enough flat identity metadata to materialize the full carry directly. +- Implementation: + added + `recurrent_message_initial_kv_backward_window_with_state` to the Fabric + temporal backward CUDA extension. It runs the existing window recurrent + message plus initial K/V backward edge, then scatters the backend-order + hidden adjoint into `[T, B, full_cells, H]` using the runtime-provided + backend recurrent inverse-order tensor and recurrent slice start. The active + K>1 temporal backward path now tries this ABI first, records + `temporal_backward_glue:cuda_recurrent_message_initial_kv_backward_window_state`, + and consumes the returned full flat-cell carry directly. The old + `cuda_recurrent_state_grad_materialize` fallback remains for unsupported + layouts, but the supported mixed CUDA temporal path no longer launches it per + physical step. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_window_state` and + `/tmp/cortical_triton_${USER}_redo_fixmass_window_state`. Low-level recurrent + window parity, including nontrivial backend inverse-order full-cell carry + scatter, passed `1 passed in 39.26s`. High-level mixed K>1 terminal-loss and + provided-state parity passed `2 passed in 196.12s`, with launch metadata + asserting the new full-cell carry ABI and absence of the old per-step + recurrent-state materialization tag. The K=128 mixed-pop backward smoke + passed `1 passed in 7.20s`. Hygiene passed: + `python -m py_compile ...`, `uv run ruff format --check ...`, + `uv run ruff check ...`, and `git diff --check`. +- Non-closure note: + this closes the recurrent reverse full-cell carry materialization substage, + not R4. `_run_backward_window` still carries the transition backward plus + recurrent-message dependency loop in Python. The next owner remains the CUDA + temporal reverse superop that moves transition backward and recurrent-message + backward into one device-side flat-bucket reverse scan; only after that can + H/K/T throughput audits close R4. + +### 2026-04-28 UTC - R4 CUDA diagonal transition-core reverse window scan + +Status: ACCEPTED AS PARTIAL KERNEL/ABI SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward transition-core reverse scan. + +- Active invariant: + transition families are physical op buckets under the shared flat temporal + engine. The reverse temporal owner must be able to scan a physical window in + device code and carry state adjoints backward across reset boundaries, rather + than asking Python to call one transition core per physical step. +- Implementation: + added `diagonal_recurrence_core_backward_window` to the Fabric temporal + backward CUDA extension. The kernel scans `[T, B, R, H]` diagonal recurrence + core tensors backward inside CUDA, carries `hc1/hc2` adjoints from the final + window boundary to the initial boundary, applies reset-row masking inside the + reverse scan, writes per-step `grad_cell_input`, and accumulates + `nu_log/theta_log/w1/w2` gradients across `T*B`. The Python wrapper exposes + `try_diagonal_recurrence_core_backward_window_cuda`, and the low-level test + compares it against the existing scalar stepwise + `diagonal_recurrence_backward_cuda` loop with reset rows and final-state + gradient seeds. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_diag_reverse_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_diag_reverse_window`. The new + diagonal reverse-window scalar parity test passed `1 passed in 41.12s`. + Adjacent temporal backward glue, recurrent-message window, and K=128 + mixed-pop backward smoke passed `3 passed in 190.82s`. Hygiene passed: + `python -m py_compile ...`, `uv run ruff format --check ...`, + `uv run ruff check ...`, and `git diff --check`. +- Non-closure note: + this is the first transition-core reverse scan kernel for the CUDA temporal + backward superop. It is not yet wired as the active transition/message + dependency owner because recurrent-message backward still has to be fused + with transition backward for semantic correctness. R4 remains open until the + CUDA temporal reverse superop carries both transition backward and + recurrent-message backward together over flat bucket identity. + +### 2026-04-28 UTC - R4 CUDA gated logspace transition-core reverse window scan + +Status: ACCEPTED AS PARTIAL KERNEL/ABI SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward transition-core reverse scan. + +- Active invariant: + gated logspace recurrence is a physical transition bucket under the shared + flat temporal engine. Its reverse core over `c/n/m` state must scan a + physical `[T,B,R,H]` window in CUDA, including reset-boundary masking and + hidden-width outnorm reductions, instead of relying on Python to call the + scalar core once per physical step. +- Implementation: + added `gated_logspace_core_backward_window` to the Fabric temporal backward + CUDA extension. The kernel launches one block per flat `(batch, receiver)` + row, uses the hidden dimension as the block reduction axis for outnorm + backward, scans the window backward over time, writes per-step `grad_raw`, + carries `c/n/m` adjoints to the previous window boundary, masks carry across + reset rows, and accumulates `outnorm_weight` gradients across `T*B`. The + Python wrapper exposes + `try_gated_logspace_recurrence_core_backward_window_cuda`, and the standalone + test compares the window ABI against the existing scalar + `gated_logspace_recurrence_outnorm_backward_cuda` schedule with reset rows, + non-power-of-two hidden width, a first-state `n=0` row, terminal state seeds, + and public/output gradients. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_gated_reverse_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_gated_reverse_window`. Focused + gated reverse-window scalar parity passed + `1 passed, 327 deselected in 40.57s`. The adjacent temporal extension guard + covering temporal glue, diagonal reverse scan, gated reverse scan, recurrent + message window, and the mixed-pop K=128 backward smoke passed + `5 passed, 323 deselected in 191.85s`. Hygiene passed: + `python -m py_compile ...`, `uv run ruff format --check ...`, + `uv run ruff check ...`, and `git diff --check`. +- Non-closure note: + this closes the gated transition-core scan ABI substage. It still does not + close R4 because the active `_run_backward_window` host loop carries the + semantic dependency through gated recurrent-affine `y` state and the + recurrent-message/initial-KV edge. The next R4 owner is the combined CUDA + temporal backward superop that fuses transition core, recurrent affine/input + projection gradient production, and recurrent-message reverse carry over flat + bucket identity; only then can the Python physical-step loop be removed and + H/K/T throughput audits close. + +### 2026-04-28 UTC - R4 CUDA gated y-carry recurrent-affine reverse window scan + +Status: ACCEPTED AS PARTIAL KERNEL/ABI SLICE; R4 REVERSE SUPEROP OPEN + +Owner: CUDA temporal backward gated transition y-state dependency. + +- Active invariant: + gated logspace transition backward cannot be a true temporal reverse scan if + the `y` state carry is produced later by a per-step Python recurrent-affine + backward call. The recurrent-affine dependency is physical transition math + over flat bucket identity and must be carried inside the same CUDA reverse + scan that handles gated core `c/n/m` state. +- Implementation: + added + `gated_logspace_core_recurrent_affine_backward_window` to the Fabric temporal + backward CUDA extension. The kernel scans a `[T,B,R,H]` gated bucket window + backward, performs the outnorm/core backward reductions, stores per-step + `grad_raw`, immediately applies the recurrent-affine transpose for the same + physical step, carries `y/c/n/m` adjoints to the previous window boundary, + masks all carries across reset rows, and accumulates both + `recurrent_kernel` and `outnorm_weight` gradients across `T*B`. The Python + wrapper exposes + `try_gated_logspace_recurrence_core_recurrent_affine_backward_window_cuda`. + The standalone parity test compares against the existing scalar gated core + backward plus an explicit `einsum("bngho,nghoi->bnhi")` recurrent-affine + backward schedule with reset rows, nontrivial `head_dim`, terminal state + seeds, and public/output gradients. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_gated_affine_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_gated_affine_window`. Focused + fused gated y-carry parity passed + `1 passed, 328 deselected in 41.42s`. The adjacent temporal extension guard + covering temporal glue, diagonal reverse scan, gated core reverse scan, + fused gated core/recurrent-affine reverse scan, recurrent-message window, and + the mixed-pop K=128 backward smoke passed + `6 passed, 323 deselected in 191.22s`. Hygiene passed: + `python -m py_compile ...`, `uv run ruff format --check ...`, + `uv run ruff check ...`, and `git diff --check`. +- Non-closure note: + this removes the main gated transition `y`-carry semantic blocker at the ABI + level, but it is not yet wired into active `_run_backward_window`. + R4 remains open until the active temporal backward path calls a combined + CUDA temporal superop that also fuses recurrent-message/initial-KV reverse + carry and the remaining gate/input projection gradient production over the + same flat bucket window. After that integration, the Python physical-step + loop can be removed and H/K/T throughput audits can run as closure evidence. + +### 2026-04-28 UTC - R4 active gated y-carry window ABI consumption + +Status: ACCEPTED AS ACTIVE-PATH BACKEND CONSUMPTION SLICE; R4 REVERSE SUPEROP OPEN + +Owner: active CUDA temporal backward gated transition edge. + +- Active invariant: + the supported mixed-pop temporal backward path should consume the fused gated + core plus recurrent-affine reverse-window ABI rather than keeping the + transition `y` carry split across a scalar gated core launch and a separate + recurrent-affine backward launch. This keeps the active path aligned with the + eventual CUDA temporal reverse superop without adding population-specific + logic. +- Implementation: + `_lower_gated_logspace_recurrence_backward` now tries the + `gated_logspace_core_recurrent_affine_backward_window` ABI with `T=1` for + supported CUDA float32 gated transition buckets. On success it consumes the + fused `grad_raw`, `y/c/n/m` carry, `recurrent_kernel`, and `outnorm_weight` + gradients and skips the old per-step recurrent-affine backward launch. On + unsupported shapes it falls back to the previous scalar gated core plus + recurrent-affine path. Reset semantics remain explicit: reset rows are passed + into the fused ABI and state gradients are still omitted when + `need_grad_packed_state_before=False`. Active metadata records + `temporal_backward_glue:cuda_gated_core_recurrent_affine_window` so audits can + prove the path physically moved. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_active_gated_affine` and + `/tmp/cortical_triton_${USER}_redo_fixmass_active_gated_affine`. High-level + mixed-pop terminal-loss, provided-state gradient, and K=128 backward smoke + passed with active launch metadata assertions: + `3 passed, 326 deselected in 5.97s`. Adjacent guards covering temporal glue, + diagonal reverse scan, gated core reverse scan, fused gated y-carry reverse + scan, recurrent-message window, single-pop K>1 parity, and mixed-pop K>1 + parity passed `8 passed, 321 deselected in 5.01s`. The existing mixed-pop + K=1 output-only owner check, which guards the restored forward metadata + assertion, passed `2 passed, 327 deselected in 3.99s`. Hygiene passed: + `python -m py_compile ...`, `uv run ruff format --check ...`, + `uv run ruff check ...`, and `git diff --check`. +- Non-closure note: + this closes active consumption of the fused gated `y/c/n/m` reverse-window + edge. R4 remains open because `_run_backward_window` still owns the + physical-step reverse loop and the recurrent-message/initial-KV edge still + has to be fused with transition backward into one CUDA temporal reverse + superop before throughput audits can close H/K/T. + +### 2026-04-28 UTC - R4 active diagonal core window ABI consumption + +Status: ACCEPTED AS ACTIVE-PATH BACKEND CONSUMPTION SLICE; R4 REVERSE SUPEROP OPEN + +Owner: active CUDA temporal backward diagonal transition edge. + +- Active invariant: + both transition families in the supported mixed-pop temporal path should + consume reverse-window CUDA ABIs on the live path. Leaving diagonal recurrence + on the scalar per-step core after the window kernel existed would keep the + future CUDA temporal reverse superop unexercised for one physical bucket + family. +- Implementation: + `_lower_diagonal_recurrence_backward` now tries + `diagonal_recurrence_core_backward_window` with `T=1` for supported CUDA + float32 diagonal buckets whose trace states are eligibility traces and do not + require reverse propagation. The fused path handles missing `grad_preproj` + with a zero preprojection gradient window, supports one-sided `hc1/hc2` + carry seeds by filling the missing side with zeros, passes reset rows into + the ABI, returns the same `grad_cell_input` and `nu_log/theta_log/w1/w2` + gradients, and leaves unsupported trace-gradient cases on the existing + scalar diagonal backward. Active metadata records + `temporal_backward_glue:cuda_diagonal_recurrence_core_backward_window`. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_active_diagonal_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_active_diagonal_window`. + High-level mixed-pop terminal-loss, provided-state gradient, and K=128 + backward smoke passed with active gated and diagonal window metadata + assertions: `3 passed, 326 deselected in 223.86s`. Adjacent guards covering + temporal glue, diagonal reverse scan, gated core reverse scan, fused gated + y-carry reverse scan, recurrent-message window, single-pop K>1 parity, and + mixed-pop K>1 parity passed `8 passed, 321 deselected in 48.18s`. Hygiene + passed: `python -m py_compile ...`, `uv run ruff format --check ...`, + `uv run ruff check ...`, and `git diff --check`. +- Non-closure note: + this closes active consumption of the diagonal transition-core reverse-window + ABI. R4 remains open until the active backward window no longer iterates + physical steps in Python and instead fuses transition buckets plus + recurrent-message/initial-KV reverse carry into one CUDA temporal reverse + superop over flat bucket identity. + +### 2026-04-28 UTC - R4 gated gate-affine parameter window deferral + +Status: ACCEPTED AS PARTIAL BACKEND PARAM-REDUCTION SLICE; R4 REVERSE SUPEROP OPEN + +Owner: active CUDA temporal backward gated receiver-affine parameter reduction. + +- Active invariant: + gated gate-affine weight/bias accumulation is temporal-window parameter + reduction work, not per-physical-step transition logic. The per-step + gate-affine input gradient still has a real dependency because it feeds the + recurrent-message backward edge, but the weight/bias reduction can be grouped + by flat bucket identity and reduced once over the window. +- Implementation: + `TransitionBackwardResult` now carries generic deferred transition parameter + gradient steps in addition to the existing input-projection step. The + flat-bucket backward accumulator forwards all deferred steps into the existing + temporal param-window reducer. The receiver-major param reducer now supports + rank-4 head-grouped gate affine weights, and + `_lower_gated_logspace_recurrence_backward` defers gated `gate_weight` and + gate `bias` gradients when temporal param binding is active while still + producing the per-step gate input gradient needed by recurrent-message + backward. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_gate_param_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_gate_param_window`. High-level + mixed-pop terminal-loss, provided-state gradient, and K=128 backward parity + passed `3 passed, 326 deselected in 229.70s`. Adjacent temporal guards + covering temporal glue, diagonal reverse scan, gated reverse scan, fused + gated y-carry reverse scan, recurrent-message window, single-pop K>1 parity, + and mixed-pop K>1 parity passed `8 passed, 321 deselected in 46.35s`. + Current-code owner timing for K=128 reports + `receiver_affine.gate_affine_backward:ms=422.116;count=128` and + `receiver_affine.input_projection_param_window:ms=13.606;count=4`. +- Non-closure note: + this is correctness-preserving and removes gate-affine parameter reductions + from the physical-step loop, but it does not close the hot owner. The + remaining `gate_affine_backward` cost is the per-step input-gradient + dependency. R4 remains open until transition backward and + recurrent-message/initial-KV backward are fused into one CUDA temporal + reverse superop over flat bucket identity, eliminating the Python + physical-step scan loop rather than only deferring parameter reductions. + +### 2026-04-28 UTC - R4 fused transition input-gradient window slices + +Status: ACCEPTED AS ACTIVE-PATH CUDA FUSION SLICE; R4 REVERSE SUPEROP OPEN + +Owner: active CUDA temporal backward transition-to-message gradient edge. + +- Active invariant: + transition-local input-gradient work that feeds the recurrent-message edge + should be produced by the same CUDA temporal transition window that computes + the local recurrence adjoints. It must not remain a separate per-step + receiver-affine launch once the core/recurrent-affine window kernel has the + required local gradients and weights. +- Implementation: + extended the gated logspace core plus recurrent-affine CUDA window ABI with + an optional gate-affine input-gradient output. The active gated transition + path passes `gate_weight` into that kernel, consumes the returned + `grad_population_input`, and records + `temporal_backward_glue:cuda_gated_core_recurrent_affine_input_window`. + Gate `weight/bias` reductions remain deferred through the temporal + param-window reducer. For diagonal recurrence, extended the core reverse + window ABI with optional output-projection gradient and output weight inputs; + the kernel now computes the output-projection input gradient internally + before applying the diagonal recurrence adjoint. The active K=128 path + records + `temporal_backward_glue:cuda_diagonal_recurrence_core_output_projection_window`. + Diagonal output-projection weight/bias reductions are deferred through the + same flat-bucket temporal param-window reducer. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_diagonal_output_fused` and + `/tmp/cortical_triton_${USER}_redo_fixmass_diagonal_output_fused`. Focused + low-level gated recurrent-affine window parity, including the new gate-input + output, passed `1 passed, 328 deselected in 41.11s`. High-level mixed-pop + terminal-loss, provided-state gradient, and K=128 backward parity passed + `3 passed, 326 deselected in 5.89s` after the final active-path condition. + Adjacent temporal guards covering temporal glue, diagonal reverse scan, + gated reverse scan, fused gated y-carry reverse scan, recurrent-message + window, single-pop K>1 parity, and mixed-pop K>1 parity passed + `8 passed, 321 deselected in 4.89s`. +- Owner movement: + K=128 current-code timing now reports + `receiver_affine.gate_affine_backward:ms=2.265;count=128`, down from the + prior `422.116 ms` after the param-only slice, and includes the new gated + input-window launch metadata. The earlier separate diagonal + `output_projection_backward` top owner is gone after broadening the active + condition to materialized and recomputed `preproj` cases. The current top + owner is + `receiver_affine.input_projection_backward:ms=453.892;count=256`. +- Non-closure note: + this removes two per-step transition-local receiver-affine launches from the + hot path, but R4 is still open. The next owner is the shared input-projection + backward that maps transition-local gradients back to recurrent messages for + both gated and diagonal buckets. That edge must move into the same shared + temporal reverse superop together with recurrent-message/initial-KV backward + before the Python physical-step reverse loop can close. + +### 2026-04-28 UTC - R4 fused transition input-projection backward windows + +Status: ACCEPTED AS ACTIVE-PATH CUDA FUSION SLICE; R4 REVERSE SUPEROP OPEN + +Owner: active CUDA temporal backward transition input-projection edge. + +- Active invariant: + after transition-local adjoints are produced inside the temporal window, the + input-projection input gradient that maps them back to recurrent-message + space is still transition-window work. Weight/bias accumulation can remain a + deferred temporal param-window reduction, but the per-step input-gradient + matmul should not be a separate receiver-affine launch for each physical + bucket step. +- Implementation: + extended the gated temporal reverse window with optional rank-2 + input-projection weight support. When active, the same CUDA kernel that + computes gated core, recurrent-affine, and gate-affine input adjoints also + emits `grad_recurrent_msg_window`, and `_lower_gated_logspace_recurrence_backward` + consumes it while creating only a deferred `value_to_cell_weight` / + `recurrent_cell_bias` param step. Extended the diagonal temporal reverse + window with optional receiver-major rank-3 input-projection weight support. + The diagonal kernel now emits `grad_recurrent_msg_window` from + `grad_cell_input_window`, and `_lower_diagonal_recurrence_backward` consumes + it while keeping the diagonal input-projection weight/bias reduction on the + temporal param-window reducer. Active metadata records + `temporal_backward_glue:cuda_gated_input_projection_backward_window` and + `temporal_backward_glue:cuda_diagonal_input_projection_backward_window`. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_input_projection_window` and + `/tmp/cortical_triton_${USER}_redo_fixmass_input_projection_window`. Focused + low-level diagonal and gated temporal window tests, including the new + `grad_recurrent_msg` outputs, passed + `2 passed, 327 deselected in 43.17s`. High-level mixed-pop terminal-loss, + provided-state gradient, and K=128 backward parity passed + `3 passed, 326 deselected in 193.89s`. Adjacent temporal guards covering + temporal glue, diagonal reverse scan, gated reverse scan, fused gated + y-carry reverse scan, recurrent-message window, single-pop K>1 parity, and + mixed-pop K>1 parity passed `8 passed, 321 deselected in 45.62s`. +- Owner movement: + `receiver_affine.input_projection_backward` dropped from the prior + `453.892 ms;count=256` to `24.167 ms;count=256` on the first timed K=128 + run and `23.672 ms;count=256` on a warmed confirmation. The new active + launch metadata for both gated and diagonal input-projection window fusion + is present. The current top owner in the warmed K=128 smoke is now + `public_projection:ms=388.966;count=2`, followed by diagonal/gated core and + recurrent-message owners around the 40 ms class. +- Non-closure note: + this closes the input-projection backward substage, but R4 remains open. + The next owner is the public projection backward path, which is now exposed + after the transition input-projection launches moved into the CUDA temporal + windows. The Python physical-step loop is also still present; final R4 + closure still requires a single CUDA temporal reverse superop over flat + bucket identity that owns transition, recurrent-message/initial-KV, public + projection, checkpoint/recompute, and carry materialization semantics. + +### 2026-04-28 UTC - R4 public-projection timing correction and next owner + +Status: PUBLIC-PROJECTION OWNER RECLASSIFIED; R4 REVERSE SUPEROP OPEN + +Owner: warmed active CUDA temporal backward reverse-loop owners. + +- Correction: + the apparent `public_projection:ms~=389-409;count=2` top owner after the + input-projection fusion was a first-use Triton specialization inside the + grouped input-boundary K/V weight-gradient reducer, not steady-state Fabric + backend work. The measured shapes were generic boundary projection shapes, + `sender=(64, 4, 8)`, grouped weight `(2, 8, 8)`, and gradient output + `(64, 4, 8)`. In-place instrumentation showed the first grouped + weight-gradient call at `~407.7 ms` and the second at `~0.09 ms`; a raw + warmed microbenchmark of the same grouped backward was `~0.046 ms`. +- Current warmed evidence: + rerunning the exact K=128/H=64 mixed-pop smoke twice in one process with GPU + 0 and private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_public_projection_diag` and + `/tmp/cortical_triton_${USER}_redo_fixmass_public_projection_diag` produced + a second-pass owner profile of + `diagonal_recurrence.core:ms=42.850;count=128`, + `state_epilogue.core:ms=40.992;count=128`, + `message.recurrent_initial_kv_backward:ms=40.426;count=128`, + `receiver_affine.input_projection_backward:ms=23.387;count=256`, + `temporal_artifact_recompute:ms=18.066;count=2`, and + `public_projection:ms=0.480;count=2`. +- Accepted next gate: + do not spend R4 work on the public-projection boundary reducer unless a + warmed audit row makes it hot again. The high-priority owner is still the + Python physical-step reverse loop itself: diagonal core, gated state + epilogue, recurrent-message/initial-KV, and the remaining transition-local + glue are being launched per physical K step. Closure requires moving the + scan loop into a shared CUDA temporal reverse superop over flat bucket + identity, with Python only passing tensors/metadata and binding returned + gradients. + +### 2026-04-28 UTC - R4 recurrent K/V param-window deferral + +Status: ACCEPTED AS ABI/BACKEND SLICE; R4 REVERSE SUPEROP OPEN + +Owner: recurrent-message/initial-KV backward parameter edge. + +- Active invariant: + recurrent K/V projection weight gradients produced during recurrent-message + reverse do not feed the reverse carry for the previous physical step. The + carry still needs input K/V gradients and recurrent hidden gradients in-step, + but recurrent K/V weight accumulation can be lowered as a temporal + window-level parameter reduction rather than as one weight-gradient launch per + physical step. +- Implementation: + extended the flat-bucket temporal recurrent-message/initial-KV CUDA ABI with + optional `return_weight_grad`, optional recurrent K/V gradient outputs, and a + backend-order recurrent K/V weight-gradient window reducer. The active + temporal backward path now skips per-step recurrent K/V weight gradients when + `defer_initial_recurrent_param_binding` is enabled, stores backend-order + hidden plus recurrent K/V adjoints in `TemporalSenderKVProjectionWindowParamGrad`, + and reduces them once in + `_reduce_temporal_initial_recurrent_raw_param_grads`. Active metadata records + `temporal_backward_glue:cuda_initial_recurrent_kv_param_grad_window` in + addition to the existing reduce tag. +- Rejected probe: + an optional fused hidden-gradient mode was added to the low-level recurrent + message kernel and covered by parity, but it was not enabled on the active + path. The warmed K=128 row was slightly slower when the message kernel + directly accumulated hidden gradients, so the active path keeps the existing + separate hidden-gradient kernel while using only the parameter-window + deferral. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_recurrent_kv_hidden_fused` and + `/tmp/cortical_triton_${USER}_redo_fixmass_recurrent_kv_hidden_fused`. + Focused recurrent-message temporal window parity, including skipped + per-step weight grad, deferred window weight grad, and optional fused hidden + grad, passed `1 passed, 328 deselected in 40.72s`. High-level mixed-pop + terminal-loss, provided-state gradient, and K=128 backward parity passed + `3 passed, 326 deselected in 189.89s`. +- Owner movement: + the active K=128 second-pass warmed timing after disabling the slower hidden + fusion was + `diagonal_recurrence.core:ms=42.081;count=128`, + `message.recurrent_initial_kv_backward:ms=41.389;count=128`, + `state_epilogue.core:ms=40.350;count=128`, + `receiver_affine.input_projection_backward:ms=23.352;count=256`, and + `public_projection:ms=0.399;count=2`. This slice removes the recurrent K/V + parameter-gradient work from the per-step edge and makes the parameter + reduction window-owned, but it does not close R4 because the warmed top owners + remain the per-step transition/message reverse loop. +- Next gate: + continue with the full CUDA temporal reverse superop. The next performance + owner is not a parameter binding or public projection; it is the physical + scan dependency itself, where transition backward, recurrent-message + backward, carry propagation, and checkpoint/recompute are still orchestrated + by `_run_backward_window` one physical step at a time. + +### 2026-04-28 UTC - R4 reverse dependency superop continuation + +Status: ACCEPTED AS ACTIVE-PATH CUDA FUSION SLICE; R4 REVERSE SUPEROP OPEN + +Owner: full CUDA temporal reverse superop over flat bucket identity. + +- Active invariant: + the current code has real CUDA kernels for transition-core reverse windows, + recurrent-message/initial-KV reverse windows, hidden-before materialization, + recurrent-query reduction, boundary accumulation, and parameter reductions. + The open R4 owner is still the dependency recurrence that connects those + edges: transition backward at physical step `t-1` needs the recurrent-message + hidden adjoint produced from step `t`, so `_run_backward_window` still owns a + reverse Python loop across physical steps. +- Recovered-core check: + `ai_docs/recovered_core.py` confirms the April 26 direction was a + `SharedTemporalBackwardHelperMixin` plus planner-owned temporal execution + fields, not benchmark-side streaming or a sibling route. No recovered CUDA + reverse-superop implementation is present in the recovered file, so current + code must build this boundary intentionally. +- Current next gate: + add backend-owned reverse-window ABI/metadata that reduces real loop work and + prepares a single combined transition/message reverse scan. This must stay + flat-bucket based: low-level transition families may appear as + `fabric.cuda.nn` physical op kinds, but user population names, single/mixed + cardinality, and benchmark rows must not drive selection. +- Non-goals for this slice: + do not relabel `temporal_plan_backward_owners` as `cuda_temporal_superop` + until the active reverse dependency loop physically moves; do not chase + cold-only `public_projection`; do not add cell-name-specific backend logic. +- Implementation: + fused the active recurrent-message `with_state` non-direct-hidden path so the + backend-order recurrent K/V hidden projection and full-cell gradient + materialization run in one CUDA window kernel. This removes one adjacent + projection/materialization launch from each recurrent-message reverse edge + while preserving the same flat bucket recurrent-message ABI and without + branching on user population names or benchmark rows. The direct-hidden + optional path remains supported separately because it was parity-clean but + slower on the active warmed row in the previous slice. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_recurrent_with_state_fused` + and `/tmp/cortical_triton_${USER}_redo_fixmass_recurrent_with_state_fused`. + Focused recurrent-message temporal window parity passed + `1 passed, 328 deselected in 40.56s`. High-level mixed-pop K>1 terminal, + provided-state gradient, and K=128 rows passed + `3 passed, 326 deselected in 187.43s`. Adjacent temporal backward guard + passed `8 passed, 321 deselected in 49.35s`. +- Owner movement: + the warmed second pass for the K=128/H=64 mixed-pop row reported + `message.recurrent_initial_kv_backward:ms=38.996;count=128`, compared to the + previous accepted warmed value of about `41.389 ms`. Other dominant owners + remained the reverse scan dependency: + `diagonal_recurrence.core:ms=42.033;count=128`, + `state_epilogue.core:ms=40.730;count=128`, + `receiver_affine.input_projection_backward:ms=25.036;count=256`, and + `temporal_artifact_recompute:ms=17.173;count=2`. First-pass + `public_projection` was again cold specialization and warmed to + `0.400 ms`. +- Next gate: + R4 remains open. The accepted slice reduces a real active recurrent-message + edge, but the actual reverse recurrence is still driven by + `_run_backward_window`. The next high-priority owner is still the combined + CUDA temporal reverse superop that carries transition backward, + recurrent-message backward, carry propagation, and checkpoint/recompute + across the physical window. + +### 2026-04-28 UTC - R4 combined reverse-superop ABI slice + +Status: IN PROGRESS + +Owner: full CUDA temporal reverse superop over flat bucket identity. + +- Active invariant: + the next valid backend slice must reduce Python-owned reverse dependency + work or create active ABI for a single backend-owned reverse scan. It must + remain flat-bucket based: no cell-family route forks, no single-pop versus + mixed-pop identities, and no benchmark-owned K/H/T loop policy. +- Starting diagnosis: + `_run_backward_window` still drives physical steps in Python. The warmed + K=128/H=64 owner stack is transition core, gated state epilogue, recurrent + message/initial-KV, and transition input-projection glue, all repeated per + physical step. The next edit will target backend-owned recurrent-message / + transition reverse ABI needed for the combined CUDA temporal scan and will + be accepted only with low-level parity plus high-level K>1 parity. +- Guardrails: + GPU runs use devices 0-4 only with private torch/triton caches. Do not mark + `temporal_plan_backward_owners=cuda_temporal_superop` until the active scan + loop itself moves and throughput audits pass. +- Current edit target: + add a sender-reverse-table CUDA ABI for the recurrent-message/initial-KV + temporal backward edge. This uses the live flat bucket identity + `local_receiver_idx_by_sender` tables already present in the runtime and + prepares the combined reverse scan to be sender-owned rather than relying on + receiver-major atomics. The edit must not introduce population-name or + cell-family routing; physical op families may remain only as lower-level + transition kernels. + +### 2026-04-28 UTC - R4 recurrent reverse ABI probe and hidden-output trim + +Status: ACCEPTED ACTIVE TRIM; SENDER-REVERSE ACTIVE PROBE REJECTED; R4 OPEN + +Owner: recurrent-message/initial-KV reverse edge inside the Python physical +reverse loop. + +- Implemented: + added a sender-reverse-table CUDA ABI for recurrent-message/initial-KV + temporal backward. The ABI accepts flat-bucket `sender_receiver_idx` tables + and runs a receiver stats/grad-query kernel followed by a sender-owned K/V + gradient kernel. This is covered by low-level parity and remains available + for the future combined reverse superop, but it is not the active hot path + for the current K=128/H=64 row. +- Rejected active probe: + enabling sender-reverse on the active with-state recurrent edge raised the + warmed `message.recurrent_initial_kv_backward` owner to + `48.980 ms;count=128`, compared with the prior accepted `~38.996 ms` class. + The active path was therefore restored to the receiver-major fused message + kernel. The sender-reverse ABI is retained as backend surface, not closure + evidence. +- Accepted active trim: + the with-state recurrent-message CUDA path can now skip returning the + backend-order hidden-gradient window when the caller already consumes the + full flat-cell carry materialized by the same kernel. The active temporal + backward bridge requests `return_hidden_grad=False`, avoids indexing the + skipped tensor, and still propagates the full-cell carry plus recurrent K/V + parameter-window adjoints. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_hidden_skip` and + `/tmp/cortical_triton_${USER}_redo_fixmass_hidden_skip`. Focused + recurrent-message parity, including sender-reverse and skipped hidden + output, passed `2 passed, 327 deselected in 3.89s` on the warmed cache + after the initial compile/parity pass. High-level mixed-pop + K>1 terminal/provided-state/K=128 parity passed + `3 passed, 326 deselected in 6.99s` after cache warm. Adjacent temporal + backward guard passed `8 passed, 321 deselected in 46.40s`. +- Owner movement: + the warmed second pass for K=128/H=64 reported + `message.recurrent_initial_kv_backward:ms=37.979;count=128`, + `diagonal_recurrence.core:ms=43.278;count=128`, + `state_epilogue.core:ms=40.163;count=128`, + `receiver_affine.input_projection_backward:ms=24.511;count=256`, and + `temporal_artifact_recompute:ms=16.346;count=2`. This is a small accepted + active recurrent-edge trim; it does not close R4 because the Python reverse + physical-step loop and per-step transition/message launches remain. +- Next gate: + continue toward the combined CUDA temporal reverse superop. Do not spend the + next pass on sender-reverse as the active hot-row replacement unless a larger + current-code row shows atomics dominate; the immediate active owners remain + the reverse scan dependency across diagonal core, gated core, and recurrent + message/carry propagation. + +### 2026-04-28 UTC - R4 backend-cache carry trim inside reverse loop + +Status: ACCEPTED AS ACTIVE-PATH HOST-CARRY TRIM; R4 REVERSE SUPEROP OPEN + +Owner: Python-owned reverse physical-step loop inside +`TemporalPhysicalBackwardScanExecutor._run_backward_window`. + +- Current invariant: + the next accepted slice must move real temporal reverse work toward the + backend-owned flat-bucket superop. It must not introduce population-name, + cell-family, benchmark-row, hidden-size, or single-pop/mixed-pop routing. The + temporal owner is shared multi-pop: one-pop and many-pop are just flat bucket + cardinalities. +- Live code refresh: + `_run_backward_window` still loops over physical steps in Python, and each + step calls transition backward before recurrent-message/initial-KV backward + can produce the previous carry. The active dependency is therefore not closed + by planner labels or audit script cleanup. Current hot rows are expected to + remain transition core, state epilogue, recurrent message/carry propagation, + and receiver-affine input projection until a combined CUDA temporal reverse + scan owns that recurrence. +- Next action: + reran the warmed K=128/H=64 owner profile on current code with GPU 0 and + private caches. Pre-edit warmed pass reported + `diagonal_recurrence.core:ms=45.469;count=128`, + `state_epilogue.core:ms=41.495;count=128`, + `message.recurrent_initial_kv_backward:ms=39.519;count=128`, + `receiver_affine.input_projection_backward:ms=25.029;count=256`, and + `temporal_artifact_recompute:ms=17.078;count=2`. +- Implemented: + kept backend-order population state gradients as backend-order cache across + the active physical reverse loop. When a step returns + `grad_backend_state_cache`, the next step consumes that backend cache + directly, so `_population_grad_dict(...)` conversion is now deferred until + the window boundary where public autograd state gradients are returned. This + removes repeated TensorDict/public-state conversion from the open host scan + without changing transition/message math or adding population/cell routing. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_r4_state_cache_trim` and + `/tmp/cortical_triton_${USER}_redo_fixmass_r4_state_cache_trim`. High-level + mixed-pop K>1 terminal/provided-state/K=128 parity passed + `3 passed, 326 deselected in 224.30s`; the provided-state row covers the + backend-cache-to-public-gradient boundary. Adjacent temporal backward guard + passed `8 passed, 321 deselected in 47.75s`. +- Owner movement: + warmed post-edit K=128/H=64 pass reported + `diagonal_recurrence.core:ms=41.087;count=128`, + `state_epilogue.core:ms=39.484;count=128`, + `message.recurrent_initial_kv_backward:ms=36.883;count=128`, + `receiver_affine.input_projection_backward:ms=24.702;count=256`, and + `temporal_artifact_recompute:ms=16.918;count=2`. Treat this as a safe + active-path host-carry cleanup, not R4 closure; the same per-step CUDA + transition/message owners still dominate. +- Next gate: + R4 remains the combined CUDA temporal reverse superop. The next high-priority + implementation should move an actual transition/message dependency edge into + a backend-owned reverse scan, not spend time on planner labels or audit + cleanup. + +### 2026-04-28 UTC - R4 backend-order recurrent carry slice + +Status: ACCEPTED AS BACKEND-CARRY SLICE; R4 REVERSE SUPEROP OPEN + +Owner: recurrent-message to transition carry edge inside the Python physical reverse loop. + +- Current invariant: + the reverse loop should carry backend-owned recurrent gradients in flat bucket + backend order until a public full-cell gradient is actually required. The + current active path still asks recurrent-message backward to materialize a + full flat-cell gradient every physical step, then the next transition step + immediately slices/reorders that full tensor back to backend order. That is a + host-loop compatibility shape, not the final CUDA temporal reverse-superop + carry representation. +- Planned edit: + extend the recurrent-message/initial-KV `with_state` CUDA ABI so it can return + backend-order recurrent hidden carry without materializing full flat-cell + carry for each step. Thread an optional backend-order recurrent carry through + `_run_backward_window`, only materializing full cells at the window boundary + when autograd needs public state gradients. This must remain flat-bucket and + physical-op generic: no population names, row IDs, hidden-size constants, or + benchmark-specific routing. +- Implemented: + the recurrent-message/initial-KV `with_state` CUDA ABI now accepts + `return_full_cell_grad`. The active temporal backward window requests + backend-order recurrent hidden carry and skips the per-step full flat-cell + carry materialization; the window materializes full cells only at the public + autograd boundary. The step backward path also keeps active output-cell + gradients visible when there is backend-only recurrent carry, which preserved + mixed-pop sequence-loss/reset parity. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_backend_carry` and + `/tmp/cortical_triton_${USER}_redo_fixmass_backend_carry`. Low-level + recurrent-message/initial-KV parity, including the new backend-carry/no + full-cell-output mode and reset handling, passed + `2 passed, 327 deselected in 3.37s` after warm cache. Adjacent temporal + backward guard passed `8 passed, 321 deselected in 4.58s`. High-level + mixed-pop K>1 terminal/provided-state/K=128 parity through public + `model(...); loss.backward()` passed `3 passed, 326 deselected in 5.42s`. + The reset/per-timestep mixed-pop row was explicitly rerun after the + output-cell-gradient fix and passed `1 passed, 328 deselected in 4.19s`. +- Owner movement: + warmed K=128/H=64 profile with backend recurrent carry reported + `diagonal_recurrence.core:ms=42.147;count=128`, + `state_epilogue.core:ms=39.301;count=128`, + `message.recurrent_initial_kv_backward:ms=38.172;count=128`, + `receiver_affine.input_projection_backward:ms=23.305;count=256`, and + `temporal_artifact_recompute:ms=17.384;count=2`. A temporary same-cache + full-cell-carry comparison reported + `message.recurrent_initial_kv_backward:ms=38.203;count=128`, so this slice + is throughput-neutral on the small K=128/H=64 row while removing per-step + full-cell carry materialization from the active backend path. It is accepted + as a necessary carry-representation step, not as R4 closure. +- Next gate: + move the transition/message dependency itself into the CUDA temporal reverse + superop. The current Python physical-step loop and per-step transition, + recurrent-message, and state-epilogue launches remain the R4 owner. + +### 2026-04-28 UTC - R4 full reverse-superop owner removal pass + +Status: IN PROGRESS + +Owner: full CUDA temporal reverse superop over flat bucket identity. + +- Current invariant: + redo_fixmaass cannot move into audit stages until the backward temporal + dependency recurrence is backend-owned. The active forward path already + streams `T*K` through the temporal scan superop, but backward still rebuilds + step artifacts and then runs transition backward, recurrent-message + backward, carry propagation, and state/public materialization one physical + step at a time in `_run_backward_window`. +- Closure target: + remove the Python-owned reverse dependency owner before audit. The first + acceptable path is an active CUDA temporal reverse-superop ABI that consumes + planner/tape artifacts and owns the recurrent carry across physical steps. + Low-level physical op families may appear as lowered operator kinds, but the + reverse owner must operate over flat bucket identity and must not branch on + user population names, single-vs-mixed cardinality, benchmark row, or hidden + size. +- Immediate implementation direction: + inspect the existing cooperative temporal scan and transition-backward window + kernels, then move the next dependency edge that is still orchestrated by the + Python reverse loop into backend-owned window/superop code. Accepted slices + must pass low-level CUDA parity plus high-level public + `model(...); loss.backward()` K>1/K=128 parity before being committed. + +### 2026-04-28 UTC - Shared temporal scan admits one-bucket fabrics + +Status: ACCEPTED FOR FORWARD SCAN CARDINALITY; R4 REVERSE SUPEROP OPEN + +Owner: flat-bucket temporal scan admission and CUDA ABI cardinality. + +- Current invariant: + single-population and mixed-population are the same shared temporal engine. + A one-population fabric is one flat transition bucket; it must not fall back + to a separate Python temporal scan just because the other transition family + is absent. The backend identity is flat bucket identity, not + single-pop/mixed-pop routing. +- Implemented: + generalized `_try_cuda_mixed_flat_bucket_temporal_scan` and CUDA replay + artifact reconstruction so the existing cooperative temporal scan accepts + one or two transition buckets. Missing gated or diagonal families are passed + as zero-count tensors with the same ABI, so the kernel stays a shared + flat-bucket temporal superop and does not branch on population names or + benchmark rows. Final-state assembly, checkpoint state cache, and transition + tape reconstruction now include only the buckets that exist. +- CUDA ABI: + `flat_bucket_gated_diagonal_temporal_scan_cuda` now allows either bucket + count to be zero while requiring the provided buckets to cover all recurrent + receivers and not overlap. This moves the one-bucket `T*K` forward scan from + `python_autograd_scan` to `cuda_temporal_superop`; it is not a reverse-loop + closure. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_single_bucket_scan` and + `/tmp/cortical_triton_${USER}_redo_fixmass_single_bucket_scan`. + High-level public `model(...); loss.backward()` single-pop K>1 parity for + sLSTM and Axon passed `2 passed in 4.66s`, and the record now reports + `launch_temporal_scan_owners == ("cuda_temporal_superop",)`. Mixed-pop K>1, + terminal loss, provided-state gradient, and K=128 rows passed + `4 passed in 8.70s`. Reset-sensitive high-level horizon parity for single + and mixed-pop rows passed `4 passed in 4.90s`. +- Remaining owner: + R4 remains open until `TemporalPhysicalBackwardScanExecutor._run_backward_window` + no longer owns the reverse physical-step dependency recurrence in Python and + the active training backward owner reports a CUDA temporal reverse superop. + +### 2026-04-28 UTC - R4 combined reverse scan ABI checkpoint + +Status: IN PROGRESS + +Owner: combined CUDA temporal reverse scan over flat bucket identity. + +- Current invariant: + do not move into audit stages while the temporal backward dependency scan is + Python-owned. The remaining dependency is semantic, not cosmetic: + transition backward at physical step `t-1` needs the recurrent hidden/KV + carry produced by recurrent-message backward at physical step `t`. +- Live code state: + transition families already have CUDA reverse-window kernels, recurrent + message/initial-KV has CUDA temporal-window kernels, and boundary/query/param + reductions have CUDA window reducers. The active path still stitches those + kernels with a Python loop in `_run_backward_window`, so each physical K step + repeats transition core, recurrent-message, and carry propagation launches. +- Next implementation target: + introduce the backend ABI for a combined reverse scan that consumes flat + bucket transition tapes plus recurrent-message tables and owns the hidden + carry across the physical window. The ABI must be flat-bucket based and may + lower to gated/diagonal physical op families, but it must not branch on + population names, single-vs-mixed cardinality, benchmark rows, or hidden-size + constants. Planner/backward owner metadata must remain transitional until the + active path physically calls this combined reverse scan and parity passes. +- Guard added before touching the combined reverse scan: + single-pop terminal/H replay now has a high-level public + `model(...); loss.backward()` parity test for both sLSTM and Axon. This + prevents the one-bucket forward admission change from hiding a Python replay + fallback in terminal training. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_one_bucket_replay` and + `/tmp/cortical_triton_${USER}_redo_fixmass_one_bucket_replay`. + `test_fabric_cuda_single_population_terminal_replay_uses_temporal_superop` + passed `2 passed in 228.04s`; ruff and `py_compile` passed for the touched + runtime test and temporal backward module. + +### 2026-04-28 UTC - R4 rejected reverse-carry probe and stricter review gate + +Status: REJECTED PROBE; R4 remains OPEN. + +Owner: shared temporal engine ABI and CUDA reverse dependency recurrence. + +- Rejected probe: + an attempted `reverse_carry_scan` binding/source started to encode + gated/diagonal transition arguments directly in the temporal engine. That is + not fixmass design, even if it targets the right bottleneck. It would make the + shared temporal engine a cell-family executor instead of a generic + flat-bucket/tensor-table scheduler. +- Cleanup performed: + removed the unlanded reverse-carry binding and deleted the unlanded + `flat_bucket_temporal_reverse_scan_kernels.cu` source. No closure credit is + assigned to that probe. +- Corrected implementation target: + the next temporal reverse owner must introduce a generic tensor-table/op-table + ABI lowered from `fabric.cuda.nn`/Fabric IR. The temporal engine may schedule + primitive op rows, tensor slots, state slots, dependencies, reset policy, + horizon/checkpoint policy, and materialization policy. It must not branch on + `sLSTM`, `Axon`, gated/diagonal cell family, population name, benchmark row, + hidden-size policy, or single-vs-mixed population route. Reusable low-level + primitive kernels may implement generic physical op families; they are not the + shared temporal scheduler. +- Manual code-review gate added for every future temporal kernel/superop: + before binding or accepting a kernel, this doc must record its ABI inputs and + explicitly confirm no cell-kind selector, no population-name selector, no + benchmark-row selector, no hidden-size policy key, no separate single/mixed + route, and no cell-family parameter bundle. Parity and throughput do not close + a backend stage if this review fails. +- Design constraint: + user pressure, audit pressure, or throughput pressure must never override the + Fabric design goals. If a requested or tempting shortcut violates the shared + temporal engine/tensor-table boundary, reject it and log the rejection. + +### 2026-04-28 UTC - R3 forward temporal ABI review and tensor-table scaffold + +Status: IN PROGRESS; R3/R4 remain OPEN. + +Owner: shared temporal forward ABI and generic primitive/tensor-table lowering. + +- Manual review finding: + forward is not fully clean yet. The active sequence-surface CUDA scan does not + select on population names or `sLSTM`/`Axon`, and one-bucket and mixed-bucket + fabrics share the same scan owner. However its Python/CUDA ABI is still a + bespoke gated/diagonal primitive bundle. That is a transitional primitive + adapter, not the final shared temporal-engine ABI. +- Correct boundary: + forward scan admission and future scan kernels must be driven by flat bucket + identity plus tensor-table/op-table rows lowered from `fabric.cuda.nn`/Fabric + IR. Primitive names such as `gated_logspace_recurrence` or `diag_rtu` may + appear as reusable primitive row declarations; they must not become cell + family, population-name, benchmark-row, hidden-size, or single/mixed route + policy. +- Implemented in this slice: + added `sequence_surface/temporal_tables.py` with `TemporalPrimitiveTablePlan`, + `TemporalPrimitiveRow`, and `TemporalTensorTableSlot`. The plan is built from + backend-order flat buckets and primitive backward capability metadata, then + scan admission consumes this table rather than tuple-matching cell transition + op sequences. Runtime metadata now records the temporal table review summary, + primitive rows, and primitive families when the CUDA temporal scan is active. +- Manual code-review gate for this slice: + ABI inputs are flat bucket rows, schema/tensor slot roles, primitive names, + primitive families, and primitive backward behavior. Review confirmed no + cell-kind selector, no population-name selector, no benchmark-row selector, no + hidden-size policy key, and no separate single/mixed route in the new table + selection layer. The legacy scan binding still passes primitive-specific + tensors, so R3 is explicitly not closed by this scaffold. +- Next required R3/R4 work: + replace the forward scan binding itself with a tensor-table/op-table temporal + superop, then apply the same table ABI to the reverse temporal scan. Do not + let the existing primitive-bundle adapter become final architecture. +- Evidence: + `ruff check` and `py_compile` passed for the touched temporal/table/test + files. `tests/test_fabric_backend_plan.py` passed `49 passed in 5.48s`. + GPU 0 high-level public CUDA guards with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_one_bucket_replay` and + `/tmp/cortical_triton_${USER}_redo_fixmass_one_bucket_replay` passed + `test_fabric_cuda_mixed_population_k1_output_only_uses_shared_temporal_bucket_scan` + and `test_fabric_cuda_single_population_terminal_replay_uses_temporal_superop` + as `4 passed in 22.14s`. + +### 2026-04-28 UTC - R3/R4 temporal tape policy uses primitive table facts + +Status: ACCEPTED AS TABLE-POLICY SLICE; R3/R4 remain OPEN. + +Owner: planner/runtime temporal tape policy and primitive table ABI. + +- Implemented: + transition tape memory classification now uses `TemporalPrimitiveTablePlan` + rows and primitive tape capability facts instead of a helper that walked + population names and local op-name sets. Full-tape extra state factors are + keyed by reusable primitive names (`gated_logspace_recurrence`, `diag_rtu`, + `diagonal_recurrence`) and bucket ordinals. +- Boundary review: + this policy layer consumes flat bucket row ordinals and primitive rows. It + does not select on cell kind, population name, benchmark row, hidden-size + policy, or single/mixed route. The legacy scan binding remains a separate + open R3 kernel ABI owner. +- Evidence: + `ruff check` and `py_compile` passed for the touched files; the focused + primitive-table test passed `1 passed in 4.88s`; full + `tests/test_fabric_backend_plan.py` passed `49 passed in 5.47s`. + GPU 0 high-level CUDA training guards with the same private caches passed + single-pop terminal replay, mixed reset horizon parity, and mixed K=128 + backward as `5 passed in 8.27s`. + +### 2026-04-28 UTC - R3/R4 execution-record metadata uses primitive table facts + +Status: ACCEPTED AS METADATA CLEANUP SLICE; R3/R4 remain OPEN. + +Owner: runtime execution metadata and temporal primitive-table boundary. + +- Implemented: + transition-backward and state-affine execution-record metadata now build a + `TemporalPrimitiveTablePlan` from flat bucket identity and static tensor + slots. The metadata no longer walks active population names or per-population + transition op lists to decide which lowered backward owners to report. +- Boundary review: + this slice consumes primitive table rows and primitive names only as reusable + `fabric.cuda.nn` capability facts. It does not select on cell kind, + population name, benchmark row, hidden-size policy, or single/mixed route. + It is still metadata-only with respect to the open R3/R4 owners: the forward + CUDA scan ABI remains a transitional primitive-bundle adapter and the reverse + physical-step dependency recurrence is still Python-owned. +- Evidence: + `ruff check` and `py_compile` passed for the touched temporal executor. + GPU 0 high-level public CUDA guards with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_one_bucket_replay` and + `/tmp/cortical_triton_${USER}_redo_fixmass_one_bucket_replay` passed + single-pop terminal replay plus mixed K=128 backward as `3 passed in 5.18s`, + and mixed K=1 forward-scan metadata guards as `2 passed in 3.43s`. +- Next required owner: + stop spending cycles on record-label cleanup unless it directly supports the + kernel ABI. The next R3/R4 step must move the active forward scan binding and + backward reverse dependency scan toward a generic tensor-table/op-table CUDA + temporal superop, not a cell-family primitive bundle. + +### 2026-04-28 UTC - R3 forward scan extension boundary uses tensor/op tables + +Status: ACCEPTED AS FORWARD ABI SLICE; R3/R4 remain OPEN. + +Owner: active forward CUDA temporal scan binding ABI. + +- Implemented: + the active Python-to-extension call for the flat-bucket temporal scan now + enters C++ as a tensor-role table plus primitive-row table and scalar scan + descriptor. The extension adapter decodes primitive rows + (`gated_logspace_recurrence`, `diag_rtu`) and dispatches into the existing + cooperative CUDA scan kernel. Runtime execution metadata records + `flat_bucket_temporal_scan_binding_abi:flat_bucket_temporal_table_extension` + when this active path is used. +- Manual boundary review: + ABI inputs are `vector`, tensor role strings, `primitive_rows`, + `scalar_i64`, and `scalar_f64`. The table boundary has no cell-kind selector, + no population-name selector, no benchmark-row selector, no hidden-size policy + key, no separate single/mixed route, and no cell-family parameter bundle. + Primitive names remain reusable `fabric.cuda.nn` capability rows. This slice + does not close R3 because the decoded adapter still feeds the old + primitive-specialized cooperative scan kernel; the kernel body itself is not + yet a fully generic tensor/op-table temporal superop. +- Evidence: + `ruff check`, `py_compile`, and `git diff --check` passed for the touched + Python files before the CUDA compile. GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_table_scan_abi` and + `/tmp/cortical_triton_${USER}_redo_fixmass_table_scan_abi` passed mixed K=1 + forward metadata/parity guards as `2 passed in 76.96s`, single-pop terminal + replay plus mixed K=128 high-level training guards as `3 passed in 152.99s`, + and reset-sensitive mixed horizon parity as `2 passed in 5.69s`. +- Next required owner: + replace the extension adapter and kernel internals with a CUDA temporal + superop that consumes the tensor/op table directly, then apply the same table + ABI to the reverse dependency scan. Do not claim R3 closure while the kernel + is still the primitive-specialized scan implementation. + +### 2026-04-28 UTC - R4 recurrent-message backward extension boundary uses tables + +Status: ACCEPTED AS BACKWARD ABI SLICE; R4 remains OPEN. + +Owner: active recurrent-message backward window binding ABI. + +- Implemented: + the active recurrent-message + backend-order K/V projection backward window + with state materialization now enters the CUDA extension through a tensor-role + table, op-row table, and scalar descriptor. The table adapter dispatches into + the existing CUDA recurrent-message backward window kernel. Backward record + updates now preserve + `flat_bucket_temporal_backward_binding_abi:flat_bucket_temporal_recurrent_message_table_extension` + once the table path is used. +- Manual boundary review: + ABI inputs are `vector`, tensor role strings, `op_rows`, + `scalar_i64`, and `scalar_f64`. The binding has no cell-kind selector, no + population-name selector, no benchmark-row selector, no hidden-size policy + key, no separate single/mixed route, and no cell-family parameter bundle. + This is a recurrent-message primitive/table boundary, not a cell route. R4 is + still open because `_run_backward_window` continues to orchestrate transition + backward, recurrent-message backward, carry propagation, boundary/query + reductions, and parameter binding one physical step at a time in Python. +- Evidence: + `ruff check`, `py_compile`, and `git diff --check` passed for the touched + Python files. GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_backward_table_abi` and + `/tmp/cortical_triton_${USER}_redo_fixmass_backward_table_abi` compiled the + updated backward extension and passed mixed K=128 high-level + `model(...); loss.backward()` parity as `1 passed in 5.20s` after the first + metadata assertion exposed and fixed late backward-alias recording. Terminal + single-pop replay and reset-sensitive mixed horizon parity passed + `4 passed in 8.03s`. +- Next required owner: + build the actual CUDA temporal reverse superop over flat bucket tensor/op + tables. Table-boundary adapters are useful checkpoints, but they do not close + the reverse dependency owner while the Python physical-step loop remains. + +### 2026-04-28 UTC - R4 transition-core backward extension boundary uses tables + +Status: ACCEPTED AS TRANSITION BACKWARD ABI SLICE; R4 remains OPEN. + +Owner: active transition-core backward window binding ABI. + +- Implemented: + gated logspace core, gated recurrent-affine core, and diagonal recurrence + core backward windows now enter the backward CUDA extension through + tensor-role tables, op-row tables, and scalar descriptors. The C++ table + adapters validate primitive/op rows and decode tensor roles before dispatching + into the existing CUDA window kernels. The active high-level transition + backward path records + `flat_bucket_temporal_transition_backward_binding_abi:flat_bucket_temporal_transition_table_extension` + when the table ABI is consumed. +- Manual boundary review: + ABI inputs are `vector`, tensor role strings, `op_rows`, + `scalar_i64`, and `scalar_f64`. The table boundary has no cell-kind selector, + no population-name selector, no benchmark-row selector, no hidden-size policy + key, no separate single/mixed route, and no cell-family parameter bundle. + Primitive names are used only as reusable `fabric.cuda.nn` capability rows. + This slice intentionally does not close R4 because `_run_backward_window` + still owns the reverse physical-step dependency loop in Python, and the table + adapters still dispatch into existing primitive-specific CUDA kernels. +- Evidence: + `ruff check`, `py_compile`, and `git diff --check` passed for the touched + Python files before CUDA compile. GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_transition_table_abi` and + `/tmp/cortical_triton_${USER}_redo_fixmass_transition_table_abi` compiled the + updated backward extension and passed mixed K=128 high-level + `model(...); loss.backward()` parity as `1 passed in 230.06s`. The same cache + passed direct gated-core window parity, single-pop terminal replay, and + reset-sensitive mixed horizon parity as `5 passed in 8.53s`. +- Next required owner: + collapse the remaining active reverse dependency scheduler into a CUDA + temporal reverse superop over flat bucket tensor/op tables. Do not close R4 + until transition backward, recurrent-message backward, carry propagation, + boundary/query reductions, and parameter-gradient windows are scheduled from + that shared table-owned reverse superop without the Python physical-step + dependency loop. + +### 2026-04-28 UTC - R4 reverse-scan fake-closure guard + +Status: ACCEPTED AS GUARDRAIL; guardrail owner only, R4 remains OPEN. + +Owner: active reverse-scan ownership evidence and boundary guard. + +- Implementing: + the current physical backward window now records + `flat_bucket_temporal_reverse_scan_owner:python_host_reverse_loop` while the + Python reverse physical-step dependency loop is active. A runtime validator + fails closed if any backward record claims `cuda_temporal_superop` or + `cuda_temporal_reverse_superop` ownership without the required + `flat_bucket_temporal_reverse_scan_binding_abi:flat_bucket_temporal_reverse_table_extension` + evidence, or while the Python host-loop owner is still present. +- Manual boundary review: + ABI/metadata inputs are only the existing `BackendExecutionRecord`, backward + owner metadata, backward physical executor names, and flat-bucket temporal + reverse-scan owner/binding tags. This does not add a temporal kernel. It adds + no cell-kind selector, no population-name selector, no benchmark-row + selector, no hidden-size policy key, no separate single/mixed route, and no + cell-family parameter bundle. A new source-level boundary test also audits + the temporal table/scan CUDA source files for those forbidden route + selectors. +- Closure rule: + this guard does not close R4. R4 remains open until the active reverse scan + itself moves into a table-owned CUDA temporal superop over flat-bucket + tensor/op tables and the high-level parity plus throughput gates pass. +- Evidence: + `ruff check`, `py_compile`, `git diff --check`, and the source boundary test + `tests/test_fabric_backend_boundaries.py` passed. GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_owner_guard` and + `/tmp/cortical_triton_${USER}_redo_fixmass_reverse_owner_guard` passed mixed + K=128 high-level `model(...); loss.backward()` parity as + `1 passed in 229.62s`. The same cache passed terminal/provided-state K>1 and + reset-sensitive mixed horizon parity as `4 passed in 6.78s`. + +### 2026-04-28 UTC - R4 real reverse temporal engine implementation + +Status: ACCEPTED AS FIRST REAL ENGINE ABI; active mixed-bucket integration +remains OPEN. + +Owner: combined table-owned CUDA temporal reverse engine over flat bucket +identity. + +- User correction: + do not keep accepting small guard/slice work for this owner. The missing work + is the engine: the recurrent transition/message reverse dependency must move + out of `_run_backward_window` and into a backend-owned temporal reverse + executor. Guardrails are useful only insofar as they prevent fake closure; + they are not progress on the remaining throughput owner by themselves. +- Engine invariant: + the shared reverse engine consumes flat-bucket tensor tables, op rows, + dependency/carry rows, reset schedule, checkpoint/window schedule, and + materialization policy. It may dispatch reusable `fabric.cuda.nn` primitive + physical rows, but it must not branch on cell kind, population name, + benchmark row, hidden-size policy, or single-vs-mixed route. Single-bucket + and multi-bucket fabrics are the same engine with different table row counts. +- Implementation target now: + add the combined temporal reverse table ABI and active executor path that + owns the transition -> recurrent-message -> carry recurrence for a whole + backward window. The old Python loop remains only as an unsupported/fallback + compatibility path until the CUDA reverse engine passes high-level K/T/H, + reset, provided-state, terminal/per-timestep parity, and throughput audits. +- Implemented in this checkpoint: + added `gated_message_reverse_table_window` to the Fabric temporal backward + CUDA extension. The ABI consumes tensor roles, op rows, and scalar + descriptors, then owns the reverse recurrence inside the backend extension: + gated transition backward produces `grad_recurrent_msg[t]`, recurrent-message + backward produces the public hidden carry, and that carry is consumed by the + previous transition step inside the same table-owned reverse-window call. + This is the first real combined transition/message reverse engine ABI; it is + not just metadata or a guard. +- Manual boundary review: + ABI inputs are `vector`, tensor role strings, `op_rows`, + `scalar_i64`, and `scalar_f64`. The engine schedules reusable primitive rows + (`gated_logspace_recurrence` and regular-local recurrent message) plus tensor + slots/dependency carries. It adds no cell-kind selector, no population-name + selector, no benchmark-row selector, no hidden-size policy key, no separate + single/mixed route, and no cell-family parameter bundle. The active mixed + flat-bucket path is not switched yet because the multi-bucket diagonal+gated + scheduler and executor integration still need to land without breaking those + boundaries. +- Evidence: + `ruff check`, `py_compile`, `git diff --check`, and the source boundary test + passed. GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_engine` and + `/tmp/cortical_triton_${USER}_redo_fixmass_reverse_engine` passed the focused + coupled reverse-engine parity test + `test_fabric_cuda_gated_message_reverse_table_window_matches_step_loop` as + `1 passed in 43.79s`. The same cache passed mixed K=128 high-level + `model(...); loss.backward()` parity as `1 passed in 191.06s`, and + terminal/provided-state K>1 plus reset-sensitive mixed horizon parity as + `4 passed in 6.79s`. +- Remaining owner: + wire the combined reverse engine into `_run_backward_window` for the active + flat-bucket executor, extend the table scheduler to cover the diagonal row + and multi-bucket mixed fabrics in one reverse scan, then change the reverse + owner from `python_host_reverse_loop` only after high-level parity and + throughput pass. + +### 2026-04-28 UTC - R4 mixed transition/message reverse table engine + +Status: ACCEPTED AS MIXED REVERSE ENGINE ABI; active executor wiring remains +OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Current step: + extend the first gated-only reverse dependency ABI into a table-owned + transition/message reverse engine that schedules both supported primitive + transition rows over one flat bucket identity. The scheduler must consume + tensor roles, op rows, bucket offsets/counts, reset schedule, and recurrent + message tables; it must not select on cell kind, population name, benchmark + row, hidden-size policy, or separate single/mixed route. +- Manual boundary review before code: + proposed ABI inputs are `vector`, tensor role strings, `op_rows`, + `scalar_i64`, and `scalar_f64`. Primitive rows are reusable physical rows + lowered from Fabric IR (`gated_logspace_recurrence`, `diag_rtu`, and + regular-local recurrent message). Bucket start/count values are flat-bucket + identity, not user population identity. This patch is not allowed to claim + CUDA temporal reverse ownership until the active high-level executor consumes + it and parity/performance gates pass. +- Implemented: + added `transition_message_reverse_table_window` to the temporal backward + CUDA extension and Python wrapper. It schedules gated-recurrent-affine, + diagonal-RTU, and regular-local recurrent-message primitive rows inside one + reverse window over full flat-bucket recurrent identity. The engine composes + per-bucket transition adjoints into a full backend-order recurrent-message + adjoint, runs recurrent-message backward, and carries the produced public + hidden adjoint into the previous transition step inside the backend + extension. This removes the missing diagonal/mixed coverage from the + combined reverse engine ABI, but the live `_run_backward_window` path still + has to be wired to consume it. +- Evidence: + `uv run ruff check`, `uv run ruff format --check`, `py_compile`, + `git diff --check`, and `tests/test_fabric_backend_boundaries.py` passed. + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_mixed_reverse_engine` and + `/tmp/cortical_triton_${USER}_redo_fixmass_mixed_reverse_engine` passed the + focused mixed reverse-engine parity test + `test_fabric_cuda_transition_message_reverse_table_window_matches_mixed_step_loop` + as `1 passed in 43.27s`. The same cache passed adjacent gated reverse-engine + parity plus terminal-loss, provided-state gradient, and K=128 high-level + `model(...); loss.backward()` guards as `4 passed in 193.74s`. +- Remaining owner: + wire the mixed transition/message reverse engine into + `TemporalPhysicalBackwardScanExecutor._run_backward_window` without changing + reverse ownership metadata until the active path actually consumes this ABI. + Output/boundary/query parameter binding and transition projection parameter + reductions must stay table/window-owned or deferred outside the reverse + dependency loop, and reset parity remains a required gate. + +### 2026-04-28 UTC - R4 active mixed reverse table consumption + +Status: ACCEPTED AS ACTIVE-PATH TABLE CONSUMPTION SLICE; R4 REVERSE SUPEROP +OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + `TemporalPhysicalBackwardScanExecutor._run_backward_window` now attempts the + mixed `transition_message_reverse_table_window` engine for supported + no-final-state output-cell windows before falling back to the legacy + per-step reverse loop. The active hook builds flat-bucket tensor windows, + carries direct output adjoints into the table engine, binds boundary/query, + recurrent K/V, and transition parameter gradients, and records + `temporal_backward_glue:cuda_transition_message_reverse_table_window` only + when the engine actually runs. +- Reset correction: + the mixed reverse ABI now carries separate `reset.transition_step_window` and + `reset.message_step_window` roles. This keeps Fabric reset semantics explicit: + transition reset and message/source reset are different primitive inputs even + when a test row happens to use the same mask. +- Primitive-dimension correction: + the message primitive `head_dim` and the gated recurrent-affine primitive + head width are separate table scalars. The latter is inferred from + `gated_recurrent_kernel` shape and named + `gated_recurrent_affine_head_dim`; it is primitive table metadata, not a + cell-family route or selector. The active mixed row exposed this because the + recurrent message primitive used head width 4 while the gated primitive used + one head of width 8. +- Manual boundary review: + no cell-kind selector, population-name selector, benchmark selector, hidden + size policy branch, or single/mixed route branch was added. The engine still + schedules reusable primitive rows (`gated_logspace_recurrence`, `diag_rtu`, + and regular-local recurrent message) over flat bucket identity. The reverse + owner remains `python_host_reverse_loop` because the host-side table engine + still contains the reverse step loop through primitive calls; the patch + records only the reverse table binding ABI, not CUDA temporal-superop + ownership. +- Parity issue found and fixed: + the first active consumption pass matched input gradients but dropped output + projection parameter grads because the active path started a fresh parameter + accumulator after `output_backward_sequence` had already produced those + grads. The active engine path now accumulates + `output_backward_sequence.param_grads` before binding boundary/query, + recurrent K/V, and transition parameter grads. +- Evidence: + `uv run ruff check`, `uv run ruff format --check`, `py_compile`, and + `git diff --check` passed on touched files. GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_mixed_reverse_engine_primitive_head` + and + `/tmp/cortical_triton_${USER}_redo_fixmass_mixed_reverse_engine_primitive_head` + passed the low-level mixed reverse table parity row and the high-level K>1 + terminal-loss mixed-pop row using normal `model(...); loss.backward()`. The + terminal row now requires + `temporal_backward_glue:cuda_transition_message_reverse_table_window`, the + reverse table binding ABI, host-loop reverse owner, and no CUDA reverse-owner + overclaim. Adjacent provided-state/final-state and K=128 high-level guards + passed as `3 passed in 8.32s`. +- Remaining owner: + R4 stays open. The table engine still needs to support materialized final + state/provided-state rows, sequence-output/K=128 rows, and then move the + reverse step loop itself into the fully owned CUDA temporal superop before + throughput audits can start. Do not claim CUDA reverse ownership until the + Python host reverse loop is gone and parity plus throughput gates pass. + +### 2026-04-28 UTC - R4 materialized-state/K128 reverse table consumption + +Status: ACCEPTED AS ACTIVE-PATH MATERIALIZED-STATE CONSUMPTION SLICE; R4 +REVERSE SUPEROP OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + the active mixed reverse table path now supports materialized final-state + rows. For engine-only execution, final output-cell carry gradients are merged + into the last output-backward step, while recurrent carry gradients seed the + reverse table dependency scan. This keeps the legacy fallback behavior + unchanged if the table path rejects. +- Coverage widened: + high-level provided-state/final-state K>1 terminal loss and high-level K=128 + mixed-pop sequence output now both consume + `temporal_backward_glue:cuda_transition_message_reverse_table_window` through + normal `model(...); loss.backward()`. The tests assert the reverse table ABI + binding and the honest host-loop reverse owner, and still forbid + `cuda_temporal_superop` reverse ownership until the actual reverse loop moves + off the host. +- Parity evidence: + the provided-state row was manually checked for input-gradient parity, + provided-state gradient parity, parameter-gradient key equality, and + parameter-gradient value parity. The focused GPU 0 test group with private + caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_mixed_reverse_engine_final_state` + and + `/tmp/cortical_triton_${USER}_redo_fixmass_mixed_reverse_engine_final_state` + passed: + `test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient`, + `test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients`, + and `test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients` + as `3 passed in 5.90s`. +- Remaining owner: + R4 remains open because the C++ table engine still owns the reverse + dependency loop as a host-side loop over reusable primitive calls. Next + backend work must move that loop into the fully owned CUDA temporal superop, + then run T=1/mixed/T*K/H/K throughput audits against April 21 references. + +### 2026-04-28 UTC - R4 mixed reverse dependency loop moved onto CUDA + +Status: ACCEPTED AS ACTIVE CUDA REVERSE-LOOP OWNERSHIP; R4 THROUGHPUT OWNER +REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- User correction handled: + the K=128 rows are not the closure basis by themselves. They are retained as + regression guards because K must extend the same temporal engine to more + physical steps, but T=1 remains the base streaming case. This checkpoint + therefore added an explicit mixed-population T=1, K=1 high-level training row + before treating K>1/K128 as follow-on guards. +- Implemented: + `flat_bucket_temporal_transition_message_reverse_table_window_cuda` no longer + owns the mixed transition/message reverse dependency as a C++ host loop over + timesteps. It launches one CUDA device-loop kernel for the active mixed + reverse table call. The high-level active path now records + `temporal_backward_glue:cuda_transition_message_reverse_table_device_loop` + and `flat_bucket_temporal_reverse_scan_owner:cuda_temporal_superop` only when + that table ABI runs. +- T=1 base-case guard: + added + `test_fabric_cuda_mixed_population_t1_k1_training_uses_shared_reverse_device_loop`. + It uses the public high-level model call, normal external loss construction, + and `loss.backward()` for reset and no-reset T=1 mixed-pop rows. The test + asserts flat bucket identity, no single-bucket executor, the reverse table + binding ABI, CUDA reverse-loop owner, and no host-loop owner alias. +- Manual boundary review: + the active ABI remains the existing reverse table extension: tensor role + slots, op-row/scalar descriptors, flat bucket start/count metadata, reset + windows, recurrent message tables, step index window, and parameter/state + tensor bindings. This patch adds no cell-kind selector, no population-name + selector, no benchmark-row selector, no hidden-size policy key, no + single-vs-mixed route branch, and no cell-family parameter bundle. Primitive + names such as gated recurrent-affine and diagonal RTU describe reusable + physical rows lowered from Fabric IR/table facts; they are not route keys. +- Important limitation: + the device-loop kernel is correctness-first and serial per batch. It removes + the active host timestep loop for the mixed reverse dependency, but it is not + the final throughput implementation. R4 remains open for parallelization, + full tensor/op-table generalization where the current wrapper is still too + bespoke, T=1 audit throughput, mixed T=1 throughput, and T*K/H/K audit + closure against April 21 JSON references. +- Evidence: + `python -m py_compile + src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py`, + `uv run ruff check`, `uv run ruff format --check`, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_t1_owner` passed: + new mixed T=1/K=1 reset/no-reset high-level parity as + `2 passed in 230.41s`, K>1 terminal/provided-state plus K=128 high-level + parity as `3 passed in 6.10s`, and low-level mixed reverse table parity as + `1 passed in 3.41s`. +- Next owner: + run a manual code review focused on the new device-loop ABI and any remaining + bespoke primitive argument bundles, then either split the primitive math into + clearer `fabric.cuda.nn` primitive kernels under the shared temporal table + scheduler or proceed to the T=1 audit harness only if the boundary review + confirms no design violation. Audit stages must start with T=1 references + from the April 21 JSON before T*K/H/K scaling rows. + +### 2026-04-28 UTC - R4 reverse device-loop ABI narrowed to tensor table + +Status: ACCEPTED AS BOUNDARY CLEANUP FOR ACTIVE CUDA REVERSE LOOP; R4 +THROUGHPUT OWNER REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Manual review result: + the previous device-loop launch removed the C++ timestep loop, but the CUDA + kernel signature still accepted a long bespoke list of gated/diagonal/message + tensor pointers. That was too close to a primitive-specific argument bundle + for the shared temporal-engine boundary, even though selection still came + from table roles and op rows. +- Implemented: + the active mixed reverse device-loop kernel now takes a single device tensor + pointer table (`tensor_table_ptrs`) plus scalar/table metadata. The host + wrapper still resolves role strings and validates shapes, then binds input, + output, reset, state, graph, and parameter tensors into table slots before + launch. The kernel loads typed views from table slots internally instead of + receiving each primitive tensor as a formal launch argument. +- Guardrail: + `tests/test_fabric_backend_boundaries.py` now asserts that + `transition_message_reverse_table_device_loop_kernel` keeps the launch + signature table-owned and does not regress to formal `float*`, `bool*`, or + `int32_t*` primitive tensor arguments. +- Manual boundary review: + ABI inputs are now one tensor-pointer table, op-row/scalar metadata, flat + bucket offsets/counts, reset flags, primitive dimensions inferred from table + tensors, and schedule scalars. No cell-kind selector, population-name + selector, benchmark-row selector, hidden-size policy key, single/mixed route + branch, or cell-family parameter bundle was added. Primitive slot names still + exist as transitional table slots inside this CUDA source; R4 remains open to + move the primitive row implementations behind a cleaner `fabric.cuda.nn` + primitive dispatch layer and to parallelize the correctness-first loop. +- Evidence: + `uv run pytest -q tests/test_fabric_backend_boundaries.py -n0` passed as + `2 passed in 0.02s`; `uv run ruff check`, `uv run ruff format --check`, + `python -m py_compile + src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_temporal_backward_cuda.py`, + and `git diff --check` passed. GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_tableptr` passed the + low-level mixed reverse table parity row as `1 passed in 43.70s` and the + active high-level T=1 reset/no-reset plus K>1/provided-state/K128 group as + `5 passed in 192.02s`. +- Next owner: + continue R4 by either extracting the primitive row math behind reusable + `fabric.cuda.nn`/physical primitive helpers or beginning T=1 audit harness + work only after the active reverse path remains table-owned under review. + +### 2026-04-28 UTC - R4 reverse message receiver-row parallelization + +Status: ACCEPTED AS PERFORMANCE-ORIENTED CUDA SLICE; R4 THROUGHPUT OWNER +REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + the mixed transition/message reverse device-loop kernel now runs the + recurrent-message backward stage across receiver rows inside the CUDA block + instead of leaving the entire table scan on one thread. Transition primitive + math is still serialized per batch for this slice, but seed initialization, + public carry clearing, receiver-row message reverse, and per-step carry + materialization are parallelized over block threads. Cross-receiver + collisions into recurrent sender public carry and input K/V gradients now use + atomic accumulation. +- Boundary review: + this change kept the table-owned ABI introduced in the prior checkpoint: + one tensor-pointer table plus scalar/op metadata, flat bucket start/counts, + reset flags, primitive dimensions inferred from tensor shapes, sender tables, + and schedule scalars. No cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, separate single/mixed route, + or cell-family parameter bundle was added. Primitive slot names remain + transitional table slots; R4 still needs cleaner reusable primitive row + dispatch and throughput audit closure. +- Evidence: + current verification after the doc/test updates passed `git diff --check`, + `uv run ruff check tests/test_fabric_runtime.py`, + `uv run pytest -q tests/test_fabric_backend_boundaries.py -n0` as + `2 passed in 0.01s`, GPU 0 low-level reverse table parity with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_msg_parallel_verify` + as `1 passed in 44.09s`, and GPU 0 high-level public API parity with private + cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_msg_parallel_high_verify` + as `5 passed in 229.91s`. +- Earlier evidence before the semantic guard: + GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_msg_parallel` passed + the low-level mixed reverse table parity row as `1 passed in 44.22s` and the + active high-level T=1/K=1 plus K>1/provided-state/K128 group as + `5 passed in 187.87s`. `git diff --check` also passed. +- Remaining owner: + R4 remains open because the transition half of the device-loop kernel is + still serial per batch, the primitive row dispatch is still too directly + encoded inside the temporal source, and no April-21-referenced T=1/mixed/T*K + throughput closure has been run after this performance slice. + +### 2026-04-28 UTC - R4 sender public-normalization backward audit + +Status: ACCEPTED AS SEMANTIC GUARD; R4 THROUGHPUT OWNER REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- User bug report checked: + the suspected failure mode was a recurrent K/V backward mismatch where + forward projected normalized/public sender state, but backward used raw + unnormalized sender state for K/V weight gradients and skipped the + public-normalization adjoint before raw recurrent-state gradients. +- Current-code inspection: + the CUDA temporal scan writes gated rows into `public_hidden` after + per-receiver outnorm and then projects that `public_hidden` into recurrent + K/V. Checkpoints store that same public hidden. The active reverse table + consumes `recurrent_hidden_before_window` for recurrent K/V weight gradients; + that window is materialized from initial public cells and checkpoint public + hidden, not raw gated state. The K/V input adjoint accumulates into + `public_carry_work`, and the previous reverse transition step consumes that + public carry through the outnorm/public-boundary backward before writing raw + transition gradients. +- Regression guard added: + `test_fabric_cuda_recurrent_kv_backward_uses_public_outnorm_sender_state` + constructs a non-uniform public-normalization row, projects the normalized + public state into K/V, and compares CUDA projection backward plus outnorm + backward against autograd. It also asserts that the wrong raw-state K/V + weight gradient is materially different, so a future raw-state projection + shortcut will fail the test. +- Skill update: + `skills/cb.fabric-backend-boundaries/SKILL.md` now records the generic rule: + projection backward must differentiate exactly the tensor projected in + forward; public/normalized projection inputs must produce public/normalized + weight gradients and route sender-state adjoints back through the owning + public-boundary backward before raw recurrent state. +- Evidence: + GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sender_ln_guard` passed the new + sender public-normalization guard as `1 passed in 42.40s`. The same final + verification pass also reran the boundary tests, low-level mixed reverse + table row, and high-level public API group listed in the receiver-row + parallelization evidence above. A separate high-level T=2 CUDA-vs-PyTorch + probe with non-uniform outnorm and ordinary `model(...); loss.backward()` + passed full output, input-gradient, and parameter-gradient parity before the + guard was added. +- Remaining owner: + this audit did not close R4. It blocks one semantic regression, but the next + code owner is still the shared temporal backward engine: keep moving hot + reverse transition/message work into table-owned CUDA primitive rows, then + run T=1, mixed T=1, and T*K/H/K throughput audits against the April 21 JSON + references. + +### 2026-04-28 UTC - R4 reverse transition row parallelization + +Status: ACCEPTED AS CUDA TRANSITION-SIDE PARALLELIZATION SLICE; R4 THROUGHPUT +OWNER REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + the same table-owned reverse device-loop now distributes transition rows + across block threads instead of making `threadIdx.x == 0` own every gated and + diagonal transition row for each reverse timestep. Each thread owns complete + flat-bucket transition rows, writes that row's raw/input/state gradients, and + emits that row's recurrent-message adjoint before the existing block barrier + hands control to the receiver-row message reverse stage. +- Boundary review: + this is still the tensor-table reverse ABI over flat bucket identity, reset + windows, sender tables, primitive dimensions inferred from tensor shapes, and + schedule scalars. It adds no cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, separate single/mixed route, + or cell-family parameter bundle. The implementation is still a transitional + primitive-row body inside the temporal source, not final fully generic + `fabric.cuda.nn` row dispatch. +- Evidence: + GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_transition_row_parallel` + passed the low-level mixed reverse table parity row as `1 passed in 44.02s`. + GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_transition_high_verify` + passed the high-level public API T=1/K=1 reset/no-reset plus + K>1/provided-state/K128 group as `5 passed in 228.22s`. After formatting the + touched CUDA block, GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_transition_postformat` + reran the low-level mixed reverse table parity row as `1 passed in 45.61s`; + `git diff --check` also passed. +- Remaining owner: + R4 remains open. This reduces the obvious single-thread transition ownership + inside the active reverse device loop, but it is not a throughput closure and + it does not yet remove the primitive bodies from the temporal kernel. Next + backend work should either extract cleaner reusable primitive-row dispatch + under the table scheduler or run a focused current-code profile to pick the + next hot physical owner before April-21-referenced T=1/T*K audits. + +### 2026-04-28 UTC - R4 reverse table owner timing made visible + +Status: ACCEPTED AS OWNER-ATTRIBUTION GUARDRAIL; R4 THROUGHPUT OWNER REMAINS +OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Problem found: + the first current-code high-level audit profile after transition-row + parallelization did not include the reverse table device-loop in + `backward_owner_timing_ms`, even though launch metadata showed + `temporal_backward_glue:cuda_transition_message_reverse_table_device_loop`. + That made `temporal_artifact_recompute` look like the largest recorded owner + while the actual reverse CUDA superop was unaccounted. +- Implemented: + `_run_backward_window` now wraps the active + `try_transition_message_reverse_table_window_cuda(...)` call in + `runtime._backend_owner_timing("transition_message_reverse_table_device_loop")`. + This is timing attribution only; it does not alter planner routing, public + API behavior, tensor-table ABI, or primitive semantics. +- Evidence: + `uv run ruff check + src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py`, + `python -m py_compile + src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py`, and + `git diff --check` passed. GPU 0 high-level audit profile with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1`, private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_owner_profile_post_transition` + and + `/tmp/cortical_triton_${USER}_redo_fixmass_owner_profile_post_transition`, + and normal `model(...); loss.backward()` wrote + `/tmp/redo_fixmaass_owner_profile_reverse_table_timing` with + `transition_message_reverse_table_device_loop:ms=85.864;count=2` as the top + CUDA-time owner, followed by `temporal_artifact_recompute:ms=22.546;count=2`. + The row was mixed-pop, T=1, K=128, H=64, terminal loss, h=8, B=1, and reported + `tokens_per_s=7.178`. +- Remaining owner: + do not treat this as closure. It proves the active reverse table device-loop + is now the next measured R4 hot owner. `temporal_plan_forward_owners` and + `temporal_plan_backward_owners` still report `python_autograd_scan`, so final + owner metadata and audit gates remain open. Next implementation should + optimize the table-owned reverse device-loop itself, not chase + `temporal_artifact_recompute` based on incomplete timing. + +### 2026-04-28 UTC - R4 cooperative reverse table device-loop + +Status: ACCEPTED AS MULTI-BLOCK CUDA REVERSE-SCAN SLICE; R4 THROUGHPUT OWNER +REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + `transition_message_reverse_table_device_loop_kernel` now launches as a CUDA + cooperative kernel and uses grid-wide synchronization between reverse + transition, public-carry clearing, recurrent-message backward, and + message-to-transition carry materialization stages. Work is distributed over + the full cooperative grid for seed initialization, transition rows, message + receiver rows, and carry copies, so B=1 no longer implies one CUDA block owns + the entire K/H reverse window. +- Boundary review: + the ABI remains the flat tensor-table reverse extension plus scalar/table + metadata. The kernel still consumes flat bucket identity, reset windows, + sender tables, primitive dimensions inferred from tensor shapes, and schedule + scalars. It adds no cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, separate single/mixed route, + or cell-family parameter bundle. Cooperative launch is a backend execution + policy for the existing table-owned reverse superop, not a public knob. +- Evidence: + GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_coop_compile` passed + the low-level mixed reverse table parity row as `1 passed in 45.09s`. GPU 0 + with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_coop_high_verify` + passed high-level public API T=1/K=1 reset/no-reset plus + K>1/provided-state/K128 guards as `5 passed in 230.75s`. +- Current timing: + GPU 0 timed high-level audit row with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1`, private extension cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_coop_high_verify`, and + Triton cache `/tmp/cortical_triton_${USER}_redo_fixmass_reverse_coop_profile` + wrote `/tmp/redo_fixmaass_owner_profile_reverse_coop_timing`. The mixed-pop + T=1, K=128, H=64, terminal-loss, h=8, B=1 row improved from the prior + measured `transition_message_reverse_table_device_loop:ms=85.864;count=2` to + `transition_message_reverse_table_device_loop:ms=76.774;count=2`, and + tokens/s moved from `7.178` to `7.508`. +- Remaining owner: + R4 is still open. The reverse table device-loop remains the top measured + CUDA owner, and planner owner metadata still reports + `python_autograd_scan`. Next backend work should continue inside the + cooperative reverse table superop: reduce per-thread local transition work, + split primitive row bodies behind cleaner table dispatch, or attack the + recurrent-message atomic/carry path based on the next profile. + +### 2026-04-28 UTC - T=1 audit discipline correction and B1024 probes + +Status: ACCEPTED AS AUDIT-STRATEGY CORRECTION; NOT A FULL AUDIT CLOSURE. + +Owner: audit discipline / R4 regression guard. + +- User correction: + do not confuse K=128/H=64 owner profiles with the T=1 audit gate. K/H rows + are useful for picking backend owners after the base path is healthy, but the + audit strategy starts with T=1, K=1 and April 21 JSON references. Periodic + T=1 probes are required while R4 backend work proceeds so we do not optimize + K/H while silently regressing the base streaming case. +- Correction made: + the earlier mixed-pop command accidentally used `--plan smoke`, which emitted + a hardcoded B=1 forward-only smoke case. That row is explicitly rejected as a + B1024 training audit signal. The correct mixed T=1 path is the + `t1-single-pop` audit plan with `--population-modes mixed`; despite the plan + name, the manifest supports both single and mixed population modes. +- Current T=1 probes: + GPU 0 single-pop high-level API probe with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_t1_b1024_probe` and + `/tmp/cortical_triton_${USER}_redo_fixmass_t1_b1024_probe` ran + `t1-single-pop_slstm_1m_forward_backward_b1024_t1_k1_h8_..._popsingle`. + It reported `tokens_per_s=53475.779`, `peak_mem_gib=0.821`, attached April + 21 reference key `h8_many_cell_stress_broad`, reference `tokens_per_s=5470.79`, + and ratio `9.775x`. +- Current mixed T=1 probe: + GPU 0 mixed-pop high-level API probe with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_t1_b1024_mixed_train_probe` + and `/tmp/cortical_triton_${USER}_redo_fixmass_t1_b1024_mixed_train_probe` + ran + `t1-single-pop_slstm_1m_forward_backward_b1024_t1_k1_h8_..._popmixed`. + It reported `tokens_per_s=43890.872`, `peak_mem_gib=0.863`, attached April + 21 reference key `h8_many_cell_stress_broad`, reference `tokens_per_s=5470.79`, + and ratio `8.023x`. The planner signature recorded mixed population families + `slstm,axoncell`, `launch_scan_implementations=["cuda_temporal_superop"]`, + and active reverse table launch counts for the mixed backward path. +- Non-closure note: + both T=1 probes were informational one-row checks, not the full T=1 audit + matrix. They show no immediate B1024 T=1 regression against the April 21 h8 + floor while R4 continues, but final closure still requires the planned T=1 + coverage across April 21 cases, mixed-pop matched-stack/MoE comparison, and + then T*K/H/K audits. + +### 2026-04-28 UTC - R4 cooperative reverse table sync trim + +Status: ACCEPTED AS NARROW CUDA REVERSE-SUPEROP CLEANUP; R4 THROUGHPUT OWNER +REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + removed the final grid-wide barrier at the end of each reverse timestep after + copying `public_carry_work` into `grad_hidden_message_window`. The next + reverse transition reads the carry, and the following transition-stage + barrier still protects the later carry clear before the message stage, so the + removed barrier was stronger than required. +- Evidence: + GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_coop_sync_trim` passed + the low-level mixed reverse table parity row as `1 passed in 46.99s`. GPU 0 + with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_coop_sync_trim_high` + passed high-level public API T=1/K=1 reset/no-reset plus + K>1/provided-state/K128 guards as `5 passed in 233.89s`. +- Timing: + GPU 0 timed high-level K=128/H=64 owner profile with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1` and private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_coop_sync_trim_high` + and `/tmp/cortical_triton_${USER}_redo_fixmass_reverse_sync_trim_profile` + wrote `/tmp/redo_fixmaass_owner_profile_reverse_sync_trim_timing`. The + reverse table owner moved only slightly from the cooperative baseline + `transition_message_reverse_table_device_loop:ms=76.774;count=2` to + `76.635;count=2`; tokens/s moved from `7.508` to `7.658`. +- Non-closure note: + this is not stack comparison evidence and not an audit result. The K=128/H=64 + B=1 owner profile remains far below the April 21 floor if compared directly; + it exists only to guide R4 backend implementation. The T=1 B1024 probes above + are the base audit regression signal for this slice. + +### 2026-04-28 UTC - R4 fused public-carry clear into transition rows + +Status: ACCEPTED AS CORRECTNESS-PRESERVING CLEANUP; NOT A MATERIAL +THROUGHPUT WIN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + removed the separate full-grid `public_carry_work` clear pass from each + reverse timestep. Each transition row now clears its own public-carry row + after consuming it. The existing transition-to-message grid barrier ensures + all rows are cleared before recurrent-message backward writes the next + timestep's carry. Because the next transition would otherwise race with the + diagnostic `grad_hidden_message_window` copy, the post-copy grid barrier is + retained for this fused-clear variant. +- Boundary review: + this only changes carry lifetime inside the existing cooperative tensor-table + reverse superop. It adds no cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, separate single/mixed route, + or cell-family parameter bundle. +- Evidence: + GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_carry_clear_fused` + passed low-level mixed reverse table parity as `1 passed in 45.10s`. GPU 0 + with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_carry_clear_high` + passed high-level public API T=1/K=1 reset/no-reset plus + K>1/provided-state/K128 guards as `5 passed in 232.08s`. +- Timing: + GPU 0 timed high-level K=128/H=64 owner profile with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1` and private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_carry_clear_high` and + `/tmp/cortical_triton_${USER}_redo_fixmass_reverse_carry_clear_profile` + wrote `/tmp/redo_fixmaass_owner_profile_reverse_carry_clear_timing`. The + reverse table owner was essentially unchanged versus the sync-trim baseline: + `transition_message_reverse_table_device_loop:ms=76.693;count=2`, with + `tokens_per_s=7.667`. +- Remaining owner: + keep R4 focused on the reverse table device-loop internals. The next + meaningful work is not another barrier-level cleanup unless profiling proves + it; it should reduce the row-local primitive bodies, recurrent-message + atomic/carry path, or table-dispatch structure while preserving T=1 probes. + +### 2026-04-28 UTC - R4 recurrent query-gradient local accumulation + +Status: ACCEPTED AS SMALL GENERIC ATOMIC REDUCTION; R4 THROUGHPUT OWNER +REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + the cooperative reverse table message row now accumulates recurrent-query + gradients locally across the sender offsets for a receiver row and issues one + atomic add per query dimension after all offsets are processed. Previously it + issued a query-gradient atomic inside the sender-offset loop. This preserves + generic receiver/sender table semantics; input K/V, recurrent K/V weight, and + public carry paths still use atomics where multiple receiver rows can collide + on the same sender. +- Boundary review: + this is a primitive-dimension optimization inside the existing tensor-table + reverse superop. The only added bound is the generic message primitive + `head_dim <= 64` for the local accumulator array; it is inferred from tensor + shape and is not a cell-family, population, benchmark, hidden-size policy, or + single/mixed route key. +- Evidence: + GPU 0 with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_q_grad_accum` passed + the low-level mixed reverse table parity row as `1 passed in 45.17s`. GPU 0 + with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_q_grad_high` passed + high-level public API T=1/K=1 reset/no-reset plus + K>1/provided-state/K128 guards as `5 passed in 233.90s`. +- Timing: + GPU 0 timed high-level K=128/H=64 owner profile with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1` and private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_q_grad_high` and + `/tmp/cortical_triton_${USER}_redo_fixmass_reverse_q_grad_profile` wrote + `/tmp/redo_fixmaass_owner_profile_reverse_q_grad_timing`. The reverse table + owner moved slightly from the carry-clear baseline + `transition_message_reverse_table_device_loop:ms=76.693;count=2` to + `76.269;count=2`; the row reported `tokens_per_s=7.532`, which is within + noise of the prior tiny K=128 owner profiles and is not audit evidence. +- Remaining owner: + R4 remains open and the reverse table device-loop remains the top measured + owner. Do not keep doing only small atomic cleanups; the next meaningful + implementation should target the recurrent-message sender/carry collision + pattern or split row-local primitive bodies behind cleaner table dispatch. + +### 2026-04-28 UTC - R4 active sender-owned recurrent K/V phase + +Status: ACCEPTED AS ACTIVE GENERIC SENDER-PHASE PERFORMANCE SLICE; R4 +THROUGHPUT OWNER REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + made the already-passed `use_sender_reverse` ABI real inside the active + cooperative transition/message reverse table kernel. Receiver rows now write + generic message weight and logit-adjoint work tables keyed by flat + receiver/offset. A sender-owned phase consumes the existing + `message.sender_receiver_idx` tensor table and produces recurrent K/V + projection parameter gradients plus previous-step public carry for recurrent + senders. Input-sender K/V gradients and recurrent-query gradients remain + receiver-row work. +- Boundary review: + the ABI remains the flat-bucket temporal tensor table. The new inputs are + `message.sender_receiver_idx` and two generic scratch tables for message + weights/dlogits. There is no cell-kind selector, no population-name selector, + no benchmark-row selector, no hidden-size policy key, no separate single/mixed + route, and no cell-family parameter bundle. The sender phase is keyed only by + flat sender rows, receiver offsets, message head/value dimensions inferred + from tensor shapes, reset masks, and tensor-table projection weights. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sender_phase` and + `/tmp/cortical_triton_${USER}_redo_fixmass_sender_phase` passed the low-level + mixed reverse table parity row as `1 passed in 46.58s`. The same cache passed + high-level public API T=1/K=1 mixed-pop training reset/no-reset, + K>1 provided-state gradient, and K=128 terminal-emission gradient guards as + `4 passed in 194.27s`. +- Timing: + GPU 0 timed high-level K=128/H=64 owner profile with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1` and private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sender_phase` and + `/tmp/cortical_triton_${USER}_redo_fixmass_sender_phase_profile`, writing + `/tmp/redo_fixmaass_owner_profile_sender_phase_timing`. The reverse table + owner moved from the prior q-gradient baseline + `transition_message_reverse_table_device_loop:ms=76.269;count=2` to + `28.179;count=2`; row throughput moved from `7.532 tok/s` to + `12.267 tok/s`. This is an owner-profile improvement, not audit closure. +- T=1 regression probes: + GPU 0 single-pop B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_sender_phase_t1_b1024_single` and reported + `53313.293 tok/s`, peak `0.820897 GiB`, April 21 reference + `h8_many_cell_stress_broad=5470.79 tok/s`, ratio `9.745x`. GPU 1 mixed-pop + B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_sender_phase_t1_b1024_mixed` and reported + `62000.143 tok/s`, peak `0.862957 GiB`, same reference, ratio `11.333x`. + These are one-row regression probes only; the full T=1 audit matrix remains + pending. +- Message-declaration cleanup note: + message passing is a user-declared semantic surface like cells. Dot-product + attention, Q/K/V projection sources, normalization/public-boundary use, + distance/delay terms, aggregation, reset behavior, and future message + operators must lower from Fabric declarations into generic message primitive + rows/tensor tables. The current active kernel still embeds the lowered + dot-product row body as transitional primitive execution; later cleanup must + split this behind generic message primitive rows and must not treat the + current built-in message rule as hidden engine policy. +- Remaining owner: + R4 remains open. The active reverse table device-loop moved materially, but + the K=128/H=64 owner profile is still far below audit goals and planner + metadata still reports `python_autograd_scan` for forward/backward owners. + +### 2026-04-28 UTC - R4 sender K/V adjoint scratch split + +Status: ACCEPTED AS SMALL GENERIC SENDER-PHASE OPTIMIZATION; R4 THROUGHPUT +OWNER REMAINS OPEN. + +Owner: active shared temporal backward reverse dependency scan. + +- Implemented: + split the active sender-owned recurrent K/V phase into a sender-dimension + adjoint scratch table and a hidden-dimension carry/parameter phase. The + sender K/V adjoint depends on flat `(batch, sender, kv_dim)` and receiver + offsets, not on hidden lane, so the new scratch avoids recomputing the same + offset loop once per hidden element. This is still a generic message + primitive/tensor-table optimization and does not introduce cell or population + selectors. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sender_gradkv` and + `/tmp/cortical_triton_${USER}_redo_fixmass_sender_gradkv` passed the low-level + mixed reverse table parity row as `1 passed in 45.28s`. The same cache passed + high-level public API T=1/K=1 mixed-pop training reset/no-reset, + K>1 provided-state gradient, and K=128 terminal-emission gradient guards as + `4 passed in 193.77s`. +- Timing: + GPU 0 timed high-level K=128/H=64 owner profile with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sender_gradkv` and + `/tmp/cortical_triton_${USER}_redo_fixmass_sender_gradkv_profile`, writing + `/tmp/redo_fixmaass_owner_profile_sender_gradkv_timing`. The reverse table + owner moved from the previous sender-phase `28.179;count=2` to + `26.667;count=2`; row throughput moved from `12.267 tok/s` to + `12.432 tok/s`. This is useful but small, and artifact recompute is now a + comparable owner at `22.627;count=2`. +- T=1 regression probes: + GPU 0 single-pop B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_sender_gradkv_t1_b1024_single` and reported + `53504.130 tok/s`, peak `0.820897 GiB`, April 21 reference + `5470.79 tok/s`, ratio `9.780x`. GPU 1 mixed-pop B=1024/T=1/K=1 training + wrote `/tmp/redo_fixmaass_sender_gradkv_t1_b1024_mixed` and reported + `59138.204 tok/s`, peak `0.862957 GiB`, same reference, ratio `10.810x`. + These remain one-row probes only; the full T=1 audit matrix remains pending. +- Remaining owner: + R4 remains open. The next high-priority work should target temporal artifact + recompute / CUDA replay cost or a larger reverse-table structural fusion. + Avoid spending more passes on tiny sender atomics unless a current profile + shows a specific regression there. + +### 2026-04-28 UTC - R4 artifact replay sparse output-message materialization + +Status: ACCEPTED AS MODEST GENERIC REPLAY MATERIALIZATION IMPROVEMENT; R4 +THROUGHPUT OWNER REMAINS OPEN. + +Owner: temporal artifact recompute / CUDA replay materialization. + +- Current owner context: + after the sender K/V adjoint split, the K=128/H=64 high-level owner profile + reports `transition_message_reverse_table_device_loop:ms=26.667;count=2` + and `temporal_artifact_recompute:ms=22.627;count=2`. Artifact recompute is + now comparable to the reverse-table owner, so R4 should reduce replay cost + before returning to smaller reverse-table atomics. +- Boundary invariant: + output-message materialization is a temporal schedule/materialization-policy + decision, not a cell or population rule. Sparse emission must be driven by + `scalar_temporal_scan_step(...).emit_output`, flat message tables, and generic + Fabric message primitives. It must not introduce cell-kind selectors, + population-name selectors, benchmark-row selectors, hidden-size policy keys, + separate single/mixed routes, or message math hidden outside Fabric + declarations. +- Planned change: + the CUDA replay scan currently writes output-message artifacts for every + physical microstep whenever backward may need output-message gradients. For + K>1 terminal/per-outer emission rows, only the emitted physical steps need + those artifacts. Keep the dense in-scan path when every replay step emits; + otherwise replay recurrent/transition artifacts through the CUDA temporal + superop and materialize output-message tensors only on emitted steps through + the existing generic `fabric.cuda.nn` message primitive path. +- Acceptance checks: + run the low-level mixed reverse-table parity row, high-level public API + T=1/K=1 and K=128 guards, a warmed K=128/H=64 owner profile, T=1 B=1024 + single/mixed regression probes, backend-boundary tests, `git diff --check`, + and commit the accepted slice. +- Implemented: + `_try_cuda_mixed_flat_bucket_recompute_artifact_window` now keeps the dense + in-scan output-message artifact path only when every replay physical step + needs an output message. Sparse-emission windows, such as K=128 terminal + output, replay recurrent/transition artifacts through the CUDA temporal + superop and materialize output-message tensors only for requested emitted + steps through the generic partitioned Fabric message primitive. Inactive + output-message artifact slots are explicit zero tensors and are not consumed + by output backward. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sparse_output_msg` and + `/tmp/cortical_triton_${USER}_redo_fixmass_sparse_output_msg` passed the + low-level mixed reverse table parity row as `1 passed in 45.60s`. The same + cache passed high-level public API T=1/K=1 mixed-pop training reset/no-reset, + K>1 provided-state gradient, and K=128 terminal-emission gradient guards as + `4 passed in 194.25s`. +- Timing: + GPU 0 timed high-level K=128/H=64 owner profile with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sparse_output_msg` and + `/tmp/cortical_triton_${USER}_redo_fixmass_sparse_output_msg_profile`, + writing `/tmp/redo_fixmaass_owner_profile_sparse_output_msg_timing`. The + replay owner moved from the sender-gradkv baseline + `temporal_artifact_recompute:ms=22.627;count=2` to + `21.131;count=2`. The reverse table owner remained comparable at + `transition_message_reverse_table_device_loop:ms=26.830;count=2`; throughput + was `12.303 tok/s`. The sparse output-message row itself measured + `artifact.recompute.sparse_output_message:ms=0.010;count=1`, confirming the + output-message work moved out of the dense per-microstep replay loop for this + terminal K row. +- T=1 regression probes: + GPU 0 single-pop B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_sparse_output_t1_b1024_single` and reported + `51759.503 tok/s`, peak `0.820897 GiB`, April 21 reference + `5470.79 tok/s`, ratio `9.461x`. GPU 1 mixed-pop B=1024/T=1/K=1 training + wrote `/tmp/redo_fixmaass_sparse_output_t1_b1024_mixed` and reported + `57176.552 tok/s`, peak `0.862957 GiB`, same reference, ratio `10.451x`; + the matched mixed-stack row was `16920.441 tok/s`, so mixed Fabric remained + `3.379x` above the stack comparison. These remain one-row regression probes, + not full T=1 audit closure. +- Static guard: + `tests/test_fabric_backend_boundaries.py` passed as `2 passed in 0.01s`, and + `git diff --check` passed. +- Remaining owner: + R4 remains open. The next useful work should target either the reverse table + device loop, now the largest event-timed owner again, or the remaining replay + bridge/assembly cost that keeps `temporal_artifact_recompute` above 20 ms. + Planner metadata still reports `python_autograd_scan`, so no temporal closure + is claimed. + +### 2026-04-28 UTC - R4 reverse sender-KV phase cooperative grid sizing + +Status: ACCEPTED AS SMALL GENERIC LAUNCH-GEOMETRY FIX; R4 THROUGHPUT OWNER +REMAINS OPEN. + +Owner: active shared temporal backward reverse table launch geometry. + +- Current owner context: + after sparse output-message replay, the K=128/H=64 owner profile reports + `transition_message_reverse_table_device_loop:ms=26.830;count=2` and + `temporal_artifact_recompute:ms=21.131;count=2`. Reverse table is again the + largest event-timed owner. +- Boundary invariant: + cooperative grid sizing is a tensor-table execution property. It may use + batch, flat recurrent count, hidden width, message head/value dimensions, and + whether the generic sender-owned phase is active; it must not use cell + family, population name, benchmark row, hidden-size policy, or separate + single/mixed route logic. +- Planned change: + the reverse-table cooperative launch currently sizes requested blocks from + transition rows, receiver-message rows, and hidden carry rows. With the + sender-owned recurrent K/V phase enabled, the `(B, R, head_dim + value_dim)` + sender-adjoint phase can be larger than `(B, R, H)` in small-hidden rows. Add + that generic phase to the work estimate so small-h reverse rows are not + underfilled by the launch geometry. +- Implemented: + `flat_bucket_temporal_backward_kernels.cu` now includes the sender-owned + `(B, R, head_dim + value_dim)` phase in the cooperative work-item estimate + when `use_sender_reverse` is active. This uses only tensor-table dimensions + and the generic sender-phase flag. +- Evidence: + GPU 0 with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sender_grid` and + `/tmp/cortical_triton_${USER}_redo_fixmass_sender_grid` passed the low-level + mixed reverse table parity row as `1 passed in 45.38s`. The same cache passed + high-level public API T=1/K=1 mixed-pop training reset/no-reset, + K>1 provided-state gradient, and K=128 terminal-emission gradient guards as + `4 passed in 189.08s`. +- Timing: + GPU 0 timed high-level K=128/H=64 owner profile with private caches + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_sender_grid` and + `/tmp/cortical_triton_${USER}_redo_fixmass_sender_grid_profile`, writing + `/tmp/redo_fixmaass_owner_profile_sender_grid_timing`. The reverse table + moved from the sparse-output baseline + `transition_message_reverse_table_device_loop:ms=26.830;count=2` to + `26.291;count=2`; `temporal_artifact_recompute` was essentially unchanged at + `21.177;count=2`; throughput was `12.434 tok/s`. +- T=1 regression probes: + GPU 0 single-pop B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_sender_grid_t1_b1024_single` and reported + `52018.968 tok/s`, peak `0.820897 GiB`, April 21 reference + `5470.79 tok/s`, ratio `9.508x`. GPU 1 mixed-pop B=1024/T=1/K=1 training + wrote `/tmp/redo_fixmaass_sender_grid_t1_b1024_mixed` and reported + `58785.249 tok/s`, peak `0.862957 GiB`, same reference, ratio `10.745x`; + the matched mixed-stack row was `18688.086 tok/s`, so mixed Fabric remained + `3.146x` above the stack comparison. These are one-row regression probes + only. +- Static guard: + `tests/test_fabric_backend_boundaries.py` passed as `2 passed in 0.01s`, and + `git diff --check` passed. Because reset parity has been a recurring + April 24-26 risk, GPU 0 also ran + `test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity` + for absent and present resets; both passed as `2 passed in 6.63s`. +- Remaining owner: + R4 remains open. The current top owners are still the reverse table at + `26.291 ms` and artifact replay at `21.177 ms`; the next useful work needs a + larger structural reduction in one of those owners, not route metadata or + cleanup. + +### 2026-04-28 UTC - R4 active reverse replay graph-assembly elision + +Status: ACCEPTED AS MODEST ACTIVE-REVERSE REPLAY MATERIALIZATION IMPROVEMENT; +R4 THROUGHPUT OWNER REMAINS OPEN. + +Owner: temporal artifact recompute / replay bridge materialization. + +- Current owner context: + after reverse sender-grid sizing, the K=128/H=64 profile reports + `transition_message_reverse_table_device_loop:ms=26.291;count=2` and + `temporal_artifact_recompute:ms=21.177;count=2`. Replay is still a + first-order owner next to the reverse table. +- Boundary invariant: + graph-order cell materialization is a user-visible/fallback artifact, not a + required internal representation for the active backend-order reverse table. + Eliding it is valid only when the active CUDA reverse engine will consume the + replayed backend-order tensor tables, and the code must fail closed rather + than silently falling back to a host loop with placeholder graph-order cells. + The decision must be based on output contract, reset/materialization policy, + flat bucket identity, and lowered tensor-table availability; it must not + depend on cell family names, population names, benchmark rows, or separate + single/mixed routes. +- Planned change: + for reset-absent recompute windows that target the active CUDA reverse table, + skip `_assemble_mixed_cuda_flat_bucket_cells` inside replay and carry only a + full-cell shape anchor from the checkpoint. Keep dense graph-order assembly + for reset-present windows and unsupported output contracts. Add an explicit + active-reverse-only artifact flag so fallback refuses to run if the CUDA + reverse table rejects the window. +- Implemented: + `TemporalBucketStepArtifacts` now carries an `active_reverse_only` flag. + Reset-absent active CUDA reverse-table replay windows elide graph-order full + cell assembly and retain only the checkpoint full-cell shape anchor. The + elision requires lowered backend-order recurrent K/V projection weights so + output backward can stay in backend-order form; output backward and the + window executor now fail closed if an active-reverse-only artifact would + reach a graph-order projection fallback or the host temporal reverse loop. + This is still generic flat bucket identity logic: no cell-family, + population-name, benchmark-row, or single/mixed route condition was added. +- Evidence: + GPU 0 low-level mixed reverse table parity passed with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_active_reverse_replay_guard_low` + as `1 passed in 45.99s`. GPU 1 high-level public API guards for T=1/K=1 + mixed-pop training, K>1 provided-state gradients, and K=128 terminal + emission gradients passed as `4 passed in 237.48s`. The high-level K=16/H=16 + warmed probe wrote + `/tmp/redo_fixmaass_owner_profile_active_reverse_replay_guard_k16h16_warmed` + and showed the new + `temporal_backward_glue:cuda_recompute_active_reverse_graph_assembly_elided` + marker alongside the active CUDA reverse table marker. +- Timing: + GPU 0 high-level K=128/H=64 warmed owner profile wrote + `/tmp/redo_fixmaass_owner_profile_active_reverse_replay_guard_k128h64_warmed`. + Throughput was `12.696 tok/s` versus the previous accepted `12.434 tok/s`. + `temporal_artifact_recompute` moved from `21.177;count=2` to + `20.911;count=2`; `transition_message_reverse_table_device_loop` was + essentially unchanged at `26.277;count=2`. +- T=1 regression probes: + GPU 0 single-pop B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_active_reverse_replay_guard_t1_b1024_single` and + reported `52278.500 tok/s`, peak `0.820897 GiB`, April 21 reference + `5470.79 tok/s`. GPU 1 mixed-pop B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_active_reverse_replay_guard_t1_b1024_mixed` and reported + `59801.870 tok/s`, peak `0.831707 GiB`; matched mixed-stack comparison was + `18571.260 tok/s`, so mixed Fabric was `3.220x` above stack. +- Static and reset guards: + `tests/test_fabric_backend_boundaries.py` passed as `2 passed in 0.01s`; + `git diff --check` passed. GPU 3 reset parity for + `test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity` + passed absent and present resets as `2 passed in 5.53s`. +- Remaining owner: + R4 remains open. The current measured top owners are still the reverse table + at `26.277 ms` and replay at `20.911 ms` on K=128/H=64. The planner metadata + still reports transitional `temporal_plan_forward_owners` and + `temporal_plan_backward_owners`, so this is not temporal-engine closure. + +### 2026-04-28 UTC - R4 sender-owned reverse carry writes grad window + +Status: REJECTED PROBE; R4 remains OPEN. + +Owner: active shared temporal backward reverse table synchronization/copy +phase. + +- Current owner context: + after active reverse replay graph-assembly elision, K=128/H=64 reports + `transition_message_reverse_table_device_loop:ms=26.277;count=2` and + `temporal_artifact_recompute:ms=20.911;count=2`. +- Boundary invariant: + the sender-owned recurrent K/V reverse phase is a generic message primitive + phase over tensor-table dimensions `(B, R, head_dim + value_dim)` and + `(B, R, H)`. It is not cell-specific. When this phase is active, it computes + the per-sender hidden carry that is both the next-step public carry and the + `grad_hidden_message_window` output for the current reverse timestep. +- Planned change: + write `grad_hidden_message_window` directly from the sender-owned hidden + carry phase and skip the following generic public-carry copy loop for that + active path. Preserve the existing copy loop for non-sender-reverse windows. + This should remove one `B*R*H` pass and one cooperative grid sync per reverse + timestep without changing transition or message math. +- Rejection: + the low-level reverse-table parity passed and high-level public API guards + passed, but the timing path did not produce a usable `cases.jsonl` row before + being killed/blocked in the local tool environment. More importantly, this + slice was being judged by owner deltas before re-establishing the paired + T=1/K throughput floor. The kernel edit was rolled back. Do not revive this + micro-optimization until the audit row is paired with its matched T=1/K=1 + training baseline and the K-adjusted floor is computed first. + +### 2026-04-28 UTC - R15 legacy Fabric config surface cleanup owner + +Status: PENDING; cleanup/design owner recorded, not a backend closure. + +Owner: public graph/config normalization and backend-boundary review. + +- User correction: + Fabric is graph-defined. `src/cortical/fabric/config.py` still exposes legacy + geometry/config fields such as `width`, `height`, `depth`, band widths, + projection-region shape, and config-level `num_heads`. These fields are + acceptable only as public graph/message-construction or compatibility inputs; + they must not become backend identity, planner policy, temporal-engine + admission, kernel specialization, or audit gates. +- Current code reality: + the preferred public API is already `Blueprint` plus graph declarations, but + normalization still lowers through `Config`, and runtime/anatomy still read + legacy config geometry for graph construction and compatibility. This is not + final fixmass shape. It needs a staged cleanup after the shared temporal + backend owner is no longer blocked. +- Cleanup criteria: + move backend-facing ownership to explicit graph facts and message primitive + metadata: flat node ids, input/output/recurrent node sets, adjacency/degree + buckets, edge distance/delay, kv group ids, population bucket assignment, + interface dims, message-rule tensor/op rows, reset policy, and materialization + policy. Backend code must not use rectangular factorization, coordinate + labels, config-level head counts, population names, or cell family names as + route selectors. +- Skill update: + `skills/cb.fabric-backend-boundaries/SKILL.md` now explicitly treats legacy + Fabric config geometry/head fields as graph/message-construction surface only. + `skills/cb.fabric-scaling-horizon/SKILL.md` now records that K>1 throughput + gates use matched T=1 training throughput divided by K. +- Test hygiene correction: + do not add fake hand-constructed `FabricAuditCase` rows with invented + throughput numbers to benchmark tests. Audit criteria should be validated by + real audit manifests/runs or narrow parser/dry-run checks, not synthetic + benchmark rows that look like evidence. +- Next cleanup action: + add static/manual review gates that distinguish graph-construction use from + backend-policy use, then migrate runtime planner inputs away from direct + `Config` geometry where graph facts already exist. Do not mix this cleanup + with R4 kernel work unless the config leak is the measured owner or the + reason a shared temporal row cannot be represented generically. + +### 2026-04-28 UTC - R15 anatomy/config cleanup correction + +Status: HIGH-PRIORITY CLEANUP RECORDED; WRONG DIRECTION REJECTED; BACKEND +WORK RESUMED. + +Owner: public graph/config normalization and backend-boundary review. + +- User correction: + the config cleanup owner also includes `src/cortical/fabric/anatomy.py`, and + the design rule is stricter than "move config geometry into anatomy." Fabric + knows graphs. Lattice is one graph constructor. Lattice-specific facts such + as `width`, `height`, `depth`, `coord_shape`, `wrap`, x/y/z coordinates, + boundary bands, projection regions, and lattice offset neighborhoods belong + to the lattice graph/anatomy constructor. Backend-facing Fabric code must + consume flat graph tables: node ids, input/output/recurrent node sets, + sender/receiver adjacency, valid masks, degree buckets, edge distance/delay, + kv groups, flat bucket identity, tensor-table rows, op-table rows, reset + policy, checkpoint policy, and materialization policy. +- Audit criterion correction: + K/H throughput gates are judged from matched current-code T=1,K=1 training + throughput divided by K. April 21 remains the score-reference context and + regression source, but K=128,H=64 is not compared against raw T=1 tok/s. + It must meet matched T1 training tok/s / 128, plus parity, reset parity, + per-timestep/terminal loss reporting, memory non-regression, and shared + temporal owner evidence. +- Rejected local direction: + do not promote `coord_shape` or `wrap` into backend-facing `AnatomySpec` as a + cleanup endpoint. That still leaks lattice into Fabric. The runtime currently + has legacy reads because local sender tables are still rebuilt from lattice + offsets; those reads are a cleanup debt, not a design to preserve. The correct + replacement is to construct local sender/receiver tables before backend + runtime, from the graph constructor, and then have runtime/backend consume the + resulting flat graph tables without knowing how they were made. +- High-priority cleanup tasks: + severely clean `src/cortical/fabric/config.py` so it is no longer a global + fabric identity object mixing graph construction, message declaration, + execution policy, populations, readout, and legacy defaults. Lattice + fields should move into lattice graph builders. + Message fields should move to user-declared message primitive configs and + tensor/op tables. Execution fields should move to planner policy inputs and + recorded decisions. +- R15 scope clarification: + the recent `Config`/`anatomy.py` correction is only one part of R15, not the + whole cleanup stage. R15 also owns deletion of stale Fabric execution routes, + stale single/mixed wrappers, benchmark-side backend/planner logic, direct + message/cell math hidden in runtime/core instead of declared Fabric + primitives, old fallback paths, stale audit harness organization, and manual + review gates that prevent cell-specific or lattice-specific facts from + leaking back into backend code. R15 remains high priority and must close + before final fixmass completion; it is sequenced after the current R4 backend + owner only because throughput-critical shared temporal ownership is still + open. +- Full config redesign scope: + every field in `src/cortical/fabric/config.py` must get a real owner. The + target is not to rename `Config`; it is to remove this global mixed-purpose + object from generic Fabric machinery. + - `CellPopulationConfig.cell_type`, `num_heads`, and `activation` belong to + cell/primitive declarations and lowered op rows. They must not be global + Fabric config policy or backend route keys. + - `width`, `height`, `depth`, `local_radius`, `patch_edges_per_cell`, + `patch_min_dist`, `patch_max_dist`, `wrap`, `graph_edges`, + `projection_region_shape`, `input_band_width`, `output_band_width`, + `cell_arrangement`, `population_mix`, and `population_node_indices` belong + to graph/anatomy constructors. Lattice owns the + lattice subset. Explicit graph edges belong to an explicit graph builder. + Backend gets only normalized flat graph tables. + - `hidden_size`, `d_public`, `d_msg`, and `d_slot` belong to interface, + message-rule, cell-state, slot-feature, and tensor-table declarations. They + are dimensions of declared tensors, not backend scheduling identity. + - config-level `num_heads` and `head_dim` should disappear as generic Fabric + policy. Message head width and any transition primitive head/tile width must + come from user-declared message/cell primitives and tensor shapes. + - `distance_logit_scale`, `conduction_speed`, and `max_delay` belong to graph + edge/message-rule declarations, lowering to edge facts and message op rows. + - `kv_group_ids` belongs to grouping declarations/lowering, not lattice + projection-region defaults in generic backend code. + - `readout_pool` and `readout_slots` belong to readout declarations/lowering. + - `backend`, `gradient_horizon_steps`, `checkpoint_steps`, `k_max`, + `default_k`, and `inject_every_step` belong to high-level execution request + or planner policy inputs; planner records the final decision. They must not + be hardcoded public scheduling knobs inside benchmarks or backend kernels. + - `population_init_noise_std` and `seed` belong to initialization utilities, + not runtime/backend identity. + No legacy wrapper is an acceptable target. The implementation must move + ownership to graph declarations, message declarations, cell declarations, + readout declarations, initialization declarations, and planner request + fields, update callsites to those owners, and delete the old config truth + path. The backend surface must consume graph/tensor/op tables only. +- High-priority anatomy cleanup tasks: + split lattice-specific coordinate, band-port, offset-neighborhood, + projection-region KV grouping, and slot-feature construction out of + `src/cortical/fabric/anatomy.py` into graph/anatomy constructor modules. + Keep the backend-facing anatomy/spec surface flat and graph-generic. Do not + close this cleanup while backend code depends on lattice coordinate math to + rebuild sender tables. +- Current priority: + this cleanup is important but does not supersede the open backend owner. R4 + remains the immediate backend target: reduce/replace Python-owned temporal + scan/replay/reverse work with the shared flat-bucket temporal engine, then + return to config/anatomy cleanup before final legacy path deletion. + +### 2026-04-28 UTC - R15 cleanup scope reaffirmed from recovered goals + +Status: PENDING; R15 BREADTH LOCKED. + +Owner: R15 cleanup board and manual review memory. + +- User correction: + do not let R15 collapse into only the recent `src/cortical/fabric/config.py` + and `src/cortical/fabric/anatomy.py` cleanup. Those files are severe cleanup + targets, but R15 is the full legacy Fabric surface and deletion owner. +- Recovered cleanup inventory: + `ai_docs/additonal_goals.md` names a broader set of issues that must remain + on the R15 board: public `Blueprint.message_passing` being effectively + dot-product only; message declaration fields not really lowering into IR; + backend IR recreating default message passing; CUDA and PyTorch message paths + hardcoding dot-product/projected-message semantics; message topology and + message math being coupled; Blueprint normalizing through old `Config`; old + config concepts such as default K, readout, placement, and population mix + remaining central; runtime/public tests still using old construction paths; + graph API support being lattice-specific rather than graph-protocol generic; + global interface dimensions being forced through old hidden-size assumptions; + named adapter support being narrower than the public type; population + cardinality and placement still constrained by config/anatomy; bucket identity + including user population names; planner policy relying on hidden-size and + magic thresholds; and cell-family names/surfaces still influencing backend + execution and memory policy. +- R15 non-negotiable: + each item above must be closed by moving ownership to the correct layer and + deleting the old path, not by adding compatibility wrappers. Message math and + cell math must be user-declared Fabric primitives lowered into tensor/op + tables; backend/planner code consumes graph facts, primitive rows, reset, + checkpoint, materialization, and workspace policy. +- Current sequencing: + R4 remains the immediate high-priority owner because the shared temporal + engine is still not fully CUDA-owned. R15 remains high priority and must not + be re-scoped or forgotten while R4 proceeds; after the measured temporal + owners move and audits pass, R15 deletes the legacy surfaces instead of + leaving them as accepted debt. + +### 2026-04-28 UTC - R4 sender-owned reverse carry writes grad window retry + +Status: REJECTED; R4 remains OPEN. + +Owner: active shared temporal backward reverse table synchronization/copy +phase. + +- Boundary review: + this is a generic message/reverse-table change, not a cell-specific kernel. + ABI inputs stay as tensor-table/op-table rows. There is no cell-kind selector, + no population-name selector, no benchmark-row selector, no hidden-size policy + key, no separate single/mixed route, and no cell-family parameter bundle. The + primitive dimension names are message/transition tensor dimensions inferred + from table tensors. +- Invariant: + when the sender-owned recurrent K/V reverse phase is active, it computes the + per-sender hidden carry from generic message K/V gradients and recurrent K/V + projection weights. That carry is also the `grad_hidden_message_window` value + for the current reverse timestep. Writing it directly in the sender phase is + semantically identical to the following copy loop, while removing one + `B*R*H` pass and one cooperative-grid barrier on the active path. +- Guard: + keep the existing copy loop for the non-sender-owned path. Do not add any + family/population/benchmark condition. +- Evidence: + GPU 0 low-level reverse table parity passed with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_carry_write_low` as + `1 passed in 47.25s`. GPU 1 high-level K>1 terminal-loss, provided-state, and + K=128 terminal-gradient guards passed as `3 passed in 227.49s`. GPU 2 reset + parity passed absent and present reset modes as `2 passed in 226.89s`. +- Rejection: + the warmed high-level K=128/H=64 owner row wrote + `/tmp/redo_fixmaass_owner_profile_sender_carry_grad_window_retry_k128h64_warmed` + and reported `8.762 tok/s`, below the previous accepted active-reverse replay + guard row of `12.696 tok/s`. The row did pass the corrected K-adjusted T1 + floor (`matched T1=68.869 tok/s`, floor `0.538 tok/s`, ratio `16.285x`), but + this does not justify landing a throughput regression against the current + owner profile. The kernel edit was rolled back. Do not revive this exact + copy-elision probe without a stronger owner-profile reason and comparable + warmed timing. + +### 2026-04-28 UTC - R4 table-owned reverse replay window + +Status: ACCEPTED AS TABLE-OWNED REPLAY/REVERSE BRIDGE REDUCTION; R4 REMAINS +OPEN. + +Owner: temporal artifact replay / reverse-window table ownership. + +- User correction: + do the correct backend refactor, not the convenient local patch. The next + R4 change must reduce the Python replay/artifact bridge that keeps supported + K/H backward rows from being a fully owned shared temporal engine. +- Current owner: + CUDA replay already returns window-major temporal tensors for recurrent + messages, state-before/state-after, recurrent K/V, transition tapes, and + input projections. The live backward path still turns those tensors into a + list of `TemporalBucketStepArtifacts`, then `_try_run_transition_message_reverse_engine_window` + re-stacks many of the same tensors before calling the reverse table. That is + host-owned window assembly inside a supported active reverse path. +- Boundary review before code: + the replacement must be a flat-bucket reverse-window table consumed by the + existing tensor-table/op-row reverse executor. ABI inputs are graph tables, + message tables, reset windows, temporal step rows, primitive tensor windows, + static tensor-table slots, and flat bucket identity. It must not add a + cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, separate single/mixed route, or cell-family parameter + bundle. Primitive names are allowed only as lowered `fabric.cuda.nn` primitive + row roles, not as route identities. +- Implementation target: + attach a `TemporalReverseWindowTables` object to active CUDA replay artifacts + and have the reverse executor consume those window tensors directly. Active + reverse-only artifacts must fail closed if the table window is absent or + mismatched. Keep reset parity intact; reset-present rows can continue using + the existing fully materialized artifacts until the reset window table path is + proven. +- Implementation: + `TemporalReverseWindowTables` now carries the active replay window tensors + produced from the CUDA temporal scan: input K/V, recurrent K/V-before, + recurrent hidden-before, recurrent message, primitive state/input/logit + windows, reset windows, and temporal step rows. The active reverse executor + consumes those tensors directly instead of re-stacking the same values from + `TemporalBucketStepArtifacts`. Active reverse-only artifacts fail closed if + the shared table is absent or mismatched. After an initial table-only probe + measured only `9.455 tok/s`, the duplicate per-step transition tape and + backend-state-cache materialization was removed for active reverse-only + replay windows; those values are owned by the table path. +- Boundary review after code: + the ABI is still flat bucket graph/message/tensor/op-table data. There is no + cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, separate single/mixed route, or cell-family parameter + bundle. Primitive labels are lowered `fabric.cuda.nn` primitive row roles. +- Validation: + `py_compile`, `ruff format --check`, `ruff check`, and + `tests/test_fabric_backend_boundaries.py` passed. GPU 0 low-level reverse + table parity passed as `1 passed in 3.29s` using private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_window_steps`. GPU 1 + high-level K>1 terminal loss, provided-state, and K=128 guards passed as + `3 passed in 5.61s`. GPU 2 reset parity passed absent and present reset + modes as `2 passed in 4.87s`. +- K=128/H=64 profile: + GPU 0 high-level audit wrote + `/tmp/redo_fixmaass_owner_profile_reverse_window_table_k128h64_materialization_trim` + with `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1`. It reported + `13.387 tok/s`, `74.702 ms`, peak `0.107805 GiB`, above the previous + accepted active-reverse replay row of `12.696 tok/s`. The corrected K gate + passed with matched T=1 training `65.376 tok/s`, divisor `128`, floor + `0.511 tok/s`, and ratio `26.210x`. +- Owner timings: + `transition_message_reverse_table_device_loop:ms=26.247;count=2` and + `temporal_artifact_recompute:ms=20.823;count=2` remain the top owners. The + new metadata markers include + `temporal_backward_glue:cuda_recompute_reverse_window_table` and + `temporal_backward_glue:cuda_reverse_engine_uses_replay_window_table`. +- T=1 regression probes: + GPU 0 single-pop B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_reverse_window_table_t1_b1024_single` and reported + `55383.513 tok/s`, peak `0.820897 GiB`, against April 21 reference + `5470.79 tok/s`. GPU 1 mixed-pop B=1024/T=1/K=1 training wrote + `/tmp/redo_fixmaass_reverse_window_table_t1_b1024_mixed` and reported + `62615.817 tok/s`, peak `0.691072 GiB`; the mixed stack comparison passed + with Fabric/stack ratio `3.562x`. +- Remaining: + R4 remains open. The table bridge now avoids duplicate per-step + tape/cache materialization and improves throughput, but the supported + training row still reports transitional `temporal_plan_forward_owners` and + `temporal_plan_backward_owners` as `python_autograd_scan`, and the top + measured owners are still the reverse table device loop and artifact + recompute. Next R4 work must move more of that reverse/replay ownership into + the CUDA temporal superop rather than adding route or benchmark logic. + +### 2026-04-28 UTC - R4 reverse message arithmetic cleanup probe + +Status: ACCEPTED AS SMALL GENERIC CUDA ARITHMETIC CLEANUP; R4 THROUGHPUT +OWNER REMAINS OPEN. + +Owner: active shared temporal backward reverse table device-loop internals. + +- Boundary review: + this probe touches only the existing table-owned reverse CUDA primitive. ABI + inputs remain tensor-table pointers, graph/message tables, op rows, reset + policy flags, temporal step rows, and flat bucket identity. It does not add a + cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, separate single/mixed route, or cell-family parameter + bundle. +- Implementation intent: + remove repeated `sqrtf(message_head_dim)` work from the receiver/sender + message reverse rows by using one reciprocal square-root scalar, and keep the + receiver-row sender scratch as integer identity rather than float scratch. + This is not R4 closure; it is a measured cleanup inside the current top + reverse-table owner before attempting a larger replay/reverse fusion. +- Acceptance rule: + accept only if low-level reverse parity, K>1 high-level guards, reset parity, + and the K=128/H=64 profile are at least as good as the current accepted + `13.387 tok/s` row. Otherwise revert and log as rejected. +- Implementation: + `transition_message_reverse_table_device_loop_kernel` now computes a single + reciprocal square-root for the message head dimension and reuses it in the + receiver and sender reverse rows. Receiver-row sender identity scratch is now + integer scratch rather than float scratch. +- Validation: + `py_compile` passed for the touched Python bridge files and + `tests/test_fabric_backend_boundaries.py` passed as `2 passed in 0.01s`. + GPU 0 low-level mixed reverse table parity passed as `1 passed in 45.97s` + with private cache + `/tmp/cortical_torch_ext_${USER}_redo_fixmass_reverse_msg_arith_low`. GPU 1 + K>1 high-level terminal/provided-state/K=128 guards passed as + `3 passed in 231.82s`. GPU 2 reset parity passed absent/present reset rows + as `2 passed in 229.56s`. +- K=128/H=64 profile: + GPU 0 audit wrote `/tmp/redo_fixmaass_owner_profile_reverse_msg_arith` and + reported `13.477 tok/s`, `74.198 ms`, peak `0.107805 GiB`, matched T=1 + training `65.147 tok/s`, K-adjusted floor `0.509 tok/s`, and ratio + `26.480x`. The reverse table device-loop owner improved from + `26.247 ms` to `25.783 ms`; temporal artifact recompute remained + first-order at `20.871 ms`. +- T=1 probes: + GPU 0 single-pop B=1024/T=1/K=1 wrote + `/tmp/redo_fixmaass_reverse_msg_arith_t1_b1024_single` and reported + `56541.455 tok/s`, peak `0.820897 GiB`, above the April 21 reference + `5470.79 tok/s`. A concurrent mixed-pop run was discarded as measurement + contention after reporting `35357.802 tok/s`; rerunning mixed-pop alone on + GPU 1 wrote `/tmp/redo_fixmaass_reverse_msg_arith_t1_b1024_mixed_rerun` and + reported `66898.775 tok/s`, peak `0.691072 GiB`, with + `launch_scan_implementations=['cuda_temporal_superop']`. +- Remaining: + R4 is still open. This cleanup reduced arithmetic overhead inside the + table-owned reverse message primitive, but it does not remove the remaining + replay/reverse split. The next high-priority R4 owner remains moving more of + artifact recompute plus reverse recurrence into one CUDA temporal superop + over flat tensor/op tables. + +### 2026-04-28 UTC - R4 artifact recompute owner split + +Status: ACCEPTED GUARDRAIL; R4 REPLAY/REVERSE FUSION OWNER. + +Owner: temporal artifact replay measurement and next-kernel targeting. + +- Boundary review: + this is measurement inside the active shared temporal path only. It does not + change planner policy, benchmark behavior, tensor-table ABI, or route + selection. It records subowners for the existing CUDA replay scan and + recurrent hidden-before materialization so the next R4 implementation targets + the actual measured replay/reverse cost instead of guessing. +- Implementation intent: + split the opaque `temporal_artifact_recompute` timing into + `artifact.recompute.cuda_temporal_replay_scan` and + `artifact.recompute.recurrent_hidden_before_window` for active CUDA + reverse-only windows. The remainder of `temporal_artifact_recompute` after + those timed subowners is the Python/table bridge and per-step artifact shell. + The next kernel/refactor step should be chosen from those measured subowners. +- Validation: + `py_compile` passed for `temporal_backward.py`; `ruff format --check`, + `ruff check`, and `tests/test_fabric_backend_boundaries.py` passed. +- K=128/H=64 attribution profile: + GPU 0 audit wrote `/tmp/redo_fixmaass_owner_profile_artifact_split` and + reported `13.454 tok/s`, `74.327 ms`, peak `0.107805 GiB`, matched T=1 + training `64.463 tok/s`, K-adjusted floor `0.504 tok/s`, and ratio + `26.715x`. Owner timings show + `transition_message_reverse_table_device_loop:ms=25.783;count=2`, + `temporal_artifact_recompute:ms=20.972;count=2`, + `artifact.recompute.cuda_temporal_replay_scan:ms=20.069;count=2`, and + `artifact.recompute.recurrent_hidden_before_window:ms=0.015;count=2`. +- Remaining: + the artifact cost is the CUDA replay scan itself, not hidden-before + materialization or table shell assembly. The next real R4 backend slice should + fuse or co-own replay scan plus reverse recurrence in the CUDA temporal + superop, while preserving the same flat tensor/op-table ABI. + +### 2026-04-28 UTC - R4 reverse replay output materialization trim + +Status: ACCEPTED; R4 REPLAY/REVERSE FUSION STILL OPEN. + +Owner: active reverse-only artifact replay inside the shared temporal scan. + +- Boundary review: + this change is not a cell-family route and does not add a single/mixed-pop + path. It adds a generic scan-table materialization flag for output readout + tensors. The flag controls whether a CUDA temporal replay materializes output + readout/checkpoint tensors; recurrence, message, transition tape, reset, + checkpoint, and reverse tables remain unchanged. The active reverse-only + backward path already consumes direct output adjoints and reverse replay + tables, not replayed output readouts, so output readout materialization is + backend-owned dead work for that owner. +- Implementation intent: + keep default scan behavior unchanged, but pass + `materialize_output_readout=False` when recomputing active reverse-only + artifacts for the transition-message reverse table path. The CUDA scan should + skip checkpoint output readout writes and final output-sequence readout work + for that replay window while still returning the recurrent/message/transition + artifacts consumed by reverse. +- Validation: + `py_compile` passed for the touched Python files; `ruff format --check`, + `ruff check`, `git diff --check`, and `tests/test_fabric_backend_boundaries.py` + passed. GPU 2 low-level reverse table parity passed via + `test_fabric_cuda_transition_message_reverse_table_window_matches_mixed_step_loop`. + GPU 0 T=1 mixed-pop shared reverse parity passed for reset absent and reset + present. GPU 1 K=128/H=64 mixed-pop backward parity smoke passed. +- Current-code T=1 probes: + GPU 1 mixed-pop B=1024 wrote + `/tmp/redo_fixmaass_materialize_t1_b1024_mixed` and reported + `67543.529 tok/s`, peak `0.691072 GiB`, mixed Fabric/stack ratio `3.432x`. + GPU 2 single-pop B=1024 wrote + `/tmp/redo_fixmaass_materialize_t1_b1024_single` and reported + `56430.735 tok/s`, peak `0.820897 GiB`. Both remain above the April 21 + `h8_many_cell_stress_broad` reference of `5470.79 tok/s`. +- K=128/H=64 owner evidence: + first profile wrote `/tmp/redo_fixmaass_owner_profile_materialize_output_trim` + and reported `13.429 tok/s`, `74.464 ms`, peak `0.107557 GiB`, with + `artifact.recompute.cuda_temporal_replay_scan:ms=19.040;count=2`. + Because tok/s was slightly below the previous profile, an isolated rerun was + required. The accepted rerun wrote + `/tmp/redo_fixmaass_owner_profile_materialize_output_trim_rerun` and reported + `13.676 tok/s`, `73.121 ms`, peak `0.107557 GiB`, matched T=1 training + `68.271 tok/s`, K-adjusted floor `0.533 tok/s`, and ratio `25.641x`. + Owner timings were + `transition_message_reverse_table_device_loop:ms=25.848;count=2`, + `temporal_artifact_recompute:ms=19.867;count=2`, + `artifact.recompute.cuda_temporal_replay_scan:ms=18.931;count=2`, + `artifact.recompute.cuda_replay_input_projection:ms=0.284;count=2`, and + `artifact.recompute.recurrent_hidden_before_window:ms=0.015;count=2`. +- Accepted change: + the materialization flag moved the targeted CUDA replay scan owner from the + prior accepted `20.069 ms` profile to `18.931 ms` without changing public + benchmark/API behavior or backend route identity. The largest remaining R4 + owner is still the table-owned reverse device loop plus replay/reverse split; + full R4 closure still requires deeper CUDA temporal superop ownership, not + this materialization trim alone. + +### 2026-04-29 UTC - R15 cleanup breadth re-locked + +Status: PENDING; R15 FULL CLEANUP OWNER PRESERVED. + +Owner: R15 cleanup board and final legacy deletion audit. + +- User correction: + R15 is not just the recently discussed `src/cortical/fabric/config.py` and + `src/cortical/fabric/anatomy.py` cleanup. Those are severe symptoms, but the + owner is the full Fabric cleanup/deletion stage recorded from recovered + additional goals and later corrections. +- R15 must continue to include: + stale execution-route deletion, old single/mixed sibling paths, benchmark-side + planner/backend logic, hidden message/cell math outside declared + `fabric.cuda.nn` primitives, legacy config truth paths, lattice-specific facts + leaking into generic Fabric surfaces, Blueprint facade removal, message-rule + genericity, graph protocol ownership, old public-test cleanup, + population/cardinality cleanup, hidden-size and magic-threshold planner + cleanup, cell-family backend-surface cleanup, manual/static guardrails, and + final grep/code-review audits that prove those leaks were deleted. +- Priority rule: + R15 remains high priority and cannot be narrowed for convenience. It is not + the active owner only because R4 shared temporal backend ownership is still + open and throughput-critical. When R4/R11-R14 audit gates are ready, R15 must + close the full cleanup board, not only the latest config/anatomy examples. + +### 2026-04-29 UTC - R4 reverse recurrent-message materialization owner + +Status: REJECTED; CODE ROLLED BACK. + +Owner: table-owned reverse device loop materialization policy. + +- Current owner: + after the output-readout replay trim, the warmed K=128/H=64 profile still + reports `transition_message_reverse_table_device_loop` as the largest timed + owner at roughly `25.8 ms` for two backward windows. The next change must + reduce actual reverse-table work, not only rename metadata. +- Boundary review before edit: + the proposed change is a generic materialization policy for an intermediate + recurrent-message gradient table. The reverse kernel already computes that + table so the message backward primitive can consume it inside the same + temporal loop. The active high-level training path discards the returned full + `T x B x R x value_dim` tensor, so materializing that full surface is dead + work for active reverse-only windows. The low-level/default ABI should keep + emitting the full tensor for parity tests and debugging. +- Design constraints: + do not add a cell-family selector, population selector, benchmark-row branch, + hidden-size policy key, or separate single/mixed route. The flag must be + expressed as tensor-table materialization policy, like the previous output + readout materialization trim. The kernel still consumes flat bucket identity, + message tables, reset policy, tensor slots, and primitive op rows. +- Intended implementation: + add a `materialize_grad_recurrent_msg_window` flag to the transition/message + reverse-table binding. When true, preserve the existing full-window return. + When false, use a one-step `[B, R, value_dim]` work surface inside the device + loop and return an empty placeholder for the unused output slot. +- Validation before rejection: + static checks passed; low-level transition/message reverse-table parity + passed on GPU 2; high-level T=1 mixed-pop reset absent/present parity passed + on GPU 0; K=128/H=64 backward smoke passed on GPU 1. +- Rejection evidence: + the first K=128/H=64 owner profile on GPU 3 wrote + `/tmp/redo_fixmaass_owner_profile_grad_msg_materialize` and reported + `13.446 tok/s`, `transition_message_reverse_table_device_loop:ms=25.940` + and `artifact.recompute.cuda_temporal_replay_scan:ms=19.120`. A rerun in the + already-compiled private cache wrote + `/tmp/redo_fixmaass_owner_profile_grad_msg_materialize_rerun` and reported + `13.545 tok/s`, `transition_message_reverse_table_device_loop:ms=25.917`, + `artifact.recompute.cuda_temporal_replay_scan:ms=19.005`, peak memory + unchanged at `0.107557 GiB`. +- Decision: + reject and roll back this code path. It was semantically generic and + parity-clean, but it did not move the measured owner relative to the prior + accepted profile (`13.676 tok/s`, reverse-table `25.848 ms`, replay scan + `18.931 ms`). Do not repeat this one-step recurrent-message materialization + flag unless a future profile shows full recurrent-message output allocation + is a real owner on a different row. + +### 2026-04-29 UTC - R4 reverse table single-batch atomic owner + +Status: REJECTED; CODE ROLLED BACK. + +Owner: transition/message reverse-table parameter-gradient atomics. + +- Current owner: + after rejecting the recurrent-message materialization probe, the active + K=128/H=64 owner remains `transition_message_reverse_table_device_loop`. + The accepted baseline is still the materialization-trim rerun: + `13.676 tok/s`, reverse-table `25.848 ms`, replay scan `18.931 ms`. +- Boundary review before edit: + this is a runtime-shape optimization inside the generic table-owned reverse + kernel. It does not add a cell-family selector, population selector, + hidden-size policy key, benchmark-row selector, or separate single/mixed + route. The optimization applies only when `B == 1`, where each transition row + and each sender row has exactly one batch contributor to its own parameter + gradient slot. Multi-batch execution keeps atomics. +- Intended implementation: + replace unnecessary parameter-gradient `atomicAdd` calls with direct + additions for the `B == 1` case in the transition primitive parameter writes, + recurrent-query parameter write, and sender-reverse recurrent K/V weight + write. Keep input K/V and receiver-shared sender paths atomic because those + can still have multiple receiver contributors even at `B == 1`. +- Validation before rejection: + static guardrails and `tests/test_fabric_backend_boundaries.py` passed; + low-level transition/message reverse-table parity passed on GPU 2; high-level + T=1 mixed-pop reset absent/present parity passed on GPU 0; K=128/H=64 + backward smoke passed on GPU 1. +- Rejection evidence: + the first K=128/H=64 owner profile on GPU 3 wrote + `/tmp/redo_fixmaass_owner_profile_single_batch_atomic` and reported + `12.834 tok/s`, `transition_message_reverse_table_device_loop:ms=28.993`, + `artifact.recompute.cuda_temporal_replay_scan:ms=19.015`, peak memory + `0.107557 GiB`. A rerun in the already-compiled private cache wrote + `/tmp/redo_fixmaass_owner_profile_single_batch_atomic_rerun` and reported + `12.995 tok/s`, `transition_message_reverse_table_device_loop:ms=28.978`, + and `artifact.recompute.cuda_temporal_replay_scan:ms=18.993`. +- Decision: + reject and roll back. Although the change was generic and parity-clean, it + made the reverse-table kernel slower than the accepted baseline. Do not + reintroduce direct single-batch global additions for these parameter-gradient + writes without lower-level profiling evidence that the compiler/hardware path + has changed. + +### 2026-04-29 UTC - R4 gated reverse preactivation reuse + +Status: REJECTED; CODE ROLLED BACK. + +Owner: transition/message reverse-table gated primitive arithmetic. + +- Current owner: + the active K=128/H=64 row is still dominated by + `transition_message_reverse_table_device_loop`; two smaller local probes were + rejected and rolled back because they did not move throughput. +- Boundary review before edit: + this change stays inside the existing generic reverse table kernel and the + lowered gated-logspace primitive row. It does not add a cell-family selector, + population selector, benchmark-row selector, hidden-size policy key, route + split, or planner behavior. The ABI remains the same tensor-table/op-table + surface. +- Intended implementation: + the gated reverse row currently reconstructs gate preactivations once to + produce normalized public output and then reloads/readds the same gate logits + during the gradient pass. Reuse the existing per-row scratch arrays to hold + those preactivation values across the two passes, then overwrite the same + arrays with gate-gradient values before the recurrent-affine projection pass. +- Validation before rejection: + static guardrails and `tests/test_fabric_backend_boundaries.py` passed; + low-level transition/message reverse-table parity passed on GPU 2; high-level + T=1 mixed-pop reset absent/present parity passed on GPU 0; K=128/H=64 + backward smoke passed on GPU 1. +- Rejection evidence: + the required K=128/H=64 audit profile failed to produce a case result twice. + GPU 4 wrote only + `/tmp/redo_fixmaass_owner_profile_gated_preact/manifest.json` before the + process exited with code `-1`; GPU 3 retry wrote only + `/tmp/redo_fixmaass_owner_profile_gated_preact_retry/manifest.json` before + the same exit. Since no current-code owner timing was produced, this cannot + be accepted as a throughput improvement. +- Decision: + reject and roll back. Do not reintroduce this scratch reuse unless it is first + isolated with lower-level profiling that explains the profile-run process + exits and proves the K=128/H=64 row can complete. + +### 2026-04-29 UTC - R4 reverse replay/message artifact ownership plan + +Status: REJECTED; CODE ROLLED BACK. + +Owner: temporal backward / shared reverse table superop. + +- Current owner: + accepted warmed K=128/H=64 evidence still has + `transition_message_reverse_table_device_loop` and + `artifact.recompute.cuda_temporal_replay_scan` as the active hot owners. The + last three local reverse-kernel probes were rejected because they did not + improve that row. +- Boundary review before edit: + active reverse-only replay still asks the temporal scan to copy + recurrent-message artifacts so Python can later build transition + input-projection parameter-gradient steps. The edit must keep the reverse + ABI generic over lowered primitive/tensor roles and flat bucket identity. + It must not add cell-family, population-name, benchmark-row, hidden-size, or + single/mixed route selectors. +- Intended implementation: + make the reverse table emit transition input-projection parameter gradients + directly for the existing primitive projection tensor roles: + gated input projection weight/bias and diagonal input projection weight/bias. + Then active reverse-only replay no longer needs recurrent-message artifact + materialization from the scan. +- Acceptance gate: + low-level reverse-table parity, T=1 reset parity, K=128/H=64 smoke, and a + warmed owner profile must pass. Keep the change only if owner timing or + throughput moves without violating the generic Fabric backend boundary. + +Validation and decision: + +- Implemented probe: + the reverse table emitted gated and diagonal input-projection parameter + gradients directly, and active reverse-only replay stopped requesting + recurrent-message artifacts. The first in-kernel fusion attempt crashed on + the sender-reverse cooperative path, so the projection-gradient work was + split into a separate table-owned CUDA kernel after the cooperative reverse + loop. +- Parity before profiling: + static checks passed; `tests/test_fabric_backend_boundaries.py` passed; + low-level transition/message reverse-table parity passed with + sender-reverse coverage on GPU 2; high-level T=1 mixed-pop reset + absent/present parity passed on GPU 0; K=128/H=64 backward smoke passed on + GPU 1. +- Rejection evidence: + the first K=128/H=64 profile wrote + `/tmp/redo_fixmaass_owner_profile_proj_grad_split` and reported + `13.509 tok/s`, peak `0.104627 GiB`. The warm rerun in the same private + cache wrote `/tmp/redo_fixmaass_owner_profile_proj_grad_split_rerun` and + reported `13.581 tok/s`, `transition_message_reverse_table_device_loop:ms=26.021`, + `artifact.recompute.cuda_temporal_replay_scan:ms=18.973`, and peak + `0.104627 GiB`. +- Decision: + reject and roll back the code. The change was generic and parity-clean after + the split, but it remained slower than the accepted baseline + (`13.676 tok/s`, reverse-table `25.848 ms`, replay scan `18.931 ms`). Do not + repeat this projection-param split unless a future owner profile shows + recurrent-message artifact materialization has become a measured bottleneck + large enough to pay for the extra CUDA pass. + +### 2026-04-29 UTC - R4 fused replay/reverse structural owner + +Status: ACCEPTED AS REPLAY MATERIALIZATION TRIM; R4 REPLAY/REVERSE FUSION +REMAINS OPEN. + +Owner: temporal backward / shared flat-bucket reverse superop. + +- Current owner: + the accepted active row remains the output-readout replay trim: + `13.676 tok/s`, `transition_message_reverse_table_device_loop:ms=25.848`, + `artifact.recompute.cuda_temporal_replay_scan:ms=18.931`. The last four + local reverse/materialization probes were rejected. The next useful backend + work must therefore target the replay/reverse split itself. +- Boundary invariant: + active reverse must be table-owned over flat bucket identity, lowered + primitive tensor/op roles, message primitive rows, reset policy, checkpoint + policy, and materialization policy. It must not introduce cell-family names, + population-name route keys, benchmark-row selectors, hidden-size policy keys, + separate single/mixed route identities, or bespoke cell bundles in the shared + temporal engine. Primitive dimensions such as head width are tensor/op-table + metadata only. +- Structural target: + active reverse-only windows should stop requiring Python to assemble a replay + artifact object graph whose main purpose is to copy forward intermediates from + a CUDA replay scan into a second CUDA reverse launch. The reverse owner should + accept checkpoint/initial state plus scan inputs and run the necessary window + recompute and reverse work through one table-owned temporal path, or expose a + guardrail that refuses to claim full CUDA temporal ownership while that split + remains physical. +- First implementation slice: + active reverse-only replay does not consume diagonal eligibility-trace state + tensors from the CUDA replay scan. Those traces are useful when final state or + later checkpoint state is user/materialization visible, but the active reverse + table recomputes parameter gradients directly and only needs the primitive + state-before windows (`hc1`, `hc2`) plus the transition/message tensors. + Introduce an explicit scan materialization policy for this primitive extra + state, keep the default behavior unchanged, and disable it only for + active-reverse replay windows. This targets + `artifact.recompute.cuda_temporal_replay_scan`, not benchmark metadata. +- Acceptance gate: + manual boundary review, static guardrails, low-level reverse-table parity, + T=1 mixed reset parity, K=128/H=64 smoke, and warmed K=128/H=64 owner profile. + Keep code only if the active owner physically moves or the guardrail prevents + a false R4 close without weakening the shared-engine target. + +Validation and decision: + +- Manual boundary review: + the scan binding now has a materialization policy for a primitive extra-state + table. The default remains materialized. The only opt-out is the + active-reverse replay path where the reverse table does not consume those + trace-state tensors. This adds no cell-kind selector, population-name route + key, benchmark-row selector, hidden-size policy key, separate single/mixed + route identity, or cell-family parameter bundle. The implementation still + flows through the existing temporal tensor/op table binding and flat bucket + identity. +- Static checks: + `git diff --check`, `ruff check` on the touched Python files, `py_compile` on + the touched Python files, and `tests/test_fabric_backend_boundaries.py` + passed. +- Parity/smoke: + GPU 3 default trace-materialized final-state temporal-superop test passed + (`4 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_trace_default_v2` and + `/tmp/cortical_triton_${USER}_redo_trace_default_v2`. GPU 0 T=1 mixed + reset/no-reset training passed (`2 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_trace_t1_reset_v2` and + `/tmp/cortical_triton_${USER}_redo_trace_t1_reset_v2`. GPU 1 K=128/H=64 + mixed backward smoke passed (`1 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_trace_optout_k128_v2` and + `/tmp/cortical_triton_${USER}_redo_trace_optout_k128_v2`. GPU 2 low-level + transition/message reverse-table parity passed with private caches + `/tmp/cortical_torch_ext_${USER}_redo_trace_reverse_lowlevel` and + `/tmp/cortical_triton_${USER}_redo_trace_reverse_lowlevel`. +- K=128/H=64 owner evidence: + first profile wrote `/tmp/redo_fixmaass_owner_profile_trace_optout` and + reported `13.792 tok/s`, `artifact.recompute.cuda_temporal_replay_scan:ms=17.707`, + `transition_message_reverse_table_device_loop:ms=25.793`, and peak memory + `0.107511 GiB`. Warm rerun in the same private cache wrote + `/tmp/redo_fixmaass_owner_profile_trace_optout_rerun` and reported + `13.885 tok/s`, `artifact.recompute.cuda_temporal_replay_scan:ms=17.611`, + `temporal_artifact_recompute:ms=18.552`, + `transition_message_reverse_table_device_loop:ms=25.768`, and peak memory + `0.107511 GiB`. +- T=1 non-regression spot-check: + longer same-cache B=1024 probes wrote + `/tmp/redo_fixmaass_trace_t1_b1024_mixed_long` and + `/tmp/redo_fixmaass_trace_t1_b1024_single_long`. Mixed-pop reported + `68501.683 tok/s`, peak `0.691072 GiB`, and Fabric/stack `3.468x`, above the + prior accepted mixed artifact (`67543.529 tok/s`). Single-pop reported + `56387.724 tok/s`, peak `0.820897 GiB`, effectively flat with the prior + accepted single artifact (`56430.735 tok/s`) and far above the April 21 + `5470.79 tok/s` reference. +- Decision: + accept this replay materialization trim. It physically moved the measured + replay owner from the prior accepted `18.931 ms` to `17.611 ms` and improved + the active K=128/H=64 row from `13.676` to `13.885 tok/s` without T=1 memory + regression. R4 remains open because the active path still has separate replay + and reverse launches, and `transition_message_reverse_table_device_loop` + remains the largest measured owner. + +### 2026-04-29 UTC - R4 reverse gated local-memory owner + +Status: REJECTED; CODE ROLLED BACK. + +Owner: temporal backward / shared flat-bucket reverse superop internals. + +- Current owner: + the accepted active K=128/H=64 row is + `/tmp/redo_fixmaass_owner_profile_trace_optout_rerun` at `13.885 tok/s`, + with `transition_message_reverse_table_device_loop:ms=25.768` and + `artifact.recompute.cuda_temporal_replay_scan:ms=17.611`. The replay owner + moved, so the next high-priority owner is the reverse table device loop. +- Boundary invariant: + the edit may only change the existing lowered primitive row implementation + inside the tensor-table reverse superop. ABI inputs remain flat bucket + identity, tensor-table/op-row roles, graph/message tables, reset policy, and + primitive tensor slots. It must not introduce a cell-kind selector, + population-name route key, benchmark-row selector, hidden-size policy key, + separate single/mixed route identity, or cell-family parameter bundle. +- Implementation target: + the gated-logspace primitive backward row currently allocates fixed-size + per-thread local scratch arrays (`raw_i`, `raw_f`, `raw_z`, `raw_o`, + `y_out`, `grad_public_buf`) sized to the maximum supported hidden width. + Active audit rows include small hidden sizes where this is wasted local + memory pressure in the dominant reverse kernel. Replace those arrays with a + primitive-local helper that recomputes forward scalars as needed and reuses + the already-materialized `grad_gated_raw_window` table for recurrent-affine + parameter accumulation. Accept only if parity passes and the warmed owner row + moves or stays flat without T=1 regression. + +Validation and decision: + +- Implemented probe: + removed the fixed local scratch arrays from the gated primitive row, added a + primitive-local forward-scalar helper, and reused `grad_gated_raw_window` as + the table-owned source for recurrent-affine accumulation. The edit stayed + inside the existing reverse tensor-table ABI and added no cell-kind, + population-name, benchmark-row, hidden-size, or single/mixed route selector. +- Parity: + `git diff --check` and `tests/test_fabric_backend_boundaries.py` passed. + GPU 2 low-level transition/message reverse-table parity passed (`1 passed`) + with private caches + `/tmp/cortical_torch_ext_${USER}_redo_gated_localmem_low` and + `/tmp/cortical_triton_${USER}_redo_gated_localmem_low`. GPU 0 high-level + T=1 mixed reset/no-reset training passed (`2 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_gated_localmem_t1_reset` and + `/tmp/cortical_triton_${USER}_redo_gated_localmem_t1_reset`. GPU 1 K=128/H=64 + backward smoke passed (`1 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_gated_localmem_k128_smoke` and + `/tmp/cortical_triton_${USER}_redo_gated_localmem_k128_smoke`. +- Rejection evidence: + GPU 3 first active profile wrote + `/tmp/redo_fixmaass_owner_profile_gated_localmem` and reported + `12.837 tok/s`, `transition_message_reverse_table_device_loop:ms=30.906`, + `artifact.recompute.cuda_temporal_replay_scan:ms=17.670`, peak + `0.107511 GiB`. Same-cache warm confirmation wrote + `/tmp/redo_fixmaass_owner_profile_gated_localmem_rerun` and reported + `12.917 tok/s`, `transition_message_reverse_table_device_loop:ms=30.735`, + `artifact.recompute.cuda_temporal_replay_scan:ms=17.631`, peak + `0.107511 GiB`. +- Decision: + reject and roll back the kernel code. The slice was parity-clean and generic, + but the recomputation/global-table-read tradeoff regressed the accepted + active row (`13.885 tok/s`, reverse table `25.768 ms`). Do not repeat this + scratch-removal probe unless a later profile shows local-memory pressure has + changed enough to justify a different implementation. + +### 2026-04-29 UTC - R4 reverse table policy specialization owner + +Status: REJECTED; CODE ROLLED BACK. + +Owner: temporal backward / shared flat-bucket reverse superop internals. + +- Current owner: + after rejecting the scratch-removal probe, the accepted active baseline + remains `/tmp/redo_fixmaass_owner_profile_trace_optout_rerun` at + `13.885 tok/s`, with `transition_message_reverse_table_device_loop:ms=25.768` + and `artifact.recompute.cuda_temporal_replay_scan:ms=17.611`. +- Boundary invariant: + specializing the reverse kernel is allowed only for generic backend policy + bits already present in the tensor-table ABI: sender-reverse availability and + delay policy. It must not specialize on cell family, population names, + benchmark rows, hidden-size thresholds, single/mixed identity, or graph + constructor shape. +- Implementation target: + convert the reverse table cooperative kernel to compile-time variants for + `use_sender_reverse` and `use_delay`, selected by the existing runtime + policy flags before launch. This should remove sender-reverse and delay + branches from the inner message reverse loops while preserving the same + tensor/op-table ABI and reset behavior. + +Validation and decision: + +- Implemented probe: + templated the cooperative reverse table device loop on `use_sender_reverse` + and `use_delay`, and selected the variant from the existing runtime policy + flags before launch. Reset behavior remained runtime-driven. The edit added + no cell-family, population-name, benchmark-row, hidden-size, single/mixed, or + graph-constructor specialization. +- Parity: + `git diff --check` and `tests/test_fabric_backend_boundaries.py` passed. GPU + 2 low-level reverse-table parity passed (`1 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_policy_specialize_low` and + `/tmp/cortical_triton_${USER}_redo_policy_specialize_low`. GPU 1 K=128/H=64 + backward smoke passed (`1 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_policy_specialize_k128_smoke` and + `/tmp/cortical_triton_${USER}_redo_policy_specialize_k128_smoke`. GPU 0 T=1 + mixed reset/no-reset training passed (`2 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_policy_specialize_t1_reset` and + `/tmp/cortical_triton_${USER}_redo_policy_specialize_t1_reset`. +- Rejection evidence: + GPU 3 first active profile wrote + `/tmp/redo_fixmaass_owner_profile_policy_specialize` and reported + `13.688 tok/s`, `transition_message_reverse_table_device_loop:ms=25.768`, + `artifact.recompute.cuda_temporal_replay_scan:ms=17.656`, peak + `0.107511 GiB`. Same-cache warm confirmation wrote + `/tmp/redo_fixmaass_owner_profile_policy_specialize_rerun` and reported + `13.723 tok/s`, `transition_message_reverse_table_device_loop:ms=25.688`, + `artifact.recompute.cuda_temporal_replay_scan:ms=17.640`, peak + `0.107511 GiB`. +- Decision: + reject and roll back the kernel code. The probe was generic and parity-clean, + but it did not materially move the reverse-table owner and stayed below the + accepted active throughput baseline (`13.885 tok/s`). Do not repeat this + policy-specialization-only edit unless paired with a larger reverse-kernel + structural change. + +### 2026-04-29 UTC - R4 reverse sender-accumulation policy probe + +Status: REJECTED; CODE ROLLED BACK. + +Owner: temporal backward / shared flat-bucket reverse message primitive. + +- Current owner: + the reverse table device loop remains the dominant measured owner after two + rejected kernel-only probes. The active shape is generic flat-bucket runtime + metadata (`B=1`, `R=384`, `H=8`, horizon window `64`), not a cell-family + route. +- Boundary invariant: + sender-gradient accumulation policy is a backend message-primitive execution + decision. It may depend on generic runtime table facts and work estimates + such as batch, receiver count, hidden width, head/value width, offsets, delay + policy, and workspace, but not on population names, cell family, benchmark + rows, lattice shape, or single/mixed identity. +- Probe target: + the active path currently forces the sender-reverse message accumulation + algorithm. That path avoids receiver-loop atomics but adds sender-reverse + work buffers and cooperative-grid synchronization phases each timestep. + Probe the direct receiver-side accumulation algorithm on the same reverse + table ABI to see whether the current generic work regime should be planner + policy instead of a hardcoded runtime choice. + +Validation and decision: + +- Implemented probe: + changed the active reverse table call to use direct receiver-side recurrent + sender accumulation (`use_sender_reverse=False`) while keeping the same + tensor/op-table ABI and reverse kernel. This was only a probe; if it had won, + the decision would have moved into generic backend policy rather than + staying as a hardcoded callsite flag. +- Parity: + static checks passed (`git diff --check`, `ruff check`, `py_compile`). GPU 1 + K=128/H=64 backward smoke passed (`1 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_sender_policy_k128_smoke` and + `/tmp/cortical_triton_${USER}_redo_sender_policy_k128_smoke`. GPU 0 T=1 + mixed reset/no-reset training passed (`2 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_sender_policy_t1_reset` and + `/tmp/cortical_triton_${USER}_redo_sender_policy_t1_reset`. +- Rejection evidence: + GPU 3 active profile wrote + `/tmp/redo_fixmaass_owner_profile_sender_policy` and reported + `8.144 tok/s`, `transition_message_reverse_table_device_loop:ms=76.212`, + `artifact.recompute.cuda_temporal_replay_scan:ms=17.638`, peak + `0.107511 GiB`. +- Decision: + reject and roll back. The direct receiver-side accumulation algorithm is + parity-clean but far slower for the active generic work regime. Keep + sender-reverse active for this row; do not spend more R4 cycles on this policy + axis unless a different shape profile identifies sender-reverse workspace or + synchronization as the measured owner. + +### 2026-04-29 UTC - R4 full-window replay artifact table probe + +Status: REJECTED; CODE ROLLED BACK. + +Owner: temporal artifact replay / reverse-window table bridge. + +- Current owner: + after the sender-policy rejection, the active accepted baseline remains + `/tmp/redo_fixmaass_owner_profile_trace_optout_rerun` at `13.885 tok/s`. + The largest device owner is still the reverse table, but the replay/reverse + split also keeps Python assembling contiguous reverse-window tensors from + `checkpoint_*` plus `final_*` scan outputs. +- Boundary invariant: + replay artifacts are tensor-table materialization policy, not cell or route + policy. A full-window replay table may be selected for active reverse if the + ABI remains flat bucket tensor/op roles and does not add population names, + cell-family selectors, benchmark rows, hidden-size policy, or single/mixed + route identities. +- Probe target: + allocate replay artifact/tape tensors with `physical_steps` rows instead of + `physical_steps - 1`, write the final step into both the full-window table + and the existing final tensor, and let `_scan_replay_tensor_window` consume + the full table directly. This should remove Python `torch.cat` window + assembly for recurrent message and transition tape tensors while keeping + default final tensor fields available to existing callers. + +Validation and decision: + +- Implemented probe: + replay recurrent-message, output-message, and transition-tape artifact + tensors were allocated as full `physical_steps` tables, with final tensors + still populated for existing callers. `_scan_replay_tensor_window` consumed a + full table directly when present. The edit stayed within replay + materialization policy and did not add cell-family, population-name, + benchmark-row, hidden-size, single/mixed, or graph-constructor selectors. +- Passing checks before rejection: + static checks passed (`git diff --check`, `ruff check`, `py_compile`, + `tests/test_fabric_backend_boundaries.py`). GPU 2 low-level reverse-table + parity passed (`1 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_full_replay_low` and + `/tmp/cortical_triton_${USER}_redo_full_replay_low`. GPU 3 default + final-state temporal-superop tests passed (`4 passed`) with private caches + `/tmp/cortical_torch_ext_${USER}_redo_full_replay_default` and + `/tmp/cortical_triton_${USER}_redo_full_replay_default`. +- Rejection evidence: + high-level GPU 0 T=1 mixed reset/no-reset training and GPU 1 K=128/H=64 smoke + both exited with process code `-1` and no pytest output using private caches + `/tmp/cortical_torch_ext_${USER}_redo_full_replay_t1_reset` and + `/tmp/cortical_torch_ext_${USER}_redo_full_replay_k128_smoke`. A + `CUDA_LAUNCH_BLOCKING=1` K=128/H=64 debug rerun on GPU 0 also exited with + code `-1` and no traceback using + `/tmp/cortical_torch_ext_${USER}_redo_full_replay_debug`. +- Decision: + reject and roll back. The probe was not safe on the high-level active + training path. Do not repeat full-window replay artifact materialization + without first adding a smaller low-level reproducer that covers the + high-level active reverse replay window shape. + +### 2026-04-29 UTC - R4 reverse sender scratch traffic trim + +Status: REJECTED; CODE ROLLED BACK. + +Owner: active shared flat-bucket reverse message primitive. + +- Current owner: + R4 remains open. The accepted active row is still + `/tmp/redo_fixmaass_owner_profile_trace_optout_rerun` at `13.885 tok/s`, + with `transition_message_reverse_table_device_loop:ms=25.768` and + `artifact.recompute.cuda_temporal_replay_scan:ms=17.611`. A live current-code + rerun on GPU 0 using private caches + `/tmp/cortical_torch_ext_${USER}_redo_live_owner_pre` and + `/tmp/cortical_triton_${USER}_redo_live_owner_pre` exited with code `-1` + after writing only `/tmp/redo_fixmaass_live_owner_pre/manifest.json`; this is + non-evidence and must not replace the accepted warmed baseline. +- Boundary invariant: + the edit stays inside the existing table-owned recurrent message primitive. + ABI inputs remain flat bucket identity, tensor-table/op-table roles, + graph/message sender tables, reset/delay policy, and primitive tensor slots. + No cell-kind selector, population-name route key, benchmark-row selector, + hidden-size policy key, single/mixed route split, or cell-family parameter + bundle is allowed. +- Implementation target: + the sender-reverse phase only consumes scratch rows for recurrent senders. + Receiver rows currently initialize and write sender-reverse scratch for every + valid sender, including input-boundary senders that are never read by the + recurrent-sender phase. Also, when delay is disabled, every recurrent + sender/receiver table entry is written before it is read, so the per-row + scratch zeroing for delayed absences is unnecessary. Trim those generic + scratch writes while preserving the existing delayed-edge path. +- Validation and rejection: + the first probe skipped scratch zeroing when delay was disabled and skipped + scratch writes for input-boundary senders. GPU 1 low-level reverse-table + parity exited with code `-1` and no pytest output using private cache + `/tmp/cortical_torch_ext_${USER}_redo_scratch_trim_low_serial`. A narrowed + probe restored unconditional scratch zeroing and only skipped input-boundary + writes; GPU 1 low-level reverse-table parity again exited with code `-1` + using `/tmp/cortical_torch_ext_${USER}_redo_scratch_trim_input_low`. + Restoring the accepted kernel made the same low-level parity pass + (`1 passed in 45.22s`) with + `/tmp/cortical_torch_ext_${USER}_redo_scratch_trim_reverted_low`. +- Decision: + reject and roll back. The sender-reverse scratch layout is more coupled to + the sender/receiver offset table than this local traffic trim assumed. Do not + retry this axis without first changing the sender scratch ABI to carry an + explicit receiver-slot mapping or adding a low-level reproducer that proves a + new zero/store policy cannot read stale work slots. + +### 2026-04-29 UTC - R4 explicit sender reverse slot ABI + +Status: ACCEPTED AS GENERIC SENDER-REVERSE TABLE ABI; R4 REMAINS OPEN. + +Owner: active shared flat-bucket reverse message primitive. + +- Current owner: + the rejected scratch trim showed that the sender-reverse scratch table + implicitly couples sender-list position to receiver-local offset. That hidden + coupling makes local traffic reductions unsafe and is the wrong ABI for a + generic table-owned temporal message primitive. +- Boundary invariant: + this slice may only add graph/message table metadata: a sender-owned receiver + table plus the matching receiver-local slot table. It must not add + cell-family names, population names, benchmark-row selectors, hidden-size + policy keys, single/mixed route identities, or message math hardcoded outside + the lowered primitive roles. +- Implementation target: + build compact sender-reverse tables at runtime from the flat + receiver-sender table: one table for receiver id, one table for the + receiver-local slot whose scratch row contains that receiver/sender edge. + Pass both through the reverse tensor-table ABI and let the CUDA sender phase + use the explicit slot instead of assuming the sender-list column is also the + receiver-local column. Keep the old same-offset behavior as the fallback for + callers that have not supplied compact slot tables. +- Implementation: + runtime now materializes compact sender-owned reverse receiver tables and the + matching receiver-local slot tables for backend-order and flat-bucket carry + order recurrent message tables. The transition/message reverse tensor-table + ABI has a new `message.sender_receiver_slot_idx` role and a compact-table + scalar flag. The CUDA sender phase uses the explicit slot to read the + receiver-owned work rows; callers without a slot table keep the old same + offset fallback. This is graph/message table metadata only, with no + cell-family, population-name, hidden-size, benchmark-row, or single/mixed + route selector. +- Validation: + static checks passed (`py_compile`, `ruff check`, `git diff --check`) plus + `tests/test_fabric_backend_boundaries.py` (`2 passed`). GPU 1 low-level + reverse-table parity passed with private cache + `/tmp/cortical_torch_ext_${USER}_redo_sender_slot_low` (`1 passed in + 45.16s`). High-level mixed T=1 reset/no-reset training passed on GPU 0 with + warmed private cache + `/tmp/cortical_torch_ext_${USER}_redo_sender_slot_debug_t1_heartbeat` + (`2 passed in 4.21s`), and mixed K=128/H=64 backward passed (`1 passed in + 5.59s`). +- Audit evidence: + K=128/H=64 mixed terminal training with owner timing: + `/tmp/redo_fixmaass_owner_profile_sender_slot_long`, `13.945 tok/s`, + `71.713 ms`, peak `0.107572 GiB`. The K-adjusted T=1 floor gate passed + (`matched_t1_training_tokens_per_s=65.673`, divisor `128`, floor `0.513 + tok/s`). Owner timing remained flat against the accepted warmed row: + `transition_message_reverse_table_device_loop:ms=25.877;count=2`, + `temporal_artifact_recompute:ms=18.510;count=2`, + `artifact.recompute.cuda_temporal_replay_scan:ms=17.632;count=2`. For T=1 + guard coverage, `/tmp/redo_fixmaass_sender_slot_t1_b1024_mixed_plain` + recorded `64,464.909 tok/s`, peak `0.691133 GiB`, and mixed Fabric/stack + ratio `3.646x`. +- Environment note: + several fresh high-level runs exited with process code `-1` while PyTorch's + extension loader was compiling unchanged forward/backward extensions and left + partial build directories. Manual `ninja -j1` completed those same private + cache builds, after which the high-level gates passed. Treat those early + exits as extension-build non-evidence, not semantic failures of this ABI. +- Decision: + accept this ABI slice. It does not close R4 by itself because the active path + still has separate replay and reverse device-loop owners and planner metadata + still reports the transitional `python_autograd_scan` owner. The next R4 work + should use the explicit slot table to safely reduce sender-reverse scratch + traffic or fuse more replay/reverse ownership into the CUDA temporal superop, + with no cell-specific backend logic. + +### 2026-04-29 UTC - R4 sender-reverse scratch writes after explicit slot ABI + +Status: ACCEPTED AS GENERIC SCRATCH TRAFFIC TRIM; R4 REMAINS OPEN. + +Owner: active shared flat-bucket reverse message primitive. + +- Current owner: + the accepted explicit slot ABI removed the hidden assumption that + sender-list column equals receiver-local slot. The active K=128/H=64 row is + still dominated by `transition_message_reverse_table_device_loop` and replay, + so the next safe kernel slice is to reduce scratch traffic inside the + sender-reverse phase. +- Boundary invariant: + this may only use generic graph/message facts already present in the + tensor-table ABI: sender id, receiver id, receiver-local slot, delay policy, + and reset policy. It must not branch on cell family, population name, hidden + size, benchmark case, or single/mixed route identity. +- Implementation target: + when delay is disabled, sender-reverse work rows read by the compact + recurrent-sender table are written in the same message pass, so pre-zeroing + `message_weight_work` and `message_dlogit_work` is unnecessary. Also, + receiver edges from input-boundary senders are consumed by direct input K/V + gradients and are never read by the recurrent sender phase, so their + sender-reverse scratch rows should not be written. Delay-enabled paths should + keep conservative zeroing because a valid static edge can be dynamically + inactive for the current step. +- Implementation: + inside the shared transition/message reverse CUDA primitive, scratch + pre-zeroing is now gated by `use_sender_reverse && use_delay`, and + sender-reverse scratch rows are only written for recurrent senders. Input + sender edges still produce direct input K/V gradients in the receiver pass; + they no longer populate recurrent-sender scratch that no sender phase reads. + This uses only generic sender-bank identity and delay policy. +- Validation: + after rebuilding the touched CUDA extension with private cache + `/tmp/cortical_torch_ext_${USER}_redo_sender_slot_debug_t1_heartbeat`, + GPU focused gates passed: low-level transition/message reverse parity + (`1 passed in 4.72s`), mixed T=1 reset/no-reset training (`2 passed in + 7.60s`), and mixed K=128/H=64 backward (`1 passed in 8.49s`). +- Audit evidence: + `/tmp/redo_fixmaass_owner_profile_scratch_after_slot` recorded K=128/H=64 + mixed terminal training at `13.938 tok/s`, `71.745 ms`, peak `0.107572 GiB`. + The K-adjusted T=1 floor gate passed (`matched_t1_training_tokens_per_s=64.332`, + divisor `128`, floor `0.503 tok/s`). The measured reverse loop improved from + the sender-slot row's `25.877 ms` to + `transition_message_reverse_table_device_loop:ms=25.534;count=2`; replay + stayed comparable at `artifact.recompute.cuda_temporal_replay_scan:ms=17.660`. + T=1 B=1024 mixed without owner timing stayed healthy at + `/tmp/redo_fixmaass_scratch_after_slot_t1_b1024_mixed_plain`, + `66,384.758 tok/s`, peak `0.691133 GiB`, mixed Fabric/stack ratio `3.472x`. +- Decision: + accept this small traffic trim. It is not R4 closure: reverse and replay are + still separate owners, and the planner still reports transitional + `python_autograd_scan` metadata. The next high-priority R4 target remains + deeper replay/reverse fusion or a larger reduction inside the table-owned + reverse device loop. + +### 2026-04-29 UTC - R4 full-window replay artifact table retry + +Status: REJECTED; CODE ROLLED BACK. + +Owner: temporal artifact replay / reverse-window table bridge. + +- Current owner: + after the scratch trim, K=128/H=64 still spends about `17.66 ms` in + `artifact.recompute.cuda_temporal_replay_scan` and still assembles active + reverse windows from artifact checkpoint tensors plus separate final tensors. + The previous full-window replay probe was rejected because high-level runs + exited with silent `-1`; later validation showed identical silent exits can + come from PyTorch extension-loader partial builds, so this owner deserves a + properly rebuilt retry. +- Boundary invariant: + replay artifacts are temporal tensor-table materialization, not cell or route + policy. This retry may alter artifact table shape and replay-window selection + only. It must not add cell-family, population-name, hidden-size, benchmark + row, reset-mode, or single/mixed route selectors. +- Implementation target: + materialize recurrent-message, output-message, and transition-tape replay + artifacts as full `physical_steps` tables. Continue writing the separate + final tensors for compatibility. Teach `_scan_replay_tensor_window` to return + the full table directly when available, avoiding `torch.cat` on the active + reverse path. +- Validation: + the temporal scan extension was rebuilt manually with `ninja -j1` in private + cache `/tmp/cortical_torch_ext_${USER}_redo_sender_slot_debug_t1_heartbeat`. + Active replay/K gates passed: mixed K>1 terminal loss (`1 passed`), mixed + K>1 provided-state gradients (`1 passed`), and mixed K=128/H=64 backward + (`1 passed`). The older T>1 recompute test completed forward/backward but + failed a route assertion because current execution reported + `cuda_temporal_superop` instead of its expected + `windowed_temporal_physical_scan`; this was not used as acceptance evidence. +- Rejection evidence: + `/tmp/redo_fixmaass_owner_profile_full_window_retry` recorded K=128/H=64 + mixed terminal training at `13.881 tok/s`, `72.039 ms`, peak `0.101804 GiB`. + The K-adjusted T=1 floor gate passed, but throughput was below the accepted + scratch-trim row (`13.938 tok/s`) and the replay owner worsened from + `artifact.recompute.cuda_temporal_replay_scan:ms=17.660;count=2` to + `17.954 ms`; total `temporal_artifact_recompute` worsened from `18.607 ms` + to `18.865 ms`. +- Decision: + reject and roll back. The full table reduced peak memory in this probe but + did not improve throughput or replay owner time, so it violates the R4 + throughput priority. Do not reattempt this exact artifact-shape change unless + a separate CUDA-side consumer removes enough host/window work to offset the + extra scan writes. + +### 2026-04-29 UTC - R4 sender-reverse grad-KV direct accumulation probe + +Status: ACCEPTED; SUBSTAGE CLOSED. + +Owner: active shared flat-bucket reverse message primitive. + +- Current owner: + after the accepted scratch trim, the largest measured kernel owner remains + `transition_message_reverse_table_device_loop` at about `25.53 ms` for the + K=128/H=64 row. The sender-reverse phase still materializes + `message_grad_kv_work`, synchronizes, and then runs a second pass over + hidden dimensions to accumulate recurrent K/V weight and public carry. +- Boundary invariant: + this probe may only change generic message/projection accumulation inside the + existing tensor-table primitive. It must not add cell-family, population-name, + hidden-size, benchmark, reset-mode, or single/mixed route selectors. +- Implementation target: + compute each recurrent sender's K/V gradient and immediately accumulate its + recurrent K/V weight gradient plus public carry contribution, removing the + intermediate `message_grad_kv_work` write/read and one grid synchronization. + This trades the second pass for atomics into public carry, so acceptance + requires measured throughput improvement and strict parity. +- Manual boundary review: + the change stays inside the generic transition/message reverse tensor-table + primitive. ABI inputs remain flat bucket/message tables (`receiver_sender`, + explicit sender reverse receiver/slot tables, K/V tensors, reset windows, + step indices, recurrent hidden table, and projection weights). It adds no + cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, separate single/mixed route, or cell-family parameter + bundle. The removed `message_grad_kv_work` slot was scratch only; recurrent + K/V parameter gradients and public carry are still computed from the same + lowered message/projection tables. +- Implementation: + sender-reverse now computes each recurrent sender/dimension `grad_kv` and + immediately accumulates recurrent K/V weight and public carry contributions. + The intermediate `message_grad_kv_work` tensor-table slot and allocation were + removed after acceptance. +- Validation: + rebuilt + `/tmp/cortical_torch_ext_${USER}_redo_sender_slot_debug_t1_heartbeat/fabric_flat_bucket_temporal_backward_cuda` + with `ninja -v -j1`. Static checks passed (`git diff --check`; `ruff` on the + touched Python glue paths). Targeted parity passed after scratch-slot removal: + low-level transition/message reverse table parity (`1 passed in 4.49s`), + mixed T=1 reset/no-reset shared reverse device loop (`2 passed in 7.36s`), + and mixed K=128/H=64 high-level backward (`1 passed in 8.47s`). +- Audit evidence: + `/tmp/redo_fixmaass_owner_profile_direct_gradkv_noscratch` recorded K=128/H=64 + mixed terminal training at `14.005 tok/s`, `71.403 ms`, peak + `0.107572 GiB`. The K-adjusted T=1 training floor gate passed + (`matched_t1_training_tokens_per_s=63.048`, divisor `128`, floor + `0.493 tok/s`, ratio `28.43x`). The measured reverse loop improved versus + the accepted scratch-trim row from + `transition_message_reverse_table_device_loop:ms=25.534;count=2` to + `25.217 ms`; a pre-scratch-removal confirmation reached `14.052 tok/s` and + `25.179 ms`, so the owner movement is repeatable within warmed noise. + T=1 B=1024 mixed final-code probe + `/tmp/redo_fixmaass_direct_gradkv_noscratch_t1_b1024_mixed_plain` stayed + healthy at `67,416.635 tok/s`, peak `0.691133 GiB`, mixed Fabric/stack ratio + `3.554x`, through the high-level + `model_forward_external_loss_backward_optimizer_step` contract. +- Decision: + accept this substage. It is still not R4 closure: the active K=128/H=64 + training path still reports transitional `python_autograd_scan` metadata and + still has separate replay/reverse owners (`artifact.recompute.cuda_temporal_replay_scan` + about `17.704 ms`, reverse loop about `25.217 ms`). The next high-priority + R4/R13 owner remains moving more replay/reverse scan work into the table-owned + CUDA temporal superop without adding cell-specific or route-specific logic. + +### 2026-04-29 UTC - R13 long T*K frontier refresh + +Status: CLOSED AS LONG-T K=128/H=64 STREAMING REPAIR; R13/R4 REMAIN OPEN. + +Owner: temporal scaling/audit and whichever backend owner blocks long-T K=128/H=64. + +- User correction: + do not get stuck proving only T=1 or T=1,K=128. The redo audit must exercise + T*K with `T=4096` and frontier `T=16K` where memory allows. T=1 remains the + matched base throughput floor, but long-T streaming is the semantic target. +- Immediate plan: + run current-code high-level API probes for mixed flat-bucket Fabric with + `seq_len=4096`, `inner_steps=128`, `gradient_horizon_steps=64`, then attempt + `seq_len=16384` under the same K/H row. Use private caches and GPUs 0-4 only. + Record whether failures are compile/cache, OOM, timeout, parity/gate, or a + named backend owner. Any follow-up code must target the long-T owner, not only + the T=1 owner. +- Early probe notes: + parallel owner-timed T=4096/T=16K probes under the warmed shared cache were + aborted as contaminated measurement probes after long low-utilization runs; + they were likely dominated by per-window timing/cache/process contention and + are not closure evidence. A T=4096 no-owner probe launched in parallel under + the same warmed cache was also aborted as contaminated. +- Current evidence: + `/tmp/redo_fixmaass_tk_t512_k128_h64_no_owner_probe` completed through the + high-level API for mixed Fabric at `T=512`, `K=128`, `H=64`: `14.204 tok/s`, + `36.047 s`, peak `0.271243 GiB`, with + `time_steps=65536`, `checkpoint_stride=64`, `produced_checkpoint_count=1023`, + and `recompute_window_len=64`. This supports flat T*K streaming at T=512, but + also shows 1023 Python-visible H windows that must not become the long-T + closure owner. +- Clean rerun: + T=4096 is being rerun solo with dedicated caches + `/tmp/cortical_torch_ext_${USER}_redo_long_t4096_solo` and + `/tmp/cortical_triton_${USER}_redo_long_t4096_solo` so the result is not + polluted by shared-cache or parallel owner-timing overhead. + +Update: + +- The first clean T=4096 run before the window-tape fix was aborted after more + than 26 minutes without a `cases.jsonl` row. A post-output-window-skip T=4096 + retry was also aborted after nearly 12 minutes. Both were low-GPU/high-host + runs and are not closure evidence. +- Implemented generic long-T fixes in + `src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py`: + inactive terminal-output H windows no longer materialize zero output-gradient + windows, reverse-engine transition projection gradients are flattened over + the H-window time/batch dimensions before entering the existing generic + receiver-major gradient reducers, and eligible no-reset reverse-only + flat-bucket windows are guarded against silent Python step-replay fallback. + These edits do not add cell-family, population-name, hidden-size, reset-mode, + benchmark-row, or single/mixed route selectors. +- Root cause for the T>=1024 stall: + the planner correctly disabled full-sequence transition tape once the full + T*K tape exceeded the bounded budget, but CUDA reverse-only replay still + needs only H=64 window-local transition artifacts. Before the guard this + silently degraded into `_run_temporal_bucket_step_backward_result` and burned + host time. The replay path now requests bounded window-local transition + artifacts for eligible reverse-only windows even when the global sequence tape + mode is disabled. +- Validation: + `git diff --check` and `ruff` passed. CUDA parity passed for terminal K>1 + final-emission gradients, provided-state terminal gradients, and K=128/H=64 + mixed-pop backward smoke (`3 passed in 5.40s`) on GPU1 with private caches. +- Long-T evidence after the fix: + T=256/K=128/H=64 closed at `15.126 tok/s`, `16.924 s`, peak + `0.189717 GiB`. T=512/K=128/H=64 closed at `15.525 tok/s`, `32.979 s`, peak + `0.271243 GiB`. T=1024/K=128/H=64 closed at `15.736 tok/s`, `65.074 s`, + peak `0.430399 GiB`, with global `transition_tape:disabled` and the + K-adjusted T=1 training floor gate passing. T=4096/K=128/H=64 closed at + `15.953 tok/s`, `256.761 s`, peak `1.468093 GiB`, consumed `8191` H=64 + checkpoints, and passed the K-adjusted T=1 training floor gate + (`matched_t1=45.424 tok/s`, floor `0.355 tok/s`, ratio `44.95x`). +- Current run: + T=16K/K=128/H=64 closed through the same high-level audit API on GPU0 using + the warmed private cache: `15.901 tok/s`, `1030.383 s`, peak + `5.757155 GiB`, global `transition_tape:disabled`, `time_steps=2097152`, + `checkpoint_stride=64`, `produced_checkpoint_count=32767`, + `consumed_checkpoint_count=32767`, and `recompute_window_len=64`. The + K-adjusted T=1 training floor gate passed (`matched_t1=41.618 tok/s`, floor + `0.325 tok/s`, ratio `48.90x`). This restores the April 26 frontier target + at K=128/H=64 for T=16K on the current redo backend. +- Decision: + close this R13 substage for long-T K=128/H=64 streaming repair. R13/R4 as + whole remain open until the metadata/planner owner no longer reports + transitional `python_autograd_scan` and the remaining replay/reverse H-window + host orchestration is fully owned by the CUDA temporal superop, but T=4096 + and T=16K are now usable audit rows again and no longer silently degrade into + Python step replay when full-sequence transition tape is over budget. + +### 2026-04-29 UTC - R4 inactive-output reverse window ownership + +Status: ACCEPTED AS GENERIC INACTIVE-WINDOW REVERSE BRIDGE TRIM; R4 REMAINS OPEN. + +Owner: active shared flat-bucket reverse table bridge / terminal-loss inactive H windows. + +- Closed before this owner: + the long-T K=128/H=64 repair above is accepted as a streaming frontier + substage. It does not close R4/R13 overall because the plan-level temporal + owners still say `python_autograd_scan`. +- Current measured design issue: + terminal-loss K=128/H=64 has one active output physical step but many inactive + H windows. Those inactive windows already use the CUDA replay plus table-owned + reverse engine, yet Python still builds a full tuple of empty + `TemporalOutputBackwardStep` objects only to provide an all-zero direct-public + gradient window to the CUDA reverse table. +- Implementation rule: + do not add cell-family, population-name, benchmark-row, hidden-size, reset, or + single/mixed route selectors. The reverse table may consume a generic + tensor/op-table window and treat a missing output-backward sequence as a zero + direct-public gradient window. Materialized-state carries must still force the + output backward path when an output-cell carry gradient can exist. +- Exit criteria for this slice: + inactive terminal-output windows skip the per-step Python output-backward + sequence, the CUDA reverse table still owns the dependency recurrence, K=128 + parity remains strict, and at least one high-level K=128/H=64 probe remains + above the matched T=1/K-adjusted floor. +- Implementation: + `TemporalPhysicalBackwardScanExecutor` no longer builds an all-empty + `TemporalOutputBackwardSequence` for inactive terminal-output windows. The + generic reverse-table path now accepts a missing output-backward sequence as + an all-zero direct-public gradient window, records + `cuda_reverse_engine_zero_direct_public_window`, and keeps materialized-state + carries on the explicit output-backward path. A mismatch guard rejects any + non-empty output-backward sequence whose step count does not match the H + window. +- Validation: + `git diff --check`, `ruff`, and `py_compile` passed. CUDA parity on GPU1 with + private caches passed for terminal K>1 final-emission gradients, + provided-state/reset propagation, and K=128/H=64 mixed backward smoke + (`3 passed in 5.45s` after warm compile). +- High-level audit evidence: + `/tmp/redo_fixmaass_inactive_zero_direct_t512_k128_h64_confirm` ran through + `python -m benchmarks.fabric.run_audit` using the high-level Fabric API for + mixed population, `T=512`, `K=128`, `H=64`, terminal loss, no resets, + `B=1`, `h=8`, planner checkpoints, and private caches on GPU1. Result: + `16.125 tok/s`, `31.753 s`, peak `0.270731 GiB`. The matched T=1 training + baseline was `72.656 tok/s`; the K-adjusted floor was `0.568 tok/s`; the gate + passed at `28.41x`. The row records the new + `temporal_backward_glue:cuda_reverse_engine_zero_direct_public_window` path + and still records the active reverse table owner while plan-level + `temporal_plan_backward_owners` remains `python_autograd_scan`. +- Decision: + accept this substage as a real active-path bridge reduction, not R4 closure. + It removes one repeated Python per-step inactive-output object path, but the + remaining R4 owner is still the broader replay/reverse H-window host shell + and transitional planner owner metadata. Next R4 work should continue moving + replay/reverse window ownership deeper into the table-owned CUDA temporal + superop without adding cell-specific logic. + +### 2026-04-29 UTC - R4 initial public carry reverse-table ABI + +Status: ACCEPTED AS ACTIVE GENERIC REVERSE-TABLE ABI SLICE; R4 REMAINS OPEN. + +Owner: active shared flat-bucket reverse table bridge / cross-window recurrent carry. + +- Follow-up issue from the inactive-output bridge: + terminal-loss H windows before the final output still carry recurrent public + gradients from the later window. The first bridge removed the empty + output-backward step tuple, but the active path still materialized a mostly + zero `[H,B,R,h]` direct-public tensor only to place that carry on the final + local step. +- Implementation: + the transition/message reverse table ABI now has a generic + `grad.initial_public_y_carry` tensor-table role. The CUDA device loop + initializes its public carry workspace from that `[B,R,h]` tensor when + present and treats `grad.direct_public_y_window` as optional. Inactive + terminal-output windows can therefore pass no direct-public window at all, + while active output windows still pass direct output gradients through the + same generic table role. This does not add any cell-family, population-name, + route, shape, reset, or benchmark selector. +- Validation: + `git diff --check`, `ruff`, and `py_compile` passed. CUDA parity on GPU2 with + private caches passed for terminal K>1 final-emission gradients, + provided-state/reset propagation, and K=128/H=64 mixed backward smoke + (`3 passed in 5.39s` after the ABI fix). +- High-level audit evidence: + `/tmp/redo_fixmaass_initial_public_carry_t512_k128_h64` ran through + `python -m benchmarks.fabric.run_audit` using the high-level Fabric API for + mixed population, `T=512`, `K=128`, `H=64`, terminal loss, no resets, + `B=1`, `h=8`, planner checkpoints, and private caches on GPU2. Result: + `16.144 tok/s`, `31.715 s`, peak `0.270731 GiB`. The matched T=1 training + baseline was `67.399 tok/s`; the K-adjusted floor was `0.527 tok/s`; the gate + passed at `30.66x`. The row records + `temporal_backward_glue:cuda_reverse_engine_absent_direct_public_window`, + proving the active row no longer materializes the direct-public H-window just + to carry recurrent gradients across windows. Plan-level + `temporal_plan_backward_owners` still correctly remains `python_autograd_scan`. +- Decision: + accept this ABI slice. It removes another real active-path bridge allocation, + but R4 remains open because replay/reverse is still launched and sequenced by + the Python autograd/H-window shell. Next R4 work should target that remaining + shell or another measured replay/reverse table owner, not metadata relabeling. + +### 2026-04-29 UTC - R4 diagonal output projection param grads in reverse table + +Status: ACCEPTED AS TABLE-OWNED PARAM-GRAD BRIDGE REMOVAL; R4 REMAINS OPEN. + +Owner: active shared flat-bucket reverse table bridge / transition projection +parameter-gradient reduction. + +- Current measured design issue: + after the table-owned reverse device loop computes direct-plus-carried public + gradients, Python still materializes `transition_public_grad_window` and then + runs the diagonal output projection parameter-gradient reducer as a separate + receiver-affine bridge. This is hot in K=128/H=64 because it repeats for every + H-window, including inactive terminal-output windows whose only public + gradient is cross-window recurrent carry. +- Boundary review before edit: + the ABI inputs must remain tensor-table primitive roles: direct public grad + window (optional), initial public carry grad (optional), diag-RTU transition + state/params, message primitive rows, reset windows, graph/message tables, and + flat bucket start/count. The kernel must not add a cell-kind selector, + population-name selector, benchmark-row selector, hidden-size policy key, + separate single/mixed route, or cell-family parameter bundle. Output + projection weight/bias grads are primitive parameter-gradient outputs of the + same flat-bucket reverse table, not a cell-specific shortcut. +- Implementation: + added `primitive.diag_rtu.projection.output_bias` plus + `grad.diag_rtu.projection.output_weight/output_bias` table slots, and the + reverse device loop now accumulates the diagonal output projection + weight/bias gradients while it owns the exact direct-plus-carried public + gradient. The Python `transition_public_grad_window` bridge and the diagonal + output projection receiver-affine param-gradient step were deleted for this + path. The fallback `diagonal_preproj_window` recompute/stacking in + `TemporalPhysicalBackwardScanExecutor` was also removed because the table + now recomputes the primitive preprojection internally from generic diag-RTU + tensor rows. +- Validation: + `git diff --check`, targeted `ruff`, and `py_compile` passed. + `tests/test_fabric_backend_boundaries.py` passed (`2 passed`). + Low-level CUDA reverse-table parity passed for + `test_fabric_cuda_transition_message_reverse_table_window_matches_mixed_step_loop` + (`1 passed in 101.67s`) and now checks the table-emitted diagonal output + projection weight/bias gradients. High-level CUDA parity passed for terminal + K>1 final-emission gradients, provided-state/reset propagation, and K=128 + mixed backward (`3 passed in 269.34s`) on GPU2 with private caches. +- High-level audit evidence: + `/tmp/redo_fixmaass_diag_output_table_t512_k128_h64` ran through + `python -m benchmarks.fabric.run_audit` using the high-level Fabric API for + mixed population, `T=512`, `K=128`, `H=64`, terminal loss, no resets, + `B=1`, `h=8`, planner checkpoints, and private caches on GPU2. Result: + `15.790 tok/s`, `32.426 s`, peak `0.270731 GiB`. The matched current-code + T=1 training baseline was `64.643 tok/s`; the K-adjusted floor was + `0.505 tok/s`; the gate passed at `31.27x`. The row records + `temporal_backward_glue:cuda_transition_diagonal_output_projection_param_grad_table`. + A current-code T=1 mixed `B=1024` regression probe + `/tmp/redo_fixmaass_diag_output_table_t1_b1024_mixed_iter5` reached + `65,597.334 tok/s`, peak `0.691230 GiB`, and `3.480x` the matched mixed + stack row while passing the April 21 reference gate. Short one-iteration + rows at `58-59k tok/s` were treated as warmup/noise, not closure evidence. +- Decision: + accept this substage as a real backend-owner reduction: the active path no + longer materializes the transition-public H-window or runs the diagonal + output projection param-gradient reducer through Python. R4 still remains + open because the H-window replay/reverse shell and plan metadata still report + `python_autograd_scan`; next work should continue moving the remaining + replay/window orchestration and transition input-projection param-gradient + bridge into the shared table-owned temporal engine. + +### 2026-04-29 UTC - R4 input projection param grads in reverse table + +Status: ACCEPTED for this substage. R4 remains open until the remaining +reverse replay/window shell and `python_autograd_scan` ownership are removed. + +Owner: active shared flat-bucket reverse table bridge / transition input +projection parameter-gradient reduction. + +- Current measured design issue: + K=128/H=64 now records the diagonal output projection gradients as + table-owned, but `temporal_backward_glue:cuda_transition_input_projection_param_grad_window` + still remains. Python builds flattened gated/diagonal recurrent-message + windows and `TransitionInputProjectionParamGradStep` objects, then invokes the + receiver-affine reducer for gated gate affine, gated static recurrent input + projection, and diagonal recurrent input projection parameter gradients. +- Boundary review before edit: + the reverse table may consume the forward recurrent-message window and + primitive projection bias tensors as tensor-table roles and emit primitive + parameter-gradient tensors. It must not introduce cell-kind selectors, + population-name selectors, benchmark-row selectors, hidden-size policy keys, + separate single/mixed routes, or cell-family parameter bundles. The host may + still map table-emitted primitive grads to existing parameter bindings, but + it must not own the temporal/window reduction. +- Implementation: + extended the transition/message reverse table ABI with generic tensor roles + for the gated transition input window, gated gate bias, gated/diagonal input + projection biases, and the forward recurrent-message backend-order window. + The extension now emits primitive parameter-gradient tensors for gated gate + affine weight/bias, gated static recurrent input projection weight/bias, and + diagonal recurrent input projection weight/bias. `TemporalPhysicalBackwardScanExecutor` + maps those table-emitted primitive grads to existing materialized/static + parameter bindings and no longer builds `TransitionInputProjectionParamGradStep` + objects for this path. +- Rejected probes: + an initial cooperative-kernel serial post-scan reduction was parity clean but + regressed the T=1 mixed `B=1024` warmed probe to `50.16k tok/s`. A parallel + atomic version recovered only to `56.92k tok/s`, and a block/warp reduction + version reached `63.49k tok/s` on GPU2 / `63.75k tok/s` on GPU3 while the + device-matched previous commit measured `68.68k tok/s` on GPU2 and + `65.45-67.92k tok/s` on GPU3. Those probes were rejected as T=1 regressions. + The accepted version keeps Python out of the reducer path but uses + table-owned C++/ATen primitive matrix reductions inside the same extension + call after the cooperative reverse scan has emitted the gradient windows. +- Validation: + `git diff --check`, targeted `ruff`, and `py_compile` passed. + `tests/test_fabric_backend_boundaries.py` passed (`2 passed`). Low-level CUDA + reverse-table parity passed for + `test_fabric_cuda_transition_message_reverse_table_window_matches_mixed_step_loop` + and now checks all six table-emitted input-projection/gate parameter-gradient + tensors. High-level CUDA parity passed for terminal K>1 final-emission + gradients, provided-state/reset propagation, and K=128 mixed backward + (`3 passed in 260.10s`) on GPU2 with private caches. +- High-level audit evidence: + `/tmp/redo_fixmaass_input_projection_table_aten_t1_b1024_mixed_iter10` + ran the high-level Fabric API mixed-population T=1 training probe with + `B=1024`, `h=8`, 1M params, no resets, warmup 5 / iterations 10. Result: + `70,953.632 tok/s`, `14.432 ms`, peak `0.673192 GiB`, April 21 gate pass, + and `3.31x` mixed stack. The device-matched previous commit baseline was + `/tmp/redo_fixmaass_prev_a73f833_t1_b1024_mixed_iter10_gpu2` at + `68,679.383 tok/s`, `14.910 ms`, peak `0.691230 GiB`, so the accepted path + removes the Python window reducer without regressing T=1 throughput or memory. + `/tmp/redo_fixmaass_input_projection_table_aten_t512_k128_h64` ran + `T=512`, `K=128`, `H=64`, `B=1`, `h=8`, mixed population, terminal loss, + planner checkpoints, and no resets through the same high-level API. Result: + `15.872 tok/s`, `32.258 s`, peak `0.270731 GiB`; matched T=1 training was + `72.899 tok/s`, K-adjusted floor `0.5695 tok/s`, gate pass at `27.87x`. + The row records `temporal_backward_glue:cuda_transition_input_projection_param_grad_table` + and no longer records + `temporal_backward_glue:cuda_transition_input_projection_param_grad_window`. +- Decision: + close the transition input-projection parameter-gradient bridge substage. + The remaining R4/R13 owner is the broader reverse replay/window orchestration + and `python_autograd_scan` shell; the next backend task should continue + moving that shell into the shared CUDA temporal superop rather than adding + cleanup-only changes. + +### 2026-04-29 UTC - R4 reverse window shell next owner + +Status: ACCEPTED for this substage. R4/R13 remain open until replay/reverse +window orchestration and `python_autograd_scan` ownership are removed. + +Owner: shared multi-pop flat-bucket temporal reverse window shell. + +- Current invariant: + reverse replay and backward must be table/superop owned over flat bucket + identity, tensor slots, op rows, reset policy, checkpoint policy, and + materialization policy. No new path may branch on cell family, population + name, benchmark row, hidden-size policy, or single-vs-mixed cardinality. +- Current issue: + the transition/message reverse table now owns the hot primitive backward + scan and the recent diagonal/input-projection parameter-gradient bridges, but + the H-window shell still records planner-level `python_autograd_scan`. + Remaining active glue includes reverse-window table assembly, sparse output + message recompute for emitted steps, boundary public K/V projection backward, + query/readout parameter binding, boundary-gradient accumulation into outer + time, and final recurrent-state gradient materialization. +- Profiling target: + run a current-code, high-level Fabric API T*K owner-timed probe with + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1`, GPU 0-4 only, and a private + extension cache. Use it only to choose the next backend-owned slice; do not + treat stale profile rows as closure evidence. +- Manual boundary review for the next edit: + acceptable ABI inputs are boundary/recurrent/output windows, flat bucket + orders, generic sender K/V projection weights, message-rule tables, primitive + tensor roles, reset windows, step-index windows, and scalar dimensions + inferred from tensor shapes. The edit must not add cell-specific metadata to + the temporal engine, must not make lattice/config fields execution policy, + and must keep benchmarks as high-level API consumers. +- Planned closure for this substage: + move one remaining reverse-window shell operation into a backend-owned + table/window primitive and prove parity plus a current T=1 non-regression and + T*K H=64 smoke. If the probe is slower, reject it here with numbers and keep + the same design target. + +Update: + +- Warmed owner-timed high-level T*K probe: + `/tmp/redo_fixmaass_owner_timing_t64_k128_h64_warm`, GPU 2, private cache, + `T=64`, `K=128`, `H=64`, `B=1`, mixed population, terminal training, + passed the K-adjusted current-code T=1 reference gate at `15.830 tok/s`, + peak `0.128 GiB`. The matched current-code T=1 training line was + `70.394 tok/s`, so the K-adjusted floor was `0.550 tok/s`. +- Warmed timing correction: + cold compile had made the first run unusable. The warmed profile still + reports `python_autograd_scan` for forward/backward owners. Dominant active + reverse owners are `temporal_artifact_recompute` / replay scan and + `transition_message_reverse_table_device_loop`; the hidden-before + materialization launch is not the warmed bottleneck, but it is a removable + reverse-shell launch. +- Accepted implementation slice in progress: + the active reverse-only artifact path now passes recurrent hidden source + banks (`initial_recurrent_hidden_backend_order`, + `recurrent_hidden_after_window`, and source offset) into the reverse table, + and the transition/message reverse device-loop ABI adds generic tensor-table + slots for those banks. The kernel selects hidden-before values from the + source banks inside the reverse table and applies generic message-reset + policy there, so the active path can elide the separate + `cuda_recurrent_hidden_before_window` materialization glue launch. +- Manual boundary review: + ABI inputs are tensor-table slots, flat recurrent bank dimensions, reset + windows, source offset, sender tables, message primitive dimensions inferred + from tensor shapes, and primitive transition tensors already present in the + reverse table. The edit adds no cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, single/mixed route selector, + lattice/config field, or cell-family parameter bundle. +- Validation: + Python compile, Ruff check, and Ruff format are green for the touched Python + wrappers. Low-level CUDA reverse-table parity passed for + `test_fabric_cuda_transition_message_reverse_table_window_matches_mixed_step_loop` + after preserving the legacy materialized hidden-before ABI. High-level K>1 + mixed backward parity passed for terminal final-emission gradients, + provided-state gradients, and K=128 outer-emission gradients + (`3 passed`). Reset parity passed for single and mixed temporal horizon rows, + both absent and present reset modes (`4 passed`). T=1/K=1 mixed training + shared reverse-device-loop tests passed (`2 passed`). +- High-level audit evidence: + `/tmp/redo_fixmaass_hidden_source_t1_b1024_mixed_notiming` ran the + high-level Fabric API mixed-population T=1 training probe with `B=1024`, + `h=8`, 1M params, no resets, warmup 5 / iterations 10. Result: + `70,776.037 tok/s`, `14.468 ms`, peak `0.673192 GiB`, April 21 gate pass, + and `3.58x` mixed stack. The owner-timed version of this row was lower + (`59,892 tok/s`) because timing instrumentation adds overhead, so it is not + used as non-regression evidence. + `/tmp/redo_fixmaass_hidden_source_t64_k128_h64_notiming` ran + `T=64`, `K=128`, `H=64`, `B=1`, `h=8`, mixed population, terminal loss, + planner checkpoints, and no resets through the same high-level API. Result: + `15.571 tok/s`, `4.110 s`, peak `0.126713 GiB`; matched current-code T=1 + training was `77.654 tok/s`, K-adjusted floor `0.607 tok/s`, gate pass at + `25.67x`. + Owner-timed `/tmp/redo_fixmaass_hidden_source_t64_k128_h64` recorded + `transition_message_reverse_table_device_loop:ms=1768.188;count=128` and + `temporal_artifact_recompute:ms=1182.885;count=128` while still reporting + `python_autograd_scan` for forward/backward temporal owners. +- Decision: + close this hidden-before materialization bridge substage. The active + launch list no longer contains the old + `temporal_backward_glue:cuda_recurrent_hidden_before_window`; it records + `temporal_backward_glue:cuda_recurrent_hidden_before_window_elided_by_reverse_table` + instead. R4/R13 remain open because replay/reverse is still sequenced by the + Python autograd/H-window shell and the warmed dominant owners are still + artifact replay and the reverse table device loop. The next backend owner + should target replay/window orchestration or another measured shell + operation inside the shared CUDA temporal superop, not cleanup-only work. + +### 2026-04-29 UTC - R4 recurrent K/V before-window source tables + +Status: REJECTED and backed out. + +Owner: shared multi-pop flat-bucket temporal reverse window shell / recurrent +message source banks. + +- Current invariant: + recurrent sender K/V banks are generic message primitive tensors. The reverse + engine may consume tensor-table source banks and reset policy, but it must not + branch on cell family, population name, hidden-size policy, benchmark row, + lattice/config fields, or single-vs-mixed cardinality. +- Current measured reason: + the accepted hidden-before slice removed one materialization launch, but the + active reverse-only H-window table assembly still materializes recurrent K/V + "before" windows from initial/checkpoint banks and applies reset with a + window-level `torch.where` before calling the reverse table. This belongs in + the same table-owned recurrent-message primitive path as hidden-before. +- Planned implementation: + add optional recurrent K/V initial/checkpoint source tensor slots and source + offset to the transition/message reverse table ABI. In active reverse-only + mode, pass source banks instead of materialized recurrent K/V before windows; + the device loop selects initial or checkpoint rows per local step and applies + message-reset zeroing only in source mode. The legacy materialized ABI must + remain byte-for-byte semantic-compatible for existing parity tests. +- Manual boundary review: + ABI additions are generic message tensor slots (`state.recurrent_k_initial`, + `state.recurrent_v_initial`, `state.recurrent_k_after_window`, + `state.recurrent_v_after_window`) plus source offset and boolean source mode. + No cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, separate single/mixed route, lattice/config field, or + cell-family parameter bundle is allowed. + +Outcome: REJECTED and backed out. + +- Implementation tried: + optional recurrent K/V initial/checkpoint source slots were added to the + transition/message reverse table, active reverse-only windows passed source + banks instead of materialized K/V-before windows, and the CUDA device loop + selected/reset those values per sender. A pointer-hoist optimization was then + tried to avoid recomputing source/reset offsets per head/value element. +- Correctness: + both implementations passed low-level materialized-ABI parity + (`test_fabric_cuda_transition_message_reverse_table_window_matches_mixed_step_loop`) + and the high-level K/reset suite + (`mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient`, + `mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients`, + `mixed_population_k128_backward_maps_outer_emission_gradients`, and single / + mixed temporal reset parity; `7 passed`). +- Performance rejection evidence: + first source attempt: + `/tmp/redo_fixmaass_kv_source_t64_k128_h64_notiming` reached + `15.431 tok/s`, peak `0.123784 GiB`, and + `/tmp/redo_fixmaass_kv_source_t1_b1024_mixed_notiming` reached + `69,444 tok/s`. Pointer-hoisted attempt: + `/tmp/redo_fixmaass_kv_source_opt_t64_k128_h64_notiming` reached + `15.462 tok/s`, peak `0.123784 GiB`, and + `/tmp/redo_fixmaass_kv_source_opt_t1_b1024_mixed_notiming` regressed T=1 to + `67,678 tok/s`. The previous accepted hidden-before slice had + `/tmp/redo_fixmaass_hidden_source_t64_k128_h64_notiming` at + `15.571 tok/s`, peak `0.126713 GiB`, and + `/tmp/redo_fixmaass_hidden_source_t1_b1024_mixed_notiming` at + `70,776 tok/s`. The memory reduction was too small to justify a T=1/T*K + throughput regression. +- Decision: + reject the recurrent K/V source-bank active path and keep the current + materialized recurrent K/V-before windows for now. The next R4/R13 owner + should target a larger replay/window orchestration reduction or a fused path + that does not add per-edge hot-loop overhead. + +### 2026-04-29 UTC - R4 active reverse replay artifact sparsening + +Status: ACCEPTED for this substage. R4/R13 remain open until the reverse +replay/window shell and `python_autograd_scan` ownership are removed. + +Owner: shared multi-pop flat-bucket temporal reverse replay/window shell. + +- Current invariant: + active reverse-only windows should be table-owned over flat bucket identity + and tensor windows. The Python shell may pass high-level tensors and window + descriptors while this stage is transitional, but it must not materialize + per-physical-step replay tensors that the reverse table never reads. +- Current measured reason: + after hidden-before elision and rejected recurrent K/V source tables, the + dominant open owners remain `temporal_artifact_recompute`, + `artifact.recompute.cuda_temporal_replay_scan`, and + `transition_message_reverse_table_device_loop`. The recompute path still + iterates over every physical step in an H window and pulls recurrent hidden, + recurrent K/V, and recurrent-message per-step views even for inactive + non-emission K microsteps. In active reverse-only terminal/per-timestep rows, + the reverse table consumes `TemporalReverseWindowTables`; only emitted output + steps need full per-step output-message artifacts for output backward. +- Planned implementation: + keep the active reverse-only table ABI unchanged, but sparsify replay + artifact materialization inside `_try_cuda_mixed_flat_bucket_recompute_artifact_window`. + Non-emitted active reverse-only physical steps should carry minimal placeholder + tensors for recurrent hidden/K/V/message fields, while emitted output steps + keep the real tensors needed by `run_temporal_output_backward_sequence`. + The reverse engine must continue consuming the same window tables, not the + per-step placeholders. +- Manual boundary review: + this edit changes only replay artifact materialization policy. It adds no + cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, single/mixed route selector, lattice/config field, or + cell-family parameter bundle. The decision is driven by temporal emission + metadata and the existing active reverse-only table ownership. +- Implementation: + active reverse-only recompute now records + `temporal_backward_glue:cuda_recompute_active_reverse_sparse_step_artifacts` + and carries minimal placeholder recurrent hidden/K/V/message tensors for + non-emitted K microsteps. Emitted output steps still materialize the real + recurrent hidden/K/V tensors needed by output backward. The reverse table + continues to consume `TemporalReverseWindowTables`; per-step placeholder + tensors are not execution inputs to the reverse engine. +- Validation: + `py_compile`, Ruff check, Ruff format, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. CUDA parity passed for the + high-level K>1/K=128 mixed backward suite (`3 passed`), single/mixed reset + parity (`4 passed`), and T=1/K=1 mixed reverse-device-loop tests (`2 passed`) + using GPUs 0-2 and private extension/cache directories. +- High-level audit evidence: + `/tmp/redo_fixmaass_sparse_artifacts_t64_k128_h64` ran the high-level Fabric + API mixed-population `T=64`, `K=128`, `H=64`, `B=1`, terminal training row on + GPU4. Result: `15.742 tok/s`, `4065.451 ms`, peak `0.124505 GiB`; matched + current-code T=1 training was `71.667 tok/s`, K-adjusted floor was + `0.5599 tok/s`, and the April 21/K-adjusted gate passed at `28.12x`. The + row records the new sparse artifact tag plus the existing reverse-table + ownership tags. + T=1 mixed `B=1024` reruns on GPU3 with the same private cache reached + `69,730 tok/s` and `64,425 tok/s`, both passing the April 21 reference gate + with peak `0.673192 GiB` and mixed stack ratios above `3.55x`. The first + concurrent T=1 run at `57,471 tok/s` is treated as startup/noise and is not + closure evidence. + Owner-timed `/tmp/redo_fixmaass_sparse_artifacts_t64_k128_h64_owner` passed + at `14.947 tok/s`, but still showed the same dominant open owners: + `transition_message_reverse_table_device_loop:ms=1763.981;count=128`, + `temporal_artifact_recompute:ms=1351.822;count=128`, and + `artifact.recompute.cuda_temporal_replay_scan:ms=1132.284;count=128`. +- Decision: + accept this as a small replay-artifact shell reduction because it reduces + active reverse-only per-step artifact materialization, slightly improves the + T*K non-timed row, and lowers T*K peak memory from the prior hidden-source + `0.126713 GiB` to `0.124505 GiB` without breaking T=1 gates. Do not count it + as R4/R13 closure: the warmed dominant owners and planner metadata still + report `python_autograd_scan`. The next owner remains a larger fused + replay/window orchestration or reverse-table/superop reduction that moves + `temporal_artifact_recompute` or `transition_message_reverse_table_device_loop`. + +### 2026-04-29 UTC - R4 compact reverse hidden-carry materialization + +Status: ACCEPTED AS A SMALL GENERIC REVERSE-TABLE MATERIALIZATION TRIM; R4 +REMAINS OPEN. + +Owner: shared multi-pop flat-bucket temporal reverse table output +materialization. + +- Current invariant: + the reverse table owns the H-window dependency scan over flat bucket identity + and tensor-table rows. Runtime glue should ask the table to materialize only + tensors that are consumed by the public autograd boundary; debug/parity + helpers may request fuller windows, but active training rows should not pay + for unused per-step outputs. +- Current measured reason: + the accepted active path still reports + `transition_message_reverse_table_device_loop` as the largest owner. Inside + that table call, `grad_hidden_message_window` is materialized for every + physical step even though the active runtime consumes only + `grad_hidden_message_window[0]` as the carry into the preceding H window. +- Planned implementation: + extend the existing reverse tensor-table ABI with a generic + hidden-carry materialization flag. The default remains full-window + materialization for low-level parity tests; active temporal backward calls + request compact carry materialization and the kernel writes only the first + carry slice. +- Manual boundary review: + this ABI change is a table materialization policy. It adds no cell-kind + selector, population-name selector, benchmark-row selector, hidden-size policy + key, single/mixed route selector, lattice/config field, or cell-family + parameter bundle. The inputs remain flat tensor-table roles, op rows, reset + policy, and materialization policy. +- Implementation: + `try_transition_message_reverse_table_window_cuda` now accepts + `materialize_hidden_message_window` with the default set to full-window + materialization. The active reverse-only temporal backward path passes + `False`, records + `temporal_backward_glue:cuda_transition_message_reverse_compact_hidden_carry`, + and the CUDA reverse table allocates/writes only the first hidden-carry slice + that is consumed as the carry into the preceding H window. Low-level parity + callers keep the full `[H, B, R, hidden]` output by default. +- Validation: + `py_compile`, Ruff check, Ruff format check, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. CUDA validation passed for + `test_fabric_cuda_transition_message_reverse_table_window_matches_mixed_step_loop` + on GPU0, the high-level mixed K>1/K=128 terminal gradient suite on GPU0 + (`3 passed`), reset parity on GPU1 (`4 passed`), and the T=1/K=1 mixed + reverse-device-loop tests on GPU2 (`2 passed`), all using private extension + and Triton cache directories. +- High-level audit evidence: + owner-timed `/tmp/redo_fixmaass_compact_carry_t64_k128_h64_owner` passed the + canonical high-level Fabric API `T=64`, `K=128`, `H=64`, `B=1`, mixed, + terminal training row at `15.389 tok/s`, peak `0.124505 GiB`, with matched + current-code T=1 training `70.854 tok/s` and K-adjusted floor + `0.5535 tok/s`. The target reverse owner moved only slightly from the prior + accepted `transition_message_reverse_table_device_loop:ms=1763.981;count=128` + to `1759.558;count=128`, so this is not a major R4 close. The same timed row + showed `temporal_artifact_recompute:ms=1180.382;count=128` and + `artifact.recompute.cuda_temporal_replay_scan:ms=1128.981;count=128`. + Non-timed `/tmp/redo_fixmaass_compact_carry_t64_k128_h64_rerun2` reached + `15.786 tok/s`, peak `0.124505 GiB`, matched current-code T=1 training + `77.593 tok/s`, K-adjusted floor `0.6062 tok/s`, and gate pass. Earlier + compact-carry non-timed `/tmp/redo_fixmaass_compact_carry_t64_k128_h64` + reached `15.509 tok/s`; keep the warmed rerun as the accepted comparison. + T=1 mixed `B=1024` checks passed at `60,932 tok/s`, `61,813 tok/s`, and the + warmed `/tmp/redo_fixmaass_compact_carry_t1_b1024_mixed_rerun2` at + `65,401 tok/s`, all with peak `0.673192 GiB` and April 21 gates passing. +- Decision: + accept the compact hidden-carry materialization flag as a small generic + reverse-table output trim because it is parity-clean, preserves the high-level + API contract, records the active compact materialization tag, does not add + cell/population/benchmark selectors, and the warmed T*K/T=1 rows stay within + the accepted current-code envelope. Do not count this as R4/R13 closure: the + active owner is still the reverse/replay H-window shell, planner metadata + still reports `python_autograd_scan`, and the largest measured owners remain + `transition_message_reverse_table_device_loop`, + `temporal_artifact_recompute`, and + `artifact.recompute.cuda_temporal_replay_scan`. + +### 2026-04-29 UTC - R4 active reverse replay final-state elision + +Status: ACCEPTED AS A SMALL GENERIC REPLAY MATERIALIZATION TRIM; R4/R13 +REMAIN OPEN. + +Owner: shared multi-pop flat-bucket temporal replay materialization policy. + +- Current invariant: + active reverse-only replay should request only the CUDA scan outputs consumed + by the reverse table or by emitted output-step backward. Checkpoint tensors, + recurrent-message artifacts, transition tape tensors, reset policy, and + reverse-window tables remain required; final recurrent state tensors are + required only when the final physical step in the replay window is also an + emitted output step whose output backward needs those tensors. +- Current measured reason: + after compact hidden-carry materialization, the T64/K128/H64 owner profile + still reports `temporal_artifact_recompute` and + `artifact.recompute.cuda_temporal_replay_scan` as major open owners. The scan + currently passes `return_final_state=True` for every active reverse-only + replay window, even when the final recurrent hidden/K/V/state tensors are not + consumed by the reverse table. +- Planned implementation: + compute a generic `replay_needs_final_state` flag from materialization mode + and temporal emission metadata. Preserve `return_final_state=True` for normal + replay paths and for active reverse-only windows whose final physical step is + an emitted output-artifact step. Pass `False` only for active reverse-only + windows where final recurrent state would be dead materialization. +- Manual boundary review: + this edit is a scan materialization policy over temporal emission metadata + and active reverse table ownership. It adds no cell-kind selector, + population-name selector, benchmark-row selector, hidden-size policy key, + single/mixed route selector, lattice/config field, or cell-family parameter + bundle. ABI inputs remain flat graph/message tables, tensor-table roles, + primitive rows, reset policy, checkpoint policy, and materialization policy. +- Acceptance gate: + static checks, backend boundary tests, low-level reverse-table parity, + high-level K>1/K128 terminal-gradient coverage, reset parity, T=1/K=1 mixed + training parity, and current-code high-level audit evidence for both + T64/K128/H64 and T=1/B1024 mixed. Accept only if the warmed current-code rows + stay inside the accepted envelope and the active replay materialization owner + moves or records the final-state-elision tag without a semantic regression. +- Implementation: + `_try_cuda_mixed_flat_bucket_recompute_artifact_window` now computes + `replay_needs_final_state` from the active reverse-only materialization mode + and the existing output-artifact physical-step set. Normal replay paths still + pass `return_final_state=True`. Active reverse-only windows pass + `return_final_state=False` only when the final physical step in the replay + window is not an output-artifact step. Those windows record + `temporal_backward_glue:cuda_recompute_active_reverse_final_state_elided`. +- Validation: + `py_compile`, Ruff check, Ruff format check, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. CUDA validation passed for + low-level mixed reverse-table parity on GPU0 (`1 passed`), the high-level + mixed K>1/K128 terminal-gradient suite on GPU1 (`3 passed`), single/mixed + reset parity on GPU2 (`4 passed`), and T=1/K=1 mixed reverse-device-loop + parity on GPU3 (`2 passed`), all with private extension and Triton cache + directories. +- High-level audit evidence: + owner-timed `/tmp/redo_fixmaass_final_state_elision_t64_k128_h64_owner` + passed the high-level Fabric API `T=64`, `K=128`, `H=64`, `B=1`, mixed, + terminal training row at `15.432 tok/s`, peak `0.124436 GiB`, matched T=1 + training `47.697 tok/s`, and K-adjusted floor `0.3726 tok/s`. It recorded + the new final-state-elision tag. Owner timings stayed essentially unchanged: + `transition_message_reverse_table_device_loop:ms=1757.289;count=128`, + `temporal_artifact_recompute:ms=1181.814;count=128`, and + `artifact.recompute.cuda_temporal_replay_scan:ms=1130.240;count=128`. + The first non-timed T64 row + `/tmp/redo_fixmaass_final_state_elision_t64_k128_h64` reached + `15.359 tok/s`; the same-cache warmed rerun + `/tmp/redo_fixmaass_final_state_elision_t64_k128_h64_rerun` reached + `15.764 tok/s`, peak `0.124436 GiB`, matched T=1 training + `77.177 tok/s`, K-adjusted floor `0.6029 tok/s`, and gate pass. This is + effectively flat with the previous compact-carry warmed comparison + (`15.786 tok/s`) while trimming peak memory from `0.124505 GiB`. + T=1 mixed `B=1024` guard + `/tmp/redo_fixmaass_final_state_elision_t1_b1024_mixed` reached + `67,506.816 tok/s`, peak `0.673192 GiB`, and passed the April 21 reference + gate. +- Decision: + accept this as a small generic scan materialization trim. It is parity-clean, + preserves high-level API behavior, records the active elision tag, keeps + T64/K128/H64 and T=1/B1024 within the accepted current-code envelope, and + slightly lowers peak memory. Do not count it as R4/R13 closure: the measured + replay/reverse owners remain first-order, and planner metadata still reports + `temporal_plan_forward_owners=['python_autograd_scan']` and + `temporal_plan_backward_owners=['python_autograd_scan']`. + +### 2026-04-29 UTC - R4 diagonal trace checkpoint materialization policy + +Status: REJECTED; CODE ROLLED BACK. + +Owner: shared multi-pop flat-bucket temporal replay scan materialization. + +- Current invariant: + `materialize_diagonal_trace_state=False` means the CUDA scan must not allocate + or write diagonal eligibility-trace state surfaces for replay windows that do + not expose or consume those traces. Active reverse-only replay consumes + diagonal `hc1/hc2` state-before windows and transition/message tape tensors, + not diagonal trace checkpoint tensors. +- Current measured reason: + after final-state elision, active reverse-only replay still asks the temporal + scan for checkpoints with stride one. The scan already suppresses trace + tracking/writes when `materialize_diagonal_trace_state=False`, but it still + allocates the eight diagonal trace checkpoint tensors whenever checkpoint + tensors are returned. That is dead materialization for the active reverse + table path. +- Planned implementation: + make the scan allocate diagonal trace checkpoint tensors only when both + checkpoints are requested and `track_diagonal_traces` is true. Preserve the + fixed output ABI by returning the existing empty tensor placeholders when + trace checkpoint materialization is disabled. +- Manual boundary review: + this is a primitive extra-state materialization policy inside the shared scan + binding. It adds no cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, single/mixed route selector, + lattice/config field, or cell-family parameter bundle. The policy is driven + by the existing tensor-table materialization flag and checkpoint policy. +- Validation before rejection: + static checks passed (`py_compile`, Ruff check, Ruff format check, + `git diff --check`, and `tests/test_fabric_backend_boundaries.py`). CUDA + parity passed for default final-state temporal-superop rows on GPU0 + (`4 passed`), the high-level K>1/K128 mixed backward suite on GPU1 + (`3 passed`), single/mixed reset parity on GPU2 (`4 passed`), T=1/K=1 mixed + training parity on GPU3 (`2 passed`), and low-level reverse-table parity on + GPU4 (`1 passed`), all with private extension and Triton cache directories. +- Rejection evidence: + the probe did reduce T64/K128/H64 peak memory from `0.124436 GiB` to + `0.121552 GiB`, but throughput did not stay within the accepted active + envelope. Owner-timed + `/tmp/redo_fixmaass_trace_checkpoint_t64_k128_h64_owner` reported + `15.214 tok/s`, `transition_message_reverse_table_device_loop:ms=1759.662`, + `temporal_artifact_recompute:ms=1184.570`, and + `artifact.recompute.cuda_temporal_replay_scan:ms=1130.699`. Non-timed + `/tmp/redo_fixmaass_trace_checkpoint_t64_k128_h64` reported + `15.522 tok/s`, and same-cache rerun + `/tmp/redo_fixmaass_trace_checkpoint_t64_k128_h64_rerun` reported + `15.556 tok/s`. The previous accepted final-state-elision warmed row was + `15.764 tok/s`, and the compact-carry warmed comparison before that was + `15.786 tok/s`. T=1/B1024 mixed remained healthy at + `/tmp/redo_fixmaass_trace_checkpoint_t1_b1024_mixed`, `70,100 tok/s`, peak + `0.673192 GiB`, but the T*K throughput regression is not acceptable. +- Decision: + reject and roll back the kernel allocation change. It was generic and + parity-clean, but memory-only improvement does not justify a repeated + throughput drop in the active T64/K128/H64 row. Do not retry this exact + checkpoint-trace allocation elision unless a larger replay-scan restructuring + can recover or improve throughput while preserving the memory reduction. + +### 2026-04-29 UTC - R4 compact hidden-carry copy-loop elision + +Status: ACCEPTED AS A SMALL GENERIC REVERSE-TABLE LOOP TRIM; R4/R13 REMAIN +OPEN. + +Owner: active shared flat-bucket reverse table device loop. + +- Current invariant: + reverse-table materialization policy should remove active device work, not + only shrink returned tensors. When `materialize_hidden_message_window=False`, + the active runtime consumes only the first hidden-message carry slice for the + preceding H window. +- Current measured reason: + compact hidden-carry materialization shrank the returned + `grad_hidden_message_window`, but the CUDA device loop still iterates over + the `B * R * H` hidden-copy loop at every reverse timestep and branches inside + that loop. For compact carry mode, all timesteps except `t == 0` do no useful + hidden-message materialization work. +- Planned implementation: + keep the existing synchronization boundary, but run the hidden-message copy + loop only when `materialize_hidden_message_window` is true or `t == 0`. + Default full-window parity behavior remains unchanged; compact active + reverse-only windows skip the empty copy loop for nonzero `t`. +- Manual boundary review: + this edit is a generic reverse-table materialization-policy cleanup. It adds + no cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, single/mixed route selector, lattice/config field, or + cell-family parameter bundle. ABI inputs remain the existing flat tensor-table + roles, op rows, reset policy, and materialization flag. +- Implementation: + `transition_message_reverse_table_device_loop_kernel` now enters the + hidden-message carry copy loop only when full hidden-message materialization + is requested or when `t == 0` supplies the compact carry slice. The grid sync + remains in place, so the reverse dependency scan ordering is unchanged. +- Validation: + `git diff --check` and `tests/test_fabric_backend_boundaries.py` passed. + CUDA validation passed for low-level mixed reverse-table parity on GPU0 + (`1 passed`), the high-level mixed K>1/K128 terminal-gradient suite on GPU1 + (`3 passed`), single/mixed reset parity on GPU2 (`4 passed`), and T=1/K=1 + mixed reverse-device-loop parity on GPU3 (`2 passed`), all with private + extension and Triton cache directories. +- High-level audit evidence: + owner-timed `/tmp/redo_fixmaass_hidden_copy_elide_t64_k128_h64_owner` + passed the high-level Fabric API `T=64`, `K=128`, `H=64`, `B=1`, mixed, + terminal training row at `15.444 tok/s`, peak `0.124436 GiB`, matched T=1 + training `71.778 tok/s`, and K-adjusted floor `0.5608 tok/s`. The measured + reverse-table owner moved from the previous accepted final-state-elision + `transition_message_reverse_table_device_loop:ms=1757.289;count=128` to + `1746.779;count=128`; replay stayed comparable at + `artifact.recompute.cuda_temporal_replay_scan:ms=1128.542;count=128`. + First non-timed `/tmp/redo_fixmaass_hidden_copy_elide_t64_k128_h64` was + noisy at `15.400 tok/s`; the same-cache warmed rerun + `/tmp/redo_fixmaass_hidden_copy_elide_t64_k128_h64_rerun` reached + `15.789 tok/s`, peak `0.124436 GiB`, matched T=1 training + `75.873 tok/s`, K-adjusted floor `0.5928 tok/s`, and gate pass. T=1 mixed + `B=1024` guard `/tmp/redo_fixmaass_hidden_copy_elide_t1_b1024_mixed` + reached `68,269.863 tok/s`, peak `0.673192 GiB`, and passed the April 21 + reference gate. +- Decision: + accept this kernel trim because it removes real no-op per-timestep work in + compact carry mode, improves the measured reverse-table owner, and keeps the + warmed T64/K128/H64 and T=1/B1024 rows inside the accepted current-code + envelope. Do not count it as R4/R13 closure: the active path still has + separate replay and reverse owners, and planner metadata still reports + `python_autograd_scan` for forward and backward temporal owners. + +### 2026-04-29 UTC - R4 inactive-window boundary tensor fast path + +Status: ACCEPTED AS GENERIC PYTHON-SHELL REDUCTION; R4/R13 REMAIN OPEN. + +Owner: active shared flat-bucket reverse/window Python shell. + +- Current invariant: + once the reverse table has emitted dense physical-step input K/V gradient + windows, inactive terminal-output windows should not rebuild those tensors as + per-step Python `TemporalBoundaryBackwardStep` objects. Boundary projection + backward is still a generic declared input-adapter primitive; this slice only + changes how the already-table-owned tensor windows are handed to it. +- Current measured reason: + the active R4 owner still includes Python replay/window shell work around the + CUDA reverse table. In terminal T*K rows with `H=64`, many backward windows + have no direct output step. Those windows currently allocate a Python object + per physical step and then restack the same input-gradient tensors before the + boundary projection backward. +- Planned implementation: + add a tensor-window boundary backward helper that accepts `[H,B,input,D]` + input K/V gradient windows directly. Use it only for active reverse-table + windows with no output-backward sequence. Output-active windows and fallback + paths keep the existing per-step path. +- Manual boundary review: + this is a host-shell reduction around a generic boundary projection + primitive. It adds no cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, single/mixed route selector, + lattice/config field, or cell-family parameter bundle. The trigger is the + existing active reverse-window table plus absence of output-backward + gradients, not a family or shape condition. +- Implementation: + added `run_temporal_boundary_backward_tensor_window`, which flattens the + reverse table's already-dense `[H,B,input,D]` K/V gradient windows directly + into the generic boundary projection backward primitive. Active reverse-table + windows with no output-backward sequence use this tensor-window handoff and + record `cuda_boundary_backward_tensor_window_inputs`; output-active windows + and non-table fallback paths keep the existing per-step accumulation path. +- Validation: + `py_compile`, Ruff check, Ruff format check, and `git diff --check` passed. + `tests/test_fabric_backend_boundaries.py` passed as `2 passed`. GPU 0 + low-level mixed reverse-table parity passed as `1 passed`; GPU 1 K>1/K128 + parity passed as `3 passed`; GPU 2 reset parity passed as `4 passed`; GPU 3 + T=1/K=1 mixed training parity passed as `2 passed`. All GPU rows used + private Torch extension and Triton cache dirs on GPUs 0-3. +- Audit evidence: + owner-timed T64/K128/H64 terminal mixed row + `/tmp/redo_fixmaass_boundary_tensor_window_t64_k128_h64_owner` passed with + `15.455 tok/s`, peak `0.124436 GiB`, matched T=1 training + `69.777 tok/s`, K-adjusted floor `0.5451 tok/s`, reverse-table owner + `1747.066 ms/count=128`, artifact recompute `1182.398 ms/count=128`, and + CUDA replay scan `1129.930 ms/count=128`. + First non-timed T64/K128/H64 confirmation + `/tmp/redo_fixmaass_boundary_tensor_window_t64_k128_h64` was noisy at + `15.416 tok/s`; same-cache warmed rerun + `/tmp/redo_fixmaass_boundary_tensor_window_t64_k128_h64_rerun` reached + `15.809 tok/s`, peak `0.124436 GiB`, matched T=1 training + `75.783 tok/s`, and K-adjusted floor `0.5921 tok/s`. +- T=1 guard: + patched T=1/B1024 mixed terminal guard + `/tmp/redo_fixmaass_boundary_tensor_window_t1_b1024_mixed_gpu4` passed at + `64,835.109 tok/s`, peak `0.673192 GiB`, and mixed-stack ratio `3.681x`. + This was below the earlier historical accepted `68k` guard, so a temporary + same-session baseline worktree at `ece849c` was run with the same command: + `/tmp/redo_fixmaass_boundary_tensor_window_baseline_t1_b1024_mixed` reached + `65,116.696 tok/s`, peak `0.673192 GiB`. The T=1 path did not record + `cuda_boundary_backward_tensor_window_inputs`, so the lower absolute value is + treated as session noise rather than a patch-specific regression. +- Decision: + accept this substage because it removes Python per-step boundary-step object + construction on inactive terminal-output H windows, keeps the T64/K128/H64 + warmed row at or above the current accepted envelope, and is flat against the + same-session T=1 baseline. This does not close R4/R13: planner-level forward + and backward owners still report `python_autograd_scan`, and the larger + replay/reverse window shell remains open. + +### 2026-04-29 UTC - R4 backend-order recurrent carry across H windows + +Status: ACCEPTED AS BACKEND-CARRY SHELL REDUCTION; R4/R13 REMAIN OPEN. + +Owner: active shared flat-bucket reverse/window carry materialization shell. + +- Current invariant: + H-window backward carry between adjacent reverse windows is a backend-order + recurrent-public gradient. The active CUDA reverse table emits that tensor in + backend recurrent order. Materializing it into a full graph-order `cells` + gradient after every H window, then slicing and reordering it back into + backend order for the next window, is host/window-shell ownership that should + be removed. +- Current measured reason: + the active launch list still records + `cuda_recurrent_state_grad_materialize_window_boundary` and the T64/K128/H64 + owner profile continues to spend most time in the reverse table plus artifact + replay. Removing per-window carry materialization should reduce one repeated + shell operation without changing transition/message math. +- Planned implementation: + extend the temporal backward window result with a generic + `grad_carry_recurrent_hidden_backend` tensor. The active reverse-table path + will pass `grad_hidden_message_window[0]` to the previous H window directly. + The scan executor will materialize a full graph-order `cells` gradient only + at the final autograd state boundary or before any host fallback that cannot + consume backend-order carry. +- Manual boundary review: + ABI inputs and outputs are flat recurrent-bank tensors, flat bucket identity, + tensor-table roles, reset/materialization policy, and the existing runtime + recurrent graph/backend order maps. This adds no cell-kind selector, + population-name selector, benchmark-row selector, hidden-size policy key, + single/mixed route selector, lattice/config field, or cell-family parameter + bundle. +- Implementation: + extended the temporal backward window result with + `grad_carry_recurrent_hidden_backend`. Active reverse-table windows now carry + `grad_hidden_message_window[0]` directly in backend recurrent order across H + windows. The next window consumes that tensor without materializing full + graph-order cells. Full `cells` gradients are materialized only at the final + autograd state boundary, or before any host fallback that cannot consume + backend-order carry. +- Validation: + `py_compile`, Ruff check, Ruff format check, and `git diff --check` passed. + `tests/test_fabric_backend_boundaries.py` passed as `2 passed`. GPU 0 + low-level mixed reverse-table parity passed as `1 passed`; GPU 1 K>1/K128 + parity passed as `3 passed`; GPU 2 reset parity passed as `4 passed`; GPU 3 + T=1/K=1 mixed training parity passed as `2 passed`. All GPU rows used + private Torch extension and Triton cache dirs on GPUs 0-3. +- Audit evidence: + owner-timed T64/K128/H64 terminal mixed row + `/tmp/redo_fixmaass_backend_carry_t64_k128_h64_owner` passed at + `15.432 tok/s`, peak `0.124428 GiB`, matched T=1 training + `68.238 tok/s`, K-adjusted floor `0.5331 tok/s`, reverse-table owner + `1751.111 ms/count=128`, artifact recompute `1181.634 ms/count=128`, and + replay scan `1128.687 ms/count=128`. The active launch list records + `cuda_recurrent_state_grad_materialize_final_boundary` and no longer records + the per-window `cuda_recurrent_state_grad_materialize_window_boundary` tag on + this T64 row. + First non-timed T64/K128/H64 confirmation + `/tmp/redo_fixmaass_backend_carry_t64_k128_h64` was low at + `15.418 tok/s`; same-cache reruns reached `15.775 tok/s` + (`/tmp/redo_fixmaass_backend_carry_t64_k128_h64_rerun`) and `15.802 tok/s` + (`/tmp/redo_fixmaass_backend_carry_t64_k128_h64_rerun2`), both with peak + `0.124428 GiB`. +- Current-code baseline check: + because the warmed T64 row was slightly below the prior `15.809 tok/s` + boundary-fast-path run, a temporary same-session baseline worktree at + `1461a9e` was run with the same T64/K128/H64 command: + `/tmp/redo_fixmaass_backend_carry_baseline_t64_k128_h64` reached + `15.809 tok/s`, peak `0.124436 GiB`, and still recorded + `cuda_recurrent_state_grad_materialize_window_boundary`. The current slice is + therefore effectively flat within measurement noise while reducing repeated + shell materialization and peak memory slightly; it is not a material + throughput win. +- T=1 guard: + `/tmp/redo_fixmaass_backend_carry_t1_b1024_mixed` passed at + `64,879.368 tok/s`, peak `0.673192 GiB`, and mixed-stack ratio `3.419x`, + consistent with the same-session T=1 baseline band from the previous + substage. +- Decision: + accept this as a backend-carry ownership reduction because it removes the + repeated full-cells carry materialization between H windows while preserving + parity, high-level API behavior, April 21 gates, and same-session throughput. + Do not count it as R4/R13 closure: dominant measured owners remain the + transition/message reverse table and artifact replay, and planner-level + forward/backward owners still report `python_autograd_scan`. + +### 2026-04-29 UTC - R13 compact boundary projection for T*K windows + +Status: ACCEPTED AS A GENERIC T*K BOUNDARY-ADAPTER REDUCTION; R4/R13 +REMAIN OPEN. + +Owner: active shared temporal backward boundary-adapter shell for T*K rows. + +- Current invariant: + K internal steps are streamed physical steps over the same graph, but external + boundary inputs are outer-step tensors. If multiple physical steps in an + active reverse-only H window refer to the same outer boundary row, the + boundary projection backward should receive the summed K/V gradient for that + outer row once. Repeating the same boundary projection backward per physical + microstep is host/window-shell work, not cell semantics. +- Planned implementation: + add a compact tensor-window boundary backward path for active reverse-table + windows with no output-backward sequence. It groups physical local steps by + `scalar_temporal_scan_step(...).outer_step`, sums `grad_input_k_window` and + `grad_input_v_window` inside each group, runs the existing generic boundary + projection backward on the compact boundary rows, and places the resulting + boundary gradient on one representative physical step for the existing + boundary-gradient accumulator. Output-active and fallback paths keep the + existing full physical-step boundary sequence. +- Manual boundary review: + this edit is a generic boundary input-adapter reduction over temporal + schedule metadata and tensor-window gradients. It adds no cell-kind selector, + population-name selector, benchmark-row selector, hidden-size policy key, + single/mixed route selector, lattice/config field, or cell-family parameter + bundle. The trigger is active reverse-window tensor tables plus repeated + outer boundary identity, not a family or graph-shape branch. +- Implementation: + added `run_temporal_boundary_backward_compact_outer_window`. Active + reverse-table windows with no output-backward sequence now sum K/V input + gradients over repeated outer boundary rows before invoking the existing + generic boundary projection backward. The resulting boundary gradient is + assigned to a representative physical step and then consumed by the existing + CUDA boundary-gradient accumulator. K=1 windows fall back to the prior dense + tensor-window path. +- Validation: + `py_compile`, Ruff check, Ruff format check, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. GPU 1 K>1/K128 parity + passed as `3 passed`; GPU 2 reset parity passed as `4 passed`; GPU 3 T=1/K=1 + mixed training parity passed as `2 passed`. The first GPU 0 low-level + reverse-table parity run missed by one `3.5e-4` recurrent K/V weight entry in + untouched kernel code; a fresh-cache rerun passed as `1 passed`, so no + tolerance was relaxed and the miss is treated as nondeterministic floating + accumulation noise. +- Audit evidence: + owner-timed T64/K128/H64 terminal mixed row + `/tmp/redo_fixmaass_compact_boundary_t64_k128_h64_owner` passed at + `15.368 tok/s`, peak `0.124188 GiB`, matched T=1 training + `71.353 tok/s`, K-adjusted floor `0.5574 tok/s`, reverse-table owner + `1748.061 ms/count=128`, artifact recompute `1183.266 ms/count=128`, and + replay scan `1130.768 ms/count=128`. The active launch list records + `cuda_boundary_backward_compact_outer_window_inputs`. + Same-cache non-timed confirmations + `/tmp/redo_fixmaass_compact_boundary_t64_k128_h64_rerun` and + `/tmp/redo_fixmaass_compact_boundary_t64_k128_h64_rerun2` reached + `15.703 tok/s` and `15.779 tok/s`, both with peak `0.124188 GiB`; the second + rerun matched T=1 training at `77.874 tok/s` and passed the K-adjusted gate. +- T=1 guard: + `/tmp/redo_fixmaass_compact_boundary_t1_b1024_mixed` passed at + `61,167.529 tok/s`, peak `0.673192 GiB`, and mixed-stack ratio `3.468x`; the + same-cache warmed rerun + `/tmp/redo_fixmaass_compact_boundary_t1_b1024_mixed_rerun` reached + `64,636.681 tok/s`, peak `0.673192 GiB`, and mixed-stack ratio `3.507x`. + The compact boundary tag is absent from the T=1 row, as expected. +- Decision: + accept this as a bounded T*K shell reduction because it removes repeated + boundary projection work for K windows, is parity-clean, keeps warmed + T64/K128/H64 and T=1/B1024 throughput inside the current-code envelope, and + slightly reduces T64 peak memory. This does not close R4/R13: the dominant + reverse/replay owners remain, and planner-level owners still report + `python_autograd_scan`. + +### 2026-04-29 UTC - R13 replay input K/V one-window cache + +Status: ACCEPTED AS A BOUNDED GENERIC REPLAY-SHELL REDUCTION; R4/R13 +REMAIN OPEN. + +Owner: active shared temporal artifact replay shell for T*K rows. + +- Current invariant: + input boundary K/V projection is an outer-boundary adapter value. Adjacent H + recompute windows can refer to the same outer input slice when `H < K`; the + replay shell should not reproject that same outer boundary slice for each + physical H window. +- Planned implementation: + add a bounded one-entry replay input K/V cache owned by the temporal backward + executor. `_recompute_temporal_bucket_artifact_window` will look up + `(outer_start, outer_stop)` before invoking the generic input boundary + projection, reuse the cached tensors on adjacent matching windows, and replace + the cache when the outer slice changes. This avoids unbounded T materialization + while removing duplicate projection launches for common T*K/H schedules. +- Manual boundary review: + the cache key is temporal schedule boundary identity only. It adds no + cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, single/mixed route selector, lattice/config field, or + cell-family parameter bundle. Projection math remains the existing generic + boundary input adapter. +- Implementation: + `_recompute_temporal_bucket_artifact_window` now accepts a one-entry + `replay_input_kv_cache` owned by `TemporalPhysicalBackwardScanExecutor.run`. + On a matching `(outer_start, outer_stop)` it reuses the projected input K/V + tensors and records `cuda_replay_input_projection_outer_window_cache_hit`; + when the outer slice changes it clears and replaces the cache. +- Validation: + `py_compile`, Ruff check, Ruff format check, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. GPU 1 K>1/K128 parity + passed as `3 passed`; GPU 2 reset parity passed as `4 passed`; GPU 3 T=1/K=1 + mixed training parity passed as `2 passed`. +- Audit evidence: + owner-timed T64/K128/H64 terminal mixed row + `/tmp/redo_fixmaass_input_cache_t64_k128_h64_owner` passed at + `15.419 tok/s`, peak `0.124180 GiB`, matched T=1 training + `70.919 tok/s`, K-adjusted floor `0.5541 tok/s`, reverse-table owner + `1748.927 ms/count=128`, artifact recompute `1177.220 ms/count=128`, and + replay scan `1130.614 ms/count=128`. The input projection replay owner moved + from the prior compact-boundary row's `18.983 ms/count=128` to + `9.380 ms/count=64`, and the active launch list records + `cuda_replay_input_projection_outer_window_cache_hit`. + Same-cache non-timed confirmations + `/tmp/redo_fixmaass_input_cache_t64_k128_h64_rerun` and + `/tmp/redo_fixmaass_input_cache_t64_k128_h64_rerun2` reached + `15.751 tok/s` and `15.768 tok/s`, both with peak `0.124180 GiB`. +- T=1 guard: + `/tmp/redo_fixmaass_input_cache_t1_b1024_mixed` passed at + `66,096.234 tok/s`, peak `0.673192 GiB`, and mixed-stack ratio `3.444x`. + The replay input-projection cache-hit tag is absent from the T=1 row, as + expected. +- Decision: + accept this bounded replay-shell reduction because it halves duplicate input + K/V projection launches for the active `H=64,K=128` schedule, keeps warmed + T64/K128/H64 and T=1/B1024 throughput inside the current-code envelope, and + slightly lowers peak memory. This does not close R4/R13: the dominant + reverse/replay owners remain, and planner-level owners still report + `python_autograd_scan`. + +### 2026-04-29 UTC - R4 deferred recurrent projection/query param binding + +Status: ACCEPTED AS A BOUNDED TEMPORAL PARAM-BINDING SHELL REDUCTION; +R4/R13 REMAIN OPEN. + +Owner: active shared flat-bucket reverse/window Python shell around table-owned recurrent message gradients. + +- Current invariant: + recurrent query and recurrent sender K/V parameter gradients are additive temporal reductions over lowered message/projection tensor-table rows. The temporal executor may accumulate raw table gradients across H windows and bind them to trainable parameters once, but it must not change message math, cell math, bucket identity, or parameter ownership. +- Current measured reason: + the latest owner-timed T64/K128/H64 row still records `glue.param_grad_binding:ms=24.131;count=1664` plus `message.query_param:ms=2.856;count=129`. The dominant owners remain reverse/replay, but this per-window binding is still part of the open Python H-window shell that prevents true CUDA temporal-superop ownership. +- Planned implementation: + extend `TemporalBackwardWindowResult` with deferred raw recurrent-query and recurrent-K/V-weight gradients emitted by the existing reverse table. Active reverse-table windows will return those raw backend-order gradients instead of immediately calling query/projection parameter binders. `TemporalPhysicalBackwardScanExecutor.run` will accumulate the raw tensors across windows, convert backend-order to graph-order once, run the existing generic query/projection parameter binders once after the reverse scan, and record a deferred binding tag. +- Manual boundary review: + the ABI remains flat tensor-table gradients from the existing reverse table: recurrent Q gradient and recurrent K/V projection weight gradient. No cell-kind selector, population-name selector, benchmark-row selector, hidden-size policy key, single/mixed route selector, lattice/config field, or cell-family parameter bundle is added. This is a temporal reduction/binding ownership change only; message/cell primitive math remains in the existing generic lowered paths. +- Implementation: + active reverse-table windows now return deferred raw recurrent-query and + recurrent K/V projection weight gradients. The physical temporal backward + executor accumulates those backend-order raw tensors across all H windows, + converts to graph order once, and calls the existing generic query/projection + parameter binders once after the reverse scan. Boundary, readout, and + transition primitive parameter gradients remain on their existing generic + paths. +- Validation: + `py_compile`, Ruff check, Ruff format check, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. GPU 0 low-level mixed + reverse-table parity passed as `1 passed`; GPU 1 K>1/K128 parity passed as + `3 passed`; GPU 2 reset parity passed as `4 passed`; GPU 3 T=1/K=1 mixed + training parity passed as `2 passed`. All GPU rows used private Torch + extension and Triton cache dirs on GPUs 0-3. +- Audit evidence: + owner-timed T64/K128/H64 terminal mixed row + `/tmp/redo_fixmaass_defer_param_t64_k128_h64_owner` passed at + `15.236 tok/s`, peak `0.124569 GiB`, matched T=1 training + `63.681 tok/s`, and K-adjusted floor `0.4975 tok/s`. It showed the intended + owner movement for recurrent query binding: `message.query_param` moved from + the previous accepted `2.856 ms/count=129` to `0.235 ms/count=2`. The broader + `glue.param_grad_binding` owner remained effectively unchanged at + `24.373 ms/count=1664`, so this is not a major R4 close. The dominant owners + remain `transition_message_reverse_table_device_loop:ms=1752.452`, + `temporal_artifact_recompute:ms=1180.435`, and + `artifact.recompute.cuda_temporal_replay_scan:ms=1132.132`. + Same-cache non-timed confirmations + `/tmp/redo_fixmaass_defer_param_t64_k128_h64_rerun` and + `/tmp/redo_fixmaass_defer_param_t64_k128_h64_rerun2` reached + `15.720 tok/s` and `15.809 tok/s`, both passing the K-adjusted matched T=1 + gate. The T=1/B1024 mixed guard + `/tmp/redo_fixmaass_defer_param_t1_b1024_mixed_rerun` passed at + `63,502.519 tok/s`, peak `0.673192 GiB`, with the deferred binding tags + present on the high-level API path. +- Decision: + accept this as a bounded temporal parameter-binding shell reduction because + it collapses recurrent query parameter binding across H windows without + changing cell/message math and keeps warmed T64/K128/H64 throughput within + the accepted current-code envelope. Do not count it as R4/R13 closure: the + remaining first-order owners are still replay/reverse H-window orchestration, + the reverse table device loop, and planner-level `python_autograd_scan`. + +### 2026-04-29 UTC - R4 deferred transition primitive param binding + +Status: ACCEPTED AS A BOUNDED TEMPORAL TRANSITION-BINDING SHELL REDUCTION; +R4/R13 REMAIN OPEN. + +Owner: active shared flat-bucket reverse/window Python shell around table-owned +transition primitive gradients. + +- Current invariant: + transition primitive parameter gradients emitted by the reverse tensor-table + path are additive temporal reductions. The temporal executor may accumulate + those primitive named-gradient sequences across H windows and bind them to + trainable parameters once after the reverse scan, but it must not change + primitive math, bucket ownership, population scheduling, or parameter + binding semantics. +- Current measured reason: + after recurrent query/KV binding deferral, the owner-timed T64/K128/H64 row + still records `glue.param_grad_binding:ms=24.373;count=1664`. Query binding + moved, so the next bounded shell owner is the per-window transition primitive + parameter binding that remains inside active reverse-table windows. +- Planned implementation: + return table-emitted transition primitive named-gradient accumulators from + active reverse-table windows, merge those accumulators in + `TemporalPhysicalBackwardScanExecutor.run`, then call the existing generic + `bind_temporal_transition_param_grads` once after all H windows. Fallback + host-loop windows keep their existing per-window binding path. +- Manual boundary review: + the data being deferred is the existing transition primitive gradient map + keyed by lowered primitive parameter names and bucket identities. No + cell-kind selector, population-name route selector, benchmark-row selector, + hidden-size policy key, single/mixed route split, lattice/config field, or + cell-family parameter bundle is added. This is a reduction/binding ownership + move only. +- Implementation: + active reverse-table windows now return their table-emitted transition + primitive named-gradient accumulators instead of binding those gradients to + trainable parameters per H window. The physical temporal backward executor + merges those accumulators across windows and binds them once after the reverse + scan. The first attempt retained every window tensor and raised T64 peak + memory to `0.199288 GiB`; it was corrected before acceptance by accumulating + each named primitive gradient in place as windows complete. +- Validation: + `py_compile`, Ruff check, Ruff format check, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. GPU 1 K>1/K128 parity + passed as `3 passed`; GPU 2 reset parity passed as `4 passed`; GPU 3 T=1/K=1 + mixed training parity passed as `2 passed`, using private Torch extension and + Triton cache dirs. +- Audit evidence: + the first retained-list owner-timed probe + `/tmp/redo_fixmaass_defer_transition_t64_k128_h64_owner` proved the binding + count movement but was not accepted because peak memory rose to + `0.199288 GiB`. The accepted in-place owner-timed row + `/tmp/redo_fixmaass_defer_transition_inplace_t64_k128_h64_owner` passed at + `15.570 tok/s`, peak `0.124541 GiB`, matched T=1 training + `66.381 tok/s`, and K-adjusted floor `0.5186 tok/s`. It moved + `glue.param_grad_binding` from the prior accepted + `24.373 ms/count=1664` to `0.173 ms/count=13`, while `message.query_param` + stayed collapsed at `0.170 ms/count=2`. Dominant owners remain + `transition_message_reverse_table_device_loop:ms=1748.542`, + `temporal_artifact_recompute:ms=1182.183`, and + `artifact.recompute.cuda_temporal_replay_scan:ms=1134.959`. + Same-cache non-timed T64 confirmation + `/tmp/redo_fixmaass_defer_transition_inplace_t64_k128_h64_rerun` reached + `15.885 tok/s`, peak `0.124541 GiB`, and passed the K-adjusted matched T=1 + gate. T=1/B1024 mixed guard + `/tmp/redo_fixmaass_defer_transition_inplace_t1_b1024_mixed` passed at + `63,707.333 tok/s`, peak `0.673192 GiB`. +- Decision: + accept this bounded shell reduction because it removes nearly all repeated + transition parameter-binding work across H windows without retaining a full + window list, preserves parity, and keeps warmed T64/K128/H64 and T=1/B1024 + throughput inside the accepted current-code envelope. This still does not + close R4/R13: the first-order reverse/replay CUDA owners and planner-level + `python_autograd_scan` remain. + +### 2026-04-29 UTC - R13 deferred inactive boundary projection binding + +Status: REJECTED; CODE ROLLED BACK. + +Owner: active shared flat-bucket boundary adapter shell for terminal T*K +windows. + +- Current invariant: + boundary input K/V gradients are additive by outer timestep. For inactive + terminal-output reverse windows, the temporal executor may accumulate + table-emitted K/V boundary gradients by outer input row and run the existing + generic boundary projection backward once per outer row after the reverse + scan. It must not change boundary adapter math, message math, cell math, + reset semantics, or benchmark-visible APIs. +- Current measured reason: + after transition binding deferral, `public_projection` remains at + `count=128` for the T64/K128/H64 terminal row. With `H=64,K=128`, many + inactive reverse windows refer to the same outer boundary row, so the current + compact-in-window projection still repeats boundary backward across adjacent + H windows. +- Planned implementation: + active reverse-table windows with no output-backward sequence will return + compact outer boundary K/V gradients instead of immediately invoking boundary + public-projection backward. `TemporalPhysicalBackwardScanExecutor.run` will + accumulate those gradients by outer timestep, run the existing generic + boundary public-projection backward once for the deferred rows, then add the + resulting gradients into `grad_boundary_seq`. Output-active and fallback + windows keep their current boundary path. +- Manual boundary review: + the deferred data is only outer timestep id, boundary input tensor, and + generic boundary K/V gradients. No cell-kind selector, population-name route + selector, benchmark-row selector, hidden-size policy key, single/mixed route + split, lattice/config field, or cell-family parameter bundle is added. +- Validation before rejection: + static checks passed (`py_compile`, Ruff check, Ruff format after formatting, + `git diff --check`, and `tests/test_fabric_backend_boundaries.py`). GPU 1 + K>1/K128 parity passed as `3 passed`; GPU 2 reset parity passed as `4 + passed`; GPU 3 T=1/K=1 mixed training parity passed as `2 passed`, all using + private Torch extension and Triton cache dirs. +- Rejection evidence: + owner-timed `/tmp/redo_fixmaass_defer_boundary_t64_k128_h64_owner` proved the + intended call-count movement: `public_projection` moved from `count=128` to + `count=2`, with T64 peak memory flat at `0.124628 GiB`. However the dominant + owners did not improve (`transition_message_reverse_table_device_loop` + `1749.650 ms`, `temporal_artifact_recompute` `1187.887 ms`, + `artifact.recompute.cuda_temporal_replay_scan` `1137.097 ms`), and warmed + non-timed confirmations regressed versus the accepted transition-binding row: + `/tmp/redo_fixmaass_defer_boundary_t64_k128_h64_rerun` reached + `15.752 tok/s`, and + `/tmp/redo_fixmaass_defer_boundary_t64_k128_h64_rerun2` reached + `15.692 tok/s`, below the prior accepted `15.885 tok/s` envelope. T=1/B1024 + mixed guard `/tmp/redo_fixmaass_defer_boundary_t1_b1024_mixed` passed at + `63,521.762 tok/s`, peak `0.673192 GiB`, but T*K throughput regression is + not acceptable. +- Decision: + reject and roll back the boundary-deferral code. The slice was generic and + parity-clean, but reducing a subdominant public-projection count did not move + the first-order replay/reverse owners and regressed warmed T64/K128/H64 + throughput. Do not retry this exact deferred-boundary path unless it is fused + with a larger replay/reverse superop change that removes enough host/window + work to recover throughput. + +### 2026-04-29 UTC - R4 reverse-table primitive gate-param reduction probe + +Status: REJECTED; CODE ROLLED BACK. + +Owner: `transition_message_reverse_table_device_loop` inside the shared +flat-bucket temporal backward path. + +- Current invariant: + primitive parameter gradients are tensor-table/op-row reductions owned by the + temporal reverse engine. When the reverse table already computes the + per-timestep recurrent-affine gate adjoints, the gate-weight and gate-bias + reductions can be emitted by that same reverse-table kernel instead of + launching separate post-kernel PyTorch reductions. This must remain a + primitive-row implementation detail, not a cell-family or benchmark route. +- Manual boundary review: + ABI inputs remain the existing reverse tensor table, op rows, and scalar + primitive metadata. The probe adds no cell-kind selector, population-name + selector, benchmark-row selector, hidden-size policy key, single/mixed route, + lattice/config field, or cell-family parameter bundle. It only changes where + the existing lowered gated recurrent-affine primitive writes two parameter + reductions. +- Planned implementation: + zero-initialize `grad_gated_gate_weight`, have + `transition_message_reverse_table_device_loop_kernel` accumulate it while it + is already traversing gate adjoints, and remove the post-kernel bmm reduction + for that output. The first attempt also moved gate-bias reduction into the + kernel, but the low-level parity test caught a layout mismatch, so the probe + now leaves the existing gate-bias sum path unchanged. Keep all other + projection/input reductions unchanged. Accept only if parity passes and + warmed T64/K128/H64 does not regress versus the accepted transition-binding + baseline; otherwise roll back and mark rejected. +- Validation before rejection: + `git diff --check` passed. The low-level GPU 0 reverse-table parity test + passed after narrowing the probe to gate-weight only: + `tests/test_fabric_runtime.py -k + "transition_message_reverse_table_window_matches_mixed_step_loop"` reported + `1 passed`. GPU 1 K>1/K128 parity reported `3 passed`; GPU 2 reset parity + reported `4 passed`; GPU 3 T=1/K=1 mixed training parity reported + `2 passed`. The first gate-bias version failed the low-level parity test on + the existing gate-bias layout, so gate bias was not kept in the probe. +- Rejection evidence: + owner-timed T64/K128/H64 terminal mixed row + `/tmp/redo_fixmaass_gate_weight_t64_k128_h64_owner` passed the K-adjusted + reference gate but regressed throughput to `14.028 tok/s`, peak + `0.124541 GiB`, matched T=1 training `63.282 tok/s`, and K-adjusted floor + `0.4944 tok/s`. The intended owner moved in the wrong direction: + `transition_message_reverse_table_device_loop` grew to + `2166.513 ms/count=128` versus the accepted transition-binding baseline + `1748.542 ms/count=128`. The atomic in-kernel gate-weight reduction is + therefore slower than the previous post-kernel bmm reduction on the active + row. +- Decision: + reject and roll back the kernel code. Do not retry per-receiver atomic + gate-weight reduction for this owner. The next R4/R13 work must remove a + larger reverse/replay window boundary or introduce a table-owned reverse + superop that reduces the first-order replay/reverse loop rather than moving + efficient small reductions into contended atomics. + +### 2026-04-29 UTC - R13 active reverse no-output artifact payload + +Status: ACCEPTED AS A SMALL GENERIC REPLAY/REVERSE PAYLOAD REDUCTION; R4/R13 +REMAIN OPEN. + +Owner: active shared temporal artifact replay shell between the CUDA replay +scan and CUDA reverse table. + +- Current invariant: + active reverse-only windows whose physical steps do not emit an output + artifact already have the replay-produced tensor windows required by the + reverse table. The temporal backward executor should pass that table-owned + window payload directly into the reverse-table path instead of allocating a + placeholder `TemporalBucketStepArtifacts` object for every physical step. +- Current measured reason: + the accepted path still reports `temporal_artifact_recompute`, + `artifact.recompute.cuda_temporal_replay_scan`, and + `transition_message_reverse_table_device_loop` as first-order owners. The + replay scan itself is still the dominant recompute work, but the Python shell + still builds H per-step artifacts for active reverse-only windows even when + the reverse table consumes `TemporalReverseWindowTables` and the current + window has no output-backward artifact. +- Planned implementation: + add a generic `TemporalReverseWindowPayload` for active reverse-only windows + with no requested output artifact. `_try_cuda_mixed_flat_bucket_recompute_artifact_window` + will return this payload after constructing the replay tensor tables. The + physical backward scan executor will run the reverse table from the payload + using schedule-derived boundary rows and boundary-gradient accumulation, not + per-step placeholder artifacts. Windows with output artifacts, resets, + fallback replay, or stored artifacts keep the existing list path. +- Manual boundary review: + the payload ABI is window start/end, outer boundary sequence, template cells, + and the existing `TemporalReverseWindowTables`. It adds no cell-kind + selector, population-name selector, benchmark-row selector, hidden-size + policy key, single/mixed route selector, lattice/config field, or + cell-family parameter bundle. It is a replay/reverse materialization boundary + change only; message and transition math remain in lowered tensor-table + primitives. +- Acceptance gate: + static checks, backend boundary tests, low-level reverse-table parity, K>1 / + K=128 terminal-gradient parity, reset parity, T=1/K=1 mixed training parity, + and high-level current-code T64/K128/H64 plus T=1/B1024 mixed audit evidence. + Accept only if the new payload tag appears on active T*K rows and warmed + throughput/memory do not regress against the accepted transition-binding / + input-cache envelope. +- Implementation: + active reverse-only recompute windows now return `TemporalReverseWindowPayload` + when no output artifact is requested and the replay checkpoint matches the + window start. The physical backward executor calls the existing CUDA reverse + table from that payload and uses schedule-derived boundary rows for compact + outer boundary backward. Windows with output artifacts, stored artifacts, or + fallback replay keep the existing per-step artifact list path. +- Validation: + `py_compile`, Ruff check, Ruff format check, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. CUDA validation passed for + low-level mixed reverse-table parity on GPU0 (`1 passed`), K>1/K128 terminal + gradient parity on GPU1 (`3 passed`), reset parity on GPU2 (`4 passed`), and + T=1/K=1 mixed reverse-device-loop parity on GPU3 (`2 passed`), all with + private Torch extension and Triton cache directories. +- Audit evidence: + owner-timed high-level T64/K128/H64 terminal mixed row + `/tmp/redo_fixmaass_payload_t64_k128_h64_owner` passed at `15.764 tok/s`, + peak `0.124533 GiB`, matched T=1 training `62.633 tok/s`, and K-adjusted + floor `0.4893 tok/s`. The row records + `cuda_recompute_active_reverse_step_artifacts_elided` and + `cuda_boundary_backward_compact_outer_schedule_inputs`. Dominant owners + remain effectively unchanged: + `transition_message_reverse_table_device_loop:ms=1749.598;count=128`, + `temporal_artifact_recompute:ms=1189.740;count=128`, and + `artifact.recompute.cuda_temporal_replay_scan:ms=1140.738;count=128`. + Same-cache non-timed confirmations + `/tmp/redo_fixmaass_payload_t64_k128_h64_rerun` and + `/tmp/redo_fixmaass_payload_t64_k128_h64_rerun2` reached `15.910 tok/s` and + `15.933 tok/s`, both above the prior accepted transition-binding warmed row + (`15.885 tok/s`) with peak memory flat at `0.124533 GiB`. + T=1/B1024 mixed guard `/tmp/redo_fixmaass_payload_t1_b1024_mixed` passed at + `61,473 tok/s`, peak `0.673192 GiB`, and mixed-stack ratio `3.668x`. +- Decision: + accept this as a small generic replay/reverse boundary reduction because it + removes placeholder per-step artifacts for no-output active reverse windows, + preserves parity and high-level API semantics, records the payload/schedule + tags, and keeps warmed T*K plus T=1 throughput inside the accepted envelope. + This still does not close R4/R13: the first-order CUDA replay/reverse owners + and planner-level `python_autograd_scan` remain. + +### 2026-04-29 UTC - R4 sparse direct-public reverse-table ABI + +Status: ACCEPTED AS A GENERIC OUTPUT-GRAD MATERIALIZATION REDUCTION; R4/R13 +REMAIN OPEN. + +Owner: active shared flat-bucket reverse table ABI for sparse output-active +T*K windows. + +- Current invariant: + direct public gradients from output backward are generic public/recurrent + boundary gradients. In T*K windows only a few physical steps may emit output + while the reverse table still scans the full H window. The temporal executor + should pass sparse direct-public rows plus their physical local-step indices + into the table instead of materializing a mostly zero `[H,B,R,h]` direct + public tensor when the active steps are sparse. +- Manual boundary review: + ABI additions are `grad.direct_public_y_step_indices`, a sparse/dense + materialization flag, and a direct-step count. The data is temporal + schedule/output-gradient metadata and flat recurrent-bank gradients. No + cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, single/mixed route selector, lattice/config field, or + cell-family parameter bundle is added. Dense direct-public windows remain the + default for low-level callers and dense sequence-output rows. +- Implementation: + `try_transition_message_reverse_table_window_cuda` now accepts optional + sparse direct-public step indices. The reverse CUDA table resolves a direct + public gradient for local step `t` by looking up that compact index vector; + missing local steps contribute zero direct public gradient while recurrent + carry still flows through the existing public-carry workspace. The active + temporal backward executor emits the sparse representation only when active + direct-public steps are at most one quarter of the H window; otherwise it + keeps the dense window to avoid hot-loop sparse lookup overhead. +- Validation: + `py_compile`, Ruff check, Ruff format check, `git diff --check`, and + `tests/test_fabric_backend_boundaries.py` passed. CUDA validation passed for + low-level mixed reverse-table parity on GPU0 (`1 passed`), K>1/K128 terminal + gradient parity on GPU1 (`3 passed`), reset parity on GPU2 (`4 passed`), and + T=1/K=1 mixed reverse-device-loop parity on GPU3 (`2 passed`), all with + private Torch extension and Triton cache directories. +- Audit evidence: + owner-timed high-level terminal T64/K128/H64 mixed row + `/tmp/redo_fixmaass_sparse_direct_t64_k128_h64_owner` passed at + `15.803 tok/s`, peak `0.124533 GiB`, matched T=1 training + `67.785 tok/s`, and K-adjusted floor `0.5296 tok/s`. It records + `temporal_backward_glue:cuda_reverse_engine_sparse_direct_public_window` + for the output-active window and + `temporal_backward_glue:cuda_reverse_engine_absent_direct_public_window` + for inactive no-output windows. First-order owners remain + `transition_message_reverse_table_device_loop:ms=1745.074;count=128`, + `temporal_artifact_recompute:ms=1188.464;count=128`, and + `artifact.recompute.cuda_temporal_replay_scan:ms=1139.361;count=128`. + Warmed non-timed confirmation + `/tmp/redo_fixmaass_sparse_direct_t64_k128_h64_rerun` reached + `15.902 tok/s`, peak `0.124533 GiB`, and passed the K-adjusted gate. + T=1/B1024 mixed guard `/tmp/redo_fixmaass_sparse_direct_t1_b1024_mixed` + passed at `62,429.812 tok/s`, peak `0.673192 GiB`, with no sparse-direct tag + on the active T=1 path. A small sequence-loss T*K smoke + `/tmp/redo_fixmaass_sparse_direct_t8_k128_h64_sequence` passed at + `15.112 tok/s`, peak `0.107309 GiB`, and recorded the sparse-direct tag. +- Decision: + accept this as a generic reverse-table materialization reduction because it + removes mostly zero direct-public H-window tensors for sparse output-active + T*K windows without changing message/cell primitive math or high-level API + semantics. Do not count it as R4/R13 closure: replay and reverse remain + separate first-order CUDA owners, and planner-level forward/backward owners + still report `python_autograd_scan`. + +### 2026-04-29 UTC - R4/R13 CUDA temporal owner audit guardrail + +Status: ACCEPTED AS A CLOSURE GUARDRAIL; R4/R13 REMAIN OPEN. + +Owner: audit closure criteria for the shared temporal engine. + +- Current invariant: + a CUDA temporal-superop closure claim is only valid when planner metadata and + runtime metadata agree, and when the first-order transitional replay/reverse + shell owners are absent from the measured owner list. A planner relabel alone + must fail the audit even if ordinary throughput and reference gates pass. +- Implementation: + `require_cuda_temporal_owner` now checks the runtime forward scan owner and + implementation, the backward runtime executor, transitional backward owner + timings (`temporal_artifact_recompute`, + `artifact.recompute.cuda_temporal_replay_scan`, + `transition_message_reverse_table_device_loop`, and replay input-projection + variants), transitional recompute markers such as + `physical_recompute_bridge` / `python_step_replay`, and host reverse-scan + aliases. `benchmarks/fabric/suite_common.py` now exports + `backward_workspace_aliases` into the audit planner signature so the gate can + see reverse-scan ownership aliases. +- Validation: + `py_compile`, Ruff check, Ruff format check, and + `tests/test_fabric_audit_runner.py` passed (`18 passed`). The new tests cover + planner relabel with a non-CUDA runtime scan owner and planner relabel with + transitional reverse-table timing still present. +- Decision: + accept this guardrail. It intentionally does not close R4/R13; it prevents a + future fake close until the live backend removes the open replay/reverse + owners rather than merely changing `temporal_plan_forward_owners` or + `temporal_plan_backward_owners`. + +### 2026-04-29 UTC - R4 reverse-table redundant grid-sync owner + +Status: ACCEPTED MICRO-CLEANUP; R4/R13 REMAIN OPEN. + +Owner: active shared flat-bucket transition/message reverse table device loop. + +- Current invariant: + the reverse table owns the H-window reverse scan over flat bucket identity, + tensor-table roles, reset policy, and message/transition primitive rows. + Synchronization inside the CUDA temporal primitive must exist only where a + later phase reads data written by an earlier phase. Extra cooperative-grid + barriers in the per-timestep loop are part of the measured reverse-table + owner and should be removed when they are not guarding a dependency. +- Current measured reason: + the latest accepted T64/K128/H64 row still records + `transition_message_reverse_table_device_loop:ms=1745.074;count=128` as the + largest owner. In active reverse-only windows + `materialize_hidden_message_window=False`, so the kernel stores + `grad_hidden_message_window` only at `t==0`. For every earlier timestep it + currently runs the sender-reverse synchronization and then immediately runs a + second end-of-loop `grid.sync()` with no intervening work. That second barrier + is redundant: the sender/message barrier already makes `public_carry_work` + ready for the next transition step when no store reads it. +- Implementation: + keep the existing tensor-table ABI and transition/message math. Compute a + per-step `store_hidden_message` predicate. After the sender-reverse phase, + synchronize before the store only when the store will read + `public_carry_work`; otherwise that synchronization is the only dependency + barrier before the next timestep. After the optional store, synchronize only + when another timestep will follow and the store actually read + `public_carry_work`. This removes the redundant back-to-back barrier on + active reverse-only no-store timesteps and the final post-store barrier at + `t==0`. +- Manual boundary review: + the edit changes no ABI inputs or outputs. It consumes the existing flat + tensor-table roles, op rows, reset/message tables, and scalar materialization + flag. It adds no cell-kind selector, population-name selector, benchmark-row + selector, hidden-size policy key, single/mixed route selector, lattice/config + field, or cell-family parameter bundle. +- Validation: + `git diff --check`; `uv run ruff check tests/test_fabric_backend_boundaries.py + benchmarks/fabric/audit.py benchmarks/fabric/suite_common.py`; GPU0 + low-level reverse-table parity + `transition_message_reverse_table_window_matches_mixed_step_loop`; GPU1 + K>1/K128 terminal-gradient parity; GPU2 reset parity; GPU3 T=1 mixed + reverse-device-loop parity. All passed with private cache dirs and GPUs 0-3. +- Audit evidence: + GPU4 owner-timed high-level T64/K128/H64 mixed terminal run passed the + K-adjusted T=1 training gate at `15.799 tok/s`, but the dominant owner stayed + effectively unchanged: + `transition_message_reverse_table_device_loop:ms=1741.663;count=128`. + GPU4 non-timed warmed rerun passed at `15.945 tok/s`, peak `0.125 GiB`, with + matched T=1 training `75.829 tok/s` and K-adjusted floor `0.592 tok/s`. +- Remaining: + this does not close R4/R13. It only removes a redundant synchronization point. + The next high-priority work remains moving the active reverse scan/recompute + loop itself into the table-owned CUDA temporal superop and eliminating the + `python_autograd_scan` planner owners. + +### 2026-04-29 UTC - R3 forward planner owner aligned with CUDA superop + +Status: ACCEPTED AS R3 OWNER ALIGNMENT; R4 REMAINS OPEN. + +Owner: planner temporal engine ownership metadata for the already-CUDA forward +scan. + +- Current invariant: + planner ownership metadata must match the active physical owner. Metadata is + not closure by itself, but a stale `python_autograd_scan` planner owner is + also wrong once the runtime forward scan is physically executed by the CUDA + temporal superop. +- Current measured/code fact: + accepted high-level rows already report + `launch_temporal_scan_owners=("cuda_temporal_superop",)` and + `launch_scan_implementations=("cuda_temporal_superop",)` for the shared + flat-bucket temporal forward path. The planner still emitted + `temporal_plan_forward_owners=("python_autograd_scan",)`, so R3/R4 appeared + equally open even though the real remaining owner is backward reverse/replay. +- Implementation: + `_plan_temporal_engine_owner` now records + `forward_owner="cuda_temporal_superop"` for CUDA flat-transition-bucket plans. + Inference rows use `status="closed_cuda_superop"`. Training rows keep + `backward_owner="python_autograd_scan"` and use + `status="forward_cuda_backward_transitional"` so the planner remains honest + about R4. +- Manual boundary review: + no execution route, tensor ABI, op row, primitive math, message math, reset + policy, checkpoint policy, or materialization policy changed. This adds no + cell-kind selector, population-name selector, benchmark-row selector, + hidden-size policy key, single/mixed route selector, lattice/config field, or + cell-family parameter bundle. It aligns planner metadata with the existing + physical forward owner only. +- Remaining: + R4 is the high-priority blocker. Backward still reports + `backward_owner="python_autograd_scan"` and still runs the active reverse + replay/reverse H-window shell. Do not enter audit closure until that owner is + replaced by the table-owned CUDA temporal reverse superop or fail-closed by + guardrails. + +### 2026-04-29 UTC - R4 table-owned reverse-window sequence controller + +Status: IN PROGRESS. + +Owner: shared multi-pop flat-bucket temporal backward superop. + +- Current invariant: + the backward temporal owner must carry reverse dependencies across H windows + inside the shared flat-bucket temporal backend. Python may pass tensor tables + and receive gradients, but Python must not own the recurrent reverse seed + chain across the streaming horizon. +- Strategic closure order: + 1. introduce a prepared transition/message reverse tensor-table ABI so the + active backward executor can hand multiple already-lowered reverse windows + to the extension without cell-specific arguments; + 2. add a C++ backend controller that runs the existing generic + transition/message reverse table over a sequence of windows and carries + `grad_hidden_message_window[0]`, gated state grads, and diagonal state + grads from one window to the next inside the extension; + 3. wire the active no-output reverse-only H-window path to that controller in + bounded chunks, keeping memory bounded by H while reducing Python-owned + reverse seed orchestration; + 4. only after parity and owner-timed evidence are green, continue moving the + replay/checkpoint H-window orchestration itself under the same backend + controller and then update the planner backward owner. +- Manual boundary review for the planned ABI: + inputs remain flat tensor-table roles, op rows, reset/message tables, + checkpoint/materialization flags, and scalar dimensions inferred from tensor + shapes. The controller adds no cell-kind selector, population-name selector, + benchmark-row selector, hidden-size policy key, single/mixed route selector, + lattice/config field, or cell-family parameter bundle. Message and transition + primitive math remain represented by generic table rows. +- Guardrail: + this substage must not be recorded as R4 closure by itself. R4 remains open + until the active high-level backward path reports CUDA temporal backward + ownership without the transitional replay/reverse timing owners and passes the + T=1, reset, K=128/H=64, and T-scaling parity/performance gates. + +### 2026-04-29 UTC - Corrected April21-shaped T/H audit evidence + +Status: R11/R13 OPEN; current-code evidence rejects the earlier B=1 smoke +interpretation. + +Owner: shared temporal training path, especially temporal backward and bounded +training tape/checkpoint ownership. + +- Correction: + the earlier `B=1`, `1M`, `h=8` T/K rows are smoke rows only. They are not + closure evidence for fixmass throughput. Closure rows must use the + April21-derived batch/params/hidden contracts and matched current-code T1 + baselines. +- Exact high-level audit command run: + `CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_${USER}_redo_t1_slstm500m_b512_h32_seq TRITON_CACHE_DIR=/tmp/cortical_triton_${USER}_redo_t1_slstm500m_b512_h32_seq uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir /tmp/redo_fixmaass_t1_slstm500m_b512_h32_sequence --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm --sizes 500m --modes forward_backward --batches 512 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries sequence --population-modes single --reset-modes absent --warmup 1 --iterations 3 --enforce-references` +- Result: + `t1-single-pop_slstm_500m_forward_backward_b512_t1_k1_h32_ghnone_ckplanner_losssequence_popsingle_resetabsent` + completed through the high-level API but failed the enforced gate: + `54.287 tok/s`, peak `125.569 GiB`, `actual_params=500764896`, + `fabric_shape=[48,1024]`, `d_hidden=1536`. +- Owner metadata from that row: + forward reports `launch_scan_implementations=["cuda_temporal_superop"]` and + `temporal_plan_forward_owners=["cuda_temporal_superop"]`, but training still + reports `temporal_plan_backward_owners=["python_autograd_scan"]`, + `flat_bucket_temporal_reverse_scan_owner:python_host_reverse_loop`, + `physical_temporal_bucket_sequence_backward`, and + `cuda_temporal_backward_glue`. +- Scientific implication: + the current backend does not yet have a valid matched T1 training floor for + the `B=512`, `500M`, `h=32`, sequence-loss contract. Any `T=4096,K=1,H=64` + run is evidence about the same open owner, not closure evidence, until the T1 + row is healthy and backward ownership moves into the CUDA temporal superop. +- Running follow-up: + the matching high-level `T=4096,K=1,H=64` row was launched on GPU1 with + private caches at + `/tmp/redo_fixmaass_t4096_k1_h64_slstm500m_b512_h32_sequence`. +- Follow-up result: + `tk-scaling_slstm_500m_forward_backward_b512_t4096_k1_h32_gh64_ckplanner_losssequence_popsingle_resetabsent` + failed with CUDA OOM before producing throughput. The row used the correct + April21 exact streaming reference + `streaming_sequence_loss:slstm:500m:b512:t4096:h32` + (`104115.17 tok/s`, `29.00 GiB`). The live run attempted a `32.00 GiB` + allocation with `129.12 GiB` already in use and `128.34 GiB` allocated by + PyTorch. +- Updated owner: + do not narrow this to only "full-sequence input K/V materialization" without + proof. The current evidence says the actual owner is the shared temporal + training path as a whole: T1 sequence training already fails throughput and + memory, and T4096/H64 OOMs. The next backend work must first make the matched + T1 training row healthy while preserving the CUDA forward superop, then rerun + T4096/H64 against that current-code T1 floor. + +### 2026-04-29 UTC - Strategic reset after corrected audit evidence + +Status: ACTIVE PLAN; R11/R13 are blocked by R3/R4 training-path health. + +Owner: shared temporal backend, not planner metadata and not benchmark probes. + +- Immediate correction: + stop treating K=128/B=1 rows or small smoke rows as meaningful closure + evidence. They can catch semantic regressions, but they cannot close T=1, + T*K, H, memory, or throughput stages. +- Current hard blocker: + the April21-shaped current-code training contract is not healthy even at + `T=1`. The representative `sLSTM 500M, B=512, h=32, sequence loss` row + reports `54.287 tok/s` and `125.569 GiB`, while backward metadata still says + `python_autograd_scan`/`python_host_reverse_loop`. This blocks R11 and makes + R13 measurements non-closure evidence. +- Correct closure order from here: + 1. cleanly checkpoint or reject the current uncommitted R4 reverse-window ABI + work so the tree is not carrying an unproven distraction; + 2. fix the matched T1 training path first, with owner timing and memory + accounting enabled, until the active row reports CUDA temporal backward + ownership and no April21 throughput/memory regression; + 3. only after T1 is healthy, rerun T=512/T=4096 K=1 H=64 per-timestep rows + against the matched current-code T1 floor; + 4. then sweep K=1..128 and terminal/per-timestep loss, judging K>1 against + the matched T1 floor divided by K; + 5. only after R11-R14 pass, run R15 legacy/config/message/graph cleanup and + deletion. +- Audit cadence correction: + full closure suites wait until the relevant backend owner is credible, but + selected representative audits must run continuously during implementation. + Do not disappear into kernel or planner work for long stretches without + rerunning the smallest representative high-level rows that can reopen the + owner. These guardrail audits are not closure evidence unless they satisfy + the full matrix, but they are mandatory steering evidence. +- Required representative guardrail rows while rebuilding R3/R4: + - T=1 exact April21-shaped training guard: + `sLSTM 500M, B=512, h=32, sequence loss`, plus at least one April21 + T=1 `h32_t1_bxparams` row with `B=1024` when the path changes. + - T/H guard: + `sLSTM 500M, B=512, h=32, T=512 and T=4096, K=1, H=64, per-timestep + loss`, judged against the matched current-code T1 row and April21 exact + streaming references. + - K guard: + K sweep probes at `K=1,2,8,32,128` on a representative row, judging K>1 + against matched T1/K. Small B=1 rows may be semantic smoke only; they do + not replace April21-shaped representative probes. + - Reset guard: + reset/no-reset parity for T=1 and T>1 after any temporal state, carry, + checkpoint, or reverse-scan edit. + - Mixed-pop guard: + one mixed-pop T=1 training row after any flat-bucket, message, or temporal + owner change, to prevent reintroducing separate single/mixed execution. + - Small-h/shape guard: + at least one h=4 or h=8 representative row after bucket/layout changes. +- Audit interpretation: + a guardrail failure immediately reopens the owning R-stage and changes the + next code target. A guardrail pass only means the current slice did not break + that representative contract; it does not close R11-R14 until the full audit + matrix passes. +- Backend implementation direction: + the next real code must move temporal training ownership, not relabel it. + The CUDA temporal engine needs a generic table-owned backward path over flat + bucket identity, emission gradients, reset policy, checkpoint/recompute + policy, and primitive/message op rows. Python may build tensor tables and call + the extension; it must not own scan loops, K loops, H windows, transition + tape, checkpoint replay, or reverse dependency carry for closure rows. +- Scientific rule: + do not claim a memory root cause from code inspection alone. For every + T/H/K failure, first record the matched current-code T1 baseline, the exact + April21 reference key, owner metadata, peak memory, and whether the failure is + already present at T1. + +### 2026-04-29 UTC - Full backend audit dry-run correction + +Status: DRY-RUN ONLY; NOT THROUGHPUT EVIDENCE. + +Owner: audit/performance with reopen ownership to R1-R4/R11-R14 as failures +appear. + +- User correction: + do not rely only on selected probes. Run the full backend audit matrix often + enough to catch wrong owners and reopen older stages. +- Dry-run case lists: + - T1 backend matrix: + `/tmp/redo_fixmaass_dry_full_t1_backend`, 96 cases. + Axes: single+mixed population, sLSTM+Axon, 100M/500M/1B, + forward+forward_backward, B=1024/16384, terminal+sequence loss, h=32. + - T*K/H backend matrix: + `/tmp/redo_fixmaass_dry_full_tk_backend`, 160 cases. + Axes: single+mixed population, sLSTM+Axon, 500M/1B, B=512, + T=512/4096, K=1/2/8/32/128, H=64, terminal+sequence loss, h=32. +- Correction: + these case lists are only argument expansion checks. They are not throughput + evidence, they do not move any REDO owner, and they must not be described as + an audit sweep. The only throughput evidence in the current log is from + measured `cases.jsonl` rows such as the failing T=1 and T=4096 rows above. +- Execution policy: + run measured throughput rows, not manifests. Every measured row must report + `tokens_per_s`, `peak_mem_gib`, April 21 reference key, planner/runtime owner + metadata, and whether `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1` was enabled. + Passing rows are not closure until rerun with closure repetitions and the + full parity matrix; failing rows immediately reopen their owner. +- GPU policy: + shard across GPUs 0-4 only with private Torch/Triton cache directories. + +### 2026-04-29 UTC - Throughput-first deep-dive closure plan + +Status: ACTIVE; replaces manifest-centered work. + +Owner: shared temporal training backend, then audit/performance. + +Current measured truth: + +- The high-level API path is valid for the failing representative rows: + benchmark iteration is `model(...)`, external MSE loss, `loss.backward()`, + optimizer step. It is not using a private planner/runtime hook to implement + the temporal behavior. +- The representative T=1 training row is the current blocker: + `sLSTM 500M, B=512, h=32, T=1, K=1, sequence loss`, artifact + `/tmp/redo_fixmaass_t1_slstm500m_b512_h32_sequence/cases.jsonl`. + It measured `54.287 tok/s`, peak `125.569 GiB`, against the April 21 + `streaming_per_timestep_sequence_loss` floor of `95,638.46 tok/s` and + `61.21 GiB`. +- That same row reports forward through `cuda_temporal_superop`, but backward + still reports `temporal_plan_backward_owners=["python_autograd_scan"]`, + `flat_bucket_temporal_reverse_scan_owner:python_host_reverse_loop`, + `physical_temporal_bucket_sequence_backward`, and + `cuda_temporal_backward_glue`. +- The matched T=4096 K=1 H=64 per-timestep row, + `/tmp/redo_fixmaass_t4096_k1_h64_slstm500m_b512_h32_sequence/cases.jsonl`, + OOMed before throughput against the exact April 21 row + `streaming_sequence_loss:slstm:500m:b512:t4096:h32` + (`104,115.17 tok/s`, `29.00 GiB`). This is not a separate proven root cause + while T=1 training is already broken. + +Immediate implementation plan: + +1. First throughput probe, not manifest: + rerun the T=1 representative with private caches and + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1`. The required output is a ranked + owner table plus tok/s and memory. If owner timing is empty, fix + instrumentation before doing more optimization, because we need to know + whether the time is in transition reverse, boundary/readout backward, + artifact recompute, state materialization, parameter binding, or Python + shell overhead. +2. Review the uncommitted reverse-window sequence-controller slice: + it currently only wraps one prepared reverse call in + `transition_message_reverse_table_window_sequence`. That is not enough to + move throughput or close R4. Accept it only if the owner-timed T=1 and + reset/T*K parity prove no regression; otherwise reject or reshape it. +3. Fix T=1 training throughput before chasing long-T or K: + T=1 must use the shared temporal training path with CUDA-owned backward + metadata and no host reverse-loop alias. The first target is eliminating the + `python_autograd_scan`/`python_host_reverse_loop` ownership on the T=1 row, + because a broken T=1 floor makes all T*K ratios meaningless. +4. Move real backward ownership into the generic temporal superop: + replace the host-level reverse scan/window shell with a table-owned CUDA + controller over flat bucket identity, tensor roles, op rows, reset policy, + emission gradients, checkpoint/recompute policy, and boundary gradients. + The ABI must stay generic: no cell-kind selector, no population-name + selector, no benchmark-row selector, no lattice/config field, no hidden-size + policy key, no separate single/mixed route. +5. Reduce materialization/memory only after owner evidence: + do not guess that input K/V, transition tape, or final state materialization + is the root cause without the T=1 owner profile. The acceptable fixes are + backend-owned streaming/checkpoint/tensor-table changes; benchmark tiling or + loss streaming is not allowed. + +Throughput checkpoints during implementation: + +- After each backend slice, run the measured T=1 representative again with + owner timing. It must move toward the April 21 floor and must not increase + peak memory. +- Once T=1 owner metadata and throughput are healthy, immediately run + `T=512,K=1,H=64` and `T=4096,K=1,H=64` per-timestep loss. T scaling must be + flat or better relative to the matched current-code T=1 floor and must also + satisfy the exact April 21 streaming rows where applicable. +- Then run K probes on the same contract: `K=1,2,8,32,128`, judged against + matched T=1 training throughput divided by K. K=128 is the accepted ceiling. +- Then run mixed-pop T=1 training and reset/no-reset parity after every + flat-bucket/message/reverse-scan change. Mixed-pop is the same shared + temporal engine with more flat buckets, not a separate backend identity. +- Then run h=4/h=8 shape guards after any bucket/layout changes. Small hidden + rows are real many-cell stress rows, not optional smoke. + +What remains to close REDO_FIXMASS: + +- R1 remains open until the planner can truthfully mark training backward as + `cuda_temporal_superop` from an actual CUDA-owned path, not by metadata + relabeling. +- R2 remains open until all single/mixed population differences are flat bucket + identity and parameter/state bindings only. +- R3 is closer than R4 for current T=1 because forward reports + `cuda_temporal_superop`, but it still needs T>1, mixed-pop, K, reset, and + materialization closure through the same engine. +- R4 is the primary blocker: backward scan/window/recompute/parameter binding + still has Python autograd/host-shell ownership and fails the T=1 throughput + floor. +- R9 audit tooling exists but does not close anything by itself. Its job is to + run high-level API throughput rows and report April 21 references, not to + produce manifests. +- R10 parity must expand around the actual backend path being changed: + T=1/T>1, reset/no-reset, terminal/sequence loss, final state + materialized/unmaterialized, input/state/parameter grads, single/mixed. +- R11 cannot close until all April 21 T=1 single-pop rows meet or exceed + throughput and stay at or below peak memory. +- R12 cannot close until mixed Axon+sLSTM rows use the same temporal engine and + beat matched same-parameter stack/MoE baselines. +- R13 cannot close until T=4096 and frontier T=16K where feasible, H=64, K up + to 128, terminal and per-timestep losses all pass the matched T=1/K criteria + through normal autograd. +- R14 cannot close until h reduction and factorization spread are measured on + the fixed shared backend. +- R15 is broad cleanup after backend proof: remove legacy execution paths, + `Config`/`anatomy` ownership leaks, lattice facts from generic Fabric, + message-rule hardcoding, public Config facade dependence, benchmark wrappers, + hidden-size/cell-family route policy, and direct message/cell math outside + declared `fabric.cuda.nn` primitives. + +Next concrete action: + +- Run the owner-timed T=1 representative throughput row on GPU 0 with private + caches, then choose the next code slice from the largest measured backend + owner. No dry-run case list counts as progress. + +### 2026-04-29 UTC - Owner-timed T=1 throughput probe + +Status: FAILED GATE; ACTIONABLE OWNER FOUND. + +Owner: R4 shared temporal training backward and artifact/recompute policy. + +- Command: + `CORTICAL_FABRIC_BACKWARD_OWNER_TIMING=1 CUDA_VISIBLE_DEVICES=0 TORCH_EXTENSIONS_DIR=/tmp/cortical_torch_ext_${USER}_redo_t1_slstm500m_b512_h32_seq TRITON_CACHE_DIR=/tmp/cortical_triton_${USER}_redo_t1_slstm500m_b512_h32_seq uv run python -m benchmarks.fabric.run_audit --plan t1-single-pop --out-dir /tmp/redo_fixmaass_t1_slstm500m_b512_h32_sequence_owner_timing --baseline-json audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json --families slstm --sizes 500m --modes forward_backward --batches 512 --seq-lens 1 --inner-steps 1 --gradient-horizon-steps none --checkpoint-steps none --hidden-sizes 32 --training-output-boundaries sequence --population-modes single --reset-modes absent --warmup 0 --iterations 1 --enforce-references` +- Result artifact: + `/tmp/redo_fixmaass_t1_slstm500m_b512_h32_sequence_owner_timing/cases.jsonl`. +- Measured result: + `52.781 tok/s`, `9700.488 ms`, peak `125.569 GiB`. Gate failed against + April 21 `streaming_per_timestep_sequence_loss` floor of `95,638.46 tok/s`. +- Owner timing, GPU-event order: + `temporal_artifact_recompute=4096.730 ms`, + `artifact.recompute.cuda_temporal_replay_scan=4027.582 ms`, + `state_epilogue.core=1029.082 ms`, + `state_epilogue.gated_core_recurrent_affine_window=1029.064 ms`, + `message.recurrent_initial_kv_backward=252.110 ms`, + then smaller receiver/message/projection/binding owners. +- Owner timing, wall order: + `temporal_artifact_recompute=4093.185 ms`, + `artifact.recompute.recurrent_hidden_before_window=4065.914 ms`, + `receiver_affine.gate_affine_backward=1033.011 ms`, + `public_projection=253.992 ms`. +- Metadata still reports + `temporal_plan_backward_owners=["python_autograd_scan"]`, + `flat_bucket_temporal_reverse_scan_owner:python_host_reverse_loop`, and + `temporal_artifacts:recompute_step_artifacts`. + +Deep-dive interpretation: + +- The T=1 base training row is paying a full CUDA replay/recompute bridge even + though T=1 is the base streaming case. This is throughput poison and also + explains why large T/H cannot be interpreted scientifically yet. +- The policy-level helper `_temporal_artifact_store_policy` has an explicit + `time_steps<=1` store-artifacts guard, but the active CUDA temporal superop + path in `_try_cuda_mixed_flat_bucket_temporal_scan` always returns a + `recompute_step_artifacts` `TemporalArtifactStore` when `collect_artifacts` + is true. That bypasses the T=1 store guard and forces the backward executor + through artifact replay. +- The next code slice must not be a planner relabel. It should make the CUDA + temporal forward/training path expose the required backward artifact/tape + tables for T=1 and selected bounded windows, or otherwise let the CUDA + temporal backward consume the forward scan's tensor-table outputs directly, + without a host-owned replay bridge. This must remain generic over flat bucket + tensor/op tables. +- Boundary condition: + do not encode sLSTM/Axon, single-pop, lattice, hidden-size, or benchmark row + logic in this fix. The allowed design surface is tensor roles, op rows, + reset/emission/checkpoint policy, flat bucket identity, and primitive tables. + +Immediate code target: + +1. Add a generic CUDA-scan artifact-store path for `time_steps==1` and for + planner-selected stored windows: + request recurrent message/output message/transition tape artifacts from the + CUDA temporal scan only when planner policy says they are needed and memory + budget allows. +2. Convert those CUDA scan outputs into `TemporalBucketStepArtifacts` or a + slimmer tensor-table payload consumed directly by + `TemporalPhysicalBackwardScanExecutor`, without calling + `_recompute_temporal_bucket_artifact_window` for the T=1 base case. +3. Rerun strict T=1 parity and reset parity, then rerun the owner-timed T=1 + throughput row. The first pass target is to remove the 4.1s + `temporal_artifact_recompute` owner and reduce peak memory, before attacking + the 1.0s state-epilogue backward owner. + +### 2026-04-29 UTC - REDO2 scalability correction + +Status: ACTIVE PLAN MOVED TO `ai_docs/REDO2_FIXMASS.md`. + +Owner: REDO2 planning / audit matrix. + +- User correction: + T/H closure cannot be a few representative rows. Every shape and size used + for T=1 closure must also be run through T/H closure so Fabric proves + scalability instead of overfitting one row. +- Action: + added `ai_docs/REDO2_FIXMASS.md` as the clean throughput-first plan. REDO2 + explicitly makes the T=1 matrix the parent matrix for T/H, K/H, hidden-size, + factorization, single-pop, mixed-pop, reset, terminal, and per-timestep + closure. +- Durable rule: + representative probes steer implementation, but final T/H/K closure inherits + the full T=1 shape/size matrix unless a row is explicitly inapplicable and + documented with a backend-owned reason. diff --git a/ai_docs/recovered_core.py b/ai_docs/recovered_core.py new file mode 100644 index 00000000..1e10fe55 --- /dev/null +++ b/ai_docs/recovered_core.py @@ -0,0 +1,5999 @@ +from __future__ import annotations + +import math +import os +from collections.abc import Callable, Iterator, Mapping +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass, replace +from typing import Any, Literal, Optional, cast + +import torch +import torch.nn as nn +from tensordict import TensorDict, TensorDictBase +from torch.utils.checkpoint import checkpoint + +from cortical.fabric.anatomy import Spec +from cortical.fabric.backend.caps import DeviceCaps, detect_device_caps +from cortical.fabric.backend.cell_backend import build_cell_backend_spec, trace_elision_core_state_names +from cortical.fabric.backend.cuda.projection.receiver_major_gates import ( + ReceiverMajorProjectionGate, + receiver_major_projection_backward_gate, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.backward_helpers import SharedTemporalBackwardHelperMixin +from cortical.fabric.backend.cuda.sequence_surface.runtime.policy import ( + CudaMemoryBudget, + LayoutBatchTileInputs, + PolicyDecision, + backward_batch_tile_policy, + flat_bucket_sequence_readout_batch_tile_policy, + forward_batch_tile_policy, + forward_layout_batch_tile_policy, + readout_pooled_batch_tile_policy, + tape_checkpoint_policy, + tape_memory_chunk_len, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.support import ( + _BackendOwnerTimingCollector, + _transition_supports_receiver_local_dependency_window, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.executor import ( + _active_region_record_mode, + _flat_bucket_active_output_region_for_inner_steps, + _resolve_temporal_plan_active_region, + execute_temporal_bucket_active_output_window, + execute_temporal_bucket_sequence, + record_temporal_bucket_sequence_surface_execution, +) +from cortical.fabric.backend.graph_regions import ( + ClosedRecurrentRegion, + close_recurrent_region_from_sender_tables, + contiguous_recurrent_region, + recurrent_sender_seed_from_table, +) +from cortical.fabric.backend.ir import compile_fabric_ir +from cortical.fabric.backend.plan_cache import FabricGraphCaptureCache, GraphCaptureCacheKey +from cortical.fabric.backend.planner import ( + FabricExecutionPlanner, + PlannedFabricBackwardExecution, + PlannedFabricExecution, + TemporalExecutionPlan, +) +from cortical.fabric.backend.pytorch.readout import ( + ReadoutConfig, +) +from cortical.fabric.backend.pytorch.readout import ( + pool_output_ports as backend_pool_output_ports, +) +from cortical.fabric.backend.pytorch.readout import ( + readout_output_cells as backend_readout_output_cells, +) +from cortical.fabric.backend.pytorch.readout import ( + select_output_cells as backend_select_output_cells, +) +from cortical.fabric.backend.runtime_dispatch import BackendRuntimeDispatchMixin +from cortical.fabric.backend.selector import select_fabric_backend +from cortical.fabric.backend.surfaces import BackendExecutionRecord +from cortical.fabric.backend.tape import TapeMode, TapePolicy, default_tape_policy +from cortical.fabric.backend.workspace import GraphCaptureWorkspace +from cortical.fabric.cells import build_cell_population_module +from cortical.fabric.contracts.cells import reset_backend_state_rows, reset_backend_tensor_rows +from cortical.fabric.registry.cells import get_cell_spec +from cortical.fabric.runtime.model_temporal import ( + ModelTemporalMixin, + _expand_resets_for_time, + _flatten_tensordict, + _ModelOutputChunkConsumer, + _slice_sequence_k, + _unflatten_tensordict, +) +from cortical.fabric.runtime.state import ( + flatten_backend_packed_state as _flatten_backend_packed_state, +) +from cortical.fabric.runtime.state import ( + unflatten_backend_packed_state as _unflatten_backend_packed_state, +) +from cortical.types import MaybeState, ResetMask, Tensor + +_RuntimeOutputChunkConsumer = Callable[[torch.Tensor, int, int], None] +_BACKWARD_OWNER_TIMING_ENV = "cortical_FABRIC_BACKWARD_OWNER_TIMING" + + +@contextmanager +def _preserve_backend_execution_record(runtime: "Runtime"): + runtime._preserve_backend_execution_record_depth += 1 + try: + yield + finally: + runtime._preserve_backend_execution_record_depth -= 1 + + +def _population_display_name(population_name: str) -> str: + return population_name + + +@dataclass(frozen=True) +class _BackendGraphInputLayout: + input_names: tuple[str, ...] + packed_state_keys: tuple[str, ...] | None + packed_state_input_names: tuple[str, ...] + packed_state_shapes: tuple[tuple[int, ...], ...] + packed_state_is_fresh: bool = False + + +class Runtime(BackendRuntimeDispatchMixin, SharedTemporalBackwardHelperMixin, nn.Module): + def __init__(self, spec: Spec) -> None: + super().__init__() + self.spec = spec + self.config = spec.config + self.hidden_size = int(spec.config.hidden_size) + self.num_heads = int(spec.config.num_heads) + self.head_dim = int(spec.config.head_dim) + self.value_dim = int(spec.config.head_dim) + self._has_edge_delay = spec.anatomy.edge_delay is not None + if spec.config.readout_pool == "mean": + self.readout_slots = 1 + elif spec.config.readout_pool == "flatten": + self.readout_slots = int(spec.output_cell_idx.numel()) + else: + self.readout_slots = int(spec.config.readout_slots) + self._population_names = spec.population_names + self._population_name_to_idx = {name: idx for idx, name in enumerate(self._population_names)} + self._population_cell_types = { + name: self.config.cell_populations[name].cell_type for name in self._population_names + } + + self.register_buffer("cell_layout", spec.anatomy.cell_layout.clone()) + self.register_buffer("neighbor_idx", spec.anatomy.neighbor_idx.clone()) + self.register_buffer("neighbor_idx_flat", spec.anatomy.neighbor_idx.reshape(-1).clone()) + self.register_buffer("neighbor_valid", spec.anatomy.neighbor_valid.clone()) + self.register_buffer("edge_type", spec.anatomy.edge_type.clone()) + self.register_buffer("edge_distance", spec.anatomy.edge_distance.clone()) + self.register_buffer( + "edge_delay", + spec.anatomy.edge_delay.clone() + if spec.anatomy.edge_delay is not None + else torch.ones_like(spec.anatomy.edge_type), + ) + self.register_buffer("kv_group_id", spec.kv_group_id.clone()) + self.register_buffer("recurrent_cell_idx", spec.recurrent_cell_idx.clone()) + self.register_buffer("input_cell_idx", spec.input_cell_idx.clone()) + self.register_buffer("output_cell_idx", spec.output_cell_idx.clone()) + self.register_buffer("coords", spec.anatomy.coords.clone()) + self.register_buffer("local_offsets", spec.anatomy.local_offsets.clone()) + self.register_buffer("local_valid", spec.anatomy.local_valid.clone()) + self.register_buffer("local_distance", spec.anatomy.local_distance.clone()) + self.register_buffer( + "local_delay", + spec.anatomy.local_delay.clone() + if spec.anatomy.local_delay is not None + else torch.zeros(spec.anatomy.local_distance.shape[0], dtype=torch.int32), + ) + self._num_neighbors = int(spec.anatomy.neighbor_idx.shape[1]) + sender_mask = torch.ones(spec.anatomy.num_cells, dtype=torch.bool) + sender_mask[spec.output_cell_idx] = False + sender_cell_idx = torch.nonzero(sender_mask, as_tuple=False).reshape(-1) + sender_lookup = torch.full((spec.anatomy.num_cells,), -1, dtype=torch.long) + sender_lookup[sender_cell_idx] = torch.arange(sender_cell_idx.numel(), dtype=torch.long) + self.register_buffer("sender_cell_idx", sender_cell_idx) + self.register_buffer("sender_lookup", sender_lookup) + self.register_buffer("input_sender_idx", sender_lookup[self.input_cell_idx].clone()) + recurrent_lookup = torch.full((spec.anatomy.num_cells,), -1, dtype=torch.long) + recurrent_lookup[self.recurrent_cell_idx] = torch.arange(self.recurrent_cell_idx.numel(), dtype=torch.long) + self.register_buffer("recurrent_lookup", recurrent_lookup) + self.register_buffer("recurrent_sender_idx", sender_lookup[self.recurrent_cell_idx].clone()) + full_local_sender_idx = _build_local_sender_table( + receiver_coords=spec.anatomy.coords, + sender_lookup=sender_lookup, + local_offsets=spec.anatomy.local_offsets, + local_valid=spec.anatomy.local_valid, + coord_shape=tuple(int(size) for size in spec.config.coord_shape), + wrap=bool(spec.config.wrap), + ) + self.register_buffer("full_local_sender_idx", full_local_sender_idx.to(torch.int32)) + self.register_buffer( + "full_local_receiver_idx_by_sender", + _build_sender_reverse_table(int(sender_cell_idx.numel()), full_local_sender_idx, self.local_valid), + ) + full_kv_group_ids, self._full_kv_group_size = _detect_uniform_contiguous_groups(spec.kv_group_id) + self._full_kv_group_range = _contiguous_cpu_index_range(full_kv_group_ids) + self.register_buffer( + "full_kv_group_ids", + full_kv_group_ids if full_kv_group_ids is not None else torch.empty(0, dtype=torch.long), + ) + sender_kv_group_ids, self._sender_kv_group_size = _detect_uniform_contiguous_groups( + spec.kv_group_id.index_select(0, sender_cell_idx) + ) + self._sender_kv_group_range = _contiguous_cpu_index_range(sender_kv_group_ids) + self.register_buffer( + "sender_kv_group_ids", + sender_kv_group_ids if sender_kv_group_ids is not None else torch.empty(0, dtype=torch.long), + ) + recurrent_sender_group_ids, self._recurrent_sender_kv_group_size = _detect_uniform_contiguous_groups( + spec.kv_group_id.index_select(0, self.recurrent_cell_idx) + ) + self._recurrent_sender_kv_group_range = _contiguous_cpu_index_range(recurrent_sender_group_ids) + self.register_buffer( + "recurrent_sender_kv_group_ids", + recurrent_sender_group_ids if recurrent_sender_group_ids is not None else torch.empty(0, dtype=torch.long), + ) + input_sender_group_ids, self._input_sender_kv_group_size = _detect_uniform_contiguous_groups( + spec.kv_group_id.index_select(0, self.input_cell_idx) + ) + self._input_sender_kv_group_range = _contiguous_cpu_index_range(input_sender_group_ids) + self.register_buffer( + "input_sender_kv_group_ids", + input_sender_group_ids if input_sender_group_ids is not None else torch.empty(0, dtype=torch.long), + ) + self._num_input_cells = int(self.input_cell_idx.numel()) + self._num_recurrent_cells = int(self.recurrent_cell_idx.numel()) + self._num_output_cells = int(self.output_cell_idx.numel()) + num_senders = int(sender_cell_idx.numel()) + self._partitioned_layout = bool( + torch.equal(self.input_cell_idx, torch.arange(self._num_input_cells, dtype=torch.long)) + and torch.equal( + self.recurrent_cell_idx, + torch.arange( + self._num_input_cells, + self._num_input_cells + self._num_recurrent_cells, + dtype=torch.long, + ), + ) + and torch.equal( + self.output_cell_idx, + torch.arange(num_senders, num_senders + self._num_output_cells, dtype=torch.long), + ) + ) + self._input_slice = slice(0, self._num_input_cells) + self._recurrent_slice = slice(self._num_input_cells, self._num_input_cells + self._num_recurrent_cells) + self._output_slice = slice(num_senders, num_senders + self._num_output_cells) + ( + recurrent_neighbor_idx, + recurrent_neighbor_valid, + recurrent_edge_distance, + recurrent_edge_delay, + ) = _select_receiver_tables( + self.neighbor_idx, + self.neighbor_valid, + self.edge_distance, + self.edge_delay, + self.recurrent_cell_idx, + self.sender_lookup, + ) + self.register_buffer("recurrent_neighbor_idx", recurrent_neighbor_idx) + self.register_buffer("recurrent_neighbor_valid", recurrent_neighbor_valid) + self.register_buffer("recurrent_edge_distance", recurrent_edge_distance) + self.register_buffer("recurrent_edge_delay", recurrent_edge_delay) + recurrent_sparse_receiver_order, recurrent_sparse_degree_ptr, recurrent_sparse_positive_degree_buckets = ( + _build_sparse_degree_grouping(recurrent_neighbor_valid) + ) + self.register_buffer("recurrent_sparse_receiver_order", recurrent_sparse_receiver_order) + self.register_buffer("recurrent_sparse_degree_ptr", recurrent_sparse_degree_ptr) + self._recurrent_sparse_positive_degree_buckets = recurrent_sparse_positive_degree_buckets + self.register_buffer("recurrent_local_valid", self.local_valid.index_select(0, self.recurrent_cell_idx)) + recurrent_local_sender_idx = _build_local_sender_table( + receiver_coords=spec.anatomy.coords.index_select(0, self.recurrent_cell_idx), + sender_lookup=sender_lookup, + local_offsets=spec.anatomy.local_offsets, + local_valid=self.recurrent_local_valid, + coord_shape=tuple(int(size) for size in spec.config.coord_shape), + wrap=bool(spec.config.wrap), + ) + self.register_buffer("recurrent_local_sender_idx", recurrent_local_sender_idx.to(torch.int32)) + self.register_buffer( + "recurrent_local_receiver_idx_by_sender", + _build_sender_reverse_table( + int(sender_cell_idx.numel()), + recurrent_local_sender_idx, + self.recurrent_local_valid, + ), + ) + ( + output_neighbor_idx, + output_neighbor_valid, + output_edge_distance, + output_edge_delay, + ) = _select_receiver_tables( + self.neighbor_idx, + self.neighbor_valid, + self.edge_distance, + self.edge_delay, + self.output_cell_idx, + self.sender_lookup, + ) + self.register_buffer("output_neighbor_idx", output_neighbor_idx) + self.register_buffer("output_neighbor_valid", output_neighbor_valid) + self.register_buffer("output_edge_distance", output_edge_distance) + self.register_buffer("output_edge_delay", output_edge_delay) + self.register_buffer("output_local_valid", self.local_valid.index_select(0, self.output_cell_idx)) + output_local_sender_idx = _build_local_sender_table( + receiver_coords=spec.anatomy.coords.index_select(0, self.output_cell_idx), + sender_lookup=sender_lookup, + local_offsets=spec.anatomy.local_offsets, + local_valid=self.output_local_valid, + coord_shape=tuple(int(size) for size in spec.config.coord_shape), + wrap=bool(spec.config.wrap), + ) + self.register_buffer("output_local_sender_idx", output_local_sender_idx.to(torch.int32)) + self.register_buffer( + "output_local_receiver_idx_by_sender", + _build_sender_reverse_table( + int(sender_cell_idx.numel()), + output_local_sender_idx, + self.output_local_valid, + ), + ) + ( + self._output_local_recurrent_window_start, + self._output_local_recurrent_window_count, + self._output_local_recurrent_window_contiguous, + ) = _contiguous_recurrent_sender_window( + num_senders=int(sender_cell_idx.numel()), + recurrent_sender_idx=self.recurrent_sender_idx, + receiver_sender_idx=output_local_sender_idx, + receiver_valid=self.output_local_valid, + ) + ( + self._output_sparse_recurrent_window_start, + self._output_sparse_recurrent_window_count, + self._output_sparse_recurrent_window_contiguous, + ) = _contiguous_recurrent_sender_window( + num_senders=int(sender_cell_idx.numel()), + recurrent_sender_idx=self.recurrent_sender_idx, + receiver_sender_idx=output_neighbor_idx, + receiver_valid=output_neighbor_valid, + ) + self._local_message_step_enabled = bool( + spec.config.patch_edges_per_cell == 0 + and int(self.local_offsets.shape[0]) > 0 + and int(self.coords.shape[1]) <= 3 + ) + self._uses_sparse_message_backend = bool( + not self._local_message_step_enabled + or bool((spec.anatomy.edge_type[spec.anatomy.neighbor_valid] != 0).any().item()) + ) + self._coord_shape = tuple(int(size) for size in spec.config.coord_shape) + + self.slot_embed = nn.Parameter(spec.slot_init.clone()) + + self.public_proj = nn.Linear(self.hidden_size, int(self.config.d_public), bias=False) + self.input_proj = nn.Linear(self.hidden_size, int(self.config.d_msg), bias=False) + self.msg_to_cell = nn.Linear(int(self.config.d_msg), self.hidden_size, bias=False) + self.cell_bias_proj = nn.Linear(int(self.config.d_slot), self.hidden_size, bias=False) + self.q_proj = nn.Linear(int(self.config.d_slot), self.num_heads * self.head_dim, bias=False) + self.k_weight = nn.Parameter( + torch.empty(spec.num_kv_groups, int(self.config.d_public), self.num_heads * self.head_dim) + ) + self.v_weight = nn.Parameter( + torch.empty(spec.num_kv_groups, int(self.config.d_public), self.num_heads * self.value_dim) + ) + self.msg_out = nn.Linear(self.num_heads * self.value_dim, int(self.config.d_msg), bias=False) + self.output_cell_weight = nn.Parameter( + torch.empty(int(self.output_cell_idx.numel()), int(self.config.d_msg), self.hidden_size) + ) + self.output_cell_bias = nn.Parameter(torch.empty(int(self.output_cell_idx.numel()), self.hidden_size)) + self.readout_query = nn.Parameter(torch.empty(self.readout_slots, self.hidden_size)) + self.readout_out = nn.Linear(self.readout_slots * self.hidden_size, self.hidden_size) + + self.population_modules = nn.ModuleDict() + self._full_recurrent_population_name: str | None = None + for name in self._population_names: + indices = self._build_population_indices(name) + self.register_buffer(_population_buffer_name(name), indices) + population_recurrent_idx = self.recurrent_lookup.index_select(0, indices) + population_recurrent_idx = population_recurrent_idx[population_recurrent_idx >= 0] + self.register_buffer(_population_recurrent_buffer_name(name), population_recurrent_idx) + if population_recurrent_idx.numel() == self._num_recurrent_cells and torch.equal( + population_recurrent_idx, torch.arange(self._num_recurrent_cells, dtype=torch.long) + ): + self._full_recurrent_population_name = name + self.population_modules[name] = build_cell_population_module( + self.config.cell_populations[name], + self.hidden_size, + num_cells=int(indices.numel()), + init_noise_std=float(self.config.population_init_noise_std), + ) + self._register_population_backend_order_buffers() + self._backend_ir = compile_fabric_ir( + spec, + hidden_size=self.hidden_size, + d_public=int(self.config.d_public), + d_msg=int(self.config.d_msg), + head_dim=self.head_dim, + value_dim=self.value_dim, + ) + self._backend_population_specs = { + name: build_cell_backend_spec( + cell_type=self._population_cell_types[name], + hidden_size=self.hidden_size, + d_public=int(self.config.d_public), + d_msg=int(self.config.d_msg), + head_dim=self.head_dim, + value_dim=self.value_dim, + ) + for name in self._population_names + } + self._backend_planner = FabricExecutionPlanner( + ir=self._backend_ir, + population_specs=self._backend_population_specs, + active_output_region=ClosedRecurrentRegion( + indices=tuple(int(index) for index in self._flat_bucket_active_output_region_indices), + full_count=int(self._num_recurrent_cells), + ), + output_closure_region=ClosedRecurrentRegion( + indices=tuple(int(index) for index in self._flat_bucket_output_recurrent_closure_indices), + full_count=int(self._num_recurrent_cells), + ), + output_sender_table=( + self.output_neighbor_idx + if bool(getattr(self, "_uses_sparse_message_backend", False)) + else self.output_local_sender_idx + ), + output_sender_valid=( + self.output_neighbor_valid + if bool(getattr(self, "_uses_sparse_message_backend", False)) + else self.output_local_valid + ), + recurrent_sender_table=( + self.recurrent_neighbor_idx + if bool(getattr(self, "_uses_sparse_message_backend", False)) + else self.recurrent_local_sender_idx + ), + recurrent_sender_valid=( + self.recurrent_neighbor_valid + if bool(getattr(self, "_uses_sparse_message_backend", False)) + else self.recurrent_local_valid + ), + num_input_senders=int(self._num_input_cells), + ) + self._backend_device_caps_cache: dict[tuple[str, int], DeviceCaps] = {} + self._backend_graph_capture_cache = FabricGraphCaptureCache() + self._last_backend_execution: BackendExecutionRecord | None = None + self._last_backend_launch_metadata: dict[str, tuple[Any, ...]] | None = None + self._last_backend_tape_chunk_len: int | None = None + self._last_backend_tape_chunk_reason: str | None = None + self._last_backend_tape_artifact_mode: str | None = None + self._last_backend_recompute_artifact_window_len: int | None = None + self._last_backend_recompute_artifact_window_reason: str | None = None + self._last_backend_recompute_checkpoint_stride: int | None = None + self._last_backend_recompute_checkpoint_count: int | None = None + self._last_backend_recompute_checkpoint_reason: str | None = None + self._last_backend_recompute_predecessor_cache_mode: str | None = None + self._last_backend_recompute_transition_tape_mode: str | None = None + self._last_backend_recompute_transition_tape_reason: str | None = None + self._last_backend_backward_batch_tile_len: int | None = None + self._last_backend_backward_batch_tile_reason: str | None = None + self._last_backend_forward_batch_tile_len: int | None = None + self._last_backend_forward_batch_tile_reason: str | None = None + self._preserve_backend_execution_record_depth: int = 0 + self._active_backend_name: str = "pytorch" + self._constant_step_flat_cache: dict[tuple[str, int, int, int, str, int], torch.Tensor] = {} + self._inference_static_cache: dict[tuple[object, ...], dict[str, object]] = {} + self._training_static_cache: dict[tuple[object, ...], dict[str, object]] = {} + + self._reset_parameters() + + def _clear_execution_caches(self) -> None: + self._clear_inference_static_cache() + self._training_static_cache.clear() + self._constant_step_flat_cache.clear() + self._backend_graph_capture_cache.clear() + self._clear_active_output_temporal_graph_caches() + direct_sender_kv_group_id_cache = getattr(self, "_direct_sender_kv_group_id_cache", None) + if direct_sender_kv_group_id_cache is not None: + direct_sender_kv_group_id_cache.clear() + + def _clear_active_output_temporal_graph_caches(self) -> None: + self._active_output_temporal_graph_static_parameter_versions = None + active_output_adjoint_graph_cache = getattr( + self, + "_active_output_recurrent_adjoint_cuda_graph_cache", + None, + ) + if active_output_adjoint_graph_cache is not None: + active_output_adjoint_graph_cache.clear() + active_output_forward_graph_cache = getattr( + self, + "_active_output_forward_recompute_cuda_graph_cache", + None, + ) + if active_output_forward_graph_cache is not None: + active_output_forward_graph_cache.clear() + + def _clear_training_materialization_caches(self) -> None: + self._training_static_cache.clear() + self._constant_step_flat_cache.clear() + self._clear_active_output_temporal_graph_caches() + + def _mark_active_output_temporal_graph_static_parameter_versions( + self, + parameter_versions: tuple[int, ...], + ) -> None: + previous_versions = getattr( + self, + "_active_output_temporal_graph_static_parameter_versions", + None, + ) + if previous_versions is not None and tuple(previous_versions) != parameter_versions: + self._clear_active_output_temporal_graph_caches() + self._active_output_temporal_graph_static_parameter_versions = parameter_versions + + def _reset_parameters(self) -> None: + self._clear_execution_caches() + nn.init.xavier_uniform_(self.public_proj.weight) + nn.init.xavier_uniform_(self.input_proj.weight) + nn.init.xavier_uniform_(self.msg_to_cell.weight) + nn.init.normal_(self.cell_bias_proj.weight, mean=0.0, std=0.02) + nn.init.xavier_uniform_(self.q_proj.weight) + nn.init.xavier_uniform_(self.k_weight) + nn.init.xavier_uniform_(self.v_weight) + nn.init.xavier_uniform_(self.msg_out.weight) + nn.init.xavier_uniform_(self.output_cell_weight) + nn.init.zeros_(self.output_cell_bias) + nn.init.normal_(self.readout_query, mean=0.0, std=1.0 / math.sqrt(max(1, self.hidden_size))) + nn.init.xavier_uniform_(self.readout_out.weight) + if self.readout_out.bias is not None: + nn.init.zeros_(self.readout_out.bias) + + def _clear_inference_static_cache(self) -> None: + self._inference_static_cache.clear() + + @property + def backend_ir(self): + return self._backend_ir + + @property + def backend_population_specs(self): + return self._backend_population_specs + + @property + def last_backend_execution(self) -> BackendExecutionRecord | None: + return self._last_backend_execution + + @property + def graph_capture_cache_stats(self) -> dict[str, int]: + return self._backend_graph_capture_cache.stats() + + def _cell_spec_for_population(self, population_name: str): + return get_cell_spec(self._population_cell_types[population_name]) + + def _transition_core_state_names_for_population(self, population_name: str) -> tuple[str, ...] | None: + return trace_elision_core_state_names(self._backend_population_specs[population_name]) + + def _backend_spec_for_cell_type(self, cell_type: str): + for population_name, spec in self._backend_population_specs.items(): + if self._population_cell_types[population_name] == cell_type: + return spec + raise ValueError(f"Fabric cell type {cell_type} is not present in this runtime") + + def describe_backend(self) -> dict[str, object]: + return { + "num_cells": self._backend_ir.num_cells, + "num_recurrent_cells": self._backend_ir.num_recurrent_cells, + "num_input_ports": self._backend_ir.num_input_ports, + "num_output_ports": self._backend_ir.num_output_ports, + "bucket_count": self._backend_ir.bucket_count, + "delay_depth": self._backend_ir.delay_depth, + "kv_group_count": self._backend_ir.kv_group_count, + "population_names": self._backend_ir.population_names, + } + + def plan_backend_execution( + self, + *, + batch_size: int, + time_steps: int, + inner_steps: int, + training: bool, + tape_policy: TapePolicy | None = None, + device: torch.device | None = None, + surface_key: str | None = None, + ) -> PlannedFabricExecution: + if tape_policy is None: + tape_policy = default_tape_policy(training) + plan_device = self.coords.device if device is None else torch.device(device) + device_caps = self._get_backend_device_caps(plan_device) + if surface_key is not None: + raise ValueError(f"Unknown backend surface {surface_key}") + return self._backend_planner.plan_execution( + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + training=training, + device_caps=device_caps, + tape_policy=tape_policy, + supported_variants=None, + ) + + def plan_backend_backward_execution( + self, + *, + batch_size: int, + time_steps: int, + inner_steps: int, + training: bool, + tape_policy: TapePolicy | None = None, + device: torch.device | None = None, + surface_key: str | None = None, + ) -> PlannedFabricBackwardExecution: + if tape_policy is None: + tape_policy = default_tape_policy(training) + plan_device = self.coords.device if device is None else torch.device(device) + device_caps = self._get_backend_device_caps(plan_device) + if surface_key is not None: + raise ValueError(f"Unknown backend surface {surface_key}") + return self._backend_planner.plan_backward_execution( + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + training=training, + device_caps=device_caps, + tape_policy=tape_policy, + supported_variants=None, + ) + + @staticmethod + def _tape_policy_from_bin(tape_policy_bin: str) -> TapePolicy | None: + if tape_policy_bin == "none": + return None + return TapePolicy(mode=TapeMode(tape_policy_bin)) + + def _should_use_backend_graph_capture( + self, + *, + plan: PlannedFabricExecution, + device: torch.device, + grad_path: bool, + time_steps: int, + ) -> bool: + if not ( + device.type == "cuda" + and not torch.cuda.is_current_stream_capturing() + and plan.workspace_plan.graph_capture_stable + and all(bucket_plan.graph_capture_enabled for bucket_plan in plan.bucket_plans) + ): + return False + if not grad_path: + return bool(time_steps > 1) + return bool(time_steps > 1 and plan.tape_policy_bin in {"full_save", "checkpoint"}) + + def _should_use_backend_tape_policy( + self, + *, + plan: PlannedFabricExecution, + grad_path: bool, + time_steps: int, + ) -> bool: + return bool(grad_path and plan.tape_policy_bin in {"full_save", "checkpoint"} and time_steps > 1) + + def _backend_tape_checkpoint_chunk_len( + self, + *, + plan: PlannedFabricExecution, + time_steps: int, + output_boundary: Literal["sequence", "terminal"] = "sequence", + boundary_seq: torch.Tensor | None = None, + packed_state: Any | None = None, + initial_hidden: torch.Tensor | None = None, + initial_recurrent_k: torch.Tensor | None = None, + initial_recurrent_v: torch.Tensor | None = None, + ) -> int: + checkpoint_t = plan.workspace_plan.checkpoint_t + if boundary_seq is None or initial_hidden is None: + policy = tape_checkpoint_policy( + time_steps=int(time_steps), + checkpoint_t=checkpoint_t, + output_boundary=output_boundary, + estimated_step_bytes=None, + memory=None, + ) + return int(policy.chunk_len) + estimated_step_bytes = self._estimate_backend_tape_step_bytes( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + ) + policy = tape_checkpoint_policy( + time_steps=int(time_steps), + checkpoint_t=checkpoint_t, + output_boundary=output_boundary, + estimated_step_bytes=estimated_step_bytes, + memory=self._cuda_memory_budget(boundary_seq.device), + ) + self._last_backend_tape_chunk_len = int(policy.chunk_len) + self._last_backend_tape_artifact_mode = policy.artifact_mode + self._last_backend_tape_chunk_reason = policy.reason + return int(policy.chunk_len) + + def _backend_tape_memory_chunk_len( + self, + *, + boundary_seq: torch.Tensor, + packed_state: Any | None, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + ) -> tuple[int, str]: + estimated_step_bytes = self._estimate_backend_tape_step_bytes( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + ) + decision = tape_memory_chunk_len( + time_steps=int(boundary_seq.shape[1]), + estimated_step_bytes=estimated_step_bytes, + memory=self._cuda_memory_budget(boundary_seq.device), + ) + return int(decision.value), decision.reason + + def _estimate_backend_tape_step_bytes( + self, + *, + boundary_seq: torch.Tensor, + packed_state: Any | None, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + ) -> int: + state_bytes = 0 + if packed_state is not None: + _state_keys, state_tensors = _flatten_backend_packed_state(packed_state) + state_bytes = sum(self._tensor_storage_bytes(tensor) for tensor in state_tensors) + boundary_step_bytes = self._tensor_storage_bytes(boundary_seq[:, 0]) if boundary_seq.shape[1] > 0 else 0 + hidden_bytes = self._tensor_storage_bytes(initial_hidden) + recurrent_k_bytes = self._tensor_storage_bytes(initial_recurrent_k) + recurrent_v_bytes = self._tensor_storage_bytes(initial_recurrent_v) + step_bytes = state_bytes + boundary_step_bytes + hidden_bytes + recurrent_k_bytes + recurrent_v_bytes + return int(math.ceil(float(step_bytes) * 1.15)) + + @staticmethod + def _tensor_storage_bytes(tensor: torch.Tensor | None) -> int: + if tensor is None: + return 0 + return int(tensor.numel()) * int(tensor.element_size()) + + @staticmethod + def _cuda_usable_memory_info(device: torch.device) -> tuple[int, int, int, int]: + free_bytes, total_bytes = torch.cuda.mem_get_info(device) + reserved_bytes = int(torch.cuda.memory_reserved(device)) + allocated_bytes = int(torch.cuda.memory_allocated(device)) + reusable_reserved_bytes = max(0, reserved_bytes - allocated_bytes) + usable_bytes = int(free_bytes) + int(reusable_reserved_bytes) + return int(usable_bytes), int(total_bytes), int(free_bytes), int(reusable_reserved_bytes) + + def _cuda_memory_budget(self, device: torch.device) -> CudaMemoryBudget | None: + if device.type != "cuda": + return None + usable_bytes, total_bytes, free_bytes, reusable_reserved_bytes = self._cuda_usable_memory_info(device) + return CudaMemoryBudget( + usable_bytes=int(usable_bytes), + total_bytes=int(total_bytes), + free_bytes=int(free_bytes), + reusable_reserved_bytes=int(reusable_reserved_bytes), + ) + + def _backend_owner_timing_enabled(self, device: torch.device) -> bool: + if device.type != "cuda": + return False + return os.environ.get(_BACKWARD_OWNER_TIMING_ENV, "").lower() in {"1", "true", "yes", "on"} + + def _begin_backend_owner_timing(self, device: torch.device) -> None: + if not self._backend_owner_timing_enabled(device): + self._active_backend_owner_timing = None + return + self._active_backend_owner_timing = _BackendOwnerTimingCollector(device=device, events=[]) + + def _finish_backend_owner_timing(self) -> None: + collector = getattr(self, "_active_backend_owner_timing", None) + self._active_backend_owner_timing = None + if collector is None: + return + timing_summary = collector.summary() + wall_summary = collector.wall_summary() + if not timing_summary and not wall_summary: + return + record = getattr(self, "_last_backend_execution", None) + if record is None: + self._last_backend_owner_timing_ms = timing_summary + self._last_backend_owner_wall_ms = wall_summary + return + self._last_backend_execution = replace( + record, + backward_owner_timing_ms=timing_summary, + backward_owner_wall_ms=wall_summary, + ) + self._last_backend_owner_timing_ms = timing_summary + self._last_backend_owner_wall_ms = wall_summary + + @contextmanager + def _backend_owner_timing(self, name: str) -> Iterator[None]: + collector = getattr(self, "_active_backend_owner_timing", None) + if collector is None: + yield + return + with collector.record(name): + yield + + @staticmethod + def _sender_reverse_table_from_receiver_table(receiver_sender_idx: torch.Tensor, num_senders: int) -> torch.Tensor: + receiver_valid = receiver_sender_idx >= 0 + reverse = torch.full( + (int(num_senders), int(receiver_sender_idx.shape[1])), + -1, + dtype=torch.int32, + device=receiver_sender_idx.device, + ) + receiver_idx, offset_idx = torch.nonzero(receiver_valid, as_tuple=True) + sender_idx = receiver_sender_idx[receiver_idx, offset_idx].to(dtype=torch.long) + if bool((sender_idx < 0).any()): + raise ValueError("receiver_sender_idx must be non-negative on valid local edges") + if bool((reverse[sender_idx, offset_idx] >= 0).any()): + raise ValueError("local sender reverse table expects unique receiver per sender/offset") + reverse[sender_idx, offset_idx] = receiver_idx.to(dtype=torch.int32) + return reverse + + def _run_backend_sender_kv_projection_backward_raw_phase( + self, + *, + role: Literal["input", "recurrent"], + sender_cells: torch.Tensor, + grad_k: torch.Tensor | None, + grad_v: torch.Tensor | None, + sequence_static_tensors: Mapping[str, object], + active_receiver_window: Any | None = None, + boundary_requires_grad: bool = True, + owner: str = "public_projection", + slice_active_receiver_window_static_tensors: bool = True, + ) -> tuple[torch.Tensor | None, Any | None]: + return super()._run_backend_sender_kv_projection_backward_raw_phase( + role=role, + sender_cells=sender_cells, + grad_k=grad_k, + grad_v=grad_v, + sequence_static_tensors=sequence_static_tensors, + active_receiver_window=active_receiver_window, + boundary_requires_grad=boundary_requires_grad, + owner=owner, + slice_active_receiver_window_static_tensors=slice_active_receiver_window_static_tensors, + ) + + def _run_backend_sender_kv_projection_step_adjoint_raw_phase( + self, + *, + role: Literal["input", "recurrent"], + sender_cells: torch.Tensor, + grad_k: torch.Tensor | None, + grad_v: torch.Tensor | None, + sequence_static_tensors: Mapping[str, object], + active_receiver_window: Any | None = None, + slice_active_receiver_window_static_tensors: bool = True, + owner: str = "grouped_projection", + ) -> tuple[torch.Tensor | None, Any | None]: + grad_sender, raw_grad = self._run_backend_sender_kv_projection_backward_raw_phase( + role=role, + sender_cells=sender_cells, + grad_k=grad_k, + grad_v=grad_v, + sequence_static_tensors=sequence_static_tensors, + active_receiver_window=active_receiver_window, + slice_active_receiver_window_static_tensors=slice_active_receiver_window_static_tensors, + owner=owner, + ) + if raw_grad is not None: + self._last_active_output_recurrent_projection_adjoint_step_executor = ( + "shared_temporal_grouped_recurrent_projection_adjoint_step_cuda" + if bool(getattr(raw_grad, "grouped", False)) + else "shared_temporal_recurrent_projection_adjoint_step_cuda" + ) + return grad_sender, raw_grad + + def _run_backend_sender_kv_projection_sender_grad_phase( + self, + *, + role: Literal["input", "recurrent"], + grad_k: torch.Tensor | None, + grad_v: torch.Tensor | None, + sequence_static_tensors: Mapping[str, object], + active_receiver_window: Any | None = None, + slice_active_receiver_window_static_tensors: bool = True, + owner: str = "grouped_projection", + ) -> torch.Tensor | None: + grad_output = self._concat_kv_grads( + grad_k, + grad_v, + head_dim=self.head_dim, + value_dim=self.value_dim, + ) + if grad_output is None: + return None + projection_static_tensors = ( + self._cached_receiver_window_static_tensors(sequence_static_tensors, active_receiver_window) + if role == "recurrent" and slice_active_receiver_window_static_tensors + else sequence_static_tensors + ) + direct_key = "input_sender_input_to_kv_weight" if role == "input" else "recurrent_sender_input_to_kv_weight" + grouped_key = "input_group_input_to_kv_weight" if role == "input" else "recurrent_group_input_to_kv_weight" + weight = cast(torch.Tensor | None, projection_static_tensors[grouped_key]) + if weight is None: + weight = cast(torch.Tensor, projection_static_tensors[direct_key]) + sender_cells = grad_output.new_zeros((*tuple(grad_output.shape[:-1]), int(weight.shape[-1]))) + grad_sender, _raw_grad = self._run_backend_sender_kv_projection_backward_raw_phase( + role=role, + sender_cells=sender_cells, + grad_k=grad_k, + grad_v=grad_v, + sequence_static_tensors=projection_static_tensors, + active_receiver_window=None, + boundary_requires_grad=True, + owner=owner, + ) + return grad_sender + + def _run_backend_sender_kv_projection_weight_grad_raw_phase( + self, + *, + role: Literal["input", "recurrent"], + sender_cells: torch.Tensor, + grad_k: torch.Tensor | None, + grad_v: torch.Tensor | None, + sequence_static_tensors: Mapping[str, object], + active_receiver_window: Any | None = None, + slice_active_receiver_window_static_tensors: bool = True, + owner: str = "grouped_projection", + ) -> Any | None: + _grad_sender, raw_grad = self._run_backend_sender_kv_projection_backward_raw_phase( + role=role, + sender_cells=sender_cells, + grad_k=grad_k, + grad_v=grad_v, + sequence_static_tensors=sequence_static_tensors, + active_receiver_window=active_receiver_window, + boundary_requires_grad=False, + owner=owner, + slice_active_receiver_window_static_tensors=slice_active_receiver_window_static_tensors, + ) + return raw_grad + + def _backend_backward_batch_tile_len( + self, + *, + boundary_seq: torch.Tensor, + packed_state: Any | None, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + output_boundary: Literal["sequence", "terminal"] = "sequence", + sequence_surface_per_batch_bytes: int | None = None, + ) -> int: + batch_size = int(boundary_seq.shape[0]) + if batch_size <= 1 or boundary_seq.device.type != "cuda": + self._last_backend_backward_batch_tile_len = batch_size + self._last_backend_backward_batch_tile_reason = "batch_tiling=disabled" + return batch_size + estimated_step_bytes = self._estimate_backend_tape_step_bytes( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + ) + decision = backward_batch_tile_policy( + batch_size=batch_size, + time_steps=int(boundary_seq.shape[1]), + estimated_step_bytes=estimated_step_bytes, + memory=self._cuda_memory_budget(boundary_seq.device), + output_boundary=output_boundary, + sequence_surface_per_batch_bytes=sequence_surface_per_batch_bytes, + ) + self._last_backend_backward_batch_tile_len = int(decision.value) + self._last_backend_backward_batch_tile_reason = decision.reason + return int(decision.value) + + def _backend_forward_batch_tile_len( + self, + *, + boundary_seq: torch.Tensor, + packed_state: Any | None, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + ) -> int: + batch_size = int(boundary_seq.shape[0]) + if batch_size <= 1 or boundary_seq.device.type != "cuda": + self._last_backend_forward_batch_tile_len = batch_size + self._last_backend_forward_batch_tile_reason = "batch_tiling=disabled" + return batch_size + estimated_step_bytes = self._estimate_backend_tape_step_bytes( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + ) + decision = forward_batch_tile_policy( + batch_size=batch_size, + estimated_step_bytes=estimated_step_bytes, + memory=self._cuda_memory_budget(boundary_seq.device), + ) + self._last_backend_forward_batch_tile_len = int(decision.value) + self._last_backend_forward_batch_tile_reason = decision.reason + return int(decision.value) + + def _backend_forward_batch_tile_len_for_layout( + self, + *, + population_name: str, + batch_size: int, + time_steps: int, + boundary_seq: torch.Tensor, + materialize_final_state: bool, + training: bool = False, + fresh_state_virtualized: bool = False, + fresh_output_dependency_receiver_count: int | None = None, + projected_message_dim: int | None = None, + raw_public_dim: int | None = None, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + ) -> int: + if output_boundary not in {"sequence", "terminal"}: + raise ValueError(f"Unsupported Fabric output boundary {output_boundary!r}") + if readout_output_boundary not in {"cells", "pooled"}: + raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") + batch_size = int(batch_size) + if batch_size <= 1 or boundary_seq.device.type != "cuda": + self._last_backend_forward_batch_tile_len = batch_size + self._last_backend_forward_batch_tile_reason = "batch_tiling=disabled" + return batch_size + recurrent_cells = int(self._population_num_cells(population_name)) + state_leaf_count = len(self._cell_spec_for_population(population_name).state_schema.keys) + input_cells = int(boundary_seq.shape[2]) if boundary_seq.dim() >= 4 else int(self._num_input_cells) + output_cells = int(self._num_output_cells) + dtype_bytes = int(torch.empty((), dtype=boundary_seq.dtype).element_size()) + projected_message_dim = int(projected_message_dim if projected_message_dim is not None else self.hidden_size) + raw_public_dim = int(raw_public_dim if raw_public_dim is not None else self.hidden_size) + active_receiver_count = fresh_output_dependency_receiver_count + if active_receiver_count is None: + active_receiver_count = self._fresh_output_dependency_receiver_count( + population_name=population_name, + time_steps=int(time_steps), + fresh_state_virtualized=bool(fresh_state_virtualized), + ) + core_state_names = ( + self._transition_core_state_names_for_population(population_name) + if active_receiver_count is not None + else () + ) + decision = forward_layout_batch_tile_policy( + inputs=LayoutBatchTileInputs( + population_name=population_name, + batch_size=batch_size, + time_steps=int(time_steps), + dtype_bytes=dtype_bytes, + recurrent_cells=recurrent_cells, + state_leaf_count=state_leaf_count, + input_cells=input_cells, + output_cells=output_cells, + readout_slots=int(self.readout_slots), + sender_cells=int(self.sender_cell_idx.numel()), + hidden_size=int(self.hidden_size), + head_dim=int(self.head_dim), + value_dim=int(self.value_dim), + projected_message_dim=projected_message_dim, + raw_public_dim=raw_public_dim, + materialize_final_state=materialize_final_state, + training=training, + fresh_state_virtualized=fresh_state_virtualized, + fresh_output_dependency_receiver_count=active_receiver_count, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + core_state_leaf_count=len(core_state_names) if core_state_names else None, + ), + memory=self._cuda_memory_budget(boundary_seq.device), + ) + self._last_backend_forward_batch_tile_len = int(decision.value) + self._last_backend_forward_batch_tile_reason = decision.reason + return int(decision.value) + + def _flat_bucket_sequence_readout_batch_tile_len( + self, + *, + batch_size: int, + time_steps: int, + dtype_bytes: int, + input_cells: int, + output_cells: int, + readout_slots: int, + projected_output_dim: int, + materialize_final_state: bool, + output_boundary: Literal["sequence", "terminal"], + readout_output_boundary: Literal["cells", "pooled"], + training: bool, + memory: CudaMemoryBudget | None, + ) -> PolicyDecision: + active_populations = [ + name for name in self._population_names if int(self._population_recurrent_indices(name).numel()) > 0 + ] + transition_workspace_elements = 0 + for population_name in active_populations: + receiver_count = int(self._population_recurrent_indices(population_name).numel()) + core_state_names = self._transition_core_state_names_for_population(population_name) + if core_state_names: + state_like_surfaces = len(core_state_names) + else: + state_like_surfaces = len(self._cell_spec_for_population(population_name).state_schema.keys) + transition_workspace_elements += receiver_count * max(2, int(state_like_surfaces)) * int(self.hidden_size) + decision = flat_bucket_sequence_readout_batch_tile_policy( + batch_size=int(batch_size), + time_steps=int(time_steps), + dtype_bytes=int(dtype_bytes), + total_cells=int(self.coords.shape[0]), + recurrent_cells=int(self._num_recurrent_cells), + sender_cells=int(self.sender_cell_idx.numel()), + input_cells=int(input_cells), + output_cells=int(output_cells), + readout_slots=int(readout_slots), + hidden_size=int(self.hidden_size), + head_dim=int(self.head_dim), + value_dim=int(self.value_dim), + transition_workspace_elements_per_batch=int(transition_workspace_elements), + projected_output_dim=int(projected_output_dim), + materialize_final_state=bool(materialize_final_state), + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + training=bool(training), + memory=memory, + ) + self._last_backend_forward_batch_tile_len = int(decision.value) + self._last_backend_forward_batch_tile_reason = decision.reason + return decision + + def _fresh_output_dependency_receiver_count( + self, + *, + population_name: str, + time_steps: int, + fresh_state_virtualized: bool, + ) -> int | None: + del time_steps, fresh_state_virtualized + if not bool(self._local_message_step_enabled) or self._has_edge_delay: + return None + if bool(getattr(self, "_uses_sparse_message_backend", False)): + return None + population_spec = self._backend_population_specs.get(population_name) + if population_spec is None or not _transition_supports_receiver_local_dependency_window( + population_spec.transition_ir + ): + return None + recurrent_cells = int(self._population_num_cells(population_name)) + if not bool(getattr(self, "_output_local_recurrent_window_contiguous", False)): + return None + window = self._closed_fixed_output_dependency_receiver_window( + reason="streaming_output_active_region:fresh_output_dependency_receiver_count" + ) + if window is None or not window.active: + return None + if int(window.full_count) != recurrent_cells: + return None + return int(window.count) + + def _backend_graph_capture_key( + self, + *, + surface: Any, + plan: PlannedFabricExecution, + shape_signature: tuple[tuple[int, ...], ...], + ) -> GraphCaptureCacheKey: + return GraphCaptureCacheKey( + surface=surface.key, + execution_families=tuple(bucket_plan.execution_family.value for bucket_plan in plan.bucket_plans), + math_backends=tuple(bucket_plan.math_backend.value for bucket_plan in plan.bucket_plans), + shape_bin=plan.shape_bin, + shape_signature=shape_signature, + dtype=plan.dtype, + device_caps=plan.device_caps_key, + tape_policy_bin=plan.tape_policy_bin, + ) + + def _build_backend_graph_inputs( + self, + *, + boundary_seq: torch.Tensor | None, + packed_state: Any, + initial_hidden: torch.Tensor, + population_resets: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + packed_state_is_fresh: bool = False, + projected_boundary_source_seq: torch.Tensor | None = None, + projected_boundary_weight: torch.Tensor | None = None, + projected_boundary_bias: torch.Tensor | None = None, + ) -> tuple[_BackendGraphInputLayout, dict[str, torch.Tensor]]: + projected_boundary_active = projected_boundary_source_seq is not None or projected_boundary_weight is not None + if projected_boundary_active and (projected_boundary_source_seq is None or projected_boundary_weight is None): + raise RuntimeError("Projected Fabric boundary input requires both source sequence and projection weight") + if packed_state is None: + if not packed_state_is_fresh: + raise RuntimeError("Fabric backend graph inputs require a packed state unless the state is fresh") + packed_state_keys, packed_state_inputs = (), () + else: + packed_state_keys, packed_state_inputs = _flatten_backend_packed_state(packed_state) + graph_inputs: dict[str, torch.Tensor] = {} + if boundary_seq is not None: + graph_inputs["boundary_seq"] = boundary_seq + if projected_boundary_active: + assert projected_boundary_source_seq is not None and projected_boundary_weight is not None + graph_inputs["projected_boundary_source_seq"] = projected_boundary_source_seq + graph_inputs["projected_boundary_weight"] = projected_boundary_weight + if projected_boundary_bias is not None: + graph_inputs["projected_boundary_bias"] = projected_boundary_bias + graph_inputs["initial_hidden"] = initial_hidden + graph_inputs["population_resets"] = population_resets + packed_state_input_names: list[str] = [] + if not packed_state_is_fresh: + for index, tensor in enumerate(packed_state_inputs): + name = f"packed_state_{index}" + graph_inputs[name] = tensor + packed_state_input_names.append(name) + tensor_ref = ( + boundary_seq + if boundary_seq is not None + else projected_boundary_source_seq + if projected_boundary_source_seq is not None + else initial_hidden + ) + graph_inputs["initial_recurrent_k"] = ( + tensor_ref.new_empty(0) if initial_recurrent_k is None else initial_recurrent_k + ) + graph_inputs["initial_recurrent_v"] = ( + tensor_ref.new_empty(0) if initial_recurrent_v is None else initial_recurrent_v + ) + return ( + _BackendGraphInputLayout( + input_names=tuple(graph_inputs.keys()), + packed_state_keys=packed_state_keys, + packed_state_input_names=tuple(packed_state_input_names), + packed_state_shapes=tuple(tuple(tensor.shape) for tensor in packed_state_inputs), + packed_state_is_fresh=packed_state_is_fresh, + ), + graph_inputs, + ) + + def _unpack_backend_graph_inputs( + self, + *, + input_layout: _BackendGraphInputLayout, + graph_inputs: dict[str, torch.Tensor], + ) -> tuple[Any, torch.Tensor | None, torch.Tensor | None]: + if input_layout.packed_state_is_fresh and not input_layout.packed_state_shapes: + packed_state = None + elif input_layout.packed_state_is_fresh: + packed_state = _unflatten_backend_packed_state( + input_layout.packed_state_keys, + tuple(graph_inputs["initial_hidden"].new_zeros(shape) for shape in input_layout.packed_state_shapes), + ) + else: + packed_state = _unflatten_backend_packed_state( + input_layout.packed_state_keys, + tuple(graph_inputs[name] for name in input_layout.packed_state_input_names), + ) + recurrent_k = graph_inputs["initial_recurrent_k"] + recurrent_v = graph_inputs["initial_recurrent_v"] + return ( + packed_state, + None if recurrent_k.numel() == 0 else recurrent_k, + None if recurrent_v.numel() == 0 else recurrent_v, + ) + + def _copy_graph_workspace_inputs( + self, + workspace: GraphCaptureWorkspace, + *, + graph_inputs: dict[str, torch.Tensor], + ) -> None: + for name, tensor in graph_inputs.items(): + workspace.tensors[name].copy_(tensor) + + @staticmethod + def _graph_shape_signature( + *, + graph_inputs: dict[str, torch.Tensor], + ) -> tuple[tuple[int, ...], ...]: + return tuple(tuple(tensor.shape) for tensor in graph_inputs.values()) + + def _get_backend_device_caps(self, device: torch.device) -> DeviceCaps: + key = (device.type, -1 if device.index is None else int(device.index)) + cached = self._backend_device_caps_cache.get(key) + if cached is None: + cached = detect_device_caps(device) + self._backend_device_caps_cache[key] = cached + return cached + + def _apply(self, fn): + self._clear_execution_caches() + self._backend_device_caps_cache.clear() + return super()._apply(fn) + + def train(self, mode: bool = True): + self._clear_execution_caches() + return super().train(mode) + + def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): + self._clear_execution_caches() + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + def _record_backend_execution( + self, + *, + surface: Any, + plan: PlannedFabricExecution, + backward_plan: PlannedFabricBackwardExecution | None = None, + batch_size: int, + time_steps: int, + inner_steps: int, + training: bool, + graph_capture_replayed: bool = False, + graph_capture_cache_hit: bool = False, + ) -> None: + requested_receiver_tiles = tuple(bucket_plan.receiver_tile for bucket_plan in plan.bucket_plans) + requested_batch_tiles = tuple(bucket_plan.batch_tile for bucket_plan in plan.bucket_plans) + requested_edge_tiles = tuple(bucket_plan.edge_tile for bucket_plan in plan.bucket_plans) + requested_hidden_chunks = tuple(bucket_plan.hidden_chunk for bucket_plan in plan.bucket_plans) + requested_state_receiver_tiles = tuple(bucket_plan.state_receiver_tile for bucket_plan in plan.bucket_plans) + requested_state_batch_tiles = tuple(bucket_plan.state_batch_tile for bucket_plan in plan.bucket_plans) + requested_state_hidden_chunks = tuple(bucket_plan.state_hidden_chunk for bucket_plan in plan.bucket_plans) + requested_state_static_stage_modes = tuple( + bucket_plan.state_static_stage_mode for bucket_plan in plan.bucket_plans + ) + requested_emit_receiver_tiles = tuple(bucket_plan.emit_receiver_tile for bucket_plan in plan.bucket_plans) + requested_emit_batch_tiles = tuple(bucket_plan.emit_batch_tile for bucket_plan in plan.bucket_plans) + requested_emit_hidden_chunks = tuple(bucket_plan.emit_hidden_chunk for bucket_plan in plan.bucket_plans) + requested_emit_static_stage_modes = tuple( + bucket_plan.emit_static_stage_mode for bucket_plan in plan.bucket_plans + ) + requested_public_receiver_tiles = tuple(bucket_plan.public_receiver_tile for bucket_plan in plan.bucket_plans) + requested_public_batch_tiles = tuple(bucket_plan.public_batch_tile for bucket_plan in plan.bucket_plans) + requested_replication_factors = tuple(bucket_plan.replication_factor for bucket_plan in plan.bucket_plans) + requested_cell_static_stage_modes = tuple( + bucket_plan.cell_static_stage_mode for bucket_plan in plan.bucket_plans + ) + requested_readout_modes = tuple(bucket_plan.readout_mode for bucket_plan in plan.bucket_plans) + launch_metadata = self._last_backend_launch_metadata if self._active_backend_name == "cuda" else None + message_rule = self.backend_ir.message_rule + + def actual_tuple(key: str, requested: tuple[Any, ...]) -> tuple[Any, ...]: + if launch_metadata is None: + return requested + actual = launch_metadata[key] + if len(actual) == 1 and len(requested) > 1: + return actual * len(requested) + return actual + + def actual_tuple_optional(key: str, requested: tuple[Any, ...] = ()) -> tuple[Any, ...]: + if launch_metadata is None or key not in launch_metadata: + return requested + actual = launch_metadata[key] + if len(actual) == 1 and len(requested) > 1: + return actual * len(requested) + return actual + + actual_receiver_tiles = actual_tuple("receiver_tiles", requested_receiver_tiles) + actual_batch_tiles = actual_tuple("batch_tiles", requested_batch_tiles) + actual_edge_tiles = actual_tuple("edge_tiles", requested_edge_tiles) + actual_hidden_chunks = actual_tuple("hidden_chunks", requested_hidden_chunks) + actual_state_receiver_tiles = actual_tuple("state_receiver_tiles", requested_state_receiver_tiles) + actual_state_batch_tiles = actual_tuple("state_batch_tiles", requested_state_batch_tiles) + actual_state_hidden_chunks = actual_tuple("state_hidden_chunks", requested_state_hidden_chunks) + actual_state_static_stage_modes = actual_tuple( + "state_static_stage_modes", + requested_state_static_stage_modes, + ) + actual_emit_receiver_tiles = actual_tuple("emit_receiver_tiles", requested_emit_receiver_tiles) + actual_emit_batch_tiles = actual_tuple("emit_batch_tiles", requested_emit_batch_tiles) + actual_emit_hidden_chunks = actual_tuple("emit_hidden_chunks", requested_emit_hidden_chunks) + actual_emit_static_stage_modes = actual_tuple( + "emit_static_stage_modes", + requested_emit_static_stage_modes, + ) + actual_public_receiver_tiles = actual_tuple("public_receiver_tiles", requested_public_receiver_tiles) + actual_public_batch_tiles = actual_tuple("public_batch_tiles", requested_public_batch_tiles) + actual_replication_factors = actual_tuple("replication_factors", requested_replication_factors) + actual_cell_static_stage_modes = actual_tuple( + "cell_static_stage_modes", + requested_cell_static_stage_modes, + ) + actual_readout_modes = actual_tuple("readout_modes", requested_readout_modes) + actual_public_projection_hidden_backends = actual_tuple("public_projection_hidden_backends", ()) + actual_public_projection_kv_backends = actual_tuple("public_projection_kv_backends", ()) + actual_readout_projection_backends = actual_tuple("readout_projection_backends", ()) + public_hidden_projection_gate = receiver_major_projection_backward_gate( + batch_size=batch_size, + receivers=int(self._num_recurrent_cells), + input_dim=int(self.hidden_size), + output_dim=int(self.hidden_size), + biased=True, + ) + public_kv_projection_gate = receiver_major_projection_backward_gate( + batch_size=batch_size, + receivers=int(self._num_recurrent_cells), + input_dim=int(self.hidden_size), + output_dim=int(self.head_dim + self.value_dim), + biased=True, + ) + readout_receiver_major_projection_gate = receiver_major_projection_backward_gate( + batch_size=batch_size, + receivers=int(self._num_output_cells), + input_dim=int(self.config.d_msg), + output_dim=int(self.hidden_size), + biased=True, + ) + + def coalesced_receiver_major_projection_gate(phase: str) -> ReceiverMajorProjectionGate: + if phase.startswith("readout_projection"): + return readout_receiver_major_projection_gate + candidates = (public_hidden_projection_gate, public_kv_projection_gate) + for candidate in candidates: + if candidate.enabled and candidate.mode == "receiver_major_projection_small_batch_cuda": + return candidate + for candidate in candidates: + if candidate.enabled: + return candidate + return public_hidden_projection_gate + + backward_affine_bucket_signatures: list[str] = [] + backward_affine_forward_backends: list[str] = [] + backward_affine_input_grad_backends: list[str] = [] + backward_affine_weight_grad_backends: list[str] = [] + backward_affine_bias_grad_backends: list[str] = [] + backward_affine_demotion_reasons: list[str] = [] + backward_affine_execution_modes: list[str] = [] + + def add_backward_affine_bucket(phase: str, signature: str, backend: object) -> None: + forward_backend = str(backend) + if forward_backend in {"", "none", "unrun", "skip"}: + return + if forward_backend == "fused_into_tiny_message": + return + dense_backends = {"large_gemm", "batched_gemm", "grouped_gemm"} + copy_backends = {"copy", "copy_or_pad", "split_last_dim", "copy_fused"} + if forward_backend in dense_backends: + input_grad_backend = forward_backend + weight_grad_backend = forward_backend + bias_grad_backend = "dense_reduction" + demotion_reason = "none" + elif forward_backend == "receiver_affine_superop": + input_grad_backend = "receiver_affine_superop_backward_cuda" + weight_grad_backend = "receiver_affine_superop_backward_cuda" + bias_grad_backend = "receiver_affine_superop_backward_cuda" + demotion_reason = "none" + elif forward_backend in { + "direct_biased_receiver_affine", + "direct_biased_receiver_affine_split_outputs", + }: + receiver_major_gate = coalesced_receiver_major_projection_gate(phase) + if receiver_major_gate.enabled: + input_grad_backend = "receiver_major_projection_backward" + weight_grad_backend = "receiver_major_projection_backward" + bias_grad_backend = "receiver_major_projection_backward" + demotion_reason = "none" + else: + input_grad_backend = "explicit_demoted" + weight_grad_backend = "explicit_demoted" + bias_grad_backend = "explicit_demoted" + demotion_reason = receiver_major_gate.demotion_reason + elif forward_backend == "grouped_projection_forward": + input_grad_backend = "grouped_projection_backward_input" + weight_grad_backend = "grouped_projection_backward_weight" + bias_grad_backend = "not_applicable" + demotion_reason = "none" + elif forward_backend in copy_backends: + input_grad_backend = "copy" + weight_grad_backend = "not_applicable" + bias_grad_backend = "not_applicable" + demotion_reason = "none" + else: + input_grad_backend = "explicit_demoted" + weight_grad_backend = "explicit_demoted" + bias_grad_backend = "explicit_demoted" + demotion_reason = f"unsupported_forward_backend:{forward_backend}" + backward_affine_bucket_signatures.append(f"{phase}:{signature}") + backward_affine_forward_backends.append(forward_backend) + backward_affine_input_grad_backends.append(input_grad_backend) + backward_affine_weight_grad_backends.append(weight_grad_backend) + backward_affine_bias_grad_backends.append(bias_grad_backend) + backward_affine_demotion_reasons.append(demotion_reason) + backward_affine_execution_modes.append("fabric_cuda_nn_backward_bucket") + + if backward_plan is not None: + for backend in actual_tuple("input_projection_backends", ()): + add_backward_affine_bucket("input_projection", f"backend={backend}", backend) + for signature, backend in zip( + actual_tuple("state_affine_bucket_signatures", ()), + actual_tuple("state_affine_backends", ()), + strict=False, + ): + add_backward_affine_bucket("state_affine", str(signature), backend) + for projection_name, backends in ( + ("public_projection_hidden", actual_public_projection_hidden_backends), + ("public_projection_kv", actual_public_projection_kv_backends), + ("readout_projection", actual_readout_projection_backends), + ): + for backend in backends: + add_backward_affine_bucket(projection_name, f"backend={backend}", backend) + + backward_physical_op_kinds: tuple[str, ...] = () + backward_physical_op_executors: tuple[str, ...] = () + backward_physical_op_demotions: tuple[str, ...] = () + backward_boundary_contracts: tuple[str, ...] = () + backward_layout_mode: tuple[str, ...] = () + backward_workspace_aliases: tuple[str, ...] = () + backward_workspace_peak_bytes: tuple[str, ...] = () + backward_tape_mode: tuple[str, ...] = () + backward_recompute_mode: tuple[str, ...] = () + backward_launch_counts: tuple[str, ...] = () + backward_saved_launch_counts: tuple[str, ...] = () + backward_residual_glue_demotions: tuple[str, ...] = () + if backward_plan is not None: + physical_backward_plan = backward_plan.physical_plan + backward_physical_op_kinds = physical_backward_plan.op_kinds + backward_physical_op_executors = physical_backward_plan.executors + backward_physical_op_demotions = physical_backward_plan.demotions + backward_boundary_contracts = physical_backward_plan.boundary_contracts + backward_layout_mode = actual_tuple_optional("layout_mode") or physical_backward_plan.layout_modes + backward_workspace_aliases = physical_backward_plan.workspace_aliases + backward_workspace_peak_bytes = physical_backward_plan.workspace_peak_bytes + backward_tape_mode = (physical_backward_plan.tape_mode,) + backward_recompute_mode = (physical_backward_plan.recompute_mode,) + backward_launch_counts = ( + f"receiver_buckets:{len(backward_plan.receiver_bucket_plans)}", + f"sender_buckets:{len(backward_plan.sender_bucket_plans)}", + f"affine_buckets:{len(backward_affine_bucket_signatures)}", + *physical_backward_plan.launch_counts, + ) + saved_launch_counts = list(physical_backward_plan.saved_launch_counts) + tape_chunk_len = getattr(self, "_last_backend_tape_chunk_len", None) + tape_chunk_reason = getattr(self, "_last_backend_tape_chunk_reason", None) + tape_artifact_mode = getattr(self, "_last_backend_tape_artifact_mode", None) + backward_batch_tile_len = getattr(self, "_last_backend_backward_batch_tile_len", None) + backward_batch_tile_reason = getattr(self, "_last_backend_backward_batch_tile_reason", None) + if training and tape_chunk_len is not None: + saved_launch_counts.append(f"training_tape_window:t={tape_chunk_len}") + if tape_artifact_mode: + saved_launch_counts.append(f"training_tape_artifacts:{tape_artifact_mode}") + recompute_artifact_window_len = getattr( + self, + "_last_backend_recompute_artifact_window_len", + None, + ) + if recompute_artifact_window_len is not None: + saved_launch_counts.append(f"training_recompute_artifact_window:t={recompute_artifact_window_len}") + recompute_artifact_window_reason = getattr( + self, + "_last_backend_recompute_artifact_window_reason", + None, + ) + if recompute_artifact_window_reason: + saved_launch_counts.append( + f"training_recompute_artifact_window_reason:{recompute_artifact_window_reason}" + ) + recompute_checkpoint_stride = getattr( + self, + "_last_backend_recompute_checkpoint_stride", + None, + ) + if recompute_checkpoint_stride is not None: + saved_launch_counts.append(f"training_recompute_checkpoint_stride:t={recompute_checkpoint_stride}") + recompute_checkpoint_count = getattr( + self, + "_last_backend_recompute_checkpoint_count", + None, + ) + if recompute_checkpoint_count is not None: + saved_launch_counts.append(f"training_recompute_checkpoint_count:n={recompute_checkpoint_count}") + recompute_checkpoint_reason = getattr( + self, + "_last_backend_recompute_checkpoint_reason", + None, + ) + if recompute_checkpoint_reason: + saved_launch_counts.append(f"training_recompute_checkpoint_reason:{recompute_checkpoint_reason}") + recompute_predecessor_cache_mode = getattr( + self, + "_last_backend_recompute_predecessor_cache_mode", + None, + ) + if recompute_predecessor_cache_mode: + saved_launch_counts.append( + f"training_recompute_predecessor_cache:{recompute_predecessor_cache_mode}" + ) + if tape_chunk_reason: + saved_launch_counts.append(f"training_tape_window_reason:{tape_chunk_reason}") + if training and backward_batch_tile_len is not None: + saved_launch_counts.append(f"training_backward_batch_tile:b={backward_batch_tile_len}") + if backward_batch_tile_reason: + saved_launch_counts.append(f"training_backward_batch_tile_reason:{backward_batch_tile_reason}") + if training and tape_chunk_len is not None and time_steps <= tape_chunk_len: + saved_launch_counts.append("training_static_materialization:single_chunk_graph_reuse") + if training and getattr(self, "_last_training_static_prepack_mode", None) == "views": + saved_launch_counts.append("training_static_prepack:receiver_major_views") + if training and getattr(self, "_last_training_static_tape_mode", None) == "detached_shared_values": + saved_launch_counts.append("training_static_tape:detached_shared_values") + if training and getattr(self, "_last_backward_projection_mode", None) == "factorized_recurrent_input": + saved_launch_counts.append("training_static_projection:factorized_receiver_input") + backward_saved_launch_counts = tuple(saved_launch_counts) + backward_residual_glue_demotions = physical_backward_plan.residual_glue_demotions + self._last_backend_execution = BackendExecutionRecord( + backend_name=self._active_backend_name, + surface_key=surface.key, + cell_type=surface.cell_type, + regime=surface.regime, + training=training, + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + bucket_ids=tuple(bucket_plan.bucket_id for bucket_plan in plan.bucket_plans), + execution_families=tuple(bucket_plan.execution_family.value for bucket_plan in plan.bucket_plans), + math_backends=tuple(bucket_plan.math_backend.value for bucket_plan in plan.bucket_plans), + requested_launch_receiver_tiles=requested_receiver_tiles, + requested_launch_batch_tiles=requested_batch_tiles, + requested_launch_edge_tiles=requested_edge_tiles, + requested_launch_hidden_chunks=requested_hidden_chunks, + requested_launch_state_receiver_tiles=requested_state_receiver_tiles, + requested_launch_state_batch_tiles=requested_state_batch_tiles, + requested_launch_state_hidden_chunks=requested_state_hidden_chunks, + requested_launch_state_static_stage_modes=requested_state_static_stage_modes, + requested_launch_emit_receiver_tiles=requested_emit_receiver_tiles, + requested_launch_emit_batch_tiles=requested_emit_batch_tiles, + requested_launch_emit_hidden_chunks=requested_emit_hidden_chunks, + requested_launch_emit_static_stage_modes=requested_emit_static_stage_modes, + requested_launch_public_receiver_tiles=requested_public_receiver_tiles, + requested_launch_public_batch_tiles=requested_public_batch_tiles, + requested_launch_replication_factors=requested_replication_factors, + requested_launch_cell_static_stage_modes=requested_cell_static_stage_modes, + requested_launch_readout_modes=requested_readout_modes, + actual_launch_receiver_tiles=actual_receiver_tiles, + actual_launch_batch_tiles=actual_batch_tiles, + actual_launch_edge_tiles=actual_edge_tiles, + actual_launch_hidden_chunks=actual_hidden_chunks, + actual_launch_state_receiver_tiles=actual_state_receiver_tiles, + actual_launch_state_batch_tiles=actual_state_batch_tiles, + actual_launch_state_hidden_chunks=actual_state_hidden_chunks, + actual_launch_state_static_stage_modes=actual_state_static_stage_modes, + actual_launch_emit_receiver_tiles=actual_emit_receiver_tiles, + actual_launch_emit_batch_tiles=actual_emit_batch_tiles, + actual_launch_emit_hidden_chunks=actual_emit_hidden_chunks, + actual_launch_emit_static_stage_modes=actual_emit_static_stage_modes, + actual_launch_public_receiver_tiles=actual_public_receiver_tiles, + actual_launch_public_batch_tiles=actual_public_batch_tiles, + actual_launch_replication_factors=actual_replication_factors, + actual_launch_cell_static_stage_modes=actual_cell_static_stage_modes, + actual_launch_readout_modes=actual_readout_modes, + launch_receiver_tiles=actual_receiver_tiles, + launch_batch_tiles=actual_batch_tiles, + launch_edge_tiles=actual_edge_tiles, + launch_hidden_chunks=actual_hidden_chunks, + launch_state_receiver_tiles=actual_state_receiver_tiles, + launch_state_batch_tiles=actual_state_batch_tiles, + launch_state_hidden_chunks=actual_state_hidden_chunks, + launch_state_static_stage_modes=actual_state_static_stage_modes, + launch_emit_receiver_tiles=actual_emit_receiver_tiles, + launch_emit_batch_tiles=actual_emit_batch_tiles, + launch_emit_hidden_chunks=actual_emit_hidden_chunks, + launch_emit_static_stage_modes=actual_emit_static_stage_modes, + launch_public_receiver_tiles=actual_public_receiver_tiles, + launch_public_batch_tiles=actual_public_batch_tiles, + launch_replication_factors=actual_replication_factors, + launch_cell_static_stage_modes=actual_cell_static_stage_modes, + launch_readout_modes=actual_readout_modes, + launch_temporal_executions=actual_tuple("temporal_executions", ()), + launch_scan_implementations=actual_tuple("scan_implementations", ()), + launch_phases=actual_tuple("phases", ()), + active_receiver_window_modes=actual_tuple("active_receiver_window_modes", ()), + active_receiver_window_offsets=actual_tuple("active_receiver_window_offsets", ()), + active_receiver_window_counts=actual_tuple("active_receiver_window_counts", ()), + input_projection_backends=actual_tuple("input_projection_backends", ()), + input_projection_notes=actual_tuple("input_projection_notes", ()), + message_projection_boundaries=actual_tuple("message_projection_boundaries", ()), + message_projection_bucket_kinds=actual_tuple("message_projection_bucket_kinds", ()), + message_bucket_count=actual_tuple("message_bucket_count", ()), + message_regular_local_bucket_count=actual_tuple("message_regular_local_bucket_count", ()), + message_sparse_bucket_count=actual_tuple("message_sparse_bucket_count", ()), + message_batched_backend_count=actual_tuple("message_batched_backend_count", ()), + message_grouped_backend_count=actual_tuple("message_grouped_backend_count", ()), + message_reset_aware_bucket_count=actual_tuple("message_reset_aware_bucket_count", ()), + message_degree_uniform_bucket_count=actual_tuple("message_degree_uniform_bucket_count", ()), + message_ragged_grouped_bucket_count=actual_tuple("message_ragged_grouped_bucket_count", ()), + message_demoted_bucket_count=actual_tuple("message_demoted_bucket_count", ()), + message_bucket_signatures=actual_tuple("message_bucket_signatures", ()), + message_bucket_kinds=actual_tuple("message_bucket_kinds", ()), + message_topology_kinds=actual_tuple("message_topology_kinds", ()), + message_spatial_ownership=actual_tuple("message_spatial_ownership", ()), + message_degree_bucket_lists=actual_tuple("message_degree_bucket_lists", ()), + message_logit_backends=actual_tuple("message_logit_backends", ()), + message_softmax_backends=actual_tuple("message_softmax_backends", ()), + message_weighted_value_backends=actual_tuple("message_weighted_value_backends", ()), + message_physical_mode=actual_tuple("message_physical_mode", ()), + message_execution_mode=actual_tuple("message_execution_mode", ()), + message_output_boundary=actual_tuple("message_output_boundary", ()), + message_degree=actual_tuple("message_degree", ()), + message_k=actual_tuple("message_k", ()), + message_v=actual_tuple("message_v", ()), + message_projected_n=actual_tuple("message_projected_n", ()), + message_reset_policies=actual_tuple("message_reset_policies", ()), + message_reset_scopes=actual_tuple("message_reset_scopes", ()), + message_use_delay=actual_tuple("message_use_delay", ()), + message_distance_penalty_kinds=actual_tuple("message_distance_penalty_kinds", ()), + message_epilogue_kinds=actual_tuple("message_epilogue_kinds", ()), + message_rule_names=actual_tuple_optional("message_rule_names", (message_rule.name,)), + message_rule_lowering_kinds=actual_tuple_optional( + "message_rule_lowering_kinds", + (message_rule.lowering_kind,), + ), + message_rule_expression_signatures=actual_tuple_optional( + "message_rule_expression_signatures", + (message_rule.expression_signature,), + ), + message_rule_source_signatures=actual_tuple_optional( + "message_rule_source_signatures", + (";".join(message_rule.source_signature),), + ), + message_rule_parameter_sharing_signatures=actual_tuple_optional( + "message_rule_parameter_sharing_signatures", + (";".join(message_rule.parameter_sharing_signature),), + ), + message_rule_output_boundaries=actual_tuple_optional( + "message_rule_output_boundaries", + (message_rule.output_boundary,), + ), + message_packed_source_reuse_count=actual_tuple("message_packed_source_reuse_count", ()), + message_demotions=actual_tuple("message_demotions", ()), + message_workspace_buffers=actual_tuple("message_workspace_buffers", ()), + message_workspace_buffer_bytes=actual_tuple("message_workspace_buffer_bytes", ()), + message_workspace_peak_bytes=actual_tuple("message_workspace_peak_bytes", ()), + message_workspace_mode=actual_tuple("message_workspace_mode", ()), + message_workspace_aliases=actual_tuple("message_workspace_aliases", ()), + message_per_bucket_workspace_bytes=actual_tuple("message_per_bucket_workspace_bytes", ()), + phase_launch_counts=actual_tuple("phase_launch_counts", ()), + small_cublas_launch_counts=actual_tuple("small_cublas_launch_counts", ()), + copy_glue_launch_counts=actual_tuple("copy_glue_launch_counts", ()), + copy_glue_saved_launch_counts=actual_tuple("copy_glue_saved_launch_counts", ()), + bias_glue_launch_counts=actual_tuple("bias_glue_launch_counts", ()), + bias_glue_saved_launch_counts=actual_tuple("bias_glue_saved_launch_counts", ()), + state_epilogue_modes=actual_tuple("state_epilogue_modes", ()), + state_epilogue_saved_launch_counts=actual_tuple("state_epilogue_saved_launch_counts", ()), + launch_coalescing_modes=actual_tuple("launch_coalescing_modes", ()), + generic_glue_fusion_modes=actual_tuple("generic_glue_fusion_modes", ()), + launch_granularity_modes=actual_tuple("launch_granularity_modes", ()), + physical_op_kinds=actual_tuple("physical_op_kinds", ()), + physical_layout_contracts=actual_tuple("physical_layout_contracts", ()), + layout_mode=actual_tuple("layout_mode", ()), + copy_elision_mode=actual_tuple("copy_elision_mode", ()), + bias_fusion_mode=actual_tuple("bias_fusion_mode", ()), + physical_op_executors=actual_tuple("physical_op_executors", ()), + physical_op_demotions=actual_tuple("physical_op_demotions", ()), + physical_boundary_contracts=actual_tuple("physical_boundary_contracts", ()), + physical_applicability_predicates=actual_tuple("physical_applicability_predicates", ()), + physical_workspace_aliases=actual_tuple("physical_workspace_aliases", ()), + physical_workspace_peak_bytes=actual_tuple("physical_workspace_peak_bytes", ()), + physical_op_launch_counts=actual_tuple("physical_op_launch_counts", ()), + physical_op_saved_launch_counts=actual_tuple("physical_op_saved_launch_counts", ()), + standalone_copy_kernel_count=actual_tuple("standalone_copy_kernel_count", ()), + standalone_bias_kernel_count=actual_tuple("standalone_bias_kernel_count", ()), + receiver_affine_superop_surface_count=actual_tuple("receiver_affine_superop_surface_count", ()), + receiver_affine_superop_receivers=actual_tuple("receiver_affine_superop_receivers", ()), + receiver_affine_superop_k=actual_tuple("receiver_affine_superop_k", ()), + receiver_affine_superop_n=actual_tuple("receiver_affine_superop_n", ()), + receiver_affine_superop_source_layout=actual_tuple("receiver_affine_superop_source_layout", ()), + receiver_affine_superop_reset_policy=actual_tuple("receiver_affine_superop_reset_policy", ()), + receiver_affine_superop_executor=actual_tuple("receiver_affine_superop_executor", ()), + receiver_affine_superop_physical_mode=actual_tuple("receiver_affine_superop_physical_mode", ()), + receiver_affine_superop_demotion_reason=actual_tuple("receiver_affine_superop_demotion_reason", ()), + diagonal_recurrence_superop_surface_count=actual_tuple("diagonal_recurrence_superop_surface_count", ()), + diagonal_recurrence_kind=actual_tuple("diagonal_recurrence_kind", ()), + diagonal_recurrence_executor=actual_tuple("diagonal_recurrence_executor", ()), + diagonal_recurrence_physical_mode=actual_tuple("diagonal_recurrence_physical_mode", ()), + diagonal_recurrence_coeff_cache_mode=actual_tuple("diagonal_recurrence_coeff_cache_mode", ()), + diagonal_recurrence_coeff_cache_hit=actual_tuple("diagonal_recurrence_coeff_cache_hit", ()), + diagonal_recurrence_coeff_cache_bytes=actual_tuple("diagonal_recurrence_coeff_cache_bytes", ()), + diagonal_recurrence_coeff_cache_version_source=actual_tuple( + "diagonal_recurrence_coeff_cache_version_source", () + ), + diagonal_recurrence_reset_policy=actual_tuple("diagonal_recurrence_reset_policy", ()), + diagonal_recurrence_reset_scope=actual_tuple("diagonal_recurrence_reset_scope", ()), + diagonal_recurrence_output_boundary=actual_tuple("diagonal_recurrence_output_boundary", ()), + diagonal_recurrence_workspace_mode=actual_tuple("diagonal_recurrence_workspace_mode", ()), + diagonal_recurrence_workspace_peak_bytes=actual_tuple("diagonal_recurrence_workspace_peak_bytes", ()), + diagonal_recurrence_demotion_reason=actual_tuple("diagonal_recurrence_demotion_reason", ()), + diagonal_recurrence_launch_count=actual_tuple("diagonal_recurrence_launch_count", ()), + state_affine_backends=actual_tuple("state_affine_backends", ()), + state_affine_sources=actual_tuple("state_affine_sources", ()), + state_affine_bucket_signatures=actual_tuple("state_affine_bucket_signatures", ()), + state_affine_output_modes=actual_tuple("state_affine_output_modes", ()), + state_affine_reset_policies=actual_tuple("state_affine_reset_policies", ()), + state_affine_reset_mode=actual_tuple("state_affine_reset_mode", ()), + state_affine_reset_scope=actual_tuple("state_affine_reset_scope", ()), + state_affine_workspace_mode=actual_tuple("state_affine_workspace_mode", ()), + state_affine_receiver_chunk_size=actual_tuple("state_affine_receiver_chunk_size", ()), + state_affine_receiver_chunks=actual_tuple("state_affine_receiver_chunks", ()), + state_affine_workspace_buffers=actual_tuple("state_affine_workspace_buffers", ()), + state_affine_workspace_buffer_bytes=actual_tuple("state_affine_workspace_buffer_bytes", ()), + state_affine_workspace_bytes=actual_tuple("state_affine_workspace_bytes", ()), + state_affine_reset_rows_present=actual_tuple("state_affine_reset_rows_present", ()), + state_affine_packed_source_reused=actual_tuple("state_affine_packed_source_reused", ()), + public_projection_hidden_backends=actual_tuple("public_projection_hidden_backends", ()), + public_projection_kv_backends=actual_tuple("public_projection_kv_backends", ()), + readout_projection_backends=actual_tuple("readout_projection_backends", ()), + workspace_buffers=actual_tuple("workspace_buffers", ()), + workspace_buffer_bytes=actual_tuple("workspace_buffer_bytes", ()), + workspace_peak_bytes=actual_tuple("workspace_peak_bytes", ()), + workspace_aliases=actual_tuple("workspace_aliases", ()), + backward_receiver_execution_families=() + if backward_plan is None + else tuple(bucket_plan.execution_family.value for bucket_plan in backward_plan.receiver_bucket_plans), + backward_receiver_math_backends=() + if backward_plan is None + else tuple(bucket_plan.math_backend.value for bucket_plan in backward_plan.receiver_bucket_plans), + backward_sender_execution_families=() + if backward_plan is None + else tuple(bucket_plan.execution_family.value for bucket_plan in backward_plan.sender_bucket_plans), + backward_sender_math_backends=() + if backward_plan is None + else tuple(bucket_plan.math_backend.value for bucket_plan in backward_plan.sender_bucket_plans), + backward_affine_bucket_signatures=tuple(backward_affine_bucket_signatures), + backward_affine_forward_backends=tuple(backward_affine_forward_backends), + backward_affine_input_grad_backends=tuple(backward_affine_input_grad_backends), + backward_affine_weight_grad_backends=tuple(backward_affine_weight_grad_backends), + backward_affine_bias_grad_backends=tuple(backward_affine_bias_grad_backends), + backward_affine_demotion_reasons=tuple(backward_affine_demotion_reasons), + backward_affine_execution_modes=tuple(backward_affine_execution_modes), + backward_physical_op_kinds=backward_physical_op_kinds, + backward_physical_op_executors=backward_physical_op_executors, + backward_physical_op_demotions=backward_physical_op_demotions, + backward_boundary_contracts=backward_boundary_contracts, + backward_layout_mode=backward_layout_mode, + backward_workspace_aliases=backward_workspace_aliases, + backward_workspace_peak_bytes=backward_workspace_peak_bytes, + backward_tape_mode=backward_tape_mode, + backward_recompute_mode=backward_recompute_mode, + backward_launch_counts=backward_launch_counts, + backward_saved_launch_counts=backward_saved_launch_counts, + backward_residual_glue_demotions=backward_residual_glue_demotions, + tape_policy_bin=plan.tape_policy_bin, + graph_capture_enabled=bool(graph_capture_replayed or graph_capture_cache_hit), + capability_variants=tuple(bucket_plan.capability_variant for bucket_plan in plan.bucket_plans), + large_r_safety_modes=tuple(bucket_plan.large_r_safety_mode for bucket_plan in plan.bucket_plans), + active_cell_tiling_plans=tuple( + bucket_plan.active_cell_tiling_plan.summary for bucket_plan in plan.bucket_plans + ), + large_r_diagnostics=tuple(bucket_plan.large_r_diagnostics for bucket_plan in plan.bucket_plans), + graph_capture_replayed=graph_capture_replayed, + graph_capture_cache_hit=graph_capture_cache_hit, + ) + + def _record_pytorch_backend_execution( + self, + *, + batch_size: int, + time_steps: int, + inner_steps: int, + training: bool, + regime: str, + ) -> None: + if self._last_backend_execution is not None: + return + population_name = self._full_recurrent_population_name + if population_name is not None: + cell_type = self._population_cell_types[population_name] + else: + cell_type = "bucketed" + self._last_backend_execution = BackendExecutionRecord( + backend_name="pytorch", + surface_key=None, + cell_type=cell_type, + regime=regime, + training=training, + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + bucket_ids=(), + execution_families=(), + math_backends=(), + tape_policy_bin="none", + graph_capture_enabled=False, + capability_variants=(), + ) + + def _materialize_inference_static_tensors( + self, + *, + device: torch.device, + dtype: torch.dtype, + include_backend_cell_tensors: bool = True, + include_backend_prepack: Optional[bool] = None, + include_population_materialized: bool = True, + include_full_cell_kv_weight: bool = True, + ) -> dict[str, object]: + def static_contiguous(tensor: torch.Tensor) -> torch.Tensor: + with torch.profiler.record_function("fabric.glue.static_tensor_contiguous"): + return tensor.contiguous() + + cell_bias = self.cell_bias_proj(self.slot_embed).view(1, 1, self.coords.shape[0], self.hidden_size) + recurrent_cell_bias = cell_bias[:, :, self.recurrent_cell_idx, :].squeeze(1) + q = self.q_proj(self.slot_embed).view(self.coords.shape[0], self.head_dim) + recurrent_q = q.index_select(0, self.recurrent_cell_idx) + output_q = q.index_select(0, self.output_cell_idx) + sender_group_kv_weight = self._grouped_kv_weight(self.sender_kv_group_ids) + recurrent_group_kv_weight = self._grouped_kv_weight(self.recurrent_sender_kv_group_ids) + input_group_kv_weight = self._grouped_kv_weight(self.input_sender_kv_group_ids) + use_grouped_sender_weights = ( + sender_group_kv_weight is not None + and recurrent_group_kv_weight is not None + and input_group_kv_weight is not None + ) + gathered_kv_weight = None + if include_full_cell_kv_weight or not use_grouped_sender_weights: + with torch.profiler.record_function("fabric.glue.static_tensor_cat"): + gathered_kv_weight = torch.cat( + ( + self.k_weight.index_select(0, self.kv_group_id), + self.v_weight.index_select(0, self.kv_group_id), + ), + dim=-1, + ) + sender_kv_weight = ( + gathered_kv_weight.index_select(0, self.sender_cell_idx) + if gathered_kv_weight is not None and sender_group_kv_weight is None + else None + ) + sender_input_to_kv_weight = ( + torch.einsum("dh,sdm->shm", self.public_proj.weight, sender_kv_weight) + if sender_kv_weight is not None + else None + ) + input_sender_input_to_kv_weight = ( + sender_input_to_kv_weight.index_select(0, self.input_sender_idx) + if sender_input_to_kv_weight is not None + else None + ) + sender_group_input_to_kv_weight = ( + torch.einsum("dh,gdm->ghm", self.public_proj.weight, sender_group_kv_weight) + if sender_group_kv_weight is not None + else None + ) + recurrent_sender_input_to_kv_weight = ( + sender_input_to_kv_weight.index_select(0, self.recurrent_sender_idx) + if sender_input_to_kv_weight is not None + else None + ) + recurrent_group_input_to_kv_weight = ( + torch.einsum("dh,gdm->ghm", self.public_proj.weight, recurrent_group_kv_weight) + if recurrent_group_kv_weight is not None + else None + ) + input_group_input_to_kv_weight = ( + torch.einsum("dh,gdm->ghm", self.public_proj.weight, input_group_kv_weight) + if input_group_kv_weight is not None + else None + ) + population_materialized: dict[str, object | None] = {} + should_include_backend_prepack = ( + include_backend_cell_tensors if include_backend_prepack is None else include_backend_prepack + ) + if include_population_materialized: + for name in self._population_names: + if hasattr(self.population_modules[name], "materialize_params"): + with torch.profiler.record_function("fabric.glue.population_param_materialization"): + population_materialized[name] = self.population_modules[name].materialize_params( + include_backend_prepack=should_include_backend_prepack + ) + else: + population_materialized[name] = None + else: + population_materialized = {name: None for name in self._population_names} + value_to_cell_weight = self.msg_to_cell.weight @ self.msg_out.weight + fused_recurrent_value_to_cell_weight = None + fused_recurrent_cell_bias = recurrent_cell_bias + fused_recurrent_population_input = False + recurrent_bias_2d = recurrent_cell_bias.squeeze(0) if recurrent_cell_bias.dim() == 3 else recurrent_cell_bias + if self._full_recurrent_population_name is not None: + full_population_name = self._full_recurrent_population_name + full_population_params = population_materialized.get(full_population_name) + input_proj_weight_t = None + out_proj_weight_t = None + out_proj_bias = None + if isinstance(full_population_params, dict): + input_proj_weight_t = full_population_params.get("input_proj_weight_t") + out_proj_weight_t = full_population_params.get("out_proj_weight_t") + out_proj_bias = full_population_params.get("out_proj_bias") + if torch.is_tensor(input_proj_weight_t): + fused_recurrent_value_to_cell_weight = torch.matmul( + value_to_cell_weight.transpose(0, 1).unsqueeze(0), + input_proj_weight_t, + ) + fused_recurrent_cell_bias = ( + torch.bmm(recurrent_cell_bias.squeeze(0).unsqueeze(1), input_proj_weight_t).squeeze(1).unsqueeze(0) + ) + fused_recurrent_population_input = True + empty_float_tensor = torch.empty(0, device=device, dtype=dtype) + sender_group_size_tensor = torch.tensor( + [int(self._recurrent_sender_kv_group_size)], + device=device, + dtype=torch.int32, + ) + recurrent_input_to_kv_weight = recurrent_sender_input_to_kv_weight + if recurrent_input_to_kv_weight is None and torch.is_tensor(recurrent_group_input_to_kv_weight): + recurrent_input_to_kv_weight = recurrent_group_input_to_kv_weight.repeat_interleave( + max(1, self._recurrent_sender_kv_group_size), + dim=0, + ) + backend_cell_tensors: dict[str, dict[str, torch.Tensor]] = {} + if include_backend_cell_tensors: + for population_name in self._population_names: + cell_spec = self._cell_spec_for_population(population_name) + if cell_spec.public_schema.kind == "hidden": + backend_cell_tensors[population_name] = { + "recurrent_message_to_state_weight": static_contiguous(value_to_cell_weight), + "recurrent_message_to_state_bias": static_contiguous(recurrent_bias_2d), + "sender_input_to_kv_weight": ( + static_contiguous(recurrent_sender_input_to_kv_weight) + if torch.is_tensor(recurrent_sender_input_to_kv_weight) + else empty_float_tensor + ), + "grouped_sender_input_to_kv_weight": ( + static_contiguous(recurrent_group_input_to_kv_weight) + if torch.is_tensor(recurrent_group_input_to_kv_weight) + else empty_float_tensor + ), + "sender_group_size": sender_group_size_tensor, + } + continue + if cell_spec.public_schema.kind != "preproj": + continue + population_params = population_materialized.get(population_name) + if not isinstance(population_params, dict): + continue + if recurrent_input_to_kv_weight is None: + raise RuntimeError( + f"Fabric cell population {population_name} is missing recurrent sender KV projection weights" + ) + population_recurrent_idx = self._population_recurrent_indices(population_name) + population_recurrent_count = int(population_recurrent_idx.numel()) + population_recurrent_input_to_kv_weight = recurrent_input_to_kv_weight.index_select( + 0, + population_recurrent_idx, + ) + sequence_population_input_weight = cast( + torch.Tensor, + fused_recurrent_value_to_cell_weight + if fused_recurrent_population_input and torch.is_tensor(fused_recurrent_value_to_cell_weight) + else static_contiguous( + value_to_cell_weight.transpose(0, 1).unsqueeze(0).expand(population_recurrent_count, -1, -1) + ), + ) + sequence_population_input_bias = cast( + torch.Tensor, + ( + fused_recurrent_cell_bias + if fused_recurrent_population_input + else recurrent_cell_bias.index_select(1, population_recurrent_idx) + ), + ) + out_proj_weight_t = cast(torch.Tensor, population_params["out_proj_weight_t"]) + out_proj_bias = cast(torch.Tensor, population_params["out_proj_bias"]) + backend_cell_tensors[population_name] = { + "input_projection_weight": static_contiguous(sequence_population_input_weight), + "input_projection_bias": static_contiguous(sequence_population_input_bias.squeeze(0)), + "recurrent_hidden_projection_weight": cast( + torch.Tensor, population_params["recurrent_hidden_projection_weight"] + ), + "recurrent_hidden_projection_bias": cast( + torch.Tensor, population_params["recurrent_hidden_projection_bias"] + ), + "recurrent_kv_projection_weight": static_contiguous( + torch.bmm( + out_proj_weight_t, + population_recurrent_input_to_kv_weight, + ) + ), + "recurrent_kv_projection_bias": static_contiguous( + torch.bmm( + out_proj_bias.unsqueeze(1), + population_recurrent_input_to_kv_weight, + ).squeeze(1) + ), + } + backend_cell_tensors[population_name]["recurrent_hidden_projection_weight"] = static_contiguous( + backend_cell_tensors[population_name]["recurrent_hidden_projection_weight"] + ) + backend_cell_tensors[population_name]["recurrent_hidden_projection_bias"] = static_contiguous( + backend_cell_tensors[population_name]["recurrent_hidden_projection_bias"] + ) + return { + "_parameter_versions": tuple(int(parameter._version) for parameter in self.parameters()), + "cell_bias": cell_bias, + "recurrent_cell_bias": recurrent_cell_bias, + "q": q, + "recurrent_q": recurrent_q, + "recurrent_q_backend_order": ( + recurrent_q.index_select(0, self.population_backend_recurrent_order) + if self.population_backend_recurrent_order.numel() == recurrent_q.shape[0] + and not self._population_backend_recurrent_order_is_identity + else recurrent_q + ), + "output_q": output_q, + "gathered_kv_weight": gathered_kv_weight, + "sender_kv_weight": sender_kv_weight, + "sender_input_to_kv_weight": sender_input_to_kv_weight, + "input_sender_input_to_kv_weight": input_sender_input_to_kv_weight, + "sender_group_input_to_kv_weight": sender_group_input_to_kv_weight, + "recurrent_sender_input_to_kv_weight": recurrent_sender_input_to_kv_weight, + "recurrent_group_input_to_kv_weight": recurrent_group_input_to_kv_weight, + "input_group_input_to_kv_weight": input_group_input_to_kv_weight, + "value_to_cell_weight": value_to_cell_weight, + "fused_recurrent_value_to_cell_weight": fused_recurrent_value_to_cell_weight, + "fused_recurrent_cell_bias": fused_recurrent_cell_bias, + "fused_recurrent_population_input": fused_recurrent_population_input, + "value_to_output_weight": torch.einsum("dv,pdh->pvh", self.msg_out.weight, self.output_cell_weight), + "population_materialized": population_materialized, + "population_materialized_include_backend_prepack": should_include_backend_prepack, + "backend_cell_tensors": backend_cell_tensors, + } + + def _detach_backend_static_tensors(self, value: object) -> object: + if torch.is_tensor(value): + return value.detach() + if isinstance(value, TensorDictBase): + return value.detach() + if isinstance(value, dict): + return {key: self._detach_backend_static_tensors(item) for key, item in value.items()} + if isinstance(value, tuple): + return tuple(self._detach_backend_static_tensors(item) for item in value) + if isinstance(value, list): + return [self._detach_backend_static_tensors(item) for item in value] + return value + + def _get_inference_static_tensors( + self, + *, + device: torch.device, + dtype: torch.dtype, + include_full_cell_kv_weight: bool = True, + ) -> dict[str, object]: + device_index = -1 if device.index is None else int(device.index) + key = ("inference", device.type, device_index, dtype, bool(include_full_cell_kv_weight)) + if key in self._inference_static_cache: + return self._inference_static_cache[key] + cached = self._materialize_inference_static_tensors( + device=device, + dtype=dtype, + include_full_cell_kv_weight=include_full_cell_kv_weight, + ) + self._inference_static_cache[key] = cached + return cached + + def _get_training_static_tensors( + self, + *, + device: torch.device, + dtype: torch.dtype, + include_backend_prepack: bool, + include_full_cell_kv_weight: bool = True, + include_population_materialized: bool = True, + detach_static_tensors: bool = True, + ) -> dict[str, object]: + parameter_versions = tuple(int(parameter._version) for parameter in self.parameters()) + self._mark_active_output_temporal_graph_static_parameter_versions(parameter_versions) + if not detach_static_tensors: + return self._materialize_inference_static_tensors( + device=device, + dtype=dtype, + include_backend_prepack=include_backend_prepack, + include_full_cell_kv_weight=include_full_cell_kv_weight, + include_population_materialized=include_population_materialized, + ) + device_index = -1 if device.index is None else int(device.index) + key = ( + "training", + device.type, + device_index, + dtype, + bool(include_backend_prepack), + bool(include_full_cell_kv_weight), + bool(include_population_materialized), + parameter_versions, + ) + if key in self._training_static_cache: + return self._training_static_cache[key] + if self._training_static_cache: + self._clear_active_output_temporal_graph_caches() + self._training_static_cache.clear() + with torch.no_grad(): + cached = cast( + dict[str, object], + self._materialize_inference_static_tensors( + device=device, + dtype=dtype, + include_backend_prepack=include_backend_prepack, + include_full_cell_kv_weight=include_full_cell_kv_weight, + include_population_materialized=include_population_materialized, + ), + ) + cached = cast(dict[str, object], self._detach_backend_static_tensors(cached)) + self._training_static_cache[key] = cached + return cached + + def init_state(self, batch: int, *, device: torch.device | str = "cpu", dtype: torch.dtype) -> TensorDict: + with torch.profiler.record_function("fabric.glue.runtime_init_state"): + state = TensorDict({}, batch_size=[]) + state["cells"] = torch.zeros(batch, self.coords.shape[0], self.hidden_size, device=device, dtype=dtype) + for population_name in self._population_names: + state[population_name] = self.population_modules[population_name].init_state( + batch=batch, + device=device, + dtype=dtype, + ) + return state + + def reset_state(self, state: MaybeState, mask: ResetMask) -> MaybeState: + if state is None or not isinstance(state, TensorDictBase): + return state + cells = state.get("cells") + state_device = cells.device if torch.is_tensor(cells) else torch.device("cpu") + batch_mask = torch.as_tensor(mask, dtype=torch.bool, device=state_device) + if batch_mask.dim() != 1: + raise ValueError(f"Runtime.reset_state expects a 1D mask, got shape {tuple(batch_mask.shape)}") + out = TensorDict({}, batch_size=[]) + if torch.is_tensor(cells): + out["cells"] = reset_backend_tensor_rows(cells, batch_mask) + sender_k = state.get("sender_k") + if torch.is_tensor(sender_k): + out["sender_k"] = reset_backend_tensor_rows(sender_k, batch_mask) + sender_v = state.get("sender_v") + if torch.is_tensor(sender_v): + out["sender_v"] = reset_backend_tensor_rows(sender_v, batch_mask) + for population_name in self._population_names: + population_state = state.get(population_name) + if population_state is None: + continue + out[population_name] = self.population_modules[population_name].reset_state(population_state, batch_mask) + return out + + def readout_output_cells(self, output_cells: torch.Tensor) -> torch.Tensor: + return self._readout(output_cells) + + def forward( + self, + hidden_input: Tensor, + state: MaybeState = None, + *, + resets: Optional[ResetMask] = None, + k: int | torch.Tensor | None = None, + materialize_final_state: bool = True, + ) -> tuple[Tensor, MaybeState]: + y_cells, next_state = self.forward_cells( + hidden_input=hidden_input, + state=state, + resets=resets, + k=k, + materialize_final_state=materialize_final_state, + ) + if y_cells.dim() == 3: + return self._readout(y_cells.unsqueeze(1)).squeeze(1), next_state + return self._readout(y_cells), next_state + + def forward_cells( + self, + hidden_input: Tensor | None = None, + state: MaybeState = None, + *, + resets: Optional[ResetMask] = None, + k: int | torch.Tensor | None = None, + boundary_input: torch.Tensor | None = None, + training_semantics: bool | None = None, + materialize_final_state: bool = True, + ) -> tuple[Tensor, MaybeState]: + """Run the fabric over either per-step hidden vectors or direct boundary-cell values. + + `hidden_input` is a single vector per timestep with shape `[B, H]` or `[B, T, H]`. + The runtime projects it into the boundary input cells internally. + + `boundary_input` is already in boundary-cell space with shape `[B, P, H]` or + `[B, T, P, H]`, where `P` is the number of boundary input cells. + """ + if self._preserve_backend_execution_record_depth <= 0: + self._last_backend_execution = None + self._last_backend_launch_metadata = None + if hidden_input is None and boundary_input is None: + raise ValueError("forward_cells requires either hidden_input or boundary_input") + + step_mode = (boundary_input.dim() == 3) if boundary_input is not None else (hidden_input.dim() == 2) + if boundary_input is not None: + boundary_seq = boundary_input.unsqueeze(1) if step_mode else boundary_input + if boundary_seq.dim() != 4: + raise ValueError( + f"Runtime boundary_input expects [B,P,D] or [B,T,P,D], got {tuple(boundary_input.shape)}" + ) + batch_size, time_steps, port_count, msg_dim = boundary_seq.shape + if port_count != self.input_cell_idx.numel(): + raise ValueError( + f"Runtime boundary_input count={port_count} must match input cells={self.input_cell_idx.numel()}" + ) + if msg_dim != self.hidden_size: + raise ValueError(f"Runtime boundary_input dim={msg_dim} must match hidden_size={self.hidden_size}") + hidden_seq = None + else: + assert hidden_input is not None + hidden_seq = hidden_input.unsqueeze(1) if step_mode else hidden_input + if hidden_seq.dim() != 3: + raise ValueError(f"Runtime hidden_input expects [B,H] or [B,T,H], got {tuple(hidden_input.shape)}") + batch_size, time_steps, hidden_size = hidden_seq.shape + if hidden_size != self.hidden_size: + raise ValueError( + f"Runtime hidden_size={self.hidden_size} requires input dim {self.hidden_size}, got {hidden_size}" + ) + boundary_seq = None + + device = boundary_seq.device if boundary_seq is not None else hidden_seq.device + dtype = boundary_seq.dtype if boundary_seq is not None else hidden_seq.dtype + grad_path = torch.is_grad_enabled() if training_semantics is None else bool(training_semantics) + backend_population_state_is_fresh = state is None or not isinstance(state, TensorDictBase) + if grad_path: + self._clear_execution_caches() + population_resets = _expand_resets_for_time(resets, batch_size=batch_size, time_steps=time_steps, device=device) + capture_active = bool(device.type == "cuda" and torch.cuda.is_current_stream_capturing()) + step_reset_flags: list[bool] | None = None + if population_resets is not None: + if capture_active: + step_reset_flags = [True] * time_steps + else: + if population_resets.dim() == 1: + reset_any = population_resets.any().view(1) + else: + reset_any = population_resets.any(dim=0) + step_reset_flags = reset_any.to(device="cpu", dtype=torch.bool).tolist() + temporal_execution_plan = self._plan_temporal_execution( + k=k, + device=device, + dtype=dtype, + output_boundary="deferred", + readout_output_boundary="deferred", + materialize_final_state=materialize_final_state, + has_output_consumer=False, + fresh_state=backend_population_state_is_fresh, + training=grad_path, + ) + backend_population_name = self._select_output_cells_stream_backend_population( + k=k, + ) + selected_backend_name = select_fabric_backend( + configured_backend=str(self.config.backend), + device=device, + supports_cuda_backend=temporal_execution_plan.supported, + ) + backend_native_static_materialization = bool( + selected_backend_name == "cuda" and temporal_execution_plan.supported + ) + if grad_path: + training_static_prepack = self._training_static_prepack_enabled() + self._last_training_static_prepack_mode = "contiguous_prepack" if training_static_prepack else "views" + self._last_backward_projection_mode = ( + "fused_static_projection" if training_static_prepack else "factorized_recurrent_input" + ) + if selected_backend_name == "pytorch": + static_tensors = self._materialize_inference_static_tensors( + device=device, + dtype=dtype, + include_backend_prepack=training_static_prepack, + ) + self._last_training_static_tape_mode = "pytorch_autograd_static_values" + else: + static_tensors = self._get_training_static_tensors( + device=device, + dtype=dtype, + include_backend_prepack=training_static_prepack, + include_full_cell_kv_weight=not backend_native_static_materialization, + detach_static_tensors=True, + ) + self._last_training_static_tape_mode = "detached_shared_values" + else: + self._last_training_static_prepack_mode = "inference_cache" + self._last_backward_projection_mode = "not_training" + static_tensors = self._get_inference_static_tensors( + device=device, + dtype=dtype, + include_full_cell_kv_weight=not backend_native_static_materialization, + ) + cell_bias = static_tensors["cell_bias"] + recurrent_cell_bias = static_tensors["recurrent_cell_bias"] + q = static_tensors["q"] + recurrent_q = static_tensors["recurrent_q"] + output_q = static_tensors["output_q"] + gathered_kv_weight = static_tensors["gathered_kv_weight"] + sender_kv_weight = static_tensors["sender_kv_weight"] + sender_input_to_kv_weight = static_tensors["sender_input_to_kv_weight"] + input_sender_input_to_kv_weight = static_tensors["input_sender_input_to_kv_weight"] + sender_group_input_to_kv_weight = static_tensors["sender_group_input_to_kv_weight"] + recurrent_sender_input_to_kv_weight = static_tensors["recurrent_sender_input_to_kv_weight"] + recurrent_group_input_to_kv_weight = static_tensors["recurrent_group_input_to_kv_weight"] + input_group_input_to_kv_weight = static_tensors["input_group_input_to_kv_weight"] + value_to_cell_weight = static_tensors["value_to_cell_weight"] + fused_recurrent_value_to_cell_weight = static_tensors["fused_recurrent_value_to_cell_weight"] + fused_recurrent_cell_bias = static_tensors["fused_recurrent_cell_bias"] + fused_recurrent_population_input = bool(static_tensors["fused_recurrent_population_input"]) + value_to_output_weight = static_tensors["value_to_output_weight"] + constant_k = self._resolve_constant_k_host(k) + population_materialized = static_tensors["population_materialized"] + self._active_backend_name = selected_backend_name + use_fresh_backend_population_cache = bool( + selected_backend_name == "cuda" + and temporal_execution_plan.carry.carry_policy == "fresh_state_population_cache_elision_candidate" + and not grad_path + ) + current_state = self._ensure_state( + state, + batch=batch_size, + device=device, + dtype=dtype, + include_population_state=not use_fresh_backend_population_cache, + ) + if ( + selected_backend_name == "cuda" + and temporal_execution_plan.executor.selected_implementation == "shared_transition_buckets" + ): + return execute_temporal_bucket_sequence( + self, + hidden_seq=hidden_seq, + boundary_seq=boundary_seq, + state=current_state, + population_resets=population_resets, + step_reset_flags=step_reset_flags, + k=k, + constant_k=constant_k, + batch_size=batch_size, + time_steps=time_steps, + step_mode=step_mode, + capture_active=capture_active, + static_tensors=static_tensors, + grad_path=grad_path, + materialize_final_state=materialize_final_state, + backend_population_state_is_fresh=backend_population_state_is_fresh, + use_fresh_backend_population_cache=use_fresh_backend_population_cache, + tape_policy=None, + output_contract="full_cells", + temporal_execution_plan=temporal_execution_plan, + ) + if step_mode: + if constant_k is None: + k_rows, max_steps = self._resolve_step_k( + k, + batch_size=batch_size, + time_steps=time_steps, + step_index=0, + device=device, + ) + all_active = None + else: + k_rows = torch.full((batch_size,), constant_k, device=device, dtype=torch.long) + max_steps = constant_k + all_active = constant_k > 0 + step_resets = population_resets[:, 0] if population_resets is not None else None + step_population_state_cache = None + if use_fresh_backend_population_cache: + step_population_state_cache = self._prepare_fresh_stream_step_population_cache( + batch=batch_size, + device=device, + dtype=dtype, + ) + elif selected_backend_name == "cuda" and backend_population_name is not None and constant_k == 1: + step_population_state_cache = self._prepare_stream_step_population_cache( + current_state, + batch=batch_size, + device=device, + dtype=dtype, + ) + y_step, next_state = self._forward_stream_step( + hidden_step=hidden_seq[:, 0] if hidden_seq is not None else None, + state=current_state, + resets=step_resets, + has_resets=step_reset_flags[0] if step_reset_flags is not None else None, + capture_active=capture_active, + k_rows=k_rows, + max_steps=max_steps, + all_active=all_active if max_steps <= 1 else None, + q=q, + recurrent_q=recurrent_q, + output_q=output_q, + gathered_kv_weight=gathered_kv_weight, + sender_kv_weight=sender_kv_weight, + sender_input_to_kv_weight=sender_input_to_kv_weight, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + sender_group_input_to_kv_weight=sender_group_input_to_kv_weight, + sender_group_size=self._sender_kv_group_size, + recurrent_sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + recurrent_group_input_to_kv_weight=recurrent_group_input_to_kv_weight, + recurrent_group_size=self._recurrent_sender_kv_group_size, + value_to_cell_weight=value_to_cell_weight, + fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, + value_to_output_weight=value_to_output_weight, + cell_bias=cell_bias, + recurrent_cell_bias=recurrent_cell_bias, + fused_recurrent_cell_bias=fused_recurrent_cell_bias, + fused_recurrent_population_input=fused_recurrent_population_input, + boundary_step=boundary_seq[:, 0] if boundary_seq is not None else None, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache, + grad_path=grad_path, + backend_static_tensors=static_tensors, + materialize_population_next_state=materialize_final_state or grad_path, + materialize_cells_state=True, + ) + if step_population_state_cache is not None and (materialize_final_state or grad_path): + self._apply_stream_step_population_cache(next_state, step_population_state_cache) + if selected_backend_name == "pytorch": + self._record_pytorch_backend_execution( + batch_size=batch_size, + time_steps=1, + inner_steps=int(max_steps), + training=grad_path, + regime="stream", + ) + return y_step, next_state + + outputs: list[torch.Tensor] | None = [] if grad_path else None + outputs_buffer: torch.Tensor | None = None + running_state = current_state + constant_k_rows = None + constant_max_steps = None + constant_all_active = None + step_population_state_cache = None + step_sender_cache = None + if constant_k is not None: + constant_k_rows = torch.full((batch_size,), constant_k, device=device, dtype=torch.long) + constant_max_steps = constant_k + constant_all_active = constant_k > 0 + if constant_k > 0: + if use_fresh_backend_population_cache: + step_population_state_cache = self._prepare_fresh_stream_step_population_cache( + batch=batch_size, + device=device, + dtype=dtype, + ) + elif selected_backend_name == "cuda": + step_population_state_cache = self._prepare_stream_step_population_cache( + running_state, + batch=batch_size, + device=device, + dtype=dtype, + ) + if constant_k == 1: + step_sender_cache = {} + for step_index in range(time_steps): + if constant_k_rows is None or constant_max_steps is None: + k_rows, max_steps = self._resolve_step_k( + k, + batch_size=batch_size, + time_steps=time_steps, + step_index=step_index, + device=device, + ) + all_active = None + else: + k_rows, max_steps = constant_k_rows, constant_max_steps + all_active = constant_all_active if max_steps <= 1 else None + step_resets = population_resets[:, step_index] if population_resets is not None else None + y_step, running_state = self._forward_stream_step( + hidden_step=hidden_seq[:, step_index] if hidden_seq is not None else None, + state=running_state, + resets=step_resets, + has_resets=step_reset_flags[step_index] if step_reset_flags is not None else None, + capture_active=capture_active, + k_rows=k_rows, + max_steps=max_steps, + all_active=all_active, + q=q, + recurrent_q=recurrent_q, + output_q=output_q, + gathered_kv_weight=gathered_kv_weight, + sender_kv_weight=sender_kv_weight, + sender_input_to_kv_weight=sender_input_to_kv_weight, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + sender_group_input_to_kv_weight=sender_group_input_to_kv_weight, + sender_group_size=self._sender_kv_group_size, + recurrent_sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + recurrent_group_input_to_kv_weight=recurrent_group_input_to_kv_weight, + recurrent_group_size=self._recurrent_sender_kv_group_size, + value_to_cell_weight=value_to_cell_weight, + fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, + value_to_output_weight=value_to_output_weight, + cell_bias=cell_bias, + recurrent_cell_bias=recurrent_cell_bias, + fused_recurrent_cell_bias=fused_recurrent_cell_bias, + fused_recurrent_population_input=fused_recurrent_population_input, + boundary_step=boundary_seq[:, step_index] if boundary_seq is not None else None, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache, + step_sender_cache=step_sender_cache, + grad_path=grad_path, + backend_static_tensors=static_tensors, + materialize_population_next_state=materialize_final_state or grad_path or step_index + 1 < time_steps, + materialize_cells_state=True, + ) + if grad_path: + assert outputs is not None + outputs.append(y_step) + else: + if outputs_buffer is None: + outputs_buffer = y_step.new_empty(batch_size, time_steps, *y_step.shape[1:]) + outputs_buffer[:, step_index].copy_(y_step) + if step_population_state_cache is not None and materialize_final_state: + self._apply_stream_step_population_cache(running_state, step_population_state_cache) + if grad_path: + assert outputs is not None + stacked = torch.stack(outputs, dim=1) + if selected_backend_name == "pytorch": + self._record_pytorch_backend_execution( + batch_size=batch_size, + time_steps=time_steps, + inner_steps=int(constant_k if constant_k is not None else max(1, int(k_rows.max().item()))), + training=grad_path, + regime="stream", + ) + return stacked, running_state if materialize_final_state else TensorDict({}, batch_size=[]) + assert outputs_buffer is not None + if selected_backend_name == "pytorch": + self._record_pytorch_backend_execution( + batch_size=batch_size, + time_steps=time_steps, + inner_steps=int(constant_k if constant_k is not None else max(1, int(k_rows.max().item()))), + training=grad_path, + regime="stream", + ) + return outputs_buffer, running_state if materialize_final_state else TensorDict({}, batch_size=[]) + + def forward_output_cells_for_readout( + self, + state: MaybeState = None, + *, + resets: Optional[ResetMask], + k: int | torch.Tensor | None, + boundary_input: torch.Tensor | None = None, + source_hidden_input: torch.Tensor | None = None, + input_projection_weight: torch.Tensor | None = None, + input_projection_bias: torch.Tensor | None = None, + training_semantics: bool | None, + materialize_final_state: bool = True, + tape_policy: TapePolicy | None = None, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + output_chunk_consumer: _RuntimeOutputChunkConsumer | None = None, + detach_internal_carry_after_output_chunk: bool = False, + ) -> tuple[torch.Tensor, TensorDict]: + if output_boundary not in {"sequence", "terminal"}: + raise ValueError(f"Unsupported Fabric output-cell boundary {output_boundary!r}") + if readout_output_boundary not in {"cells", "pooled"}: + raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") + if self._preserve_backend_execution_record_depth <= 0: + self._last_backend_execution = None + self._last_backend_launch_metadata = None + projected_boundary_active = source_hidden_input is not None or input_projection_weight is not None + source_hidden_seq: torch.Tensor | None = None + if boundary_input is not None and projected_boundary_active: + raise ValueError("Pass either boundary_input or source_hidden_input/input_projection_weight, not both") + if projected_boundary_active: + if source_hidden_input is None or input_projection_weight is None: + raise ValueError( + "Projected Fabric boundary input requires source_hidden_input and input_projection_weight" + ) + source_hidden_seq = ( + source_hidden_input.unsqueeze(1) if source_hidden_input.dim() == 2 else source_hidden_input + ) + if source_hidden_seq.dim() != 3: + raise ValueError( + f"source_hidden_input must be shaped [B,H] or [B,T,H], got {tuple(source_hidden_input.shape)}" + ) + projected_features = int(self.input_cell_idx.numel()) * int(self.hidden_size) + if input_projection_weight.dim() != 2 or tuple(input_projection_weight.shape) != ( + projected_features, + int(source_hidden_seq.shape[-1]), + ): + raise ValueError( + "input_projection_weight must have shape " + f"[{projected_features}, {int(source_hidden_seq.shape[-1])}], " + f"got {tuple(input_projection_weight.shape)}" + ) + if input_projection_bias is not None and tuple(input_projection_bias.shape) != (projected_features,): + raise ValueError( + f"input_projection_bias must have shape [{projected_features}], " + f"got {tuple(input_projection_bias.shape)}" + ) + projected_boundary_weight = input_projection_weight + projected_boundary_bias = input_projection_bias + boundary_seq = None + boundary_input_for_fallback = None + else: + if boundary_input is None: + raise ValueError("forward_output_cells_for_readout requires boundary_input or projected source input") + projected_boundary_weight = None + projected_boundary_bias = None + boundary_seq = boundary_input.unsqueeze(1) if boundary_input.dim() == 3 else boundary_input + boundary_input_for_fallback = boundary_input + if boundary_seq is not None and boundary_seq.dim() != 4: + raise ValueError( + "forward_output_cells_for_readout expects boundary_input shaped [B,P,D] or [B,T,P,D], " + f"got {tuple(boundary_input.shape) if boundary_input is not None else tuple(source_hidden_input.shape)}" + ) + if boundary_seq is not None: + batch_size, time_steps, port_count, msg_dim = boundary_seq.shape + device = boundary_seq.device + dtype = boundary_seq.dtype + else: + assert source_hidden_seq is not None + batch_size = int(source_hidden_seq.shape[0]) + time_steps = int(source_hidden_seq.shape[1]) + port_count = int(self.input_cell_idx.numel()) + msg_dim = int(self.hidden_size) + device = source_hidden_seq.device + dtype = source_hidden_seq.dtype + if port_count != self.input_cell_idx.numel(): + raise ValueError( + f"Runtime boundary_input count={port_count} must match input cells={self.input_cell_idx.numel()}" + ) + if msg_dim != self.hidden_size: + raise ValueError(f"Runtime boundary_input dim={msg_dim} must match hidden_size={self.hidden_size}") + backend_population_state_is_fresh = state is None or not isinstance(state, TensorDictBase) + grad_path = torch.is_grad_enabled() if training_semantics is None else bool(training_semantics) + population_resets = _expand_resets_for_time(resets, batch_size=batch_size, time_steps=time_steps, device=device) + temporal_execution_plan = self._plan_temporal_execution( + k=k, + device=device, + dtype=dtype, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + materialize_final_state=materialize_final_state, + has_output_consumer=output_chunk_consumer is not None, + fresh_state=backend_population_state_is_fresh, + training=grad_path, + time_steps=time_steps, + ) + selected_backend_name = select_fabric_backend( + configured_backend=str(self.config.backend), + device=device, + supports_cuda_backend=temporal_execution_plan.supported, + ) + constant_k_host = self._resolve_constant_k_host(k) + backend_native_static_materialization = bool( + selected_backend_name == "cuda" and temporal_execution_plan.supported + ) + active_output_window_autograd_static_values = ( + temporal_execution_plan.backward.static_values_mode == "active_output_autograd_static_values" + ) + static_tensors: dict[str, object] | None = None + + def resolve_static_tensors() -> dict[str, object]: + nonlocal static_tensors + if static_tensors is not None: + return static_tensors + if grad_path: + training_static_prepack = bool( + self._training_static_prepack_enabled() or active_output_window_autograd_static_values + ) + self._last_training_static_prepack_mode = "contiguous_prepack" if training_static_prepack else "views" + self._last_backward_projection_mode = ( + "fused_static_projection" if training_static_prepack else "factorized_recurrent_input" + ) + if selected_backend_name == "pytorch": + static_tensors = self._materialize_inference_static_tensors( + device=device, + dtype=dtype, + include_backend_prepack=training_static_prepack, + ) + self._last_training_static_tape_mode = "pytorch_autograd_static_values" + else: + static_tensors = self._get_training_static_tensors( + device=device, + dtype=dtype, + include_backend_prepack=training_static_prepack, + include_full_cell_kv_weight=not backend_native_static_materialization, + include_population_materialized=not active_output_window_autograd_static_values, + detach_static_tensors=not active_output_window_autograd_static_values, + ) + if active_output_window_autograd_static_values: + self._last_training_static_tape_mode = "active_output_autograd_static_values" + else: + self._last_training_static_tape_mode = "detached_shared_values" + else: + self._last_training_static_prepack_mode = "inference_cache" + self._last_backward_projection_mode = "not_training" + static_tensors = self._get_inference_static_tensors( + device=device, + dtype=dtype, + include_full_cell_kv_weight=not backend_native_static_materialization, + ) + return static_tensors + + def resolve_temporal_sequence_static_tensors() -> dict[str, object]: + if not grad_path or not active_output_window_autograd_static_values: + return resolve_static_tensors() + training_static_prepack = bool(self._training_static_prepack_enabled()) + self._last_training_static_prepack_mode = "contiguous_prepack" if training_static_prepack else "views" + self._last_backward_projection_mode = ( + "fused_static_projection" if training_static_prepack else "factorized_recurrent_input" + ) + self._last_training_static_tape_mode = "detached_shared_values" + return self._get_training_static_tensors( + device=device, + dtype=dtype, + include_backend_prepack=training_static_prepack, + include_full_cell_kv_weight=not backend_native_static_materialization, + include_population_materialized=True, + detach_static_tensors=True, + ) + + def resolve_active_output_route_static_tensors(*, inner_steps: int) -> dict[str, object]: + horizon_steps = self.config.gradient_horizon_steps + total_steps = int(time_steps) * max(1, int(inner_steps)) + if grad_path and horizon_steps is not None and int(horizon_steps) < total_steps: + return resolve_temporal_sequence_static_tensors() + return resolve_static_tensors() + + flat_temporal_strategy = temporal_execution_plan.executor.temporal_strategy + + def execute_planned_flat_temporal_strategy() -> tuple[torch.Tensor, MaybeState] | None: + nonlocal boundary_seq, boundary_input_for_fallback + if ( + projected_boundary_active + and source_hidden_seq is not None + and boundary_seq is None + and selected_backend_name == "cuda" + and temporal_execution_plan.executor.selected_implementation == "shared_transition_buckets" + and output_chunk_consumer is not None + ): + return self._execute_flat_bucket_projected_source_sequence( + state=state if isinstance(state, TensorDictBase) else None, + projected_boundary_source_seq=source_hidden_seq, + projected_boundary_weight=cast(torch.Tensor, projected_boundary_weight), + projected_boundary_bias=projected_boundary_bias, + static_tensors=resolve_temporal_sequence_static_tensors(), + population_resets=population_resets, + materialize_final_state=materialize_final_state, + grad_path=grad_path, + tape_policy=tape_policy, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + output_chunk_consumer=output_chunk_consumer, + detach_internal_carry_after_output_chunk=detach_internal_carry_after_output_chunk, + k=k, + temporal_execution_plan=temporal_execution_plan, + ) + if ( + projected_boundary_active + and source_hidden_seq is not None + and boundary_seq is None + and selected_backend_name == "cuda" + and flat_temporal_strategy == "shared_active_output_window" + ): + inner_steps = int(constant_k_host) + route_static_tensors = resolve_active_output_route_static_tensors(inner_steps=inner_steps) + active_region = _resolve_temporal_plan_active_region( + self, + temporal_execution_plan=temporal_execution_plan, + inner_steps=inner_steps, + time_steps=time_steps, + ) or _flat_bucket_active_output_region_for_inner_steps( + self, + inner_steps=inner_steps, + time_steps=time_steps, + ) + active_output_result = execute_temporal_bucket_active_output_window( + self, + projected_boundary_source_seq=source_hidden_seq, + projected_boundary_weight=cast(torch.Tensor, projected_boundary_weight), + projected_boundary_bias=projected_boundary_bias, + resets=population_resets, + static_tensors=route_static_tensors, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + active_region=active_region, + inner_steps=inner_steps, + ) + if active_output_result is not None: + output_cells, next_state = active_output_result + record_temporal_bucket_sequence_surface_execution( + self, + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + training=grad_path, + readout_output_boundary=readout_output_boundary, + active_receiver_window_mode=_active_region_record_mode( + active_region, + inner_steps=inner_steps, + time_steps=time_steps, + temporal_execution_plan=temporal_execution_plan, + ), + active_receiver_window_offset=str(int(active_region.start)), + active_receiver_window_count=str(int(active_region.count)), + ) + return output_cells, next_state + if ( + boundary_seq is not None + and selected_backend_name == "cuda" + and flat_temporal_strategy == "shared_active_output_window" + ): + inner_steps = int(constant_k_host) + route_static_tensors = resolve_active_output_route_static_tensors(inner_steps=inner_steps) + active_region = _resolve_temporal_plan_active_region( + self, + temporal_execution_plan=temporal_execution_plan, + inner_steps=inner_steps, + time_steps=time_steps, + ) or _flat_bucket_active_output_region_for_inner_steps( + self, + inner_steps=inner_steps, + time_steps=time_steps, + ) + active_output_result = execute_temporal_bucket_active_output_window( + self, + boundary_seq=boundary_seq, + resets=population_resets, + static_tensors=route_static_tensors, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + active_region=active_region, + inner_steps=inner_steps, + ) + if active_output_result is not None: + output_cells, next_state = active_output_result + record_temporal_bucket_sequence_surface_execution( + self, + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + training=grad_path, + readout_output_boundary=readout_output_boundary, + active_receiver_window_mode=_active_region_record_mode( + active_region, + inner_steps=inner_steps, + time_steps=time_steps, + temporal_execution_plan=temporal_execution_plan, + ), + active_receiver_window_offset=str(int(active_region.start)), + active_receiver_window_count=str(int(active_region.count)), + ) + return output_cells, next_state + if projected_boundary_active and source_hidden_seq is not None and boundary_seq is None: + assert input_projection_weight is not None + boundary_seq = self._project_boundary_source_sequence( + source_hidden_seq, + input_projection_weight=input_projection_weight, + input_projection_bias=input_projection_bias, + ) + boundary_input_for_fallback = boundary_seq + if ( + selected_backend_name == "cuda" + and flat_temporal_strategy == "shared_temporal_sequence" + and boundary_seq is not None + ): + current_state = self._ensure_state(state, batch=batch_size, device=device, dtype=dtype) + route_static_tensors = resolve_static_tensors() + output_cells, next_state = execute_temporal_bucket_sequence( + self, + hidden_seq=None, + boundary_seq=boundary_seq, + state=current_state, + population_resets=population_resets, + step_reset_flags=None, + k=k, + constant_k=self._resolve_constant_k_host(k), + batch_size=batch_size, + time_steps=time_steps, + step_mode=False, + capture_active=bool(device.type == "cuda" and torch.cuda.is_current_stream_capturing()), + static_tensors=route_static_tensors, + grad_path=grad_path, + materialize_final_state=materialize_final_state, + backend_population_state_is_fresh=backend_population_state_is_fresh, + use_fresh_backend_population_cache=False, + tape_policy=tape_policy, + output_contract="pooled_output_cells" if readout_output_boundary == "pooled" else "output_cells", + output_boundary=output_boundary, + temporal_execution_plan=temporal_execution_plan, + ) + if output_boundary == "terminal" and int(output_cells.shape[1]) > 1: + output_cells = output_cells[:, -1:] + return output_cells, next_state + return None + + self._active_backend_name = selected_backend_name + planned_flat_temporal_result = execute_planned_flat_temporal_strategy() + if planned_flat_temporal_result is not None: + output_cells, next_state = planned_flat_temporal_result + if output_chunk_consumer is None: + return output_cells, next_state + if int(output_cells.shape[1]) == 0: + return output_cells, next_state + output_chunk_consumer(output_cells, 0, int(output_cells.shape[1])) + empty_output = output_cells.new_empty((batch_size, 0, *tuple(output_cells.shape[2:]))) + return empty_output, next_state + internal_materialize_state_for_output = bool(not materialize_final_state and selected_backend_name == "pytorch") + y_cells, next_state = self.forward_cells( + state=state, + resets=resets, + k=k, + boundary_input=boundary_input_for_fallback, + training_semantics=training_semantics, + materialize_final_state=materialize_final_state or internal_materialize_state_for_output, + ) + y_seq = y_cells.unsqueeze(1) if y_cells.dim() == 3 else y_cells + output_cells = self._select_output_cells(y_seq) + if readout_output_boundary == "pooled": + output_cells = self._pool_output_ports(output_cells) + if self._last_backend_execution is not None: + self._last_backend_execution = replace( + self._last_backend_execution, + workspace_aliases=self._last_backend_execution.workspace_aliases + + ("readout_output_boundary:pooled_from_backend_cells",), + ) + if output_boundary == "terminal" and int(output_cells.shape[1]) > 1: + output_cells = output_cells[:, -1:] + if output_chunk_consumer is not None: + output_chunk_consumer(output_cells, 0, int(output_cells.shape[1])) + empty_output = output_cells.new_empty((batch_size, 0, *tuple(output_cells.shape[2:]))) + if detach_internal_carry_after_output_chunk: + return empty_output, TensorDict({}, batch_size=[]) + return empty_output, TensorDict({}, batch_size=[]) if internal_materialize_state_for_output else next_state + return output_cells, TensorDict({}, batch_size=[]) if internal_materialize_state_for_output else next_state + + def _execute_flat_bucket_projected_source_sequence( + self, + *, + state: TensorDictBase | None, + projected_boundary_source_seq: torch.Tensor, + projected_boundary_weight: torch.Tensor, + projected_boundary_bias: torch.Tensor | None, + static_tensors: dict[str, object], + population_resets: torch.Tensor | None, + materialize_final_state: bool, + grad_path: bool, + tape_policy: TapePolicy | None, + output_boundary: Literal["sequence", "terminal"], + readout_output_boundary: Literal["cells", "pooled"], + output_chunk_consumer: _RuntimeOutputChunkConsumer | None, + detach_internal_carry_after_output_chunk: bool, + k: int | torch.Tensor | None, + temporal_execution_plan: TemporalExecutionPlan, + ) -> tuple[torch.Tensor, TensorDict]: + batch_size = int(projected_boundary_source_seq.shape[0]) + time_steps = int(projected_boundary_source_seq.shape[1]) + device = projected_boundary_source_seq.device + dtype = projected_boundary_source_seq.dtype + running_state = self._ensure_state(state, batch=batch_size, device=device, dtype=dtype) + chunk_len = self._projected_boundary_time_chunk_len( + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + packed_state=None, + initial_hidden=cast(torch.Tensor, running_state["cells"]), + initial_recurrent_k=cast(torch.Tensor | None, running_state.get("sender_k")), + initial_recurrent_v=cast(torch.Tensor | None, running_state.get("sender_v")), + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + outputs: list[torch.Tensor] = [] + for start in range(0, time_steps, chunk_len): + end = min(time_steps, start + chunk_len) + hidden_chunk = projected_boundary_source_seq[:, start:end] + boundary_chunk = self._project_boundary_source_sequence( + hidden_chunk, + input_projection_weight=projected_boundary_weight, + input_projection_bias=projected_boundary_bias, + ) + k_chunk = _slice_sequence_k(k, start=start, end=end, batch_size=batch_size, device=device) + output_chunk, next_state = execute_temporal_bucket_sequence( + self, + hidden_seq=None, + boundary_seq=boundary_chunk, + state=running_state, + population_resets=None if population_resets is None else population_resets[:, start:end], + step_reset_flags=None, + k=k_chunk, + constant_k=self._resolve_constant_k_host(k_chunk), + batch_size=batch_size, + time_steps=end - start, + step_mode=False, + capture_active=bool(device.type == "cuda" and torch.cuda.is_current_stream_capturing()), + static_tensors=static_tensors, + grad_path=grad_path, + materialize_final_state=bool(materialize_final_state or end < time_steps), + backend_population_state_is_fresh=False, + use_fresh_backend_population_cache=False, + tape_policy=tape_policy, + output_contract="pooled_output_cells" if readout_output_boundary == "pooled" else "output_cells", + output_boundary=output_boundary if end == time_steps else "sequence", + temporal_execution_plan=temporal_execution_plan, + ) + running_state = next_state + if output_boundary == "sequence" or end == time_steps: + if output_chunk_consumer is None: + outputs.append(output_chunk) + else: + output_chunk_consumer(output_chunk, start, end) + if detach_internal_carry_after_output_chunk and end < time_steps: + running_state = _detach_tensordict(running_state) + if self._last_backend_execution is not None: + chunk_reason = getattr(self, "_last_projected_boundary_time_chunk_reason", None) + self._last_backend_execution = replace( + self._last_backend_execution, + batch_size=batch_size, + time_steps=time_steps, + workspace_aliases=self._last_backend_execution.workspace_aliases + + ( + f"projected_boundary_time_chunk_len:t={int(chunk_len)}", + "projected_boundary_sequence_executor:flat_bucket_streaming_chunked", + *((f"projected_boundary_time_chunk_reason:{chunk_reason}",) if chunk_reason else ()), + ), + ) + if output_chunk_consumer is not None: + empty_output = projected_boundary_source_seq.new_empty( + (batch_size, 0, int(self._num_output_cells), int(self.hidden_size)) + ) + return empty_output, running_state if materialize_final_state else TensorDict({}, batch_size=[]) + output_seq = ( + torch.cat(outputs, dim=1) + if outputs + else projected_boundary_source_seq.new_empty((batch_size, 0, int(projected_boundary_source_seq.shape[-1]))) + ) + return output_seq, running_state if materialize_final_state else TensorDict({}, batch_size=[]) + + def _resolve_k( + self, + k: int | torch.Tensor | None, + *, + batch_size: int, + time_steps: int, + device: torch.device, + ) -> tuple[torch.Tensor, int]: + if k is None: + max_steps = int(self.config.default_k) + k_rows = torch.full((batch_size,), max_steps, device=device, dtype=torch.long) + elif isinstance(k, int): + max_steps = self._clamp_k_int(k) + k_rows = torch.full((batch_size,), max_steps, device=device, dtype=torch.long) + else: + k_tensor = torch.as_tensor(k, device=device, dtype=torch.long) + if k_tensor.dim() == 1 and k_tensor.shape[0] == batch_size: + k_rows = k_tensor + elif k_tensor.dim() == 2 and k_tensor.shape == (batch_size, time_steps): + first = k_tensor[:, :1] + if not bool((k_tensor == first).all()): + raise NotImplementedError( + "Per-timestep varying k within one sequence is not yet supported by " + "the current sequence-kernel runtime" + ) + k_rows = first.reshape(batch_size) + else: + raise ValueError(f"k must be int, [B], or [B,T], got shape {tuple(k_tensor.shape)}") + + k_rows = self._clamp_k_rows(k_rows) + max_steps = int(k_rows.max().item()) if k_rows.numel() > 0 else 0 + return k_rows, max_steps + + def _resolve_constant_k_host(self, k: int | torch.Tensor | None) -> int | None: + if k is None: + return int(self.config.default_k) + if isinstance(k, int): + return self._clamp_k_int(k) + return None + + def _resolve_step_k( + self, + k: int | torch.Tensor | None, + *, + batch_size: int, + time_steps: int, + step_index: int, + device: torch.device, + ) -> tuple[torch.Tensor, int]: + if k is None: + k_rows = torch.full((batch_size,), int(self.config.default_k), device=device, dtype=torch.long) + elif isinstance(k, int): + k_rows = torch.full((batch_size,), int(k), device=device, dtype=torch.long) + else: + k_tensor = torch.as_tensor(k, device=device, dtype=torch.long) + if k_tensor.dim() == 1 and k_tensor.shape[0] == batch_size: + k_rows = k_tensor + elif k_tensor.dim() == 2 and k_tensor.shape == (batch_size, time_steps): + k_rows = k_tensor[:, step_index] + else: + raise ValueError(f"k must be int, [B], or [B,T], got shape {tuple(k_tensor.shape)}") + k_rows = self._clamp_k_rows(k_rows) + max_steps = int(k_rows.max().item()) if k_rows.numel() > 0 else 0 + return k_rows, max_steps + + def _clamp_k_int(self, k: int) -> int: + steps = max(0, int(k)) + if self.config.k_max is not None: + steps = min(int(self.config.k_max), steps) + return steps + + def _clamp_k_rows(self, k_rows: torch.Tensor) -> torch.Tensor: + k_rows = k_rows.clamp(min=0) + if self.config.k_max is not None: + k_rows = k_rows.clamp(max=int(self.config.k_max)) + return k_rows + + def _constant_step_flat( + self, + step_idx: int, + *, + batch_size: int, + time_steps: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + key = ( + device.type, + -1 if device.index is None else int(device.index), + batch_size, + time_steps, + str(dtype), + int(step_idx), + ) + cached = self._constant_step_flat_cache.get(key) + if cached is None or cached.device != device: + cached = torch.full((batch_size * time_steps,), step_idx, device=device, dtype=dtype) + self._constant_step_flat_cache[key] = cached + return cached + + def _ensure_state( + self, + state: MaybeState, + *, + batch: int, + device: torch.device, + dtype: torch.dtype, + include_population_state: bool = True, + ) -> TensorDict: + if state is None or not isinstance(state, TensorDictBase): + if include_population_state: + return self.init_state(batch=batch, device=device, dtype=dtype) + with torch.profiler.record_function("fabric.glue.runtime_init_state"): + state_out = TensorDict({}, batch_size=[]) + state_out["cells"] = torch.zeros( + batch, + self.coords.shape[0], + self.hidden_size, + device=device, + dtype=dtype, + ) + return state_out + out = TensorDict({}, batch_size=[]) + cells = state.get("cells") + expected_cells = (batch, self.coords.shape[0], self.hidden_size) + if not torch.is_tensor(cells) or tuple(cells.shape) != expected_cells: + out["cells"] = torch.zeros(*expected_cells, device=device, dtype=dtype) + else: + out["cells"] = cells.to(device=device, dtype=dtype) + sender_k = state.get("sender_k") + expected_sender = (batch, int(self.sender_cell_idx.numel()), self.head_dim) + if torch.is_tensor(sender_k) and tuple(sender_k.shape) == expected_sender: + out["sender_k"] = sender_k.to(device=device, dtype=dtype) + sender_v = state.get("sender_v") + expected_value = (batch, int(self.sender_cell_idx.numel()), self.value_dim) + if torch.is_tensor(sender_v) and tuple(sender_v.shape) == expected_value: + out["sender_v"] = sender_v.to(device=device, dtype=dtype) + if include_population_state: + for population_name in self._population_names: + population_state = state.get(population_name) + expected = torch.Size([self._population_num_cells(population_name), batch]) + if population_state is None or population_state.batch_size != expected: + out[population_name] = self.population_modules[population_name].init_state( + batch=batch, + device=device, + dtype=dtype, + ) + elif self._population_state_matches_device_dtype( + population_name, + population_state, + device=device, + dtype=dtype, + ): + out[population_name] = population_state + else: + out[population_name] = population_state.to(device=device, dtype=dtype) + return out + + def _population_state_matches_device_dtype( + self, + population_name: str, + population_state: TensorDictBase, + *, + device: torch.device, + dtype: torch.dtype, + ) -> bool: + for state_name in self._cell_spec_for_population(population_name).state_schema.keys: + leaf = population_state[state_name] + if not torch.is_tensor(leaf) or leaf.device != device or leaf.dtype != dtype: + return False + return True + + def _forward_stream_step( + self, + *, + hidden_step: torch.Tensor | None, + state: TensorDict, + resets: torch.Tensor | None, + has_resets: bool | None, + capture_active: bool, + k_rows: torch.Tensor, + max_steps: int, + all_active: bool | None, + q: torch.Tensor, + recurrent_q: torch.Tensor, + output_q: torch.Tensor, + gathered_kv_weight: torch.Tensor, + value_to_cell_weight: torch.Tensor, + fused_recurrent_value_to_cell_weight: torch.Tensor | None, + value_to_output_weight: torch.Tensor, + cell_bias: torch.Tensor, + recurrent_cell_bias: torch.Tensor, + fused_recurrent_cell_bias: torch.Tensor, + fused_recurrent_population_input: bool, + boundary_step: torch.Tensor | None, + population_materialized: dict[str, object | None], + sender_kv_weight: torch.Tensor | None = None, + sender_input_to_kv_weight: torch.Tensor | None = None, + input_sender_input_to_kv_weight: torch.Tensor | None = None, + sender_group_input_to_kv_weight: torch.Tensor | None = None, + sender_group_size: int = 1, + recurrent_sender_input_to_kv_weight: torch.Tensor | None = None, + recurrent_group_input_to_kv_weight: torch.Tensor | None = None, + recurrent_group_size: int = 1, + input_group_input_to_kv_weight: torch.Tensor | None = None, + step_population_state_cache: dict[str, object] | None = None, + step_sender_cache: dict[str, torch.Tensor] | None = None, + grad_path: bool | None = None, + input_k_step: torch.Tensor | None = None, + input_v_step: torch.Tensor | None = None, + backend_static_tensors: dict[str, object] | None = None, + materialize_population_next_state: bool = True, + materialize_cells_state: bool = True, + ) -> tuple[torch.Tensor, TensorDict]: + current_state = state + if resets is not None and (capture_active or has_resets is True): + materialized_population_state = [] + for population_name in self._population_names: + population_state_value = state.get(population_name) + materialized_population_state.append( + isinstance(population_state_value, TensorDictBase) + and any(torch.is_tensor(population_state_value.get(key)) for key in population_state_value.keys()) + ) + cache_has_materialized_state = any(materialized_population_state) + if step_population_state_cache is not None: + self._reset_stream_step_population_cache(step_population_state_cache, resets) + if step_sender_cache is not None: + self._reset_stream_step_sender_cache(step_sender_cache, resets) + if step_population_state_cache is not None and not cache_has_materialized_state: + reset_mask = torch.as_tensor(resets, device=self.coords.device, dtype=torch.bool).view(-1) + current_state = TensorDict(current_state.to_dict(), batch_size=[]) + current_state["cells"] = reset_backend_tensor_rows(current_state["cells"], reset_mask) + sender_k = current_state.get("sender_k") + if torch.is_tensor(sender_k): + current_state["sender_k"] = reset_backend_tensor_rows(sender_k, reset_mask) + sender_v = current_state.get("sender_v") + if torch.is_tensor(sender_v): + current_state["sender_v"] = reset_backend_tensor_rows(sender_v, reset_mask) + else: + reset_state = self.reset_state(current_state, resets) + assert isinstance(reset_state, TensorDictBase) + current_state = TensorDict(reset_state.to_dict(), batch_size=[]) + + cells_prev = current_state["cells"] + population_state = current_state + population_resets = resets.view(-1, 1) if resets is not None else None + + if max_steps <= 1: + backend_step_static_tensors = backend_static_tensors or { + "recurrent_q": recurrent_q, + "output_q": output_q, + "recurrent_sender_input_to_kv_weight": recurrent_sender_input_to_kv_weight, + "recurrent_group_input_to_kv_weight": recurrent_group_input_to_kv_weight, + "value_to_cell_weight": value_to_cell_weight, + "fused_recurrent_value_to_cell_weight": fused_recurrent_value_to_cell_weight, + "fused_recurrent_cell_bias": fused_recurrent_cell_bias, + "fused_recurrent_population_input": fused_recurrent_population_input, + "recurrent_cell_bias": recurrent_cell_bias, + "value_to_output_weight": value_to_output_weight, + "population_materialized": population_materialized, + } + return self._forward_stream_step_k1( + cells_prev=cells_prev, + population_state=population_state, + population_resets=population_resets, + k_rows=k_rows, + all_active=all_active, + recurrent_q=recurrent_q, + output_q=output_q, + sender_input_to_kv_weight=sender_input_to_kv_weight, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + sender_group_input_to_kv_weight=sender_group_input_to_kv_weight, + sender_group_size=self._sender_kv_group_size, + recurrent_sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + recurrent_group_input_to_kv_weight=recurrent_group_input_to_kv_weight, + recurrent_group_size=self._recurrent_sender_kv_group_size, + value_to_cell_weight=value_to_cell_weight, + fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, + value_to_output_weight=value_to_output_weight, + recurrent_cell_bias=recurrent_cell_bias, + fused_recurrent_cell_bias=fused_recurrent_cell_bias, + fused_recurrent_population_input=fused_recurrent_population_input, + boundary_step=boundary_step, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache, + step_sender_cache=step_sender_cache, + grad_path=torch.is_grad_enabled() if grad_path is None else grad_path, + input_k_step=input_k_step, + input_v_step=input_v_step, + backend_static_tensors=backend_step_static_tensors, + materialize_population_next_state=materialize_population_next_state, + materialize_cells_state=materialize_cells_state, + ) + if boundary_step is not None and hidden_step is None: + return self._forward_stream_step_boundary_multistep( + cells_prev=cells_prev, + population_state=population_state, + population_resets=population_resets, + k_rows=k_rows, + max_steps=max_steps, + recurrent_q=recurrent_q, + output_q=output_q, + sender_input_to_kv_weight=sender_input_to_kv_weight, + sender_group_input_to_kv_weight=sender_group_input_to_kv_weight, + sender_group_size=self._sender_kv_group_size, + recurrent_sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + recurrent_group_input_to_kv_weight=recurrent_group_input_to_kv_weight, + recurrent_group_size=self._recurrent_sender_kv_group_size, + value_to_cell_weight=value_to_cell_weight, + value_to_output_weight=value_to_output_weight, + recurrent_cell_bias=recurrent_cell_bias, + boundary_step=boundary_step, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache, + ) + + if boundary_step is not None: + cells_prev = cells_prev.clone() + if self._partitioned_layout: + cells_prev[:, self._input_slice, :] = boundary_step + else: + cells_prev[:, self.input_cell_idx, :] = boundary_step + + y_prev = cells_prev.unsqueeze(1) + use_packed_loop_cache = ( + step_population_state_cache is not None + and self._full_recurrent_population_name is not None + and self._full_recurrent_population_name in step_population_state_cache + ) + boundary_step_seq = boundary_step.unsqueeze(1) if boundary_step is not None else None + zero_output_step = None + if use_packed_loop_cache and boundary_step_seq is not None and self._partitioned_layout: + zero_output_step = cells_prev.new_zeros(cells_prev.shape[0], 1, self._num_output_cells, self.hidden_size) + + for step_idx in range(max_steps): + inner_population_resets = population_resets if step_idx == 0 else None + z_prev = self.public_proj(y_prev) + msg = self._compute_messages( + z_prev, + q=q, + gathered_kv_weight=gathered_kv_weight, + step_idx=step_idx + 1, + ) + if hidden_step is not None and (self.config.inject_every_step or step_idx == 0): + msg = self._inject_hidden_inputs(msg, hidden_step.unsqueeze(1)) + population_input = self.msg_to_cell(msg) + cell_bias + if use_packed_loop_cache: + population_y = self._run_population_updates_step_cached( + population_input, + resets=inner_population_resets, + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache, + ) + if zero_output_step is not None and boundary_step_seq is not None: + y_next = torch.cat((boundary_step_seq, population_y.unsqueeze(1), zero_output_step), dim=2) + else: + y_next = population_input.new_zeros(population_input.shape) + population_name = self._full_recurrent_population_name + assert population_name is not None + population_idx = self._population_indices(population_name) + y_next[:, 0, population_idx, :] = population_y.to(dtype=y_next.dtype) + else: + y_next, next_population_state = self._run_population_updates( + population_input, + population_state, + resets=inner_population_resets, + batch_size=y_prev.shape[0], + time_steps=1, + population_materialized=population_materialized, + ) + if boundary_step_seq is not None and not (use_packed_loop_cache and zero_output_step is not None): + y_next[:, :, self.input_cell_idx, :] = boundary_step_seq + active_rows = step_idx < k_rows + y_prev = torch.where(active_rows.view(-1, 1, 1, 1), y_next, y_prev) + if not use_packed_loop_cache: + population_state = self._blend_population_states(population_state, next_population_state, active_rows) + + final_z = self.public_proj(y_prev) + final_msg = self._compute_messages( + final_z, + q=q, + gathered_kv_weight=gathered_kv_weight, + step_idx=k_rows, + ) + output_cells = self._project_output_cells(final_msg[:, :, self.output_cell_idx, :]).to(dtype=y_prev.dtype) + if zero_output_step is not None and boundary_step_seq is not None: + y_out = torch.cat((boundary_step_seq, y_prev[:, :, self._recurrent_slice, :], output_cells), dim=2) + else: + y_out = y_prev.clone() + y_out[:, :, self.output_cell_idx, :] = output_cells + next_state = TensorDict({}, batch_size=[]) + next_state["cells"] = y_out.squeeze(1) + for population_name in self._population_names: + next_state[population_name] = population_state[population_name] + return y_out.squeeze(1), next_state + + def _forward_stream_step_k1( + self, + *, + cells_prev: torch.Tensor, + population_state: TensorDict, + population_resets: torch.Tensor | None, + k_rows: torch.Tensor, + all_active: bool | None, + recurrent_q: torch.Tensor, + output_q: torch.Tensor, + value_to_cell_weight: torch.Tensor, + fused_recurrent_value_to_cell_weight: torch.Tensor | None = None, + value_to_output_weight: torch.Tensor, + recurrent_cell_bias: torch.Tensor, + fused_recurrent_cell_bias: torch.Tensor | None = None, + fused_recurrent_population_input: bool = False, + boundary_step: torch.Tensor | None, + population_materialized: dict[str, object | None], + sender_input_to_kv_weight: torch.Tensor | None = None, + input_sender_input_to_kv_weight: torch.Tensor | None = None, + sender_group_input_to_kv_weight: torch.Tensor | None = None, + sender_group_size: int = 1, + recurrent_sender_input_to_kv_weight: torch.Tensor | None = None, + recurrent_group_input_to_kv_weight: torch.Tensor | None = None, + recurrent_group_size: int = 1, + input_group_input_to_kv_weight: torch.Tensor | None = None, + step_population_state_cache: dict[str, object] | None = None, + step_sender_cache: dict[str, torch.Tensor] | None = None, + grad_path: bool | None = None, + input_k_step: torch.Tensor | None = None, + input_v_step: torch.Tensor | None = None, + backend_static_tensors: dict[str, object] | None = None, + materialize_population_next_state: bool = True, + materialize_cells_state: bool = True, + ) -> tuple[torch.Tensor, TensorDict]: + effective_grad_path = torch.is_grad_enabled() if grad_path is None else grad_path + if fused_recurrent_cell_bias is None: + fused_recurrent_cell_bias = recurrent_cell_bias + inplace_state_update = not effective_grad_path + grad_partitioned_sender_cache = bool( + not inplace_state_update + and step_sender_cache is not None + and self._partitioned_layout + and self._local_message_step_enabled + and boundary_step is not None + ) + if inplace_state_update: + if not torch.is_tensor(population_state.get("sender_k")): + population_state["sender_k"] = cells_prev.new_zeros( + cells_prev.shape[0], + int(self.sender_cell_idx.numel()), + self.head_dim, + ) + if not torch.is_tensor(population_state.get("sender_v")): + population_state["sender_v"] = cells_prev.new_zeros( + cells_prev.shape[0], + int(self.sender_cell_idx.numel()), + self.value_dim, + ) + current_sender_k = population_state.get("sender_k") if inplace_state_update else None + current_sender_v = population_state.get("sender_v") if inplace_state_update else None + if step_sender_cache is not None and not grad_partitioned_sender_cache: + current_sender_k = step_sender_cache.get("sender_k") + current_sender_v = step_sender_cache.get("sender_v") + use_sender_cache = ( + torch.is_tensor(current_sender_k) + and torch.is_tensor(current_sender_v) + and tuple(current_sender_k.shape) == (cells_prev.shape[0], int(self.sender_cell_idx.numel()), self.head_dim) + and tuple(current_sender_v.shape) + == (cells_prev.shape[0], int(self.sender_cell_idx.numel()), self.value_dim) + ) + use_partitioned_sender_banks = bool( + self._partitioned_layout + and self._local_message_step_enabled + and (inplace_state_update or grad_partitioned_sender_cache) + ) + input_k = None + input_v = None + recurrent_k_prev = None + recurrent_v_prev = None + if self._partitioned_layout: + if boundary_step is None: + input_prev = cells_prev[:, self._input_slice, :] + sender_cells_prev = cells_prev[:, : self._num_input_cells + self._num_recurrent_cells, :] + else: + input_prev = boundary_step + sender_cells_prev = None + recurrent_slice = self._recurrent_slice + output_slice = self._output_slice + recurrent_prev = cells_prev[:, recurrent_slice, :] + else: + if boundary_step is None: + sender_cells_prev = cells_prev.index_select(1, self.sender_cell_idx) + else: + if not inplace_state_update: + cells_prev = cells_prev.clone() + cells_prev[:, self.input_cell_idx, :] = boundary_step + sender_cells_prev = cells_prev.index_select(1, self.sender_cell_idx) + recurrent_slice = None + output_slice = None + recurrent_prev = cells_prev[:, self.recurrent_cell_idx, :] + if not self._partitioned_layout: + input_prev = None + sender_cells_prev = cells_prev.index_select(1, self.sender_cell_idx) + if self._partitioned_layout and boundary_step is not None: + assert input_prev is not None + if input_sender_input_to_kv_weight is None and sender_input_to_kv_weight is not None: + input_sender_input_to_kv_weight = sender_input_to_kv_weight.index_select(0, self.input_sender_idx) + if torch.is_tensor(input_k_step) and torch.is_tensor(input_v_step): + input_k, input_v = input_k_step, input_v_step + else: + input_k, input_v = self._project_sender_kv_from_cells_step( + input_prev, + sender_input_to_kv_weight=input_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, + sender_group_size=self._input_sender_kv_group_size, + ) + if grad_partitioned_sender_cache: + recurrent_k_prev = step_sender_cache.get("recurrent_k") + recurrent_v_prev = step_sender_cache.get("recurrent_v") + state_sender_k = population_state.get("sender_k") + state_sender_v = population_state.get("sender_v") + if ( + not ( + torch.is_tensor(recurrent_k_prev) + and tuple(recurrent_k_prev.shape) + == (cells_prev.shape[0], self._num_recurrent_cells, self.head_dim) + ) + and torch.is_tensor(state_sender_k) + and tuple(state_sender_k.shape) + == (cells_prev.shape[0], int(self.sender_cell_idx.numel()), self.head_dim) + ): + recurrent_k_prev = state_sender_k[:, self._recurrent_slice, :] + if ( + not ( + torch.is_tensor(recurrent_v_prev) + and tuple(recurrent_v_prev.shape) + == (cells_prev.shape[0], self._num_recurrent_cells, self.value_dim) + ) + and torch.is_tensor(state_sender_v) + and tuple(state_sender_v.shape) + == (cells_prev.shape[0], int(self.sender_cell_idx.numel()), self.value_dim) + ): + recurrent_v_prev = state_sender_v[:, self._recurrent_slice, :] + if not ( + torch.is_tensor(recurrent_k_prev) + and torch.is_tensor(recurrent_v_prev) + and tuple(recurrent_k_prev.shape) == (cells_prev.shape[0], self._num_recurrent_cells, self.head_dim) + and tuple(recurrent_v_prev.shape) + == (cells_prev.shape[0], self._num_recurrent_cells, self.value_dim) + ): + recurrent_k_prev, recurrent_v_prev = self._project_sender_kv_from_cells_step( + recurrent_prev, + sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=recurrent_group_input_to_kv_weight, + sender_group_size=recurrent_group_size, + contiguous_kv=use_partitioned_sender_banks, + ) + k_all = None + v_all = None + elif use_sender_cache: + assert current_sender_k is not None and current_sender_v is not None + k_all = current_sender_k + v_all = current_sender_v + k_all[:, self._input_slice, :] = input_k + v_all[:, self._input_slice, :] = input_v + recurrent_k_prev = k_all[:, self._recurrent_slice, :] + recurrent_v_prev = v_all[:, self._recurrent_slice, :] + else: + recurrent_k_prev, recurrent_v_prev = self._project_sender_kv_from_cells_step( + recurrent_prev, + sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=recurrent_group_input_to_kv_weight, + sender_group_size=recurrent_group_size, + contiguous_kv=use_partitioned_sender_banks, + ) + if use_partitioned_sender_banks: + k_all = None + v_all = None + else: + k_all = input_k.new_empty(cells_prev.shape[0], self.sender_cell_idx.numel(), self.head_dim) + v_all = input_v.new_empty(cells_prev.shape[0], self.sender_cell_idx.numel(), self.value_dim) + k_all[:, self._input_slice, :] = input_k + k_all[:, self._recurrent_slice, :] = recurrent_k_prev + v_all[:, self._input_slice, :] = input_v + v_all[:, self._recurrent_slice, :] = recurrent_v_prev + else: + if use_sender_cache: + assert current_sender_k is not None and current_sender_v is not None + k_all = current_sender_k + v_all = current_sender_v + if boundary_step is not None: + assert input_prev is not None + if torch.is_tensor(input_k_step) and torch.is_tensor(input_v_step): + input_k, input_v = input_k_step, input_v_step + else: + input_k, input_v = self._project_sender_kv_from_cells_step( + input_prev, + sender_input_to_kv_weight=input_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, + sender_group_size=self._input_sender_kv_group_size, + ) + k_all[:, self.input_sender_idx, :] = input_k + v_all[:, self.input_sender_idx, :] = input_v + else: + assert sender_cells_prev is not None + k_all, v_all = self._project_sender_kv_from_cells_step( + sender_cells_prev, + sender_input_to_kv_weight=sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=sender_group_input_to_kv_weight, + sender_group_size=sender_group_size, + ) + if self._partitioned_layout: + input_k = k_all[:, : self._num_input_cells, :] + input_v = v_all[:, : self._num_input_cells, :] + recurrent_k_prev = None + recurrent_v_prev = None + if use_partitioned_sender_banks and self._partitioned_layout: + if input_k is None or input_v is None: + if current_sender_k is not None and current_sender_v is not None: + input_k = current_sender_k[:, self._input_slice, :] + input_v = current_sender_v[:, self._input_slice, :] + else: + assert input_prev is not None + if torch.is_tensor(input_k_step) and torch.is_tensor(input_v_step): + input_k, input_v = input_k_step, input_v_step + else: + input_k, input_v = self._project_sender_kv_from_cells_step( + input_prev, + sender_input_to_kv_weight=input_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, + sender_group_size=self._input_sender_kv_group_size, + ) + if recurrent_k_prev is None or recurrent_v_prev is None: + if current_sender_k is not None and current_sender_v is not None: + recurrent_k_prev = current_sender_k[:, self._recurrent_slice, :] + recurrent_v_prev = current_sender_v[:, self._recurrent_slice, :] + else: + recurrent_k_prev, recurrent_v_prev = self._project_sender_kv_from_cells_step( + recurrent_prev, + sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=recurrent_group_input_to_kv_weight, + sender_group_size=recurrent_group_size, + contiguous_kv=use_partitioned_sender_banks, + ) + if all_active is False: + recurrent_mid = recurrent_prev + blended_population_state = population_state + if use_partitioned_sender_banks: + final_k = None + final_v = None + recurrent_k = recurrent_k_prev + recurrent_v = recurrent_v_prev + else: + final_k = k_all + final_v = v_all + else: + use_backend_order_transition_buckets = bool( + self._active_backend_name != "pytorch" + and backend_static_tensors is not None + and self.population_backend_recurrent_order.numel() == self._num_recurrent_cells + and torch.is_tensor(backend_static_tensors.get("recurrent_q_backend_order")) + ) + recurrent_q_for_message = ( + cast(torch.Tensor, backend_static_tensors["recurrent_q_backend_order"]) + if use_backend_order_transition_buckets + else recurrent_q + ) + recurrent_neighbor_idx_for_message = ( + self.recurrent_neighbor_idx_backend_order + if use_backend_order_transition_buckets + else self.recurrent_neighbor_idx + ) + recurrent_neighbor_valid_for_message = ( + self.recurrent_neighbor_valid_backend_order + if use_backend_order_transition_buckets + else self.recurrent_neighbor_valid + ) + recurrent_edge_distance_for_message = ( + self.recurrent_edge_distance_backend_order + if use_backend_order_transition_buckets + else self.recurrent_edge_distance + ) + recurrent_edge_delay_for_message = ( + self.recurrent_edge_delay_backend_order + if use_backend_order_transition_buckets + else self.recurrent_edge_delay + ) + recurrent_local_sender_idx_for_message = ( + self.recurrent_local_sender_idx_backend_order + if use_backend_order_transition_buckets + else self.recurrent_local_sender_idx + ) + recurrent_local_receiver_idx_by_sender_for_message = ( + self.recurrent_local_receiver_idx_by_sender_backend_order + if use_backend_order_transition_buckets + else self.recurrent_local_receiver_idx_by_sender + ) + if use_partitioned_sender_banks: + assert input_k is not None and input_v is not None + assert recurrent_k_prev is not None and recurrent_v_prev is not None + recurrent_msg = self._compute_messages_step_subset_partitioned_raw( + input_k, + input_v, + recurrent_k_prev, + recurrent_v_prev, + q_subset=recurrent_q_for_message, + neighbor_idx=recurrent_neighbor_idx_for_message, + neighbor_valid=recurrent_neighbor_valid_for_message, + edge_distance=recurrent_edge_distance_for_message, + edge_delay=recurrent_edge_delay_for_message, + use_delay=self._has_edge_delay, + step_idx=1, + local_sender_idx=recurrent_local_sender_idx_for_message, + local_receiver_idx_by_sender=recurrent_local_receiver_idx_by_sender_for_message, + owner_tag="recurrent", + ) + else: + recurrent_msg = self._compute_messages_step_subset_raw( + k_all, + v_all, + q_subset=recurrent_q_for_message, + neighbor_idx=recurrent_neighbor_idx_for_message, + neighbor_valid=recurrent_neighbor_valid_for_message, + edge_distance=recurrent_edge_distance_for_message, + edge_delay=recurrent_edge_delay_for_message, + use_delay=self._has_edge_delay, + step_idx=1, + local_sender_idx=recurrent_local_sender_idx_for_message, + local_receiver_idx_by_sender=recurrent_local_receiver_idx_by_sender_for_message, + owner_tag="recurrent", + ) + use_cuda_flat_bucket_transition_step = bool( + self._active_backend_name != "pytorch" and recurrent_msg.is_cuda and backend_static_tensors is not None + ) + if use_cuda_flat_bucket_transition_step: + if use_backend_order_transition_buckets: + recurrent_next_backend_order, next_population_state = ( + self._run_backend_order_transition_buckets_step( + recurrent_msg, + population_state, + resets=population_resets, + batch_size=cells_prev.shape[0], + static_tensors=backend_static_tensors, + step_population_state_cache=step_population_state_cache if all_active is True else None, + materialize_next_state=materialize_population_next_state, + ) + ) + recurrent_next = recurrent_next_backend_order.index_select( + 1, + self.population_backend_recurrent_inverse_order, + ) + else: + recurrent_next, next_population_state = self._run_transition_buckets_step( + recurrent_msg, + population_state, + resets=population_resets, + batch_size=cells_prev.shape[0], + static_tensors=backend_static_tensors, + step_population_state_cache=step_population_state_cache if all_active is True else None, + materialize_next_state=materialize_population_next_state, + ) + else: + recurrent_input = self._project_recurrent_message_to_cell_step( + recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, + fused_recurrent_cell_bias=fused_recurrent_cell_bias, + fused_recurrent_population_input=fused_recurrent_population_input, + ) + recurrent_next, next_population_state = self._run_population_updates_recurrent_step( + recurrent_input, + population_state, + resets=population_resets, + batch_size=cells_prev.shape[0], + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache if all_active is True else None, + population_input_already_projected=fused_recurrent_population_input, + ) + if all_active is True: + recurrent_mid = recurrent_next + blended_population_state = next_population_state + else: + active_rows = k_rows > 0 + recurrent_mid = torch.where(active_rows.view(-1, 1, 1), recurrent_next, recurrent_prev) + blended_population_state = self._blend_population_states( + population_state, + next_population_state, + active_rows, + ) + + recurrent_k, recurrent_v = self._project_sender_kv_from_cells_step( + recurrent_mid, + sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=recurrent_group_input_to_kv_weight, + sender_group_size=recurrent_group_size, + contiguous_kv=use_partitioned_sender_banks, + ) + if use_partitioned_sender_banks: + final_k = None + final_v = None + if use_sender_cache: + assert current_sender_k is not None and current_sender_v is not None + current_sender_k[:, self._recurrent_slice, :] = recurrent_k + current_sender_v[:, self._recurrent_slice, :] = recurrent_v + elif self._partitioned_layout: + final_k = k_all if inplace_state_update else k_all.clone() + final_v = v_all if inplace_state_update else v_all.clone() + final_k[:, self._recurrent_slice, :] = recurrent_k + final_v[:, self._recurrent_slice, :] = recurrent_v + else: + final_k = k_all if inplace_state_update else k_all.clone() + final_v = v_all if inplace_state_update else v_all.clone() + final_k[:, self.recurrent_sender_idx, :] = recurrent_k + final_v[:, self.recurrent_sender_idx, :] = recurrent_v + output_step_idx: int | torch.Tensor = 1 if all_active is True else k_rows + if use_partitioned_sender_banks: + assert input_k is not None and input_v is not None + assert recurrent_k is not None and recurrent_v is not None + output_msg = self._compute_messages_step_subset_partitioned_raw( + input_k, + input_v, + recurrent_k, + recurrent_v, + q_subset=output_q, + neighbor_idx=self.output_neighbor_idx, + neighbor_valid=self.output_neighbor_valid, + edge_distance=self.output_edge_distance, + edge_delay=self.output_edge_delay, + use_delay=self._has_edge_delay, + step_idx=output_step_idx, + local_sender_idx=self.output_local_sender_idx, + local_receiver_idx_by_sender=self.output_local_receiver_idx_by_sender, + owner_tag="readout", + ) + else: + output_msg = self._compute_messages_step_subset_raw( + final_k, + final_v, + q_subset=output_q, + neighbor_idx=self.output_neighbor_idx, + neighbor_valid=self.output_neighbor_valid, + edge_distance=self.output_edge_distance, + edge_delay=self.output_edge_delay, + use_delay=self._has_edge_delay, + step_idx=output_step_idx, + local_sender_idx=self.output_local_sender_idx, + local_receiver_idx_by_sender=self.output_local_receiver_idx_by_sender, + owner_tag="readout", + ) + output_cells = self._project_output_cells_step_raw( + output_msg, + value_to_output_weight=value_to_output_weight, + ).to(dtype=cells_prev.dtype) + if not materialize_cells_state and step_sender_cache is not None and boundary_step is not None: + cells_out = cells_prev.clone() + if output_slice is not None: + assert input_prev is not None + cells_out[:, self._input_slice, :] = input_prev + cells_out[:, self._output_slice, :] = output_cells + else: + cells_out[:, self.output_cell_idx, :] = output_cells + elif inplace_state_update: + cells_out = cells_prev + if output_slice is not None: + assert input_prev is not None + cells_out[:, self._input_slice, :] = input_prev + cells_out[:, self._recurrent_slice, :] = recurrent_mid + cells_out[:, self._output_slice, :] = output_cells + else: + cells_out[:, self.recurrent_cell_idx, :] = recurrent_mid + cells_out[:, self.output_cell_idx, :] = output_cells + elif output_slice is not None: + assert input_prev is not None + cells_out = torch.cat((input_prev, recurrent_mid, output_cells), dim=1) + else: + cells_out = cells_prev.clone() + cells_out[:, self.recurrent_cell_idx, :] = recurrent_mid + cells_out[:, self.output_cell_idx, :] = output_cells + next_state = TensorDict({}, batch_size=[]) + next_state["cells"] = cells_out + if inplace_state_update: + next_state["sender_k"] = current_sender_k if use_partitioned_sender_banks else final_k + next_state["sender_v"] = current_sender_v if use_partitioned_sender_banks else final_v + if grad_partitioned_sender_cache and recurrent_k is not None and recurrent_v is not None: + step_sender_cache["recurrent_k"] = recurrent_k + step_sender_cache["recurrent_v"] = recurrent_v + if step_sender_cache is not None and final_k is not None and final_v is not None: + step_sender_cache["sender_k"] = final_k + step_sender_cache["sender_v"] = final_v + if materialize_cells_state or step_population_state_cache is None: + for population_name in self._population_names: + next_state[population_name] = blended_population_state[population_name] + return cells_out, next_state + + def _forward_stream_step_boundary_multistep( + self, + *, + cells_prev: torch.Tensor, + population_state: TensorDict, + population_resets: torch.Tensor | None, + k_rows: torch.Tensor, + max_steps: int, + recurrent_q: torch.Tensor, + output_q: torch.Tensor, + value_to_cell_weight: torch.Tensor, + value_to_output_weight: torch.Tensor, + recurrent_cell_bias: torch.Tensor, + boundary_step: torch.Tensor, + population_materialized: dict[str, object | None], + sender_input_to_kv_weight: torch.Tensor | None = None, + sender_group_input_to_kv_weight: torch.Tensor | None = None, + sender_group_size: int = 1, + recurrent_sender_input_to_kv_weight: torch.Tensor | None = None, + recurrent_group_input_to_kv_weight: torch.Tensor | None = None, + recurrent_group_size: int = 1, + input_group_input_to_kv_weight: torch.Tensor | None = None, + step_population_state_cache: dict[str, object] | None = None, + ) -> tuple[torch.Tensor, TensorDict]: + batch_size = cells_prev.shape[0] + recurrent_mid = cells_prev[:, self.recurrent_cell_idx, :] + input_k, input_v = self._project_sender_kv_from_cells_step( + boundary_step, + sender_input_to_kv_weight=( + sender_input_to_kv_weight.index_select(0, self.input_sender_idx) + if sender_input_to_kv_weight is not None + else None + ), + grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, + sender_group_size=self._input_sender_kv_group_size, + ) + use_packed_cache = step_population_state_cache is not None + running_population_state = population_state + + for step_idx in range(max_steps): + inner_population_resets = population_resets if step_idx == 0 else None + recurrent_k, recurrent_v = self._project_sender_kv_from_cells_step( + recurrent_mid, + sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=recurrent_group_input_to_kv_weight, + sender_group_size=recurrent_group_size, + ) + if self._partitioned_layout: + k_all = torch.cat((input_k, recurrent_k), dim=1) + v_all = torch.cat((input_v, recurrent_v), dim=1) + else: + k_all = input_k.new_zeros(batch_size, self.sender_cell_idx.numel(), self.head_dim) + v_all = input_v.new_zeros(batch_size, self.sender_cell_idx.numel(), self.value_dim) + k_all[:, self.input_sender_idx, :] = input_k + v_all[:, self.input_sender_idx, :] = input_v + k_all[:, self.recurrent_sender_idx, :] = recurrent_k + v_all[:, self.recurrent_sender_idx, :] = recurrent_v + recurrent_msg = self._compute_messages_step_subset_raw( + k_all, + v_all, + q_subset=recurrent_q, + neighbor_idx=self.recurrent_neighbor_idx, + neighbor_valid=self.recurrent_neighbor_valid, + edge_distance=self.recurrent_edge_distance, + edge_delay=self.recurrent_edge_delay, + use_delay=self._has_edge_delay, + step_idx=step_idx + 1, + local_sender_idx=self.recurrent_local_sender_idx, + local_receiver_idx_by_sender=self.recurrent_local_receiver_idx_by_sender, + ) + recurrent_input = self._project_recurrent_message_to_cell_step( + recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + ) + recurrent_next, next_population_state = self._run_population_updates_recurrent_step( + recurrent_input, + running_population_state, + resets=inner_population_resets, + batch_size=batch_size, + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache, + ) + active_rows = step_idx < k_rows + recurrent_mid = torch.where(active_rows.view(-1, 1, 1), recurrent_next, recurrent_mid) + if not use_packed_cache: + running_population_state = self._blend_population_states( + running_population_state, + next_population_state, + active_rows, + ) + + recurrent_k, recurrent_v = self._project_sender_kv_from_cells_step( + recurrent_mid, + sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + grouped_sender_input_to_kv_weight=recurrent_group_input_to_kv_weight, + sender_group_size=recurrent_group_size, + ) + if self._partitioned_layout: + final_k = torch.cat((input_k, recurrent_k), dim=1) + final_v = torch.cat((input_v, recurrent_v), dim=1) + output_cells = self._project_output_cells_step_raw( + self._compute_messages_step_subset_raw( + final_k, + final_v, + q_subset=output_q, + neighbor_idx=self.output_neighbor_idx, + neighbor_valid=self.output_neighbor_valid, + edge_distance=self.output_edge_distance, + edge_delay=self.output_edge_delay, + use_delay=self._has_edge_delay, + step_idx=k_rows, + local_sender_idx=self.output_local_sender_idx, + local_receiver_idx_by_sender=self.output_local_receiver_idx_by_sender, + ), + value_to_output_weight=value_to_output_weight, + ).to(dtype=cells_prev.dtype) + cells_out = torch.cat((boundary_step, recurrent_mid, output_cells), dim=1) + else: + final_k = input_k.new_zeros(batch_size, self.sender_cell_idx.numel(), self.head_dim) + final_v = input_v.new_zeros(batch_size, self.sender_cell_idx.numel(), self.value_dim) + final_k[:, self.input_sender_idx, :] = input_k + final_v[:, self.input_sender_idx, :] = input_v + final_k[:, self.recurrent_sender_idx, :] = recurrent_k + final_v[:, self.recurrent_sender_idx, :] = recurrent_v + output_cells = self._project_output_cells_step_raw( + self._compute_messages_step_subset_raw( + final_k, + final_v, + q_subset=output_q, + neighbor_idx=self.output_neighbor_idx, + neighbor_valid=self.output_neighbor_valid, + edge_distance=self.output_edge_distance, + edge_delay=self.output_edge_delay, + use_delay=self._has_edge_delay, + step_idx=k_rows, + local_sender_idx=self.output_local_sender_idx, + local_receiver_idx_by_sender=self.output_local_receiver_idx_by_sender, + ), + value_to_output_weight=value_to_output_weight, + ).to(dtype=cells_prev.dtype) + cells_out = cells_prev.clone() + cells_out[:, self.input_cell_idx, :] = boundary_step + cells_out[:, self.recurrent_cell_idx, :] = recurrent_mid + cells_out[:, self.output_cell_idx, :] = output_cells + next_state = TensorDict({}, batch_size=[]) + next_state["cells"] = cells_out + for population_name in self._population_names: + next_state[population_name] = running_population_state[population_name] + return cells_out, next_state + + def _inject_hidden_inputs(self, msg: torch.Tensor, hidden_seq: torch.Tensor) -> torch.Tensor: + if self.input_cell_idx.numel() == 0: + return msg + out = msg.clone() + projected = self.input_proj(hidden_seq).unsqueeze(2) + out[:, :, self.input_cell_idx, :] = out[:, :, self.input_cell_idx, :] + projected + return out + + def _inject_boundary_inputs(self, msg: torch.Tensor, boundary_seq: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("boundary_input now represents direct boundary cell states") + + def _run_population_updates( + self, + population_input: torch.Tensor, + state: TensorDict, + *, + resets: torch.Tensor | None, + batch_size: int, + time_steps: int, + population_materialized: dict[str, object | None], + ) -> tuple[torch.Tensor, TensorDict]: + y_next = torch.zeros_like(population_input) + next_state = TensorDict({}, batch_size=[]) + active_populations = [name for name in self._population_names if self._population_num_cells(name) > 0] + for name in active_populations: + population_y, population_state = self._run_population( + name, + population_input, + state.get(name), + resets, + batch_size, + time_steps, + population_materialized, + ) + idx = self._population_indices(name) + y_next[:, :, idx, :] = population_y.to(dtype=y_next.dtype) + next_state[name] = population_state + return y_next, next_state + + def _run_population_updates_recurrent_step( + self, + recurrent_input: torch.Tensor, + state: TensorDict, + *, + resets: torch.Tensor | None, + batch_size: int, + population_materialized: dict[str, object | None], + step_population_state_cache: dict[str, object] | None = None, + population_input_already_projected: bool = False, + ) -> tuple[torch.Tensor, TensorDict]: + num_recurrent = int(self.recurrent_cell_idx.numel()) + if num_recurrent == 0: + return recurrent_input.new_empty(batch_size, 0, self.hidden_size), TensorDict({}, batch_size=[]) + if self._full_recurrent_population_name is not None: + population_name = self._full_recurrent_population_name + population_y, population_state = self._run_recurrent_population_step( + population_name, + recurrent_input, + state.get(population_name), + resets=resets, + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache, + population_input_already_projected=population_input_already_projected, + ) + next_state = TensorDict({}, batch_size=[]) + next_state[population_name] = population_state + return population_y.to(dtype=recurrent_input.dtype), next_state + + recurrent_next = recurrent_input.new_empty(batch_size, num_recurrent, self.hidden_size) + next_state = TensorDict({}, batch_size=[]) + active_populations = [ + name for name in self._population_names if self._population_recurrent_indices(name).numel() > 0 + ] + for name in active_populations: + recurrent_idx = self._population_recurrent_indices(name) + population_y, population_state = self._run_recurrent_population_step( + name, + recurrent_input, + state.get(name), + resets=resets, + population_materialized=population_materialized, + step_population_state_cache=step_population_state_cache, + population_input_already_projected=population_input_already_projected, + ) + recurrent_next[:, recurrent_idx, :] = population_y.to(dtype=recurrent_next.dtype) + next_state[name] = population_state + return recurrent_next, next_state + + def _blend_population_states( + self, + prev: TensorDict, + next_state: TensorDict, + active_rows: torch.Tensor, + ) -> TensorDict: + if bool(active_rows.all()): + return next_state + if not bool(active_rows.any()): + return prev + out = TensorDict({}, batch_size=[]) + for population_name in self._population_names: + prev_state = prev.get(population_name) + population_next = next_state.get(population_name) + if prev_state is None: + if population_next is not None: + out[population_name] = population_next + continue + if population_next is None: + out[population_name] = prev_state + continue + population_mask = active_rows.view(1, -1) + out[population_name] = _where_tensordict(population_mask, population_next, prev_state) + return out + + def _readout(self, y_final: torch.Tensor) -> torch.Tensor: + return backend_readout_output_cells( + y_final, + config=self._backend_readout_config(), + readout_query=self.readout_query, + readout_weight=self.readout_out.weight, + readout_bias=self.readout_out.bias, + ) + + def _pool_output_cells(self, y_final: torch.Tensor) -> torch.Tensor: + return self._pool_output_ports(self._select_output_cells(y_final)) + + def _select_output_cells(self, y_cells: torch.Tensor) -> torch.Tensor: + return backend_select_output_cells(y_cells, config=self._backend_readout_config()) + + def _pool_output_ports(self, port_y: torch.Tensor) -> torch.Tensor: + return backend_pool_output_ports( + port_y, + readout_pool=self.config.readout_pool, + readout_query=self.readout_query, + ) + + def _backend_readout_config(self) -> ReadoutConfig: + return ReadoutConfig( + partitioned_layout=bool(self._partitioned_layout), + output_slice=self._output_slice, + output_cell_idx=self.output_cell_idx, + readout_pool=str(self.config.readout_pool), + ) + + def _select_output_cells_stream_backend_population( + self, + *, + k: int | torch.Tensor | None, + ) -> str | None: + if self._resolve_constant_k_host(k) != 1: + return None + if not self._partitioned_layout: + return None + return self._full_recurrent_population_name + + def _plan_temporal_execution( + self, + *, + k: int | torch.Tensor | None, + device: torch.device, + dtype: torch.dtype, + output_boundary: Literal["sequence", "terminal"] | str = "deferred", + readout_output_boundary: Literal["cells", "pooled"] | str = "deferred", + materialize_final_state: bool | None = None, + has_output_consumer: bool = False, + fresh_state: bool = False, + training: bool = False, + time_steps: int | None = None, + active_region_mode_override: str | None = None, + compact_input_carry: bool = False, + preserve_internal_carry: bool = False, + helper_state_present: bool = False, + helper_direct_grad_feasible: bool | None = None, + helper_input_requires_grad: bool = False, + forward_direct_grad_feasible: bool | None = None, + forward_input_requires_grad: bool = False, + gradient_horizon_steps: int | None = None, + checkpoint_steps: tuple[int, ...] | None = None, + ) -> TemporalExecutionPlan: + planned_gradient_horizon_steps = ( + self.config.gradient_horizon_steps if gradient_horizon_steps is None else gradient_horizon_steps + ) + planned_checkpoint_steps = self.config.checkpoint_steps if checkpoint_steps is None else checkpoint_steps + return self._backend_planner.plan_temporal_execution( + device_type=device.type, + dtype=str(dtype).removeprefix("torch."), + partitioned_layout=bool(self._partitioned_layout), + has_edge_delay=bool(self._has_edge_delay), + constant_k=self._resolve_constant_k_host(k), + output_boundary=str(output_boundary), + readout_output_boundary=str(readout_output_boundary), + materialize_final_state=materialize_final_state, + has_output_consumer=has_output_consumer, + fresh_state=fresh_state, + training=training, + time_steps=time_steps, + active_region_mode_override=active_region_mode_override, + compact_input_carry=compact_input_carry, + preserve_internal_carry=preserve_internal_carry, + helper_state_present=helper_state_present, + helper_direct_grad_feasible=helper_direct_grad_feasible, + helper_input_requires_grad=helper_input_requires_grad, + forward_direct_grad_feasible=forward_direct_grad_feasible, + forward_input_requires_grad=forward_input_requires_grad, + gradient_horizon_steps=planned_gradient_horizon_steps, + checkpoint_steps=planned_checkpoint_steps, + ) + + def _training_static_prepack_enabled(self) -> bool: + population_name = self._full_recurrent_population_name + if population_name is None: + return True + population_spec = self._backend_population_specs.get(population_name) + if population_spec is None: + return True + return bool(population_spec.supports_training_static_prepack) + + def _population_num_cells(self, population_name: str) -> int: + return int(self._population_indices(population_name).numel()) + + def _init_backend_population_state( + self, + population_name: str, + *, + batch: int, + device: torch.device, + dtype: torch.dtype, + ) -> TensorDict: + return self._init_backend_population_state_for_receivers( + population_name, + batch=batch, + receivers=self._population_num_cells(population_name), + device=device, + dtype=dtype, + ) + + def _init_backend_population_state_for_receivers( + self, + population_name: str, + *, + batch: int, + receivers: int, + device: torch.device, + dtype: torch.dtype, + state_names: tuple[str, ...] | None = None, + ) -> TensorDict: + state_names = state_names or self._cell_spec_for_population(population_name).state_schema.keys + num_receivers = int(receivers) + with torch.profiler.record_function("fabric.glue.backend_population_state_zero"): + return TensorDict( + { + state_name: torch.zeros(batch, num_receivers, self.hidden_size, device=device, dtype=dtype) + for state_name in state_names + }, + batch_size=[batch, num_receivers], + device=device, + ) + + def _population_state_to_backend_state( + self, + population_name: str, + population_state: TensorDictBase, + ) -> TensorDict: + state_names = self._cell_spec_for_population(population_name).state_schema.keys + first = population_state[state_names[0]] + batch_size = int(first.shape[1]) + num_receivers = int(first.shape[0]) + with torch.profiler.record_function("fabric.glue.population_to_backend_state_layout"): + leaves: dict[str, torch.Tensor] = {} + for state_name in state_names: + backend_layout = population_state[state_name].permute(1, 0, 2) + leaves[state_name] = backend_layout if backend_layout.is_contiguous() else backend_layout.contiguous() + return TensorDict( + leaves, + batch_size=[batch_size, num_receivers], + device=first.device, + ) + + def _backend_state_to_population_state( + self, + population_name: str, + backend_state: Mapping[str, torch.Tensor], + ) -> TensorDict: + state_names = self._cell_spec_for_population(population_name).state_schema.keys + first = backend_state[state_names[0]] + batch_size = int(first.shape[0]) + num_receivers = int(first.shape[1]) + with torch.profiler.record_function("fabric.glue.backend_to_population_state_view"): + leaves: dict[str, torch.Tensor] = {} + for state_name in state_names: + leaves[state_name] = backend_state[state_name].permute(1, 0, 2) + return TensorDict( + leaves, + batch_size=[num_receivers, batch_size], + ) + + def _population_indices(self, population_name: str) -> torch.Tensor: + return getattr(self, _population_buffer_name(population_name)) + + def _population_recurrent_indices(self, population_name: str) -> torch.Tensor: + return getattr(self, _population_recurrent_buffer_name(population_name)) + + def _population_backend_recurrent_slice(self, population_name: str) -> tuple[int, int]: + return self._population_backend_recurrent_slices[population_name] + + def _build_population_indices(self, population_name: str) -> torch.Tensor: + population_idx = self._population_name_to_idx[population_name] + return torch.nonzero(self.cell_layout == population_idx, as_tuple=False).reshape(-1) + + def _register_population_backend_order_buffers(self) -> None: + order_parts: list[torch.Tensor] = [] + slices: dict[str, tuple[int, int]] = {} + offset = 0 + for name in self._population_names: + recurrent_idx = self._population_recurrent_indices(name) + count = int(recurrent_idx.numel()) + slices[name] = (offset, offset + count) + if count: + order_parts.append(recurrent_idx) + offset += count + if order_parts: + backend_order = torch.cat(order_parts, dim=0).to(dtype=torch.long) + else: + backend_order = torch.empty(0, dtype=torch.long) + self._population_backend_recurrent_order_is_identity = bool( + backend_order.numel() == self._num_recurrent_cells + and torch.equal(backend_order, torch.arange(backend_order.numel(), dtype=torch.long)) + ) + inverse_order = torch.empty_like(backend_order) + if backend_order.numel() > 0: + inverse_order[backend_order] = torch.arange(backend_order.numel(), dtype=torch.long) + self._population_backend_recurrent_slices = slices + self.register_buffer("population_backend_recurrent_order", backend_order) + self.register_buffer("population_backend_recurrent_inverse_order", inverse_order) + if backend_order.numel() == self._num_recurrent_cells: + self.register_buffer( + "recurrent_neighbor_idx_backend_order", + self.recurrent_neighbor_idx.index_select(0, backend_order), + ) + self.register_buffer( + "recurrent_neighbor_valid_backend_order", + self.recurrent_neighbor_valid.index_select(0, backend_order), + ) + self.register_buffer( + "recurrent_edge_distance_backend_order", + self.recurrent_edge_distance.index_select(0, backend_order), + ) + self.register_buffer( + "recurrent_edge_delay_backend_order", + self.recurrent_edge_delay.index_select(0, backend_order), + ) + recurrent_local_valid_backend_order = self.recurrent_local_valid.index_select(0, backend_order) + recurrent_local_sender_idx_backend_order = self.recurrent_local_sender_idx.index_select(0, backend_order) + self.register_buffer("recurrent_local_valid_backend_order", recurrent_local_valid_backend_order) + self.register_buffer("recurrent_local_sender_idx_backend_order", recurrent_local_sender_idx_backend_order) + self.register_buffer( + "recurrent_local_receiver_idx_by_sender_backend_order", + _build_sender_reverse_table( + int(self.sender_cell_idx.numel()), + recurrent_local_sender_idx_backend_order, + recurrent_local_valid_backend_order, + ), + ) + self._register_shared_active_region_buffers() + + def _register_shared_active_region_buffers(self) -> None: + uses_sparse_message_backend = bool(getattr(self, "_uses_sparse_message_backend", False)) + output_sender_table = self.output_neighbor_idx if uses_sparse_message_backend else self.output_local_sender_idx + output_sender_valid = self.output_neighbor_valid if uses_sparse_message_backend else self.output_local_valid + recurrent_sender_table = ( + self.recurrent_neighbor_idx if uses_sparse_message_backend else self.recurrent_local_sender_idx + ) + recurrent_sender_valid = ( + self.recurrent_neighbor_valid if uses_sparse_message_backend else self.recurrent_local_valid + ) + output_window_start = ( + self._output_sparse_recurrent_window_start + if uses_sparse_message_backend + else self._output_local_recurrent_window_start + ) + output_window_count = ( + self._output_sparse_recurrent_window_count + if uses_sparse_message_backend + else self._output_local_recurrent_window_count + ) + output_window_contiguous = ( + self._output_sparse_recurrent_window_contiguous + if uses_sparse_message_backend + else self._output_local_recurrent_window_contiguous + ) + output_seed = recurrent_sender_seed_from_table( + output_sender_table, + num_input_senders=int(self._num_input_cells), + sender_valid=output_sender_valid, + ) + output_dependency_region = contiguous_recurrent_region( + start=int(output_window_start), + count=int(output_window_count) if bool(output_window_contiguous) else 0, + full_count=int(self._num_recurrent_cells), + ) + output_closure = close_recurrent_region_from_sender_tables( + seed_recurrent_receivers=output_seed, + recurrent_sender_table=recurrent_sender_table, + num_input_senders=int(self._num_input_cells), + recurrent_count=int(self._num_recurrent_cells), + recurrent_sender_valid=recurrent_sender_valid, + ) + self._flat_bucket_output_dependency_recurrent_indices = output_dependency_region.indices + self._flat_bucket_output_dependency_recurrent_mode = ( + "output_dependency_window" if output_dependency_region.compact_contiguous else output_dependency_region.mode + ) + self._flat_bucket_output_dependency_recurrent_start = int(output_dependency_region.start) + self._flat_bucket_output_dependency_recurrent_count = int(output_dependency_region.count) + self._flat_bucket_output_dependency_recurrent_is_full = bool(output_dependency_region.is_full) + self._flat_bucket_output_dependency_recurrent_compact_contiguous = bool( + output_dependency_region.compact_contiguous + ) + self._flat_bucket_output_recurrent_closure_indices = output_closure.indices + self._flat_bucket_output_recurrent_closure_mode = output_closure.mode + self._flat_bucket_output_recurrent_closure_start = int(output_closure.start) + self._flat_bucket_output_recurrent_closure_count = int(output_closure.count) + self._flat_bucket_output_recurrent_closure_is_full = bool(output_closure.is_full) + self._flat_bucket_output_recurrent_closure_compact_contiguous = bool(output_closure.compact_contiguous) + + if uses_sparse_message_backend and output_closure.is_full: + active_region = output_closure + else: + active_region = output_dependency_region if output_dependency_region.compact_contiguous else output_closure + self._flat_bucket_active_output_region_indices = active_region.indices + self._flat_bucket_active_output_region_mode = ( + "output_dependency_window" if active_region is output_dependency_region else active_region.mode + ) + self._flat_bucket_active_output_region_start = int(active_region.start) + self._flat_bucket_active_output_region_count = int(active_region.count) + self._flat_bucket_active_output_region_is_full = bool(active_region.is_full) + self._flat_bucket_active_output_region_compact_contiguous = bool(active_region.compact_contiguous) + + active_indices = tuple(int(index) for index in active_region.indices) + use_compact_region = bool(active_region.compact_contiguous or active_region.is_full) + if active_region.is_empty or not use_compact_region: + active_indices = tuple(range(int(self._num_recurrent_cells))) + use_compact_region = bool(active_indices) + start = int(active_indices[0]) if active_indices else 0 + count = len(active_indices) + full_count = int(self._num_recurrent_cells) + valid_window = bool(use_compact_region and start >= 0 and count > 0 and start + count <= full_count) + self._shared_active_region_buffer_names: dict[str, tuple[str, str, str]] = {} + if not valid_window: + self.register_buffer("shared_active_region_recurrent_idx", torch.empty(0, dtype=torch.long)) + return + active_recurrent_idx = torch.tensor(active_indices, dtype=torch.long) + self.register_buffer("shared_active_region_recurrent_idx", active_recurrent_idx) + for name in self._population_names: + population_recurrent_idx = self._population_recurrent_indices(name) + in_window = (population_recurrent_idx >= start) & (population_recurrent_idx < start + count) + population_positions = torch.nonzero(in_window, as_tuple=False).reshape(-1).to(dtype=torch.long) + active_full_idx = population_recurrent_idx.index_select(0, population_positions) + active_offsets = (active_full_idx - start).to(dtype=torch.long) + positions_name = _shared_active_region_positions_buffer_name(name) + offsets_name = _shared_active_region_offsets_buffer_name(name) + active_idx_name = _shared_active_region_recurrent_idx_buffer_name(name) + self.register_buffer(positions_name, population_positions) + self.register_buffer(offsets_name, active_offsets) + self.register_buffer(active_idx_name, active_full_idx.to(dtype=torch.long)) + self._shared_active_region_buffer_names[name] = ( + positions_name, + offsets_name, + active_idx_name, + ) + + def _shared_active_region_buckets(self) -> dict[str, dict[str, torch.Tensor]]: + buckets: dict[str, dict[str, torch.Tensor]] = {} + for name, buffer_names in self._shared_active_region_buffer_names.items(): + positions_name, offsets_name, active_idx_name = buffer_names + population_positions = getattr(self, positions_name) + if int(population_positions.numel()) == 0: + continue + buckets[name] = { + "population_positions": population_positions, + "active_offsets": getattr(self, offsets_name), + "active_recurrent_idx": getattr(self, active_idx_name), + } + return buckets + + def _project_output_cells(self, output_msg: torch.Tensor) -> torch.Tensor: + return torch.einsum("btpd,pdh->btph", output_msg, self.output_cell_weight) + self.output_cell_bias.view( + 1, 1, -1, self.hidden_size + ) + + def _project_output_cells_step(self, output_msg: torch.Tensor) -> torch.Tensor: + return torch.einsum("bpd,pdh->bph", output_msg, self.output_cell_weight) + self.output_cell_bias.view( + 1, + -1, + self.hidden_size, + ) + + def _prepare_stream_step_population_cache( + self, + state: TensorDict, + *, + batch: int, + device: torch.device, + dtype: torch.dtype, + ) -> dict[str, object] | None: + del batch, device, dtype + cache: dict[str, object] = {} + for population_name in self._population_names: + if self._population_recurrent_indices(population_name).numel() == 0: + continue + cache[population_name] = self._population_state_to_backend_state( + population_name, + cast(TensorDictBase, state[population_name]), + ) + return cache or None + + def _prepare_fresh_stream_step_population_cache( + self, + *, + batch: int, + device: torch.device, + dtype: torch.dtype, + ) -> dict[str, object] | None: + del batch, device, dtype + cache: dict[str, object] = {} + for population_name in self._population_names: + if self._population_recurrent_indices(population_name).numel() == 0: + continue + cache[population_name] = cast(object, None) + return cache or None + + def _reset_stream_step_population_cache( + self, + step_population_state_cache: dict[str, object], + resets: torch.Tensor, + ) -> None: + reset_mask = torch.as_tensor(resets, device=self.coords.device, dtype=torch.bool).view(-1) + for population_name, cached_state in list(step_population_state_cache.items()): + if cached_state is None: + continue + step_population_state_cache[population_name] = reset_backend_state_rows(cached_state, reset_mask) + + def _reset_stream_step_sender_cache( + self, + step_sender_cache: dict[str, torch.Tensor], + resets: torch.Tensor, + ) -> None: + reset_mask = torch.as_tensor(resets, device=self.coords.device, dtype=torch.bool).view(-1, 1, 1) + for key in ("sender_k", "sender_v", "recurrent_k", "recurrent_v"): + cached = step_sender_cache.get(key) + if not torch.is_tensor(cached): + continue + step_sender_cache[key] = reset_backend_tensor_rows(cached, reset_mask.view(-1)) + + def _apply_stream_step_population_cache( + self, + state: TensorDict, + step_population_state_cache: dict[str, object], + ) -> None: + for population_name, cached_state in step_population_state_cache.items(): + state[population_name] = self._backend_state_to_population_state( + population_name, + cast(Mapping[str, torch.Tensor], cached_state), + ) + + +class Model(ModelTemporalMixin, nn.Module): + def __init__(self, spec: Spec) -> None: + super().__init__() + self.spec = spec + self.runtime = Runtime(spec) + self.num_input_cells = int(spec.input_cell_idx.numel()) + if spec.config.input_adapters is None or spec.config.output_adapters is None: + raise ValueError("Fabric Model requires blueprint input and output adapters") + self.input_names = tuple(spec.config.input_adapters.keys()) + self.output_names = tuple(spec.config.output_adapters.keys()) + self.input_projections = nn.ModuleDict() + self.output_projections = nn.ModuleDict() + self._input_position_buffers: dict[str, str] = {} + self._output_position_buffers: dict[str, str] = {} + self.num_readout_slots = 0 + self.output_dim_total = 0 + input_position_lookup = _node_position_lookup(spec.input_cell_idx) + output_position_lookup = _node_position_lookup(spec.output_cell_idx) + for idx, (name, adapter) in enumerate(spec.config.input_adapters.items()): + nodes = spec.config.input_cell_groups[adapter.region] # type: ignore[index] + positions = torch.tensor([input_position_lookup[int(node)] for node in nodes], dtype=torch.long) + buffer_name = f"_input_positions_{idx}" + self.register_buffer(buffer_name, positions) + self._input_position_buffers[name] = buffer_name + self.input_projections[name] = nn.Linear(int(adapter.dim), len(nodes) * self.runtime.hidden_size) + for idx, (name, adapter) in enumerate(spec.config.output_adapters.items()): + nodes = spec.config.output_cell_groups[adapter.region] # type: ignore[index] + positions = torch.tensor([output_position_lookup[int(node)] for node in nodes], dtype=torch.long) + buffer_name = f"_output_positions_{idx}" + self.register_buffer(buffer_name, positions) + self._output_position_buffers[name] = buffer_name + readout_slots = 1 if adapter.readout == "mean" else len(nodes) + self.num_readout_slots += int(readout_slots) + self.output_dim_total += int(adapter.dim) + self.output_projections[name] = nn.Linear(readout_slots * self.runtime.hidden_size, int(adapter.dim)) + # Cell-specific defaults live behind the cell plugin; these are generic overrides. + self._sequence_checkpoint_target_bytes: int | None = None + self._sequence_checkpoint_state_overhead_factor: float | None = None + self._sequence_direct_grad_target_bytes: int | None = None + + @property + def backend_ir(self): + return self.runtime.backend_ir + + @property + def backend_population_specs(self): + return self.runtime.backend_population_specs + + @property + def last_backend_execution(self): + return self.runtime.last_backend_execution + + @property + def graph_capture_cache_stats(self): + return self.runtime.graph_capture_cache_stats + + def describe_backend(self) -> dict[str, object]: + return self.runtime.describe_backend() + + def plan_backend_execution( + self, + *, + batch_size: int, + time_steps: int, + inner_steps: int, + training: bool, + tape_policy: TapePolicy | None = None, + device: torch.device | None = None, + surface_key: str | None = None, + ) -> PlannedFabricExecution: + return self.runtime.plan_backend_execution( + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + training=training, + tape_policy=tape_policy, + device=device, + surface_key=surface_key, + ) + + def init_state(self, batch: int, *, device: torch.device | str = "cpu", dtype: torch.dtype) -> TensorDict: + return self.runtime.init_state(batch=batch, device=device, dtype=dtype) + + def reset_state(self, state: MaybeState, mask: ResetMask) -> MaybeState: + return self.runtime.reset_state(state, mask) + + def readout_output_cells(self, output_cells: torch.Tensor) -> torch.Tensor: + return self._readout(output_cells) + + def _pack_boundary_input(self, external_input: Tensor | Mapping[str, Tensor]) -> tuple[torch.Tensor, bool]: + named_inputs = self._named_external_inputs(external_input) + first = next(iter(named_inputs.values())) + step_mode = first.dim() == 2 + first_seq = first.unsqueeze(1) if step_mode else first + if first_seq.dim() != 3: + raise ValueError(f"Fabric input must be shaped [B,H] or [B,T,H], got {tuple(first.shape)}") + batch_size, time_steps = int(first_seq.shape[0]), int(first_seq.shape[1]) + boundary = first_seq.new_zeros((batch_size, time_steps, self.num_input_cells, self.runtime.hidden_size)) + for name, value in named_inputs.items(): + adapter = self.spec.config.input_adapters[name] # type: ignore[index] + seq = value.unsqueeze(1) if value.dim() == 2 else value + if seq.dim() != 3: + raise ValueError(f"Fabric input {name!r} must be shaped [B,H] or [B,T,H], got {tuple(value.shape)}") + if int(seq.shape[0]) != batch_size or int(seq.shape[1]) != time_steps: + raise ValueError(f"Fabric input {name!r} must share batch and time dimensions") + if int(seq.shape[2]) != int(adapter.dim): + raise ValueError(f"Fabric input {name!r} dim={seq.shape[2]} must match declared dim={adapter.dim}") + positions = getattr(self, self._input_position_buffers[name]).to(device=seq.device) + projected = self.input_projections[name](seq).view(batch_size, time_steps, positions.numel(), -1) + boundary.index_copy_(2, positions, projected) + return boundary, step_mode + + def _projected_boundary_source_input( + self, + external_input: Tensor | Mapping[str, Tensor], + ) -> tuple[torch.Tensor, bool, torch.Tensor, torch.Tensor | None] | None: + named_inputs = self._named_external_inputs(external_input) + if len(named_inputs) != 1: + return None + name, value = next(iter(named_inputs.items())) + positions = getattr(self, self._input_position_buffers[name]).to(device=value.device) + expected = torch.arange(self.num_input_cells, device=value.device, dtype=positions.dtype) + if int(positions.numel()) != self.num_input_cells or not torch.equal(positions, expected): + return None + seq = value.unsqueeze(1) if value.dim() == 2 else value + if seq.dim() != 3: + raise ValueError(f"Fabric input {name!r} must be shaped [B,H] or [B,T,H], got {tuple(value.shape)}") + adapter = self.spec.config.input_adapters[name] # type: ignore[index] + if int(seq.shape[2]) != int(adapter.dim): + raise ValueError(f"Fabric input {name!r} dim={seq.shape[2]} must match declared dim={adapter.dim}") + projection = self.input_projections[name] + return seq, value.dim() == 2, projection.weight, projection.bias + + def _named_external_inputs(self, external_input: Tensor | Mapping[str, Tensor]) -> dict[str, Tensor]: + if torch.is_tensor(external_input): + if len(self.input_names) != 1: + raise ValueError("Fabric with multiple inputs expects a mapping of input name to tensor") + return {self.input_names[0]: external_input} + provided = {str(name): value for name, value in external_input.items()} + expected = set(self.input_names) + if set(provided) != expected: + raise ValueError(f"Fabric inputs must be exactly {sorted(expected)}, got {sorted(provided)}") + return provided + + def _project_outputs_from_cells(self, output_cells: torch.Tensor, *, squeeze_time: bool = False): + outputs = { + name: self._project_one_output(name, output_cells, squeeze_time=squeeze_time) for name in self.output_names + } + if len(outputs) == 1: + return next(iter(outputs.values())) + return outputs + + def _can_project_outputs_from_backend_pooled_readout(self) -> bool: + if len(self.output_names) != 1: + return False + name = self.output_names[0] + adapter = self.spec.config.output_adapters[name] # type: ignore[index] + if adapter.readout != "mean": + return False + positions = getattr(self, self._output_position_buffers[name]) + expected = torch.arange(int(self.runtime._num_output_cells), device=positions.device, dtype=positions.dtype) + return bool(positions.shape == expected.shape and torch.equal(positions, expected)) + + def _project_outputs_from_backend_pooled_readout( + self, + pooled_cells: torch.Tensor, + *, + squeeze_time: bool = False, + ) -> torch.Tensor: + if not self._can_project_outputs_from_backend_pooled_readout(): + raise RuntimeError("Backend pooled readout projection requires one full-boundary mean output") + name = self.output_names[0] + projected = self._project_output_flat( + name, + pooled_cells.reshape(pooled_cells.shape[0], pooled_cells.shape[1], -1), + ) + return projected.squeeze(1) if squeeze_time else projected + + def _project_outputs_from_backend_readout( + self, + output_cells: torch.Tensor, + *, + readout_output_boundary: Literal["cells", "pooled"], + squeeze_time: bool = False, + ): + if readout_output_boundary == "pooled": + return self._project_outputs_from_backend_pooled_readout(output_cells, squeeze_time=squeeze_time) + return self._project_outputs_from_cells(output_cells, squeeze_time=squeeze_time) + + def _project_single_output_from_cells(self, output_cells: torch.Tensor) -> torch.Tensor: + if len(self.output_names) != 1: + raise ValueError("This Fabric operation requires exactly one declared output") + return self._project_one_output(self.output_names[0], output_cells, squeeze_time=False) + + def _project_one_output(self, name: str, output_cells: torch.Tensor, *, squeeze_time: bool) -> torch.Tensor: + adapter = self.spec.config.output_adapters[name] # type: ignore[index] + positions = getattr(self, self._output_position_buffers[name]).to(device=output_cells.device) + region = output_cells.index_select(2, positions) + if adapter.readout == "mean": + pooled = region.mean(dim=2, keepdim=True) + elif adapter.readout == "flatten": + pooled = region + else: + raise ValueError(f"Unsupported Fabric output readout {adapter.readout!r}") + projected = self._project_output_flat( + name, + pooled.reshape(output_cells.shape[0], output_cells.shape[1], -1), + ) + return projected.squeeze(1) if squeeze_time else projected + + def _project_output_flat(self, name: str, flat_input: torch.Tensor) -> torch.Tensor: + projection = self.output_projections[name] + row_tile_len, row_tile_reason = self._output_projection_row_tile_len(flat_input, projection) + rows = int(flat_input.shape[0]) * int(flat_input.shape[1]) + if row_tile_len >= rows: + return projection(flat_input) + flat_2d = flat_input.reshape(rows, int(flat_input.shape[-1])) + output_2d = flat_2d.new_empty((rows, int(projection.out_features))) + for start in range(0, rows, int(row_tile_len)): + end = min(rows, start + int(row_tile_len)) + output_2d[start:end].copy_(projection(flat_2d[start:end])) + self._annotate_output_projection_row_tile( + rows=rows, + row_tile_len=int(row_tile_len), + row_tile_reason=row_tile_reason, + ) + return output_2d.view(int(flat_input.shape[0]), int(flat_input.shape[1]), int(projection.out_features)) + + def _output_projection_row_tile_len( + self, + flat_input: torch.Tensor, + projection: nn.Linear, + ) -> tuple[int, str]: + rows = int(flat_input.shape[0]) * int(flat_input.shape[1]) + if rows <= 1 or flat_input.device.type != "cuda" or not torch.is_grad_enabled(): + return rows, "output_projection_row_tiling=disabled" + dtype_bytes = int(flat_input.element_size()) + output_dim = int(projection.out_features) + estimated_output_bytes = int(rows) * output_dim * dtype_bytes + target_bytes = 512 << 20 + if estimated_output_bytes <= target_bytes: + return rows, ( + f"output_projection_row_tiling=full;estimated_output_bytes={estimated_output_bytes};" + f"target_bytes={target_bytes};output_dim={output_dim}" + ) + row_tile_len = max(1, int(target_bytes // max(1, output_dim * dtype_bytes))) + return int(min(rows, row_tile_len)), ( + f"output_projection_row_tiling=active;estimated_output_bytes={estimated_output_bytes};" + f"target_bytes={target_bytes};output_dim={output_dim}" + ) + + def _annotate_output_projection_row_tile( + self, + *, + rows: int, + row_tile_len: int, + row_tile_reason: str, + ) -> None: + record = self.runtime._last_backend_execution + if record is None: + return + self.runtime._last_backend_execution = replace( + record, + workspace_aliases=record.workspace_aliases + + ( + f"output_projection_row_tile:rows={int(rows)};tile_rows={int(row_tile_len)}", + f"output_projection_row_tile_reason:{row_tile_reason}", + ), + launch_readout_modes=record.launch_readout_modes + ("output_projection_row_tiled",), + actual_launch_readout_modes=record.actual_launch_readout_modes + ("output_projection_row_tiled",), + ) + + def _forward_sequence_with_readout( + self, + boundary_seq: torch.Tensor, + state: TensorDictBase | None, + *, + resets: Optional[ResetMask], + k: int | torch.Tensor | None, + training_semantics: bool | None = None, + materialize_final_state: bool = True, + tape_policy: TapePolicy | None = None, + output_boundary: Literal["sequence", "terminal"] = "sequence", + ) -> tuple[torch.Tensor, TensorDict]: + if output_boundary not in {"sequence", "terminal"}: + raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") + readout_batch_tile_len, readout_batch_tile_reason = self._readout_pooled_batch_tile_len( + boundary_seq, + k=k, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + ) + if readout_batch_tile_len < int(boundary_seq.shape[0]): + return self._forward_sequence_with_readout_batch_tiled( + boundary_seq, + state, + resets=resets, + k=k, + training_semantics=training_semantics, + materialize_final_state=materialize_final_state, + tape_policy=tape_policy, + output_boundary=output_boundary, + batch_tile_len=readout_batch_tile_len, + batch_tile_reason=readout_batch_tile_reason, + ) + output_cells, next_state = self.runtime.forward_output_cells_for_readout( + state=state, + resets=resets, + k=k, + boundary_input=boundary_seq, + training_semantics=training_semantics, + materialize_final_state=materialize_final_state, + tape_policy=tape_policy, + output_boundary=output_boundary, + readout_output_boundary="cells", + ) + return output_cells, next_state + + def _stream_sequence_with_readout( + self, + boundary_seq: torch.Tensor, + state: TensorDictBase | None, + *, + resets: Optional[ResetMask], + k: int | torch.Tensor | None, + training_semantics: bool | None = None, + materialize_final_state: bool = True, + tape_policy: TapePolicy | None = None, + output_boundary: Literal["sequence", "terminal"] = "sequence", + output_consumer: _ModelOutputChunkConsumer, + detach_internal_carry_after_output_chunk: bool = False, + ) -> TensorDict: + if output_boundary not in {"sequence", "terminal"}: + raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") + batch_size = int(boundary_seq.shape[0]) + readout_batch_tile_len, readout_batch_tile_reason = self._readout_pooled_batch_tile_len( + boundary_seq, + k=k, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + ) + if readout_batch_tile_len < batch_size: + return self._stream_sequence_with_readout_batch_tiled( + boundary_seq, + state, + resets=resets, + k=k, + training_semantics=training_semantics, + materialize_final_state=materialize_final_state, + tape_policy=tape_policy, + output_boundary=output_boundary, + batch_tile_len=readout_batch_tile_len, + batch_tile_reason=readout_batch_tile_reason, + output_consumer=output_consumer, + detach_internal_carry_after_output_chunk=detach_internal_carry_after_output_chunk, + ) + + def consume_output_cells(output_chunk: torch.Tensor, time_start: int, time_end: int) -> None: + output_consumer( + self._project_single_output_from_cells(output_chunk), + 0, + batch_size, + time_start, + time_end, + ) + + _unused_output, next_state = self.runtime.forward_output_cells_for_readout( + state=state, + resets=resets, + k=k, + boundary_input=boundary_seq, + training_semantics=training_semantics, + materialize_final_state=materialize_final_state, + tape_policy=tape_policy, + output_boundary=output_boundary, + readout_output_boundary="cells", + output_chunk_consumer=consume_output_cells, + detach_internal_carry_after_output_chunk=detach_internal_carry_after_output_chunk, + ) + if materialize_final_state: + return next_state + return TensorDict({}, batch_size=[]) + + def _readout_pooled_batch_tile_len( + self, + hidden_seq: torch.Tensor, + *, + k: int | torch.Tensor | None, + materialize_final_state: bool, + output_boundary: Literal["sequence", "terminal"], + readout_output_boundary: Literal["cells", "pooled"] = "cells", + training: bool | None = None, + ) -> tuple[int, str]: + batch_size = int(hidden_seq.shape[0]) + time_steps = int(hidden_seq.shape[1]) + output_cells = int(self.spec.output_cell_idx.numel()) + readout_slots = int(self.num_readout_slots) + training_enabled = bool(torch.is_grad_enabled() if training is None else training) + temporal_execution_plan = self.runtime._plan_temporal_execution( + k=k, + device=hidden_seq.device, + dtype=hidden_seq.dtype, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + materialize_final_state=materialize_final_state, + has_output_consumer=False, + fresh_state=False, + training=training_enabled, + time_steps=time_steps, + ) + if temporal_execution_plan.executor.selected_implementation == "shared_transition_buckets": + decision = self.runtime._flat_bucket_sequence_readout_batch_tile_len( + batch_size=batch_size, + time_steps=time_steps, + dtype_bytes=int(torch.empty((), dtype=hidden_seq.dtype).element_size()), + input_cells=int(self.num_input_cells), + output_cells=output_cells, + readout_slots=readout_slots, + projected_output_dim=int(self.output_dim_total), + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + training=training_enabled, + memory=self.runtime._cuda_memory_budget(hidden_seq.device), + ) + return int(decision.value), decision.reason + decision = readout_pooled_batch_tile_policy( + batch_size=batch_size, + time_steps=time_steps, + dtype_bytes=int(torch.empty((), dtype=hidden_seq.dtype).element_size()), + input_cells=int(self.num_input_cells), + output_cells=output_cells, + readout_slots=readout_slots, + hidden_size=int(self.runtime.hidden_size), + projected_output_dim=int(self.output_dim_total), + materialize_final_state=materialize_final_state, + training=training_enabled, + backend_sequence_surface_supported=False, + memory=self.runtime._cuda_memory_budget(hidden_seq.device), + ) + return int(decision.value), decision.reason + + def _forward_projected_source_sequence_batch_tiled( + self, + source_hidden_seq: torch.Tensor, + state: TensorDictBase | None, + *, + resets: Optional[ResetMask], + k: int | torch.Tensor | None, + projection_weight: torch.Tensor, + projection_bias: torch.Tensor | None, + materialize_final_state: bool, + output_boundary: Literal["sequence", "terminal"], + readout_output_boundary: Literal["cells", "pooled"], + batch_tile_len: int, + batch_tile_reason: str, + squeeze_time: bool, + ) -> tuple[Tensor | dict[str, Tensor], TensorDict]: + batch_size = int(source_hidden_seq.shape[0]) + output_accum: Any | None = None + state_chunks: list[TensorDict] = [] + for start in range(0, batch_size, int(batch_tile_len)): + end = min(start + int(batch_tile_len), batch_size) + output_cells, next_state = self.runtime.forward_output_cells_for_readout( + state=cast(TensorDictBase | None, _slice_batch_value(state, start, end)), + resets=_slice_batch_reset(resets, start, end, batch_size=batch_size, device=source_hidden_seq.device), + k=_slice_batch_k(k, start, end, batch_size=batch_size, time_steps=int(source_hidden_seq.shape[1])), + source_hidden_input=source_hidden_seq[start:end], + input_projection_weight=projection_weight, + input_projection_bias=projection_bias, + training_semantics=None, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + output_chunk = self._project_outputs_from_backend_readout( + output_cells, + readout_output_boundary=readout_output_boundary, + squeeze_time=squeeze_time, + ) + if output_accum is None: + output_accum = _new_batch_like(output_chunk, batch_size=batch_size) + _copy_batch_output(output_accum, output_chunk, start, end) + if materialize_final_state: + state_chunks.append(next_state) + if output_accum is None: + raise RuntimeError("projected-source sequence batch tiling produced no output chunks") + output = cast(Tensor | dict[str, Tensor], output_accum) + next_full_state = ( + cast(TensorDict, _cat_batch_value(cast(list[Any], state_chunks))) + if materialize_final_state + else TensorDict({}, batch_size=[]) + ) + self._annotate_readout_pooled_batch_tile( + batch_size=batch_size, + batch_tile_len=int(batch_tile_len), + batch_tile_reason=batch_tile_reason, + output_mode="projected_source", + ) + return output, next_full_state + + def _forward_sequence_with_readout_batch_tiled( + self, + boundary_seq: torch.Tensor, + state: TensorDictBase | None, + *, + resets: Optional[ResetMask], + k: int | torch.Tensor | None, + training_semantics: bool | None, + materialize_final_state: bool, + tape_policy: TapePolicy | None, + output_boundary: Literal["sequence", "terminal"], + batch_tile_len: int, + batch_tile_reason: str, + ) -> tuple[torch.Tensor, TensorDict]: + batch_size = int(boundary_seq.shape[0]) + output_chunks: list[torch.Tensor] = [] + state_chunks: list[TensorDict] = [] + output_seq: torch.Tensor | None = None + preallocate_output = not torch.is_grad_enabled() + for start in range(0, batch_size, int(batch_tile_len)): + end = min(start + int(batch_tile_len), batch_size) + boundary_tile = boundary_seq[start:end] + output_cells, next_state = self.runtime.forward_output_cells_for_readout( + state=cast(TensorDictBase | None, _slice_batch_value(state, start, end)), + resets=_slice_batch_reset(resets, start, end, batch_size=batch_size, device=boundary_seq.device), + k=_slice_batch_k(k, start, end, batch_size=batch_size, time_steps=int(boundary_seq.shape[1])), + boundary_input=boundary_tile, + training_semantics=training_semantics, + materialize_final_state=materialize_final_state, + tape_policy=tape_policy, + output_boundary=output_boundary, + readout_output_boundary="cells", + ) + if preallocate_output: + if output_seq is None: + output_seq = output_cells.new_empty((batch_size, *output_cells.shape[1:])) + output_seq[start:end].copy_(output_cells) + else: + output_chunks.append(output_cells) + if materialize_final_state: + state_chunks.append(next_state) + if output_seq is None: + output_seq = torch.cat(output_chunks, dim=0) + next_full_state = ( + cast(TensorDict, _cat_batch_value(cast(list[Any], state_chunks))) + if materialize_final_state + else TensorDict({}, batch_size=[]) + ) + self._annotate_readout_pooled_batch_tile( + batch_size=batch_size, + batch_tile_len=int(batch_tile_len), + batch_tile_reason=batch_tile_reason, + output_mode="preallocated" if preallocate_output else "cat", + ) + return output_seq, next_full_state + + def _stream_sequence_with_readout_batch_tiled( + self, + boundary_seq: torch.Tensor, + state: TensorDictBase | None, + *, + resets: Optional[ResetMask], + k: int | torch.Tensor | None, + training_semantics: bool | None, + materialize_final_state: bool, + tape_policy: TapePolicy | None, + output_boundary: Literal["sequence", "terminal"], + batch_tile_len: int, + batch_tile_reason: str, + output_consumer: _ModelOutputChunkConsumer, + detach_internal_carry_after_output_chunk: bool = False, + ) -> TensorDict: + batch_size = int(boundary_seq.shape[0]) + state_chunks: list[TensorDict] = [] + for start in range(0, batch_size, int(batch_tile_len)): + end = min(start + int(batch_tile_len), batch_size) + boundary_tile = boundary_seq[start:end] + + def consume_output_cells( + output_chunk: torch.Tensor, + time_start: int, + time_end: int, + *, + boundary_tile_len: int = int(boundary_tile.shape[0]), + batch_start: int = int(start), + batch_end: int = int(end), + ) -> None: + del boundary_tile_len + output_consumer( + self._project_single_output_from_cells(output_chunk), + batch_start, + batch_end, + time_start, + time_end, + ) + + _unused_output, next_state = self.runtime.forward_output_cells_for_readout( + state=cast(TensorDictBase | None, _slice_batch_value(state, start, end)), + resets=_slice_batch_reset(resets, start, end, batch_size=batch_size, device=boundary_seq.device), + k=_slice_batch_k(k, start, end, batch_size=batch_size, time_steps=int(boundary_seq.shape[1])), + boundary_input=boundary_tile, + training_semantics=training_semantics, + materialize_final_state=materialize_final_state, + tape_policy=tape_policy, + output_boundary=output_boundary, + readout_output_boundary="cells", + output_chunk_consumer=consume_output_cells, + detach_internal_carry_after_output_chunk=detach_internal_carry_after_output_chunk, + ) + if materialize_final_state: + state_chunks.append(next_state) + next_full_state = ( + cast(TensorDict, _cat_batch_value(cast(list[Any], state_chunks))) + if materialize_final_state + else TensorDict({}, batch_size=[]) + ) + self._annotate_readout_pooled_batch_tile( + batch_size=batch_size, + batch_tile_len=int(batch_tile_len), + batch_tile_reason=batch_tile_reason, + output_mode="streaming_consumer", + ) + return next_full_state + + def _annotate_readout_pooled_batch_tile( + self, + *, + batch_size: int, + batch_tile_len: int, + batch_tile_reason: str, + output_mode: str, + ) -> None: + record = self.runtime._last_backend_execution + if record is None: + return + self.runtime._last_backend_execution = replace( + record, + batch_size=int(batch_size), + workspace_aliases=record.workspace_aliases + + ( + f"readout_pooled_batch_tile:b={int(batch_tile_len)}", + f"readout_pooled_batch_tile_reason:{batch_tile_reason}", + f"readout_pooled_batch_tile_output:{output_mode}", + ), + launch_readout_modes=record.launch_readout_modes + ("pooled_batch_tiled",), + actual_launch_readout_modes=record.actual_launch_readout_modes + ("pooled_batch_tiled",), + ) + + def _forward_sequence_checkpointed( + self, + hidden_seq: torch.Tensor, + state: MaybeState, + *, + resets: Optional[ResetMask], + k: int | torch.Tensor | None, + materialize_final_state: bool = True, + output_boundary: Literal["sequence", "terminal"] = "sequence", + ) -> tuple[torch.Tensor, TensorDict]: + batch_size = hidden_seq.shape[0] + current_state = self.runtime._ensure_state( + state, + batch=batch_size, + device=hidden_seq.device, + dtype=hidden_seq.dtype, + ) + chunk_len = self._sequence_checkpoint_chunk_len(hidden_seq, current_state) + resets_bt = _expand_resets_for_time( + resets, + batch_size=batch_size, + time_steps=hidden_seq.shape[1], + device=hidden_seq.device, + ) + outputs: list[torch.Tensor] = [] + running_state = current_state + + for start in range(0, hidden_seq.shape[1], chunk_len): + end = min(start + chunk_len, hidden_seq.shape[1]) + hidden_chunk = hidden_seq[:, start:end] + reset_chunk = None if resets_bt is None else resets_bt[:, start:end] + k_chunk = _slice_sequence_k(k, start=start, end=end, batch_size=batch_size, device=hidden_seq.device) + chunk_materialize_final_state = materialize_final_state or end < hidden_seq.shape[1] + chunk_output_boundary: Literal["sequence", "terminal"] = ( + output_boundary if end == hidden_seq.shape[1] else "sequence" + ) + state_paths, state_batch_sizes, state_tensors = _flatten_tensordict(running_state) + grad_marker = hidden_seq.new_zeros((), requires_grad=True) + + def run_sequence( + hidden_piece: torch.Tensor, + *flat_inputs: torch.Tensor, + chunk_paths: tuple[tuple[str, ...], ...] = state_paths, + chunk_batch_sizes: dict[tuple[str, ...], torch.Size] = state_batch_sizes, + chunk_resets: torch.Tensor | None = reset_chunk, + chunk_k: int | torch.Tensor | None = k_chunk, + chunk_materialize: bool = chunk_materialize_final_state, + chunk_output_boundary: Literal["sequence", "terminal"] = chunk_output_boundary, + ) -> tuple[torch.Tensor, ...]: + state_values = flat_inputs[:-1] + next_input_state = _unflatten_tensordict(chunk_paths, chunk_batch_sizes, state_values) + output_cells, next_state = self._forward_sequence_with_readout( + hidden_piece, + next_input_state, + resets=chunk_resets, + k=chunk_k, + training_semantics=True, + materialize_final_state=chunk_materialize, + output_boundary=chunk_output_boundary, + ) + _, _, next_tensors = _flatten_tensordict(next_state) + return (output_cells, *next_tensors) + + def checkpoint_contexts(): + return nullcontext(), _preserve_backend_execution_record(self.runtime) + + checkpoint_outputs = checkpoint( + run_sequence, + hidden_chunk, + *state_tensors, + grad_marker, + use_reentrant=False, + preserve_rng_state=False, + context_fn=checkpoint_contexts, + ) + if output_boundary == "sequence" or end == hidden_seq.shape[1]: + outputs.append(checkpoint_outputs[0]) + if chunk_materialize_final_state: + running_state = _unflatten_tensordict(state_paths, state_batch_sizes, checkpoint_outputs[1:]) + else: + running_state = TensorDict({}, batch_size=[]) + return self._project_outputs_from_cells(torch.cat(outputs, dim=1)), running_state + + def _sequence_checkpoint_chunk_len( + self, + hidden_seq: torch.Tensor, + state: TensorDictBase, + ) -> int: + seq_len = int(hidden_seq.shape[1]) + if seq_len <= 1: + return seq_len + state_bytes = self._estimate_sequence_state_bytes(state, hidden_seq=hidden_seq) + if state_bytes <= 0: + return seq_len + estimated_per_step_bytes = int(math.ceil(state_bytes * self._sequence_checkpoint_overhead_factor())) + target_bytes = max(1, int(self._sequence_checkpoint_target_bytes_for_state())) + chunk_len = max(1, target_bytes // max(1, estimated_per_step_bytes)) + return min(seq_len, chunk_len) + + def _estimate_sequence_state_bytes( + self, + state: TensorDictBase, + *, + hidden_seq: torch.Tensor | None = None, + ) -> int: + _, _, state_tensors = _flatten_tensordict(state) + state_bytes = sum(int(t.numel()) * int(t.element_size()) for t in state_tensors) + if state_bytes > 0 or hidden_seq is None: + return state_bytes + return self._estimate_fresh_sequence_runtime_state_bytes( + batch_size=int(hidden_seq.shape[0]), + dtype=hidden_seq.dtype, + ) + + def _estimate_fresh_sequence_runtime_state_bytes( + self, + *, + batch_size: int, + dtype: torch.dtype, + ) -> int: + dtype_bytes = int(torch.empty((), dtype=dtype).element_size()) + hidden_size = int(self.runtime.hidden_size) + cells = int(self.runtime.coords.shape[0]) + total_elements = int(batch_size) * cells * hidden_size + for population_name in self.runtime._population_names: + population_cells = int(self.runtime._population_num_cells(population_name)) + if population_cells <= 0: + continue + state_leaf_count = len(self.runtime._cell_spec_for_population(population_name).state_schema.keys) + total_elements += int(batch_size) * population_cells * hidden_size * int(state_leaf_count) + return total_elements * dtype_bytes + + def _active_cell_sequence_memory_policy(self) -> dict[str, float | int]: + default_checkpoint_target_bytes = 32 << 30 + default_checkpoint_state_overhead_factor = 4.0 + default_direct_grad_target_bytes = 96 << 30 + default_output_overhead_factor = 6.0 + checkpoint_targets: list[int] = [] + checkpoint_overheads: list[float] = [] + direct_grad_targets: list[int] = [] + output_overheads: list[float] = [] + for population_name in self.spec.population_names: + if int(self.runtime._population_num_cells(population_name)) <= 0: + continue + backend_spec = self.runtime._backend_population_specs.get(population_name) + if backend_spec is None: + checkpoint_targets.append(default_checkpoint_target_bytes) + checkpoint_overheads.append(default_checkpoint_state_overhead_factor) + direct_grad_targets.append(default_direct_grad_target_bytes) + output_overheads.append(default_output_overhead_factor) + continue + checkpoint_targets.append(int(backend_spec.sequence_checkpoint_target_bytes)) + checkpoint_overheads.append(float(backend_spec.sequence_checkpoint_state_overhead_factor)) + direct_grad_targets.append(int(backend_spec.sequence_direct_grad_target_bytes)) + output_overheads.append(float(backend_spec.sequence_checkpoint_output_overhead_factor)) + policy: dict[str, float | int] = { + "checkpoint_target_bytes": ( + min(checkpoint_targets) if checkpoint_targets else default_checkpoint_target_bytes + ), + "checkpoint_state_overhead_factor": ( + max(checkpoint_overheads) if checkpoint_overheads else default_checkpoint_state_overhead_factor + ), + "direct_grad_target_bytes": ( + min(direct_grad_targets) if direct_grad_targets else default_direct_grad_target_bytes + ), + "checkpoint_output_overhead_factor": ( + max(output_overheads) if output_overheads else default_output_overhead_factor + ), + } + if self._sequence_checkpoint_target_bytes is not None: + policy["checkpoint_target_bytes"] = int(self._sequence_checkpoint_target_bytes) + if self._sequence_checkpoint_state_overhead_factor is not None: + policy["checkpoint_state_overhead_factor"] = float(self._sequence_checkpoint_state_overhead_factor) + if self._sequence_direct_grad_target_bytes is not None: + policy["direct_grad_target_bytes"] = int(self._sequence_direct_grad_target_bytes) + return policy + + def _active_cell_supports_direct_grad_sequence(self) -> bool: + if len(self.spec.population_names) != 1: + return False + population_name = self.spec.population_names[0] + backend_spec = self.runtime._backend_population_specs.get(population_name) + if backend_spec is None: + return False + return bool(backend_spec.supports_direct_grad_sequence) + + def _sequence_checkpoint_overhead_factor(self) -> float: + policy = self._active_cell_sequence_memory_policy() + return float(policy["checkpoint_state_overhead_factor"]) + + def _sequence_checkpoint_target_bytes_for_state(self) -> int: + policy = self._active_cell_sequence_memory_policy() + return int(policy["checkpoint_target_bytes"]) + + def _sequence_output_overhead_factor(self) -> float: + policy = self._active_cell_sequence_memory_policy() + return float(policy.get("checkpoint_output_overhead_factor", 6.0)) + + def _estimate_sequence_output_window_bytes( + self, + hidden_seq: torch.Tensor, + *, + output_boundary: Literal["sequence", "terminal"], + ) -> int: + time_steps = int(hidden_seq.shape[1]) if output_boundary == "sequence" else 1 + if time_steps <= 0: + return 0 + dtype_bytes = int(torch.empty((), dtype=hidden_seq.dtype).element_size()) + output_elements = int(hidden_seq.shape[0]) * int(time_steps) * int(self.output_dim_total) + return int(math.ceil(float(output_elements * dtype_bytes) * self._sequence_output_overhead_factor())) + + def _should_use_direct_grad_sequence( + self, + hidden_seq: torch.Tensor, + state: MaybeState, + *, + materialize_final_state: bool = True, + ) -> tuple[bool, TensorDictBase | None]: + if state is None and not materialize_final_state and int(hidden_seq.shape[1]) == 1: + population_name = self.runtime._full_recurrent_population_name + if ( + population_name is not None + and self.runtime._fresh_output_dependency_receiver_count( + population_name=population_name, + time_steps=int(hidden_seq.shape[1]), + fresh_state_virtualized=True, + ) + is not None + ): + return True, None + current_state = self.runtime._ensure_state( + state, + batch=hidden_seq.shape[0], + device=hidden_seq.device, + dtype=hidden_seq.dtype, + ) + if not self._active_cell_supports_direct_grad_sequence(): + return False, current_state + state_bytes = self._estimate_sequence_state_bytes(current_state) + if state_bytes <= 0: + return True, current_state + estimated_per_step_bytes = int(math.ceil(state_bytes * self._sequence_checkpoint_overhead_factor())) + estimated_window_bytes = int(hidden_seq.shape[1]) * estimated_per_step_bytes + return estimated_window_bytes <= int(self._sequence_direct_grad_target_bytes_for_state()), current_state + + def _should_use_direct_grad_reduced_sequence( + self, + hidden_seq: torch.Tensor, + state: MaybeState, + *, + materialize_final_state: bool = True, + output_boundary: Literal["sequence", "terminal"] = "sequence", + ) -> tuple[bool, TensorDictBase | None]: + current_state = self.runtime._ensure_state( + state, + batch=hidden_seq.shape[0], + device=hidden_seq.device, + dtype=hidden_seq.dtype, + ) + if not self._active_cell_supports_direct_grad_sequence(): + return False, current_state + state_bytes = self._estimate_sequence_state_bytes(current_state) + estimated_state_step_bytes = ( + 0 if state_bytes <= 0 else int(math.ceil(state_bytes * self._sequence_checkpoint_overhead_factor())) + ) + estimated_state_window_bytes = int(hidden_seq.shape[1]) * estimated_state_step_bytes + estimated_output_window_bytes = self._estimate_sequence_output_window_bytes( + hidden_seq, + output_boundary=output_boundary, + ) + direct_target_bytes = int(self._sequence_direct_grad_target_bytes_for_state()) + estimated_window_bytes = estimated_state_window_bytes + estimated_output_window_bytes + return estimated_window_bytes <= direct_target_bytes, current_state + + def _sequence_direct_grad_target_bytes_for_state(self) -> int: + policy = self._active_cell_sequence_memory_policy() + return int(policy["direct_grad_target_bytes"]) + + def _sequence_reduction_checkpoint_chunk_len( + self, + hidden_seq: torch.Tensor, + state: TensorDictBase, + *, + output_boundary: Literal["sequence", "terminal"] = "sequence", + ) -> int: + checkpoint_chunk_len = self._sequence_checkpoint_chunk_len(hidden_seq, state) + if output_boundary == "terminal": + return checkpoint_chunk_len + seq_len = int(hidden_seq.shape[1]) + if seq_len <= 1: + return seq_len + per_step_output_bytes = self._estimate_sequence_output_window_bytes( + hidden_seq[:, :1], + output_boundary="sequence", + ) + if per_step_output_bytes <= 0: + return checkpoint_chunk_len + target_bytes = max(1, int(self._sequence_checkpoint_target_bytes_for_state())) + output_chunk_len = max(1, target_bytes // per_step_output_bytes) + return min(seq_len, checkpoint_chunk_len, output_chunk_len) + + +def build(spec: Spec) -> nn.Module: + if spec.config.input_adapters is None and spec.config.output_adapters is None: + return Runtime(spec) + return Model(spec) + + +def _node_position_lookup(nodes: torch.Tensor) -> dict[int, int]: + return {int(node): idx for idx, node in enumerate(nodes.tolist())} + + +def _detach_tensordict(td: TensorDictBase) -> TensorDict: + out = TensorDict({}, batch_size=td.batch_size, device=td.device) + for key, value in td.items(): + if isinstance(value, TensorDictBase): + out[str(key)] = _detach_tensordict(value) + elif torch.is_tensor(value): + out[str(key)] = value.detach() + else: + out[str(key)] = value + return out + + +def _slice_batch_value(value: Any, start: int, end: int) -> Any: + if value is None: + return None + if torch.is_tensor(value): + return value[start:end] + if isinstance(value, TensorDictBase): + return value[start:end] + if isinstance(value, dict): + return {key: _slice_batch_value(item, start, end) for key, item in value.items()} + return value + + +def _cat_batch_value(values: list[Any]) -> Any: + active = [value for value in values if value is not None] + if not active: + return None + first = active[0] + if torch.is_tensor(first): + return torch.cat(cast(list[torch.Tensor], active), dim=0) + if isinstance(first, TensorDictBase): + if len(first.batch_size) == 0: + return TensorDict({}, batch_size=[]) + return torch.cat(cast(list[TensorDictBase], active), dim=0) + if isinstance(first, dict): + return {key: _cat_batch_value([value[key] for value in active]) for key in first} + return first + + +def _new_batch_like(value: Any, *, batch_size: int) -> Any: + if torch.is_tensor(value): + return value.new_empty((int(batch_size), *value.shape[1:])) + if isinstance(value, dict): + return {key: _new_batch_like(item, batch_size=batch_size) for key, item in value.items()} + raise TypeError(f"Cannot allocate batched Fabric output for {type(value)!r}") + + +def _copy_batch_output(destination: Any, source: Any, start: int, end: int) -> None: + if torch.is_tensor(destination) and torch.is_tensor(source): + destination[int(start) : int(end)].copy_(source) + return + if isinstance(destination, dict) and isinstance(source, dict): + for key in destination: + _copy_batch_output(destination[key], source[key], start, end) + return + raise TypeError(f"Cannot copy Fabric output chunk {type(source)!r} into {type(destination)!r}") + + +def _slice_batch_reset( + resets: ResetMask | None, + start: int, + end: int, + *, + batch_size: int, + device: torch.device, +) -> torch.Tensor | None: + if resets is None: + return None + mask = torch.as_tensor(resets, device=device, dtype=torch.bool) + if mask.dim() == 1 and mask.shape[0] == batch_size: + return mask[start:end] + if mask.dim() == 2 and mask.shape[0] == batch_size: + return mask[start:end] + raise ValueError(f"resets must have shape [B] or [B,T], got {tuple(mask.shape)}") + + +def _slice_batch_k( + k: int | torch.Tensor | None, + start: int, + end: int, + *, + batch_size: int, + time_steps: int, +) -> int | torch.Tensor | None: + if k is None or isinstance(k, int): + return k + if k.dim() == 1 and k.shape[0] == batch_size: + return k[start:end] + if k.dim() == 2 and k.shape == (batch_size, time_steps): + return k[start:end] + raise ValueError(f"k must be int, [B], or [B,T], got shape {tuple(k.shape)}") + + +def _flatten_step_idx( + step_idx: int | torch.Tensor, + *, + batch_size: int, + time_steps: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + if isinstance(step_idx, int): + return torch.full((batch_size * time_steps,), step_idx, device=device, dtype=dtype) + step_tensor = torch.as_tensor(step_idx, device=device, dtype=dtype) + if step_tensor.dim() == 1 and step_tensor.shape[0] == batch_size: + return step_tensor.view(batch_size, 1).expand(batch_size, time_steps).reshape(batch_size * time_steps) + if step_tensor.dim() == 2 and step_tensor.shape == (batch_size, time_steps): + return step_tensor.reshape(batch_size * time_steps) + raise ValueError(f"step_idx tensor must have shape [B] or [B,T], got {tuple(step_tensor.shape)}") + + +def _where_tensordict(mask: torch.Tensor, new_state: TensorDictBase, old_state: TensorDictBase) -> TensorDict: + out = TensorDict({}, batch_size=new_state.batch_size) + keys = set(new_state.keys()) | set(old_state.keys()) + for key in keys: + new_value = new_state.get(key) + old_value = old_state.get(key) + if isinstance(new_value, TensorDictBase) and isinstance(old_value, TensorDictBase): + out[key] = _where_tensordict(mask, new_value, old_value) + continue + if torch.is_tensor(new_value) and torch.is_tensor(old_value): + shape = tuple(mask.shape) + (1,) * (new_value.dim() - mask.dim()) + out[key] = torch.where(mask.view(shape), new_value, old_value) + continue + out[key] = new_value if new_value is not None else old_value + return out + + +def _population_buffer_name(population_name: str) -> str: + return f"_population_idx__{population_name}" + + +def _population_recurrent_buffer_name(population_name: str) -> str: + return f"_population_recurrent_idx__{population_name}" + + +def _shared_active_region_positions_buffer_name(population_name: str) -> str: + return f"_shared_active_region_positions__{population_name}" + + +def _shared_active_region_offsets_buffer_name(population_name: str) -> str: + return f"_shared_active_region_offsets__{population_name}" + + +def _shared_active_region_recurrent_idx_buffer_name(population_name: str) -> str: + return f"_shared_active_region_recurrent_idx__{population_name}" + + +def _select_receiver_tables( + neighbor_idx: torch.Tensor, + neighbor_valid: torch.Tensor, + edge_distance: torch.Tensor, + edge_delay: torch.Tensor, + receiver_idx: torch.Tensor, + sender_lookup: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + recv_neighbor_idx = neighbor_idx.index_select(0, receiver_idx) + recv_neighbor_valid = neighbor_valid.index_select(0, receiver_idx) + recv_edge_distance = edge_distance.index_select(0, receiver_idx) + recv_edge_delay = edge_delay.index_select(0, receiver_idx) + compact_idx = sender_lookup.index_select(0, recv_neighbor_idx.reshape(-1)).view_as(recv_neighbor_idx) + if bool((compact_idx[recv_neighbor_valid] < 0).any()): + raise ValueError("Receiver subset contains a sender outside the compact sender set") + compact_idx = torch.where(recv_neighbor_valid, compact_idx, torch.zeros_like(compact_idx)) + return compact_idx, recv_neighbor_valid, recv_edge_distance, recv_edge_delay + + +def _build_sparse_degree_grouping(neighbor_valid: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + degree = neighbor_valid.to(dtype=torch.long).sum(dim=1) + receiver_order = torch.argsort(degree) + max_degree = int(neighbor_valid.shape[1]) + degree_counts = torch.bincount(degree, minlength=max_degree + 1) + degree_ptr = torch.zeros(max_degree + 2, dtype=torch.long) + degree_ptr[1:] = degree_counts.cumsum(dim=0) + positive_degree_buckets = int((degree_counts[1:] > 0).sum().item()) + return receiver_order.to(dtype=torch.long), degree_ptr, positive_degree_buckets + + +def _build_sender_reverse_table( + num_senders: int, + receiver_sender_idx: torch.Tensor, + receiver_valid: torch.Tensor, +) -> torch.Tensor: + reverse = torch.full((num_senders, receiver_sender_idx.shape[1]), -1, dtype=torch.int32) + receiver_idx, offset_idx = torch.nonzero(receiver_valid, as_tuple=True) + sender_idx = receiver_sender_idx[receiver_idx, offset_idx] + if bool((sender_idx < 0).any()): + raise ValueError("receiver_sender_idx must be non-negative on valid local edges") + if bool((reverse[sender_idx, offset_idx] >= 0).any()): + raise ValueError("local sender reverse table expects unique receiver per sender/offset") + reverse[sender_idx, offset_idx] = receiver_idx.to(torch.int32) + return reverse + + +def _contiguous_recurrent_sender_window( + *, + num_senders: int, + recurrent_sender_idx: torch.Tensor, + receiver_sender_idx: torch.Tensor, + receiver_valid: torch.Tensor, +) -> tuple[int, int, bool]: + sender_to_recurrent = torch.full((num_senders,), -1, dtype=torch.long) + sender_to_recurrent[recurrent_sender_idx.to(dtype=torch.long)] = torch.arange( + recurrent_sender_idx.numel(), dtype=torch.long + ) + valid_senders = receiver_sender_idx[receiver_valid.to(dtype=torch.bool)].to(dtype=torch.long) + valid_senders = valid_senders[(valid_senders >= 0) & (valid_senders < num_senders)] + if valid_senders.numel() == 0: + return 0, 0, False + recurrent_receivers = sender_to_recurrent.index_select(0, valid_senders) + recurrent_receivers = recurrent_receivers[recurrent_receivers >= 0] + if recurrent_receivers.numel() == 0: + return 0, 0, False + unique_receivers = torch.unique(recurrent_receivers, sorted=True) + start = int(unique_receivers[0].item()) + count = int(unique_receivers.numel()) + contiguous = bool((unique_receivers[-1] - unique_receivers[0] + 1).item() == count) + return start, count, contiguous + + +def _build_local_sender_table( + *, + receiver_coords: torch.Tensor, + sender_lookup: torch.Tensor, + local_offsets: torch.Tensor, + local_valid: torch.Tensor, + coord_shape: tuple[int, ...], + wrap: bool, +) -> torch.Tensor: + receiver_coords_long = receiver_coords.to(torch.long) + local_offsets_long = local_offsets.to(torch.long) + sender_table = torch.full(local_valid.shape, -1, dtype=torch.long) + target_coords = receiver_coords_long[:, None, :] + local_offsets_long[None, :, :] + for dim, size in enumerate(coord_shape): + if wrap: + target_coords[..., dim] = torch.remainder(target_coords[..., dim], size) + else: + target_coords[..., dim].clamp_(0, size - 1) + target_flat = target_coords[..., 0] + for dim, size in enumerate(coord_shape[1:], start=1): + target_flat = target_flat * size + target_coords[..., dim] + sender_table[local_valid] = sender_lookup.index_select(0, target_flat[local_valid]) + if bool((sender_table[local_valid] < 0).any()): + raise ValueError("Local receiver subset contains a sender outside the compact sender set") + return sender_table + + +def _detect_uniform_contiguous_groups(group_ids: torch.Tensor) -> tuple[torch.Tensor | None, int]: + if group_ids.numel() == 0: + return None, 0 + if not bool(torch.all(group_ids[1:] >= group_ids[:-1])): + return None, 0 + unique_ids, counts = torch.unique_consecutive(group_ids, return_counts=True) + if unique_ids.numel() == 0 or not bool(torch.all(counts == counts[0])): + return None, 0 + group_size = int(counts[0].item()) + if group_size <= 1: + return None, 0 + return unique_ids.to(dtype=torch.long), group_size + + +def _contiguous_cpu_index_range(indices: torch.Tensor | None) -> tuple[int, int] | None: + if not torch.is_tensor(indices) or indices.dim() != 1: + return None + count = int(indices.numel()) + if count == 0: + return 0, 0 + if indices.is_cuda: + raise ValueError("contiguous group-range detection expects CPU construction-time indices") + start = int(indices[0]) + expected = torch.arange(start, start + count, dtype=torch.long) + if torch.equal(indices.to(dtype=torch.long), expected): + return start, count + return None + + +__all__ = ["Model", "Runtime", "build"] diff --git a/benchmarks/README.md b/benchmarks/README.md index b13fb525..3a37988b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -13,6 +13,12 @@ uv run python benchmarks/run.py [--device cuda] [--warmup N] [-- Keys: `rtu`, `slstm`, `mlstm`, `lstm`, `conv1d`, `axons`, `linear_vs_axon`, `fabric`. +Fabric audit experiments use the Fabric package entrypoint only: + +``` +uv run python -m benchmarks.fabric.run_audit --help +``` + ### RTU (streaming diag) ```bash diff --git a/benchmarks/fabric/__init__.py b/benchmarks/fabric/__init__.py new file mode 100644 index 00000000..3c736276 --- /dev/null +++ b/benchmarks/fabric/__init__.py @@ -0,0 +1,3 @@ +"""Fabric benchmark and audit entrypoints.""" + +from . import benchmark as _benchmark # noqa: F401 diff --git a/benchmarks/fabric/audit.py b/benchmarks/fabric/audit.py new file mode 100644 index 00000000..b3384047 --- /dev/null +++ b/benchmarks/fabric/audit.py @@ -0,0 +1,1095 @@ +from __future__ import annotations + +import argparse +import json +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Literal + +import torch + +_REPO_ROOT = Path(__file__).resolve().parents[2] +for _path in (_REPO_ROOT, _REPO_ROOT / "src"): + if str(_path) not in sys.path: + sys.path.insert(0, str(_path)) + +from benchmarks.fabric.suite_common import ( # noqa: E402 + BackboneFamily, + PopulationMode, + SequenceMode, + find_param_matched_backbone, + find_param_matched_mixed_fabric_backbone, + find_param_matched_mixed_stack_backbone, + resolve_backbone_benchmark_dtype, + run_mixed_fabric_sequence_case, + run_mixed_stack_sequence_case, + run_sequence_case, +) + +AuditPlan = Literal["smoke", "t1-single-pop", "tk-scaling", "full"] +TrainingOutputBoundary = Literal["terminal", "sequence"] +ResetMode = Literal["absent", "present"] +MIXED_STACK_PARAM_ERROR_BAND = 0.05 +REGISTERED_TEMPORAL_FORWARD_PLAN_OWNERS = { + "registered_fused_forward_program_cuda", +} +REGISTERED_TEMPORAL_FORWARD_RUNTIME_OWNERS = { + "registered_fused_forward_program_cuda", +} +REGISTERED_TEMPORAL_FORWARD_IMPLEMENTATIONS = { + "registered_temporal_fused_forward_program_cuda", +} +REGISTERED_TEMPORAL_BACKWARD_PLAN_OWNERS = { + "registered_reverse_executor_bindings", +} +REGISTERED_TEMPORAL_BACKWARD_RUNTIME_EXECUTORS = { + "physical_temporal_bucket_sequence_backward", + "registered_reverse_executor_bindings", + "cuda_temporal_backward_glue", +} +FORBIDDEN_TEMPORAL_BACKWARD_TIMING_OWNERS = { + "artifact.recompute.cuda_replay_input_projection", + "artifact.recompute.cuda_temporal_replay_scan", + "artifact.recompute.input_projection_sequence", + "temporal_artifact_recompute", + "transition_message_reverse_table_device_loop", +} +FORBIDDEN_TEMPORAL_RECOMPUTE_MARKERS = ( + "open_cuda_temporal_superop", + "python_autograd_scan", + "python_step_replay", +) + +APRIL21_BASELINE_JSON = _REPO_ROOT / "audits/fabric/2026-04-21/fabric_physical_backend_final_results_2026-04-21.json" + + +@dataclass(frozen=True) +class April21Reference: + key: str + source: str + reference_kind: Literal["summary_floor", "exact_streaming_row", "informational"] + tokens_per_s: float | None + peak_mem_gib: float | None + status: str + row: dict[str, Any] + + +@dataclass(frozen=True) +class FabricAuditCase: + case_id: str + plan: AuditPlan + owner_stage: str + prompt_requirements: tuple[str, ...] + family: BackboneFamily + target_params: int + mode: SequenceMode + batch_size: int + seq_len: int + inner_steps: int + gradient_horizon_steps: int | None + checkpoint_steps: int | None + hidden_size: int + training_output_boundary: TrainingOutputBoundary + reset_mode: ResetMode + population_mode: PopulationMode + reference_key: str + high_level_api_contract: str = "model_forward_external_loss_backward_optimizer_step" + + +def load_april21_references(path: Path = APRIL21_BASELINE_JSON) -> dict[str, April21Reference]: + data = json.loads(path.read_text()) + references: dict[str, April21Reference] = {} + for row in data.get("audit_matrix", []): + key = str(row["audit"]) + references[key] = April21Reference( + key=key, + source=str(path), + reference_kind="summary_floor", + tokens_per_s=_optional_float(row.get("fabric_tokens_per_s")), + peak_mem_gib=_optional_float(row.get("fabric_peak_gib")), + status=str(row.get("status", "unknown")), + row=dict(row), + ) + for row in data.get("streaming_sequence_loss", []): + key = _streaming_reference_key( + family=str(row["family"]).lower(), + params_label=str(row["params_label"]).lower(), + batch=int(row["batch"]), + seq_len=int(row["seq_len"]), + hidden_size=int(row["hidden_size"]), + ) + references[key] = April21Reference( + key=key, + source=str(path), + reference_kind="exact_streaming_row", + tokens_per_s=_optional_float(row.get("fabric_tokens_per_s")), + peak_mem_gib=_optional_float(row.get("fabric_peak_gib")), + status="closed", + row=dict(row), + ) + return references + + +def build_case_manifest( + *, + plan: AuditPlan, + families: tuple[BackboneFamily, ...], + target_params: tuple[int, ...], + modes: tuple[SequenceMode, ...], + batches: tuple[int, ...], + seq_lens: tuple[int, ...], + inner_steps: tuple[int, ...], + hidden_sizes: tuple[int, ...], + gradient_horizon_steps: tuple[int | None, ...] = (None,), + checkpoint_steps: tuple[int | None, ...] = (None,), + reset_modes: tuple[ResetMode, ...] = ("absent",), + population_modes: tuple[PopulationMode, ...] = ("single",), + training_output_boundaries: tuple[TrainingOutputBoundary, ...] | None = None, +) -> list[FabricAuditCase]: + if plan == "full": + selected_plans: tuple[AuditPlan, ...] = ("t1-single-pop", "tk-scaling") + else: + selected_plans = (plan,) + cases: list[FabricAuditCase] = [] + for selected_plan in selected_plans: + if selected_plan == "smoke": + cases.extend(_smoke_cases(population_modes=population_modes)) + continue + for population_mode in population_modes: + for family in families: + for params in target_params: + for mode in modes: + output_boundaries = _training_output_boundaries_for_case( + selected_plan=selected_plan, + mode=mode, + requested=training_output_boundaries, + ) + for batch_size in batches: + for seq_len in seq_lens: + for k in inner_steps: + for horizon in gradient_horizon_steps: + for checkpoint in checkpoint_steps: + for hidden_size in hidden_sizes: + for reset_mode in reset_modes: + for training_output_boundary in output_boundaries: + if selected_plan == "t1-single-pop" and ( + seq_len != 1 or k != 1 + ): + continue + if selected_plan == "tk-scaling" and seq_len <= 1 and k <= 1: + continue + reference_key = select_reference_key( + family=family, + target_params=params, + mode=mode, + batch_size=batch_size, + seq_len=seq_len, + inner_steps=k, + hidden_size=hidden_size, + training_output_boundary=training_output_boundary, + ) + cases.append( + FabricAuditCase( + case_id=_case_id( + selected_plan, + family, + params, + mode, + batch_size, + seq_len, + k, + hidden_size, + gradient_horizon_steps=horizon, + checkpoint_steps=checkpoint, + training_output_boundary=training_output_boundary, + reset_mode=reset_mode, + population_mode=population_mode, + ), + plan=selected_plan, + owner_stage=_owner_stage_for_case( + selected_plan=selected_plan, + population_mode=population_mode, + ), + prompt_requirements=_prompt_requirements_for_case( + selected_plan=selected_plan, + population_mode=population_mode, + ), + family=family, + target_params=params, + mode=mode, + batch_size=batch_size, + seq_len=seq_len, + inner_steps=k, + gradient_horizon_steps=horizon, + checkpoint_steps=checkpoint, + hidden_size=hidden_size, + training_output_boundary=training_output_boundary, + reset_mode=reset_mode, + population_mode=population_mode, + reference_key=reference_key, + ) + ) + return cases + + +def _training_output_boundaries_for_case( + *, + selected_plan: AuditPlan, + mode: SequenceMode, + requested: tuple[TrainingOutputBoundary, ...] | None, +) -> tuple[TrainingOutputBoundary, ...]: + if requested is not None: + return requested + if selected_plan == "tk-scaling": + return ("sequence", "terminal") if mode == "forward_backward" else ("sequence",) + return ("terminal",) + + +def select_reference_key( + *, + family: BackboneFamily, + target_params: int, + mode: SequenceMode, + batch_size: int, + seq_len: int, + inner_steps: int, + hidden_size: int, + training_output_boundary: TrainingOutputBoundary, +) -> str: + params_label = _params_label(target_params) + if ( + mode == "forward_backward" + and training_output_boundary == "sequence" + and inner_steps == 1 + and hidden_size == 32 + and batch_size == 512 + and seq_len in {512, 4096} + and params_label in {"500m", "1b"} + ): + return _streaming_reference_key( + family=family, + params_label=params_label, + batch=batch_size, + seq_len=seq_len, + hidden_size=hidden_size, + ) + if hidden_size == 4: + return "h4_many_cell_stress" + if hidden_size == 8: + return "h8_many_cell_stress_focused_warmed_rerun" if family == "axoncell" else "h8_many_cell_stress_broad" + if hidden_size == 16: + return "h16_many_cell_stress" + if target_params <= 10_000_000: + return "h32_small_params_high_batch" + if seq_len > 1 or inner_steps > 1 or training_output_boundary == "sequence": + return "streaming_per_timestep_sequence_loss" + return "h32_t1_bxparams" + + +def _owner_stage_for_case(*, selected_plan: AuditPlan, population_mode: PopulationMode) -> str: + if selected_plan == "t1-single-pop": + return "R11" if population_mode == "single" else "R12" + return "R13" + + +def _prompt_requirements_for_case( + *, + selected_plan: AuditPlan, + population_mode: PopulationMode, +) -> tuple[str, ...]: + if selected_plan == "t1-single-pop": + if population_mode == "single": + return ("P0", "P3", "P6", "P7", "P11", "P19") + return ("P0", "P2", "P3", "P6", "P7", "P12", "P19") + requirements = ["P0", "P4", "P5", "P7", "P13", "P19"] + if population_mode == "mixed": + requirements.insert(1, "P2") + return tuple(requirements) + + +def run_audit(args: argparse.Namespace) -> int: + cases = build_case_manifest( + plan=args.plan, + families=_parse_families(args.families), + target_params=_parse_sizes(args.sizes), + modes=_parse_modes(args.modes), + batches=_parse_int_csv(args.batches), + seq_lens=_parse_int_csv(args.seq_lens), + inner_steps=_parse_int_csv(args.inner_steps), + gradient_horizon_steps=_parse_optional_int_csv(args.gradient_horizon_steps), + checkpoint_steps=_parse_optional_int_csv(args.checkpoint_steps), + hidden_sizes=_parse_int_csv(args.hidden_sizes), + reset_modes=_parse_reset_modes(args.reset_modes), + population_modes=_parse_population_modes(args.population_modes), + training_output_boundaries=_parse_training_output_boundaries(args.training_output_boundaries), + ) + cases = _slice_cases(cases, index=args.case_index, count=args.case_count) + if args.limit is not None: + cases = cases[: int(args.limit)] + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + _write_json(out_dir / "manifest.json", [asdict(case) for case in cases]) + shared_coverage_gate = _shared_temporal_coverage_gate(cases) + if args.require_shared_temporal_coverage and shared_coverage_gate["status"] == "fail": + _write_summary(out_dir, cases=cases, results=[], references={}, elapsed_s=0.0, args=args) + print( + json.dumps( + { + "status": "failed", + "reason": shared_coverage_gate["reason"], + "cases": len(cases), + "out_dir": str(out_dir), + }, + indent=2, + ) + ) + return 1 + if args.dry_run: + _write_summary(out_dir, cases=cases, results=[], references={}, elapsed_s=0.0, args=args) + print(json.dumps({"status": "dry_run", "cases": len(cases), "out_dir": str(out_dir)}, indent=2)) + return 0 + + references = load_april21_references(Path(args.baseline_json)) + started_at = time.time() + results = [] + for case in cases: + result = run_case(case, references=references, args=args) + results.append(result) + _append_jsonl(out_dir / "cases.jsonl", result) + _write_summary( + out_dir, cases=cases, results=results, references=references, elapsed_s=time.time() - started_at, args=args + ) + _write_summary( + out_dir, cases=cases, results=results, references=references, elapsed_s=time.time() - started_at, args=args + ) + failed = [result for result in results if result.get("gate", {}).get("status") == "fail"] + print( + json.dumps({"status": "failed" if failed else "ok", "cases": len(results), "out_dir": str(out_dir)}, indent=2) + ) + return 1 if failed and (args.enforce_references or args.require_cuda_temporal_owner) else 0 + + +def run_case( + case: FabricAuditCase, + *, + references: dict[str, April21Reference], + args: argparse.Namespace, +) -> dict[str, Any]: + device = torch.device(args.device) + requested_dtype = _parse_dtype(args.dtype) + dtype = resolve_backbone_benchmark_dtype(device=device, requested_dtype=requested_dtype) + resets = _make_reset_mask( + case.reset_mode, + batch_size=case.batch_size, + seq_len=case.seq_len, + device=device, + ) + if case.population_mode == "mixed": + match = find_param_matched_mixed_fabric_backbone( + target_params=case.target_params, + family=case.family, + fabric_hidden_grid=(case.hidden_size,), + ) + result = run_mixed_fabric_sequence_case( + match=match, + mode=case.mode, + batch_size=case.batch_size, + seq_len=case.seq_len, + device=device, + dtype=dtype, + warmup=int(args.warmup), + iterations=int(args.iterations), + training_output_boundary=case.training_output_boundary, + inner_steps=case.inner_steps, + gradient_horizon_steps=case.gradient_horizon_steps, + checkpoint_steps=case.checkpoint_steps, + resets=resets, + ) + if _case_requires_k_adjusted_t1_baseline(case): + baseline_resets = _make_reset_mask( + case.reset_mode, + batch_size=case.batch_size, + seq_len=1, + device=device, + ) + t1_result = run_mixed_fabric_sequence_case( + match=match, + mode="forward_backward", + batch_size=case.batch_size, + seq_len=1, + device=device, + dtype=dtype, + warmup=int(args.warmup), + iterations=int(args.iterations), + training_output_boundary=case.training_output_boundary, + inner_steps=1, + gradient_horizon_steps=None, + checkpoint_steps=None, + resets=baseline_resets, + ) + _attach_k_adjusted_t1_baseline(result, t1_result, case) + if _case_requires_mixed_stack_baseline(case): + stack_match = find_param_matched_mixed_stack_backbone( + target_params=int(result.get("actual_params", match.actual_params)), + family=case.family, + ) + stack_result = run_mixed_stack_sequence_case( + match=stack_match, + mode=case.mode, + batch_size=case.batch_size, + seq_len=case.seq_len, + device=device, + dtype=dtype, + warmup=int(args.warmup), + iterations=int(args.iterations), + training_output_boundary=case.training_output_boundary, + forward_output_boundary=case.training_output_boundary, + resets=resets, + ) + _attach_mixed_stack_baseline(result, stack_result) + else: + match = find_param_matched_backbone( + target_params=case.target_params, + kind="fabric", + family=case.family, + fabric_hidden_grid=(case.hidden_size,), + ) + result = run_sequence_case( + match=match, + mode=case.mode, + batch_size=case.batch_size, + seq_len=case.seq_len, + family=case.family, + device=device, + dtype=dtype, + warmup=int(args.warmup), + iterations=int(args.iterations), + training_output_boundary=case.training_output_boundary, + inner_steps=case.inner_steps, + gradient_horizon_steps=case.gradient_horizon_steps, + checkpoint_steps=case.checkpoint_steps, + resets=resets, + ) + if _case_requires_k_adjusted_t1_baseline(case): + baseline_resets = _make_reset_mask( + case.reset_mode, + batch_size=case.batch_size, + seq_len=1, + device=device, + ) + t1_result = run_sequence_case( + match=match, + mode="forward_backward", + batch_size=case.batch_size, + seq_len=1, + family=case.family, + device=device, + dtype=dtype, + warmup=int(args.warmup), + iterations=int(args.iterations), + training_output_boundary=case.training_output_boundary, + inner_steps=1, + gradient_horizon_steps=None, + checkpoint_steps=None, + resets=baseline_resets, + ) + _attach_k_adjusted_t1_baseline(result, t1_result, case) + reference = references.get(case.reference_key) + return { + "case": asdict(case), + "reference": asdict(reference) if reference is not None else None, + "result": result, + "gate": compare_to_reference( + result, + reference, + case=case, + enforce=bool(args.enforce_references), + require_cuda_temporal_owner=bool(args.require_cuda_temporal_owner), + ), + "api_proof": { + "high_level_api_contract": case.high_level_api_contract, + "private_runtime_or_planner_calls": False, + "benchmark_owned_temporal_tiling": False, + }, + } + + +def compare_to_reference( + result: dict[str, Any], + reference: April21Reference | None, + *, + case: FabricAuditCase, + enforce: bool, + require_cuda_temporal_owner: bool = False, +) -> dict[str, Any]: + if reference is None: + return {"status": "missing_reference", "enforced": enforce} + if result.get("status") != "ok": + return {"status": "fail" if enforce else "not_run_ok", "reason": result.get("error"), "enforced": enforce} + if require_cuda_temporal_owner: + owner_gate = _cuda_temporal_owner_gate(result) + if owner_gate is not None: + return {"status": "fail", "enforced": True, **owner_gate} + tokens = _optional_float(result.get("tokens_per_s")) + peak = _optional_float(result.get("peak_mem_gib")) + k_adjusted_gate = _k_adjusted_t1_baseline_gate(result, case) + reference_tokens = ( + _optional_float(k_adjusted_gate.get("floor_tokens_per_s")) + if k_adjusted_gate is not None + else reference.tokens_per_s + ) + token_ok = reference_tokens is None or (tokens is not None and tokens >= reference_tokens) + memory_ok = reference.peak_mem_gib is None or (peak is not None and peak <= reference.peak_mem_gib) + mixed_stack_gate = _mixed_stack_baseline_gate(result) + if enforce and k_adjusted_gate is not None and k_adjusted_gate.get("status") == "fail": + return { + "status": "fail", + "reason": k_adjusted_gate.get("reason", "k_adjusted_t1_training_floor_failed"), + "tokens_per_s": tokens, + "reference_tokens_per_s": reference_tokens, + "april21_reference_tokens_per_s": reference.tokens_per_s, + "k_adjusted_t1_gate": k_adjusted_gate, + "enforced": True, + } + if enforce and not token_ok: + return { + "status": "fail", + "reason": "tokens_per_s_below_k_adjusted_t1_training_floor" + if k_adjusted_gate is not None + else "tokens_per_s_below_april21_reference", + "tokens_per_s": tokens, + "reference_tokens_per_s": reference_tokens, + "april21_reference_tokens_per_s": reference.tokens_per_s, + "k_adjusted_t1_gate": k_adjusted_gate, + "enforced": True, + } + if enforce and not memory_ok: + return { + "status": "fail", + "reason": "peak_memory_above_april21_reference", + "peak_mem_gib": peak, + "reference_peak_mem_gib": reference.peak_mem_gib, + "enforced": True, + } + if enforce and mixed_stack_gate is not None and mixed_stack_gate["status"] == "fail": + return { + "status": "fail", + "reason": mixed_stack_gate["reason"], + "mixed_stack_gate": mixed_stack_gate, + "enforced": True, + } + return { + "status": "pass" if enforce else "informational", + "tokens_per_s": tokens, + "reference_tokens_per_s": reference_tokens, + "april21_reference_tokens_per_s": reference.tokens_per_s, + "peak_mem_gib": peak, + "reference_peak_mem_gib": reference.peak_mem_gib, + "reference_kind": reference.reference_kind, + "mixed_stack_gate": mixed_stack_gate, + "k_adjusted_t1_gate": k_adjusted_gate, + "enforced": enforce, + } + + +def _attach_mixed_stack_baseline(result: dict[str, Any], stack_result: dict[str, object]) -> None: + result["mixed_stack_baseline"] = stack_result + result["mixed_stack_actual_params"] = stack_result.get("actual_params") + result["mixed_stack_status"] = stack_result.get("status") + fabric_tokens = _optional_float(result.get("tokens_per_s")) + stack_tokens = _optional_float(stack_result.get("tokens_per_s")) + fabric_params = result.get("actual_params") + stack_params = stack_result.get("actual_params") + if isinstance(fabric_params, int | float) and isinstance(stack_params, int | float): + result["mixed_stack_param_error"] = (int(fabric_params) - int(stack_params)) / max(1, int(fabric_params)) + if fabric_tokens is not None and stack_tokens is not None and stack_tokens > 0: + result["mixed_fabric_stack_ratio"] = fabric_tokens / stack_tokens + + +def _case_requires_mixed_stack_baseline(case: FabricAuditCase) -> bool: + return case.population_mode == "mixed" and case.seq_len == 1 and case.inner_steps == 1 + + +def _case_requires_k_adjusted_t1_baseline(case: FabricAuditCase) -> bool: + return case.plan == "tk-scaling" and case.mode == "forward_backward" and int(case.inner_steps) > 1 + + +def _attach_k_adjusted_t1_baseline( + result: dict[str, Any], + t1_result: dict[str, object], + case: FabricAuditCase, +) -> None: + result["matched_t1_training_baseline"] = t1_result + result["matched_t1_training_case"] = { + "population_mode": case.population_mode, + "family": case.family, + "target_params": case.target_params, + "batch_size": case.batch_size, + "seq_len": 1, + "inner_steps": 1, + "hidden_size": case.hidden_size, + "training_output_boundary": case.training_output_boundary, + "reset_mode": case.reset_mode, + } + t1_tokens = _optional_float(t1_result.get("tokens_per_s")) + if t1_tokens is not None: + result["matched_t1_training_tokens_per_s"] = t1_tokens + result["k_adjusted_t1_floor_tokens_per_s"] = t1_tokens / max(1, int(case.inner_steps)) + result["k_adjusted_t1_floor_divisor"] = int(case.inner_steps) + + +def _k_adjusted_t1_baseline_gate(result: dict[str, Any], case: FabricAuditCase) -> dict[str, Any] | None: + if not _case_requires_k_adjusted_t1_baseline(case): + return None + baseline = result.get("matched_t1_training_baseline") + if not isinstance(baseline, dict): + return { + "status": "fail", + "reason": "missing_matched_t1_training_baseline", + "divisor": int(case.inner_steps), + } + if baseline.get("status") != "ok": + return { + "status": "fail", + "reason": "matched_t1_training_baseline_not_ok", + "baseline_status": baseline.get("status"), + "baseline_error": baseline.get("error"), + "divisor": int(case.inner_steps), + } + tokens = _optional_float(result.get("tokens_per_s")) + t1_tokens = _optional_float(result.get("matched_t1_training_tokens_per_s")) + floor = _optional_float(result.get("k_adjusted_t1_floor_tokens_per_s")) + if t1_tokens is None or floor is None: + return { + "status": "fail", + "reason": "missing_matched_t1_training_tokens_per_s", + "divisor": int(case.inner_steps), + } + passed = tokens is not None and tokens >= floor + return { + "status": "pass" if passed else "fail", + "reason": None if passed else "tokens_per_s_below_matched_t1_divided_by_k", + "tokens_per_s": tokens, + "matched_t1_training_tokens_per_s": t1_tokens, + "floor_tokens_per_s": floor, + "divisor": int(case.inner_steps), + "ratio_to_floor": None if tokens is None or floor <= 0 else tokens / floor, + } + + +def _mixed_stack_baseline_gate(result: dict[str, Any]) -> dict[str, Any] | None: + if "mixed_stack_baseline" not in result: + return None + stack_status = result.get("mixed_stack_status") + if stack_status != "ok": + return { + "status": "fail", + "reason": "mixed_stack_baseline_not_ok", + "mixed_stack_status": stack_status, + } + ratio = _optional_float(result.get("mixed_fabric_stack_ratio")) + param_error = _optional_float(result.get("mixed_stack_param_error")) + if param_error is None: + return { + "status": "fail", + "reason": "missing_mixed_stack_param_error", + } + if abs(param_error) > MIXED_STACK_PARAM_ERROR_BAND: + return { + "status": "fail", + "reason": "mixed_stack_params_not_matched", + "mixed_stack_param_error": param_error, + "mixed_stack_param_error_band": MIXED_STACK_PARAM_ERROR_BAND, + } + if ratio is None: + return { + "status": "fail", + "reason": "missing_mixed_fabric_stack_ratio", + } + if ratio <= 1.0: + return { + "status": "fail", + "reason": "mixed_fabric_tokens_not_above_stack", + "mixed_fabric_stack_ratio": ratio, + } + return { + "status": "pass", + "mixed_fabric_stack_ratio": ratio, + "mixed_stack_param_error": param_error, + } + + +def _cuda_temporal_owner_gate(result: dict[str, Any]) -> dict[str, Any] | None: + planner_signature = result.get("planner_signature") + if not isinstance(planner_signature, dict): + return {"reason": "missing_planner_signature_for_temporal_owner_gate"} + forward_owners = tuple(str(owner) for owner in planner_signature.get("temporal_plan_forward_owners", ())) + backward_owners = tuple(str(owner) for owner in planner_signature.get("temporal_plan_backward_owners", ())) + if not forward_owners or any(owner not in REGISTERED_TEMPORAL_FORWARD_PLAN_OWNERS for owner in forward_owners): + return { + "reason": "forward_temporal_owner_not_registered_program", + "temporal_plan_forward_owners": forward_owners, + "accepted_forward_owners": tuple(sorted(REGISTERED_TEMPORAL_FORWARD_PLAN_OWNERS)), + } + runtime_forward_owners = tuple(str(owner) for owner in planner_signature.get("launch_temporal_scan_owners", ())) + if not runtime_forward_owners or any( + owner not in REGISTERED_TEMPORAL_FORWARD_RUNTIME_OWNERS for owner in runtime_forward_owners + ): + return { + "reason": "runtime_forward_temporal_owner_not_registered_program", + "launch_temporal_scan_owners": runtime_forward_owners, + "accepted_runtime_forward_owners": tuple(sorted(REGISTERED_TEMPORAL_FORWARD_RUNTIME_OWNERS)), + } + launch_scan_implementations = tuple( + str(implementation) for implementation in planner_signature.get("launch_scan_implementations", ()) + ) + if not any( + implementation in REGISTERED_TEMPORAL_FORWARD_IMPLEMENTATIONS for implementation in launch_scan_implementations + ): + return { + "reason": "runtime_forward_temporal_implementation_not_registered_program", + "launch_scan_implementations": launch_scan_implementations, + "accepted_forward_implementations": tuple(sorted(REGISTERED_TEMPORAL_FORWARD_IMPLEMENTATIONS)), + } + primitive_executor_blockers = tuple( + str(blocker) for blocker in planner_signature.get("temporal_primitive_executor_blockers", ()) + ) + if primitive_executor_blockers: + return { + "reason": "temporal_primitive_executor_blockers_present", + "temporal_primitive_executor_blockers": primitive_executor_blockers, + } + mode = str(result.get("mode", "")) + if mode == "forward_backward" and ( + not backward_owners or any(owner not in REGISTERED_TEMPORAL_BACKWARD_PLAN_OWNERS for owner in backward_owners) + ): + return { + "reason": "backward_temporal_owner_not_registered_program", + "temporal_plan_backward_owners": backward_owners, + "accepted_backward_owners": tuple(sorted(REGISTERED_TEMPORAL_BACKWARD_PLAN_OWNERS)), + } + if mode == "forward_backward": + backward_executors = tuple( + str(executor) for executor in planner_signature.get("backward_physical_op_executors", ()) + ) + if not any(executor in REGISTERED_TEMPORAL_BACKWARD_RUNTIME_EXECUTORS for executor in backward_executors): + return { + "reason": "runtime_backward_temporal_executor_not_registered_program", + "backward_physical_op_executors": backward_executors, + "accepted_backward_executors": tuple(sorted(REGISTERED_TEMPORAL_BACKWARD_RUNTIME_EXECUTORS)), + } + timing_entries = tuple(str(entry) for entry in planner_signature.get("backward_owner_timing_ms", ())) + forbidden_timing = tuple( + entry.split(":ms=", 1)[0] + for entry in timing_entries + if entry.split(":ms=", 1)[0] in FORBIDDEN_TEMPORAL_BACKWARD_TIMING_OWNERS + ) + if forbidden_timing: + return { + "reason": "forbidden_backward_temporal_owner_timing_present", + "backward_owner_timing_ms": timing_entries, + "forbidden_backward_temporal_owners": forbidden_timing, + } + recompute_entries = tuple(str(entry) for entry in planner_signature.get("backward_recompute_mode", ())) + forbidden_recompute = tuple( + entry + for entry in recompute_entries + if any(marker in entry for marker in FORBIDDEN_TEMPORAL_RECOMPUTE_MARKERS) + ) + if forbidden_recompute: + return { + "reason": "forbidden_backward_temporal_recompute_mode_present", + "backward_recompute_mode": recompute_entries, + "forbidden_backward_recompute_mode": forbidden_recompute, + } + return None + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Canonical Fabric audit runner using the high-level Fabric API.") + parser.add_argument("--plan", choices=("smoke", "t1-single-pop", "tk-scaling", "full"), default="smoke") + parser.add_argument("--out-dir", default="audits/fabric/redo_fixmaass/current") + parser.add_argument("--baseline-json", default=str(APRIL21_BASELINE_JSON)) + parser.add_argument("--families", default="slstm,axoncell") + parser.add_argument("--sizes", default="100m,500m,1b") + parser.add_argument("--modes", default="forward,forward_backward") + parser.add_argument("--batches", default="1024,16384") + parser.add_argument("--seq-lens", default="1") + parser.add_argument("--inner-steps", default="1") + parser.add_argument("--gradient-horizon-steps", default="none") + parser.add_argument("--checkpoint-steps", default="none") + parser.add_argument("--hidden-sizes", default="32") + parser.add_argument( + "--training-output-boundaries", + default="auto", + help="auto, terminal, sequence, or comma-separated terminal,sequence; auto includes both for T*K training", + ) + parser.add_argument("--reset-modes", default="absent") + parser.add_argument("--population-modes", default="single") + parser.add_argument("--device", default="cuda") + parser.add_argument("--dtype", default="float32") + parser.add_argument("--warmup", type=int, default=1) + parser.add_argument("--iterations", type=int, default=3) + parser.add_argument("--limit", type=int) + parser.add_argument("--case-index", type=int, default=0) + parser.add_argument("--case-count", type=int, default=1) + parser.add_argument("--dry-run", action="store_true") + parser.add_argument("--enforce-references", action="store_true") + parser.add_argument("--require-cuda-temporal-owner", action="store_true") + parser.add_argument("--require-shared-temporal-coverage", action="store_true") + return parser + + +def _smoke_cases(*, population_modes: tuple[PopulationMode, ...]) -> list[FabricAuditCase]: + return [ + FabricAuditCase( + case_id=f"smoke_slstm_1m_forward_b1_t1_k1_h32_pop{population_mode}", + plan="smoke", + owner_stage="R8" if population_mode == "single" else "R12", + prompt_requirements=("P0", "P7", "P11") if population_mode == "single" else ("P0", "P2", "P7", "P12"), + family="slstm", + target_params=1_000_000, + mode="forward", + batch_size=1, + seq_len=1, + inner_steps=1, + gradient_horizon_steps=None, + checkpoint_steps=None, + hidden_size=32, + training_output_boundary="terminal", + reset_mode="absent", + population_mode=population_mode, + reference_key="h32_small_params_high_batch", + ) + for population_mode in population_modes + ] + + +def _write_summary( + out_dir: Path, + *, + cases: list[FabricAuditCase], + results: list[dict[str, Any]], + references: dict[str, April21Reference], + elapsed_s: float, + args: argparse.Namespace, +) -> None: + gates = [result.get("gate", {}).get("status") for result in results] + summary = { + "audit": "redo_fixmaass_fabric", + "plan": args.plan, + "baseline_json": str(args.baseline_json), + "case_count": len(cases), + "completed_count": len(results), + "elapsed_s": elapsed_s, + "gate_counts": {str(status): gates.count(status) for status in sorted(set(gates), key=str)}, + "reference_keys": sorted({case.reference_key for case in cases}), + "references_loaded": len(references), + "high_level_api_only": True, + "reset_modes": sorted({case.reset_mode for case in cases}), + "training_output_boundaries": sorted({case.training_output_boundary for case in cases}), + "population_modes": sorted({case.population_mode for case in cases}), + "case_index": int(args.case_index), + "case_count_shards": int(args.case_count), + "enforce_references": bool(args.enforce_references), + "require_cuda_temporal_owner": bool(args.require_cuda_temporal_owner), + "require_shared_temporal_coverage": bool(args.require_shared_temporal_coverage), + "shared_temporal_coverage_gate": _shared_temporal_coverage_gate(cases), + } + _write_json(out_dir / "summary.json", summary) + + +def _shared_temporal_coverage_gate(cases: list[FabricAuditCase]) -> dict[str, Any]: + population_modes = sorted({case.population_mode for case in cases}) + missing = sorted({"single", "mixed"} - set(population_modes)) + if missing: + return { + "status": "fail", + "reason": "shared_temporal_owner_requires_single_and_mixed_population_coverage", + "population_modes": population_modes, + "missing_population_modes": missing, + } + return { + "status": "pass", + "population_modes": population_modes, + "missing_population_modes": [], + } + + +def _slice_cases(cases: list[FabricAuditCase], *, index: int, count: int) -> list[FabricAuditCase]: + if count <= 0: + raise ValueError("--case-count must be positive") + if index < 0 or index >= count: + raise ValueError("--case-index must be in [0, case-count)") + return [case for offset, case in enumerate(cases) if offset % count == index] + + +def _case_id( + plan: AuditPlan, + family: BackboneFamily, + target_params: int, + mode: SequenceMode, + batch_size: int, + seq_len: int, + inner_steps: int, + hidden_size: int, + *, + gradient_horizon_steps: int | None, + checkpoint_steps: int | None, + training_output_boundary: TrainingOutputBoundary, + reset_mode: ResetMode, + population_mode: PopulationMode, +) -> str: + horizon_label = "none" if gradient_horizon_steps is None else str(int(gradient_horizon_steps)) + checkpoint_label = "planner" if checkpoint_steps is None else str(int(checkpoint_steps)) + return ( + f"{plan}_{family}_{_params_label(target_params)}_{mode}_b{batch_size}_t{seq_len}_k{inner_steps}" + f"_h{hidden_size}_gh{horizon_label}_ck{checkpoint_label}_loss{training_output_boundary}" + f"_pop{population_mode}_reset{reset_mode}" + ) + + +def _streaming_reference_key( + *, + family: str, + params_label: str, + batch: int, + seq_len: int, + hidden_size: int, +) -> str: + return f"streaming_sequence_loss:{family}:{params_label}:b{batch}:t{seq_len}:h{hidden_size}" + + +def _parse_families(value: str) -> tuple[BackboneFamily, ...]: + families = tuple(item.strip().lower() for item in value.split(",") if item.strip()) + invalid = [family for family in families if family not in {"slstm", "axoncell"}] + if invalid: + raise ValueError(f"Unsupported families: {invalid}") + return families # type: ignore[return-value] + + +def _parse_modes(value: str) -> tuple[SequenceMode, ...]: + modes = tuple(item.strip().lower() for item in value.split(",") if item.strip()) + invalid = [mode for mode in modes if mode not in {"forward", "forward_backward"}] + if invalid: + raise ValueError(f"Unsupported modes: {invalid}") + return modes # type: ignore[return-value] + + +def _parse_sizes(value: str) -> tuple[int, ...]: + return tuple(_parse_size(item.strip()) for item in value.split(",") if item.strip()) + + +def _parse_size(value: str) -> int: + normalized = value.lower().replace("_", "") + if normalized.endswith("b"): + return int(float(normalized[:-1]) * 1_000_000_000) + if normalized.endswith("m"): + return int(float(normalized[:-1]) * 1_000_000) + if normalized.endswith("k"): + return int(float(normalized[:-1]) * 1_000) + return int(normalized) + + +def _params_label(value: int) -> str: + if value % 1_000_000_000 == 0: + return f"{value // 1_000_000_000}b" + if value % 1_000_000 == 0: + return f"{value // 1_000_000}m" + if value % 1_000 == 0: + return f"{value // 1_000}k" + return str(value) + + +def _parse_int_csv(value: str) -> tuple[int, ...]: + return tuple(int(item.strip()) for item in value.split(",") if item.strip()) + + +def _parse_optional_int_csv(value: str) -> tuple[int | None, ...]: + parsed: list[int | None] = [] + for item in value.split(","): + normalized = item.strip().lower() + if not normalized: + continue + parsed.append(None if normalized in {"none", "null", "planner"} else int(normalized)) + return tuple(parsed) or (None,) + + +def _parse_reset_modes(value: str) -> tuple[ResetMode, ...]: + modes = tuple(item.strip().lower() for item in value.split(",") if item.strip()) + invalid = [mode for mode in modes if mode not in {"absent", "present"}] + if invalid: + raise ValueError(f"Unsupported reset modes: {invalid}") + if not modes: + raise ValueError("Expected at least one reset mode") + return modes # type: ignore[return-value] + + +def _parse_training_output_boundaries(value: str) -> tuple[TrainingOutputBoundary, ...] | None: + normalized_value = value.strip().lower() + if normalized_value in {"auto", "default"}: + return None + boundaries = tuple(item.strip().lower() for item in value.split(",") if item.strip()) + invalid = [boundary for boundary in boundaries if boundary not in {"terminal", "sequence"}] + if invalid: + raise ValueError(f"Unsupported training output boundaries: {invalid}") + if not boundaries: + raise ValueError("Expected at least one training output boundary") + return boundaries # type: ignore[return-value] + + +def _parse_population_modes(value: str) -> tuple[PopulationMode, ...]: + modes = tuple(item.strip().lower() for item in value.split(",") if item.strip()) + invalid = [mode for mode in modes if mode not in {"single", "mixed"}] + if invalid: + raise ValueError(f"Unsupported population modes: {invalid}") + if not modes: + raise ValueError("Expected at least one population mode") + return modes # type: ignore[return-value] + + +def _make_reset_mask( + reset_mode: ResetMode, + *, + batch_size: int, + seq_len: int, + device: torch.device, +) -> torch.Tensor | None: + if reset_mode == "absent": + return None + if seq_len <= 0: + raise ValueError("Reset-present Fabric audit cases require positive seq_len") + rows = torch.arange(int(batch_size), device=device).view(int(batch_size), 1) + steps = torch.arange(int(seq_len), device=device).view(1, int(seq_len)) + period = max(2, min(7, int(seq_len) + 1)) + resets = (rows + steps) % period == 1 + resets[:, 0] = torch.arange(int(batch_size), device=device) % 3 == 0 + return resets + + +def _parse_dtype(value: str) -> torch.dtype: + normalized = value.strip().lower().replace("torch.", "") + if normalized == "float32": + return torch.float32 + if normalized in {"bfloat16", "bf16"}: + return torch.bfloat16 + if normalized in {"float16", "fp16"}: + return torch.float16 + raise ValueError(f"Unsupported dtype {value!r}") + + +def _optional_float(value: Any) -> float | None: + return None if value is None else float(value) + + +def _write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n") + + +def _append_jsonl(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, sort_keys=True) + "\n") diff --git a/benchmarks/fabric.py b/benchmarks/fabric/benchmark.py similarity index 99% rename from benchmarks/fabric.py rename to benchmarks/fabric/benchmark.py index 8a7b6a33..da430c09 100644 --- a/benchmarks/fabric.py +++ b/benchmarks/fabric/benchmark.py @@ -6,7 +6,7 @@ import torch from evaluations.stacks import build_axons_preup, build_slstm_postup -from .common import ( +from ..common import ( BenchmarkCase, BenchmarkDefinition, BenchmarkSettings, diff --git a/benchmarks/fabric/run_audit.py b/benchmarks/fabric/run_audit.py new file mode 100644 index 00000000..901edb4e --- /dev/null +++ b/benchmarks/fabric/run_audit.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from benchmarks.fabric.audit import build_parser, run_audit + + +def main() -> int: + return run_audit(build_parser().parse_args()) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarks/fabric_suite_common.py b/benchmarks/fabric/suite_common.py similarity index 62% rename from benchmarks/fabric_suite_common.py rename to benchmarks/fabric/suite_common.py index 6f7ae6d7..cd356c01 100644 --- a/benchmarks/fabric_suite_common.py +++ b/benchmarks/fabric/suite_common.py @@ -1,5 +1,6 @@ from __future__ import annotations +import gc import time from dataclasses import dataclass from functools import lru_cache @@ -17,6 +18,7 @@ BackboneFamily = Literal["slstm", "axoncell"] BackboneKind = Literal["stack", "fabric"] +PopulationMode = Literal["single", "mixed"] SequenceMode = Literal["forward", "forward_backward"] TrainingOutputBoundary = Literal["terminal", "sequence"] RolloutMode = Literal["forward", "online_grad"] @@ -36,6 +38,7 @@ ) D_HIDDEN_GRID = (64, 96, 128, 192, 256, 384, 512, 768, 1024, 1536, 2048) NUM_LAYERS_GRID = (2, 4, 8, 12, 16, 24, 32, 48, 64) +MIXED_STACK_NUM_LAYERS_GRID = tuple(range(1, 65)) def _build_fabric_shape_grid() -> tuple[tuple[int, int], ...]: @@ -110,15 +113,9 @@ def forward( return output, next_state def grad_sequence_strategy(self, x: torch.Tensor) -> str | None: - strategy = getattr(self.backbone, "_should_use_direct_grad_sequence", None) - if callable(strategy): - try: - use_direct, _ = strategy(x, None, materialize_final_state=False) - except RuntimeError as exc: - if "out of memory" in str(exc).lower(): - return "unavailable_oom" - raise - return "direct" if use_direct else "checkpointed" + del x + if hasattr(self.backbone, "_forward_sequence_with_readout"): + return "shared_temporal_engine" return None def prime_rollout_state(self, x: torch.Tensor) -> object | None: @@ -181,6 +178,16 @@ def build_stack_backbone(*, family: BackboneFamily, d_hidden: int, num_layers: i ) +def build_mixed_stack_backbone(*, d_hidden: int, num_layers: int) -> nn.Module: + return build_cortical_auto_stack( + d_hidden=d_hidden, + num_layers=num_layers, + layers=[[AxonCellConfig(), sLSTMCellConfig()] for _ in range(num_layers)], + post_norm=True, + compile_scaffolds=False, + ) + + def build_fabric_backbone( *, family: BackboneFamily, @@ -189,6 +196,8 @@ def build_fabric_backbone( height: int, hidden_size: int, inner_steps: int = 1, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, input_cell_indices: tuple[int, ...] | None = None, output_cell_indices: tuple[int, ...] | None = None, graph_edges: tuple[tuple[int, int], ...] | None = None, @@ -223,11 +232,106 @@ def build_fabric_backbone( head_dim=hidden_size, kv_sharing=fabric.message_rules.ShareBySenderTile(tile_shape=(1, 2)), ), - execution=fabric.ExecutionSpec(backend="auto", inner_steps=inner_steps), + execution=fabric.ExecutionSpec( + backend="auto", + inner_steps=inner_steps, + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + ), ) return fabric.compile(blueprint) +def build_mixed_fabric_backbone( + *, + d_hidden: int, + width: int, + height: int, + hidden_size: int, + inner_steps: int = 1, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, + input_cell_indices: tuple[int, ...] | None = None, + output_cell_indices: tuple[int, ...] | None = None, + graph_edges: tuple[tuple[int, int], ...] | None = None, + kv_group_ids: tuple[int, ...] | None = None, +) -> nn.Module: + slstm_nodes, axon_nodes = _mixed_population_node_indices( + width=width, + height=height, + input_cell_indices=input_cell_indices, + output_cell_indices=output_cell_indices, + ) + connectivity: list[fabric.graphs.lattice2d.LocalRadius | fabric.graphs.lattice2d.ExplicitEdges] = [ + fabric.graphs.lattice2d.LocalRadius(radius=1.5) + ] + if graph_edges is not None: + connectivity = [fabric.graphs.lattice2d.ExplicitEdges(edges=graph_edges, kv_group_ids=kv_group_ids)] + graph = fabric.graphs.lattice2d.Graph( + width=width, + height=height, + populations={ + "slstm": fabric.Population( + cell=fabric.cells.SLSTM(hidden_dim=hidden_size), + nodes=slstm_nodes, + ), + "axoncell": fabric.Population( + cell=fabric.cells.AxonCell(hidden_dim=hidden_size), + nodes=axon_nodes, + ), + }, + inputs=None if input_cell_indices is None else {"tokens": input_cell_indices}, + outputs=None + if output_cell_indices is None + else {"prediction": fabric.graphs.lattice2d.Output(output_cell_indices)}, + connectivity=connectivity, + ) + blueprint = fabric.Blueprint( + interface=fabric.Interface(public_dim=hidden_size, message_dim=hidden_size), + graph=graph, + inputs={"tokens": fabric.Input(dim=d_hidden)}, + outputs={"prediction": fabric.Output(dim=d_hidden)}, + message_passing=fabric.message_rules.DotProduct( + head_dim=hidden_size, + kv_sharing=fabric.message_rules.ShareBySenderTile(tile_shape=(1, 2)), + ), + execution=fabric.ExecutionSpec( + backend="auto", + inner_steps=inner_steps, + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + ), + ) + return fabric.compile(blueprint) + + +def _mixed_population_node_indices( + *, + width: int, + height: int, + input_cell_indices: tuple[int, ...] | None = None, + output_cell_indices: tuple[int, ...] | None = None, +) -> tuple[tuple[int, ...], tuple[int, ...]]: + if input_cell_indices is None: + input_cell_indices = tuple(_flat_lattice_index(0, y, height=height) for y in range(height)) + if output_cell_indices is None: + output_cell_indices = tuple(_flat_lattice_index(width - 1, y, height=height) for y in range(height)) + boundary_nodes = set(input_cell_indices) | set(output_cell_indices) + recurrent_nodes = tuple( + node + for x in range(width) + for y in range(height) + if (node := _flat_lattice_index(x, y, height=height)) not in boundary_nodes + ) + if len(recurrent_nodes) < 2: + raise ValueError("mixed Fabric benchmark requires at least two recurrent nodes") + return recurrent_nodes[::2], recurrent_nodes[1::2] + + +def _flat_lattice_index(x: int, y: int, *, height: int) -> int: + return int(x) * int(height) + int(y) + + def _fabric_layout_counts(*, width: int, height: int) -> tuple[int, int, int, int, int]: total_cells = width * height input_cells = height @@ -240,7 +344,7 @@ def _fabric_layout_counts(*, width: int, height: int) -> tuple[int, int, int, in def _fabric_family_cell_param_count(*, family: BackboneFamily, hidden_size: int) -> int: if family == "slstm": return (8 * hidden_size * hidden_size) + (5 * hidden_size) - return (3 * hidden_size * hidden_size) + (5 * hidden_size) + return (3 * hidden_size * hidden_size) + (6 * hidden_size) def _count_fabric_runtime_params( @@ -250,15 +354,35 @@ def _count_fabric_runtime_params( height: int, hidden_size: int, ) -> int: - total_cells, _, output_cells, recurrent_cells, kv_groups = _fabric_layout_counts(width=width, height=height) + _, _, _, recurrent_cells, _ = _fabric_layout_counts(width=width, height=height) family_params = (recurrent_cells + 1) * _fabric_family_cell_param_count( family=family, hidden_size=hidden_size, ) - runtime_params = ((output_cells + 6) * hidden_size * hidden_size) + ( - (2 * total_cells) + output_cells + 14 + (8 * kv_groups) + return family_params + _count_fabric_shared_runtime_params(width=width, height=height, hidden_size=hidden_size) + + +def _count_fabric_shared_runtime_params(*, width: int, height: int, hidden_size: int) -> int: + total_cells, _, output_cells, _, kv_groups = _fabric_layout_counts(width=width, height=height) + shared_params = ((2 * kv_groups + output_cells + 9) * hidden_size * hidden_size) + ( + (2 * total_cells) + output_cells + 2 ) * hidden_size - return family_params + runtime_params + fixed_slot_message_params = (4 * hidden_size * hidden_size) + (total_cells * hidden_size) + 1 + return shared_params + fixed_slot_message_params + + +def _count_mixed_fabric_runtime_params(*, width: int, height: int, hidden_size: int) -> int: + _, _, _, recurrent_cells, _ = _fabric_layout_counts(width=width, height=height) + slstm_cells = (recurrent_cells + 1) // 2 + axon_cells = recurrent_cells // 2 + family_params = (slstm_cells + 1) * _fabric_family_cell_param_count( + family="slstm", + hidden_size=hidden_size, + ) + (axon_cells + 1) * _fabric_family_cell_param_count( + family="axoncell", + hidden_size=hidden_size, + ) + return family_params + _count_fabric_shared_runtime_params(width=width, height=height, hidden_size=hidden_size) def _count_fabric_boundary_wrapper_params(*, d_hidden: int, width: int, height: int, hidden_size: int) -> int: @@ -272,11 +396,41 @@ def _count_stack_layer_params(*, family: BackboneFamily, d_hidden: int) -> int: return (18 * d_hidden * d_hidden) + ((23 * d_hidden) // 2) +def _count_mixed_stack_layer_params(*, d_hidden: int) -> int: + return (46 * d_hidden * d_hidden) + (39 * d_hidden) + 2 + + @lru_cache(maxsize=None) def _count_stack_backbone_params(*, family: BackboneFamily, d_hidden: int, num_layers: int) -> int: return (num_layers * _count_stack_layer_params(family=family, d_hidden=d_hidden)) + (2 * d_hidden) +@lru_cache(maxsize=None) +def _count_mixed_stack_backbone_params(*, d_hidden: int, num_layers: int) -> int: + return (num_layers * _count_mixed_stack_layer_params(d_hidden=d_hidden)) + (2 * d_hidden) + + +def _count_sequence_head_params(*, d_hidden: int) -> int: + return (d_hidden * d_hidden) + d_hidden + + +@lru_cache(maxsize=None) +def _count_stack_sequence_model_params(*, family: BackboneFamily, d_hidden: int, num_layers: int) -> int: + return _count_stack_backbone_params( + family=family, + d_hidden=d_hidden, + num_layers=num_layers, + ) + _count_sequence_head_params(d_hidden=d_hidden) + + +@lru_cache(maxsize=None) +def _count_mixed_stack_sequence_model_params(*, d_hidden: int, num_layers: int) -> int: + return _count_mixed_stack_backbone_params( + d_hidden=d_hidden, + num_layers=num_layers, + ) + _count_sequence_head_params(d_hidden=d_hidden) + + @lru_cache(maxsize=None) def _count_fabric_backbone_params( *, @@ -299,6 +453,60 @@ def _count_fabric_backbone_params( ) +@lru_cache(maxsize=None) +def _count_mixed_fabric_backbone_params( + *, + d_hidden: int, + width: int, + height: int, + hidden_size: int, +) -> int: + return _count_mixed_fabric_runtime_params( + width=width, + height=height, + hidden_size=hidden_size, + ) + _count_fabric_boundary_wrapper_params( + d_hidden=d_hidden, + width=width, + height=height, + hidden_size=hidden_size, + ) + + +@lru_cache(maxsize=None) +def _count_fabric_sequence_model_params( + *, + family: BackboneFamily, + d_hidden: int, + width: int, + height: int, + hidden_size: int, +) -> int: + return _count_fabric_backbone_params( + family=family, + d_hidden=d_hidden, + width=width, + height=height, + hidden_size=hidden_size, + ) + _count_sequence_head_params(d_hidden=d_hidden) + + +@lru_cache(maxsize=None) +def _count_mixed_fabric_sequence_model_params( + *, + d_hidden: int, + width: int, + height: int, + hidden_size: int, +) -> int: + return _count_mixed_fabric_backbone_params( + d_hidden=d_hidden, + width=width, + height=height, + hidden_size=hidden_size, + ) + _count_sequence_head_params(d_hidden=d_hidden) + + @lru_cache(maxsize=None) def find_param_matched_backbone( *, @@ -341,6 +549,72 @@ def find_param_matched_backbone( raise ValueError(f"Could not find a {kind} backbone with at least {target_params} params") +@lru_cache(maxsize=None) +def find_param_matched_mixed_fabric_backbone( + *, + target_params: int, + family: BackboneFamily = "slstm", + forced_d_hidden: int | None = None, + fabric_hidden_grid: tuple[int, ...] | None = None, +) -> MatchedBackbone: + d_hidden_grid = (forced_d_hidden,) if forced_d_hidden is not None else D_HIDDEN_GRID + candidates = _find_matching_mixed_fabric_candidates( + family=family, + d_hidden_grid=d_hidden_grid, + target_params=target_params, + fabric_hidden_grid=fabric_hidden_grid, + ) + if candidates: + return _select_fabric_match(candidates) + raise ValueError(f"Could not find a mixed Fabric backbone with at least {target_params} params") + + +@lru_cache(maxsize=None) +def find_param_matched_mixed_stack_backbone( + *, + target_params: int, + family: BackboneFamily = "slstm", + forced_d_hidden: int | None = None, +) -> MatchedBackbone: + d_hidden_grid = (forced_d_hidden,) if forced_d_hidden is not None else D_HIDDEN_GRID + candidates: list[MatchedBackbone] = [] + for d_hidden in d_hidden_grid: + result = _find_first_matching_value( + values=MIXED_STACK_NUM_LAYERS_GRID, + target_params=target_params, + count_fn=lambda num_layers, d_hidden=d_hidden: _count_mixed_stack_sequence_model_params( + d_hidden=d_hidden, + num_layers=num_layers, + ), + ) + if result is None: + continue + num_layers, actual_params = result + candidates.append( + MatchedBackbone( + kind="stack", + family=family, + target_params=target_params, + actual_params=actual_params, + d_hidden=d_hidden, + num_layers=num_layers, + fabric_shape=None, + fabric_hidden_size=None, + ) + ) + if not candidates: + raise ValueError(f"Could not find a mixed stack backbone with at least {target_params} params") + return min( + candidates, + key=lambda match: ( + _fabric_relative_param_error(match.actual_params, match.target_params), + match.actual_params, + match.d_hidden, + int(match.num_layers or 0), + ), + ) + + def _find_first_stack_match( *, family: BackboneFamily, @@ -350,7 +624,7 @@ def _find_first_stack_match( result = _find_first_matching_value( values=NUM_LAYERS_GRID, target_params=target_params, - count_fn=lambda num_layers: _count_stack_backbone_params( + count_fn=lambda num_layers: _count_stack_sequence_model_params( family=family, d_hidden=d_hidden, num_layers=num_layers, @@ -384,7 +658,7 @@ def _find_first_fabric_match( candidates = [ ( hidden_size, - _count_fabric_backbone_params( + _count_fabric_sequence_model_params( family=family, d_hidden=d_hidden, width=width, @@ -440,7 +714,7 @@ def _find_matching_fabric_candidates( hidden_size=hidden_size, allow_dynamic_shapes=allow_dynamic_shapes, ): - actual_params = _count_fabric_backbone_params( + actual_params = _count_fabric_sequence_model_params( family=family, d_hidden=d_hidden, width=width, @@ -462,6 +736,45 @@ def _find_matching_fabric_candidates( return candidates +def _find_matching_mixed_fabric_candidates( + *, + family: BackboneFamily, + d_hidden_grid: tuple[int, ...], + target_params: int, + fabric_hidden_grid: tuple[int, ...] | None = None, +) -> list[MatchedBackbone]: + candidates: list[MatchedBackbone] = [] + hidden_grid = _fabric_hidden_grid_for_target(target_params, fabric_hidden_grid=fabric_hidden_grid) + for d_hidden in d_hidden_grid: + for hidden_size in hidden_grid: + for width, height in _fabric_shape_candidates_for_target( + family=family, + d_hidden=d_hidden, + target_params=target_params, + hidden_size=hidden_size, + allow_dynamic_shapes=True, + ): + actual_params = _count_mixed_fabric_sequence_model_params( + d_hidden=d_hidden, + width=width, + height=height, + hidden_size=hidden_size, + ) + candidates.append( + MatchedBackbone( + kind="fabric", + family=family, + target_params=target_params, + actual_params=actual_params, + d_hidden=d_hidden, + num_layers=None, + fabric_shape=(width, height), + fabric_hidden_size=hidden_size, + ) + ) + return candidates + + def _fabric_shape_candidates_for_target( *, family: BackboneFamily, @@ -507,7 +820,7 @@ def _static_fabric_candidate_grid_is_close_enough( return False best_error = min( _fabric_relative_param_error( - _count_fabric_backbone_params( + _count_fabric_sequence_model_params( family=family, d_hidden=d_hidden, width=width, @@ -534,7 +847,7 @@ def _static_fabric_shapes_are_close_enough( return False best_error = min( _fabric_relative_param_error( - _count_fabric_backbone_params( + _count_fabric_sequence_model_params( family=family, d_hidden=d_hidden, width=width, @@ -613,7 +926,7 @@ def _find_fabric_width_upper_bound( hi = max(4, min(FABRIC_DYNAMIC_SHAPE_MAX_WIDTH, _round_up_to_multiple(max(height, 16), 16))) while ( hi < FABRIC_DYNAMIC_SHAPE_MAX_WIDTH - and _count_fabric_backbone_params( + and _count_fabric_sequence_model_params( family=family, d_hidden=d_hidden, width=hi, @@ -626,7 +939,7 @@ def _find_fabric_width_upper_bound( hi = min(FABRIC_DYNAMIC_SHAPE_MAX_WIDTH, hi * 2) while lo < hi: mid = (lo + hi) // 2 - actual_params = _count_fabric_backbone_params( + actual_params = _count_fabric_sequence_model_params( family=family, d_hidden=d_hidden, width=mid, @@ -810,6 +1123,8 @@ def make_sequence_model( device: torch.device, dtype: torch.dtype, inner_steps: int = 1, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, ) -> _BackboneWithHead: if match.kind == "stack": backbone = build_stack_backbone( @@ -826,12 +1141,58 @@ def make_sequence_model( height=height, hidden_size=int(match.fabric_hidden_size), inner_steps=inner_steps, + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, ) return _BackboneWithHead(backbone.to(device=device, dtype=dtype), d_hidden=match.d_hidden).to( device=device, dtype=dtype ) +def make_mixed_fabric_sequence_model( + match: MatchedBackbone, + *, + device: torch.device, + dtype: torch.dtype, + inner_steps: int = 1, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, +) -> _BackboneWithHead: + if match.kind != "fabric": + raise ValueError("mixed Fabric sequence model requires a Fabric match") + width, height = match.fabric_shape or (0, 0) + backbone = build_mixed_fabric_backbone( + d_hidden=match.d_hidden, + width=width, + height=height, + hidden_size=int(match.fabric_hidden_size), + inner_steps=inner_steps, + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + ) + return _BackboneWithHead(backbone.to(device=device, dtype=dtype), d_hidden=match.d_hidden).to( + device=device, dtype=dtype + ) + + +def make_mixed_stack_sequence_model( + match: MatchedBackbone, + *, + device: torch.device, + dtype: torch.dtype, +) -> _BackboneWithHead: + if match.kind != "stack": + raise ValueError("mixed stack sequence model requires a stack match") + backbone = build_mixed_stack_backbone( + d_hidden=match.d_hidden, + num_layers=int(match.num_layers), + ) + return _BackboneWithHead(backbone.to(device=device, dtype=dtype), d_hidden=match.d_hidden).to( + device=device, + dtype=dtype, + ) + + def run_synthetic_case( *, match: MatchedBackbone, @@ -877,12 +1238,21 @@ def run_sequence_case( training_output_boundary: TrainingOutputBoundary = "terminal", forward_output_boundary: TrainingOutputBoundary = "sequence", inner_steps: int = 1, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, ) -> dict[str, object]: if training_output_boundary not in {"terminal", "sequence"}: raise ValueError(f"Unsupported training output boundary {training_output_boundary!r}") if forward_output_boundary not in {"terminal", "sequence"}: raise ValueError(f"Unsupported forward output boundary {forward_output_boundary!r}") - bench_model = model or make_sequence_model(match, device=device, dtype=dtype, inner_steps=inner_steps) + bench_model = model or make_sequence_model( + match, + device=device, + dtype=dtype, + inner_steps=inner_steps, + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + ) x = torch.randn(batch_size, seq_len, match.d_hidden, device=device, dtype=dtype) target = None if mode == "forward" else make_sequence_training_target(x, output_boundary=training_output_boundary) output_boundary = training_output_boundary if target is not None else forward_output_boundary @@ -917,7 +1287,11 @@ def run_mode(x: torch.Tensor) -> torch.Tensor: "batch_size": batch_size, "seq_len": seq_len, "inner_steps": inner_steps, + "gradient_horizon_steps": gradient_horizon_steps, + "checkpoint_steps": checkpoint_steps, + "reset_mode": "present" if resets is not None else "absent", "output_boundary": output_boundary, + "population_mode": "single", } if mode == "forward_backward": grad_strategy = bench_model.grad_sequence_strategy(x) @@ -930,9 +1304,16 @@ def run_mode(x: torch.Tensor) -> torch.Tensor: batch_size=batch_size, seq_len=seq_len, inner_steps=inner_steps, + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, ) if planner_signature is not None: result["planner_signature"] = planner_signature + if isinstance(metrics.get("memory_ledger"), dict): + memory_ledger = dict(metrics["memory_ledger"]) + if planner_signature is not None: + _attach_compiler_memory_owner_ledger(memory_ledger, planner_signature) + result["memory_ledger"] = memory_ledger if metrics["status"] != "ok": result["error"] = metrics["error"] return result @@ -944,6 +1325,99 @@ def run_mode(x: torch.Tensor) -> torch.Tensor: } +def run_mixed_fabric_sequence_case( + *, + match: MatchedBackbone, + mode: SequenceMode, + batch_size: int, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + warmup: int, + iterations: int, + resets: torch.Tensor | None = None, + training_output_boundary: TrainingOutputBoundary = "terminal", + forward_output_boundary: TrainingOutputBoundary = "sequence", + inner_steps: int = 1, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, +) -> dict[str, object]: + model = make_mixed_fabric_sequence_model( + match, + device=device, + dtype=dtype, + inner_steps=inner_steps, + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + ) + actual_params = _count_params(model) + result = run_sequence_case( + match=match, + mode=mode, + batch_size=batch_size, + seq_len=seq_len, + family=match.family, + device=device, + dtype=dtype, + warmup=warmup, + iterations=iterations, + model=model, + resets=resets, + training_output_boundary=training_output_boundary, + forward_output_boundary=forward_output_boundary, + inner_steps=inner_steps, + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + ) + result["population_mode"] = "mixed" + result["population_families"] = ("slstm", "axoncell") + result["actual_params"] = actual_params + planner_signature = result.get("planner_signature") + if isinstance(planner_signature, dict): + planner_signature["population_mode"] = "mixed" + planner_signature["population_families"] = ["slstm", "axoncell"] + planner_signature["actual_params"] = actual_params + return result + + +def run_mixed_stack_sequence_case( + *, + match: MatchedBackbone, + mode: SequenceMode, + batch_size: int, + seq_len: int, + device: torch.device, + dtype: torch.dtype, + warmup: int, + iterations: int, + resets: torch.Tensor | None = None, + training_output_boundary: TrainingOutputBoundary = "terminal", + forward_output_boundary: TrainingOutputBoundary = "sequence", +) -> dict[str, object]: + model = make_mixed_stack_sequence_model(match, device=device, dtype=dtype) + actual_params = _count_params(model) + result = run_sequence_case( + match=match, + mode=mode, + batch_size=batch_size, + seq_len=seq_len, + family=match.family, + device=device, + dtype=dtype, + warmup=warmup, + iterations=iterations, + model=model, + resets=resets, + training_output_boundary=training_output_boundary, + forward_output_boundary=forward_output_boundary, + ) + result["population_mode"] = "mixed" + result["population_families"] = ("slstm", "axoncell") + result["baseline_kind"] = "mixed_stack" + result["actual_params"] = actual_params + return result + + def _planner_signature_from_model( *, model: _BackboneWithHead, @@ -952,6 +1426,8 @@ def _planner_signature_from_model( batch_size: int, seq_len: int, inner_steps: int, + gradient_horizon_steps: int | None, + checkpoint_steps: int | None, ) -> dict[str, object] | None: backbone = getattr(model, "backbone", None) record = getattr(backbone, "last_backend_execution", None) @@ -992,6 +1468,8 @@ def tuple_field(name: str) -> list[Any]: "batch_size": int(batch_size), "window_len": int(seq_len), "requested_inner_steps": int(inner_steps), + "requested_gradient_horizon_steps": gradient_horizon_steps, + "requested_checkpoint_steps": checkpoint_steps, "inner_steps": getattr(record, "inner_steps", None), "backend_name": getattr(record, "backend_name", None), "surface_key": getattr(record, "surface_key", None), @@ -1018,6 +1496,7 @@ def tuple_field(name: str) -> list[Any]: "message_physical_modes": tuple_field("message_physical_modes"), "layout_mode": tuple_field("layout_mode"), "workspace_aliases": tuple_field("workspace_aliases"), + "backward_workspace_aliases": tuple_field("backward_workspace_aliases"), "launch_batch_tiles": tuple_field("launch_batch_tiles"), "launch_receiver_tiles": tuple_field("launch_receiver_tiles"), "launch_public_batch_tiles": tuple_field("launch_public_batch_tiles"), @@ -1025,7 +1504,54 @@ def tuple_field(name: str) -> list[Any]: "launch_readout_modes": tuple_field("launch_readout_modes"), "large_r_safety_modes": tuple_field("large_r_safety_modes"), "launch_temporal_executions": tuple_field("launch_temporal_executions"), + "launch_temporal_scan_owners": tuple_field("launch_temporal_scan_owners"), "launch_scan_implementations": tuple_field("launch_scan_implementations"), + "temporal_primitive_executor_contracts": tuple_field("temporal_primitive_executor_contracts"), + "temporal_primitive_executor_blockers": tuple_field("temporal_primitive_executor_blockers"), + "temporal_plan_schedule_kinds": tuple_field("temporal_plan_schedule_kinds"), + "temporal_plan_outer_time_steps": tuple_field("temporal_plan_outer_time_steps"), + "temporal_plan_inner_steps": tuple_field("temporal_plan_inner_steps"), + "temporal_plan_total_scan_steps": tuple_field("temporal_plan_total_scan_steps"), + "temporal_plan_per_timestep_k": tuple_field("temporal_plan_per_timestep_k"), + "temporal_plan_substrate_kinds": tuple_field("temporal_plan_substrate_kinds"), + "temporal_plan_bucket_identity": tuple_field("temporal_plan_bucket_identity"), + "temporal_plan_resets": tuple_field("temporal_plan_resets"), + "temporal_plan_output_selectors": tuple_field("temporal_plan_output_selectors"), + "temporal_plan_output_explicit_outer_steps": tuple_field("temporal_plan_output_explicit_outer_steps"), + "temporal_plan_output_first_outer_steps": tuple_field("temporal_plan_output_first_outer_steps"), + "temporal_plan_output_outer_strides": tuple_field("temporal_plan_output_outer_strides"), + "temporal_plan_output_counts": tuple_field("temporal_plan_output_counts"), + "temporal_plan_output_first_physical_steps": tuple_field("temporal_plan_output_first_physical_steps"), + "temporal_plan_output_physical_strides": tuple_field("temporal_plan_output_physical_strides"), + "temporal_plan_output_surfaces": tuple_field("temporal_plan_output_surfaces"), + "temporal_plan_readout_surfaces": tuple_field("temporal_plan_readout_surfaces"), + "temporal_plan_output_materializations": tuple_field("temporal_plan_output_materializations"), + "temporal_plan_autograd_seed_kinds": tuple_field("temporal_plan_autograd_seed_kinds"), + "temporal_plan_required_backward_surfaces": tuple_field("temporal_plan_required_backward_surfaces"), + "temporal_plan_checkpoint_policy_basis": tuple_field("temporal_plan_checkpoint_policy_basis"), + "temporal_plan_fresh_state_population_cache": tuple_field("temporal_plan_fresh_state_population_cache"), + "temporal_plan_fresh_state_population_cache_reasons": tuple_field( + "temporal_plan_fresh_state_population_cache_reasons" + ), + "temporal_plan_gradient_boundaries": tuple_field("temporal_plan_gradient_boundaries"), + "temporal_plan_horizon_steps": tuple_field("temporal_plan_horizon_steps"), + "temporal_plan_checkpoint_kinds": tuple_field("temporal_plan_checkpoint_kinds"), + "temporal_plan_checkpoint_steps": tuple_field("temporal_plan_checkpoint_steps"), + "temporal_plan_reverse_artifact_kinds": tuple_field("temporal_plan_reverse_artifact_kinds"), + "temporal_plan_recompute_window_steps": tuple_field("temporal_plan_recompute_window_steps"), + "temporal_plan_materialization_reasons": tuple_field("temporal_plan_materialization_reasons"), + "temporal_plan_backward_windows": tuple_field("temporal_plan_backward_windows"), + "temporal_plan_static_value_modes": tuple_field("temporal_plan_static_value_modes"), + "temporal_plan_backend_names": tuple_field("temporal_plan_backend_names"), + "temporal_plan_executors": tuple_field("temporal_plan_executors"), + "temporal_plan_selected_implementations": tuple_field("temporal_plan_selected_implementations"), + "temporal_plan_reasons": tuple_field("temporal_plan_reasons"), + "temporal_plan_forward_owners": tuple_field("temporal_plan_forward_owners"), + "temporal_plan_backward_owners": tuple_field("temporal_plan_backward_owners"), + "temporal_plan_checkpoint_owners": tuple_field("temporal_plan_checkpoint_owners"), + "temporal_plan_target_owners": tuple_field("temporal_plan_target_owners"), + "temporal_plan_engine_statuses": tuple_field("temporal_plan_engine_statuses"), + "temporal_plan_engine_reasons": tuple_field("temporal_plan_engine_reasons"), "active_cell_tiling_plans": tuple_field("active_cell_tiling_plans"), "backward_physical_op_kinds": tuple_field("backward_physical_op_kinds"), "backward_physical_op_executors": tuple_field("backward_physical_op_executors"), @@ -1039,6 +1565,200 @@ def tuple_field(name: str) -> list[Any]: } +def _summary_value(entries: list[Any], prefix: str, key: str) -> str | None: + marker = f"{prefix}:{key}=" + for entry in entries: + text = str(entry) + if text.startswith(marker): + return text[len(marker) :] + return None + + +def _summary_int(entries: list[Any], prefix: str, key: str) -> int: + value = _summary_value(entries, prefix, key) + if value is None: + return 0 + try: + return int(value) + except ValueError: + return 0 + + +def _summary_int_map(entries: list[Any], prefix: str, key: str) -> dict[str, int]: + value = _summary_value(entries, prefix, key) + if value is None: + return {} + parsed: dict[str, int] = {} + for item in value.split(","): + if ":" not in item: + continue + name, raw_count = item.rsplit(":", 1) + try: + parsed[str(name)] = int(raw_count) + except ValueError: + continue + return parsed + + +def _summary_records(entries: list[Any], prefix: str) -> tuple[dict[str, str], ...]: + marker = f"{prefix}:" + records: list[dict[str, str]] = [] + for entry in entries: + text = str(entry) + if not text.startswith(marker): + continue + fields: dict[str, str] = {} + for item in text[len(marker) :].split(";"): + if "=" not in item: + continue + key, value = item.split("=", 1) + fields[str(key)] = str(value) + if fields: + records.append(fields) + return tuple(records) + + +def _record_int_value(record: dict[str, str], key: str) -> int: + try: + return int(record.get(key, "0")) + except ValueError: + return 0 + + +def _reset_fabric_temporal_memory_stage_ledger(model: nn.Module) -> None: + for module in model.modules(): + runtime = getattr(module, "runtime", None) + if runtime is None: + continue + if hasattr(runtime, "_last_flat_bucket_temporal_registered_backward_memory_stages"): + runtime._last_flat_bucket_temporal_registered_backward_memory_stages = () + if hasattr(runtime, "_last_flat_bucket_temporal_frontend_tensor_bytes"): + runtime._last_flat_bucket_temporal_frontend_tensor_bytes = () + + +def _attach_compiler_memory_owner_ledger(memory_ledger: dict[str, int], planner_signature: dict[str, object]) -> None: + workspace_aliases = list(planner_signature.get("workspace_aliases", ())) + runtime_prefix = "flat_bucket_temporal_memory_runtime_buffer" + artifact_prefix = "flat_bucket_temporal_reverse_artifact_tensor_store" + backward_stage_prefix = "flat_bucket_temporal_registered_backward_memory_stage" + frontend_tensor_prefix = "flat_bucket_temporal_frontend_tensor_bytes" + runtime_buffer_bytes = _summary_int(workspace_aliases, runtime_prefix, "estimated_allocated_buffer_bytes") + reverse_artifact_bytes = _summary_int(workspace_aliases, artifact_prefix, "unique_storage_bytes") + memory_ledger["fabric_compiler_runtime_buffer_bytes"] = int(runtime_buffer_bytes) + memory_ledger["fabric_compiler_reverse_artifact_bytes"] = int(reverse_artifact_bytes) + memory_ledger["fabric_compiler_named_runtime_artifact_bytes"] = int(runtime_buffer_bytes + reverse_artifact_bytes) + runtime_by_role = _summary_int_map(workspace_aliases, runtime_prefix, "bytes_by_runtime_role") + artifact_by_role = _summary_int_map(workspace_aliases, artifact_prefix, "logical_bytes_by_role") + for role, byte_count in runtime_by_role.items(): + memory_ledger[f"fabric_compiler_runtime_role_bytes.{role}"] = int(byte_count) + for role, byte_count in artifact_by_role.items(): + memory_ledger[f"fabric_compiler_reverse_artifact_role_bytes.{role}"] = int(byte_count) + backward_stage_records = _summary_records(workspace_aliases, backward_stage_prefix) + memory_ledger["fabric_registered_backward_memory_stage_count"] = int(len(backward_stage_records)) + peak_stage = "" + peak_stage_max_allocated = 0 + first_peak_stage = "" + first_peak_stage_max_allocated = 0 + peak_max_delta_stage = "" + peak_max_delta = 0 + peak_current_stage = "" + peak_current_stage_allocated = 0 + previous_stage_allocated: int | None = None + previous_stage_max_allocated: int | None = None + for stage_record in backward_stage_records: + stage_name = stage_record.get("stage", "unknown").replace(" ", "_") + allocated = _record_int_value(stage_record, "allocated") + reserved = _record_int_value(stage_record, "reserved") + max_allocated = _record_int_value(stage_record, "max_allocated") + max_delta = ( + 0 + if previous_stage_max_allocated is None + else max(0, int(max_allocated) - int(previous_stage_max_allocated)) + ) + existing_allocated = int( + memory_ledger.get(f"fabric_registered_backward_stage_allocated_bytes.{stage_name}", -1) + ) + if allocated >= existing_allocated: + memory_ledger[f"fabric_registered_backward_stage_allocated_bytes.{stage_name}"] = int(allocated) + memory_ledger[f"fabric_registered_backward_stage_reserved_bytes.{stage_name}"] = int(reserved) + memory_ledger[f"fabric_registered_backward_stage_max_allocated_bytes.{stage_name}"] = int(max_allocated) + if previous_stage_allocated is not None: + memory_ledger[f"fabric_registered_backward_stage_allocated_delta_bytes.{stage_name}"] = int( + allocated - previous_stage_allocated + ) + existing_max_delta = int( + memory_ledger.get(f"fabric_registered_backward_stage_max_delta_bytes.{stage_name}", -1) + ) + if max_delta >= existing_max_delta: + memory_ledger[f"fabric_registered_backward_stage_max_delta_bytes.{stage_name}"] = int(max_delta) + previous_stage_allocated = int(allocated) + previous_stage_max_allocated = max( + int(max_allocated), + 0 if previous_stage_max_allocated is None else int(previous_stage_max_allocated), + ) + if max_delta > peak_max_delta: + peak_max_delta_stage = stage_name + peak_max_delta = int(max_delta) + if max_allocated > first_peak_stage_max_allocated: + first_peak_stage = stage_name + first_peak_stage_max_allocated = int(max_allocated) + if max_allocated >= peak_stage_max_allocated: + peak_stage = stage_name + peak_stage_max_allocated = int(max_allocated) + if allocated >= peak_current_stage_allocated: + peak_current_stage = stage_name + peak_current_stage_allocated = int(allocated) + memory_ledger["fabric_registered_backward_peak_stage_max_allocated_bytes"] = int(peak_stage_max_allocated) + if peak_stage: + memory_ledger[f"fabric_registered_backward_peak_stage.{peak_stage}"] = 1 + memory_ledger["fabric_registered_backward_first_peak_stage_max_allocated_bytes"] = int( + first_peak_stage_max_allocated + ) + if first_peak_stage: + memory_ledger[f"fabric_registered_backward_first_peak_stage.{first_peak_stage}"] = 1 + memory_ledger["fabric_registered_backward_peak_stage_max_delta_bytes"] = int(peak_max_delta) + if peak_max_delta_stage: + memory_ledger[f"fabric_registered_backward_peak_stage_by_max_delta.{peak_max_delta_stage}"] = 1 + memory_ledger["fabric_registered_backward_peak_current_stage_allocated_bytes"] = int(peak_current_stage_allocated) + if peak_current_stage: + memory_ledger[f"fabric_registered_backward_peak_current_stage.{peak_current_stage}"] = 1 + frontend_tensor_records = _summary_records(workspace_aliases, frontend_tensor_prefix) + memory_ledger["fabric_frontend_tensor_byte_stage_count"] = int(len(frontend_tensor_records)) + peak_frontend_stage = "" + peak_frontend_total_bytes = 0 + for tensor_record in frontend_tensor_records: + stage_name = tensor_record.get("stage", "unknown").replace(" ", "_") + total_bytes = _record_int_value(tensor_record, "total_bytes") + if total_bytes >= int(memory_ledger.get(f"fabric_frontend_tensor_stage_total_bytes.{stage_name}", -1)): + memory_ledger[f"fabric_frontend_tensor_stage_total_bytes.{stage_name}"] = int(total_bytes) + for item in tensor_record.get("bytes_by_role", "").split(","): + if ":" not in item: + continue + role, raw_count = item.rsplit(":", 1) + try: + memory_ledger[f"fabric_frontend_tensor_role_bytes.{stage_name}.{role}"] = int(raw_count) + except ValueError: + continue + if total_bytes >= peak_frontend_total_bytes: + peak_frontend_stage = stage_name + peak_frontend_total_bytes = int(total_bytes) + memory_ledger["fabric_frontend_tensor_peak_stage_total_bytes"] = int(peak_frontend_total_bytes) + if peak_frontend_stage: + memory_ledger[f"fabric_frontend_tensor_peak_stage.{peak_frontend_stage}"] = 1 + cuda_peak = int(memory_ledger.get("cuda_max_allocated_bytes", 0)) + model_bytes = int(memory_ledger.get("model_parameter_bytes", 0)) + grad_bytes = int(memory_ledger.get("model_parameter_grad_bytes", 0)) + named_bytes = int(runtime_buffer_bytes + reverse_artifact_bytes) + memory_ledger["fabric_unclassified_cuda_peak_bytes"] = max( + 0, + int(cuda_peak - model_bytes - grad_bytes - named_bytes), + ) + memory_ledger["fabric_cuda_reserved_gap_bytes"] = max( + 0, + int(memory_ledger.get("cuda_reserved_bytes", 0)) - int(memory_ledger.get("cuda_allocated_bytes", 0)), + ) + + def run_rollout_case( *, match: MatchedBackbone, @@ -1431,14 +2151,19 @@ def _measure_model( iterations: int, device: torch.device, iteration_fn=None, -) -> tuple[float, float | None]: +) -> tuple[float, float | None, dict[str, int]]: optimizer = torch.optim.SGD(model.parameters(), lr=0.0) if target is not None else None run_iteration = _run_iteration if iteration_fn is None else iteration_fn for _ in range(max(0, warmup)): run_iteration(model=model, run_mode=run_mode, optimizer=optimizer, x=x, target=target) + if optimizer is not None: + optimizer.zero_grad(set_to_none=True) + gc.collect() if device.type == "cuda": torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() + _reset_fabric_temporal_memory_stage_ledger(model) + measurement_start_memory = _cuda_memory_byte_summary(device) start = time.perf_counter() for _ in range(max(1, iterations)): run_iteration(model=model, run_mode=run_mode, optimizer=optimizer, x=x, target=target) @@ -1448,7 +2173,7 @@ def _measure_model( else: peak_mem = None dt = (time.perf_counter() - start) / max(1, iterations) - return dt, peak_mem + return dt, peak_mem, measurement_start_memory def _measure_case( @@ -1463,7 +2188,7 @@ def _measure_case( iteration_fn=None, ) -> dict[str, object]: try: - dt, peak_mem = _measure_model( + dt, peak_mem, measurement_start_memory = _measure_model( model=model, run_mode=run_mode, x=x, @@ -1474,16 +2199,69 @@ def _measure_case( iteration_fn=iteration_fn, ) except (RuntimeError, ValueError, torch.OutOfMemoryError) as exc: + memory_ledger = _benchmark_memory_ledger_on_error(model=model, device=device) if device.type == "cuda": torch.cuda.empty_cache() return { "status": _classify_case_error(exc), "error": str(exc), + "memory_ledger": memory_ledger, } + memory_ledger = _benchmark_memory_ledger_on_success(model=model, device=device) + memory_ledger.update({f"measurement_start_{key}": int(value) for key, value in measurement_start_memory.items()}) return { "status": "ok", "dt": dt, "peak_mem_gib": peak_mem, + "memory_ledger": memory_ledger, + } + + +def _module_parameter_byte_summary(model: nn.Module) -> dict[str, int]: + parameter_bytes = 0 + parameter_grad_bytes = 0 + trainable_parameter_bytes = 0 + for parameter in model.parameters(): + bytes_ = int(parameter.numel()) * int(parameter.element_size()) + parameter_bytes += bytes_ + if parameter.requires_grad: + trainable_parameter_bytes += bytes_ + if parameter.grad is not None: + parameter_grad_bytes += int(parameter.grad.numel()) * int(parameter.grad.element_size()) + return { + "model_parameter_bytes": int(parameter_bytes), + "model_trainable_parameter_bytes": int(trainable_parameter_bytes), + "model_parameter_grad_bytes": int(parameter_grad_bytes), + } + + +def _cuda_memory_byte_summary(device: torch.device) -> dict[str, int]: + if device.type != "cuda": + return {} + try: + free_bytes, total_bytes = torch.cuda.mem_get_info(device) + except RuntimeError: + free_bytes, total_bytes = (0, 0) + return { + "cuda_allocated_bytes": int(torch.cuda.memory_allocated(device)), + "cuda_reserved_bytes": int(torch.cuda.memory_reserved(device)), + "cuda_max_allocated_bytes": int(torch.cuda.max_memory_allocated(device)), + "cuda_free_bytes": int(free_bytes), + "cuda_total_bytes": int(total_bytes), + } + + +def _benchmark_memory_ledger_on_error(*, model: nn.Module, device: torch.device) -> dict[str, int]: + return { + **_module_parameter_byte_summary(model), + **_cuda_memory_byte_summary(device), + } + + +def _benchmark_memory_ledger_on_success(*, model: nn.Module, device: torch.device) -> dict[str, int]: + return { + **_module_parameter_byte_summary(model), + **_cuda_memory_byte_summary(device), } @@ -1657,23 +2435,32 @@ def _run_rollout_online_grad_iteration( "BackboneFamily", "ExecutionKind", "MatchedBackbone", + "PopulationMode", "RolloutMode", "SequenceMode", "build_fabric_backbone", + "build_mixed_fabric_backbone", + "build_mixed_stack_backbone", "build_stack_backbone", "dtype_name", "find_param_matched_backbone", + "find_param_matched_mixed_fabric_backbone", + "find_param_matched_mixed_stack_backbone", "hf_context_window_tokens", "hf_benchmark_configs", "hf_rollout_configs", "hf_sequence_configs", "load_hf_benchmark_models", + "make_mixed_fabric_sequence_model", + "make_mixed_stack_sequence_model", "make_sequence_training_target", "make_sequence_model", "resolve_backbone_benchmark_dtype", "run_hf_case", "run_hf_rollout_case", "run_hf_sequence_case", + "run_mixed_fabric_sequence_case", + "run_mixed_stack_sequence_case", "run_rollout_case", "run_sequence_case", "run_synthetic_case", diff --git a/benchmarks/fabric_hf.py b/benchmarks/fabric_hf.py index 073cd004..dece8819 100644 --- a/benchmarks/fabric_hf.py +++ b/benchmarks/fabric_hf.py @@ -5,7 +5,7 @@ import torch from .common import BenchmarkCase, BenchmarkDefinition, BenchmarkSettings, ColumnSpec, register -from .fabric_suite_common import hf_sequence_configs, run_hf_sequence_case +from .fabric.suite_common import hf_sequence_configs, run_hf_sequence_case Config = Tuple[str, str, int, int] CONFIGS: Tuple[Config, ...] = hf_sequence_configs() diff --git a/benchmarks/fabric_hf_rollout.py b/benchmarks/fabric_hf_rollout.py index b1fb821d..588c1e1f 100644 --- a/benchmarks/fabric_hf_rollout.py +++ b/benchmarks/fabric_hf_rollout.py @@ -5,7 +5,7 @@ import torch from .common import BenchmarkCase, BenchmarkDefinition, BenchmarkSettings, ColumnSpec, register -from .fabric_suite_common import hf_rollout_configs, run_hf_rollout_case +from .fabric.suite_common import hf_rollout_configs, run_hf_rollout_case Config = Tuple[str, str, int, int] CONFIGS: Tuple[Config, ...] = hf_rollout_configs() diff --git a/benchmarks/fabric_rollout.py b/benchmarks/fabric_rollout.py index 3dd0528e..b8855049 100644 --- a/benchmarks/fabric_rollout.py +++ b/benchmarks/fabric_rollout.py @@ -5,7 +5,7 @@ import torch from .common import BenchmarkCase, BenchmarkDefinition, BenchmarkSettings, ColumnSpec, register -from .fabric_suite_common import ( +from .fabric.suite_common import ( find_param_matched_backbone, run_rollout_case, synthetic_rollout_configs, diff --git a/benchmarks/fabric_scaling.py b/benchmarks/fabric_scaling.py index fd1aaa21..afc10839 100644 --- a/benchmarks/fabric_scaling.py +++ b/benchmarks/fabric_scaling.py @@ -5,7 +5,7 @@ import torch from .common import BenchmarkCase, BenchmarkDefinition, BenchmarkSettings, ColumnSpec, register -from .fabric_suite_common import find_param_matched_backbone, run_sequence_case, synthetic_sequence_configs +from .fabric.suite_common import find_param_matched_backbone, run_sequence_case, synthetic_sequence_configs Config = Tuple[int, str, int, int] CONFIGS: Tuple[Config, ...] = synthetic_sequence_configs() diff --git a/benchmarks/run.py b/benchmarks/run.py index c58b6413..3f3c2e3b 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -41,6 +41,16 @@ def _load_benchmarks() -> None: continue fullname = f"{pkg}.{mod_name}" if pkg else mod_name importlib.import_module(fullname) + for entry in os.listdir(here): + package_dir = os.path.join(here, entry) + if not os.path.isdir(package_dir): + continue + if entry.startswith("__"): + continue + if not os.path.exists(os.path.join(package_dir, "benchmark.py")): + continue + fullname = f"{pkg}.{entry}.benchmark" if pkg else f"{entry}.benchmark" + importlib.import_module(fullname) def _format_available(registry: dict[str, BenchmarkDefinition]) -> str: diff --git a/benchmarks/run_fabric_bxt_scaling_audit.py b/benchmarks/run_fabric_bxt_scaling_audit.py deleted file mode 100644 index a2610be7..00000000 --- a/benchmarks/run_fabric_bxt_scaling_audit.py +++ /dev/null @@ -1,1068 +0,0 @@ -from __future__ import annotations - -import argparse -import hashlib -import json -import os -import sqlite3 -import subprocess -import sys -import time -from pathlib import Path -from typing import Any, Literal - -import torch - -ROOT = Path(__file__).resolve().parents[1] -for extra_path in (ROOT, ROOT / "src"): - if str(extra_path) not in sys.path: - sys.path.insert(0, str(extra_path)) - -from benchmarks.fabric_suite_common import ( # noqa: E402 - BT_GRID, - BackboneFamily, - BackboneKind, - SequenceMode, - dtype_name, - find_param_matched_backbone, - resolve_backbone_benchmark_dtype, - run_sequence_case, -) - -Side = Literal["stack", "fabric"] -CaseKey = tuple[str, int, str, int, int, int, str, tuple[int, ...] | None] - - -def main() -> int: - args = _build_parser().parse_args() - if args.child_kind == "case": - return _run_child_case(args) - return _run_parent(args) - - -def _build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description="Run Fabric B x parameter-count scaling audit with a streaming-window axis." - ) - parser.add_argument( - "--out-dir", - default="docs/user/subho/fabric_benchmark/results/profiles/fabric_final_bxt_scaling", - ) - parser.add_argument("--families", default="slstm,axoncell") - parser.add_argument("--sizes", default="100m,1b") - parser.add_argument("--modes", default="forward,forward_backward") - parser.add_argument("--batches", default=",".join(str(value) for value in BT_GRID)) - parser.add_argument("--seq-lens", default=",".join(str(value) for value in BT_GRID)) - parser.add_argument("--windows", help="Alias for --seq-lens; these are streaming/TBPTT window lengths.") - parser.add_argument( - "--inner-steps", - default="1", - help="Comma-separated semantic recurrent microstep counts K. This sets Fabric default_k/k_max.", - ) - parser.add_argument("--device", default="cuda") - parser.add_argument("--dtype", default="float32") - parser.add_argument( - "--fabric-hidden-sizes", - help=( - "Optional comma-separated Fabric cell hidden sizes for stress/sanity sweeps. " - "Omit for the default h=32 closure baseline." - ), - ) - parser.add_argument("--warmup", type=int, default=1) - parser.add_argument("--iterations", type=int, default=3) - parser.add_argument( - "--training-output-boundary", - choices=("terminal", "sequence"), - default="terminal", - help="For forward_backward rows, use terminal streaming loss or dense sequence loss.", - ) - parser.add_argument( - "--resume", - action="store_true", - help="Reuse completed pairs from cases.jsonl in the output dir.", - ) - parser.add_argument("--child-kind", choices=("parent", "case"), default="parent") - parser.add_argument("--family", choices=("slstm", "axoncell")) - parser.add_argument("--target-params", type=int) - parser.add_argument("--mode", choices=("forward", "forward_backward")) - parser.add_argument("--batch-size", type=int) - parser.add_argument("--seq-len", type=int) - parser.add_argument("--inner-step", type=int, default=1) - parser.add_argument("--side", choices=("stack", "fabric")) - parser.add_argument("--output-json") - return parser - - -def _run_parent(args: argparse.Namespace) -> int: - out_dir = Path(args.out_dir) - case_dir = out_dir / "cases" - case_dir.mkdir(parents=True, exist_ok=True) - families = _parse_families(args.families) - sizes = _parse_size_targets(args.sizes) - modes = _parse_modes(args.modes) - batches = _parse_int_csv(args.batches) - seq_lens = _parse_int_csv(args.windows or args.seq_lens) - inner_steps_values = _parse_int_csv(args.inner_steps) - fabric_hidden_sizes = _parse_optional_int_csv(args.fabric_hidden_sizes) - completed_cases = _load_completed_cases(out_dir / "cases.jsonl") if args.resume else {} - cases: list[dict[str, Any]] = [] - started_at = time.time() - for family in families: - for target_params in sizes: - for mode in modes: - for batch_size in batches: - for seq_len in seq_lens: - for inner_steps in inner_steps_values: - key = _case_key( - family=family, - target_params=target_params, - mode=mode, - batch_size=batch_size, - seq_len=seq_len, - inner_steps=inner_steps, - training_output_boundary=args.training_output_boundary, - fabric_hidden_sizes=fabric_hidden_sizes, - ) - if key in completed_cases: - cases.append(completed_cases[key]) - _write_summary(out_dir, args=args, cases=cases, elapsed_s=time.time() - started_at) - continue - pair: dict[str, Any] = { - "family": family, - "size_label": _format_size_label(target_params), - "target_params": target_params, - "mode": mode, - "batch_size": batch_size, - "seq_len": seq_len, - "inner_steps": inner_steps, - "training_output_boundary": args.training_output_boundary, - "fabric_hidden_sizes": fabric_hidden_sizes, - } - for side in ("stack", "fabric"): - result = _run_side_subprocess( - args=args, - case_dir=case_dir, - family=family, - target_params=target_params, - mode=mode, - batch_size=batch_size, - seq_len=seq_len, - inner_steps=inner_steps, - side=side, - ) - pair[side] = result - _add_pair_summary(pair) - cases.append(pair) - _write_jsonl(out_dir / "cases.jsonl", pair) - _write_summary(out_dir, args=args, cases=cases, elapsed_s=time.time() - started_at) - _write_summary(out_dir, args=args, cases=cases, elapsed_s=time.time() - started_at) - print( - json.dumps( - {"benchmark": "fabric_batch_param_scaling_audit", "out_dir": str(out_dir), "cases": len(cases)}, - indent=2, - ) - ) - return 0 - - -def _case_key( - *, - family: BackboneFamily, - target_params: int, - mode: SequenceMode, - batch_size: int, - seq_len: int, - inner_steps: int, - training_output_boundary: str, - fabric_hidden_sizes: tuple[int, ...] | None, -) -> CaseKey: - return ( - family, - target_params, - mode, - batch_size, - seq_len, - inner_steps, - training_output_boundary, - fabric_hidden_sizes, - ) - - -def _load_completed_cases(path: Path) -> dict[CaseKey, dict[str, Any]]: - cases: dict[CaseKey, dict[str, Any]] = {} - if not path.exists(): - return cases - for line in path.read_text().splitlines(): - if not line.strip(): - continue - case = json.loads(line) - key = _case_key( - family=case["family"], - target_params=int(case["target_params"]), - mode=case["mode"], - batch_size=int(case["batch_size"]), - seq_len=int(case["seq_len"]), - inner_steps=int(case.get("inner_steps", 1)), - training_output_boundary=str(case.get("training_output_boundary", "terminal")), - fabric_hidden_sizes=tuple(case["fabric_hidden_sizes"]) if case.get("fabric_hidden_sizes") else None, - ) - cases[key] = case - return cases - - -def _run_side_subprocess( - *, - args: argparse.Namespace, - case_dir: Path, - family: BackboneFamily, - target_params: int, - mode: SequenceMode, - batch_size: int, - seq_len: int, - inner_steps: int, - side: Side, -) -> dict[str, Any]: - hidden_suffix = "" - if args.fabric_hidden_sizes: - hidden_suffix = "_h" + "-".join(str(value) for value in _parse_int_csv(args.fabric_hidden_sizes)) - output_json = case_dir / ( - f"{family}_{_format_size_label(target_params)}_{mode}_{args.training_output_boundary}_" - f"b{batch_size}_t{seq_len}_k{inner_steps}{hidden_suffix}_{side}.json" - ) - command = [ - sys.executable, - str(Path(__file__).resolve()), - "--child-kind", - "case", - "--family", - family, - "--target-params", - str(target_params), - "--mode", - mode, - "--batch-size", - str(batch_size), - "--seq-len", - str(seq_len), - "--inner-step", - str(inner_steps), - "--side", - side, - "--device", - args.device, - "--dtype", - args.dtype, - "--warmup", - str(args.warmup), - "--iterations", - str(args.iterations), - "--training-output-boundary", - args.training_output_boundary, - "--output-json", - str(output_json), - ] - if args.fabric_hidden_sizes: - command.extend(["--fabric-hidden-sizes", args.fabric_hidden_sizes]) - completed = subprocess.run(command, cwd=ROOT, env=os.environ.copy(), text=True, capture_output=True, check=False) - if completed.returncode != 0: - result: dict[str, Any] = { - "status": "failed", - "side": side, - "returncode": completed.returncode, - "stderr": completed.stderr[-4000:], - "stdout": completed.stdout[-4000:], - } - _write_json(output_json, result) - return result - return json.loads(output_json.read_text()) - - -def _run_child_case(args: argparse.Namespace) -> int: - assert args.family is not None - assert args.target_params is not None - assert args.mode is not None - assert args.batch_size is not None - assert args.seq_len is not None - assert args.side is not None - assert args.output_json is not None - device = torch.device(args.device) - requested_dtype = _resolve_dtype(args.dtype) - dtype = resolve_backbone_benchmark_dtype(device=device, requested_dtype=requested_dtype) - stack_match = find_param_matched_backbone(target_params=args.target_params, kind="stack", family=args.family) - match_kind: BackboneKind = args.side - if match_kind == "fabric": - fabric_hidden_sizes = _parse_optional_int_csv(args.fabric_hidden_sizes) - match = find_param_matched_backbone( - target_params=stack_match.actual_params, - kind="fabric", - family=args.family, - forced_d_hidden=stack_match.d_hidden, - fabric_hidden_grid=fabric_hidden_sizes, - ) - else: - match = stack_match - result = run_sequence_case( - match=match, - mode=args.mode, - batch_size=args.batch_size, - seq_len=args.seq_len, - family=args.family, - device=device, - dtype=dtype, - warmup=args.warmup, - iterations=args.iterations, - training_output_boundary=args.training_output_boundary, - forward_output_boundary=args.training_output_boundary, - inner_steps=args.inner_step if args.side == "fabric" else 1, - ) - result["side"] = args.side - result["size_label"] = _format_size_label(args.target_params) - result["requested_target_params"] = args.target_params - result["requested_dtype"] = args.dtype - result["resolved_dtype"] = dtype_name(dtype) - result["training_output_boundary"] = args.training_output_boundary - result["requested_inner_steps"] = args.inner_step - result["fabric_hidden_sizes"] = _parse_optional_int_csv(args.fabric_hidden_sizes) - _write_json(Path(args.output_json), result) - return 0 - - -def _add_pair_summary(pair: dict[str, Any]) -> None: - stack = pair["stack"] - fabric = pair["fabric"] - stack_ok = stack.get("status") == "ok" - fabric_ok = fabric.get("status") == "ok" - stack_actual = stack.get("actual_params") - fabric_actual = fabric.get("actual_params") - requested_target = pair.get("target_params") - if isinstance(stack_actual, int | float) and isinstance(fabric_actual, int | float): - pair["stack_actual_params"] = int(stack_actual) - pair["fabric_actual_params"] = int(fabric_actual) - pair["fabric_vs_stack_param_error"] = _relative_error(int(fabric_actual), int(stack_actual)) - if isinstance(requested_target, int): - pair["fabric_vs_requested_param_error"] = _relative_error(int(fabric_actual), int(requested_target)) - pair["param_match_note"] = _param_match_note( - requested_target=int(requested_target) if isinstance(requested_target, int) else None, - stack_actual=int(stack_actual), - fabric_actual=int(fabric_actual), - ) - pair["status"] = "ok" if stack_ok and fabric_ok else _merge_status(stack.get("status"), fabric.get("status")) - if stack_ok and fabric_ok: - pair["fabric_stack_ratio"] = float(fabric["tokens_per_s"]) / float(stack["tokens_per_s"]) - pair["fabric_stack_ratio_percent"] = 100.0 * float(pair["fabric_stack_ratio"]) - pair["stack_tokens_per_s"] = stack["tokens_per_s"] - pair["fabric_tokens_per_s"] = fabric["tokens_per_s"] - pair["stack_peak_mem_gib"] = stack["peak_mem_gib"] - pair["fabric_peak_mem_gib"] = fabric["peak_mem_gib"] - - -def _write_summary(out_dir: Path, *, args: argparse.Namespace, cases: list[dict[str, Any]], elapsed_s: float) -> None: - planner_signature_rows = _planner_signature_rows(cases) - mode_coverage_rows = _mode_coverage_rows(cases) - summary = { - "benchmark": "fabric_batch_param_scaling_audit", - "device": args.device, - "dtype": args.dtype, - "warmup": args.warmup, - "iterations": args.iterations, - "elapsed_s": elapsed_s, - "cases": cases, - "mode_coverage_rows": mode_coverage_rows, - "planner_signature_rows": planner_signature_rows, - } - _write_json(out_dir / "bxt_scaling_audit.json", summary) - _write_json(out_dir / "batch_param_scaling_audit.json", summary) - _write_jsonl_replace(out_dir / "planner_signature_db.jsonl", planner_signature_rows) - _write_planner_policy_sqlite(out_dir / "planner_policy_db.sqlite", cases, planner_signature_rows) - (out_dir / "bxt_scaling_audit.md").write_text(_format_markdown_summary(summary)) - (out_dir / "batch_param_scaling_audit.md").write_text(_format_markdown_summary(summary)) - - -def _format_markdown_summary(summary: dict[str, Any]) -> str: - lines = [ - "# Fabric B x Params Scaling Audit", - "", - f"- device: `{summary['device']}`", - f"- dtype: `{summary['dtype']}`", - f"- warmup: `{summary['warmup']}`", - f"- iterations: `{summary['iterations']}`", - f"- planner signature rows: `{len(summary.get('planner_signature_rows', []))}`", - "- queryable planner DB: `planner_policy_db.sqlite`", - "- Fabric hidden size override: `{}`".format( - ",".join(str(value) for value in summary["cases"][0].get("fabric_hidden_sizes") or ()) - if summary.get("cases") and summary["cases"][0].get("fabric_hidden_sizes") - else "default h=32" - ), - "- T is reported as the streaming/TBPTT window length.", - "- K is reported as the semantic inner recurrent microstep count.", - "- Final scaling closure requires both `forward` and `forward_backward` coverage for each supported B x params " - "row.", - "", - "## Forward/Backward Coverage", - "", - "| family | size | B | T | K | forward | forward_backward | coverage |", - "|---|---|---:|---:|---:|---|---|---|", - ] - for row in summary.get("mode_coverage_rows", ()): - lines.append( - "| {family} | {size} | {batch} | {seq} | {inner} | {forward} | {forward_backward} | {coverage} |".format( - family=row["family"], - size=row["size_label"], - batch=row["batch_size"], - seq=row["seq_len"], - inner=row["inner_steps"], - forward=row["forward_status"], - forward_backward=row["forward_backward_status"], - coverage=row["coverage_status"], - ) - ) - lines.extend( - [ - "", - "## Measurements", - "", - "| family | size | mode | B | T | K | status | param match | actual params | active cells | h | d_hidden | " - "stack tok/s | " - "Fabric tok/s | ratio % | stack GiB | Fabric GiB |", - "|---|---|---|---:|---:|---:|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|", - ] - ) - for case in summary["cases"]: - stack = case["stack"] - fabric = case["fabric"] - fabric_signature = fabric.get("planner_signature") or {} - lines.append( - ( - "| {family} | {size} | {mode} | {batch} | {seq} | {inner} | {status} | {param_match} | {actual_params} | " - "{active_cells} | {cell_hidden} | {d_hidden} | {stack_toks} | {fabric_toks} | {ratio} | " - "{stack_mem} | {fabric_mem} |" - ).format( - family=case["family"], - size=case["size_label"], - mode=case["mode"], - batch=case["batch_size"], - seq=case["seq_len"], - inner=case.get("inner_steps", 1), - status=case["status"], - param_match=case.get("param_match_note", "-"), - stack_toks=_fmt_float(stack.get("tokens_per_s")), - actual_params=_fmt_int(fabric.get("actual_params")), - active_cells=_fmt_int(fabric_signature.get("active_receivers")), - cell_hidden=_fmt_int(fabric.get("cell_hidden_size")), - d_hidden=_fmt_int(fabric.get("d_hidden")), - fabric_toks=_fmt_float(fabric.get("tokens_per_s")), - ratio=_fmt_percent(case.get("fabric_stack_ratio")), - stack_mem=_fmt_float(stack.get("peak_mem_gib")), - fabric_mem=_fmt_float(fabric.get("peak_mem_gib")), - ) - ) - return "\n".join(lines) + "\n" - - -def _relative_error(actual: int, target: int) -> float: - if target <= 0: - return 0.0 - return (float(actual) - float(target)) / float(target) - - -def _param_match_note(*, requested_target: int | None, stack_actual: int, fabric_actual: int) -> str: - fabric_vs_stack = _relative_error(fabric_actual, stack_actual) - parts = [f"fabric_vs_stack={fabric_vs_stack:+.1%}"] - if requested_target is not None: - parts.append(f"fabric_vs_requested={_relative_error(fabric_actual, requested_target):+.1%}") - if abs(fabric_vs_stack) <= 0.01: - status = "matched" - elif abs(fabric_vs_stack) <= 0.20: - status = "near" - else: - status = "nearest_supported" - return status + ";" + ";".join(parts) - - -def _mode_coverage_rows(cases: list[dict[str, Any]]) -> list[dict[str, Any]]: - grouped: dict[tuple[str, str, int, int, int, int], dict[str, str]] = {} - labels: dict[tuple[str, str, int, int, int, int], dict[str, object]] = {} - for case in cases: - key = ( - str(case["family"]), - str(case["size_label"]), - int(case["target_params"]), - int(case["batch_size"]), - int(case["seq_len"]), - int(case.get("inner_steps", 1)), - ) - grouped.setdefault(key, {})[str(case["mode"])] = str(case["status"]) - labels[key] = { - "family": case["family"], - "size_label": case["size_label"], - "target_params": case["target_params"], - "batch_size": case["batch_size"], - "seq_len": case["seq_len"], - "inner_steps": case.get("inner_steps", 1), - } - rows: list[dict[str, Any]] = [] - for key in sorted(grouped): - modes = grouped[key] - forward_status = modes.get("forward", "missing") - forward_backward_status = modes.get("forward_backward", "missing") - if forward_status != "missing" and forward_backward_status != "missing": - coverage_status = "attempted" - elif forward_status == "missing": - coverage_status = "missing_forward" - else: - coverage_status = "missing_forward_backward" - rows.append( - { - **labels[key], - "forward_status": forward_status, - "forward_backward_status": forward_backward_status, - "coverage_status": coverage_status, - } - ) - return rows - - -def _planner_signature_rows(cases: list[dict[str, Any]]) -> list[dict[str, Any]]: - rows: list[dict[str, Any]] = [] - for case in cases: - for side in ("stack", "fabric"): - result = case.get(side) - if not isinstance(result, dict): - continue - signature = result.get("planner_signature") - if not isinstance(signature, dict): - continue - rows.append( - { - "family": case["family"], - "size_label": case["size_label"], - "target_params": case["target_params"], - "mode": case["mode"], - "batch_size": case["batch_size"], - "window_len": case["seq_len"], - "inner_steps": case.get("inner_steps", 1), - "side": side, - "status": result.get("status"), - "tokens_per_s": result.get("tokens_per_s"), - "peak_mem_gib": result.get("peak_mem_gib"), - "fabric_stack_ratio": case.get("fabric_stack_ratio"), - "planner_signature": signature, - } - ) - return rows - - -def _write_planner_policy_sqlite( - path: Path, - cases: list[dict[str, Any]], - planner_signature_rows: list[dict[str, Any]], -) -> None: - path.unlink(missing_ok=True) - with sqlite3.connect(path) as conn: - conn.execute("PRAGMA journal_mode=OFF") - conn.execute("PRAGMA synchronous=OFF") - conn.execute( - """ - CREATE TABLE cases ( - case_id TEXT PRIMARY KEY, - family TEXT NOT NULL, - size_label TEXT NOT NULL, - target_params INTEGER NOT NULL, - stack_actual_params INTEGER, - fabric_actual_params INTEGER, - fabric_vs_stack_param_error REAL, - fabric_vs_requested_param_error REAL, - param_match_note TEXT, - mode TEXT NOT NULL, - batch_size INTEGER NOT NULL, - window_len INTEGER NOT NULL, - inner_steps INTEGER NOT NULL, - status TEXT NOT NULL, - stack_status TEXT, - fabric_status TEXT, - stack_tokens_per_s REAL, - fabric_tokens_per_s REAL, - fabric_stack_ratio REAL, - stack_peak_mem_gib REAL, - fabric_peak_mem_gib REAL - ) - """ - ) - conn.execute( - """ - CREATE TABLE planner_signatures ( - signature_id TEXT PRIMARY KEY, - family TEXT NOT NULL, - size_label TEXT NOT NULL, - target_params INTEGER NOT NULL, - mode TEXT NOT NULL, - batch_size INTEGER NOT NULL, - window_len INTEGER NOT NULL, - inner_steps INTEGER NOT NULL, - side TEXT NOT NULL, - status TEXT, - tokens_per_s REAL, - peak_mem_gib REAL, - fabric_stack_ratio REAL, - active_receivers INTEGER, - actual_params INTEGER, - d_hidden INTEGER, - cell_hidden_size INTEGER, - signature_json TEXT NOT NULL - ) - """ - ) - conn.execute( - """ - CREATE TABLE case_planner_signatures ( - case_id TEXT NOT NULL, - signature_id TEXT NOT NULL, - side TEXT NOT NULL, - PRIMARY KEY (case_id, signature_id, side), - FOREIGN KEY(case_id) REFERENCES cases(case_id), - FOREIGN KEY(signature_id) REFERENCES planner_signatures(signature_id) - ) - """ - ) - conn.execute( - """ - CREATE TABLE planner_policy_observations ( - observation_id TEXT PRIMARY KEY, - signature_id TEXT NOT NULL, - family TEXT NOT NULL, - mode TEXT NOT NULL, - side TEXT NOT NULL, - status TEXT, - target_params INTEGER NOT NULL, - actual_params INTEGER, - active_receivers INTEGER, - d_hidden INTEGER, - cell_hidden_size INTEGER, - batch_size INTEGER NOT NULL, - window_len INTEGER NOT NULL, - inner_steps INTEGER NOT NULL, - tokens_per_s REAL, - peak_mem_gib REAL, - fabric_stack_ratio REAL, - policy_key TEXT NOT NULL, - policy_json TEXT NOT NULL, - FOREIGN KEY(signature_id) REFERENCES planner_signatures(signature_id) - ) - """ - ) - conn.execute( - """ - CREATE VIEW best_fabric_policy_by_shape AS - SELECT - family, - mode, - target_params, - actual_params, - active_receivers, - d_hidden, - cell_hidden_size, - batch_size, - window_len, - inner_steps, - MAX(tokens_per_s) AS best_tokens_per_s, - MIN(peak_mem_gib) AS min_peak_mem_gib, - MAX(fabric_stack_ratio) AS best_fabric_stack_ratio - FROM planner_policy_observations - WHERE side = 'fabric' AND status = 'ok' - GROUP BY - family, - mode, - target_params, - actual_params, - active_receivers, - d_hidden, - cell_hidden_size, - batch_size, - window_len, - inner_steps - """ - ) - conn.execute( - """ - CREATE VIEW forward_backward_pair_coverage AS - SELECT - family, - size_label, - target_params, - batch_size, - window_len, - inner_steps, - MAX(CASE WHEN mode = 'forward' THEN status END) AS forward_status, - MAX(CASE WHEN mode = 'forward_backward' THEN status END) AS forward_backward_status, - MAX(CASE WHEN mode = 'forward' THEN fabric_tokens_per_s END) AS forward_fabric_tokens_per_s, - MAX(CASE WHEN mode = 'forward_backward' THEN fabric_tokens_per_s END) - AS forward_backward_fabric_tokens_per_s, - MAX(CASE WHEN mode = 'forward' THEN fabric_stack_ratio END) AS forward_fabric_stack_ratio, - MAX(CASE WHEN mode = 'forward_backward' THEN fabric_stack_ratio END) - AS forward_backward_fabric_stack_ratio, - CASE - WHEN SUM(CASE WHEN mode = 'forward' THEN 1 ELSE 0 END) > 0 - AND SUM(CASE WHEN mode = 'forward_backward' THEN 1 ELSE 0 END) > 0 - THEN 'attempted' - WHEN SUM(CASE WHEN mode = 'forward' THEN 1 ELSE 0 END) = 0 - THEN 'missing_forward' - ELSE 'missing_forward_backward' - END AS coverage_status - FROM cases - GROUP BY family, size_label, target_params, batch_size, window_len, inner_steps - """ - ) - conn.execute("CREATE INDEX idx_cases_family_mode ON cases(family, mode, batch_size, window_len, inner_steps)") - conn.execute( - "CREATE INDEX idx_signatures_generic ON planner_signatures(" - "family, mode, active_receivers, batch_size, window_len, inner_steps, side)" - ) - conn.execute( - "CREATE INDEX idx_policy_shape ON planner_policy_observations(" - "mode, active_receivers, batch_size, window_len, inner_steps, side)" - ) - conn.execute("CREATE INDEX idx_policy_key ON planner_policy_observations(policy_key, side, status)") - conn.executemany( - """ - INSERT INTO cases ( - case_id, - family, - size_label, - target_params, - stack_actual_params, - fabric_actual_params, - fabric_vs_stack_param_error, - fabric_vs_requested_param_error, - param_match_note, - mode, - batch_size, - window_len, - inner_steps, - status, - stack_status, - fabric_status, - stack_tokens_per_s, - fabric_tokens_per_s, - fabric_stack_ratio, - stack_peak_mem_gib, - fabric_peak_mem_gib - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - (_case_sql_row(case) for case in cases), - ) - signature_rows = [] - link_rows = [] - observation_rows = [] - for row in planner_signature_rows: - signature = row["planner_signature"] - signature_json = _canonical_json(signature) - signature_id = _stable_id(signature_json) - policy = _planner_policy_from_signature(signature) - policy_json = _canonical_json(policy) - policy_key = _stable_id(policy_json) - observation_id = _stable_id( - _canonical_json( - { - "signature_id": signature_id, - "side": row["side"], - "batch_size": row["batch_size"], - "window_len": row["window_len"], - "inner_steps": row["inner_steps"], - } - ) - ) - signature_rows.append( - ( - signature_id, - row["family"], - row["size_label"], - int(row["target_params"]), - row["mode"], - int(row["batch_size"]), - int(row["window_len"]), - int(row["inner_steps"]), - row["side"], - row.get("status"), - row.get("tokens_per_s"), - row.get("peak_mem_gib"), - row.get("fabric_stack_ratio"), - signature.get("active_receivers"), - signature.get("actual_params"), - signature.get("d_hidden"), - signature.get("cell_hidden_size"), - signature_json, - ) - ) - observation_rows.append( - ( - observation_id, - signature_id, - row["family"], - row["mode"], - row["side"], - row.get("status"), - int(row["target_params"]), - signature.get("actual_params"), - signature.get("active_receivers"), - signature.get("d_hidden"), - signature.get("cell_hidden_size"), - int(row["batch_size"]), - int(row["window_len"]), - int(row["inner_steps"]), - row.get("tokens_per_s"), - row.get("peak_mem_gib"), - row.get("fabric_stack_ratio"), - policy_key, - policy_json, - ) - ) - link_rows.append( - ( - _case_id( - family=row["family"], - target_params=int(row["target_params"]), - mode=row["mode"], - batch_size=int(row["batch_size"]), - window_len=int(row["window_len"]), - inner_steps=int(row["inner_steps"]), - ), - signature_id, - row["side"], - ) - ) - conn.executemany( - """ - INSERT OR REPLACE INTO planner_signatures ( - signature_id, - family, - size_label, - target_params, - mode, - batch_size, - window_len, - inner_steps, - side, - status, - tokens_per_s, - peak_mem_gib, - fabric_stack_ratio, - active_receivers, - actual_params, - d_hidden, - cell_hidden_size, - signature_json - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - signature_rows, - ) - conn.executemany( - "INSERT OR REPLACE INTO case_planner_signatures (case_id, signature_id, side) VALUES (?, ?, ?)", - link_rows, - ) - conn.executemany( - """ - INSERT OR REPLACE INTO planner_policy_observations ( - observation_id, - signature_id, - family, - mode, - side, - status, - target_params, - actual_params, - active_receivers, - d_hidden, - cell_hidden_size, - batch_size, - window_len, - inner_steps, - tokens_per_s, - peak_mem_gib, - fabric_stack_ratio, - policy_key, - policy_json - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - observation_rows, - ) - - -def _case_sql_row(case: dict[str, Any]) -> tuple[object, ...]: - stack = case.get("stack", {}) - fabric = case.get("fabric", {}) - return ( - _case_id( - family=case["family"], - target_params=int(case["target_params"]), - mode=case["mode"], - batch_size=int(case["batch_size"]), - window_len=int(case["seq_len"]), - inner_steps=int(case.get("inner_steps", 1)), - ), - case["family"], - case["size_label"], - int(case["target_params"]), - case.get("stack_actual_params"), - case.get("fabric_actual_params"), - case.get("fabric_vs_stack_param_error"), - case.get("fabric_vs_requested_param_error"), - case.get("param_match_note"), - case["mode"], - int(case["batch_size"]), - int(case["seq_len"]), - int(case.get("inner_steps", 1)), - case.get("status"), - stack.get("status") if isinstance(stack, dict) else None, - fabric.get("status") if isinstance(fabric, dict) else None, - stack.get("tokens_per_s") if isinstance(stack, dict) else None, - fabric.get("tokens_per_s") if isinstance(fabric, dict) else None, - case.get("fabric_stack_ratio"), - stack.get("peak_mem_gib") if isinstance(stack, dict) else None, - fabric.get("peak_mem_gib") if isinstance(fabric, dict) else None, - ) - - -def _case_id( - *, - family: str, - target_params: int, - mode: str, - batch_size: int, - window_len: int, - inner_steps: int, -) -> str: - return f"{family}:{target_params}:{mode}:b{batch_size}:t{window_len}:k{inner_steps}" - - -def _canonical_json(value: Any) -> str: - return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str) - - -def _planner_policy_from_signature(signature: dict[str, Any]) -> dict[str, Any]: - return { - "training": signature.get("training"), - "inner_steps": signature.get("inner_steps"), - "execution_families": signature.get("execution_families", ()), - "math_backends": signature.get("math_backends", ()), - "physical_op_kinds": signature.get("physical_op_kinds", ()), - "physical_op_executors": signature.get("physical_op_executors", ()), - "state_epilogue_execution_mode": signature.get("state_epilogue_execution_mode", ()), - "message_physical_modes": signature.get("message_physical_modes", ()), - "layout_mode": signature.get("layout_mode", ()), - "active_cell_tiling_plans": signature.get("active_cell_tiling_plans", ()), - "large_r_safety_modes": signature.get("large_r_safety_modes", ()), - "backward_physical_op_kinds": signature.get("backward_physical_op_kinds", ()), - "backward_physical_op_executors": signature.get("backward_physical_op_executors", ()), - "backward_tape_mode": signature.get("backward_tape_mode", ()), - "backward_recompute_mode": signature.get("backward_recompute_mode", ()), - } - - -def _stable_id(value: str) -> str: - return hashlib.sha256(value.encode("utf-8")).hexdigest()[:24] - - -def _parse_size_targets(value: str) -> tuple[int, ...]: - return tuple(_parse_size_target(part) for part in value.split(",") if part) - - -def _parse_size_target(value: str) -> int: - normalized = value.strip().lower() - if normalized.endswith("k"): - return int(float(normalized[:-1]) * 1_000) - if normalized.endswith("m"): - return int(float(normalized[:-1]) * 1_000_000) - if normalized.endswith("b"): - return int(float(normalized[:-1]) * 1_000_000_000) - return int(normalized) - - -def _parse_int_csv(value: str) -> tuple[int, ...]: - return tuple(int(part) for part in value.split(",") if part) - - -def _parse_optional_int_csv(value: str | None) -> tuple[int, ...] | None: - if value is None or not value.strip(): - return None - return _parse_int_csv(value) - - -def _parse_families(value: str) -> tuple[BackboneFamily, ...]: - families = tuple(part.strip() for part in value.split(",") if part.strip()) - assert all(family in ("slstm", "axoncell") for family in families) - return families # type: ignore[return-value] - - -def _parse_modes(value: str) -> tuple[SequenceMode, ...]: - modes = tuple(part.strip() for part in value.split(",") if part.strip()) - assert all(mode in ("forward", "forward_backward") for mode in modes) - return modes # type: ignore[return-value] - - -def _format_size_label(target_params: int) -> str: - if target_params >= 1_000_000_000 and target_params % 1_000_000_000 == 0: - return f"{target_params // 1_000_000_000}b" - if target_params >= 1_000_000 and target_params % 1_000_000 == 0: - return f"{target_params // 1_000_000}m" - if target_params >= 1_000 and target_params % 1_000 == 0: - return f"{target_params // 1_000}k" - return str(target_params) - - -def _resolve_dtype(value: str) -> torch.dtype: - normalized = value.lower() - if normalized == "float32": - return torch.float32 - if normalized == "float16": - return torch.float16 - if normalized == "bfloat16": - return torch.bfloat16 - raise ValueError(f"Unsupported dtype: {value}") - - -def _merge_status(stack_status: object, fabric_status: object) -> str: - if stack_status == "ok" and fabric_status == "ok": - return "ok" - if stack_status == "oom" or fabric_status == "oom": - return "oom" - if stack_status == "unsupported" or fabric_status == "unsupported": - return "unsupported" - return "failed" - - -def _write_json(path: Path, payload: dict[str, Any]) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n") - - -def _write_jsonl(path: Path, payload: dict[str, Any]) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("a") as file: - file.write(json.dumps(payload, sort_keys=True) + "\n") - - -def _write_jsonl_replace(path: Path, payloads: list[dict[str, Any]]) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text("".join(json.dumps(payload, sort_keys=True) + "\n" for payload in payloads)) - - -def _fmt_float(value: object) -> str: - if value is None: - return "-" - return f"{float(value):.1f}" - - -def _fmt_int(value: object) -> str: - if value is None: - return "-" - return str(int(value)) - - -def _fmt_percent(value: object) -> str: - if value is None: - return "-" - return f"{100.0 * float(value):.2f}" - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/benchmarks/run_fabric_factorization_invariance_audit.py b/benchmarks/run_fabric_factorization_invariance_audit.py deleted file mode 100644 index dd6c1bcb..00000000 --- a/benchmarks/run_fabric_factorization_invariance_audit.py +++ /dev/null @@ -1,390 +0,0 @@ -from __future__ import annotations - -import argparse -import json -import sys -import time -from pathlib import Path -from typing import Any - -import torch - -ROOT = Path(__file__).resolve().parents[1] -for extra_path in (ROOT, ROOT / "src"): - if str(extra_path) not in sys.path: - sys.path.insert(0, str(extra_path)) - -from benchmarks.fabric_suite_common import ( # noqa: E402 - BackboneFamily, - MatchedBackbone, - SequenceMode, - _BackboneWithHead, - build_fabric_backbone, - dtype_name, - resolve_backbone_benchmark_dtype, - run_sequence_case, -) - - -def main() -> int: - args = _build_parser().parse_args() - started_at = time.time() - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - device = torch.device(args.device) - dtype = resolve_backbone_benchmark_dtype(device=device, requested_dtype=_resolve_dtype(args.dtype)) - families = _parse_families(args.families) - modes = _parse_modes(args.modes) - shapes = _parse_shapes(args.shapes) - cases: list[dict[str, Any]] = [] - for family in families: - for mode in modes: - for width, height in shapes: - result = _run_case( - family=family, - mode=mode, - width=width, - height=height, - d_hidden=args.d_hidden, - hidden_size=args.fabric_hidden_size, - boundary_port_count=args.boundary_port_count, - topology_mode=args.topology_mode, - degree=args.degree, - batch_size=args.batch_size, - seq_len=args.seq_len, - device=device, - dtype=dtype, - warmup=args.warmup, - iterations=args.iterations, - ) - cases.append(result) - _write_json(out_dir / "cases.json", cases) - _write_summary( - out_dir, - args=args, - cases=cases, - dtype=dtype, - elapsed_s=time.time() - started_at, - ) - _write_summary(out_dir, args=args, cases=cases, dtype=dtype, elapsed_s=time.time() - started_at) - print(json.dumps({"benchmark": "fabric_factorization_invariance_audit", "out_dir": str(out_dir)}, indent=2)) - return 0 - - -def _build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Run a graph-signature-aware Fabric factorization audit.") - parser.add_argument( - "--out-dir", - default="docs/user/subho/fabric_benchmark/results/profiles/fabric_factorization_invariance_audit", - ) - parser.add_argument("--families", default="slstm,axoncell") - parser.add_argument("--modes", default="forward") - parser.add_argument("--shapes", default="128x2048,512x512,2048x128") - parser.add_argument("--d-hidden", type=int, default=1536) - parser.add_argument("--fabric-hidden-size", type=int, default=32) - parser.add_argument("--boundary-port-count", type=int, default=512) - parser.add_argument( - "--topology-mode", - choices=("lattice", "flat_ring"), - default="flat_ring", - help="Use lattice constructor edges or a shape-independent flat ring graph.", - ) - parser.add_argument("--degree", type=int, default=8) - parser.add_argument("--batch-size", type=int, default=1024) - parser.add_argument("--seq-len", type=int, default=1) - parser.add_argument("--device", default="cuda") - parser.add_argument("--dtype", default="float32") - parser.add_argument("--warmup", type=int, default=1) - parser.add_argument("--iterations", type=int, default=2) - return parser - - -def _run_case( - *, - family: BackboneFamily, - mode: SequenceMode, - width: int, - height: int, - d_hidden: int, - hidden_size: int, - boundary_port_count: int, - topology_mode: str, - degree: int, - batch_size: int, - seq_len: int, - device: torch.device, - dtype: torch.dtype, - warmup: int, - iterations: int, -) -> dict[str, Any]: - input_indices, output_indices = _flat_boundary_indices(width=width, height=height, count=boundary_port_count) - graph_edges = None - kv_group_ids = None - if topology_mode == "flat_ring": - graph_edges = _flat_ring_edges( - width=width, - height=height, - input_indices=input_indices, - output_indices=output_indices, - degree=degree, - ) - kv_group_ids = tuple(idx // 2 for idx in range(width * height)) - backbone = build_fabric_backbone( - family=family, - d_hidden=d_hidden, - width=width, - height=height, - hidden_size=hidden_size, - input_cell_indices=input_indices, - output_cell_indices=output_indices, - graph_edges=graph_edges, - kv_group_ids=kv_group_ids, - ) - model = _BackboneWithHead(backbone, d_hidden=d_hidden).to(device=device, dtype=dtype) - actual_params = sum(param.numel() for param in model.parameters()) - match = MatchedBackbone( - kind="fabric", - family=family, - target_params=actual_params, - actual_params=actual_params, - d_hidden=d_hidden, - num_layers=None, - fabric_shape=(width, height), - fabric_hidden_size=hidden_size, - ) - result = run_sequence_case( - match=match, - mode=mode, - batch_size=batch_size, - seq_len=seq_len, - family=family, - device=device, - dtype=dtype, - warmup=warmup, - iterations=iterations, - model=model, - training_output_boundary="terminal", - forward_output_boundary="terminal", - ) - result["boundary_port_count"] = boundary_port_count - result["explicit_boundary_indices"] = True - result["topology_mode"] = topology_mode - return result - - -def _flat_boundary_indices(*, width: int, height: int, count: int) -> tuple[tuple[int, ...], tuple[int, ...]]: - node_count = width * height - if count <= 0: - raise ValueError("boundary port count must be positive") - if 2 * count >= node_count: - raise ValueError( - f"boundary port count {count} is too large for {width}x{height}; " - "input and output sets must leave recurrent nodes" - ) - return tuple(range(count)), tuple(range(node_count - count, node_count)) - - -def _flat_ring_edges( - *, - width: int, - height: int, - input_indices: tuple[int, ...], - output_indices: tuple[int, ...], - degree: int, -) -> tuple[tuple[int, int], ...]: - if degree <= 0: - raise ValueError("degree must be positive") - if degree % 2 != 0: - raise ValueError("flat_ring degree must be even") - node_count = width * height - input_nodes = set(input_indices) - output_nodes = set(output_indices) - recurrent_nodes = tuple(idx for idx in range(node_count) if idx not in input_nodes and idx not in output_nodes) - if not recurrent_nodes: - raise ValueError("flat ring topology needs recurrent nodes") - half_degree = degree // 2 - offsets = tuple(range(1, half_degree + 1)) - edges: list[tuple[int, int]] = [] - recurrent_position = {node: pos for pos, node in enumerate(recurrent_nodes)} - for receiver in recurrent_nodes: - pos = recurrent_position[receiver] - for offset in offsets: - edges.append((receiver, recurrent_nodes[(pos - offset) % len(recurrent_nodes)])) - edges.append((receiver, recurrent_nodes[(pos + offset) % len(recurrent_nodes)])) - for output_idx, receiver in enumerate(output_indices): - pos = output_idx % len(recurrent_nodes) - for offset in offsets: - edges.append((receiver, recurrent_nodes[(pos - offset) % len(recurrent_nodes)])) - edges.append((receiver, recurrent_nodes[(pos + offset) % len(recurrent_nodes)])) - return tuple(edges) - - -def _write_summary( - out_dir: Path, - *, - args: argparse.Namespace, - cases: list[dict[str, Any]], - dtype: torch.dtype, - elapsed_s: float, -) -> None: - grouped: dict[tuple[str, str], list[dict[str, Any]]] = {} - for case in cases: - grouped.setdefault((str(case["family"]), str(case["mode"])), []).append(case) - spread_rows = [] - for (family, mode), rows in grouped.items(): - ok_rows = [row for row in rows if row.get("status") == "ok"] - tokens = [float(row["tokens_per_s"]) for row in ok_rows] - graph_signatures = [ - row.get("planner_signature", {}).get("graph_signature") - for row in ok_rows - if row.get("planner_signature", {}).get("graph_signature") is not None - ] - graph_equivalent = bool(graph_signatures) and all( - signature == graph_signatures[0] for signature in graph_signatures - ) - spread = None - if tokens: - spread = (max(tokens) - min(tokens)) / max(tokens) - spread_rows.append( - { - "family": family, - "mode": mode, - "ok_cases": len(ok_rows), - "total_cases": len(rows), - "graph_equivalent": graph_equivalent, - "throughput_spread": spread, - "min_tokens_per_s": min(tokens) if tokens else None, - "max_tokens_per_s": max(tokens) if tokens else None, - "reference_graph_signature": graph_signatures[0] if graph_signatures else None, - } - ) - summary = { - "benchmark": "fabric_factorization_invariance_audit", - "device": args.device, - "dtype": dtype_name(dtype), - "warmup": args.warmup, - "iterations": args.iterations, - "elapsed_s": elapsed_s, - "shape_is_constructor_only": True, - "topology_mode": args.topology_mode, - "degree": args.degree, - "boundary_port_count": args.boundary_port_count, - "cases": cases, - "spread_rows": spread_rows, - } - _write_json(out_dir / "factorization_invariance_audit.json", summary) - _write_markdown(out_dir / "factorization_invariance_audit.md", summary) - - -def _write_markdown(path: Path, summary: dict[str, Any]) -> None: - lines = [ - "# Fabric Factorization Invariance Audit", - "", - "This audit treats shape as a graph constructor. A row is factorization-equivalent only if its flat graph " - "signature matches the other rows in the group.", - "", - "## Spread", - "", - "| family | mode | graph equivalent | min tok/s | max tok/s | spread |", - "|---|---|---:|---:|---:|---:|", - ] - for row in summary["spread_rows"]: - lines.append( - "| {family} | {mode} | {equiv} | {min_toks} | {max_toks} | {spread} |".format( - family=row["family"], - mode=row["mode"], - equiv="yes" if row["graph_equivalent"] else "no", - min_toks=_fmt(row["min_tokens_per_s"]), - max_toks=_fmt(row["max_tokens_per_s"]), - spread=_fmt_pct(row["throughput_spread"]), - ) - ) - lines.extend( - [ - "", - "## Cases", - "", - "| family | mode | shape | status | tok/s | peak GiB | graph signature |", - "|---|---|---|---|---:|---:|---|", - ] - ) - for case in summary["cases"]: - signature = case.get("planner_signature", {}).get("graph_signature") - signature_text = ( - "none" - if signature is None - else ( - f"nodes={signature['node_count']},in={signature['input_count']},out={signature['output_count']}," - f"recur={signature['recurrent_count']},edges={signature['edge_count']}," - f"degree={signature['degree_histogram']}" - ) - ) - lines.append( - "| {family} | {mode} | {shape} | {status} | {toks} | {mem} | {signature} |".format( - family=case["family"], - mode=case["mode"], - shape=case.get("fabric_shape"), - status=case.get("status"), - toks=_fmt(case.get("tokens_per_s")), - mem=_fmt(case.get("peak_mem_gib")), - signature=signature_text, - ) - ) - path.write_text("\n".join(lines) + "\n") - - -def _parse_families(value: str) -> tuple[BackboneFamily, ...]: - out = tuple(item.strip() for item in value.split(",") if item.strip()) - invalid = [item for item in out if item not in {"slstm", "axoncell"}] - if invalid: - raise ValueError(f"invalid families: {invalid}") - return out # type: ignore[return-value] - - -def _parse_modes(value: str) -> tuple[SequenceMode, ...]: - out = tuple(item.strip() for item in value.split(",") if item.strip()) - invalid = [item for item in out if item not in {"forward", "forward_backward"}] - if invalid: - raise ValueError(f"invalid modes: {invalid}") - return out # type: ignore[return-value] - - -def _parse_shapes(value: str) -> tuple[tuple[int, int], ...]: - shapes = [] - for item in value.split(","): - raw = item.strip().lower() - if not raw: - continue - width, height = raw.split("x", maxsplit=1) - shapes.append((int(width), int(height))) - if not shapes: - raise ValueError("at least one shape is required") - return tuple(shapes) - - -def _resolve_dtype(value: str) -> torch.dtype: - try: - return getattr(torch, value) - except AttributeError as exc: - raise ValueError(f"unknown torch dtype {value!r}") from exc - - -def _write_json(path: Path, payload: object) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, indent=2, sort_keys=True)) - - -def _fmt(value: object) -> str: - if value is None: - return "-" - return f"{float(value):.2f}" - - -def _fmt_pct(value: object) -> str: - if value is None: - return "-" - return f"{100.0 * float(value):.2f}%" - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/benchmarks/run_fabric_mixed_population_audit.py b/benchmarks/run_fabric_mixed_population_audit.py deleted file mode 100644 index 78ba94ad..00000000 --- a/benchmarks/run_fabric_mixed_population_audit.py +++ /dev/null @@ -1,531 +0,0 @@ -from __future__ import annotations - -import argparse -import gc -import json -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Literal - -import cortical.fabric as fabric -import torch -import torch.nn as nn -from cortical.stacks import AxonCellConfig, sLSTMCellConfig -from cortical.stacks.auto import build_cortical_auto_stack - -Mode = Literal["forward", "forward_backward"] -ResetMode = Literal["none", "strided"] -Side = Literal["fabric", "stack"] - - -@dataclass(frozen=True) -class ModelSpec: - kind: Side - model: nn.Module - actual_params: int - stack_layers: int | None = None - - -class BackboneWithHead(nn.Module): - def __init__(self, backbone: nn.Module, *, d_hidden: int) -> None: - super().__init__() - self.backbone = backbone - self.head = nn.Linear(d_hidden, d_hidden) - - def forward( - self, - x: torch.Tensor, - state: object | None = None, - *, - resets: torch.Tensor | None, - ) -> tuple[torch.Tensor, object | None]: - kwargs: dict[str, object] = {} - if resets is not None: - kwargs["resets"] = resets - if hasattr(self.backbone, "_forward_sequence_with_readout"): - kwargs["materialize_final_state"] = False - kwargs["output_boundary"] = "sequence" - y, next_state = self.backbone(x, state, **kwargs) - return self.head(y), next_state - - -def main() -> int: - args = _parser().parse_args() - device = torch.device(args.device) - dtype = _parse_dtype(args.dtype) - out_dir = Path(args.out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - rows: list[dict[str, object]] = [] - for shape in _parse_shapes(args.shapes): - for batch_size in _parse_ints(args.batches): - for seq_len in _parse_ints(args.seq_lens): - for inner_steps in _parse_ints(args.inner_steps): - fabric = _build_fabric_model( - width=shape[0], - height=shape[1], - hidden_size=args.fabric_hidden_size, - d_hidden=args.d_hidden, - inner_steps=inner_steps, - device=device, - dtype=dtype, - ) - stack = _build_nearest_stack_model( - target_params=fabric.actual_params, - max_layers=args.max_stack_layers, - d_hidden=args.d_hidden, - device=device, - dtype=dtype, - ) - for mode in _parse_modes(args.modes): - for reset_mode in _parse_reset_modes(args.reset_modes): - rows.append( - _run_pair( - fabric=fabric, - stack=stack, - width=shape[0], - height=shape[1], - hidden_size=args.fabric_hidden_size, - d_hidden=args.d_hidden, - batch_size=batch_size, - seq_len=seq_len, - inner_steps=inner_steps, - mode=mode, - reset_mode=reset_mode, - device=device, - dtype=dtype, - warmup=args.warmup, - iterations=args.iterations, - ) - ) - _write_outputs(out_dir, rows) - _write_outputs(out_dir, rows) - print(json.dumps({"benchmark": "fabric_mixed_population_audit", "out_dir": str(out_dir), "rows": len(rows)})) - return 0 - - -def _parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Run mixed-population Fabric vs stack baseline audit.") - parser.add_argument( - "--out-dir", - default="docs/user/subho/fabric_benchmark/results/profiles/fabric_mixed_population_audit_current", - ) - parser.add_argument("--shapes", default="16x16,32x32") - parser.add_argument("--batches", default="512,2048") - parser.add_argument("--seq-lens", default="1,4") - parser.add_argument("--inner-steps", default="1") - parser.add_argument("--modes", default="forward") - parser.add_argument("--reset-modes", default="none,strided") - parser.add_argument("--d-hidden", type=int, default=128) - parser.add_argument("--fabric-hidden-size", type=int, default=32) - parser.add_argument("--max-stack-layers", type=int, default=16) - parser.add_argument("--warmup", type=int, default=2) - parser.add_argument("--iterations", type=int, default=5) - parser.add_argument("--device", default="cuda") - parser.add_argument("--dtype", default="float32") - return parser - - -def _parse_dtype(name: str) -> torch.dtype: - if name == "float32": - return torch.float32 - if name == "bfloat16": - return torch.bfloat16 - if name == "float16": - return torch.float16 - raise ValueError(f"Unsupported dtype {name!r}") - - -def _parse_ints(value: str) -> tuple[int, ...]: - return tuple(int(item) for item in value.split(",") if item) - - -def _parse_modes(value: str) -> tuple[Mode, ...]: - modes = tuple(item for item in value.split(",") if item) - invalid = [mode for mode in modes if mode not in {"forward", "forward_backward"}] - if invalid: - raise ValueError(f"Unsupported modes: {invalid}") - return modes # type: ignore[return-value] - - -def _parse_reset_modes(value: str) -> tuple[ResetMode, ...]: - modes = tuple(item for item in value.split(",") if item) - invalid = [mode for mode in modes if mode not in {"none", "strided"}] - if invalid: - raise ValueError(f"Unsupported reset modes: {invalid}") - return modes # type: ignore[return-value] - - -def _parse_shapes(value: str) -> tuple[tuple[int, int], ...]: - shapes: list[tuple[int, int]] = [] - for item in value.split(","): - if not item: - continue - width, height = item.lower().split("x", maxsplit=1) - shapes.append((int(width), int(height))) - return tuple(shapes) - - -def _build_fabric_model( - *, - width: int, - height: int, - hidden_size: int, - d_hidden: int, - inner_steps: int, - device: torch.device, - dtype: torch.dtype, -) -> ModelSpec: - graph = fabric.graphs.lattice2d.Graph( - width=width, - height=height, - populations={ - "slstm": fabric.Population( - cell=fabric.cells.SLSTM(hidden_dim=hidden_size), - nodes=fabric.graphs.lattice2d.Region(x=(0.0, 0.5)), - ), - "axoncell": fabric.Population( - cell=fabric.cells.AxonCell(hidden_dim=hidden_size), - nodes=fabric.graphs.lattice2d.Region(x=(0.5, 1.0)), - ), - }, - ) - blueprint = fabric.Blueprint( - interface=fabric.Interface(public_dim=hidden_size, message_dim=hidden_size), - graph=graph, - inputs={"tokens": fabric.Input(dim=d_hidden)}, - outputs={"prediction": fabric.Output(dim=d_hidden)}, - message_passing=fabric.message_rules.DotProduct( - head_dim=hidden_size, - kv_sharing=fabric.message_rules.ShareBySenderTile(tile_shape=(4, 4)), - ), - execution=fabric.ExecutionSpec(backend="auto", inner_steps=inner_steps), - ) - backbone = fabric.compile(blueprint) - model = BackboneWithHead(backbone, d_hidden=d_hidden).to(device=device, dtype=dtype) - return ModelSpec(kind="fabric", model=model, actual_params=_count_params(model)) - - -def _build_nearest_stack_model( - *, - target_params: int, - max_layers: int, - d_hidden: int, - device: torch.device, - dtype: torch.dtype, -) -> ModelSpec: - best_model: nn.Module | None = None - best_params: int | None = None - best_layers: int | None = None - best_delta: int | None = None - for num_layers in range(1, max_layers + 1): - backbone = build_cortical_auto_stack( - d_hidden=d_hidden, - num_layers=num_layers, - layers=[[AxonCellConfig(), sLSTMCellConfig()] for _ in range(num_layers)], - post_norm=True, - compile_scaffolds=False, - ) - model = BackboneWithHead(backbone, d_hidden=d_hidden) - params = _count_params(model) - delta = abs(params - target_params) - if best_delta is None or delta < best_delta: - if best_model is not None: - del best_model - best_model = model - best_params = params - best_layers = num_layers - best_delta = delta - else: - del model - gc.collect() - if best_delta is not None and params > target_params and delta > best_delta: - break - if best_model is None or best_params is None or best_layers is None: - raise RuntimeError("No stack candidates were built for parameter matching.") - return ModelSpec( - kind="stack", - model=best_model.to(device=device, dtype=dtype), - actual_params=best_params, - stack_layers=best_layers, - ) - - -def _count_params(model: nn.Module) -> int: - return sum(parameter.numel() for parameter in model.parameters()) - - -def _run_pair( - *, - fabric: ModelSpec, - stack: ModelSpec, - width: int, - height: int, - hidden_size: int, - d_hidden: int, - batch_size: int, - seq_len: int, - inner_steps: int, - mode: Mode, - reset_mode: ResetMode, - device: torch.device, - dtype: torch.dtype, - warmup: int, - iterations: int, -) -> dict[str, object]: - torch.manual_seed(1234) - if device.type == "cuda": - torch.cuda.manual_seed_all(1234) - generator = torch.Generator(device=device).manual_seed(1234) - x = torch.randn(batch_size, seq_len, d_hidden, device=device, dtype=dtype, generator=generator) - target = ( - None - if mode == "forward" - else torch.randn(batch_size, seq_len, d_hidden, device=device, dtype=dtype, generator=generator) - ) - resets = _make_resets(batch_size=batch_size, seq_len=seq_len, mode=reset_mode, device=device) - fabric_result = _measure_model( - fabric.model, - x=x, - target=target, - resets=resets, - warmup=warmup, - iterations=iterations, - device=device, - ) - stack_result = _measure_model( - stack.model, - x=x, - target=target, - resets=resets, - warmup=warmup, - iterations=iterations, - device=device, - ) - fabric_backbone = fabric.model.backbone if isinstance(fabric.model, BackboneWithHead) else fabric.model - fabric_record = getattr(fabric_backbone, "runtime", fabric_backbone) - backend_record = getattr(fabric_record, "last_backend_execution", None) - ratio = None - fabric_tokens = fabric_result.get("tokens_per_s") - stack_tokens = stack_result.get("tokens_per_s") - if isinstance(fabric_tokens, (int, float)) and isinstance(stack_tokens, (int, float)): - ratio = float(fabric_tokens) / float(stack_tokens) - return { - "shape": [width, height], - "fabric_hidden_size": hidden_size, - "d_hidden": d_hidden, - "batch_size": batch_size, - "seq_len": seq_len, - "inner_steps": inner_steps, - "mode": mode, - "reset_mode": reset_mode, - "fabric_actual_params": fabric.actual_params, - "stack_actual_params": stack.actual_params, - "stack_layers": stack.stack_layers, - "param_error_vs_stack": (fabric.actual_params - stack.actual_params) / max(1, stack.actual_params), - "fabric": fabric_result, - "stack": stack_result, - "fabric_stack_ratio": ratio, - "fabric_backend": _backend_record_summary(backend_record), - } - - -def _make_resets(*, batch_size: int, seq_len: int, mode: ResetMode, device: torch.device) -> torch.Tensor | None: - if mode == "none": - return None - resets = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) - if seq_len == 1: - resets[::17, 0] = True - else: - resets[::17, seq_len // 2] = True - return resets - - -def _measure_model( - model: nn.Module, - *, - x: torch.Tensor, - target: torch.Tensor | None, - resets: torch.Tensor | None, - warmup: int, - iterations: int, - device: torch.device, -) -> dict[str, object]: - try: - for _ in range(warmup): - _run_iteration(model, x=x, target=target, resets=resets) - if device.type == "cuda": - torch.cuda.synchronize(device) - torch.cuda.reset_peak_memory_stats(device) - start = time.perf_counter() - for _ in range(iterations): - _run_iteration(model, x=x, target=target, resets=resets) - if device.type == "cuda": - torch.cuda.synchronize(device) - peak_mem = torch.cuda.max_memory_allocated(device) / (1024**3) - else: - peak_mem = None - dt = (time.perf_counter() - start) / max(1, iterations) - return { - "status": "ok", - "ms": dt * 1000.0, - "tokens_per_s": (x.shape[0] * x.shape[1]) / dt, - "peak_mem_gib": peak_mem, - } - except RuntimeError as exc: - if device.type == "cuda": - torch.cuda.empty_cache() - return {"status": "error", "error": str(exc)} - - -def _run_iteration( - model: nn.Module, - *, - x: torch.Tensor, - target: torch.Tensor | None, - resets: torch.Tensor | None, -) -> None: - model.zero_grad(set_to_none=True) - if target is None: - with torch.no_grad(): - model(x, None, resets=resets) - return - output, state = model(x, None, resets=resets) - del state - loss = torch.nn.functional.mse_loss(output, target) - loss.backward() - - -def _backend_record_summary(record: object | None) -> dict[str, object] | None: - if record is None: - return None - tuple_fields = ( - "launch_temporal_executions", - "launch_scan_implementations", - "physical_op_kinds", - "physical_op_executors", - "physical_op_demotions", - "physical_boundary_contracts", - "active_receiver_window_modes", - "active_receiver_window_offsets", - "active_receiver_window_counts", - "backward_physical_op_kinds", - "backward_physical_op_executors", - "backward_physical_op_demotions", - "backward_boundary_contracts", - "backward_tape_mode", - "backward_recompute_mode", - "backward_launch_counts", - "backward_saved_launch_counts", - "backward_owner_timing_ms", - "backward_owner_wall_ms", - "backward_residual_glue_demotions", - ) - return { - "backend_name": getattr(record, "backend_name", None), - "surface_key": getattr(record, "surface_key", None), - "cell_type": getattr(record, "cell_type", None), - "regime": getattr(record, "regime", None), - "training": getattr(record, "training", None), - "batch_size": getattr(record, "batch_size", None), - "time_steps": getattr(record, "time_steps", None), - "inner_steps": getattr(record, "inner_steps", None), - **{field: list(getattr(record, field, ()) or ()) for field in tuple_fields}, - } - - -def _write_outputs(out_dir: Path, rows: list[dict[str, object]]) -> None: - (out_dir / "mixed_population_audit.json").write_text( - json.dumps({"benchmark": "fabric_mixed_population_audit", "rows": rows}, indent=2), - ) - lines = [ - "# Fabric Mixed-Population Audit", - "", - ( - "| shape | B | T | K | mode | resets | Fabric params | Stack params | stack layers | " - "Fabric tok/s | Stack tok/s | ratio % | Fabric GiB | Stack GiB | backend | backward executors | " - "backward demotions |" - ), - "|---|---:|---:|---:|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---|---|---|", - ] - for row in rows: - fabric = row["fabric"] - stack = row["stack"] - shape = row["shape"] - assert isinstance(fabric, dict) - assert isinstance(stack, dict) - assert isinstance(shape, list) - ratio = row["fabric_stack_ratio"] - ratio_value = float(ratio) if isinstance(ratio, (int, float)) else None - backend = row["fabric_backend"] - backend_name = "" - if isinstance(backend, dict): - executors = backend.get("physical_op_executors", []) - temporal_executions = backend.get("launch_temporal_executions", []) - scan_implementations = backend.get("launch_scan_implementations", []) - executor_text = ",".join(str(item) for item in executors) if isinstance(executors, list) else "" - temporal_text = ( - ",".join(str(item) for item in temporal_executions) if isinstance(temporal_executions, list) else "" - ) - scan_text = ( - ",".join(str(item) for item in scan_implementations) if isinstance(scan_implementations, list) else "" - ) - backend_name = "/".join( - str(part) - for part in ( - backend.get("backend_name"), - backend.get("surface_key"), - temporal_text, - scan_text, - executor_text, - ) - if part - ) - backward_executors = backend.get("backward_physical_op_executors", []) - backward_demotions = backend.get("backward_physical_op_demotions", []) - backward_executor_text = ( - ",".join(str(item) for item in backward_executors) if isinstance(backward_executors, list) else "" - ) - backward_demotion_text = ( - ",".join(str(item) for item in backward_demotions) if isinstance(backward_demotions, list) else "" - ) - else: - backward_executor_text = "" - backward_demotion_text = "" - lines.append( - ( - "| {shape} | {B} | {T} | {K} | {mode} | {resets} | {fp} | {sp} | {layers} | " - "{ft} | {st} | {ratio} | {fm} | {sm} | {backend} | {bwd_exec} | {bwd_demo} |" - ).format( - shape="x".join(str(value) for value in shape), - B=row["batch_size"], - T=row["seq_len"], - K=row.get("inner_steps", 1), - mode=row["mode"], - resets=row["reset_mode"], - fp=row["fabric_actual_params"], - sp=row["stack_actual_params"], - layers=row["stack_layers"], - ft=_fmt(fabric.get("tokens_per_s")), - st=_fmt(stack.get("tokens_per_s")), - ratio=_fmt(None if ratio_value is None else 100.0 * ratio_value), - fm=_fmt(fabric.get("peak_mem_gib")), - sm=_fmt(stack.get("peak_mem_gib")), - backend=backend_name, - bwd_exec=backward_executor_text, - bwd_demo=backward_demotion_text, - ) - ) - (out_dir / "mixed_population_audit.md").write_text("\n".join(lines) + "\n") - - -def _fmt(value: object) -> str: - if value is None: - return "" - if isinstance(value, float): - return f"{value:.2f}" - return str(value) - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/benchmarks/run_fabric_scaling_profile.py b/benchmarks/run_fabric_scaling_profile.py deleted file mode 100644 index 0bf72f1c..00000000 --- a/benchmarks/run_fabric_scaling_profile.py +++ /dev/null @@ -1,2764 +0,0 @@ -from __future__ import annotations - -import argparse -import json -import math -import os -import shutil -import sys -from contextlib import contextmanager -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Literal - -import pynvml -import torch - -_REPO_ROOT = Path(__file__).resolve().parents[1] -_CORTICAL_SRC = _REPO_ROOT / "src" -for _path in (_REPO_ROOT, _CORTICAL_SRC): - if str(_path) not in sys.path: - sys.path.insert(0, str(_path)) - -from benchmarks.fabric_suite_common import ( # noqa: E402 - BackboneFamily, - MatchedBackbone, - SequenceMode, - find_param_matched_backbone, - make_sequence_model, - make_sequence_training_target, - run_sequence_case, - sequence_training_loss, -) -from cortical.fabric.backend.surfaces import BackendExecutionRecord # noqa: E402 - -Side = Literal["stack", "fabric"] -BackwardAttributionMode = Literal[ - "active", - "phase_decomposed_probe", - "state_public_output_probe", - "state_public_state_probe", - "full_replay_boundary_probe", - "full_replay_no_parameter_probe", - "full_replay_without_boundary_inputs_probe", - "full_replay_no_parameter_boundary_probe", -] - -BASE_BATCHES = (256, 512, 1024, 2048, 4096) -LEGACY_KERNEL_PATTERNS = ( - "fabric_slstm_cell_step_kernel", - "fabric_axon_cell_step_kernel", - "generic_cell_step", - "receiver_owned_stepwise_kernel", - "receiver_apply_from_message", - "receiver_apply_from_normalized_message", - "forward_sequence_major", - "forward_receiver_major_grouped_gemm", - "forward_edge_major_grouped_gemm", -) -GENERIC_FABRIC_KERNEL_PATTERNS = ( - "diagonal_recurrence_complex_exp_update_emit_kernel", - "diagonal_recurrence_complex_exp_fresh_zero_update_emit_kernel", - "diagonal_recurrence_complex_exp_core_update_emit_tiled_kernel", - "regular_local_tiny_message_projected_kernel", - "regular_local_tiny_message_projected_rowgroup8_kernel", - "pack_regular_local_message_keys_kernel", - "pack_sparse_message_keys_kernel", - "pack_ragged_sparse_message_keys_kernel", - "regular_local_message_softmax_inplace_kernel", - "sparse_message_softmax_inplace_kernel", - "ragged_sparse_message_softmax_inplace_kernel", - "pack_regular_local_message_values_kernel", - "pack_sparse_message_values_kernel", - "pack_ragged_sparse_message_values_kernel", - "receiver_message_aggregate_kernel", - "receiver_state_update_kernel", - "receiver_state_update_emit_kernel", - "receiver_reduce_stats_kernel", - "receiver_emit_raw_public_kernel", - "edge_owned_accumulate", - "readout_message_kernel", - "readout_message_from_raw_public_kernel", - "gemm", -) -REQUIRED_FORWARD_BASE_KERNEL_PATTERNS = ( - "readout_message", - "gemm", -) -REQUIRED_FORWARD_MESSAGE_KERNEL_ALTERNATIVES = ( - "regular_local_tiny_message_projected_kernel", - "regular_local_tiny_message_projected_rowgroup8_kernel", - "pack_regular_local_message_keys_kernel", - "pack_sparse_message_keys_kernel", - "pack_ragged_sparse_message_keys_kernel", - "receiver_message_aggregate_kernel", -) -REQUIRED_SLSTM_FORWARD_PHASE_KERNEL_PATTERNS = ("receiver_state_update",) -REQUIRED_AXON_FORWARD_PHASE_KERNEL_PATTERNS = ("diagonal_recurrence_complex_exp",) -MESSAGE_PROFILE_KERNEL_PATTERNS = ( - "regular_local_tiny_message_projected_kernel", - "regular_local_tiny_message_projected_rowgroup8_kernel", - "receiver_message_aggregate_kernel", - "receiver_message_project_kernel", - "pack_regular_local_message_keys_kernel", - "regular_local_message_softmax_inplace_kernel", - "pack_regular_local_message_values_kernel", - "scatter_receiver_major_message_kernel", - "pack_sparse_message_keys_kernel", - "sparse_message_softmax_inplace_kernel", - "pack_sparse_message_values_kernel", - "pack_ragged_sparse_message_keys_kernel", - "ragged_sparse_message_softmax_inplace_kernel", - "pack_ragged_sparse_message_values_kernel", - "receiver_state_update_emit_kernel", -) -PHYSICAL_OP_PROFILE_PATTERNS = ( - "fabric.physical.receiver_affine", - "fabric.physical.state_affine", - "fabric.physical.diagonal_recurrence", - "fabric.physical.message", - "fabric.physical.input_projection", - "fabric.physical.public_projection", - "fabric.physical.readout", - "fabric.physical.state_epilogue", -) -BACKWARD_OWNER_PROFILE_PATTERNS = ( - "fabric.backward.total", - "fabric.backward.receiver_affine", - "fabric.backward.message.receiver", - "fabric.backward.message.sender", - "fabric.backward.message.query_param", - "fabric.backward.message.autograd", - "fabric.backward.tiny_message_superop", - "fabric.backward.sparse_message_superop", - "fabric.backward.grouped_projection", - "fabric.backward.receiver_major_projection", - "fabric.backward.input_projection.receiver_affine", - "fabric.backward.gate_affine.receiver_affine", - "fabric.backward.recurrent_affine.receiver_affine", - "fabric.backward.diagonal_recurrence", - "fabric.backward.state_epilogue", - "fabric.backward.public_projection", - "fabric.backward.readout", - "fabric.backward.glue.initial_recurrent", - "fabric.backward.glue.param_grad_binding", - "fabric.backward.phase_decomposed_probe", - "fabric.backward.replay.output_sequence.boundary_inputs", - "fabric.backward.replay.output_sequence.packed_state_inputs", - "fabric.backward.replay.output_sequence.recurrent_carry_inputs", - "fabric.backward.replay.output_sequence.parameter_inputs", - "fabric.backward.replay.no_parameter_inputs", - "fabric.backward.replay.without_boundary_inputs", - "fabric.backward.replay.no_parameter_output_sequence.boundary_inputs", - "fabric.backward.replay.no_parameter_output_sequence.packed_state_inputs", - "fabric.backward.replay.no_parameter_output_sequence.recurrent_carry_inputs", - "fabric.backward.replay.output_sequence", - "fabric.backward.replay.next_packed_state", - "fabric.backward.replay.recurrent_carry", - "fabric.backward.replay.input_kv_last", - "fabric.backward.full_replay_autograd", - "fabric.backward.reference_replay_autograd", -) -BACKWARD_RECOMPUTE_PROFILE_PATTERNS = PHYSICAL_OP_PROFILE_PATTERNS -BACKWARD_OWNER_SOURCE_PATTERNS = tuple( - dict.fromkeys( - pattern - for pattern in BACKWARD_OWNER_PROFILE_PATTERNS + BACKWARD_RECOMPUTE_PROFILE_PATTERNS - if pattern - not in { - "fabric.backward.total", - "fabric.backward.phase_decomposed_probe", - "fabric.backward.replay.output_sequence.boundary_inputs", - "fabric.backward.replay.output_sequence.packed_state_inputs", - "fabric.backward.replay.output_sequence.recurrent_carry_inputs", - "fabric.backward.replay.output_sequence.parameter_inputs", - "fabric.backward.replay.no_parameter_inputs", - "fabric.backward.replay.without_boundary_inputs", - "fabric.backward.replay.no_parameter_output_sequence.boundary_inputs", - "fabric.backward.replay.no_parameter_output_sequence.packed_state_inputs", - "fabric.backward.replay.no_parameter_output_sequence.recurrent_carry_inputs", - "fabric.backward.replay.output_sequence", - "fabric.backward.replay.next_packed_state", - "fabric.backward.replay.recurrent_carry", - "fabric.backward.replay.input_kv_last", - "fabric.backward.full_replay_autograd", - "fabric.backward.reference_replay_autograd", - } - ) -) -BACKWARD_OWNER_EVENT_SOURCE_PATTERNS = ( - "_FabricLocalMessageCUDABackward", - "_FabricLocalMessagePartitionedCUDABackward", - "_FabricSparseMessageCUDABackward", - "_FabricSparseMessagePartitionedCUDABackward", - "_FabricLocalMessageCUDA", - "_FabricLocalMessagePartitionedCUDA", - "_FabricSparseMessageCUDA", - "_FabricSparseMessagePartitionedCUDA", - "_FabricGroupedProjectionCUDABackward", - "_FabricGroupedProjectionCUDA", - "_ReceiverMajorAffineBmmFunctionBackward", - "_ReceiverMajorAffineBmmFunction", - "_ReceiverMajorAffineSmallBatchFunctionBackward", - "_ReceiverMajorAffineSmallBatchFunction", - "_ReceiverMajorAffineNoBiasSmallBatchFunctionBackward", - "_ReceiverMajorAffineNoBiasSmallBatchFunction", - "_ReceiverMajorAffineBiasFunctionBackward", - "_ReceiverMajorAffineBiasFunction", - "_SharedReceiverBiasLinearFunctionBackward", - "_SharedReceiverBiasLinearFunction", - "_PerReceiverLinearFunctionBackward", - "_PerReceiverLinearFunction", - "_HeadGroupedGateLinearFunctionBackward", - "_HeadGroupedGateLinearFunction", - "_RecurrentMatmulFunctionBackward", - "_RecurrentMatmulFunction", - "_PerCellOutnormFunctionBackward", - "_PerCellOutnormFunction", - "_GatedLogspaceRecurrenceOutnormFunctionBackward", - "_GatedLogspaceRecurrenceOutnormFunction", - "_GatedLogspaceRecurrenceOutnormCUDAFunctionBackward", - "_GatedLogspaceRecurrenceOutnormCUDAFunction", - "_ResetRowsManyFunctionBackward", - "_ResetRowsManyFunction", -) -BACKWARD_OWNER_PARENT_SOURCE_PATTERNS = BACKWARD_OWNER_EVENT_SOURCE_PATTERNS -BACKWARD_DERIVED_OWNER_RULES = ( - ( - "fabric.backward.derived.diagonal_recurrence", - (), - ( - "_LinearRTUFunctionDiag", - "DiagonalRecurrence", - "diagonal_recurrence", - ), - ), - ( - "fabric.backward.derived.lowered_projection", - ( - "aten::bmm", - "aten::mm", - "aten::matmul", - "aten::einsum", - ), - ( - "BmmBackward", - "MmBackward", - "AddmmBackward", - "MatmulBackward", - "EinsumBackward", - ), - ), - ( - "fabric.backward.derived.state_public_epilogue", - ( - "aten::where", - "aten::mul", - "aten::add", - "aten::add_", - "aten::sub", - "aten::div", - "aten::eq", - "aten::gt", - "aten::lt", - "aten::minimum", - "aten::maximum", - "aten::masked_fill_", - "aten::log_sigmoid", - "aten::log_sigmoid_backward", - "aten::tanh_backward", - "aten::sigmoid", - "aten::silu", - "aten::exp", - "aten::sqrt", - "aten::clamp", - "aten::cos", - "aten::neg", - ), - ( - "WhereBackward", - "MulBackward", - "AddBackward", - "SubBackward", - "DivBackward", - "MaximumBackward", - "MinimumBackward", - "ExpBackward", - "TanhBackward", - "SigmoidBackward", - "SiluBackward", - "LogSigmoidBackward", - ), - ), - ( - "fabric.backward.derived.boundary_glue", - ( - "aten::copy_", - "aten::clone", - "aten::cat", - "aten::stack", - "aten::fill_", - "aten::zero_", - "aten::zeros", - "aten::zeros_like", - "aten::empty", - "aten::empty_like", - "aten::empty_strided", - "aten::select_backward", - "aten::index_select", - "aten::index_select_backward", - "aten::scalar_tensor", - "aten::_local_scalar_dense", - ), - ( - "CloneBackward", - "SplitWithSizesBackward", - "UnbindBackward", - "SelectBackward", - "IndexBackward", - "IndexSelectBackward", - "ViewBackward", - "TransposeBackward", - "PermuteBackward", - ), - ), -) -BACKWARD_DERIVED_BOUNDARY_GLUE_PARENT_PREFIXES = ("fabric.glue.",) -BACKWARD_SMALL_CUBLAS_PROFILE_PATTERNS = ( - "gemv2N_kernel", - "xmma", - "cublas", - "gemm", -) -BACKWARD_ATTRIBUTION_CONTAINER_PREFIXES = ( - "fabric.backward.", - "fabric.physical.", - "fabric.glue.", -) -GLUE_PROFILE_KERNEL_PATTERNS = ( - "aten::copy_", - "Memcpy DtoD", - "direct_copy_kernel_cuda", - "add_receiver_bias_kernel", - "append_bias_feature_kernel", - "append_bias_weight_kernel", - "receiver_major_copy_or_pad_kernel", - "receiver_major_split_last_dim_kernel", - "aten::cat", - "CatArrayBatchedCopy", - "aten::clone", - "aten::fill_", -) -GLUE_COVERAGE_KERNEL_PATTERNS = ( - "aten::copy_", - "add_receiver_bias_kernel", - "append_bias_feature_kernel", - "append_bias_weight_kernel", - "receiver_major_copy_or_pad_kernel", - "receiver_major_split_last_dim_kernel", - "aten::cat", - "aten::fill_", -) -GLUE_SOURCE_PROFILE_PATTERNS = ( - "fabric.glue.launch_state_tree_contiguous", - "fabric.glue.launch_cell_param_contiguous", - "fabric.glue.launch_input_projection_param_contiguous", - "fabric.glue.launch_public_projection_param_contiguous", - "fabric.glue.launch_initial_public_contiguous", - "fabric.glue.launch_initial_public_project_from_hidden", - "fabric.glue.dense_affine_python_contiguous", - "fabric.glue.tensor_table_pack", - "fabric.glue.empty_tensor_table_pack", - "fabric.glue.build_local_topology", - "fabric.glue.build_sparse_topology", - "fabric.glue.output_readout_arg_contiguous", - "fabric.glue.output_readout_step_index", - "fabric.glue.population_param_materialization", - "fabric.glue.static_tensor_contiguous", - "fabric.glue.static_tensor_cat", - "fabric.glue.runtime_init_state", - "fabric.glue.backend_population_state_zero", - "fabric.glue.collect_cell_tensors_contiguous", - "fabric.glue.population_to_backend_state_contiguous", - "fabric.glue.backend_to_population_state_contiguous", - "fabric.glue.launch_initial_public_zero_kv", - "fabric.glue.launch_persistent_scan_initial_copy", - "fabric.glue.launch_message_param_contiguous", - "fabric.glue.launch_input_ports_contiguous", - "fabric.glue.launch_readout_param_contiguous", - "fabric.glue.launch_public_tree_contiguous", - "fabric.glue.launch_routing_tensor_contiguous", - "fabric.glue.launch_topology_contiguous", - "fabric.glue.launch_resets_contiguous", - "fabric.glue.default_resets_zero", - "fabric.glue.reset_mask_to_u8", - "fabric.glue.normalized_population_resets", - "fabric.glue.materialize_next_state_cat", - "fabric.glue.output_sequence_cat", - "fabric.glue.grouped_kv_weight_cat", - "fabric.glue.runtime_sparse_message_clone", - "fabric.glue.dense_affine_bias_add", - "fabric.glue.dense_affine_bias_input_pack", - "fabric.glue.dense_affine_bias_weight_pack", - "fabric.glue.dense_affine_reset_source_pack", - "fabric.glue.backward_replay_leaf_boundary", - "fabric.glue.backward_replay_leaf_packed_state", - "fabric.glue.backward_replay_leaf_recurrent_carry", - "fabric.glue.receiver_major_copy_or_pad", - "fabric.glue.receiver_major_split_last_dim", - "fabric.glue.public_projection_kv_workspace", - "fabric.glue.message_ragged_output_zero", - "fabric.glue.message_degree_ptr_cpu", - "fabric.non_owned.benchmark_output_head", -) -GLUE_COVERAGE_SOURCE_PATTERNS = tuple( - pattern - for pattern in GLUE_SOURCE_PROFILE_PATTERNS - if pattern - not in { - "fabric.glue.launch_initial_public_project_from_hidden", - "fabric.non_owned.benchmark_output_head", - } -) -OLD_MESSAGE_SELF_CUDA_MS = { - "slstm_2m_t32_forward": "431.941 ms receiver_message_aggregate_kernel", - "slstm_100m_t1_forward": "194.856 ms receiver_message_aggregate_kernel", - "axon_2m_t32_forward": "516.2 ms receiver_message_project_kernel", -} -RUNTIME_OP_PATTERNS = ("copy_", "cat", "clone", "zero", "zeros", "empty", "cudaLaunch", "cudaGraphLaunch") -GLUE_TOTAL_ROW_LIMIT = 0.10 -GLUE_SINGLE_SOURCE_ROW_LIMIT = 0.05 -GLUE_ATTRIBUTION_COVERAGE_LIMIT = 0.90 - - -@dataclass(frozen=True) -class RowSpec: - row_id: str - family: BackboneFamily - params_label: str - target_params: int - mode: SequenceMode - seq_len: int - - -@dataclass(frozen=True) -class EventSummary: - name: str - count: int - self_cuda_us: float - self_cpu_us: float - cuda_total_us: float - cpu_total_us: float - - -@dataclass(frozen=True) -class ProfileSummary: - row_id: str - side: Side - batch_size: int - use_resets: bool - status: str - tokens_per_s: float - ms: float - effective_batch_tile: int | None - launch_count_cuda_events: int - launch_count_runtime_events: int - top_cuda: tuple[EventSummary, ...] - top_cpu: tuple[EventSummary, ...] - runtime_ops: tuple[EventSummary, ...] - message_kernel_events: tuple[EventSummary, ...] - physical_op_events: tuple[EventSummary, ...] - backward_owner_events: tuple[EventSummary, ...] - backward_recompute_events: tuple[EventSummary, ...] - backward_derived_owner_events: tuple[EventSummary, ...] - backward_derived_owner_source_events: tuple[EventSummary, ...] - backward_unattributed_events: tuple[EventSummary, ...] - backward_small_cublas_events: tuple[EventSummary, ...] - backward_copy_glue_events: tuple[EventSummary, ...] - backward_owner_cuda_total_us: float - backward_owner_attributed_cuda_total_us: float - backward_explicit_owner_attributed_cuda_total_us: float - backward_derived_owner_attributed_cuda_total_us: float - backward_attribution_coverage: float - backward_explicit_attribution_coverage: float - backward_derived_attribution_coverage: float - backward_attribution_mode: str - glue_kernel_events: tuple[EventSummary, ...] - glue_source_events: tuple[EventSummary, ...] - glue_source_cuda_total_us: float - glue_kernel_self_cuda_us: float - glue_attribution_coverage: float - kernel_names: tuple[str, ...] - generic_kernel_patterns_present: tuple[str, ...] - legacy_kernel_patterns_present: tuple[str, ...] - backend_record: dict[str, object] | None - - -DEFAULT_ROWS: tuple[RowSpec, ...] = ( - RowSpec("slstm_2m_t1_forward", "slstm", "2M", 2_000_000, "forward", 1), - RowSpec("slstm_2m_t32_forward", "slstm", "2M", 2_000_000, "forward", 32), - RowSpec("slstm_2m_t32_train", "slstm", "2M", 2_000_000, "forward_backward", 32), - RowSpec("axon_2m_t32_forward", "axoncell", "2M", 2_000_000, "forward", 32), - RowSpec("axon_2m_t32_train", "axoncell", "2M", 2_000_000, "forward_backward", 32), - RowSpec("slstm_100m_t1_forward", "slstm", "100M", 100_000_000, "forward", 1), - RowSpec("slstm_100m_t1_train", "slstm", "100M", 100_000_000, "forward_backward", 1), - RowSpec("axon_100m_t1_forward", "axoncell", "100M", 100_000_000, "forward", 1), - RowSpec("axon_100m_t1_train", "axoncell", "100M", 100_000_000, "forward_backward", 1), - RowSpec("slstm_500m_t1_forward", "slstm", "500M", 500_000_000, "forward", 1), - RowSpec("axon_500m_t1_forward", "axoncell", "500M", 500_000_000, "forward", 1), -) -LARGE_PARAMETER_ROWS: tuple[RowSpec, ...] = ( - RowSpec("slstm_1b_t1_forward", "slstm", "1B", 1_000_000_000, "forward", 1), - RowSpec("slstm_2b_t1_forward", "slstm", "2B", 2_000_000_000, "forward", 1), - RowSpec("slstm_4b_t1_forward", "slstm", "4B", 4_000_000_000, "forward", 1), - RowSpec("slstm_7b_t1_forward", "slstm", "7B", 7_000_000_000, "forward", 1), - RowSpec("axon_1b_t1_forward", "axoncell", "1B", 1_000_000_000, "forward", 1), - RowSpec("axon_2b_t1_forward", "axoncell", "2B", 2_000_000_000, "forward", 1), - RowSpec("axon_4b_t1_forward", "axoncell", "4B", 4_000_000_000, "forward", 1), - RowSpec("axon_7b_t1_forward", "axoncell", "7B", 7_000_000_000, "forward", 1), - RowSpec("slstm_500m_t1_train", "slstm", "500M", 500_000_000, "forward_backward", 1), - RowSpec("slstm_1b_t1_train", "slstm", "1B", 1_000_000_000, "forward_backward", 1), - RowSpec("slstm_2b_t1_train", "slstm", "2B", 2_000_000_000, "forward_backward", 1), - RowSpec("slstm_1b_t32_train", "slstm", "1B", 1_000_000_000, "forward_backward", 32), - RowSpec("axon_500m_t1_train", "axoncell", "500M", 500_000_000, "forward_backward", 1), - RowSpec("axon_1b_t1_train", "axoncell", "1B", 1_000_000_000, "forward_backward", 1), - RowSpec("axon_2b_t1_train", "axoncell", "2B", 2_000_000_000, "forward_backward", 1), - RowSpec("axon_1b_t32_train", "axoncell", "1B", 1_000_000_000, "forward_backward", 32), -) -ROWS: tuple[RowSpec, ...] = DEFAULT_ROWS + LARGE_PARAMETER_ROWS - - -def _parse_batches(value: str) -> tuple[int, ...]: - return tuple(int(part) for part in value.split(",") if part) - - -def _parse_optional_batches(value: str | None) -> tuple[int, ...] | None: - if value is None or not value.strip(): - return None - return _parse_batches(value) - - -def _row_by_id(row_id: str) -> RowSpec: - for row in ROWS: - if row.row_id == row_id: - return row - raise ValueError(f"Unknown row_id={row_id}") - - -def _set_seed(seed: int) -> None: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - -def _dtype_name(dtype: torch.dtype) -> str: - return str(dtype).removeprefix("torch.") - - -def _device_name(device: torch.device) -> str: - if device.type != "cuda": - return str(device) - return torch.cuda.get_device_name(device) - - -@contextmanager -def _temporary_env_var(name: str, value: str | None): - old_value = os.environ.get(name) - if value is None: - os.environ.pop(name, None) - else: - os.environ[name] = value - try: - yield - finally: - if old_value is None: - os.environ.pop(name, None) - else: - os.environ[name] = old_value - - -def _required_tool_path(tool: str) -> str: - path = shutil.which(tool) - if path is None: - raise RuntimeError( - f"Required CUDA profiling tool `{tool}` was not found on PATH. " - "Install the CUDA Toolkit / Nsight tools on the machine; this is not a Python package dependency." - ) - return path - - -def _nvml_device_summary(device: torch.device) -> dict[str, object]: - if device.type != "cuda": - raise RuntimeError("Fabric CUDA profiling requires a CUDA device") - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(device.index or 0) - memory = pynvml.nvmlDeviceGetMemoryInfo(handle) - return { - "driver_version": pynvml.nvmlSystemGetDriverVersion(), - "memory_total_mib": int(memory.total // (1024 * 1024)), - } - - -def _matches( - row: RowSpec, - *, - fabric_hidden_grid: tuple[int, ...] | None = None, -) -> tuple[MatchedBackbone, MatchedBackbone]: - stack_match = find_param_matched_backbone( - target_params=row.target_params, - kind="stack", - family=row.family, - ) - fabric_match = find_param_matched_backbone( - target_params=stack_match.actual_params, - kind="fabric", - family=row.family, - forced_d_hidden=stack_match.d_hidden, - fabric_hidden_grid=fabric_hidden_grid, - ) - return stack_match, fabric_match - - -def _record_to_json(record: BackendExecutionRecord) -> dict[str, object]: - return { - "backend_name": record.backend_name, - "surface_key": record.surface_key, - "cell_type": record.cell_type, - "regime": record.regime, - "training": record.training, - "batch_size": record.batch_size, - "time_steps": record.time_steps, - "inner_steps": record.inner_steps, - "bucket_ids": record.bucket_ids, - "execution_families": record.execution_families, - "math_backends": record.math_backends, - "tape_policy_bin": record.tape_policy_bin, - "graph_capture_enabled": record.graph_capture_enabled, - "graph_capture_cache_hit": record.graph_capture_cache_hit, - "graph_capture_replayed": record.graph_capture_replayed, - "capability_variants": record.capability_variants, - "large_r_safety_modes": record.large_r_safety_modes, - "active_cell_tiling_plans": record.active_cell_tiling_plans, - "large_r_diagnostics": record.large_r_diagnostics, - "requested_launch_receiver_tiles": record.requested_launch_receiver_tiles, - "requested_launch_batch_tiles": record.requested_launch_batch_tiles, - "requested_launch_edge_tiles": record.requested_launch_edge_tiles, - "requested_launch_hidden_chunks": record.requested_launch_hidden_chunks, - "requested_launch_state_receiver_tiles": record.requested_launch_state_receiver_tiles, - "requested_launch_state_batch_tiles": record.requested_launch_state_batch_tiles, - "requested_launch_state_hidden_chunks": record.requested_launch_state_hidden_chunks, - "requested_launch_state_static_stage_modes": record.requested_launch_state_static_stage_modes, - "requested_launch_emit_receiver_tiles": record.requested_launch_emit_receiver_tiles, - "requested_launch_emit_batch_tiles": record.requested_launch_emit_batch_tiles, - "requested_launch_emit_hidden_chunks": record.requested_launch_emit_hidden_chunks, - "requested_launch_emit_static_stage_modes": record.requested_launch_emit_static_stage_modes, - "requested_launch_public_receiver_tiles": record.requested_launch_public_receiver_tiles, - "requested_launch_public_batch_tiles": record.requested_launch_public_batch_tiles, - "requested_launch_replication_factors": record.requested_launch_replication_factors, - "requested_launch_cell_static_stage_modes": record.requested_launch_cell_static_stage_modes, - "requested_launch_readout_modes": record.requested_launch_readout_modes, - "actual_launch_receiver_tiles": record.actual_launch_receiver_tiles, - "actual_launch_batch_tiles": record.actual_launch_batch_tiles, - "actual_launch_edge_tiles": record.actual_launch_edge_tiles, - "actual_launch_hidden_chunks": record.actual_launch_hidden_chunks, - "actual_launch_state_receiver_tiles": record.actual_launch_state_receiver_tiles, - "actual_launch_state_batch_tiles": record.actual_launch_state_batch_tiles, - "actual_launch_state_hidden_chunks": record.actual_launch_state_hidden_chunks, - "actual_launch_state_static_stage_modes": record.actual_launch_state_static_stage_modes, - "actual_launch_emit_receiver_tiles": record.actual_launch_emit_receiver_tiles, - "actual_launch_emit_batch_tiles": record.actual_launch_emit_batch_tiles, - "actual_launch_emit_hidden_chunks": record.actual_launch_emit_hidden_chunks, - "actual_launch_emit_static_stage_modes": record.actual_launch_emit_static_stage_modes, - "actual_launch_public_receiver_tiles": record.actual_launch_public_receiver_tiles, - "actual_launch_public_batch_tiles": record.actual_launch_public_batch_tiles, - "actual_launch_replication_factors": record.actual_launch_replication_factors, - "actual_launch_cell_static_stage_modes": record.actual_launch_cell_static_stage_modes, - "actual_launch_readout_modes": record.actual_launch_readout_modes, - "launch_temporal_executions": record.launch_temporal_executions, - "launch_scan_implementations": record.launch_scan_implementations, - "launch_phases": record.launch_phases, - "active_receiver_window_modes": record.active_receiver_window_modes, - "active_receiver_window_offsets": record.active_receiver_window_offsets, - "active_receiver_window_counts": record.active_receiver_window_counts, - "input_projection_backends": record.input_projection_backends, - "input_projection_notes": record.input_projection_notes, - "message_projection_boundaries": record.message_projection_boundaries, - "message_projection_bucket_kinds": record.message_projection_bucket_kinds, - "message_bucket_count": record.message_bucket_count, - "message_regular_local_bucket_count": record.message_regular_local_bucket_count, - "message_sparse_bucket_count": record.message_sparse_bucket_count, - "message_batched_backend_count": record.message_batched_backend_count, - "message_grouped_backend_count": record.message_grouped_backend_count, - "message_reset_aware_bucket_count": record.message_reset_aware_bucket_count, - "message_degree_uniform_bucket_count": record.message_degree_uniform_bucket_count, - "message_ragged_grouped_bucket_count": record.message_ragged_grouped_bucket_count, - "message_demoted_bucket_count": record.message_demoted_bucket_count, - "message_bucket_signatures": record.message_bucket_signatures, - "message_bucket_kinds": record.message_bucket_kinds, - "message_topology_kinds": record.message_topology_kinds, - "message_spatial_ownership": record.message_spatial_ownership, - "message_degree_bucket_lists": record.message_degree_bucket_lists, - "message_logit_backends": record.message_logit_backends, - "message_softmax_backends": record.message_softmax_backends, - "message_weighted_value_backends": record.message_weighted_value_backends, - "message_physical_mode": record.message_physical_mode, - "message_execution_mode": record.message_execution_mode, - "message_output_boundary": record.message_output_boundary, - "message_reset_policies": record.message_reset_policies, - "message_reset_scopes": record.message_reset_scopes, - "message_use_delay": record.message_use_delay, - "message_distance_penalty_kinds": record.message_distance_penalty_kinds, - "message_epilogue_kinds": record.message_epilogue_kinds, - "message_packed_source_reuse_count": record.message_packed_source_reuse_count, - "message_demotions": record.message_demotions, - "message_workspace_buffers": record.message_workspace_buffers, - "message_workspace_buffer_bytes": record.message_workspace_buffer_bytes, - "message_workspace_peak_bytes": record.message_workspace_peak_bytes, - "message_workspace_mode": record.message_workspace_mode, - "message_workspace_aliases": record.message_workspace_aliases, - "message_per_bucket_workspace_bytes": record.message_per_bucket_workspace_bytes, - "state_affine_backends": record.state_affine_backends, - "state_affine_sources": record.state_affine_sources, - "state_affine_bucket_signatures": record.state_affine_bucket_signatures, - "state_affine_output_modes": record.state_affine_output_modes, - "state_affine_reset_policies": record.state_affine_reset_policies, - "state_affine_reset_mode": record.state_affine_reset_mode, - "state_affine_reset_scope": record.state_affine_reset_scope, - "state_affine_workspace_mode": record.state_affine_workspace_mode, - "state_affine_receiver_chunk_size": record.state_affine_receiver_chunk_size, - "state_affine_receiver_chunks": record.state_affine_receiver_chunks, - "state_affine_workspace_buffers": record.state_affine_workspace_buffers, - "state_affine_workspace_buffer_bytes": record.state_affine_workspace_buffer_bytes, - "state_affine_workspace_bytes": record.state_affine_workspace_bytes, - "state_affine_reset_rows_present": record.state_affine_reset_rows_present, - "state_affine_packed_source_reused": record.state_affine_packed_source_reused, - "public_projection_hidden_backends": record.public_projection_hidden_backends, - "public_projection_kv_backends": record.public_projection_kv_backends, - "readout_projection_backends": record.readout_projection_backends, - "workspace_buffers": record.workspace_buffers, - "workspace_buffer_bytes": record.workspace_buffer_bytes, - "workspace_peak_bytes": record.workspace_peak_bytes, - "workspace_aliases": record.workspace_aliases, - "phase_launch_counts": record.phase_launch_counts, - "small_cublas_launch_counts": record.small_cublas_launch_counts, - "copy_glue_launch_counts": record.copy_glue_launch_counts, - "copy_glue_saved_launch_counts": record.copy_glue_saved_launch_counts, - "state_epilogue_modes": record.state_epilogue_modes, - "state_epilogue_saved_launch_counts": record.state_epilogue_saved_launch_counts, - "launch_coalescing_modes": record.launch_coalescing_modes, - "generic_glue_fusion_modes": record.generic_glue_fusion_modes, - "launch_granularity_modes": record.launch_granularity_modes, - "physical_op_kinds": record.physical_op_kinds, - "physical_layout_contracts": record.physical_layout_contracts, - "layout_mode": record.layout_mode, - "copy_elision_mode": record.copy_elision_mode, - "bias_fusion_mode": record.bias_fusion_mode, - "physical_op_executors": record.physical_op_executors, - "physical_op_demotions": record.physical_op_demotions, - "physical_boundary_contracts": record.physical_boundary_contracts, - "physical_applicability_predicates": record.physical_applicability_predicates, - "physical_workspace_aliases": record.physical_workspace_aliases, - "physical_workspace_peak_bytes": record.physical_workspace_peak_bytes, - "physical_op_launch_counts": record.physical_op_launch_counts, - "physical_op_saved_launch_counts": record.physical_op_saved_launch_counts, - "standalone_copy_kernel_count": record.standalone_copy_kernel_count, - "standalone_bias_kernel_count": record.standalone_bias_kernel_count, - "receiver_affine_superop_surface_count": record.receiver_affine_superop_surface_count, - "receiver_affine_superop_receivers": record.receiver_affine_superop_receivers, - "receiver_affine_superop_k": record.receiver_affine_superop_k, - "receiver_affine_superop_n": record.receiver_affine_superop_n, - "receiver_affine_superop_source_layout": record.receiver_affine_superop_source_layout, - "receiver_affine_superop_reset_policy": record.receiver_affine_superop_reset_policy, - "receiver_affine_superop_executor": record.receiver_affine_superop_executor, - "receiver_affine_superop_physical_mode": record.receiver_affine_superop_physical_mode, - "receiver_affine_superop_demotion_reason": record.receiver_affine_superop_demotion_reason, - "diagonal_recurrence_superop_surface_count": record.diagonal_recurrence_superop_surface_count, - "diagonal_recurrence_kind": record.diagonal_recurrence_kind, - "diagonal_recurrence_executor": record.diagonal_recurrence_executor, - "diagonal_recurrence_physical_mode": record.diagonal_recurrence_physical_mode, - "diagonal_recurrence_coeff_cache_mode": record.diagonal_recurrence_coeff_cache_mode, - "diagonal_recurrence_coeff_cache_hit": record.diagonal_recurrence_coeff_cache_hit, - "diagonal_recurrence_coeff_cache_bytes": record.diagonal_recurrence_coeff_cache_bytes, - "diagonal_recurrence_coeff_cache_version_source": record.diagonal_recurrence_coeff_cache_version_source, - "diagonal_recurrence_reset_policy": record.diagonal_recurrence_reset_policy, - "diagonal_recurrence_reset_scope": record.diagonal_recurrence_reset_scope, - "diagonal_recurrence_output_boundary": record.diagonal_recurrence_output_boundary, - "diagonal_recurrence_workspace_mode": record.diagonal_recurrence_workspace_mode, - "diagonal_recurrence_workspace_peak_bytes": record.diagonal_recurrence_workspace_peak_bytes, - "diagonal_recurrence_demotion_reason": record.diagonal_recurrence_demotion_reason, - "diagonal_recurrence_launch_count": record.diagonal_recurrence_launch_count, - "backward_receiver_execution_families": record.backward_receiver_execution_families, - "backward_receiver_math_backends": record.backward_receiver_math_backends, - "backward_sender_execution_families": record.backward_sender_execution_families, - "backward_sender_math_backends": record.backward_sender_math_backends, - "backward_affine_bucket_signatures": record.backward_affine_bucket_signatures, - "backward_affine_forward_backends": record.backward_affine_forward_backends, - "backward_affine_input_grad_backends": record.backward_affine_input_grad_backends, - "backward_affine_weight_grad_backends": record.backward_affine_weight_grad_backends, - "backward_affine_bias_grad_backends": record.backward_affine_bias_grad_backends, - "backward_affine_demotion_reasons": record.backward_affine_demotion_reasons, - "backward_affine_execution_modes": record.backward_affine_execution_modes, - "backward_physical_op_kinds": record.backward_physical_op_kinds, - "backward_physical_op_executors": record.backward_physical_op_executors, - "backward_physical_op_demotions": record.backward_physical_op_demotions, - "backward_boundary_contracts": record.backward_boundary_contracts, - "backward_layout_mode": record.backward_layout_mode, - "backward_workspace_aliases": record.backward_workspace_aliases, - "backward_workspace_peak_bytes": record.backward_workspace_peak_bytes, - "backward_tape_mode": record.backward_tape_mode, - "backward_recompute_mode": record.backward_recompute_mode, - "backward_launch_counts": record.backward_launch_counts, - "backward_saved_launch_counts": record.backward_saved_launch_counts, - "backward_owner_timing_ms": record.backward_owner_timing_ms, - "backward_owner_wall_ms": record.backward_owner_wall_ms, - "backward_residual_glue_demotions": record.backward_residual_glue_demotions, - } - - -def _model_backend_record(model: torch.nn.Module) -> dict[str, object] | None: - backbone = model.backbone - record = getattr(backbone, "last_backend_execution", None) - if record is None: - return None - return _record_to_json(record) - - -def _shape_json(match: MatchedBackbone) -> dict[str, object]: - return { - "kind": match.kind, - "family": match.family, - "target_params": match.target_params, - "actual_params": match.actual_params, - "d_hidden": match.d_hidden, - "num_layers": match.num_layers, - "fabric_shape": match.fabric_shape, - "fabric_hidden_size": match.fabric_hidden_size, - } - - -def _param_gap_pct(*, actual: int, reference: int) -> float: - if reference == 0: - return 0.0 - return 100.0 * (float(actual) - float(reference)) / float(reference) - - -def _match_note(row: RowSpec, stack_match: MatchedBackbone, fabric_match: MatchedBackbone) -> str: - stack_target_gap = _param_gap_pct(actual=int(stack_match.actual_params), reference=int(row.target_params)) - fabric_stack_gap = _param_gap_pct(actual=int(fabric_match.actual_params), reference=int(stack_match.actual_params)) - status = "matched" if abs(fabric_stack_gap) <= 10.0 else "stand_in" - return f"{status};stack_vs_target={stack_target_gap:+.1f}%;fabric_vs_stack={fabric_stack_gap:+.1f}%" - - -def _status_value(result: dict[str, object], key: str) -> object | None: - if result["status"] != "ok": - return None - return result[key] - - -def _error_value(result: dict[str, object]) -> object | None: - if "error" not in result: - return None - return result["error"] - - -def _plateau_detected(rows: list[dict[str, object]], side: Side) -> bool: - values = [ - float(row[f"{side}_tokens_per_s"]) - for row in rows - if row[f"{side}_status"] == "ok" and row[f"{side}_tokens_per_s"] is not None - ] - if len(values) < 3: - return False - gains = [ - (values[-2] - values[-3]) / values[-3] if values[-3] > 0.0 else math.inf, - (values[-1] - values[-2]) / values[-2] if values[-2] > 0.0 else math.inf, - ] - return gains[0] < 0.05 and gains[1] < 0.05 - - -def _plateau_note(rows: list[dict[str, object]], side: Side) -> str: - if any(row[f"{side}_status"] != "ok" for row in rows): - failed = next(row for row in rows if row[f"{side}_status"] != "ok") - return f"{failed[f'{side}_status']} at B={failed['batch_size']}" - if _plateau_detected(rows, side): - return "plateau: last two gains <5%" - return "not proven" - - -def _row_batch_gain(rows: list[dict[str, object]], side: Side, index: int) -> float | None: - if index == 0: - return None - previous = rows[index - 1] - current = rows[index] - if previous[f"{side}_status"] != "ok" or current[f"{side}_status"] != "ok": - return None - previous_value = float(previous[f"{side}_tokens_per_s"]) - current_value = float(current[f"{side}_tokens_per_s"]) - return (current_value - previous_value) / previous_value if previous_value > 0.0 else math.inf - - -def _run_sweep_row( - *, - row: RowSpec, - device: torch.device, - dtype: torch.dtype, - warmup: int, - iterations: int, - batches: tuple[int, ...], - max_batch: int, - seed: int, - fabric_hidden_grid: tuple[int, ...] | None, -) -> dict[str, object]: - _set_seed(seed) - stack_match, fabric_match = _matches(row, fabric_hidden_grid=fabric_hidden_grid) - rows: list[dict[str, object]] = [] - planned_batches = list(batches) - next_batch = max(batches) * 2 - while planned_batches: - batch_size = planned_batches.pop(0) - if device.type == "cuda": - torch.cuda.empty_cache() - _set_seed(seed + batch_size + 1) - stack_model = make_sequence_model(stack_match, device=device, dtype=dtype).eval() - stack_result = run_sequence_case( - match=stack_match, - mode=row.mode, - batch_size=batch_size, - seq_len=row.seq_len, - family=row.family, - device=device, - dtype=dtype, - warmup=warmup, - iterations=iterations, - model=stack_model, - ) - del stack_model - if device.type == "cuda": - torch.cuda.empty_cache() - _set_seed(seed + batch_size + 2) - fabric_model = make_sequence_model(fabric_match, device=device, dtype=dtype).eval() - fabric_result = run_sequence_case( - match=fabric_match, - mode=row.mode, - batch_size=batch_size, - seq_len=row.seq_len, - family=row.family, - device=device, - dtype=dtype, - warmup=warmup, - iterations=iterations, - model=fabric_model, - ) - fabric_backend_record = _model_backend_record(fabric_model) - del fabric_model - if device.type == "cuda": - torch.cuda.empty_cache() - ratio = None - if stack_result["status"] == "ok" and fabric_result["status"] == "ok": - ratio = float(fabric_result["tokens_per_s"]) / float(stack_result["tokens_per_s"]) - sweep_row: dict[str, object] = { - "row_id": row.row_id, - "family": row.family, - "params_label": row.params_label, - "mode": row.mode, - "seq_len": row.seq_len, - "batch_size": batch_size, - "stack_status": stack_result["status"], - "fabric_status": fabric_result["status"], - "stack_error": _error_value(stack_result), - "fabric_error": _error_value(fabric_result), - "stack_ms": _status_value(stack_result, "ms"), - "fabric_ms": _status_value(fabric_result, "ms"), - "stack_tokens_per_s": _status_value(stack_result, "tokens_per_s"), - "fabric_tokens_per_s": _status_value(fabric_result, "tokens_per_s"), - "fabric_stack_ratio": ratio, - "stack_peak_mem_gib": _status_value(stack_result, "peak_mem_gib"), - "fabric_peak_mem_gib": _status_value(fabric_result, "peak_mem_gib"), - "fabric_backend_record": fabric_backend_record, - } - rows.append(sweep_row) - if batch_size >= max(batches): - terminal_status = stack_result["status"] != "ok" or fabric_result["status"] != "ok" - terminal_plateau = _plateau_detected(rows, "stack") or _plateau_detected(rows, "fabric") - if terminal_status or terminal_plateau: - break - if next_batch <= max_batch: - planned_batches.append(next_batch) - next_batch *= 2 - for index, sweep_row in enumerate(rows): - sweep_row["stack_gain_from_previous_batch"] = _row_batch_gain(rows, "stack", index) - sweep_row["fabric_gain_from_previous_batch"] = _row_batch_gain(rows, "fabric", index) - if device.type == "cuda": - torch.cuda.empty_cache() - return { - "row_id": row.row_id, - "family": row.family, - "params_label": row.params_label, - "target_params": row.target_params, - "mode": row.mode, - "seq_len": row.seq_len, - "stack_match": _shape_json(stack_match), - "fabric_match": _shape_json(fabric_match), - "match_note": _match_note(row, stack_match, fabric_match), - "stack_plateau_note": _plateau_note(rows, "stack"), - "fabric_plateau_note": _plateau_note(rows, "fabric"), - "rows": rows, - } - - -def _run_profile_iteration( - *, - model: torch.nn.Module, - x: torch.Tensor, - resets: torch.Tensor | None = None, - target: torch.Tensor | None, - optimizer: torch.optim.Optimizer | None, -) -> None: - if target is None: - with torch.inference_mode(): - kwargs: dict[str, object] = {"materialize_final_state": False} - if resets is not None: - kwargs["resets"] = resets - model(x, **kwargs) - return - if optimizer is None: - raise RuntimeError("Training profile requires an optimizer") - output_boundary = "terminal" if target is not None and target.dim() + 1 == x.dim() else "sequence" - optimizer.zero_grad(set_to_none=True) - kwargs: dict[str, object] = {"resets": resets} - if hasattr(getattr(model, "backbone", None), "_forward_sequence_with_readout"): - kwargs["output_boundary"] = output_boundary - y, _ = model(x, **kwargs) - loss = sequence_training_loss(y, target) - loss.backward() - optimizer.step() - - -def _profile_events(prof: torch.profiler.profile, *, key: str, limit: int = 20) -> tuple[EventSummary, ...]: - events = prof.key_averages() - if key == "cuda": - sorted_events = sorted(events, key=lambda event: event.self_device_time_total, reverse=True) - elif key == "cpu": - sorted_events = sorted(events, key=lambda event: event.self_cpu_time_total, reverse=True) - else: - sorted_events = [ - event - for event in events - if any(pattern in event.key for pattern in RUNTIME_OP_PATTERNS) - and (event.self_cpu_time_total > 0.0 or event.self_device_time_total > 0.0) - ] - sorted_events = sorted( - sorted_events, - key=lambda event: event.self_device_time_total + event.self_cpu_time_total, - reverse=True, - ) - return tuple( - EventSummary( - name=event.key, - count=int(event.count), - self_cuda_us=float(event.self_device_time_total), - self_cpu_us=float(event.self_cpu_time_total), - cuda_total_us=float(event.device_time_total), - cpu_total_us=float(event.cpu_time_total), - ) - for event in sorted_events[:limit] - if event.self_cpu_time_total > 0.0 or event.self_device_time_total > 0.0 - ) - - -def _profile_events_for_patterns( - prof: torch.profiler.profile, - patterns: tuple[str, ...], -) -> tuple[EventSummary, ...]: - events = prof.key_averages() - summaries: list[EventSummary] = [] - for pattern in patterns: - matched_events = [event for event in events if pattern in event.key] - summaries.append( - EventSummary( - name=pattern, - count=sum(int(event.count) for event in matched_events), - self_cuda_us=sum(float(event.self_device_time_total) for event in matched_events), - self_cpu_us=sum(float(event.self_cpu_time_total) for event in matched_events), - cuda_total_us=sum(float(event.device_time_total) for event in matched_events), - cpu_total_us=sum(float(event.cpu_time_total) for event in matched_events), - ) - ) - return tuple(summaries) - - -def _profile_events_for_patterns_with_parent( - prof: torch.profiler.profile, - patterns: tuple[str, ...], - parent_patterns: tuple[str, ...], -) -> tuple[EventSummary, ...]: - summaries: list[EventSummary] = [] - for pattern in patterns: - matched_events = [ - event - for event in prof.events() - if pattern in str(event.key) and any(parent in parent_patterns for parent in _parent_keys(event)) - ] - summaries.append( - EventSummary( - name=pattern, - count=len(matched_events), - self_cuda_us=sum(float(event.self_device_time_total) for event in matched_events), - self_cpu_us=sum(float(event.self_cpu_time_total) for event in matched_events), - cuda_total_us=sum(float(event.device_time_total) for event in matched_events), - cpu_total_us=sum(float(event.cpu_time_total) for event in matched_events), - ) - ) - return tuple(summaries) - - -def _event_matches_patterns(name: str, patterns: tuple[str, ...]) -> bool: - return any(pattern in name for pattern in patterns) - - -def _parent_keys(event: object) -> tuple[str, ...]: - keys: list[str] = [] - parent = getattr(event, "cpu_parent", None) - while parent is not None: - key = getattr(parent, "key", None) - if isinstance(key, str): - keys.append(key) - parent = getattr(parent, "cpu_parent", None) - return tuple(keys) - - -def _is_backward_explicit_owner_attributed(event_key: str, parents: tuple[str, ...]) -> bool: - if any(parent_key in BACKWARD_OWNER_SOURCE_PATTERNS for parent_key in parents): - return True - if _event_matches_patterns(event_key, BACKWARD_OWNER_EVENT_SOURCE_PATTERNS): - return True - if any(_event_matches_patterns(parent_key, BACKWARD_OWNER_PARENT_SOURCE_PATTERNS) for parent_key in parents): - return True - return False - - -def _is_backward_owner_attributed(event_key: str, parents: tuple[str, ...]) -> bool: - return _is_backward_explicit_owner_attributed(event_key, parents) or ( - _backward_derived_owner(event_key, parents) is not None - ) - - -def _backward_derived_owner(event_key: str, parents: tuple[str, ...]) -> str | None: - if "fabric.backward.total" not in parents: - return None - if any(parent_key.startswith(BACKWARD_DERIVED_BOUNDARY_GLUE_PARENT_PREFIXES) for parent_key in parents): - return "fabric.backward.derived.boundary_glue" - for owner, event_patterns, parent_patterns in BACKWARD_DERIVED_OWNER_RULES: - if _event_matches_patterns(event_key, event_patterns): - return owner - if any(_event_matches_patterns(parent_key, parent_patterns) for parent_key in parents): - return owner - return None - - -def _glue_parent_attribution(prof: torch.profiler.profile) -> tuple[float, float]: - denominator_us = 0.0 - attributed_us = 0.0 - for event in prof.events(): - if not _event_matches_patterns(str(event.key), GLUE_COVERAGE_KERNEL_PATTERNS): - continue - self_cuda_us = float(event.self_device_time_total) - if self_cuda_us <= 0.0: - continue - denominator_us += self_cuda_us - if any(parent_key in GLUE_COVERAGE_SOURCE_PATTERNS for parent_key in _parent_keys(event)): - attributed_us += self_cuda_us - return attributed_us, denominator_us - - -def _backward_parent_attribution(prof: torch.profiler.profile) -> tuple[float, float, float, float]: - denominator_us = 0.0 - attributed_us = 0.0 - explicit_us = 0.0 - derived_us = 0.0 - for event in prof.events(): - event_key = str(event.key) - if event_key.startswith(BACKWARD_ATTRIBUTION_CONTAINER_PREFIXES): - continue - self_cuda_us = float(event.self_device_time_total) - if self_cuda_us <= 0.0: - continue - parents = _parent_keys(event) - if "fabric.backward.total" not in parents: - continue - denominator_us += self_cuda_us - if _is_backward_explicit_owner_attributed(event_key, parents): - attributed_us += self_cuda_us - explicit_us += self_cuda_us - continue - if _backward_derived_owner(event_key, parents) is not None: - attributed_us += self_cuda_us - derived_us += self_cuda_us - return attributed_us, explicit_us, derived_us, denominator_us - - -def _backward_unattributed_events(prof: torch.profiler.profile, *, limit: int = 12) -> tuple[EventSummary, ...]: - by_key: dict[str, dict[str, float]] = {} - for event in prof.events(): - event_key = str(event.key) - if event_key.startswith(BACKWARD_ATTRIBUTION_CONTAINER_PREFIXES): - continue - self_cuda_us = float(event.self_device_time_total) - if self_cuda_us <= 0.0: - continue - parents = _parent_keys(event) - if "fabric.backward.total" not in parents: - continue - if _is_backward_owner_attributed(event_key, parents): - continue - bucket = by_key.setdefault( - event_key, - { - "count": 0.0, - "self_cuda_us": 0.0, - "self_cpu_us": 0.0, - "cuda_total_us": 0.0, - "cpu_total_us": 0.0, - }, - ) - bucket["count"] += 1.0 - bucket["self_cuda_us"] += self_cuda_us - bucket["self_cpu_us"] += float(event.self_cpu_time_total) - bucket["cuda_total_us"] += float(event.device_time_total) - bucket["cpu_total_us"] += float(event.cpu_time_total) - sorted_items = sorted(by_key.items(), key=lambda item: item[1]["self_cuda_us"], reverse=True) - return tuple( - EventSummary( - name=name, - count=int(values["count"]), - self_cuda_us=values["self_cuda_us"], - self_cpu_us=values["self_cpu_us"], - cuda_total_us=values["cuda_total_us"], - cpu_total_us=values["cpu_total_us"], - ) - for name, values in sorted_items[:limit] - ) - - -def _backward_derived_owner_events(prof: torch.profiler.profile) -> tuple[EventSummary, ...]: - by_owner: dict[str, dict[str, float]] = {} - for event in prof.events(): - event_key = str(event.key) - if event_key.startswith(BACKWARD_ATTRIBUTION_CONTAINER_PREFIXES): - continue - self_cuda_us = float(event.self_device_time_total) - if self_cuda_us <= 0.0: - continue - parents = _parent_keys(event) - if _is_backward_explicit_owner_attributed(event_key, parents): - continue - owner = _backward_derived_owner(event_key, parents) - if owner is None: - continue - bucket = by_owner.setdefault( - owner, - { - "count": 0.0, - "self_cuda_us": 0.0, - "self_cpu_us": 0.0, - "cuda_total_us": 0.0, - "cpu_total_us": 0.0, - }, - ) - bucket["count"] += 1.0 - bucket["self_cuda_us"] += self_cuda_us - bucket["self_cpu_us"] += float(event.self_cpu_time_total) - bucket["cuda_total_us"] += float(event.device_time_total) - bucket["cpu_total_us"] += float(event.cpu_time_total) - sorted_items = sorted(by_owner.items(), key=lambda item: item[1]["self_cuda_us"], reverse=True) - return tuple( - EventSummary( - name=name, - count=int(values["count"]), - self_cuda_us=values["self_cuda_us"], - self_cpu_us=values["self_cpu_us"], - cuda_total_us=values["cuda_total_us"], - cpu_total_us=values["cpu_total_us"], - ) - for name, values in sorted_items - ) - - -def _backward_derived_owner_source_events( - prof: torch.profiler.profile, - *, - limit: int = 20, -) -> tuple[EventSummary, ...]: - by_source: dict[str, dict[str, float]] = {} - for event in prof.events(): - event_key = str(event.key) - if event_key.startswith(BACKWARD_ATTRIBUTION_CONTAINER_PREFIXES): - continue - self_cuda_us = float(event.self_device_time_total) - if self_cuda_us <= 0.0: - continue - parents = _parent_keys(event) - if _is_backward_explicit_owner_attributed(event_key, parents): - continue - owner = _backward_derived_owner(event_key, parents) - if owner is None: - continue - source_parent = next( - (parent_key for parent_key in parents if parent_key.startswith("fabric.glue.")), - next( - (parent_key for parent_key in parents if parent_key.startswith("fabric.projection.")), - next((parent_key for parent_key in parents if parent_key.startswith("fabric.backward.")), "unknown"), - ), - ) - source_name = f"{owner}|{source_parent}|{event_key}" - bucket = by_source.setdefault( - source_name, - { - "count": 0.0, - "self_cuda_us": 0.0, - "self_cpu_us": 0.0, - "cuda_total_us": 0.0, - "cpu_total_us": 0.0, - }, - ) - bucket["count"] += 1.0 - bucket["self_cuda_us"] += self_cuda_us - bucket["self_cpu_us"] += float(event.self_cpu_time_total) - bucket["cuda_total_us"] += float(event.device_time_total) - bucket["cpu_total_us"] += float(event.cpu_time_total) - sorted_items = sorted(by_source.items(), key=lambda item: item[1]["self_cuda_us"], reverse=True) - return tuple( - EventSummary( - name=name, - count=int(values["count"]), - self_cuda_us=values["self_cuda_us"], - self_cpu_us=values["self_cpu_us"], - cuda_total_us=values["cuda_total_us"], - cpu_total_us=values["cpu_total_us"], - ) - for name, values in sorted_items[:limit] - ) - - -def _required_forward_phase_kernel_patterns(row: RowSpec) -> tuple[str, ...]: - if row.mode != "forward": - return () - if row.family == "slstm": - return REQUIRED_FORWARD_BASE_KERNEL_PATTERNS + REQUIRED_SLSTM_FORWARD_PHASE_KERNEL_PATTERNS - if row.family == "axoncell": - return REQUIRED_FORWARD_BASE_KERNEL_PATTERNS + REQUIRED_AXON_FORWARD_PHASE_KERNEL_PATTERNS - return REQUIRED_FORWARD_BASE_KERNEL_PATTERNS - - -def _profile_case( - *, - row: RowSpec, - side: Side, - batch_size: int, - device: torch.device, - dtype: torch.dtype, - warmup: int, - iterations: int, - seed: int, - backward_attribution_mode: BackwardAttributionMode, - use_resets: bool, - fabric_hidden_grid: tuple[int, ...] | None, -) -> ProfileSummary: - _set_seed(seed) - stack_match, fabric_match = _matches(row, fabric_hidden_grid=fabric_hidden_grid) - match = stack_match if side == "stack" else fabric_match - model = make_sequence_model(match, device=device, dtype=dtype).eval() - x = torch.randn(batch_size, row.seq_len, match.d_hidden, device=device, dtype=dtype) - resets = None - if use_resets: - resets = torch.zeros(batch_size, row.seq_len, device=device, dtype=torch.bool) - resets[::2, 0] = True - if row.seq_len > 1: - resets[1::2, -1] = True - target = None if row.mode == "forward" else make_sequence_training_target(x) - optimizer = None if target is None else torch.optim.SGD(model.parameters(), lr=0.0) - env_value = ( - backward_attribution_mode - if side == "fabric" and row.mode != "forward" and backward_attribution_mode != "active" - else None - ) - with _temporary_env_var("CORTICAL_FABRIC_BACKWARD_ATTRIBUTION_MODE", env_value): - for _ in range(warmup): - _run_profile_iteration(model=model, x=x, resets=resets, target=target, optimizer=optimizer) - if device.type == "cuda": - torch.cuda.synchronize() - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - record_shapes=False, - ) as prof: - for _ in range(iterations): - _run_profile_iteration(model=model, x=x, resets=resets, target=target, optimizer=optimizer) - if device.type == "cuda": - torch.cuda.synchronize() - top_cuda = _profile_events(prof, key="cuda") - top_cpu = _profile_events(prof, key="cpu") - runtime_ops = _profile_events(prof, key="runtime") - message_kernel_events = _profile_events_for_patterns(prof, MESSAGE_PROFILE_KERNEL_PATTERNS) - physical_op_events = _profile_events_for_patterns(prof, PHYSICAL_OP_PROFILE_PATTERNS) - backward_owner_events = _profile_events_for_patterns(prof, BACKWARD_OWNER_PROFILE_PATTERNS) - backward_recompute_events = _profile_events_for_patterns_with_parent( - prof, - BACKWARD_RECOMPUTE_PROFILE_PATTERNS, - ("fabric.backward.total",), - ) - backward_derived_owner_events = _backward_derived_owner_events(prof) - backward_derived_owner_source_events = _backward_derived_owner_source_events(prof) - backward_unattributed_events = _backward_unattributed_events(prof) - backward_small_cublas_events = _profile_events_for_patterns_with_parent( - prof, - BACKWARD_SMALL_CUBLAS_PROFILE_PATTERNS, - ("fabric.backward.total",), - ) - backward_copy_glue_events = _profile_events_for_patterns_with_parent( - prof, - GLUE_PROFILE_KERNEL_PATTERNS, - ("fabric.backward.total",), - ) - ( - backward_owner_attributed_cuda_total_us, - backward_explicit_owner_attributed_cuda_total_us, - backward_derived_owner_attributed_cuda_total_us, - backward_owner_cuda_total_us, - ) = _backward_parent_attribution(prof) - backward_attribution_coverage = ( - min(1.0, backward_owner_attributed_cuda_total_us / backward_owner_cuda_total_us) - if backward_owner_cuda_total_us > 0.0 - else 1.0 - ) - backward_explicit_attribution_coverage = ( - min(1.0, backward_explicit_owner_attributed_cuda_total_us / backward_owner_cuda_total_us) - if backward_owner_cuda_total_us > 0.0 - else 1.0 - ) - backward_derived_attribution_coverage = ( - min(1.0, backward_derived_owner_attributed_cuda_total_us / backward_owner_cuda_total_us) - if backward_owner_cuda_total_us > 0.0 - else 0.0 - ) - glue_kernel_events = _profile_events_for_patterns(prof, GLUE_PROFILE_KERNEL_PATTERNS) - glue_source_events = _profile_events_for_patterns(prof, GLUE_SOURCE_PROFILE_PATTERNS) - glue_source_cuda_total_us, glue_kernel_self_cuda_us = _glue_parent_attribution(prof) - glue_attribution_coverage = ( - min(1.0, glue_source_cuda_total_us / glue_kernel_self_cuda_us) if glue_kernel_self_cuda_us > 0.0 else 1.0 - ) - kernel_names = tuple(sorted(event.key for event in prof.key_averages())) - generic_present = tuple( - pattern for pattern in GENERIC_FABRIC_KERNEL_PATTERNS if any(pattern in name for name in kernel_names) - ) - legacy_present = tuple( - pattern for pattern in LEGACY_KERNEL_PATTERNS if any(pattern in name for name in kernel_names) - ) - required_phase_patterns = _required_forward_phase_kernel_patterns(row) - missing_required_phase_patterns = tuple( - pattern - for pattern in required_phase_patterns - if row.mode == "forward" and side == "fabric" and not any(pattern in name for name in kernel_names) - ) - missing_required_message_kernel = ( - row.mode == "forward" - and side == "fabric" - and not any( - pattern in name for pattern in REQUIRED_FORWARD_MESSAGE_KERNEL_ALTERNATIVES for name in kernel_names - ) - ) - if side == "fabric" and not generic_present: - raise RuntimeError(f"Fabric profiler did not observe generic backend kernels for row_id={row.row_id}") - if missing_required_phase_patterns: - raise RuntimeError( - "Fabric profiler did not observe required factorized forward phase kernels " - f"for row_id={row.row_id}: {missing_required_phase_patterns}" - ) - if missing_required_message_kernel: - raise RuntimeError( - "Fabric profiler did not observe any supported message physical kernel " - f"for row_id={row.row_id}: {REQUIRED_FORWARD_MESSAGE_KERNEL_ALTERNATIVES}" - ) - if side == "fabric" and legacy_present: - raise RuntimeError(f"Fabric profiler observed legacy kernels for row_id={row.row_id}: {legacy_present}") - result = run_sequence_case( - match=match, - mode=row.mode, - batch_size=batch_size, - seq_len=row.seq_len, - family=row.family, - device=device, - dtype=dtype, - warmup=0, - iterations=1, - model=model, - resets=resets, - ) - if result["status"] != "ok": - raise RuntimeError(f"Profiled case measurement failed for {row.row_id}/{side}/B={batch_size}: {result}") - summary = ProfileSummary( - row_id=row.row_id, - side=side, - batch_size=batch_size, - use_resets=use_resets, - status="ok", - tokens_per_s=float(result["tokens_per_s"]), - ms=float(result["ms"]), - effective_batch_tile=(int(result["effective_batch_tile"]) if "effective_batch_tile" in result else None), - launch_count_cuda_events=sum(event.count for event in top_cuda), - launch_count_runtime_events=sum(event.count for event in runtime_ops), - top_cuda=top_cuda, - top_cpu=top_cpu, - runtime_ops=runtime_ops, - message_kernel_events=message_kernel_events, - physical_op_events=physical_op_events, - backward_owner_events=backward_owner_events, - backward_recompute_events=backward_recompute_events, - backward_derived_owner_events=backward_derived_owner_events, - backward_derived_owner_source_events=backward_derived_owner_source_events, - backward_unattributed_events=backward_unattributed_events, - backward_small_cublas_events=backward_small_cublas_events, - backward_copy_glue_events=backward_copy_glue_events, - backward_owner_cuda_total_us=backward_owner_cuda_total_us, - backward_owner_attributed_cuda_total_us=backward_owner_attributed_cuda_total_us, - backward_explicit_owner_attributed_cuda_total_us=backward_explicit_owner_attributed_cuda_total_us, - backward_derived_owner_attributed_cuda_total_us=backward_derived_owner_attributed_cuda_total_us, - backward_attribution_coverage=backward_attribution_coverage, - backward_explicit_attribution_coverage=backward_explicit_attribution_coverage, - backward_derived_attribution_coverage=backward_derived_attribution_coverage, - backward_attribution_mode=backward_attribution_mode if env_value is not None else "active", - glue_kernel_events=glue_kernel_events, - glue_source_events=glue_source_events, - glue_source_cuda_total_us=glue_source_cuda_total_us, - glue_kernel_self_cuda_us=glue_kernel_self_cuda_us, - glue_attribution_coverage=glue_attribution_coverage, - kernel_names=kernel_names, - generic_kernel_patterns_present=generic_present, - legacy_kernel_patterns_present=legacy_present, - backend_record=_model_backend_record(model) if side == "fabric" else None, - ) - del model, x, target, optimizer - if device.type == "cuda": - torch.cuda.empty_cache() - return summary - - -def _run_profiles( - *, - rows: tuple[RowSpec, ...], - profile_batches: tuple[int, ...], - device: torch.device, - dtype: torch.dtype, - warmup: int, - iterations: int, - seed: int, - backward_attribution_mode: BackwardAttributionMode, - use_resets: bool, - fabric_hidden_grid: tuple[int, ...] | None, -) -> tuple[dict[str, object], ...]: - profiles: list[dict[str, object]] = [] - for row_index, row in enumerate(rows): - for batch_size in profile_batches: - for side in ("fabric", "stack"): - try: - profile = _profile_case( - row=row, - side=side, - batch_size=batch_size, - device=device, - dtype=dtype, - warmup=warmup, - iterations=iterations, - seed=seed + (row_index * 1000) + batch_size + (1 if side == "fabric" else 2), - backward_attribution_mode=backward_attribution_mode, - use_resets=use_resets, - fabric_hidden_grid=fabric_hidden_grid, - ) - profiles.append(asdict(profile)) - except RuntimeError as error: - if device.type == "cuda": - torch.cuda.empty_cache() - profiles.append( - { - "row_id": row.row_id, - "side": side, - "batch_size": batch_size, - "use_resets": use_resets, - "status": "error", - "error": str(error), - "tokens_per_s": 0.0, - "ms": 0.0, - "effective_batch_tile": None, - "launch_count_cuda_events": 0, - "launch_count_runtime_events": 0, - "top_cuda": (), - "top_cpu": (), - "runtime_ops": (), - "message_kernel_events": (), - "physical_op_events": (), - "backward_owner_events": (), - "backward_recompute_events": (), - "backward_derived_owner_events": (), - "backward_derived_owner_source_events": (), - "backward_unattributed_events": (), - "backward_small_cublas_events": (), - "backward_copy_glue_events": (), - "backward_owner_cuda_total_us": 0.0, - "backward_owner_attributed_cuda_total_us": 0.0, - "backward_explicit_owner_attributed_cuda_total_us": 0.0, - "backward_derived_owner_attributed_cuda_total_us": 0.0, - "backward_attribution_coverage": 0.0, - "backward_explicit_attribution_coverage": 0.0, - "backward_derived_attribution_coverage": 0.0, - "backward_attribution_mode": backward_attribution_mode, - "glue_kernel_events": (), - "glue_source_events": (), - "glue_source_cuda_total_us": 0.0, - "glue_kernel_self_cuda_us": 0.0, - "glue_attribution_coverage": 0.0, - "kernel_names": (), - "generic_kernel_patterns_present": (), - "legacy_kernel_patterns_present": (), - "backend_record": None, - } - ) - return tuple(profiles) - - -def _single_run( - *, - row: RowSpec, - side: Side, - batch_size: int, - device: torch.device, - dtype: torch.dtype, - warmup: int, - iterations: int, - seed: int, - use_resets: bool = False, - fabric_hidden_grid: tuple[int, ...] | None = None, -) -> None: - _set_seed(seed) - stack_match, fabric_match = _matches(row, fabric_hidden_grid=fabric_hidden_grid) - match = stack_match if side == "stack" else fabric_match - model = make_sequence_model(match, device=device, dtype=dtype).eval() - x = torch.randn(batch_size, row.seq_len, match.d_hidden, device=device, dtype=dtype) - resets = None - if use_resets: - resets = torch.zeros(batch_size, row.seq_len, device=device, dtype=torch.bool) - resets[::2, 0] = True - if row.seq_len > 1: - resets[1::2, -1] = True - target = None if row.mode == "forward" else make_sequence_training_target(x) - optimizer = None if target is None else torch.optim.SGD(model.parameters(), lr=0.0) - for _ in range(warmup): - _run_profile_iteration(model=model, x=x, resets=resets, target=target, optimizer=optimizer) - if device.type == "cuda": - torch.cuda.synchronize() - torch.cuda.nvtx.range_push(f"{row.row_id}:{side}:B{batch_size}") - for _ in range(iterations): - _run_profile_iteration(model=model, x=x, resets=resets, target=target, optimizer=optimizer) - if device.type == "cuda": - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() - record = _model_backend_record(model) if side == "fabric" else None - print( - json.dumps( - { - "row_id": row.row_id, - "side": side, - "batch_size": batch_size, - "use_resets": use_resets, - "record": record, - }, - indent=2, - ) - ) - - -def _as_table_value(value: object) -> str: - if value is None: - return "-" - if isinstance(value, float): - return f"{value:.1f}" - if isinstance(value, tuple): - return ",".join(str(item) for item in value) - if isinstance(value, list): - return ",".join(str(item) for item in value) - return str(value) - - -def _first_record_value(record: dict[str, object] | None, key: str) -> str: - if record is None: - return "-" - value = record.get(key) - if isinstance(value, (list, tuple)): - return str(value[0]) if value else "-" - if value is None: - return "-" - return str(value) - - -def _record_values(record: dict[str, object] | None, key: str) -> str: - if record is None: - return "-" - value = record.get(key) - if isinstance(value, (list, tuple)): - return ",".join(str(item) for item in value) if value else "-" - if value is None: - return "-" - return str(value) - - -def _message_kernel_ms(profile: dict[str, object], pattern: str) -> str: - for event in profile.get("message_kernel_events", ()): - if event["name"] == pattern: - return f"{float(event['self_cuda_us']) / 1000.0:.3f} ms" - if not any(pattern in name for name in profile.get("kernel_names", ())): - return "0.000 ms" - return "present, not recorded" - - -def _physical_op_ms(profile: dict[str, object], pattern: str) -> str: - for event in profile.get("physical_op_events", ()): - if event["name"] == pattern: - return f"{float(event['self_cuda_us']) / 1000.0:.3f} ms/{int(event['count'])}" - return "0.000 ms/0" - - -def _backward_owner_ms(profile: dict[str, object], pattern: str) -> str: - for event in profile.get("backward_owner_events", ()): - if event["name"] == pattern: - return f"{float(event['cuda_total_us']) / 1000.0:.3f} ms/{int(event['count'])}" - return "0.000 ms/0" - - -def _backward_event_summary(profile: dict[str, object], key: str) -> str: - parts: list[str] = [] - for event in profile.get(key, ()): - count = int(event["count"]) - cuda_ms = float(event["self_cuda_us"]) / 1000.0 - if count > 0 or cuda_ms > 0.0: - name = str(event["name"]).replace("|", " / ") - parts.append(f"{name}={cuda_ms:.3f} ms/{count}") - return "; ".join(parts) if parts else "-" - - -def _largest_backward_event(profile: dict[str, object], key: str) -> dict[str, object] | None: - events = [ - event - for event in profile.get(key, ()) - if int(event.get("count", 0)) > 0 or float(event.get("self_cuda_us", 0.0)) > 0.0 - ] - if not events: - return None - return max(events, key=lambda event: float(event.get("self_cuda_us", 0.0))) - - -def _backward_family_for_owner(owner: str) -> str: - if "receiver_major_projection" in owner: - return "public/readout projection backward" - if "receiver_affine" in owner: - return "receiver-affine backward" - if ( - "tiny_message" in owner - or "message.receiver" in owner - or "message.sender" in owner - or "message.query_param" in owner - ): - return "regular-local message backward" - if "sparse_message" in owner: - return "sparse/edge-owned message backward" - if "diagonal_recurrence" in owner: - return "diagonal recurrence backward" - if "boundary_glue" in owner: - return "glue/layout backward" - if "state_public_epilogue" in owner or "state_epilogue" in owner: - return "state/public thin reverse" - if ( - "lowered_projection" in owner - or "grouped_projection" in owner - or "public_projection" in owner - or "readout" in owner - ): - return "public/readout projection backward" - return "integration/demotion cleanup" - - -def _backward_next_owner(profile: dict[str, object]) -> tuple[str, str, str]: - if profile["side"] != "fabric" or not str(profile["row_id"]).endswith("_train"): - return "-", "-", "-" - derived = _largest_backward_event(profile, "backward_derived_owner_events") - if derived is not None: - owner = str(derived["name"]) - owner_ms = f"{float(derived['self_cuda_us']) / 1000.0:.3f} ms/{int(derived['count'])}" - return owner, owner_ms, _backward_family_for_owner(owner) - explicit_events = [ - event - for event in profile.get("backward_owner_events", ()) - if str(event["name"]) not in {"fabric.backward.total", "fabric.backward.full_replay_autograd"} - and (int(event["count"]) > 0 or float(event["self_cuda_us"]) > 0.0) - ] - if not explicit_events: - return "unattributed", "0.000 ms/0", "backward attribution still open" - explicit = max(explicit_events, key=lambda event: float(event["self_cuda_us"])) - owner = str(explicit["name"]) - owner_ms = f"{float(explicit['self_cuda_us']) / 1000.0:.3f} ms/{int(explicit['count'])}" - return owner, owner_ms, _backward_family_for_owner(owner) - - -def _backward_recompute_summary(profile: dict[str, object]) -> str: - parts: list[str] = [] - for event in profile.get("backward_recompute_events", ()): - count = int(event["count"]) - cuda_ms = float(event["cuda_total_us"]) / 1000.0 - if count > 0 or cuda_ms > 0.0: - name = str(event["name"]).removeprefix("fabric.physical.") - parts.append(f"{name}={cuda_ms:.3f} ms/{count}") - return "; ".join(parts) if parts else "-" - - -def _backward_attribution_gate(profile: dict[str, object]) -> tuple[str, str]: - if profile["side"] != "fabric" or not str(profile["row_id"]).endswith("_train"): - return "-", "not a Fabric training row" - if profile.get("backward_attribution_mode") == "phase_decomposed_probe": - return "probe", "profile-only phase decomposition; active path still requires attribution" - coverage = float(profile.get("backward_attribution_coverage", 0.0)) - specific_owner_ms = 0.0 - for event in profile.get("backward_owner_events", ()): - name = str(event["name"]) - if name in {"fabric.backward.total", "fabric.backward.full_replay_autograd"}: - continue - specific_owner_ms += float(event["cuda_total_us"]) / 1000.0 - for event in profile.get("backward_recompute_events", ()): - specific_owner_ms += float(event["cuda_total_us"]) / 1000.0 - for event in profile.get("backward_derived_owner_events", ()): - specific_owner_ms += float(event["self_cuda_us"]) / 1000.0 - if coverage < 0.9: - return "open", f"backward attribution coverage {coverage:.1%} < 90%" - if specific_owner_ms <= 0.0: - return "open", "only broad full-replay attribution is visible" - return "attributed", "dominant backward owners are profiler-visible" - - -def _backward_physical_gate(profile: dict[str, object]) -> tuple[str, str]: - if profile["side"] != "fabric" or not str(profile["row_id"]).endswith("_train"): - return "-", "not a Fabric training row" - if profile.get("backward_attribution_mode") == "phase_decomposed_probe": - return "probe", "profile-only phase decomposition; active path does not count" - explicit_coverage = float(profile.get("backward_explicit_attribution_coverage", 0.0)) - total_coverage = float(profile.get("backward_attribution_coverage", 0.0)) - derived_coverage = float(profile.get("backward_derived_attribution_coverage", 0.0)) - derived_ms = float(profile.get("backward_derived_owner_attributed_cuda_total_us", 0.0)) / 1000.0 - derived_owners = _backward_event_summary(profile, "backward_derived_owner_events") - wall_ms = float(profile.get("ms", 0.0)) - largest_derived = _largest_backward_event(profile, "backward_derived_owner_events") - largest_derived_wall_share = 0.0 - if largest_derived is not None and wall_ms > 0.0: - largest_derived_wall_share = (float(largest_derived["self_cuda_us"]) / 1000.0) / wall_ms - if explicit_coverage >= 0.9 and derived_ms <= 0.0: - return "physical_owned", f"explicit owner coverage {explicit_coverage:.1%}" - if total_coverage >= 0.9 and derived_ms > 0.0 and largest_derived_wall_share <= 0.055: - return ( - "physical_owned_with_thin_reverse_residuals", - ( - f"total={total_coverage:.1%}, explicit={explicit_coverage:.1%}, derived={derived_coverage:.1%}; " - f"largest residual derived owner is {largest_derived_wall_share:.1%} of wall and remains a named " - "thin-reverse residual" - ), - ) - if derived_ms > 0.0: - return ( - "open", - ( - f"explicit={explicit_coverage:.1%}, derived={derived_coverage:.1%}; " - f"derived replay owners still active: {derived_owners}" - ), - ) - return "open", f"explicit owner coverage {explicit_coverage:.1%} < 90%" - - -def _glue_kernel_ms(profile: dict[str, object], pattern: str) -> str: - for event in profile.get("glue_kernel_events", ()): - if event["name"] == pattern: - return f"{float(event['self_cuda_us']) / 1000.0:.3f} ms/{int(event['count'])}" - return "0.000 ms/0" - - -def _glue_source_ms(profile: dict[str, object], pattern: str) -> str: - for event in profile.get("glue_source_events", ()): - if event["name"] == pattern: - return f"{float(event['cuda_total_us']) / 1000.0:.3f} ms/{int(event['count'])}" - return "0.000 ms/0" - - -def _glue_coverage(profile: dict[str, object]) -> str: - return f"{100.0 * float(profile.get('glue_attribution_coverage', 0.0)):.1f}%" - - -def _glue_gate(profile: dict[str, object]) -> tuple[str, str]: - if profile.get("side") != "fabric": - return "not_applicable", "stack side" - if profile.get("status") != "ok": - return "open", "profile failed" - wall_us = float(profile.get("ms", 0.0)) * 1000.0 - if wall_us <= 0.0: - return "open", "missing wall time" - glue_events = tuple(profile.get("glue_kernel_events", ())) - total_us = float(profile.get("glue_kernel_self_cuda_us", 0.0)) - max_us = max((float(event.get("self_cuda_us", 0.0)) for event in glue_events), default=0.0) - total_ratio = total_us / wall_us - max_ratio = max_us / wall_us - coverage = float(profile.get("glue_attribution_coverage", 0.0)) - if total_ratio < GLUE_TOTAL_ROW_LIMIT and max_ratio < GLUE_SINGLE_SOURCE_ROW_LIMIT: - return ( - "pass_residual_subthreshold", - f"total={100.0 * total_ratio:.1f}%, max={100.0 * max_ratio:.1f}%", - ) - if coverage >= GLUE_ATTRIBUTION_COVERAGE_LIMIT: - return "pass_attributed", f"coverage={100.0 * coverage:.1f}%" - return ( - "open", - (f"coverage={100.0 * coverage:.1f}%, total={100.0 * total_ratio:.1f}%, max={100.0 * max_ratio:.1f}%"), - ) - - -def _launch_shape(record: dict[str, object] | None, prefix: str) -> str: - if record is None: - return "-" - fields = ( - "receiver_tiles", - "batch_tiles", - "edge_tiles", - "hidden_chunks", - "state_receiver_tiles", - "state_batch_tiles", - "state_hidden_chunks", - "state_static_stage_modes", - "emit_receiver_tiles", - "emit_batch_tiles", - "emit_hidden_chunks", - "emit_static_stage_modes", - "public_receiver_tiles", - "public_batch_tiles", - "replication_factors", - "readout_modes", - "cell_static_stage_modes", - ) - return "; ".join(f"{field}={_as_table_value(record[f'{prefix}_launch_{field}'])}" for field in fields) - - -def _dominant_cuda_name(profile: dict[str, object]) -> str: - top_cuda = profile["top_cuda"] - if not top_cuda: - return "-" - first = top_cuda[0] - return str(first["name"]) - - -def _write_markdown(report: dict[str, object], path: Path) -> None: - lines = [ - f"# Fabric Scaling Profile ({report['date']})", - "", - "Fresh current-branch Fabric-vs-stack sweep after the scalable executor rewrite and launch-record tightening.", - "", - f"- Command: `{report['command']}`", - f"- Device: `{report['device']}`", - f"- GPU: `{report['gpu_name']}`", - f"- GPU NVML: `{report['gpu_nvml']}`", - f"- Nsight Systems: `{report['nsys_path']}`", - f"- Nsight Compute: `{report['ncu_path']}`", - f"- Dtype: `{report['dtype']}`", - f"- Warmup: `{report['warmup']}`", - f"- Iterations: `{report['iterations']}`", - f"- Torch profiler warmup/iterations: `{report['profile_warmup']}` / `{report['profile_iterations']}`", - f"- Torch profile resets: `{report.get('profile_use_resets', False)}`", - "", - "## Matched Rows", - "", - "| Row | Family | Params target | Stack params | Fabric params | Match note | Stack shape | Fabric shape |", - "|---|---|---:|---:|---:|---|---|---|", - ] - for sweep in report["sweeps"]: - stack_match = sweep["stack_match"] - fabric_match = sweep["fabric_match"] - stack_shape = f"d_hidden={stack_match['d_hidden']}, layers={stack_match['num_layers']}" - fabric_shape = ( - f"d_hidden={fabric_match['d_hidden']}, shape={tuple(fabric_match['fabric_shape'])}, " - f"cell_hidden={fabric_match['fabric_hidden_size']}" - ) - lines.append( - ( - "| {row} | {family} | {target} | {stack_params} | {fabric_params} | " - "{match_note} | {stack_shape} | {fabric_shape} |" - ).format( - row=sweep["row_id"], - family=sweep["family"], - target=int(sweep["target_params"]), - stack_params=int(stack_match["actual_params"]), - fabric_params=int(fabric_match["actual_params"]), - match_note=sweep["match_note"], - stack_shape=stack_shape, - fabric_shape=fabric_shape, - ) - ) - lines.extend( - [ - "", - "## Benchmark Table", - "", - ( - "| Row | B | Fabric tok/s | Stack tok/s | Fabric/stack | Fabric gain | Stack gain | " - "Requested launch | Actual launch | Temporal | Graph | Plateau note |" - ), - "|---|---:|---:|---:|---:|---:|---:|---|---|---|---|---|", - ] - ) - for sweep in report["sweeps"]: - for row in sweep["rows"]: - record = row["fabric_backend_record"] - temporal = "-" if record is None else _as_table_value(record["launch_temporal_executions"]) - graph = "-" if record is None else str(record["graph_capture_enabled"]) - plateau_note = f"fabric={sweep['fabric_plateau_note']}; stack={sweep['stack_plateau_note']}" - lines.append( - ( - "| {row_id} | {batch} | {fabric_tok} | {stack_tok} | {ratio} | {fabric_gain} | " - "{stack_gain} | {requested} | {actual} | {temporal} | {graph} | {plateau} |" - ).format( - row_id=row["row_id"], - batch=row["batch_size"], - fabric_tok=_as_table_value(row["fabric_tokens_per_s"]), - stack_tok=_as_table_value(row["stack_tokens_per_s"]), - ratio="-" - if row["fabric_stack_ratio"] is None - else f"{100.0 * float(row['fabric_stack_ratio']):.1f}%", - fabric_gain="-" - if row["fabric_gain_from_previous_batch"] is None - else f"{100.0 * float(row['fabric_gain_from_previous_batch']):.1f}%", - stack_gain="-" - if row["stack_gain_from_previous_batch"] is None - else f"{100.0 * float(row['stack_gain_from_previous_batch']):.1f}%", - requested=_launch_shape(record, "requested"), - actual=_launch_shape(record, "actual"), - temporal=temporal, - graph=graph, - plateau=plateau_note, - ) - ) - lines.extend( - [ - "", - "## Large-R Diagnostics", - "", - "| Row | B | Safety modes | Tiling plans | Diagnostics |", - "|---|---:|---|---|---|", - ] - ) - for sweep in report["sweeps"]: - for row in sweep["rows"]: - record = row["fabric_backend_record"] - if record is None: - continue - lines.append( - ("| {row_id} | {batch} | {safety} | {tiling} | {diagnostics} |").format( - row_id=row["row_id"], - batch=row["batch_size"], - safety=_record_values(record, "large_r_safety_modes"), - tiling=_record_values(record, "active_cell_tiling_plans"), - diagnostics=_record_values(record, "large_r_diagnostics"), - ) - ) - lines.extend( - [ - "", - "## Hot-Kernel Summary", - "", - ( - "| Row | Side | B | Resets | Dominant CUDA event | CUDA event launches in top 20 | " - "Runtime launches in top runtime ops | Generic kernels present | Legacy kernels present |" - ), - "|---|---|---:|---|---|---:|---:|---|---|", - ] - ) - for profile in report["profiles"]: - lines.append( - ( - "| {row_id} | {side} | {batch} | {resets} | {cuda} | {cuda_launches} | {runtime_launches} | " - "{generic} | {legacy} |" - ).format( - row_id=profile["row_id"], - side=profile["side"], - batch=profile["batch_size"], - resets=profile.get("use_resets", False), - cuda=_dominant_cuda_name(profile), - cuda_launches=profile["launch_count_cuda_events"], - runtime_launches=profile["launch_count_runtime_events"], - generic=_as_table_value(profile["generic_kernel_patterns_present"]), - legacy=_as_table_value(profile["legacy_kernel_patterns_present"]), - ) - ) - lines.extend( - [ - "", - "## Physical Op Attribution", - "", - ( - "| Row | Side | B | Receiver affine | Diagonal recurrence | Message | Input projection | " - "State epilogue | Public projection | Readout | Lowered state affine |" - ), - "|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|", - ] - ) - for profile in report["profiles"]: - lines.append( - ( - "| {row_id} | {side} | {batch} | {receiver_affine} | {diagonal_recurrence} | {message} | " - "{input_projection} | {state_epilogue} | {public_projection} | {readout} | {state_affine} |" - ).format( - row_id=profile["row_id"], - side=profile["side"], - batch=profile["batch_size"], - receiver_affine=_physical_op_ms(profile, "fabric.physical.receiver_affine"), - diagonal_recurrence=_physical_op_ms(profile, "fabric.physical.diagonal_recurrence"), - message=_physical_op_ms(profile, "fabric.physical.message"), - input_projection=_physical_op_ms(profile, "fabric.physical.input_projection"), - state_epilogue=_physical_op_ms(profile, "fabric.physical.state_epilogue"), - public_projection=_physical_op_ms(profile, "fabric.physical.public_projection"), - readout=_physical_op_ms(profile, "fabric.physical.readout"), - state_affine=_physical_op_ms(profile, "fabric.physical.state_affine"), - ) - ) - lines.extend( - [ - "", - "## Backward Attribution", - "", - ( - "| Row | Side | B | Attribution mode | Total backward | Receiver affine | " - "Message receiver | Message sender | Message params | Message autograd | " - "Tiny message | Sparse message | Grouped projection | " - "Receiver-major projection | Diagonal recurrence | State epilogue | " - "Public projection | Readout | Glue/init recurrent | " - "Full replay | Phase probe | Derived replay owners | Coverage | " - "Explicit coverage | Derived coverage | Attribution gate | Physical ownership gate | Gate reason |" - ), - ( - "|---|---|---:|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|" - "---:|---:|---:|---:|---:|---:|---:|---|---:|---:|---:|---|---|---|" - ), - ] - ) - for profile in report["profiles"]: - gate, gate_reason = _backward_attribution_gate(profile) - physical_gate, physical_gate_reason = _backward_physical_gate(profile) - lines.append( - ( - "| {row_id} | {side} | {batch} | {mode} | {total} | {receiver_affine} | {message_receiver} | " - "{message_sender} | {message_params} | {message_autograd} | {tiny_message} | " - "{sparse_message} | {grouped_projection} | {receiver_major_projection} | " - "{diagonal_recurrence} | {state_epilogue} | {public_projection} | {readout} | " - "{glue_initial} | {full_replay} | {phase_probe} | {derived} | " - "{coverage:.1%} | {explicit_coverage:.1%} | " - "{derived_coverage:.1%} | {gate} | {physical_gate} | {gate_reason}; {physical_gate_reason} |" - ).format( - row_id=profile["row_id"], - side=profile["side"], - batch=profile["batch_size"], - mode=profile.get("backward_attribution_mode", "active"), - total=_backward_owner_ms(profile, "fabric.backward.total"), - receiver_affine=_backward_owner_ms(profile, "fabric.backward.receiver_affine"), - message_receiver=_backward_owner_ms(profile, "fabric.backward.message.receiver"), - message_sender=_backward_owner_ms(profile, "fabric.backward.message.sender"), - message_params=_backward_owner_ms(profile, "fabric.backward.message.query_param"), - message_autograd=_backward_owner_ms(profile, "fabric.backward.message.autograd"), - tiny_message=_backward_owner_ms(profile, "fabric.backward.tiny_message_superop"), - sparse_message=_backward_owner_ms(profile, "fabric.backward.sparse_message_superop"), - grouped_projection=_backward_owner_ms(profile, "fabric.backward.grouped_projection"), - receiver_major_projection=_backward_owner_ms(profile, "fabric.backward.receiver_major_projection"), - diagonal_recurrence=_backward_owner_ms(profile, "fabric.backward.diagonal_recurrence"), - state_epilogue=_backward_owner_ms(profile, "fabric.backward.state_epilogue"), - public_projection=_backward_owner_ms(profile, "fabric.backward.public_projection"), - readout=_backward_owner_ms(profile, "fabric.backward.readout"), - glue_initial=_backward_owner_ms(profile, "fabric.backward.glue.initial_recurrent"), - full_replay=_backward_owner_ms(profile, "fabric.backward.full_replay_autograd"), - phase_probe=_backward_owner_ms(profile, "fabric.backward.phase_decomposed_probe"), - derived=_backward_event_summary(profile, "backward_derived_owner_events"), - coverage=float(profile.get("backward_attribution_coverage", 0.0)), - explicit_coverage=float(profile.get("backward_explicit_attribution_coverage", 0.0)), - derived_coverage=float(profile.get("backward_derived_attribution_coverage", 0.0)), - gate=gate, - physical_gate=physical_gate, - gate_reason=gate_reason, - physical_gate_reason=physical_gate_reason, - ) - ) - lines.extend( - [ - "", - "## Backward Owner Priority", - "", - ( - "| Row | Side | B | Explicit coverage | Derived coverage | Largest open owner | Owner self-CUDA | " - "Next owner family |" - ), - "|---|---|---:|---:|---:|---|---:|---|", - ] - ) - for profile in report["profiles"]: - owner, owner_ms, family = _backward_next_owner(profile) - lines.append( - ( - "| {row_id} | {side} | {batch} | {explicit:.1%} | {derived:.1%} | {owner} | {owner_ms} | {family} |" - ).format( - row_id=profile["row_id"], - side=profile["side"], - batch=profile["batch_size"], - explicit=float(profile.get("backward_explicit_attribution_coverage", 0.0)), - derived=float(profile.get("backward_derived_attribution_coverage", 0.0)), - owner=owner, - owner_ms=owner_ms, - family=family, - ) - ) - lines.extend( - [ - "", - "## Backward Derived Source Detail", - "", - "| Row | Side | B | Top derived source events |", - "|---|---|---:|---|", - ] - ) - for profile in report["profiles"]: - lines.append( - "| {row_id} | {side} | {batch} | {sources} |".format( - row_id=profile["row_id"], - side=profile["side"], - batch=profile["batch_size"], - sources=_backward_event_summary(profile, "backward_derived_owner_source_events"), - ) - ) - lines.extend( - [ - "", - "## Backward Metadata", - "", - ( - "| Row | B | Backward physical ops | Executors | Demotions | Boundary contracts | Layout | Tape | " - "Recompute | Recompute physical ops | Workspace peak | Workspace aliases | Launch counts | " - "Saved launch counts | Small cuBLAS | Copy glue | Derived replay owners | " - "Unattributed top | Residual glue demotions |" - ), - "|---|---:|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|", - ] - ) - for profile in report["profiles"]: - if profile["side"] != "fabric" or profile["status"] != "ok": - continue - record = profile["backend_record"] - lines.append( - ( - "| {row_id} | {batch} | {ops} | {executors} | {demotions} | {boundaries} | {layout} | {tape} | " - "{recompute} | {recompute_ops} | {workspace_peak} | {workspace_aliases} | {launch_counts} | " - "{saved_launch_counts} | {small_cublas} | {copy_glue} | {derived} | {unattributed} | {residual_glue} |" - ).format( - row_id=profile["row_id"], - batch=profile["batch_size"], - ops=_record_values(record, "backward_physical_op_kinds"), - executors=_record_values(record, "backward_physical_op_executors"), - demotions=_record_values(record, "backward_physical_op_demotions"), - boundaries=_record_values(record, "backward_boundary_contracts"), - layout=_record_values(record, "backward_layout_mode"), - tape=_record_values(record, "backward_tape_mode"), - recompute=_record_values(record, "backward_recompute_mode"), - recompute_ops=_backward_recompute_summary(profile), - workspace_peak=_record_values(record, "backward_workspace_peak_bytes"), - workspace_aliases=_record_values(record, "backward_workspace_aliases"), - launch_counts=_record_values(record, "backward_launch_counts"), - saved_launch_counts=_record_values(record, "backward_saved_launch_counts"), - small_cublas=_backward_event_summary(profile, "backward_small_cublas_events"), - copy_glue=_backward_event_summary(profile, "backward_copy_glue_events"), - derived=_backward_event_summary(profile, "backward_derived_owner_events"), - unattributed=_backward_event_summary(profile, "backward_unattributed_events"), - residual_glue=_record_values(record, "backward_residual_glue_demotions"), - ) - ) - lines.extend( - [ - "", - "## Message Summary", - "", - ( - "| Row | B | Buckets | Kind/topology | Regular/sparse | Degree/ragged/demoted | " - "Batched/grouped | Reset-aware | Packed-source reuse | Workspace mode | " - "Message workspace peak bytes | Workspace aliases | Per-bucket workspace | Phase launches | " - "Small cuBLAS launches | Copy glue launches | Saved copy launches | State epilogue | " - "Saved state launches | Coalescing | Glue fusion | " - "Old monolithic message self-CUDA | New aggregate self-CUDA | Dense message self-CUDA |" - ), - "|---|---:|---:|---|---|---|---|---:|---:|---|---:|---|---|---|---|---|---|---|---|---|---:|---|", - ] - ) - for profile in report["profiles"]: - if profile["side"] != "fabric" or profile["status"] != "ok": - continue - record = profile["backend_record"] - dense_message_ms = ( - f"tiny_direct={_message_kernel_ms(profile, 'regular_local_tiny_message_projected_kernel')}; " - f"tiny_rowgroup8={_message_kernel_ms(profile, 'regular_local_tiny_message_projected_rowgroup8_kernel')}; " - f"keys={_message_kernel_ms(profile, 'pack_regular_local_message_keys_kernel')}; " - f"sparse_keys={_message_kernel_ms(profile, 'pack_sparse_message_keys_kernel')}; " - f"ragged_keys={_message_kernel_ms(profile, 'pack_ragged_sparse_message_keys_kernel')}; " - f"softmax={_message_kernel_ms(profile, 'regular_local_message_softmax_inplace_kernel')}; " - f"sparse_softmax={_message_kernel_ms(profile, 'sparse_message_softmax_inplace_kernel')}; " - f"ragged_softmax={_message_kernel_ms(profile, 'ragged_sparse_message_softmax_inplace_kernel')}; " - f"values={_message_kernel_ms(profile, 'pack_regular_local_message_values_kernel')}; " - f"sparse_values={_message_kernel_ms(profile, 'pack_sparse_message_values_kernel')}; " - f"ragged_values={_message_kernel_ms(profile, 'pack_ragged_sparse_message_values_kernel')}; " - f"scatter={_message_kernel_ms(profile, 'scatter_receiver_major_message_kernel')}" - ) - lines.append( - ( - "| {row_id} | {batch} | {bucket_count} | {kind_topology} | {regular_sparse} | " - "{degree_ragged_demoted} | {batched_grouped} | {reset_aware} | {reuse_count} | {workspace_mode} | " - "{workspace_peak} | {aliases} | {per_bucket_workspace} | {phase_launches} | " - "{small_cublas_launches} | {copy_glue_launches} | {copy_glue_saved_launches} | " - "{state_epilogue_modes} | {state_epilogue_saved_launches} | {coalescing_modes} | " - "{glue_fusion_modes} | {old_message} | {new_aggregate} | {dense_message} |" - ).format( - row_id=profile["row_id"], - batch=profile["batch_size"], - bucket_count=_first_record_value(record, "message_bucket_count"), - kind_topology=( - f"{_record_values(record, 'message_bucket_kinds')}/" - f"{_record_values(record, 'message_topology_kinds')}" - ), - regular_sparse=( - f"{_first_record_value(record, 'message_regular_local_bucket_count')}/" - f"{_first_record_value(record, 'message_sparse_bucket_count')}" - ), - degree_ragged_demoted=( - f"{_first_record_value(record, 'message_degree_uniform_bucket_count')}/" - f"{_first_record_value(record, 'message_ragged_grouped_bucket_count')}/" - f"{_first_record_value(record, 'message_demoted_bucket_count')}" - ), - batched_grouped=( - f"{_first_record_value(record, 'message_batched_backend_count')}/" - f"{_first_record_value(record, 'message_grouped_backend_count')}" - ), - reset_aware=_first_record_value(record, "message_reset_aware_bucket_count"), - reuse_count=_first_record_value(record, "message_packed_source_reuse_count"), - workspace_mode=_first_record_value(record, "message_workspace_mode"), - workspace_peak=_first_record_value(record, "message_workspace_peak_bytes"), - aliases=_record_values(record, "message_workspace_aliases"), - per_bucket_workspace=_record_values(record, "message_per_bucket_workspace_bytes"), - phase_launches=_record_values(record, "phase_launch_counts"), - small_cublas_launches=_record_values(record, "small_cublas_launch_counts"), - copy_glue_launches=_record_values(record, "copy_glue_launch_counts"), - copy_glue_saved_launches=_record_values(record, "copy_glue_saved_launch_counts"), - state_epilogue_modes=_record_values(record, "state_epilogue_modes"), - state_epilogue_saved_launches=_record_values(record, "state_epilogue_saved_launch_counts"), - coalescing_modes=_record_values(record, "launch_coalescing_modes"), - glue_fusion_modes=_record_values(record, "generic_glue_fusion_modes"), - old_message=OLD_MESSAGE_SELF_CUDA_MS.get(str(profile["row_id"]), "-"), - new_aggregate=_message_kernel_ms(profile, "receiver_message_aggregate_kernel"), - dense_message=dense_message_ms, - ) - ) - lines.extend( - [ - "", - "## Glue Summary", - "", - ( - "| Row | Side | B | aten::copy_ | DtoD memcpy | elementwise copy | add bias | append bias feature | " - "append bias weight | copy/pad | split | cat | fill | Metadata copy glue | Metadata bias glue |" - ), - "|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---|---|", - ] - ) - for profile in report["profiles"]: - record = profile["backend_record"] - lines.append( - ( - "| {row_id} | {side} | {batch} | {copy} | {dto_d} | {elementwise_copy} | {add_bias} | " - "{append_bias_feature} | {append_bias_weight} | {copy_pad} | {split} | {cat} | {fill} | " - "{copy_glue} | {bias_glue} |" - ).format( - row_id=profile["row_id"], - side=profile["side"], - batch=profile["batch_size"], - copy=_glue_kernel_ms(profile, "aten::copy_"), - dto_d=_glue_kernel_ms(profile, "Memcpy DtoD"), - elementwise_copy=_glue_kernel_ms(profile, "direct_copy_kernel_cuda"), - add_bias=_glue_kernel_ms(profile, "add_receiver_bias_kernel"), - append_bias_feature=_glue_kernel_ms(profile, "append_bias_feature_kernel"), - append_bias_weight=_glue_kernel_ms(profile, "append_bias_weight_kernel"), - copy_pad=_glue_kernel_ms(profile, "receiver_major_copy_or_pad_kernel"), - split=_glue_kernel_ms(profile, "receiver_major_split_last_dim_kernel"), - cat=_glue_kernel_ms(profile, "aten::cat"), - fill=_glue_kernel_ms(profile, "aten::fill_"), - copy_glue=_record_values(record, "copy_glue_launch_counts"), - bias_glue=_record_values(record, "standalone_bias_kernel_count"), - ) - ) - lines.extend( - [ - "", - "## Glue Attribution Coverage", - "", - ( - "| Row | Side | B | Deduped glue self-CUDA | Parent-attributed glue self-CUDA | Coverage | " - "Glue gate | Gate reason | Top source 1 | Top source 2 | Top source 3 |" - ), - "|---|---|---:|---:|---:|---:|---|---|---:|---:|---:|", - ] - ) - for profile in report["profiles"]: - gate, gate_reason = _glue_gate(profile) - source_events = sorted( - ( - event - for event in profile.get("glue_source_events", ()) - if event["name"] in GLUE_COVERAGE_SOURCE_PATTERNS - ), - key=lambda event: float(event["cuda_total_us"]), - reverse=True, - ) - top_sources = [ - f"{event['name']} {float(event['cuda_total_us']) / 1000.0:.3f} ms/{int(event['count'])}" - for event in source_events[:3] - ] - while len(top_sources) < 3: - top_sources.append("-") - lines.append( - ( - "| {row_id} | {side} | {batch} | {kernel_ms:.3f} ms | {source_ms:.3f} ms | {coverage} | " - "{gate} | {gate_reason} | {source0} | {source1} | {source2} |" - ).format( - row_id=profile["row_id"], - side=profile["side"], - batch=profile["batch_size"], - kernel_ms=float(profile.get("glue_kernel_self_cuda_us", 0.0)) / 1000.0, - source_ms=float(profile.get("glue_source_cuda_total_us", 0.0)) / 1000.0, - coverage=_glue_coverage(profile), - gate=gate, - gate_reason=gate_reason, - source0=top_sources[0], - source1=top_sources[1], - source2=top_sources[2], - ) - ) - lines.extend( - [ - "", - "## Glue Source Detail", - "", - ( - "| Row | Side | B | Launch prep state | Tensor table pack | Topology build | " - "Static materialization | State init | Readout prep | " - "Launch prep params | Initial public | " - "Persistent init copy | Message/input/readout prep | Public tree/topology prep | " - "Default resets | Normalized resets | Next-state cat | Output cat | Runtime sparse clone | " - "Dense bias add | Bias input pack | Bias weight pack | Reset source pack | Copy/pad | Split | " - "Public KV workspace | Ragged message zero | Degree ptr CPU |" - ), - ( - "|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|" - "---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|" - ), - ] - ) - for profile in report["profiles"]: - lines.append( - ( - "| {row_id} | {side} | {batch} | {state} | {tensor_pack} | {topology_build} | " - "{static_materialization} | {state_init} | {readout_prep} | {params} | {initial_public} | " - "{persistent_copy} | {message_prep} | {topology_prep} | {default_resets} | {normalized_resets} | " - "{next_state_cat} | {output_cat} | {sparse_clone} | {bias_add} | {bias_input} | " - "{bias_weight} | {reset_pack} | {copy_pad} | {split} | {public_kv_workspace} | " - "{ragged_zero} | {degree_ptr} |" - ).format( - row_id=profile["row_id"], - side=profile["side"], - batch=profile["batch_size"], - state=_glue_source_ms(profile, "fabric.glue.launch_state_tree_contiguous"), - tensor_pack=( - f"pack={_glue_source_ms(profile, 'fabric.glue.tensor_table_pack')}; " - f"empty={_glue_source_ms(profile, 'fabric.glue.empty_tensor_table_pack')}" - ), - topology_build=( - f"local={_glue_source_ms(profile, 'fabric.glue.build_local_topology')}; " - f"sparse={_glue_source_ms(profile, 'fabric.glue.build_sparse_topology')}" - ), - static_materialization=( - f"population={_glue_source_ms(profile, 'fabric.glue.population_param_materialization')}; " - f"contig={_glue_source_ms(profile, 'fabric.glue.static_tensor_contiguous')}; " - f"cat={_glue_source_ms(profile, 'fabric.glue.static_tensor_cat')}; " - f"collect={_glue_source_ms(profile, 'fabric.glue.collect_cell_tensors_contiguous')}; " - f"to_backend={_glue_source_ms(profile, 'fabric.glue.population_to_backend_state_contiguous')}; " - f"to_pop={_glue_source_ms(profile, 'fabric.glue.backend_to_population_state_contiguous')}" - ), - state_init=( - f"runtime={_glue_source_ms(profile, 'fabric.glue.runtime_init_state')}; " - f"backend={_glue_source_ms(profile, 'fabric.glue.backend_population_state_zero')}" - ), - readout_prep=( - f"args={_glue_source_ms(profile, 'fabric.glue.output_readout_arg_contiguous')}; " - f"step={_glue_source_ms(profile, 'fabric.glue.output_readout_step_index')}" - ), - params=( - f"cell={_glue_source_ms(profile, 'fabric.glue.launch_cell_param_contiguous')}; " - f"in={_glue_source_ms(profile, 'fabric.glue.launch_input_projection_param_contiguous')}; " - f"pub={_glue_source_ms(profile, 'fabric.glue.launch_public_projection_param_contiguous')}" - ), - initial_public=( - f"contig={_glue_source_ms(profile, 'fabric.glue.launch_initial_public_contiguous')}; " - f"project={_glue_source_ms(profile, 'fabric.glue.launch_initial_public_project_from_hidden')}; " - f"dense_contig={_glue_source_ms(profile, 'fabric.glue.dense_affine_python_contiguous')}; " - f"zero={_glue_source_ms(profile, 'fabric.glue.launch_initial_public_zero_kv')}" - ), - persistent_copy=_glue_source_ms(profile, "fabric.glue.launch_persistent_scan_initial_copy"), - message_prep=( - f"msg={_glue_source_ms(profile, 'fabric.glue.launch_message_param_contiguous')}; " - f"ports={_glue_source_ms(profile, 'fabric.glue.launch_input_ports_contiguous')}; " - f"readout={_glue_source_ms(profile, 'fabric.glue.launch_readout_param_contiguous')}" - ), - topology_prep=( - f"public={_glue_source_ms(profile, 'fabric.glue.launch_public_tree_contiguous')}; " - f"routing={_glue_source_ms(profile, 'fabric.glue.launch_routing_tensor_contiguous')}; " - f"topology={_glue_source_ms(profile, 'fabric.glue.launch_topology_contiguous')}; " - f"resets={_glue_source_ms(profile, 'fabric.glue.launch_resets_contiguous')}" - ), - default_resets=( - f"zero={_glue_source_ms(profile, 'fabric.glue.default_resets_zero')}; " - f"to_u8={_glue_source_ms(profile, 'fabric.glue.reset_mask_to_u8')}" - ), - normalized_resets=_glue_source_ms(profile, "fabric.glue.normalized_population_resets"), - next_state_cat=_glue_source_ms(profile, "fabric.glue.materialize_next_state_cat"), - output_cat=( - f"sequence={_glue_source_ms(profile, 'fabric.glue.output_sequence_cat')}; " - f"grouped_kv={_glue_source_ms(profile, 'fabric.glue.grouped_kv_weight_cat')}" - ), - sparse_clone=_glue_source_ms(profile, "fabric.glue.runtime_sparse_message_clone"), - bias_add=_glue_source_ms(profile, "fabric.glue.dense_affine_bias_add"), - bias_input=_glue_source_ms(profile, "fabric.glue.dense_affine_bias_input_pack"), - bias_weight=_glue_source_ms(profile, "fabric.glue.dense_affine_bias_weight_pack"), - reset_pack=_glue_source_ms(profile, "fabric.glue.dense_affine_reset_source_pack"), - copy_pad=_glue_source_ms(profile, "fabric.glue.receiver_major_copy_or_pad"), - split=_glue_source_ms(profile, "fabric.glue.receiver_major_split_last_dim"), - public_kv_workspace=_glue_source_ms(profile, "fabric.glue.public_projection_kv_workspace"), - ragged_zero=_glue_source_ms(profile, "fabric.glue.message_ragged_output_zero"), - degree_ptr=_glue_source_ms(profile, "fabric.glue.message_degree_ptr_cpu"), - ) - ) - lines.extend(["", "## Profiler Details", ""]) - for profile in report["profiles"]: - lines.extend( - [ - f"### {profile['row_id']} {profile['side']} B={profile['batch_size']}", - "", - "Top self-CUDA events:", - "", - ] - ) - for event in profile["top_cuda"]: - lines.append( - f"- `{event['name']}`: self CUDA `{event['self_cuda_us'] / 1000.0:.3f} ms`, count `{event['count']}`" - ) - lines.extend(["", "Top self-CPU/runtime events:", ""]) - for event in profile["top_cpu"]: - lines.append( - f"- `{event['name']}`: self CPU `{event['self_cpu_us'] / 1000.0:.3f} ms`, count `{event['count']}`" - ) - lines.extend(["", "Copy/cat/clone/zero/launch events:", ""]) - for event in profile["runtime_ops"]: - lines.append( - f"- `{event['name']}`: self CUDA `{event['self_cuda_us'] / 1000.0:.3f} ms`, " - f"self CPU `{event['self_cpu_us'] / 1000.0:.3f} ms`, count `{event['count']}`" - ) - lines.extend(["", "Glue attribution events:", ""]) - for event in profile.get("glue_kernel_events", ()): - lines.append( - f"- `{event['name']}`: self CUDA `{event['self_cuda_us'] / 1000.0:.3f} ms`, " - f"CUDA total `{event['cuda_total_us'] / 1000.0:.3f} ms`, count `{event['count']}`" - ) - lines.extend(["", "Glue source events:", ""]) - for event in profile.get("glue_source_events", ()): - lines.append( - f"- `{event['name']}`: self CUDA `{event['self_cuda_us'] / 1000.0:.3f} ms`, " - f"CUDA total `{event['cuda_total_us'] / 1000.0:.3f} ms`, count `{event['count']}`" - ) - lines.extend(["", "Physical op attribution events:", ""]) - for event in profile.get("physical_op_events", ()): - lines.append( - f"- `{event['name']}`: self CUDA `{event['self_cuda_us'] / 1000.0:.3f} ms`, " - f"CUDA total `{event['cuda_total_us'] / 1000.0:.3f} ms`, count `{event['count']}`" - ) - lines.append("") - lines.extend( - [ - "## Initial Diagnosis", - "", - "This section is intentionally data-derived and should be tightened after reviewing Nsight outputs.", - "Use the benchmark gains, top CUDA events, and launch records above as the source of truth.", - ] - ) - path.write_text("\n".join(lines) + "\n") - - -def _run_sweep(args: argparse.Namespace) -> int: - device = torch.device(args.device) - if device.type == "cuda": - torch.cuda.set_device(device) - dtype = torch.float32 - batches = _parse_batches(args.batches) - row_ids = tuple(_row_by_id(row_id) for row_id in args.rows.split(",") if row_id) - fabric_hidden_grid = _parse_optional_batches(args.fabric_hidden_sizes) - sweeps = [ - _run_sweep_row( - row=row, - device=device, - dtype=dtype, - warmup=args.warmup, - iterations=args.iterations, - batches=batches, - max_batch=args.max_batch, - seed=args.seed + row_index, - fabric_hidden_grid=fabric_hidden_grid, - ) - for row_index, row in enumerate(row_ids) - ] - profiles = ( - () - if args.skip_torch_profile - else _run_profiles( - rows=row_ids, - profile_batches=_parse_batches(args.profile_batches), - device=device, - dtype=dtype, - warmup=args.profile_warmup, - iterations=args.profile_iterations, - seed=args.seed + 10_000, - backward_attribution_mode=args.backward_attribution_mode, - use_resets=args.profile_use_resets, - fabric_hidden_grid=fabric_hidden_grid, - ) - ) - out_prefix = Path(args.out_prefix) - report: dict[str, object] = { - "date": args.date, - "command": " ".join(sys.argv), - "device": str(device), - "gpu_name": _device_name(device), - "gpu_nvml": _nvml_device_summary(device), - "nsys_path": _required_tool_path("nsys"), - "ncu_path": _required_tool_path("ncu"), - "dtype": _dtype_name(dtype), - "warmup": args.warmup, - "iterations": args.iterations, - "profile_warmup": args.profile_warmup, - "profile_iterations": args.profile_iterations, - "backward_attribution_mode": args.backward_attribution_mode, - "profile_use_resets": args.profile_use_resets, - "fabric_hidden_sizes": fabric_hidden_grid, - "seed": args.seed, - "cuda_visible_devices": os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else None, - "sweeps": sweeps, - "profiles": profiles, - "legacy_kernel_patterns": LEGACY_KERNEL_PATTERNS, - "generic_fabric_kernel_patterns": GENERIC_FABRIC_KERNEL_PATTERNS, - } - json_path = out_prefix.with_suffix(".json") - md_path = out_prefix.with_suffix(".md") - json_path.parent.mkdir(parents=True, exist_ok=True) - json_path.write_text(json.dumps(report, indent=2)) - _write_markdown(report, md_path) - print(json_path) - print(md_path) - if args.require_glue_gate: - open_rows = [ - f"{profile['row_id']}/{profile['side']}/B={profile['batch_size']}:{reason}" - for profile in profiles - for gate, reason in (_glue_gate(profile),) - if gate == "open" - ] - if open_rows: - print("Glue gate failed: " + "; ".join(open_rows), file=sys.stderr) - return 1 - return 0 - - -def main() -> int: - parser = argparse.ArgumentParser(description="Run current-branch Fabric scaling/profile sweep.") - parser.add_argument("--action", choices=("sweep", "single-run"), default="sweep") - parser.add_argument("--date", default="2026-04-13") - parser.add_argument("--device", default="cuda:0") - parser.add_argument("--seed", type=int, default=1234) - parser.add_argument("--warmup", type=int, default=2) - parser.add_argument("--iterations", type=int, default=5) - parser.add_argument("--profile-warmup", type=int, default=1) - parser.add_argument("--profile-iterations", type=int, default=1) - parser.add_argument("--batches", default=",".join(str(batch) for batch in BASE_BATCHES)) - parser.add_argument("--max-batch", type=int, default=16_384) - parser.add_argument("--profile-batches", default="512,4096") - parser.add_argument( - "--fabric-hidden-sizes", - help=( - "Optional comma-separated Fabric cell hidden sizes for stress/sanity sweeps. " - "Omit for the default h=32 closure baseline." - ), - ) - parser.add_argument( - "--profile-use-resets", - action="store_true", - help="Use deterministic batch-row resets during torch-profiled rows; throughput sweeps remain no-reset.", - ) - parser.add_argument("--skip-torch-profile", action="store_true") - parser.add_argument("--rows", default=",".join(row.row_id for row in DEFAULT_ROWS)) - parser.add_argument( - "--out-prefix", - default="/workspace/metta/docs/user/subho/fabric_benchmark/results/profiles/fabric_scaling_profile_2026-04-13", - ) - parser.add_argument("--single-row", default="slstm_2m_t32_forward") - parser.add_argument("--single-side", choices=("stack", "fabric"), default="fabric") - parser.add_argument("--single-batch", type=int, default=4096) - parser.add_argument("--single-use-resets", action="store_true") - parser.add_argument( - "--backward-attribution-mode", - choices=( - "active", - "phase_decomposed_probe", - "state_public_output_probe", - "state_public_state_probe", - "full_replay_boundary_probe", - "full_replay_no_parameter_probe", - "full_replay_without_boundary_inputs_probe", - "full_replay_no_parameter_boundary_probe", - ), - default="active", - help=( - "Use probe modes only for backward attribution profiling. They do not change the default training " - "measurement path and do not close physical backward ownership by themselves." - ), - ) - parser.add_argument("--require-glue-gate", action="store_true") - args = parser.parse_args() - if args.action == "single-run": - _single_run( - row=_row_by_id(args.single_row), - side=args.single_side, - batch_size=args.single_batch, - device=torch.device(args.device), - dtype=torch.float32, - warmup=args.warmup, - iterations=args.iterations, - seed=args.seed, - use_resets=args.single_use_resets, - fabric_hidden_grid=_parse_optional_batches(args.fabric_hidden_sizes), - ) - return 0 - return _run_sweep(args) - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/validate_fabric_generated_catalogs.py b/scripts/validate_fabric_generated_catalogs.py new file mode 100644 index 00000000..fbde3ebc --- /dev/null +++ b/scripts/validate_fabric_generated_catalogs.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +"""Validate Fabric compiler-generated C++ catalog headers.""" + +from __future__ import annotations + +import argparse +import sys +from dataclasses import dataclass +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[1] +SRC_ROOT = REPO_ROOT / "src" +if str(SRC_ROOT) not in sys.path: + sys.path.insert(0, str(SRC_ROOT)) + +from cortical.fabric.backend.message_rules import ( # noqa: E402 + message_rule_lowering_catalog_header_text, + validate_message_rule_lowering_catalog_header, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.native_callables import ( # noqa: E402 + temporal_native_callable_generated_header_text, + validate_temporal_native_callable_generated_header, +) + + +@dataclass(frozen=True) +class GeneratedCatalog: + label: str + path: Path + expected_text: str + + def validate(self) -> None: + actual = self.path.read_text(encoding="utf-8") + if self.label == "message-rule-lowering": + validate_message_rule_lowering_catalog_header(actual) + elif self.label == "temporal-native-callables": + validate_temporal_native_callable_generated_header(actual) + else: + if actual != self.expected_text: + raise RuntimeError(f"Unknown generated catalog {self.label!r} is out of sync") + + def write(self) -> bool: + previous = self.path.read_text(encoding="utf-8") if self.path.exists() else "" + if previous == self.expected_text: + return False + self.path.write_text(self.expected_text, encoding="utf-8") + return True + + +def _catalogs() -> tuple[GeneratedCatalog, ...]: + return ( + GeneratedCatalog( + label="message-rule-lowering", + path=REPO_ROOT / "src/cortical/fabric/backend/cuda/nn/message_rule_lowering_catalog.cuh", + expected_text=message_rule_lowering_catalog_header_text(), + ), + GeneratedCatalog( + label="temporal-native-callables", + path=REPO_ROOT + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/" + / "flat_bucket_registered_native_callables.cuh", + expected_text=temporal_native_callable_generated_header_text(), + ), + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--write", + action="store_true", + help="Rewrite generated catalog headers instead of only checking them.", + ) + return parser.parse_args() + + +def main() -> int: + args = _parse_args() + changed: list[str] = [] + for catalog in _catalogs(): + if args.write: + if catalog.write(): + changed.append(str(catalog.path.relative_to(REPO_ROOT))) + else: + catalog.validate() + if args.write: + if changed: + print("Updated generated Fabric catalog headers:") + for path in changed: + print(f" {path}") + else: + print("Generated Fabric catalog headers are already up to date.") + else: + print("Generated Fabric catalog headers are up to date.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/skills/cb.fabric-backend-boundaries/SKILL.md b/skills/cb.fabric-backend-boundaries/SKILL.md index 01cc66f7..991d7e4f 100644 --- a/skills/cb.fabric-backend-boundaries/SKILL.md +++ b/skills/cb.fabric-backend-boundaries/SKILL.md @@ -1,6 +1,6 @@ --- name: cb.fabric-backend-boundaries -description: Use when modifying Cortical Fabric CUDA backend execution, fabric.cuda.nn declarations, physical operators, physical super-ops, backward physical operators, or anything that risks mixing Fabric cell semantics with backend scheduling/execution ownership. +description: Use when modifying Cortical Fabric CUDA backend execution, fabric.cuda.nn declarations, graph/topology declarations, physical operators, physical super-ops, backward physical operators, compiler rows/bindings, or anything that risks mixing Fabric semantics with backend scheduling/execution ownership. --- # Fabric Backend Boundaries @@ -13,6 +13,33 @@ Use this skill before touching Fabric CUDA/backend physical execution. It is an - `src/cortical/fabric/README.md` - Current Fabric CUDA lowering, planner, runtime, tests, and benchmark/profile artifacts relevant to the requested change. +- For ambiguous Fabric requests or the user's analyze -> plan -> proceed loop, first read + `skills/cb.fabric-workflow-router/SKILL.md` to classify the lane before editing. +- For non-trivial compiler/backend changes, also read `skills/cb.fabric-compiler-boundary-audit/SKILL.md` and use it as + the preflight/closeout checklist. +- For user-visible graph, cell, message, readout, primitive, reset/init, or Config/Blueprint semantic changes, also read + `skills/cb.fabric-declaration-onboarding/SKILL.md`. +- For deliberate pre-throughput semantic stress tests or formula perturbations, also read + `skills/cb.fabric-compiler-stress-test/SKILL.md` and keep throughput work deferred. +- For new primitive ops, message/readout rules, transition primitives, tensor bindings, native callables, parameter + reducers, or throughput executor strategies, also read `skills/cb.fabric-compiler-extension/SKILL.md`. +- For parameter reducers, reverse span outputs, runtime-buffer lifetimes, workspace ownership, or native tensors consumed + by reducers, also read `skills/cb.fabric-reducer-liveness/SKILL.md`. +- For new primitive ops specifically, also read `skills/cb.fabric-primitive-op-onboarding/SKILL.md`. +- For readout rules, output routes, output-boundary behavior, pooling, or readout reducers, also read + `skills/cb.fabric-readout-rule-onboarding/SKILL.md`. +- For throughput executor strategies specifically, also read `skills/cb.fabric-throughput-strategy/SKILL.md`. +- For CUDA/Triton/C++ native strategy bodies, fused program kernels, native callable binding schemas, or kernel ABIs, + also read `skills/cb.fabric-native-strategy-onboarding/SKILL.md`. +- For graph/topology declarations, graph constructors, lattice/config cleanup, boundary/port semantics, or graph-driven + planner legality, also read `skills/cb.fabric-graph-onboarding/SKILL.md`. +- For public API, Config, Blueprint, compatibility constructor, or declaration-normalization cleanup, also read + `skills/cb.fabric-public-api-cleanup/SKILL.md`. +- For public model/runtime handoff, input/output adapters, boundary tensor prep, state initialization, reset + normalization, sender K/V setup, or pre-registered-entry throughput owners, also read + `skills/cb.fabric-runtime-front-end-handoff/SKILL.md`. +- For source/static guardrails, no-fallback/no-compat checks, or legacy deletion checks, also read + `skills/cb.fabric-boundary-guardrails/SKILL.md`. - Prior Codex chats for durable lessons and rejected approaches: ```bash @@ -24,6 +51,60 @@ Do not use docs, configs, or chat history as proof of current support. Verify cu ## Design Contract +### Fabric Compiler Contract + +- Fabric is a compiler for declared Fabric programs. The public API should feel PyTorch-style to users, but the backend + contract is compiler-style: declarations lower into IR; IR lowers into op/tensor tables and parameter bindings; + primitive executors run the lowered program; the temporal engine schedules those primitive rows over time. +- This compiler contract applies equally to graph construction, message passing, cell transition math, recurrence, + projection, normalization, readout, reset, output materialization, and backward adjoints. There is no special exemption + for message math or currently supported cell families. +- The backend must execute the declared program. It must never substitute, reinterpret, or silently collapse a user + declaration into the current built-in dot-product message rule, gated recurrence, diagonal recurrence, fixed readout, + or any other hidden canonical route. +- A path only counts as generic when supported declarations lower into executable primitive programs. A generic-looking + Python/C++ API, metadata wrapper, or `fabric.cuda.nn` name is a facade unless the active route is: + `declaration -> IR -> primitive op rows -> tensor-table roles -> parameter bindings -> primitive executor`. +- Composite primitives are valid compiler units when the declaration names them explicitly and the active route still + follows the compiler chain. `gated_logspace_recurrence` and `diag_rtu` may be supported as composite recurrence + primitives; they are not required to decompose into every scalar/elementwise op first. The composite implementation is + valid only inside the primitive executor selected by the lowered op row, with an explicit tensor-role ABI, forward + executor, backward executor, recompute/tape policy, parameter-gradient binding, and fail-closed unsupported-op behavior. +- Composite recurrence declarations should remain user-readable and concise: ordinary primitives such as `linear`, + `matmul`, and `norm_or_identity` stay separate, while only the true recurrence formula is a composite like + `gated_logspace_recurrence` or `diag_rtu`. Do not hide projections, normalization, activation selection, state + schemas, or output projection behind a broad cell-family bundle. If a composite needs options such as activation kind + or trace policy, those options must be explicit op inputs/attributes in the lowered row, not side-channel runtime + fields inferred from a cell name. +- Unsupported declarations fail closed at lowering or executor selection with a clear unsupported-op reason. If an op is + missing, add it to `fabric.cuda.nn` and carry it through IR, primitive rows, tensor roles, parameter bindings, and + executors. Do not bypass the compiler by calling private kernels, direct runtime helpers, cell bundles, or + metadata-only wrappers. +- The temporal engine owns time, `T*K`, horizon, reset scheduling, checkpoint/recompute, materialization, dependency + ordering, and flat-bucket traversal. It must not own primitive semantics such as Q/K/V, attention logits, gated + recurrence, diagonal recurrence, layernorm, projection formulas, activation formulas, or readout formulas. Those names + may appear only inside the corresponding primitive declaration, lowering, and primitive executor. +- Public runtime/model handoff owns API preparation only. It may prepare adapter outputs, boundary tensors, initial + state, reset masks, and sender K/V inputs for compiler rows, but it must not select strategies, allocate backend + workspace outside liveness rows, run hidden temporal loops, or infer semantics from Config, cell family, message-rule + names, graph constructors, hidden-size constants, or old tensor names. +- Hardcoded primitive equations in shared temporal files are compiler breaches. Code that directly implements formulas + such as diagonal recurrence/traces (`nu`, `theta`, `hc1/hc2`, eligibility traces), gated sLSTM equations + (`iraw/fraw/zraw/oraw`, `c/n/m`), fixed Q/K/V attention, layernorm math, projection/readout formulas, or their + adjoints is valid only inside a primitive executor selected by a lowered op row. If the same math appears in a + temporal scan, reverse scan, planner, runtime bridge, or benchmark helper, mark it open R15/R2.1 debt and remove or + fail-close it before claiming generic Fabric support. +- Changing a user-declared cell or message program must either change the compiled primitive program that executes or + fail closed as unsupported. If changing the declaration still runs the same hidden diagonal/gated/QKV/readout route, + the backend is a facade over one fixed algorithm and cannot close REDO/Fixmass. +- Do not over-decompose a supported composite recurrence into a large eager semantic interpreter just to appear generic. + Prefer a concise declared composite primitive when it is the chosen ABI, then make the row-owned primitive executor and + temporal scheduler real. A PyTorch-style semantic interpreter is reference/debug machinery only, not CUDA closure. +- Parity and throughput evidence through a facade is invalid. Closure evidence must prove the declared program compiled + into the runtime op/tensor tables and primitive executors that produced the values, gradients, and performance. + +### Ownership Contract + - Fabric separates configured cell populations from backend execution. - Cells may own schemas, parameter scopes, local recurrent equations, local backward/recompute recipes, dtype/layout constraints, and per-population parameter materialization. @@ -44,11 +125,51 @@ Do not use docs, configs, or chat history as proof of current support. Verify cu - Treat user corrections as general design principles by default. Do not scope a correction to the last failing row, mixed/single-pop only, one file, one benchmark, or one hidden size unless the user explicitly narrows it. Apply the principle across Fabric backend design, docs, parity, performance, cleanup, and future refactors. +- Backend policy must be derived from generic plan/runtime facts: flat graph tables, bucket identity, tensor/op rows, + reset policy, output materialization, `T*K`, horizon, dtype/device, and workspace pressure. Do not introduce a policy + keyed by an audit row id, family name, parameter label, hidden-size constant, single-vs-mixed population label, or a + fabricated unit-test throughput fixture. +- Prioritize the shared temporal engine and throughput-critical backend ownership before cleanup or route-name polish. + Cleanup, metadata relabeling, wrapper deletion, and doc hygiene are important, but they do not close backend stages + until the active forward/backward temporal work physically runs through the shared backend engine and current-code + parity/performance evidence is green. - `fabric.cuda.nn` is the native CUDA semantic declaration layer. Cells declare transition math there; it is not a `torch.nn` eager API, Python helper, or schedule owner. +- Do not fake `fabric.cuda.nn` internally. A declaration only counts if it is the source of truth that lowers into + explicit IR/primitive op rows, tensor roles, parameter bindings, shape metadata, and executor selection. A wrapper + that names `fabric.cuda.nn` while the backend still calls hardcoded gated/diagonal/message formulas, cell-specific + bundles, or private temporal-kernel branches is a closure blocker, not an intermediate success. +- Message passing is a user-declared Fabric semantic surface just like cells. Query/key/value sources, normalization or + public-boundary use, distance/delay terms, aggregation, reset behavior, and future message operators must lower from + Fabric declarations into generic message primitive rows and tensor tables. Do not bake one message rule into the + temporal engine as hidden backend policy. +- Cell math and message-passing math have the same closure standard. A generic-looking declaration is not enough if the + active backend still hardcodes the math elsewhere. `fabric.cuda.nn` only counts when both cells and message rules lower + into IR -> primitive op rows -> tensor roles -> parameter bindings -> primitive executors. The temporal engine must not + name Q/K/V, gated recurrence, diagonal recurrence, layernorm, projections, readout formulas, or any other primitive + math as engine-owned concepts; those names may appear only inside the corresponding primitive declaration/lowering and + primitive executor. +- Do not confuse "we currently support one message rule/cell primitive" with "the backend is generic." If support is + narrow, expose that as a narrow supported primitive and fail other declarations closed. Do not present the route as a + general Fabric CUDA library while internally forcing all programs through that one primitive. - Substantial Fabric backend, API, graph, message-rule, or performance work must maintain a doc-backed progress outline. Pick the relevant design/benchmark doc, add the current checklist and closure order, and update it before returning. Do not rely on chat state to remember cleanup, parity, performance, or stale-artifact status. +- Every REDO/fixmass owner must pair new shared-engine work with deletion, fail-closed guarding, or explicit reopening of + the stale path it makes obsolete. Do not accumulate live legacy siblings, transitional replay routes, compatibility + shims, or reference-only wrappers in the codebase; use commit history for old behavior. +- Keep shared-engine code small and decisive. If an edit introduces another abstraction, metadata layer, or helper + surface, it must immediately delete, fail-close, or expose for removal the legacy path it replaces. Concise owned code + is preferred over multiple live abstractions because parallel routes confuse future backend work and hide throughput + regressions. +- If a route is unreferenced or only a dormant sibling of the compiler-owned path and it violates Fabric boundaries, + delete it in the same pass instead of preserving it as a fallback, debug hook, or future option. Commit history is the + reference for old routes; live code should contain the shared engine and explicit fail-closed unsupported paths. +- The compiler target should reduce backend code over time. Once a primitive executor replaces any legacy kernel, delete + the legacy CUDA kernel, pybind export, Python wrapper, and route tests after parity, reset coverage, owner metadata, + and representative throughput evidence pass. This applies across all compiler-owned surfaces: message, transition, + projection, normalization, activation, readout, boundary adjoints, parameter reductions, temporal scheduling, and any + old route-specific kernels. Do not leave the compiler as a permanent layer over the old kernels. - Do not accept symptom-level local edits as Fabric backend design. If the root issue is an abstraction boundary, planner ownership, execution surface, or stale bridge, fix that framework-level invariant and remove the stale path. This applies to every Fabric path: single-population, mixed-population, forward, backward, `B`, params, `h`, graph, @@ -70,6 +191,25 @@ Do not use docs, configs, or chat history as proof of current support. Verify cu or in backend-owned CUDA execution modules. - Once a supported `fabric.cuda.nn` op exists, migrated affine/message/recurrence work must be authored as `fabric.cuda.nn` declarations, not direct cell-side calls or bridge specs against `fabric.cuda.ops`. +- If the active engine needs an op that is missing from `fabric.cuda.nn`, add the missing op/declaration/lowering to + `fabric.cuda.nn` and carry it through IR, primitive rows, tensor roles, and executors. Do not bypass the design by + calling private kernels, direct runtime helpers, cell bundles, or metadata-only wrappers around the missing op. +- A supported `fabric.cuda.nn` op must have an auditable lowering chain: + declaration -> Fabric IR -> primitive op row -> tensor-table roles -> primitive executor. If any step is missing and + the backend infers semantics from cell names, parameter bundle shapes, benchmark rows, or handwritten temporal-kernel + code, the path is fake and must fail closed or remain open. +- The active executor must be selected by the lowered program, not by the existence of legacy tensor names. A tensor role + such as `q`, `k`, `v`, `gate`, `diag`, `norm`, or `readout` is valid only as data owned by a specific lowered + primitive row. It is invalid if the temporal engine expects those names as its fixed ABI for all Fabric programs. +- Do not replace string-based fixed tensor roles with a temporal-side numeric role map. A table such as + `{"primitive.gated_logspace_recurrence.param.gate_weight": 13}` is still a facade if it is owned by the temporal + bridge. Tensor bindings must be produced by declaration/IR lowering from primitive inputs, outputs, and parameter + bindings. If a user signature changes, the compiled binding rows must change with it, or execution must fail closed + before the temporal engine. +- Parameter-gradient and reduction rows must come from compiled parameter bindings, not every tensor schema entry. A + schema can describe private state, public interfaces, reusable scratch, or parameters, but only binding-owned + executable parameters may enter temporal parameter-reduction tables. Do not let broad schemas create fake gradient + rows for receiver queries, placeholders, cached artifacts, or other non-bound tensors. - Physical operator selection must come from `CellTransitionIR`, `FabricStepIR`, `LoweredPhaseIR`, `PhysicalOpPlan`, `PhysicalExecutionPlan`, or equivalent semantic lowering. Do not select from `native_cell_kind`, cell names, row names, benchmark ids, parameter target, hidden-size constants, or dispatcher-side family checks. If these type names drift, inspect the live equivalent IR/planner structures rather than preserving stale names. @@ -79,7 +219,89 @@ Do not use docs, configs, or chat history as proof of current support. Verify cu `src/cortical/ops/rtu/**` to close Fabric physical-operator work. If Fabric needs a recurrence executor, add a Fabric-owned generic CUDA op under `src/cortical/fabric/backend/cuda/ops`. - Do not add cell-specific fast paths, hidden-size constants, one-off topology branches, benchmark-row branches, or named-model branches. Cell-family names may describe workload families in docs, but they must not drive executor selection. +- CUDA-native, Triton, PyTorch reference, and future hardware targets are implementation strategies, not Fabric semantics. + Generic compiler/planner rows must not assume that the registered temporal program is inherently CUDA-native. Strategy + selection should be explicit and row-owned, with declared fields such as implementation backend, runtime entrypoint, + supported device, dtype/layout contract, forward/backward support, artifact contract, workspace policy, and fail-closed + unsupported reasons. A Triton kernel may implement a primitive executor, but it must be selected through the same + primitive rows, tensor bindings, legality checks, and memory/liveness rows as any CUDA-native executor. Do not add + Triton, CUDA, family, benchmark, or shape branches in temporal scheduler code. +- Design goals are hard constraints, not preferences. Time pressure, throughput pressure, audit pressure, or a narrower + user request must not justify a shortcut that violates Fabric boundaries. Reject and log the probe instead. +- Temporal engine kernels and superops must consume generic flat-bucket identity plus tensor-table/op-table descriptors + lowered from `fabric.cuda.nn`/Fabric IR. Do not pass bespoke gated/diagonal/sLSTM/Axon argument lists into the shared + temporal engine, and do not branch on cell-family semantics inside temporal-engine kernels. Reusable primitive kernels + may implement generic physical op families; the temporal engine remains a scheduler over tensor slots, op rows, + dependencies, reset policy, checkpoint policy, and materialization policy. +- Tensor tables and op rows are the only accepted ABI for temporal engine work. If a temporal CUDA edit needs transition, + recurrence, normalization, message, projection, or readout math, the edit must add or call a declared primitive + executor selected by the op row. Do not inline primitive formulas such as gated recurrence, diagonal recurrence, + layernorm, attention, or cell-local adjoints inside the temporal scan/reverse engine. +- REDO/Fixmass cannot close with a `fabric.cuda.nn` facade. Closure requires greppable code evidence that temporal + forward/backward reaches primitive math through lowered tensor/op rows, plus audit metadata proving the active row used + those primitive executors rather than legacy hardcoded backend formulas. +- A faster temporal-kernel probe is rejected if it moves declared cell/message primitive math into the temporal engine. + Revert that probe immediately, record it as a boundary failure in the progress doc, and continue by adding the missing + `fabric.cuda.nn`/tensor-table primitive executor or by making the temporal scheduler dispatch existing primitive rows. +- Message primitive math follows the same rule as transition primitive math. A temporal kernel may execute a lowered + generic message primitive row, but it must not make dot-product attention, a fixed Q/K/V projection source, a fixed + normalization choice, or a fixed aggregation rule the only engine-owned message semantics. If a message operator is + missing, add a declarative Fabric primitive/lowering path rather than a route-specific backend branch. +- If a backend edit exposes user-declared message concepts such as Q/K/V outside the message primitive executor boundary, + stop and refactor the message-rule lowering/executor boundary before optimizing that path. Treat the exposed role names + as evidence of a facade unless the active code is dispatching a `MessageRuleIR` primitive row. +- `fabric.cuda.nn` primitives are universal operators, not message-only, readout-only, or cell-only operators. A primitive + such as `linear`, `add`, `attention_logits`, `segment_softmax`, `weighted_sum`, `reduction_boundary`, or a composite + recurrence is selected by op rows and can be used on any compatible surface. Surface labels such as message, readout, + transition, and parameter reduction describe where the op row is scheduled; they must not create separate primitive + meanings or hidden surface-specific semantics. Add tests that prove shared primitives keep the same opcode/executor + identity across surfaces. +- Primitive-specific dimensions such as message head width or a recurrent-affine primitive tile width may appear only as + tensor-table/op-row metadata inferred from primitive tensor shapes. Do not conflate unrelated primitive dimensions, and + do not use them as cell-family selectors or route keys. +- Before adding or binding a temporal CUDA kernel/superop, perform a manual boundary review and record it in the live + progress doc. The review must list the ABI inputs and explicitly confirm: no cell-kind selector, no population-name + selector, no benchmark-row selector, no hidden-size policy key, no separate single/mixed route, no cell-family + parameter bundle, no inline cell/message primitive formula, and all primitive math reached through tensor-table/op-row + dispatch. If the ABI cannot pass that review, do not land the kernel. - A single-step call is only a degenerate case of a generic streaming/sequence contract. Operator design must not encode one active time shape as the whole model. +- The shared temporal engine target is one fully owned CUDA temporal superop over flat buckets. Python may construct + declarations, call the high-level model API, and pass tensors/metadata, but must not own temporal scan loops, K + microstep loops, horizon windows, checkpoint/recompute, materialization policy, or backward loops beyond unavoidable + API glue. Transitional Python loops, Python `autograd.Function` scan bodies, and sibling routes must be recorded as + open cleanup and deleted after the CUDA-owned temporal engine has parity/audit proof. +- Temporal materialization must follow the April26-style recovered design shape: + the planner records output request, autograd seed surface, finite-H backward + window, checkpoint stride, recompute window, and reverse artifact kind for the + whole `T*K` stream. CUDA executes that plan over generic tensor/op tables. + Do not implement one-step, terminal-only, K-only, or row-specific + materialization as separate backend ownership. +- Training is always a streamed Fabric execution. A public training call + (`model(...)`, external loss, `loss.backward()`) may emit dense user outputs + because the user requested them, but the backend must still run the planned + `T*K` substrate as a streaming scan with planner-owned compact checkpoints, + bounded reverse windows, recompute policy, and segment-local artifact + disposal. Full-sequence state, message, K/V, transition tape, or boundary + materialization is a backend bug unless it is an explicit user-visible output + or a planner-recorded compact checkpoint/recompute artifact. +- Sequence training, terminal training, T=1, T>1, and K>1 are output/emission + schedules over the same streamed temporal engine. Do not add a separate + "training path" that retains the whole sequence, calls a benchmark streaming + helper, tiles time outside the backend, or treats per-timestep loss as a + reason to abandon streaming. +- When the active high-priority backend owner requires a CUDA kernel or physical superop, implement that kernel/superop + before spending more cycles on cleanup, route polish, planner metadata, or benchmark organization. Do not avoid kernel + work by relabeling a Python loop, adding a wrapper, or moving logic to another host-side helper. +- During migration, reduce Python-owned temporal work by moving hot scan, K, reset, emission, checkpoint, and backward + mechanics into the shared flat-bucket temporal engine first. Do not spend repeated cycles on cosmetic cleanup while + the active owner is still a host loop, runtime helper, or per-population fallback. +- R15 cleanup is a broad Fabric surface and legacy-deletion owner, not a narrow config/anatomy rename. It includes stale + execution routes, single/mixed sibling paths, benchmark-side backend/planner logic, direct message or cell math hidden + outside declared Fabric primitives, legacy config truth paths, lattice facts leaking into generic anatomy/backend code, + public Blueprint-to-Config facade removal, message-rule genericity, graph protocol ownership, old public-test cleanup, + population/cardinality cleanup, hidden-size and magic-threshold planner cleanup, cell-family backend-surface cleanup, + and manual/static guardrails that prevent those leaks from returning. Do not narrow R15 to the most recent cleanup + complaint; preserve and close the recovered additional-goals inventory. - Reset behavior is backend-owned and generic. Use explicit reset policy/scope metadata; do not hide reset-sensitive behavior in cells. - Whole-surface `MathBackend` / registry selection is transitional only and must never override a supported `PhysicalOpPlan` or per-bucket backend choice. @@ -104,10 +326,20 @@ Do not use docs, configs, or chat history as proof of current support. Verify cu ## Scaling Contract - Fabric should scale across cell count, topology size, batch/parallel rollout width, streaming or sequence time, and future population/bucket structure. -- Stage-style scaling audits must treat the real frontier as `B x params x h x graph x T`: batch and parameter count - should scale without regressing smaller rows; hidden size must not be a policy key; small `h` rows are many-cell stress - tests, not disposable oddities; graph shape/factorization is a user construction detail, not a backend scheduling key; - and `T` is streaming time over the same graph. +- Stage-style scaling audits must treat the real frontier as spatial/parallel axes plus temporal streaming axes. `B`, + params, graph/node count, topology, population buckets, and `h` are parallel/spatial work axes. `T` and `K` are + temporal streaming axes over the same T=1 substrate: `T*K` is repeated T=1 execution under one backend-owned temporal + engine, with different output materialization and checkpoint/recompute policy, not a separate algorithm or route. +- T=1 is the base shared-engine contract. Backend design must not chase K/H/long-T closure through a divergent path + while matched T=1 training is unhealthy; T/H/K work extends the T=1 temporal superop and reopens T=1 owners when it + exposes regressions. +- T=1 backend health cannot be inferred from a fast 1M probe alone. Representative guardrails for shared-engine work + must include April21-shaped 100M/500M/1B rows, and full closure must cover the April21 100M/500M/1B matrix, + high-batch small-param rows, small-hidden stress rows, reset/state axes, and mixed-pop rows through the same + flat-bucket temporal engine. +- Batch and parameter count should scale without regressing smaller rows; hidden size must not be a policy key; small + `h` rows are many-cell stress tests, not disposable oddities; graph shape/factorization is a user construction detail, + not a backend scheduling key. - Mixed-population CUDA support must be judged as a backend planner/executor feature, not just a parity convenience. Benchmark it against matched stack/MoE-style baselines and PyTorch Fabric reference rows, cover reset and no-reset cases, include enough `B x params` coverage to expose bucket scaling, and report whether cost is close to equivalent @@ -141,6 +373,24 @@ Do not use docs, configs, or chat history as proof of current support. Verify cu node sets, adjacency/degree buckets, group buckets, reset density, layout contracts, dtype/device caps, workspace, and tape pressure. Do not use rectangular factorization, x/y/z coordinates, cuboid shape labels, or future arbitrary-shape labels as execution-policy keys. +- Legacy Fabric config fields such as `width`, `height`, `depth`, `coord_shape`, `coord_dim`, `input_band_width`, + `output_band_width`, `projection_region_shape`, and config-level `num_heads` are public graph/message-construction or + compatibility surface only. They must be normalized into explicit graph facts, message primitive rows, and tensor-table + dimensions before backend planning. Treat any direct use of these fields for backend route selection, temporal-engine + admission, launch policy, kernel specialization, mixed/single-pop identity, or performance gates as a design bug and + log/clean it up rather than building on it. +- Do not "fix" legacy config leakage by promoting lattice facts into backend-facing anatomy. Lattice owns lattice facts: + rectangular factorization, coordinate shape, wraparound, band selectors, projection regions, and offset neighborhoods. + Fabric backend/anatomy contracts should expose graph-generic flat tables and metadata only. Runtime/backend code should + consume sender/receiver tables, valid masks, edge facts, group ids, flat bucket identity, tensor tables, and op rows + without knowing whether those facts came from a lattice, an arbitrary graph, or a future graph constructor. +- Treat `src/cortical/fabric/config.py` as a mis-owned legacy surface, not a generic Fabric design center. Its fields + must be split by owner: graph constructors own graph/topology/port/grouping fields, message declarations own message + dimensions and distance/delay semantics, cell declarations own transition primitive choices and state schemas, readout + declarations own readout pooling/slots, initialization utilities own seeds/noise/slot features, and planner request + types own horizons/checkpoints/K/backend preferences. Do not add new generic backend behavior to this global config, + and do not add legacy wrappers as the migration strategy; update callsites to the right owners and delete the old + path. - Factorization-invariance claims must hold flat graph invariants constant. If a benchmark changes boundary cardinality, degree histogram, edge buckets, or group structure, it is a different graph workload, not proof of backend factorization sensitivity. @@ -176,6 +426,9 @@ Do not use docs, configs, or chat history as proof of current support. Verify cu - Time indices, sequence chunk ids, and reset timesteps are runtime data, not physical-kernel specialization keys. Triton or CUDA wrappers must not compile a distinct physical kernel for each timestep. For reset masks, prefer passing a one-step reset view or runtime scalar offset so the same family executor is reused across streaming time. +- Temporal message delay indices are owned by the temporal schedule/artifacts. Do not hardcode recurrent or readout + message forward/backward to step `1` in a shared temporal path; derive the message step from the physical scan step + or pass an explicit per-row step tensor into the backend kernel. - Large per-timestep sequence loss must not return one retained full-`T` autograd graph when the intended semantics are streaming/TBPTT. The model/runtime should consume backend-selected output chunks, apply or reduce loss through a backend-owned streaming sink, and detach only the compact carry at the rolling-window boundary. Do not detach by @@ -242,7 +495,7 @@ Route work by ownership: wrapper, replay, Python custom Function, or legacy execution identity remains dominant, stop editing around that owner and implement the planner-owned physical executor that removes it. Delete the old route instead of keeping parallel paths once parity and non-regression gates pass. -- No backwards-compatibility shims in Fabric backend migrations. Update all callsites to the new public/backend surface, +- No backwards-compatibility wrappers in Fabric backend migrations. Update all callsites to the new public/backend surface, remove the old export or bridge, and let unsupported rows fail with explicit demotion rather than silently taking an obsolete path. - Regressions do not authorize changing the architecture target, weakening the refactor, adding fallback paths, or @@ -257,6 +510,10 @@ Route work by ownership: - Attribute backward hot time before implementing a physical family. Choose work by current ownership, not by assumption. - Backward operators must preserve forward boundaries: `projected_message`, `state_affine_output`, `raw_public`, public/readout boundaries, and reduction boundaries. +- Projection backward must differentiate the exact tensor consumed by forward projection. If a forward row projects a + normalized/public sender state, weight gradients use that normalized/public sender state, and sender-state adjoints + must flow back through the owning public-boundary/normalization backward before reaching raw recurrent state. Never + bypass the public boundary by treating raw state as the projection input. - Cells, builders, and registration must not declare physical backward executors or executor hints. - No whole-surface `MathBackend` override of backward physical operators. - Specialized recurrence backward selection may key on semantic recurrence kind, coefficient layout, activation id, reset @@ -265,17 +522,23 @@ Route work by ownership: ## Required Workflow 1. Identify the active semantic declaration and physical plan owner before editing. -2. Update the live progress doc with the current cleanup/implementation/verification order before switching tasks. -3. Rerun the relevant current-code row before relying on performance evidence; historical profiles are context only. -4. Run or cite an isolated experiment before structural performance changes. -5. Choose one accountable owner per profile cycle; make the full cohesive structural change needed for that owner. -6. If the profile regresses, keep the same owner and design target; diagnose the mechanism, iterate, and rerun or reject +2. If the change adds or changes semantics or executor strategies, start with the compiler-boundary audit and route it + through the compiler-extension workflow: + declaration/spec -> primitive rows -> tensor bindings -> legality -> registered forward/backward/reducer strategy. + Use `cb.fabric-primitive-op-onboarding` for new primitive semantics, `cb.fabric-graph-onboarding` for graph/topology + semantics, `cb.fabric-readout-rule-onboarding` for readout/output-boundary semantics, and + `cb.fabric-throughput-strategy` for optimization of already-lowered compiler products. +3. Update the live progress doc with the current cleanup/implementation/verification order before switching tasks. +4. Rerun the relevant current-code row before relying on performance evidence; historical profiles are context only. +5. Run or cite an isolated experiment before structural performance changes. +6. Choose one accountable owner per profile cycle; make the full cohesive structural change needed for that owner. +7. If the profile regresses, keep the same owner and design target; diagnose the mechanism, iterate, and rerun or reject the probe with evidence. Do not change the goal to fit the regression. -7. Add explicit demotion reasons for unsupported cases. Never silently fall back. -8. Validate source boundaries before finalizing. -9. Update the progress doc with accepted changes, rejected/aborted runs, remaining cleanup, parity gates, and next owner +8. Add explicit demotion reasons for unsupported cases. Never silently fall back. +9. Validate source boundaries before finalizing. +10. Update the progress doc with accepted changes, rejected/aborted runs, remaining cleanup, parity gates, and next owner before returning control. -10. If user feedback exposes a durable rule, update this skill during the same task when edits are allowed. +11. If user feedback exposes a durable rule, update this skill during the same task when edits are allowed. ## Skill Learning Loop @@ -312,4 +575,14 @@ For performance passes, use `cb.fabric-performance-loop` and report active-path ## Integration -**Pairs with:** `cb.fabric-cell-boundaries` for cell-facing boundary risks, `cb.fabric-cell-onboarding` for adding new cell families, `cb.fabric-performance-loop` for profile-driven backend work. +**Pairs with:** `cb.fabric-compiler-boundary-audit` for preflight/closeout checks, +`cb.fabric-compiler-extension` for compiler products, `cb.fabric-primitive-op-onboarding` for new +primitive ops, `cb.fabric-graph-onboarding` for graph/topology declarations, `cb.fabric-throughput-strategy` for +strategy optimization, `cb.fabric-native-strategy-onboarding` for native/fused strategy implementation, +`cb.fabric-message-rule-onboarding` for message-rule changes, +`cb.fabric-readout-rule-onboarding` for output/readout rules, `cb.fabric-cell-boundaries` for cell-facing boundary +risks, `cb.fabric-cell-onboarding` for adding new cell families, `cb.fabric-compiler-stress-test` for pre-throughput +semantic locality tests, `cb.fabric-boundary-guardrails` for source/static guardrails and deletion checks, +`cb.fabric-runtime-front-end-handoff` for public-call preparation and pre-registered-entry owners, +`cb.fabric-performance-loop` for profile-driven backend +work. diff --git a/skills/cb.fabric-boundary-guardrails/SKILL.md b/skills/cb.fabric-boundary-guardrails/SKILL.md new file mode 100644 index 00000000..8f0f3c87 --- /dev/null +++ b/skills/cb.fabric-boundary-guardrails/SKILL.md @@ -0,0 +1,105 @@ +--- +name: cb.fabric-boundary-guardrails +description: Use when adding, editing, or reviewing Cortical Fabric source guardrails, static tests, legacy-path deletion checks, fallback/compatibility bans, row/binding locality checks, or no-hidden-route assertions for compiler, throughput, cell, message-rule, readout, graph, primitive-op, or public API work. +--- + +# Fabric Boundary Guardrails + +Use this skill when a Fabric change needs tests or static checks that keep compiler boundaries from regressing. Guardrails +are not closure by themselves; they prevent known bad routes from returning after the active path is proven. + +**Announce at start:** "Using Fabric boundary guardrails. I'll add focused source checks that protect compiler products without replacing parity or runtime evidence." + +## First Read + +- `skills/cb.fabric-workflow-router/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- The narrow skill for the changed surface: throughput, native strategy, reducer liveness, primitive op, message rule, + cell, readout, graph, public API cleanup, or compiler stress test. +- Existing tests around the touched boundary: + +```bash +rg -n "fixed|compat|fallback|primitive_row|executor_row|tensor_binding|artifact_route|output_route|memory_liveness|source guard|guardrail" tests src/cortical/fabric +``` + +## Guardrail Contract + +A good guardrail protects a compiler invariant that has already failed or is easy to regress: + +- supported execution consumes compiler-owned rows, bindings, routes, liveness rows, and strategy records; +- unsupported declarations fail closed through legality before launch; +- primitive formulas stay inside declaration/reference/native executor boundaries; +- old fixed-slot, compatibility, fallback, wrapper, RuntimeError-bridge, family-selector, benchmark-selector, and + hidden-route paths stay deleted; +- throughput patches keep semantic rows stable and move only strategy/runtime/liveness owners; +- semantic patches change rows/bindings/routes or fail closed. + +Do not use guardrails as a substitute for runtime parity or performance evidence. A source check that says +`compiler_owned` exists is not proof that the active path ran through the compiler. + +## What To Assert + +Prefer checks that pair a positive compiler product with a negative stale-route ban: + +```text +positive: registered executor rows / tensor binding rows / route rows / liveness rows are consumed +negative: fixed slot enum / compatibility wrapper / family branch / fallback route is absent +``` + +For semantic work, assert the locality expectation: + +```text +old declaration -> old rows/bindings/routes +new declaration -> changed rows/bindings/routes +unsupported declaration -> typed blocker before launch +``` + +For throughput work, assert the inverse: + +```text +semantic rows stable +strategy/liveness/reducer rows or native stage changed +active owner metadata and measured owner moved +``` + +## What Not To Assert + +- Do not add broad "water is wet" checks that only prove a file contains common words such as `compiler`, `row`, or + `strategy`. +- Do not freeze incidental line structure, formatting, variable names, or comments when the invariant is structural. +- Do not whitelist one benchmark row, one hidden size, one cell family, or one population count as proof of generic + support. +- Do not require legacy names to remain just so a guardrail can find them. +- Do not accept a guardrail that only checks metadata labels while the active call path can still use a fallback. + +## Work Loop + +1. Name the boundary invariant and the old route being kept out. +2. Identify the compiler product that should own the behavior: declaration/spec, primitive rows, tensor bindings, + artifact/output routes, memory/liveness rows, reducer rows, or strategy records. +3. Add the smallest source/static test that checks that product and forbids the stale path. +4. Add or update runtime/parity/perf tests separately when behavior or throughput is affected. +5. If the guardrail exposes missing compiler products, stop and route through `cb.fabric-compiler-extension` or the + narrow semantic skill before adding more assertions. + +## Closeout + +Record in the active progress doc or final summary: + +```text +Invariant guarded: +Positive compiler product checked: +Negative stale route banned: +Runtime/parity/perf evidence paired with it: +Known limits: +``` + +## Integration + +**Uses:** `cb.fabric-workflow-router`, `cb.fabric-compiler-boundary-audit`. + +**Pairs with:** `cb.fabric-throughput-strategy`, `cb.fabric-native-strategy-onboarding`, +`cb.fabric-reducer-liveness`, `cb.fabric-compiler-extension`, `cb.fabric-declaration-onboarding`, +`cb.fabric-primitive-op-onboarding`, `cb.fabric-message-rule-onboarding`, `cb.fabric-cell-onboarding`, +`cb.fabric-readout-rule-onboarding`, `cb.fabric-graph-onboarding`, `cb.fabric-public-api-cleanup`, +`cb.fabric-compiler-stress-test`, `cb.fabric-parity-gate`, and `cb.fabric-performance-loop`. diff --git a/skills/cb.fabric-cell-boundaries/SKILL.md b/skills/cb.fabric-cell-boundaries/SKILL.md index 96b6d926..9ec595c5 100644 --- a/skills/cb.fabric-cell-boundaries/SKILL.md +++ b/skills/cb.fabric-cell-boundaries/SKILL.md @@ -1,6 +1,6 @@ --- name: cb.fabric-cell-boundaries -description: Use when changing Cortical Fabric cell declarations, population config, local cell math, or code that risks mixing cell-local semantics with backend execution ownership. +description: Use when changing Cortical Fabric cell declarations, population config, local cell math, message dependencies, transition semantics, or code that risks mixing cell-local semantics with backend/compiler execution ownership. --- # Fabric Cell Boundaries @@ -13,6 +13,24 @@ Use this skill to keep Fabric cell work on the semantic side of the Fabric bound - Read `src/cortical/fabric/README.md`, especially the current cell authoring section. - Inspect the live cell exports, tests, and callsites before copying any API shape from docs or prior chats. +- If the change crosses into compiler/backend ownership, read `skills/cb.fabric-compiler-boundary-audit/SKILL.md` and + use it as the closeout checklist. +- If the work adds backend support, primitive ops, message rules, tensor bindings, native CUDA strategies, backward + executors, or reducers, read `skills/cb.fabric-compiler-extension/SKILL.md` before editing. +- If the work adds a primitive op, read `skills/cb.fabric-primitive-op-onboarding/SKILL.md`; if it optimizes an existing + compiler product, read `skills/cb.fabric-throughput-strategy/SKILL.md`. +- If the work writes CUDA/Triton/C++ native strategy bodies, fused kernels, or binding schemas, read + `skills/cb.fabric-native-strategy-onboarding/SKILL.md`. +- If the work changes message dependencies, sender/receiver roles, aggregation, distance/delay use, or message + parameters, read `skills/cb.fabric-message-rule-onboarding/SKILL.md`. +- If the work changes population graph construction, lattice/config fields, boundary/port semantics, or topology facts, + read `skills/cb.fabric-graph-onboarding/SKILL.md`. +- If the work changes public output, readout routing, output-boundary semantics, pooling, or readout reducers, read + `skills/cb.fabric-readout-rule-onboarding/SKILL.md`. +- If adding source/static guardrails for cell/message boundary leaks, stale wrappers, fixed routes, or no-hidden-fallback + checks, read `skills/cb.fabric-boundary-guardrails/SKILL.md`. +- If behavior, backend support, CUDA execution, or gradient ownership changes, read + `skills/cb.fabric-parity-gate/SKILL.md` before accepting the result. - Search relevant Codex chat history for durable lessons, not current API truth: ```bash @@ -22,6 +40,13 @@ rg -l -i "fabric|CellTransitionIR|projected_message|hardcode|fallback|rejected|p ## Boundary Rules +- Cell and message declarations use the same compiler-boundary standard: declared semantics must lower into rows, + bindings, legality, reference behavior, optional native strategies, backward/reducer coverage, and parity. Neither + surface is defined by a side `.cuh`, family selector, config shortcut, or temporal-engine formula. +- Fabric is a compiler for declared Fabric programs. Cell and message declarations are user-authored semantics that must + lower into executable backend primitive programs: declaration -> IR -> op/tensor rows -> parameter bindings -> + primitive executor. Do not add API surfaces that look generic while the active backend silently executes only one + hardcoded cell/message formula. - Treat `fabric/cells/.py` as cell-local schema, population module construction, local transition contract, and registration only. - Treat `fabric/backend/pytorch/cells/.py` and optional `fabric/backend/cuda/cells/.cuh` as backend-specific local cell math only. - Do not add cell-specific executor kernels, ownership modes, planner rules, graph policy, workspace policy, message aggregation, readout, batching, replication, reuse policy, or whole-sequence orchestration as part of adding a cell. @@ -46,6 +71,20 @@ rg -l -i "fabric|CellTransitionIR|projected_message|hardcode|fallback|rejected|p - If a probe regresses, do not move the cell/backend boundary or change the semantic goal to make it pass. Diagnose within the intended contract, iterate, or reject the implementation while preserving the design target. - Optimize for the correct architecture, not for the smallest diff. If the clean fix requires deleting a bridge, replacing an ownership model, or moving work into backend-owned family modules, do that cohesive refactor while keeping the validation tied to one accountable owner. - Do not add reference fallbacks, compatibility shims, or "temporary" paths to hide failed invariants. Fix the invariant or fail closed. +- If a cell or message declaration needs an op that is missing from `fabric.cuda.nn`, add the missing op to + `fabric.cuda.nn` and carry it through IR, primitive rows, tensor roles, and executors. Do not bypass the design with + private kernels, direct backend helpers, cell-specific bundles, or metadata-only wrappers. +- New cell-local math becomes executable only through the compiler chain. Do not make a new cell CUDA-supported by adding + a cell-name selector, private CUDA helper, or temporal-engine branch; add declaration/spec metadata, primitive rows, + tensor bindings, legality, reference execution, native strategy coverage, backward/tape/reducer ownership, and parity. +- If only one cell or message primitive is currently supported, keep that support explicit and fail unsupported + declarations closed. Do not make broad `fabric.cuda.nn` declarations, configs, or blueprints that secretly collapse + every program into a fixed dot-product message rule, fixed recurrent cell, fixed projection stack, or fixed readout. +- A cell declaration may use a first-class composite recurrence primitive when that is the actual supported compiler + granularity. Keep the declaration style concise: separate projections, recurrent affines, normalization, public + emission, and state emission as declared ops, and make only the recurrence formula a composite such as + `gated_logspace_recurrence` or `diag_rtu`. Composite options such as activation kind, trace policy, and emitted + artifacts must be explicit op inputs/attributes, not hidden cell-family side channels. ## Skill Learning Loop @@ -70,4 +109,11 @@ rg --files tests | rg "test_fabric" ## Integration -**Pairs with:** `cb.fabric-cell-onboarding` for adding a new cell family, `cb.fabric-backend-boundaries` for backend execution changes, `cb.fabric-performance-loop` for profile-driven changes. +**Pairs with:** `cb.fabric-compiler-boundary-audit` for preflight/closeout checks, +`cb.fabric-cell-onboarding` for adding a new cell family, `cb.fabric-message-rule-onboarding` for +message-rule changes, `cb.fabric-primitive-op-onboarding` for new primitive ops, `cb.fabric-graph-onboarding` for +graph/topology declarations, `cb.fabric-throughput-strategy` for strategy optimization, `cb.fabric-compiler-extension` +for compiler products, `cb.fabric-readout-rule-onboarding` for output/readout rules, +`cb.fabric-boundary-guardrails` for source/static guardrails and deleted cell/message-route checks, +`cb.fabric-parity-gate` for behavior/backend/gradient parity, +`cb.fabric-backend-boundaries` for backend execution changes, `cb.fabric-performance-loop` for profile-driven changes. diff --git a/skills/cb.fabric-cell-onboarding/SKILL.md b/skills/cb.fabric-cell-onboarding/SKILL.md index 6cc56394..ce8bd562 100644 --- a/skills/cb.fabric-cell-onboarding/SKILL.md +++ b/skills/cb.fabric-cell-onboarding/SKILL.md @@ -1,17 +1,20 @@ --- name: cb.fabric-cell-onboarding -description: Use when adding a new Cortical Fabric cell family, porting a recurrent cell into Fabric, or wiring a new Fabric cell through registration, reference parity, and tests. +description: Use when adding a new Cortical Fabric cell family, porting a recurrent cell into Fabric, or wiring a new Fabric cell through declaration/spec registration, compiler primitive lowering, reference parity, CUDA strategy coverage, and tests. --- # Fabric Cell Onboarding Use this skill for the concrete workflow of adding a new Fabric cell. For boundary questions, use `cb.fabric-cell-boundaries`. -**Announce at start:** "Using Fabric cell onboarding. I'll inspect the live cell API, copy the current registration pattern, and keep execution policy out of the new cell." +**Announce at start:** "Using Fabric cell onboarding. I'll inspect the live cell API, extend the semantic registration path, and keep execution policy out of the new cell." ## First Read - Read `src/cortical/fabric/README.md`, then inspect current cell implementations and tests. Do not copy API shapes from stale docs or chats. +- Read `cb.fabric-compiler-boundary-audit` before adding backend/compiler support for the cell. +- Read `cb.fabric-declaration-onboarding` when adding or changing public cell declaration semantics. +- Read `cb.fabric-compiler-stress-test` when changing cell recurrence math as a pre-throughput compiler locality test. ```bash rg --files src/cortical/fabric | rg "cells|backend/.*/cells|registry|runtime" @@ -19,15 +22,90 @@ rg -n "CellTransitionIR|cell_transition|register|population|public|private_state ``` - Read `cb.fabric-cell-boundaries` before editing if the change touches CUDA declarations, local math, resets, public output, or backend-facing state. +- Read `cb.fabric-compiler-extension` before adding a new transition primitive, message dependency, tensor binding, + native CUDA strategy, backward executor, or parameter reducer for the cell. +- Read `cb.fabric-primitive-op-onboarding` before adding a new transition primitive or operator used by the cell. +- Read `cb.fabric-throughput-strategy` before adding or tuning a CUDA/Triton strategy for existing cell primitive rows. +- Read `cb.fabric-native-strategy-onboarding` before writing CUDA/Triton/C++ native strategy bodies, fused kernels, or + binding schemas for the cell's existing primitive rows. +- Read `cb.fabric-graph-onboarding` if the cell work also changes graph/population construction, boundary/port sets, or + topology facts. +- Read `cb.fabric-readout-rule-onboarding` if the cell work changes public output semantics, output-boundary behavior, + readout routing, pooling, or readout parameter reducers. +- Read `cb.fabric-boundary-guardrails` when adding source/static checks for cell row delta, no cell-family backend + selectors, deleted wrappers, or no hidden fallback. +- Read `cb.fabric-parity-gate` before accepting CUDA/backend support for the cell. Cell onboarding must prove outputs, + exposed state, input/carry gradients, state gradients when exposed, and parameter gradients against the PyTorch + Fabric reference before performance claims. ## Onboarding Steps -1. Identify the live cell pattern to extend: schema/config, population construction, parameter names/shapes, registration/export, PyTorch reference path, CUDA path, and tests. +1. Identify the live cell pattern to extend: declaration/spec, population construction, parameter names/shapes, + registration/export, PyTorch reference path, primitive rows, tensor bindings, CUDA strategy path, reducers, and tests. 2. Add only cell-owned surfaces: cell schema, parameter materialization, local recurrent equation, local backward/recompute recipe if present, and native `fabric.cuda.nn` semantic declarations. 3. Declare meaning, not execution. The cell may declare private/public state, parameters, affine/message inputs, reset policy, reduction boundaries, diagonal recurrence kind, and emission contract. It must not choose schedules, launch shapes, GEMM family, routing ownership, graph policy, workspace policy, readout, batching, or fallback behavior. -4. Keep backend contracts canonical: `projected_message`, state-affine output, raw public/public projection, readout boundary, reset semantics, dtype/layout constraints, and shape metadata must match the existing engine expectations. -5. Add targeted tests by discovering the current suite: construction/import, parameter shape, PyTorch reference behavior, CUDA parity when available, reset behavior, shape smoke, and serialization/config coverage if the live code supports it. -6. If onboarding exposes missing generic backend support, stop treating it as a cell task. Route that work through `cb.fabric-backend-boundaries` and keep unsupported cases explicit. +4. Lower the declaration through the compiler chain: backend spec -> primitive rows -> tensor roles -> parameter + bindings -> verifier/legality -> registered reference executor -> optional native/fused CUDA strategy. A cell is not + CUDA-supported just because a cell name was registered. New primitive semantics use `cb.fabric-primitive-op-onboarding`; + performance-only strategy work uses `cb.fabric-throughput-strategy`; native kernel work also uses + `cb.fabric-native-strategy-onboarding`. +5. Keep backend contracts canonical: `projected_message`, state-affine output, raw public/public projection, readout boundary, reset semantics, dtype/layout constraints, and shape metadata must match the existing engine expectations. +6. Add targeted tests by discovering the current suite: construction/import, parameter shape, lowering rows, typed + fail-closed unsupported cases, PyTorch reference behavior, CUDA parity when available, reset behavior, state/carry + gradients, parameter reducers, shape smoke, and serialization/config coverage if the live code supports it. Use + `cb.fabric-boundary-guardrails` for source/static checks that ban cell-family backend selectors and stale wrappers. +7. If onboarding exposes missing generic backend support, stop treating it as a cell task. Route that work through + `cb.fabric-compiler-extension` and `cb.fabric-backend-boundaries`, and keep unsupported cases explicit. + +Before editing backend/CUDA code for the cell, record: + +```text +Cell declaration/spec: +Transition primitive rows: +Tensor/state/parameter bindings: +Reference executor: +Native strategy, if any: +Backward/tape/reducer contracts: +Unsupported typed blockers: +Files that should not change if the boundary holds: +``` + +If this cannot be filled in, the cell is not ready for CUDA/throughput work. + +## Unified Declaration Standard + +Cells and message rules must feel different only where their semantics are different. Both follow: + +```text +public declaration -> normalized spec -> primitive rows -> tensor bindings -> registered executors +``` + +A cell may own state schema, parameters, transition math, public emission, and reset semantics. It may not own graph +construction, message aggregation, readout routing, workspace/liveness, time scheduling, or benchmark policy. If adding a +cell requires a new message rule or primitive op, add that missing compiler product through the message-rule or +primitive-op skill; do not hide it in the cell implementation. + +The cell registry is a semantic registry, not an executor selector. Adding a cell should register its declaration/spec +and transition primitive program, then rely on compiler lowering and strategy legality. Do not make CUDA support by +adding a cell-name branch, private kernel wrapper, or parameter-bundle shortcut. + +Use the same standard for message rules: both cells and messages are user-declared semantics that lower into rows and +bindings; both may have reference and native strategies; neither is defined by a standalone `.cuh` implementation. + +For a cell formula change, record the same row-delta packet used for message rules: + +```text +Old cell declaration/spec: +New cell declaration/spec: +Transition primitive rows changed: +State/tensor/parameter roles changed: +Backward/tape/reducer contract changed: +Reference parity owner: +Unsupported typed blocker: +``` + +If a cell change requires a new message rule, readout rule, graph fact, or primitive op, add that product through the +matching skill first. Do not hide the missing product in a cell-family branch or native bundle. ## Closeout @@ -39,4 +117,15 @@ rg -n "CellTransitionIR|cell_transition|register|population|public|private_state **Uses:** `cb.fabric-cell-boundaries` for semantic/backend ownership rules. -**Pairs with:** `cb.fabric-backend-boundaries` when adding the cell requires engine support, `cb.fabric-performance-loop` when the new cell changes active-path performance. +**Pairs with:** `cb.fabric-declaration-onboarding` for public semantic changes, +`cb.fabric-compiler-boundary-audit` for preflight/closeout checks, +`cb.fabric-compiler-extension` when the cell needs compiler products, +`cb.fabric-primitive-op-onboarding` when it needs new primitive ops, `cb.fabric-graph-onboarding` when it changes graph +or topology declarations, `cb.fabric-throughput-strategy` when it needs CUDA strategy work, +`cb.fabric-native-strategy-onboarding` when that strategy enters native kernels or binding schemas, +`cb.fabric-readout-rule-onboarding` when it changes output/readout semantics, +`cb.fabric-compiler-stress-test` when recurrence math changes are used as compiler locality tests, +`cb.fabric-boundary-guardrails` for source/static guardrails and deleted cell-route checks, +`cb.fabric-parity-gate` for output/state/input/carry/parameter-gradient parity, +`cb.fabric-backend-boundaries` when adding the cell requires engine support, `cb.fabric-performance-loop` when the new +cell changes active-path performance. diff --git a/skills/cb.fabric-compiler-boundary-audit/SKILL.md b/skills/cb.fabric-compiler-boundary-audit/SKILL.md new file mode 100644 index 00000000..03b473e8 --- /dev/null +++ b/skills/cb.fabric-compiler-boundary-audit/SKILL.md @@ -0,0 +1,217 @@ +--- +name: cb.fabric-compiler-boundary-audit +description: Use when reviewing, planning, or implementing Cortical Fabric compiler, throughput, cell, message-rule, readout-rule, primitive-op, graph, CUDA backend, or skill changes to verify they preserve compiler ownership before code, performance claims, or closure. +--- + +# Fabric Compiler Boundary Audit + +Use this skill as a compact preflight and closeout review for Fabric work. It does not replace the narrower skills; it +forces the agent to prove the work is routed through the compiler boundary before editing or accepting results. + +**Announce at start:** "Using Fabric compiler-boundary audit. I'll prove the change is declaration/row/binding owned before implementation or closure." + +## First Read + +- `skills/cb.fabric-backend-boundaries/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- The narrow skill for the surface being changed: + - public semantic declaration: `skills/cb.fabric-declaration-onboarding/SKILL.md` + - compiler reality/stress test for changed semantics before throughput: + `skills/cb.fabric-compiler-stress-test/SKILL.md` + - throughput strategy: `skills/cb.fabric-throughput-strategy/SKILL.md` + - native strategy/kernel implementation: `skills/cb.fabric-native-strategy-onboarding/SKILL.md` + - primitive op: `skills/cb.fabric-primitive-op-onboarding/SKILL.md` + - message rule: `skills/cb.fabric-message-rule-onboarding/SKILL.md` + - cell: `skills/cb.fabric-cell-onboarding/SKILL.md` + - readout rule: `skills/cb.fabric-readout-rule-onboarding/SKILL.md` + - public API/config/blueprint cleanup: `skills/cb.fabric-public-api-cleanup/SKILL.md` + - runtime/model front-end handoff before registered backend entry: + `skills/cb.fabric-runtime-front-end-handoff/SKILL.md` + - graph/topology: `skills/cb.fabric-graph-onboarding/SKILL.md` + - parity/performance evidence: `skills/cb.fabric-parity-gate/SKILL.md` and + `skills/cb.fabric-performance-loop/SKILL.md` + - source/static guardrails and legacy deletion checks: `skills/cb.fabric-boundary-guardrails/SKILL.md` + - Fabric skill changes: `skills/cb.fabric-skill-maintenance/SKILL.md` + +Inspect the active code, not only docs: + +```bash +rg -n "primitive_row|executor_row|tensor_binding|program_access|artifact_route|output_route|memory_liveness|native_callable|missing_executor|fixed|compat|fallback|python_autograd_scan" src/cortical/fabric tests benchmarks +``` + +## Work Classifier + +Classify the request before planning or editing: + +- **Semantic extension:** user-visible graph, message, cell, readout, primitive-op, tensor role, reset, output, or + formula meaning changes. Use `cb.fabric-declaration-onboarding`, `cb.fabric-compiler-extension`, and the narrow + onboarding skill. Throughput work stops until declaration/spec, rows, bindings, reference behavior, legality, + backward/reducer coverage, and fail-closed blockers exist. +- **Compiler stress test:** deliberate semantic formula/tensor-role change before throughput to prove locality, such as + changed dot-product/message math, normalization, gating, recurrence, readout, or reducer behavior. Use + `cb.fabric-compiler-stress-test`; rows/bindings/routes must change or the declaration must fail closed. Do not + optimize during the stress pass. +- **Throughput strategy:** primitive/graph rows and semantics stay stable; only registered strategy, launch, memory, + artifact, reducer, liveness, or cost policy changes. Use `cb.fabric-throughput-strategy`. +- **Native strategy implementation:** existing compiler rows stay stable while CUDA/Triton/C++ kernels, native callable + schemas, fused program kernels, or pybind ABIs change. Use `cb.fabric-native-strategy-onboarding` after confirming the + row/binding products already exist. +- **Evidence/parity/perf:** no code semantics change; measure the active high-level path and prove owner metadata, + parity, and current-code performance. Use `cb.fabric-performance-loop` and `cb.fabric-parity-gate`. +- **Hypothesis probe:** a small, time-bounded experiment, including synthetic fast experiments, that decides whether a + throughput or liveness direction is worth continuing. It must use or faithfully isolate the active compiler-owned + route, record the expected owner movement and keep/narrow/revert rule, and remain steering evidence only until a + representative warmed row passes. +- **Cleanup/deletion:** remove stale routes only after the compiler-owned replacement is active, covered by tests, and + unsupported cases fail closed. +- **Guardrail:** source/static checks that preserve a compiler invariant after a replacement is active. Use + `cb.fabric-boundary-guardrails`; pair positive compiler-product checks with negative stale-route bans. +- **Public API cleanup:** old `Config`, `Blueprint`, compatibility constructors, and declaration normalization change. + Use `cb.fabric-public-api-cleanup`; every field must move to an explicit graph/cell/message/readout/reset/planner + owner or be deleted. +- **Runtime front-end handoff:** public `model(...)` call preparation, input/output adapters, boundary tensors, state + initialization, reset normalization, sender K/V setup, or the call into the registered temporal program changes. Use + `cb.fabric-runtime-front-end-handoff`; if semantics are missing, reclassify as declaration/compiler extension, and if + the owner is throughput-only, keep rows stable and move work behind compiler liveness/strategy rows. + +If a throughput task discovers a missing primitive, tensor role, graph fact, message/readout declaration, or backward +contract, reclassify the work as a semantic extension and start at `cb.fabric-declaration-onboarding`. Do not add that +missing meaning inside a faster kernel. + +For ambiguous Fabric requests or the user's repeated analyze -> plan -> proceed workflow, use +`cb.fabric-workflow-router` first, then return here for the boundary proof. + +## Universal Surface Contract + +Cells, message rules, readout rules, primitive ops, graph constructors, and public declaration cleanup all use the same +compiler boundary. The surface-specific skill may add details, but the required proof is identical: + +```text +public declaration/spec + -> normalized semantic record + -> graph or primitive rows + -> tensor/parameter/artifact/output/reset/liveness bindings + -> verifier legality and typed blockers + -> reference executor + -> optional registered native/fused strategy + -> backward/tape/recompute/reducer coverage + -> parity, active-route metadata, and focused source guardrails +``` + +Adding a new primitive op must require only registry/lowering metadata, a reference executor, optional fused/native +strategy, reducer/liveness rows when needed, and tests. It must not require editing temporal scheduler ownership, fixed +tensor slot enums, cell-family route selectors, graph-constructor selectors, public config defaults, benchmark policy, or +monolithic scan/reverse ABIs. + +For a healthy compiler, the same review question should work for every surface: "If this declaration changed, which rows, +bindings, routes, typed blockers, reference behavior, and optional strategies changed?" If the answer is "none, but a +kernel helper changed," the compiler boundary failed. + +## Pre-Implementation Gate + +Before editing, write down the intended compiler products: + +```text +Surface: +Lane: semantic extension | throughput strategy | native implementation | evidence | cleanup +Declaration/spec owner: +Primitive or graph rows: +Tensor/parameter/artifact/output bindings: +Legality/verifier blockers: +Forward owner: +Backward/reducer owner: +Memory/liveness owner: +Old route to delete or fail-close: +Progress doc: +Guardrail invariant: +``` + +If any required row/binding/legality product does not exist, stop. Use the compiler-extension or onboarding skill to add +semantics first. Do not compensate by editing temporal scheduler code, fixed slots, benchmarks, config, private helpers, +or legacy wrappers. + +For throughput/native work, also record the row fingerprint or row group that must remain semantically stable. A +throughput patch whose primitive rows, tensor roles, declaration attributes, parameter bindings, or gradient contracts +change is misclassified semantic work and must be replanned. + +For hypothesis probes, record the smallest row, synthetic fixture, or telemetry check, why it exercises or isolates the +same compiler route/owner as the target, and the stop condition. If the probe requires benchmark-owned scheduling, +row-specific selectors, hidden fallback, or semantic changes, reject the probe before running it. If the probe bypasses +the high-level `model(...)` path, label it mechanism-only and require a high-level representative row before acceptance. + +For semantic work, record the expected row/binding/route delta before implementation. If changing a message rule, cell, +readout, graph, or primitive op leaves the compiled program unchanged, the compiler boundary failed or the declaration +is unsupported and must fail closed. + +For new cells, message rules, readout rules, graph constructors, and primitive ops, record the locality expectation: +which declaration/spec, lowering, reference, native strategy, reducer/liveness, and test files may change, and which +temporal scheduler, fixed-slot, benchmark, config, and family-selector files should not change. If the expected edit set +requires scheduler-owned formulas or route selectors, stop and close that compiler gap first. + +For public API cleanup, also name the old field/path being removed, the new declaration/spec owner, and the source +guardrail that prevents backend execution from reading the old broad config surface again. + +## Boundary Questions + +Answer these before accepting a patch: + +- Did a user-visible declaration change, or is this only an implementation strategy? +- If semantics changed, did the compiled primitive/graph rows, tensor roles, attributes, parameter bindings, or legality + metadata change? +- If only throughput changed, did the primitive rows stay stable while strategy/runtime/memory rows changed? +- Does execution consume compiler-owned rows/bindings/routes/liveness directly? +- Are cell, message, readout, projection, normalization, recurrence, attention, activation, and reducer formulas confined + to primitive/reference/native strategy implementations? +- Are reset, checkpoint, recompute, materialization, aliasing, output routing, and artifact routing compiler/planner + products instead of benchmark or temporal-kernel side channels? +- Are unsupported cases typed fail-closed before launch? +- Was the replaced legacy route deleted or made unreachable for supported rows? +- Does each new source/static guardrail check a real compiler product and ban a concrete stale route, rather than only + matching a metadata label? + +## Hard Rejects + +Reject and record the probe if it: + +- copies or wraps old fixed-slot kernels instead of consuming current compiler rows; +- adds Q/K/V, gated, diagonal, layernorm, readout, projection, or message formulas to temporal scheduler code; +- keys selection on cell family, message-rule name, readout name, graph constructor, hidden-size constant, benchmark id, + single/mixed-pop label, or old tensor names; +- moves policy into benchmarks/config instead of graph/message/cell/readout declarations and planner rows; +- claims performance from metadata while the active owner, launch count, peak memory, or kernel path did not move; +- passes parity through a hidden fallback/replay/compat path. + +## Closeout Gate + +Before finalizing: + +1. Run the narrow parity/perf/source tests required by the surface skill. +2. Prove active-route ownership, not just metadata. Runtime/audit output should name the registered compiler owner, + row/binding/route fingerprint, and absence of hidden fallback/replay/compat paths for supported rows. +3. Grep for new boundary leaks in touched areas: + +```bash +rg -n "native_cell_kind|family_type|hidden.?==|benchmark|fallback|compat|fixed_slot|python_autograd_scan|qkv|gated|diagonal|readout" src/cortical/fabric tests benchmarks +``` + +4. If source guardrails changed, run `cb.fabric-boundary-guardrails` and record the invariant, positive compiler product, + negative stale route, and paired runtime/parity/perf evidence. +5. Update the active progress doc with: + - accepted compiler products; + - commands and artifacts; + - rejected shortcuts; + - remaining open owners; + - what did not need to change because the boundary held. + +If an optimization only changed labels/docs or added a wrapper while active owner time, launch path, memory, or compiled +row evidence did not move, mark it rejected and keep the owner open. + +## Integration + +**Uses:** `cb.fabric-backend-boundaries`, `cb.fabric-compiler-extension`. + +**Pairs with:** `cb.fabric-workflow-router`, `cb.fabric-declaration-onboarding`, `cb.fabric-throughput-strategy`, +`cb.fabric-native-strategy-onboarding`, `cb.fabric-primitive-op-onboarding`, `cb.fabric-message-rule-onboarding`, +`cb.fabric-readout-rule-onboarding`, `cb.fabric-cell-onboarding`, `cb.fabric-graph-onboarding`, +`cb.fabric-public-api-cleanup`, `cb.fabric-runtime-front-end-handoff`, `cb.fabric-compiler-stress-test`, +`cb.fabric-boundary-guardrails`, `cb.fabric-parity-gate`, and `cb.fabric-performance-loop`. diff --git a/skills/cb.fabric-compiler-extension/SKILL.md b/skills/cb.fabric-compiler-extension/SKILL.md new file mode 100644 index 00000000..688e1256 --- /dev/null +++ b/skills/cb.fabric-compiler-extension/SKILL.md @@ -0,0 +1,276 @@ +--- +name: cb.fabric-compiler-extension +description: "Use when adding or modifying Cortical Fabric compiler products: graph/topology declarations, primitive ops, message rules, readout rules, transition primitives, native CUDA/Triton strategies, tensor bindings, parameter reducers, compiler rows, executor legality, memory/artifact plans, or throughput executors that must preserve the Fabric compiler boundary." +--- + +# Fabric Compiler Extension + +Use this skill when adding new Fabric semantics or new executor strategies. It is the concrete workflow for keeping +throughput work, new cells, message rules, readout rules, primitive ops, and CUDA kernels inside the compiler contract. + +**Announce at start:** "Using Fabric compiler extension. I'll add semantics through declaration/IR/primitive rows and add execution through registered strategies, not scheduler-owned formulas." + +## First Read + +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` for the preflight/closeout boundary checklist. +- `skills/cb.fabric-backend-boundaries/SKILL.md` +- For any public semantic declaration change, first read `skills/cb.fabric-declaration-onboarding/SKILL.md`. +- For a deliberate pre-throughput semantic stress test, formula perturbation, or proof that adding new math is local, + also read `skills/cb.fabric-compiler-stress-test/SKILL.md`. +- For a new or changed primitive op, also read `skills/cb.fabric-primitive-op-onboarding/SKILL.md`. +- For a new or changed throughput strategy, also read `skills/cb.fabric-throughput-strategy/SKILL.md`. +- For CUDA/Triton/C++ native strategy or fused program kernel implementation, also read + `skills/cb.fabric-native-strategy-onboarding/SKILL.md`. +- For parameter reducers, reverse span outputs, runtime-buffer lifetimes, workspace reuse, or reducer-consumed native + outputs, also read `skills/cb.fabric-reducer-liveness/SKILL.md`. +- For a message-rule change, also read `skills/cb.fabric-message-rule-onboarding/SKILL.md`. +- For a readout-rule or output-boundary change, also read `skills/cb.fabric-readout-rule-onboarding/SKILL.md`. +- For a cell-family change, also read `skills/cb.fabric-cell-onboarding/SKILL.md`. +- For a graph/topology declaration or constructor change, also read `skills/cb.fabric-graph-onboarding/SKILL.md`. +- For public API, Config, Blueprint, or declaration-normalization cleanup, also read + `skills/cb.fabric-public-api-cleanup/SKILL.md`. +- For source/static guardrails, legacy deletion checks, or no-fallback/no-compat checks, also read + `skills/cb.fabric-boundary-guardrails/SKILL.md`. +- Relevant live declarations/specs: + +```bash +rg --files src/cortical/fabric | rg "message_rule|cell_spec|cells|cuda/nn|transition_execution|sequence_surface/compiler|registered_program" +rg -n "MessageRuleBackendSpec|CellBackendSpec|Primitive|executor_row|tensor_binding|native_callable|parameter_reducer|missing_executor" src/cortical/fabric tests +``` + +Use docs and chat history only for intent. Verify current behavior in code and tests. + +## Skill Routing + +- If the request is ambiguous or arrives as the user's analyze -> plan -> proceed loop, first use + `cb.fabric-workflow-router` to classify the phase and the narrow skill set. +- Any user-visible graph/cell/message/readout/primitive/reset/init formula or declaration change: + `cb.fabric-declaration-onboarding` first, then the narrow semantic skill. +- Pre-throughput compiler reality test or semantic formula perturbation: + `cb.fabric-compiler-stress-test` first, then declaration onboarding and the narrow semantic skill. Do not optimize + throughput during the stress pass. +- New primitive op or formula change: `cb.fabric-primitive-op-onboarding`. +- New throughput executor, memory/liveness policy, artifact policy, or fused kernel: `cb.fabric-throughput-strategy`. +- Native CUDA/Triton/C++ strategy body or fused program kernel: `cb.fabric-native-strategy-onboarding`. +- Parameter reducer, reverse span output, runtime-buffer lifetime, or workspace ownership: + `cb.fabric-reducer-liveness`. +- New message rule or message formula change: `cb.fabric-message-rule-onboarding`. +- New readout rule, output route, output-boundary, pooling, or readout formula change: + `cb.fabric-readout-rule-onboarding`. +- New cell family or transition declaration: `cb.fabric-cell-onboarding` plus `cb.fabric-cell-boundaries`. +- New graph constructor, topology field, edge fact, boundary/port rule, or lattice/config cleanup: `cb.fabric-graph-onboarding`. +- Public API, Config, Blueprint, or declaration-normalization cleanup: `cb.fabric-public-api-cleanup`. +- Parity evidence for any of the above: `cb.fabric-parity-gate`. +- Current-code profiling or throughput evidence: `cb.fabric-performance-loop`. +- Source/static guardrails or old-route deletion checks: `cb.fabric-boundary-guardrails`. + +Routing is exclusive at the point of implementation: + +- If public meaning, primitive rows, tensor roles, parameter bindings, or backward/reducer semantics change, do semantic + compiler-extension work first. Do not optimize in the same step unless the user explicitly asks after the semantic + path is proven. +- If only the implementation strategy changes, do throughput-strategy work and keep primitive rows stable. +- If a performance owner cannot be improved without inventing a missing role/op/contract, stop the throughput patch and + add the compiler product here first. + +## One Compiler Extension Shape + +Cells, message rules, readout rules, graph/topology declarations, primitive ops, and throughput strategies must use the +same extension shape. For user-visible semantic changes, `cb.fabric-declaration-onboarding` is the common entry point +before the surface-specific skill: + +```text +registered declaration/spec + -> normalized semantic record + -> compiler rows/bindings/routes/liveness + -> verifier and typed blockers + -> reference executor + -> optional native/fused strategy + -> parity/perf evidence +``` + +Do not let one surface be "real compiler" while another is a side file, hardcoded `.cuh`, old `Config` field, route +selector, or benchmark-owned helper. A `.cuh` file, pybind symbol, CUDA kernel, or Python helper can implement a +strategy, but it is never the semantic declaration. + +## Core Rule + +Every supported extension must follow this chain: + +```text +public declaration + -> semantic IR / backend spec + -> primitive op rows + -> tensor role and parameter binding rows + -> verifier / legality / typed blocker metadata + -> registered reference executor + -> optional registered native/fused CUDA strategy + -> forward + backward + reducer coverage + -> parity and source guardrails +``` + +Do not add an extension by editing temporal scheduler math, fixed tensor-slot enums, cell-family route selectors, +benchmark hooks, private runtime helpers, direct pybind wrappers, or metadata-only facades. + +## Boundary Acceptance Gate + +Before editing implementation code, write down the compiler products that will change. If any item is missing, the work +is not ready for throughput or CUDA optimization. + +| Product | Required answer | +| --- | --- | +| Declaration/spec | Which public graph, message, cell, readout, or op field carries the semantics? | +| Primitive rows | Which opcode/row group represents the semantics structurally? | +| Tensor roles | Which inputs, outputs, state, artifacts, and parameters are bound by role? | +| Legality | Which shapes, layouts, dtypes, resets, tape modes, and memory plans are supported or rejected? | +| Execution | Which reference executor and optional native strategy implement the same rows? | +| Backward/reducer | Which gradient inputs, saved/recomputed tensors, state grads, and parameter reducers are owned? | +| Audit | Which metadata proves the compiled rows ran, and which typed blockers explain rejected cases? | + +Adding a new primitive, message rule, cell transition, readout, graph behavior, or throughput strategy should require +only declaration/spec metadata, lowering, registry/native strategy code, verifier rules, parity tests, and source +guardrails. It must not require editing temporal scheduler ownership, fixed tensor-slot enums, cell-family route +selectors, graph-name selectors, or monolithic scan/reverse ABIs. If it does, reopen compiler closure before optimizing. + +Before implementing, perform the locality test: describe the smallest files that should change if the compiler boundary +is real. If the planned change needs edits in unrelated temporal scheduler loops, legacy fixed-slot tables, benchmark +harnesses, or family-specific dispatch, stop and close that compiler gap first. + +Use `cb.fabric-compiler-stress-test` for semantic formula changes used as a pre-throughput stress test. A changed +dot-product/message rule, normalization/gating term, cell recurrence, readout, or reducer should be introduced through +declaration/spec and primitive lowering first. +If adding that math requires editing temporal scheduler ownership or fixed ABI slots, the compiler path is not closed +enough for throughput work. + +For every semantic extension, add a row-delta assertion somewhere in tests or review notes: + +```text +Old declaration -> old rows/bindings/routes +New declaration -> changed rows/bindings/routes +Unsupported declaration -> typed blocker before launch +``` + +For every throughput-only extension, add the inverse assertion: declarations and primitive rows stay stable while +strategy/runtime/liveness/reducer rows change. + +## Extension Checklist + +### 1. Classify the work + +- **Primitive op:** universal semantic unit such as `linear`, `matmul`, `norm_or_identity`, activation, segment op, or + a deliberate composite recurrence. Use `cb.fabric-primitive-op-onboarding` for the detailed checklist. +- **Graph/topology behavior:** user-declared graph semantics and edge facts. Use `cb.fabric-graph-onboarding`. + Lattice2D or any other constructor is only one producer of graph facts; backend code consumes lowered graph rows, not + graph-constructor identities. +- **Message rule:** public communication semantics over graph-local sender/receiver/edge facts. +- **Cell/transition:** population-local state schema, parameters, transition primitive program, reset/tape contract. +- **Readout:** output-boundary semantics and output route/merge contract. +- **Throughput strategy:** replaceable implementation of an already verified compiler product. Use + `cb.fabric-throughput-strategy` for the detailed checklist. + +If the change is a throughput strategy, semantics must already be represented by primitive rows. Add no new math to the +temporal engine. + +### 2. Add semantics before execution + +- Add or update the public declaration and normalized spec. +- Lower declaration fields into IR/backend spec fields; do not discard fields into a default rule. +- Emit structural primitive rows with stable opcode/executor identities. +- Emit tensor roles, parameter bindings, output routes, artifact roles, reset/tape policy, and schema versions. +- Add verifier/legality checks and typed rejection reasons for unsupported shape/layout/dtype/reset/tape cases. +- Register reference execution before native execution. A fused CUDA/Triton/native strategy is valid only as an + implementation of already-registered semantics and rows. + +Changing a declaration must either change the compiled primitive program or fail closed before launch. + +### 3. Add execution as a strategy + +- Register a reference executor or reference path used for truth/parity. +- Register forward native/fused executor metadata only after legality is explicit. +- Register backward/tape/recompute metadata with required saved tensors, recomputable tensors, state gradients, + parameter-gradient outputs, accumulation semantics, determinism/tolerance class, and failure modes. +- Register parameter reducer rows from binding-owned executable parameters only. +- Register memory/workspace/liveness needs as compiler products; strategies request workspace classes, the planner owns + allocation, aliasing, checkpoint, and recompute. +- Register native outputs as semantic returns, reducer inputs, carry/state inputs, artifacts, workspaces, or metadata. + A tensor with no declared compiler consumer is not an extension point; delete it or fail closed before launch. + +Use `can_implement(plan)` style legality before cost/ranking. A strategy may be slower if the cost model is wrong; it +must not be incorrect. + +### 4. Keep temporal runtime generic + +Temporal scan/reverse code may schedule executor rows over flat buckets, time, reset, checkpoint, materialization, and +artifact routes. It must not own: + +- Q/K/V formulas or fixed dot-product attention +- gated/diagonal recurrence equations +- layernorm/projection/readout formulas +- cell-family names or population-name routing +- benchmark-row, hidden-size, single-vs-mixed, or topology-specific selectors +- direct Python loops, direct pybind wrappers, or fallback replay routes for supported rows + +If a temporal file needs primitive math, stop and add/dispatch a registered primitive executor instead. + +### 5. Prove locality for future additions + +Before closing, answer in the progress doc: + +- What declaration or spec changed? +- Which primitive rows changed? +- Which tensor/parameter/artifact/route rows changed? +- Which executor strategy records changed? +- Which verifier/legality failures were added? +- Which forward, backward, and reducer paths own the new work? +- What files did not need to change because the compiler boundary held? + +A healthy new op/message-rule/strategy should be local to declaration/spec, lowering, registry/native-callable metadata, +strategy implementation, and tests. If adding one op requires editing scheduler ownership, fixed slot enums, or +monolithic scan/reverse ABIs, reopen compiler closure. + +Use this as the reviewer shortcut: if a future PR adding a primitive/message/cell/readout can be reviewed without +opening the temporal scheduler except for generic row dispatch or liveness policy, the compiler boundary is probably +holding. If the reviewer must inspect family-specific scheduler math or fixed tensor slots to understand semantics, it +is not. + +## Required Tests + +Pick the narrowest tests that prove the touched chain: + +- Lowering golden: declaration fields produce distinct IR/spec and primitive rows. +- Legality negative: unsupported op/layout/reset/tape fails closed before launch with typed blocker metadata. +- Reference parity: forward values, final state, input/carry gradients, parameter gradients. +- Native/fused parity when CUDA strategy exists. +- Reducer test for every new trainable parameter-gradient output. +- Focused source guardrail through `cb.fabric-boundary-guardrails`, pairing the positive compiler product with a + negative ban on direct wrappers, fixed slot roles, scheduler-owned formulas, or hidden fallbacks. +- Current-code smoke through the high-level user path: `output = model(x, ...)`, external loss, `loss.backward()`. + +## Throughput Additions + +For throughput work, the order is: + +1. Confirm the active semantics already lower into verified primitive rows. +2. Run `cb.fabric-throughput-strategy` and add or improve a registered strategy over those rows. +3. Prove parity first with `cb.fabric-parity-gate`. +4. Profile with `cb.fabric-performance-loop`. +5. Delete or fail-close the replaced old route after parity/current-code non-regression. + +Never accept faster throughput from a strategy that bypasses the compiler chain or moves primitive formulas into the +temporal scheduler. + +Historical April21 code/results are performance targets and comparison context only. Do not copy old fixed-slot kernels, +route structure, or implicit ABIs into the compiler path. Recreate the useful implementation idea as a registered +strategy over current primitive rows, bindings, routes, and memory/liveness plans. + +## Integration + +**Uses:** `cb.fabric-compiler-boundary-audit`, `cb.fabric-backend-boundaries`, `cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-declaration-onboarding` for public semantic changes, `cb.fabric-primitive-op-onboarding` for +new primitive ops, `cb.fabric-throughput-strategy` for throughput strategies, `cb.fabric-native-strategy-onboarding` for +native/fused strategy implementations, `cb.fabric-performance-loop` for profiling, `cb.fabric-cell-onboarding` for new +cells, `cb.fabric-message-rule-onboarding` for message rules, `cb.fabric-readout-rule-onboarding` for readout/output +rules, `cb.fabric-graph-onboarding` for graph/topology declarations, `cb.fabric-public-api-cleanup` for +declaration/API cleanup, `cb.fabric-boundary-guardrails` for source/static guardrails and deletion checks, +`cb.fabric-cell-boundaries` for cell semantics, `cb.fabric-scaling-horizon` for B/params/h/graph/T/K/H expansion. diff --git a/skills/cb.fabric-compiler-stress-test/SKILL.md b/skills/cb.fabric-compiler-stress-test/SKILL.md new file mode 100644 index 00000000..424604a5 --- /dev/null +++ b/skills/cb.fabric-compiler-stress-test/SKILL.md @@ -0,0 +1,148 @@ +--- +name: cb.fabric-compiler-stress-test +description: Use when deliberately changing Fabric semantics as a pre-throughput compiler reality test, such as dot-product/message math, normalization, gating, context terms, cell recurrence math, readout formula, primitive-op behavior, tensor roles, or reducer contracts, to prove the active path is truly declaration/row/binding owned. +--- + +# Fabric Compiler Stress Test + +Use this skill when the goal is to prove the compiler path is real by making a +semantic change before throughput work. This is semantic compiler-extension +work, not performance optimization. + +**Announce at start:** "Using Fabric compiler stress test. I'll make the semantic change through declarations, rows, bindings, reference behavior, and typed legality before any throughput work." + +## First Read + +- `skills/cb.fabric-workflow-router/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- The narrow semantic skill for the surface being stressed: + - message rule: `skills/cb.fabric-message-rule-onboarding/SKILL.md` + - primitive op: `skills/cb.fabric-primitive-op-onboarding/SKILL.md` + - cell/transition: `skills/cb.fabric-cell-onboarding/SKILL.md` + - readout/output: `skills/cb.fabric-readout-rule-onboarding/SKILL.md` + - graph/topology: `skills/cb.fabric-graph-onboarding/SKILL.md` +- `skills/cb.fabric-native-strategy-onboarding/SKILL.md` only after reference semantics and legality exist. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` when adding source/static checks for old-vs-new row delta, no fixed + algorithm, deleted wrappers, or no hidden fallback. +- `skills/cb.fabric-parity-gate/SKILL.md` before accepting behavior. + +Inspect live declarations, lowering, and active strategy records: + +```bash +rg -n "MessageRule|Cell|Readout|Primitive|primitive_row|tensor_binding|parameter_binding|native_callable|missing_executor|unsupported" src/cortical/fabric tests +``` + +## Contract + +A compiler stress test must make the same public call path compile a different +program, or fail closed before launch. It is not enough to patch a `.cuh` +helper, pybind wrapper, temporal scan/reverse kernel, benchmark hook, or +metadata label. + +Required shape: + +```text +old declaration -> old rows/bindings/routes -> old reference behavior +new declaration -> changed rows/bindings/routes -> changed reference behavior +unsupported native coverage -> typed blocker before launch +optional native strategy -> implementation of the new rows +``` + +If the row/binding/route delta is empty, stop. Either the new declaration is +not represented by the compiler yet, or the requested work is really a +throughput strategy over unchanged semantics. + +## Stress Packet + +Before implementation, record: + +```text +Stress target: +Old declaration/spec: +New declaration/spec: +Expected primitive/graph row delta: +Expected tensor/parameter/artifact/output route delta: +Reference forward/backward owner: +Native strategy owner, if any: +Reducer/tape/liveness delta: +Unsupported typed blockers: +Files that should not change if the boundary holds: +Parity matrix: +Throughput work explicitly deferred: +``` + +## Hard Rules + +- Do not optimize throughput in the same pass unless the user explicitly asks + after the semantic stress test is green. +- Do not copy April21 or legacy fixed-slot code. Re-express the semantic idea + through current declarations, rows, bindings, reference behavior, and + registered strategies. +- Do not add formulas to temporal scheduler, fixed tensor slots, benchmark + code, old Config/Blueprint defaults, route selectors, or compatibility + wrappers. +- Q/K/V, context nudge, normalization, gating, aggregation, recurrence, + readout, and reducer changes must live inside the relevant message/cell/ + readout/primitive declaration and executor boundary. +- Native CUDA/Triton/C++ may implement the new rows only after the reference + executor, backward/reducer contract, legality, and typed blockers exist. +- Unsupported combinations must fail closed before launch. Hidden fallback, + replay, or "old math with new metadata" is a failed stress test. + +## Work Loop + +1. Classify the changed public semantics and choose the narrow skill. +2. Add or update declaration/spec fields so the user-visible change is explicit. +3. Lower those fields into changed primitive rows, tensor roles, parameter + bindings, route rows, attributes, or legality metadata. +4. Add reference forward and backward behavior first. +5. Add typed blockers for unsupported native/device/layout/reset/tape/reducer + coverage. +6. Add optional registered native strategy only through + `cb.fabric-native-strategy-onboarding`. +7. Delete or fail-close old fixed-slot/direct-wrapper paths the stress exposes. +8. Prove row delta, reference parity, native parity when applicable, and source + guardrails. +9. Record what did not need to change because the compiler boundary held. + +## Required Tests + +- Row-delta/golden test for old vs new declaration. +- Negative legality test for unsupported native or strategy coverage. +- Reference parity for outputs, exposed state, input/carry gradients, and all + touched parameter gradients. +- Native/fused parity if a registered native strategy exists. +- Source guardrail through `cb.fabric-boundary-guardrails`: old/new declarations produce the expected row delta, and no + semantic formula appears in temporal scheduler, benchmark, fixed-slot, Config/Blueprint, or compatibility code. +- High-level call-path smoke: `output = model(x, ...)`, external loss, + `loss.backward()` when training is affected. + +## Closeout + +Record in the active progress doc: + +```text +Stress result: +Rows/bindings/routes changed: +Reference and native parity: +Unsupported blockers: +Old route deleted or fail-closed: +Files that did not change: +Throughput status: not started | still deferred | ready to resume +``` + +Throughput may resume only after the stress test proves the compiler, not a +hidden fixed algorithm, owns the semantic change. + +## Integration + +**Uses:** `cb.fabric-workflow-router`, `cb.fabric-compiler-boundary-audit`, +`cb.fabric-declaration-onboarding`, `cb.fabric-compiler-extension`, +`cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-message-rule-onboarding`, `cb.fabric-primitive-op-onboarding`, +`cb.fabric-cell-onboarding`, `cb.fabric-readout-rule-onboarding`, `cb.fabric-graph-onboarding`, +`cb.fabric-native-strategy-onboarding`, `cb.fabric-boundary-guardrails`, and `cb.fabric-throughput-strategy` only after +the stress test is accepted. diff --git a/skills/cb.fabric-declaration-onboarding/SKILL.md b/skills/cb.fabric-declaration-onboarding/SKILL.md new file mode 100644 index 00000000..2d767b3d --- /dev/null +++ b/skills/cb.fabric-declaration-onboarding/SKILL.md @@ -0,0 +1,147 @@ +--- +name: cb.fabric-declaration-onboarding +description: Use when adding, changing, or reviewing Cortical Fabric user-declared semantics across graph/topology, cells, message rules, readout rules, primitive ops, reset/init fields, public Config/Blueprint surfaces, or formula changes that must lower through the compiler before throughput work. +--- + +# Fabric Declaration Onboarding + +Use this skill when the meaning of a Fabric program changes. It is the common front door for graph, cell, message, +readout, primitive-op, reset/init, and public declaration work. It prevents new semantics from entering through native +kernels, throughput strategies, Config defaults, fixed slots, or temporal scheduler code. + +**Announce at start:** "Using Fabric declaration onboarding. I'll add or change user semantics through normalized declarations, rows, bindings, reference behavior, and typed legality." + +## First Read + +- `skills/cb.fabric-workflow-router/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- `skills/cb.fabric-compiler-stress-test/SKILL.md` when the declaration change is a pre-throughput compiler reality + test or formula perturbation. +- The narrow semantic skill for the changed surface: + - graph/topology: `skills/cb.fabric-graph-onboarding/SKILL.md` + - cell/transition: `skills/cb.fabric-cell-onboarding/SKILL.md` and `skills/cb.fabric-cell-boundaries/SKILL.md` + - message rule: `skills/cb.fabric-message-rule-onboarding/SKILL.md` + - readout/output: `skills/cb.fabric-readout-rule-onboarding/SKILL.md` + - primitive op: `skills/cb.fabric-primitive-op-onboarding/SKILL.md` + - public Config/Blueprint cleanup: `skills/cb.fabric-public-api-cleanup/SKILL.md` +- `skills/cb.fabric-boundary-guardrails/SKILL.md` when adding source/static checks for row delta, stale direct wrappers, + old Config/Blueprint choke points, fixed slots, or hidden fallbacks. +- `skills/cb.fabric-parity-gate/SKILL.md` before accepting behavior. + +Inspect live declarations and lowering: + +```bash +rg --files src/cortical/fabric tests | rg "config|blueprint|graph|message|readout|cell|primitive|cuda/nn|registry|compiler" +rg -n "Config|Blueprint|Graph|Cell|MessageRule|Readout|Primitive|primitive_row|tensor_binding|output_route|artifact_route|missing_executor" src/cortical/fabric tests +``` + +## Unified Declaration Rule + +Every user-visible Fabric semantic surface follows the same compiler shape: + +```text +public declaration + -> normalized spec + -> semantic IR / graph facts + -> primitive or graph rows + -> tensor, parameter, artifact, output, reset, and liveness bindings + -> verifier legality and typed blockers + -> reference executor + -> optional registered native/fused strategy + -> backward, tape/recompute, and reducer coverage + -> parity and source guardrails +``` + +Cells, message rules, readout rules, graph constructors, and primitive ops should feel similar at the boundary. They own +different semantics, but none of them are defined by a `.cuh` helper, route selector, old Config field, benchmark hook, +fixed tensor slot enum, or temporal scan/reverse branch. + +Adding a new op, message rule, cell, readout, or graph constructor should be local once the compiler boundary is healthy: +registry/spec metadata, lowering, verifier, reference executor, optional native strategy, reducer/liveness rows, tests, +and source guardrails. If the change requires temporal scheduler formula edits, fixed tensor slots, monolithic program +ABI expansion, or benchmark policy, record that as an open compiler boundary before proceeding. + +## Hard Stops + +Stop and use this skill before throughput/native work if any of these change: + +- graph facts, topology, boundary/port ownership, distance, delay, groups, or reset scope; +- message math, Q/K/V source, normalization, gating, aggregation, context terms, or message parameters; +- cell state schema, transition math, recurrence attributes, public emission, reset/tape contract, or cell parameters; +- readout math, pooling, output boundary, output route/merge, output materialization, or readout parameters; +- primitive op attributes, tensor roles, parameter roles, reducer meaning, or backward contract; +- public Config/Blueprint fields that decide graph/cell/message/readout/reset/planner semantics. + +A throughput strategy may optimize only after these products already exist and row fingerprints are stable. +If the change is meant to prove compiler reality before throughput, use `cb.fabric-compiler-stress-test` and require an +explicit old-vs-new row/binding/route delta before native or performance work. + +## Declaration Packet + +Before implementation, write this in working notes or the active progress doc: + +```text +Changed public declaration: +Normalized spec owner: +Expected row/binding/route delta: +Reference forward/backward owner: +Typed unsupported blockers: +Optional native strategy owner: +Backward/tape/reducer owner: +Old path to delete or fail-close: +Files that should not change if the compiler boundary holds: +``` + +If the expected row/binding/route delta is empty, do not patch math. Either the declaration is unsupported and should +fail closed, or the change is really a throughput strategy over existing semantics. + +## Work Loop + +1. Assign every changed field to exactly one owner: graph, cell, message, readout, primitive op, reset/init, or planner + request. Delete unowned fields or reject them with typed legality. +2. Add or update the normalized spec and lowering so changed declarations produce changed graph/primitive rows, + tensor/parameter bindings, output/artifact routes, and legality metadata. +3. Add or update reference forward and backward behavior before native strategy work. +4. Add typed blockers for unsupported native coverage, reset/tape/materialization, graph facts, route kinds, layouts, + dtypes, or reducer contracts. +5. Add optional registered native/fused strategies only through `cb.fabric-native-strategy-onboarding`. +6. Delete or fail-close old direct wrappers, fixed-slot compatibility paths, old Config/Blueprint choke points, and + route selectors once the compiler-owned replacement is active. +7. Prove row delta, reference parity, native parity when applicable, and source guardrails through + `cb.fabric-boundary-guardrails`. + +## Locality Test + +For a healthy declaration change, most edits stay near: + +- public declaration/spec and registry; +- lowering and verifier rows; +- reference executor; +- optional registered native strategy; +- reducer/tape/liveness rows; +- tests and source guardrails. + +If the change requires editing temporal scheduler math, fixed tensor slots, monolithic scan/reverse ABI, graph-constructor +route selectors, benchmark policy, or cell-family branches, compiler closure is not holding. Close that boundary before +claiming the declaration is supported. + +## Required Tests + +- Lowering golden: old and new declarations produce different rows/bindings/routes. +- Negative legality: unsupported declarations fail closed before launch with typed blockers. +- Reference parity: outputs, exposed state, input/carry gradients, and parameter gradients. +- Native/fused parity if a registered strategy exists. +- Source guardrail through `cb.fabric-boundary-guardrails`: changed declaration produces changed rows/routes, and no new + semantic formula enters scheduler, benchmark, fixed-slot, or compatibility code. +- High-level smoke: user-style `output = model(x, ...)`, external loss, `loss.backward()` when training is affected. + +## Integration + +**Uses:** `cb.fabric-workflow-router`, `cb.fabric-compiler-boundary-audit`, `cb.fabric-compiler-extension`, +`cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-graph-onboarding`, `cb.fabric-cell-onboarding`, `cb.fabric-cell-boundaries`, +`cb.fabric-message-rule-onboarding`, `cb.fabric-readout-rule-onboarding`, `cb.fabric-primitive-op-onboarding`, +`cb.fabric-public-api-cleanup`, `cb.fabric-compiler-stress-test`, `cb.fabric-boundary-guardrails`, +`cb.fabric-native-strategy-onboarding`, and `cb.fabric-throughput-strategy`. diff --git a/skills/cb.fabric-graph-onboarding/SKILL.md b/skills/cb.fabric-graph-onboarding/SKILL.md new file mode 100644 index 00000000..ddb4a923 --- /dev/null +++ b/skills/cb.fabric-graph-onboarding/SKILL.md @@ -0,0 +1,142 @@ +--- +name: cb.fabric-graph-onboarding +description: Use when adding, changing, or porting Cortical Fabric graph/topology declarations, graph constructors, boundary/input/output port semantics, edge facts, distance/delay/group metadata, topology lowering, graph compiler rows, or graph-driven planner legality. +--- + +# Fabric Graph Onboarding + +Use this skill when user graph semantics or topology facts change. Graphs are first-class Fabric declarations, not +Config shortcuts and not backend route selectors. + +**Announce at start:** "Using Fabric graph onboarding. I'll lower graph semantics into graph facts/rows and keep topology constructors out of backend policy." + +## First Read + +- `skills/cb.fabric-backend-boundaries/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` +- `skills/cb.fabric-compiler-stress-test/SKILL.md` if graph/topology facts are deliberately changed to prove compiler + locality before throughput work. +- `skills/cb.fabric-public-api-cleanup/SKILL.md` if the graph work removes or rewrites `Config`, `Blueprint`, or public + declaration normalization. +- If graph facts affect throughput strategy, also read `skills/cb.fabric-throughput-strategy/SKILL.md`. +- If graph facts feed a CUDA/Triton/C++ native strategy or fused program kernel, also read + `skills/cb.fabric-native-strategy-onboarding/SKILL.md`. +- If graph facts affect parity coverage, also read `skills/cb.fabric-parity-gate/SKILL.md`. +- If graph facts affect readout/output-boundary semantics or routes, also read + `skills/cb.fabric-readout-rule-onboarding/SKILL.md`. +- If adding source/static checks for graph-row ownership, no lattice/config backend selectors, deleted wrappers, or + no hidden fallback, also read `skills/cb.fabric-boundary-guardrails/SKILL.md`. +- Inspect live graph declarations, constructors, config/blueprint lowering, planner graph tables, and tests: + +```bash +rg --files src/cortical/fabric | rg "graph|topology|lattice|blueprint|config|anatomy|sequence_surface/compiler" +rg -n "Graph|Topology|Lattice|boundary|input.*cell|output.*cell|edge|degree|delay|distance|group" src/cortical/fabric tests +``` + +Use docs for intent only. Verify the active declaration/lowering/runtime path in code. + +## Required Chain + +Every supported graph/topology declaration must follow: + +```text +public graph declaration / constructor + -> normalized graph spec + -> flat graph facts and boundary/port sets + -> graph/topology rows and edge metadata + -> primitive/message/readout tensor bindings that consume those facts + -> verifier / legality / typed blockers + -> registered executor strategies over graph rows + -> parity + source guardrails +``` + +Changing a graph declaration must either change the compiled graph facts/rows or fail closed before launch. + +## Boundary Rules + +- Lattice2D is one graph constructor, not the Fabric graph model. Arbitrary user graphs should lower to the same flat + facts: node ids, input/output/boundary sets, edge sender/receiver rows, degree buckets, groups, distance, delay, reset + scope, and port metadata. +- Config and Blueprint may carry user-facing declarations only as front-end inputs. They must not be the backend source + of graph truth or a long-term normalization choke point. Backend code consumes normalized graph facts and compiler rows, + never lattice names or config fields. +- Cleanup that removes old broad `Config`/Blueprint translation must replace it with direct normalized declarations for + graph, cell, message, readout, reset, and planner policy. Do not keep an old flat config object as the hidden source of + semantics under a new declaration wrapper. +- If cleanup crosses graph ownership into population/cell/message/readout/planner fields, switch to + `cb.fabric-public-api-cleanup` and treat graph as one declaration owner, not the whole public API migration. +- Do not key executor selection, memory policy, active-region policy, message routes, or readout routes on graph + constructor identity, lattice shape, benchmark row, hidden size, single/mixed population label, or population name. +- Do not move message math or cell math into graph code. Graphs define topology and facts; message/cell/readout + semantics lower through their own primitive rows. +- If a graph change needs a new message primitive, readout behavior, reset behavior, or throughput strategy, route that + work through the matching skill instead of hiding it in topology lowering. +- Compact active-region or closure optimizations must be legality-proven over flat sender/receiver rows. Never infer + correctness from lattice bands, x/y coordinates, or output adjacency names alone. + +## Implementation Steps + +1. Identify the public graph declaration or constructor fields and the normalized graph spec they produce. +2. Emit explicit flat graph facts: nodes, ports, boundaries, edge rows, degree buckets, groups, distance/delay, and reset + scope where applicable. +3. Thread those facts into primitive/message/readout/tensor binding rows without adding graph-constructor branches. +4. Add verifier rules and typed blockers for unsupported graph features, layouts, delay policies, degree limits, reset + scopes, or active-region policies. +5. Update registered strategy legality/cost only from graph facts and compiler rows. If native code changes, use + `cb.fabric-native-strategy-onboarding` before writing the kernel or binding schema. +6. Delete or fail-close old config/anatomy/lattice-specific backend paths that the graph rows replace. +7. Prove locality: adding a new graph constructor should not require editing temporal scheduler formulas, fixed tensor + slot enums, message/cell primitive math, or monolithic scan/reverse ABIs. + +Before backend/planner edits for graph work, record: + +```text +Graph declaration/constructor: +Normalized graph facts: +Graph/topology rows: +Boundary/input/output port rows: +Edge/distance/delay/group rows: +Primitive/message/readout bindings that consume them: +Legality blockers: +Old config/lattice/anatomy route to delete or fail-close: +Files that should not change if the boundary holds: +``` + +If a graph optimization needs message/cell/readout formulas, switch to that semantic skill instead of hiding the formula +in topology lowering. + +## Required Tests + +- Lowering golden: declaration fields produce expected flat graph facts/rows. +- Negative legality: unsupported graph features fail closed before launch with typed blockers. +- Source guardrail through `cb.fabric-boundary-guardrails`: graph facts/rows are consumed and no graph-constructor, + lattice-shape, population-name, benchmark-row, or hidden-size selector appears in backend execution policy. +- Parity: at least one supported graph path through high-level user calls, including output/state/input and parameter + gradients when training is supported. +- Scaling guard: if graph shape/factorization changes, compare flat graph invariants before making performance claims. + +## Closeout Questions + +Record these in the active progress doc: + +- Which public graph declaration changed? +- Which normalized graph facts/rows changed? +- Which primitive/message/readout bindings consume the graph facts? +- Which verifier/legality blockers were added? +- Which old config/lattice/anatomy route was deleted or fail-closed? +- What did not need to change because the compiler boundary held? + +## Integration + +**Uses:** `cb.fabric-backend-boundaries`, `cb.fabric-compiler-boundary-audit`, `cb.fabric-compiler-extension`, +`cb.fabric-declaration-onboarding`, `cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-message-rule-onboarding` when graph facts feed message semantics, +`cb.fabric-cell-onboarding` when graph/population construction exposes cell declarations, +`cb.fabric-readout-rule-onboarding` when graph facts feed output/readout semantics, +`cb.fabric-throughput-strategy` and `cb.fabric-native-strategy-onboarding` for graph-driven strategy optimization, +`cb.fabric-compiler-stress-test` for pre-throughput graph-fact locality tests, `cb.fabric-boundary-guardrails` for +source/static guardrails and deleted graph-route checks, `cb.fabric-performance-loop` and +`cb.fabric-scaling-horizon` for graph/shape audits. diff --git a/skills/cb.fabric-message-rule-onboarding/SKILL.md b/skills/cb.fabric-message-rule-onboarding/SKILL.md new file mode 100644 index 00000000..09aa7a82 --- /dev/null +++ b/skills/cb.fabric-message-rule-onboarding/SKILL.md @@ -0,0 +1,132 @@ +--- +name: cb.fabric-message-rule-onboarding +description: Use when adding, changing, or porting a Cortical Fabric message rule, message-rule declaration, message-rule backend spec, message primitive lowering, message native CUDA strategy, or message parameter-gradient reducer. +--- + +# Fabric Message Rule Onboarding + +Use this skill for the concrete workflow of adding or changing Fabric message passing. Message rules are semantic +declarations like cells: they are not temporal-engine policies and not fixed Q/K/V shortcuts. + +**Announce at start:** "Using Fabric message-rule onboarding. I'll add message semantics through registered declarations/specs and keep message math out of the temporal scheduler." + +## First Read + +- `skills/cb.fabric-backend-boundaries/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` +- `skills/cb.fabric-compiler-stress-test/SKILL.md` when changing message math as a pre-throughput compiler reality + test, such as dot-product, context nudge, normalization, gating, or aggregation changes. +- `skills/cb.fabric-primitive-op-onboarding/SKILL.md` if the message rule introduces or changes a primitive op. +- `skills/cb.fabric-throughput-strategy/SKILL.md` if the change is a faster strategy over existing message rows. +- `skills/cb.fabric-native-strategy-onboarding/SKILL.md` before implementing CUDA/Triton/C++ native message strategy + bodies or binding schemas. +- `skills/cb.fabric-graph-onboarding/SKILL.md` if the rule introduces or consumes new graph/edge facts, distance/delay + semantics, boundary sets, or topology grouping. +- `skills/cb.fabric-readout-rule-onboarding/SKILL.md` if the message rule changes output routes, readout-boundary + inputs, readout pooling, or output artifact ownership. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` when adding source/static checks for message row delta, no fixed Q/K/V + slots, no hidden dot-product route, or deleted message wrappers. +- Live message declarations/specs/tests: + +```bash +rg --files src/cortical/fabric | rg "message_rule|message_rules|cuda/nn|sequence_surface/compiler|registered_program" +rg -n "MessageRule|MessageRuleBackendSpec|message_rule|message_strategy|message_param_grad|attention|segment" src/cortical/fabric tests +``` + +## Onboarding Steps + +1. Identify the public declaration fields, registry entry, and normalized message-rule spec. Do not discard declaration + fields into a default dot-product program. +2. Lower the message rule into semantic IR/backend spec, primitive rows, tensor roles, parameter bindings, output/artifact + roles, reset/tape policy, and schema versions. +3. Add verifier/legality checks with typed blocker metadata for unsupported layout, dtype, degree, reset, delay, tape, + or workspace cases. +4. Add a reference executor or PyTorch reference path that defines truth for forward values and backward gradients. +5. Add optional native/fused CUDA strategy records over the lowered primitive rows. Use `cb.fabric-throughput-strategy` + for performance-only implementations and `cb.fabric-native-strategy-onboarding` before writing native kernels. Keep + Q/K/V, normalization, gating, distance/delay terms, and aggregation inside the message primitive/strategy boundary, + not the temporal scheduler. +6. Register any trainable message parameters and reducer outputs from binding-owned executable parameter rows only. +7. Add tests proving declaration fields change the compiled program, unsupported variants fail closed before launch, and + supported variants train through the registered temporal program. + +Before editing registered-program kernels or CUDA helpers for a message rule, record: + +```text +Message declaration/spec: +Message primitive rows: +Sender/receiver/edge tensor roles: +Parameter bindings and reducers: +Reference executor: +Native strategy, if any: +Backward/tape/artifact contract: +Unsupported typed blockers: +Files that should not change if the boundary holds: +``` + +If the implementation plan starts with a `.cuh` math edit and cannot name the declaration/spec and rows, stop and add +the compiler products first. + +For a message formula change, record the row delta before native work: + +```text +Old message declaration: +New message declaration: +Primitive/message rows changed: +Tensor roles changed: +Parameter/reducer rows changed: +Reference backward changed: +Unsupported typed blocker: +``` + +If this formula change is being used to prove the compiler path before throughput, use `cb.fabric-compiler-stress-test` +and keep throughput optimization deferred. If the dot-product/context/normalization/gating change only touches `.cuh` +files, the compiler path is not real enough for the change. + +## Boundary Rules + +- Message rules and cells use the same compiler standard: declaration -> IR/spec -> primitive rows -> tensor bindings -> + registered executors. +- Message rules must be registered through a semantic registry analogous to cell/primitive registration. A `.cuh` file or + CUDA helper may implement a registered message strategy, but it must not define message semantics by itself. +- A temporal kernel may schedule message executor rows, but must not own dot-product attention, Q/K/V projection source, + normalization/gating, distance/delay formulas, or aggregation semantics. +- Adding a new message rule should be local to declaration/spec, lowering, native callable/strategy metadata, + implementation, reducers, and tests. If it requires scheduler-owned fixed slots or monolithic scan/reverse ABI edits, + reopen compiler closure. +- If only one message rule is supported, expose that as narrow support and fail other declarations closed. Do not accept a + generic-looking public message API that silently runs one hidden built-in rule. +- A message rule may own sender/receiver/edge semantic inputs, projection roles, normalization/gating/context terms, + aggregation, output roles, and message-parameter reducers. It may not own time scheduling, graph construction, + population routing, readout ownership, workspace/liveness, checkpoint policy, or benchmark policy. +- If dot-product attention math changes, make that visible as message-rule/primitive-op semantics and lowering changes. + Do not patch a `.cuh` helper or registered program kernel in isolation and call the compiler path real. +- Q/K/V-like names are allowed only as tensor roles on a specific lowered message primitive. They are invalid as + temporal-engine global slots or hidden assumptions that every message rule has the same projection structure. +- Message and readout strategies often share sender/receiver/output artifacts, but they must communicate through + artifact/output/reducer route rows. Do not have a message strategy depend on a readout tensor position, one current + producer count, or a role-only lookup when routes are required. + +## Verification + +- Lowering golden for declaration fields and strategy identity. +- Negative legality test for unsupported cases. +- CUDA/reference parity for outputs, input/carry/state gradients where applicable, and every message parameter gradient. +- Source guardrail through `cb.fabric-boundary-guardrails`: message declaration/spec produces owned rows/bindings, and + message math stays out of temporal scheduler files. +- High-level user path smoke: `output = model(x, ...)`, external loss, `loss.backward()`. + +## Integration + +**Uses:** `cb.fabric-compiler-extension`, `cb.fabric-compiler-boundary-audit`, `cb.fabric-backend-boundaries`, +`cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-primitive-op-onboarding` for new message primitives, `cb.fabric-graph-onboarding` for +graph/edge facts consumed by message semantics, `cb.fabric-throughput-strategy`, +`cb.fabric-native-strategy-onboarding`, and `cb.fabric-performance-loop` for throughput/native message strategies, +`cb.fabric-readout-rule-onboarding` when output/readout ownership is affected, `cb.fabric-compiler-stress-test` for +pre-throughput formula-change locality tests, `cb.fabric-boundary-guardrails` for source/static guardrails and deleted +message-route checks, +`cb.fabric-scaling-horizon` for topology and T/K/H audits. diff --git a/skills/cb.fabric-native-strategy-onboarding/SKILL.md b/skills/cb.fabric-native-strategy-onboarding/SKILL.md new file mode 100644 index 00000000..49af62d2 --- /dev/null +++ b/skills/cb.fabric-native-strategy-onboarding/SKILL.md @@ -0,0 +1,157 @@ +--- +name: cb.fabric-native-strategy-onboarding +description: Use when adding or changing a Cortical Fabric registered CUDA/Triton/native callable strategy, fused forward/reverse program kernel, primitive executor implementation, native binding schema, or kernel ABI for existing compiler rows. +--- + +# Fabric Native Strategy Onboarding + +Use this skill when implementation moves into CUDA, Triton, C++ bindings, native callables, or fused program kernels. A +native strategy is an implementation of compiler products; it is not where new Fabric semantics are invented. + +**Announce at start:** "Using Fabric native strategy onboarding. I'll implement only a registered strategy over existing rows and keep new semantics out of native kernels." + +## First Read + +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-throughput-strategy/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` if the native change needs new public semantics or formula meaning. +- `skills/cb.fabric-compiler-stress-test/SKILL.md` if the native change implements a pre-throughput semantic stress test. +- `skills/cb.fabric-parity-gate/SKILL.md` +- `skills/cb.fabric-reducer-liveness/SKILL.md` when native outputs feed parameter reducers, reverse span groups, + carry/state buffers, artifacts, or workspace/lifetime policy. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` when adding source/static checks for native ABI ownership, deleted + fixed wrappers, or no-fallback/no-compat behavior. +- `skills/cb.fabric-runtime-front-end-handoff/SKILL.md` if the native ABI change crosses the public-call boundary into + registered temporal program handoff, state/boundary preparation, reset normalization, or sender K/V setup. +- The narrow semantic skill if the strategy implements a message, readout, cell, graph, or primitive-op change. + +Inspect the active strategy path: + +```bash +rg -n "native_callable|strategy_id|can_implement|executor_row|tensor_binding|program_access|artifact_route|output_route|memory_liveness|parameter_reducer" src/cortical/fabric tests +``` + +## Hard Gate + +Before writing native code, the following must already exist or be added through the semantic/compiler skill first: + +```text +Declaration/spec: +Primitive or graph rows: +Tensor/parameter/artifact/output bindings: +Verifier/legality blockers: +Reference executor: +Forward/backward/reducer contracts: +Memory/liveness/artifact policy: +Audit metadata: +``` + +If any of these are missing, stop native work. Do not create tensor roles, formulas, reset semantics, reducer meaning, or +output ownership inside a `.cuh`, pybind wrapper, temporal scheduler, benchmark, or compatibility helper. + +Write an ABI boundary note before editing native files: + +```text +Rows consumed: +Bindings/routes/liveness consumed: +Native outputs returned by role: +Workspace ownership: +Unsupported typed blockers: +Legacy wrapper/kernel deleted or fail-closed: +``` + +If the ABI starts from fixed tensor slots, Q/K/V globals, gated/diagonal bundle names, output projection globals, or +single/mixed-pop branches, stop and add the missing compiler binding/route rows first. +If the ABI starts from public model/runtime handoff tensors, first classify them with `cb.fabric-runtime-front-end-handoff` +as public adapter outputs, compiler row inputs, planned runtime buffers, artifacts, workspace, or illegal temporaries. +Do not move handoff work into a native kernel unless the consumed tensors and lifetime are compiler-owned rows. + +Native ABI changes must be one of these: + +- a registered implementation of existing semantic rows; +- a compiler-declared output/reducer/artifact/workspace route; +- a metadata-only audit row stripped before semantic consumption. + +Anything else is a semantic compiler-extension. Do not smuggle new tensor roles, parameter meanings, or gradient outputs +through an extra pybind return group, C++ helper argument, fixed slot, or native-only enum. + +## Strategy Contract + +A registered native strategy must declare: + +- stable strategy id/version and row pattern signature; +- input/output/state/artifact role schema and schema version; +- dtype/device/layout/shape/reset/tape/materialization legality; +- forward executor and backward executor coverage, or typed no-grad legality; +- parameter-gradient and reducer output contract; +- workspace/liveness/aliasing requirements; +- deterministic/tolerance class; +- typed rejection reasons; +- audit metadata proving the strategy ran. + +Legality is separate from ranking. `can_implement(plan)` decides correctness; cost/ranking chooses among legal strategies. + +## Kernel Rules + +- Consume compiler-owned rows and bindings directly: primitive rows, executor rows, native callable rows, program access + rows, tensor bindings, artifact/output routes, reset rows, memory/liveness rows, and reducer rows. +- Do not add fixed-slot enums, global Q/K/V assumptions, gated/diagonal/readout formulas, family branches, graph + constructor branches, hidden-size branches, benchmark branches, or separate single/mixed-pop routes. +- Do not copy April21 or legacy fixed-slot kernels. Re-express the useful idea as a registered current strategy over + current rows and bindings. +- If the native strategy needs a new formula, tensor role, artifact route, output route, reducer, tape policy, or reset + behavior, reclassify as compiler-extension work before optimizing. +- If the new formula/role is being introduced to prove compiler locality before throughput, reclassify as + `cb.fabric-compiler-stress-test` before writing native code. +- Delete or fail-close the direct wrapper, fixed ABI, or legacy route the native strategy replaces after parity and + current-code non-regression pass. +- Classify every allocation site: primitive output, primitive workspace, planned runtime buffer, metadata row, or illegal + scheduler allocation. New CPU telemetry or validation tables are metadata rows; CUDA work buffers must be planned + runtime buffers or primitive workspaces, not ad hoc native temporaries. +- Classify every native output: semantic return, reducer input, carry/state input, artifact/tape, workspace, metadata, or + illegal. Local-only tensors should be consumed or released in the native strategy once compiler-declared consumers have + run; do not retain full span groups solely for ABI uniformity. +- If a native kernel currently assumes one message/readout/transition executor, the next strategy change must either + prove that as compiler legality before launch or consume route/merge rows generically. Do not keep singleton + assumptions as hidden ABI. + +## Required Tests + +- Source guardrail through `cb.fabric-boundary-guardrails`: positive native strategy/row/binding ownership plus no fixed + slots, family selectors, benchmark selectors, hidden-size selectors, or scheduler-owned primitive formulas in touched + native/temporal files. +- Legality negative test for unsupported rows/layouts/dtypes/resets/tape/materialization. +- CUDA/reference parity for outputs, exposed states, input/carry gradients, state gradients when exposed, and parameter + gradients. +- Runtime metadata proving the registered strategy id, row fingerprint, and no hidden fallback. +- Warmed performance row only after parity is green. + +## Closeout + +Record in the active progress doc: + +```text +Rows implemented: +Native strategy id/version: +Bindings consumed: +Memory/liveness policy: +Backward/reducer coverage: +Legacy route deleted or fail-closed: +Parity: +Perf: +Remaining owner: +``` + +If the strategy required temporal scheduler formula edits or monolithic ABI expansion, the compiler boundary failed; do +not claim support until that boundary is fixed. + +## Integration + +**Uses:** `cb.fabric-compiler-boundary-audit`, `cb.fabric-throughput-strategy`, `cb.fabric-compiler-extension`, +`cb.fabric-reducer-liveness`, `cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-declaration-onboarding`, `cb.fabric-primitive-op-onboarding`, `cb.fabric-message-rule-onboarding`, +`cb.fabric-readout-rule-onboarding`, `cb.fabric-cell-onboarding`, `cb.fabric-graph-onboarding`, +`cb.fabric-public-api-cleanup`, `cb.fabric-compiler-stress-test`, `cb.fabric-boundary-guardrails`, +`cb.fabric-runtime-front-end-handoff`, `cb.fabric-performance-loop`. diff --git a/skills/cb.fabric-parity-gate/SKILL.md b/skills/cb.fabric-parity-gate/SKILL.md index d12833a8..e491c60d 100644 --- a/skills/cb.fabric-parity-gate/SKILL.md +++ b/skills/cb.fabric-parity-gate/SKILL.md @@ -1,6 +1,6 @@ --- name: cb.fabric-parity-gate -description: Use when modifying or closing Cortical Fabric CUDA/backend paths and parity must be proven against the PyTorch Fabric reference across outputs, states, input/carry gradients, parameter gradients, single/mixed populations, small hidden widths, resets, T=1/T>1, direct/chunked paths, or planner-owned tiling. +description: Use when modifying or closing Cortical Fabric CUDA/backend paths or compiler extensions and parity must be proven against the PyTorch Fabric reference across outputs, states, input/carry gradients, parameter gradients, single/mixed populations, small hidden widths, resets, T=1/T>1, direct/chunked paths, or planner-owned tiling. --- # Fabric Parity Gate @@ -16,20 +16,47 @@ Fabric parity is a backend contract, not a smoke test. A path is not ready for p 0. Treat user corrections as general parity methodology unless explicitly narrowed. A missing parity surface found in one row means the matrix must be strengthened for every comparable backend path, not just that row. 1. Identify the active route and touched owners from runtime metadata: surface key, executor, transition buckets, message/readout owners, demotions, chunking/tape policy, and reset policy. -2. Build the smallest matrix that covers every changed execution path. Do not use one row to represent another if the route, population structure, hidden width, reset behavior, materialization mode, or backward owner differs. -3. Compare CUDA/backend against the PyTorch Fabric reference with identical weights, inputs, resets, and requested state materialization. -4. Assert all required parity surfaces: +2. For new ops/message rules/cells/readouts/graphs/strategies, first use `cb.fabric-compiler-boundary-audit`; for + semantic changes, also use `cb.fabric-declaration-onboarding`; then + identify the declaration/spec, primitive or graph rows, tensor bindings, verifier legality, forward executor, + backward executor, and reducer outputs under test. If those + compiler products do not exist, use `cb.fabric-primitive-op-onboarding`, `cb.fabric-graph-onboarding`, or + `cb.fabric-compiler-extension` before treating parity as meaningful. For readout/output-boundary changes, also use + `cb.fabric-readout-rule-onboarding`. For public API/Config/Blueprint cleanup, also use + `cb.fabric-public-api-cleanup` and prove the high-level declaration no longer reaches the old normalization choke + point. If they exist but the implementation strategy changed, use + `cb.fabric-throughput-strategy`. + If the change alters reducer inputs, reverse span outputs, native return groups, runtime buffers, or workspace + lifetime, also use `cb.fabric-reducer-liveness` and compare every touched parameter-gradient sink. + If the change alters public model forward preparation, input/output adapters, state initialization, boundary tensors, + reset normalization, sender K/V setup, or the call into registered temporal programs, also use + `cb.fabric-runtime-front-end-handoff` and prove the high-level user call reaches the same compiler-owned route. + If the change is a pre-throughput compiler reality test or formula perturbation, also use + `cb.fabric-compiler-stress-test` and require old-vs-new row/binding/route evidence before parity is meaningful. + If the change adds source/static guardrails or deleted-route checks, also use `cb.fabric-boundary-guardrails`; parity + still must prove active behavior. +3. Build the smallest matrix that covers every changed execution path. Do not use one row to represent another if the route, population structure, hidden width, reset behavior, materialization mode, or backward owner differs. +4. Compare CUDA/backend against the PyTorch Fabric reference with identical weights, inputs, resets, and requested state materialization. +5. Assert all required parity surfaces: - forward outputs - materialized final state and any exposed user state - input gradients - provided initial carry/state gradients where applicable - full parameter gradient key equality - full parameter gradient value closeness for every nonzero reference gradient -5. Fail on missing keys, new zero gradients, unexpected detached tensors, benchmark-owned tiling/detach helpers, or hidden fallback owners. -6. Only after parity is green, run performance profiles. If a performance change alters backend routing, rerun parity before accepting the numbers. -7. When a Fabric test weakness is found, audit the whole Fabric test surface, not just the failing test. At minimum run +6. Fail on missing keys, new zero gradients, unexpected detached tensors, benchmark-owned tiling/detach helpers, hidden fallback owners, or CUDA success where the declaration did not lower into the expected primitive rows. +7. Only after parity is green, run performance profiles. If a performance change alters backend routing, rerun parity before accepting the numbers. +8. When a Fabric test weakness is found, audit the whole Fabric test surface, not just the failing test. At minimum run the full `test_fabric*.py` suite and a static hygiene pass for unreachable assertions after direct returns. Stale or unreachable assertions are correctness bugs in the test matrix, not harmless cleanup. +9. For a semantic formula change, first prove the compiler changed: the same public declaration must produce different + primitive rows, tensor roles, attributes, parameter bindings, or strategy legality. Matching values through unchanged + rows usually means the CUDA path is still hiding the old built-in algorithm. +10. For throughput strategies, prove the opposite: the primitive rows and semantics did not change, only the registered + implementation strategy changed. If a throughput patch changes math, tensor roles, or declaration semantics, stop and + reroute it through `cb.fabric-compiler-extension`. +11. For native/fused strategy changes, prove the ABI did not invent semantics: any extra native outputs must be declared + by output/reducer/artifact rows or be metadata stripped before semantic consumption. ## Required Matrix @@ -46,6 +73,10 @@ Fabric parity is a backend contract, not a smoke test. A path is not ready for p | state materialization | materialized final state and no-final-state when the runtime supports both | | backend mode | direct and chunked/compact/streaming paths when planner may select either | +For Fixmass/REDO-style shared temporal engine work, the "large-param" representative surface must include +April21-shaped 100M/500M/1B rows, not only 1M smoke rows. Full T=1 closure also includes the April21 +100M/500M/1B matrix, high-batch small-param rows, small-hidden stress rows, reset/state axes, and mixed-pop rows. + ## Test Pattern Use helpers such as `_build_fabric_model_pair(...)`, `_param_grads(...)`, and runtime metadata assertions in `tests/test_fabric_runtime.py`. @@ -72,6 +103,31 @@ for name in cuda_grads: - A passing h=32 row does not cover small-h many-cell stress paths. - Single-population parity does not cover mixed-population parity while active executors differ, and the reverse is also true. - Full replay/autograd parity is reference evidence only; it does not prove the physical CUDA backward path. +- Parity is not sufficient when the implementation violates Fabric boundaries. For temporal-engine changes, require a + manual code-review note in the active progress doc proving the active CUDA path is generic over flat-bucket + tensor/op-table descriptors and contains no cell-family, population-name, benchmark-row, hidden-size policy, or + separate single/mixed route logic. +- The boundary review must cover both cell math and message-passing math. Passing parity through a facade is not closure: + if Q/K/V, dot-product attention, gated recurrence, diagonal recurrence, projections, normalization, or readout math are + hardcoded outside the corresponding `fabric.cuda.nn` primitive executor, keep the owner open even when values and + gradients match. +- Parity must compare the declared Fabric program, not merely the hidden algorithm that the backend chose to run. If the + public API accepts a generic-looking cell/message declaration but the active CUDA route ignores it or rewrites it into + one fixed built-in algorithm, parity only proves that built-in route and the generic owner remains open. +- Treat Fabric parity as compiler parity: the same declaration must lower into the same semantic program on CUDA and the + PyTorch reference. If the CUDA path bypasses declaration -> IR -> op/tensor rows -> primitive executor lowering, parity + does not close the compiler/backend owner. +- For a new executor strategy, parity must prove both semantics and strategy locality: the same primitive rows still + execute, the strategy owns only implementation choices, and temporal scheduler files did not gain primitive formulas or + strategy-specific route selectors. +- Source guardrails are not parity evidence. They must be paired with runtime CUDA/reference parity when behavior, + gradients, state, reducers, routes, or active execution ownership can change. +- For reducer/liveness changes, parity must prove that moving a tensor from returned output to reducer/runtime/workspace + ownership did not drop, duplicate, detach, or reorder gradients. Parameter-gradient keys and values are mandatory. +- For a new primitive/message/cell formula, parity must prove the opposite locality: the declaration and rows changed, + the reference executor defines the new math, and native strategies implement those rows without bypassing lowering. +- For route-aware outputs, parity must exercise producer routing. A single output/readout row does not prove multi-route + correctness unless compiler legality explicitly rejects multiple producers before launch. - If the active path demotes to replay, wrapper execution, or benchmark-owned tiling, record it as open backend work. - A compatibility or legacy route that passes parity is still open backend work if the intended planner-owned physical path exists or is being introduced. Parity proves correctness only; it does not justify retaining stale bridges or @@ -87,6 +143,13 @@ for name in cuda_grads: ## Integration -**Uses:** `cb.fabric-backend-boundaries`, `cb.fabric-performance-loop` +**Uses:** `cb.fabric-backend-boundaries`, `cb.fabric-compiler-boundary-audit`, `cb.fabric-performance-loop` -**Pairs with:** `t.run-tests` for broader test execution after targeted parity is green. +**Pairs with:** `cb.fabric-compiler-extension` and `cb.fabric-declaration-onboarding` for compiler products, +`cb.fabric-primitive-op-onboarding` for new primitive/message/readout/transition ops, +`cb.fabric-graph-onboarding` for graph/topology declarations, +`cb.fabric-readout-rule-onboarding` for output/readout rules, `cb.fabric-throughput-strategy` for throughput +strategies, `cb.fabric-compiler-stress-test` for pre-throughput formula-change locality tests, +`cb.fabric-public-api-cleanup` for public declaration cleanup, `cb.fabric-runtime-front-end-handoff` for public-call +handoff changes, `cb.fabric-boundary-guardrails` for source/static guardrails and deleted-route checks, `t.run-tests` +for broader test execution after targeted parity is green. diff --git a/skills/cb.fabric-performance-loop/SKILL.md b/skills/cb.fabric-performance-loop/SKILL.md index 601fcb01..30a6f6a3 100644 --- a/skills/cb.fabric-performance-loop/SKILL.md +++ b/skills/cb.fabric-performance-loop/SKILL.md @@ -1,6 +1,6 @@ --- name: cb.fabric-performance-loop -description: Use when profiling, benchmarking, optimizing, or investigating Cortical Fabric performance, throughput, scaling, CUDA kernel ownership, backward cost, profile artifacts, or benchmark regressions. +description: Use when profiling, benchmarking, optimizing, or investigating Cortical Fabric performance, throughput, scaling, CUDA kernel ownership, backward cost, profile artifacts, benchmark regressions, or new throughput executor strategies. --- # Fabric Performance Loop @@ -12,6 +12,35 @@ Use this skill for Fabric performance work where current evidence matters more t ## First Read - Read `src/cortical/fabric/README.md` for current architecture vocabulary. +- If the work will change compiler products, backend route ownership, CUDA/native strategies, or closure evidence, also + read `skills/cb.fabric-compiler-boundary-audit/SKILL.md` and complete its audit gate before accepting the result. +- If the work adds or changes primitive/message/readout/transition math, CUDA kernels, native callables, parameter + reducers, tensor bindings, or executor strategies, also read `skills/cb.fabric-compiler-extension/SKILL.md` before + editing. +- If profiling exposes missing or changed user-visible graph/cell/message/readout/primitive/reset semantics, also read + `skills/cb.fabric-declaration-onboarding/SKILL.md` and stop optimization until the declaration lowers into rows. +- If the prompt asks for a pre-throughput compiler reality test or formula perturbation, also read + `skills/cb.fabric-compiler-stress-test/SKILL.md`; do not run it as throughput optimization. +- If the work adds or improves a throughput executor, fused program kernel, memory/liveness policy, artifact policy, or + parameter-reduction strategy, also read `skills/cb.fabric-throughput-strategy/SKILL.md`. +- If the owner is a parameter reducer, reverse span output, native return group, workspace/lifetime edge, or runtime + buffer consumed by a reducer, also read `skills/cb.fabric-reducer-liveness/SKILL.md`. +- If the work changes CUDA/Triton/C++ native strategy bodies, native callable binding schemas, fused program kernels, or + kernel ABIs, also read `skills/cb.fabric-native-strategy-onboarding/SKILL.md`. +- If the performance owner needs a new semantic primitive, read `skills/cb.fabric-primitive-op-onboarding/SKILL.md` + before adding execution. +- If the performance owner changes readout/output-boundary semantics, output routes, pooling, or readout reducers, read + `skills/cb.fabric-readout-rule-onboarding/SKILL.md`. +- If the performance owner depends on graph/topology construction, factorization, boundary sets, degree buckets, + distance/delay facts, or lattice/config cleanup, read `skills/cb.fabric-graph-onboarding/SKILL.md`. +- If performance work exposes public API, Config, Blueprint, or declaration-normalization cleanup, read + `skills/cb.fabric-public-api-cleanup/SKILL.md`; cleanup must not become a benchmark-side or backend-policy shortcut. +- If profiling names a pre-temporal or pre-registered-entry owner, such as public model forward preparation, + input/output adapters, boundary tensor prep, state initialization, reset normalization, sender K/V setup, or call + handoff into the registered temporal program, read `skills/cb.fabric-runtime-front-end-handoff/SKILL.md` before + editing. +- If performance work adds source/static guardrails, no-fallback/no-compat checks, or legacy deletion checks, read + `skills/cb.fabric-boundary-guardrails/SKILL.md`; guardrails do not replace warmed current-code measurements. - Search current Fabric benchmark, profile, and perf artifacts: ```bash @@ -27,19 +56,82 @@ rg -l -i "fabric|profile|benchmark|PhysicalOpPlan|rejected|probe|fallback|hardco Use chats for process memory, not current benchmark truth. +## Hypothesis Probes + +Before spending a long pass on an optimization direction, run the cheapest representative experiment that can falsify it. +Write the probe in the progress doc before running it: + +```text +Hypothesis: +Expected owner movement: +Smallest representative row or synthetic fixture: +Artifact path: +Keep/narrow/revert rule: +Follow-up representative row if promising: +``` + +Good probes are small but real: fewer iterations, lower batch/params, forward-only when the owner is forward-only, +source instrumentation, allocator-stage telemetry, launch-count checks, or a single warmed row that shares the same +compiler route and owner as the larger target. The probe must use the normal high-level Fabric call path and active +compiler-owned backend route; it must not add benchmark-side tiling, private helper calls, shape/family branches, +fallback routes, or fake metadata. + +Synthetic fast experiments are encouraged when they answer a narrow mechanism question faster than an audit row. Valid +synthetic probes include tiny generated Fabric programs, toy graph/bucket shapes, synthetic tensors/weights/resets, +lower batch/params, isolated allocator/liveness instrumentation, launch-count checks, and one-off scratch scripts under +`tmp/` that exercise the suspected owner. Prefer a high-level `model(...)` synthetic mini row when route behavior matters. +If a lower-level fixture calls compiler/runtime internals directly, mark it as mechanism-only evidence and follow it with +a normal high-level row before accepting the direction. + +Use a probe to decide whether to keep investigating a direction, not to claim closure. If the expected owner does not +move in timing, launch count, storage identity, allocator telemetry, or memory stage evidence, record the probe as +rejected and stop expanding that direction. If it moves, run the matched representative warmed row before accepting the +change or citing performance progress. + ## Evidence Rules - Treat user corrections as general performance methodology unless explicitly narrowed. If the user rejects stale numbers, benchmark-owned tiling, local workaround churn, legacy paths, or narrow shape rules in one context, apply that rule to all Fabric performance work before accepting new evidence. +- For throughput prompt sequencing, do not optimize during the "deep dive/analyze" phase. That phase produces the + current-code owner table, baseline commands, active compiler products, and likely highest-impact owner. The "plan" + phase selects the strategy and gates. Implementation starts only when the user asks to proceed or the task explicitly + asks for code changes. +- Throughput work must optimize a verified compiler product. Before adding a kernel, native callable, launch path, or + parameter reducer, identify the primitive/executor/tensor-binding rows it implements. If those rows do not exist, stop + performance work and use `cb.fabric-compiler-extension` to add semantics and legality first. +- A semantic stress test is not a performance probe. If the task changes dot-product/message math, normalization, + gating, recurrence, readout, tensor roles, or reducer meaning to prove compiler locality, switch to + `cb.fabric-compiler-stress-test` and defer throughput metrics until row-delta and parity gates pass. +- In the "Proceed" phase, do not start by editing kernels. First write the boundary classifier result from + `cb.fabric-compiler-boundary-audit`: semantic extension, throughput strategy, evidence-only, or cleanup/deletion. A + performance owner that needs a new op, tensor role, artifact route, output route, or gradient contract is semantic + extension work, not throughput work. +- Treat April21 as a baseline to beat, not a source tree to copy. A performance idea from old code is acceptable only + after it is redesigned as a registered current compiler strategy with explicit legality, memory/liveness, forward, + backward, reducer, and audit contracts. - Do not hardcode benchmark commands, expected winners, profile rows, speedups, or kernel counts in this skill. - Before starting or resuming Fabric performance work, name the live progress doc that will carry the current checklist, commands, artifacts, and status. If no suitable doc exists, create or append a concise progress section in the relevant Fabric benchmark/design doc before running long benchmarks. - Progress belongs in docs, not chat memory. Keep the doc updated with: the current cleanup step, the owner being measured, commands run, artifact paths, accepted/rejected probes, stale or aborted runs, and the next closure gate. +- Prefer a fast hypothesis probe before a long representative audit when the result would decide whether an optimization + direction is worth continuing. The probe must be explicitly marked as steering evidence and paired with the + representative row before any acceptance or closure claim. +- Do not postpone audits until after a large backend rebuild. Full closure suites can wait for a credible owner, but + selected representative high-level audit rows must run throughout implementation to steer the work and reopen older + owners early. Treat these as guardrails, not closure evidence: T=1 April21-shaped training, T/H representative rows, + K sweep probes, reset parity, mixed-pop T=1, and small-hidden/shape guards should be rerun after temporal engine, + checkpoint, reverse-scan, bucket, or message changes that can affect them. +- Fast 1M/high-batch rows are guardrail probes only. They do not represent T=1 closure or the representative audit + frontier by themselves. Representative T=1 guardrails must also include April21-shaped 100M/500M/1B rows, especially + any current-code large-model rows that are failing, plus the matched batch/hidden/loss/reset/state contract. Full T=1 + closure still includes the April21 100M/500M/1B rows, small-hidden stress rows, high-batch small-param rows, + reset/state axes, and mixed-pop rows. - Performance evidence is invalid if the active planner/backend path has not passed `cb.fabric-parity-gate` against the - PyTorch reference. Before accepting a profile-driven change, run targeted parity for every affected execution mode: + PyTorch reference or if the compiler-boundary audit was skipped for a backend/compiler strategy change. Before + accepting a profile-driven change, run targeted parity for every affected execution mode: direct and chunked, reset and no-reset, materialized and unmaterialized final state, terminal and per-timestep sequence loss, `T=1` and `T>1`, small-hidden many-cell stress rows, large batch rows, and large-parameter rows where feasible. - Parity means outputs, exposed states, hidden/input/carry gradients, and parameter gradients all match. Missing gradient @@ -48,6 +140,13 @@ Use chats for process memory, not current benchmark truth. - If existing profiles are missing, stale, or not the active shape, discover the current command from scripts, tests, docs, or `--help`. - Historical Fabric profiles are context only. Before naming a row as the active largest owner, claiming a regression, claiming a fix, or closing a performance gate, rerun the relevant current-code row with a warmed confirmation profile. +- For a severe Fabric regression against an April21-style row, an audit summary is not enough. Run deeper current-code + profiles before structural CUDA edits: owner-timed audit JSONL, warmed repeated measurements, a PyTorch profiler or + CUDA-event timeline when available, and a focused kernel-level breakdown for the dominant Fabric kernel. Record every + artifact path in the active progress doc before using the profile to choose the next owner. +- If an owner appears hot only on the first measured backward in a fresh process or cache, run an in-process warmup and + measure the second pass before changing backend code. Cold Triton/CUDA extension specialization is a compile artifact, + not a Fabric throughput owner, unless it remains hot after the warmed pass. - When a current-code rerun contradicts a historical artifact, update the active plan/profile doc with the correction so later work does not loop back to the stale conclusion. - If an accepted current-code path already demonstrates the target scaling behavior, use it as the performance and @@ -59,7 +158,70 @@ Use chats for process memory, not current benchmark truth. - Use warmed repeated measurements before accepting a performance claim. - Do not run two performance/profile commands concurrently on the same GPU and then treat either result as evidence. Concurrent profiles are useful only when the commands are explicitly isolated to different devices. -- A metadata label is not evidence. The active path must physically move, or the named owner must shrink in the profile. +- When the user gives a GPU allocation for Fabric experiments, treat it as binding for the session and record it in the + active progress doc. Use private per-project or per-run `TORCH_EXTENSIONS_DIR` and `TRITON_CACHE_DIR` values for + CUDA/Triton work instead of the global cache; stuck or cross-contaminated extension caches are not valid performance + evidence. +- A metadata label is not evidence. The active owner must physically move in timing, launch count, storage identity, + allocator telemetry, or memory stage evidence, or the named owner must shrink in the profile. +- A source guardrail is not performance evidence. It can keep a stale route from returning, but accepted throughput work + still requires active owner movement in timing, launch count, allocator telemetry, or memory stages. +- For memory/liveness probes, report actual storage identity and lifetime movement, not only metadata. If a no-copy or + alias probe increases peak memory or creates a large unclassified allocator owner, stop adding aliases and identify the + owner first: compiler lifetime bug, autograd saved tensor, CUDA temporary/workspace, or allocator reserve gap. Record + whether the exact diff should be kept, narrowed to forward-only, or reverted before further code changes. +- For native-stage memory work, collect allocator telemetry inside the registered strategy if Python-side before/after + hooks cannot see the peak. The telemetry must be metadata, and the semantic pybind return must remain compiler-row + owned. Use the named native stage to choose the next liveness/reducer/workspace patch. +- For reducer/lifetime memory work, the accepted change must show which tensor moved from returned span output to reducer + input, carry/state buffer, workspace, artifact, metadata, or drop-after-use. If the hot native stage only changes labels + and the allocator owner does not move, reject the patch. +- For front-end handoff memory work, separate public adapter tensors from backend runtime tensors. If a pre-registered + owner is really boundary/state/sender-KV setup for the temporal program, move it into compiler-owned runtime buffers, + liveness rows, or a registered strategy; do not hide it with benchmark-side chunking, private helper calls, old Config + defaults, or shape/family branches. +- The first performance priority for fixmass-style Fabric work is the shared temporal engine. Optimize and migrate the + throughput-critical forward/backward temporal owner before spending engineering cycles on cleanup that does not move + active kernels, launch count, memory, or measured tok/s. +- If the current high-priority owner is a Python/host scan loop, K/H loop, backward replay loop, or missing temporal + superop, write the CUDA kernel or backend superop needed to move that owner. Do not defer kernel implementation behind + planner-only edits, metadata relabeling, benchmark cleanup, or route deletion when the kernel is the next bottleneck. +- Kernel priority never overrides Fabric design. A temporal superop is acceptable only after its ABI is generic over + flat-bucket tensor/op tables lowered from `fabric.cuda.nn`/IR. Reject kernels that encode gated/diagonal/sLSTM/Axon + semantics, population names, benchmark rows, or hidden-size policy inside the shared temporal engine, even if they + appear to improve the immediate throughput owner. +- A throughput strategy is a replaceable implementation of lowered primitive rows, not a new semantic authority. It must + declare legality, tensor roles, layouts, workspace/liveness needs, forward executor, backward executor, reducer + outputs, determinism/tolerance class, audit metadata, and typed rejection reasons. Cost/ranking may choose among legal + strategies; it must not decide correctness. +- Do not optimize throughput by copying declared cell/message primitive math into temporal scan or reverse kernels. + Temporal kernels schedule time, horizon, reset, checkpoint, materialization, and tensor/op-row dependencies. Primitive + math belongs in `fabric.cuda.nn`/lowered primitive executors selected by tensor-table/op-row metadata. A probe that + inlines gated recurrence, diagonal recurrence, attention, normalization, projection, or readout formulas in the + temporal engine is invalid performance evidence and must be reverted before further profiling. +- The same facade rule applies to cells and message passing. A message rule that looks generic in `fabric.cuda.nn` is not + performance evidence if the measured backend still hardcodes Q/K/V, dot-product attention, or fixed aggregation inside + the temporal engine. Q/K/V-like names may appear only as roles of a declared message primitive and its executor, not as + temporal scheduler concepts. +- Do not accept a fake `fabric.cuda.nn` path as performance evidence. A row only counts if the declaration lowered into + IR, primitive op rows, tensor-table roles, parameter bindings, and primitive executor dispatch. If `fabric.cuda.nn` + appears only as metadata while throughput comes from hardcoded backend formulas or cell-specific temporal branches, + mark the run invalid and reopen the backend boundary owner. +- Fabric performance claims must measure the declared Fabric program, not a hidden canonical algorithm. A PyTorch-style + library cannot claim generic throughput if the benchmark requested one `fabric.cuda.nn` program and the backend ran a + fixed dot-product/gated/diagonal/readout route because that is the only implemented path. Treat that result as an + unsupported-declaration or facade failure, not as performance evidence. +- Treat Fabric as a compiler when profiling. The measured path must prove that the declared program compiled into the + runtime op/tensor tables and primitive executors that produced the throughput. If the compiler boundary is skipped, + the profile is measuring an implementation shortcut, not Fabric. +- Before treating a new temporal kernel as performance progress, do a manual code-review audit and log it in the active + progress doc: ABI shape, tensor-table/op-table ownership, reset/checkpoint/materialization ownership, active-route + metadata, absence of cell-family or single/mixed-pop routing, and proof that `fabric.cuda.nn` declarations truly lower + through tensor-table/op-row primitive executors rather than acting as a facade over hardcoded internals. Passing parity + or faster tok/s does not rescue a boundary violation. +- Final fixmaass-style temporal closure requires runtime/audit metadata to report compiler-owned registered temporal + forward and backward owners for supported training rows. Any `python_autograd_scan` or replay/fallback route is + forbidden legacy code and must fail closed even if parity and ordinary reference gates pass. - A regression does not change the target design or refactor goal. Keep the same architectural objective, diagnose the concrete owner, iterate within that goal, or reject the probe with evidence. - Do not start with a fast throughput patch when a regression contradicts an accepted Fabric invariant. First identify the planner/runtime decision that changed, state the generic invariant it violated, and then fix the shared decision @@ -69,15 +231,33 @@ Use chats for process memory, not current benchmark truth. performance work. If the same bridge/wrapper/replay owner remains dominant after a probe, the next owner is the framework-level planner/executor replacement, followed by deletion of the stale path after parity and current-code non-regression pass. +- Benchmark and audit code is a high-level Fabric API consumer, not a planner/backend design surface. Closure evidence + must come from ordinary user-style calls: `output = model(x, ...)`, an external loss, `loss.backward()`, and optimizer + step when applicable. This applies to T=1, T>1, K>1, T*K, horizon-H, terminal loss, and per-timestep loss rows. + Benchmarks may report backend metadata after the call, but must not call private planner/runtime helpers, implement + temporal chunking, choose workspace/tape/checkpoint policy, detach carry, or run per-chunk backward loops to make a + Fabric row pass. ## Scaling Rules - Fabric scales through cell count, topology size, batch/parallel rollout width, streaming or sequence time, and future population/bucket structure. -- Treat the audit frontier as `B x params x h x graph x T`. Batch and parameter count are the main throughput axes, but - hidden size is part of the correctness/stress space: `h=32` may be the headline baseline while `h=4/8/16` small-hidden - rows remain required many-cell scaling guards. Do not drop or explain away small-hidden failures as nonrepresentative - without a current-code generic owner analysis. -- Include inner recurrent steps `K` when temporal planning changes. K>1 should reuse the same backend-owned temporal +- Treat the Fabric parallelism/scaling frontier as spatial axes plus temporal streaming axes. `B`, params, graph/node + count, topology, population buckets, and `h` expose parallel/spatial work. `T` and `K` extend the same streaming + substrate over time; `T*K` is repeated T=1 execution under one backend-owned temporal engine, with different output + materialization and checkpoint/recompute policy, not a separate algorithm. +- T=1 health is the base performance and semantic contract for the shared temporal engine. Do not treat K, H, or long-T + work as independent closure while matched T=1 training is regressed; T/H/K profiles are diagnostic until the T=1 path + is healthy on the same graph/batch/params/hidden/population/loss-boundary contract. +- T=1 throughput closure should start with the mixed-population path when it exposes the real shared-engine owner, while + still keeping single-population rows as same-engine guardrails. Do not create separate single-pop and mixed-pop + performance fixes. +- Do not claim T=1 health from a small 1M row. The April21-shaped 100M/500M/1B rows are part of the minimum + representative surface for backend work because they expose memory/parameter-gradient behavior that small guardrails + miss. +- Batch and parameter count are the main throughput axes, but hidden size is part of the correctness/stress space: + `h=32` may be the headline baseline while `h=4/8/16` small-hidden rows remain required many-cell scaling guards. Do + not drop or explain away small-hidden failures as nonrepresentative without a current-code generic owner analysis. +- Include inner recurrent steps `K` when temporal planning changes. K>1 must reuse the same backend-owned temporal executor as K=1, with `inner_steps=K` recorded in metadata. Performance probes should compare K=1/2/4 where relevant, cover reset/no-reset, and verify throughput/memory movement comes from named temporal/message/transition/readout owners rather than benchmark-side repeated calls. @@ -87,6 +267,11 @@ Use chats for process memory, not current benchmark truth. workspace behavior in the benchmark harness to make a row fit. The harness passes the requested tensors and reports backend/runtime metadata; tiling, chunking, rolling tape, large-R decomposition, and workspace policy must be selected inside the backend. +- Do not infer or declare the root cause of a T/H/K throughput or memory failure from code inspection, metadata, or + small `B=1` smoke probes while the matched April21-shaped T=1 training row has not passed. First run the matched + current-code T=1 baseline on the same graph/batch/params/hidden/population/loss-boundary contract, record the exact + April21 reference key, owner metadata, throughput, and peak memory, then decide whether the owner is already a T=1 + training regression or a larger-T/H-specific regression. - Prefer shape-size agnostic optimizations driven by runtime shape, topology, bucket, layout, reset, dtype, weight-sharing, and workspace metadata. - Treat shape/factorization audits as flat-graph audits. Before claiming factorization sensitivity or invariance, confirm that flat graph invariants match: active nodes, input/output boundary counts, degree histogram, edge buckets, group @@ -114,6 +299,27 @@ Use chats for process memory, not current benchmark truth. The backend/runtime must make that regular path use bounded streaming tape internally. If a result required a benchmark-side streaming loss hook, per-chunk backward, or harness-side time tiling, mark it invalid and rerun before citing it. +- Training itself is always streamed in Fabric. Sequence-loss training can + produce a dense user-visible output tensor, but internal state, message, + transition tape, K/V, readout, and boundary adjoints must be bounded by the + planner's emission/checkpoint/recompute/window policy over `T*K`. A profile + that allocates full `[T, B, cells, ...]` internal artifacts for training is a + streaming-liveness regression unless that tensor is the requested user output + or an explicitly recorded compact checkpoint artifact. +- The April26 target shape is: forward streams the planned substrate once, + saves only compact checkpoints selected by `CheckpointPlan`, then ordinary + `loss.backward()` walks `BackwardWindowPlan` right-to-left, injects output + gradients at emission steps, runs physical adjoints, accumulates gradients, + and discards segment-local artifacts. Use this as the default interpretation + when diagnosing T/H/K throughput and memory. +- T*K and finite-horizon H audits follow the same rule. The benchmark can request the public semantic inputs that the + model exposes, but the actual scan schedule, emission plan, rolling horizon, checkpoints, and materialization policy + must be backend-owned and invisible to the measurement path except as recorded metadata. +- For fixmass temporal work, performance changes must preserve the April26-style + planning contract: output request, autograd seed surface, H window, checkpoint + stride, recompute window, and reverse artifact kind are planner decisions over + the whole `T*K` stream. Do not accept throughput from a one-step-only, + terminal-only, K-only, or benchmark-row materialization route. - Compare `T>1` throughput to the current-code Fabric `T=1` per-token line on the same graph/batch/parameter row. Fabric throughput should stay flat or increase as `T` grows. Stack comparison remains useful row context, but it is not the large-`T` pass/fail criterion. A Fabric slowdown is an open owner until warmed evidence attributes it to a named @@ -155,18 +361,31 @@ Use chats for process memory, not current benchmark truth. ## Loop 1. Identify the current performance question and the active execution mode. -2. Update the live progress doc with the current step order before running benchmarks, especially when cleanup work is +2. If the change adds/improves a throughput executor strategy, run the throughput-strategy checklist first: semantics + already lower to primitive rows, legality is explicit, forward/backward/reducer coverage exists, memory/artifact rows + are compiler-owned, and replaced routes are deleted or fail closed. If native code changes, also run the native + strategy checklist before writing kernels. +3. If the profile suggests adding a missing op, tensor role, output route, artifact route, reducer, reset behavior, or + backward contract, stop the performance loop and switch to compiler-extension work. Do not implement the missing + semantic product as a faster kernel patch. +4. Update the live progress doc with the current step order before running benchmarks, especially when cleanup work is in progress and could otherwise be forgotten. -3. Use existing artifacts only to choose the confirmation target, then rerun the relevant current-code row before making +5. Use existing artifacts only to choose the confirmation target, then rerun the relevant current-code row before making owner or closure claims. -4. Name the dominant owner in concrete backend terms: message, projection, receiver affine, state epilogue, readout, backward glue, layout/copy, or another live owner found in code/profiles. -5. Make one focused probe or structural change aimed at that owner. -6. Re-run the closest warmed active-path profile. -7. If it regresses, do not redefine success or switch goals. Attribute the regression, adjust the implementation inside the same target design, and rerun; if the approach is wrong, reject the probe and record why. -8. Accept only if the owner moves without violating Fabric semantics or shape-general scaling goals. -9. Update the progress doc before returning: record status, artifact paths, numbers that are current, stale corrections, +6. Name the dominant owner in concrete backend terms: message, projection, receiver affine, state epilogue, readout, backward glue, layout/copy, or another live owner found in code/profiles. +7. Make one focused probe or structural change aimed at that owner. +8. Re-run the closest warmed active-path profile. +9. If it regresses, do not redefine success or switch goals. Attribute the regression, adjust the implementation inside the same target design, and rerun; if the approach is wrong, reject the probe and record why. + If a low-level kernel path is parity-clean but slower on the warmed active row, keep it disabled on the active path + unless another required row proves it is needed, and record that rejection in the progress doc. + If a hot temporal reverse edge already has CUDA kernels but still launches adjacent projection, layout, or + materialization kernels per physical step, prefer fusing those backend-owned pieces in the same CUDA window path + before chasing planner labels or benchmark harness changes. Count the slice only if parity stays green and the + warmed owner moves. +9. Accept only if the owner moves without violating Fabric semantics or shape-general scaling goals. +10. Update the progress doc before returning: record status, artifact paths, numbers that are current, stale corrections, cleanup items still pending, and the next single closure gate. -10. Record stale-profile corrections and rejected probes when they are likely to be repeated. +11. Record stale-profile corrections and rejected probes when they are likely to be repeated. ## Status Format @@ -191,4 +410,13 @@ Scaling impact: ## Integration -**Pairs with:** `cb.fabric-backend-boundaries` for backend execution changes, `cb.fabric-cell-boundaries` when a performance issue appears to originate in cell-local math. +**Pairs with:** `cb.fabric-backend-boundaries` and `cb.fabric-compiler-boundary-audit` for backend execution changes, +`cb.fabric-throughput-strategy` for registered strategy optimization, `cb.fabric-native-strategy-onboarding` for native +strategy implementation, `cb.fabric-compiler-extension` and `cb.fabric-primitive-op-onboarding` for new +primitive/message/readout/transition math, +`cb.fabric-declaration-onboarding` for public semantic changes, `cb.fabric-readout-rule-onboarding` for output/readout +owners, `cb.fabric-graph-onboarding` for graph/topology or lattice/config owners, `cb.fabric-public-api-cleanup` for +declaration/API cleanup, `cb.fabric-runtime-front-end-handoff` for public-call preparation and pre-registered-entry +owners, `cb.fabric-compiler-stress-test` for pre-throughput formula-change locality tests, +`cb.fabric-boundary-guardrails` for source/static guardrails and deletion checks, +`cb.fabric-cell-boundaries` when a performance issue appears to originate in cell-local math. diff --git a/skills/cb.fabric-primitive-op-onboarding/SKILL.md b/skills/cb.fabric-primitive-op-onboarding/SKILL.md new file mode 100644 index 00000000..aa7b5b9b --- /dev/null +++ b/skills/cb.fabric-primitive-op-onboarding/SKILL.md @@ -0,0 +1,179 @@ +--- +name: cb.fabric-primitive-op-onboarding +description: Use when adding, changing, or porting a Cortical Fabric primitive op, fabric.cuda.nn operator, transition/readout/message primitive, primitive lowering row, tensor-role binding, native callable, backward adjoint, tape contract, or parameter-gradient reducer. +--- + +# Fabric Primitive Op Onboarding + +Use this skill when the semantic unit itself changes. A primitive op is a compiler product, not a temporal-engine +shortcut. + +**Announce at start:** "Using Fabric primitive-op onboarding. I'll add the op through declaration, lowering, rows, bindings, legality, and executors before touching performance." + +## First Read + +- `skills/cb.fabric-backend-boundaries/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- `skills/cb.fabric-compiler-stress-test/SKILL.md` when changing an existing primitive as a pre-throughput compiler + reality test. +- `skills/cb.fabric-declaration-onboarding/SKILL.md` when the primitive is exposed as user-visible semantics or changes + a declared formula. +- `skills/cb.fabric-graph-onboarding/SKILL.md` if the primitive consumes graph/topology facts such as edge rows, + segments, degree buckets, distance, delay, groups, or boundary/port sets. +- `skills/cb.fabric-readout-rule-onboarding/SKILL.md` if the primitive is introduced for readout/output-boundary + semantics. +- `skills/cb.fabric-native-strategy-onboarding/SKILL.md` before implementing a CUDA/Triton/C++ native strategy for the + primitive. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` when adding source/static checks for primitive row ownership, no + scheduler-owned formulas, deleted wrappers, or no hidden fallback. +- Relevant live primitive registries, lowering, executor, and tests: + +```bash +rg --files src/cortical/fabric | rg "cuda/nn|primitive|native_callable|executor_patterns|executor_bindings|parameter_reducer|transition_execution|sequence_surface/compiler" +rg -n "Primitive|primitive_row|primitive_opcode|native_callable|tensor_binding|parameter_reducer|missing_executor|unsupported" src/cortical/fabric tests +``` + +Use docs for intent only. Verify the active lowering and runtime path in code. + +## Required Chain + +Every supported primitive must have this chain: + +```text +public declaration / fabric.cuda.nn op + -> semantic IR or backend spec + -> primitive registry definition + -> primitive op rows + -> tensor role rows and parameter binding rows + -> verifier / legality / typed blocker metadata + -> reference executor + -> optional native/fused strategy + -> backward/tape/recompute contract + -> reducer rows for executable trainable parameters + -> parity + source guardrails +``` + +Changing the declaration must either change the compiled rows/bindings or fail closed before launch. + +Before writing native or scheduler-adjacent code for a primitive, fill this out: + +```text +Semantic attributes: +Input/output/state roles: +Parameter roles: +Forward reference: +Backward/reference gradient: +Reducer rows: +Native strategy legality: +Rows that prove this is a semantic change: +``` + +If `Rows that prove this is a semantic change` is empty, do not patch native math. Either the declaration did not lower +or this is actually throughput work over existing semantics. + +## Boundary Rules + +- Do not implement primitive formulas inside temporal scan/reverse, scheduler, benchmark, config, or route-selection code. +- Do not add fixed slot enums or top-level temporal ABI names for the new op. Tensor roles come from the primitive's + declared inputs, outputs, attributes, and parameter bindings. +- Do not key legality or executor selection on cell family, message-rule name, benchmark row, hidden-size constants, + single/mixed population labels, or old tensor names. +- Composite primitives are valid only when they are explicit semantic ops with declared attributes, tensor roles, + forward/backward coverage, and fail-closed unsupported cases. +- A schema entry is not automatically a parameter-gradient row. Reducers must be derived from binding-owned executable + trainable parameters. +- Reference/PyTorch execution defines truth; native CUDA/Triton strategies are replaceable implementations of the same + primitive rows. +- The locality test for a new op is strict: after the compiler scaffolding exists, adding or changing op math should be + local to public declaration/spec metadata, primitive lowering, reference execution, optional native strategy, + backward/reducer contracts, verifier rules, and tests. It should not touch temporal scheduler policy, graph + constructors, fixed slot enums, or monolithic scan/reverse program ABIs. + +## Implementation Steps + +1. Define the public op/declaration and normalized semantic attributes. +2. Register the primitive with stable opcode/name/schema version, explicit input/output/state/attribute roles, and a + reference execution contract. +3. Lower the declaration into primitive rows and tensor/parameter binding rows. +4. Add verifier rules for shape, dtype, layout, reset/tape/materialization, aliasing, and unsupported attributes. +5. Add a reference executor and backward/reference gradient behavior. +6. Add optional native/fused executor metadata only after `can_implement` legality is explicit. Use + `cb.fabric-native-strategy-onboarding` before writing native kernels or binding schemas. +7. Add backward/tape/recompute rows and parameter reducers for every trainable binding. +8. Update audit metadata so the active path reports the primitive/executor owner and typed rejection reasons. +9. Delete or fail-close any old direct wrapper/legacy path that the primitive replaces. + +Before native/CUDA work for the op, record: + +```text +Op declaration/spec: +Primitive opcode/schema version: +Input/output/state/attribute roles: +Parameter bindings and reducers: +Reference executor: +Native strategy, if any: +Backward/tape/recompute contract: +Memory/artifact/output routes, if any: +Unsupported typed blockers: +Files that should not change if the boundary holds: +``` + +If this list cannot be completed, do not write a native kernel yet. + +## Formula Change Stress Test + +When an existing formula changes, such as adding normalization, gating, context terms, or a new reduction to a message or +transition op, use `cb.fabric-compiler-stress-test` and treat it as a compiler stress test: + +- the public declaration/spec must carry the new semantic fields; +- lowering must emit different primitive rows, tensor roles, attributes, or parameter bindings; +- the reference executor must define forward and backward truth; +- unsupported CUDA/native coverage must fail closed with typed blockers; +- native/fused strategies may be added only after the reference and legality contracts exist; +- parity must prove the declared program changed, not just that a hidden built-in formula still matches. + +## Required Tests + +- Lowering golden: op attributes and tensor roles produce distinct primitive rows/bindings. +- Negative legality: unsupported attributes/layout/dtype/reset/tape fail closed before launch. +- Reference parity: outputs, state, input/carry gradients, and parameter gradients. +- Native/fused parity if a CUDA/Triton strategy is registered. +- Reducer coverage for every trainable parameter output. +- Source guardrail through `cb.fabric-boundary-guardrails`: primitive rows/bindings are consumed and primitive formulas + do not appear in temporal scheduler/scan/reverse/benchmark files. +- High-level smoke: user-style model forward, external loss, `loss.backward()`. +- Locality/source guard: adding the primitive did not require editing temporal scheduler formulas, fixed slots, + benchmark harness policy, or cell/message family selectors. + +## Closeout Questions + +Record these in the active progress doc: + +- Which declaration/spec changed? +- Which primitive rows and tensor bindings changed? +- Which verifier/legality failures were added? +- Which forward, backward, tape/recompute, and reducer contracts own the op? +- Which old wrappers or fixed slots were deleted or invalidated? +- What files did not need to change because the compiler boundary held? + +If adding the op required editing monolithic scan/reverse ABIs or temporal scheduler formulas, compiler closure is not +holding; reopen the boundary and refactor before claiming support. + +The easiest new-op review is the locality review: a simple op addition should not require touching unrelated graph +constructors, benchmark code, temporal loop ownership, fixed slot enums, or family-specific dispatch. If it does, record +the compiler boundary that failed and close that first. + +## Integration + +**Uses:** `cb.fabric-backend-boundaries`, `cb.fabric-compiler-boundary-audit`, `cb.fabric-compiler-extension`, +`cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-declaration-onboarding` for public semantic changes, `cb.fabric-throughput-strategy` when +optimizing the op, `cb.fabric-native-strategy-onboarding` for native/fused strategy implementation, +`cb.fabric-graph-onboarding` when the op consumes graph/topology facts, +`cb.fabric-message-rule-onboarding` when the op is introduced for message semantics, +`cb.fabric-readout-rule-onboarding` when the op is introduced for readout/output semantics, +`cb.fabric-cell-onboarding` when the op is introduced for a cell transition, +`cb.fabric-compiler-stress-test` when the op change is a pre-throughput locality proof, +`cb.fabric-boundary-guardrails` for source/static guardrails and deleted primitive-route checks. diff --git a/skills/cb.fabric-public-api-cleanup/SKILL.md b/skills/cb.fabric-public-api-cleanup/SKILL.md new file mode 100644 index 00000000..c1e592ba --- /dev/null +++ b/skills/cb.fabric-public-api-cleanup/SKILL.md @@ -0,0 +1,118 @@ +--- +name: cb.fabric-public-api-cleanup +description: Use when removing, refactoring, or replacing Cortical Fabric public declaration surfaces such as Config, Blueprint, graph constructors, population/cell/message/readout configuration, planner requests, compatibility constructors, or old normalization choke points. +--- + +# Fabric Public API Cleanup + +Use this skill when the public Fabric API or declaration model changes. Public API cleanup is compiler-front-end work: +it must produce normalized declarations that lower into rows and bindings, not compatibility wrappers around old config. + +**Announce at start:** "Using Fabric public API cleanup. I'll remove old declaration choke points and route user-facing fields to explicit compiler-owned specs." + +## First Read + +- `skills/cb.fabric-workflow-router/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` +- `skills/cb.fabric-graph-onboarding/SKILL.md` +- `skills/cb.fabric-cell-onboarding/SKILL.md` and `skills/cb.fabric-cell-boundaries/SKILL.md` for population/cell fields +- `skills/cb.fabric-message-rule-onboarding/SKILL.md` for message fields +- `skills/cb.fabric-readout-rule-onboarding/SKILL.md` for output/readout fields +- `skills/cb.fabric-runtime-front-end-handoff/SKILL.md` if cleanup changes public model forward preparation, state + initialization, boundary tensor construction, reset normalization, sender K/V setup, or the call into registered + temporal programs. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` for source/static checks that old Config/Blueprint/choke-point routes + cannot feed backend execution again. +- Live public API, docs, tests, and backend callsites: + +```bash +rg --files src/cortical/fabric tests | rg "config|blueprint|graph|message|readout|cell|planner|runtime" +rg -n "Config|Blueprint|_blueprint_to_config|width|height|coord|input_band|output_band|message|readout|planner" src/cortical/fabric tests +``` + +## Target Shape + +User-facing inputs must normalize into explicit owners: + +```text +GraphSpec / graph facts +PopulationSpec / CellSpec +MessageRuleSpec +ReadoutRuleSpec / output request +ResetSpec +PlannerRequest / backend preferences +``` + +Those specs then lower through compiler products: + +```text +normalized declarations -> semantic IR/specs -> graph/primitive rows -> tensor/parameter/artifact/output bindings +-> verifier blockers -> reference executors -> optional registered native strategies +``` + +## Cleanup Rules + +- Do not keep old `Config` or `_blueprint_to_config` as the hidden source of truth under a new wrapper. Update callsites + to consume the normalized declaration owners directly, then delete the obsolete path. +- Do not replace old config paths with `RuntimeError` bridges. Unsupported public declarations should fail closed through + compiler legality, with typed blockers, before backend launch. +- Do not put backend policy in public API cleanup. Graph constructors declare graph facts; message/cell/readout + declarations declare semantics; planner request types declare user preferences. Strategy selection, liveness, + checkpointing, scheduling, and workspace policy remain backend/compiler products. +- Legacy lattice fields such as width/height/coord/bands may survive only as constructor inputs that immediately lower to + graph facts. Backend execution must not read them for route selection or kernel policy. +- If cleanup exposes a missing primitive op, message rule, readout route, cell transition, graph fact, reducer, or + backward contract, stop cleanup and route that missing semantic product through the matching onboarding skill. +- A public API change must either change the compiled rows/bindings/routes or fail closed. It must not silently collapse + into the old built-in message, cell, readout, or planner route. +- Public cleanup must not turn old fields into hidden compatibility defaults. For each removed or migrated field, add a + test or source guard proving backend execution consumes the new graph/cell/message/readout/reset/planner spec and not + the old broad config object. +- Public cleanup that changes the high-level `model(...)` handoff must preserve the front-end/backend split: public + adapters and input normalization may stay in the front end, but backend runtime buffers, liveness, strategy selection, + reset policy, artifact policy, and sender K/V workspaces must be compiler-owned rows. + +## Work Loop + +1. Inventory each public field and assign it to graph, cell, message, readout, reset, initialization, or planner owner. +2. Add or update the normalized declaration/spec object for each owner. +3. Update lowering so declarations produce compiler rows and typed blockers directly, without old config translation. +4. Migrate callsites and tests to the normalized declarations. +5. Delete the old broad config/choke-point route once no supported path uses it. +6. Add source guardrails so backend execution cannot regain public config route selectors. + +Before editing, record: + +```text +Old public field/path: +New declaration/spec owner: +Rows/bindings/routes produced: +Verifier blocker for unsupported use: +Backend callsites migrated: +Old path deleted: +Tests/source guardrails: +``` + +If an old public field cannot be assigned to an owner, delete it or make its use fail closed. Do not keep it as an +unowned convenience flag. + +## Required Tests + +- Public construction test proving user graph/cell/message/readout declarations lower without the old config choke point. +- Lowering golden for at least one changed declaration owner. +- Negative legality test for an unsupported declaration. +- Source guardrail through `cb.fabric-boundary-guardrails`: backend/runtime strategy code consumes normalized + declaration owners and does not read old Config/Blueprint graph/message/cell/readout fields for execution policy. +- High-level smoke through `model(...)`, external loss, and `loss.backward()` when training behavior is affected. + +## Integration + +**Uses:** `cb.fabric-workflow-router`, `cb.fabric-compiler-boundary-audit`, `cb.fabric-compiler-extension`, +`cb.fabric-declaration-onboarding`. + +**Pairs with:** `cb.fabric-graph-onboarding`, `cb.fabric-cell-onboarding`, `cb.fabric-cell-boundaries`, +`cb.fabric-message-rule-onboarding`, `cb.fabric-readout-rule-onboarding`, `cb.fabric-primitive-op-onboarding`, +`cb.fabric-throughput-strategy`, `cb.fabric-runtime-front-end-handoff`, `cb.fabric-boundary-guardrails`, and +`cb.fabric-parity-gate`. diff --git a/skills/cb.fabric-readout-rule-onboarding/SKILL.md b/skills/cb.fabric-readout-rule-onboarding/SKILL.md new file mode 100644 index 00000000..f74b4c81 --- /dev/null +++ b/skills/cb.fabric-readout-rule-onboarding/SKILL.md @@ -0,0 +1,146 @@ +--- +name: cb.fabric-readout-rule-onboarding +description: Use when adding, changing, or porting a Cortical Fabric readout rule, output-boundary declaration, pooling/output route, readout primitive lowering, readout native CUDA strategy, readout backward adjoint, or readout parameter-gradient reducer. +--- + +# Fabric Readout Rule Onboarding + +Use this skill for Fabric readout and output-boundary semantics. Readout rules are user-declared compiler products, not +fixed output projection shortcuts in the temporal engine. + +**Announce at start:** "Using Fabric readout-rule onboarding. I'll add readout semantics through declarations, rows, routes, and registered executors, not temporal-engine formulas." + +## First Read + +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-backend-boundaries/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` +- `skills/cb.fabric-compiler-stress-test/SKILL.md` when changing readout/output math as a pre-throughput compiler + locality test. +- `skills/cb.fabric-primitive-op-onboarding/SKILL.md` if the readout introduces or changes a primitive op. +- `skills/cb.fabric-throughput-strategy/SKILL.md` if the change is a faster strategy over existing readout/output rows. +- `skills/cb.fabric-native-strategy-onboarding/SKILL.md` before implementing CUDA/Triton/C++ native readout strategy + bodies or binding schemas. +- `skills/cb.fabric-graph-onboarding/SKILL.md` if readout consumes graph/port/boundary facts. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` when adding source/static checks for readout row/route ownership, + no fixed output slots, deleted wrappers, or no hidden fallback. +- Live readout declarations/specs/tests: + +```bash +rg --files src/cortical/fabric | rg "readout|output_route|output_boundary|sequence_surface/compiler|registered_program" +rg -n "ReadoutRule|readout_rule|output_route|output_boundary|pooled_output|readout_projection|readout_param_grad" src/cortical/fabric tests +``` + +## Required Chain + +Every supported readout rule follows: + +```text +public readout declaration / output request + -> normalized readout spec + -> primitive rows and output-boundary rows + -> tensor role, parameter binding, output route, and artifact route rows + -> verifier / legality / typed blockers + -> reference executor + -> optional native/fused readout strategy + -> backward/output-gradient and parameter-reducer coverage + -> parity + source guardrails +``` + +Changing the readout declaration must either change the compiled rows/routes/bindings or fail closed before launch. + +## Boundary Rules + +- Readout owns output semantics: output query/source roles, projection/pooling formula, output route/merge policy, + output-boundary materialization, readout parameters, and readout backward contracts. +- Readout does not own time scheduling, graph construction, transition state updates, message aggregation, + workspace/liveness, checkpoint/recompute policy, benchmark tiling, or temporal scan identity. +- Output route semantics must be explicit. If multiple readout producers are supported, route rows must define concat, + sum, select, or another typed merge. If only one is supported, fail multiple-readout declarations closed before launch. +- A `.cuh` file or CUDA helper may implement a registered readout strategy; it must not define readout semantics by + itself. +- Do not add fixed temporal slots for `output_query`, `value_to_output_weight`, `output_cell_bias`, pooled outputs, or + output cells. Those are tensor roles on a lowered readout primitive or output-boundary row. +- Do not infer readout behavior from cell family, graph constructor, output count, benchmark row, hidden size, or legacy + tensor names. +- If a readout formula changes, make that visible in declaration/spec fields, primitive rows, tensor roles, attributes, + parameter bindings, reference behavior, and native strategy legality. Do not patch a registered-program helper alone. + +## Implementation Steps + +1. Identify the public readout declaration/output request and normalized readout spec. +2. Lower readout fields into primitive/output-boundary rows, tensor roles, parameter bindings, output routes, artifact + routes, and schema versions. +3. Add verifier rules and typed blockers for unsupported route kinds, pooling modes, layouts, dtypes, resets, + materialization, and backward coverage. +4. Add or update the reference executor for forward output values and backward gradients. +5. Add optional native/fused strategy records only after `can_implement` legality is explicit. Use + `cb.fabric-native-strategy-onboarding` before writing native kernels or binding schemas. +6. Register readout parameter-gradient reducer rows from binding-owned executable trainable parameters. +7. Delete or fail-close legacy readout wrappers/fixed slots that the registered readout path replaces. + +Before editing registered-program kernels or CUDA helpers for readout, record: + +```text +Readout declaration/spec: +Readout primitive/output-boundary rows: +Output route/merge rows: +Tensor/artifact roles: +Parameter bindings and reducers: +Reference executor: +Native strategy, if any: +Backward/output-gradient contract: +Unsupported typed blockers: +Files that should not change if the boundary holds: +``` + +If a readout change needs new output ownership or merge semantics, add those route rows first; do not encode the merge +inside a fixed backward helper. + +For route-aware readout changes, add an explicit producer/consumer packet: + +```text +Readout producer rows: +Output route/merge rows: +Artifact route rows: +Backward consumer route rows: +Reducer route rows: +Singleton legality, if only one producer is supported: +``` + +Role-only lookup is valid only for truly global artifacts. Output messages/cells, per-readout parameters, and routed +gradients must resolve through route/merge rows or fail closed before launch. + +## Required Tests + +- Lowering golden: declaration fields produce distinct readout primitive rows, output routes, and parameter bindings. +- Negative legality: unsupported route/pooling/layout/dtype/materialization fails closed before launch. +- Reference parity: output values, output-cell adjoints, input/carry gradients, and readout parameter gradients. +- Native/fused parity if a CUDA/Triton strategy is registered. +- Source guardrail through `cb.fabric-boundary-guardrails`: readout rows/routes are consumed and readout formulas/fixed + output slots do not appear in temporal scheduler/scan/reverse files. +- High-level smoke: user-style `output = model(x, ...)`, external loss, `loss.backward()`. + +## Closeout Questions + +Record in the active progress doc: + +- Which readout declaration/spec changed? +- Which primitive rows, output routes, artifact routes, and tensor bindings changed? +- Which verifier/legality failures were added? +- Which forward, backward, and reducer paths own the readout? +- Which old fixed readout path was deleted or fail-closed? +- What did not need to change because the compiler boundary held? + +## Integration + +**Uses:** `cb.fabric-compiler-boundary-audit`, `cb.fabric-backend-boundaries`, +`cb.fabric-compiler-extension`, `cb.fabric-declaration-onboarding`, `cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-primitive-op-onboarding` for new readout primitives, +`cb.fabric-graph-onboarding` for graph/port facts, `cb.fabric-throughput-strategy`, +`cb.fabric-native-strategy-onboarding`, and `cb.fabric-performance-loop` for readout throughput/native strategies, +`cb.fabric-compiler-stress-test` for pre-throughput readout formula locality tests, `cb.fabric-message-rule-onboarding` and +`cb.fabric-cell-onboarding` when readout changes interact with message/cell semantics, and +`cb.fabric-boundary-guardrails` for source/static guardrails and deleted readout-route checks. diff --git a/skills/cb.fabric-reducer-liveness/SKILL.md b/skills/cb.fabric-reducer-liveness/SKILL.md new file mode 100644 index 00000000..3de2a1c0 --- /dev/null +++ b/skills/cb.fabric-reducer-liveness/SKILL.md @@ -0,0 +1,103 @@ +--- +name: cb.fabric-reducer-liveness +description: Use when changing Cortical Fabric parameter reducers, gradient accumulation, reverse span outputs, native strategy return groups, runtime buffers, workspace reuse, or memory/lifetime policy for tensors that are consumed by reducers or carried across registered forward/backward programs. +--- + +# Fabric Reducer Liveness + +Use this skill when a Fabric change decides whether a tensor is returned to Python, consumed by a reducer, accumulated +into a runtime buffer, aliased as workspace, saved as an artifact, or dropped after local use. + +**Announce at start:** "Using Fabric reducer liveness. I'll route gradients and temporary tensors through compiler-owned reducer/lifetime rows." + +## First Read + +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-throughput-strategy/SKILL.md` +- `skills/cb.fabric-native-strategy-onboarding/SKILL.md` for CUDA/Triton/C++ changes. +- `skills/cb.fabric-parity-gate/SKILL.md` +- `skills/cb.fabric-boundary-guardrails/SKILL.md` when adding source/static checks for reducer rows, liveness rows, + returned native tensors, deleted fixed return groups, or no hidden fallback. +- The active progress doc and the latest owner table for the row being changed. + +Inspect the live reducer/lifetime path: + +```bash +rg -n "parameter_reducer|reducer_route|span_output|runtime_buffer|memory_liveness|artifact_route|workspace|grad_weight|grad_.*bank" \ + src/cortical/fabric tests benchmarks +``` + +## Contract + +Reducer and liveness work is implementation strategy work only when semantics are already represented by compiler rows. +If a parameter gradient, state gradient, artifact role, tensor role, or reducer meaning is missing, stop and use +`cb.fabric-compiler-extension` plus the relevant primitive/message/cell/readout skill. + +Every tensor produced by a native/fused strategy must be classified as exactly one of: + +- semantic return: a declared output row consumed outside the strategy; +- reducer input: consumed by `reverse_parameter_reducer_route_rows` or equivalent compiler reducer rows; +- carry/state input: accumulated into a compiler-owned carry/state runtime buffer; +- artifact/tape: saved by artifact/tape rows with a declared consumer; +- workspace: planned by memory/liveness rows and not observable after its live interval; +- metadata-only: audit/debug row stripped before semantic consumption; +- illegal: no declared consumer or lifetime owner. + +Illegal tensors are deleted or fail closed before launch. Do not keep them because a fixed ABI, old wrapper, or uniform +return group expects them. + +## Hard Rules + +- Parameter-gradient rows come from binding-owned executable parameters, not broad schemas or tensor names. +- Reducer routes own accumulation semantics. Native kernels may compute reducer inputs, but they do not invent parameter + meaning or decide final gradient destinations. +- Reverse span outputs are not a dumping ground. If a span output has a single local consumer, consume it locally or + route it into a runtime/reducer buffer instead of returning a full tensor group. +- Singleton executor/span behavior may be optimized only when compiler rows prove singleton cardinality. Do not use cell + family, benchmark row, hidden size, or single-vs-mixed labels. +- In-place accumulation is valid only when the source tensor has no later consumer except the accumulated value, and the + lifetime edge is named in memory/liveness or reducer rows. +- A memory win is accepted only if the intended lifetime/storage owner moves in allocator telemetry or profile evidence. + If max allocation or unclassified memory grows, stop and classify the owner before expanding the route. + +## Work Loop + +1. Name the owner: reducer route, span output role, runtime buffer role, artifact/tape role, workspace class, or native + stage. +2. List all consumers for the tensor by compiler row. If any consumer is inferred from role names or fixed slot order, + stop and add the missing route row first. +3. Choose the narrowest legal action: + - return only if the output is semantic; + - reduce immediately if a reducer route is the only consumer; + - accumulate into carry/state/runtime buffer if that buffer is the declared next owner; + - alias/reuse only through memory/liveness proof; + - drop after local use when no declared downstream consumer remains. +4. Preserve multi-producer/multi-span behavior through route rows. A singleton optimization must leave generic route + semantics intact or fail closed before launch for unsupported rows. +5. Run parity before perf. Include parameter-gradient key/value parity for every reducer touched. +6. Update the active progress doc with accepted/rejected lifetime movement and the next named owner. + +## Required Tests + +- Source guardrail through `cb.fabric-boundary-guardrails` when route ownership changes: positive reducer/liveness rows + consumed, negative stale return group or fixed-slot consumer absent. +- Source/legality test that missing reducer/lifetime rows fail before launch. +- Parity for outputs, state/carry gradients, input gradients, and all touched parameter gradients. +- Runtime metadata proving the active route used compiler-owned reducer/liveness rows. +- Perf or allocator telemetry showing the named owner moved; metadata-only changes are rejected. + +## Closeout + +Record: + +```text +Owner tensor/stage: +Compiler consumers: +Chosen lifetime action: +Rows consumed: +Native outputs retained: +Native outputs dropped/reduced/aliased: +Parity: +Perf/allocator movement: +Next owner: +``` diff --git a/skills/cb.fabric-runtime-front-end-handoff/SKILL.md b/skills/cb.fabric-runtime-front-end-handoff/SKILL.md new file mode 100644 index 00000000..cb7d711e --- /dev/null +++ b/skills/cb.fabric-runtime-front-end-handoff/SKILL.md @@ -0,0 +1,106 @@ +--- +name: cb.fabric-runtime-front-end-handoff +description: "Use when changing, profiling, or optimizing the Cortical Fabric runtime/model front-end handoff: public model forward, input/output adapters, state initialization, boundary tensor preparation, reset normalization, sender K/V setup, or the call boundary into registered temporal programs." +--- + +# Fabric Runtime Front-End Handoff + +Use this skill when work touches the path between a user-style `model(...)` call and the registered compiler-owned +temporal/backend program. This boundary is easy to misuse during throughput work because it can allocate large CUDA +tensors before any temporal stage reports ownership. + +**Announce at start:** "Using Fabric runtime front-end handoff. I'll keep public-call preparation separate from backend strategy ownership." + +## First Read + +- `skills/cb.fabric-workflow-router/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` if public graph/cell/message/readout/reset/state semantics change. +- `skills/cb.fabric-public-api-cleanup/SKILL.md` if Config, Blueprint, constructors, or normalized declarations change. +- `skills/cb.fabric-throughput-strategy/SKILL.md` and `skills/cb.fabric-performance-loop/SKILL.md` if this is a + throughput owner. +- `skills/cb.fabric-reducer-liveness/SKILL.md` if tensors are saved, aliased, returned, accumulated, or dropped across + the handoff. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` for source/static checks that the front-end does not regain scheduler + or primitive-formula ownership. +- Live handoff code and tests: + +```bash +rg -n "forward\\(|init_state|reset_state|boundary|sender_k|sender_v|run_shared_temporal|registered_temporal|last_backend_execution|materialize_final_state|output_boundary" src/cortical/fabric tests benchmarks +``` + +## Ownership Contract + +The front end may: + +- accept user tensors, state, resets, and output requests; +- run public adapter modules explicitly owned by the model wrapper, such as input/output linear adapters; +- validate and normalize shapes for the already-built declaration/spec; +- create or scatter boundary tensors and initial state tensors only as inputs to compiler-owned rows; +- call the registered temporal/backend program and expose its returned outputs/state. + +The front end must not: + +- select backend strategies, bucket policies, tiling, checkpoint/recompute, workspace, or liveness; +- implement message, recurrence, readout, normalization, projection, activation, or reducer formulas hidden from + primitive rows; +- infer semantics from `Config`, graph constructor names, cell family, message-rule name, benchmark row, hidden-size + constants, or old tensor names; +- allocate large CUDA temporaries that are really backend runtime buffers without memory/liveness rows; +- hide unsupported declarations behind fallback/replay/compat paths. + +If front-end preparation needs a new graph fact, tensor role, reset rule, state schema, artifact, or output route, stop +and route through declaration/compiler extension first. If it needs faster execution for existing rows, move that work +behind a registered strategy or a compiler-owned runtime buffer/liveness row. + +## Throughput Owner Rules + +When profiling names pre-temporal or pre-registered-entry memory/time: + +1. Measure the handoff separately from temporal stages with current-code allocator/timing evidence. +2. Classify each hot tensor as public adapter output, boundary input, initial state, sender K/V state, reset table, + compiler runtime buffer, artifact/tape, workspace, or illegal temporary. +3. If the tensor is a backend runtime buffer or workspace, add/move it to compiler memory/liveness rows before + optimizing. +4. If the tensor is public adapter output, document it as outside the temporal engine and optimize only through ordinary + module/adapter strategy if that is in scope. +5. If the tensor is a semantic field not represented by rows, reclassify as semantic compiler-extension work. + +Do not fix a front-end peak by benchmark-side chunking, private helper calls, shape-specific branches, detaching state, +or changing public API semantics. + +## Required Handoff Packet + +Before editing, record: + +```text +Public call surface: +Declaration/spec already built: +Tensors created before registered entry: +Owner of each tensor: +Rows/bindings/liveness consumed: +Unsupported typed blockers: +Old config/compat/fallback route removed or guarded: +Parity/perf evidence: +``` + +For throughput-only work, primitive rows and public semantics must stay stable. For semantic work, expected row/binding +or typed-blocker deltas must be named before implementation. + +## Tests And Closeout + +- Source guardrail through `cb.fabric-boundary-guardrails`: front-end code may prepare inputs but must not contain + primitive formulas, strategy selection, benchmark policies, fixed slots, or family/shape selectors. +- Runtime metadata proving the high-level user call reaches the registered temporal/backend owner with no fallback. +- Parity for outputs, exposed state, input/carry/state gradients, and parameter gradients when training is affected. +- Performance/allocator evidence showing the named front-end owner moved, or a rejected-probe note if it did not. +- Progress doc update naming what remains front-end-owned versus compiler/backend-owned. + +## Integration + +**Uses:** `cb.fabric-workflow-router`, `cb.fabric-compiler-boundary-audit`, `cb.fabric-declaration-onboarding`, +`cb.fabric-public-api-cleanup`. + +**Pairs with:** `cb.fabric-throughput-strategy`, `cb.fabric-performance-loop`, `cb.fabric-reducer-liveness`, +`cb.fabric-boundary-guardrails`, `cb.fabric-parity-gate`, `cb.fabric-compiler-extension`, +`cb.fabric-native-strategy-onboarding`. diff --git a/skills/cb.fabric-scaling-horizon/SKILL.md b/skills/cb.fabric-scaling-horizon/SKILL.md index a269e34a..bed29ff8 100644 --- a/skills/cb.fabric-scaling-horizon/SKILL.md +++ b/skills/cb.fabric-scaling-horizon/SKILL.md @@ -1,6 +1,6 @@ --- name: cb.fabric-scaling-horizon -description: Use when continuing Cortical Fabric scaling work after the physical backend closeout, expanding or regressing the B x params x h x graph structural frontier, validating T streaming-horizon behavior, or updating Fabric benchmark horizons/results. +description: Use when continuing Cortical Fabric scaling work after the physical backend closeout, expanding or regressing the B x params x h x graph structural frontier, validating T streaming-horizon behavior, updating Fabric benchmark horizons/results, or adding scaling-driven compiler/executor strategies. --- # Fabric Scaling Horizon @@ -16,7 +16,17 @@ Use this skill to keep expanding the Fabric backend structural scaling frontier ```bash sed -n '1,220p' src/cortical/fabric/README.md sed -n '1,220p' skills/cb.fabric-backend-boundaries/SKILL.md +sed -n '1,220p' skills/cb.fabric-compiler-boundary-audit/SKILL.md +sed -n '1,220p' skills/cb.fabric-declaration-onboarding/SKILL.md sed -n '1,220p' skills/cb.fabric-performance-loop/SKILL.md +sed -n '1,180p' skills/cb.fabric-parity-gate/SKILL.md +sed -n '1,220p' skills/cb.fabric-compiler-extension/SKILL.md +sed -n '1,180p' skills/cb.fabric-compiler-stress-test/SKILL.md +sed -n '1,220p' skills/cb.fabric-throughput-strategy/SKILL.md +sed -n '1,220p' skills/cb.fabric-native-strategy-onboarding/SKILL.md +sed -n '1,220p' skills/cb.fabric-graph-onboarding/SKILL.md +sed -n '1,180p' skills/cb.fabric-boundary-guardrails/SKILL.md +sed -n '1,180p' skills/cb.fabric-public-api-cleanup/SKILL.md FINAL_RESULTS_JSON=$(find docs/user/subho/fabric_benchmark/results/profiles -maxdepth 1 -name 'fabric_physical_backend_final_results_*.json' | sort | tail -1) python -m json.tool "$FINAL_RESULTS_JSON" >/tmp/fabric_final_results.json ``` @@ -24,7 +34,7 @@ python -m json.tool "$FINAL_RESULTS_JSON" >/tmp/fabric_final_results.json Also inspect current benchmark entrypoints before running anything: ```bash -python benchmarks/run_fabric_scaling_profile.py --help +python benchmarks/fabric/run_audit.py --help python -m pytest -q tests/test_fabric_benchmark_suite_common.py -n0 rg -n "Fabric|fabric" benchmarks tests ``` @@ -62,6 +72,7 @@ digraph fabric_horizon { | `h` | `h=32` is the headline baseline; `h=4/8/16` are required many-cell stress rows. Small `h` failures are real backend owners. | | `graph` | Treat shapes as user graph constructors only. Backend planning consumes flat node ids, boundary sets, degree buckets, group ids, reset density, and edge metadata. | | `T` streaming horizon | `T` is not a structural capacity axis. It runs the same planned graph for more streaming steps with carried state. Audit `T` primarily against the same row's Fabric `T=1` per-token line: Fabric throughput should stay flat or increase as `T` grows. Stack is context for the row, not the large-`T` pass/fail criterion. Also verify the backend does not materialize a full `[T, cells, state]` surface unless explicitly requested. | +| `K` internal steps | `K>1` performs K times the temporal work over the same graph. Compare K rows to the matched current-code Fabric `T=1,K=1` training line divided by K, not raw T=1 throughput. K=128/H=64 closure uses this matched T1/K floor plus parity, reset, per-timestep loss, bounded memory, and shared temporal-owner evidence. | | message rules | Message semantics are declared through `fabric.cuda.nn` / message-rule IR; backend owns topology bucketing, source packing, segmented reductions, and backward lowering. | | mixed populations | Must lower into the same backend-owned sequence/physical plan as single-population rows, with population differences as flat graph buckets. Compare to matched stack/MoE-style baselines as well as PyTorch Fabric reference. Mixed-pop throughput should be close to equivalent single-pop buckets plus expected shared-message/scatter/parameter-binding overhead. | @@ -69,16 +80,53 @@ digraph fabric_horizon { - No cell-name, model-size, hidden-size, benchmark-row, or fabric-shape branches. - No benchmark-owned tiling, time chunking, detach policy, workspace policy, or graph expansion. Users pass `B x T`; backend chooses execution. +- Benchmarks and audits must use the same high-level Fabric API path as users: model forward, external loss, + `loss.backward()`, and optimizer step when applicable. This applies to T*K and horizon-H experiments too. Benchmark + code may request public semantic inputs and record metadata after the call, but must not call private runtime/planner + helpers, implement streaming-loss hooks, run per-chunk backward, detach carry, or choose checkpoint/workspace/tape + policy. - If one Fabric path already proves the target scaling or streaming contract, treat it as the shared implementation baseline for comparable paths. First generalize that planner/executor invariant and remove divergent routes; do not tune a separate sibling path unless the active design doc proves the semantics cannot be represented by the shared executor. +- For temporal scaling work, prioritize the shared temporal engine and measured throughput before cleanup. Route deletion + and metadata cleanup are follow-on work; they cannot close a horizon or scaling stage while T, K, H, emission, or + backward are still owned by Python host loops or sibling runtime helpers. +- If T/K/H closure is blocked by a missing CUDA temporal kernel or physical superop, that kernel is the next scaling + owner. Do not spend the next iteration on planner-only, doc-only, or benchmark-organization work unless it is required + to route or validate the kernel path. - Single-population and mixed-population scaling must be audited and fixed as one flat-bucket engine. The number of populations may change the bucket list and parameter bindings, but must not change the route identity, planner ownership, temporal-scan semantics, tape policy, or user/benchmark API. - No stale numbers. If a result is not from the current commit, label it historical and rerun before using it. - No metadata-only wins. Active owner time, launch count, memory, or parity must move. +- No T/H/K root-cause claim before the matched current-code T=1,K=1 training row is measured on the same + graph/batch/params/hidden/population/loss-boundary contract. If that T=1 row is already slow, high-memory, or reports + `python_autograd_scan`, the owner is the shared temporal training path and the legacy owner is a fail-closed defect; + larger T/H/K rows are blocked evidence, not closure evidence. `B=1` or 1M-parameter rows are smoke only unless the + closure matrix explicitly names that shape. +- Do not defer representative scaling audits until final closure. During engine work, run selected high-level guardrail + rows often enough to catch wrong owners early: matched T=1 training, T=512/T=4096 K=1 H=64 per-timestep loss, + K=1/2/8/32/128 probes, reset parity, mixed-pop T=1, and small-hidden/shape guards. A guardrail pass is steering + evidence, not closure; a guardrail failure reopens the owning backend/planner stage immediately. +- Final T/H/K closure must inherit the full T=1 shape/size matrix. Representative T/H/K rows are implementation + guardrails only; they do not prove scalable Fabric. For closure, every family, parameter target, batch, hidden-size + stress row, population mode, reset mode, output boundary, and graph/factorization group used for T=1 closure must also + be run through the relevant T/H/K closure rows unless the row is explicitly inapplicable and documented with a + backend-owned reason. Do not optimize or declare success for one showcase T/H/K row while leaving the rest of the T=1 + matrix unmeasured. - No broad replay/autograd ownership on supported training rows where an existing family should own backward. +- Scaling or horizon work that requires a new message rule, transition primitive, readout behavior, native callable, + reducer, or throughput executor must go through `cb.fabric-compiler-extension` first. Use + `cb.fabric-primitive-op-onboarding` for new semantics and `cb.fabric-throughput-strategy` for optimization of existing + compiler rows. Use `cb.fabric-native-strategy-onboarding` for CUDA/Triton/C++ native strategy bodies. Use + `cb.fabric-readout-rule-onboarding` for readout/output-boundary semantics. Do not add + shape-specific scheduler branches or benchmark-owned policies to make a scaling row fit. +- Scaling work that deliberately changes semantics to prove compiler locality before throughput must use + `cb.fabric-compiler-stress-test`; do not treat that row as a throughput tuning pass until row-delta and parity gates + are accepted. +- Scaling work that exposes old Config/Blueprint/public declaration leakage must use `cb.fabric-public-api-cleanup`. + Do not hide scaling policy behind public constructor fields or compatibility config translation. - No mixed-population closure on a separate population-step fallback or benchmark-only helper. If mixed-pop rows still report a wrapper owner, continue the unified flat-bucket sequence/physical-plan refactor before accepting performance. - No raw artifact sprawl. Keep the curated final-results JSON, human summary, and current rollout graphs; leave raw profiles local or ignored. @@ -120,6 +168,20 @@ Next frontier: - For large `T`, the primary throughput question is whether Fabric stays at or above the same-row Fabric `T=1` line. Do not mark a large-`T` throughput result bad just because stack improves faster at large `T`; record stack as context, then separately attribute any Fabric memory/tape/full-surface owner. +- For large `K`, the throughput question is whether Fabric stays at or above the matched current-code Fabric `T=1,K=1` + training line divided by K. K extra-work scaling is expected; do not judge K=128/H=64 against raw T=1 tok/s. +- `H` is a rolling TBPTT sequence horizon, not a T=1 closure dimension. Use T=1 rows as the baseline full-step + reference and K-executor probes; claim H=64 closure only on `T>1` sequence rows. +- TBPTT horizon clips dependency length at each loss-emission point and at the actual available physical stream. If + `T*K <= H`, the backward dependency range is the real `T*K` stream; H must not introduce extra scan, replay, + checkpoint, or materialization work beyond the sequence. +- Planner/default checkpoint policy must use the effective clipped horizon, not the raw requested H. A row where H + covers the full `T*K` stream should record a full-horizon backward window and a checkpoint stride no larger than the + available stream. +- Checkpointing and replay are backend planner/workspace choices driven by shape, size, tape pressure, loss/output + materialization, reset policy, `T*K`, H, and device memory. They are not fixed audit routes. Small or `T*K <= H` + rows may directly store/feed generic reverse-table artifacts when that is cheaper; larger rows should use + planner-selected checkpoint/recompute windows. - Stack comparison matters for every publishable row, but stack OOM should be reported as OOM, not converted into a ratio. - Factorization invariance is only meaningful when flat graph invariants match. - Large `T` should not expand active receiver ownership or retain full recurrent state surfaces. @@ -129,6 +191,13 @@ Next frontier: ## Integration -**Uses:** `cb.fabric-backend-boundaries`, `cb.fabric-performance-loop`, `t.run-tests`. +**Uses:** `cb.fabric-backend-boundaries`, `cb.fabric-compiler-boundary-audit`, `cb.fabric-performance-loop`, +`cb.fabric-parity-gate`, `cb.fabric-compiler-extension`, `cb.fabric-compiler-stress-test`, `cb.fabric-throughput-strategy`, +`cb.fabric-native-strategy-onboarding`, `cb.fabric-boundary-guardrails`, `t.run-tests`. -**Pairs with:** `cb.fabric-cell-boundaries` for cell declaration changes, `cb.branch-hygiene` before final PR cleanup, `pr.fix-pr-meaningful` or `db.fix-failing-ci` after publishing. +**Pairs with:** `cb.fabric-cell-boundaries` for cell declaration changes, `cb.fabric-message-rule-onboarding` for message +rule expansion, `cb.fabric-readout-rule-onboarding` for readout/output-boundary expansion, +`cb.fabric-primitive-op-onboarding` for primitive semantics, `cb.fabric-graph-onboarding` for graph/topology frontiers, +`cb.fabric-public-api-cleanup` for public declaration cleanup, `cb.fabric-boundary-guardrails` for source/static +guardrails and deleted-route checks, +`cb.branch-hygiene` before final PR cleanup, `pr.fix-pr-meaningful` or `db.fix-failing-ci` after publishing. diff --git a/skills/cb.fabric-skill-maintenance/SKILL.md b/skills/cb.fabric-skill-maintenance/SKILL.md new file mode 100644 index 00000000..285b8df4 --- /dev/null +++ b/skills/cb.fabric-skill-maintenance/SKILL.md @@ -0,0 +1,142 @@ +--- +name: cb.fabric-skill-maintenance +description: Use when creating, editing, auditing, or reconciling Cortical Fabric skills so future compiler, declaration onboarding, throughput, cell, message-rule, readout-rule, graph, primitive-op, parity, and performance work preserves compiler boundaries. +--- + +# Fabric Skill Maintenance + +Use this skill when the task is to change Fabric skills or add a missing Fabric workflow skill. + +**Announce at start:** "Using Fabric skill maintenance. I'll keep the skills concise, routed, and aligned with the Fabric compiler-boundary contract." + +## First Read + +- `/home/ubuntu/.codex/skills/.system/skill-creator/SKILL.md` +- `skills/cb.fabric-workflow-router/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-boundary-guardrails/SKILL.md` +- The Fabric skills being edited. + +Inspect the skill set: + +```bash +find skills -maxdepth 2 -name SKILL.md -print | sort +rg -n "compiler|throughput|primitive|message|readout|cell|graph|fallback|compat|fixed|April21" skills +sed -n '1,80p' AGENTS.md +python skills/cb.fabric-skill-maintenance/scripts/audit_fabric_skills.py +``` + +Also check frontmatter trigger coverage. The `description` field is what causes Codex to load a skill in future +sessions; it must name the durable workflows and risk surfaces clearly enough that throughput, native strategies, +runtime front-end handoff, cells, message rules, readouts, graph work, primitive ops, public API cleanup, guardrails, +parity, performance, and skill-maintenance requests trigger the right skill without relying on chat memory. + +## Maintenance Rules + +- Keep skill bodies concise. Add only durable process rules, not transient benchmark numbers, one-off file paths, or + current implementation status. +- Prefer patching the narrow surface skill over bloating `cb.fabric-backend-boundaries`. +- Every Fabric skill that can lead to code changes must route through `cb.fabric-compiler-boundary-audit` when compiler + ownership is at risk. +- Every Fabric skill that asks for source/static guardrails must route through `cb.fabric-boundary-guardrails`. Guardrail + instructions must pair a positive compiler product check with a negative stale-route ban, and must state that guardrails + do not replace parity, active-route metadata, or performance evidence. +- Throughput skills must say that strategies implement existing rows. If semantics, tensor roles, or gradient contracts + are missing, route to `cb.fabric-compiler-extension` first. +- Native-strategy skills must say that CUDA/Triton/C++ kernels consume existing compiler rows and bindings. If a native + kernel needs new meaning, route to the semantic onboarding skill first. +- Skills must preserve the lane split: + semantic extensions change declarations/rows/bindings/reference behavior; + throughput strategies keep semantics stable and change implementation/liveness/cost; + native strategies implement existing rows; evidence work measures only; cleanup deletes stale routes. + If a skill blurs those lanes, patch it before using it for code changes. +- Any skill that can lead to performance claims must require row-fingerprint or route-owner evidence: semantic changes + prove rows changed, throughput changes prove rows stayed stable and active owner moved. +- Onboarding skills for cells, message rules, readout rules, graph/topology, and primitive ops must require: + declaration/spec, rows, bindings, legality, reference executor, optional native strategy, backward/reducer coverage, + fail-closed blockers, parity, and source guardrails. +- Public semantic changes must route through `cb.fabric-declaration-onboarding` before the narrow graph/cell/message/ + readout/primitive skill. This keeps cells, message rules, readouts, graph constructors, primitive ops, reset/init + fields, and Config/Blueprint cleanup on one declaration/spec -> rows -> bindings -> reference -> strategy standard. +- Pre-throughput semantic stress tests must route through `cb.fabric-compiler-stress-test`. They prove the compiler path + by changing rows/bindings/routes and reference behavior; they are not throughput optimization passes. +- Public API cleanup skills must force old `Config`/`Blueprint`/compatibility fields into explicit graph, cell, message, + readout, reset, initialization, or planner owners. They must not preserve a broad config object as hidden compiler + truth or replace an old path with a RuntimeError bridge. +- Runtime front-end handoff skills must keep public-call preparation separate from backend/compiler strategy ownership. + They must classify adapter outputs, boundary/state tensors, resets, sender K/V setup, compiler runtime buffers, + artifacts, workspaces, and illegal temporaries before any throughput or liveness patch. +- Add a new skill only for a missing durable workflow. Do not create a new skill for a temporary bug, one benchmark row, + one file cleanup, or a single rejected probe. +- Keep `cb.fabric-workflow-router` as the top-level routing skill for ambiguous Fabric requests, repeated + analyze/plan/proceed prompts, or work that may span compiler closure, throughput, semantic onboarding, cleanup, parity, + and performance. +- Keep `AGENTS.md` pointed at `cb.fabric-workflow-router` for non-trivial Fabric work. Some future sessions may not have + every repository skill prelisted in the tool-provided skill inventory, so the repository router must still force agents + to inspect the local `skills/` directory and load the Fabric workflow router before code changes. +- When adding or editing a narrow skill, make sure the router points to it and the narrow skill still names + `cb.fabric-compiler-boundary-audit` when compiler ownership is at risk. +- When a skill says "add a source guardrail," make sure it names the invariant to guard, the compiler product that owns + the invariant, and the old route that must stay deleted. Avoid broad string-presence checks that only prove the code + contains words like `compiler`, `row`, or `strategy`. +- Check cross-skill routing after edits: router -> boundary audit -> declaration onboarding for semantic changes -> + compiler extension or compiler stress test -> narrow semantic or native strategy skill -> boundary guardrails -> + parity/performance closeout. + A future agent should not be able to start throughput, cells, message rules, readout rules, graph work, public API + cleanup, runtime front-end handoff, semantic stress tests, or native kernels without seeing the compiler-boundary gate. +- Run `python skills/cb.fabric-skill-maintenance/scripts/audit_fabric_skills.py` after skill edits. This script is a + routing smoke test only; it catches missing references, absent core phrases, and missing top-level AGENTS routing, but + it does not replace manual review of whether the instructions are precise enough for the current Fabric boundary. +- Add a new skill only when the workflow is not covered by router, boundary audit, compiler extension, throughput, + declaration onboarding, compiler stress test, native strategy, reducer liveness, primitive, graph, cell, message, + readout, source guardrails, public API cleanup, parity, performance, or scaling skills. Prefer strengthening those + skills over adding a near-duplicate. + +## Boundary Coverage Audit + +Before closing a skill-maintenance pass, verify each Fabric workflow can answer: + +```text +Which lane is this? semantic | throughput | native | evidence | cleanup | guardrail +Which declaration/spec owns semantics? +Which rows/bindings/routes/liveness products prove compiler ownership? +For semantic work, what row/binding/route delta is expected? +For throughput work, what semantic row fingerprint must stay stable? +What old route is deleted or fail-closed? +What focused guardrail prevents that old route from returning? +What parity/perf evidence is still required? +``` + +If a skill cannot force those answers for its surface, patch that skill before relying on it for code work. + +## Required Coverage Matrix + +After edits, check that the skill set still has one clear owner for each durable workflow: + +| Workflow | Required owner skill | +| --- | --- | +| Ambiguous Fabric request or analyze -> plan -> proceed loop | `cb.fabric-workflow-router` | +| Boundary proof before code/perf/closure | `cb.fabric-compiler-boundary-audit` | +| Public semantic change | `cb.fabric-declaration-onboarding` plus the narrow surface skill | +| New primitive op or formula change | `cb.fabric-primitive-op-onboarding` | +| New message rule or message math change | `cb.fabric-message-rule-onboarding` | +| New cell or transition math change | `cb.fabric-cell-onboarding` and `cb.fabric-cell-boundaries` | +| New readout/output route | `cb.fabric-readout-rule-onboarding` | +| New graph/topology constructor or graph fact | `cb.fabric-graph-onboarding` | +| Native CUDA/Triton/C++ implementation | `cb.fabric-native-strategy-onboarding` | +| Throughput strategy over stable rows | `cb.fabric-throughput-strategy` and `cb.fabric-performance-loop` | +| Runtime/model front-end handoff before registered backend entry | `cb.fabric-runtime-front-end-handoff` | +| Reducer, workspace, alias, artifact, or liveness owner | `cb.fabric-reducer-liveness` | +| Pre-throughput formula/locality stress test | `cb.fabric-compiler-stress-test` | +| Config/Blueprint/public cleanup | `cb.fabric-public-api-cleanup` | +| Source/static guardrail or deleted-route check | `cb.fabric-boundary-guardrails` | + +If a new durable workflow is not covered by this matrix, add a skill. If it is covered, strengthen the existing skill +instead of adding a duplicate. + +## Closeout + +- Validate frontmatter has `name` and `description`. +- Grep for stale or contradictory guidance and missing router references. +- Run `python skills/cb.fabric-skill-maintenance/scripts/audit_fabric_skills.py`. +- Summarize which skills changed and what future workflow they now enforce. diff --git a/skills/cb.fabric-skill-maintenance/scripts/audit_fabric_skills.py b/skills/cb.fabric-skill-maintenance/scripts/audit_fabric_skills.py new file mode 100644 index 00000000..c697c4b0 --- /dev/null +++ b/skills/cb.fabric-skill-maintenance/scripts/audit_fabric_skills.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import re +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[3] +SKILLS = ROOT / "skills" +AGENTS = ROOT / "AGENTS.md" +FABRIC_SKILL_RE = re.compile(r"cb\.fabric-[a-z0-9-]+") +PATH_REF_RE = re.compile(r"skills/(cb\.fabric-[a-z0-9-]+)/SKILL\.md") + +REQUIRED_SKILLS = { + "cb.fabric-backend-boundaries", + "cb.fabric-boundary-guardrails", + "cb.fabric-cell-boundaries", + "cb.fabric-cell-onboarding", + "cb.fabric-compiler-boundary-audit", + "cb.fabric-compiler-extension", + "cb.fabric-compiler-stress-test", + "cb.fabric-declaration-onboarding", + "cb.fabric-graph-onboarding", + "cb.fabric-message-rule-onboarding", + "cb.fabric-native-strategy-onboarding", + "cb.fabric-parity-gate", + "cb.fabric-performance-loop", + "cb.fabric-primitive-op-onboarding", + "cb.fabric-public-api-cleanup", + "cb.fabric-readout-rule-onboarding", + "cb.fabric-reducer-liveness", + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-scaling-horizon", + "cb.fabric-skill-maintenance", + "cb.fabric-throughput-strategy", + "cb.fabric-workflow-router", +} + +REQUIRED_PHRASES = { + "cb.fabric-workflow-router": ( + "Boundary Manifest Gate", + "Rows expected to change", + "Rows expected to stay stable", + "Skill edits", + ), + "cb.fabric-compiler-boundary-audit": ( + "semantic extension", + "throughput strategy", + "compiler-boundary", + "declaration/spec", + ), + "cb.fabric-compiler-extension": ( + "public declaration", + "primitive op rows", + "tensor role and parameter binding rows", + "Adding a new primitive, message rule, cell transition, readout, graph behavior, or throughput strategy", + ), + "cb.fabric-declaration-onboarding": ( + "Unified Declaration Rule", + "Expected row/binding/route delta", + "old Config/Blueprint choke points", + ), + "cb.fabric-throughput-strategy": ( + "Expected semantic delta: none", + "Do not use a throughput strategy to introduce new math", + "active owner moved", + "narrow liveness improvement", + ), + "cb.fabric-native-strategy-onboarding": ( + "Rows consumed", + "Do not create tensor roles, formulas", + "Classify every native output", + ), + "cb.fabric-reducer-liveness": ( + "Every tensor produced by a native/fused strategy must be classified", + "Parameter-gradient rows come from binding-owned executable parameters", + ), + "cb.fabric-message-rule-onboarding": ( + "Message rules and cells use the same compiler standard", + "not temporal-engine policies", + "no fixed Q/K/V", + ), + "cb.fabric-cell-onboarding": ( + "Unified Declaration Standard", + "cell registry is a semantic registry", + "row-delta packet", + "cb.fabric-parity-gate", + ), + "cb.fabric-cell-boundaries": ( + "Cell and message declarations", + "cb.fabric-parity-gate", + "same compiler-boundary standard", + ), + "cb.fabric-primitive-op-onboarding": ( + "Every supported primitive must have this chain", + "Do not implement primitive formulas inside temporal", + "Rows that prove this is a semantic change", + ), + "cb.fabric-boundary-guardrails": ( + "positive compiler product", + "negative stale-route ban", + "not closure by themselves", + ), + "cb.fabric-parity-gate": ( + "full parameter gradient key equality", + "For throughput strategies, prove the opposite", + "Source guardrails are not parity evidence", + ), + "cb.fabric-performance-loop": ( + "Treat April21 as a baseline to beat, not a source tree to copy", + "A metadata label is not evidence", + "active owner must physically move", + ), + "cb.fabric-public-api-cleanup": ( + "Do not replace old config paths with `RuntimeError` bridges", + "normalized declaration owners directly", + "old broad config object", + ), + "cb.fabric-runtime-front-end-handoff": ( + "registered compiler-owned", + "Tensors created before registered entry", + "backend runtime buffers", + "Do not fix a front-end peak by benchmark-side chunking", + ), + "cb.fabric-scaling-horizon": ( + "cb.fabric-parity-gate", + "same backend-owned sequence/physical plan", + "shape-specific scheduler branches", + ), + "cb.fabric-skill-maintenance": ( + "frontmatter trigger coverage", + "negative stale-route ban", + "Required Coverage Matrix", + ), +} + +REQUIRED_REFERENCES = { + "cb.fabric-backend-boundaries": ( + "cb.fabric-workflow-router", + "cb.fabric-compiler-boundary-audit", + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-cell-boundaries": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-compiler-extension", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-cell-onboarding": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-compiler-extension", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-compiler-boundary-audit": ( + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-throughput-strategy", + "cb.fabric-declaration-onboarding", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-compiler-extension": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-throughput-strategy", + "cb.fabric-native-strategy-onboarding", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-declaration-onboarding": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-boundary-guardrails", + "cb.fabric-parity-gate", + ), + "cb.fabric-graph-onboarding": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-message-rule-onboarding": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-native-strategy-onboarding": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-throughput-strategy", + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-performance-loop": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-throughput-strategy", + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-primitive-op-onboarding": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-public-api-cleanup": ( + "cb.fabric-workflow-router", + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-readout-rule-onboarding": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-reducer-liveness": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-throughput-strategy", + "cb.fabric-parity-gate", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-scaling-horizon": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-performance-loop", + "cb.fabric-parity-gate", + "cb.fabric-throughput-strategy", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-skill-maintenance": ( + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-compiler-boundary-audit", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-throughput-strategy": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-performance-loop", + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-parity-gate", + "cb.fabric-native-strategy-onboarding", + "cb.fabric-boundary-guardrails", + ), + "cb.fabric-runtime-front-end-handoff": ( + "cb.fabric-workflow-router", + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-public-api-cleanup", + "cb.fabric-throughput-strategy", + "cb.fabric-performance-loop", + "cb.fabric-boundary-guardrails", + "cb.fabric-parity-gate", + ), + "cb.fabric-workflow-router": ( + "cb.fabric-compiler-boundary-audit", + "cb.fabric-declaration-onboarding", + "cb.fabric-runtime-front-end-handoff", + "cb.fabric-boundary-guardrails", + ), +} + + +def read_skill(path: Path) -> tuple[dict[str, str], str]: + text = path.read_text(encoding="utf-8") + if not text.startswith("---\n"): + raise ValueError(f"{path}: missing YAML frontmatter") + try: + _, frontmatter, body = text.split("---", 2) + except ValueError as exc: + raise ValueError(f"{path}: malformed YAML frontmatter") from exc + data: dict[str, str] = {} + for line in frontmatter.splitlines(): + if ":" not in line: + continue + key, value = line.split(":", 1) + data[key.strip()] = value.strip() + return data, body + + +def main() -> int: + errors: list[str] = [] + paths = sorted(SKILLS.glob("cb.fabric-*/SKILL.md")) + skill_names: set[str] = set() + bodies: dict[str, str] = {} + + for path in paths: + try: + data, body = read_skill(path) + except ValueError as exc: + errors.append(str(exc)) + continue + name = data.get("name", "") + description = data.get("description", "") + expected = path.parent.name + if name != expected: + errors.append(f"{path}: frontmatter name {name!r} does not match directory {expected!r}") + if not description: + errors.append(f"{path}: missing non-empty description") + skill_names.add(name) + bodies[name] = body + + missing_required = sorted(REQUIRED_SKILLS - skill_names) + for name in missing_required: + errors.append(f"missing required Fabric skill: {name}") + + for name, body in bodies.items(): + for ref in PATH_REF_RE.findall(body): + if ref not in skill_names: + errors.append(f"{name}: references missing skill path {ref}") + for ref in FABRIC_SKILL_RE.findall(body): + if ref not in skill_names: + errors.append(f"{name}: references missing skill name {ref}") + + for name, phrases in REQUIRED_PHRASES.items(): + body = bodies.get(name, "") + for phrase in phrases: + if phrase not in body: + errors.append(f"{name}: missing required routing phrase {phrase!r}") + + for name, references in REQUIRED_REFERENCES.items(): + body = bodies.get(name, "") + for reference in references: + if reference not in body: + errors.append(f"{name}: missing required skill reference {reference!r}") + + if not AGENTS.exists(): + errors.append("AGENTS.md is missing; Fabric skill routing has no top-level entry point") + else: + agents_text = AGENTS.read_text(encoding="utf-8") + for required_text in ( + "skills/cb.fabric-workflow-router/SKILL.md", + "skills/cb.fabric-compiler-boundary-audit/SKILL.md", + "Cortical Fabric work", + ): + if required_text not in agents_text: + errors.append(f"AGENTS.md: missing required Fabric routing text {required_text!r}") + + if errors: + print("Fabric skill audit failed:", file=sys.stderr) + for error in errors: + print(f"- {error}", file=sys.stderr) + return 1 + + print(f"Fabric skill audit passed: {len(skill_names)} skills checked.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/skills/cb.fabric-throughput-strategy/SKILL.md b/skills/cb.fabric-throughput-strategy/SKILL.md new file mode 100644 index 00000000..35d8b65d --- /dev/null +++ b/skills/cb.fabric-throughput-strategy/SKILL.md @@ -0,0 +1,306 @@ +--- +name: cb.fabric-throughput-strategy +description: Use when adding, changing, profiling, or optimizing a Cortical Fabric throughput executor strategy, fused CUDA/Triton/native callable, registered forward/reverse program kernel, memory/liveness strategy, artifact strategy, parameter-reduction strategy, or cost/legality rule. +--- + +# Fabric Throughput Strategy + +Use this skill when performance work changes how an already-declared compiler product is executed. A throughput strategy +is a replaceable implementation over verified primitive rows; it is not a new semantic authority. + +**Announce at start:** "Using Fabric throughput strategy. I'll optimize only registered compiler-owned strategies over verified rows and keep semantics out of scheduler code." + +## First Read + +- `skills/cb.fabric-backend-boundaries/SKILL.md` +- `skills/cb.fabric-compiler-extension/SKILL.md` +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` if the owner needs new user-visible graph, cell, message, readout, + primitive, reset/init, or public declaration semantics. +- `skills/cb.fabric-compiler-stress-test/SKILL.md` if the requested pass is a pre-throughput semantic stress test or + formula perturbation. In that case, do not optimize throughput in the same pass. +- `skills/cb.fabric-performance-loop/SKILL.md` +- `skills/cb.fabric-parity-gate/SKILL.md` +- `skills/cb.fabric-native-strategy-onboarding/SKILL.md` before editing CUDA/Triton/C++ native strategy bodies, fused + program kernels, native callable binding schemas, or kernel ABIs. +- `skills/cb.fabric-reducer-liveness/SKILL.md` before changing parameter reducers, reverse span outputs, native return + groups, runtime-buffer lifetimes, workspace reuse, or tensors consumed by reducers. +- `skills/cb.fabric-readout-rule-onboarding/SKILL.md` if the strategy changes readout/output-boundary behavior, output + routes, pooling, or readout parameter reducers. +- `skills/cb.fabric-graph-onboarding/SKILL.md` if the strategy depends on graph/topology facts, active-region legality, + factorization, degree buckets, or boundary/port ownership. +- `skills/cb.fabric-public-api-cleanup/SKILL.md` if the optimization exposes old Config/Blueprint/public declaration + paths that still own semantics or policy. +- `skills/cb.fabric-runtime-front-end-handoff/SKILL.md` if the owner appears before the registered temporal/native + program entry, including public model forward preparation, input/output adapters, boundary tensor prep, state + initialization, reset normalization, sender K/V setup, or call handoff. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` if the strategy adds source/static checks, legacy deletion checks, or + no-fallback/no-compat guardrails. +- The active throughput progress doc and latest current-code benchmark artifacts. + +Inspect the live strategy and program path: + +```bash +rg -n "RegisteredTemporalExecutorProgram|strategy|can_implement|native_callable|executor_row|tensor_binding|memory_liveness|artifact_route|parameter_reducer|missing_executor|fixed|compat" src/cortical/fabric tests benchmarks +``` + +## Strategy Contract + +A strategy is legal only if it declares: + +- stable strategy id/version and primitive pattern signature +- required primitive rows, tensor roles, parameter bindings, artifact routes, and output routes +- supported dtype/device/layout/shape/reset/tape/materialization policy +- forward executor and backward executor coverage, or explicit no-grad legality +- saved-tensor/recompute/tape contract +- workspace/liveness/aliasing requirements +- parameter-gradient/reducer outputs +- determinism/tolerance class +- typed rejection reasons +- audit metadata and plan fingerprint + +Legality and cost are separate: + +```text +candidate match -> legality filtering -> cost/ranking -> launch plan -> audited execution +``` + +A bad cost model may pick a slower legal strategy. It must never make an illegal strategy run. + +Before implementation, create a boundary manifest in working notes or the progress doc. It must say which semantic rows +are unchanged, which strategy/runtime rows change, which bindings/routes/liveness rows are consumed, and which old route +is deleted or fail-closed. If this cannot be written, the work is not ready for throughput. + +Record a semantic-stability guard: + +```text +Primitive row fingerprint before: +Primitive row fingerprint after: +Tensor-role/binding fingerprint before: +Tensor-role/binding fingerprint after: +Expected semantic delta: none +``` + +If the before/after fingerprints differ because math, roles, parameters, output routes, reset/tape behavior, or gradient +contracts changed, stop. That pass is semantic compiler-extension work, not throughput. +If the user intentionally requested that row delta as a pre-throughput compiler reality test, switch to +`cb.fabric-compiler-stress-test` and keep performance tuning deferred. + +A streaming producer-consumer schedule is a valid throughput strategy, and often the right architecture, only when it is +registered as a physical strategy over compiler products. User semantics stay in graph, cell, message, transition, +readout, and output declarations. The compiler emits primitive rows, tensor bindings, executor rows, liveness rows, +artifact routes, output routes, and reducer rows. The backend may choose streaming schedule, chunking, workspace, tape, +and checkpointing over those rows. It may not hardcode sLSTM/Axon special paths, move time chunking/detach/replay into +benchmarks, or change formulas in scheduler glue. The required framing is: semantic delta `none`; row fingerprints +stable; liveness/artifact/output/reducer rows consumed directly; backend policy selected from compiler rows. + +Before committing to a large strategy implementation, run a small hypothesis probe when it can cheaply answer whether +the direction is useful. The probe must still consume the normal compiler-owned route and should be recorded as: + +```text +Hypothesis: +Expected strategy/owner movement: +Probe row, synthetic fixture, or source/telemetry check: +Artifact path: +Keep/narrow/revert rule: +Representative row required before acceptance: +``` + +A valid probe can reduce batch, params, iterations, mode, or instrumentation scope only when the same strategy, +bindings, route rows, and owner are exercised. It may not change semantics, use benchmark-owned scheduling, or key on +cell family, hidden size, row id, or shape. Passing a probe means "continue"; it does not mean the strategy is accepted. + +Synthetic fast experiments are valid for strategy selection when they isolate one question: Does this liveness edge move? +Does this kernel launch count drop? Does this binding table avoid an allocation? Does this route still select the same +strategy? Use tiny generated programs, synthetic tensors, toy bucket shapes, or temporary `tmp/` scripts for that purpose. +If the probe bypasses the high-level Fabric call path, treat it as mechanism-only evidence and require a high-level +representative row before accepting the implementation. + +For throughput patches that add or modify a native/fused program, also record the row-consumption proof: + +```text +Program rows consumed directly: +Executor rows consumed directly: +Tensor binding rows consumed directly: +Artifact/output route rows consumed directly: +Reset/liveness/workspace rows consumed directly: +Legacy slots/wrappers no longer reachable: +``` + +If the implementation still needs fixed tensor positions, role-only producer lookup, singleton message/readout +assumptions, Q/K/V globals, gated/diagonal bundles, or output projection globals that are not compiler rows, the +throughput owner is still compiler closure, not compute tuning. + +## Hard-Stop Preflight + +Before writing a faster kernel, prove the row already exists as a compiler product: + +```text +declaration/spec -> primitive rows -> executor rows -> tensor bindings -> memory/artifact rows -> audit metadata +``` + +If the planned optimization needs a new formula, new tensor role, new reset behavior, new readout/output ownership, or +new graph/message/cell semantics, stop throughput work and use `cb.fabric-compiler-extension` first. A throughput +strategy may fuse, reorder, tile, allocate workspace, recompute, checkpoint, or choose launch geometry for legal rows. +It may not create semantics or infer missing tensor roles from old names. + +Complete the `cb.fabric-compiler-boundary-audit` pre-implementation gate before adding or changing a throughput +strategy. If the audit cannot name the declaration/spec owner, primitive rows, tensor bindings, route rows, backward +owner, reducer owner, and memory/liveness owner, the strategy is not ready for optimization. + +If the strategy depends on topology shape, graph factorization, degree buckets, distance/delay, compact active regions, +or boundary sets, those facts must come from lowered graph rows. Do not inspect lattice constructors, config fields, or +benchmark graph names inside the strategy. + +Before code, record this in the progress doc: + +```text +Unchanged semantic rows: +Changed strategy/runtime rows: +Tensor/route/liveness rows consumed directly: +Old route being deleted or fail-closed: +Parity gate: +Perf row: +``` + +If `Unchanged semantic rows` cannot be filled in, this is not throughput work. Switch to `cb.fabric-compiler-extension`. + +For the common three-step workflow, keep the phases distinct: + +- **Deep dive/analyze:** build the owner table and list compiler products; do not optimize. +- **Plan:** choose the highest-impact legal strategy owner and name parity/perf gates; do not optimize. +- **Proceed:** implement only the selected registered strategy and update the progress doc. + +## Non-Negotiables + +- Do not add primitive formulas to temporal scheduler, scan/reverse, benchmark, config, or route glue. +- Do not key strategy selection on cell family, message-rule name, benchmark id, hidden-size constants, single/mixed + population labels, old tensor names, or April21 code shape. +- Do not copy April21 or legacy fixed-slot implementations as code. Use old results only as a baseline target and use + old code, if inspected at all, only to understand a performance idea that must be re-expressed through current + compiler rows, bindings, route rows, reset rows, and liveness rows. +- Do not copy old fixed-slot kernels into a new wrapper. The strategy must consume compiler-owned primitive rows, + executor rows, tensor bindings, artifact routes, output routes, reset rows, and memory/liveness rows directly. +- Do not tune benchmark-side tiling, detach, checkpoint, time chunking, workspace, or graph expansion. +- Do not treat pre-registered-entry runtime glue as a strategy shortcut. Public-call handoff work must classify tensors + as public adapters, boundary/state inputs, reset tables, sender K/V setup, compiler runtime buffers, artifacts, or + workspace, then move backend-owned buffers behind compiler liveness rows or registered strategies. +- Do not accept speedups from hidden fallback/replay/compat paths. Delete or fail-close replaced routes after parity and + current-code non-regression. +- Do not move to compute optimization while a named compiler memory/liveness owner prevents the representative row from + fitting. +- Do not use a throughput strategy to introduce new math, new tensor roles, or new parameter-gradient semantics. A + faster strategy may change layout, fusion, launch shape, allocation, recompute, checkpoint, aliasing, or reducer + implementation only for already verified rows. + +## Memory/Liveness Probes + +Memory, aliasing, and no-copy work must be treated as compiler-owned liveness work, not as hopeful metadata. + +Before an alias/reuse/no-copy patch, record: + +```text +Expected storage identity change: +Lifetime edge being shortened: +Compiler liveness/alias rows consumed: +Autograd saved-tensor impact: +CUDA temporary/workspace impact: +Peak allocated/reserved owner to watch: +Keep, narrow, or revert rule: +``` + +Accept the probe only if the measured active path changes the intended storage identity or lifetime and the named peak +owner moves. If peak memory jumps unexpectedly, especially as unclassified/allocator memory, stop expanding the alias +route. First classify whether the owner is a liveness bug, autograd saved tensor, CUDA temporary/workspace allocation, or +allocator reserve gap. Keep the diff only where that classification says it is correct; narrow it to forward-only or +revert it when backward/tape lifetime grows or storage identity did not actually change. +If current live allocations move but the high-water/max-allocated owner does not, record the result as an accepted +narrow liveness improvement only when parity is green and the shortened lifetime is real. Do not call it throughput +closure; the next owner must name the remaining high-water allocator/native stage before adding another alias or +return-slot shortcut. + +## Work Loop + +1. Name the current owner from warmed current-code evidence and build an owner table before selecting an optimization. +2. Verify the semantics already lower into primitive rows/bindings. If not, stop and use + `cb.fabric-primitive-op-onboarding`. +3. If the proposed direction is uncertain, run the smallest valid hypothesis probe first and record the keep/narrow/revert + rule. Reject the direction if the named owner does not move. +4. Define the strategy's legality, ABI, memory/liveness, artifact, forward, backward, and reducer contracts. +5. Add or modify registered strategy records and program kernels only through compiler-owned rows. Program-level kernels + must consume primitive rows, executor rows, tensor bindings, artifact/output routes, reset rows, and memory/liveness + rows directly; fixed role order, singleton executor assumptions, and compatibility slot aliases are open compiler + blockers. + When the implementation enters native code, switch to `cb.fabric-native-strategy-onboarding` for the kernel ABI and + binding-schema checklist. +6. Add C++/Python validation that rejects missing rows, missing buffers, illegal layouts, and stale compatibility ABI. +7. Remove or fail-close the old route the strategy replaces, and use `cb.fabric-boundary-guardrails` for the focused + source/static check that keeps the stale route out. +8. Run targeted parity before using performance numbers. +9. Run warmed targeted perf rows, then update the progress doc with commands, artifacts, accepted/rejected probes, and + the next owner. + +When optimizing memory/liveness, the accepted diff must name the lifetime edge shortened or buffer ownership moved. If +max allocation or unclassified allocator usage grows unexpectedly, stop and classify the owner before expanding the +route. The next patch should target the named compiler-owned stage, not add more metadata. +Throughput acceptance requires evidence that the active owner moved in timing, launch count, storage identity, +allocator telemetry, or memory stage rows. A strategy id, row label, or metadata-only change is rejected. + +When optimizing native/fused programs, prefer output/reducer routing, workspace reuse, and compiler liveness rows over +extra returned tensor groups. Returning extra tensors from a native strategy is a semantic ABI change unless those +tensors are declared by output rows, reducer rows, artifact rows, or metadata rows and stripped before semantic +consumption. + +For reducer/lifetime work, classify every native output as semantic return, reducer input, carry/state input, artifact, +workspace, or metadata-only. Local-only outputs must be consumed, reduced, accumulated, aliased, or dropped according to +compiler rows; do not keep them alive just to preserve a fixed return group. Singleton span optimizations are legal only +when executor/span cardinality comes from compiler rows and the generic multi-span route remains correct. + +T=1 is the first throughput closure surface. Mixed-population T=1 and single-population T=1 must use the same registered +program machinery; the population count changes buckets and parameter bindings, not the route identity. Use single-pop +rows as cheap guardrails when they expose the same owner, but do not let single-pop tuning create a separate path that +mixed-pop cannot use. + +## Required Tests + +- Source guardrail through `cb.fabric-boundary-guardrails`: positive compiler product consumed, negative stale route + absent; no fixed slot, family, benchmark, hidden-size, or primitive formula leaked into temporal scheduler. +- Legality negative: unsupported rows reject before launch with typed reason. +- CUDA/reference parity for outputs, states, input/carry gradients, and parameter gradients on touched surfaces. +- Runtime metadata: active row reports the registered strategy, owner, plan fingerprint, and no primitive blockers. +- Performance: warmed current-code row for the named owner, plus the smallest representative non-regression row sharing + the strategy. + +## Progress Doc Entry + +Use this shape: + +```text +Owner: +Boundary audit: +Primitive/executor rows implemented: +Legality contract: +Memory/artifact contract: +Old route deleted or fail-closed: +Parity: +Perf artifacts: +Accepted/rejected: +Next owner: +``` + +## Integration + +**Uses:** `cb.fabric-backend-boundaries`, `cb.fabric-compiler-extension`, `cb.fabric-compiler-boundary-audit`, +`cb.fabric-native-strategy-onboarding`, `cb.fabric-reducer-liveness`, `cb.fabric-performance-loop`, +`cb.fabric-parity-gate`. + +**Pairs with:** `cb.fabric-declaration-onboarding` and `cb.fabric-primitive-op-onboarding` if semantics are missing, +`cb.fabric-graph-onboarding` when strategy legality depends on graph/topology facts, `cb.fabric-scaling-horizon` for +B/params/h or T/K/H expansion, +`cb.fabric-message-rule-onboarding` and `cb.fabric-cell-onboarding` when the optimized rows came from new user +semantics, `cb.fabric-readout-rule-onboarding` when output/readout rows or reducers are involved, +`cb.fabric-public-api-cleanup` when old public declaration surfaces still feed the strategy, +`cb.fabric-runtime-front-end-handoff` when pre-temporal public-call preparation or state/boundary/sender-KV handoff is +the owner, +`cb.fabric-boundary-guardrails` for source/static guardrails and deletion checks, +`cb.fabric-compiler-stress-test` when changed semantics must be proven before throughput resumes. diff --git a/skills/cb.fabric-workflow-router/SKILL.md b/skills/cb.fabric-workflow-router/SKILL.md new file mode 100644 index 00000000..c42b7eab --- /dev/null +++ b/skills/cb.fabric-workflow-router/SKILL.md @@ -0,0 +1,177 @@ +--- +name: cb.fabric-workflow-router +description: Use when a Cortical Fabric request needs routing across compiler closure, declaration onboarding, throughput strategy, semantic extension, cell/message/readout/graph/primitive onboarding, parity, performance, cleanup, or the user's analyze -> plan -> proceed workflow. +--- + +# Fabric Workflow Router + +Use this skill to choose the right Fabric workflow before code changes. It is a router, not a replacement for the +narrow skills. + +**Announce at start:** "Using Fabric workflow router. I'll classify the Fabric work first, then use the narrow boundary skill for implementation." + +## First Read + +- `skills/cb.fabric-compiler-boundary-audit/SKILL.md` +- `skills/cb.fabric-declaration-onboarding/SKILL.md` when the request changes user-visible graph, cell, message, + readout, primitive-op, reset/init, or public declaration semantics. +- `skills/cb.fabric-boundary-guardrails/SKILL.md` when adding or editing source/static checks, deletion checks, or + no-fallback/no-compat guardrails. +- The narrow skill selected by the classifier below. +- The active progress doc for the work, when one exists. + +Inspect current code rather than relying on chat memory: + +```bash +find skills -maxdepth 2 -name SKILL.md -print | sort +rg -n "primitive_row|executor_row|tensor_binding|artifact_route|output_route|memory_liveness|native_callable|fallback|compat|fixed" src/cortical/fabric tests benchmarks +``` + +## Prompt Phase Classifier + +When the user uses the common three-prompt loop, keep the phases separate: + +- **Deep dive / analyze:** build the owner table, active path map, missing compiler products, and evidence commands. + Do not optimize or refactor. +- **Come up with a plan:** choose one highest-impact legal owner, name the compiler products, tests, and rollback rule. + Do not optimize. +- **Proceed:** implement only the selected owner through the relevant narrow skill, run targeted gates, and update the + progress doc. + +If the newest prompt changes phase, obey the newest prompt. Do not continue an older implementation plan during a new +analysis-only prompt. + +## Boundary Manifest Gate + +Every non-trivial Fabric plan needs a boundary manifest before implementation. This applies to throughput, native +kernels, new cells, message rules, readout rules, primitive ops, graph constructors, public API cleanup, and skill edits +that change those workflows. + +```text +Lane: +Public declaration/spec owner: +Rows expected to change: +Rows expected to stay stable: +Bindings/routes/liveness consumed: +Reference executor: +Native strategy, if any: +Backward/reducer owner: +Unsupported typed blocker: +Old route deleted or fail-closed: +Evidence gates: +``` + +For semantic work, `Rows expected to change` is mandatory. For throughput/native work, `Rows expected to stay stable` is +mandatory. If neither can be named, the next task is analysis or compiler-extension scaffolding, not implementation. + +## Work Classifier + +Use the narrowest matching skill set: + +| Work | Required skills | +| --- | --- | +| Throughput strategy, fused kernel, memory/liveness, reducer, cost rule | `cb.fabric-compiler-boundary-audit`, `cb.fabric-throughput-strategy`, `cb.fabric-performance-loop`, `cb.fabric-parity-gate` | +| Registered CUDA/Triton/native callable or fused program kernel implementation | `cb.fabric-compiler-boundary-audit`, `cb.fabric-native-strategy-onboarding`, `cb.fabric-throughput-strategy`, `cb.fabric-parity-gate` | +| Parameter reducer, reverse span output, runtime-buffer lifetime, or workspace ownership | `cb.fabric-compiler-boundary-audit`, `cb.fabric-reducer-liveness`, `cb.fabric-throughput-strategy`, `cb.fabric-parity-gate` | +| Runtime/model front-end handoff, input/output adapters, boundary tensor prep, state init, sender K/V setup | `cb.fabric-compiler-boundary-audit`, `cb.fabric-runtime-front-end-handoff`, plus `cb.fabric-throughput-strategy` for performance or `cb.fabric-declaration-onboarding` for semantic changes | +| New or changed public semantics/declaration shape | `cb.fabric-compiler-boundary-audit`, `cb.fabric-declaration-onboarding`, `cb.fabric-compiler-extension`, `cb.fabric-parity-gate` | +| Pre-throughput compiler reality/stress test for changed math or tensor roles | `cb.fabric-compiler-boundary-audit`, `cb.fabric-compiler-stress-test`, `cb.fabric-declaration-onboarding`, `cb.fabric-compiler-extension`, plus the narrow semantic skill | +| New or changed primitive op/formula | `cb.fabric-compiler-boundary-audit`, `cb.fabric-declaration-onboarding`, `cb.fabric-primitive-op-onboarding`, `cb.fabric-compiler-extension`, `cb.fabric-parity-gate` | +| New or changed message rule | `cb.fabric-compiler-boundary-audit`, `cb.fabric-declaration-onboarding`, `cb.fabric-message-rule-onboarding`, `cb.fabric-compiler-extension`, `cb.fabric-parity-gate`, plus primitive/graph/readout skills if it adds those products | +| New or changed cell | `cb.fabric-compiler-boundary-audit`, `cb.fabric-declaration-onboarding`, `cb.fabric-cell-onboarding`, `cb.fabric-cell-boundaries`, `cb.fabric-compiler-extension`, `cb.fabric-parity-gate`, plus primitive/message/readout skills as needed | +| New or changed readout/output route | `cb.fabric-compiler-boundary-audit`, `cb.fabric-declaration-onboarding`, `cb.fabric-readout-rule-onboarding`, `cb.fabric-compiler-extension`, `cb.fabric-parity-gate` | +| Public API, Config, Blueprint, declaration cleanup | `cb.fabric-compiler-boundary-audit`, `cb.fabric-public-api-cleanup`, `cb.fabric-declaration-onboarding`, plus graph/cell/message/readout skills as needed | +| Graph/topology constructor or graph-fact work | `cb.fabric-compiler-boundary-audit`, `cb.fabric-declaration-onboarding`, `cb.fabric-graph-onboarding`, `cb.fabric-backend-boundaries`, `cb.fabric-parity-gate` | +| Boundary review or closure check | `cb.fabric-compiler-boundary-audit` | +| Source/static guardrail, legacy deletion check, no-fallback/no-compat check | `cb.fabric-boundary-guardrails`, plus the narrow skill whose invariant is being guarded | +| Skill edits | `cb.fabric-skill-maintenance`, `skill-creator` | + +Always include `cb.fabric-compiler-boundary-audit` when code changes can affect compiler ownership. + +## Routing Rules + +- If public meaning changes, route to `cb.fabric-declaration-onboarding` and then the narrow semantic onboarding skill + first. Public meaning includes graph facts, message math, cell transition math, readout/output semantics, tensor roles, + reset/tape behavior, parameter reducers, and gradient contracts. +- If a future task adds or changes a cell, message rule, readout rule, graph constructor, or primitive op, enforce the + same boundary shape for all of them: public declaration -> normalized spec -> rows/bindings/routes -> verifier -> + reference executor -> optional native strategy -> backward/reducer -> parity/guardrails. Do not let one surface be + defined by a side `.cuh`, config field, route selector, or benchmark helper while another uses the compiler path. +- If the user asks for a compiler stress test, formula perturbation, dot-product/math update before throughput, or proof + that adding new semantics is local, route to `cb.fabric-compiler-stress-test`. Do not treat that pass as throughput + work even if the stress target came from a performance branch. +- If primitive rows and public semantics stay stable, route to throughput strategy. Throughput may change fusion, launch + shape, layout, workspace, liveness, aliasing, checkpoint/recompute, reducers, or cost policy only for already verified + rows. +- Before a "Proceed" implementation, classify the lane and keep it stable for that pass: + - **semantic lane:** declaration/spec, lowering, rows, bindings, reference, legality, backward/reducer; + - **strategy lane:** registered implementation over unchanged rows; + - **native lane:** CUDA/Triton/C++ ABI for existing rows; + - **evidence lane:** owner table, parity, perf, docs only; + - **cleanup lane:** delete stale public/backend routes after the replacement is active. + If the work crosses lanes, finish the current pass or replan before editing. +- If the work enters CUDA/Triton/C++ native strategy code, use `cb.fabric-native-strategy-onboarding` and prove the + strategy consumes existing compiler rows directly before writing kernels. +- If a throughput plan needs a missing op, tensor role, artifact route, output route, backward contract, or reducer, + stop and reclassify as compiler extension. +- If a throughput/native plan changes whether tensors are returned, reduced, accumulated, aliased, saved, or dropped, + route through `cb.fabric-reducer-liveness`; every retained tensor needs a compiler consumer row. +- If a performance owner appears before the registered temporal/native program entry, route through + `cb.fabric-runtime-front-end-handoff`. Public-call preparation may create boundary/state tensors, but backend runtime + buffers, workspace, strategy selection, and primitive formulas must move behind compiler-owned rows or fail closed. +- If cleanup removes a legacy route, prove the compiler-owned replacement is active and covered first. Do not replace an + old path with a RuntimeError bridge; delete the route or make unsupported declarations fail closed through legality. +- If the task adds a source/static guardrail, use `cb.fabric-boundary-guardrails`. Pair each negative stale-route ban + with a positive compiler product check, and do not use the guardrail as a substitute for parity or active-route + evidence. +- If cleanup touches `Config`, `Blueprint`, public constructors, or declaration normalization, route to + `cb.fabric-public-api-cleanup`. The result must assign fields to graph, cell, message, readout, reset, initialization, + or planner owners; it must not keep a broad config object as hidden compiler truth. +- If a performance probe does not move active owner time, launch count, storage identity, or peak memory as intended, + record it as rejected and do not keep expanding the same shortcut. + +## Required Preflight + +Before implementation, fill this in the progress doc or working notes: + +```text +Phase: +Surface: +Declaration/spec owner: +Primitive/graph rows: +Tensor/parameter/artifact/output bindings: +Forward owner: +Backward/reducer owner: +Memory/liveness owner: +Old route to delete or fail-close: +Parity/perf gates: +Rollback/narrowing rule: +Guardrail invariant: +``` + +If any required owner cannot be named, the next task is to add that compiler product, not to tune a kernel or wrapper. + +## Required Handoff Packet + +Before handing from analyze -> plan or plan -> proceed, write this in the active progress doc or working notes: + +```text +Lane: +Rows/fingerprints expected to stay stable: +Rows/bindings/routes expected to change: +Native ABI touched: +Semantic stress/locality test: +Parity matrix: +Perf owner row: +Delete/fail-close target: +Source guardrail target: +``` + +For throughput, `Rows/fingerprints expected to stay stable` is mandatory. For semantic work, `Rows/bindings/routes +expected to change` is mandatory. If both are unknown, do not implement. + +## Closeout + +- Update the active progress doc with accepted and rejected work. +- Run the narrow tests required by the selected skill. +- State which skill owns the next step. diff --git a/src/cortical/fabric/README.md b/src/cortical/fabric/README.md index ee220333..1dac40e3 100644 --- a/src/cortical/fabric/README.md +++ b/src/cortical/fabric/README.md @@ -204,7 +204,8 @@ Cell files own cell-local declarations and math: - `fabric/cells/`: user-facing cell declarations and population construction - `fabric/backend/cell_specs.py`: planner-facing transition metadata - `fabric/backend/pytorch/cells/`: PyTorch reference cell math -- `fabric/backend/cuda/cells/`: CUDA-side local cell declarations and local math +- `fabric/backend/cuda/transition_execution/` and `fabric/backend/cuda/sequence_surface/`: + compiler-selected CUDA primitive executors for supported transition rows Cells do not choose schedules, launch shapes, GEMM families, graph policy, message execution, output aggregation, workspace policy, or fallback behavior. @@ -225,6 +226,9 @@ fabric.message_rules.DotProduct( ``` The supported public `message_rules` rule is `DotProduct`, configured on `Blueprint.message_passing`. +Like cells, message rules lower through a backend spec registry. `DotProduct` is the currently registered +`MessageRuleBackendSpec`; adding another message rule should add a semantic declaration, a backend spec builder, primitive +bindings, and executor coverage instead of editing temporal scheduler code. `DotProduct` declares: diff --git a/src/cortical/fabric/anatomy.py b/src/cortical/fabric/anatomy.py index d14b6934..3a000a1e 100644 --- a/src/cortical/fabric/anatomy.py +++ b/src/cortical/fabric/anatomy.py @@ -1,14 +1,25 @@ from __future__ import annotations -import math from dataclasses import dataclass -from functools import lru_cache -from itertools import product +from typing import TYPE_CHECKING import torch from cortical.fabric.config import Config -from cortical.fabric.graph import GraphTopology, build_graph_topology, normalize_node_indices +from cortical.fabric.graph import GraphTopology, build_graph_topology +from cortical.fabric.graphs.lattice_anatomy import ( + assign_lattice_cells, + build_lattice_coords, + build_lattice_kv_groups, + build_lattice_local_sender_table, + build_lattice_ports, + build_lattice_slot_init, + build_lattice_sparse_graph, + lattice_anatomy_config_from_graph, +) + +if TYPE_CHECKING: + from cortical.fabric.backend.message_rules import MessageRuleIR @dataclass(frozen=True) @@ -21,6 +32,9 @@ class AnatomySpec: local_valid: torch.Tensor local_distance: torch.Tensor local_delay: torch.Tensor | None + full_local_sender_idx: torch.Tensor + recurrent_local_sender_idx: torch.Tensor + output_local_sender_idx: torch.Tensor neighbor_idx: torch.Tensor neighbor_valid: torch.Tensor edge_type: torch.Tensor @@ -41,17 +55,19 @@ class Spec: input_cell_idx: torch.Tensor output_cell_idx: torch.Tensor slot_init: torch.Tensor + message_rule: MessageRuleIR | None = None def init(config: Config | None = None, **kwargs) -> Spec: cfg = config if config is not None else Config(**kwargs) - coords = _build_coords(cfg) - population_names = tuple(cfg.cell_populations.keys()) - input_cell_idx, output_cell_idx = _build_ports(cfg, coords) + lattice_cfg = lattice_anatomy_config_from_graph(cfg.graph, cfg) + coords = build_lattice_coords(lattice_cfg) + population_names = tuple(cfg.populations.cell_populations.keys()) + input_cell_idx, output_cell_idx = build_lattice_ports(lattice_cfg, coords) recurrent_cell_idx = _build_recurrent_cells(coords.shape[0], input_cell_idx, output_cell_idx) cell_layout = -torch.ones(coords.shape[0], dtype=torch.long) - cell_layout[recurrent_cell_idx] = _assign_cells( - cfg, + cell_layout[recurrent_cell_idx] = assign_lattice_cells( + lattice_cfg, recurrent_cell_idx=recurrent_cell_idx, recurrent_coords=coords.index_select(0, recurrent_cell_idx), population_names=population_names, @@ -66,15 +82,46 @@ def init(config: Config | None = None, **kwargs) -> Spec: edge_type, edge_distance, edge_delay, - ) = _build_sparse_graph( - cfg, + ) = build_lattice_sparse_graph( + lattice_cfg, coords, input_cell_idx=input_cell_idx, output_cell_idx=output_cell_idx, ) - kv_group_id, num_kv_groups = _build_kv_groups(cfg, coords) - slot_init = _build_slot_init( - cfg, + kv_group_id, num_kv_groups = build_lattice_kv_groups(lattice_cfg, coords) + sender_mask = torch.ones(coords.shape[0], dtype=torch.bool) + sender_mask[output_cell_idx] = False + sender_cell_idx = torch.nonzero(sender_mask, as_tuple=False).reshape(-1) + sender_lookup = torch.full((coords.shape[0],), -1, dtype=torch.long) + sender_lookup[sender_cell_idx] = torch.arange(sender_cell_idx.numel(), dtype=torch.long) + full_local_sender_idx = build_lattice_local_sender_table( + receiver_coords=coords, + sender_lookup=sender_lookup, + local_offsets=local_offsets, + local_valid=local_valid, + coord_shape=lattice_cfg.coord_shape, + wrap=lattice_cfg.wrap, + ) + recurrent_local_valid = local_valid.index_select(0, recurrent_cell_idx) + recurrent_local_sender_idx = build_lattice_local_sender_table( + receiver_coords=coords.index_select(0, recurrent_cell_idx), + sender_lookup=sender_lookup, + local_offsets=local_offsets, + local_valid=recurrent_local_valid, + coord_shape=lattice_cfg.coord_shape, + wrap=lattice_cfg.wrap, + ) + output_local_valid = local_valid.index_select(0, output_cell_idx) + output_local_sender_idx = build_lattice_local_sender_table( + receiver_coords=coords.index_select(0, output_cell_idx), + sender_lookup=sender_lookup, + local_offsets=local_offsets, + local_valid=output_local_valid, + coord_shape=lattice_cfg.coord_shape, + wrap=lattice_cfg.wrap, + ) + slot_init = build_lattice_slot_init( + lattice_cfg, coords, cell_layout, recurrent_cell_idx, @@ -90,12 +137,18 @@ def init(config: Config | None = None, **kwargs) -> Spec: local_valid=local_valid, local_distance=local_distance, local_delay=local_delay, + full_local_sender_idx=full_local_sender_idx, + recurrent_local_sender_idx=recurrent_local_sender_idx, + output_local_sender_idx=output_local_sender_idx, neighbor_idx=neighbor_idx, neighbor_valid=neighbor_valid, edge_type=edge_type, edge_distance=edge_distance, edge_delay=edge_delay, - metadata={"shape": cfg.coord_shape, "wrap": cfg.wrap}, + metadata={ + "coord_shape": tuple(int(size) for size in lattice_cfg.coord_shape), + "wrap": bool(lattice_cfg.wrap), + }, ) graph_topology = build_graph_topology( node_count=coords.shape[0], @@ -122,467 +175,6 @@ def init(config: Config | None = None, **kwargs) -> Spec: ) -def _build_coords(cfg: Config) -> torch.Tensor: - axes = [torch.arange(size, dtype=torch.float32) for size in cfg.coord_shape] - return torch.cartesian_prod(*axes) - - -def _assign_cells( - cfg: Config, - *, - recurrent_cell_idx: torch.Tensor, - recurrent_coords: torch.Tensor, - population_names: tuple[str, ...], -) -> torch.Tensor: - num_cells = recurrent_coords.shape[0] - if cfg.population_node_indices is not None: - return _assign_explicit_population_nodes( - cfg, - recurrent_cell_idx=recurrent_cell_idx, - population_names=population_names, - ) - weights = torch.tensor([cfg.population_mix[name] for name in population_names], dtype=torch.float64) - weights = weights / weights.sum() - expected = weights * float(num_cells) - counts = expected.floor().to(torch.long) - remainder = int(num_cells - counts.sum().item()) - if remainder > 0: - frac = expected - counts.to(expected.dtype) - order = torch.argsort(frac, descending=True) - counts[order[:remainder]] += 1 - labels = [] - for population_idx, count in enumerate(counts.tolist()): - labels.extend([population_idx] * count) - layout = torch.tensor(labels, dtype=torch.long) - if cfg.cell_arrangement == "x_bands": - order = torch.argsort(_lexsort_key(recurrent_coords)) - arranged = torch.empty(num_cells, dtype=torch.long) - arranged[order] = layout - return arranged - gen = torch.Generator(device="cpu") - gen.manual_seed(cfg.seed) - perm = torch.randperm(num_cells, generator=gen) - return layout.index_select(0, perm) - - -def _assign_explicit_population_nodes( - cfg: Config, - *, - recurrent_cell_idx: torch.Tensor, - population_names: tuple[str, ...], -) -> torch.Tensor: - assert cfg.population_node_indices is not None - recurrent_nodes = [int(idx) for idx in recurrent_cell_idx.tolist()] - global_to_recurrent = {node: local_idx for local_idx, node in enumerate(recurrent_nodes)} - layout = -torch.ones(len(recurrent_nodes), dtype=torch.long) - for population_idx, population_name in enumerate(population_names): - for node in cfg.population_node_indices[population_name]: - recurrent_idx = global_to_recurrent.get(int(node)) - if recurrent_idx is None: - raise ValueError( - f"population {population_name!r} targets non-recurrent node {int(node)}; " - "population nodes must exclude input and output boundary nodes" - ) - layout[recurrent_idx] = population_idx - missing = torch.nonzero(layout < 0, as_tuple=False).reshape(-1) - if missing.numel() > 0: - missing_nodes = recurrent_cell_idx.index_select(0, missing[:8]).tolist() - raise ValueError( - "population_node_indices must cover every recurrent node exactly once; " - f"missing recurrent nodes={missing_nodes}" - ) - return layout - - -def _build_sparse_graph( - cfg: Config, - coords: torch.Tensor, - *, - input_cell_idx: torch.Tensor, - output_cell_idx: torch.Tensor, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor | None, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor | None, -]: - if cfg.graph_edges is not None: - return _build_explicit_sparse_graph( - cfg, - coords, - input_cell_idx=input_cell_idx, - output_cell_idx=output_cell_idx, - ) - num_cells = coords.shape[0] - coord_dim = coords.shape[1] - coords_long = coords.to(torch.long) - shape = cfg.coord_shape - shape_tensor = torch.tensor(shape, dtype=torch.long) - stride_tensor = torch.tensor(_flat_index_strides(shape), dtype=torch.long) - input_mask = torch.zeros(num_cells, dtype=torch.bool) - input_mask[input_cell_idx] = True - output_mask = torch.zeros(num_cells, dtype=torch.bool) - output_mask[output_cell_idx] = True - recv_idx = torch.arange(num_cells, dtype=torch.long) - local_offsets = _neighbor_offsets(coord_dim=coord_dim, min_distance=0.0, max_distance=cfg.local_radius) - patch_offsets = ( - _neighbor_offsets(coord_dim=coord_dim, min_distance=cfg.patch_min_dist, max_distance=cfg.patch_max_dist) - if cfg.patch_edges_per_cell > 0 - else () - ) - local_offset_tensor = torch.tensor([delta for delta, _distance in local_offsets], dtype=torch.int32) - local_valid = torch.zeros(num_cells, len(local_offsets), dtype=torch.bool) - local_distance = torch.tensor([distance for _delta, distance in local_offsets], dtype=coords.dtype) - local_delay = ( - torch.tensor( - [_edge_delay(distance=distance, cfg=cfg) for _delta, distance in local_offsets], - dtype=torch.int32, - ) - if cfg.max_delay is not None - else None - ) - max_slots = len(local_offsets) + cfg.patch_edges_per_cell - if max_slots == 0: - raise ValueError("fabric graph has no edges; increase local_radius or change anatomy size") - - neighbor_idx = torch.zeros(num_cells, max_slots, dtype=torch.long) - neighbor_valid = torch.zeros(num_cells, max_slots, dtype=torch.bool) - edge_type = torch.zeros(num_cells, max_slots, dtype=torch.long) - edge_distance = torch.zeros(num_cells, max_slots, dtype=coords.dtype) - edge_delay = torch.ones(num_cells, max_slots, dtype=torch.long) if cfg.max_delay is not None else None - neighbor_counts = torch.zeros(num_cells, dtype=torch.long) - patch_counts = torch.zeros(num_cells, dtype=torch.long) if cfg.patch_edges_per_cell > 0 else None - max_neighbors = 0 - for offset_idx, (delta, distance) in enumerate(local_offsets): - max_neighbors = _append_offset_neighbors( - coords=coords_long, - shape_tensor=shape_tensor, - stride_tensor=stride_tensor, - wrap=cfg.wrap, - recv_idx=recv_idx, - input_mask=input_mask, - output_mask=output_mask, - delta=delta, - distance=distance, - edge_kind=0, - patch_limit=None, - neighbor_idx=neighbor_idx, - neighbor_valid=neighbor_valid, - edge_type=edge_type, - edge_distance=edge_distance, - edge_delay=edge_delay, - neighbor_counts=neighbor_counts, - patch_counts=patch_counts, - delay_value=_edge_delay(distance=distance, cfg=cfg), - current_max_neighbors=max_neighbors, - offset_valid=local_valid, - offset_slot=offset_idx, - ) - for delta, distance in patch_offsets: - max_neighbors = _append_offset_neighbors( - coords=coords_long, - shape_tensor=shape_tensor, - stride_tensor=stride_tensor, - wrap=cfg.wrap, - recv_idx=recv_idx, - input_mask=input_mask, - output_mask=output_mask, - delta=delta, - distance=distance, - edge_kind=1, - patch_limit=cfg.patch_edges_per_cell, - neighbor_idx=neighbor_idx, - neighbor_valid=neighbor_valid, - edge_type=edge_type, - edge_distance=edge_distance, - edge_delay=edge_delay, - neighbor_counts=neighbor_counts, - patch_counts=patch_counts, - delay_value=_edge_delay(distance=distance, cfg=cfg), - current_max_neighbors=max_neighbors, - ) - - if max_neighbors == 0: - raise ValueError("fabric graph has no edges; increase local_radius or change anatomy size") - edge_delay_out = edge_delay[:, :max_neighbors] if edge_delay is not None else None - return ( - local_offset_tensor, - local_valid, - local_distance, - local_delay, - neighbor_idx[:, :max_neighbors], - neighbor_valid[:, :max_neighbors], - edge_type[:, :max_neighbors], - edge_distance[:, :max_neighbors], - edge_delay_out, - ) - - -def _build_explicit_sparse_graph( - cfg: Config, - coords: torch.Tensor, - *, - input_cell_idx: torch.Tensor, - output_cell_idx: torch.Tensor, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor | None, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor | None, -]: - assert cfg.graph_edges is not None - num_cells = coords.shape[0] - coord_dim = coords.shape[1] - input_mask = torch.zeros(num_cells, dtype=torch.bool) - input_mask[input_cell_idx] = True - output_mask = torch.zeros(num_cells, dtype=torch.bool) - output_mask[output_cell_idx] = True - receivers = torch.tensor([receiver for receiver, _sender in cfg.graph_edges], dtype=torch.long) - senders = torch.tensor([sender for _receiver, sender in cfg.graph_edges], dtype=torch.long) - if bool(input_mask.index_select(0, receivers).any()): - raise ValueError("graph_edges must not target input boundary nodes") - if bool(output_mask.index_select(0, senders).any()): - raise ValueError("graph_edges must not use output boundary nodes as senders") - direct_input_to_output = output_mask.index_select(0, receivers) & input_mask.index_select(0, senders) - if bool(direct_input_to_output.any()): - raise ValueError("graph_edges must not connect input boundary nodes directly into output boundary nodes") - degree = torch.bincount(receivers, minlength=num_cells) - max_neighbors = int(degree.max().item()) if degree.numel() > 0 else 0 - if max_neighbors == 0: - raise ValueError("fabric graph has no edges") - neighbor_idx = torch.zeros(num_cells, max_neighbors, dtype=torch.long) - neighbor_valid = torch.zeros(num_cells, max_neighbors, dtype=torch.bool) - edge_type = torch.ones(num_cells, max_neighbors, dtype=torch.long) - edge_distance = torch.ones(num_cells, max_neighbors, dtype=coords.dtype) - edge_delay = torch.ones(num_cells, max_neighbors, dtype=torch.long) if cfg.max_delay is not None else None - write_pos = torch.zeros(num_cells, dtype=torch.long) - for receiver, sender in cfg.graph_edges: - col = int(write_pos[receiver].item()) - neighbor_idx[receiver, col] = int(sender) - neighbor_valid[receiver, col] = True - write_pos[receiver] += 1 - local_offsets = torch.empty(0, coord_dim, dtype=torch.int32) - local_valid = torch.zeros(num_cells, 0, dtype=torch.bool) - local_distance = torch.empty(0, dtype=coords.dtype) - local_delay = torch.empty(0, dtype=torch.int32) if cfg.max_delay is not None else None - return ( - local_offsets, - local_valid, - local_distance, - local_delay, - neighbor_idx, - neighbor_valid, - edge_type, - edge_distance, - edge_delay, - ) - - -def _append_offset_neighbors( - *, - coords: torch.Tensor, - shape_tensor: torch.Tensor, - stride_tensor: torch.Tensor, - wrap: bool, - recv_idx: torch.Tensor, - input_mask: torch.Tensor, - output_mask: torch.Tensor, - delta: tuple[int, ...], - distance: float, - edge_kind: int, - patch_limit: int | None, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_type: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor | None, - neighbor_counts: torch.Tensor, - patch_counts: torch.Tensor | None, - delay_value: int, - current_max_neighbors: int, - offset_valid: torch.Tensor | None = None, - offset_slot: int | None = None, -) -> int: - send_idx, valid = _resolve_offset_indices( - coords=coords, - shape_tensor=shape_tensor, - stride_tensor=stride_tensor, - wrap=wrap, - delta=delta, - ) - valid = valid & ~input_mask - valid = valid & (send_idx != recv_idx) - valid = valid & ~output_mask.index_select(0, send_idx) - valid = valid & ~(output_mask & input_mask.index_select(0, send_idx)) - if patch_limit is not None and patch_counts is not None: - valid = valid & (patch_counts < patch_limit) - if current_max_neighbors > 0: - selected_idx = neighbor_idx[:, :current_max_neighbors] - selected_valid = neighbor_valid[:, :current_max_neighbors] - duplicate = ((selected_idx == send_idx.unsqueeze(1)) & selected_valid).any(dim=1) - valid = valid & ~duplicate - rows = torch.nonzero(valid, as_tuple=False).reshape(-1) - if rows.numel() == 0: - return current_max_neighbors - cols = neighbor_counts.index_select(0, rows) - send_rows = send_idx.index_select(0, rows) - neighbor_idx[rows, cols] = send_rows - neighbor_valid[rows, cols] = True - edge_type[rows, cols] = edge_kind - edge_distance[rows, cols] = distance - if offset_valid is not None and offset_slot is not None: - offset_valid[rows, offset_slot] = True - if edge_delay is not None: - edge_delay[rows, cols] = delay_value - neighbor_counts[rows] += 1 - if patch_limit is not None and patch_counts is not None: - patch_counts[rows] += 1 - return max(current_max_neighbors, int(neighbor_counts[rows].max().item())) - - -def _resolve_offset_indices( - *, - coords: torch.Tensor, - shape_tensor: torch.Tensor, - stride_tensor: torch.Tensor, - wrap: bool, - delta: tuple[int, ...], -) -> tuple[torch.Tensor, torch.Tensor]: - shifted = coords + torch.tensor(delta, dtype=torch.long) - if wrap: - shifted = torch.remainder(shifted, shape_tensor) - valid = torch.ones(coords.shape[0], dtype=torch.bool) - else: - valid = ((shifted >= 0) & (shifted < shape_tensor)).all(dim=1) - shifted = torch.where(valid.unsqueeze(1), shifted, torch.zeros_like(shifted)) - return (shifted * stride_tensor).sum(dim=1), valid - - -@lru_cache(maxsize=None) -def _flat_index_strides(shape: tuple[int, ...]) -> tuple[int, ...]: - strides: list[int] = [] - acc = 1 - for size in reversed(shape[1:]): - acc *= int(size) - strides.append(acc) - return tuple(list(reversed(strides)) + [1]) - - -@lru_cache(maxsize=None) -def _neighbor_offsets( - *, - coord_dim: int, - min_distance: float, - max_distance: float, -) -> tuple[tuple[tuple[int, ...], float], ...]: - radius = math.ceil(max_distance) - tol = 1e-6 - offsets: list[tuple[tuple[int, ...], float]] = [] - for delta in product(range(-radius, radius + 1), repeat=coord_dim): - if all(step == 0 for step in delta): - continue - distance = math.sqrt(sum(step * step for step in delta)) - if distance + tol < min_distance or distance > max_distance + tol: - continue - offsets.append((tuple(int(step) for step in delta), distance)) - offsets.sort(key=lambda item: (item[1], tuple(abs(step) for step in item[0]), item[0])) - return tuple(offsets) - - -def _edge_delay(*, distance: float, cfg: Config) -> int: - if cfg.conduction_speed is None or cfg.max_delay is None: - return 1 - delay = int(math.ceil(distance / cfg.conduction_speed)) - return max(1, min(cfg.max_delay, delay)) - - -def _pairwise_distances(coords: torch.Tensor, shape: tuple[int, ...], *, wrap: bool) -> torch.Tensor: - diffs = (coords[:, None, :] - coords[None, :, :]).abs() - if wrap: - shape_tensor = torch.tensor(shape, dtype=coords.dtype).view(1, 1, -1) - diffs = torch.minimum(diffs, shape_tensor - diffs) - return torch.linalg.vector_norm(diffs, dim=-1) - - -def _lexsort_key(coords: torch.Tensor) -> torch.Tensor: - strides = [] - acc = 1 - max_vals = coords.max(dim=0).values.to(torch.long) + 1 - for size in reversed(max_vals[1:].tolist()): - acc *= int(size) - strides.append(acc) - strides = list(reversed(strides)) + [1] - key = torch.zeros(coords.shape[0], dtype=torch.long) - coords_long = coords.to(torch.long) - for axis, stride in enumerate(strides): - key = key + coords_long[:, axis] * stride - return key - - -def _build_kv_groups(cfg: Config, coords: torch.Tensor) -> tuple[torch.Tensor, int]: - if cfg.kv_group_ids is not None: - kv_group_id = torch.tensor(cfg.kv_group_ids, dtype=torch.long) - num_groups = int(kv_group_id.max().item()) + 1 if kv_group_id.numel() > 0 else 0 - return kv_group_id, num_groups - if cfg.projection_region_shape is None: - region_shape = tuple(max(1, size // 4) for size in cfg.coord_shape) - else: - region_shape = cfg.projection_region_shape - region = torch.div(coords.to(torch.long), torch.tensor(region_shape, dtype=torch.long), rounding_mode="floor") - grid_dims = [(size + tile - 1) // tile for size, tile in zip(cfg.coord_shape, region_shape, strict=True)] - strides = [] - acc = 1 - for size in reversed(grid_dims[1:]): - acc *= size - strides.append(acc) - strides = list(reversed(strides)) + [1] - kv_group_id = torch.zeros(region.shape[0], dtype=torch.long) - for axis in range(region.shape[1]): - kv_group_id = kv_group_id + region[:, axis] * strides[axis] - num_groups = int(kv_group_id.max().item()) + 1 if kv_group_id.numel() > 0 else 0 - return kv_group_id.to(torch.long), num_groups - - -def _build_ports(cfg: Config, coords: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - num_cells = coords.shape[0] - explicit_input = normalize_node_indices( - cfg.input_cell_indices, - node_count=num_cells, - name="input_cell_indices", - ) - explicit_output = normalize_node_indices( - cfg.output_cell_indices, - node_count=num_cells, - name="output_cell_indices", - ) - x_coord = coords[:, 0] - if explicit_input is None: - input_mask = x_coord < float(cfg.input_band_width) - input_idx = torch.nonzero(input_mask, as_tuple=False).reshape(-1).to(torch.long) - else: - input_idx = explicit_input - if explicit_output is None: - output_mask = x_coord >= float(cfg.width - cfg.output_band_width) - output_idx = torch.nonzero(output_mask, as_tuple=False).reshape(-1).to(torch.long) - else: - output_idx = explicit_output - if input_idx.numel() == 0 or output_idx.numel() == 0: - raise ValueError("port construction produced an empty input or output port set") - if bool(torch.isin(input_idx, output_idx).any()): - raise ValueError("input and output port cells must be disjoint") - return input_idx, output_idx - - def _build_recurrent_cells(num_cells: int, input_idx: torch.Tensor, output_idx: torch.Tensor) -> torch.Tensor: recurrent_mask = torch.ones(num_cells, dtype=torch.bool) recurrent_mask[input_idx] = False @@ -593,35 +185,4 @@ def _build_recurrent_cells(num_cells: int, input_idx: torch.Tensor, output_idx: return recurrent_idx -def _build_slot_init( - cfg: Config, - coords: torch.Tensor, - cell_layout: torch.Tensor, - recurrent_idx: torch.Tensor, - input_idx: torch.Tensor, - output_idx: torch.Tensor, -) -> torch.Tensor: - shape = torch.tensor(cfg.coord_shape, dtype=coords.dtype) - coords_norm = coords / shape.view(1, -1).clamp_min(1.0) - sin_feat = torch.sin(2.0 * math.pi * coords_norm) - cos_feat = torch.cos(2.0 * math.pi * coords_norm) - num_populations = len(cfg.cell_populations) - population_one_hot = torch.zeros(coords.shape[0], num_populations, dtype=coords.dtype) - population_one_hot[recurrent_idx] = torch.nn.functional.one_hot( - cell_layout[recurrent_idx], num_classes=num_populations - ).to(coords.dtype) - input_mask = torch.zeros(coords.shape[0], 1, dtype=coords.dtype) - input_mask[input_idx] = 1.0 - output_mask = torch.zeros(coords.shape[0], 1, dtype=coords.dtype) - output_mask[output_idx] = 1.0 - base = torch.cat([coords_norm, sin_feat, cos_feat, population_one_hot, input_mask, output_mask], dim=-1) - d_slot = int(cfg.d_slot if cfg.d_slot is not None else 2 * cfg.hidden_size) - repeats = math.ceil(d_slot / base.shape[1]) - slot = base.repeat(1, repeats)[:, :d_slot] - gen = torch.Generator(device="cpu") - gen.manual_seed(cfg.seed + 17) - noise = 0.01 * torch.randn(coords.shape[0], d_slot, generator=gen, dtype=coords.dtype) - return slot + noise - - __all__ = ["AnatomySpec", "Spec", "init"] diff --git a/src/cortical/fabric/backend/__init__.py b/src/cortical/fabric/backend/__init__.py index 3a1973e1..0cd68bd5 100644 --- a/src/cortical/fabric/backend/__init__.py +++ b/src/cortical/fabric/backend/__init__.py @@ -1,15 +1,34 @@ import cortical.fabric.backend.cell_specs as _cell_specs # noqa: F401 +import cortical.fabric.backend.message_rule_specs as _message_rule_specs # noqa: F401 +import cortical.fabric.backend.readout_rule_specs as _readout_rule_specs # noqa: F401 from cortical.fabric.backend.buckets import FabricBucket, ReceiverKind from cortical.fabric.backend.caps import DeviceCaps, detect_device_caps from cortical.fabric.backend.cell_backend import ( CellBackendSpec, CellTransitionIR, + CompiledTransitionParameterBinding, + CompiledTransitionPrimitiveOp, + CompiledTransitionProgram, SurfaceBackendVariant, TensorSchema, TransitionOp, build_cell_backend_spec, + compile_transition_program, ) from cortical.fabric.backend.ir import FabricIR, compile_fabric_ir +from cortical.fabric.backend.message_rules import ( + CompiledMessagePrimitiveOp, + CompiledMessageRule, + MessageOpPrimitiveBinding, + MessageRuleBackendSpec, + MessageRuleIR, + build_message_rule_backend_spec, + build_message_rule_ir, + compile_message_rule, + ordered_message_rule_backend_spec_types, + register_message_rule_backend_spec_builder, + registered_message_rule_backend_spec_types, +) from cortical.fabric.backend.plan_cache import ( FabricGraphCaptureCache, FabricPlanCache, @@ -32,6 +51,20 @@ cuda_nn_primitive_backward_behaviors, ) from cortical.fabric.backend.reuse import ExecutionFamily, MathBackend, ReuseScope +from cortical.fabric.backend.readout_rules import ( + CompiledReadoutPrimitiveOp, + CompiledReadoutRule, + ReadoutRuleBackendSpec, + ReadoutRuleIR, + ReadoutRuleNativeExecutorSpec, + ReadoutRuleStaticTensorSpec, + build_readout_rule_backend_spec, + compile_readout_rule, + default_readout_rule_ir, + readout_rule_native_executor, + register_readout_rule_backend_spec_builder, + registered_readout_rule_backend_spec_lowering_kinds, +) from cortical.fabric.backend.selector import select_fabric_backend from cortical.fabric.backend.surfaces import ( SUPPORTED_BACKEND_SURFACES, @@ -41,6 +74,19 @@ supported_surface_for_cell_type, ) from cortical.fabric.backend.tape import TapeMode, TapePolicy, default_tape_policy +from cortical.fabric.backend.temporal_plan import ( + SequenceSurfaceRoute, + TemporalBackwardWindowPlan, + TemporalBoundaryPlan, + TemporalCarryPlan, + TemporalCheckpointPlan, + TemporalExecutionPlan, + TemporalExecutorPlan, + TemporalGradientBoundaryPlan, + TemporalOutputRequestPlan, + TemporalSchedulePlan, + TemporalSubstratePlan, +) from cortical.fabric.backend.types import FabricBackendName, FabricEngineRequest, FabricEngineResult from cortical.fabric.backend.workspace import GraphCaptureWorkspace, WorkspacePlan, WorkspacePlanner @@ -49,6 +95,13 @@ "BackendExecutionRecord", "CellTransitionIR", "CellBackendSpec", + "CompiledTransitionParameterBinding", + "CompiledTransitionPrimitiveOp", + "CompiledTransitionProgram", + "CompiledReadoutPrimitiveOp", + "CompiledReadoutRule", + "CompiledMessagePrimitiveOp", + "CompiledMessageRule", "DeviceCaps", "ExecutionFamily", "FabricBackendName", @@ -64,6 +117,9 @@ "GraphCaptureWorkspace", "GraphCaptureCacheKey", "MathBackend", + "MessageOpPrimitiveBinding", + "MessageRuleBackendSpec", + "MessageRuleIR", "ParamGradBinding", "PhysicalBackwardOpPlan", "PhysicalBackwardPlan", @@ -72,23 +128,51 @@ "PlannedFabricExecution", "PrimitiveBackwardBehavior", "ReceiverKind", + "ReadoutRuleBackendSpec", + "ReadoutRuleIR", + "ReadoutRuleNativeExecutorSpec", + "ReadoutRuleStaticTensorSpec", "ReuseScope", + "SequenceSurfaceRoute", "SurfaceBackendVariant", "SUPPORTED_BACKEND_SURFACES", "SupportedSurface", "TensorSchema", "TapeMode", "TapePolicy", + "TemporalBackwardWindowPlan", + "TemporalBoundaryPlan", + "TemporalCarryPlan", + "TemporalCheckpointPlan", + "TemporalExecutionPlan", + "TemporalExecutorPlan", + "TemporalGradientBoundaryPlan", + "TemporalOutputRequestPlan", + "TemporalSchedulePlan", + "TemporalSubstratePlan", "TransitionOp", "WorkspacePlan", "WorkspacePlanner", "build_cell_backend_spec", + "build_message_rule_backend_spec", + "build_message_rule_ir", + "build_readout_rule_backend_spec", + "compile_transition_program", "compile_fabric_ir", + "compile_message_rule", + "compile_readout_rule", "cuda_nn_callable_primitives", "cuda_nn_primitive_backward_behavior", "cuda_nn_primitive_backward_behaviors", "default_tape_policy", + "default_readout_rule_ir", "detect_device_caps", + "readout_rule_native_executor", + "register_message_rule_backend_spec_builder", + "register_readout_rule_backend_spec_builder", + "ordered_message_rule_backend_spec_types", + "registered_message_rule_backend_spec_types", + "registered_readout_rule_backend_spec_lowering_kinds", "select_fabric_backend", "supported_surface_by_key", "supported_surface_for_cell_type", diff --git a/src/cortical/fabric/backend/buckets.py b/src/cortical/fabric/backend/buckets.py index 79ab60f8..09966b36 100644 --- a/src/cortical/fabric/backend/buckets.py +++ b/src/cortical/fabric/backend/buckets.py @@ -17,6 +17,9 @@ class FabricBucket: bucket_id: int receiver_kind: ReceiverKind population_name: str + population_index: int | None + transition_signature: tuple[str, ...] + parameter_binding: str dim_signature: tuple[int, int, int, int, int] receiver_count: int degree_bin: str @@ -31,8 +34,9 @@ class FabricBucket: @property def signature(self) -> tuple[object, ...]: return ( - self.population_name, self.receiver_kind.value, + self.transition_signature, + self.parameter_binding, self.degree_bin, self.dim_signature, self.delay_depth, @@ -40,3 +44,19 @@ def signature(self) -> tuple[object, ...]: self.sharing_pattern, self.has_sparse_overlay, ) + + @property + def planner_signature(self) -> tuple[object, ...]: + return ( + self.receiver_kind.value, + self.transition_signature, + self.receiver_count, + self.degree_bin, + self.degree_min, + self.degree_max, + self.dim_signature, + self.delay_depth, + self.stencil_template_id, + self.sharing_pattern, + self.has_sparse_overlay, + ) diff --git a/src/cortical/fabric/backend/cell_backend.py b/src/cortical/fabric/backend/cell_backend.py index fb42219d..6753d4ed 100644 --- a/src/cortical/fabric/backend/cell_backend.py +++ b/src/cortical/fabric/backend/cell_backend.py @@ -3,6 +3,7 @@ from collections.abc import Callable from dataclasses import dataclass +from cortical.fabric.backend.primitives import is_callable_cuda_nn_primitive from cortical.fabric.backend.reuse import ExecutionFamily, MathBackend, ReuseScope @@ -47,6 +48,53 @@ class TransitionParameterBinding: kind: str = "cell_param" +@dataclass(frozen=True) +class CompiledTransitionParameterBinding: + parameter: str + bindings: tuple[TransitionParameterBinding, ...] + + +@dataclass(frozen=True) +class CompiledTransitionPrimitiveOp: + op_index: int + primitive: str + source_op: str + inputs: tuple[str, ...] + outputs: tuple[str, ...] + attributes: tuple[tuple[str, str], ...] = () + parameter_inputs: tuple[str, ...] = () + + @property + def name(self) -> str: + return self.primitive + + +@dataclass(frozen=True) +class CompiledTransitionProgram: + binding_slot: int + lowering_kind: str + state_inputs: tuple[str, ...] + message_inputs: tuple[str, ...] + parameter_inputs: tuple[str, ...] + state_outputs: tuple[str, ...] + public_outputs: tuple[str, ...] + recompute_outputs: tuple[str, ...] + backward_decomposition: tuple[str, ...] + private_state_schema: tuple[TensorSchema, ...] + public_interface_schema: tuple[TensorSchema, ...] + parameter_schema: tuple[TensorSchema, ...] + parameter_bindings: tuple[CompiledTransitionParameterBinding, ...] + primitive_ops: tuple[CompiledTransitionPrimitiveOp, ...] + + @property + def ops(self) -> tuple[CompiledTransitionPrimitiveOp, ...]: + return self.primitive_ops + + @property + def primitive_names(self) -> tuple[str, ...]: + return tuple(dict.fromkeys(op.primitive for op in self.primitive_ops)) + + @dataclass(frozen=True) class SurfaceBackendVariant: execution_family: ExecutionFamily @@ -101,3 +149,57 @@ def build_cell_backend_spec( head_dim=head_dim, value_dim=value_dim, ) + + +def compile_transition_program( + spec: CellBackendSpec, + *, + binding_slot: int, +) -> CompiledTransitionProgram: + transition_ir = spec.transition_ir + if not transition_ir.ops: + raise ValueError("Unsupported Fabric transition program: no executable primitive rows were lowered") + primitive_ops: list[CompiledTransitionPrimitiveOp] = [] + for op_index, op in enumerate(transition_ir.ops): + primitive = str(op.name) + if not is_callable_cuda_nn_primitive(primitive): + raise ValueError( + f"Unsupported Fabric transition op {primitive!r}; add the op to fabric.cuda.nn lowering before using it" + ) + parameter_inputs = tuple( + str(input_name) for input_name in op.inputs if str(input_name) in spec.transition_parameter_bindings + ) + primitive_ops.append( + CompiledTransitionPrimitiveOp( + op_index=int(op_index), + primitive=primitive, + source_op=primitive, + inputs=tuple(str(input_name) for input_name in op.inputs), + outputs=tuple(str(output_name) for output_name in op.outputs), + attributes=( + ("source_op", primitive), + ("op_index", str(int(op_index))), + ) + + tuple((str(key), str(value)) for key, value in op.attributes), + parameter_inputs=parameter_inputs, + ) + ) + return CompiledTransitionProgram( + binding_slot=int(binding_slot), + lowering_kind="primitive_program:" + ",".join(op.primitive for op in primitive_ops), + state_inputs=tuple(str(item) for item in transition_ir.state_inputs), + message_inputs=tuple(str(item) for item in transition_ir.message_inputs), + parameter_inputs=tuple(str(item) for item in transition_ir.parameter_inputs), + state_outputs=tuple(str(item) for item in transition_ir.state_outputs), + public_outputs=tuple(str(item) for item in transition_ir.public_outputs), + recompute_outputs=tuple(str(item) for item in transition_ir.recompute_outputs), + backward_decomposition=tuple(str(item) for item in transition_ir.backward_decomposition), + private_state_schema=spec.private_state_schema, + public_interface_schema=spec.public_interface_schema, + parameter_schema=spec.parameter_schema, + parameter_bindings=tuple( + CompiledTransitionParameterBinding(str(parameter), tuple(bindings)) + for parameter, bindings in sorted(spec.transition_parameter_bindings.items()) + ), + primitive_ops=tuple(primitive_ops), + ) diff --git a/src/cortical/fabric/backend/cell_specs.py b/src/cortical/fabric/backend/cell_specs.py index 616389cf..b2d4c4ec 100644 --- a/src/cortical/fabric/backend/cell_specs.py +++ b/src/cortical/fabric/backend/cell_specs.py @@ -37,6 +37,8 @@ def build_slstm_cell_backend_spec( ), parameter_schema=( TensorSchema("receiver_query", "parameter", ("receiver", "head_dim"), ReuseScope.RECEIVER_LOCAL), + TensorSchema("value_to_state_weight", "parameter", ("hidden", "message"), ReuseScope.GROUP_SHARED), + TensorSchema("recurrent_bias", "parameter", ("hidden",), ReuseScope.GROUP_SHARED), TensorSchema("gate_weight", "parameter", ("receiver", "transition_in", "gate"), ReuseScope.RECEIVER_LOCAL), TensorSchema("recurrent_kernel", "parameter", ("receiver", "hidden", "gate"), ReuseScope.RECEIVER_LOCAL), TensorSchema("bias", "parameter", ("receiver", "gate"), ReuseScope.RECEIVER_LOCAL), @@ -51,16 +53,29 @@ def build_slstm_cell_backend_spec( transition_ir=CellTransitionIR( state_inputs=("y", "c", "n", "m"), message_inputs=("aggregated_message",), - parameter_inputs=("gate_weight", "recurrent_kernel", "bias", "outnorm_weight"), + parameter_inputs=( + "value_to_state_weight", + "recurrent_bias", + "gate_weight", + "recurrent_kernel", + "bias", + "outnorm_weight", + "outnorm_eps", + ), ops=( - TransitionOp("linear", ("aggregated_message", "gate_weight", "bias"), ("gate_logits",)), + TransitionOp( + "linear", + ("aggregated_message", "value_to_state_weight", "recurrent_bias"), + ("transition_input",), + ), + TransitionOp("linear", ("transition_input", "gate_weight", "bias"), ("gate_logits",)), TransitionOp("matmul", ("y", "recurrent_kernel"), ("recurrent_gate_logits",)), TransitionOp( "gated_logspace_recurrence", ("gate_logits", "recurrent_gate_logits", "c", "n", "m"), ("next_y", "next_c", "next_n", "next_m"), ), - TransitionOp("norm_or_identity", ("next_y", "outnorm_weight"), ("public_y",)), + TransitionOp("norm_or_identity", ("next_y", "outnorm_weight", "outnorm_eps"), ("public_y",)), ), state_outputs=("next_y", "next_c", "next_n", "next_m"), public_outputs=("public_y",), @@ -79,6 +94,15 @@ def build_slstm_cell_backend_spec( ), }, transition_parameter_bindings={ + "value_to_state_weight": ( + TransitionParameterBinding("fused_recurrent_value_to_cell_weight", kind="static_tensor"), + TransitionParameterBinding("message_to_cell_weight", kind="static_tensor"), + TransitionParameterBinding("value_to_cell_weight", kind="static_tensor"), + ), + "recurrent_bias": ( + TransitionParameterBinding("fused_recurrent_cell_bias", kind="static_tensor"), + TransitionParameterBinding("recurrent_cell_bias", kind="static_tensor"), + ), "gate_weight": (TransitionParameterBinding("gate_weight"),), "recurrent_kernel": (TransitionParameterBinding("recurrent_kernel"),), "bias": (TransitionParameterBinding("bias"),), @@ -86,6 +110,8 @@ def build_slstm_cell_backend_spec( }, reuse_scopes={ "receiver_query": ReuseScope.RECEIVER_LOCAL, + "value_to_state_weight": ReuseScope.GROUP_SHARED, + "recurrent_bias": ReuseScope.GROUP_SHARED, "gate_weight": ReuseScope.RECEIVER_LOCAL, "recurrent_kernel": ReuseScope.RECEIVER_LOCAL, "bias": ReuseScope.RECEIVER_LOCAL, @@ -136,6 +162,7 @@ def build_axon_cell_backend_spec( TensorSchema("w2", "parameter", ("receiver", "hidden"), ReuseScope.RECEIVER_LOCAL), TensorSchema("out_proj_weight", "parameter", ("receiver", "preproj", "hidden"), ReuseScope.RECEIVER_LOCAL), TensorSchema("out_proj_bias", "parameter", ("receiver", "hidden"), ReuseScope.RECEIVER_LOCAL), + TensorSchema("outnorm_weight", "parameter", ("receiver", "hidden"), ReuseScope.RECEIVER_LOCAL), TensorSchema( "input_proj_weight", "parameter", @@ -167,6 +194,8 @@ def build_axon_cell_backend_spec( "w2", "out_proj_weight", "out_proj_bias", + "outnorm_weight", + "outnorm_eps", ), ops=( TransitionOp( @@ -207,7 +236,8 @@ def build_axon_cell_backend_spec( "next_E_w2_c2", ), ), - TransitionOp("linear", ("preproj", "out_proj_weight", "out_proj_bias"), ("public_y",)), + TransitionOp("linear", ("preproj", "out_proj_weight", "out_proj_bias"), ("public_y_raw",)), + TransitionOp("norm_or_identity", ("public_y_raw", "outnorm_weight", "outnorm_eps"), ("public_y",)), ), state_outputs=( "next_hc1", @@ -234,6 +264,7 @@ def build_axon_cell_backend_spec( transition_parameter_bindings={ "input_proj_weight": ( TransitionParameterBinding("fused_recurrent_value_to_cell_weight", kind="static_tensor"), + TransitionParameterBinding("message_to_cell_weight", kind="static_tensor"), TransitionParameterBinding("value_to_cell_weight", kind="expanded_transposed_static_tensor"), ), "recurrent_cell_bias": ( @@ -261,6 +292,7 @@ def build_axon_cell_backend_spec( TransitionParameterBinding("out_proj_weight"), ), "out_proj_bias": (TransitionParameterBinding("out_proj_bias"),), + "outnorm_weight": (TransitionParameterBinding("outnorm_weight"),), }, reuse_scopes={ "nu_log": ReuseScope.RECEIVER_LOCAL, @@ -269,6 +301,7 @@ def build_axon_cell_backend_spec( "w2": ReuseScope.RECEIVER_LOCAL, "out_proj_weight": ReuseScope.RECEIVER_LOCAL, "out_proj_bias": ReuseScope.RECEIVER_LOCAL, + "outnorm_weight": ReuseScope.RECEIVER_LOCAL, "input_proj_weight": ReuseScope.RECEIVER_LOCAL, "recurrent_cell_bias": ReuseScope.RECEIVER_LOCAL, }, diff --git a/src/cortical/fabric/backend/cuda/__init__.py b/src/cortical/fabric/backend/cuda/__init__.py index b0caa198..bdec2fc8 100644 --- a/src/cortical/fabric/backend/cuda/__init__.py +++ b/src/cortical/fabric/backend/cuda/__init__.py @@ -1,51 +1,3 @@ -from cortical.fabric.backend.cuda import cells -from cortical.fabric.backend.cuda.execution import ( - ExecutionVariantSpec, - FabricExecutionRequest, - pack_tensor_tree, - run_registered_execution, -) -from cortical.fabric.backend.cuda.message_passing.local_message_cuda import ( - fabric_local_message_backward_receiver_cuda, - fabric_local_message_backward_sender_cuda, - fabric_local_message_cuda, - fabric_local_message_partitioned_backward_fused_cuda, - fabric_local_message_partitioned_backward_receiver_cuda, - fabric_local_message_partitioned_backward_sender_cuda, - fabric_local_message_partitioned_cuda, -) -from cortical.fabric.backend.cuda.message_passing.sparse_message_cuda import ( - fabric_sparse_message_backward_receiver_cuda, - fabric_sparse_message_backward_sender_cuda, - fabric_sparse_message_cuda, - fabric_sparse_message_partitioned_backward_receiver_cuda, - fabric_sparse_message_partitioned_backward_sender_cuda, - fabric_sparse_message_partitioned_cuda, -) -from cortical.fabric.backend.cuda.projection.grouped_projection_cuda import ( - fabric_grouped_projection_cuda, - fabric_grouped_projection_forward_cuda, -) +from __future__ import annotations -__all__ = [ - "fabric_grouped_projection_cuda", - "fabric_grouped_projection_forward_cuda", - "fabric_local_message_backward_receiver_cuda", - "fabric_local_message_backward_sender_cuda", - "fabric_local_message_cuda", - "fabric_local_message_partitioned_backward_fused_cuda", - "fabric_local_message_partitioned_backward_receiver_cuda", - "fabric_local_message_partitioned_backward_sender_cuda", - "fabric_local_message_partitioned_cuda", - "FabricExecutionRequest", - "ExecutionVariantSpec", - "cells", - "pack_tensor_tree", - "run_registered_execution", - "fabric_sparse_message_backward_receiver_cuda", - "fabric_sparse_message_backward_sender_cuda", - "fabric_sparse_message_cuda", - "fabric_sparse_message_partitioned_backward_receiver_cuda", - "fabric_sparse_message_partitioned_backward_sender_cuda", - "fabric_sparse_message_partitioned_cuda", -] +__all__: list[str] = [] diff --git a/src/cortical/fabric/backend/cuda/cells/__init__.py b/src/cortical/fabric/backend/cuda/cells/__init__.py deleted file mode 100644 index f92dc490..00000000 --- a/src/cortical/fabric/backend/cuda/cells/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -import cortical.fabric.backend.cuda.cells.axon as _axon # noqa: F401 -import cortical.fabric.backend.cuda.cells.slstm as _slstm # noqa: F401 - -__all__: list[str] = [] diff --git a/src/cortical/fabric/backend/cuda/cells/axon.cuh b/src/cortical/fabric/backend/cuda/cells/axon.cuh deleted file mode 100644 index 35171136..00000000 --- a/src/cortical/fabric/backend/cuda/cells/axon.cuh +++ /dev/null @@ -1,296 +0,0 @@ -#pragma once - -#include - -#include -#include - -#include "cortical/fabric/backend/cuda/contracts/cell.cuh" - -namespace fabric { - -struct Axon { - static constexpr int kReductionStatsDim = 0; - static constexpr bool kSupportsEmitOnlyStateUpdate = false; - - __device__ static float act_forward(int activation_id, float z) { - switch (activation_id) { - case 0: { - const float s = 1.0f / (1.0f + expf(-z)); - return z * s; - } - case 1: - return z > 0.0f ? z : 0.0f; - case 2: - return tanhf(z); - default: - return z; - } - } - - static int state_static_bytes_host(const std::vector& params, int receivers) { - (void)receivers; - TORCH_CHECK(params.size() >= 4, "Axon CUDA state staging requires dynamics params"); - const int hidden = static_cast(params[0].size(1)); - return static_cast(static_cast(4 * hidden) * sizeof(float)); - } - - static int emit_static_bytes_host(const std::vector& params, int receivers) { - (void)params; - (void)receivers; - return static_cast(sizeof(int32_t)); - } - - static fabric::cuda::nn::CellTransitionIR cell_transition_ir_host( - const std::vector& params, - int receivers, - int projected_message_dim, - int raw_public_dim) { - (void)receivers; - TORCH_CHECK(params.size() >= 5, "Axon CUDA transition IR requires dynamics and activation params"); - const int64_t hidden = params[0].size(1); - fabric::cuda::nn::Builder builder; - builder.private_state("state", 3); - builder.public_tensor("public", 3); - builder.diagonal_recurrence( - fabric::cuda::nn::DiagonalRecurrenceKind::ComplexExponential2D, - fabric::cuda::nn::DiagonalRecurrenceBinding{ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 0, - 1, - 2, - 3, - 4, - }, - hidden, - projected_message_dim, - raw_public_dim, - fabric::cuda::nn::ResetPolicy::ZeroSourceRows); - return builder.build_cell_transition(); - } - - __device__ static size_t state_static_bytes(int receiver, const TensorTable& params) { - (void)receiver; - const auto nu = tensor_ref(params, 0); - const int hidden = static_cast(nu.size[1]); - return static_cast(4 * hidden) * sizeof(float); - } - - __device__ static void stage_state_static( - int receiver, - const TensorTable& params, - void* smem, - int lane, - int lane_stride) { - const auto nu = tensor_ref(params, 0); - const auto theta = tensor_ref(params, 1); - const auto w1 = tensor_ref(params, 2); - const auto w2 = tensor_ref(params, 3); - float* dst = reinterpret_cast(smem); - const int hidden = static_cast(nu.size[1]); - for (int h = lane; h < hidden; h += lane_stride) { - dst[h] = nu.at(receiver, h); - dst[hidden + h] = theta.at(receiver, h); - dst[2 * hidden + h] = w1.at(receiver, h); - dst[3 * hidden + h] = w2.at(receiver, h); - } - } - - __device__ static size_t emit_static_bytes(int receiver, const TensorTable& params) { - (void)receiver; - (void)params; - return sizeof(int32_t); - } - - __device__ static void stage_emit_static( - int receiver, - const TensorTable& params, - void* smem, - int lane, - int lane_stride) { - (void)receiver; - for (int scalar = lane; scalar < 1; scalar += lane_stride) { - *reinterpret_cast(smem) = tensor_ref(params, 4).at(0); - } - } - - __device__ static void init_reduction_stats(float* stats) { - (void)stats; - } - - __device__ static void combine_reduction_stats(float* dst, const float* src) { - (void)dst; - (void)src; - } - - __device__ static void finalize_reduction_stats( - float* stats, - int state_dim, - int lane, - int lane_stride) { - (void)stats; - (void)state_dim; - (void)lane; - (void)lane_stride; - } - - __device__ static int state_dim( - int projected_message_dim, - int raw_public_dim, - const TensorTable& state_prev) { - (void)projected_message_dim; - (void)raw_public_dim; - const auto hc1_prev = tensor_ref(state_prev, 0); - return static_cast(hc1_prev.size[2]); - } - - __device__ static void forward_state_chunk( - int b, - int state_receiver, - int param_receiver, - int aux_receiver, - bool reset_row, - int lane, - int lane_stride, - const void* staged_static, - const TensorTable& params, - const TensorTable& state_prev, - TensorTable& state_next, - const float* projected_message, - int projected_message_dim, - int h0, - int h_count, - float* reduction_stats, - TensorTable* aux) { - (void)aux; - (void)aux_receiver; - (void)projected_message_dim; - (void)reduction_stats; - const auto hc1_prev = tensor_ref(state_prev, 0); - const auto hc2_prev = tensor_ref(state_prev, 1); - const auto E_nu_c1_prev = tensor_ref(state_prev, 2); - const auto E_nu_c2_prev = tensor_ref(state_prev, 3); - const auto E_th_c1_prev = tensor_ref(state_prev, 4); - const auto E_th_c2_prev = tensor_ref(state_prev, 5); - const auto E_w1_c1_prev = tensor_ref(state_prev, 6); - const auto E_w1_c2_prev = tensor_ref(state_prev, 7); - const auto E_w2_c1_prev = tensor_ref(state_prev, 8); - const auto E_w2_c2_prev = tensor_ref(state_prev, 9); - auto next_hc1 = tensor_ref(state_next, 0); - auto next_hc2 = tensor_ref(state_next, 1); - auto next_E_nu_c1 = tensor_ref(state_next, 2); - auto next_E_nu_c2 = tensor_ref(state_next, 3); - auto next_E_th_c1 = tensor_ref(state_next, 4); - auto next_E_th_c2 = tensor_ref(state_next, 5); - auto next_E_w1_c1 = tensor_ref(state_next, 6); - auto next_E_w1_c2 = tensor_ref(state_next, 7); - auto next_E_w2_c1 = tensor_ref(state_next, 8); - auto next_E_w2_c2 = tensor_ref(state_next, 9); - const bool use_staged_static = staged_static != nullptr; - const float* staged = reinterpret_cast(staged_static); - const int hidden = static_cast(hc1_prev.size[2]); - const auto nu_ref = tensor_ref(params, 0); - const auto theta_ref = tensor_ref(params, 1); - const auto w1_ref = tensor_ref(params, 2); - const auto w2_ref = tensor_ref(params, 3); - const float one_minus = reset_row ? 0.0f : 1.0f; - const int h_end = min(hidden, h0 + h_count); - for (int h = h0 + lane; h < h_end; h += lane_stride) { - const float xval = projected_message[h]; - const float nu = use_staged_static ? staged[h] : nu_ref.at(param_receiver, h); - const float thl = use_staged_static ? staged[hidden + h] : theta_ref.at(param_receiver, h); - const float w1 = use_staged_static ? staged[2 * hidden + h] : w1_ref.at(param_receiver, h); - const float w2 = use_staged_static ? staged[3 * hidden + h] : w2_ref.at(param_receiver, h); - const float exp_nu = expf(nu); - const float r = expf(-exp_nu); - const float theta = expf(thl); - float sin_theta; - float cos_theta; - sincosf(theta, &sin_theta, &cos_theta); - const float g = r * cos_theta; - const float phi = r * sin_theta; - const float r2 = r * r; - const float gamma = sqrtf(fmaxf(1.0f - r2, 0.0f)); - const float d_g_d_nu = -exp_nu * g; - const float d_phi_d_nu = -exp_nu * phi; - const float d_gamma_d_nu = exp_nu * r2 / fmaxf(gamma, 1e-20f); - const float exp_th = expf(thl); - const float d_g_d_th = -phi * exp_th; - const float d_phi_d_th = g * exp_th; - const float hc1_p = hc1_prev.at(b, state_receiver, h) * one_minus; - const float hc2_p = hc2_prev.at(b, state_receiver, h) * one_minus; - const float E_nu_c1_p = E_nu_c1_prev.at(b, state_receiver, h) * one_minus; - const float E_nu_c2_p = E_nu_c2_prev.at(b, state_receiver, h) * one_minus; - const float E_th_c1_p = E_th_c1_prev.at(b, state_receiver, h) * one_minus; - const float E_th_c2_p = E_th_c2_prev.at(b, state_receiver, h) * one_minus; - const float E_w1_c1_p = E_w1_c1_prev.at(b, state_receiver, h) * one_minus; - const float E_w1_c2_p = E_w1_c2_prev.at(b, state_receiver, h) * one_minus; - const float E_w2_c1_p = E_w2_c1_prev.at(b, state_receiver, h) * one_minus; - const float E_w2_c2_p = E_w2_c2_prev.at(b, state_receiver, h) * one_minus; - const float u1 = w1 * xval; - const float u2 = w2 * xval; - const float c1 = fmaf(gamma, u1, g * hc1_p - phi * hc2_p); - const float c2 = fmaf(gamma, u2, g * hc2_p + phi * hc1_p); - next_hc1.at(b, state_receiver, h) = c1; - next_hc2.at(b, state_receiver, h) = c2; - next_E_w1_c1.at(b, state_receiver, h) = fmaf(gamma, xval, g * E_w1_c1_p - phi * E_w1_c2_p); - next_E_w1_c2.at(b, state_receiver, h) = g * E_w1_c2_p + phi * E_w1_c1_p; - next_E_w2_c2.at(b, state_receiver, h) = fmaf(gamma, xval, g * E_w2_c2_p + phi * E_w2_c1_p); - next_E_w2_c1.at(b, state_receiver, h) = g * E_w2_c1_p - phi * E_w2_c2_p; - next_E_nu_c1.at(b, state_receiver, h) = - d_g_d_nu * hc1_p + g * E_nu_c1_p - d_phi_d_nu * hc2_p - phi * E_nu_c2_p + d_gamma_d_nu * u1; - next_E_nu_c2.at(b, state_receiver, h) = - d_g_d_nu * hc2_p + g * E_nu_c2_p + d_phi_d_nu * hc1_p + phi * E_nu_c1_p + d_gamma_d_nu * u2; - next_E_th_c1.at(b, state_receiver, h) = - d_g_d_th * hc1_p + g * E_th_c1_p - d_phi_d_th * hc2_p - phi * E_th_c2_p; - next_E_th_c2.at(b, state_receiver, h) = - d_g_d_th * hc2_p + g * E_th_c2_p + d_phi_d_th * hc1_p + phi * E_th_c1_p; - } - } - - __device__ static void emit_public_chunk( - int b, - int state_receiver, - int param_receiver, - int lane, - int lane_stride, - const void* staged_static, - const TensorTable& params, - const TensorTable& state_next, - float* raw_public_out, - int raw_public_dim, - int h0, - int h_count, - const float* reduced_stats) { - (void)params; - (void)param_receiver; - (void)reduced_stats; - const auto next_hc1 = tensor_ref(state_next, 0); - const auto next_hc2 = tensor_ref(state_next, 1); - const int hidden = static_cast(next_hc1.size[2]); - const bool use_staged_static = staged_static != nullptr; - const int activation_id = use_staged_static - ? *reinterpret_cast(staged_static) - : tensor_ref(params, 4).at(0); - const int h_end = min(raw_public_dim, h0 + h_count); - for (int out_h = h0 + lane; out_h < h_end; out_h += lane_stride) { - if (out_h < hidden) { - raw_public_out[out_h] = act_forward(activation_id, next_hc1.at(b, state_receiver, out_h)); - } else { - const int h = out_h - hidden; - raw_public_out[out_h] = h < hidden ? act_forward(activation_id, next_hc2.at(b, state_receiver, h)) : 0.0f; - } - } - } -}; - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/cells/axon.py b/src/cortical/fabric/backend/cuda/cells/axon.py deleted file mode 100644 index e1cfe201..00000000 --- a/src/cortical/fabric/backend/cuda/cells/axon.py +++ /dev/null @@ -1,10 +0,0 @@ -from cortical.fabric.contracts.cells import CellBackendImplementation -from cortical.fabric.registry.cell_backends import register_cell_backend_implementation - -register_cell_backend_implementation( - CellBackendImplementation( - cell_type="axoncell", - backend_name="cuda", - metadata={"native_cell_kind": 1}, - ) -) diff --git a/src/cortical/fabric/backend/cuda/cells/axon_registration.cu b/src/cortical/fabric/backend/cuda/cells/axon_registration.cu deleted file mode 100644 index f08cae82..00000000 --- a/src/cortical/fabric/backend/cuda/cells/axon_registration.cu +++ /dev/null @@ -1,18 +0,0 @@ -#include "cortical/fabric/backend/cuda/cells/axon.cuh" -#include "cortical/fabric/backend/cuda/registry/cell_registration_helpers.cuh" - -namespace fabric { - -namespace { - -struct AxonDispatchRegistration { - AxonDispatchRegistration() { - register_cell_core_dispatch_entry(1, make_cell_core_dispatch_entry()); - } -}; - -AxonDispatchRegistration kAxonDispatchRegistration; - -} // namespace - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/cells/slstm.cuh b/src/cortical/fabric/backend/cuda/cells/slstm.cuh deleted file mode 100644 index 5fa496fd..00000000 --- a/src/cortical/fabric/backend/cuda/cells/slstm.cuh +++ /dev/null @@ -1,332 +0,0 @@ -#pragma once - -#include - -#include -#include - -#include "cortical/fabric/backend/cuda/contracts/cell.cuh" - -namespace fabric { - -struct SLSTM { - static constexpr int kReductionStatsDim = 2; - static constexpr bool kSupportsEmitOnlyStateUpdate = true; - - __device__ static float sigmoidf_approx(float x) { - return 1.0f / (1.0f + expf(-x)); - } - - __device__ static float log_sigmoidf_approx(float x) { - return -log1pf(expf(-x)); - } - - static int state_static_bytes_host(const std::vector& params, int receivers) { - (void)params; - (void)receivers; - return 0; - } - - static int emit_static_bytes_host(const std::vector& params, int receivers) { - (void)receivers; - TORCH_CHECK(params.size() >= 5, "SLSTM CUDA emit staging requires outnorm params"); - const int hidden = static_cast(params[3].size(1)); - return static_cast(static_cast(hidden + 1) * sizeof(float)); - } - - static fabric::cuda::nn::CellTransitionIR cell_transition_ir_host( - const std::vector& params, - int receivers, - int projected_message_dim, - int raw_public_dim) { - (void)raw_public_dim; - TORCH_CHECK(params.size() >= 7, "SLSTM CUDA dense state affine lowering requires folded params"); - const int hidden = static_cast(params[5].size(1)); - const int gate_dim = static_cast(params[0].size(2)); - fabric::cuda::nn::Builder builder; - builder.private_state("state", 3); - builder.public_tensor("public", 3); - builder.parameter("input_gate_weight", 3); - builder.parameter("recurrent_gate_weight", 3); - builder.parameter("input_gate_bias", 2); - builder.state_affine( - fabric::cuda::nn::StateAffineSourceKind::ProjectedMessage, - -1, - 0, - 6, - receivers, - projected_message_dim, - gate_dim, - 1, - fabric::cuda::nn::ResetPolicy::None); - builder.state_affine( - fabric::cuda::nn::StateAffineSourceKind::StatePrev, - 0, - 5, - -1, - receivers, - hidden, - gate_dim, - 1, - fabric::cuda::nn::ResetPolicy::ZeroSourceRows); - builder.reduction_boundary(kReductionStatsDim); - return builder.build_cell_transition(); - } - - __device__ static size_t state_static_bytes(int receiver, const TensorTable& params) { - (void)receiver; - (void)params; - return 0; - } - - __device__ static void stage_state_static( - int receiver, - const TensorTable& params, - void* smem, - int lane, - int lane_stride) { - (void)receiver; - (void)params; - (void)smem; - (void)lane; - (void)lane_stride; - } - - __device__ static size_t emit_static_bytes(int receiver, const TensorTable& params) { - (void)receiver; - const auto outnorm_weight = tensor_ref(params, 3); - const int hidden = static_cast(outnorm_weight.size[1]); - return static_cast(hidden + 1) * sizeof(float); - } - - __device__ static void stage_emit_static( - int receiver, - const TensorTable& params, - void* smem, - int lane, - int lane_stride) { - const auto outnorm_weight = tensor_ref(params, 3); - float* dst = reinterpret_cast(smem); - const int hidden = static_cast(outnorm_weight.size[1]); - for (int h = lane; h < hidden; h += lane_stride) { - dst[h] = outnorm_weight.at(receiver, h); - } - if (hidden % lane_stride == lane) { - dst[hidden] = tensor_ref(params, 4).at(0); - } - } - - __device__ static void init_reduction_stats(float* stats) { - stats[0] = 0.0f; - stats[1] = 0.0f; - } - - __device__ static void accumulate_reduction_stats(float* stats, float y_new) { - stats[0] += y_new; - stats[1] += y_new * y_new; - } - - __device__ static void combine_reduction_stats(float* dst, const float* src) { - dst[0] += src[0]; - dst[1] += src[1]; - } - - __device__ static void finalize_reduction_stats( - float* stats, - int state_dim, - int lane, - int lane_stride) { - (void)stats; - (void)state_dim; - (void)lane; - (void)lane_stride; - } - - __device__ static int state_dim( - int projected_message_dim, - int raw_public_dim, - const TensorTable& state_prev) { - (void)projected_message_dim; - (void)raw_public_dim; - const auto y_prev = tensor_ref(state_prev, 0); - return static_cast(y_prev.size[2]); - } - - __device__ static void forward_state_chunk( - int b, - int state_receiver, - int param_receiver, - int aux_receiver, - bool reset_row, - int lane, - int lane_stride, - const void* staged_static, - const TensorTable& params, - const TensorTable& state_prev, - TensorTable& state_next, - const float* projected_message, - int projected_message_dim, - int h0, - int h_count, - float* reduction_stats, - TensorTable* aux) { - const auto y_prev = tensor_ref(state_prev, 0); - const auto c_prev = tensor_ref(state_prev, 1); - const auto n_prev = tensor_ref(state_prev, 2); - const auto m_prev = tensor_ref(state_prev, 3); - auto next_y = tensor_ref(state_next, 0); - auto next_c = tensor_ref(state_next, 1); - auto next_n = tensor_ref(state_next, 2); - auto next_m = tensor_ref(state_next, 3); - const int hidden = static_cast(y_prev.size[2]); - (void)staged_static; - (void)params; - (void)param_receiver; - (void)projected_message; - (void)projected_message_dim; - const bool use_dense_state_affines = aux != nullptr && aux->count >= 1; - if (!use_dense_state_affines) { - return; - } - const auto gate_affine = tensor_ref(*aux, 0); - init_reduction_stats(reduction_stats); - const int h_end = min(hidden, h0 + h_count); - for (int h = h0 + lane; h < h_end; h += lane_stride) { - const float ibar = gate_affine.at(b, aux_receiver, h); - const float fbar = gate_affine.at(b, aux_receiver, hidden + h); - const float zbar = gate_affine.at(b, aux_receiver, 2 * hidden + h); - const float obar = gate_affine.at(b, aux_receiver, 3 * hidden + h); - const float c_prev_v = reset_row ? 0.0f : c_prev.at(b, state_receiver, h); - const float n_prev_v = reset_row ? 0.0f : n_prev.at(b, state_receiver, h); - const float m_prev_v = reset_row ? 0.0f : m_prev.at(b, state_receiver, h); - const float logfplusm = m_prev_v + log_sigmoidf_approx(fbar); - const bool is_first = n_prev_v == 0.0f; - const float m_new = is_first ? ibar : fmaxf(ibar, logfplusm); - const float i = expf(fminf(ibar - m_new, 0.0f)); - const float f = expf(fminf(logfplusm - m_new, 0.0f)); - const float z = tanhf(zbar); - const float o = sigmoidf_approx(obar); - const float c_new = f * c_prev_v + i * z; - const float n_new = f * n_prev_v + i; - const float y_new = o * (c_new / (n_new + 1e-6f)); - next_y.at(b, state_receiver, h) = y_new; - next_c.at(b, state_receiver, h) = c_new; - next_n.at(b, state_receiver, h) = n_new; - next_m.at(b, state_receiver, h) = m_new; - accumulate_reduction_stats(reduction_stats, y_new); - } - } - - __device__ static void forward_state_lane_value( - int b, - int state_receiver, - int param_receiver, - int aux_receiver, - bool reset_row, - int lane, - int lane_stride, - const void* staged_static, - const TensorTable& params, - const TensorTable& state_prev, - const float* projected_message, - int projected_message_dim, - const TensorTable& aux, - int h0, - int h_count, - float* reduction_stats, - int* h_out, - float* y_out) { - (void)lane_stride; - (void)staged_static; - (void)params; - (void)param_receiver; - (void)projected_message; - (void)projected_message_dim; - const auto c_prev = tensor_ref(state_prev, 1); - const auto n_prev = tensor_ref(state_prev, 2); - const auto m_prev = tensor_ref(state_prev, 3); - const auto gate_affine = tensor_ref(aux, 0); - const int hidden = static_cast(c_prev.size[2]); - init_reduction_stats(reduction_stats); - *h_out = -1; - *y_out = 0.0f; - const int h = h0 + lane; - if (h >= min(hidden, h0 + h_count)) { - return; - } - const float ibar = gate_affine.at(b, aux_receiver, h); - const float fbar = gate_affine.at(b, aux_receiver, hidden + h); - const float zbar = gate_affine.at(b, aux_receiver, 2 * hidden + h); - const float obar = gate_affine.at(b, aux_receiver, 3 * hidden + h); - const float c_prev_v = reset_row ? 0.0f : c_prev.at(b, state_receiver, h); - const float n_prev_v = reset_row ? 0.0f : n_prev.at(b, state_receiver, h); - const float m_prev_v = reset_row ? 0.0f : m_prev.at(b, state_receiver, h); - const float logfplusm = m_prev_v + log_sigmoidf_approx(fbar); - const bool is_first = n_prev_v == 0.0f; - const float m_new = is_first ? ibar : fmaxf(ibar, logfplusm); - const float i = expf(fminf(ibar - m_new, 0.0f)); - const float f = expf(fminf(logfplusm - m_new, 0.0f)); - const float z = tanhf(zbar); - const float o = sigmoidf_approx(obar); - const float c_new = f * c_prev_v + i * z; - const float n_new = f * n_prev_v + i; - const float y_new = o * (c_new / (n_new + 1e-6f)); - *h_out = h; - *y_out = y_new; - accumulate_reduction_stats(reduction_stats, y_new); - } - - __device__ static void emit_public_lane_value( - int receiver, - int h, - float y_value, - const void* staged_static, - const TensorTable& params, - float* raw_public_out, - const float* reduced_stats) { - const bool use_staged_static = staged_static != nullptr; - const float* outnorm_weight = reinterpret_cast(staged_static); - const auto outnorm_weight_ref = tensor_ref(params, 3); - const int hidden = static_cast(outnorm_weight_ref.size[1]); - const float outnorm_eps = use_staged_static ? outnorm_weight[hidden] : tensor_ref(params, 4).at(0); - const float inv_hidden = 1.0f / static_cast(hidden > 0 ? hidden : 1); - const float mean = reduced_stats[0] * inv_hidden; - const float var = fmaxf(reduced_stats[1] * inv_hidden - mean * mean, 0.0f); - const float centered = y_value - mean; - const float outnorm = use_staged_static ? outnorm_weight[h] : outnorm_weight_ref.at(receiver, h); - raw_public_out[h] = centered * rsqrtf(var + outnorm_eps) * outnorm; - } - - __device__ static void emit_public_chunk( - int b, - int state_receiver, - int param_receiver, - int lane, - int lane_stride, - const void* staged_static, - const TensorTable& params, - const TensorTable& state_next, - float* raw_public_out, - int raw_public_dim, - int h0, - int h_count, - const float* reduced_stats) { - const auto next_y = tensor_ref(state_next, 0); - const int hidden = static_cast(next_y.size[2]); - const bool use_staged_static = staged_static != nullptr; - const float* outnorm_weight = reinterpret_cast(staged_static); - const auto outnorm_weight_ref = tensor_ref(params, 3); - const float outnorm_eps = use_staged_static ? outnorm_weight[hidden] : tensor_ref(params, 4).at(0); - const float inv_hidden = 1.0f / static_cast(hidden > 0 ? hidden : 1); - const float mean = reduced_stats[0] * inv_hidden; - const float var = fmaxf(reduced_stats[1] * inv_hidden - mean * mean, 0.0f); - const int h_end = min(min(hidden, raw_public_dim), h0 + h_count); - for (int h = h0 + lane; h < h_end; h += lane_stride) { - const float centered = next_y.at(b, state_receiver, h) - mean; - const float outnorm = use_staged_static ? outnorm_weight[h] : outnorm_weight_ref.at(param_receiver, h); - raw_public_out[h] = centered * rsqrtf(var + outnorm_eps) * outnorm; - } - } -}; - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/cells/slstm.py b/src/cortical/fabric/backend/cuda/cells/slstm.py deleted file mode 100644 index cf33a1a6..00000000 --- a/src/cortical/fabric/backend/cuda/cells/slstm.py +++ /dev/null @@ -1,10 +0,0 @@ -from cortical.fabric.contracts.cells import CellBackendImplementation -from cortical.fabric.registry.cell_backends import register_cell_backend_implementation - -register_cell_backend_implementation( - CellBackendImplementation( - cell_type="slstm", - backend_name="cuda", - metadata={"native_cell_kind": 0}, - ) -) diff --git a/src/cortical/fabric/backend/cuda/cells/slstm_registration.cu b/src/cortical/fabric/backend/cuda/cells/slstm_registration.cu deleted file mode 100644 index 2ec03b42..00000000 --- a/src/cortical/fabric/backend/cuda/cells/slstm_registration.cu +++ /dev/null @@ -1,18 +0,0 @@ -#include "cortical/fabric/backend/cuda/cells/slstm.cuh" -#include "cortical/fabric/backend/cuda/registry/cell_registration_helpers.cuh" - -namespace fabric { - -namespace { - -struct SLSTMDispatchRegistration { - SLSTMDispatchRegistration() { - register_cell_core_dispatch_entry(0, make_cell_core_dispatch_entry()); - } -}; - -SLSTMDispatchRegistration kSLSTMDispatchRegistration; - -} // namespace - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/contracts/cell.cuh b/src/cortical/fabric/backend/cuda/contracts/cell.cuh index cd48a4e2..9d2cad9b 100644 --- a/src/cortical/fabric/backend/cuda/contracts/cell.cuh +++ b/src/cortical/fabric/backend/cuda/contracts/cell.cuh @@ -1,6 +1,6 @@ #pragma once -#include "cortical/fabric/backend/cuda/execution/common.cuh" +#include "cortical/fabric/backend/cuda/contracts/common.cuh" #include "cortical/fabric/backend/cuda/nn/ir.cuh" namespace fabric { diff --git a/src/cortical/fabric/backend/cuda/execution/common.cuh b/src/cortical/fabric/backend/cuda/contracts/common.cuh similarity index 64% rename from src/cortical/fabric/backend/cuda/execution/common.cuh rename to src/cortical/fabric/backend/cuda/contracts/common.cuh index 3d685407..331ba510 100644 --- a/src/cortical/fabric/backend/cuda/execution/common.cuh +++ b/src/cortical/fabric/backend/cuda/contracts/common.cuh @@ -22,6 +22,12 @@ enum class TemporalExecution : int { PersistentScan = 1, }; +enum class TemporalScanOwner : int { + SingleStep = 0, + BackendHostLoop = 1, + CudaTemporalSuperOp = 2, +}; + enum class ReadoutMode : int { Skip = 0, SeparatePortOwned = 1, @@ -151,4 +157,57 @@ struct ExecutionPlan { bool emit_readout; }; +struct TemporalScanDescriptor { + int outer_time_steps; + int inner_steps; + int physical_time_steps; + int emission_count; + int first_emission_step; + int emission_stride; + bool terminal_only; + TemporalScanOwner owner; +}; + +__host__ __device__ inline TemporalScanDescriptor make_temporal_scan_descriptor( + int outer_time_steps, + int inner_steps, + bool terminal_only, + TemporalScanOwner owner) { + const int safe_outer_time_steps = outer_time_steps > 0 ? outer_time_steps : 0; + const int safe_inner_steps = inner_steps > 0 ? inner_steps : 1; + const int physical_time_steps = safe_outer_time_steps * safe_inner_steps; + const int emission_count = terminal_only ? (safe_outer_time_steps > 0 ? 1 : 0) : safe_outer_time_steps; + return TemporalScanDescriptor{ + safe_outer_time_steps, + safe_inner_steps, + physical_time_steps, + emission_count, + safe_inner_steps - 1, + safe_inner_steps, + terminal_only, + owner, + }; +} + +__host__ __device__ inline bool temporal_scan_emits_step( + const TemporalScanDescriptor& scan, + int physical_step) { + if (physical_step < 0 || physical_step >= scan.physical_time_steps) { + return false; + } + if (scan.terminal_only) { + return physical_step + 1 == scan.physical_time_steps; + } + return physical_step % scan.emission_stride == scan.first_emission_step; +} + +__host__ __device__ inline int temporal_scan_output_index( + const TemporalScanDescriptor& scan, + int physical_step) { + if (!temporal_scan_emits_step(scan, physical_step)) { + return -1; + } + return scan.terminal_only ? 0 : physical_step / scan.emission_stride; +} + } // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/execution/__init__.py b/src/cortical/fabric/backend/cuda/execution/__init__.py deleted file mode 100644 index 650f0eac..00000000 --- a/src/cortical/fabric/backend/cuda/execution/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -from cortical.fabric.backend.cuda.execution import backend_registrations as _backend_registrations # noqa: F401 -from cortical.fabric.backend.cuda.execution.dispatcher_cuda import ( - last_forward_carry_checkpoints, - last_launch_metadata, - normalize_launch_request, -) -from cortical.fabric.backend.cuda.execution.output_readout_cuda import project_output_sequence_from_banks -from cortical.fabric.backend.cuda.execution.registry import ( - ExecutionRegistryEntry, - ExecutionVariantSpec, - FabricExecutionRequest, - ForwardCarryCheckpoints, - register_execution_backend, - run_registered_execution, -) -from cortical.fabric.backend.cuda.execution.tensor_pack import ( - PackedTensorTable, - flatten_tensor_tree, - pack_tensor_tree, - rebuild_tensor_tree, -) - -__all__ = [ - "ExecutionRegistryEntry", - "ExecutionVariantSpec", - "FabricExecutionRequest", - "ForwardCarryCheckpoints", - "PackedTensorTable", - "flatten_tensor_tree", - "last_forward_carry_checkpoints", - "last_launch_metadata", - "normalize_launch_request", - "pack_tensor_tree", - "project_output_sequence_from_banks", - "register_execution_backend", - "rebuild_tensor_tree", - "run_registered_execution", -] diff --git a/src/cortical/fabric/backend/cuda/execution/backend_registrations.py b/src/cortical/fabric/backend/cuda/execution/backend_registrations.py deleted file mode 100644 index 0c0778f3..00000000 --- a/src/cortical/fabric/backend/cuda/execution/backend_registrations.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -from cortical.fabric.backend.cuda.execution.dispatcher_cuda import run_backend_dispatch_forward -from cortical.fabric.backend.cuda.execution.registry import ( - ExecutionRegistryEntry, - ExecutionVariantSpec, - FabricExecutionRequest, - register_execution_backend, - request_has_generic_dispatch_contract, -) -from cortical.fabric.backend.reuse import MathBackend - - -def _supports_generic_dispatch(request: FabricExecutionRequest) -> bool: - return request_has_generic_dispatch_contract(request) - - -def _run_receiver_owned_stepwise(request: FabricExecutionRequest) -> tuple[object, ...]: - return run_backend_dispatch_forward( - request, - spatial_ownership="receiver_owned", - temporal_execution="stepwise", - ) - - -def _run_receiver_owned_persistent_scan(request: FabricExecutionRequest) -> tuple[object, ...]: - return run_backend_dispatch_forward( - request, - spatial_ownership="receiver_owned", - temporal_execution="persistent_scan", - ) - - -def _run_edge_owned_persistent_scan(request: FabricExecutionRequest) -> tuple[object, ...]: - return run_backend_dispatch_forward( - request, - spatial_ownership="edge_owned", - temporal_execution="persistent_scan", - ) - - -register_execution_backend( - ExecutionRegistryEntry( - variant=ExecutionVariantSpec( - spatial_ownership="receiver_owned", - temporal_execution="stepwise", - math_backend=MathBackend.MICROKERNEL, - ), - runner=_run_receiver_owned_stepwise, - supports=_supports_generic_dispatch, - ) -) - -register_execution_backend( - ExecutionRegistryEntry( - variant=ExecutionVariantSpec( - spatial_ownership="receiver_owned", - temporal_execution="persistent_scan", - math_backend=MathBackend.MICROKERNEL, - ), - runner=_run_receiver_owned_persistent_scan, - supports=_supports_generic_dispatch, - ) -) - -register_execution_backend( - ExecutionRegistryEntry( - variant=ExecutionVariantSpec( - spatial_ownership="receiver_owned", - temporal_execution="persistent_scan", - math_backend=MathBackend.GROUPED_GEMM, - ), - runner=_run_receiver_owned_persistent_scan, - supports=_supports_generic_dispatch, - ) -) - -register_execution_backend( - ExecutionRegistryEntry( - variant=ExecutionVariantSpec( - spatial_ownership="edge_owned", - temporal_execution="persistent_scan", - math_backend=MathBackend.GROUPED_GEMM, - ), - runner=_run_edge_owned_persistent_scan, - supports=_supports_generic_dispatch, - ) -) diff --git a/src/cortical/fabric/backend/cuda/execution/dispatcher.cpp b/src/cortical/fabric/backend/cuda/execution/dispatcher.cpp deleted file mode 100644 index 888c46ab..00000000 --- a/src/cortical/fabric/backend/cuda/execution/dispatcher.cpp +++ /dev/null @@ -1,4927 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "cortical/fabric/backend/cuda/execution/common.cuh" -#include "cortical/fabric/backend/cuda/ops/dense_affine.cuh" -#include "cortical/fabric/backend/cuda/ops/dense_message.cuh" -#include "cortical/fabric/backend/cuda/ops/diagonal_recurrence.cuh" -#include "cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cuh" - -namespace py = pybind11; - -namespace fabric { - -using StateAffineDeclaration = fabric::cuda::nn::StateAffineDeclaration; -using StateAffineSourceKind = fabric::cuda::nn::StateAffineSourceKind; - -void launch_edge_owned_accumulate_stepwise_cuda( - int message_backend_id, - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* max_buffer, - float* msg_buffer, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - cudaStream_t stream); - -void launch_readout_message_cuda( - TensorTable input_ports, - TensorTable public_now, - ReadoutSpec readout, - at::Tensor output_msg, - ExecutionPlan plan, - int head_dim, - int value_dim, - int t, - int recurrent_receiver_offset, - cudaStream_t stream); - -void launch_readout_message_from_raw_public_cuda( - TensorTable input_ports, - const at::Tensor& raw_public, - int public_projection_kind, - TensorTable public_projection_params, - ReadoutSpec readout, - at::Tensor output_msg, - ExecutionPlan plan, - int value_dim, - int t, - int recurrent_receiver_offset, - cudaStream_t stream); - -void launch_receiver_message_aggregate_cuda( - int message_backend_id, - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* message_out, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - cudaStream_t stream); - -void launch_receiver_normalize_accumulated_message_cuda( - const float* msg_buffer, - float* message_out, - ExecutionPlan plan, - cudaStream_t stream); - -namespace { - -TensorTable unpack_table(const py::tuple& packed) { - TORCH_CHECK(packed.size() == 4, "packed tensor table must contain ptrs/sizes/strides/ndims"); - const auto ptrs = packed[0].cast(); - const auto sizes = packed[1].cast(); - const auto strides = packed[2].cast(); - const auto ndims = packed[3].cast(); - return TensorTable{ - ptrs.numel() > 0 ? reinterpret_cast(ptrs.data_ptr()) : nullptr, - sizes.numel() > 0 ? sizes.data_ptr() : nullptr, - strides.numel() > 0 ? strides.data_ptr() : nullptr, - ndims.numel() > 0 ? ndims.data_ptr() : nullptr, - static_cast(ptrs.numel()), - }; -} - -MessageTopology unpack_topology( - const at::Tensor& receiver_ptr, - const at::Tensor& sender_idx, - const at::Tensor& edge_delay, - const at::Tensor& edge_weight, - int num_input_ports) { - return MessageTopology{ - receiver_ptr.numel() > 0 ? receiver_ptr.data_ptr() : nullptr, - sender_idx.numel() > 0 ? sender_idx.data_ptr() : nullptr, - edge_delay.numel() > 0 ? edge_delay.data_ptr() : nullptr, - edge_weight.numel() > 0 ? edge_weight.data_ptr() : nullptr, - static_cast(sender_idx.numel()), - num_input_ports, - }; -} - -ExecutionPlan make_plan( - int spatial_id, - int temporal_id, - int B, - int T, - int receivers, - int edges, - int output_ports, - int message_dim, - int public_dim, - int receiver_tile, - int batch_tile, - int edge_tile, - int hidden_chunk, - int state_receiver_tile, - int state_batch_tile, - int state_hidden_chunk, - int state_static_stage_mode, - int emit_receiver_tile, - int emit_batch_tile, - int emit_hidden_chunk, - int emit_static_stage_mode, - int public_receiver_tile, - int public_batch_tile, - int readout_mode, - int readout_port_tile, - int readout_output_chunk, - int cell_static_stage_mode, - int replication_factor, - bool stage_receiver_static, - bool emit_readout) { - TORCH_CHECK( - readout_mode == static_cast(ReadoutMode::Skip) || - readout_mode == static_cast(ReadoutMode::SeparatePortOwned), - "unsupported Fabric readout mode id: ", - readout_mode); - TORCH_CHECK( - cell_static_stage_mode == static_cast(CellStaticStageMode::Disabled) || - cell_static_stage_mode == static_cast(CellStaticStageMode::SharedFull), - "unsupported Fabric cell static stage mode id: ", - cell_static_stage_mode); - TORCH_CHECK( - state_static_stage_mode == static_cast(CellStaticStageMode::Disabled) || - state_static_stage_mode == static_cast(CellStaticStageMode::SharedFull), - "unsupported Fabric state static stage mode id: ", - state_static_stage_mode); - TORCH_CHECK( - emit_static_stage_mode == static_cast(CellStaticStageMode::Disabled) || - emit_static_stage_mode == static_cast(CellStaticStageMode::SharedFull), - "unsupported Fabric emit static stage mode id: ", - emit_static_stage_mode); - return ExecutionPlan{ - static_cast(spatial_id), - static_cast(temporal_id), - B, - T, - receivers, - edges, - output_ports, - message_dim, - public_dim, - receiver_tile, - batch_tile, - edge_tile, - hidden_chunk, - state_receiver_tile, - state_batch_tile, - state_hidden_chunk, - static_cast(state_static_stage_mode), - emit_receiver_tile, - emit_batch_tile, - emit_hidden_chunk, - static_cast(emit_static_stage_mode), - public_receiver_tile, - public_batch_tile, - static_cast(readout_mode), - readout_port_tile, - readout_output_chunk, - static_cast(cell_static_stage_mode), - replication_factor, - stage_receiver_static, - emit_readout, - }; -} - -at::Tensor tuple_tensor(const py::tuple& tensors, int idx) { - TORCH_CHECK(idx >= 0 && idx < tensors.size(), "tuple index out of range"); - return tensors[idx].cast(); -} - -std::vector tuple_tensors(const py::tuple& tensors) { - std::vector out; - out.reserve(tensors.size()); - for (const py::handle tensor : tensors) { - out.push_back(tensor.cast()); - } - return out; -} - -int tuple_source_index_or_identity(const py::tuple& indices, ssize_t logical_index) { - if (indices.size() == 0) { - return static_cast(logical_index); - } - TORCH_CHECK( - logical_index >= 0 && logical_index < indices.size(), - "Fabric carry checkpoint source-index tuple is shorter than checkpoint output tuple"); - return indices[logical_index].cast(); -} - -const char* stage_mode_name(CellStaticStageMode mode) { - switch (mode) { - case CellStaticStageMode::Disabled: - return "disabled"; - case CellStaticStageMode::SharedFull: - return "shared_full"; - } - return "unknown"; -} - -const char* readout_mode_name(ReadoutMode mode) { - switch (mode) { - case ReadoutMode::Skip: - return "skip"; - case ReadoutMode::SeparatePortOwned: - return "separate_port_owned"; - } - return "unknown"; -} - -py::dict launch_metadata(const ExecutionPlan& plan) { - py::dict metadata; - metadata["receiver_tiles"] = py::make_tuple(plan.receiver_tile); - metadata["batch_tiles"] = py::make_tuple(plan.batch_tile); - metadata["edge_tiles"] = py::make_tuple(plan.edge_tile); - metadata["hidden_chunks"] = py::make_tuple(plan.hidden_chunk); - metadata["state_receiver_tiles"] = py::make_tuple(plan.state_receiver_tile); - metadata["state_batch_tiles"] = py::make_tuple(plan.state_batch_tile); - metadata["state_hidden_chunks"] = py::make_tuple(plan.state_hidden_chunk); - metadata["state_static_stage_modes"] = py::make_tuple(stage_mode_name(plan.state_static_stage_mode)); - metadata["emit_receiver_tiles"] = py::make_tuple(plan.emit_receiver_tile); - metadata["emit_batch_tiles"] = py::make_tuple(plan.emit_batch_tile); - metadata["emit_hidden_chunks"] = py::make_tuple(plan.emit_hidden_chunk); - metadata["emit_static_stage_modes"] = py::make_tuple(stage_mode_name(plan.emit_static_stage_mode)); - metadata["public_receiver_tiles"] = py::make_tuple(plan.public_receiver_tile); - metadata["public_batch_tiles"] = py::make_tuple(plan.public_batch_tile); - metadata["replication_factors"] = py::make_tuple(plan.replication_factor); - metadata["cell_static_stage_modes"] = py::make_tuple(stage_mode_name(plan.cell_static_stage_mode)); - metadata["readout_modes"] = py::make_tuple(readout_mode_name(plan.readout_mode)); - return metadata; -} - -at::Tensor optional_tuple_tensor(const py::tuple& tensors, int idx, const at::Tensor& like) { - if (idx < 0 || idx >= tensors.size()) { - return at::empty({0}, like.options()); - } - return tensors[idx].cast(); -} - -struct RuntimeTensorTable { - TensorTable table; - std::vector metadata; -}; - -RuntimeTensorTable pack_runtime_tensor_table(const std::vector& tensors, const at::Tensor& like) { - if (tensors.empty()) { - auto ptrs = at::empty({0}, like.options().dtype(at::kLong)); - auto sizes = at::empty({0, kMaxRank}, like.options().dtype(at::kLong)); - auto strides = at::empty({0, kMaxRank}, like.options().dtype(at::kLong)); - auto ndims = at::empty({0}, like.options().dtype(at::kInt)); - return RuntimeTensorTable{ - TensorTable{nullptr, nullptr, nullptr, nullptr, 0}, - {ptrs, sizes, strides, ndims}, - }; - } - std::vector ptr_values; - std::vector size_values; - std::vector stride_values; - std::vector ndim_values; - ptr_values.reserve(tensors.size()); - size_values.reserve(tensors.size() * kMaxRank); - stride_values.reserve(tensors.size() * kMaxRank); - ndim_values.reserve(tensors.size()); - for (const at::Tensor& tensor : tensors) { - TORCH_CHECK(tensor.is_cuda(), "runtime tensor table only supports CUDA tensors"); - ptr_values.push_back(reinterpret_cast(tensor.data_ptr())); - ndim_values.push_back(static_cast(tensor.dim())); - for (int dim = 0; dim < kMaxRank; ++dim) { - size_values.push_back(dim < tensor.dim() ? tensor.size(dim) : 1); - stride_values.push_back(dim < tensor.dim() ? tensor.stride(dim) : 1); - } - } - auto ptrs = at::tensor(ptr_values, like.options().dtype(at::kLong)); - auto sizes = - at::tensor(size_values, like.options().dtype(at::kLong)).reshape({static_cast(tensors.size()), kMaxRank}); - auto strides = - at::tensor(stride_values, like.options().dtype(at::kLong)).reshape({static_cast(tensors.size()), kMaxRank}); - auto ndims = at::tensor(ndim_values, like.options().dtype(at::kInt)); - return RuntimeTensorTable{ - TensorTable{ - reinterpret_cast(ptrs.data_ptr()), - sizes.data_ptr(), - strides.data_ptr(), - ndims.data_ptr(), - static_cast(tensors.size()), - }, - {ptrs, sizes, strides, ndims}, - }; -} - -struct PublicProjectionExecution { - const char* hidden_backend; - const char* kv_backend; - const char* hidden_copy_mode; - const char* kv_split_mode; - const char* coalescing_mode; - int64_t launch_count; - int64_t small_cublas_launch_count; - int64_t copy_glue_launch_count; - int64_t copy_glue_saved_launches; - int64_t bias_glue_launch_count; - int64_t bias_glue_saved_launches; -}; - -PublicProjectionExecution readout_narrow_public_projection_execution() { - return PublicProjectionExecution{ - "not_materialized", - "readout_narrow_projected", - "not_materialized", - "readout_narrow_projected", - "readout_narrow_projected", - 0, - 0, - 0, - 2, - 0, - 0}; -} - -bool backend_uses_cublas(const char* backend); - -int64_t tensor_scalar_i64(const at::Tensor& tensor, int64_t default_value) { - if (!tensor.defined() || tensor.numel() == 0) { - return default_value; - } - return tensor.item(); -} - -int64_t previous_public_hidden_copy_launches(const at::Tensor& raw_public, const at::Tensor& hidden_out) { - return raw_public.size(2) == hidden_out.size(2) ? 1 : 2; -} - -bool raw_public_can_alias_public_hidden( - int public_projection_kind, - const at::Tensor& public_hidden_out, - int B, - int receivers, - int raw_public_dim) { - return public_projection_kind == 0 && public_hidden_out.defined() && public_hidden_out.dim() == 3 && - public_hidden_out.size(0) == B && public_hidden_out.size(1) == receivers && - public_hidden_out.size(2) == raw_public_dim && public_hidden_out.stride(2) == 1; -} - -bool can_use_direct_biased_receiver_affine( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - const at::Tensor& output) { - return input.defined() && weight.defined() && bias.defined() && output.defined() && - input.is_cuda() && weight.is_cuda() && bias.is_cuda() && output.is_cuda() && - input.scalar_type() == at::kFloat && weight.scalar_type() == at::kFloat && - bias.scalar_type() == at::kFloat && output.scalar_type() == at::kFloat && - input.dim() == 3 && weight.dim() == 3 && output.dim() == 3 && - input.size(0) == output.size(0) && input.size(1) == output.size(1) && - weight.size(0) == input.size(1) && weight.size(1) == input.size(2) && - weight.size(2) == output.size(2) && - (bias.dim() == 1 || bias.dim() == 2) && - (bias.dim() == 1 ? bias.size(0) == output.size(2) - : bias.size(1) == output.size(2) && (bias.size(0) == 1 || bias.size(0) == input.size(1))) && - input.stride(2) == 1 && weight.stride(2) == 1 && output.stride(2) == 1 && - bias.stride(bias.dim() - 1) == 1 && - output.size(2) > 0; -} - -bool can_use_direct_biased_receiver_affine_split( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - const at::Tensor& output_a, - const at::Tensor& output_b) { - return input.defined() && weight.defined() && bias.defined() && output_a.defined() && output_b.defined() && - input.is_cuda() && weight.is_cuda() && bias.is_cuda() && output_a.is_cuda() && output_b.is_cuda() && - input.scalar_type() == at::kFloat && weight.scalar_type() == at::kFloat && - bias.scalar_type() == at::kFloat && output_a.scalar_type() == at::kFloat && output_b.scalar_type() == at::kFloat && - input.dim() == 3 && weight.dim() == 3 && output_a.dim() == 3 && output_b.dim() == 3 && - input.size(0) == output_a.size(0) && input.size(0) == output_b.size(0) && - input.size(1) == output_a.size(1) && input.size(1) == output_b.size(1) && - weight.size(0) == input.size(1) && weight.size(1) == input.size(2) && - weight.size(2) == output_a.size(2) + output_b.size(2) && - (bias.dim() == 1 || bias.dim() == 2) && - (bias.dim() == 1 - ? bias.size(0) == weight.size(2) - : bias.size(1) == weight.size(2) && (bias.size(0) == 1 || bias.size(0) == input.size(1))) && - input.stride(2) == 1 && weight.stride(2) == 1 && output_a.stride(2) == 1 && output_b.stride(2) == 1 && - bias.stride(bias.dim() - 1) == 1 && output_a.size(2) > 0 && output_b.size(2) > 0; -} - -void launch_direct_biased_receiver_affine( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - const at::Tensor& output) { - py::gil_scoped_acquire gil; - static py::object direct_affine = - py::module_::import("cortical.fabric.backend.cuda.ops.receiver_major_affine_triton") - .attr("receiver_major_affine_bias_out_cuda"); - direct_affine(input, weight, bias, output); -} - -void launch_direct_biased_receiver_affine_split( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, - const at::Tensor& output_a, - const at::Tensor& output_b) { - py::gil_scoped_acquire gil; - static py::object direct_affine_split = - py::module_::import("cortical.fabric.backend.cuda.ops.receiver_major_affine_triton") - .attr("receiver_major_affine_bias_split_out_cuda"); - direct_affine_split(input, weight, bias, output_a, output_b); -} - -at::Tensor launch_grouped_public_projection( - const at::Tensor& input, - const at::Tensor& grouped_weight, - int64_t group_size, - int64_t receiver_offset) { - py::gil_scoped_acquire gil; - static py::object grouped_projection = - py::module_::import("cortical.fabric.backend.cuda.projection.grouped_projection_cuda") - .attr("fabric_grouped_projection_forward_cuda"); - return grouped_projection( - input, - grouped_weight, - py::arg("group_size") = group_size, - py::arg("receiver_offset") = receiver_offset) - .cast(); -} - -at::Tensor public_projection_receiver_window( - const at::Tensor& tensor, - int64_t receiver_offset, - int64_t receiver_count) { - if (!tensor.defined() || tensor.numel() == 0 || receiver_offset == 0) { - return tensor; - } - if (tensor.dim() >= 1 && tensor.size(0) >= receiver_offset + receiver_count && - tensor.size(0) != receiver_count) { - return tensor.slice(/*dim=*/0, receiver_offset, receiver_offset + receiver_count); - } - return tensor; -} - -PublicProjectionExecution launch_dense_public_projection( - int public_projection_kind, - const at::Tensor& raw_public, - const py::tuple& public_projection_tensors, - const py::tuple& public_next_tensors, - int64_t receiver_offset = 0) { - RECORD_FUNCTION("fabric.physical.public_projection", std::vector()); - auto hidden_out = tuple_tensor(public_next_tensors, 0); - auto recurrent_k_out = tuple_tensor(public_next_tensors, 1); - auto recurrent_v_out = tuple_tensor(public_next_tensors, 2); - const int64_t head_dim = recurrent_k_out.size(2); - const int64_t value_dim = recurrent_v_out.size(2); - const char* hidden_backend = "copy"; - const char* kv_backend = "none"; - if (public_projection_kind == 0) { - const bool hidden_aliases_raw_public = - hidden_out.data_ptr() == raw_public.data_ptr() && hidden_out.sizes() == raw_public.sizes(); - if (hidden_aliases_raw_public) { - hidden_backend = "alias"; - } else { - fabric::cuda::ops::dense_affine_receiver_major_copy_or_pad_out_cuda(raw_public, hidden_out); - } - const at::Tensor direct_weight = public_projection_receiver_window( - optional_tuple_tensor(public_projection_tensors, 0, raw_public), - receiver_offset, - raw_public.size(1)); - const at::Tensor grouped_weight = optional_tuple_tensor(public_projection_tensors, 1, raw_public); - const at::Tensor group_size_tensor = optional_tuple_tensor(public_projection_tensors, 2, raw_public); - const bool use_grouped = grouped_weight.defined() && grouped_weight.numel() > 0 && - tensor_scalar_i64(group_size_tensor, 1) > 1; - const at::Tensor& kv_weight = use_grouped ? grouped_weight : direct_weight; - TORCH_CHECK(kv_weight.defined() && kv_weight.numel() > 0, "dense public projection requires kv weight"); - const int64_t group_size = use_grouped ? tensor_scalar_i64(group_size_tensor, 1) : 1; - at::Tensor kv_all; - if (use_grouped) { - kv_all = launch_grouped_public_projection(raw_public, grouped_weight, group_size, receiver_offset); - kv_backend = "grouped_projection_forward"; - } else { - { - RECORD_FUNCTION("fabric.glue.public_projection_kv_workspace", std::vector()); - kv_all = at::empty({raw_public.size(0), raw_public.size(1), head_dim + value_dim}, raw_public.options()); - } - const auto backend = fabric::cuda::ops::dense_affine_out_cuda( - raw_public, - kv_weight, - at::Tensor(), - kv_all, - fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, - group_size, - fabric::cuda::ops::DenseAffineOutputMode::Overwrite); - kv_backend = fabric::cuda::ops::dense_affine_backend_name(backend); - } - fabric::cuda::ops::dense_affine_receiver_major_split_last_dim_out_cuda(kv_all, recurrent_k_out, recurrent_v_out); - return PublicProjectionExecution{ - hidden_backend, - kv_backend, - hidden_aliases_raw_public ? "aliased_raw_public" : "copy_or_pad_kernel", - "split_last_dim_kernel", - use_grouped ? "grouped_projection_forward" : "no_identical_signature", - hidden_aliases_raw_public ? 2 : 3, - use_grouped ? 0 : (backend_uses_cublas(kv_backend) ? 1 : 0), - hidden_aliases_raw_public ? 1 : 2, - previous_public_hidden_copy_launches(raw_public, hidden_out), - 0, - 0}; - } - - const at::Tensor hidden_weight = public_projection_receiver_window( - tuple_tensor(public_projection_tensors, 0), receiver_offset, raw_public.size(1)); - const at::Tensor hidden_bias = public_projection_receiver_window( - tuple_tensor(public_projection_tensors, 1), receiver_offset, raw_public.size(1)); - const at::Tensor kv_weight = public_projection_receiver_window( - tuple_tensor(public_projection_tensors, 2), receiver_offset, raw_public.size(1)); - const at::Tensor kv_bias = public_projection_receiver_window( - tuple_tensor(public_projection_tensors, 3), receiver_offset, raw_public.size(1)); - const bool hidden_direct_bias = - can_use_direct_biased_receiver_affine(raw_public, hidden_weight, hidden_bias, hidden_out); - if (hidden_direct_bias) { - launch_direct_biased_receiver_affine(raw_public, hidden_weight, hidden_bias, hidden_out); - hidden_backend = "direct_biased_receiver_affine"; - } else { - const auto hidden_dense_backend = - fabric::cuda::ops::dense_affine_out_cuda( - raw_public, - hidden_weight, - hidden_bias, - hidden_out, - fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, - 1, - fabric::cuda::ops::DenseAffineOutputMode::Overwrite); - hidden_backend = fabric::cuda::ops::dense_affine_backend_name(hidden_dense_backend); - } - const bool kv_direct_bias_split = - can_use_direct_biased_receiver_affine_split(raw_public, kv_weight, kv_bias, recurrent_k_out, recurrent_v_out); - at::Tensor kv_all; - if (kv_direct_bias_split) { - launch_direct_biased_receiver_affine_split(raw_public, kv_weight, kv_bias, recurrent_k_out, recurrent_v_out); - kv_backend = "direct_biased_receiver_affine_split_outputs"; - } else { - { - RECORD_FUNCTION("fabric.glue.public_projection_kv_workspace", std::vector()); - kv_all = at::empty({raw_public.size(0), raw_public.size(1), head_dim + value_dim}, raw_public.options()); - } - } - const bool kv_direct_bias = - !kv_direct_bias_split && can_use_direct_biased_receiver_affine(raw_public, kv_weight, kv_bias, kv_all); - if (kv_direct_bias) { - launch_direct_biased_receiver_affine(raw_public, kv_weight, kv_bias, kv_all); - kv_backend = "direct_biased_receiver_affine"; - } else if (!kv_direct_bias_split) { - const auto kv_dense_backend = - fabric::cuda::ops::dense_affine_out_cuda( - raw_public, - kv_weight, - kv_bias, - kv_all, - fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, - 1, - fabric::cuda::ops::DenseAffineOutputMode::Overwrite); - kv_backend = fabric::cuda::ops::dense_affine_backend_name(kv_dense_backend); - } - if (!kv_direct_bias_split) { - fabric::cuda::ops::dense_affine_receiver_major_split_last_dim_out_cuda(kv_all, recurrent_k_out, recurrent_v_out); - } - const int64_t bias_glue_launches = - (hidden_direct_bias || !hidden_bias.defined() || hidden_bias.numel() == 0 ? 0 : 1) + - (kv_direct_bias || kv_direct_bias_split || !kv_bias.defined() || kv_bias.numel() == 0 ? 0 : 1); - const int64_t previous_bias_glue_launches = - (!hidden_bias.defined() || hidden_bias.numel() == 0 ? 0 : 1) + - (!kv_bias.defined() || kv_bias.numel() == 0 ? 0 : 1); - return PublicProjectionExecution{ - hidden_backend, - kv_backend, - "none", - kv_direct_bias_split ? "direct_split_outputs" : "split_last_dim_kernel", - (hidden_direct_bias || kv_direct_bias || kv_direct_bias_split) ? "direct_biased_receiver_affine" - : "no_identical_signature", - (kv_direct_bias_split ? 2 : 3) + bias_glue_launches, - (backend_uses_cublas(hidden_backend) ? 1 : 0) + (backend_uses_cublas(kv_backend) ? 1 : 0), - kv_direct_bias_split ? 0 : 1, - 1, - bias_glue_launches, - previous_bias_glue_launches - bias_glue_launches}; -} - -const char* launch_dense_input_projection( - const at::Tensor& message, - const py::tuple& input_projection_tensors, - const at::Tensor& projected_message) { - RECORD_FUNCTION("fabric.physical.input_projection", std::vector()); - if (input_projection_tensors.size() == 0) { - TORCH_CHECK( - message.size(2) == projected_message.size(2), - "identity dense input projection requires message_dim == projected_message_dim"); - fabric::cuda::ops::dense_affine_receiver_major_copy_or_pad_out_cuda(message, projected_message); - return "copy_fused"; - } - const at::Tensor weight = tuple_tensor(input_projection_tensors, 0); - const at::Tensor bias = optional_tuple_tensor(input_projection_tensors, 1, message); - const auto backend = fabric::cuda::ops::dense_affine_out_cuda( - message, - weight, - bias, - projected_message, - fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, - /*group_size=*/1, - fabric::cuda::ops::DenseAffineOutputMode::Overwrite); - return fabric::cuda::ops::dense_affine_backend_name(backend); -} - -at::Tensor state_affine_input_tensor( - const StateAffineDeclaration& spec, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors) { - if (spec.binding.input_kind == StateAffineSourceKind::ProjectedMessage) { - return projected_message; - } - TORCH_CHECK( - spec.binding.input_kind == StateAffineSourceKind::StatePrev, - "unsupported Fabric state affine input kind: ", - static_cast(spec.binding.input_kind)); - TORCH_CHECK( - spec.binding.input_tensor_index >= 0 && spec.binding.input_tensor_index < state_prev_tensors.size(), - "Fabric state affine input tensor index is out of range"); - return state_prev_tensors[spec.binding.input_tensor_index].cast(); -} - -const char* state_affine_source_name(StateAffineSourceKind input_kind) { - switch (input_kind) { - case StateAffineSourceKind::ProjectedMessage: - return "projected_message"; - case StateAffineSourceKind::StatePrev: - return "state_prev"; - } - return "unknown"; -} - -const char* reset_policy_name(fabric::cuda::nn::ResetPolicy policy) { - switch (policy) { - case fabric::cuda::nn::ResetPolicy::None: - return "none"; - case fabric::cuda::nn::ResetPolicy::ZeroSourceRows: - return "zero_source_rows"; - } - return "unknown"; -} - -std::string state_affine_bucket_signature(const at::Tensor& input, const StateAffineDeclaration& spec) { - return "receiver_major:source=" + std::string(state_affine_source_name(spec.binding.input_kind)) + - ",K=" + std::to_string(input.size(2)) + ",N=" + std::to_string(spec.op.signature.N) + - ",group_size=" + std::to_string(spec.binding.group_size > 0 ? spec.binding.group_size : 1); -} - -const char* output_mode_name(fabric::cuda::ops::DenseAffineOutputMode mode) { - switch (mode) { - case fabric::cuda::ops::DenseAffineOutputMode::Overwrite: - return "overwrite"; - case fabric::cuda::ops::DenseAffineOutputMode::Accumulate: - return "accumulate"; - } - return "unknown"; -} - -void validate_state_affine_reset_policy(fabric::cuda::nn::ResetPolicy policy) { - TORCH_CHECK( - policy == fabric::cuda::nn::ResetPolicy::None || policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows, - "unsupported Fabric state affine reset policy: ", - static_cast(policy)); -} - -int64_t tensor_nbytes(const at::Tensor& tensor) { - if (!tensor.defined()) { - return 0; - } - return static_cast(tensor.numel()) * static_cast(tensor.element_size()); -} - -struct PackedSourceCacheEntry { - StateAffineSourceKind input_kind; - int input_tensor_index; - at::Tensor tensor; -}; - -bool same_state_affine_source(const PackedSourceCacheEntry& entry, const StateAffineDeclaration& spec) { - return entry.input_kind == spec.binding.input_kind && - entry.input_tensor_index == spec.binding.input_tensor_index; -} - -bool state_affine_source_already_planned( - const std::vector& planned_sources, - const StateAffineDeclaration& spec) { - for (const PackedSourceCacheEntry& entry : planned_sources) { - if (same_state_affine_source(entry, spec)) { - return true; - } - } - return false; -} - -std::string state_affine_source_key(const StateAffineDeclaration& spec) { - return std::string(state_affine_source_name(spec.binding.input_kind)) + ":" + - std::to_string(spec.binding.input_tensor_index); -} - -void append_workspace_metadata( - const std::string& name, - const at::Tensor& tensor, - const std::string& lifetime, - const std::string& alias_class, - std::vector* buffers, - std::vector* buffer_bytes, - int64_t* peak_bytes, - bool contributes_to_peak = true) { - const int64_t bytes = tensor_nbytes(tensor); - buffers->push_back(name + ":lifetime=" + lifetime + ",alias_class=" + alias_class); - buffer_bytes->push_back(name + ":" + std::to_string(bytes)); - if (contributes_to_peak) { - *peak_bytes += bytes; - } -} - -bool same_tensor_shape_and_layout(const at::Tensor& lhs, const at::Tensor& rhs) { - return lhs.defined() && rhs.defined() && lhs.device() == rhs.device() && lhs.scalar_type() == rhs.scalar_type() && - lhs.sizes().equals(rhs.sizes()) && lhs.strides().equals(rhs.strides()); -} - -bool can_reuse_projected_message_for_reset_source( - const std::vector& specs, - size_t spec_index, - bool allow_projected_message_reset_source_reuse) { - if (!allow_projected_message_reset_source_reuse) { - return false; - } - if (specs[spec_index].binding.input_kind == StateAffineSourceKind::ProjectedMessage) { - return false; - } - for (size_t idx = spec_index + 1; idx < specs.size(); ++idx) { - if (specs[idx].binding.input_kind == StateAffineSourceKind::ProjectedMessage) { - return false; - } - } - return true; -} - -int64_t state_affine_receiver_alignment( - const std::vector& specs, - const std::vector& cell_param_tensors, - int64_t receivers) { - int64_t alignment = 1; - for (const StateAffineDeclaration& spec : specs) { - const int64_t group_size = spec.binding.group_size > 0 ? spec.binding.group_size : 1; - if (spec.binding.weight_param_index >= 0 && - spec.binding.weight_param_index < static_cast(cell_param_tensors.size())) { - const at::Tensor& weight = cell_param_tensors[spec.binding.weight_param_index]; - if (weight.defined() && weight.dim() == 3 && weight.size(0) != receivers && weight.size(0) != 1) { - TORCH_CHECK( - weight.size(0) * group_size == receivers, - "grouped Fabric state affine weight does not cover all receivers"); - alignment = std::max(alignment, group_size); - } - } - if (spec.binding.bias_param_index >= 0 && - spec.binding.bias_param_index < static_cast(cell_param_tensors.size())) { - const at::Tensor& bias = cell_param_tensors[spec.binding.bias_param_index]; - if (bias.defined() && bias.dim() == 2 && bias.size(0) != receivers && bias.size(0) != 1) { - TORCH_CHECK( - bias.size(0) * group_size == receivers, - "grouped Fabric state affine bias does not cover all receivers"); - alignment = std::max(alignment, group_size); - } - } - } - return alignment; -} - -int64_t choose_state_affine_receiver_chunk_size( - const std::vector& specs, - const std::vector& cell_param_tensors, - int64_t batch, - int64_t full_receivers, - int64_t receivers, - int64_t output_dim, - int64_t element_size, - bool include_receiver_affine_superop_pack_workspace, - int64_t target_state_affine_workspace_bytes) { - int64_t bytes_per_receiver = batch * output_dim * element_size; - if (include_receiver_affine_superop_pack_workspace && specs.size() == 2) { - const int64_t packed_k = specs[0].op.signature.K + specs[1].op.signature.K + 1; - bytes_per_receiver += batch * packed_k * element_size; - bytes_per_receiver += packed_k * output_dim * element_size; - } - if (receivers <= 0 || bytes_per_receiver <= 0 || - bytes_per_receiver * receivers <= target_state_affine_workspace_bytes) { - return receivers; - } - int64_t chunk_size = std::max(1, target_state_affine_workspace_bytes / bytes_per_receiver); - chunk_size = std::min(chunk_size, receivers); - const int64_t alignment = state_affine_receiver_alignment(specs, cell_param_tensors, full_receivers); - if (alignment > 1 && chunk_size < receivers) { - chunk_size = std::max(alignment, (chunk_size / alignment) * alignment); - } - return std::max(1, std::min(chunk_size, receivers)); -} - -at::Tensor receiver_window(const at::Tensor& tensor, int64_t receiver_start, int64_t receiver_count) { - if (!tensor.defined() || tensor.numel() == 0) { - return tensor; - } - TORCH_CHECK(tensor.dim() >= 2, "receiver-windowed tensor must have rank >= 2"); - return tensor.slice(/*dim=*/1, receiver_start, receiver_start + receiver_count); -} - -at::Tensor workspace_receiver_window(const at::Tensor& tensor, int64_t receiver_count) { - if (!tensor.defined() || tensor.numel() == 0) { - return tensor; - } - TORCH_CHECK(tensor.dim() >= 2, "receiver-windowed workspace must have rank >= 2"); - return tensor.slice(/*dim=*/1, 0, receiver_count); -} - -at::Tensor workspace_leading_receiver_window(const at::Tensor& tensor, int64_t receiver_count) { - if (!tensor.defined() || tensor.numel() == 0) { - return tensor; - } - TORCH_CHECK(tensor.dim() >= 1, "receiver-leading workspace must have rank >= 1"); - return tensor.slice(/*dim=*/0, 0, receiver_count); -} - -bool state_affine_source_is_compact(StateAffineSourceKind input_kind) { - return input_kind == StateAffineSourceKind::ProjectedMessage; -} - -int64_t state_affine_source_receiver_start( - const StateAffineDeclaration& spec, - int64_t receiver_start, - int64_t receiver_global_offset, - int64_t source_receiver_extent, - int64_t compact_receiver_extent) { - if (source_receiver_extent == compact_receiver_extent) { - return receiver_start; - } - return state_affine_source_is_compact(spec.binding.input_kind) ? receiver_start - : receiver_global_offset + receiver_start; -} - -at::Tensor state_affine_input_window( - const StateAffineDeclaration& spec, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors, - int64_t receiver_start, - int64_t receiver_global_offset, - int64_t receiver_count, - bool state_prev_is_zero = false) { - if (state_prev_is_zero && spec.binding.input_kind == StateAffineSourceKind::StatePrev) { - return projected_message.new_zeros({projected_message.size(0), receiver_count, spec.op.signature.K}); - } - const at::Tensor input = state_affine_input_tensor(spec, projected_message, state_prev_tensors); - TORCH_CHECK(input.dim() >= 2, "Fabric state affine input must be receiver-major rank >= 2"); - return receiver_window( - input, - state_affine_source_receiver_start( - spec, - receiver_start, - receiver_global_offset, - input.size(1), - projected_message.size(1)), - receiver_count); -} - -at::Tensor state_affine_weight_window( - const at::Tensor& weight, - const StateAffineDeclaration& spec, - int64_t receivers, - int64_t receiver_start, - int64_t receiver_count) { - if (!weight.defined() || weight.numel() == 0 || weight.dim() != 3 || weight.size(0) == 1) { - return weight; - } - if (weight.size(0) == receivers) { - return weight.slice(/*dim=*/0, receiver_start, receiver_start + receiver_count); - } - const int64_t group_size = spec.binding.group_size > 0 ? spec.binding.group_size : 1; - TORCH_CHECK( - weight.size(0) * group_size == receivers, - "receiver-windowed grouped Fabric state affine weight does not cover all receivers"); - TORCH_CHECK( - receiver_start % group_size == 0 && receiver_count % group_size == 0, - "receiver-windowed Fabric state affine chunk must align to grouped weight boundaries"); - return weight.slice(/*dim=*/0, receiver_start / group_size, (receiver_start + receiver_count) / group_size); -} - -at::Tensor state_affine_bias_window( - const at::Tensor& bias, - const StateAffineDeclaration& spec, - int64_t receivers, - int64_t receiver_start, - int64_t receiver_count) { - if (!bias.defined() || bias.numel() == 0 || bias.dim() != 2 || bias.size(0) == 1) { - return bias; - } - if (bias.size(0) == receivers) { - return bias.slice(/*dim=*/0, receiver_start, receiver_start + receiver_count); - } - const int64_t group_size = spec.binding.group_size > 0 ? spec.binding.group_size : 1; - TORCH_CHECK( - bias.size(0) * group_size == receivers, - "receiver-windowed grouped Fabric state affine bias does not cover all receivers"); - TORCH_CHECK( - receiver_start % group_size == 0 && receiver_count % group_size == 0, - "receiver-windowed Fabric state affine chunk must align to grouped bias boundaries"); - return bias.slice(/*dim=*/0, receiver_start / group_size, (receiver_start + receiver_count) / group_size); -} - -struct DenseStateAffineWorkspace { - std::vector outputs; - std::vector reset_packed_sources; - at::Tensor receiver_affine_packed_input; - at::Tensor receiver_affine_packed_weight; - std::vector buffers; - std::vector buffer_bytes; - std::vector aliases; - int64_t receiver_chunk_size = 0; - int64_t receiver_chunks = 0; - int64_t bytes = 0; -}; - -bool receiver_affine_superop_v1_declaration_family( - const std::vector& specs) { - if (specs.size() != 2) { - return false; - } - return specs[0].binding.input_kind == StateAffineSourceKind::ProjectedMessage && - specs[1].binding.input_kind == StateAffineSourceKind::StatePrev && - specs[0].op.signature.N == specs[1].op.signature.N && - specs[0].op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::None && - (specs[1].op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::None || - specs[1].op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows); -} - -struct ReceiverAffineSuperOpPlan { - bool active = false; - bool direct_persistent = false; - std::string demotion_reason = "receiver_affine_superop_ineligible:unsupported_affine_count"; - std::string applicability_predicate = "state_affine_v1_two_affine_projected_recurrent"; -}; - -struct TinyMessageSuperOpPlan { - bool active = false; - std::string demotion_reason = "tiny_message_superop_pending:executor_not_enabled"; - std::string applicability_predicate = "regular_local_receiver_owned_direct_projected_receiver_major_batch_row_reset"; -}; - -struct DiagonalRecurrenceSuperOpPlan { - bool active = false; - fabric::cuda::nn::DiagonalRecurrenceDeclaration declaration{}; - int64_t receiver_offset = 0; - std::string demotion_reason = "diagonal_recurrence_superop_ineligible:no_declaration"; - std::string applicability_predicate = "diagonal_recurrence_declared_in_fabric_cuda_nn"; -}; - -const char* receiver_affine_superop_physical_mode(const ReceiverAffineSuperOpPlan& plan) { - if (!plan.active) { - return "lowered_phase_executor"; - } - return plan.direct_persistent ? "direct_persistent" : "pack_cublas_transitional"; -} - -const char* diagonal_recurrence_kind_name(fabric::cuda::nn::DiagonalRecurrenceKind kind) { - switch (kind) { - case fabric::cuda::nn::DiagonalRecurrenceKind::ComplexExponential2D: - return "complex_exponential_2d"; - } - return "unknown"; -} - -TinyMessageSuperOpPlan select_tiny_message_superop_plan( - int spatial_id, - int message_backend_id, - bool has_sparse_message_topology, - const at::Tensor& q, - const at::Tensor& input_v, - const at::Tensor& recurrent_v, - const at::Tensor& receiver_sender_idx, - int64_t projected_message_dim, - const py::tuple& input_projection_tensors) { - TinyMessageSuperOpPlan plan{}; - if (spatial_id == static_cast(SpatialOwnership::EdgeOwned)) { - plan.demotion_reason = "tiny_message_superop_ineligible:edge_owned"; - return plan; - } - (void)has_sparse_message_topology; - if (message_backend_id != 0) { - plan.demotion_reason = "tiny_message_superop_ineligible:sparse_topology"; - return plan; - } - if (!receiver_sender_idx.defined() || receiver_sender_idx.dim() != 2 || - receiver_sender_idx.size(1) <= 0 || receiver_sender_idx.size(1) > 16) { - plan.demotion_reason = "tiny_message_superop_ineligible:unsupported_degree"; - return plan; - } - if (!q.defined() || q.dim() != 2 || q.size(1) <= 0) { - plan.demotion_reason = "tiny_message_superop_ineligible:unsupported_query_layout"; - return plan; - } - if (!input_v.defined() || input_v.dim() != 4 || input_v.size(3) <= 0 || input_v.size(3) > 16 || - !recurrent_v.defined() || recurrent_v.dim() != 3 || recurrent_v.size(2) != input_v.size(3)) { - plan.demotion_reason = "tiny_message_superop_ineligible:unsupported_value_layout"; - return plan; - } - if (projected_message_dim <= 0) { - plan.demotion_reason = "tiny_message_superop_ineligible:unsupported_projected_dim"; - return plan; - } - if (input_projection_tensors.size() == 0) { - plan.demotion_reason = "tiny_message_superop_ineligible:missing_projection"; - return plan; - } - const at::Tensor projection_weight = tuple_tensor(input_projection_tensors, 0); - if (!projection_weight.defined() || !(projection_weight.dim() == 2 || projection_weight.dim() == 3)) { - plan.demotion_reason = "tiny_message_superop_ineligible:unsupported_projection_layout"; - return plan; - } - const bool receiver_weight = - projection_weight.dim() == 3 && projection_weight.size(0) == q.size(0) && - projection_weight.size(1) == input_v.size(3) && projection_weight.size(2) == projected_message_dim; - const bool shared_weight = - projection_weight.dim() == 2 && - ((projection_weight.size(0) == input_v.size(3) && projection_weight.size(1) == projected_message_dim) || - (projection_weight.size(0) == projected_message_dim && projection_weight.size(1) == input_v.size(3))); - if (!receiver_weight && !shared_weight) { - plan.demotion_reason = "tiny_message_superop_ineligible:unsupported_projection_shape"; - return plan; - } - plan.active = true; - plan.demotion_reason = "none"; - plan.applicability_predicate = "regular_local_receiver_owned_direct_projected_receiver_major_batch_row_reset"; - return plan; -} - -DenseStateAffineWorkspace allocate_dense_state_affine_workspace( - const std::vector& specs, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors, - const std::vector& cell_param_tensors, - int64_t full_receivers, - int64_t receiver_global_offset, - int64_t receiver_count, - bool include_receiver_affine_superop_pack_workspace, - bool receiver_affine_superop_handles_reset_source, - bool allow_projected_message_reset_source_reuse, - bool state_prev_is_zero) { - DenseStateAffineWorkspace workspace; - if (specs.empty()) { - return workspace; - } - workspace.outputs.reserve(1); - const at::Tensor first_input = - state_affine_input_window( - specs.front(), - projected_message, - state_prev_tensors, - 0, - receiver_global_offset, - receiver_count, - state_prev_is_zero); - const int64_t batch = first_input.size(0); - const int output_dim = static_cast(specs.front().op.signature.N); - const int64_t target_state_affine_workspace_bytes = 2LL * 1024LL * 1024LL * 1024LL; - workspace.receiver_chunk_size = choose_state_affine_receiver_chunk_size( - specs, - cell_param_tensors, - batch, - full_receivers, - receiver_count, - output_dim, - first_input.element_size(), - include_receiver_affine_superop_pack_workspace, - target_state_affine_workspace_bytes); - workspace.receiver_chunks = ceil_div(receiver_count, workspace.receiver_chunk_size); - for (size_t spec_index = 0; spec_index < specs.size(); ++spec_index) { - const StateAffineDeclaration& spec = specs[spec_index]; - const at::Tensor input = - state_affine_input_window( - spec, - projected_message, - state_prev_tensors, - 0, - receiver_global_offset, - receiver_count, - state_prev_is_zero); - validate_state_affine_reset_policy(spec.op.signature.reset_policy); - TORCH_CHECK(input.size(0) == batch, "combined state affine inputs must have matching B"); - TORCH_CHECK( - input.size(1) == receiver_count, - "combined state affine inputs must have matching R: source=", - state_affine_source_name(spec.binding.input_kind), - ", input_R=", - input.size(1), - ", receiver_count=", - receiver_count, - ", projected_R=", - projected_message.size(1), - ", full_R=", - full_receivers, - ", receiver_global_offset=", - receiver_global_offset); - TORCH_CHECK(spec.op.signature.N == output_dim, "combined state affine outputs must have matching N"); - if (spec.binding.weight_param_index >= 0 && - spec.binding.weight_param_index < static_cast(cell_param_tensors.size())) { - (void)state_affine_weight_window( - cell_param_tensors[spec.binding.weight_param_index], - spec, - full_receivers, - receiver_global_offset, - receiver_count); - } - if (spec.binding.bias_param_index >= 0 && - spec.binding.bias_param_index < static_cast(cell_param_tensors.size())) { - (void)state_affine_bias_window( - cell_param_tensors[spec.binding.bias_param_index], - spec, - full_receivers, - receiver_global_offset, - receiver_count); - } - if (!receiver_affine_superop_handles_reset_source && - spec.op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows && - !state_affine_source_already_planned(workspace.reset_packed_sources, spec)) { - const bool reuse_projected_message = - can_reuse_projected_message_for_reset_source( - specs, - spec_index, - allow_projected_message_reset_source_reuse) && - same_tensor_shape_and_layout(input, projected_message); - at::Tensor packed_source = reuse_projected_message - ? projected_message - : at::empty({batch, workspace.receiver_chunk_size, input.size(2)}, input.options()); - workspace.reset_packed_sources.push_back( - PackedSourceCacheEntry{spec.binding.input_kind, spec.binding.input_tensor_index, packed_source}); - const std::string buffer_name = - "reset_packed_source_" + std::to_string(workspace.reset_packed_sources.size() - 1) + - "[" + state_affine_source_key(spec) + "]"; - append_workspace_metadata( - buffer_name, - packed_source, - "dense_state_affines", - reuse_projected_message ? "phase_reuse" : "reset_packed_source", - &workspace.buffers, - &workspace.buffer_bytes, - &workspace.bytes, - !reuse_projected_message); - if (reuse_projected_message) { - workspace.aliases.push_back("projected_message=" + buffer_name); - } - } - } - at::Tensor combined_output = at::empty({batch, workspace.receiver_chunk_size, output_dim}, first_input.options()); - workspace.outputs.push_back(combined_output); - append_workspace_metadata( - "state_affine_output_0", - combined_output, - "dense_state_affines->receiver_state_update", - "state_affine_contributions", - &workspace.buffers, - &workspace.buffer_bytes, - &workspace.bytes); - workspace.aliases.push_back("state_affine_contributions:combined"); - return workspace; -} - -fabric::cuda::nn::LoweredPhaseIR lower_state_affine_specs_to_phase_ir( - const std::vector& specs) { - std::vector ops; - ops.reserve(specs.size()); - for (const StateAffineDeclaration& spec : specs) { - ops.push_back(spec.op); - } - return fabric::cuda::nn::lower_affine_ops_to_phase_ir(ops); -} - -bool receiver_affine_superop_bias_supported( - const at::Tensor& bias, - int64_t receivers, - int64_t output_dim) { - if (!bias.defined() || bias.numel() == 0) { - return true; - } - if (bias.dim() == 1) { - return bias.size(0) == output_dim; - } - return bias.dim() == 2 && bias.size(1) == output_dim && (bias.size(0) == 1 || bias.size(0) == receivers); -} - -bool receiver_affine_direct_persistent_supported( - const std::vector& specs, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors, - const std::vector& cell_param_tensors, - int64_t full_receivers, - int64_t receiver_global_offset, - int64_t receiver_count) { - if (specs.size() != 2) { - return false; - } - const at::Tensor first_input = - state_affine_input_window(specs[0], projected_message, state_prev_tensors, 0, receiver_global_offset, receiver_count); - const at::Tensor second_input = - state_affine_input_window(specs[1], projected_message, state_prev_tensors, 0, receiver_global_offset, receiver_count); - const int64_t output_dim = specs[0].op.signature.N; - const int64_t max_input_dim = std::max(first_input.size(2), second_input.size(2)); - if (max_input_dim <= 0 || max_input_dim > 128) { - return false; - } - for (const StateAffineDeclaration& spec : specs) { - if (spec.binding.weight_param_index >= static_cast(cell_param_tensors.size())) { - return false; - } - if (spec.binding.weight_param_index >= 0) { - (void)state_affine_weight_window( - cell_param_tensors[spec.binding.weight_param_index], - spec, - full_receivers, - receiver_global_offset, - receiver_count); - } - if (spec.binding.bias_param_index >= static_cast(cell_param_tensors.size())) { - return false; - } - if (spec.binding.bias_param_index >= 0 && - !receiver_affine_superop_bias_supported( - state_affine_bias_window( - cell_param_tensors[spec.binding.bias_param_index], - spec, - full_receivers, - receiver_global_offset, - receiver_count), - receiver_count, - output_dim)) { - return false; - } - } - return true; -} - -ReceiverAffineSuperOpPlan select_receiver_affine_superop_plan( - const std::vector& specs, - const fabric::cuda::nn::LoweredPhaseIR& lowered_phase_ir, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors, - const std::vector& cell_param_tensors, - int64_t full_receivers, - int64_t receiver_global_offset, - int64_t receiver_count) { - ReceiverAffineSuperOpPlan plan; - if (specs.empty()) { - plan.demotion_reason = "none"; - plan.applicability_predicate = "state_affine_absent"; - return plan; - } - if (specs.size() != 2) { - plan.demotion_reason = "receiver_affine_superop_ineligible:unsupported_affine_count"; - return plan; - } - if (lowered_phase_ir.affine_buckets.size() != specs.size()) { - plan.demotion_reason = "receiver_affine_superop_ineligible:unsupported_affine_count"; - plan.applicability_predicate = "state_affine_lowered_bucket_count_mismatch"; - return plan; - } - const StateAffineDeclaration& first = specs[0]; - const StateAffineDeclaration& second = specs[1]; - for (const StateAffineDeclaration& spec : specs) { - const auto& signature = spec.op.signature; - if (signature.op_kind != fabric::cuda::nn::OpKind::Linear || - signature.phase_kind != fabric::cuda::nn::PhaseKind::RecurrentAffine || - signature.layout != fabric::cuda::nn::TensorLayout::ReceiverMajor) { - plan.demotion_reason = "receiver_affine_superop_ineligible:non_receiver_major_layout"; - return plan; - } - if (!( - (signature.reset_policy == fabric::cuda::nn::ResetPolicy::None && - signature.reset_scope == fabric::cuda::nn::ResetScope::None) || - (signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows && - signature.reset_scope == fabric::cuda::nn::ResetScope::BatchRow))) { - plan.demotion_reason = "receiver_affine_superop_ineligible:unsupported_reset_scope"; - return plan; - } - } - if (first.binding.input_kind != StateAffineSourceKind::ProjectedMessage || - second.binding.input_kind != StateAffineSourceKind::StatePrev) { - plan.demotion_reason = "receiver_affine_superop_ineligible:unsupported_source_family"; - return plan; - } - if (first.op.signature.reset_policy != fabric::cuda::nn::ResetPolicy::None || - !(second.op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::None || - second.op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows)) { - plan.demotion_reason = "receiver_affine_superop_ineligible:unsupported_reset_scope"; - return plan; - } - if (first.op.signature.N != second.op.signature.N) { - plan.demotion_reason = "receiver_affine_superop_ineligible:mixed_output_dim"; - return plan; - } - if (first.binding.group_size != 1 || second.binding.group_size != 1) { - plan.demotion_reason = "receiver_affine_superop_ineligible:mixed_chunk_family"; - return plan; - } - const int64_t output_dim = first.op.signature.N; - for (const StateAffineDeclaration& spec : specs) { - if (spec.binding.weight_param_index < 0 || - spec.binding.weight_param_index >= static_cast(cell_param_tensors.size())) { - plan.demotion_reason = "receiver_affine_superop_ineligible:non_receiver_major_layout"; - return plan; - } - const at::Tensor weight = state_affine_weight_window( - cell_param_tensors[spec.binding.weight_param_index], - spec, - full_receivers, - receiver_global_offset, - receiver_count); - const at::Tensor input = - state_affine_input_window(spec, projected_message, state_prev_tensors, 0, receiver_global_offset, receiver_count); - if (!weight.defined() || weight.dim() != 3 || weight.size(0) != receiver_count || - weight.size(1) != input.size(2) || weight.size(2) != output_dim) { - plan.demotion_reason = "receiver_affine_superop_ineligible:non_receiver_major_layout"; - return plan; - } - if (spec.binding.bias_param_index >= static_cast(cell_param_tensors.size())) { - plan.demotion_reason = "receiver_affine_superop_ineligible:non_receiver_major_layout"; - return plan; - } - if (spec.binding.bias_param_index >= 0 && - !receiver_affine_superop_bias_supported( - state_affine_bias_window( - cell_param_tensors[spec.binding.bias_param_index], - spec, - full_receivers, - receiver_global_offset, - receiver_count), - receiver_count, - output_dim)) { - plan.demotion_reason = "receiver_affine_superop_ineligible:non_receiver_major_layout"; - return plan; - } - } - plan.active = true; - plan.direct_persistent = receiver_affine_direct_persistent_supported( - specs, - projected_message, - state_prev_tensors, - cell_param_tensors, - full_receivers, - receiver_global_offset, - receiver_count); - plan.demotion_reason = "none"; - plan.applicability_predicate = - "v1_two_affine_projected_overwrite_recurrent_accumulate_receiver_major_batch_row_reset"; - return plan; -} - -bool state_affine_receiver_window_supported( - const std::vector& specs, - const std::vector& cell_param_tensors, - int64_t full_receivers, - int64_t receiver_global_offset, - int64_t receiver_count) { - if (specs.empty() || receiver_global_offset < 0 || receiver_count <= 0 || - receiver_global_offset + receiver_count > full_receivers) { - return false; - } - const int64_t output_dim = specs.front().op.signature.N; - for (const StateAffineDeclaration& spec : specs) { - const auto& signature = spec.op.signature; - if (signature.op_kind != fabric::cuda::nn::OpKind::Linear || - signature.phase_kind != fabric::cuda::nn::PhaseKind::RecurrentAffine || - signature.layout != fabric::cuda::nn::TensorLayout::ReceiverMajor || - signature.N != output_dim) { - return false; - } - if (!(spec.binding.input_kind == StateAffineSourceKind::ProjectedMessage || - spec.binding.input_kind == StateAffineSourceKind::StatePrev)) { - return false; - } - if (!( - (signature.reset_policy == fabric::cuda::nn::ResetPolicy::None && - signature.reset_scope == fabric::cuda::nn::ResetScope::None) || - (signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows && - signature.reset_scope == fabric::cuda::nn::ResetScope::BatchRow))) { - return false; - } - if (spec.binding.weight_param_index < 0 || - spec.binding.weight_param_index >= static_cast(cell_param_tensors.size())) { - return false; - } - const at::Tensor weight = state_affine_weight_window( - cell_param_tensors[spec.binding.weight_param_index], - spec, - full_receivers, - receiver_global_offset, - receiver_count); - if (!weight.defined() || weight.numel() == 0) { - return false; - } - if (spec.binding.bias_param_index >= static_cast(cell_param_tensors.size())) { - return false; - } - if (spec.binding.bias_param_index >= 0) { - const at::Tensor bias = state_affine_bias_window( - cell_param_tensors[spec.binding.bias_param_index], - spec, - full_receivers, - receiver_global_offset, - receiver_count); - if (!receiver_affine_superop_bias_supported(bias, receiver_count, output_dim)) { - return false; - } - } - } - return true; -} - -DiagonalRecurrenceSuperOpPlan select_diagonal_recurrence_superop_plan( - const fabric::cuda::nn::LoweredPhaseIR& lowered_phase_ir, - const std::vector& state_affine_specs, - const std::vector& cell_param_tensors, - const at::Tensor& projected_message, - int64_t raw_public_dim, - int reduction_stats_dim, - int64_t receiver_offset = 0) { - DiagonalRecurrenceSuperOpPlan plan{}; - if (receiver_offset < 0) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:unsupported_receiver_window"; - return plan; - } - if (lowered_phase_ir.diagonal_recurrences.empty()) { - return plan; - } - if (lowered_phase_ir.diagonal_recurrences.size() != 1) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:unsupported_declaration_count"; - return plan; - } - if (!state_affine_specs.empty()) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:state_affine_surface"; - return plan; - } - if (reduction_stats_dim != 0) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:reduction_boundary"; - return plan; - } - const auto& declaration = lowered_phase_ir.diagonal_recurrences.front(); - if (declaration.kind != fabric::cuda::nn::DiagonalRecurrenceKind::ComplexExponential2D) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:unsupported_recurrence_kind"; - return plan; - } - if (declaration.layout != fabric::cuda::nn::TensorLayout::ReceiverMajor || - declaration.input_source_kind != fabric::cuda::nn::SourceKind::ProjectedMessage) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:non_receiver_major_layout"; - return plan; - } - if (!( - (declaration.reset_policy == fabric::cuda::nn::ResetPolicy::None && - declaration.reset_scope == fabric::cuda::nn::ResetScope::None) || - (declaration.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows && - declaration.reset_scope == fabric::cuda::nn::ResetScope::BatchRow))) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:unsupported_reset_scope"; - return plan; - } - if (declaration.hidden_dim <= 0 || declaration.projected_message_dim <= 0 || declaration.raw_public_dim <= 0 || - projected_message.dim() != 3 || projected_message.size(2) < declaration.hidden_dim || - declaration.raw_public_dim != raw_public_dim) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:unsupported_shape"; - return plan; - } - const auto& binding = declaration.binding; - const int max_param_index = std::max( - std::max(binding.nu_param_index, binding.theta_param_index), - std::max(std::max(binding.w1_param_index, binding.w2_param_index), binding.activation_param_index)); - if (binding.nu_param_index < 0 || binding.theta_param_index < 0 || binding.w1_param_index < 0 || - binding.w2_param_index < 0 || binding.activation_param_index < 0 || - max_param_index >= static_cast(cell_param_tensors.size())) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:missing_parameters"; - return plan; - } - for (const int param_index : {binding.nu_param_index, binding.theta_param_index, binding.w1_param_index, binding.w2_param_index}) { - const at::Tensor& param = cell_param_tensors.at(static_cast(param_index)); - if (!param.defined() || param.dim() != 2 || param.size(0) < receiver_offset + projected_message.size(1) || - param.size(1) != declaration.hidden_dim || param.scalar_type() != at::kFloat) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:unsupported_parameter_layout"; - return plan; - } - } - const at::Tensor& activation = cell_param_tensors.at(static_cast(binding.activation_param_index)); - if (!activation.defined() || activation.scalar_type() != at::kInt) { - plan.demotion_reason = "diagonal_recurrence_superop_ineligible:unsupported_activation_layout"; - return plan; - } - plan.active = true; - plan.declaration = declaration; - plan.receiver_offset = receiver_offset; - plan.demotion_reason = "none"; - plan.applicability_predicate = - "complex_exponential_2d_receiver_major_batch_row_reset_no_reduction_boundary"; - return plan; -} - -void allocate_receiver_affine_superop_workspace( - const std::vector& specs, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors, - const std::vector& cell_param_tensors, - const ReceiverAffineSuperOpPlan& plan, - DenseStateAffineWorkspace* workspace) { - if (!plan.active) { - return; - } - if (plan.direct_persistent) { - workspace->aliases.push_back("receiver_affine_superop:direct_persistent_no_pack_workspace"); - return; - } - TORCH_CHECK(specs.size() == 2, "receiver affine super-op workspace expects exactly two state affines"); - TORCH_CHECK(workspace->receiver_chunk_size > 0, "receiver affine super-op workspace requires receiver chunks"); - const at::Tensor first_input = state_affine_input_tensor(specs[0], projected_message, state_prev_tensors); - const at::Tensor second_input = state_affine_input_tensor(specs[1], projected_message, state_prev_tensors); - TORCH_CHECK( - specs[0].binding.weight_param_index >= 0 && - specs[0].binding.weight_param_index < static_cast(cell_param_tensors.size()), - "receiver affine super-op first weight param index is out of range"); - const at::Tensor& first_weight = cell_param_tensors[specs[0].binding.weight_param_index]; - const int64_t packed_k = first_input.size(2) + second_input.size(2) + 1; - const int64_t output_dim = specs[0].op.signature.N; - workspace->receiver_affine_packed_input = - at::empty({first_input.size(0), workspace->receiver_chunk_size, packed_k}, first_input.options()); - workspace->receiver_affine_packed_weight = - at::empty({workspace->receiver_chunk_size, packed_k, output_dim}, first_weight.options()); - append_workspace_metadata( - "receiver_affine_superop_packed_input", - workspace->receiver_affine_packed_input, - "receiver_affine_superop_pack_input->receiver_affine_superop_gemm", - "receiver_affine_superop_pack_input", - &workspace->buffers, - &workspace->buffer_bytes, - &workspace->bytes); - append_workspace_metadata( - "receiver_affine_superop_packed_weight", - workspace->receiver_affine_packed_weight, - "receiver_affine_superop_pack_weight->receiver_affine_superop_gemm", - "receiver_affine_superop_pack_weight", - &workspace->buffers, - &workspace->buffer_bytes, - &workspace->bytes); -} - -at::Tensor reset_aware_state_affine_input_tensor( - const StateAffineDeclaration& spec, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors, - const at::Tensor& resets_u8, - int t, - int64_t receiver_start, - int64_t receiver_global_offset, - int64_t receiver_count, - const std::vector& reset_packed_source_workspaces, - std::vector* packed_source_cache, - bool* packed_source_reused, - bool state_prev_is_zero = false) { - const at::Tensor input_window = state_affine_input_window( - spec, - projected_message, - state_prev_tensors, - receiver_start, - receiver_global_offset, - receiver_count, - state_prev_is_zero); - if (spec.op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::None) { - return input_window; - } - if (state_prev_is_zero && spec.binding.input_kind == StateAffineSourceKind::StatePrev) { - return input_window; - } - const at::Tensor input = state_affine_input_tensor(spec, projected_message, state_prev_tensors); - TORCH_CHECK( - spec.op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows, - "unsupported Fabric state affine reset policy: ", - static_cast(spec.op.signature.reset_policy)); - for (const PackedSourceCacheEntry& entry : *packed_source_cache) { - if (same_state_affine_source(entry, spec)) { - *packed_source_reused = true; - return entry.tensor; - } - } - for (const PackedSourceCacheEntry& workspace : reset_packed_source_workspaces) { - if (same_state_affine_source(workspace, spec)) { - const at::Tensor packed_window = workspace.tensor.size(1) == input.size(1) - ? receiver_window( - workspace.tensor, - state_affine_source_receiver_start( - spec, - receiver_start, - receiver_global_offset, - workspace.tensor.size(1), - projected_message.size(1)), - receiver_count) - : workspace.tensor.size(1) == projected_message.size(1) - ? receiver_window(workspace.tensor, receiver_start, receiver_count) - : workspace_receiver_window(workspace.tensor, receiver_count); - fabric::cuda::ops::dense_affine_pack_reset_source_rows_out_cuda(input_window, packed_window, resets_u8, t); - packed_source_cache->push_back( - PackedSourceCacheEntry{workspace.input_kind, workspace.input_tensor_index, packed_window}); - return packed_window; - } - } - TORCH_CHECK(false, "missing planned reset-packed source workspace for Fabric state affine"); - return input; -} - -void launch_dense_state_affines( - const std::vector& specs, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors, - const std::vector& cell_param_tensors, - const std::vector& outputs, - const std::vector& reset_packed_source_workspaces, - const at::Tensor& resets_u8, - int t, - int64_t full_receivers, - int64_t receiver_start, - int64_t receiver_global_offset, - int64_t receiver_count, - bool state_prev_is_zero, - std::vector* backends, - std::vector* source_kinds, - std::vector* bucket_signatures, - std::vector* output_modes, - std::vector* reset_policies, - bool* packed_source_reused) { - RECORD_FUNCTION("fabric.physical.state_affine", std::vector()); - TORCH_CHECK( - (specs.empty() && outputs.empty()) || (!specs.empty() && outputs.size() == 1), - "Fabric state affines use one combined output workspace"); - backends->clear(); - source_kinds->clear(); - bucket_signatures->clear(); - output_modes->clear(); - reset_policies->clear(); - if (specs.empty()) { - return; - } - backends->reserve(specs.size()); - source_kinds->reserve(specs.size()); - bucket_signatures->reserve(specs.size()); - output_modes->reserve(specs.size()); - reset_policies->reserve(specs.size()); - std::vector packed_source_cache; - const at::Tensor output_window = workspace_receiver_window(outputs[0], receiver_count); - for (size_t idx = 0; idx < specs.size(); ++idx) { - const StateAffineDeclaration& spec = specs[idx]; - const at::Tensor input = reset_aware_state_affine_input_tensor( - spec, - projected_message, - state_prev_tensors, - resets_u8, - t, - receiver_start, - receiver_global_offset, - receiver_count, - reset_packed_source_workspaces, - &packed_source_cache, - packed_source_reused, - state_prev_is_zero); - TORCH_CHECK( - spec.binding.weight_param_index >= 0 && - spec.binding.weight_param_index < static_cast(cell_param_tensors.size()), - "Fabric state affine weight param index is out of range"); - const at::Tensor weight = state_affine_weight_window( - cell_param_tensors[spec.binding.weight_param_index], - spec, - full_receivers, - receiver_global_offset + receiver_start, - receiver_count); - TORCH_CHECK( - spec.binding.bias_param_index < static_cast(cell_param_tensors.size()), - "Fabric state affine bias param index is out of range"); - const at::Tensor bias = - spec.binding.bias_param_index >= 0 - ? state_affine_bias_window( - cell_param_tensors.at(static_cast(spec.binding.bias_param_index)), - spec, - full_receivers, - receiver_global_offset + receiver_start, - receiver_count) - : at::Tensor(); - const auto output_mode = idx == 0 ? fabric::cuda::ops::DenseAffineOutputMode::Overwrite - : fabric::cuda::ops::DenseAffineOutputMode::Accumulate; - const auto backend = fabric::cuda::ops::dense_affine_out_cuda( - input, - weight, - bias, - output_window, - fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, - spec.binding.group_size > 0 ? spec.binding.group_size : 1, - output_mode); - backends->push_back(fabric::cuda::ops::dense_affine_backend_name(backend)); - source_kinds->push_back(state_affine_source_name(spec.binding.input_kind)); - bucket_signatures->push_back(state_affine_bucket_signature(input, spec)); - output_modes->push_back(output_mode_name(output_mode)); - reset_policies->push_back(reset_policy_name(spec.op.signature.reset_policy)); - } -} - -void launch_receiver_affine_superop( - const ReceiverAffineSuperOpPlan& plan, - const std::vector& specs, - const at::Tensor& projected_message, - const py::tuple& state_prev_tensors, - const std::vector& cell_param_tensors, - const std::vector& outputs, - const at::Tensor& receiver_affine_packed_input, - const at::Tensor& receiver_affine_packed_weight, - const at::Tensor& resets_u8, - int t, - int64_t full_receivers, - int64_t receiver_start, - int64_t receiver_global_offset, - int64_t receiver_count, - bool state_prev_is_zero, - std::vector* backends, - std::vector* source_kinds, - std::vector* bucket_signatures, - std::vector* output_modes, - std::vector* reset_policies, - bool* packed_source_reused) { - RECORD_FUNCTION("fabric.physical.receiver_affine", std::vector()); - TORCH_CHECK(specs.size() == 2, "receiver affine super-op v1 expects exactly two affines"); - TORCH_CHECK(outputs.size() == 1, "receiver affine super-op expects one combined output workspace"); - backends->clear(); - source_kinds->clear(); - bucket_signatures->clear(); - output_modes->clear(); - reset_policies->clear(); - *packed_source_reused = false; - TORCH_CHECK( - specs[0].op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::None, - "receiver affine super-op v1 expects the overwrite source to have no reset policy"); - TORCH_CHECK( - specs[1].op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::None || - specs[1].op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows, - "receiver affine super-op v1 only supports none or batch-row zero-source reset policy"); - const at::Tensor first_input = - state_affine_input_window( - specs[0], - projected_message, - state_prev_tensors, - receiver_start, - receiver_global_offset, - receiver_count, - state_prev_is_zero); - const at::Tensor output_window = workspace_receiver_window(outputs[0], receiver_count); - TORCH_CHECK( - specs[0].binding.weight_param_index >= 0 && - specs[0].binding.weight_param_index < static_cast(cell_param_tensors.size()), - "receiver affine super-op first weight param index is out of range"); - TORCH_CHECK( - specs[1].binding.weight_param_index >= 0 && - specs[1].binding.weight_param_index < static_cast(cell_param_tensors.size()), - "receiver affine super-op second weight param index is out of range"); - const at::Tensor first_weight = state_affine_weight_window( - cell_param_tensors[specs[0].binding.weight_param_index], - specs[0], - full_receivers, - receiver_global_offset + receiver_start, - receiver_count); - const at::Tensor second_weight = state_affine_weight_window( - cell_param_tensors[specs[1].binding.weight_param_index], - specs[1], - full_receivers, - receiver_global_offset + receiver_start, - receiver_count); - TORCH_CHECK( - specs[0].binding.bias_param_index < static_cast(cell_param_tensors.size()), - "receiver affine super-op first bias param index is out of range"); - TORCH_CHECK( - specs[1].binding.bias_param_index < static_cast(cell_param_tensors.size()), - "receiver affine super-op second bias param index is out of range"); - const at::Tensor first_bias = - specs[0].binding.bias_param_index >= 0 - ? state_affine_bias_window( - cell_param_tensors.at(static_cast(specs[0].binding.bias_param_index)), - specs[0], - full_receivers, - receiver_global_offset + receiver_start, - receiver_count) - : at::Tensor(); - const at::Tensor second_bias = - specs[1].binding.bias_param_index >= 0 - ? state_affine_bias_window( - cell_param_tensors.at(static_cast(specs[1].binding.bias_param_index)), - specs[1], - full_receivers, - receiver_global_offset + receiver_start, - receiver_count) - : at::Tensor(); - const bool zero_second_source_rows = - specs[1].op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows; - const bool second_source_is_zero = - state_prev_is_zero && specs[1].binding.input_kind == StateAffineSourceKind::StatePrev && - !second_bias.defined(); - const at::Tensor second_input = second_source_is_zero - ? first_input - : state_affine_input_window( - specs[1], - projected_message, - state_prev_tensors, - receiver_start, - receiver_global_offset, - receiver_count, - state_prev_is_zero); - if (plan.direct_persistent) { - py::gil_scoped_acquire gil; - static py::object direct_receiver_affine = - py::module_::import("cortical.fabric.backend.cuda.ops.receiver_affine_triton") - .attr("receiver_affine2_direct_persistent_out_cuda"); - py::object first_bias_arg = first_bias.defined() ? py::cast(first_bias) : py::none(); - py::object second_bias_arg = second_bias.defined() ? py::cast(second_bias) : py::none(); - direct_receiver_affine( - first_input, - first_weight, - first_bias_arg, - second_input, - second_weight, - second_bias_arg, - output_window, - resets_u8, - py::arg("t") = t, - py::arg("zero_second_source_rows") = zero_second_source_rows, - py::arg("second_source_is_zero") = second_source_is_zero); - } else { - const at::Tensor packed_input_window = workspace_receiver_window(receiver_affine_packed_input, receiver_count); - const at::Tensor packed_weight_window = - workspace_leading_receiver_window(receiver_affine_packed_weight, receiver_count); - fabric::cuda::ops::receiver_affine2_superop_workspace_out_cuda( - first_input, - first_weight, - first_bias, - second_input, - second_weight, - second_bias, - packed_input_window, - packed_weight_window, - output_window, - resets_u8, - t, - zero_second_source_rows); - } - backends->reserve(specs.size()); - source_kinds->reserve(specs.size()); - bucket_signatures->reserve(specs.size()); - output_modes->reserve(specs.size()); - reset_policies->reserve(specs.size()); - for (size_t idx = 0; idx < specs.size(); ++idx) { - const at::Tensor input = idx == 0 ? first_input : second_input; - backends->push_back("receiver_affine_superop"); - source_kinds->push_back(state_affine_source_name(specs[idx].binding.input_kind)); - bucket_signatures->push_back(state_affine_bucket_signature(input, specs[idx])); - output_modes->push_back(idx == 0 ? "overwrite" : "accumulate"); - reset_policies->push_back(reset_policy_name(specs[idx].op.signature.reset_policy)); - } -} - -py::tuple string_tuple(const std::vector& values) { - py::tuple out(values.size()); - for (size_t idx = 0; idx < values.size(); ++idx) { - out[idx] = values[idx]; - } - return out; -} - -void set_workspace_metadata( - py::dict* metadata, - const std::vector& buffers, - const std::vector& buffer_bytes, - int64_t peak_bytes, - const std::vector& aliases) { - (*metadata)["workspace_buffers"] = string_tuple(buffers); - (*metadata)["workspace_buffer_bytes"] = string_tuple(buffer_bytes); - (*metadata)["workspace_peak_bytes"] = py::make_tuple(std::to_string(peak_bytes)); - (*metadata)["workspace_aliases"] = aliases.empty() ? py::make_tuple("none") : string_tuple(aliases); - (*metadata)["physical_workspace_peak_bytes"] = py::make_tuple(std::to_string(peak_bytes)); - (*metadata)["physical_workspace_aliases"] = aliases.empty() ? py::make_tuple("none") : string_tuple(aliases); -} - -void append_dense_message_workspace_metadata( - const fabric::cuda::ops::DenseMessageExecution& execution, - std::vector* buffers, - std::vector* buffer_bytes, - std::vector* aliases, - int64_t* peak_bytes) { - buffers->insert(buffers->end(), execution.workspace_buffers.begin(), execution.workspace_buffers.end()); - buffer_bytes->insert( - buffer_bytes->end(), - execution.workspace_buffer_bytes.begin(), - execution.workspace_buffer_bytes.end()); - aliases->insert(aliases->end(), execution.workspace_aliases.begin(), execution.workspace_aliases.end()); - *peak_bytes += execution.workspace_peak_bytes; -} - -void append_diagonal_recurrence_workspace_metadata( - const fabric::cuda::ops::DiagonalRecurrenceExecution& execution, - std::vector* buffers, - std::vector* buffer_bytes, - std::vector* aliases, - int64_t* peak_bytes) { - buffers->insert(buffers->end(), execution.workspace_buffers.begin(), execution.workspace_buffers.end()); - buffer_bytes->insert( - buffer_bytes->end(), - execution.workspace_buffer_bytes.begin(), - execution.workspace_buffer_bytes.end()); - aliases->insert(aliases->end(), execution.workspace_aliases.begin(), execution.workspace_aliases.end()); - *peak_bytes += execution.workspace_peak_bytes; -} - -const char* message_demoted_reason(int message_backend_id, int spatial_id) { - if (spatial_id == static_cast(SpatialOwnership::EdgeOwned)) { - return "edge_owned_sparse_message_demoted"; - } - if (message_backend_id == 1) { - return "degree_bucketed_sparse_message_demoted"; - } - return "none"; -} - -void set_dense_message_metadata( - py::dict* metadata, - const fabric::cuda::ops::DenseMessageExecution* execution, - int message_backend_id, - int spatial_id) { - if (execution != nullptr) { - (*metadata)["message_bucket_count"] = py::make_tuple(std::to_string(execution->bucket_count)); - (*metadata)["message_regular_local_bucket_count"] = - py::make_tuple(std::to_string(execution->regular_local_bucket_count)); - (*metadata)["message_sparse_bucket_count"] = py::make_tuple(std::to_string(execution->sparse_bucket_count)); - (*metadata)["message_batched_backend_count"] = py::make_tuple(std::to_string(execution->batched_backend_count)); - (*metadata)["message_grouped_backend_count"] = py::make_tuple(std::to_string(execution->grouped_backend_count)); - (*metadata)["message_reset_aware_bucket_count"] = - py::make_tuple(std::to_string(execution->reset_aware_bucket_count)); - (*metadata)["message_degree_uniform_bucket_count"] = - py::make_tuple(std::to_string(execution->degree_uniform_bucket_count)); - (*metadata)["message_ragged_grouped_bucket_count"] = - py::make_tuple(std::to_string(execution->ragged_grouped_bucket_count)); - (*metadata)["message_demoted_bucket_count"] = - py::make_tuple(std::to_string(execution->demoted_bucket_count)); - (*metadata)["message_bucket_signatures"] = py::make_tuple(execution->bucket_signature); - (*metadata)["message_bucket_kinds"] = py::make_tuple(execution->bucket_kind); - (*metadata)["message_topology_kinds"] = py::make_tuple(execution->topology_kind); - (*metadata)["message_spatial_ownership"] = py::make_tuple(execution->spatial_ownership); - (*metadata)["message_degree_bucket_lists"] = py::make_tuple(execution->degree_bucket_list); - (*metadata)["message_logit_backends"] = py::make_tuple(execution->logit_backend); - (*metadata)["message_softmax_backends"] = py::make_tuple(execution->softmax_backend); - (*metadata)["message_weighted_value_backends"] = py::make_tuple(execution->weighted_value_backend); - (*metadata)["message_physical_mode"] = py::make_tuple(execution->physical_mode); - (*metadata)["message_execution_mode"] = py::make_tuple(execution->execution_mode); - (*metadata)["message_output_boundary"] = py::make_tuple(execution->output_boundary); - (*metadata)["message_degree"] = py::make_tuple(std::to_string(execution->degree_or_block)); - (*metadata)["message_k"] = py::make_tuple(std::to_string(execution->key_dim)); - (*metadata)["message_v"] = py::make_tuple(std::to_string(execution->value_dim)); - (*metadata)["message_projected_n"] = py::make_tuple(std::to_string(execution->projected_output_dim)); - (*metadata)["message_reset_policies"] = py::make_tuple(execution->reset_policy); - (*metadata)["message_reset_scopes"] = py::make_tuple(execution->reset_scope); - (*metadata)["message_use_delay"] = py::make_tuple(execution->use_delay); - (*metadata)["message_distance_penalty_kinds"] = py::make_tuple(execution->distance_penalty_kind); - (*metadata)["message_epilogue_kinds"] = py::make_tuple(execution->epilogue_kind); - (*metadata)["message_packed_source_reuse_count"] = - py::make_tuple(std::to_string(execution->packed_source_reuse_count)); - (*metadata)["message_demotions"] = py::make_tuple("none"); - (*metadata)["message_workspace_buffers"] = string_tuple(execution->workspace_buffers); - (*metadata)["message_workspace_buffer_bytes"] = string_tuple(execution->workspace_buffer_bytes); - (*metadata)["message_workspace_peak_bytes"] = py::make_tuple(std::to_string(execution->workspace_peak_bytes)); - (*metadata)["message_workspace_mode"] = py::make_tuple(execution->workspace_mode); - (*metadata)["message_workspace_aliases"] = - execution->workspace_aliases.empty() ? py::make_tuple("none") : string_tuple(execution->workspace_aliases); - (*metadata)["message_per_bucket_workspace_bytes"] = - execution->per_bucket_workspace_bytes.empty() - ? py::make_tuple("none") - : string_tuple(execution->per_bucket_workspace_bytes); - return; - } - - const bool sparse_like = - spatial_id == static_cast(SpatialOwnership::EdgeOwned) || message_backend_id == 1; - (*metadata)["message_bucket_count"] = py::make_tuple(sparse_like ? "1" : "0"); - (*metadata)["message_regular_local_bucket_count"] = py::make_tuple("0"); - (*metadata)["message_sparse_bucket_count"] = py::make_tuple(sparse_like ? "1" : "0"); - (*metadata)["message_batched_backend_count"] = py::make_tuple("0"); - (*metadata)["message_grouped_backend_count"] = py::make_tuple("0"); - (*metadata)["message_reset_aware_bucket_count"] = py::make_tuple("0"); - (*metadata)["message_degree_uniform_bucket_count"] = py::make_tuple("0"); - (*metadata)["message_ragged_grouped_bucket_count"] = py::make_tuple("0"); - (*metadata)["message_demoted_bucket_count"] = py::make_tuple(sparse_like ? "1" : "0"); - (*metadata)["message_bucket_signatures"] = py::make_tuple("demoted_sparse_message"); - (*metadata)["message_bucket_kinds"] = - py::make_tuple(sparse_like ? "degree_bucketed_sparse" : "none"); - (*metadata)["message_topology_kinds"] = py::make_tuple( - spatial_id == static_cast(SpatialOwnership::EdgeOwned) - ? "edge_owned_sparse" - : message_backend_id == 1 ? "receiver_owned_sparse" : "demoted"); - (*metadata)["message_spatial_ownership"] = - py::make_tuple(spatial_id == static_cast(SpatialOwnership::EdgeOwned) ? "edge_owned" : "receiver_owned"); - (*metadata)["message_degree_bucket_lists"] = py::make_tuple("none"); - (*metadata)["message_logit_backends"] = py::make_tuple("demoted"); - (*metadata)["message_softmax_backends"] = py::make_tuple("demoted"); - (*metadata)["message_weighted_value_backends"] = py::make_tuple("demoted"); - (*metadata)["message_physical_mode"] = py::make_tuple("demoted"); - (*metadata)["message_execution_mode"] = py::make_tuple("demoted"); - (*metadata)["message_output_boundary"] = py::make_tuple("projected_message"); - (*metadata)["message_degree"] = py::make_tuple("0"); - (*metadata)["message_k"] = py::make_tuple("0"); - (*metadata)["message_v"] = py::make_tuple("0"); - (*metadata)["message_projected_n"] = py::make_tuple("0"); - (*metadata)["message_reset_policies"] = py::make_tuple("none"); - (*metadata)["message_reset_scopes"] = py::make_tuple("none"); - (*metadata)["message_use_delay"] = py::make_tuple("false"); - (*metadata)["message_distance_penalty_kinds"] = py::make_tuple("offset_distance"); - (*metadata)["message_epilogue_kinds"] = py::make_tuple("segment_softmax_weighted_sum"); - (*metadata)["message_packed_source_reuse_count"] = py::make_tuple("0"); - (*metadata)["message_demotions"] = py::make_tuple(message_demoted_reason(message_backend_id, spatial_id)); - (*metadata)["message_workspace_buffers"] = py::make_tuple("none"); - (*metadata)["message_workspace_buffer_bytes"] = py::make_tuple("none"); - (*metadata)["message_workspace_peak_bytes"] = py::make_tuple("0"); - (*metadata)["message_workspace_mode"] = py::make_tuple("none"); - (*metadata)["message_workspace_aliases"] = py::make_tuple("none"); - (*metadata)["message_per_bucket_workspace_bytes"] = py::make_tuple("none"); -} - -bool backend_uses_cublas(const char* backend) { - const std::string name(backend); - return name == "large_gemm" || name == "batched_gemm" || name == "grouped_gemm"; -} - -int64_t launch_if_ran(const char* backend) { - const std::string name(backend); - return name == "none" || name == "skip" || name == "unrun" || name == "fused_into_tiny_message" ? 0 : 1; -} - -int64_t dense_message_launch_count( - const fabric::cuda::ops::DenseMessageExecution* execution, - bool execution_recorded, - int message_backend_id, - int spatial_id, - int64_t time_steps) { - if (execution_recorded) { - return execution->launch_count * time_steps; - } - if (spatial_id == static_cast(SpatialOwnership::EdgeOwned)) { - return 3 * time_steps; - } - return (message_backend_id == 0 ? 1 : 0) * time_steps; -} - -int64_t state_affine_receiver_chunks( - const std::vector& specs, - const DenseStateAffineWorkspace& workspace) { - if (specs.empty()) { - return 1; - } - return std::max(1, workspace.receiver_chunks); -} - -const char* state_epilogue_policy_name(fabric::cuda::nn::StateEpiloguePolicy policy) { - switch (policy) { - case fabric::cuda::nn::StateEpiloguePolicy::Separate: - return "separate_state_update_reduce_emit"; - case fabric::cuda::nn::StateEpiloguePolicy::FusedNoReductionSameChunk: - return "fused_no_reduction_state_update_emit"; - } - return "unknown"; -} - -bool uses_fused_state_epilogue( - const std::vector& state_affine_specs, - fabric::cuda::nn::StateEpiloguePolicy policy) { - return state_affine_specs.empty() && - policy == fabric::cuda::nn::StateEpiloguePolicy::FusedNoReductionSameChunk; -} - -const char* physical_op_kind_name(fabric::cuda::nn::PhysicalOpKind kind) { - switch (kind) { - case fabric::cuda::nn::PhysicalOpKind::LoweredMessagePipeline: - return "lowered_message_pipeline"; - case fabric::cuda::nn::PhysicalOpKind::LoweredInputProjection: - return "lowered_input_projection"; - case fabric::cuda::nn::PhysicalOpKind::LoweredStateAffines: - return "lowered_state_affines"; - case fabric::cuda::nn::PhysicalOpKind::LoweredStateEpilogue: - return "lowered_state_epilogue"; - case fabric::cuda::nn::PhysicalOpKind::LoweredPublicProjection: - return "lowered_public_projection"; - case fabric::cuda::nn::PhysicalOpKind::LoweredReadoutMessage: - return "lowered_readout_message"; - case fabric::cuda::nn::PhysicalOpKind::LoweredReadoutProjection: - return "lowered_readout_projection"; - case fabric::cuda::nn::PhysicalOpKind::ReceiverAffineSuperOp: - return "receiver_affine_superop"; - case fabric::cuda::nn::PhysicalOpKind::TinyMessageSuperOp: - return "tiny_message_superop"; - case fabric::cuda::nn::PhysicalOpKind::SparseMessageSuperOp: - return "sparse_message_superop"; - case fabric::cuda::nn::PhysicalOpKind::DiagonalRecurrenceSuperOp: - return "diagonal_recurrence_superop"; - case fabric::cuda::nn::PhysicalOpKind::ReceiverAffineSuperOpBackward: - return "receiver_affine_superop_backward"; - case fabric::cuda::nn::PhysicalOpKind::TinyMessageSuperOpBackward: - return "tiny_message_superop_backward"; - case fabric::cuda::nn::PhysicalOpKind::SparseMessageSuperOpBackward: - return "sparse_message_superop_backward"; - case fabric::cuda::nn::PhysicalOpKind::DiagonalRecurrenceSuperOpBackward: - return "diagonal_recurrence_superop_backward"; - case fabric::cuda::nn::PhysicalOpKind::BackwardGlue: - return "backward_glue"; - } - return "unknown"; -} - -const char* physical_executor_name(fabric::cuda::nn::PhysicalExecutorKind executor) { - switch (executor) { - case fabric::cuda::nn::PhysicalExecutorKind::LoweredPhaseExecutor: - return "lowered_phase_executor"; - case fabric::cuda::nn::PhysicalExecutorKind::PhysicalSuperOpExecutor: - return "physical_superop_executor"; - case fabric::cuda::nn::PhysicalExecutorKind::ExplicitDemotion: - return "explicit_demotion"; - } - return "unknown"; -} - -const char* physical_layout_mode_name(fabric::cuda::nn::PhysicalLayoutMode mode) { - switch (mode) { - case fabric::cuda::nn::PhysicalLayoutMode::TargetBackendNativeReceiverMajor: - return "target_backend_native_receiver_major_runtime_mixed"; - } - return "unknown"; -} - -const char* copy_elision_mode_name(fabric::cuda::nn::CopyElisionMode mode) { - switch (mode) { - case fabric::cuda::nn::CopyElisionMode::PlannerVisibleNoCopyGlue: - return "planner_counted_no_copy_glue"; - case fabric::cuda::nn::CopyElisionMode::PlannerVisibleResidualCopyGlue: - return "planner_counted_residual_copy_glue"; - } - return "unknown"; -} - -const char* bias_fusion_mode_name(fabric::cuda::nn::BiasFusionMode mode) { - switch (mode) { - case fabric::cuda::nn::BiasFusionMode::PlannerVisibleNoStandaloneBias: - return "planner_counted_no_standalone_bias"; - case fabric::cuda::nn::BiasFusionMode::PlannerVisibleResidualBiasGlue: - return "planner_counted_residual_bias_glue"; - } - return "unknown"; -} - -const char* sparse_message_backend_family(const fabric::cuda::ops::DenseMessageExecution& execution) { - if (execution.degree_uniform_bucket_count > 0 && execution.grouped_backend_count == 0) { - return "degree_uniform_batched"; - } - if (execution.ragged_grouped_bucket_count > 0 && execution.batched_backend_count == 0) { - return "ragged_grouped"; - } - if (execution.batched_backend_count > 0 && execution.grouped_backend_count > 0) { - return "mixed_batched_grouped"; - } - return "sparse_bucketed"; -} - -fabric::cuda::nn::PhysicalExecutionPlan make_runtime_physical_execution_plan( - const fabric::cuda::ops::DenseMessageExecution* dense_message_execution, - bool dense_message_execution_recorded, - const TinyMessageSuperOpPlan& tiny_message_superop_plan, - bool has_state_affines, - const ReceiverAffineSuperOpPlan& receiver_affine_superop_plan, - const DiagonalRecurrenceSuperOpPlan& diagonal_recurrence_superop_plan, - int64_t public_copy_glue_launches, - int64_t public_bias_glue_launches) { - fabric::cuda::nn::PhysicalExecutionPlan plan{}; - plan.layout_mode = fabric::cuda::nn::PhysicalLayoutMode::TargetBackendNativeReceiverMajor; - plan.copy_elision_mode = public_copy_glue_launches == 0 - ? fabric::cuda::nn::CopyElisionMode::PlannerVisibleNoCopyGlue - : fabric::cuda::nn::CopyElisionMode::PlannerVisibleResidualCopyGlue; - plan.bias_fusion_mode = public_bias_glue_launches == 0 - ? fabric::cuda::nn::BiasFusionMode::PlannerVisibleNoStandaloneBias - : fabric::cuda::nn::BiasFusionMode::PlannerVisibleResidualBiasGlue; - const bool sparse_message_superop_active = - !tiny_message_superop_plan.active && dense_message_execution_recorded && dense_message_execution != nullptr && - dense_message_execution->sparse_bucket_count > 0 && dense_message_execution->demoted_bucket_count == 0; - fabric::cuda::nn::PhysicalOpKind message_op_kind = fabric::cuda::nn::PhysicalOpKind::LoweredMessagePipeline; - fabric::cuda::nn::PhysicalExecutorKind message_executor = - fabric::cuda::nn::PhysicalExecutorKind::LoweredPhaseExecutor; - std::string message_demotion_reason = tiny_message_superop_plan.demotion_reason; - std::string message_applicability_predicate = - dense_message_execution_recorded ? "dense_message_bucket_declared_in_fabric_cuda_nn" - : "legacy_message_path_explicitly_demoted"; - if (tiny_message_superop_plan.active) { - message_op_kind = fabric::cuda::nn::PhysicalOpKind::TinyMessageSuperOp; - message_executor = fabric::cuda::nn::PhysicalExecutorKind::PhysicalSuperOpExecutor; - message_demotion_reason = "none"; - message_applicability_predicate = tiny_message_superop_plan.applicability_predicate; - } else if (sparse_message_superop_active) { - message_op_kind = fabric::cuda::nn::PhysicalOpKind::SparseMessageSuperOp; - message_executor = fabric::cuda::nn::PhysicalExecutorKind::PhysicalSuperOpExecutor; - message_demotion_reason = "none"; - message_applicability_predicate = - std::string(dense_message_execution->spatial_ownership) + "_" + - dense_message_execution->bucket_kind + "_" + - sparse_message_backend_family(*dense_message_execution) + "_receiver_major_" + - dense_message_execution->reset_scope + "_reset"; - } - plan.ops.push_back(fabric::cuda::nn::make_physical_op_plan( - message_op_kind, - {"message_query", "message_key", "message_value"}, - {"projected_message"}, - dense_message_execution_recorded ? fabric::cuda::nn::ResetPolicy::ZeroSourceRows - : fabric::cuda::nn::ResetPolicy::None, - dense_message_execution_recorded ? fabric::cuda::nn::ResetScope::BatchRow - : fabric::cuda::nn::ResetScope::None, - fabric::cuda::nn::WorkspaceAliasClass::PhaseReusable, - message_executor, - message_demotion_reason, - "projected_message", - message_applicability_predicate, - "message")); - if (!tiny_message_superop_plan.active) { - plan.ops.push_back(fabric::cuda::nn::make_physical_op_plan( - fabric::cuda::nn::PhysicalOpKind::LoweredInputProjection, - {"message"}, - {"projected_message"}, - fabric::cuda::nn::ResetPolicy::None, - fabric::cuda::nn::ResetScope::None, - fabric::cuda::nn::WorkspaceAliasClass::PhaseReusable, - fabric::cuda::nn::PhysicalExecutorKind::LoweredPhaseExecutor, - "none", - "projected_message", - "input_projection_declared_in_fabric_cuda_nn", - "input_projection")); - } - if (has_state_affines) { - plan.ops.push_back(fabric::cuda::nn::make_physical_op_plan( - receiver_affine_superop_plan.active ? fabric::cuda::nn::PhysicalOpKind::ReceiverAffineSuperOp - : fabric::cuda::nn::PhysicalOpKind::LoweredStateAffines, - {"projected_message", "state_prev"}, - {"state_affine_output"}, - fabric::cuda::nn::ResetPolicy::ZeroSourceRows, - fabric::cuda::nn::ResetScope::BatchRow, - fabric::cuda::nn::WorkspaceAliasClass::StateAffineContributions, - receiver_affine_superop_plan.active ? fabric::cuda::nn::PhysicalExecutorKind::PhysicalSuperOpExecutor - : fabric::cuda::nn::PhysicalExecutorKind::LoweredPhaseExecutor, - receiver_affine_superop_plan.demotion_reason, - "state_affine_output", - receiver_affine_superop_plan.applicability_predicate, - "state_affine")); - } - const bool diagonal_recurrence_active = diagonal_recurrence_superop_plan.active; - const std::string state_epilogue_demotion = - diagonal_recurrence_active || - diagonal_recurrence_superop_plan.demotion_reason == - "diagonal_recurrence_superop_ineligible:no_declaration" - ? "none" - : diagonal_recurrence_superop_plan.demotion_reason; - plan.ops.push_back(fabric::cuda::nn::make_physical_op_plan( - diagonal_recurrence_active ? fabric::cuda::nn::PhysicalOpKind::DiagonalRecurrenceSuperOp - : fabric::cuda::nn::PhysicalOpKind::LoweredStateEpilogue, - has_state_affines ? std::vector{"state_affine_output"} : std::vector{"projected_message"}, - {"raw_public"}, - diagonal_recurrence_active ? diagonal_recurrence_superop_plan.declaration.reset_policy - : fabric::cuda::nn::ResetPolicy::None, - diagonal_recurrence_active ? diagonal_recurrence_superop_plan.declaration.reset_scope - : fabric::cuda::nn::ResetScope::None, - fabric::cuda::nn::WorkspaceAliasClass::PhaseReusable, - diagonal_recurrence_active ? fabric::cuda::nn::PhysicalExecutorKind::PhysicalSuperOpExecutor - : fabric::cuda::nn::PhysicalExecutorKind::LoweredPhaseExecutor, - state_epilogue_demotion, - "raw_public", - diagonal_recurrence_active ? diagonal_recurrence_superop_plan.applicability_predicate - : "state_epilogue_policy_declared_in_fabric_cuda_nn", - "state_epilogue")); - plan.ops.push_back(fabric::cuda::nn::make_physical_op_plan( - fabric::cuda::nn::PhysicalOpKind::LoweredPublicProjection, - {"raw_public"}, - {"public"}, - fabric::cuda::nn::ResetPolicy::None, - fabric::cuda::nn::ResetScope::None, - fabric::cuda::nn::WorkspaceAliasClass::PhaseReusable, - fabric::cuda::nn::PhysicalExecutorKind::LoweredPhaseExecutor, - "none", - "public", - "public_projection_declared_in_fabric_cuda_nn", - "public_projection")); - plan.ops.push_back(fabric::cuda::nn::make_physical_op_plan( - fabric::cuda::nn::PhysicalOpKind::LoweredReadoutMessage, - {"public", "input_ports"}, - {"readout_message"}, - fabric::cuda::nn::ResetPolicy::None, - fabric::cuda::nn::ResetScope::None, - fabric::cuda::nn::WorkspaceAliasClass::PhaseReusable, - fabric::cuda::nn::PhysicalExecutorKind::LoweredPhaseExecutor, - "none", - "readout_message", - "readout_message_routing_glue", - "readout_message")); - plan.ops.push_back(fabric::cuda::nn::make_physical_op_plan( - fabric::cuda::nn::PhysicalOpKind::LoweredReadoutProjection, - {"readout_message"}, - {"readout"}, - fabric::cuda::nn::ResetPolicy::None, - fabric::cuda::nn::ResetScope::None, - fabric::cuda::nn::WorkspaceAliasClass::PhaseReusable, - fabric::cuda::nn::PhysicalExecutorKind::LoweredPhaseExecutor, - "none", - "readout", - "readout_projection_declared_in_fabric_cuda_nn", - "readout_projection")); - return plan; -} - -void set_physical_op_plan_metadata( - py::dict* metadata, - const fabric::cuda::nn::PhysicalExecutionPlan& physical_plan, - const std::vector& physical_op_launch_counts, - const std::vector& physical_op_saved_launch_counts, - const std::vector& standalone_copy_kernel_count, - const std::vector& standalone_bias_kernel_count) { - std::vector kinds; - std::vector executors; - std::vector demotions; - std::vector boundary_contracts; - std::vector applicability_predicates; - kinds.reserve(physical_plan.ops.size()); - executors.reserve(physical_plan.ops.size()); - boundary_contracts.reserve(physical_plan.ops.size()); - applicability_predicates.reserve(physical_plan.ops.size()); - for (const fabric::cuda::nn::PhysicalOpPlan& op : physical_plan.ops) { - const std::string kind = physical_op_kind_name(op.kind); - kinds.push_back(kind); - executors.push_back(physical_executor_name(op.executor)); - boundary_contracts.push_back(kind + ":" + op.boundary_contract); - applicability_predicates.push_back(kind + ":" + op.applicability_predicate); - if (op.demotion_reason != "none") { - demotions.push_back(kind + ":" + op.demotion_reason + "->" + physical_executor_name(op.executor)); - } - } - const std::vector physical_layout_contracts{ - "target_state=backend_native_receiver_major", - "target_public=backend_native_receiver_major", - "target_projected_message=backend_native_receiver_major", - "target_state_affine_output=backend_native_receiver_major", - "target_readout=backend_native_receiver_major", - }; - (*metadata)["physical_op_kinds"] = string_tuple(kinds); - (*metadata)["physical_op_executors"] = string_tuple(executors); - (*metadata)["physical_op_demotions"] = demotions.empty() ? py::make_tuple("none") : string_tuple(demotions); - (*metadata)["physical_boundary_contracts"] = string_tuple(boundary_contracts); - (*metadata)["physical_applicability_predicates"] = string_tuple(applicability_predicates); - (*metadata)["physical_layout_contracts"] = string_tuple(physical_layout_contracts); - (*metadata)["layout_mode"] = py::make_tuple(physical_layout_mode_name(physical_plan.layout_mode)); - (*metadata)["copy_elision_mode"] = py::make_tuple(copy_elision_mode_name(physical_plan.copy_elision_mode)); - (*metadata)["bias_fusion_mode"] = py::make_tuple(bias_fusion_mode_name(physical_plan.bias_fusion_mode)); - (*metadata)["physical_op_launch_counts"] = string_tuple(physical_op_launch_counts); - (*metadata)["physical_op_saved_launch_counts"] = string_tuple(physical_op_saved_launch_counts); - (*metadata)["standalone_copy_kernel_count"] = string_tuple(standalone_copy_kernel_count); - (*metadata)["standalone_bias_kernel_count"] = string_tuple(standalone_bias_kernel_count); -} - -void set_launch_granularity_metadata( - py::dict* metadata, - const fabric::cuda::ops::DenseMessageExecution* dense_message_execution, - bool dense_message_execution_recorded, - const TinyMessageSuperOpPlan& tiny_message_superop_plan, - const std::vector& state_affine_specs, - const DenseStateAffineWorkspace& state_affine_workspace, - const ReceiverAffineSuperOpPlan& receiver_affine_superop_plan, - const DiagonalRecurrenceSuperOpPlan& diagonal_recurrence_superop_plan, - const fabric::cuda::ops::DiagonalRecurrenceExecution* diagonal_recurrence_execution, - bool diagonal_recurrence_execution_recorded, - int64_t diagonal_recurrence_launch_count, - int message_backend_id, - int spatial_id, - int64_t time_steps, - int reduction_stats_dim, - int num_hidden_chunks, - fabric::cuda::nn::StateEpiloguePolicy state_epilogue_policy, - const char* input_projection_backend, - const PublicProjectionExecution& public_projection, - const char* readout_projection_backend) { - const int64_t steps = std::max(1, time_steps); - const int64_t message_launches = dense_message_launch_count( - dense_message_execution, - dense_message_execution_recorded, - message_backend_id, - spatial_id, - steps); - const int64_t message_small_cublas = dense_message_execution_recorded - ? dense_message_execution->small_cublas_launch_count * steps - : 0; - const int64_t old_split_message_launches = dense_message_execution_recorded && tiny_message_superop_plan.active - ? 6 * std::max(1, dense_message_execution->receiver_chunk_count) * steps - : message_launches; - const int64_t message_saved_launches = std::max(0, old_split_message_launches - message_launches); - const int64_t state_receiver_chunks = state_affine_receiver_chunks(state_affine_specs, state_affine_workspace); - const int64_t state_affine_old_gemm_launches = - static_cast(state_affine_specs.size()) * state_receiver_chunks * steps; - const int64_t state_affine_superop_launches = - receiver_affine_superop_plan.active ? state_receiver_chunks * steps : 0; - const int64_t state_affine_gemm_launches = - receiver_affine_superop_plan.active ? state_affine_superop_launches : state_affine_old_gemm_launches; - const int64_t state_affine_reset_pack_launches = - static_cast(state_affine_workspace.reset_packed_sources.size()) * state_receiver_chunks * steps; - const int64_t state_affine_launches = state_affine_gemm_launches + state_affine_reset_pack_launches; - const int64_t state_affine_saved_launches = - receiver_affine_superop_plan.active ? state_affine_old_gemm_launches - state_affine_superop_launches : 0; - const bool diagonal_recurrence_active = - diagonal_recurrence_superop_plan.active && diagonal_recurrence_execution_recorded && - diagonal_recurrence_execution != nullptr; - const bool fused_state_epilogue = uses_fused_state_epilogue(state_affine_specs, state_epilogue_policy); - const bool single_chunk_reduction_alias = reduction_stats_dim > 0 && num_hidden_chunks <= 1; - const bool single_chunk_reduction_update_emit = single_chunk_reduction_alias; - const int64_t baseline_reduce_emit_launches = 1 + (reduction_stats_dim > 0 ? 1 : 0); - const int64_t actual_reduce_emit_launches = - single_chunk_reduction_update_emit ? 0 : 1 + (reduction_stats_dim > 0 ? 1 : 0); - const int64_t separate_state_epilogue_launches = - (state_affine_specs.empty() ? 1 + actual_reduce_emit_launches - : state_receiver_chunks + actual_reduce_emit_launches) * - steps; - const int64_t baseline_state_epilogue_launches = - (state_affine_specs.empty() ? 1 + baseline_reduce_emit_launches - : state_receiver_chunks + baseline_reduce_emit_launches) * - steps; - const int64_t state_epilogue_launches = - diagonal_recurrence_active ? diagonal_recurrence_launch_count - : fused_state_epilogue ? steps : separate_state_epilogue_launches; - const int64_t state_epilogue_saved_launches = - std::max(0, baseline_state_epilogue_launches - state_epilogue_launches); - const int64_t input_projection_launches = launch_if_ran(input_projection_backend) * steps; - const int64_t input_projection_small_cublas = - (backend_uses_cublas(input_projection_backend) ? 1 : 0) * steps; - const int64_t public_projection_launches = public_projection.launch_count * steps; - const int64_t public_projection_small_cublas = public_projection.small_cublas_launch_count * steps; - const int64_t public_copy_glue_launches = public_projection.copy_glue_launch_count * steps; - const int64_t public_copy_glue_saved_launches = public_projection.copy_glue_saved_launches * steps; - const int64_t public_bias_glue_launches = public_projection.bias_glue_launch_count * steps; - const int64_t public_bias_glue_saved_launches = public_projection.bias_glue_saved_launches * steps; - const int64_t readout_projection_launches = - std::string(readout_projection_backend) == "skip" ? 0 : 2 * steps; - const int64_t readout_projection_small_cublas = - (backend_uses_cublas(readout_projection_backend) ? 1 : 0) * steps; - const int64_t total_launches = - message_launches + input_projection_launches + state_affine_launches + state_epilogue_launches + - public_projection_launches + readout_projection_launches; - const int64_t total_small_cublas = - message_small_cublas + input_projection_small_cublas + - (receiver_affine_superop_plan.active ? 0 : state_affine_gemm_launches) + - public_projection_small_cublas + readout_projection_small_cublas; - const std::vector phase_launch_counts{ - "message:" + std::to_string(message_launches), - "input_projection:" + std::to_string(input_projection_launches), - "state_affine:" + std::to_string(state_affine_launches), - "state_epilogue:" + std::to_string(state_epilogue_launches), - "public_projection:" + std::to_string(public_projection_launches), - "readout:" + std::to_string(readout_projection_launches), - "total:" + std::to_string(total_launches), - }; - const std::vector small_cublas_launch_counts{ - "message:" + std::to_string(message_small_cublas), - "input_projection:" + std::to_string(input_projection_small_cublas), - "state_affine:" + std::to_string(receiver_affine_superop_plan.active ? 0 : state_affine_gemm_launches), - "public_projection:" + std::to_string(public_projection_small_cublas), - "readout:" + std::to_string(readout_projection_small_cublas), - "total:" + std::to_string(total_small_cublas), - }; - const std::vector copy_glue_launch_counts{ - "public_projection:" + std::to_string(public_copy_glue_launches), - "total:" + std::to_string(public_copy_glue_launches), - }; - const std::vector copy_glue_saved_launch_counts{ - "public_projection:" + std::to_string(public_copy_glue_saved_launches), - "total:" + std::to_string(public_copy_glue_saved_launches), - }; - const std::vector bias_glue_launch_counts{ - "public_projection:" + std::to_string(public_bias_glue_launches), - "total:" + std::to_string(public_bias_glue_launches), - }; - const std::vector bias_glue_saved_launch_counts{ - "public_projection:" + std::to_string(public_bias_glue_saved_launches), - "total:" + std::to_string(public_bias_glue_saved_launches), - }; - const std::vector state_epilogue_saved_launch_counts{ - "state_epilogue:" + std::to_string(state_epilogue_saved_launches), - "total:" + std::to_string(state_epilogue_saved_launches), - }; - const std::vector physical_op_launch_counts{ - "message:" + std::to_string(message_launches), - "input_projection:" + std::to_string(input_projection_launches), - "state_affine:" + std::to_string(state_affine_launches), - "state_epilogue:" + std::to_string(state_epilogue_launches), - "public_projection:" + std::to_string(public_projection_launches), - "readout:" + std::to_string(readout_projection_launches), - "total:" + std::to_string(total_launches), - }; - const int64_t total_saved_launches = - message_saved_launches + public_copy_glue_saved_launches + public_bias_glue_saved_launches + - state_epilogue_saved_launches + state_affine_saved_launches; - const std::vector physical_op_saved_launch_counts{ - "message:" + std::to_string(message_saved_launches), - "receiver_affine:" + std::to_string(state_affine_saved_launches), - "copy_glue:" + std::to_string(public_copy_glue_saved_launches), - "bias_glue:" + std::to_string(public_bias_glue_saved_launches), - "state_epilogue:" + std::to_string(state_epilogue_saved_launches), - "total:" + std::to_string(total_saved_launches), - }; - const fabric::cuda::nn::PhysicalExecutionPlan physical_plan = make_runtime_physical_execution_plan( - dense_message_execution, - dense_message_execution_recorded, - tiny_message_superop_plan, - !state_affine_specs.empty(), - receiver_affine_superop_plan, - diagonal_recurrence_superop_plan, - public_copy_glue_launches, - public_bias_glue_launches); - const std::string state_epilogue_mode = diagonal_recurrence_active - ? "diagonal_recurrence_superop" - : fused_state_epilogue ? std::string(state_epilogue_policy_name(state_epilogue_policy)) - : single_chunk_reduction_update_emit - ? "single_chunk_reduction_update_emit" - : std::string(state_epilogue_policy_name(state_epilogue_policy)); - const std::vector launch_coalescing_modes{ - std::string("message:") + - (tiny_message_superop_plan.active ? "tiny_message_direct_projected" - : dense_message_execution_recorded ? "signature_bucketed" - : "not_dense_lowered"), - std::string("state_affine:") + - (receiver_affine_superop_plan.active - ? "receiver_affine_superop_v1" - : state_affine_specs.size() > 1 ? "no_identical_source_bucket" : "single_or_none"), - std::string("state_epilogue:") + - state_epilogue_mode, - std::string("public_projection:") + public_projection.coalescing_mode, - "readout:single_bucket", - }; - const std::vector generic_glue_fusion_modes{ - std::string("public_hidden_copy:") + public_projection.hidden_copy_mode, - std::string("public_kv_split:") + public_projection.kv_split_mode, - std::string("public_bias:") + - (public_bias_glue_launches == 0 ? "fused_or_absent" : "residual_standalone_bias"), - std::string("state_epilogue:") + - state_epilogue_mode, - std::string("segmented_softmax:") + - (tiny_message_superop_plan.active ? "fused_tiny_message" : "explicit_glue"), - "routing:explicit_glue", - }; - (*metadata)["phase_launch_counts"] = string_tuple(phase_launch_counts); - (*metadata)["small_cublas_launch_counts"] = string_tuple(small_cublas_launch_counts); - (*metadata)["copy_glue_launch_counts"] = string_tuple(copy_glue_launch_counts); - (*metadata)["copy_glue_saved_launch_counts"] = string_tuple(copy_glue_saved_launch_counts); - (*metadata)["bias_glue_launch_counts"] = string_tuple(bias_glue_launch_counts); - (*metadata)["bias_glue_saved_launch_counts"] = string_tuple(bias_glue_saved_launch_counts); - (*metadata)["state_epilogue_modes"] = py::make_tuple(state_epilogue_mode); - (*metadata)["state_epilogue_saved_launch_counts"] = string_tuple(state_epilogue_saved_launch_counts); - (*metadata)["launch_coalescing_modes"] = string_tuple(launch_coalescing_modes); - (*metadata)["generic_glue_fusion_modes"] = string_tuple(generic_glue_fusion_modes); - (*metadata)["launch_granularity_modes"] = - py::make_tuple("phase_accounted", "copy_glue_fused", "coalescing_metadata_visible"); - (*metadata)["receiver_affine_superop_surface_count"] = - py::make_tuple(receiver_affine_superop_plan.active ? "1" : "0"); - (*metadata)["receiver_affine_superop_receivers"] = py::make_tuple( - state_affine_specs.empty() ? "0" : std::to_string(state_affine_specs.front().op.signature.M)); - (*metadata)["receiver_affine_superop_k"] = - state_affine_specs.empty() - ? py::make_tuple("none") - : [&state_affine_specs]() { - std::vector values; - values.reserve(state_affine_specs.size()); - for (const StateAffineDeclaration& spec : state_affine_specs) { - values.push_back(std::to_string(spec.op.signature.K)); - } - return string_tuple(values); - }(); - (*metadata)["receiver_affine_superop_n"] = py::make_tuple( - state_affine_specs.empty() ? "0" : std::to_string(state_affine_specs.front().op.signature.N)); - (*metadata)["receiver_affine_superop_source_layout"] = py::make_tuple( - state_affine_specs.empty() ? "none" : "backend_native_receiver_major"); - (*metadata)["receiver_affine_superop_reset_policy"] = - state_affine_specs.empty() - ? py::make_tuple("none") - : [&state_affine_specs]() { - std::vector values; - values.reserve(state_affine_specs.size()); - for (const StateAffineDeclaration& spec : state_affine_specs) { - values.push_back(reset_policy_name(spec.op.signature.reset_policy)); - } - return string_tuple(values); - }(); - (*metadata)["receiver_affine_superop_executor"] = py::make_tuple( - state_affine_specs.empty() - ? "none" - : receiver_affine_superop_plan.active ? "physical_superop_executor" : "lowered_phase_executor"); - (*metadata)["receiver_affine_superop_physical_mode"] = py::make_tuple( - state_affine_specs.empty() - ? "none" - : receiver_affine_superop_physical_mode(receiver_affine_superop_plan)); - (*metadata)["receiver_affine_superop_demotion_reason"] = - py::make_tuple(receiver_affine_superop_plan.demotion_reason); - (*metadata)["diagonal_recurrence_superop_surface_count"] = - py::make_tuple(diagonal_recurrence_active ? "1" : "0"); - (*metadata)["diagonal_recurrence_kind"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->recurrence_kind - : diagonal_recurrence_superop_plan.active - ? diagonal_recurrence_kind_name(diagonal_recurrence_superop_plan.declaration.kind) - : "none"); - (*metadata)["diagonal_recurrence_executor"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->executor : "lowered_phase_executor"); - (*metadata)["diagonal_recurrence_physical_mode"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->physical_mode : "none"); - (*metadata)["diagonal_recurrence_coeff_cache_mode"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->coeff_cache_mode : "none"); - (*metadata)["diagonal_recurrence_coeff_cache_hit"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->coeff_cache_hit : "false"); - (*metadata)["diagonal_recurrence_coeff_cache_bytes"] = py::make_tuple( - diagonal_recurrence_active ? std::to_string(diagonal_recurrence_execution->workspace_peak_bytes) : "0"); - (*metadata)["diagonal_recurrence_coeff_cache_version_source"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->coeff_cache_version_source : "none"); - (*metadata)["diagonal_recurrence_reset_policy"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->reset_policy : "none"); - (*metadata)["diagonal_recurrence_reset_scope"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->reset_scope : "none"); - (*metadata)["diagonal_recurrence_output_boundary"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->output_boundary : "none"); - (*metadata)["diagonal_recurrence_workspace_mode"] = py::make_tuple( - diagonal_recurrence_active ? diagonal_recurrence_execution->workspace_mode : "none"); - (*metadata)["diagonal_recurrence_workspace_peak_bytes"] = py::make_tuple( - diagonal_recurrence_active ? std::to_string(diagonal_recurrence_execution->workspace_peak_bytes) : "0"); - (*metadata)["diagonal_recurrence_demotion_reason"] = - py::make_tuple(diagonal_recurrence_superop_plan.demotion_reason); - (*metadata)["diagonal_recurrence_launch_count"] = - py::make_tuple(std::to_string(diagonal_recurrence_active ? diagonal_recurrence_launch_count : 0)); - set_physical_op_plan_metadata( - metadata, - physical_plan, - physical_op_launch_counts, - physical_op_saved_launch_counts, - copy_glue_launch_counts, - bias_glue_launch_counts); -} - -const char* state_affine_workspace_mode( - const std::vector& specs, - const DenseStateAffineWorkspace& workspace) { - if (specs.empty()) { - return "none"; - } - return workspace.receiver_chunks > 1 ? "combined_reset_aware_receiver_windowed" : "combined_reset_aware"; -} - -const char* state_affine_reset_mode(const std::vector& specs) { - for (const StateAffineDeclaration& spec : specs) { - if (spec.op.signature.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows) { - return "row_mask_pack"; - } - } - return "none"; -} - -void set_state_affine_metadata( - py::dict* metadata, - const std::vector& specs, - const std::vector& backends, - const std::vector& source_kinds, - const std::vector& bucket_signatures, - const std::vector& output_modes, - const std::vector& reset_policies, - const DenseStateAffineWorkspace& state_affine_workspace, - bool packed_source_reused) { - (*metadata)["state_affine_backends"] = string_tuple(backends); - (*metadata)["state_affine_sources"] = string_tuple(source_kinds); - (*metadata)["state_affine_bucket_signatures"] = string_tuple(bucket_signatures); - (*metadata)["state_affine_output_modes"] = string_tuple(output_modes); - (*metadata)["state_affine_reset_policies"] = string_tuple(reset_policies); - (*metadata)["state_affine_reset_mode"] = py::make_tuple(state_affine_reset_mode(specs)); - (*metadata)["state_affine_reset_scope"] = py::make_tuple(specs.empty() ? "none" : "batch_row"); - (*metadata)["state_affine_workspace_mode"] = py::make_tuple(state_affine_workspace_mode(specs, state_affine_workspace)); - (*metadata)["state_affine_receiver_chunk_size"] = - py::make_tuple(std::to_string(state_affine_workspace.receiver_chunk_size)); - (*metadata)["state_affine_receiver_chunks"] = - py::make_tuple(std::to_string(state_affine_workspace.receiver_chunks)); - (*metadata)["state_affine_workspace_buffers"] = string_tuple(state_affine_workspace.buffers); - (*metadata)["state_affine_workspace_buffer_bytes"] = string_tuple(state_affine_workspace.buffer_bytes); - (*metadata)["state_affine_workspace_bytes"] = py::make_tuple(std::to_string(state_affine_workspace.bytes)); - (*metadata)["state_affine_packed_source_reused"] = py::make_tuple(packed_source_reused ? "true" : "false"); -} - -const char* input_projection_note(const char* backend) { - const std::string name(backend); - if (name == "fused_into_tiny_message") { - return "tiny_message_direct_projected_boundary"; - } - if (name == "copy_fused") { - return "identity_projection_fused_copy"; - } - if (name == "copy") { - return "identity_projection_legacy_copy"; - } - return "none"; -} - -const char* message_projection_bucket_kind(int message_backend_id, int spatial_id) { - if (spatial_id == static_cast(SpatialOwnership::EdgeOwned)) { - return "sparse_projected_message_boundary"; - } - if (message_backend_id == 1) { - return "sparse_projected_message_boundary"; - } - return "regular_local_projected_message_boundary"; -} - -fabric::cuda::nn::LoweredMessageBucket lower_regular_local_message_bucket( - int64_t B, - int64_t receivers, - int64_t degree_or_block, - int64_t K, - int64_t V) { - return fabric::cuda::nn::lower_single_message_op(fabric::cuda::nn::make_message_op( - fabric::cuda::nn::regular_local_receiver_owned_message_signature(B * receivers, degree_or_block, K, V), - /*op_index=*/0, - /*same_sized_window=*/true, - /*allow_grouped=*/false)); -} - -fabric::cuda::nn::LoweredMessageBucket lower_sparse_message_bucket( - fabric::cuda::nn::MessageTopologyKind topology_kind, - int64_t B, - int64_t receivers, - int64_t degree_or_block, - int64_t K, - int64_t V) { - return fabric::cuda::nn::lower_single_message_op(fabric::cuda::nn::make_message_op( - fabric::cuda::nn::degree_bucketed_sparse_message_signature(topology_kind, B * receivers, degree_or_block, K, V), - /*op_index=*/0, - /*same_sized_window=*/true, - /*allow_grouped=*/false)); -} - -fabric::cuda::nn::LoweredMessageBucket lower_ragged_sparse_message_bucket( - fabric::cuda::nn::MessageTopologyKind topology_kind, - int64_t B, - int64_t receivers, - int64_t max_degree, - int64_t K, - int64_t V) { - return fabric::cuda::nn::lower_single_message_op( - fabric::cuda::nn::make_message_op( - fabric::cuda::nn::ragged_grouped_sparse_message_signature(topology_kind, B * receivers, max_degree, K, V), - /*op_index=*/0, - /*same_sized_window=*/false, - /*allow_grouped=*/true), - /*allow_grouped=*/true); -} - -const char* launch_dense_readout( - TensorTable input_ports, - TensorTable public_now, - ReadoutSpec readout, - const at::Tensor& readout_output_seq, - const py::tuple& readout_param_tensors, - const at::Tensor& readout_message, - ExecutionPlan plan, - int t, - int output_t, - int recurrent_receiver_offset, - cudaStream_t stream) { - RECORD_FUNCTION("fabric.physical.readout", std::vector()); - if (!readout.enabled || plan.readout_mode == ReadoutMode::Skip || plan.output_ports <= 0) { - return "skip"; - } - const at::Tensor value_to_output_weight = tuple_tensor(readout_param_tensors, 1); - const at::Tensor output_bias = tuple_tensor(readout_param_tensors, 2); - const at::Tensor output_query = tuple_tensor(readout_param_tensors, 0); - const int64_t head_dim = output_query.dim() >= 2 ? output_query.size(1) : output_query.numel(); - const int64_t value_dim = value_to_output_weight.size(1); - launch_readout_message_cuda( - input_ports, - public_now, - readout, - readout_message, - plan, - static_cast(head_dim), - static_cast(value_dim), - t, - recurrent_receiver_offset, - stream); - at::Tensor output_t_tensor = readout_output_seq.select(/*dim=*/1, /*index=*/output_t); - const bool mean_pooled_output = output_t_tensor.size(1) == 1 && readout_message.size(1) > 1; - at::Tensor readout_projection_output = - mean_pooled_output - ? at::empty( - {readout_message.size(0), readout_message.size(1), output_t_tensor.size(2)}, - readout_message.options()) - : output_t_tensor; - const auto backend = fabric::cuda::ops::dense_affine_out_cuda( - readout_message, - value_to_output_weight, - output_bias, - readout_projection_output, - fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, - /*group_size=*/1, - fabric::cuda::ops::DenseAffineOutputMode::Overwrite); - if (mean_pooled_output) { - output_t_tensor.copy_(readout_projection_output.mean(/*dim=*/1, /*keepdim=*/true)); - } - return fabric::cuda::ops::dense_affine_backend_name(backend); -} - -bool should_materialize_readout_public_window( - bool output_dependency_receiver_window, - const ReadoutSpec& readout, - const at::Tensor& output_q, - const at::Tensor& value_to_output_weight, - int64_t active_receiver_count, - int64_t raw_public_dim) { - if (!output_dependency_receiver_window || active_receiver_count <= 0 || raw_public_dim <= 0 || !readout.enabled) { - return false; - } - if (!output_q.defined() || !value_to_output_weight.defined()) { - return false; - } - const int64_t readout_edges = static_cast(readout.topology.num_edges); - const int64_t head_dim = output_q.dim() >= 2 ? output_q.size(1) : output_q.numel(); - const int64_t value_dim = value_to_output_weight.dim() >= 2 ? value_to_output_weight.size(1) : value_to_output_weight.numel(); - if (head_dim <= 0 || value_dim <= 0 || readout_edges <= 0) { - return false; - } - const int64_t raw_reproject_work = readout_edges * (2 * head_dim + value_dim) * raw_public_dim; - const int64_t materialized_window_work = active_receiver_count * (head_dim + value_dim) * raw_public_dim; - return raw_reproject_work > materialized_window_work; -} - -PublicProjectionExecution launch_materialized_readout_public_window( - int64_t batch, - int64_t active_receiver_count, - int public_projection_kind, - const at::Tensor& raw_public, - const py::tuple& public_projection_tensors, - const py::tuple& public_template_tensors, - const py::tuple& readout_param_tensors, - const at::Tensor& like, - int64_t active_receiver_offset, - RuntimeTensorTable* public_pack, - std::vector* public_keepalive, - std::vector* workspace_aliases, - std::vector* workspace_buffers, - std::vector* workspace_buffer_bytes, - int64_t* workspace_peak_bytes) { - const at::Tensor readout_query = tuple_tensor(readout_param_tensors, 0); - const at::Tensor value_to_output_weight = tuple_tensor(readout_param_tensors, 1); - const at::Tensor public_hidden_template = tuple_tensor(public_template_tensors, 0); - const int64_t public_hidden_dim = public_hidden_template.size(2); - const int64_t readout_head_dim = readout_query.dim() >= 2 ? readout_query.size(1) : readout_query.numel(); - const int64_t readout_value_dim = - value_to_output_weight.dim() >= 2 ? value_to_output_weight.size(1) : value_to_output_weight.numel(); - public_keepalive->clear(); - public_keepalive->push_back(at::empty({batch, active_receiver_count, public_hidden_dim}, like.options())); - public_keepalive->push_back(at::empty({batch, active_receiver_count, readout_head_dim}, like.options())); - public_keepalive->push_back(at::empty({batch, active_receiver_count, readout_value_dim}, like.options())); - py::tuple readout_public_tensors = - py::make_tuple((*public_keepalive)[0], (*public_keepalive)[1], (*public_keepalive)[2]); - PublicProjectionExecution public_projection = - launch_dense_public_projection( - public_projection_kind, - raw_public, - public_projection_tensors, - readout_public_tensors, - active_receiver_offset); - *public_pack = pack_runtime_tensor_table(*public_keepalive, like); - workspace_aliases->push_back("readout_public_projection=materialized_active_window"); - append_workspace_metadata( - "readout_public_hidden", - (*public_keepalive)[0], - "raw_public->readout_materialized_public_projection", - "unique", - workspace_buffers, - workspace_buffer_bytes, - workspace_peak_bytes); - append_workspace_metadata( - "readout_public_k", - (*public_keepalive)[1], - "raw_public->readout_materialized_public_projection", - "unique", - workspace_buffers, - workspace_buffer_bytes, - workspace_peak_bytes); - append_workspace_metadata( - "readout_public_v", - (*public_keepalive)[2], - "raw_public->readout_materialized_public_projection", - "unique", - workspace_buffers, - workspace_buffer_bytes, - workspace_peak_bytes); - return public_projection; -} - -const char* launch_dense_readout_from_raw_public( - TensorTable input_ports, - const at::Tensor& raw_public, - int public_projection_kind, - TensorTable public_projection_params, - ReadoutSpec readout, - const at::Tensor& readout_output_seq, - const py::tuple& readout_param_tensors, - const at::Tensor& readout_message, - ExecutionPlan plan, - int t, - int output_t, - int recurrent_receiver_offset, - cudaStream_t stream) { - RECORD_FUNCTION("fabric.physical.readout", std::vector()); - if (!readout.enabled || plan.readout_mode == ReadoutMode::Skip || plan.output_ports <= 0) { - return "skip"; - } - const at::Tensor value_to_output_weight = tuple_tensor(readout_param_tensors, 1); - const at::Tensor output_bias = tuple_tensor(readout_param_tensors, 2); - const int64_t value_dim = value_to_output_weight.size(1); - launch_readout_message_from_raw_public_cuda( - input_ports, - raw_public, - public_projection_kind, - public_projection_params, - readout, - readout_message, - plan, - static_cast(value_dim), - t, - recurrent_receiver_offset, - stream); - at::Tensor output_t_tensor = readout_output_seq.select(/*dim=*/1, /*index=*/output_t); - const bool mean_pooled_output = output_t_tensor.size(1) == 1 && readout_message.size(1) > 1; - at::Tensor readout_projection_output = - mean_pooled_output - ? at::empty( - {readout_message.size(0), readout_message.size(1), output_t_tensor.size(2)}, - readout_message.options()) - : output_t_tensor; - const auto backend = fabric::cuda::ops::dense_affine_out_cuda( - readout_message, - value_to_output_weight, - output_bias, - readout_projection_output, - fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, - /*group_size=*/1, - fabric::cuda::ops::DenseAffineOutputMode::Overwrite); - if (mean_pooled_output) { - output_t_tensor.copy_(readout_projection_output.mean(/*dim=*/1, /*keepdim=*/true)); - } - return fabric::cuda::ops::dense_affine_backend_name(backend); -} - -ExecutionPlan apply_phase_staging_limits( - ExecutionPlan plan, - int state_static_bytes, - int emit_static_bytes) { - const auto* props = at::cuda::getCurrentDeviceProperties(); - const size_t optin_limit = static_cast(props->sharedMemPerBlockOptin); - const size_t state_static_smem = - static_cast(plan.state_receiver_tile) * static_cast(state_static_bytes); - const size_t emit_static_smem = - static_cast(plan.emit_receiver_tile) * static_cast(emit_static_bytes); - if (!plan.stage_receiver_static || state_static_bytes <= 0 || state_static_smem > optin_limit) { - plan.state_static_stage_mode = CellStaticStageMode::Disabled; - } - if (!plan.stage_receiver_static || emit_static_bytes <= 0 || emit_static_smem > optin_limit) { - plan.emit_static_stage_mode = CellStaticStageMode::Disabled; - } - if ( - plan.state_static_stage_mode == CellStaticStageMode::Disabled && - plan.emit_static_stage_mode == CellStaticStageMode::Disabled) { - plan.cell_static_stage_mode = CellStaticStageMode::Disabled; - } - return plan; -} - -ExecutionPlan apply_state_epilogue_policy( - ExecutionPlan plan, - fabric::cuda::nn::StateEpiloguePolicy policy) { - if (policy != fabric::cuda::nn::StateEpiloguePolicy::FusedNoReductionSameChunk) { - return plan; - } - plan.emit_receiver_tile = plan.state_receiver_tile; - plan.emit_batch_tile = plan.state_batch_tile; - plan.emit_hidden_chunk = plan.state_hidden_chunk; - return plan; -} - -} // namespace - -void launch_receiver_state_emit_cuda( - int cell_core_id, - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int state_epilogue_policy, - float* raw_public, - float* partial_stats, - float* reduced_stats, - int num_hidden_chunks, - cudaStream_t stream) { - RECORD_FUNCTION("fabric.physical.state_epilogue", std::vector()); - const auto& entry = lookup_cell_core_dispatch_entry(cell_core_id); - entry.receiver_state_emit( - projected_message, - state_prev, - state_next, - cell_params, - aux, - plan, - resets_u8, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - state_epilogue_policy, - raw_public, - partial_stats, - reduced_stats, - num_hidden_chunks, - stream); -} - -void launch_receiver_state_update_cuda( - int cell_core_id, - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - float* partial_stats, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - RECORD_FUNCTION("fabric.physical.state_epilogue", std::vector()); - const auto& entry = lookup_cell_core_dispatch_entry(cell_core_id); - entry.receiver_state_update( - projected_message, - state_prev, - state_next, - cell_params, - aux, - plan, - resets_u8, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - partial_stats, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); -} - -void launch_receiver_state_update_emit_cuda( - int cell_core_id, - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - bool materialize_state_output, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - float* raw_public, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - RECORD_FUNCTION("fabric.physical.state_epilogue", std::vector()); - const auto& entry = lookup_cell_core_dispatch_entry(cell_core_id); - entry.receiver_state_update_emit( - projected_message, - state_prev, - state_next, - cell_params, - aux, - plan, - resets_u8, - state_prev_is_zero, - materialize_state_output, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - raw_public, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); -} - -void launch_receiver_reduce_emit_cuda( - int cell_core_id, - TensorTable state_next, - TensorTable cell_params, - ExecutionPlan plan, - int projected_message_dim, - int raw_public_dim, - int emit_static_bytes, - float* raw_public, - float* partial_stats, - float* reduced_stats, - int num_hidden_chunks, - int receiver_global_offset, - cudaStream_t stream) { - RECORD_FUNCTION("fabric.physical.state_epilogue", std::vector()); - const auto& entry = lookup_cell_core_dispatch_entry(cell_core_id); - entry.receiver_reduce_emit( - state_next, - cell_params, - plan, - projected_message_dim, - raw_public_dim, - emit_static_bytes, - raw_public, - partial_stats, - reduced_stats, - num_hidden_chunks, - receiver_global_offset, - stream); -} - -void launch_receiver_state_transition_cuda( - int cell_core_id, - const at::Tensor& projected_message_tensor, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable cell_aux, - TensorTable state_affine_aux, - const py::tuple& state_prev_tensors, - const std::vector& cell_param_vector, - const std::vector& state_affine_specs, - const DenseStateAffineWorkspace& state_affine_workspace, - const ReceiverAffineSuperOpPlan& receiver_affine_superop_plan, - const DiagonalRecurrenceSuperOpPlan& diagonal_recurrence_superop_plan, - fabric::cuda::nn::StateEpiloguePolicy state_epilogue_policy, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - const at::Tensor& raw_public_tensor, - float* raw_public, - float* partial_stats, - float* reduced_stats, - int num_hidden_chunks, - cudaStream_t stream, - bool state_prev_is_zero, - bool write_state_next, - bool write_trace_state_next, - int64_t full_receivers, - int64_t receiver_global_offset, - std::vector* state_affine_backends, - std::vector* state_affine_sources, - std::vector* state_affine_bucket_signatures, - std::vector* state_affine_output_modes, - std::vector* state_affine_reset_policies, - bool* state_affine_packed_source_reused, - fabric::cuda::ops::DiagonalRecurrenceExecution* diagonal_recurrence_execution, - bool* diagonal_recurrence_execution_recorded, - int64_t* diagonal_recurrence_launch_count) { - if (state_affine_specs.empty()) { - if (diagonal_recurrence_superop_plan.active) { - RECORD_FUNCTION("fabric.physical.state_epilogue", std::vector()); - *diagonal_recurrence_execution = - fabric::cuda::ops::diagonal_recurrence_complex_exp_update_emit_window_out_cuda( - diagonal_recurrence_superop_plan.declaration, - state_prev, - state_next, - cell_params, - cell_param_vector, - projected_message_tensor, - raw_public_tensor, - resets_u8, - state_prev_is_zero, - write_state_next, - write_trace_state_next, - diagonal_recurrence_superop_plan.receiver_offset, - t); - *diagonal_recurrence_execution_recorded = true; - *diagonal_recurrence_launch_count += diagonal_recurrence_execution->launch_count; - return; - } - launch_receiver_state_emit_cuda( - cell_core_id, - projected_message_tensor.data_ptr(), - state_prev, - state_next, - cell_params, - cell_aux, - plan, - resets_u8, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - static_cast(state_epilogue_policy), - raw_public, - partial_stats, - reduced_stats, - num_hidden_chunks, - stream); - return; - } - const bool single_chunk_reduction_update_emit = - lookup_cell_core_dispatch_entry(cell_core_id).reduction_stats_dim > 0 && num_hidden_chunks <= 1; - for (int64_t receiver_start = 0; receiver_start < plan.receivers; - receiver_start += state_affine_workspace.receiver_chunk_size) { - const int64_t receiver_count = - std::min(state_affine_workspace.receiver_chunk_size, plan.receivers - receiver_start); - if (receiver_affine_superop_plan.active) { - launch_receiver_affine_superop( - receiver_affine_superop_plan, - state_affine_specs, - projected_message_tensor, - state_prev_tensors, - cell_param_vector, - state_affine_workspace.outputs, - state_affine_workspace.receiver_affine_packed_input, - state_affine_workspace.receiver_affine_packed_weight, - resets_u8, - t, - full_receivers, - receiver_start, - receiver_global_offset, - receiver_count, - state_prev_is_zero, - state_affine_backends, - state_affine_sources, - state_affine_bucket_signatures, - state_affine_output_modes, - state_affine_reset_policies, - state_affine_packed_source_reused); - } else { - launch_dense_state_affines( - state_affine_specs, - projected_message_tensor, - state_prev_tensors, - cell_param_vector, - state_affine_workspace.outputs, - state_affine_workspace.reset_packed_sources, - resets_u8, - t, - full_receivers, - receiver_start, - receiver_global_offset, - receiver_count, - state_prev_is_zero, - state_affine_backends, - state_affine_sources, - state_affine_bucket_signatures, - state_affine_output_modes, - state_affine_reset_policies, - state_affine_packed_source_reused); - } - if (single_chunk_reduction_update_emit) { - launch_receiver_state_update_emit_cuda( - cell_core_id, - projected_message_tensor.data_ptr(), - state_prev, - state_next, - cell_params, - state_affine_aux, - plan, - resets_u8, - state_prev_is_zero, - write_state_next, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - raw_public, - num_hidden_chunks, - static_cast(receiver_start), - static_cast(receiver_global_offset), - static_cast(receiver_count), - stream); - } else { - launch_receiver_state_update_cuda( - cell_core_id, - projected_message_tensor.data_ptr(), - state_prev, - state_next, - cell_params, - state_affine_aux, - plan, - resets_u8, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - partial_stats, - num_hidden_chunks, - static_cast(receiver_start), - static_cast(receiver_global_offset), - static_cast(receiver_count), - stream); - } - } - if (!single_chunk_reduction_update_emit) { - launch_receiver_reduce_emit_cuda( - cell_core_id, - state_next, - cell_params, - plan, - projected_message_dim, - raw_public_dim, - emit_static_bytes, - raw_public, - partial_stats, - reduced_stats, - num_hidden_chunks, - static_cast(receiver_global_offset), - stream); - } -} - -bool fabric_dispatch_can_virtualize_fresh_state_cuda( - int cell_core_id, - py::tuple cell_param_tensors, - int receivers, - int projected_message_dim, - int raw_public_dim) { - const auto& dispatch_entry = lookup_cell_core_dispatch_entry(cell_core_id); - const std::vector cell_param_vector = tuple_tensors(cell_param_tensors); - const fabric::cuda::nn::CellTransitionIR cell_transition_ir = - dispatch_entry.cell_transition_ir(cell_param_vector, receivers, projected_message_dim, raw_public_dim); - if (!cell_transition_ir.state_affines.empty() || dispatch_entry.reduction_stats_dim != 0) { - return false; - } - const fabric::cuda::nn::LoweredPhaseIR phase_ir = - fabric::cuda::nn::lower_diagonal_recurrences_to_phase_ir(cell_transition_ir.diagonal_recurrences); - if (phase_ir.diagonal_recurrences.size() != 1) { - return false; - } - const auto& declaration = phase_ir.diagonal_recurrences.front(); - if (declaration.kind != fabric::cuda::nn::DiagonalRecurrenceKind::ComplexExponential2D || - declaration.layout != fabric::cuda::nn::TensorLayout::ReceiverMajor || - declaration.input_source_kind != fabric::cuda::nn::SourceKind::ProjectedMessage || - declaration.hidden_dim <= 0 || - declaration.projected_message_dim != projected_message_dim || - declaration.raw_public_dim != raw_public_dim) { - return false; - } - if (!( - (declaration.reset_policy == fabric::cuda::nn::ResetPolicy::None && - declaration.reset_scope == fabric::cuda::nn::ResetScope::None) || - (declaration.reset_policy == fabric::cuda::nn::ResetPolicy::ZeroSourceRows && - declaration.reset_scope == fabric::cuda::nn::ResetScope::BatchRow))) { - return false; - } - const auto& binding = declaration.binding; - const int max_param_index = std::max( - std::max(binding.nu_param_index, binding.theta_param_index), - std::max(std::max(binding.w1_param_index, binding.w2_param_index), binding.activation_param_index)); - if (binding.nu_param_index < 0 || binding.theta_param_index < 0 || binding.w1_param_index < 0 || - binding.w2_param_index < 0 || binding.activation_param_index < 0 || - max_param_index >= static_cast(cell_param_vector.size())) { - return false; - } - for (const int param_index : {binding.nu_param_index, binding.theta_param_index, binding.w1_param_index, binding.w2_param_index}) { - const at::Tensor& param = cell_param_vector.at(static_cast(param_index)); - if (!param.defined() || param.dim() != 2 || param.size(0) != receivers || - param.size(1) != declaration.hidden_dim || param.scalar_type() != at::kFloat) { - return false; - } - } - const at::Tensor& activation = cell_param_vector.at(static_cast(binding.activation_param_index)); - return activation.defined() && activation.scalar_type() == at::kInt; -} - -void record_forward_carry_checkpoint( - const py::tuple& checkpoint_state_tensors, - const py::tuple& checkpoint_public_tensors, - const py::tuple& state_source_tensors, - const py::tuple& public_source_tensors, - const py::tuple& state_source_indices, - const py::tuple& public_source_indices, - int checkpoint_stride, - int step_after_transition, - int time_steps) { - if (checkpoint_stride <= 0 || step_after_transition <= 0 || step_after_transition >= time_steps || - step_after_transition % checkpoint_stride != 0) { - return; - } - const int checkpoint_index = step_after_transition / checkpoint_stride - 1; - const ssize_t state_count = py::len(checkpoint_state_tensors); - TORCH_CHECK( - py::len(state_source_tensors) >= state_count, - "Fabric carry checkpoint state source count is smaller than checkpoint output count"); - for (ssize_t i = 0; i < state_count; ++i) { - at::Tensor dst = tuple_tensor(checkpoint_state_tensors, static_cast(i)); - if (!dst.defined() || dst.numel() == 0) { - continue; - } - TORCH_CHECK( - checkpoint_index >= 0 && checkpoint_index < dst.size(0), - "Fabric carry checkpoint index ", - checkpoint_index, - " out of range for state checkpoint tensor with ", - dst.size(0), - " entries"); - const int source_index = tuple_source_index_or_identity(state_source_indices, i); - TORCH_CHECK( - source_index >= 0 && source_index < state_source_tensors.size(), - "Fabric carry checkpoint state source index ", - source_index, - " out of range for ", - state_source_tensors.size(), - " source tensors"); - at::Tensor src = tuple_tensor(state_source_tensors, source_index); - dst.select(/*dim=*/0, checkpoint_index).copy_(src); - } - const ssize_t public_count = py::len(checkpoint_public_tensors); - TORCH_CHECK( - py::len(public_source_tensors) >= public_count, - "Fabric carry checkpoint public source count is smaller than checkpoint output count"); - for (ssize_t i = 0; i < public_count; ++i) { - at::Tensor dst = tuple_tensor(checkpoint_public_tensors, static_cast(i)); - if (!dst.defined() || dst.numel() == 0) { - continue; - } - TORCH_CHECK( - checkpoint_index >= 0 && checkpoint_index < dst.size(0), - "Fabric carry checkpoint index ", - checkpoint_index, - " out of range for public checkpoint tensor with ", - dst.size(0), - " entries"); - const int source_index = tuple_source_index_or_identity(public_source_indices, i); - TORCH_CHECK( - source_index >= 0 && source_index < public_source_tensors.size(), - "Fabric carry checkpoint public source index ", - source_index, - " out of range for ", - public_source_tensors.size(), - " source tensors"); - at::Tensor src = tuple_tensor(public_source_tensors, source_index); - dst.select(/*dim=*/0, checkpoint_index).copy_(src); - } -} - -py::dict fabric_dispatch_forward_cuda( - int cell_core_id, - int message_backend_id, - int spatial_id, - int temporal_id, - int public_projection_kind, - int projected_message_dim, - int raw_public_dim, - py::tuple state_prev_pack, - py::tuple state_prev_tensors, - py::tuple state_next_pack, - py::tuple state_next_tensors, - py::tuple public_prev_pack, - py::tuple public_next_pack, - py::tuple state_work_pack, - py::tuple public_work_pack, - py::tuple cell_params_pack, - py::tuple cell_param_tensors, - py::tuple input_projection_pack, - py::tuple input_projection_tensors, - py::tuple public_projection_pack, - py::tuple public_projection_tensors, - py::tuple message_params_pack, - py::tuple input_ports_pack, - py::tuple aux_pack, - py::tuple readout_output_pack, - py::tuple readout_output_tensors, - py::tuple readout_params_pack, - py::tuple readout_param_tensors, - py::tuple input_port_tensors, - py::tuple public_prev_tensors, - py::tuple public_next_tensors, - py::tuple public_work_tensors, - at::Tensor recurrent_q_raw, - at::Tensor recurrent_local_sender_idx_raw, - at::Tensor recurrent_local_distance_raw, - at::Tensor recurrent_local_delay_raw, - at::Tensor recurrent_neighbor_idx_raw, - at::Tensor recurrent_neighbor_valid_raw, - at::Tensor recurrent_edge_distance_raw, - at::Tensor recurrent_edge_delay_raw, - at::Tensor recurrent_sparse_receiver_order_raw, - at::Tensor recurrent_sparse_degree_ptr_raw, - int recurrent_sparse_positive_degree_buckets, - double distance_scale, - at::Tensor recurrent_receiver_ptr, - at::Tensor recurrent_sender_idx, - at::Tensor recurrent_edge_delay, - at::Tensor recurrent_edge_weight, - at::Tensor readout_receiver_ptr, - at::Tensor readout_sender_idx, - at::Tensor readout_edge_delay, - at::Tensor readout_edge_weight, - at::Tensor resets_u8, - int output_recurrent_window_start, - int output_recurrent_window_count, - bool output_recurrent_window_contiguous, - py::tuple forward_carry_checkpoint_state_tensors, - py::tuple forward_carry_checkpoint_public_tensors, - py::tuple forward_carry_checkpoint_state_source_indices, - py::tuple forward_carry_checkpoint_public_source_indices, - int forward_carry_checkpoint_stride, - int num_input_ports, - int B, - int T, - int receivers, - int edges, - int output_ports, - int message_dim, - int public_dim, - int receiver_tile, - int batch_tile, - int edge_tile, - int hidden_chunk, - int state_receiver_tile, - int state_batch_tile, - int state_hidden_chunk, - int state_static_stage_mode, - int emit_receiver_tile, - int emit_batch_tile, - int emit_hidden_chunk, - int emit_static_stage_mode, - int public_receiver_tile, - int public_batch_tile, - int readout_mode, - int readout_port_tile, - int readout_output_chunk, - int cell_static_stage_mode, - int replication_factor, - bool stage_receiver_static, - bool initial_state_is_fresh, - bool emit_readout, - bool materialize_final_state, - bool preserve_internal_carry, - bool compact_input_carry) { - const auto stream = at::cuda::getCurrentCUDAStream(); - const TensorTable state_prev = unpack_table(state_prev_pack); - const TensorTable state_next = unpack_table(state_next_pack); - const TensorTable public_prev = unpack_table(public_prev_pack); - const TensorTable public_next = unpack_table(public_next_pack); - const TensorTable cell_params = unpack_table(cell_params_pack); - const TensorTable public_projection_params = unpack_table(public_projection_pack); - const TensorTable message_params = unpack_table(message_params_pack); - const TensorTable input_ports = unpack_table(input_ports_pack); - const TensorTable aux = unpack_table(aux_pack); - const TensorTable readout_output = unpack_table(readout_output_pack); - const TensorTable readout_params = unpack_table(readout_params_pack); - const at::Tensor readout_output_seq = tuple_tensor(readout_output_tensors, 0); - TORCH_CHECK(readout_output_seq.dim() == 4, "Fabric readout output must be [B,T,output_ports,N]"); - const int readout_output_steps = static_cast(readout_output_seq.size(1)); - TORCH_CHECK( - readout_output_steps == T || readout_output_steps == 1, - "Fabric readout output time dimension must be either full sequence T or terminal step, got ", - readout_output_steps, - " for T=", - T); - const bool terminal_readout_boundary = readout_output_steps == 1; - auto emit_readout_for_step = [&](int t) -> bool { - return !terminal_readout_boundary || t + 1 == T; - }; - auto readout_output_index = [&](int t) -> int { - return terminal_readout_boundary ? 0 : t; - }; - const at::Tensor input_k_ref = tuple_tensor(input_port_tensors, 0); - const at::Tensor input_v_ref = tuple_tensor(input_port_tensors, 1); - (void)state_work_pack; - (void)public_work_pack; - (void)input_projection_pack; - (void)public_work_tensors; - (void)recurrent_local_delay_raw; - (void)recurrent_edge_delay_raw; - const ExecutionPlan requested_plan = make_plan( - spatial_id, - temporal_id, - B, - T, - receivers, - edges, - output_ports, - message_dim, - public_dim, - receiver_tile, - batch_tile, - edge_tile, - hidden_chunk, - state_receiver_tile, - state_batch_tile, - state_hidden_chunk, - state_static_stage_mode, - emit_receiver_tile, - emit_batch_tile, - emit_hidden_chunk, - emit_static_stage_mode, - public_receiver_tile, - public_batch_tile, - readout_mode, - readout_port_tile, - readout_output_chunk, - cell_static_stage_mode, - replication_factor, - stage_receiver_static, - emit_readout); - const MessageTopology recurrent_topology = unpack_topology( - recurrent_receiver_ptr, - recurrent_sender_idx, - recurrent_edge_delay, - recurrent_edge_weight, - num_input_ports); - const ReadoutSpec readout{ - unpack_topology(readout_receiver_ptr, readout_sender_idx, readout_edge_delay, readout_edge_weight, num_input_ports), - readout_output, - readout_params, - 0.0f, - false, - emit_readout, - }; - const auto& dispatch_entry = lookup_cell_core_dispatch_entry(cell_core_id); - const std::vector cell_param_vector = tuple_tensors(cell_param_tensors); - const int state_static_bytes = dispatch_entry.state_static_bytes(cell_param_vector, receivers); - const int emit_static_bytes = dispatch_entry.emit_static_bytes(cell_param_vector, receivers); - const fabric::cuda::nn::CellTransitionIR cell_transition_ir = - dispatch_entry.cell_transition_ir(cell_param_vector, receivers, projected_message_dim, raw_public_dim); - const std::vector& state_affine_specs = cell_transition_ir.state_affines; - ExecutionPlan plan = apply_phase_staging_limits(requested_plan, state_static_bytes, emit_static_bytes); - plan = apply_state_epilogue_policy(plan, cell_transition_ir.state_epilogue_policy); - py::dict metadata = launch_metadata(plan); - metadata["temporal_executions"] = py::make_tuple( - temporal_id == static_cast(TemporalExecution::PersistentScan) ? "persistent_scan" : "stepwise"); - metadata["scan_implementations"] = py::make_tuple( - temporal_id == static_cast(TemporalExecution::PersistentScan) ? "backend_host_loop" : "single_step"); - const bool has_sparse_message_topology = - recurrent_neighbor_idx_raw.defined() && recurrent_neighbor_idx_raw.numel() > 0 && - recurrent_neighbor_valid_raw.defined() && recurrent_neighbor_valid_raw.numel() > 0 && - recurrent_edge_distance_raw.defined() && recurrent_edge_distance_raw.numel() > 0; - const bool use_ragged_sparse_message = - message_backend_id == 1 && recurrent_sparse_positive_degree_buckets > 1; - const bool uses_dense_message_lowering = - (spatial_id == static_cast(SpatialOwnership::ReceiverOwned) && message_backend_id == 0) || - (message_backend_id == 1 && has_sparse_message_topology); - const TinyMessageSuperOpPlan tiny_message_superop_plan = select_tiny_message_superop_plan( - spatial_id, - message_backend_id, - has_sparse_message_topology, - recurrent_q_raw, - input_v_ref, - tuple_tensor(public_prev_tensors, 2), - recurrent_local_sender_idx_raw, - projected_message_dim, - input_projection_tensors); - const bool fused_state_epilogue = - uses_fused_state_epilogue(state_affine_specs, cell_transition_ir.state_epilogue_policy); - const bool output_dependency_receiver_window_base = - emit_readout && (initial_state_is_fresh || compact_input_carry) && !materialize_final_state && - output_ports > 0 && spatial_id == static_cast(SpatialOwnership::ReceiverOwned) && message_backend_id == 0 && - tiny_message_superop_plan.active && output_recurrent_window_contiguous && output_recurrent_window_start >= 0 && - output_recurrent_window_count > 0 && - static_cast(output_recurrent_window_start) + output_recurrent_window_count <= receivers; - const bool diagonal_recurrence_receiver_window_supported = - state_affine_specs.empty() && dispatch_entry.reduction_stats_dim == 0 && - cell_transition_ir.diagonal_recurrences.size() == 1; - const bool recurrent_affine_receiver_window_supported = - cell_transition_ir.diagonal_recurrences.empty() && - state_affine_receiver_window_supported( - state_affine_specs, - cell_param_vector, - receivers, - static_cast(output_recurrent_window_start), - output_recurrent_window_count); - const bool output_dependency_receiver_window = - output_dependency_receiver_window_base && - (diagonal_recurrence_receiver_window_supported || recurrent_affine_receiver_window_supported); - const int64_t active_receiver_offset = - output_dependency_receiver_window ? static_cast(output_recurrent_window_start) : 0; - const int64_t active_receiver_count = - output_dependency_receiver_window ? static_cast(output_recurrent_window_count) : receivers; - const bool write_internal_carry = materialize_final_state || preserve_internal_carry; - ExecutionPlan active_receiver_plan = plan; - active_receiver_plan.receivers = active_receiver_count; - metadata["active_receiver_window_modes"] = - py::make_tuple(output_dependency_receiver_window ? "readout_dependency_cone" : "full_surface"); - metadata["active_receiver_window_offsets"] = py::make_tuple(std::to_string(active_receiver_offset)); - metadata["active_receiver_window_counts"] = py::make_tuple(std::to_string(active_receiver_count)); - std::vector phase_names; - if (tiny_message_superop_plan.active) { - phase_names.push_back("tiny_message_direct_projected"); - } else if (uses_dense_message_lowering) { - phase_names.insert( - phase_names.end(), - {"dense_message_source_pack", - "dense_message_logits", - "message_segment_softmax_glue", - "dense_message_weighted_values"}); - } else if (spatial_id == static_cast(SpatialOwnership::ReceiverOwned)) { - phase_names.push_back("receiver_message_aggregate"); - } else { - phase_names.push_back("edge_owned_accumulate"); - phase_names.push_back("receiver_message_normalize"); - } - if (!tiny_message_superop_plan.active) { - phase_names.push_back("dense_input_projection"); - } - phase_names.push_back("dense_state_affines"); - const int num_hidden_chunks = ceil_div(std::max(1, raw_public_dim), std::max(1, state_hidden_chunk)); - const bool single_chunk_reduction_alias = - dispatch_entry.reduction_stats_dim > 0 && num_hidden_chunks <= 1; - auto projected_message_tensor = at::empty({B, active_receiver_count, projected_message_dim}, input_k_ref.options()); - auto message_tensor = tiny_message_superop_plan.active - ? at::empty({0}, input_k_ref.options()) - : at::empty({B, receivers, message_dim}, input_k_ref.options()); - const fabric::cuda::nn::LoweredPhaseIR state_affine_phase_ir = - lower_state_affine_specs_to_phase_ir(state_affine_specs); - const fabric::cuda::nn::LoweredPhaseIR diagonal_recurrence_phase_ir = - fabric::cuda::nn::lower_diagonal_recurrences_to_phase_ir(cell_transition_ir.diagonal_recurrences); - const ReceiverAffineSuperOpPlan receiver_affine_superop_plan = - select_receiver_affine_superop_plan( - state_affine_specs, - state_affine_phase_ir, - projected_message_tensor, - state_prev_tensors, - cell_param_vector, - receivers, - active_receiver_offset, - active_receiver_count); - const DiagonalRecurrenceSuperOpPlan diagonal_recurrence_superop_plan = - select_diagonal_recurrence_superop_plan( - diagonal_recurrence_phase_ir, - state_affine_specs, - cell_param_vector, - projected_message_tensor, - raw_public_dim, - dispatch_entry.reduction_stats_dim, - active_receiver_offset); - if (diagonal_recurrence_superop_plan.active) { - phase_names.push_back("diagonal_recurrence_superop"); - } else if (fused_state_epilogue || single_chunk_reduction_alias) { - phase_names.push_back("receiver_state_update_emit"); - } else { - phase_names.push_back("receiver_state_update"); - if (dispatch_entry.reduction_stats_dim > 0 && !single_chunk_reduction_alias) { - phase_names.push_back("receiver_reduce_stats"); - } - phase_names.push_back("receiver_emit_raw_public"); - } - phase_names.push_back("dense_public_projection"); - phase_names.push_back("readout_message_aggregate"); - phase_names.push_back("dense_readout_projection"); - metadata["phases"] = string_tuple(phase_names); - const bool receiver_affine_uses_pack_workspace = - receiver_affine_superop_plan.active && !receiver_affine_superop_plan.direct_persistent; - DenseStateAffineWorkspace state_affine_workspace = - allocate_dense_state_affine_workspace( - state_affine_specs, - projected_message_tensor, - state_prev_tensors, - cell_param_vector, - receivers, - active_receiver_offset, - active_receiver_count, - receiver_affine_uses_pack_workspace, - receiver_affine_superop_plan.active, - !receiver_affine_superop_plan.active, - initial_state_is_fresh && T == 1); - allocate_receiver_affine_superop_workspace( - state_affine_specs, - projected_message_tensor, - state_prev_tensors, - cell_param_vector, - receiver_affine_superop_plan, - &state_affine_workspace); - RuntimeTensorTable state_affine_pack = pack_runtime_tensor_table(state_affine_workspace.outputs, input_k_ref); - RuntimeTensorTable readout_public_pack = pack_runtime_tensor_table({}, input_k_ref); - std::vector readout_public_window_keepalive; - TensorTable state_aux = state_affine_specs.empty() ? aux : state_affine_pack.table; - const at::Tensor public_hidden_out = tuple_tensor(public_next_tensors, 0); - const bool raw_public_aliases_public_hidden = !output_dependency_receiver_window && T == 1 && - raw_public_can_alias_public_hidden(public_projection_kind, public_hidden_out, B, receivers, raw_public_dim); - const bool raw_public_aliases_projected_message = - !output_dependency_receiver_window && !raw_public_aliases_public_hidden && !state_affine_specs.empty() && - projected_message_tensor.size(2) == raw_public_dim; - auto raw_public_tensor = raw_public_aliases_projected_message - ? projected_message_tensor - : raw_public_aliases_public_hidden - ? public_hidden_out - : at::empty({B, active_receiver_count, raw_public_dim}, input_k_ref.options()); - const int64_t readout_value_dim = tuple_tensor(readout_param_tensors, 1).size(1); - auto readout_message_tensor = at::empty({B, output_ports, readout_value_dim}, input_k_ref.options()); - float* raw_public = raw_public_tensor.data_ptr(); - auto partial_stats = dispatch_entry.reduction_stats_dim > 0 - ? at::empty({B, active_receiver_count, num_hidden_chunks, dispatch_entry.reduction_stats_dim}, input_k_ref.options()) - : at::empty({0}, input_k_ref.options()); - auto reduced_stats = dispatch_entry.reduction_stats_dim > 0 && !single_chunk_reduction_alias - ? at::empty({B, active_receiver_count, dispatch_entry.reduction_stats_dim}, input_k_ref.options()) - : dispatch_entry.reduction_stats_dim > 0 ? partial_stats - : at::empty({0}, input_k_ref.options()); - std::vector workspace_buffers; - std::vector workspace_buffer_bytes; - std::vector workspace_aliases = state_affine_workspace.aliases; - if (raw_public_aliases_projected_message) { - workspace_aliases.push_back("projected_message=raw_public"); - } else if (raw_public_aliases_public_hidden) { - workspace_aliases.push_back("raw_public=public_hidden"); - } else if (output_dependency_receiver_window) { - workspace_aliases.push_back( - "active_receiver_window=readout_dependency_cone[offset=" + std::to_string(active_receiver_offset) + - ",count=" + std::to_string(active_receiver_count) + "]"); - } - int64_t workspace_peak_bytes = 0; - if (!tiny_message_superop_plan.active) { - append_workspace_metadata( - "message_buffer", - message_tensor, - "receiver_message_aggregate->dense_input_projection", - "unique", - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes); - } - append_workspace_metadata( - "projected_message", - projected_message_tensor, - tiny_message_superop_plan.active ? "tiny_message_direct_projected->dense_state_affines" - : "dense_input_projection->dense_state_affines", - "unique", - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes); - append_workspace_metadata( - "raw_public", - raw_public_tensor, - "receiver_emit_raw_public->dense_public_projection", - raw_public_aliases_projected_message ? "phase_reuse" - : raw_public_aliases_public_hidden ? "public_hidden_alias" : "unique", - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes, - !raw_public_aliases_projected_message && !raw_public_aliases_public_hidden); - append_workspace_metadata( - "readout_message", - readout_message_tensor, - "readout_message_aggregate->dense_readout_projection", - "unique", - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes); - append_workspace_metadata( - "partial_stats", - partial_stats, - "receiver_state_update->receiver_reduce_stats", - "unique", - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes); - append_workspace_metadata( - "reduced_stats", - reduced_stats, - single_chunk_reduction_alias ? "receiver_state_update->receiver_emit_raw_public" - : "receiver_reduce_stats->receiver_emit_raw_public", - single_chunk_reduction_alias ? "partial_stats_alias" : "unique", - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes, - !single_chunk_reduction_alias); - if (single_chunk_reduction_alias) { - workspace_aliases.push_back("partial_stats=reduced_stats"); - } - workspace_buffers.insert( - workspace_buffers.end(), state_affine_workspace.buffers.begin(), state_affine_workspace.buffers.end()); - workspace_buffer_bytes.insert( - workspace_buffer_bytes.end(), - state_affine_workspace.buffer_bytes.begin(), - state_affine_workspace.buffer_bytes.end()); - workspace_peak_bytes += state_affine_workspace.bytes; - const char* input_projection_backend = "unrun"; - PublicProjectionExecution public_projection{ - "unrun", - "unrun", - "none", - "none", - "unrun", - 0, - 0, - 0, - 0, - 0, - 0}; - const char* readout_projection_backend = "unrun"; - std::vector state_affine_backends; - std::vector state_affine_sources; - std::vector state_affine_bucket_signatures; - std::vector state_affine_output_modes; - std::vector state_affine_reset_policies; - bool state_affine_packed_source_reused = false; - fabric::cuda::ops::DenseMessageExecution dense_message_execution{}; - bool dense_message_execution_recorded = false; - fabric::cuda::ops::DiagonalRecurrenceExecution diagonal_recurrence_execution{}; - bool diagonal_recurrence_execution_recorded = false; - int64_t diagonal_recurrence_launch_count = 0; - if (spatial_id == static_cast(SpatialOwnership::ReceiverOwned)) { - if (temporal_id == static_cast(TemporalExecution::Stepwise)) { - TORCH_CHECK(T == 1, "receiver-owned stepwise dispatcher expects T == 1"); - { - RECORD_FUNCTION("fabric.physical.message", std::vector()); - if (message_backend_id == 0) { - const fabric::cuda::nn::LoweredMessageBucket message_bucket = lower_regular_local_message_bucket( - B, - tiny_message_superop_plan.active ? active_receiver_count : receivers, - recurrent_local_sender_idx_raw.size(1), - recurrent_q_raw.size(1), - message_dim); - if (tiny_message_superop_plan.active) { - const at::Tensor projection_weight = tuple_tensor(input_projection_tensors, 0); - const at::Tensor projection_bias = optional_tuple_tensor(input_projection_tensors, 1, projected_message_tensor); - dense_message_execution = fabric::cuda::ops::dense_regular_local_tiny_message_projected_window_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(public_prev_tensors, 1), - tuple_tensor(public_prev_tensors, 2), - recurrent_local_sender_idx_raw, - recurrent_local_distance_raw, - resets_u8, - projection_weight, - projection_bias, - projected_message_tensor, - active_receiver_offset, - active_receiver_count, - num_input_ports, - active_receiver_offset, - distance_scale, - 0); - input_projection_backend = "fused_into_tiny_message"; - } else { - dense_message_execution = fabric::cuda::ops::dense_regular_local_message_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(public_prev_tensors, 1), - tuple_tensor(public_prev_tensors, 2), - recurrent_local_sender_idx_raw, - recurrent_local_distance_raw, - resets_u8, - message_tensor, - num_input_ports, - distance_scale, - 0); - } - dense_message_execution_recorded = true; - } else if (message_backend_id == 1 && has_sparse_message_topology) { - if (use_ragged_sparse_message) { - const fabric::cuda::nn::LoweredMessageBucket message_bucket = lower_ragged_sparse_message_bucket( - fabric::cuda::nn::MessageTopologyKind::ReceiverOwnedSparse, - B, - receivers, - recurrent_neighbor_idx_raw.size(1), - recurrent_q_raw.size(1), - message_dim); - dense_message_execution = fabric::cuda::ops::dense_receiver_owned_sparse_ragged_grouped_message_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(public_prev_tensors, 1), - tuple_tensor(public_prev_tensors, 2), - recurrent_neighbor_idx_raw, - recurrent_neighbor_valid_raw, - recurrent_edge_distance_raw, - recurrent_sparse_receiver_order_raw, - recurrent_sparse_degree_ptr_raw, - resets_u8, - message_tensor, - num_input_ports, - distance_scale, - 0); - } else { - const fabric::cuda::nn::LoweredMessageBucket message_bucket = lower_sparse_message_bucket( - fabric::cuda::nn::MessageTopologyKind::ReceiverOwnedSparse, - B, - receivers, - recurrent_neighbor_idx_raw.size(1), - recurrent_q_raw.size(1), - message_dim); - dense_message_execution = fabric::cuda::ops::dense_receiver_owned_sparse_degree_bucketed_message_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(public_prev_tensors, 1), - tuple_tensor(public_prev_tensors, 2), - recurrent_neighbor_idx_raw, - recurrent_neighbor_valid_raw, - recurrent_edge_distance_raw, - resets_u8, - message_tensor, - num_input_ports, - distance_scale, - 0); - } - dense_message_execution_recorded = true; - } else { - TORCH_CHECK( - message_backend_id != 1, - "sparse message backend requires sparse topology tensors for dense sparse-message lowering"); - launch_receiver_message_aggregate_cuda( - message_backend_id, - public_prev, - message_params, - input_ports, - recurrent_topology, - message_tensor.data_ptr(), - plan, - resets_u8, - 0, - stream.stream()); - } - } - if (!tiny_message_superop_plan.active) { - input_projection_backend = - launch_dense_input_projection(message_tensor, input_projection_tensors, projected_message_tensor); - } - launch_receiver_state_transition_cuda( - cell_core_id, - projected_message_tensor, - state_prev, - state_next, - cell_params, - aux, - state_aux, - state_prev_tensors, - cell_param_vector, - state_affine_specs, - state_affine_workspace, - receiver_affine_superop_plan, - diagonal_recurrence_superop_plan, - cell_transition_ir.state_epilogue_policy, - output_dependency_receiver_window ? active_receiver_plan : plan, - resets_u8, - 0, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - raw_public_tensor, - raw_public, - partial_stats.data_ptr(), - reduced_stats.data_ptr(), - num_hidden_chunks, - stream.stream(), - initial_state_is_fresh, - write_internal_carry, - write_internal_carry, - receivers, - active_receiver_offset, - &state_affine_backends, - &state_affine_sources, - &state_affine_bucket_signatures, - &state_affine_output_modes, - &state_affine_reset_policies, - &state_affine_packed_source_reused, - &diagonal_recurrence_execution, - &diagonal_recurrence_execution_recorded, - &diagonal_recurrence_launch_count); - const at::Tensor readout_query = tuple_tensor(readout_param_tensors, 0); - const at::Tensor value_to_output_weight = tuple_tensor(readout_param_tensors, 1); - const bool materialize_readout_public_window = - !materialize_final_state && - should_materialize_readout_public_window( - output_dependency_receiver_window, - readout, - readout_query, - value_to_output_weight, - active_receiver_count, - raw_public_dim); - metadata["readout_public_materialization_modes"] = py::make_tuple( - materialize_readout_public_window ? "materialized_active_window" : "raw_public_direct"); - const bool public_next_required = write_internal_carry; - if (public_next_required) { - public_projection = - launch_dense_public_projection( - public_projection_kind, - raw_public_tensor, - public_projection_tensors, - public_next_tensors, - active_receiver_offset); - readout_projection_backend = launch_dense_readout( - input_ports, - public_next, - readout, - readout_output_seq, - readout_param_tensors, - readout_message_tensor, - plan, - 0, - readout_output_index(0), - 0, - stream.stream()); - } else if (materialize_readout_public_window) { - public_projection = launch_materialized_readout_public_window( - B, - active_receiver_count, - public_projection_kind, - raw_public_tensor, - public_projection_tensors, - public_next_tensors, - readout_param_tensors, - input_k_ref, - active_receiver_offset, - &readout_public_pack, - &readout_public_window_keepalive, - &workspace_aliases, - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes); - readout_projection_backend = launch_dense_readout( - input_ports, - readout_public_pack.table, - readout, - readout_output_seq, - readout_param_tensors, - readout_message_tensor, - active_receiver_plan, - 0, - readout_output_index(0), - static_cast(active_receiver_offset), - stream.stream()); - } else { - public_projection = readout_narrow_public_projection_execution(); - readout_projection_backend = launch_dense_readout_from_raw_public( - input_ports, - raw_public_tensor, - public_projection_kind, - public_projection_params, - readout, - readout_output_seq, - readout_param_tensors, - readout_message_tensor, - output_dependency_receiver_window ? active_receiver_plan : plan, - 0, - readout_output_index(0), - static_cast(active_receiver_offset), - stream.stream()); - } - metadata["input_projection_backends"] = py::make_tuple(input_projection_backend); - metadata["input_projection_notes"] = py::make_tuple(input_projection_note(input_projection_backend)); - metadata["message_projection_boundaries"] = py::make_tuple("projected_message"); - metadata["message_projection_bucket_kinds"] = - py::make_tuple(message_projection_bucket_kind(message_backend_id, spatial_id)); - if (dense_message_execution_recorded) { - append_dense_message_workspace_metadata( - dense_message_execution, - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_aliases, - &workspace_peak_bytes); - } - set_dense_message_metadata( - &metadata, - dense_message_execution_recorded ? &dense_message_execution : nullptr, - message_backend_id, - spatial_id); - set_state_affine_metadata( - &metadata, - state_affine_specs, - state_affine_backends, - state_affine_sources, - state_affine_bucket_signatures, - state_affine_output_modes, - state_affine_reset_policies, - state_affine_workspace, - state_affine_packed_source_reused); - if (diagonal_recurrence_execution_recorded) { - append_diagonal_recurrence_workspace_metadata( - diagonal_recurrence_execution, - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_aliases, - &workspace_peak_bytes); - } - set_workspace_metadata(&metadata, workspace_buffers, workspace_buffer_bytes, workspace_peak_bytes, workspace_aliases); - metadata["public_projection_hidden_backends"] = py::make_tuple(public_projection.hidden_backend); - metadata["public_projection_kv_backends"] = py::make_tuple(public_projection.kv_backend); - metadata["readout_projection_backends"] = py::make_tuple(readout_projection_backend); - set_launch_granularity_metadata( - &metadata, - dense_message_execution_recorded ? &dense_message_execution : nullptr, - dense_message_execution_recorded, - tiny_message_superop_plan, - state_affine_specs, - state_affine_workspace, - receiver_affine_superop_plan, - diagonal_recurrence_superop_plan, - diagonal_recurrence_execution_recorded ? &diagonal_recurrence_execution : nullptr, - diagonal_recurrence_execution_recorded, - diagonal_recurrence_launch_count, - message_backend_id, - spatial_id, - T, - dispatch_entry.reduction_stats_dim, - num_hidden_chunks, - cell_transition_ir.state_epilogue_policy, - input_projection_backend, - public_projection, - readout_projection_backend); - return metadata; - } - TORCH_CHECK( - temporal_id == static_cast(TemporalExecution::PersistentScan), - "unsupported receiver-owned temporal execution id: ", - temporal_id); - TensorTable current_state_prev = state_prev; - TensorTable current_state_next = state_next; - TensorTable current_public_prev = public_prev; - TensorTable current_public_next = public_next; - py::tuple current_state_prev_tensors = state_prev_tensors; - py::tuple current_state_next_tensors = state_next_tensors; - py::tuple current_public_prev_tensors = public_prev_tensors; - py::tuple current_public_next_tensors = public_next_tensors; - for (int t = 0; t < T; ++t) { - { - RECORD_FUNCTION("fabric.physical.message", std::vector()); - if (message_backend_id == 0) { - const fabric::cuda::nn::LoweredMessageBucket message_bucket = lower_regular_local_message_bucket( - B, - tiny_message_superop_plan.active ? active_receiver_count : receivers, - recurrent_local_sender_idx_raw.size(1), - recurrent_q_raw.size(1), - message_dim); - if (tiny_message_superop_plan.active) { - const at::Tensor projection_weight = tuple_tensor(input_projection_tensors, 0); - const at::Tensor projection_bias = optional_tuple_tensor(input_projection_tensors, 1, projected_message_tensor); - dense_message_execution = fabric::cuda::ops::dense_regular_local_tiny_message_projected_window_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(current_public_prev_tensors, 1), - tuple_tensor(current_public_prev_tensors, 2), - recurrent_local_sender_idx_raw, - recurrent_local_distance_raw, - resets_u8, - projection_weight, - projection_bias, - projected_message_tensor, - active_receiver_offset, - active_receiver_count, - num_input_ports, - active_receiver_offset, - distance_scale, - t); - input_projection_backend = "fused_into_tiny_message"; - } else { - dense_message_execution = fabric::cuda::ops::dense_regular_local_message_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(current_public_prev_tensors, 1), - tuple_tensor(current_public_prev_tensors, 2), - recurrent_local_sender_idx_raw, - recurrent_local_distance_raw, - resets_u8, - message_tensor, - num_input_ports, - distance_scale, - t); - } - dense_message_execution_recorded = true; - } else if (message_backend_id == 1 && has_sparse_message_topology) { - if (use_ragged_sparse_message) { - const fabric::cuda::nn::LoweredMessageBucket message_bucket = lower_ragged_sparse_message_bucket( - fabric::cuda::nn::MessageTopologyKind::ReceiverOwnedSparse, - B, - receivers, - recurrent_neighbor_idx_raw.size(1), - recurrent_q_raw.size(1), - message_dim); - dense_message_execution = fabric::cuda::ops::dense_receiver_owned_sparse_ragged_grouped_message_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(current_public_prev_tensors, 1), - tuple_tensor(current_public_prev_tensors, 2), - recurrent_neighbor_idx_raw, - recurrent_neighbor_valid_raw, - recurrent_edge_distance_raw, - recurrent_sparse_receiver_order_raw, - recurrent_sparse_degree_ptr_raw, - resets_u8, - message_tensor, - num_input_ports, - distance_scale, - t); - } else { - const fabric::cuda::nn::LoweredMessageBucket message_bucket = lower_sparse_message_bucket( - fabric::cuda::nn::MessageTopologyKind::ReceiverOwnedSparse, - B, - receivers, - recurrent_neighbor_idx_raw.size(1), - recurrent_q_raw.size(1), - message_dim); - dense_message_execution = fabric::cuda::ops::dense_receiver_owned_sparse_degree_bucketed_message_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(current_public_prev_tensors, 1), - tuple_tensor(current_public_prev_tensors, 2), - recurrent_neighbor_idx_raw, - recurrent_neighbor_valid_raw, - recurrent_edge_distance_raw, - resets_u8, - message_tensor, - num_input_ports, - distance_scale, - t); - } - dense_message_execution_recorded = true; - } else { - TORCH_CHECK( - message_backend_id != 1, - "sparse message backend requires sparse topology tensors for dense sparse-message lowering"); - launch_receiver_message_aggregate_cuda( - message_backend_id, - current_public_prev, - message_params, - input_ports, - recurrent_topology, - message_tensor.data_ptr(), - plan, - resets_u8, - t, - stream.stream()); - } - } - if (!tiny_message_superop_plan.active) { - input_projection_backend = - launch_dense_input_projection(message_tensor, input_projection_tensors, projected_message_tensor); - } - launch_receiver_state_transition_cuda( - cell_core_id, - projected_message_tensor, - current_state_prev, - current_state_next, - cell_params, - aux, - state_aux, - current_state_prev_tensors, - cell_param_vector, - state_affine_specs, - state_affine_workspace, - receiver_affine_superop_plan, - diagonal_recurrence_superop_plan, - cell_transition_ir.state_epilogue_policy, - output_dependency_receiver_window ? active_receiver_plan : plan, - resets_u8, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - raw_public_tensor, - raw_public, - partial_stats.data_ptr(), - reduced_stats.data_ptr(), - num_hidden_chunks, - stream.stream(), - initial_state_is_fresh && t == 0, - write_internal_carry || t + 1 < T, - write_internal_carry || t + 1 < T, - receivers, - active_receiver_offset, - &state_affine_backends, - &state_affine_sources, - &state_affine_bucket_signatures, - &state_affine_output_modes, - &state_affine_reset_policies, - &state_affine_packed_source_reused, - &diagonal_recurrence_execution, - &diagonal_recurrence_execution_recorded, - &diagonal_recurrence_launch_count); - const bool public_next_required = write_internal_carry || t + 1 < T; - if (public_next_required) { - public_projection = - launch_dense_public_projection( - public_projection_kind, - raw_public_tensor, - public_projection_tensors, - current_public_next_tensors, - active_receiver_offset); - if (emit_readout_for_step(t)) { - readout_projection_backend = launch_dense_readout( - input_ports, - current_public_next, - readout, - readout_output_seq, - readout_param_tensors, - readout_message_tensor, - output_dependency_receiver_window ? active_receiver_plan : plan, - t, - readout_output_index(t), - static_cast(output_dependency_receiver_window ? active_receiver_offset : 0), - stream.stream()); - } - } else if ( - emit_readout_for_step(t) && - should_materialize_readout_public_window( - output_dependency_receiver_window, - readout, - tuple_tensor(readout_param_tensors, 0), - tuple_tensor(readout_param_tensors, 1), - active_receiver_count, - raw_public_dim)) { - public_projection = launch_materialized_readout_public_window( - B, - active_receiver_count, - public_projection_kind, - raw_public_tensor, - public_projection_tensors, - current_public_next_tensors, - readout_param_tensors, - input_k_ref, - active_receiver_offset, - &readout_public_pack, - &readout_public_window_keepalive, - &workspace_aliases, - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes); - readout_projection_backend = launch_dense_readout( - input_ports, - readout_public_pack.table, - readout, - readout_output_seq, - readout_param_tensors, - readout_message_tensor, - active_receiver_plan, - t, - readout_output_index(t), - static_cast(active_receiver_offset), - stream.stream()); - } else if (emit_readout_for_step(t)) { - public_projection = readout_narrow_public_projection_execution(); - readout_projection_backend = launch_dense_readout_from_raw_public( - input_ports, - raw_public_tensor, - public_projection_kind, - public_projection_params, - readout, - readout_output_seq, - readout_param_tensors, - readout_message_tensor, - output_dependency_receiver_window ? active_receiver_plan : plan, - t, - readout_output_index(t), - static_cast(active_receiver_offset), - stream.stream()); - } - record_forward_carry_checkpoint( - forward_carry_checkpoint_state_tensors, - forward_carry_checkpoint_public_tensors, - current_state_next_tensors, - current_public_next_tensors, - forward_carry_checkpoint_state_source_indices, - forward_carry_checkpoint_public_source_indices, - forward_carry_checkpoint_stride, - t + 1, - T); - TensorTable tmp_state = current_state_prev; - current_state_prev = current_state_next; - current_state_next = tmp_state; - py::tuple tmp_state_tensors = current_state_prev_tensors; - current_state_prev_tensors = current_state_next_tensors; - current_state_next_tensors = tmp_state_tensors; - TensorTable tmp_public = current_public_prev; - current_public_prev = current_public_next; - current_public_next = tmp_public; - py::tuple tmp_public_prev_tensors = current_public_prev_tensors; - current_public_prev_tensors = current_public_next_tensors; - current_public_next_tensors = tmp_public_prev_tensors; - } - metadata["input_projection_backends"] = py::make_tuple(input_projection_backend); - metadata["input_projection_notes"] = py::make_tuple(input_projection_note(input_projection_backend)); - metadata["message_projection_boundaries"] = py::make_tuple("projected_message"); - metadata["message_projection_bucket_kinds"] = - py::make_tuple(message_projection_bucket_kind(message_backend_id, spatial_id)); - if (dense_message_execution_recorded) { - append_dense_message_workspace_metadata( - dense_message_execution, - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_aliases, - &workspace_peak_bytes); - } - set_dense_message_metadata( - &metadata, - dense_message_execution_recorded ? &dense_message_execution : nullptr, - message_backend_id, - spatial_id); - set_state_affine_metadata( - &metadata, - state_affine_specs, - state_affine_backends, - state_affine_sources, - state_affine_bucket_signatures, - state_affine_output_modes, - state_affine_reset_policies, - state_affine_workspace, - state_affine_packed_source_reused); - if (diagonal_recurrence_execution_recorded) { - append_diagonal_recurrence_workspace_metadata( - diagonal_recurrence_execution, - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_aliases, - &workspace_peak_bytes); - } - set_workspace_metadata(&metadata, workspace_buffers, workspace_buffer_bytes, workspace_peak_bytes, workspace_aliases); - metadata["public_projection_hidden_backends"] = py::make_tuple(public_projection.hidden_backend); - metadata["public_projection_kv_backends"] = py::make_tuple(public_projection.kv_backend); - metadata["readout_projection_backends"] = py::make_tuple(readout_projection_backend); - set_launch_granularity_metadata( - &metadata, - dense_message_execution_recorded ? &dense_message_execution : nullptr, - dense_message_execution_recorded, - tiny_message_superop_plan, - state_affine_specs, - state_affine_workspace, - receiver_affine_superop_plan, - diagonal_recurrence_superop_plan, - diagonal_recurrence_execution_recorded ? &diagonal_recurrence_execution : nullptr, - diagonal_recurrence_execution_recorded, - diagonal_recurrence_launch_count, - message_backend_id, - spatial_id, - T, - dispatch_entry.reduction_stats_dim, - num_hidden_chunks, - cell_transition_ir.state_epilogue_policy, - input_projection_backend, - public_projection, - readout_projection_backend); - return metadata; - } - - TORCH_CHECK(spatial_id == static_cast(SpatialOwnership::EdgeOwned), "unsupported spatial ownership id: ", spatial_id); - if (temporal_id == static_cast(TemporalExecution::Stepwise)) { - TORCH_CHECK(T == 1, "edge-owned stepwise dispatcher expects T == 1"); - } else { - TORCH_CHECK( - temporal_id == static_cast(TemporalExecution::PersistentScan), - "unsupported edge-owned temporal execution id: ", - temporal_id); - } - const bool edge_uses_dense_sparse_message = message_backend_id == 1 && has_sparse_message_topology; - TORCH_CHECK( - message_backend_id != 1 || edge_uses_dense_sparse_message, - "edge-owned sparse message backend requires sparse topology tensors for dense sparse-message lowering"); - at::Tensor msg_buffer; - at::Tensor max_buffer; - if (!edge_uses_dense_sparse_message) { - msg_buffer = at::zeros({B, receivers, message_dim + 1}, recurrent_edge_weight.options()); - max_buffer = at::full( - {B, receivers}, - -std::numeric_limits::infinity(), - recurrent_edge_weight.options()); - append_workspace_metadata( - "edge_message_accum", - msg_buffer, - "edge_owned_accumulate->receiver_message_normalize", - "unique", - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes); - append_workspace_metadata( - "edge_message_max", - max_buffer, - "edge_owned_accumulate", - "unique", - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_peak_bytes); - } - TensorTable current_state_prev = state_prev; - TensorTable current_state_next = state_next; - TensorTable current_public_prev = public_prev; - TensorTable current_public_next = public_next; - py::tuple current_state_prev_tensors = state_prev_tensors; - py::tuple current_state_next_tensors = state_next_tensors; - py::tuple current_public_prev_tensors = public_prev_tensors; - py::tuple current_public_next_tensors = public_next_tensors; - for (int t = 0; t < T; ++t) { - { - RECORD_FUNCTION("fabric.physical.message", std::vector()); - if (edge_uses_dense_sparse_message) { - if (use_ragged_sparse_message) { - const fabric::cuda::nn::LoweredMessageBucket message_bucket = lower_ragged_sparse_message_bucket( - fabric::cuda::nn::MessageTopologyKind::EdgeOwnedSparse, - B, - receivers, - recurrent_neighbor_idx_raw.size(1), - recurrent_q_raw.size(1), - message_dim); - dense_message_execution = fabric::cuda::ops::dense_edge_owned_sparse_ragged_grouped_message_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(current_public_prev_tensors, 1), - tuple_tensor(current_public_prev_tensors, 2), - recurrent_neighbor_idx_raw, - recurrent_neighbor_valid_raw, - recurrent_edge_distance_raw, - recurrent_sparse_receiver_order_raw, - recurrent_sparse_degree_ptr_raw, - resets_u8, - message_tensor, - num_input_ports, - distance_scale, - t); - } else { - const fabric::cuda::nn::LoweredMessageBucket message_bucket = lower_sparse_message_bucket( - fabric::cuda::nn::MessageTopologyKind::EdgeOwnedSparse, - B, - receivers, - recurrent_neighbor_idx_raw.size(1), - recurrent_q_raw.size(1), - message_dim); - dense_message_execution = fabric::cuda::ops::dense_edge_owned_sparse_degree_bucketed_message_out_cuda( - message_bucket, - recurrent_q_raw, - input_k_ref, - input_v_ref, - tuple_tensor(current_public_prev_tensors, 1), - tuple_tensor(current_public_prev_tensors, 2), - recurrent_neighbor_idx_raw, - recurrent_neighbor_valid_raw, - recurrent_edge_distance_raw, - resets_u8, - message_tensor, - num_input_ports, - distance_scale, - t); - } - dense_message_execution_recorded = true; - } else { - msg_buffer.zero_(); - max_buffer.fill_(-std::numeric_limits::infinity()); - launch_edge_owned_accumulate_stepwise_cuda( - message_backend_id, - current_public_prev, - message_params, - input_ports, - recurrent_topology, - max_buffer.data_ptr(), - msg_buffer.data_ptr(), - plan, - resets_u8, - t, - stream.stream()); - launch_receiver_normalize_accumulated_message_cuda( - msg_buffer.data_ptr(), - message_tensor.data_ptr(), - plan, - stream.stream()); - } - } - input_projection_backend = - launch_dense_input_projection(message_tensor, input_projection_tensors, projected_message_tensor); - launch_receiver_state_transition_cuda( - cell_core_id, - projected_message_tensor, - current_state_prev, - current_state_next, - cell_params, - aux, - state_aux, - current_state_prev_tensors, - cell_param_vector, - state_affine_specs, - state_affine_workspace, - receiver_affine_superop_plan, - diagonal_recurrence_superop_plan, - cell_transition_ir.state_epilogue_policy, - plan, - resets_u8, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - raw_public_tensor, - raw_public, - partial_stats.data_ptr(), - reduced_stats.data_ptr(), - num_hidden_chunks, - stream.stream(), - initial_state_is_fresh && t == 0, - write_internal_carry || t + 1 < T, - write_internal_carry || t + 1 < T, - receivers, - 0, - &state_affine_backends, - &state_affine_sources, - &state_affine_bucket_signatures, - &state_affine_output_modes, - &state_affine_reset_policies, - &state_affine_packed_source_reused, - &diagonal_recurrence_execution, - &diagonal_recurrence_execution_recorded, - &diagonal_recurrence_launch_count); - const bool public_next_required = write_internal_carry || t + 1 < T; - if (public_next_required) { - public_projection = - launch_dense_public_projection( - public_projection_kind, raw_public_tensor, public_projection_tensors, current_public_next_tensors); - if (emit_readout_for_step(t)) { - readout_projection_backend = launch_dense_readout( - input_ports, - current_public_next, - readout, - readout_output_seq, - readout_param_tensors, - readout_message_tensor, - plan, - t, - readout_output_index(t), - 0, - stream.stream()); - } - } else if (emit_readout_for_step(t)) { - public_projection = readout_narrow_public_projection_execution(); - readout_projection_backend = launch_dense_readout_from_raw_public( - input_ports, - raw_public_tensor, - public_projection_kind, - public_projection_params, - readout, - readout_output_seq, - readout_param_tensors, - readout_message_tensor, - plan, - t, - readout_output_index(t), - 0, - stream.stream()); - } - record_forward_carry_checkpoint( - forward_carry_checkpoint_state_tensors, - forward_carry_checkpoint_public_tensors, - current_state_next_tensors, - current_public_next_tensors, - forward_carry_checkpoint_state_source_indices, - forward_carry_checkpoint_public_source_indices, - forward_carry_checkpoint_stride, - t + 1, - T); - TensorTable tmp_state = current_state_prev; - current_state_prev = current_state_next; - current_state_next = tmp_state; - py::tuple tmp_state_tensors = current_state_prev_tensors; - current_state_prev_tensors = current_state_next_tensors; - current_state_next_tensors = tmp_state_tensors; - TensorTable tmp_public = current_public_prev; - current_public_prev = current_public_next; - current_public_next = tmp_public; - py::tuple tmp_public_prev_tensors = current_public_prev_tensors; - current_public_prev_tensors = current_public_next_tensors; - current_public_next_tensors = tmp_public_prev_tensors; - } - metadata["input_projection_backends"] = py::make_tuple(input_projection_backend); - metadata["input_projection_notes"] = py::make_tuple(input_projection_note(input_projection_backend)); - metadata["message_projection_boundaries"] = py::make_tuple("projected_message"); - metadata["message_projection_bucket_kinds"] = - py::make_tuple(message_projection_bucket_kind(message_backend_id, spatial_id)); - if (dense_message_execution_recorded) { - append_dense_message_workspace_metadata( - dense_message_execution, - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_aliases, - &workspace_peak_bytes); - } - set_dense_message_metadata( - &metadata, - dense_message_execution_recorded ? &dense_message_execution : nullptr, - message_backend_id, - spatial_id); - set_state_affine_metadata( - &metadata, - state_affine_specs, - state_affine_backends, - state_affine_sources, - state_affine_bucket_signatures, - state_affine_output_modes, - state_affine_reset_policies, - state_affine_workspace, - state_affine_packed_source_reused); - if (diagonal_recurrence_execution_recorded) { - append_diagonal_recurrence_workspace_metadata( - diagonal_recurrence_execution, - &workspace_buffers, - &workspace_buffer_bytes, - &workspace_aliases, - &workspace_peak_bytes); - } - set_workspace_metadata(&metadata, workspace_buffers, workspace_buffer_bytes, workspace_peak_bytes, workspace_aliases); - metadata["public_projection_hidden_backends"] = py::make_tuple(public_projection.hidden_backend); - metadata["public_projection_kv_backends"] = py::make_tuple(public_projection.kv_backend); - metadata["readout_projection_backends"] = py::make_tuple(readout_projection_backend); - set_launch_granularity_metadata( - &metadata, - dense_message_execution_recorded ? &dense_message_execution : nullptr, - dense_message_execution_recorded, - tiny_message_superop_plan, - state_affine_specs, - state_affine_workspace, - receiver_affine_superop_plan, - diagonal_recurrence_superop_plan, - diagonal_recurrence_execution_recorded ? &diagonal_recurrence_execution : nullptr, - diagonal_recurrence_execution_recorded, - diagonal_recurrence_launch_count, - message_backend_id, - spatial_id, - T, - dispatch_entry.reduction_stats_dim, - num_hidden_chunks, - cell_transition_ir.state_epilogue_policy, - input_projection_backend, - public_projection, - readout_projection_backend); - return metadata; -} - -} // namespace fabric - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "can_virtualize_fresh_state", - &fabric::fabric_dispatch_can_virtualize_fresh_state_cuda, - "Return whether the semantic CUDA transition can consume a virtual fresh-zero previous state"); - m.def("forward", &fabric::fabric_dispatch_forward_cuda, "Fabric generic backend forward dispatch (CUDA)"); -} diff --git a/src/cortical/fabric/backend/cuda/execution/dispatcher_cuda.py b/src/cortical/fabric/backend/cuda/execution/dispatcher_cuda.py deleted file mode 100644 index dff67332..00000000 --- a/src/cortical/fabric/backend/cuda/execution/dispatcher_cuda.py +++ /dev/null @@ -1,1104 +0,0 @@ -from __future__ import annotations - -import math -import os -from collections.abc import Sequence -from glob import glob -from typing import Any, Callable - -import torch - -from cortical.fabric.backend.cuda.execution.registry import ( - SUPPORTED_CELL_STATIC_STAGE_MODES, - SUPPORTED_MESSAGE_RULE_LOWERINGS, - SUPPORTED_MESSAGE_RULE_OUTPUT_BOUNDARIES, - SUPPORTED_READOUT_MODES, - FabricExecutionRequest, - ForwardCarryCheckpoints, -) -from cortical.fabric.backend.cuda.execution.tensor_pack import empty_packed_tensor_table, pack_tensor_tree -from cortical.fabric.backend.cuda.ops import dense_affine_cuda -from cortical.native.extension_loader import safe_load_extension - -_MOD_PATH = os.path.dirname(__file__) -_SRC_ROOT = os.path.normpath(os.path.join(_MOD_PATH, "..", "..", "..", "..", "..")) -_EXT = None -_LAST_LAUNCH_METADATA: dict[str, tuple[Any, ...]] = {} -_LAST_FORWARD_CARRY_CHECKPOINTS: ForwardCarryCheckpoints | None = None -_CELL_NATIVE_ROOT = os.path.normpath(os.path.join(_MOD_PATH, "..", "cells")) - -_MESSAGE_BACKEND_ID = { - "local": 0, - "sparse": 1, -} -_PUBLIC_PROJECTION_KIND = { - "hidden": 0, - "preproj": 1, -} -_READOUT_MODE = { - "skip": 0, - "separate_port_owned": 1, -} -_CELL_STATIC_STAGE_MODE = { - "disabled": 0, - "shared_full": 1, -} - - -def _check_supported_launch_modes(request: FabricExecutionRequest) -> None: - if request.readout_mode not in SUPPORTED_READOUT_MODES: - raise RuntimeError( - f"Unsupported Fabric CUDA readout_mode={request.readout_mode!r}; " - f"supported modes are {sorted(SUPPORTED_READOUT_MODES)}" - ) - if request.cell_static_stage_mode not in SUPPORTED_CELL_STATIC_STAGE_MODES: - raise RuntimeError( - f"Unsupported Fabric CUDA cell_static_stage_mode={request.cell_static_stage_mode!r}; " - f"supported modes are {sorted(SUPPORTED_CELL_STATIC_STAGE_MODES)}" - ) - if request.state_static_stage_mode not in SUPPORTED_CELL_STATIC_STAGE_MODES: - raise RuntimeError( - f"Unsupported Fabric CUDA state_static_stage_mode={request.state_static_stage_mode!r}; " - f"supported modes are {sorted(SUPPORTED_CELL_STATIC_STAGE_MODES)}" - ) - if request.emit_static_stage_mode not in SUPPORTED_CELL_STATIC_STAGE_MODES: - raise RuntimeError( - f"Unsupported Fabric CUDA emit_static_stage_mode={request.emit_static_stage_mode!r}; " - f"supported modes are {sorted(SUPPORTED_CELL_STATIC_STAGE_MODES)}" - ) - if request.message_rule_lowering_kind not in SUPPORTED_MESSAGE_RULE_LOWERINGS: - raise RuntimeError( - f"Unsupported Fabric CUDA message_rule_lowering_kind={request.message_rule_lowering_kind!r}; " - f"supported lowerings are {sorted(SUPPORTED_MESSAGE_RULE_LOWERINGS)}" - ) - if request.message_rule_output_boundary not in SUPPORTED_MESSAGE_RULE_OUTPUT_BOUNDARIES: - raise RuntimeError( - f"Unsupported Fabric CUDA message_rule_output_boundary={request.message_rule_output_boundary!r}; " - f"supported boundaries are {sorted(SUPPORTED_MESSAGE_RULE_OUTPUT_BOUNDARIES)}" - ) - - -def _message_rule_launch_metadata(request: FabricExecutionRequest) -> dict[str, tuple[str, ...]]: - return { - "message_rule_names": (request.message_rule_name,), - "message_rule_lowering_kinds": (request.message_rule_lowering_kind,), - "message_rule_expression_signatures": (request.message_rule_expression_signature,), - "message_rule_source_signatures": (request.message_rule_source_signature,), - "message_rule_parameter_sharing_signatures": (request.message_rule_parameter_sharing_signature,), - "message_rule_output_boundaries": (request.message_rule_output_boundary,), - } - - -def _load_ext(): - global _EXT - if _EXT is not None: - return _EXT - native_registration_sources = sorted(glob(os.path.join(_CELL_NATIVE_ROOT, "*_registration.cu"))) - _EXT = safe_load_extension( - name="fabric_dispatcher_cuda", - sources=[ - os.path.join(_MOD_PATH, "dispatcher.cpp"), - os.path.join(_MOD_PATH, "receiver_owned_stepwise.cu"), - os.path.join(_MOD_PATH, "edge_owned_accumulate_stepwise.cu"), - os.path.join(_MOD_PATH, "readout_apply.cu"), - os.path.join(_MOD_PATH, "..", "ops", "dense_affine_kernels.cu"), - os.path.join(_MOD_PATH, "..", "ops", "dense_message_kernels.cu"), - os.path.join(_MOD_PATH, "..", "ops", "diagonal_recurrence_kernels.cu"), - os.path.join(_MOD_PATH, "..", "registry", "cell_dispatch_registry.cpp"), - os.path.join(_MOD_PATH, "..", "message_passing", "local_message_kernels.cu"), - os.path.join(_MOD_PATH, "..", "message_passing", "sparse_message_kernels.cu"), - *native_registration_sources, - ], - extra_cflags=["-O3"], - extra_cuda_cflags=["-O3", "-Xptxas", "-O3"], - extra_include_paths=[_SRC_ROOT], - extra_ldflags=["-lcublas"], - verbose=False, - ) - return _EXT - - -def last_launch_metadata() -> dict[str, tuple[Any, ...]]: - return dict(_LAST_LAUNCH_METADATA) - - -def last_forward_carry_checkpoints() -> ForwardCarryCheckpoints | None: - return _LAST_FORWARD_CARRY_CHECKPOINTS - - -def _contiguous_tensor(tensor: torch.Tensor, source: str) -> torch.Tensor: - with torch.profiler.record_function(source): - return tensor.contiguous() - - -def _contiguous_tuple(tensors: Sequence[torch.Tensor], source: str) -> tuple[torch.Tensor, ...]: - with torch.profiler.record_function(source): - out: list[torch.Tensor] = [] - for tensor in tensors: - out.append(tensor.contiguous()) - return tuple(out) - - -def _receiver_leading_param_window( - tensor: torch.Tensor | None, - request: FabricExecutionRequest, - *, - group_size: int = 1, -) -> torch.Tensor | None: - if tensor is None: - return None - start, count, active = _terminal_active_receiver_window(request) - full_receivers = int(request.routing_tensors["recurrent_q"].shape[0]) - if not active or tensor.dim() != 3 or int(tensor.shape[0]) == 1: - return tensor - if int(tensor.shape[0]) == full_receivers: - return tensor.narrow(0, start, count).contiguous() - if group_size > 1 and int(tensor.shape[0]) * group_size == full_receivers: - if start % group_size != 0 or count % group_size != 0: - raise RuntimeError("Compact receiver window must align with grouped receiver-major projection weights") - return tensor.narrow(0, start // group_size, count // group_size).contiguous() - return tensor - - -def _project_sender_kv_from_hidden( - hidden: torch.Tensor, - *, - direct_weight: torch.Tensor | None, - grouped_weight: torch.Tensor | None, - group_size: int, - head_dim: int, - value_dim: int, -) -> tuple[torch.Tensor, torch.Tensor]: - if ( - not torch.is_grad_enabled() - and hidden.is_cuda - and hidden.dtype == torch.float32 - and grouped_weight is not None - and group_size > 1 - and grouped_weight.is_cuda - and grouped_weight.dtype == torch.float32 - ): - kv_all = dense_affine_cuda( - hidden, - grouped_weight, - layout="receiver_major", - group_size=group_size, - ).output - return kv_all.split((head_dim, value_dim), dim=-1) - if ( - not torch.is_grad_enabled() - and hidden.is_cuda - and hidden.dtype == torch.float32 - and direct_weight is not None - and direct_weight.is_cuda - and direct_weight.dtype == torch.float32 - ): - kv_all = dense_affine_cuda(hidden, direct_weight, layout="receiver_major").output - return kv_all.split((head_dim, value_dim), dim=-1) - if grouped_weight is not None and group_size > 1: - batch_size, num_cells, hidden_size = hidden.shape - num_groups = int(grouped_weight.shape[0]) - grouped_cells = ( - hidden.reshape(batch_size, num_groups, group_size, hidden_size) - .permute(1, 0, 2, 3) - .reshape(num_groups, batch_size * group_size, hidden_size) - ) - kv_all = torch.bmm(grouped_cells, grouped_weight) - kv_all = ( - kv_all.reshape(num_groups, batch_size, group_size, head_dim + value_dim) - .permute(1, 0, 2, 3) - .reshape(batch_size, num_cells, head_dim + value_dim) - ) - else: - assert direct_weight is not None - if direct_weight.dim() == 3: - kv_all = torch.bmm(hidden.transpose(0, 1), direct_weight).transpose(0, 1) - else: - kv_all = torch.nn.functional.linear(hidden, direct_weight) - return kv_all.split((head_dim, value_dim), dim=-1) - - -def _build_local_csr( - *, - sender_idx: torch.Tensor, - valid: torch.Tensor, - distance: torch.Tensor, - delay: torch.Tensor | None, - distance_scale: float, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - with torch.profiler.record_function("fabric.glue.build_local_topology"): - if valid.dtype != torch.bool: - valid = valid.to(dtype=torch.bool) - counts = valid.to(dtype=torch.int32).sum(dim=1) - receiver_ptr = torch.zeros(valid.shape[0] + 1, device=sender_idx.device, dtype=torch.int32) - receiver_ptr[1:] = counts.cumsum(dim=0) - receiver_idx, slot_idx = torch.nonzero(valid, as_tuple=True) - flat_sender = sender_idx[receiver_idx, slot_idx].to(dtype=torch.int32) - flat_weight = distance.index_select(0, slot_idx).to(dtype=torch.float32) * float(distance_scale) - if delay is None or delay.numel() == 0: - flat_delay = torch.empty(0, device=sender_idx.device, dtype=torch.int32) - else: - flat_delay = delay.index_select(0, slot_idx).to(dtype=torch.int32) - return receiver_ptr, flat_sender, flat_delay, flat_weight - - -def _build_sparse_csr( - *, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor | None, - distance_scale: float, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - with torch.profiler.record_function("fabric.glue.build_sparse_topology"): - valid = neighbor_valid.to(dtype=torch.bool) - counts = valid.to(dtype=torch.int32).sum(dim=1) - receiver_ptr = torch.zeros(valid.shape[0] + 1, device=neighbor_idx.device, dtype=torch.int32) - receiver_ptr[1:] = counts.cumsum(dim=0) - flat_sender = neighbor_idx[valid].to(dtype=torch.int32) - flat_weight = edge_distance[valid].to(dtype=torch.float32) * float(distance_scale) - if edge_delay is None or edge_delay.numel() == 0: - flat_delay = torch.empty(0, device=neighbor_idx.device, dtype=torch.int32) - else: - flat_delay = edge_delay[valid].to(dtype=torch.int32) - return receiver_ptr, flat_sender, flat_delay, flat_weight - - -def _build_topologies(request: FabricExecutionRequest) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]: - distance_scale = float(request.static_config.get("distance_scale", 0.0)) - use_delay = bool(request.static_config.get("use_delay", False)) - if request.message_backend_name == "local": - recurrent = _build_local_csr( - sender_idx=request.routing_tensors["recurrent_local_sender_idx"], - valid=request.routing_tensors["recurrent_local_valid"], - distance=request.routing_tensors["local_distance"], - delay=request.routing_tensors.get("local_delay"), - distance_scale=distance_scale, - ) - else: - recurrent = _build_sparse_csr( - neighbor_idx=request.routing_tensors["recurrent_neighbor_idx"], - neighbor_valid=request.routing_tensors["recurrent_neighbor_valid"], - edge_distance=request.routing_tensors["recurrent_edge_distance"], - edge_delay=request.routing_tensors.get("recurrent_edge_delay"), - distance_scale=distance_scale, - ) - if request.readout_backend_name == "output_sequence_from_banks": - readout = _build_local_csr( - sender_idx=request.routing_tensors["output_local_sender_idx"], - valid=request.routing_tensors["output_local_valid"], - distance=request.routing_tensors["local_distance"], - delay=request.routing_tensors.get("local_delay"), - distance_scale=distance_scale, - ) - else: - readout = _build_sparse_csr( - neighbor_idx=request.routing_tensors["output_neighbor_idx"], - neighbor_valid=request.routing_tensors["output_neighbor_valid"], - edge_distance=request.routing_tensors["output_edge_distance"], - edge_delay=request.routing_tensors.get("output_edge_delay"), - distance_scale=distance_scale, - ) - if use_delay: - recurrent_delay = recurrent[2] - readout_delay = readout[2] - if (recurrent_delay.numel() > 0 and bool((recurrent_delay != 0).any())) or ( - readout_delay.numel() > 0 and bool((readout_delay != 0).any()) - ): - raise RuntimeError("Generic Fabric CUDA dispatcher does not support non-zero edge delay yet") - return recurrent, readout - - -def _pack_state_tree( - request: FabricExecutionRequest, -) -> tuple[tuple[torch.Tensor, ...], Callable[[Sequence[torch.Tensor]], Any]]: - state_leaves = tuple( - _slice_active_receiver_bank(tensor, request) - for tensor in _contiguous_tuple( - request.cell_core_spec.state_schema.flatten(request.packed_state), - "fabric.glue.launch_state_tree_contiguous", - ) - ) - - def rebuild(leaves: Sequence[torch.Tensor]) -> Any: - return request.cell_core_spec.state_schema.rebuild(leaves) - - return state_leaves, rebuild - - -def _terminal_active_receiver_window(request: FabricExecutionRequest) -> tuple[int, int, bool]: - output_boundary = str(request.static_config.get("output_boundary", "sequence")) - mode = str(request.static_config.get("terminal_active_receiver_window_mode", "full_surface")) - full_receivers = int(request.routing_tensors["recurrent_q"].shape[0]) - start = int(request.static_config.get("output_local_recurrent_window_start", 0)) - count = int(request.static_config.get("output_local_recurrent_window_count", 0)) - contiguous = bool(request.static_config.get("output_local_recurrent_window_contiguous", False)) - single_step_sequence_boundary = output_boundary == "sequence" and int(request.input_k_seq.shape[1]) == 1 - streaming_output_active_region = mode == "streaming_output_active_region" - active = bool( - (output_boundary == "terminal" or single_step_sequence_boundary or streaming_output_active_region) - and mode != "full_surface" - and (request.initial_state_is_fresh or request.compact_input_carry) - and not request.materialize_final_state - and contiguous - and start >= 0 - and count > 0 - and start + count <= full_receivers - and count < full_receivers - ) - return start, count, active - - -def _request_receiver_count(request: FabricExecutionRequest) -> int: - _start, count, active = _terminal_active_receiver_window(request) - return int(count if active else request.routing_tensors["recurrent_q"].shape[0]) - - -def _slice_active_receiver_bank( - tensor: torch.Tensor, - request: FabricExecutionRequest, -) -> torch.Tensor: - start, count, active = _terminal_active_receiver_window(request) - full_receivers = int(request.routing_tensors["recurrent_q"].shape[0]) - if not active or tensor.dim() < 2 or int(tensor.shape[1]) != full_receivers: - return tensor - return tensor.narrow(1, start, count).contiguous() - - -def _empty_backend_state_leaves( - request: FabricExecutionRequest, -) -> tuple[tuple[torch.Tensor, ...], Callable[[Sequence[torch.Tensor]], Any]]: - batch_size = int(request.input_k_seq.shape[0]) - receivers = _request_receiver_count(request) - hidden = int(request.initial_hidden.shape[-1]) - state_keys = request.cell_core_spec.state_schema.keys - with torch.profiler.record_function("fabric.glue.virtual_fresh_state_work_alloc"): - leaves = tuple(request.input_k_seq.new_empty(batch_size, receivers, hidden) for _state_key in state_keys) - - def rebuild(output_leaves: Sequence[torch.Tensor]) -> Any: - return request.cell_core_spec.state_schema.rebuild(output_leaves) - - return leaves, rebuild - - -def _virtual_backend_state_leaves() -> tuple[tuple[torch.Tensor, ...], Callable[[Sequence[torch.Tensor]], Any]]: - def rebuild(_output_leaves: Sequence[torch.Tensor]) -> Any: - return None - - return (), rebuild - - -def _zero_backend_state_leaves( - request: FabricExecutionRequest, -) -> tuple[tuple[torch.Tensor, ...], Callable[[Sequence[torch.Tensor]], Any]]: - batch_size = int(request.input_k_seq.shape[0]) - receivers = _request_receiver_count(request) - hidden = int(request.initial_hidden.shape[-1]) - state_keys = request.cell_core_spec.state_schema.keys - with torch.profiler.record_function("fabric.glue.backend_population_state_zero"): - leaves = tuple(request.input_k_seq.new_zeros(batch_size, receivers, hidden) for _state_key in state_keys) - - def rebuild(output_leaves: Sequence[torch.Tensor]) -> Any: - return request.cell_core_spec.state_schema.rebuild(output_leaves) - - return leaves, rebuild - - -def _forward_carry_checkpoint_state_selection( - request: FabricExecutionRequest, - state_output_leaves: Sequence[torch.Tensor], -) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], tuple[str, ...], Callable[[Sequence[torch.Tensor]], Any]]: - state_keys = tuple(str(key) for key in request.cell_core_spec.state_schema.keys) - configured_state_names = tuple( - str(name) for name in request.static_config.get("forward_carry_checkpoint_state_names", ()) - ) - checkpoint_state_names = configured_state_names if configured_state_names else state_keys - state_index_by_name = {name: index for index, name in enumerate(state_keys)} - missing_names = tuple(name for name in checkpoint_state_names if name not in state_index_by_name) - if missing_names: - raise RuntimeError( - "Fabric forward carry checkpoint requested unknown state leaves: " + ", ".join(missing_names) - ) - source_indices = tuple(int(state_index_by_name[name]) for name in checkpoint_state_names) - if any(index >= len(state_output_leaves) for index in source_indices): - raise RuntimeError("Fabric forward carry checkpoint state source index exceeds materialized state leaf count") - checkpoint_state_leaves = tuple(state_output_leaves[index] for index in source_indices) - - def rebuild(output_leaves: Sequence[torch.Tensor]) -> Any: - if tuple(checkpoint_state_names) == state_keys: - return request.cell_core_spec.state_schema.rebuild(output_leaves) - return dict(zip(checkpoint_state_names, output_leaves, strict=True)) - - return checkpoint_state_leaves, source_indices, checkpoint_state_names, rebuild - - -def _pack_cell_params( - request: FabricExecutionRequest, -) -> tuple[torch.Tensor, ...]: - return _contiguous_tuple( - request.cell_core_spec.parameter_schema.flatten(request.cell_tensors), - "fabric.glue.launch_cell_param_contiguous", - ) - - -def _pack_input_projection_params( - request: FabricExecutionRequest, -) -> tuple[torch.Tensor, ...]: - return _contiguous_tuple( - request.cell_core_spec.input_projection_schema.flatten(request.cell_tensors), - "fabric.glue.launch_input_projection_param_contiguous", - ) - - -def _pack_public_projection_params( - request: FabricExecutionRequest, -) -> tuple[torch.Tensor, ...]: - return _contiguous_tuple( - request.cell_core_spec.public_projection_schema.flatten(request.cell_tensors), - "fabric.glue.launch_public_projection_param_contiguous", - ) - - -def _project_initial_public_banks( - request: FabricExecutionRequest, - *, - public_projection_params: Sequence[torch.Tensor], -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - hidden = _slice_active_receiver_bank( - _contiguous_tensor(request.initial_hidden, "fabric.glue.launch_initial_public_contiguous"), - request, - ) - head_dim = int(request.routing_tensors["recurrent_q"].shape[-1]) - value_dim = int(request.readout_tensors["output_projection_weight"].shape[1]) - if request.initial_state_is_fresh and int(hidden.shape[1]) == 0: - return ( - hidden, - hidden.new_empty(int(hidden.shape[0]), 0, head_dim), - hidden.new_empty(int(hidden.shape[0]), 0, value_dim), - ) - if request.initial_recurrent_k is not None and request.initial_recurrent_v is not None: - return ( - hidden, - _slice_active_receiver_bank( - _contiguous_tensor(request.initial_recurrent_k, "fabric.glue.launch_initial_public_contiguous"), - request, - ), - _slice_active_receiver_bank( - _contiguous_tensor(request.initial_recurrent_v, "fabric.glue.launch_initial_public_contiguous"), - request, - ), - ) - if request.cell_core_spec.public_schema.kind == "hidden": - direct_weight = public_projection_params[0] if len(public_projection_params) > 0 else None - grouped_weight = public_projection_params[1] if len(public_projection_params) > 1 else None - sender_group_size = int(request.static_config.get("sender_group_size", 1)) - direct_weight = _receiver_leading_param_window( - direct_weight if torch.is_tensor(direct_weight) else None, - request, - group_size=sender_group_size, - ) - grouped_weight = _receiver_leading_param_window( - grouped_weight if torch.is_tensor(grouped_weight) else None, - request, - group_size=sender_group_size, - ) - with torch.profiler.record_function("fabric.glue.launch_initial_public_project_from_hidden"): - recurrent_k, recurrent_v = _project_sender_kv_from_hidden( - hidden, - direct_weight=(direct_weight if torch.is_tensor(direct_weight) and direct_weight.numel() > 0 else None), - grouped_weight=( - grouped_weight if torch.is_tensor(grouped_weight) and grouped_weight.numel() > 0 else None - ), - group_size=sender_group_size, - head_dim=head_dim, - value_dim=value_dim, - ) - return ( - hidden, - _contiguous_tensor(recurrent_k, "fabric.glue.launch_initial_public_contiguous"), - _contiguous_tensor(recurrent_v, "fabric.glue.launch_initial_public_contiguous"), - ) - batch_size, receivers, _hidden = hidden.shape - with torch.profiler.record_function("fabric.glue.launch_initial_public_zero_kv"): - recurrent_k = hidden.new_zeros(batch_size, receivers, head_dim) - recurrent_v = hidden.new_zeros(batch_size, receivers, value_dim) - return hidden, recurrent_k, recurrent_v - - -def _public_output_buffers(request: FabricExecutionRequest) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size = int(request.input_k_seq.shape[0]) - receivers = _request_receiver_count(request) - hidden_dim = int(request.initial_hidden.shape[-1]) - head_dim = int(request.routing_tensors["recurrent_q"].shape[-1]) - value_dim = int(request.readout_tensors["output_projection_weight"].shape[1]) - return ( - request.input_k_seq.new_empty(batch_size, receivers, hidden_dim), - request.input_k_seq.new_empty(batch_size, receivers, head_dim), - request.input_k_seq.new_empty(batch_size, receivers, value_dim), - ) - - -def _empty_public_output_buffers(request: FabricExecutionRequest) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size = int(request.input_k_seq.shape[0]) - hidden_dim = int(request.initial_hidden.shape[-1]) - head_dim = int(request.routing_tensors["recurrent_q"].shape[-1]) - value_dim = int(request.readout_tensors["output_projection_weight"].shape[1]) - return ( - request.input_k_seq.new_empty(batch_size, 0, hidden_dim), - request.input_k_seq.new_empty(batch_size, 0, head_dim), - request.input_k_seq.new_empty(batch_size, 0, value_dim), - ) - - -def _budgeted_forward_carry_checkpoint_stride( - *, - state_leaves: Sequence[torch.Tensor], - public_hidden: torch.Tensor, - time_steps: int, -) -> tuple[int, str]: - time_steps = int(time_steps) - if time_steps <= 1 or public_hidden.device.type != "cuda": - return 0, "disabled:not_cuda_sequence" - carry_checkpoint_bytes = int(sum(tensor.numel() * tensor.element_size() for tensor in state_leaves)) - carry_checkpoint_bytes += int(public_hidden.numel() * public_hidden.element_size()) - carry_checkpoint_bytes = int(math.ceil(float(carry_checkpoint_bytes) * 1.15)) - if carry_checkpoint_bytes <= 0: - return 0, "disabled:zero_carry_checkpoint_bytes" - free_bytes, total_bytes = torch.cuda.mem_get_info(public_hidden.device) - runtime_reserve_bytes = max(4 << 30, int(total_bytes * 0.04)) - checkpoint_budget = max(0, min(int(total_bytes * 0.20), int(free_bytes) - int(runtime_reserve_bytes))) - compute_cap = max(1, int(math.ceil(math.sqrt(float(max(1, time_steps)))))) - max_checkpoints = min(int(compute_cap), int(checkpoint_budget // max(1, carry_checkpoint_bytes))) - if max_checkpoints <= 0: - return ( - 0, - f"disabled:budget;carry_checkpoint_bytes={carry_checkpoint_bytes};" - f"budget_bytes={checkpoint_budget};free_bytes={int(free_bytes)}", - ) - stride = max(1, int(math.floor(float(time_steps) / float(max_checkpoints + 1)))) - checkpoint_count = max(0, int(math.ceil(float(time_steps) / float(stride))) - 1) - while checkpoint_count > 0 and int(checkpoint_count) * int(carry_checkpoint_bytes) > int(checkpoint_budget): - stride += 1 - checkpoint_count = max(0, int(math.ceil(float(time_steps) / float(stride))) - 1) - if checkpoint_count <= 0: - return ( - 0, - f"disabled:post_budget;carry_checkpoint_bytes={carry_checkpoint_bytes};budget_bytes={checkpoint_budget}", - ) - return ( - int(stride), - f"enabled:dispatcher_budgeted;stride={int(stride)};count={int(checkpoint_count)};" - f"carry_checkpoint_bytes={carry_checkpoint_bytes};budget_bytes={checkpoint_budget};" - f"compute_cap={int(compute_cap)}", - ) - - -def normalize_launch_request(request: FabricExecutionRequest) -> FabricExecutionRequest: - _check_supported_launch_modes(request) - return request - - -def _dims( - *, - request: FabricExecutionRequest, - input_projection_params: Sequence[torch.Tensor], - public_projection_params: Sequence[torch.Tensor], - public_kind: str, -) -> tuple[int, int]: - input_weight = input_projection_params[0] - projected_message_dim = int(input_weight.shape[2] if input_weight.dim() == 3 else input_weight.shape[0]) - if public_kind == "hidden": - raw_public_dim = int(request.initial_hidden.shape[-1]) - else: - raw_public_dim = int(public_projection_params[0].shape[1]) - return projected_message_dim, raw_public_dim - - -def run_backend_dispatch_forward( - request: FabricExecutionRequest, - *, - spatial_ownership: str, - temporal_execution: str, -) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor]: - global _LAST_FORWARD_CARRY_CHECKPOINTS - _LAST_FORWARD_CARRY_CHECKPOINTS = None - request = normalize_launch_request(request) - device = request.input_k_seq.device - fresh_persistent_scan = temporal_execution == "persistent_scan" and request.initial_state_is_fresh - compact_input_persistent_scan = temporal_execution == "persistent_scan" and request.compact_input_carry - cell_params = _pack_cell_params(request) - input_projection_params = _pack_input_projection_params(request) - public_projection_params = _pack_public_projection_params(request) - projected_message_dim, raw_public_dim = _dims( - request=request, - input_projection_params=input_projection_params, - public_projection_params=public_projection_params, - public_kind=request.cell_core_spec.public_schema.kind, - ) - virtual_fresh_state_prev = False - if fresh_persistent_scan and request.packed_state is None: - virtual_fresh_state_prev = bool( - _load_ext().can_virtualize_fresh_state( - int(request.cell_core_spec.cell_kind), - tuple(cell_params), - int(request.routing_tensors["recurrent_q"].shape[0]), - int(projected_message_dim), - int(raw_public_dim), - ) - ) - state_output_required = request.materialize_final_state or int(request.input_k_seq.shape[1]) > 1 - if virtual_fresh_state_prev and not state_output_required: - state_leaves, rebuild_state = _virtual_backend_state_leaves() - elif virtual_fresh_state_prev: - state_leaves, rebuild_state = _empty_backend_state_leaves(request) - elif fresh_persistent_scan and request.packed_state is None: - state_leaves, rebuild_state = _empty_backend_state_leaves(request) - elif request.packed_state is None: - state_leaves, rebuild_state = _zero_backend_state_leaves(request) - else: - state_leaves, rebuild_state = _pack_state_tree(request) - single_step_nonfresh_persistent_scan = ( - temporal_execution == "persistent_scan" - and not request.initial_state_is_fresh - and request.input_k_seq.shape[1] == 1 - ) - reuse_single_step_inference_banks = single_step_nonfresh_persistent_scan and not request.gradient_enabled - public_output_required = request.materialize_final_state or int(request.input_k_seq.shape[1]) > 1 - reuse_single_step_public_banks = reuse_single_step_inference_banks and public_output_required - state_output_leaves = ( - tuple(state_leaves) - if reuse_single_step_inference_banks - else tuple(torch.empty_like(tensor) for tensor in state_leaves) - if state_leaves - else () - if not state_output_required - else _empty_backend_state_leaves(request)[0] - ) - state_work_leaves = ( - () - if fresh_persistent_scan or compact_input_persistent_scan or single_step_nonfresh_persistent_scan - else tuple(torch.empty_like(tensor) for tensor in state_leaves) - ) - hidden_prev, recurrent_k_prev, recurrent_v_prev = _project_initial_public_banks( - request, - public_projection_params=public_projection_params, - ) - virtual_fresh_public_prev = request.initial_state_is_fresh and int(hidden_prev.shape[1]) == 0 - public_init_leaves = (hidden_prev, recurrent_k_prev, recurrent_v_prev) - public_output_leaves = ( - tuple(public_init_leaves) - if reuse_single_step_public_banks - else _empty_public_output_buffers(request) - if not public_output_required - else _public_output_buffers(request) - ) - public_work_leaves = ( - () - if fresh_persistent_scan or compact_input_persistent_scan or single_step_nonfresh_persistent_scan - else tuple(torch.empty_like(tensor) for tensor in public_output_leaves) - ) - output_boundary = str(request.static_config.get("output_boundary", "sequence")) - if output_boundary not in {"sequence", "terminal"}: - raise ValueError(f"Unsupported Fabric CUDA output_boundary={output_boundary!r}") - output_time_steps = 1 if output_boundary == "terminal" else int(request.input_k_seq.shape[1]) - readout_pool = str(request.static_config.get("readout_pool", "flatten")) - readout_output_ports = ( - int(request.static_config.get("readout_slots", 1) or 1) - if readout_pool == "mean" - else int(request.routing_tensors["output_q"].shape[0]) - ) - output_seq = torch.empty( - request.input_k_seq.shape[0], - output_time_steps, - readout_output_ports, - request.readout_tensors["output_projection_bias"].shape[-1], - device=device, - dtype=request.input_k_seq.dtype, - ) - configured_forward_carry_checkpoint_stride = int( - request.static_config.get("forward_carry_checkpoint_stride", 0) or 0 - ) - forward_carry_checkpoint_stride = int(configured_forward_carry_checkpoint_stride) - forward_carry_checkpoint_reason = str(request.static_config.get("forward_carry_checkpoint_reason", "disabled")) - forward_carry_source_state_leaves: tuple[torch.Tensor, ...] = () - forward_carry_state_source_indices: tuple[int, ...] = () - forward_carry_checkpoint_state_names: tuple[str, ...] = () - _unused_forward_carry_state_leaves, forward_carry_rebuild_state = _virtual_backend_state_leaves() - del _unused_forward_carry_state_leaves - forward_carry_public_source_indices: tuple[int, ...] = () - forward_carry_checkpoint_possible = bool( - request.gradient_enabled - and temporal_execution == "persistent_scan" - and int(request.input_k_seq.shape[1]) > 1 - and configured_forward_carry_checkpoint_stride != 0 - ) - if forward_carry_checkpoint_possible: - ( - forward_carry_source_state_leaves, - forward_carry_state_source_indices, - forward_carry_checkpoint_state_names, - forward_carry_rebuild_state, - ) = _forward_carry_checkpoint_state_selection(request, state_output_leaves) - forward_carry_public_source_indices = (0,) - if configured_forward_carry_checkpoint_stride < 0: - forward_carry_checkpoint_stride, dispatcher_checkpoint_reason = _budgeted_forward_carry_checkpoint_stride( - state_leaves=forward_carry_source_state_leaves, - public_hidden=public_output_leaves[0], - time_steps=int(request.input_k_seq.shape[1]), - ) - forward_carry_checkpoint_reason = ( - f"{forward_carry_checkpoint_reason};{dispatcher_checkpoint_reason}" - if forward_carry_checkpoint_reason - else dispatcher_checkpoint_reason - ) - forward_carry_checkpoint_steps: tuple[int, ...] = () - forward_carry_state_leaves: tuple[torch.Tensor, ...] = () - forward_carry_public_leaves: tuple[torch.Tensor, ...] = () - if ( - request.gradient_enabled - and temporal_execution == "persistent_scan" - and forward_carry_checkpoint_stride > 0 - and int(request.input_k_seq.shape[1]) > 1 - ): - forward_carry_checkpoint_steps = tuple( - range( - int(forward_carry_checkpoint_stride), - int(request.input_k_seq.shape[1]), - int(forward_carry_checkpoint_stride), - ) - ) - if forward_carry_checkpoint_steps: - checkpoint_count = len(forward_carry_checkpoint_steps) - with torch.profiler.record_function("fabric.glue.forward_carry_checkpoint_alloc"): - forward_carry_state_leaves = tuple( - tensor.new_empty((checkpoint_count, *tuple(tensor.shape))) - for tensor in forward_carry_source_state_leaves - ) - forward_carry_public_leaves = ( - public_output_leaves[0].new_empty((checkpoint_count, *tuple(public_output_leaves[0].shape))), - ) - - state_prev_tree: Any = state_leaves - public_prev_tree: Sequence[torch.Tensor] = public_init_leaves - state_next_tree: Sequence[torch.Tensor] = state_output_leaves - public_next_tree: Sequence[torch.Tensor] = public_output_leaves - final_state_leaves = state_output_leaves - final_public_leaves = public_output_leaves - if temporal_execution == "persistent_scan": - if request.initial_state_is_fresh or request.compact_input_carry: - state_prev_tree = state_leaves - public_prev_tree = public_init_leaves - state_next_tree = state_output_leaves - public_next_tree = public_output_leaves - if request.input_k_seq.shape[1] % 2 == 0: - final_state_leaves = state_leaves - final_public_leaves = public_init_leaves - else: - final_state_leaves = state_output_leaves - final_public_leaves = public_output_leaves - elif request.input_k_seq.shape[1] == 1: - state_prev_tree = state_leaves - public_prev_tree = public_init_leaves - state_next_tree = state_output_leaves - public_next_tree = public_output_leaves - final_state_leaves = state_output_leaves - final_public_leaves = public_output_leaves - else: - with torch.profiler.record_function("fabric.glue.launch_persistent_scan_initial_copy"): - for dst, src in zip(state_output_leaves, state_leaves, strict=True): - dst.copy_(src) - for dst, src in zip(public_output_leaves, public_init_leaves, strict=True): - dst.copy_(src) - state_prev_tree = state_output_leaves - public_prev_tree = public_output_leaves - state_next_tree = state_work_leaves - public_next_tree = public_work_leaves - if request.input_k_seq.shape[1] % 2 == 0: - final_state_leaves = state_output_leaves - final_public_leaves = public_output_leaves - else: - final_state_leaves = state_work_leaves - final_public_leaves = public_work_leaves - - if state_prev_tree: - _, _state_prev_leaves, state_prev_pack = pack_tensor_tree(state_prev_tree) - else: - _state_prev_leaves = () - state_prev_pack = empty_packed_tensor_table(device=device) - if state_next_tree: - _, _state_next_leaves, state_next_pack = pack_tensor_tree(state_next_tree) - else: - _state_next_leaves = () - state_next_pack = empty_packed_tensor_table(device=device) - _, _public_prev_leaves, public_prev_pack = pack_tensor_tree(public_prev_tree) - _, _public_next_leaves, public_next_pack = pack_tensor_tree(public_next_tree) - if state_work_leaves: - _, _state_work_leaves, state_work_pack = pack_tensor_tree(state_work_leaves) - else: - _state_work_leaves = () - state_work_pack = empty_packed_tensor_table(device=device) - if public_work_leaves: - _, _public_work_leaves, public_work_pack = pack_tensor_tree(public_work_leaves) - else: - _public_work_leaves = () - public_work_pack = empty_packed_tensor_table(device=device) - _, _, cell_params_pack = pack_tensor_tree(cell_params) - _, _, input_projection_pack = pack_tensor_tree(input_projection_params) - _, _, public_projection_pack = pack_tensor_tree(public_projection_params) - _, message_param_leaves, message_params_pack = pack_tensor_tree( - ( - _contiguous_tensor( - request.routing_tensors["recurrent_q"], - "fabric.glue.launch_message_param_contiguous", - ), - ) - ) - _, input_port_leaves, input_ports_pack = pack_tensor_tree( - _contiguous_tuple( - (request.input_k_seq, request.input_v_seq), - "fabric.glue.launch_input_ports_contiguous", - ) - ) - aux_pack = empty_packed_tensor_table(device=device) - _, _, readout_output_pack = pack_tensor_tree((output_seq,)) - _, _readout_param_leaves, readout_params_pack = pack_tensor_tree( - _contiguous_tuple( - ( - request.routing_tensors["output_q"], - request.readout_tensors["output_projection_weight"], - request.readout_tensors["output_projection_bias"], - ), - "fabric.glue.launch_readout_param_contiguous", - ) - ) - - recurrent_topology, readout_topology = _build_topologies(request) - public_prev_tensors = _contiguous_tuple(public_prev_tree, "fabric.glue.launch_public_tree_contiguous") - public_next_tensors = _contiguous_tuple(public_next_tree, "fabric.glue.launch_public_tree_contiguous") - public_work_tensors = ( - tuple(public_work_leaves) - if not public_work_leaves - else _contiguous_tuple(public_work_leaves, "fabric.glue.launch_public_tree_contiguous") - ) - empty_i32 = torch.empty(0, device=device, dtype=torch.int32) - empty_i64 = torch.empty(0, device=device, dtype=torch.int64) - empty_bool = torch.empty(0, device=device, dtype=torch.bool) - empty_f32 = torch.empty(0, device=device, dtype=torch.float32) - - global _LAST_LAUNCH_METADATA - launch_metadata = _load_ext().forward( - int(request.cell_core_spec.cell_kind), - int(_MESSAGE_BACKEND_ID[request.message_backend_name]), - 0 if spatial_ownership == "receiver_owned" else 1, - 0 if temporal_execution == "stepwise" else 1, - int(_PUBLIC_PROJECTION_KIND[request.cell_core_spec.public_schema.kind]), - int(projected_message_dim), - int(raw_public_dim), - state_prev_pack.as_extension_tuple(), - tuple(_state_prev_leaves), - state_next_pack.as_extension_tuple(), - tuple(_state_next_leaves), - public_prev_pack.as_extension_tuple(), - public_next_pack.as_extension_tuple(), - state_work_pack.as_extension_tuple(), - public_work_pack.as_extension_tuple(), - cell_params_pack.as_extension_tuple(), - tuple(cell_params), - input_projection_pack.as_extension_tuple(), - tuple(input_projection_params), - public_projection_pack.as_extension_tuple(), - tuple(public_projection_params), - message_params_pack.as_extension_tuple(), - input_ports_pack.as_extension_tuple(), - aux_pack.as_extension_tuple(), - readout_output_pack.as_extension_tuple(), - tuple((output_seq,)), - readout_params_pack.as_extension_tuple(), - tuple(_readout_param_leaves), - tuple(input_port_leaves), - public_prev_tensors, - public_next_tensors, - public_work_tensors, - message_param_leaves[0], - _contiguous_tensor( - request.routing_tensors["recurrent_local_sender_idx"], - "fabric.glue.launch_routing_tensor_contiguous", - ), - _contiguous_tensor(request.routing_tensors["local_distance"], "fabric.glue.launch_routing_tensor_contiguous"), - _contiguous_tensor( - request.routing_tensors.get("local_delay", empty_i32), - "fabric.glue.launch_routing_tensor_contiguous", - ), - _contiguous_tensor( - request.routing_tensors.get("recurrent_neighbor_idx", empty_i32), - "fabric.glue.launch_routing_tensor_contiguous", - ), - _contiguous_tensor( - request.routing_tensors.get("recurrent_neighbor_valid", empty_bool), - "fabric.glue.launch_routing_tensor_contiguous", - ), - _contiguous_tensor( - request.routing_tensors.get("recurrent_edge_distance", empty_f32), - "fabric.glue.launch_routing_tensor_contiguous", - ), - _contiguous_tensor( - request.routing_tensors.get("recurrent_edge_delay", empty_i32), - "fabric.glue.launch_routing_tensor_contiguous", - ), - _contiguous_tensor( - request.routing_tensors.get("recurrent_sparse_receiver_order", empty_i64), - "fabric.glue.launch_routing_tensor_contiguous", - ), - _contiguous_tensor( - request.routing_tensors.get("recurrent_sparse_degree_ptr", empty_i64), - "fabric.glue.launch_routing_tensor_contiguous", - ), - int(request.static_config.get("recurrent_sparse_positive_degree_buckets", 0)), - float(request.static_config.get("distance_scale", 0.0)), - _contiguous_tensor(recurrent_topology[0], "fabric.glue.launch_topology_contiguous"), - _contiguous_tensor(recurrent_topology[1], "fabric.glue.launch_topology_contiguous"), - _contiguous_tensor(recurrent_topology[2], "fabric.glue.launch_topology_contiguous"), - _contiguous_tensor(recurrent_topology[3], "fabric.glue.launch_topology_contiguous"), - _contiguous_tensor(readout_topology[0], "fabric.glue.launch_topology_contiguous"), - _contiguous_tensor(readout_topology[1], "fabric.glue.launch_topology_contiguous"), - _contiguous_tensor(readout_topology[2], "fabric.glue.launch_topology_contiguous"), - _contiguous_tensor(readout_topology[3], "fabric.glue.launch_topology_contiguous"), - _contiguous_tensor(request.resets_u8, "fabric.glue.launch_resets_contiguous"), - int( - request.static_config.get( - "output_local_recurrent_window_start" - if request.readout_backend_name == "output_sequence_from_banks" - else "output_sparse_recurrent_window_start", - 0, - ) - ), - int( - request.static_config.get( - "output_local_recurrent_window_count" - if request.readout_backend_name == "output_sequence_from_banks" - else "output_sparse_recurrent_window_count", - 0, - ) - ), - bool( - request.static_config.get( - "output_local_recurrent_window_contiguous" - if request.readout_backend_name == "output_sequence_from_banks" - else "output_sparse_recurrent_window_contiguous", - False, - ) - ), - tuple(forward_carry_state_leaves), - tuple(forward_carry_public_leaves), - tuple(forward_carry_state_source_indices), - tuple(forward_carry_public_source_indices), - int(forward_carry_checkpoint_stride), - int(request.input_k_seq.shape[2]), - int(request.input_k_seq.shape[0]), - int(request.input_k_seq.shape[1]), - int(request.routing_tensors["recurrent_q"].shape[0]), - int(recurrent_topology[1].numel()), - int(request.routing_tensors["output_q"].shape[0]), - int(request.input_v_seq.shape[-1]), - int(request.initial_hidden.shape[-1]), - int(max(1, request.receiver_tile)), - int(max(1, request.batch_tile)), - int(max(1, request.edge_tile)), - int(max(1, request.hidden_chunk)), - int(max(1, request.state_receiver_tile)), - int(max(1, request.state_batch_tile)), - int(max(1, request.state_hidden_chunk)), - int(_CELL_STATIC_STAGE_MODE[request.state_static_stage_mode]), - int(max(1, request.emit_receiver_tile)), - int(max(1, request.emit_batch_tile)), - int(max(1, request.emit_hidden_chunk)), - int(_CELL_STATIC_STAGE_MODE[request.emit_static_stage_mode]), - int(max(1, request.public_receiver_tile)), - int(max(1, request.public_batch_tile)), - int(_READOUT_MODE[request.readout_mode]), - int(max(1, request.readout_port_tile)), - int(max(1, request.readout_output_chunk)), - int(_CELL_STATIC_STAGE_MODE[request.cell_static_stage_mode]), - int(max(1, request.replication_factor)), - bool(request.stage_receiver_static), - bool(request.initial_state_is_fresh), - True, - bool(request.materialize_final_state), - bool(request.preserve_internal_carry), - bool(request.compact_input_carry), - ) - _LAST_LAUNCH_METADATA = { - str(key): tuple(value) if isinstance(value, tuple) else (value,) for key, value in dict(launch_metadata).items() - } - _LAST_LAUNCH_METADATA.update(_message_rule_launch_metadata(request)) - if forward_carry_checkpoint_steps: - _LAST_FORWARD_CARRY_CHECKPOINTS = ForwardCarryCheckpoints( - stride=int(forward_carry_checkpoint_stride), - steps=forward_carry_checkpoint_steps, - state_tensors=forward_carry_state_leaves, - public_tensors=forward_carry_public_leaves, - rebuild_state=forward_carry_rebuild_state, - state_names=forward_carry_checkpoint_state_names, - ) - _LAST_LAUNCH_METADATA["workspace_aliases"] = tuple(_LAST_LAUNCH_METADATA.get("workspace_aliases", ())) + ( - f"forward_carry_checkpoint_stride:t={int(forward_carry_checkpoint_stride)}", - f"forward_carry_checkpoint_count:n={len(forward_carry_checkpoint_steps)}", - "forward_carry_checkpoint_public:hidden_only_recompute_kv", - "forward_carry_checkpoint_state_names:" + ",".join(forward_carry_checkpoint_state_names), - ) - _LAST_LAUNCH_METADATA["generic_glue_fusion_modes"] = tuple( - _LAST_LAUNCH_METADATA.get("generic_glue_fusion_modes", ()) - ) + ("forward_carry_checkpoint_tape",) - elif request.gradient_enabled and temporal_execution == "persistent_scan": - _LAST_LAUNCH_METADATA["workspace_aliases"] = tuple(_LAST_LAUNCH_METADATA.get("workspace_aliases", ())) + ( - "forward_carry_checkpoint:disabled:" + forward_carry_checkpoint_reason, - ) - if reuse_single_step_inference_banks: - _LAST_LAUNCH_METADATA["generic_glue_fusion_modes"] = tuple( - _LAST_LAUNCH_METADATA.get("generic_glue_fusion_modes", ()) - ) + ("single_step_inference_state_public_bank_reuse",) - reuse_aliases = ("state_next=state_prev[single_step_inference]",) - if reuse_single_step_public_banks: - reuse_aliases = reuse_aliases + ("public_next=public_prev[single_step_inference]",) - _LAST_LAUNCH_METADATA["workspace_aliases"] = ( - tuple(_LAST_LAUNCH_METADATA.get("workspace_aliases", ())) + reuse_aliases - ) - if virtual_fresh_public_prev: - _LAST_LAUNCH_METADATA["generic_glue_fusion_modes"] = tuple( - _LAST_LAUNCH_METADATA.get("generic_glue_fusion_modes", ()) - ) + ("fresh_public_prev_virtual_zero",) - _LAST_LAUNCH_METADATA["workspace_aliases"] = tuple(_LAST_LAUNCH_METADATA.get("workspace_aliases", ())) + ( - "public_prev=fresh_zero_contract", - ) - workspace_aliases = tuple(_LAST_LAUNCH_METADATA.get("workspace_aliases", ())) - readout_materialized_public_window = "readout_public_projection=materialized_active_window" in workspace_aliases - if not public_output_required: - readout_mode = ( - "readout_uses_materialized_public_window" - if readout_materialized_public_window - else "readout_uses_raw_public_projection" - ) - public_next_alias = ( - "public_next=not_materialized[readout_materialized_active_window]" - if readout_materialized_public_window - else "public_next=not_materialized[readout_raw_public_projection]" - ) - _LAST_LAUNCH_METADATA["generic_glue_fusion_modes"] = tuple( - _LAST_LAUNCH_METADATA.get("generic_glue_fusion_modes", ()) - ) + ("public_next_not_materialized", readout_mode) - _LAST_LAUNCH_METADATA["workspace_aliases"] = tuple(_LAST_LAUNCH_METADATA.get("workspace_aliases", ())) + ( - public_next_alias, - ) - _LAST_LAUNCH_METADATA["state_affine_reset_rows_present"] = ("true" if request.reset_rows_present else "false",) - next_packed_state = ( - rebuild_state(final_state_leaves) - if request.materialize_final_state or request.preserve_internal_carry - else None - ) - recurrent_hidden = final_public_leaves[0] - recurrent_k = final_public_leaves[1] - recurrent_v = final_public_leaves[2] - return output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v - - -__all__ = ["last_forward_carry_checkpoints", "last_launch_metadata", "run_backend_dispatch_forward"] diff --git a/src/cortical/fabric/backend/cuda/execution/edge_owned_accumulate_stepwise.cu b/src/cortical/fabric/backend/cuda/execution/edge_owned_accumulate_stepwise.cu deleted file mode 100644 index 1da6f1ab..00000000 --- a/src/cortical/fabric/backend/cuda/execution/edge_owned_accumulate_stepwise.cu +++ /dev/null @@ -1,197 +0,0 @@ -#include -#include -#include -#include - -#include "cortical/fabric/backend/cuda/execution/common.cuh" -#include "cortical/fabric/backend/cuda/message_passing/local_message_backend.cuh" -#include "cortical/fabric/backend/cuda/message_passing/sparse_message_backend.cuh" - -namespace fabric { - -namespace { - -constexpr int kMaxWarpsPerBlock = 32; - -__device__ inline void atomic_max_float(float* address, float value) { - int* address_as_int = reinterpret_cast(address); - int old = *address_as_int; - while (__int_as_float(old) < value) { - const int assumed = old; - old = atomicCAS(address_as_int, assumed, __float_as_int(value)); - if (old == assumed) { - break; - } - } -} - -__device__ inline int elected_edge_lane(int edge) { - return edge & (kWarpSize - 1); -} - -template -__global__ void edge_owned_max_logit_stepwise_kernel( - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* max_buffer, - ExecutionPlan plan, - const uint8_t* resets_u8, - int t) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - const int edge_tile = max(1, plan.edge_tile); - const int batch_tile = max(1, plan.batch_tile); - const int e_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int edge = blockIdx.x * edge_tile + e_local; - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (edge >= topo.num_edges || b >= plan.B) { - continue; - } - const bool reset_row = resets_u8 != nullptr ? resets_u8[static_cast(b) * plan.T + t] != 0 : false; - int receiver = -1; - const float logit = MessageBackend::edge_logit_step_warp( - b, - t, - edge, - reset_row, - topo, - input_ports, - public_prev, - message_params, - &receiver, - lane); - if (lane == elected_edge_lane(edge) && receiver >= 0 && isfinite(logit)) { - atomic_max_float(&max_buffer[static_cast(b) * plan.receivers + receiver], logit); - } - } -} - -template -__global__ void edge_owned_accumulate_stepwise_kernel( - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - const float* max_buffer, - float* msg_buffer, - ExecutionPlan plan, - const uint8_t* resets_u8, - int t) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - const int edge_tile = max(1, plan.edge_tile); - const int batch_tile = max(1, plan.batch_tile); - const int e_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int edge = blockIdx.x * edge_tile + e_local; - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (edge >= topo.num_edges || b >= plan.B) { - continue; - } - const bool reset_row = resets_u8 != nullptr ? resets_u8[static_cast(b) * plan.T + t] != 0 : false; - int receiver = -1; - const float logit = MessageBackend::edge_logit_step_warp( - b, - t, - edge, - reset_row, - topo, - input_ports, - public_prev, - message_params, - &receiver, - lane); - if (receiver < 0) { - continue; - } - const float max_logit = max_buffer[static_cast(b) * plan.receivers + receiver]; - const float weight = (isfinite(logit) && isfinite(max_logit)) ? expf(logit - max_logit) : 0.0f; - const int64_t base = (static_cast(b) * plan.receivers + receiver) * (plan.message_dim + 1); - for (int d = lane; d < plan.message_dim; d += kWarpSize) { - const float contrib = MessageBackend::edge_value_component_step( - b, t, edge, reset_row, topo, input_ports, public_prev, d); - atomicAdd(&msg_buffer[base + d], weight * contrib); - } - if (lane == elected_edge_lane(edge)) { - atomicAdd(&msg_buffer[base + plan.message_dim], weight); - } - } -} - -inline void validate_edge_owned_launch(const ExecutionPlan& plan) { - TORCH_CHECK(plan.edge_tile > 0, "Fabric edge-owned accumulation requires edge_tile > 0"); - TORCH_CHECK(plan.batch_tile > 0, "Fabric edge-owned accumulation requires batch_tile > 0"); - TORCH_CHECK( - plan.edge_tile * plan.batch_tile <= kMaxWarpsPerBlock, - "Fabric edge-owned accumulation launch requested too many warps per block: ", - plan.edge_tile * plan.batch_tile); -} - -template -void launch_edge_owned_accumulate_stepwise_impl( - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* max_buffer, - float* msg_buffer, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - cudaStream_t stream) { - validate_edge_owned_launch(plan); - const int warps_per_block = plan.edge_tile * plan.batch_tile; - const dim3 block(kWarpSize, warps_per_block); - const dim3 grid( - ceil_div(plan.edges, plan.edge_tile), - plan.replication_factor > 0 ? plan.replication_factor : 1); - edge_owned_max_logit_stepwise_kernel<<>>( - public_prev, - message_params, - input_ports, - topo, - max_buffer, - plan, - resets_u8.defined() ? resets_u8.data_ptr() : nullptr, - t); - edge_owned_accumulate_stepwise_kernel<<>>( - public_prev, - message_params, - input_ports, - topo, - max_buffer, - msg_buffer, - plan, - resets_u8.defined() ? resets_u8.data_ptr() : nullptr, - t); -} - -} // namespace - -void launch_edge_owned_accumulate_stepwise_cuda( - int message_backend_id, - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* max_buffer, - float* msg_buffer, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - cudaStream_t stream) { - if (message_backend_id == 1) { - launch_edge_owned_accumulate_stepwise_impl( - public_prev, message_params, input_ports, topo, max_buffer, msg_buffer, plan, resets_u8, t, stream); - return; - } - launch_edge_owned_accumulate_stepwise_impl( - public_prev, message_params, input_ports, topo, max_buffer, msg_buffer, plan, resets_u8, t, stream); -} - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/execution/output_readout_cuda.py b/src/cortical/fabric/backend/cuda/execution/output_readout_cuda.py deleted file mode 100644 index 8b9b7b85..00000000 --- a/src/cortical/fabric/backend/cuda/execution/output_readout_cuda.py +++ /dev/null @@ -1,180 +0,0 @@ -from __future__ import annotations - -import torch - -from cortical.fabric.backend.cuda.message_passing.local_message_cuda import fabric_local_message_partitioned_cuda -from cortical.fabric.backend.cuda.message_passing.sparse_message_cuda import fabric_sparse_message_partitioned_cuda -from cortical.fabric.backend.cuda.ops import dense_affine_cuda -from cortical.fabric.backend.cuda.projection.registry import register_readout_backend - - -def _contiguous_readout_args(tensors: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: - with torch.profiler.record_function("fabric.glue.output_readout_arg_contiguous"): - return tuple(tensor.contiguous() for tensor in tensors) - - -def _readout_step_index(batch_time: int, *, device: torch.device) -> torch.Tensor: - with torch.profiler.record_function("fabric.glue.output_readout_step_index"): - return torch.ones(batch_time, device=device, dtype=torch.long) - - -def _project_output_message_dense( - output_msg_flat: torch.Tensor, - *, - value_to_output_weight: torch.Tensor, - output_cell_bias: torch.Tensor, -) -> torch.Tensor: - if ( - not torch.is_grad_enabled() - and output_msg_flat.is_cuda - and output_msg_flat.dtype == torch.float32 - and value_to_output_weight.is_cuda - and value_to_output_weight.dtype == torch.float32 - and output_cell_bias.is_cuda - and output_cell_bias.dtype == torch.float32 - ): - return dense_affine_cuda( - output_msg_flat, - value_to_output_weight, - output_cell_bias, - layout="receiver_major", - ).output - output_cells_flat = torch.bmm( - output_msg_flat.transpose(0, 1), - value_to_output_weight, - ).transpose(0, 1) - return output_cells_flat + output_cell_bias.view(1, -1, output_cell_bias.shape[-1]) - - -def project_output_sequence_from_banks( - *, - input_k_seq: torch.Tensor, - input_v_seq: torch.Tensor, - recurrent_k_seq: torch.Tensor, - recurrent_v_seq: torch.Tensor, - output_q: torch.Tensor, - output_local_sender_idx: torch.Tensor, - output_local_receiver_idx_by_sender: torch.Tensor, - local_distance: torch.Tensor, - local_delay: torch.Tensor, - value_to_output_weight: torch.Tensor, - output_cell_bias: torch.Tensor, - distance_scale: float, -) -> torch.Tensor: - batch_size, time_steps = input_k_seq.shape[:2] - input_senders = int(input_k_seq.shape[2]) - ( - output_q_arg, - input_k_arg, - input_v_arg, - recurrent_k_arg, - recurrent_v_arg, - output_local_sender_idx_arg, - output_local_receiver_idx_by_sender_arg, - local_distance_arg, - local_delay_arg, - ) = _contiguous_readout_args( - ( - output_q, - input_k_seq.reshape(batch_size * time_steps, input_senders, input_k_seq.shape[-1]), - input_v_seq.reshape(batch_size * time_steps, input_senders, input_v_seq.shape[-1]), - recurrent_k_seq.reshape(batch_size * time_steps, recurrent_k_seq.shape[2], recurrent_k_seq.shape[-1]), - recurrent_v_seq.reshape(batch_size * time_steps, recurrent_v_seq.shape[2], recurrent_v_seq.shape[-1]), - output_local_sender_idx, - output_local_receiver_idx_by_sender, - local_distance, - local_delay, - ) - ) - output_msg_flat = fabric_local_message_partitioned_cuda( - output_q_arg, - input_k_arg, - input_v_arg, - recurrent_k_arg, - recurrent_v_arg, - output_local_sender_idx_arg, - output_local_receiver_idx_by_sender_arg, - local_distance_arg, - local_delay_arg, - _readout_step_index(batch_size * time_steps, device=input_k_seq.device), - num_input_senders=input_senders, - distance_scale=float(distance_scale), - use_delay=False, - ) - output_cells_flat = _project_output_message_dense( - output_msg_flat, - value_to_output_weight=value_to_output_weight, - output_cell_bias=output_cell_bias, - ) - return output_cells_flat.view(batch_size, time_steps, output_cells_flat.shape[1], output_cells_flat.shape[2]) - - -def project_output_sequence_from_sparse_banks( - *, - input_k_seq: torch.Tensor, - input_v_seq: torch.Tensor, - recurrent_k_seq: torch.Tensor, - recurrent_v_seq: torch.Tensor, - output_q: torch.Tensor, - output_neighbor_idx: torch.Tensor, - output_neighbor_valid: torch.Tensor, - output_edge_distance: torch.Tensor, - output_edge_delay: torch.Tensor, - value_to_output_weight: torch.Tensor, - output_cell_bias: torch.Tensor, - distance_scale: float, - use_delay: bool, -) -> torch.Tensor: - batch_size, time_steps = input_k_seq.shape[:2] - ( - output_q_arg, - input_k_arg, - input_v_arg, - recurrent_k_arg, - recurrent_v_arg, - output_neighbor_idx_arg, - output_neighbor_valid_arg, - output_edge_distance_arg, - output_edge_delay_arg, - ) = _contiguous_readout_args( - ( - output_q, - input_k_seq.reshape(batch_size * time_steps, input_k_seq.shape[2], input_k_seq.shape[-1]), - input_v_seq.reshape(batch_size * time_steps, input_v_seq.shape[2], input_v_seq.shape[-1]), - recurrent_k_seq.reshape(batch_size * time_steps, recurrent_k_seq.shape[2], recurrent_k_seq.shape[-1]), - recurrent_v_seq.reshape(batch_size * time_steps, recurrent_v_seq.shape[2], recurrent_v_seq.shape[-1]), - output_neighbor_idx, - output_neighbor_valid, - output_edge_distance, - output_edge_delay, - ) - ) - output_msg_flat = fabric_sparse_message_partitioned_cuda( - output_q_arg, - input_k_arg, - input_v_arg, - recurrent_k_arg, - recurrent_v_arg, - output_neighbor_idx_arg, - output_neighbor_valid_arg, - output_edge_distance_arg, - output_edge_delay_arg, - _readout_step_index(batch_size * time_steps, device=input_k_seq.device), - distance_scale=float(distance_scale), - use_delay=bool(use_delay), - ) - output_cells_flat = _project_output_message_dense( - output_msg_flat, - value_to_output_weight=value_to_output_weight, - output_cell_bias=output_cell_bias, - ) - return output_cells_flat.view(batch_size, time_steps, output_cells_flat.shape[1], output_cells_flat.shape[2]) - - -__all__ = [ - "project_output_sequence_from_banks", - "project_output_sequence_from_sparse_banks", -] - -register_readout_backend("output_sequence_from_banks", project_output_sequence_from_banks) -register_readout_backend("output_sequence_from_sparse_banks", project_output_sequence_from_sparse_banks) diff --git a/src/cortical/fabric/backend/cuda/execution/readout_apply.cu b/src/cortical/fabric/backend/cuda/execution/readout_apply.cu deleted file mode 100644 index 652957ff..00000000 --- a/src/cortical/fabric/backend/cuda/execution/readout_apply.cu +++ /dev/null @@ -1,560 +0,0 @@ -#include -#include -#include -#include - -#include - -#include "cortical/fabric/backend/cuda/execution/common.cuh" -#include "cortical/fabric/backend/cuda/projection/linear_readout_backend.cuh" -#include "cortical/fabric/backend/cuda/projection/public_projection_backends.cuh" - -namespace fabric { - -namespace { - -constexpr int kMaxWarpsPerBlock = 32; -constexpr int kMaxGridDimY = 65535; - -int readout_batch_grid_tiles(const ExecutionPlan& plan) { - const int batch_tile = max(1, plan.batch_tile); - const int batch_blocks = ceil_div(plan.B, batch_tile); - if (batch_blocks <= 0) { - return 1; - } - return batch_blocks < kMaxGridDimY ? batch_blocks : kMaxGridDimY; -} - -int readout_subwarp_width(int head_dim, int value_dim) { - const int max_dim = max(head_dim, value_dim); - if (max_dim <= 4) { - return 4; - } - if (max_dim <= 8) { - return 8; - } - if (max_dim <= 16) { - return 16; - } - return kWarpSize; -} - -__device__ inline float subwarp_sum(float value, int width) { - const unsigned mask = __match_any_sync(0xffffffffu, threadIdx.x / width); - for (int delta = width / 2; delta > 0; delta >>= 1) { - value += __shfl_down_sync(mask, value, delta, width); - } - return __shfl_sync(mask, value, 0, width); -} - -__global__ void readout_message_kernel( - TensorTable input_ports, - TensorTable public_now, - ReadoutSpec readout, - float* output_msg, - int64_t output_msg_stride_b, - int64_t output_msg_stride_p, - ExecutionPlan plan, - int value_dim, - int t, - int recurrent_receiver_offset) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - const int port_tile = max(1, plan.readout_port_tile); - const int batch_tile = max(1, plan.batch_tile); - const int value_chunk = max(1, plan.readout_output_chunk); - const int p_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int output_port = blockIdx.x * port_tile + p_local; - if (output_port >= plan.output_ports) { - return; - } - - const auto output_q = tensor_ref(readout.params, 0); - const auto input_k = tensor_ref(input_ports, 0); - const auto input_v = tensor_ref(input_ports, 1); - const auto recurrent_k = tensor_ref(public_now, 1); - const auto recurrent_v = tensor_ref(public_now, 2); - const int head_dim = static_cast(output_q.size[1]); - const int value_begin = blockIdx.z * value_chunk; - const int value_end = min(value_dim, value_begin + value_chunk); - if (value_begin >= value_dim) { - return; - } - const float inv_sqrt_dk = rsqrtf(static_cast(head_dim > 0 ? head_dim : 1)); - const int edge_begin = readout.topology.receiver_ptr[output_port]; - const int edge_end = readout.topology.receiver_ptr[output_port + 1]; - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (b >= plan.B) { - continue; - } - - float max_logit = -INFINITY; - for (int edge = edge_begin; edge < edge_end; ++edge) { - const int sender = readout.topology.sender_idx[edge]; - float dot = 0.0f; - if (sender < readout.topology.num_input_ports) { - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += output_q.at(output_port, d) * - detail::read_input_port_for_readout(input_k, b, t, sender, d); - } - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += output_q.at(output_port, d) * recurrent_k.at(b, recurrent_sender, d); - } - } - const float penalty = readout.topology.edge_weight == nullptr ? 0.0f : readout.topology.edge_weight[edge]; - max_logit = fmaxf(max_logit, warp_sum(dot) * inv_sqrt_dk - penalty); - } - - for (int value_index = value_begin + lane; value_index < value_end; value_index += kWarpSize) { - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] = 0.0f; - } - - float norm = 0.0f; - for (int edge = edge_begin; edge < edge_end; ++edge) { - const int sender = readout.topology.sender_idx[edge]; - float dot = 0.0f; - if (sender < readout.topology.num_input_ports) { - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += output_q.at(output_port, d) * - detail::read_input_port_for_readout(input_k, b, t, sender, d); - } - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += output_q.at(output_port, d) * recurrent_k.at(b, recurrent_sender, d); - } - } - const float penalty = readout.topology.edge_weight == nullptr ? 0.0f : readout.topology.edge_weight[edge]; - const float weight = expf(warp_sum(dot) * inv_sqrt_dk - penalty - max_logit); - norm += weight; - for (int value_index = value_begin + lane; value_index < value_end; value_index += kWarpSize) { - float value; - if (sender < readout.topology.num_input_ports) { - value = detail::read_input_port_for_readout(input_v, b, t, sender, value_index); - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - value = recurrent_v.at(b, recurrent_sender, value_index); - } - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] += weight * value; - } - } - - if (norm > 0.0f) { - const float inv_norm = 1.0f / norm; - for (int value_index = value_begin + lane; value_index < value_end; value_index += kWarpSize) { - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] *= inv_norm; - } - } - } -} - -__global__ void readout_message_subwarp_kernel( - TensorTable input_ports, - TensorTable public_now, - ReadoutSpec readout, - float* output_msg, - int64_t output_msg_stride_b, - int64_t output_msg_stride_p, - ExecutionPlan plan, - int value_dim, - int t, - int recurrent_receiver_offset, - int subwarp_width) { - const int lane = threadIdx.x; - const int subwarp_id = lane / subwarp_width; - const int sublane = lane % subwarp_width; - const int subwarps_per_warp = kWarpSize / subwarp_width; - const int pair_local = threadIdx.y * subwarps_per_warp + subwarp_id; - const int pairs_per_block = blockDim.y * subwarps_per_warp; - const int64_t pair = static_cast(blockIdx.x) * pairs_per_block + pair_local; - const int64_t total_pairs = static_cast(plan.B) * max(1, plan.output_ports); - if (pair >= total_pairs) { - return; - } - const int output_port = static_cast(pair % plan.output_ports); - const int b = static_cast(pair / plan.output_ports); - - const auto output_q = tensor_ref(readout.params, 0); - const auto input_k = tensor_ref(input_ports, 0); - const auto input_v = tensor_ref(input_ports, 1); - const auto recurrent_k = tensor_ref(public_now, 1); - const auto recurrent_v = tensor_ref(public_now, 2); - const int head_dim = static_cast(output_q.size[1]); - const int value_chunk = max(1, plan.readout_output_chunk); - const int value_begin = blockIdx.z * value_chunk; - const int value_end = min(value_dim, value_begin + value_chunk); - if (value_begin >= value_dim) { - return; - } - const float inv_sqrt_dk = rsqrtf(static_cast(head_dim > 0 ? head_dim : 1)); - const int edge_begin = readout.topology.receiver_ptr[output_port]; - const int edge_end = readout.topology.receiver_ptr[output_port + 1]; - - float max_logit = -INFINITY; - for (int edge = edge_begin; edge < edge_end; ++edge) { - const int sender = readout.topology.sender_idx[edge]; - float dot = 0.0f; - if (sender < readout.topology.num_input_ports) { - for (int d = sublane; d < head_dim; d += subwarp_width) { - dot += output_q.at(output_port, d) * detail::read_input_port_for_readout(input_k, b, t, sender, d); - } - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - for (int d = sublane; d < head_dim; d += subwarp_width) { - dot += output_q.at(output_port, d) * recurrent_k.at(b, recurrent_sender, d); - } - } - const float penalty = readout.topology.edge_weight == nullptr ? 0.0f : readout.topology.edge_weight[edge]; - max_logit = fmaxf(max_logit, subwarp_sum(dot, subwarp_width) * inv_sqrt_dk - penalty); - } - - for (int value_index = value_begin + sublane; value_index < value_end; value_index += subwarp_width) { - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] = 0.0f; - } - - float norm = 0.0f; - for (int edge = edge_begin; edge < edge_end; ++edge) { - const int sender = readout.topology.sender_idx[edge]; - float dot = 0.0f; - if (sender < readout.topology.num_input_ports) { - for (int d = sublane; d < head_dim; d += subwarp_width) { - dot += output_q.at(output_port, d) * detail::read_input_port_for_readout(input_k, b, t, sender, d); - } - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - for (int d = sublane; d < head_dim; d += subwarp_width) { - dot += output_q.at(output_port, d) * recurrent_k.at(b, recurrent_sender, d); - } - } - const float penalty = readout.topology.edge_weight == nullptr ? 0.0f : readout.topology.edge_weight[edge]; - const float weight = expf(subwarp_sum(dot, subwarp_width) * inv_sqrt_dk - penalty - max_logit); - norm += weight; - for (int value_index = value_begin + sublane; value_index < value_end; value_index += subwarp_width) { - float value; - if (sender < readout.topology.num_input_ports) { - value = detail::read_input_port_for_readout(input_v, b, t, sender, value_index); - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - value = recurrent_v.at(b, recurrent_sender, value_index); - } - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] += weight * value; - } - } - - if (norm > 0.0f) { - const float inv_norm = 1.0f / norm; - for (int value_index = value_begin + sublane; value_index < value_end; value_index += subwarp_width) { - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] *= inv_norm; - } - } -} - -__global__ void readout_message_from_raw_public_kernel( - TensorTable input_ports, - const float* raw_public, - int64_t raw_public_stride_b, - int64_t raw_public_stride_r, - int64_t raw_public_stride_d, - int raw_public_dim, - int recurrent_receiver_offset, - int public_projection_kind, - TensorTable public_projection_params, - ReadoutSpec readout, - float* output_msg, - int64_t output_msg_stride_b, - int64_t output_msg_stride_p, - ExecutionPlan plan, - int value_dim, - int t) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - const int port_tile = max(1, plan.readout_port_tile); - const int batch_tile = max(1, plan.batch_tile); - const int value_chunk = max(1, plan.readout_output_chunk); - const int p_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int output_port = blockIdx.x * port_tile + p_local; - if (output_port >= plan.output_ports) { - return; - } - - const auto output_q = tensor_ref(readout.params, 0); - const auto input_k = tensor_ref(input_ports, 0); - const auto input_v = tensor_ref(input_ports, 1); - const int head_dim = static_cast(output_q.size[1]); - const int value_begin = blockIdx.z * value_chunk; - const int value_end = min(value_dim, value_begin + value_chunk); - if (value_begin >= value_dim) { - return; - } - const float inv_sqrt_dk = rsqrtf(static_cast(head_dim > 0 ? head_dim : 1)); - const int edge_begin = readout.topology.receiver_ptr[output_port]; - const int edge_end = readout.topology.receiver_ptr[output_port + 1]; - const auto projection_kind = static_cast(public_projection_kind); - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (b >= plan.B) { - continue; - } - - float max_logit = -INFINITY; - for (int edge = edge_begin; edge < edge_end; ++edge) { - const int sender = readout.topology.sender_idx[edge]; - float dot = 0.0f; - if (sender < readout.topology.num_input_ports) { - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += output_q.at(output_port, d) * - detail::read_input_port_for_readout(input_k, b, t, sender, d); - } - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - const float* raw_public_row = - raw_public + b * raw_public_stride_b + recurrent_sender * raw_public_stride_r; - for (int d = lane; d < head_dim; d += kWarpSize) { - const float key = project_public_kv_from_raw_public( - projection_kind, - global_recurrent_sender, - d, - raw_public_row, - raw_public_dim, - public_projection_params); - dot += output_q.at(output_port, d) * key; - } - } - const float penalty = readout.topology.edge_weight == nullptr ? 0.0f : readout.topology.edge_weight[edge]; - max_logit = fmaxf(max_logit, warp_sum(dot) * inv_sqrt_dk - penalty); - } - - for (int value_index = value_begin + lane; value_index < value_end; value_index += kWarpSize) { - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] = 0.0f; - } - - float norm = 0.0f; - for (int edge = edge_begin; edge < edge_end; ++edge) { - const int sender = readout.topology.sender_idx[edge]; - float dot = 0.0f; - if (sender < readout.topology.num_input_ports) { - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += output_q.at(output_port, d) * - detail::read_input_port_for_readout(input_k, b, t, sender, d); - } - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - const float* raw_public_row = - raw_public + b * raw_public_stride_b + recurrent_sender * raw_public_stride_r; - for (int d = lane; d < head_dim; d += kWarpSize) { - const float key = project_public_kv_from_raw_public( - projection_kind, - global_recurrent_sender, - d, - raw_public_row, - raw_public_dim, - public_projection_params); - dot += output_q.at(output_port, d) * key; - } - } - const float penalty = readout.topology.edge_weight == nullptr ? 0.0f : readout.topology.edge_weight[edge]; - const float weight = expf(warp_sum(dot) * inv_sqrt_dk - penalty - max_logit); - norm += weight; - for (int value_index = value_begin + lane; value_index < value_end; value_index += kWarpSize) { - float value; - if (sender < readout.topology.num_input_ports) { - value = detail::read_input_port_for_readout(input_v, b, t, sender, value_index); - } else { - const int global_recurrent_sender = sender - readout.topology.num_input_ports; - const int recurrent_sender = global_recurrent_sender - recurrent_receiver_offset; - if (recurrent_sender < 0 || recurrent_sender >= plan.receivers) { - continue; - } - const float* raw_public_row = - raw_public + b * raw_public_stride_b + recurrent_sender * raw_public_stride_r; - value = project_public_kv_from_raw_public( - projection_kind, - global_recurrent_sender, - head_dim + value_index, - raw_public_row, - raw_public_dim, - public_projection_params); - } - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] += weight * value; - } - } - - if (norm > 0.0f) { - const float inv_norm = 1.0f / norm; - for (int value_index = value_begin + lane; value_index < value_end; value_index += kWarpSize) { - output_msg[b * output_msg_stride_b + output_port * output_msg_stride_p + value_index] *= inv_norm; - } - } - } -} - -} // namespace - -void launch_readout_message_cuda( - TensorTable input_ports, - TensorTable public_now, - ReadoutSpec readout, - at::Tensor output_msg, - ExecutionPlan plan, - int head_dim, - int value_dim, - int t, - int recurrent_receiver_offset, - cudaStream_t stream) { - if (!readout.enabled || plan.readout_mode == ReadoutMode::Skip || plan.output_ports <= 0) { - return; - } - TORCH_CHECK(output_msg.dim() == 3, "Fabric readout message output must be [B, output_ports, value_dim]"); - TORCH_CHECK(output_msg.size(0) == plan.B, "Fabric readout message output B must match plan"); - TORCH_CHECK(output_msg.size(1) == plan.output_ports, "Fabric readout message output ports must match plan"); - TORCH_CHECK(output_msg.size(2) == value_dim, "Fabric readout message output value_dim must match plan"); - TORCH_CHECK(output_msg.stride(2) == 1, "Fabric readout message output last dimension must be contiguous"); - TORCH_CHECK(plan.readout_port_tile > 0, "Fabric readout requires readout_port_tile > 0"); - TORCH_CHECK(plan.readout_output_chunk > 0, "Fabric readout requires readout_output_chunk > 0"); - TORCH_CHECK(plan.batch_tile > 0, "Fabric readout requires batch_tile > 0"); - TORCH_CHECK( - plan.readout_port_tile * plan.batch_tile <= kMaxWarpsPerBlock, - "Fabric readout launch requested too many warps per block: ", - plan.readout_port_tile * plan.batch_tile); - const int subwarp_width = readout_subwarp_width(head_dim, value_dim); - if (subwarp_width < kWarpSize) { - constexpr int kSubwarpWarpsPerBlock = 4; - const int pairs_per_warp = kWarpSize / subwarp_width; - const int pairs_per_block = kSubwarpWarpsPerBlock * pairs_per_warp; - const int64_t total_pairs = static_cast(plan.B) * max(1, plan.output_ports); - const dim3 block(kWarpSize, kSubwarpWarpsPerBlock); - const int64_t grid_x = (total_pairs + pairs_per_block - 1) / pairs_per_block; - TORCH_CHECK(grid_x <= INT_MAX, "Fabric readout subwarp launch grid exceeds supported x dimension"); - const dim3 grid( - static_cast(grid_x), - 1, - ceil_div(value_dim > 0 ? value_dim : 1, plan.readout_output_chunk)); - readout_message_subwarp_kernel<<>>( - input_ports, - public_now, - readout, - output_msg.data_ptr(), - output_msg.stride(0), - output_msg.stride(1), - plan, - value_dim, - t, - recurrent_receiver_offset, - subwarp_width); - return; - } - const int warps_per_block = plan.readout_port_tile * plan.batch_tile; - const dim3 block(kWarpSize, warps_per_block); - const dim3 grid( - ceil_div(plan.output_ports, plan.readout_port_tile), - readout_batch_grid_tiles(plan), - ceil_div(value_dim > 0 ? value_dim : 1, plan.readout_output_chunk)); - readout_message_kernel<<>>( - input_ports, - public_now, - readout, - output_msg.data_ptr(), - output_msg.stride(0), - output_msg.stride(1), - plan, - value_dim, - t, - recurrent_receiver_offset); -} - -void launch_readout_message_from_raw_public_cuda( - TensorTable input_ports, - const at::Tensor& raw_public, - int public_projection_kind, - TensorTable public_projection_params, - ReadoutSpec readout, - at::Tensor output_msg, - ExecutionPlan plan, - int value_dim, - int t, - int recurrent_receiver_offset, - cudaStream_t stream) { - if (!readout.enabled || plan.readout_mode == ReadoutMode::Skip || plan.output_ports <= 0) { - return; - } - TORCH_CHECK(raw_public.dim() == 3, "Fabric raw-public readout source must be [B, receivers, raw_public_dim]"); - TORCH_CHECK(raw_public.size(0) == plan.B, "Fabric raw-public readout source B must match plan"); - TORCH_CHECK(raw_public.size(1) == plan.receivers, "Fabric raw-public readout source receivers must match plan"); - TORCH_CHECK(raw_public.stride(2) == 1, "Fabric raw-public readout source last dimension must be contiguous"); - TORCH_CHECK(output_msg.dim() == 3, "Fabric readout message output must be [B, output_ports, value_dim]"); - TORCH_CHECK(output_msg.size(0) == plan.B, "Fabric readout message output B must match plan"); - TORCH_CHECK(output_msg.size(1) == plan.output_ports, "Fabric readout message output ports must match plan"); - TORCH_CHECK(output_msg.size(2) == value_dim, "Fabric readout message output value_dim must match plan"); - TORCH_CHECK(output_msg.stride(2) == 1, "Fabric readout message output last dimension must be contiguous"); - TORCH_CHECK(plan.readout_port_tile > 0, "Fabric readout requires readout_port_tile > 0"); - TORCH_CHECK(plan.readout_output_chunk > 0, "Fabric readout requires readout_output_chunk > 0"); - TORCH_CHECK(plan.batch_tile > 0, "Fabric readout requires batch_tile > 0"); - TORCH_CHECK( - plan.readout_port_tile * plan.batch_tile <= kMaxWarpsPerBlock, - "Fabric readout launch requested too many warps per block: ", - plan.readout_port_tile * plan.batch_tile); - const int warps_per_block = plan.readout_port_tile * plan.batch_tile; - const dim3 block(kWarpSize, warps_per_block); - const dim3 grid( - ceil_div(plan.output_ports, plan.readout_port_tile), - readout_batch_grid_tiles(plan), - ceil_div(value_dim > 0 ? value_dim : 1, plan.readout_output_chunk)); - readout_message_from_raw_public_kernel<<>>( - input_ports, - raw_public.data_ptr(), - raw_public.stride(0), - raw_public.stride(1), - raw_public.stride(2), - static_cast(raw_public.size(2)), - recurrent_receiver_offset, - public_projection_kind, - public_projection_params, - readout, - output_msg.data_ptr(), - output_msg.stride(0), - output_msg.stride(1), - plan, - value_dim, - t); -} - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/execution/receiver_owned_stepwise.cu b/src/cortical/fabric/backend/cuda/execution/receiver_owned_stepwise.cu deleted file mode 100644 index c09a8375..00000000 --- a/src/cortical/fabric/backend/cuda/execution/receiver_owned_stepwise.cu +++ /dev/null @@ -1,38 +0,0 @@ -#include "cortical/fabric/backend/cuda/execution/receiver_owned_stepwise.cuh" - -#include "cortical/fabric/backend/cuda/message_passing/local_message_backend.cuh" -#include "cortical/fabric/backend/cuda/message_passing/sparse_message_backend.cuh" - -namespace fabric { - -void launch_receiver_message_aggregate_cuda( - int message_backend_id, - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* message_out, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - cudaStream_t stream) { - stepwise_detail::validate_receiver_launch(plan); - if (message_backend_id == 1) { - stepwise_detail::launch_receiver_message_aggregate_variant( - public_prev, message_params, input_ports, topo, message_out, plan, resets_u8, t, stream); - return; - } - stepwise_detail::launch_receiver_message_aggregate_variant( - public_prev, message_params, input_ports, topo, message_out, plan, resets_u8, t, stream); -} - -void launch_receiver_normalize_accumulated_message_cuda( - const float* msg_buffer, - float* message_out, - ExecutionPlan plan, - cudaStream_t stream) { - stepwise_detail::validate_receiver_launch(plan); - stepwise_detail::launch_receiver_normalize_accumulated_variant(msg_buffer, message_out, plan, stream); -} - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/execution/receiver_owned_stepwise.cuh b/src/cortical/fabric/backend/cuda/execution/receiver_owned_stepwise.cuh deleted file mode 100644 index 0ae2d75a..00000000 --- a/src/cortical/fabric/backend/cuda/execution/receiver_owned_stepwise.cuh +++ /dev/null @@ -1,2080 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include -#include - -#include "cortical/fabric/backend/cuda/execution/common.cuh" -#include "cortical/fabric/backend/cuda/nn/ir.cuh" - -namespace fabric { - -namespace stepwise_detail { - -constexpr int kMaxWarpsPerBlock = 32; - -__host__ __device__ inline size_t align_float_bytes(size_t bytes) { - return (bytes + sizeof(float) - 1) & ~(sizeof(float) - 1); -} - -template -constexpr int reduction_storage_dim() { - return CellCore::kReductionStatsDim > 0 ? CellCore::kReductionStatsDim : 1; -} - -inline void validate_phase_launch( - const char* phase, - int receiver_tile, - int batch_tile, - int hidden_chunk) { - TORCH_CHECK(receiver_tile > 0, "Fabric ", phase, " launch requires receiver_tile > 0"); - TORCH_CHECK(batch_tile > 0, "Fabric ", phase, " launch requires batch_tile > 0"); - TORCH_CHECK(hidden_chunk > 0, "Fabric ", phase, " launch requires hidden_chunk > 0"); - TORCH_CHECK( - receiver_tile * batch_tile <= kMaxWarpsPerBlock, - "Fabric ", - phase, - " launch requested too many warps per block: ", - receiver_tile * batch_tile); -} - -inline void validate_receiver_launch(const ExecutionPlan& plan) { - validate_phase_launch("receiver message", plan.receiver_tile, plan.batch_tile, plan.hidden_chunk); - validate_phase_launch("receiver state", plan.state_receiver_tile, plan.state_batch_tile, plan.state_hidden_chunk); - validate_phase_launch("receiver emit", plan.emit_receiver_tile, plan.emit_batch_tile, plan.emit_hidden_chunk); -} - -inline void check_stepwise_launch(const char* kernel_name) { - const cudaError_t err = cudaGetLastError(); - TORCH_CHECK(err == cudaSuccess, kernel_name, " launch failed: ", cudaGetErrorString(err)); -} - -inline void fail_unsupported_tile(const char* phase, int receiver_tile, int batch_tile, int hidden_chunk) { - TORCH_CHECK( - false, - "Unsupported Fabric ", - phase, - " tile variant receiver_tile=", - receiver_tile, - " batch_tile=", - batch_tile, - " hidden_chunk=", - hidden_chunk); -} - -inline void fail_unsupported_tile(const char* phase, int receiver_tile, int batch_tile) { - TORCH_CHECK( - false, - "Unsupported Fabric ", - phase, - " tile variant receiver_tile=", - receiver_tile, - " batch_tile=", - batch_tile); -} - -template -inline void set_dynamic_smem_if_needed(Kernel kernel, size_t shared_bytes, const cudaDeviceProp* props) { - if (shared_bytes > static_cast(props->sharedMemPerBlock)) { - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_bytes)); - } -} - -template -__device__ inline void stage_receiver_state_static_warp( - int receiver, - bool receiver_active, - int b_local, - int r_local, - int lane, - TensorTable cell_params, - ExecutionPlan plan, - int state_static_bytes, - unsigned char* static_base, - int lane_stride = kWarpSize) { - const bool should_stage = plan.stage_receiver_static && - plan.state_static_stage_mode != CellStaticStageMode::Disabled && state_static_bytes > 0; - if (should_stage && receiver_active && b_local == 0) { - CellCore::stage_state_static( - receiver, - cell_params, - static_base + static_cast(r_local) * state_static_bytes, - lane, - lane_stride); - } -} - -__device__ inline int state_table_receiver_index( - const TensorTable& table, - int receiver_local, - int receiver_global, - int planned_receivers) { - if (table.count <= 0) { - return receiver_local; - } - const auto first = tensor_ref(table, 0); - if (first.ndim >= 2 && first.size[1] == planned_receivers) { - return receiver_local; - } - return receiver_global; -} - -template -__device__ inline const void* staged_state_static_ptr( - int r_local, - ExecutionPlan plan, - int state_static_bytes, - unsigned char* static_base) { - const bool should_stage = plan.stage_receiver_static && - plan.state_static_stage_mode != CellStaticStageMode::Disabled && state_static_bytes > 0; - return should_stage ? static_base + static_cast(r_local) * state_static_bytes : nullptr; -} - -template -__device__ inline void stage_receiver_emit_static_warp( - int receiver, - bool receiver_active, - int b_local, - int r_local, - int lane, - TensorTable cell_params, - ExecutionPlan plan, - int emit_static_bytes, - unsigned char* static_base, - int lane_stride = kWarpSize) { - const bool should_stage = plan.stage_receiver_static && - plan.emit_static_stage_mode != CellStaticStageMode::Disabled && emit_static_bytes > 0; - if (should_stage && receiver_active && b_local == 0) { - CellCore::stage_emit_static( - receiver, - cell_params, - static_base + static_cast(r_local) * emit_static_bytes, - lane, - lane_stride); - } -} - -template -__device__ inline const void* staged_emit_static_ptr( - int r_local, - ExecutionPlan plan, - int emit_static_bytes, - unsigned char* static_base) { - const bool should_stage = plan.stage_receiver_static && - plan.emit_static_stage_mode != CellStaticStageMode::Disabled && emit_static_bytes > 0; - return should_stage ? static_base + static_cast(r_local) * emit_static_bytes : nullptr; -} - -template -__device__ inline void reduce_warp_stats_to_buffer( - float* warp_stats, - const float* lane_stats, - float* out_stats) { - if constexpr (CellCore::kReductionStatsDim > 0) { - constexpr int kStatsDim = CellCore::kReductionStatsDim; - float* lane_slot = warp_stats + static_cast(threadIdx.x) * kStatsDim; - for (int stat = 0; stat < kStatsDim; ++stat) { - lane_slot[stat] = lane_stats[stat]; - } - __syncwarp(); - for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { - if (threadIdx.x < offset) { - CellCore::combine_reduction_stats(lane_slot, lane_slot + static_cast(offset) * kStatsDim); - } - __syncwarp(); - } - for (int stat = threadIdx.x; stat < kStatsDim; stat += kWarpSize) { - out_stats[stat] = warp_stats[stat]; - } - } -} - -template -__device__ inline void reduce_lane_group_stats_to_buffer( - float* group_stats, - const float* lane_stats, - float* out_stats, - int group_lane) { - if constexpr (CellCore::kReductionStatsDim > 0) { - constexpr int kStatsDim = CellCore::kReductionStatsDim; - float* lane_slot = group_stats + static_cast(group_lane) * kStatsDim; - for (int stat = 0; stat < kStatsDim; ++stat) { - lane_slot[stat] = lane_stats[stat]; - } - __syncwarp(); - for (int offset = GROUP_SIZE / 2; offset > 0; offset >>= 1) { - if (group_lane < offset) { - CellCore::combine_reduction_stats(lane_slot, lane_slot + static_cast(offset) * kStatsDim); - } - __syncwarp(); - } - for (int stat = group_lane; stat < kStatsDim; stat += GROUP_SIZE) { - out_stats[stat] = group_stats[stat]; - } - } -} - -template -__device__ inline void reduce_chunks_to_buffer( - float* warp_stats, - float* scratch_stats, - const float* partial_stats, - float* reduced_stats, - int num_hidden_chunks, - int state_dim) { - if constexpr (CellCore::kReductionStatsDim > 0) { - constexpr int kStatsDim = CellCore::kReductionStatsDim; - float lane_stats[kStatsDim]; - float chunk_stats[kStatsDim]; - CellCore::init_reduction_stats(lane_stats); - for (int chunk = threadIdx.x; chunk < num_hidden_chunks; chunk += kWarpSize) { - const float* src = partial_stats + static_cast(chunk) * kStatsDim; - for (int stat = 0; stat < kStatsDim; ++stat) { - chunk_stats[stat] = src[stat]; - } - CellCore::combine_reduction_stats(lane_stats, chunk_stats); - } - float* lane_slot = warp_stats + static_cast(threadIdx.x) * kStatsDim; - for (int stat = 0; stat < kStatsDim; ++stat) { - lane_slot[stat] = lane_stats[stat]; - } - __syncwarp(); - for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { - if (threadIdx.x < offset) { - CellCore::combine_reduction_stats(lane_slot, lane_slot + static_cast(offset) * kStatsDim); - } - __syncwarp(); - } - for (int stat = threadIdx.x; stat < kStatsDim; stat += kWarpSize) { - scratch_stats[stat] = warp_stats[stat]; - } - __syncwarp(); - CellCore::finalize_reduction_stats(scratch_stats, state_dim, threadIdx.x, kWarpSize); - __syncwarp(); - for (int stat = threadIdx.x; stat < kStatsDim; stat += kWarpSize) { - reduced_stats[stat] = scratch_stats[stat]; - } - } -} - -template -__global__ void receiver_message_aggregate_kernel( - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* message_out, - ExecutionPlan plan, - const uint8_t* resets_u8, - int t) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - constexpr int receiver_tile = R_TILE; - constexpr int batch_tile = B_TILE; - const int r_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int receiver = blockIdx.x * receiver_tile + r_local; - float* message = message_out + static_cast(receiver) * plan.message_dim; - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (r_local >= receiver_tile || receiver >= plan.receivers || b >= plan.B) { - continue; - } - message = message_out + (static_cast(b) * plan.receivers + receiver) * plan.message_dim; - const bool reset_row = resets_u8 != nullptr ? resets_u8[static_cast(b) * plan.T + t] != 0 : false; - MessageBackend::aggregate_receiver_step_warp( - b, - t, - receiver, - reset_row, - topo, - input_ports, - public_prev, - message_params, - message, - plan.message_dim, - lane); - } -} - -template -__global__ void receiver_normalize_accumulated_message_kernel( - const float* msg_buffer, - float* message_out, - ExecutionPlan plan) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - constexpr int receiver_tile = R_TILE; - constexpr int batch_tile = B_TILE; - const int r_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int receiver = blockIdx.x * receiver_tile + r_local; - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (r_local >= receiver_tile || receiver >= plan.receivers || b >= plan.B) { - continue; - } - const int64_t accumulated_base = (static_cast(b) * plan.receivers + receiver) * (plan.message_dim + 1); - const int64_t message_base = (static_cast(b) * plan.receivers + receiver) * plan.message_dim; - const float denom = msg_buffer[accumulated_base + plan.message_dim]; - const float inv_norm = denom > 0.0f ? 1.0f / denom : 0.0f; - for (int d = lane; d < plan.message_dim; d += kWarpSize) { - message_out[message_base + d] = msg_buffer[accumulated_base + d] * inv_norm; - } - } -} - -template -__global__ void receiver_state_update_kernel( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* partial_stats, - ExecutionPlan plan, - const uint8_t* resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - constexpr int receiver_tile = R_TILE; - constexpr int batch_tile = B_TILE; - constexpr int hidden_chunk = H_CHUNK; - const int r_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int receiver_chunk_local = blockIdx.x * receiver_tile + r_local; - const int receiver_local = receiver_offset + receiver_chunk_local; - const int receiver = receiver_global_offset + receiver_local; - const int h0 = blockIdx.z * hidden_chunk; - extern __shared__ unsigned char smem[]; - unsigned char* static_base = smem; - const size_t static_bytes = plan.stage_receiver_static && STAGE_MODE != CellStaticStageMode::Disabled && - state_static_bytes > 0 - ? static_cast(receiver_tile) * static_cast(state_static_bytes) - : 0; - float* stats_scratch = reinterpret_cast(smem + align_float_bytes(static_bytes)); - - stage_receiver_state_static_warp( - receiver, - receiver_local < plan.receivers, - b_local, - r_local, - lane, - cell_params, - plan, - state_static_bytes, - static_base); - __syncthreads(); - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (r_local >= receiver_tile || receiver_chunk_local >= receiver_count || receiver_local >= plan.receivers || - b >= plan.B) { - continue; - } - const int state_dim = CellCore::state_dim(projected_message_dim, raw_public_dim, state_prev); - if (h0 >= state_dim) { - continue; - } - const int h_count = min(hidden_chunk, state_dim - h0); - const bool reset_row = - state_prev_is_zero || (resets_u8 != nullptr ? resets_u8[static_cast(b) * plan.T + t] != 0 : false); - const float* projected_in = - projected_message + (static_cast(b) * plan.receivers + receiver_local) * projected_message_dim; - const void* staged_static = staged_state_static_ptr(r_local, plan, state_static_bytes, static_base); - const int state_receiver = state_table_receiver_index(state_prev, receiver_local, receiver, plan.receivers); - constexpr int kStatsStorageDim = reduction_storage_dim(); - float lane_stats[kStatsStorageDim]; - if constexpr (CellCore::kReductionStatsDim > 0) { - CellCore::init_reduction_stats(lane_stats); - } - CellCore::forward_state_chunk( - b, - state_receiver, - receiver, - receiver_chunk_local, - reset_row, - lane, - kWarpSize, - staged_static, - cell_params, - state_prev, - state_next, - projected_in, - projected_message_dim, - h0, - h_count, - lane_stats, - &aux); - if constexpr (CellCore::kReductionStatsDim > 0) { - float* out_stats = partial_stats + - (((static_cast(b) * plan.receivers + receiver_local) * num_hidden_chunks + blockIdx.z) * - CellCore::kReductionStatsDim); - reduce_warp_stats_to_buffer( - stats_scratch + static_cast(warp) * kWarpSize * CellCore::kReductionStatsDim, - lane_stats, - out_stats); - } - } -} - -template -__global__ void receiver_state_update_emit_kernel( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* raw_public, - ExecutionPlan plan, - const uint8_t* resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - constexpr int receiver_tile = R_TILE; - constexpr int batch_tile = B_TILE; - constexpr int hidden_chunk = H_CHUNK; - const int r_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int receiver_chunk_local = blockIdx.x * receiver_tile + r_local; - const int receiver_local = receiver_offset + receiver_chunk_local; - const int receiver = receiver_global_offset + receiver_local; - const int h0 = blockIdx.z * hidden_chunk; - extern __shared__ unsigned char smem[]; - const size_t state_static_smem = - plan.stage_receiver_static && STATE_STAGE_MODE != CellStaticStageMode::Disabled && state_static_bytes > 0 - ? static_cast(receiver_tile) * static_cast(state_static_bytes) - : 0; - const size_t state_static_aligned = align_float_bytes(state_static_smem); - const size_t emit_static_smem = - plan.stage_receiver_static && EMIT_STAGE_MODE != CellStaticStageMode::Disabled && emit_static_bytes > 0 - ? static_cast(receiver_tile) * static_cast(emit_static_bytes) - : 0; - const size_t emit_static_aligned = align_float_bytes(emit_static_smem); - const size_t stats_scratch_offset = state_static_aligned + emit_static_aligned; - unsigned char* state_static_base = smem; - unsigned char* emit_static_base = smem + state_static_aligned; - float* stats_scratch = reinterpret_cast(smem + stats_scratch_offset); - - stage_receiver_state_static_warp( - receiver, - receiver_local < plan.receivers, - b_local, - r_local, - lane, - cell_params, - plan, - state_static_bytes, - state_static_base); - stage_receiver_emit_static_warp( - receiver, - receiver_local < plan.receivers, - b_local, - r_local, - lane, - cell_params, - plan, - emit_static_bytes, - emit_static_base); - __syncthreads(); - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (r_local >= receiver_tile || receiver_chunk_local >= receiver_count || receiver_local >= plan.receivers || - b >= plan.B) { - continue; - } - const int state_dim = CellCore::state_dim(projected_message_dim, raw_public_dim, state_prev); - if (h0 >= state_dim) { - continue; - } - const int h_count = min(hidden_chunk, state_dim - h0); - const bool reset_row = - state_prev_is_zero || (resets_u8 != nullptr ? resets_u8[static_cast(b) * plan.T + t] != 0 : false); - const float* projected_in = - projected_message + (static_cast(b) * plan.receivers + receiver_local) * projected_message_dim; - const void* staged_state_static = - staged_state_static_ptr(r_local, plan, state_static_bytes, state_static_base); - const int state_receiver = state_table_receiver_index(state_prev, receiver_local, receiver, plan.receivers); - constexpr int kStatsStorageDim = reduction_storage_dim(); - float lane_stats[kStatsStorageDim]; - if constexpr (CellCore::kReductionStatsDim > 0) { - CellCore::init_reduction_stats(lane_stats); - } - CellCore::forward_state_chunk( - b, - state_receiver, - receiver, - receiver_chunk_local, - reset_row, - lane, - kWarpSize, - staged_state_static, - cell_params, - state_prev, - state_next, - projected_in, - projected_message_dim, - h0, - h_count, - lane_stats, - &aux); - if (h0 < raw_public_dim) { - const float* stats = nullptr; - if constexpr (CellCore::kReductionStatsDim > 0) { - float* warp_stats = - stats_scratch + static_cast(warp) * kWarpSize * CellCore::kReductionStatsDim; - reduce_warp_stats_to_buffer(warp_stats, lane_stats, warp_stats); - __syncwarp(); - stats = warp_stats; - } - const int emit_count = min(hidden_chunk, raw_public_dim - h0); - float* raw_public_out = - raw_public + (static_cast(b) * plan.receivers + receiver_local) * raw_public_dim; - const void* staged_emit_static = - staged_emit_static_ptr(r_local, plan, emit_static_bytes, emit_static_base); - CellCore::emit_public_chunk( - b, - state_receiver, - receiver, - lane, - kWarpSize, - staged_emit_static, - cell_params, - state_next, - raw_public_out, - raw_public_dim, - h0, - emit_count, - stats); - } - } -} - -template < - typename CellCore, - int R_TILE, - int B_TILE, - int H_CHUNK, - int SUBWARP_SIZE, - CellStaticStageMode STATE_STAGE_MODE, - CellStaticStageMode EMIT_STAGE_MODE> -__global__ void receiver_state_update_emit_subwarp_kernel( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* raw_public, - ExecutionPlan plan, - const uint8_t* resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count) { - static_assert(SUBWARP_SIZE == 8 || SUBWARP_SIZE == 16, "subwarp state/update emit supports 8- or 16-lane groups"); - static_assert(H_CHUNK <= SUBWARP_SIZE, "subwarp state/update emit requires H_CHUNK <= SUBWARP_SIZE"); - static_assert(CellCore::kReductionStatsDim > 0, "subwarp state/update emit is for reduction-boundary epilogues"); - constexpr int subwarp_size = SUBWARP_SIZE; - constexpr int subwarps_per_warp = kWarpSize / subwarp_size; - const int lane = threadIdx.x; - const int subwarp = lane / subwarp_size; - const int sub_lane = lane - subwarp * subwarp_size; - const int logical_warp = threadIdx.y * subwarps_per_warp + subwarp; - constexpr int receiver_tile = R_TILE; - constexpr int batch_tile = B_TILE; - constexpr int hidden_chunk = H_CHUNK; - const int r_local = logical_warp / batch_tile; - const int b_local = logical_warp % batch_tile; - const int receiver_chunk_local = blockIdx.x * receiver_tile + r_local; - const int receiver_local = receiver_offset + receiver_chunk_local; - const int receiver = receiver_global_offset + receiver_local; - const int h0 = blockIdx.z * hidden_chunk; - extern __shared__ unsigned char smem[]; - const size_t state_static_smem = - plan.stage_receiver_static && STATE_STAGE_MODE != CellStaticStageMode::Disabled && state_static_bytes > 0 - ? static_cast(receiver_tile) * static_cast(state_static_bytes) - : 0; - const size_t state_static_aligned = align_float_bytes(state_static_smem); - const size_t emit_static_smem = - plan.stage_receiver_static && EMIT_STAGE_MODE != CellStaticStageMode::Disabled && emit_static_bytes > 0 - ? static_cast(receiver_tile) * static_cast(emit_static_bytes) - : 0; - const size_t emit_static_aligned = align_float_bytes(emit_static_smem); - const size_t stats_scratch_offset = state_static_aligned + emit_static_aligned; - unsigned char* state_static_base = smem; - unsigned char* emit_static_base = smem + state_static_aligned; - float* stats_scratch = reinterpret_cast(smem + stats_scratch_offset); - - stage_receiver_state_static_warp( - receiver, - receiver_local < plan.receivers, - b_local, - r_local, - sub_lane, - cell_params, - plan, - state_static_bytes, - state_static_base, - subwarp_size); - stage_receiver_emit_static_warp( - receiver, - receiver_local < plan.receivers, - b_local, - r_local, - sub_lane, - cell_params, - plan, - emit_static_bytes, - emit_static_base, - subwarp_size); - __syncthreads(); - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (r_local >= receiver_tile || receiver_chunk_local >= receiver_count || receiver_local >= plan.receivers || - b >= plan.B) { - continue; - } - const int state_dim = CellCore::state_dim(projected_message_dim, raw_public_dim, state_prev); - if (h0 >= state_dim) { - continue; - } - const int h_count = min(hidden_chunk, state_dim - h0); - const bool reset_row = - state_prev_is_zero || (resets_u8 != nullptr ? resets_u8[static_cast(b) * plan.T + t] != 0 : false); - const float* projected_in = - projected_message + (static_cast(b) * plan.receivers + receiver_local) * projected_message_dim; - const void* staged_state_static = - staged_state_static_ptr(r_local, plan, state_static_bytes, state_static_base); - const int state_receiver = state_table_receiver_index(state_prev, receiver_local, receiver, plan.receivers); - constexpr int kStatsStorageDim = reduction_storage_dim(); - float lane_stats[kStatsStorageDim]; - CellCore::init_reduction_stats(lane_stats); - CellCore::forward_state_chunk( - b, - state_receiver, - receiver, - receiver_chunk_local, - reset_row, - sub_lane, - subwarp_size, - staged_state_static, - cell_params, - state_prev, - state_next, - projected_in, - projected_message_dim, - h0, - h_count, - lane_stats, - &aux); - if (h0 < raw_public_dim) { - float* group_stats = - stats_scratch + static_cast(logical_warp) * subwarp_size * CellCore::kReductionStatsDim; - reduce_lane_group_stats_to_buffer(group_stats, lane_stats, group_stats, sub_lane); - __syncwarp(); - const int emit_count = min(hidden_chunk, raw_public_dim - h0); - float* raw_public_out = - raw_public + (static_cast(b) * plan.receivers + receiver_local) * raw_public_dim; - const void* staged_emit_static = - staged_emit_static_ptr(r_local, plan, emit_static_bytes, emit_static_base); - CellCore::emit_public_chunk( - b, - state_receiver, - receiver, - sub_lane, - subwarp_size, - staged_emit_static, - cell_params, - state_next, - raw_public_out, - raw_public_dim, - h0, - emit_count, - group_stats); - } - } -} - -template < - typename CellCore, - int R_TILE, - int B_TILE, - int H_CHUNK, - int LANE_GROUP_SIZE, - CellStaticStageMode STATE_STAGE_MODE, - CellStaticStageMode EMIT_STAGE_MODE> -__global__ void receiver_state_update_emit_only_kernel( - const float* projected_message, - TensorTable state_prev, - TensorTable cell_params, - TensorTable aux, - float* raw_public, - ExecutionPlan plan, - const uint8_t* resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int receiver_offset, - int receiver_global_offset, - int receiver_count) { - static_assert( - LANE_GROUP_SIZE == 8 || LANE_GROUP_SIZE == 16 || LANE_GROUP_SIZE == 32, - "emit-only state/update supports 8-, 16-, or 32-lane groups"); - static_assert(H_CHUNK <= LANE_GROUP_SIZE, "emit-only state/update requires H_CHUNK <= lane group size"); - static_assert(CellCore::kReductionStatsDim > 0, "emit-only state/update requires reduction stats"); - static_assert(CellCore::kSupportsEmitOnlyStateUpdate, "cell core must declare emit-only state/update support"); - constexpr int lane_group_size = LANE_GROUP_SIZE; - constexpr int lane_groups_per_warp = kWarpSize / lane_group_size; - const int lane = threadIdx.x; - const int lane_group = lane / lane_group_size; - const int group_lane = lane - lane_group * lane_group_size; - const int logical_warp = threadIdx.y * lane_groups_per_warp + lane_group; - constexpr int receiver_tile = R_TILE; - constexpr int batch_tile = B_TILE; - constexpr int hidden_chunk = H_CHUNK; - const int r_local = logical_warp / batch_tile; - const int b_local = logical_warp % batch_tile; - const int receiver_chunk_local = blockIdx.x * receiver_tile + r_local; - const int receiver_local = receiver_offset + receiver_chunk_local; - const int receiver = receiver_global_offset + receiver_local; - const int h0 = blockIdx.z * hidden_chunk; - extern __shared__ unsigned char smem[]; - const size_t state_static_smem = - plan.stage_receiver_static && STATE_STAGE_MODE != CellStaticStageMode::Disabled && state_static_bytes > 0 - ? static_cast(receiver_tile) * static_cast(state_static_bytes) - : 0; - const size_t state_static_aligned = align_float_bytes(state_static_smem); - const size_t emit_static_smem = - plan.stage_receiver_static && EMIT_STAGE_MODE != CellStaticStageMode::Disabled && emit_static_bytes > 0 - ? static_cast(receiver_tile) * static_cast(emit_static_bytes) - : 0; - const size_t emit_static_aligned = align_float_bytes(emit_static_smem); - const size_t stats_scratch_offset = state_static_aligned + emit_static_aligned; - unsigned char* state_static_base = smem; - unsigned char* emit_static_base = smem + state_static_aligned; - float* stats_scratch = reinterpret_cast(smem + stats_scratch_offset); - - stage_receiver_state_static_warp( - receiver, - receiver_local < plan.receivers, - b_local, - r_local, - group_lane, - cell_params, - plan, - state_static_bytes, - state_static_base, - lane_group_size); - stage_receiver_emit_static_warp( - receiver, - receiver_local < plan.receivers, - b_local, - r_local, - group_lane, - cell_params, - plan, - emit_static_bytes, - emit_static_base, - lane_group_size); - __syncthreads(); - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (r_local >= receiver_tile || receiver_chunk_local >= receiver_count || receiver_local >= plan.receivers || - b >= plan.B) { - continue; - } - const int state_dim = CellCore::state_dim(projected_message_dim, raw_public_dim, state_prev); - if (h0 >= state_dim || h0 >= raw_public_dim) { - continue; - } - const int h_count = min(hidden_chunk, state_dim - h0); - const bool reset_row = - state_prev_is_zero || (resets_u8 != nullptr ? resets_u8[static_cast(b) * plan.T + t] != 0 : false); - const float* projected_in = - projected_message + (static_cast(b) * plan.receivers + receiver_local) * projected_message_dim; - const void* staged_state_static = - staged_state_static_ptr(r_local, plan, state_static_bytes, state_static_base); - const void* staged_emit_static = - staged_emit_static_ptr(r_local, plan, emit_static_bytes, emit_static_base); - const int state_receiver = state_table_receiver_index(state_prev, receiver_local, receiver, plan.receivers); - constexpr int kStatsStorageDim = reduction_storage_dim(); - float lane_stats[kStatsStorageDim]; - int h = -1; - float y = 0.0f; - CellCore::forward_state_lane_value( - b, - state_receiver, - receiver, - receiver_chunk_local, - reset_row, - group_lane, - lane_group_size, - staged_state_static, - cell_params, - state_prev, - projected_in, - projected_message_dim, - aux, - h0, - h_count, - lane_stats, - &h, - &y); - float* group_stats = - stats_scratch + static_cast(logical_warp) * lane_group_size * CellCore::kReductionStatsDim; - reduce_lane_group_stats_to_buffer(group_stats, lane_stats, group_stats, group_lane); - __syncwarp(); - if (h >= 0 && h < raw_public_dim) { - float* raw_public_out = - raw_public + (static_cast(b) * plan.receivers + receiver_local) * raw_public_dim; - CellCore::emit_public_lane_value(receiver, h, y, staged_emit_static, cell_params, raw_public_out, group_stats); - } - } -} - -template -__global__ void receiver_reduce_stats_kernel( - TensorTable state_next, - const float* partial_stats, - float* reduced_stats, - ExecutionPlan plan, - int projected_message_dim, - int raw_public_dim, - int num_hidden_chunks) { - if constexpr (CellCore::kReductionStatsDim > 0) { - const int warp = threadIdx.y; - constexpr int receiver_tile = R_TILE; - constexpr int batch_tile = B_TILE; - const int r_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int receiver = blockIdx.x * receiver_tile + r_local; - extern __shared__ unsigned char smem[]; - float* stats_base = reinterpret_cast(smem); - float* warp_stats = stats_base + static_cast(warp) * kWarpSize * CellCore::kReductionStatsDim; - float* scratch_stats = stats_base + - static_cast(receiver_tile * batch_tile) * kWarpSize * CellCore::kReductionStatsDim + - static_cast(warp) * CellCore::kReductionStatsDim; - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (r_local >= receiver_tile || receiver >= plan.receivers || b >= plan.B) { - continue; - } - const int64_t base = static_cast(b) * plan.receivers + receiver; - const float* partial = partial_stats + base * num_hidden_chunks * CellCore::kReductionStatsDim; - float* reduced = reduced_stats + base * CellCore::kReductionStatsDim; - const int state_dim = CellCore::state_dim(projected_message_dim, raw_public_dim, state_next); - reduce_chunks_to_buffer( - warp_stats, - scratch_stats, - partial, - reduced, - num_hidden_chunks, - state_dim); - } - } -} - -template -__global__ void receiver_emit_raw_public_kernel( - TensorTable state_next, - TensorTable cell_params, - float* raw_public, - const float* reduced_stats, - ExecutionPlan plan, - int raw_public_dim, - int emit_static_bytes, - int receiver_global_offset) { - const int warp = threadIdx.y; - const int lane = threadIdx.x; - constexpr int receiver_tile = R_TILE; - constexpr int batch_tile = B_TILE; - constexpr int hidden_chunk = H_CHUNK; - const int r_local = warp / batch_tile; - const int b_local = warp % batch_tile; - const int receiver_local = blockIdx.x * receiver_tile + r_local; - const int receiver = receiver_global_offset + receiver_local; - const int h0 = blockIdx.z * hidden_chunk; - extern __shared__ unsigned char smem[]; - unsigned char* static_base = smem; - - stage_receiver_emit_static_warp( - receiver, - receiver_local < plan.receivers, - b_local, - r_local, - lane, - cell_params, - plan, - emit_static_bytes, - static_base); - __syncthreads(); - - for (int bt = blockIdx.y; bt < ceil_div(plan.B, batch_tile); bt += gridDim.y) { - const int b = bt * batch_tile + b_local; - if (r_local >= receiver_tile || receiver_local >= plan.receivers || b >= plan.B || h0 >= raw_public_dim) { - continue; - } - const int h_count = min(hidden_chunk, raw_public_dim - h0); - const void* staged_static = staged_emit_static_ptr(r_local, plan, emit_static_bytes, static_base); - const int state_receiver = state_table_receiver_index(state_next, receiver_local, receiver, plan.receivers); - float* raw_public_out = - raw_public + (static_cast(b) * plan.receivers + receiver_local) * raw_public_dim; - const float* stats = nullptr; - if constexpr (CellCore::kReductionStatsDim > 0) { - stats = - reduced_stats + (static_cast(b) * plan.receivers + receiver_local) * CellCore::kReductionStatsDim; - } - CellCore::emit_public_chunk( - b, - state_receiver, - receiver, - lane, - kWarpSize, - staged_static, - cell_params, - state_next, - raw_public_out, - raw_public_dim, - h0, - h_count, - stats); - } -} - -template -void launch_receiver_message_aggregate_concrete( - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* message_out, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - cudaStream_t stream) { - const dim3 block(kWarpSize, R_TILE * B_TILE); - const dim3 grid(ceil_div(plan.receivers, R_TILE), plan.replication_factor > 0 ? plan.replication_factor : 1); - receiver_message_aggregate_kernel<<>>( - public_prev, - message_params, - input_ports, - topo, - message_out, - plan, - resets_u8.defined() ? resets_u8.data_ptr() : nullptr, - t); -} - -template -void launch_receiver_message_aggregate_variant( - TensorTable public_prev, - TensorTable message_params, - TensorTable input_ports, - MessageTopology topo, - float* message_out, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - cudaStream_t stream) { -#define FABRIC_LAUNCH_MESSAGE_AGG_TILE(R_TILE, B_TILE) \ - if (plan.receiver_tile == R_TILE && plan.batch_tile == B_TILE) { \ - launch_receiver_message_aggregate_concrete( \ - public_prev, message_params, input_ports, topo, message_out, plan, resets_u8, t, stream); \ - return; \ - } - FABRIC_LAUNCH_MESSAGE_AGG_TILE(4, 2) - FABRIC_LAUNCH_MESSAGE_AGG_TILE(2, 4) - FABRIC_LAUNCH_MESSAGE_AGG_TILE(2, 2) - FABRIC_LAUNCH_MESSAGE_AGG_TILE(4, 1) - FABRIC_LAUNCH_MESSAGE_AGG_TILE(8, 1) -#undef FABRIC_LAUNCH_MESSAGE_AGG_TILE - fail_unsupported_tile("receiver message aggregate", plan.receiver_tile, plan.batch_tile); -} - -template -void launch_receiver_normalize_accumulated_concrete( - const float* msg_buffer, - float* message_out, - ExecutionPlan plan, - cudaStream_t stream) { - const dim3 block(kWarpSize, R_TILE * B_TILE); - const dim3 grid(ceil_div(plan.receivers, R_TILE), plan.replication_factor > 0 ? plan.replication_factor : 1); - receiver_normalize_accumulated_message_kernel<<>>( - msg_buffer, - message_out, - plan); -} - -inline void launch_receiver_normalize_accumulated_variant( - const float* msg_buffer, - float* message_out, - ExecutionPlan plan, - cudaStream_t stream) { -#define FABRIC_LAUNCH_NORMALIZE_TILE(R_TILE, B_TILE) \ - if (plan.receiver_tile == R_TILE && plan.batch_tile == B_TILE) { \ - launch_receiver_normalize_accumulated_concrete(msg_buffer, message_out, plan, stream); \ - return; \ - } - FABRIC_LAUNCH_NORMALIZE_TILE(4, 2) - FABRIC_LAUNCH_NORMALIZE_TILE(2, 4) - FABRIC_LAUNCH_NORMALIZE_TILE(2, 2) - FABRIC_LAUNCH_NORMALIZE_TILE(4, 1) - FABRIC_LAUNCH_NORMALIZE_TILE(8, 1) -#undef FABRIC_LAUNCH_NORMALIZE_TILE - fail_unsupported_tile("receiver accumulated-message normalize", plan.receiver_tile, plan.batch_tile); -} - -template -void launch_state_update_concrete( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* partial_stats, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - const auto* props = at::cuda::getCurrentDeviceProperties(); - ExecutionPlan launch_plan = plan; - launch_plan.state_receiver_tile = R_TILE; - launch_plan.state_batch_tile = B_TILE; - launch_plan.state_hidden_chunk = H_CHUNK; - launch_plan.state_static_stage_mode = STAGE_MODE; - const dim3 block(kWarpSize, R_TILE * B_TILE); - const dim3 grid( - ceil_div(receiver_count, R_TILE), - plan.replication_factor > 0 ? plan.replication_factor : 1, - num_hidden_chunks); - const size_t state_static_smem = - plan.stage_receiver_static && STAGE_MODE != CellStaticStageMode::Disabled && state_static_bytes > 0 - ? static_cast(R_TILE) * static_cast(state_static_bytes) - : 0; - const size_t stats_scratch_bytes = CellCore::kReductionStatsDim > 0 - ? static_cast(R_TILE * B_TILE) * kWarpSize * CellCore::kReductionStatsDim * sizeof(float) - : 0; - const size_t state_update_smem = align_float_bytes(state_static_smem) + stats_scratch_bytes; - set_dynamic_smem_if_needed( - receiver_state_update_kernel, - state_update_smem, - props); - receiver_state_update_kernel - <<>>( - projected_message, - state_prev, - state_next, - cell_params, - aux, - partial_stats, - launch_plan, - resets_u8.defined() ? resets_u8.data_ptr() : nullptr, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count); - check_stepwise_launch("receiver_state_update_kernel"); -} - -template -void launch_state_update_emit_concrete( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* raw_public, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - bool materialize_state_output, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - if constexpr (CellCore::kReductionStatsDim > 0) { - TORCH_CHECK( - num_hidden_chunks <= 1, - "Fabric fused state/update emit with reduction stats requires a single hidden chunk"); - } - const auto* props = at::cuda::getCurrentDeviceProperties(); - ExecutionPlan launch_plan = plan; - launch_plan.state_receiver_tile = R_TILE; - launch_plan.state_batch_tile = B_TILE; - launch_plan.state_hidden_chunk = H_CHUNK; - launch_plan.state_static_stage_mode = STATE_STAGE_MODE; - launch_plan.emit_receiver_tile = R_TILE; - launch_plan.emit_batch_tile = B_TILE; - launch_plan.emit_hidden_chunk = H_CHUNK; - launch_plan.emit_static_stage_mode = EMIT_STAGE_MODE; - const dim3 block(kWarpSize, R_TILE * B_TILE); - const dim3 grid( - ceil_div(receiver_count, R_TILE), - plan.replication_factor > 0 ? plan.replication_factor : 1, - num_hidden_chunks); - const size_t state_static_smem = - plan.stage_receiver_static && STATE_STAGE_MODE != CellStaticStageMode::Disabled && state_static_bytes > 0 - ? static_cast(R_TILE) * static_cast(state_static_bytes) - : 0; - const size_t emit_static_smem = - plan.stage_receiver_static && EMIT_STAGE_MODE != CellStaticStageMode::Disabled && emit_static_bytes > 0 - ? static_cast(R_TILE) * static_cast(emit_static_bytes) - : 0; - const size_t stats_scratch_bytes = CellCore::kReductionStatsDim > 0 - ? static_cast(R_TILE * B_TILE) * kWarpSize * CellCore::kReductionStatsDim * sizeof(float) - : 0; - const size_t fused_smem = - align_float_bytes(state_static_smem) + align_float_bytes(emit_static_smem) + stats_scratch_bytes; - if constexpr (CellCore::kSupportsEmitOnlyStateUpdate && CellCore::kReductionStatsDim > 0) { - if (!materialize_state_output && num_hidden_chunks <= 1) { - constexpr int lane_group_size = H_CHUNK <= 8 ? 8 : H_CHUNK <= 16 ? 16 : 32; - constexpr int lane_groups_per_warp = kWarpSize / lane_group_size; - const dim3 emit_only_block(kWarpSize, (R_TILE * B_TILE + lane_groups_per_warp - 1) / lane_groups_per_warp); - const size_t emit_only_stats_scratch_bytes = - static_cast(R_TILE * B_TILE) * lane_group_size * CellCore::kReductionStatsDim * sizeof(float); - const size_t emit_only_smem = - align_float_bytes(state_static_smem) + align_float_bytes(emit_static_smem) + emit_only_stats_scratch_bytes; - set_dynamic_smem_if_needed( - receiver_state_update_emit_only_kernel< - CellCore, - R_TILE, - B_TILE, - H_CHUNK, - lane_group_size, - STATE_STAGE_MODE, - EMIT_STAGE_MODE>, - emit_only_smem, - props); - receiver_state_update_emit_only_kernel< - CellCore, - R_TILE, - B_TILE, - H_CHUNK, - lane_group_size, - STATE_STAGE_MODE, - EMIT_STAGE_MODE><<>>( - projected_message, - state_prev, - cell_params, - aux, - raw_public, - launch_plan, - resets_u8.defined() ? resets_u8.data_ptr() : nullptr, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - receiver_offset, - receiver_global_offset, - receiver_count); - check_stepwise_launch("receiver_state_update_emit_only_kernel"); - return; - } - } - if constexpr (CellCore::kReductionStatsDim > 0 && H_CHUNK <= 16) { - constexpr int subwarp_size = H_CHUNK <= 8 ? 8 : 16; - constexpr int subwarps_per_warp = kWarpSize / subwarp_size; - const dim3 subwarp_block(kWarpSize, (R_TILE * B_TILE + subwarps_per_warp - 1) / subwarps_per_warp); - const size_t subwarp_stats_scratch_bytes = - static_cast(R_TILE * B_TILE) * subwarp_size * CellCore::kReductionStatsDim * sizeof(float); - const size_t subwarp_smem = - align_float_bytes(state_static_smem) + align_float_bytes(emit_static_smem) + subwarp_stats_scratch_bytes; - set_dynamic_smem_if_needed( - receiver_state_update_emit_subwarp_kernel< - CellCore, - R_TILE, - B_TILE, - H_CHUNK, - subwarp_size, - STATE_STAGE_MODE, - EMIT_STAGE_MODE>, - subwarp_smem, - props); - receiver_state_update_emit_subwarp_kernel< - CellCore, - R_TILE, - B_TILE, - H_CHUNK, - subwarp_size, - STATE_STAGE_MODE, - EMIT_STAGE_MODE><<>>( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - launch_plan, - resets_u8.defined() ? resets_u8.data_ptr() : nullptr, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count); - check_stepwise_launch("receiver_state_update_emit_subwarp_kernel"); - return; - } - set_dynamic_smem_if_needed( - receiver_state_update_emit_kernel, - fused_smem, - props); - receiver_state_update_emit_kernel - <<>>( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - launch_plan, - resets_u8.defined() ? resets_u8.data_ptr() : nullptr, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count); - check_stepwise_launch("receiver_state_update_emit_kernel"); -} - -template -void launch_reduce_stats_concrete( - TensorTable state_next, - const float* partial_stats, - float* reduced_stats, - ExecutionPlan plan, - int projected_message_dim, - int raw_public_dim, - int num_hidden_chunks, - cudaStream_t stream) { - if constexpr (CellCore::kReductionStatsDim > 0) { - const auto* props = at::cuda::getCurrentDeviceProperties(); - const dim3 block(kWarpSize, R_TILE * B_TILE); - const dim3 grid(ceil_div(plan.receivers, R_TILE), plan.replication_factor > 0 ? plan.replication_factor : 1); - const size_t reduce_smem = - static_cast(R_TILE * B_TILE) * (kWarpSize + 1) * CellCore::kReductionStatsDim * sizeof(float); - set_dynamic_smem_if_needed(receiver_reduce_stats_kernel, reduce_smem, props); - receiver_reduce_stats_kernel<<>>( - state_next, - partial_stats, - reduced_stats, - plan, - projected_message_dim, - raw_public_dim, - num_hidden_chunks); - } -} - -template -void launch_state_update_variant_for_stage( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* partial_stats, - float* reduced_stats, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - (void)reduced_stats; -#define FABRIC_LAUNCH_STATE_TILE(R_TILE, B_TILE, H_CHUNK) \ - if (plan.state_receiver_tile == R_TILE && plan.state_batch_tile == B_TILE && \ - plan.state_hidden_chunk == H_CHUNK) { \ - launch_state_update_concrete( \ - projected_message, \ - state_prev, \ - state_next, \ - cell_params, \ - aux, \ - partial_stats, \ - plan, \ - resets_u8, \ - state_prev_is_zero, \ - t, \ - projected_message_dim, \ - raw_public_dim, \ - state_static_bytes, \ - num_hidden_chunks, \ - receiver_offset, \ - receiver_global_offset, \ - receiver_count, \ - stream); \ - return; \ - } - FABRIC_LAUNCH_STATE_TILE(4, 2, 32) - FABRIC_LAUNCH_STATE_TILE(2, 4, 32) - FABRIC_LAUNCH_STATE_TILE(2, 2, 32) - FABRIC_LAUNCH_STATE_TILE(4, 1, 32) - FABRIC_LAUNCH_STATE_TILE(4, 2, 16) - FABRIC_LAUNCH_STATE_TILE(2, 2, 16) - FABRIC_LAUNCH_STATE_TILE(2, 1, 16) - FABRIC_LAUNCH_STATE_TILE(4, 2, 8) - FABRIC_LAUNCH_STATE_TILE(2, 2, 8) - FABRIC_LAUNCH_STATE_TILE(2, 1, 8) -#undef FABRIC_LAUNCH_STATE_TILE - fail_unsupported_tile("receiver state", plan.state_receiver_tile, plan.state_batch_tile, plan.state_hidden_chunk); -} - -template -void launch_state_update_variant( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* partial_stats, - float* reduced_stats, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - if (plan.state_static_stage_mode == CellStaticStageMode::Disabled) { - launch_state_update_variant_for_stage( - projected_message, - state_prev, - state_next, - cell_params, - aux, - partial_stats, - reduced_stats, - plan, - resets_u8, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); - return; - } - TORCH_CHECK( - plan.state_static_stage_mode == CellStaticStageMode::SharedFull, - "Unsupported Fabric receiver state static stage mode"); - launch_state_update_variant_for_stage( - projected_message, - state_prev, - state_next, - cell_params, - aux, - partial_stats, - reduced_stats, - plan, - resets_u8, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); -} - -template -void launch_state_update_emit_variant_for_stage_pair( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* raw_public, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - bool materialize_state_output, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { -#define FABRIC_LAUNCH_STATE_EMIT_TILE(R_TILE, B_TILE, H_CHUNK) \ - if (plan.state_receiver_tile == R_TILE && plan.state_batch_tile == B_TILE && \ - plan.state_hidden_chunk == H_CHUNK) { \ - launch_state_update_emit_concrete( \ - projected_message, \ - state_prev, \ - state_next, \ - cell_params, \ - aux, \ - raw_public, \ - plan, \ - resets_u8, \ - state_prev_is_zero, \ - materialize_state_output, \ - t, \ - projected_message_dim, \ - raw_public_dim, \ - state_static_bytes, \ - emit_static_bytes, \ - num_hidden_chunks, \ - receiver_offset, \ - receiver_global_offset, \ - receiver_count, \ - stream); \ - return; \ - } - FABRIC_LAUNCH_STATE_EMIT_TILE(4, 2, 32) - FABRIC_LAUNCH_STATE_EMIT_TILE(2, 4, 32) - FABRIC_LAUNCH_STATE_EMIT_TILE(2, 2, 32) - FABRIC_LAUNCH_STATE_EMIT_TILE(4, 1, 32) - FABRIC_LAUNCH_STATE_EMIT_TILE(4, 2, 16) - FABRIC_LAUNCH_STATE_EMIT_TILE(2, 2, 16) - FABRIC_LAUNCH_STATE_EMIT_TILE(2, 1, 16) - FABRIC_LAUNCH_STATE_EMIT_TILE(4, 2, 8) - FABRIC_LAUNCH_STATE_EMIT_TILE(2, 2, 8) - FABRIC_LAUNCH_STATE_EMIT_TILE(2, 1, 8) -#undef FABRIC_LAUNCH_STATE_EMIT_TILE - fail_unsupported_tile("receiver state emit", plan.state_receiver_tile, plan.state_batch_tile, plan.state_hidden_chunk); -} - -template -void launch_state_update_emit_variant( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - float* raw_public, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - bool materialize_state_output, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - if (plan.state_static_stage_mode == CellStaticStageMode::Disabled) { - if (plan.emit_static_stage_mode == CellStaticStageMode::Disabled) { - launch_state_update_emit_variant_for_stage_pair< - CellCore, - CellStaticStageMode::Disabled, - CellStaticStageMode::Disabled>( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - plan, - resets_u8, - state_prev_is_zero, - materialize_state_output, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); - return; - } - launch_state_update_emit_variant_for_stage_pair< - CellCore, - CellStaticStageMode::Disabled, - CellStaticStageMode::SharedFull>( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - plan, - resets_u8, - state_prev_is_zero, - materialize_state_output, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); - return; - } - TORCH_CHECK( - plan.state_static_stage_mode == CellStaticStageMode::SharedFull, - "Unsupported Fabric fused receiver state static stage mode"); - if (plan.emit_static_stage_mode == CellStaticStageMode::Disabled) { - launch_state_update_emit_variant_for_stage_pair< - CellCore, - CellStaticStageMode::SharedFull, - CellStaticStageMode::Disabled>( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - plan, - resets_u8, - state_prev_is_zero, - materialize_state_output, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); - return; - } - TORCH_CHECK( - plan.emit_static_stage_mode == CellStaticStageMode::SharedFull, - "Unsupported Fabric fused receiver emit static stage mode"); - launch_state_update_emit_variant_for_stage_pair< - CellCore, - CellStaticStageMode::SharedFull, - CellStaticStageMode::SharedFull>( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - plan, - resets_u8, - state_prev_is_zero, - materialize_state_output, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); -} - -template -void launch_emit_raw_public_concrete( - TensorTable state_next, - TensorTable cell_params, - float* raw_public, - const float* reduced_stats, - ExecutionPlan plan, - int raw_public_dim, - int emit_static_bytes, - int receiver_global_offset, - cudaStream_t stream) { - const auto* props = at::cuda::getCurrentDeviceProperties(); - ExecutionPlan launch_plan = plan; - launch_plan.emit_receiver_tile = R_TILE; - launch_plan.emit_batch_tile = B_TILE; - launch_plan.emit_hidden_chunk = H_CHUNK; - launch_plan.emit_static_stage_mode = STAGE_MODE; - const int emit_hidden_chunks = ceil_div(std::max(1, raw_public_dim), H_CHUNK); - const dim3 block(kWarpSize, R_TILE * B_TILE); - const dim3 grid( - ceil_div(plan.receivers, R_TILE), - plan.replication_factor > 0 ? plan.replication_factor : 1, - emit_hidden_chunks); - const size_t emit_static_smem = - plan.stage_receiver_static && STAGE_MODE != CellStaticStageMode::Disabled && emit_static_bytes > 0 - ? static_cast(R_TILE) * static_cast(emit_static_bytes) - : 0; - set_dynamic_smem_if_needed( - receiver_emit_raw_public_kernel, - emit_static_smem, - props); - receiver_emit_raw_public_kernel - <<>>( - state_next, - cell_params, - raw_public, - reduced_stats, - launch_plan, - raw_public_dim, - emit_static_bytes, - receiver_global_offset); - check_stepwise_launch("receiver_emit_raw_public_kernel"); -} - -template -void launch_emit_raw_public_variant_for_stage( - TensorTable state_next, - TensorTable cell_params, - float* raw_public, - const float* reduced_stats, - ExecutionPlan plan, - int raw_public_dim, - int emit_static_bytes, - int receiver_global_offset, - cudaStream_t stream) { -#define FABRIC_LAUNCH_EMIT_TILE(R_TILE, B_TILE, H_CHUNK) \ - if (plan.emit_receiver_tile == R_TILE && plan.emit_batch_tile == B_TILE && \ - plan.emit_hidden_chunk == H_CHUNK) { \ - launch_emit_raw_public_concrete( \ - state_next, \ - cell_params, \ - raw_public, \ - reduced_stats, \ - plan, \ - raw_public_dim, \ - emit_static_bytes, \ - receiver_global_offset, \ - stream); \ - return; \ - } - FABRIC_LAUNCH_EMIT_TILE(4, 4, 32) - FABRIC_LAUNCH_EMIT_TILE(4, 2, 32) - FABRIC_LAUNCH_EMIT_TILE(2, 4, 32) -#undef FABRIC_LAUNCH_EMIT_TILE - fail_unsupported_tile("receiver emit", plan.emit_receiver_tile, plan.emit_batch_tile, plan.emit_hidden_chunk); -} - -template -void launch_reduce_stats_variant( - TensorTable state_next, - const float* partial_stats, - float* reduced_stats, - ExecutionPlan plan, - int projected_message_dim, - int raw_public_dim, - int num_hidden_chunks, - cudaStream_t stream) { -#define FABRIC_LAUNCH_REDUCE_TILE(R_TILE, B_TILE, H_CHUNK) \ - if (plan.state_receiver_tile == R_TILE && plan.state_batch_tile == B_TILE && \ - plan.state_hidden_chunk == H_CHUNK) { \ - launch_reduce_stats_concrete( \ - state_next, \ - partial_stats, \ - reduced_stats, \ - plan, \ - projected_message_dim, \ - raw_public_dim, \ - num_hidden_chunks, \ - stream); \ - return; \ - } - FABRIC_LAUNCH_REDUCE_TILE(4, 2, 32) - FABRIC_LAUNCH_REDUCE_TILE(2, 4, 32) - FABRIC_LAUNCH_REDUCE_TILE(2, 2, 32) - FABRIC_LAUNCH_REDUCE_TILE(4, 1, 32) - FABRIC_LAUNCH_REDUCE_TILE(4, 2, 16) - FABRIC_LAUNCH_REDUCE_TILE(2, 2, 16) - FABRIC_LAUNCH_REDUCE_TILE(2, 1, 16) - FABRIC_LAUNCH_REDUCE_TILE(4, 2, 8) - FABRIC_LAUNCH_REDUCE_TILE(2, 2, 8) - FABRIC_LAUNCH_REDUCE_TILE(2, 1, 8) -#undef FABRIC_LAUNCH_REDUCE_TILE - fail_unsupported_tile("receiver reduce", plan.state_receiver_tile, plan.state_batch_tile, plan.state_hidden_chunk); -} - -template -void launch_emit_raw_public_variant( - TensorTable state_next, - TensorTable cell_params, - float* raw_public, - const float* reduced_stats, - ExecutionPlan plan, - int raw_public_dim, - int emit_static_bytes, - int receiver_global_offset, - cudaStream_t stream) { - if (plan.emit_static_stage_mode == CellStaticStageMode::Disabled) { - launch_emit_raw_public_variant_for_stage( - state_next, - cell_params, - raw_public, - reduced_stats, - plan, - raw_public_dim, - emit_static_bytes, - receiver_global_offset, - stream); - return; - } - TORCH_CHECK( - plan.emit_static_stage_mode == CellStaticStageMode::SharedFull, - "Unsupported Fabric receiver emit static stage mode"); - launch_emit_raw_public_variant_for_stage( - state_next, - cell_params, - raw_public, - reduced_stats, - plan, - raw_public_dim, - emit_static_bytes, - receiver_global_offset, - stream); -} - -template -void launch_receiver_state_update_typed( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - float* partial_stats, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - validate_receiver_launch(plan); - TORCH_CHECK(receiver_offset >= 0, "receiver state update offset must be non-negative"); - TORCH_CHECK(receiver_count >= 0, "receiver state update count must be non-negative"); - TORCH_CHECK( - receiver_offset + receiver_count <= plan.receivers, - "receiver state update window exceeds receiver count"); - launch_state_update_variant( - projected_message, - state_prev, - state_next, - cell_params, - aux, - partial_stats, - nullptr, - plan, - resets_u8, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); -} - -template -void launch_receiver_state_update_emit_typed( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - bool materialize_state_output, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - float* raw_public, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - validate_receiver_launch(plan); - TORCH_CHECK(receiver_offset >= 0, "receiver state update/emit offset must be non-negative"); - TORCH_CHECK(receiver_count >= 0, "receiver state update/emit count must be non-negative"); - TORCH_CHECK( - receiver_offset + receiver_count <= plan.receivers, - "receiver state update/emit window exceeds receiver count"); - launch_state_update_emit_variant( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - plan, - resets_u8, - state_prev_is_zero, - materialize_state_output, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); -} - -template -void launch_receiver_reduce_emit_typed( - TensorTable state_next, - TensorTable cell_params, - ExecutionPlan plan, - int projected_message_dim, - int raw_public_dim, - int emit_static_bytes, - float* raw_public, - float* partial_stats, - float* reduced_stats, - int num_hidden_chunks, - int receiver_global_offset, - cudaStream_t stream) { - validate_receiver_launch(plan); - const float* emit_stats = reduced_stats; - if constexpr (CellCore::kReductionStatsDim > 0) { - if (num_hidden_chunks <= 1) { - emit_stats = partial_stats; - } else { - launch_reduce_stats_variant( - state_next, - partial_stats, - reduced_stats, - plan, - projected_message_dim, - raw_public_dim, - num_hidden_chunks, - stream); - } - } - launch_emit_raw_public_variant( - state_next, - cell_params, - raw_public, - emit_stats, - plan, - raw_public_dim, - emit_static_bytes, - receiver_global_offset, - stream); -} - -template -void launch_receiver_state_emit_typed( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int state_epilogue_policy, - float* raw_public, - float* partial_stats, - float* reduced_stats, - int num_hidden_chunks, - cudaStream_t stream) { - validate_receiver_launch(plan); - if constexpr (CellCore::kReductionStatsDim > 0) { - if (num_hidden_chunks <= 1) { - launch_state_update_emit_variant( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - plan, - resets_u8, - false, - true, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - 0, - 0, - plan.receivers, - stream); - return; - } - } - if constexpr (CellCore::kReductionStatsDim == 0) { - if (static_cast(state_epilogue_policy) == - fabric::cuda::nn::StateEpiloguePolicy::FusedNoReductionSameChunk) { - launch_state_update_emit_variant( - projected_message, - state_prev, - state_next, - cell_params, - aux, - raw_public, - plan, - resets_u8, - false, - true, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - num_hidden_chunks, - 0, - 0, - plan.receivers, - stream); - return; - } - } - launch_state_update_variant( - projected_message, - state_prev, - state_next, - cell_params, - aux, - partial_stats, - reduced_stats, - plan, - resets_u8, - false, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - num_hidden_chunks, - 0, - 0, - plan.receivers, - stream); - const float* emit_stats = reduced_stats; - if constexpr (CellCore::kReductionStatsDim > 0) { - if (num_hidden_chunks <= 1) { - emit_stats = partial_stats; - } else { - launch_reduce_stats_variant( - state_next, - partial_stats, - reduced_stats, - plan, - projected_message_dim, - raw_public_dim, - num_hidden_chunks, - stream); - } - } - launch_emit_raw_public_variant( - state_next, - cell_params, - raw_public, - emit_stats, - plan, - raw_public_dim, - emit_static_bytes, - 0, - stream); -} - -} // namespace stepwise_detail - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/execution/registry.py b/src/cortical/fabric/backend/cuda/execution/registry.py deleted file mode 100644 index 42f24596..00000000 --- a/src/cortical/fabric/backend/cuda/execution/registry.py +++ /dev/null @@ -1,248 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable, Mapping, Sequence -from dataclasses import dataclass -from typing import Any, Literal - -import torch - -from cortical.fabric.backend.reuse import MathBackend -from cortical.fabric.contracts.cells import CellSpec - -SpatialOwnership = Literal["receiver_owned", "edge_owned"] -TemporalExecution = Literal["stepwise", "persistent_scan"] - -SUPPORTED_READOUT_MODES = frozenset({"skip", "separate_port_owned"}) -SUPPORTED_CELL_STATIC_STAGE_MODES = frozenset({"disabled", "shared_full"}) -SUPPORTED_MESSAGE_RULE_LOWERINGS = frozenset({"dot_product_segment_softmax_weighted_sum"}) -SUPPORTED_MESSAGE_RULE_OUTPUT_BOUNDARIES = frozenset({"projected_message"}) - - -@dataclass(frozen=True) -class ExecutionVariantSpec: - spatial_ownership: SpatialOwnership - temporal_execution: TemporalExecution - math_backend: MathBackend - - -@dataclass(frozen=True) -class FabricExecutionRequest: - population_name: str - cell_core_spec: CellSpec - message_backend_name: str - message_rule_name: str - message_rule_lowering_kind: str - message_rule_expression_signature: str - message_rule_source_signature: str - message_rule_parameter_sharing_signature: str - message_rule_output_boundary: str - readout_backend_name: str - gradient_enabled: bool - input_k_seq: torch.Tensor - input_v_seq: torch.Tensor - packed_state: Any - initial_hidden: torch.Tensor - initial_recurrent_k: torch.Tensor | None - initial_recurrent_v: torch.Tensor | None - initial_state_is_fresh: bool - materialize_final_state: bool - resets_u8: torch.Tensor - reset_rows_present: bool - stage_receiver_static: bool - replication_factor: int - receiver_tile: int - batch_tile: int - edge_tile: int - hidden_chunk: int - state_receiver_tile: int - state_batch_tile: int - state_hidden_chunk: int - state_static_stage_mode: str - emit_receiver_tile: int - emit_batch_tile: int - emit_hidden_chunk: int - emit_static_stage_mode: str - public_receiver_tile: int - public_batch_tile: int - readout_mode: str - readout_port_tile: int - readout_output_chunk: int - cell_static_stage_mode: str - routing_tensors: Mapping[str, torch.Tensor] - cell_tensors: Mapping[str, torch.Tensor] - readout_tensors: Mapping[str, torch.Tensor] - static_config: Mapping[str, Any] - compact_input_carry: bool = False - preserve_internal_carry: bool = False - - def __post_init__(self) -> None: - if self.readout_mode not in SUPPORTED_READOUT_MODES: - raise ValueError( - f"Unsupported Fabric CUDA readout_mode={self.readout_mode!r}; " - f"supported modes are {sorted(SUPPORTED_READOUT_MODES)}" - ) - if self.cell_static_stage_mode not in SUPPORTED_CELL_STATIC_STAGE_MODES: - raise ValueError( - f"Unsupported Fabric CUDA cell_static_stage_mode={self.cell_static_stage_mode!r}; " - f"supported modes are {sorted(SUPPORTED_CELL_STATIC_STAGE_MODES)}" - ) - if self.state_static_stage_mode not in SUPPORTED_CELL_STATIC_STAGE_MODES: - raise ValueError( - f"Unsupported Fabric CUDA state_static_stage_mode={self.state_static_stage_mode!r}; " - f"supported modes are {sorted(SUPPORTED_CELL_STATIC_STAGE_MODES)}" - ) - if self.emit_static_stage_mode not in SUPPORTED_CELL_STATIC_STAGE_MODES: - raise ValueError( - f"Unsupported Fabric CUDA emit_static_stage_mode={self.emit_static_stage_mode!r}; " - f"supported modes are {sorted(SUPPORTED_CELL_STATIC_STAGE_MODES)}" - ) - if self.message_rule_lowering_kind not in SUPPORTED_MESSAGE_RULE_LOWERINGS: - raise ValueError( - f"Unsupported Fabric CUDA message_rule_lowering_kind={self.message_rule_lowering_kind!r}; " - f"supported lowerings are {sorted(SUPPORTED_MESSAGE_RULE_LOWERINGS)}" - ) - if self.message_rule_output_boundary not in SUPPORTED_MESSAGE_RULE_OUTPUT_BOUNDARIES: - raise ValueError( - f"Unsupported Fabric CUDA message_rule_output_boundary={self.message_rule_output_boundary!r}; " - f"supported boundaries are {sorted(SUPPORTED_MESSAGE_RULE_OUTPUT_BOUNDARIES)}" - ) - - -@dataclass(frozen=True) -class ForwardCarryCheckpoints: - stride: int - steps: tuple[int, ...] - state_tensors: tuple[torch.Tensor, ...] - public_tensors: tuple[torch.Tensor, ...] - rebuild_state: Callable[[Sequence[torch.Tensor]], Any] - state_names: tuple[str, ...] = () - - -@dataclass(frozen=True) -class ExecutionRegistryEntry: - variant: ExecutionVariantSpec - runner: Callable[[FabricExecutionRequest], tuple[Any, ...]] - supports: Callable[[FabricExecutionRequest], bool] | None = None - - -_EXECUTION_REGISTRY: list[ExecutionRegistryEntry] = [] - -_LOCAL_MESSAGE_ROUTING_KEYS = frozenset( - { - "recurrent_q", - "recurrent_local_sender_idx", - "recurrent_local_valid", - "local_distance", - "local_delay", - } -) -_SPARSE_MESSAGE_ROUTING_KEYS = frozenset( - { - "recurrent_q", - "recurrent_neighbor_idx", - "recurrent_neighbor_valid", - "recurrent_edge_distance", - "recurrent_edge_delay", - "recurrent_sparse_receiver_order", - "recurrent_sparse_degree_ptr", - } -) -_LOCAL_READOUT_ROUTING_KEYS = frozenset( - { - "output_q", - "output_local_sender_idx", - "output_local_valid", - "local_distance", - "local_delay", - } -) -_SPARSE_READOUT_ROUTING_KEYS = frozenset( - { - "output_q", - "output_neighbor_idx", - "output_neighbor_valid", - "output_edge_distance", - "output_edge_delay", - } -) -_READOUT_TENSOR_KEYS = frozenset({"output_projection_weight", "output_projection_bias"}) - - -def register_execution_backend(entry: ExecutionRegistryEntry) -> None: - _EXECUTION_REGISTRY.append(entry) - - -def required_cell_tensor_keys(request: FabricExecutionRequest) -> frozenset[str]: - return frozenset( - tuple(request.cell_core_spec.parameter_schema.keys) - + tuple(request.cell_core_spec.input_projection_schema.keys) - + tuple(request.cell_core_spec.public_projection_schema.keys) - ) - - -def request_has_generic_dispatch_contract(request: FabricExecutionRequest) -> bool: - if request.message_rule_lowering_kind not in SUPPORTED_MESSAGE_RULE_LOWERINGS: - return False - if request.message_rule_output_boundary not in SUPPORTED_MESSAGE_RULE_OUTPUT_BOUNDARIES: - return False - if not required_cell_tensor_keys(request).issubset(request.cell_tensors.keys()): - return False - if not _READOUT_TENSOR_KEYS.issubset(request.readout_tensors.keys()): - return False - if request.message_backend_name == "local": - message_keys = _LOCAL_MESSAGE_ROUTING_KEYS - elif request.message_backend_name == "sparse": - message_keys = _SPARSE_MESSAGE_ROUTING_KEYS - else: - return False - if request.readout_backend_name == "output_sequence_from_banks": - readout_keys = _LOCAL_READOUT_ROUTING_KEYS - elif request.readout_backend_name == "output_sequence_from_sparse_banks": - readout_keys = _SPARSE_READOUT_ROUTING_KEYS - else: - return False - return message_keys.issubset(request.routing_tensors.keys()) and readout_keys.issubset( - request.routing_tensors.keys() - ) - - -def run_registered_execution( - *, - spatial_ownership: SpatialOwnership, - temporal_execution: TemporalExecution, - math_backend: MathBackend, - request: FabricExecutionRequest, -) -> tuple[Any, ...]: - for entry in _EXECUTION_REGISTRY: - if entry.variant.spatial_ownership != spatial_ownership: - continue - if entry.variant.temporal_execution != temporal_execution: - continue - if entry.variant.math_backend != math_backend: - continue - if entry.supports is not None and not entry.supports(request): - continue - return entry.runner(request) - raise RuntimeError( - "No registered Fabric execution backend for " - f"population={request.population_name} spatial={spatial_ownership} " - f"temporal={temporal_execution} math={math_backend.value}" - ) - - -__all__ = [ - "ExecutionVariantSpec", - "ExecutionRegistryEntry", - "FabricExecutionRequest", - "ForwardCarryCheckpoints", - "SUPPORTED_CELL_STATIC_STAGE_MODES", - "SUPPORTED_MESSAGE_RULE_LOWERINGS", - "SUPPORTED_MESSAGE_RULE_OUTPUT_BOUNDARIES", - "SUPPORTED_READOUT_MODES", - "SpatialOwnership", - "TemporalExecution", - "request_has_generic_dispatch_contract", - "required_cell_tensor_keys", - "register_execution_backend", - "run_registered_execution", -] diff --git a/src/cortical/fabric/backend/cuda/execution/tensor_pack.py b/src/cortical/fabric/backend/cuda/execution/tensor_pack.py deleted file mode 100644 index f90a339e..00000000 --- a/src/cortical/fabric/backend/cuda/execution/tensor_pack.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from dataclasses import dataclass -from typing import Any - -import torch -from tensordict import TensorDictBase - - -@dataclass(frozen=True) -class PackedTensorTable: - ptrs: torch.Tensor - sizes: torch.Tensor - strides: torch.Tensor - ndims: torch.Tensor - count: int - - def as_extension_tuple(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return (self.ptrs, self.sizes, self.strides, self.ndims) - - -def flatten_tensor_tree(tree: Any) -> tuple[tuple[str, ...] | None, tuple[torch.Tensor, ...]]: - if torch.is_tensor(tree): - return None, (tree,) - if isinstance(tree, TensorDictBase): - keys = tuple(str(key) for key in tree.keys()) - return keys, tuple(tree[key] for key in keys) - if isinstance(tree, Mapping): - keys = tuple(str(key) for key in tree.keys()) - return keys, tuple(tree[key] for key in keys) - if isinstance(tree, Sequence) and not isinstance(tree, (str, bytes)): - return None, tuple(tree) - raise TypeError(f"Unsupported tensor tree type {type(tree).__name__}") - - -def pack_tensor_tree(tree: Any) -> tuple[tuple[str, ...] | None, tuple[torch.Tensor, ...], PackedTensorTable]: - keys, leaves = flatten_tensor_tree(tree) - if not leaves: - raise ValueError("Tensor pack requires at least one tensor leaf") - device = leaves[0].device - if any(not torch.is_tensor(tensor) for tensor in leaves): - raise TypeError("Tensor pack only supports tensor leaves") - if any(tensor.device != device for tensor in leaves): - raise ValueError("Tensor pack requires all leaves on the same device") - with torch.profiler.record_function("fabric.glue.tensor_table_pack"): - ptrs = torch.tensor([int(tensor.data_ptr()) for tensor in leaves], device=device, dtype=torch.int64) - sizes = torch.tensor( - [[*(list(tensor.shape) + [1] * (4 - tensor.dim()))][:4] for tensor in leaves], - device=device, - dtype=torch.int64, - ) - strides = torch.tensor( - [[*(list(tensor.stride()) + [1] * (4 - tensor.dim()))][:4] for tensor in leaves], - device=device, - dtype=torch.int64, - ) - ndims = torch.tensor([tensor.dim() for tensor in leaves], device=device, dtype=torch.int32) - return ( - keys, - leaves, - PackedTensorTable( - ptrs=ptrs, - sizes=sizes, - strides=strides, - ndims=ndims, - count=len(leaves), - ), - ) - - -def empty_packed_tensor_table(*, device: torch.device) -> PackedTensorTable: - with torch.profiler.record_function("fabric.glue.empty_tensor_table_pack"): - return PackedTensorTable( - ptrs=torch.empty(0, device=device, dtype=torch.int64), - sizes=torch.empty(0, 4, device=device, dtype=torch.int64), - strides=torch.empty(0, 4, device=device, dtype=torch.int64), - ndims=torch.empty(0, device=device, dtype=torch.int32), - count=0, - ) - - -def rebuild_tensor_tree(keys: tuple[str, ...] | None, leaves: Sequence[torch.Tensor]) -> Any: - if keys is None: - if len(leaves) == 1: - return leaves[0] - return tuple(leaves) - if len(keys) != len(leaves): - raise ValueError("Tensor-tree rebuild requires matching key and leaf counts") - return {key: leaf for key, leaf in zip(keys, leaves, strict=True)} - - -__all__ = [ - "PackedTensorTable", - "empty_packed_tensor_table", - "flatten_tensor_tree", - "pack_tensor_tree", - "rebuild_tensor_tree", -] diff --git a/src/cortical/fabric/backend/cuda/message_passing/__init__.py b/src/cortical/fabric/backend/cuda/message_passing/__init__.py deleted file mode 100644 index 0309101d..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from cortical.fabric.backend.cuda.message_passing.local_message_cuda import ( - fabric_local_message_backward_receiver_cuda, - fabric_local_message_backward_sender_cuda, - fabric_local_message_cuda, - fabric_local_message_partitioned_backward_fused_cuda, - fabric_local_message_partitioned_backward_receiver_cuda, - fabric_local_message_partitioned_backward_sender_cuda, - fabric_local_message_partitioned_cuda, -) -from cortical.fabric.backend.cuda.message_passing.registry import get_message_backend, register_message_backend -from cortical.fabric.backend.cuda.message_passing.sparse_message_cuda import ( - fabric_sparse_message_backward_receiver_cuda, - fabric_sparse_message_backward_sender_cuda, - fabric_sparse_message_cuda, - fabric_sparse_message_partitioned_backward_receiver_cuda, - fabric_sparse_message_partitioned_backward_sender_cuda, - fabric_sparse_message_partitioned_cuda, -) - -register_message_backend("local_partitioned", fabric_local_message_partitioned_cuda) -register_message_backend("sparse_partitioned", fabric_sparse_message_partitioned_cuda) - -__all__ = [ - "fabric_local_message_backward_receiver_cuda", - "fabric_local_message_backward_sender_cuda", - "fabric_local_message_cuda", - "fabric_local_message_partitioned_backward_fused_cuda", - "fabric_local_message_partitioned_backward_receiver_cuda", - "fabric_local_message_partitioned_backward_sender_cuda", - "fabric_local_message_partitioned_cuda", - "get_message_backend", - "register_message_backend", - "fabric_sparse_message_backward_receiver_cuda", - "fabric_sparse_message_backward_sender_cuda", - "fabric_sparse_message_cuda", - "fabric_sparse_message_partitioned_backward_receiver_cuda", - "fabric_sparse_message_partitioned_backward_sender_cuda", - "fabric_sparse_message_partitioned_cuda", -] diff --git a/src/cortical/fabric/backend/cuda/message_passing/local_message_backend.cuh b/src/cortical/fabric/backend/cuda/message_passing/local_message_backend.cuh deleted file mode 100644 index 8a3726cc..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/local_message_backend.cuh +++ /dev/null @@ -1,267 +0,0 @@ -#pragma once - -#include - -#include "cortical/fabric/backend/cuda/execution/common.cuh" - -namespace fabric { - -namespace detail { - -__device__ inline float read_input_port_value( - const TensorRef& tensor, - int b, - int t, - int sender, - int d) { - if (tensor.ndim == 4) { - return tensor.at(b, t, sender, d); - } - return tensor.at(b, sender, d); -} - -__device__ inline int find_receiver_for_edge( - const MessageTopology& topo, - int receiver_count, - int edge_idx) { - int receiver = 0; - while (receiver + 1 < receiver_count && topo.receiver_ptr[receiver + 1] <= edge_idx) { - ++receiver; - } - return receiver; -} - -} // namespace detail - -struct LocalMessageBackend { - __device__ static void aggregate_receiver_step_warp( - int b, - int t, - int receiver, - bool reset_row, - const MessageTopology& topo, - const TensorTable& input_ports, - const TensorTable& public_prev, - const TensorTable& msg_params, - float* msg_out, - int msg_dim, - int lane) { - const auto input_k = tensor_ref(input_ports, 0); - const auto input_v = tensor_ref(input_ports, 1); - const auto recurrent_k = tensor_ref(public_prev, 1); - const auto recurrent_v = tensor_ref(public_prev, 2); - const auto q = tensor_ref(msg_params, 0); - const int head_dim = static_cast(q.size[1]); - const int edge_begin = topo.receiver_ptr[receiver]; - const int edge_end = topo.receiver_ptr[receiver + 1]; - const float inv_sqrt_dk = rsqrtf(static_cast(head_dim > 0 ? head_dim : 1)); - float max_logit = -INFINITY; - for (int edge = edge_begin; edge < edge_end; ++edge) { - const int sender = topo.sender_idx[edge]; - float dot = 0.0f; - if (sender < topo.num_input_ports) { - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += q.at(receiver, d) * detail::read_input_port_value(input_k, b, t, sender, d); - } - } else { - const int recurrent_sender = sender - topo.num_input_ports; - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += q.at(receiver, d) * (reset_row ? 0.0f : recurrent_k.at(b, recurrent_sender, d)); - } - } - const float penalty = topo.edge_weight == nullptr ? 0.0f : topo.edge_weight[edge]; - max_logit = fmaxf(max_logit, warp_sum(dot) * inv_sqrt_dk - penalty); - } - - for (int d = lane; d < msg_dim; d += kWarpSize) { - msg_out[d] = 0.0f; - } - float norm = 0.0f; - for (int edge = edge_begin; edge < edge_end; ++edge) { - const int sender = topo.sender_idx[edge]; - float dot = 0.0f; - if (sender < topo.num_input_ports) { - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += q.at(receiver, d) * detail::read_input_port_value(input_k, b, t, sender, d); - } - } else { - const int recurrent_sender = sender - topo.num_input_ports; - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += q.at(receiver, d) * (reset_row ? 0.0f : recurrent_k.at(b, recurrent_sender, d)); - } - } - const float penalty = topo.edge_weight == nullptr ? 0.0f : topo.edge_weight[edge]; - const float weight = expf(warp_sum(dot) * inv_sqrt_dk - penalty - max_logit); - norm += weight; - if (sender < topo.num_input_ports) { - for (int d = lane; d < msg_dim; d += kWarpSize) { - msg_out[d] += weight * detail::read_input_port_value(input_v, b, t, sender, d); - } - } else { - const int recurrent_sender = sender - topo.num_input_ports; - for (int d = lane; d < msg_dim; d += kWarpSize) { - msg_out[d] += weight * (reset_row ? 0.0f : recurrent_v.at(b, recurrent_sender, d)); - } - } - } - if (norm > 0.0f) { - const float inv_norm = 1.0f / norm; - for (int d = lane; d < msg_dim; d += kWarpSize) { - msg_out[d] *= inv_norm; - } - } - } - - __device__ static void edge_contrib_step( - int b, - int t, - int edge_idx, - bool reset_row, - const MessageTopology& topo, - const TensorTable& input_ports, - const TensorTable& public_prev, - const TensorTable& msg_params, - float* contrib_out, - int msg_dim, - int* receiver_out) { - const float logit = edge_logit_step( - b, - t, - edge_idx, - reset_row, - topo, - input_ports, - public_prev, - msg_params, - receiver_out); - edge_value_step( - b, - t, - edge_idx, - reset_row, - topo, - input_ports, - public_prev, - contrib_out, - msg_dim); - const float weight = isfinite(logit) ? expf(logit) : 0.0f; - for (int d = 0; d < msg_dim; ++d) { - contrib_out[d] *= weight; - } - } - - __device__ static float edge_logit_step( - int b, - int t, - int edge_idx, - bool reset_row, - const MessageTopology& topo, - const TensorTable& input_ports, - const TensorTable& public_prev, - const TensorTable& msg_params, - int* receiver_out) { - const auto q = tensor_ref(msg_params, 0); - const int receiver_count = static_cast(q.size[0]); - const int receiver = detail::find_receiver_for_edge(topo, receiver_count, edge_idx); - *receiver_out = receiver; - const auto input_k = tensor_ref(input_ports, 0); - const auto recurrent_k = tensor_ref(public_prev, 1); - const int sender = topo.sender_idx[edge_idx]; - const int head_dim = static_cast(q.size[1]); - const float inv_sqrt_dk = rsqrtf(static_cast(head_dim > 0 ? head_dim : 1)); - float dot = 0.0f; - if (sender < topo.num_input_ports) { - for (int d = 0; d < head_dim; ++d) { - dot += q.at(receiver, d) * detail::read_input_port_value(input_k, b, t, sender, d); - } - } else { - const int recurrent_sender = sender - topo.num_input_ports; - for (int d = 0; d < head_dim; ++d) { - dot += q.at(receiver, d) * (reset_row ? 0.0f : recurrent_k.at(b, recurrent_sender, d)); - } - } - const float penalty = topo.edge_weight == nullptr ? 0.0f : topo.edge_weight[edge_idx]; - return dot * inv_sqrt_dk - penalty; - } - - __device__ static float edge_logit_step_warp( - int b, - int t, - int edge_idx, - bool reset_row, - const MessageTopology& topo, - const TensorTable& input_ports, - const TensorTable& public_prev, - const TensorTable& msg_params, - int* receiver_out, - int lane) { - const auto q = tensor_ref(msg_params, 0); - const int receiver_count = static_cast(q.size[0]); - const int receiver = detail::find_receiver_for_edge(topo, receiver_count, edge_idx); - *receiver_out = receiver; - const auto input_k = tensor_ref(input_ports, 0); - const auto recurrent_k = tensor_ref(public_prev, 1); - const int sender = topo.sender_idx[edge_idx]; - const int head_dim = static_cast(q.size[1]); - const float inv_sqrt_dk = rsqrtf(static_cast(head_dim > 0 ? head_dim : 1)); - float dot = 0.0f; - if (sender < topo.num_input_ports) { - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += q.at(receiver, d) * detail::read_input_port_value(input_k, b, t, sender, d); - } - } else { - const int recurrent_sender = sender - topo.num_input_ports; - for (int d = lane; d < head_dim; d += kWarpSize) { - dot += q.at(receiver, d) * (reset_row ? 0.0f : recurrent_k.at(b, recurrent_sender, d)); - } - } - const float penalty = topo.edge_weight == nullptr ? 0.0f : topo.edge_weight[edge_idx]; - return warp_sum(dot) * inv_sqrt_dk - penalty; - } - - __device__ static void edge_value_step( - int b, - int t, - int edge_idx, - bool reset_row, - const MessageTopology& topo, - const TensorTable& input_ports, - const TensorTable& public_prev, - float* value_out, - int msg_dim) { - const auto input_v = tensor_ref(input_ports, 1); - const auto recurrent_v = tensor_ref(public_prev, 2); - const int sender = topo.sender_idx[edge_idx]; - if (sender < topo.num_input_ports) { - for (int d = 0; d < msg_dim; ++d) { - value_out[d] = detail::read_input_port_value(input_v, b, t, sender, d); - } - } else { - const int recurrent_sender = sender - topo.num_input_ports; - for (int d = 0; d < msg_dim; ++d) { - value_out[d] = reset_row ? 0.0f : recurrent_v.at(b, recurrent_sender, d); - } - } - } - - __device__ static float edge_value_component_step( - int b, - int t, - int edge_idx, - bool reset_row, - const MessageTopology& topo, - const TensorTable& input_ports, - const TensorTable& public_prev, - int d) { - const auto input_v = tensor_ref(input_ports, 1); - const auto recurrent_v = tensor_ref(public_prev, 2); - const int sender = topo.sender_idx[edge_idx]; - if (sender < topo.num_input_ports) { - return detail::read_input_port_value(input_v, b, t, sender, d); - } - const int recurrent_sender = sender - topo.num_input_ports; - return reset_row ? 0.0f : recurrent_v.at(b, recurrent_sender, d); - } -}; - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/message_passing/local_message_binding.cpp b/src/cortical/fabric/backend/cuda/message_passing/local_message_binding.cpp deleted file mode 100644 index 7b15701f..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/local_message_binding.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#include - -std::vector fabric_local_message_forward_cuda( - at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor receiver_sender_idx, at::Tensor sender_receiver_idx, - at::Tensor offset_distance, at::Tensor offset_delay, at::Tensor step_flat, - double distance_scale, bool use_delay); - -std::vector fabric_local_message_forward_partitioned_cuda( - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor receiver_sender_idx, - at::Tensor offset_distance, - at::Tensor offset_delay, - at::Tensor step_flat, - int64_t num_input_senders, - double distance_scale, - bool use_delay); - -std::vector fabric_local_message_backward_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor receiver_sender_idx, at::Tensor sender_receiver_idx, - at::Tensor offset_distance, at::Tensor offset_delay, at::Tensor step_flat, - double distance_scale, bool use_delay); - -std::vector fabric_local_message_backward_receiver_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor receiver_sender_idx, at::Tensor offset_distance, - at::Tensor offset_delay, at::Tensor step_flat, double distance_scale, - bool use_delay); - -std::vector fabric_local_message_backward_receiver_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor receiver_sender_idx, - at::Tensor offset_distance, - at::Tensor offset_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay); - -std::vector fabric_local_message_backward_sender_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor sender_receiver_idx, at::Tensor offset_distance, - at::Tensor offset_delay, at::Tensor step_flat, - at::Tensor receiver_max_logit, at::Tensor receiver_sumexp, - at::Tensor receiver_weighted_sum, double distance_scale, bool use_delay); - -std::vector fabric_local_message_backward_sender_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor sender_receiver_idx, - at::Tensor offset_distance, - at::Tensor offset_delay, - at::Tensor step_flat, - at::Tensor receiver_max_logit, - at::Tensor receiver_sumexp, - at::Tensor receiver_weighted_sum, - double distance_scale, - bool use_delay); - -std::vector fabric_local_message_backward_partitioned_fused_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor receiver_sender_idx, - at::Tensor offset_distance, - at::Tensor offset_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &fabric_local_message_forward_cuda, - "Fabric local message forward (CUDA)"); - m.def("forward_partitioned", &fabric_local_message_forward_partitioned_cuda, - "Fabric local message forward for partitioned sender banks (CUDA)"); - m.def("backward", &fabric_local_message_backward_cuda, - "Fabric local message backward (CUDA)"); - m.def("backward_receiver", &fabric_local_message_backward_receiver_cuda, - "Fabric local message receiver-phase backward (CUDA)"); - m.def("backward_receiver_partitioned", - &fabric_local_message_backward_receiver_partitioned_cuda, - "Fabric local message receiver-phase backward for partitioned sender banks (CUDA)"); - m.def("backward_sender", &fabric_local_message_backward_sender_cuda, - "Fabric local message sender-phase backward (CUDA)"); - m.def("backward_sender_partitioned", - &fabric_local_message_backward_sender_partitioned_cuda, - "Fabric local message sender-phase backward for partitioned sender banks (CUDA)"); - m.def("backward_partitioned_fused", - &fabric_local_message_backward_partitioned_fused_cuda, - "Fabric local message fused backward for partitioned sender banks (CUDA)"); -} diff --git a/src/cortical/fabric/backend/cuda/message_passing/local_message_cuda.py b/src/cortical/fabric/backend/cuda/message_passing/local_message_cuda.py deleted file mode 100644 index 43104320..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/local_message_cuda.py +++ /dev/null @@ -1,616 +0,0 @@ -from __future__ import annotations - -import os - -import torch -from torch.autograd import Function - -from cortical.native.extension_loader import safe_load_extension - -_mod_path = os.path.dirname(__file__) -_ext = None - - -def _merge_partitioned_sender_banks( - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - num_input_senders = int(input_k.shape[1]) - num_recurrent_senders = int(recurrent_k.shape[1]) - if num_input_senders == 0: - return recurrent_k, recurrent_v - if num_recurrent_senders == 0: - return input_k, input_v - k_all = torch.nn.functional.pad(input_k, (0, 0, 0, num_recurrent_senders)) + torch.nn.functional.pad( - recurrent_k, - (0, 0, num_input_senders, 0), - ) - v_all = torch.nn.functional.pad(input_v, (0, 0, 0, num_recurrent_senders)) + torch.nn.functional.pad( - recurrent_v, - (0, 0, num_input_senders, 0), - ) - return k_all, v_all - - -def _load_ext(): - global _ext - if _ext is not None: - return _ext - _ext = safe_load_extension( - name="fabric_local_message_cuda", - sources=[ - os.path.join(_mod_path, "local_message_binding.cpp"), - os.path.join(_mod_path, "local_message_kernels.cu"), - ], - extra_cflags=["-O3"], - extra_cuda_cflags=["-O3", "-Xptxas", "-O3"], - verbose=False, - ) - return _ext - - -class _FabricLocalMessageCUDA(Function): - @staticmethod - def forward( - q: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - receiver_sender_idx: torch.Tensor, - sender_receiver_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - distance_scale: float, - use_delay: bool, - owner_tag: str, - ) -> torch.Tensor: - del owner_tag - (msg,) = _load_ext().forward( - q.contiguous(), - k_all.contiguous(), - v_all.contiguous(), - receiver_sender_idx.contiguous(), - sender_receiver_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return msg - - @staticmethod - def setup_context(ctx, inputs, output): - del output - ( - q, - k_all, - v_all, - receiver_sender_idx, - sender_receiver_idx, - offset_distance, - offset_delay, - step_flat, - distance_scale, - use_delay, - owner_tag, - ) = inputs - ctx.save_for_backward( - q, - k_all, - v_all, - receiver_sender_idx, - sender_receiver_idx, - offset_distance, - offset_delay, - step_flat, - ) - ctx.distance_scale = float(distance_scale) - ctx.use_delay = bool(use_delay) - ctx.profile_name = ( - f"fabric.backward.tiny_message_superop.{owner_tag}" - if os.environ.get("CORTICAL_FABRIC_BACKWARD_ATTRIBUTION_MODE") == "message_owner_probe" - else "fabric.backward.tiny_message_superop" - ) - - @staticmethod - def backward(ctx, grad_msg: torch.Tensor): - ( - q, - k_all, - v_all, - receiver_sender_idx, - sender_receiver_idx, - offset_distance, - offset_delay, - step_flat, - ) = ctx.saved_tensors - with torch.profiler.record_function(ctx.profile_name): - grad_q, grad_k, grad_v = _load_ext().backward( - grad_msg.contiguous(), - q.contiguous(), - k_all.contiguous(), - v_all.contiguous(), - receiver_sender_idx.contiguous(), - sender_receiver_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - ctx.distance_scale, - ctx.use_delay, - ) - return ( - grad_q, - grad_k, - grad_v, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class _FabricLocalMessagePartitionedCUDA(Function): - @staticmethod - def forward( - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - receiver_sender_idx: torch.Tensor, - sender_receiver_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - num_input_senders: int, - distance_scale: float, - use_delay: bool, - owner_tag: str, - ) -> torch.Tensor: - del owner_tag - (msg,) = _load_ext().forward_partitioned( - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - receiver_sender_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - int(num_input_senders), - float(distance_scale), - bool(use_delay), - ) - return msg - - @staticmethod - def setup_context(ctx, inputs, output): - del output - ( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - receiver_sender_idx, - _sender_receiver_idx, - offset_distance, - offset_delay, - step_flat, - num_input_senders, - distance_scale, - use_delay, - owner_tag, - ) = inputs - ctx.save_for_backward( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - receiver_sender_idx, - offset_distance, - offset_delay, - step_flat, - ) - ctx.distance_scale = float(distance_scale) - ctx.use_delay = bool(use_delay) - ctx.profile_name = ( - f"fabric.backward.tiny_message_superop.{owner_tag}" - if os.environ.get("CORTICAL_FABRIC_BACKWARD_ATTRIBUTION_MODE") == "message_owner_probe" - else "fabric.backward.tiny_message_superop" - ) - - @staticmethod - def backward(ctx, grad_msg: torch.Tensor): - ( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - receiver_sender_idx, - offset_distance, - offset_delay, - step_flat, - ) = ctx.saved_tensors - with torch.profiler.record_function(ctx.profile_name): - grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v = ( - _load_ext().backward_partitioned_fused( - grad_msg.contiguous(), - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - receiver_sender_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - ctx.distance_scale, - ctx.use_delay, - ) - ) - return ( - grad_q, - grad_input_k, - grad_input_v, - grad_recurrent_k, - grad_recurrent_v, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -def fabric_local_message_cuda( - q: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - receiver_sender_idx: torch.Tensor, - sender_receiver_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, - owner_tag: str = "generic", -) -> torch.Tensor: - return _FabricLocalMessageCUDA.apply( - q, - k_all, - v_all, - receiver_sender_idx, - sender_receiver_idx, - offset_distance, - offset_delay, - step_flat, - distance_scale, - use_delay, - owner_tag, - ) - - -def fabric_local_message_partitioned_cuda( - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - receiver_sender_idx: torch.Tensor, - sender_receiver_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - num_input_senders: int, - distance_scale: float, - use_delay: bool, - owner_tag: str = "generic", -) -> torch.Tensor: - return _FabricLocalMessagePartitionedCUDA.apply( - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - receiver_sender_idx.contiguous(), - sender_receiver_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - int(num_input_senders), - float(distance_scale), - bool(use_delay), - owner_tag, - ) - - -def fabric_local_message_backward_receiver_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - receiver_sender_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - grad_q, receiver_max_logit, receiver_sumexp, receiver_weighted_sum = _load_ext().backward_receiver( - grad_msg.contiguous(), - q.contiguous(), - k_all.contiguous(), - v_all.contiguous(), - receiver_sender_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return grad_q, receiver_max_logit, receiver_sumexp, receiver_weighted_sum - - -def fabric_local_message_backward_sender_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - sender_receiver_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - receiver_max_logit: torch.Tensor, - receiver_sumexp: torch.Tensor, - receiver_weighted_sum: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> tuple[torch.Tensor, torch.Tensor]: - grad_k, grad_v = _load_ext().backward_sender( - grad_msg.contiguous(), - q.contiguous(), - k_all.contiguous(), - v_all.contiguous(), - sender_receiver_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - receiver_max_logit.contiguous(), - receiver_sumexp.contiguous(), - receiver_weighted_sum.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return grad_k, grad_v - - -def fabric_local_message_partitioned_backward_receiver_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - receiver_sender_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - if not q.is_cuda: - k_all, v_all = _merge_partitioned_sender_banks(input_k, input_v, recurrent_k, recurrent_v) - return fabric_local_message_backward_receiver_cuda( - grad_msg, - q, - k_all, - v_all, - receiver_sender_idx, - offset_distance, - offset_delay, - step_flat, - distance_scale=distance_scale, - use_delay=use_delay, - ) - grad_q, receiver_max_logit, receiver_sumexp, receiver_weighted_sum = _load_ext().backward_receiver_partitioned( - grad_msg.contiguous(), - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - receiver_sender_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return grad_q, receiver_max_logit, receiver_sumexp, receiver_weighted_sum - - -def fabric_local_message_partitioned_backward_sender_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - sender_receiver_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - receiver_max_logit: torch.Tensor, - receiver_sumexp: torch.Tensor, - receiver_weighted_sum: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - if not q.is_cuda: - k_all, v_all = _merge_partitioned_sender_banks(input_k, input_v, recurrent_k, recurrent_v) - grad_k_all, grad_v_all = fabric_local_message_backward_sender_cuda( - grad_msg, - q, - k_all, - v_all, - sender_receiver_idx, - offset_distance, - offset_delay, - step_flat, - receiver_max_logit, - receiver_sumexp, - receiver_weighted_sum, - distance_scale=distance_scale, - use_delay=use_delay, - ) - num_input_senders = int(input_k.shape[1]) - grad_input_k, grad_recurrent_k = grad_k_all.split( - (num_input_senders, grad_k_all.shape[1] - num_input_senders), - dim=1, - ) - grad_input_v, grad_recurrent_v = grad_v_all.split( - (num_input_senders, grad_v_all.shape[1] - num_input_senders), - dim=1, - ) - return grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v - return _load_ext().backward_sender_partitioned( - grad_msg.contiguous(), - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - sender_receiver_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - receiver_max_logit.contiguous(), - receiver_sumexp.contiguous(), - receiver_weighted_sum.contiguous(), - float(distance_scale), - bool(use_delay), - ) - - -def fabric_local_message_partitioned_backward_fused_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - receiver_sender_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - if not q.is_cuda: - grad_q, receiver_max_logit, receiver_sumexp, receiver_weighted_sum = ( - fabric_local_message_partitioned_backward_receiver_cuda( - grad_msg, - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - receiver_sender_idx, - offset_distance, - offset_delay, - step_flat, - distance_scale=distance_scale, - use_delay=use_delay, - ) - ) - sender_receiver_idx = _receiver_sender_to_sender_receiver( - receiver_sender_idx, - input_k.shape[1] + recurrent_k.shape[1], - ) - grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v = ( - fabric_local_message_partitioned_backward_sender_cuda( - grad_msg, - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - sender_receiver_idx, - offset_distance, - offset_delay, - step_flat, - receiver_max_logit, - receiver_sumexp, - receiver_weighted_sum, - distance_scale=distance_scale, - use_delay=use_delay, - ) - ) - return grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v - return _load_ext().backward_partitioned_fused( - grad_msg.contiguous(), - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - receiver_sender_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - - -def _receiver_sender_to_sender_receiver(receiver_sender_idx: torch.Tensor, sender_count: int) -> torch.Tensor: - sender_receiver_idx = torch.full( - (int(sender_count), int(receiver_sender_idx.shape[1])), - -1, - device=receiver_sender_idx.device, - dtype=receiver_sender_idx.dtype, - ) - write_count = torch.zeros(int(sender_count), device=receiver_sender_idx.device, dtype=torch.long) - for receiver in range(int(receiver_sender_idx.shape[0])): - for offset in range(int(receiver_sender_idx.shape[1])): - sender = int(receiver_sender_idx[receiver, offset].item()) - if sender < 0: - continue - position = int(write_count[sender].item()) - if position < int(receiver_sender_idx.shape[1]): - sender_receiver_idx[sender, position] = receiver - write_count[sender] += 1 - return sender_receiver_idx - - -__all__ = [ - "fabric_local_message_backward_receiver_cuda", - "fabric_local_message_backward_sender_cuda", - "fabric_local_message_cuda", - "fabric_local_message_partitioned_backward_fused_cuda", - "fabric_local_message_partitioned_backward_receiver_cuda", - "fabric_local_message_partitioned_backward_sender_cuda", - "fabric_local_message_partitioned_cuda", -] diff --git a/src/cortical/fabric/backend/cuda/message_passing/local_message_kernels.cu b/src/cortical/fabric/backend/cuda/message_passing/local_message_kernels.cu deleted file mode 100644 index 657d544d..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/local_message_kernels.cu +++ /dev/null @@ -1,1721 +0,0 @@ -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace { - -constexpr int kWarpSize = 32; -constexpr int kMaxLocalOffsets = 32; -constexpr int kMaxSubgroupsPerWarp = kWarpSize; -constexpr int kWarpsPerBlock = 4; -constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; - -__host__ __device__ inline int next_power_of_two(int value) { - int result = 1; - while (result < value) { - result <<= 1; - } - return result; -} - -inline void check_cuda_tensor(const at::Tensor &tensor, const char *name) { - TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); - TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); -} - -inline void check_launch(const char *name) { - const cudaError_t err = cudaGetLastError(); - TORCH_CHECK(err == cudaSuccess, name, " launch failed: ", - cudaGetErrorString(err)); -} - -__device__ inline float warp_reduce_sum(float value, unsigned mask, int width) { - for (int delta = width / 2; delta > 0; delta >>= 1) { - value += __shfl_down_sync(mask, value, delta, width); - } - return __shfl_sync(mask, value, 0, width); -} - -__device__ inline float warp_reduce_max(float value, unsigned mask, int width) { - for (int delta = width / 2; delta > 0; delta >>= 1) { - value = fmaxf(value, __shfl_down_sync(mask, value, delta, width)); - } - return __shfl_sync(mask, value, 0, width); -} - -__global__ void fabric_local_message_forward_kernel( - const float *__restrict__ q, const float *__restrict__ k_all, - const float *__restrict__ v_all, - const int32_t *__restrict__ receiver_sender_idx, - const float *__restrict__ offset_distance, - const int32_t *__restrict__ offset_delay, - const int64_t *__restrict__ step_flat, float *__restrict__ out, int BT, - int R, int S, int O, int d_k, int d_v, float inv_sqrt_dk, - float distance_scale, bool use_delay) { - const int warp = threadIdx.x / kWarpSize; - const int lane = threadIdx.x % kWarpSize; - const int subgroup_width = next_power_of_two(O); - const int groups_per_warp = kWarpSize / subgroup_width; - const int subgroup = lane / subgroup_width; - const int subgroup_lane = lane % subgroup_width; - const int64_t linear_idx = - static_cast(blockIdx.x) * kWarpsPerBlock * groups_per_warp + - warp * groups_per_warp + subgroup; - const int64_t total = static_cast(BT) * R; - if (linear_idx >= total) { - return; - } - __shared__ int shared_sender[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - __shared__ float shared_weight[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - - const int bt = static_cast(linear_idx / R); - const int recv = static_cast(linear_idx % R); - const int64_t step_value = use_delay ? step_flat[bt] : 0; - const bool lane_active = subgroup_lane < O; - const unsigned subgroup_mask = - subgroup_width == kWarpSize - ? 0xffffffffu - : (((1u << subgroup_width) - 1u) << (subgroup * subgroup_width)); - int sender_idx = -1; - float weight = 0.0f; - if (subgroup < groups_per_warp) { - float logit = -std::numeric_limits::infinity(); - if (lane_active && - (!use_delay || - static_cast(offset_delay[subgroup_lane]) <= step_value)) { - sender_idx = receiver_sender_idx[recv * O + subgroup_lane]; - if (sender_idx >= 0 && sender_idx < S) { - const int64_t q_offset = static_cast(recv) * d_k; - const int64_t k_offset = - (static_cast(bt) * S + sender_idx) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_all[k_offset + d]; - } - logit = - dot * inv_sqrt_dk - distance_scale * offset_distance[subgroup_lane]; - } - } - const float max_logit = - warp_reduce_max(logit, subgroup_mask, subgroup_width); - if (sender_idx >= 0) { - weight = expf(logit - max_logit); - } - const float sum = warp_reduce_sum(weight, subgroup_mask, subgroup_width); - weight = sum > 0.0f ? weight / sum : 0.0f; - if (lane_active) { - shared_sender[warp][subgroup][subgroup_lane] = sender_idx; - shared_weight[warp][subgroup][subgroup_lane] = weight; - } - } - __syncwarp(subgroup_mask); - - const int64_t out_offset = - (static_cast(bt) * R + recv) * d_v; - for (int base = 0; base < d_v; base += subgroup_width) { - const int d = base + subgroup_lane; - if (d < d_v) { - float acc = 0.0f; - for (int src = 0; src < O; ++src) { - const int src_sender = shared_sender[warp][subgroup][src]; - const float src_weight = shared_weight[warp][subgroup][src]; - if (src_sender >= 0 && src_weight != 0.0f) { - const int64_t v_offset = - (static_cast(bt) * S + src_sender) * d_v + d; - acc += src_weight * v_all[v_offset]; - } - } - out[out_offset + d] = acc; - } - } -} - -__global__ void fabric_local_message_forward_partitioned_kernel( - const float *__restrict__ q, - const float *__restrict__ input_k, - const float *__restrict__ input_v, - const float *__restrict__ recurrent_k, - const float *__restrict__ recurrent_v, - const int32_t *__restrict__ receiver_sender_idx, - const float *__restrict__ offset_distance, - const int32_t *__restrict__ offset_delay, - const int64_t *__restrict__ step_flat, - float *__restrict__ out, - int BT, - int R, - int input_senders, - int recurrent_senders, - int O, - int d_k, - int d_v, - float inv_sqrt_dk, - float distance_scale, - bool use_delay) { - const int warp = threadIdx.x / kWarpSize; - const int lane = threadIdx.x % kWarpSize; - const int subgroup_width = next_power_of_two(O); - const int groups_per_warp = kWarpSize / subgroup_width; - const int subgroup = lane / subgroup_width; - const int subgroup_lane = lane % subgroup_width; - const int64_t linear_idx = - static_cast(blockIdx.x) * kWarpsPerBlock * groups_per_warp + - warp * groups_per_warp + subgroup; - const int64_t total = static_cast(BT) * R; - if (linear_idx >= total) { - return; - } - __shared__ int shared_sender[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - __shared__ float shared_weight[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - - const int bt = static_cast(linear_idx / R); - const int recv = static_cast(linear_idx % R); - const int64_t step_value = use_delay ? step_flat[bt] : 0; - const bool lane_active = subgroup_lane < O; - const unsigned subgroup_mask = - subgroup_width == kWarpSize - ? 0xffffffffu - : (((1u << subgroup_width) - 1u) << (subgroup * subgroup_width)); - int sender_idx = -1; - float weight = 0.0f; - if (subgroup < groups_per_warp) { - float logit = -std::numeric_limits::infinity(); - if (lane_active && - (!use_delay || - static_cast(offset_delay[subgroup_lane]) <= step_value)) { - sender_idx = receiver_sender_idx[recv * O + subgroup_lane]; - if (sender_idx >= 0 && sender_idx < input_senders + recurrent_senders) { - const bool is_input = sender_idx < input_senders; - const int bank_sender_idx = is_input ? sender_idx : sender_idx - input_senders; - const float* k_bank = is_input ? input_k : recurrent_k; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - const int64_t q_offset = static_cast(recv) * d_k; - const int64_t k_offset = - (static_cast(bt) * bank_sender_count + bank_sender_idx) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_bank[k_offset + d]; - } - logit = - dot * inv_sqrt_dk - distance_scale * offset_distance[subgroup_lane]; - } - } - const float max_logit = - warp_reduce_max(logit, subgroup_mask, subgroup_width); - if (sender_idx >= 0) { - weight = expf(logit - max_logit); - } - const float sum = warp_reduce_sum(weight, subgroup_mask, subgroup_width); - weight = sum > 0.0f ? weight / sum : 0.0f; - if (lane_active) { - shared_sender[warp][subgroup][subgroup_lane] = sender_idx; - shared_weight[warp][subgroup][subgroup_lane] = weight; - } - } - __syncwarp(subgroup_mask); - - const int64_t out_offset = - (static_cast(bt) * R + recv) * d_v; - for (int base = 0; base < d_v; base += subgroup_width) { - const int d = base + subgroup_lane; - if (d < d_v) { - float acc = 0.0f; - for (int src = 0; src < O; ++src) { - const int src_sender = shared_sender[warp][subgroup][src]; - const float src_weight = shared_weight[warp][subgroup][src]; - if (src_sender >= 0 && src_weight != 0.0f) { - const bool is_input = src_sender < input_senders; - const int bank_sender_idx = is_input ? src_sender : src_sender - input_senders; - const float* v_bank = is_input ? input_v : recurrent_v; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - const int64_t v_offset = - (static_cast(bt) * bank_sender_count + bank_sender_idx) * d_v + d; - acc += src_weight * v_bank[v_offset]; - } - } - out[out_offset + d] = acc; - } - } -} - -__global__ void fabric_local_message_backward_receiver_kernel( - const float *__restrict__ grad_msg, const float *__restrict__ q, - const float *__restrict__ k_all, const float *__restrict__ v_all, - const int32_t *__restrict__ receiver_sender_idx, - const float *__restrict__ offset_distance, - const int32_t *__restrict__ offset_delay, - const int64_t *__restrict__ step_flat, float *__restrict__ grad_q, - float *__restrict__ receiver_max_logit, - float *__restrict__ receiver_sumexp, - float *__restrict__ receiver_weighted_sum, int BT, int R, int S, int O, - int d_k, int d_v, float inv_sqrt_dk, float distance_scale, - bool use_delay) { - const int warp = threadIdx.x / kWarpSize; - const int lane = threadIdx.x % kWarpSize; - const int subgroup_width = next_power_of_two(O); - const int groups_per_warp = kWarpSize / subgroup_width; - const int subgroup = lane / subgroup_width; - const int subgroup_lane = lane % subgroup_width; - const int64_t linear_idx = - static_cast(blockIdx.x) * kWarpsPerBlock * groups_per_warp + - warp * groups_per_warp + subgroup; - const int64_t total = static_cast(BT) * R; - if (linear_idx >= total) { - return; - } - __shared__ int shared_sender[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - __shared__ float shared_dlogit[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - - const int bt = static_cast(linear_idx / R); - const int recv = static_cast(linear_idx % R); - const int64_t receiver_idx = static_cast(bt) * R + recv; - const int64_t step_value = use_delay ? step_flat[bt] : 0; - const int64_t grad_offset = receiver_idx * d_v; - const bool lane_active = subgroup_lane < O; - const unsigned subgroup_mask = - subgroup_width == kWarpSize - ? 0xffffffffu - : (((1u << subgroup_width) - 1u) << (subgroup * subgroup_width)); - int sender_idx = -1; - float weight = 0.0f; - float dlogit = 0.0f; - if (subgroup < groups_per_warp) { - float logit = -std::numeric_limits::infinity(); - if (lane_active && - (!use_delay || - static_cast(offset_delay[subgroup_lane]) <= step_value)) { - sender_idx = receiver_sender_idx[recv * O + subgroup_lane]; - if (sender_idx >= 0 && sender_idx < S) { - const int64_t q_offset = static_cast(recv) * d_k; - const int64_t k_offset = - (static_cast(bt) * S + sender_idx) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_all[k_offset + d]; - } - logit = - dot * inv_sqrt_dk - distance_scale * offset_distance[subgroup_lane]; - } - } - const float max_logit = - warp_reduce_max(logit, subgroup_mask, subgroup_width); - float raw_weight = 0.0f; - if (sender_idx >= 0) { - raw_weight = expf(logit - max_logit); - } - const float sumexp = - warp_reduce_sum(raw_weight, subgroup_mask, subgroup_width); - weight = sumexp > 0.0f ? raw_weight / sumexp : 0.0f; - - float dweight = 0.0f; - if (sender_idx >= 0) { - const int64_t v_offset = - (static_cast(bt) * S + sender_idx) * d_v; - for (int d = 0; d < d_v; ++d) { - dweight += grad_msg[grad_offset + d] * v_all[v_offset + d]; - } - } - const float weighted_sum = - warp_reduce_sum(weight * dweight, subgroup_mask, subgroup_width); - if (sender_idx >= 0) { - dlogit = weight * (dweight - weighted_sum); - } - const int stats_lane = static_cast(receiver_idx & (subgroup_width - 1)); - if (subgroup_lane == stats_lane) { - receiver_max_logit[receiver_idx] = max_logit; - receiver_sumexp[receiver_idx] = sumexp; - receiver_weighted_sum[receiver_idx] = weighted_sum; - } - if (lane_active) { - shared_sender[warp][subgroup][subgroup_lane] = sender_idx; - shared_dlogit[warp][subgroup][subgroup_lane] = dlogit; - } - } - __syncwarp(subgroup_mask); - - for (int base = 0; base < d_k; base += subgroup_width) { - const int d = base + subgroup_lane; - if (d < d_k) { - float recv_grad = 0.0f; - for (int src = 0; src < O; ++src) { - const int src_sender = shared_sender[warp][subgroup][src]; - const float src_dlogit = shared_dlogit[warp][subgroup][src]; - if (src_sender >= 0 && src_dlogit != 0.0f) { - recv_grad += - src_dlogit * - k_all[(static_cast(bt) * S + src_sender) * d_k + d] * - inv_sqrt_dk; - } - } - atomicAdd(&grad_q[recv * d_k + d], recv_grad); - } - } -} - -__global__ void fabric_local_message_backward_receiver_partitioned_kernel( - const float *__restrict__ grad_msg, - const float *__restrict__ q, - const float *__restrict__ input_k, - const float *__restrict__ input_v, - const float *__restrict__ recurrent_k, - const float *__restrict__ recurrent_v, - const int32_t *__restrict__ receiver_sender_idx, - const float *__restrict__ offset_distance, - const int32_t *__restrict__ offset_delay, - const int64_t *__restrict__ step_flat, - float *__restrict__ grad_q, - float *__restrict__ receiver_max_logit, - float *__restrict__ receiver_sumexp, - float *__restrict__ receiver_weighted_sum, - int BT, - int R, - int input_senders, - int recurrent_senders, - int O, - int d_k, - int d_v, - float inv_sqrt_dk, - float distance_scale, - bool use_delay) { - const int warp = threadIdx.x / kWarpSize; - const int lane = threadIdx.x % kWarpSize; - const int subgroup_width = next_power_of_two(O); - const int groups_per_warp = kWarpSize / subgroup_width; - const int subgroup = lane / subgroup_width; - const int subgroup_lane = lane % subgroup_width; - const int64_t linear_idx = - static_cast(blockIdx.x) * kWarpsPerBlock * groups_per_warp + - warp * groups_per_warp + subgroup; - const int64_t total = static_cast(BT) * R; - if (linear_idx >= total) { - return; - } - __shared__ int shared_sender[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - __shared__ float shared_dlogit[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - - const int bt = static_cast(linear_idx / R); - const int recv = static_cast(linear_idx % R); - const int64_t receiver_idx = static_cast(bt) * R + recv; - const int64_t step_value = use_delay ? step_flat[bt] : 0; - const int64_t grad_offset = receiver_idx * d_v; - const bool lane_active = subgroup_lane < O; - const unsigned subgroup_mask = - subgroup_width == kWarpSize - ? 0xffffffffu - : (((1u << subgroup_width) - 1u) << (subgroup * subgroup_width)); - int sender_idx = -1; - float weight = 0.0f; - float dlogit = 0.0f; - if (subgroup < groups_per_warp) { - float logit = -std::numeric_limits::infinity(); - if (lane_active && - (!use_delay || - static_cast(offset_delay[subgroup_lane]) <= step_value)) { - sender_idx = receiver_sender_idx[recv * O + subgroup_lane]; - if (sender_idx >= 0 && sender_idx < input_senders + recurrent_senders) { - const bool is_input = sender_idx < input_senders; - const int bank_sender_idx = is_input ? sender_idx : sender_idx - input_senders; - const float* k_bank = is_input ? input_k : recurrent_k; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - const int64_t q_offset = static_cast(recv) * d_k; - const int64_t k_offset = - (static_cast(bt) * bank_sender_count + bank_sender_idx) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_bank[k_offset + d]; - } - logit = - dot * inv_sqrt_dk - distance_scale * offset_distance[subgroup_lane]; - } - } - const float max_logit = - warp_reduce_max(logit, subgroup_mask, subgroup_width); - float raw_weight = 0.0f; - if (sender_idx >= 0) { - raw_weight = expf(logit - max_logit); - } - const float sumexp = - warp_reduce_sum(raw_weight, subgroup_mask, subgroup_width); - weight = sumexp > 0.0f ? raw_weight / sumexp : 0.0f; - - float dweight = 0.0f; - if (sender_idx >= 0) { - const bool is_input = sender_idx < input_senders; - const int bank_sender_idx = is_input ? sender_idx : sender_idx - input_senders; - const float* v_bank = is_input ? input_v : recurrent_v; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - const int64_t v_offset = - (static_cast(bt) * bank_sender_count + bank_sender_idx) * d_v; - for (int d = 0; d < d_v; ++d) { - dweight += grad_msg[grad_offset + d] * v_bank[v_offset + d]; - } - } - const float weighted_sum = - warp_reduce_sum(weight * dweight, subgroup_mask, subgroup_width); - if (sender_idx >= 0) { - dlogit = weight * (dweight - weighted_sum); - } - const int stats_lane = static_cast(receiver_idx & (subgroup_width - 1)); - if (subgroup_lane == stats_lane) { - receiver_max_logit[receiver_idx] = max_logit; - receiver_sumexp[receiver_idx] = sumexp; - receiver_weighted_sum[receiver_idx] = weighted_sum; - } - if (lane_active) { - shared_sender[warp][subgroup][subgroup_lane] = sender_idx; - shared_dlogit[warp][subgroup][subgroup_lane] = dlogit; - } - } - __syncwarp(subgroup_mask); - - for (int base = 0; base < d_k; base += subgroup_width) { - const int d = base + subgroup_lane; - if (d < d_k) { - float recv_grad = 0.0f; - for (int src = 0; src < O; ++src) { - const int src_sender = shared_sender[warp][subgroup][src]; - const float src_dlogit = shared_dlogit[warp][subgroup][src]; - if (src_sender >= 0 && src_dlogit != 0.0f) { - const bool is_input = src_sender < input_senders; - const int bank_sender_idx = is_input ? src_sender : src_sender - input_senders; - const float* k_bank = is_input ? input_k : recurrent_k; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - recv_grad += - src_dlogit * - k_bank[(static_cast(bt) * bank_sender_count + bank_sender_idx) * d_k + d] * - inv_sqrt_dk; - } - } - atomicAdd(&grad_q[recv * d_k + d], recv_grad); - } - } -} - -__global__ void fabric_local_message_backward_sender_kernel( - const float *__restrict__ grad_msg, const float *__restrict__ q, - const float *__restrict__ k_all, const float *__restrict__ v_all, - const int32_t *__restrict__ sender_receiver_idx, - const float *__restrict__ offset_distance, - const int32_t *__restrict__ offset_delay, - const int64_t *__restrict__ step_flat, - const float *__restrict__ receiver_max_logit, - const float *__restrict__ receiver_sumexp, - const float *__restrict__ receiver_weighted_sum, float *__restrict__ grad_k, - float *__restrict__ grad_v, int BT, int R, int S, int O, int d_k, int d_v, - float inv_sqrt_dk, float distance_scale, bool use_delay) { - const int warp = threadIdx.x / kWarpSize; - const int lane = threadIdx.x % kWarpSize; - const int subgroup_width = - min(kWarpSize, next_power_of_two(d_k > d_v ? d_k : d_v)); - const int groups_per_warp = kWarpSize / subgroup_width; - const int subgroup = lane / subgroup_width; - const int subgroup_lane = lane % subgroup_width; - const unsigned subgroup_mask = - subgroup_width == kWarpSize - ? 0xffffffffu - : (((1u << subgroup_width) - 1u) << (subgroup * subgroup_width)); - const int64_t linear_idx = - static_cast(blockIdx.x) * kWarpsPerBlock * groups_per_warp + - warp * groups_per_warp + subgroup; - const int64_t total = static_cast(BT) * S; - if (linear_idx >= total) { - return; - } - - const int bt = static_cast(linear_idx / S); - const int sender = static_cast(linear_idx % S); - const int64_t step_value = use_delay ? step_flat[bt] : 0; - const int max_dim = d_k > d_v ? d_k : d_v; - for (int base = 0; base < max_dim; base += subgroup_width) { - const int d_v_idx = base + subgroup_lane; - const int d_k_idx = base + subgroup_lane; - float grad_v_value = 0.0f; - float grad_k_value = 0.0f; - for (int src = 0; src < O; ++src) { - const int stats_lane = src & (subgroup_width - 1); - int receiver = -1; - float max_logit = -std::numeric_limits::infinity(); - float sumexp = 0.0f; - float weighted_sum = 0.0f; - if (subgroup_lane == stats_lane) { - const int recv = sender_receiver_idx[sender * O + src]; - if (recv >= 0 && recv < R && - (!use_delay || - static_cast(offset_delay[src]) <= step_value)) { - const int64_t receiver_idx = static_cast(bt) * R + recv; - max_logit = receiver_max_logit[receiver_idx]; - sumexp = receiver_sumexp[receiver_idx]; - weighted_sum = receiver_weighted_sum[receiver_idx]; - if (isfinite(max_logit) && sumexp > 0.0f) { - receiver = recv; - } - } - } - receiver = __shfl_sync(subgroup_mask, receiver, stats_lane, subgroup_width); - if (receiver < 0) { - continue; - } - max_logit = __shfl_sync(subgroup_mask, max_logit, stats_lane, subgroup_width); - sumexp = __shfl_sync(subgroup_mask, sumexp, stats_lane, subgroup_width); - weighted_sum = - __shfl_sync(subgroup_mask, weighted_sum, stats_lane, subgroup_width); - float dot_part = 0.0f; - for (int dk = subgroup_lane; dk < d_k; dk += subgroup_width) { - dot_part += - q[static_cast(receiver) * d_k + dk] * - k_all[(static_cast(bt) * S + sender) * d_k + dk]; - } - const float dot = warp_reduce_sum(dot_part, subgroup_mask, subgroup_width); - float dweight_part = 0.0f; - for (int dv = subgroup_lane; dv < d_v; dv += subgroup_width) { - dweight_part += - grad_msg[(static_cast(bt) * R + receiver) * d_v + dv] * - v_all[(static_cast(bt) * S + sender) * d_v + dv]; - } - const float dweight = - warp_reduce_sum(dweight_part, subgroup_mask, subgroup_width); - float dlogit = 0.0f; - float weight = 0.0f; - if (subgroup_lane == stats_lane) { - const float logit = - dot * inv_sqrt_dk - distance_scale * offset_distance[src]; - weight = expf(logit - max_logit) / sumexp; - dlogit = weight * (dweight - weighted_sum); - } - weight = __shfl_sync(subgroup_mask, weight, stats_lane, subgroup_width); - dlogit = __shfl_sync(subgroup_mask, dlogit, stats_lane, subgroup_width); - if (d_v_idx < d_v && weight > 0.0f) { - grad_v_value += - weight * grad_msg[(static_cast(bt) * R + receiver) * d_v + d_v_idx]; - } - if (d_k_idx < d_k && dlogit != 0.0f) { - grad_k_value += dlogit * q[receiver * d_k + d_k_idx] * inv_sqrt_dk; - } - } - if (d_v_idx < d_v) { - grad_v[(static_cast(bt) * S + sender) * d_v + d_v_idx] = grad_v_value; - } - if (d_k_idx < d_k) { - grad_k[(static_cast(bt) * S + sender) * d_k + d_k_idx] = grad_k_value; - } - } -} - -__global__ void fabric_local_message_backward_sender_partitioned_kernel( - const float *__restrict__ grad_msg, - const float *__restrict__ q, - const float *__restrict__ input_k, - const float *__restrict__ input_v, - const float *__restrict__ recurrent_k, - const float *__restrict__ recurrent_v, - const int32_t *__restrict__ sender_receiver_idx, - const float *__restrict__ offset_distance, - const int32_t *__restrict__ offset_delay, - const int64_t *__restrict__ step_flat, - const float *__restrict__ receiver_max_logit, - const float *__restrict__ receiver_sumexp, - const float *__restrict__ receiver_weighted_sum, - float *__restrict__ grad_input_k, - float *__restrict__ grad_input_v, - float *__restrict__ grad_recurrent_k, - float *__restrict__ grad_recurrent_v, - int BT, - int R, - int input_senders, - int recurrent_senders, - int O, - int d_k, - int d_v, - float inv_sqrt_dk, - float distance_scale, - bool use_delay) { - const int warp = threadIdx.x / kWarpSize; - const int lane = threadIdx.x % kWarpSize; - const int subgroup_width = - min(kWarpSize, next_power_of_two(d_k > d_v ? d_k : d_v)); - const int groups_per_warp = kWarpSize / subgroup_width; - const int subgroup = lane / subgroup_width; - const int subgroup_lane = lane % subgroup_width; - const unsigned subgroup_mask = - subgroup_width == kWarpSize - ? 0xffffffffu - : (((1u << subgroup_width) - 1u) << (subgroup * subgroup_width)); - const int64_t linear_idx = - static_cast(blockIdx.x) * kWarpsPerBlock * groups_per_warp + - warp * groups_per_warp + subgroup; - const int total_senders = input_senders + recurrent_senders; - const int64_t total = static_cast(BT) * total_senders; - if (linear_idx >= total) { - return; - } - - const int bt = static_cast(linear_idx / total_senders); - const int sender = static_cast(linear_idx % total_senders); - const bool is_input = sender < input_senders; - const int bank_sender = is_input ? sender : sender - input_senders; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - const float* k_bank = is_input ? input_k : recurrent_k; - const float* v_bank = is_input ? input_v : recurrent_v; - float* grad_k_bank = is_input ? grad_input_k : grad_recurrent_k; - float* grad_v_bank = is_input ? grad_input_v : grad_recurrent_v; - const int64_t step_value = use_delay ? step_flat[bt] : 0; - const int max_dim = d_k > d_v ? d_k : d_v; - for (int base = 0; base < max_dim; base += subgroup_width) { - const int d_v_idx = base + subgroup_lane; - const int d_k_idx = base + subgroup_lane; - float grad_v_value = 0.0f; - float grad_k_value = 0.0f; - for (int src = 0; src < O; ++src) { - const int stats_lane = src & (subgroup_width - 1); - int receiver = -1; - float max_logit = -std::numeric_limits::infinity(); - float sumexp = 0.0f; - float weighted_sum = 0.0f; - if (subgroup_lane == stats_lane) { - const int recv = sender_receiver_idx[sender * O + src]; - if (recv >= 0 && recv < R && - (!use_delay || - static_cast(offset_delay[src]) <= step_value)) { - const int64_t receiver_idx = static_cast(bt) * R + recv; - max_logit = receiver_max_logit[receiver_idx]; - sumexp = receiver_sumexp[receiver_idx]; - weighted_sum = receiver_weighted_sum[receiver_idx]; - if (isfinite(max_logit) && sumexp > 0.0f) { - receiver = recv; - } - } - } - receiver = __shfl_sync(subgroup_mask, receiver, stats_lane, subgroup_width); - if (receiver < 0) { - continue; - } - max_logit = __shfl_sync(subgroup_mask, max_logit, stats_lane, subgroup_width); - sumexp = __shfl_sync(subgroup_mask, sumexp, stats_lane, subgroup_width); - weighted_sum = - __shfl_sync(subgroup_mask, weighted_sum, stats_lane, subgroup_width); - float dot_part = 0.0f; - for (int dk = subgroup_lane; dk < d_k; dk += subgroup_width) { - dot_part += - q[static_cast(receiver) * d_k + dk] * - k_bank[(static_cast(bt) * bank_sender_count + bank_sender) * d_k + dk]; - } - const float dot = warp_reduce_sum(dot_part, subgroup_mask, subgroup_width); - float dweight_part = 0.0f; - for (int dv = subgroup_lane; dv < d_v; dv += subgroup_width) { - dweight_part += - grad_msg[(static_cast(bt) * R + receiver) * d_v + dv] * - v_bank[(static_cast(bt) * bank_sender_count + bank_sender) * d_v + dv]; - } - const float dweight = - warp_reduce_sum(dweight_part, subgroup_mask, subgroup_width); - float dlogit = 0.0f; - float weight = 0.0f; - if (subgroup_lane == stats_lane) { - const float logit = - dot * inv_sqrt_dk - distance_scale * offset_distance[src]; - weight = expf(logit - max_logit) / sumexp; - dlogit = weight * (dweight - weighted_sum); - } - weight = __shfl_sync(subgroup_mask, weight, stats_lane, subgroup_width); - dlogit = __shfl_sync(subgroup_mask, dlogit, stats_lane, subgroup_width); - if (d_v_idx < d_v && weight > 0.0f) { - grad_v_value += - weight * grad_msg[(static_cast(bt) * R + receiver) * d_v + d_v_idx]; - } - if (d_k_idx < d_k && dlogit != 0.0f) { - grad_k_value += dlogit * q[receiver * d_k + d_k_idx] * inv_sqrt_dk; - } - } - if (d_v_idx < d_v) { - grad_v_bank[(static_cast(bt) * bank_sender_count + bank_sender) * d_v + d_v_idx] = grad_v_value; - } - if (d_k_idx < d_k) { - grad_k_bank[(static_cast(bt) * bank_sender_count + bank_sender) * d_k + d_k_idx] = grad_k_value; - } - } -} - -__global__ void fabric_local_message_backward_partitioned_fused_kernel( - const float *__restrict__ grad_msg, - const float *__restrict__ q, - const float *__restrict__ input_k, - const float *__restrict__ input_v, - const float *__restrict__ recurrent_k, - const float *__restrict__ recurrent_v, - const int32_t *__restrict__ receiver_sender_idx, - const float *__restrict__ offset_distance, - const int32_t *__restrict__ offset_delay, - const int64_t *__restrict__ step_flat, - float *__restrict__ grad_q, - float *__restrict__ grad_input_k, - float *__restrict__ grad_input_v, - float *__restrict__ grad_recurrent_k, - float *__restrict__ grad_recurrent_v, - int BT, - int R, - int input_senders, - int recurrent_senders, - int O, - int d_k, - int d_v, - float inv_sqrt_dk, - float distance_scale, - bool use_delay) { - const int warp = threadIdx.x / kWarpSize; - const int lane = threadIdx.x % kWarpSize; - const int subgroup_width = next_power_of_two(O); - const int groups_per_warp = kWarpSize / subgroup_width; - const int subgroup = lane / subgroup_width; - const int subgroup_lane = lane % subgroup_width; - const int64_t linear_idx = - static_cast(blockIdx.x) * kWarpsPerBlock * groups_per_warp + - warp * groups_per_warp + subgroup; - const int64_t total = static_cast(BT) * R; - if (linear_idx >= total) { - return; - } - __shared__ int shared_sender[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - __shared__ float shared_weight[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - __shared__ float shared_dlogit[kWarpsPerBlock][kMaxSubgroupsPerWarp][kMaxLocalOffsets]; - - const int bt = static_cast(linear_idx / R); - const int recv = static_cast(linear_idx % R); - const int64_t step_value = use_delay ? step_flat[bt] : 0; - const int64_t grad_offset = - (static_cast(bt) * R + recv) * d_v; - const bool lane_active = subgroup_lane < O; - const unsigned subgroup_mask = - subgroup_width == kWarpSize - ? 0xffffffffu - : (((1u << subgroup_width) - 1u) << (subgroup * subgroup_width)); - int sender_idx = -1; - float weight = 0.0f; - float dlogit = 0.0f; - if (subgroup < groups_per_warp) { - float logit = -std::numeric_limits::infinity(); - if (lane_active && - (!use_delay || - static_cast(offset_delay[subgroup_lane]) <= step_value)) { - sender_idx = receiver_sender_idx[recv * O + subgroup_lane]; - if (sender_idx >= 0 && sender_idx < input_senders + recurrent_senders) { - const bool is_input = sender_idx < input_senders; - const int bank_sender_idx = is_input ? sender_idx : sender_idx - input_senders; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - const float *k_bank = is_input ? input_k : recurrent_k; - const int64_t q_offset = static_cast(recv) * d_k; - const int64_t k_offset = - (static_cast(bt) * bank_sender_count + bank_sender_idx) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_bank[k_offset + d]; - } - logit = - dot * inv_sqrt_dk - distance_scale * offset_distance[subgroup_lane]; - } - } - const float max_logit = - warp_reduce_max(logit, subgroup_mask, subgroup_width); - float raw_weight = 0.0f; - if (sender_idx >= 0) { - raw_weight = expf(logit - max_logit); - } - const float sumexp = - warp_reduce_sum(raw_weight, subgroup_mask, subgroup_width); - weight = sumexp > 0.0f ? raw_weight / sumexp : 0.0f; - - float dweight = 0.0f; - if (sender_idx >= 0) { - const bool is_input = sender_idx < input_senders; - const int bank_sender_idx = is_input ? sender_idx : sender_idx - input_senders; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - const float *v_bank = is_input ? input_v : recurrent_v; - const int64_t v_offset = - (static_cast(bt) * bank_sender_count + bank_sender_idx) * d_v; - for (int d = 0; d < d_v; ++d) { - dweight += grad_msg[grad_offset + d] * v_bank[v_offset + d]; - } - } - const float weighted_sum = - warp_reduce_sum(weight * dweight, subgroup_mask, subgroup_width); - if (sender_idx >= 0) { - dlogit = weight * (dweight - weighted_sum); - } - if (lane_active) { - shared_sender[warp][subgroup][subgroup_lane] = sender_idx; - shared_weight[warp][subgroup][subgroup_lane] = weight; - shared_dlogit[warp][subgroup][subgroup_lane] = dlogit; - } - } - __syncwarp(subgroup_mask); - - for (int base = 0; base < d_k; base += subgroup_width) { - const int d = base + subgroup_lane; - if (d < d_k) { - float recv_grad = 0.0f; - for (int src = 0; src < O; ++src) { - const int src_sender = shared_sender[warp][subgroup][src]; - const float src_dlogit = shared_dlogit[warp][subgroup][src]; - if (src_sender >= 0 && src_dlogit != 0.0f) { - const bool is_input = src_sender < input_senders; - const int bank_sender_idx = is_input ? src_sender : src_sender - input_senders; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - const float *k_bank = is_input ? input_k : recurrent_k; - recv_grad += - src_dlogit * - k_bank[(static_cast(bt) * bank_sender_count + bank_sender_idx) * d_k + d] * - inv_sqrt_dk; - } - } - atomicAdd(&grad_q[recv * d_k + d], recv_grad); - } - } - - const int max_dim = d_k > d_v ? d_k : d_v; - for (int src = 0; src < O; ++src) { - const int src_sender = shared_sender[warp][subgroup][src]; - if (src_sender < 0) { - continue; - } - const bool is_input = src_sender < input_senders; - const int bank_sender = is_input ? src_sender : src_sender - input_senders; - const int bank_sender_count = is_input ? input_senders : recurrent_senders; - float *grad_k_bank = is_input ? grad_input_k : grad_recurrent_k; - float *grad_v_bank = is_input ? grad_input_v : grad_recurrent_v; - const float src_weight = shared_weight[warp][subgroup][src]; - const float src_dlogit = shared_dlogit[warp][subgroup][src]; - for (int base = 0; base < max_dim; base += subgroup_width) { - const int d = base + subgroup_lane; - if (d < d_v && src_weight != 0.0f) { - atomicAdd( - &grad_v_bank[(static_cast(bt) * bank_sender_count + bank_sender) * d_v + d], - src_weight * grad_msg[grad_offset + d]); - } - if (d < d_k && src_dlogit != 0.0f) { - atomicAdd( - &grad_k_bank[(static_cast(bt) * bank_sender_count + bank_sender) * d_k + d], - src_dlogit * q[recv * d_k + d] * inv_sqrt_dk); - } - } - } -} - -} // namespace - -std::vector fabric_local_message_forward_cuda( - at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor receiver_sender_idx, at::Tensor sender_receiver_idx, - at::Tensor offset_distance, at::Tensor offset_delay, at::Tensor step_flat, - double distance_scale, bool use_delay) { - check_cuda_tensor(q, "q"); - check_cuda_tensor(k_all, "k_all"); - check_cuda_tensor(v_all, "v_all"); - check_cuda_tensor(receiver_sender_idx, "receiver_sender_idx"); - check_cuda_tensor(sender_receiver_idx, "sender_receiver_idx"); - check_cuda_tensor(offset_distance, "offset_distance"); - check_cuda_tensor(offset_delay, "offset_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(k_all.scalar_type() == at::kFloat, "k_all must be float32"); - TORCH_CHECK(v_all.scalar_type() == at::kFloat, "v_all must be float32"); - TORCH_CHECK(receiver_sender_idx.scalar_type() == at::kInt, - "receiver_sender_idx must be int32"); - TORCH_CHECK(sender_receiver_idx.scalar_type() == at::kInt, - "sender_receiver_idx must be int32"); - TORCH_CHECK(offset_distance.scalar_type() == at::kFloat, - "offset_distance must be float32"); - TORCH_CHECK(offset_delay.scalar_type() == at::kInt, - "offset_delay must be int32"); - TORCH_CHECK(step_flat.scalar_type() == at::kLong, - "step_flat must be int64"); - TORCH_CHECK(q.dim() == 2, "q must have shape [R, d_k]"); - TORCH_CHECK(k_all.dim() == 3, "k_all must have shape [BT, S, d_k]"); - TORCH_CHECK(v_all.dim() == 3, "v_all must have shape [BT, S, d_v]"); - TORCH_CHECK(receiver_sender_idx.dim() == 2, - "receiver_sender_idx must have shape [R, O]"); - TORCH_CHECK(sender_receiver_idx.dim() == 2, - "sender_receiver_idx must have shape [S, O]"); - TORCH_CHECK(offset_distance.dim() == 1, - "offset_distance must have shape [O]"); - TORCH_CHECK(offset_delay.dim() == 1, "offset_delay must have shape [O]"); - TORCH_CHECK(step_flat.dim() == 1, "step_flat must have shape [BT]"); - - const int BT = static_cast(k_all.size(0)); - const int S = static_cast(k_all.size(1)); - const int R = static_cast(q.size(0)); - const int O = static_cast(receiver_sender_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(v_all.size(2)); - TORCH_CHECK(O <= kMaxLocalOffsets, - "local-message kernel supports at most ", kMaxLocalOffsets, - " offsets"); - TORCH_CHECK(k_all.size(2) == q.size(1), - "k_all and q must agree on d_k"); - TORCH_CHECK(receiver_sender_idx.size(0) == q.size(0), - "receiver_sender_idx and q must agree on receiver count"); - TORCH_CHECK(sender_receiver_idx.size(0) == k_all.size(1) && - sender_receiver_idx.size(1) == receiver_sender_idx.size(1), - "sender_receiver_idx must match [S, O]"); - TORCH_CHECK(offset_distance.size(0) == receiver_sender_idx.size(1), - "offset_distance must match O"); - TORCH_CHECK(offset_delay.size(0) == receiver_sender_idx.size(1), - "offset_delay must match O"); - TORCH_CHECK(step_flat.size(0) == k_all.size(0), - "step_flat length must equal BT"); - - auto out = at::zeros({BT, R, d_v}, v_all.options()); - const int threads = kThreadsPerBlock; - const int64_t total = static_cast(BT) * R; - const int groups_per_warp = kWarpSize / next_power_of_two(O); - const int groups_per_block = kWarpsPerBlock * groups_per_warp; - const int blocks = - static_cast((total + groups_per_block - 1) / groups_per_block); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_local_message_forward_kernel<<>>( - q.data_ptr(), k_all.data_ptr(), v_all.data_ptr(), - receiver_sender_idx.data_ptr(), - offset_distance.data_ptr(), offset_delay.data_ptr(), - step_flat.data_ptr(), out.data_ptr(), BT, R, S, O, d_k, - d_v, 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_local_message_forward_kernel"); - return {out}; -} - -std::vector fabric_local_message_forward_partitioned_cuda( - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor receiver_sender_idx, - at::Tensor offset_distance, - at::Tensor offset_delay, - at::Tensor step_flat, - int64_t num_input_senders, - double distance_scale, - bool use_delay) { - check_cuda_tensor(q, "q"); - check_cuda_tensor(input_k, "input_k"); - check_cuda_tensor(input_v, "input_v"); - check_cuda_tensor(recurrent_k, "recurrent_k"); - check_cuda_tensor(recurrent_v, "recurrent_v"); - check_cuda_tensor(receiver_sender_idx, "receiver_sender_idx"); - check_cuda_tensor(offset_distance, "offset_distance"); - check_cuda_tensor(offset_delay, "offset_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(input_k.scalar_type() == at::kFloat, "input_k must be float32"); - TORCH_CHECK(input_v.scalar_type() == at::kFloat, "input_v must be float32"); - TORCH_CHECK(recurrent_k.scalar_type() == at::kFloat, "recurrent_k must be float32"); - TORCH_CHECK(recurrent_v.scalar_type() == at::kFloat, "recurrent_v must be float32"); - TORCH_CHECK(receiver_sender_idx.scalar_type() == at::kInt, - "receiver_sender_idx must be int32"); - TORCH_CHECK(offset_distance.scalar_type() == at::kFloat, - "offset_distance must be float32"); - TORCH_CHECK(offset_delay.scalar_type() == at::kInt, - "offset_delay must be int32"); - TORCH_CHECK(step_flat.scalar_type() == at::kLong, - "step_flat must be int64"); - TORCH_CHECK(q.dim() == 2, "q must have shape [R, d_k]"); - TORCH_CHECK(input_k.dim() == 3, "input_k must have shape [BT, Si, d_k]"); - TORCH_CHECK(input_v.dim() == 3, "input_v must have shape [BT, Si, d_v]"); - TORCH_CHECK(recurrent_k.dim() == 3, "recurrent_k must have shape [BT, Sr, d_k]"); - TORCH_CHECK(recurrent_v.dim() == 3, "recurrent_v must have shape [BT, Sr, d_v]"); - TORCH_CHECK(receiver_sender_idx.dim() == 2, - "receiver_sender_idx must have shape [R, O]"); - TORCH_CHECK(offset_distance.dim() == 1, "offset_distance must have shape [O]"); - TORCH_CHECK(offset_delay.dim() == 1, "offset_delay must have shape [O]"); - TORCH_CHECK(step_flat.dim() == 1, "step_flat must have shape [BT]"); - - const int BT = static_cast(input_k.size(0)); - const int input_senders = static_cast(input_k.size(1)); - const int recurrent_senders = static_cast(recurrent_k.size(1)); - const int R = static_cast(q.size(0)); - const int O = static_cast(receiver_sender_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(input_v.size(2)); - TORCH_CHECK(O <= kMaxLocalOffsets, - "local-message kernel supports at most ", kMaxLocalOffsets, - " offsets"); - TORCH_CHECK(input_k.size(0) == recurrent_k.size(0) && - input_k.size(0) == input_v.size(0) && - input_k.size(0) == recurrent_v.size(0), - "all sender banks must agree on BT"); - TORCH_CHECK(input_k.size(2) == q.size(1) && recurrent_k.size(2) == q.size(1), - "sender k banks and q must agree on d_k"); - TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), - "sender v banks must agree on d_v"); - TORCH_CHECK(receiver_sender_idx.size(0) == q.size(0), - "receiver_sender_idx and q must agree on receiver count"); - TORCH_CHECK(offset_distance.size(0) == receiver_sender_idx.size(1), - "offset_distance must match O"); - TORCH_CHECK(offset_delay.size(0) == receiver_sender_idx.size(1), - "offset_delay must match O"); - TORCH_CHECK(step_flat.size(0) == input_k.size(0), - "step_flat length must equal BT"); - TORCH_CHECK(num_input_senders == input_k.size(1), - "num_input_senders must match input_k sender dimension"); - - auto out = at::zeros({BT, R, d_v}, input_v.options()); - const int threads = kThreadsPerBlock; - const int64_t total = static_cast(BT) * R; - const int groups_per_warp = kWarpSize / next_power_of_two(O); - const int groups_per_block = kWarpsPerBlock * groups_per_warp; - const int blocks = - static_cast((total + groups_per_block - 1) / groups_per_block); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_local_message_forward_partitioned_kernel<<>>( - q.data_ptr(), - input_k.data_ptr(), - input_v.data_ptr(), - recurrent_k.data_ptr(), - recurrent_v.data_ptr(), - receiver_sender_idx.data_ptr(), - offset_distance.data_ptr(), - offset_delay.data_ptr(), - step_flat.data_ptr(), - out.data_ptr(), - BT, - R, - input_senders, - recurrent_senders, - O, - d_k, - d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), - use_delay); - check_launch("fabric_local_message_forward_partitioned_kernel"); - return {out}; -} - -std::vector fabric_local_message_backward_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor receiver_sender_idx, at::Tensor sender_receiver_idx, - at::Tensor offset_distance, at::Tensor offset_delay, at::Tensor step_flat, - double distance_scale, bool use_delay) { - check_cuda_tensor(grad_msg, "grad_msg"); - check_cuda_tensor(q, "q"); - check_cuda_tensor(k_all, "k_all"); - check_cuda_tensor(v_all, "v_all"); - check_cuda_tensor(receiver_sender_idx, "receiver_sender_idx"); - check_cuda_tensor(sender_receiver_idx, "sender_receiver_idx"); - check_cuda_tensor(offset_distance, "offset_distance"); - check_cuda_tensor(offset_delay, "offset_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(grad_msg.scalar_type() == at::kFloat, - "grad_msg must be float32"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(k_all.scalar_type() == at::kFloat, "k_all must be float32"); - TORCH_CHECK(v_all.scalar_type() == at::kFloat, "v_all must be float32"); - TORCH_CHECK(receiver_sender_idx.scalar_type() == at::kInt, - "receiver_sender_idx must be int32"); - TORCH_CHECK(sender_receiver_idx.scalar_type() == at::kInt, - "sender_receiver_idx must be int32"); - TORCH_CHECK(offset_distance.scalar_type() == at::kFloat, - "offset_distance must be float32"); - TORCH_CHECK(offset_delay.scalar_type() == at::kInt, - "offset_delay must be int32"); - TORCH_CHECK(step_flat.scalar_type() == at::kLong, - "step_flat must be int64"); - - const int BT = static_cast(k_all.size(0)); - const int S = static_cast(k_all.size(1)); - const int R = static_cast(q.size(0)); - const int O = static_cast(receiver_sender_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(v_all.size(2)); - TORCH_CHECK(O <= kMaxLocalOffsets, - "local-message kernel supports at most ", kMaxLocalOffsets, - " offsets"); - TORCH_CHECK(grad_msg.dim() == 3 && grad_msg.size(0) == k_all.size(0) && - grad_msg.size(1) == q.size(0) && - grad_msg.size(2) == v_all.size(2), - "grad_msg must have shape [BT, R, d_v]"); - TORCH_CHECK(receiver_sender_idx.dim() == 2 && - receiver_sender_idx.size(0) == q.size(0), - "receiver_sender_idx must match [R, O]"); - TORCH_CHECK(sender_receiver_idx.dim() == 2 && - sender_receiver_idx.size(0) == k_all.size(1) && - sender_receiver_idx.size(1) == receiver_sender_idx.size(1), - "sender_receiver_idx must match [S, O]"); - TORCH_CHECK(offset_distance.size(0) == receiver_sender_idx.size(1), - "offset_distance must match O"); - TORCH_CHECK(offset_delay.size(0) == receiver_sender_idx.size(1), - "offset_delay must match O"); - - auto grad_q = at::zeros_like(q); - auto grad_k = at::zeros_like(k_all); - auto grad_v = at::zeros_like(v_all); - auto receiver_max_logit = at::empty({BT, R}, q.options()); - auto receiver_sumexp = at::empty({BT, R}, q.options()); - auto receiver_weighted_sum = at::empty({BT, R}, q.options()); - - const int threads = kThreadsPerBlock; - const int64_t receiver_total = static_cast(BT) * R; - const int groups_per_warp = kWarpSize / next_power_of_two(O); - const int groups_per_block = kWarpsPerBlock * groups_per_warp; - const int receiver_blocks = - static_cast((receiver_total + groups_per_block - 1) / groups_per_block); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_local_message_backward_receiver_kernel<<>>( - grad_msg.data_ptr(), q.data_ptr(), k_all.data_ptr(), - v_all.data_ptr(), receiver_sender_idx.data_ptr(), - offset_distance.data_ptr(), - offset_delay.data_ptr(), step_flat.data_ptr(), - grad_q.data_ptr(), receiver_max_logit.data_ptr(), - receiver_sumexp.data_ptr(), - receiver_weighted_sum.data_ptr(), BT, R, S, O, d_k, d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_local_message_backward_receiver_kernel"); - - const int64_t sender_total = static_cast(BT) * S; - const int sender_blocks = - static_cast((sender_total + kWarpsPerBlock - 1) / kWarpsPerBlock); - fabric_local_message_backward_sender_kernel<<>>( - grad_msg.data_ptr(), q.data_ptr(), k_all.data_ptr(), - v_all.data_ptr(), sender_receiver_idx.data_ptr(), - offset_distance.data_ptr(), - offset_delay.data_ptr(), step_flat.data_ptr(), - receiver_max_logit.data_ptr(), - receiver_sumexp.data_ptr(), - receiver_weighted_sum.data_ptr(), grad_k.data_ptr(), - grad_v.data_ptr(), BT, R, S, O, d_k, d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_local_message_backward_sender_kernel"); - - return {grad_q, grad_k, grad_v}; -} - -std::vector fabric_local_message_backward_receiver_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor receiver_sender_idx, at::Tensor offset_distance, - at::Tensor offset_delay, at::Tensor step_flat, double distance_scale, - bool use_delay) { - check_cuda_tensor(grad_msg, "grad_msg"); - check_cuda_tensor(q, "q"); - check_cuda_tensor(k_all, "k_all"); - check_cuda_tensor(v_all, "v_all"); - check_cuda_tensor(receiver_sender_idx, "receiver_sender_idx"); - check_cuda_tensor(offset_distance, "offset_distance"); - check_cuda_tensor(offset_delay, "offset_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(grad_msg.scalar_type() == at::kFloat, - "grad_msg must be float32"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(k_all.scalar_type() == at::kFloat, "k_all must be float32"); - TORCH_CHECK(v_all.scalar_type() == at::kFloat, "v_all must be float32"); - TORCH_CHECK(receiver_sender_idx.scalar_type() == at::kInt, - "receiver_sender_idx must be int32"); - TORCH_CHECK(offset_distance.scalar_type() == at::kFloat, - "offset_distance must be float32"); - TORCH_CHECK(offset_delay.scalar_type() == at::kInt, - "offset_delay must be int32"); - TORCH_CHECK(step_flat.scalar_type() == at::kLong, - "step_flat must be int64"); - - const int BT = static_cast(k_all.size(0)); - const int S = static_cast(k_all.size(1)); - const int R = static_cast(q.size(0)); - const int O = static_cast(receiver_sender_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(v_all.size(2)); - TORCH_CHECK(O <= kMaxLocalOffsets, - "local-message kernel supports at most ", kMaxLocalOffsets, - " offsets"); - TORCH_CHECK(grad_msg.dim() == 3 && grad_msg.size(0) == k_all.size(0) && - grad_msg.size(1) == q.size(0) && - grad_msg.size(2) == v_all.size(2), - "grad_msg must have shape [BT, R, d_v]"); - TORCH_CHECK(receiver_sender_idx.dim() == 2 && - receiver_sender_idx.size(0) == q.size(0), - "receiver_sender_idx must match [R, O]"); - TORCH_CHECK(offset_distance.size(0) == receiver_sender_idx.size(1), - "offset_distance must match O"); - TORCH_CHECK(offset_delay.size(0) == receiver_sender_idx.size(1), - "offset_delay must match O"); - - auto grad_q = at::zeros_like(q); - auto receiver_max_logit = at::empty({BT, R}, q.options()); - auto receiver_sumexp = at::empty({BT, R}, q.options()); - auto receiver_weighted_sum = at::empty({BT, R}, q.options()); - - const int threads = kThreadsPerBlock; - const int64_t receiver_total = static_cast(BT) * R; - const int groups_per_warp = kWarpSize / next_power_of_two(O); - const int groups_per_block = kWarpsPerBlock * groups_per_warp; - const int receiver_blocks = - static_cast((receiver_total + groups_per_block - 1) / groups_per_block); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_local_message_backward_receiver_kernel<<>>( - grad_msg.data_ptr(), q.data_ptr(), k_all.data_ptr(), - v_all.data_ptr(), receiver_sender_idx.data_ptr(), - offset_distance.data_ptr(), - offset_delay.data_ptr(), step_flat.data_ptr(), - grad_q.data_ptr(), receiver_max_logit.data_ptr(), - receiver_sumexp.data_ptr(), - receiver_weighted_sum.data_ptr(), BT, R, S, O, d_k, d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_local_message_backward_receiver_kernel"); - return {grad_q, receiver_max_logit, receiver_sumexp, receiver_weighted_sum}; -} - -std::vector fabric_local_message_backward_receiver_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor receiver_sender_idx, - at::Tensor offset_distance, - at::Tensor offset_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay) { - check_cuda_tensor(grad_msg, "grad_msg"); - check_cuda_tensor(q, "q"); - check_cuda_tensor(input_k, "input_k"); - check_cuda_tensor(input_v, "input_v"); - check_cuda_tensor(recurrent_k, "recurrent_k"); - check_cuda_tensor(recurrent_v, "recurrent_v"); - check_cuda_tensor(receiver_sender_idx, "receiver_sender_idx"); - check_cuda_tensor(offset_distance, "offset_distance"); - check_cuda_tensor(offset_delay, "offset_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(grad_msg.scalar_type() == at::kFloat, - "grad_msg must be float32"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(input_k.scalar_type() == at::kFloat, "input_k must be float32"); - TORCH_CHECK(input_v.scalar_type() == at::kFloat, "input_v must be float32"); - TORCH_CHECK(recurrent_k.scalar_type() == at::kFloat, "recurrent_k must be float32"); - TORCH_CHECK(recurrent_v.scalar_type() == at::kFloat, "recurrent_v must be float32"); - TORCH_CHECK(receiver_sender_idx.scalar_type() == at::kInt, - "receiver_sender_idx must be int32"); - TORCH_CHECK(offset_distance.scalar_type() == at::kFloat, - "offset_distance must be float32"); - TORCH_CHECK(offset_delay.scalar_type() == at::kInt, - "offset_delay must be int32"); - TORCH_CHECK(step_flat.scalar_type() == at::kLong, - "step_flat must be int64"); - - const int BT = static_cast(input_k.size(0)); - const int input_senders = static_cast(input_k.size(1)); - const int recurrent_senders = static_cast(recurrent_k.size(1)); - const int R = static_cast(q.size(0)); - const int O = static_cast(receiver_sender_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(input_v.size(2)); - TORCH_CHECK(O <= kMaxLocalOffsets, - "local-message kernel supports at most ", kMaxLocalOffsets, - " offsets"); - TORCH_CHECK(input_k.dim() == 3 && input_v.dim() == 3 && - recurrent_k.dim() == 3 && recurrent_v.dim() == 3, - "sender banks must have shape [BT, S, d]"); - TORCH_CHECK(input_k.size(0) == input_v.size(0) && - input_k.size(0) == recurrent_k.size(0) && - input_k.size(0) == recurrent_v.size(0), - "all sender banks must agree on BT"); - TORCH_CHECK(input_k.size(2) == q.size(1) && recurrent_k.size(2) == q.size(1), - "sender k banks and q must agree on d_k"); - TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), - "sender v banks must agree on d_v"); - TORCH_CHECK(grad_msg.dim() == 3 && grad_msg.size(0) == input_k.size(0) && - grad_msg.size(1) == q.size(0) && - grad_msg.size(2) == input_v.size(2), - "grad_msg must have shape [BT, R, d_v]"); - TORCH_CHECK(receiver_sender_idx.dim() == 2 && - receiver_sender_idx.size(0) == q.size(0), - "receiver_sender_idx must match [R, O]"); - TORCH_CHECK(offset_distance.size(0) == receiver_sender_idx.size(1), - "offset_distance must match O"); - TORCH_CHECK(offset_delay.size(0) == receiver_sender_idx.size(1), - "offset_delay must match O"); - TORCH_CHECK(step_flat.size(0) == input_k.size(0), - "step_flat length must equal BT"); - - auto grad_q = at::zeros_like(q); - auto receiver_max_logit = at::empty({BT, R}, q.options()); - auto receiver_sumexp = at::empty({BT, R}, q.options()); - auto receiver_weighted_sum = at::empty({BT, R}, q.options()); - - const int threads = kThreadsPerBlock; - const int64_t receiver_total = static_cast(BT) * R; - const int groups_per_warp = kWarpSize / next_power_of_two(O); - const int groups_per_block = kWarpsPerBlock * groups_per_warp; - const int receiver_blocks = - static_cast((receiver_total + groups_per_block - 1) / groups_per_block); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_local_message_backward_receiver_partitioned_kernel<<>>( - grad_msg.data_ptr(), - q.data_ptr(), - input_k.data_ptr(), - input_v.data_ptr(), - recurrent_k.data_ptr(), - recurrent_v.data_ptr(), - receiver_sender_idx.data_ptr(), - offset_distance.data_ptr(), - offset_delay.data_ptr(), - step_flat.data_ptr(), - grad_q.data_ptr(), - receiver_max_logit.data_ptr(), - receiver_sumexp.data_ptr(), - receiver_weighted_sum.data_ptr(), - BT, - R, - input_senders, - recurrent_senders, - O, - d_k, - d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), - use_delay); - check_launch("fabric_local_message_backward_receiver_partitioned_kernel"); - return {grad_q, receiver_max_logit, receiver_sumexp, receiver_weighted_sum}; -} - -std::vector fabric_local_message_backward_sender_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor sender_receiver_idx, at::Tensor offset_distance, - at::Tensor offset_delay, at::Tensor step_flat, - at::Tensor receiver_max_logit, at::Tensor receiver_sumexp, - at::Tensor receiver_weighted_sum, double distance_scale, - bool use_delay) { - check_cuda_tensor(grad_msg, "grad_msg"); - check_cuda_tensor(q, "q"); - check_cuda_tensor(k_all, "k_all"); - check_cuda_tensor(v_all, "v_all"); - check_cuda_tensor(sender_receiver_idx, "sender_receiver_idx"); - check_cuda_tensor(offset_distance, "offset_distance"); - check_cuda_tensor(offset_delay, "offset_delay"); - check_cuda_tensor(step_flat, "step_flat"); - check_cuda_tensor(receiver_max_logit, "receiver_max_logit"); - check_cuda_tensor(receiver_sumexp, "receiver_sumexp"); - check_cuda_tensor(receiver_weighted_sum, "receiver_weighted_sum"); - TORCH_CHECK(grad_msg.scalar_type() == at::kFloat, - "grad_msg must be float32"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(k_all.scalar_type() == at::kFloat, "k_all must be float32"); - TORCH_CHECK(v_all.scalar_type() == at::kFloat, "v_all must be float32"); - TORCH_CHECK(sender_receiver_idx.scalar_type() == at::kInt, - "sender_receiver_idx must be int32"); - TORCH_CHECK(offset_distance.scalar_type() == at::kFloat, - "offset_distance must be float32"); - TORCH_CHECK(offset_delay.scalar_type() == at::kInt, - "offset_delay must be int32"); - TORCH_CHECK(step_flat.scalar_type() == at::kLong, - "step_flat must be int64"); - - const int BT = static_cast(k_all.size(0)); - const int S = static_cast(k_all.size(1)); - const int R = static_cast(q.size(0)); - const int O = static_cast(sender_receiver_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(v_all.size(2)); - TORCH_CHECK(O <= kMaxLocalOffsets, - "local-message kernel supports at most ", kMaxLocalOffsets, - " offsets"); - TORCH_CHECK(grad_msg.dim() == 3 && grad_msg.size(0) == k_all.size(0) && - grad_msg.size(1) == q.size(0) && - grad_msg.size(2) == v_all.size(2), - "grad_msg must have shape [BT, R, d_v]"); - TORCH_CHECK(sender_receiver_idx.dim() == 2 && - sender_receiver_idx.size(0) == k_all.size(1), - "sender_receiver_idx must match [S, O]"); - TORCH_CHECK(offset_distance.size(0) == sender_receiver_idx.size(1), - "offset_distance must match O"); - TORCH_CHECK(offset_delay.size(0) == sender_receiver_idx.size(1), - "offset_delay must match O"); - TORCH_CHECK(receiver_max_logit.dim() == 2 && - receiver_max_logit.size(0) == BT && - receiver_max_logit.size(1) == R, - "receiver_max_logit must have shape [BT, R]"); - TORCH_CHECK(receiver_sumexp.dim() == 2 && - receiver_sumexp.size(0) == BT && - receiver_sumexp.size(1) == R, - "receiver_sumexp must have shape [BT, R]"); - TORCH_CHECK(receiver_weighted_sum.dim() == 2 && - receiver_weighted_sum.size(0) == BT && - receiver_weighted_sum.size(1) == R, - "receiver_weighted_sum must have shape [BT, R]"); - - auto grad_k = at::zeros_like(k_all); - auto grad_v = at::zeros_like(v_all); - const int threads = kThreadsPerBlock; - const int64_t sender_total = static_cast(BT) * S; - const int sender_blocks = - static_cast((sender_total + kWarpsPerBlock - 1) / kWarpsPerBlock); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_local_message_backward_sender_kernel<<>>( - grad_msg.data_ptr(), q.data_ptr(), k_all.data_ptr(), - v_all.data_ptr(), sender_receiver_idx.data_ptr(), - offset_distance.data_ptr(), - offset_delay.data_ptr(), step_flat.data_ptr(), - receiver_max_logit.data_ptr(), - receiver_sumexp.data_ptr(), - receiver_weighted_sum.data_ptr(), grad_k.data_ptr(), - grad_v.data_ptr(), BT, R, S, O, d_k, d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_local_message_backward_sender_kernel"); - return {grad_k, grad_v}; -} - -std::vector fabric_local_message_backward_sender_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor sender_receiver_idx, - at::Tensor offset_distance, - at::Tensor offset_delay, - at::Tensor step_flat, - at::Tensor receiver_max_logit, - at::Tensor receiver_sumexp, - at::Tensor receiver_weighted_sum, - double distance_scale, - bool use_delay) { - check_cuda_tensor(grad_msg, "grad_msg"); - check_cuda_tensor(q, "q"); - check_cuda_tensor(input_k, "input_k"); - check_cuda_tensor(input_v, "input_v"); - check_cuda_tensor(recurrent_k, "recurrent_k"); - check_cuda_tensor(recurrent_v, "recurrent_v"); - check_cuda_tensor(sender_receiver_idx, "sender_receiver_idx"); - check_cuda_tensor(offset_distance, "offset_distance"); - check_cuda_tensor(offset_delay, "offset_delay"); - check_cuda_tensor(step_flat, "step_flat"); - check_cuda_tensor(receiver_max_logit, "receiver_max_logit"); - check_cuda_tensor(receiver_sumexp, "receiver_sumexp"); - check_cuda_tensor(receiver_weighted_sum, "receiver_weighted_sum"); - TORCH_CHECK(grad_msg.scalar_type() == at::kFloat, - "grad_msg must be float32"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(input_k.scalar_type() == at::kFloat, "input_k must be float32"); - TORCH_CHECK(input_v.scalar_type() == at::kFloat, "input_v must be float32"); - TORCH_CHECK(recurrent_k.scalar_type() == at::kFloat, "recurrent_k must be float32"); - TORCH_CHECK(recurrent_v.scalar_type() == at::kFloat, "recurrent_v must be float32"); - TORCH_CHECK(sender_receiver_idx.scalar_type() == at::kInt, - "sender_receiver_idx must be int32"); - TORCH_CHECK(offset_distance.scalar_type() == at::kFloat, - "offset_distance must be float32"); - TORCH_CHECK(offset_delay.scalar_type() == at::kInt, - "offset_delay must be int32"); - TORCH_CHECK(step_flat.scalar_type() == at::kLong, - "step_flat must be int64"); - - const int BT = static_cast(input_k.size(0)); - const int input_senders = static_cast(input_k.size(1)); - const int recurrent_senders = static_cast(recurrent_k.size(1)); - const int total_senders = input_senders + recurrent_senders; - const int R = static_cast(q.size(0)); - const int O = static_cast(sender_receiver_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(input_v.size(2)); - TORCH_CHECK(O <= kMaxLocalOffsets, - "local-message kernel supports at most ", kMaxLocalOffsets, - " offsets"); - TORCH_CHECK(input_k.dim() == 3 && input_v.dim() == 3 && - recurrent_k.dim() == 3 && recurrent_v.dim() == 3, - "sender banks must have shape [BT, S, d]"); - TORCH_CHECK(input_k.size(0) == input_v.size(0) && - input_k.size(0) == recurrent_k.size(0) && - input_k.size(0) == recurrent_v.size(0), - "all sender banks must agree on BT"); - TORCH_CHECK(input_k.size(2) == q.size(1) && recurrent_k.size(2) == q.size(1), - "sender k banks and q must agree on d_k"); - TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), - "sender v banks must agree on d_v"); - TORCH_CHECK(grad_msg.dim() == 3 && grad_msg.size(0) == input_k.size(0) && - grad_msg.size(1) == q.size(0) && - grad_msg.size(2) == input_v.size(2), - "grad_msg must have shape [BT, R, d_v]"); - TORCH_CHECK(sender_receiver_idx.dim() == 2 && - sender_receiver_idx.size(0) == total_senders, - "sender_receiver_idx must match [input_senders + recurrent_senders, O]"); - TORCH_CHECK(offset_distance.size(0) == sender_receiver_idx.size(1), - "offset_distance must match O"); - TORCH_CHECK(offset_delay.size(0) == sender_receiver_idx.size(1), - "offset_delay must match O"); - TORCH_CHECK(step_flat.size(0) == input_k.size(0), - "step_flat length must equal BT"); - TORCH_CHECK(receiver_max_logit.dim() == 2 && - receiver_max_logit.size(0) == BT && - receiver_max_logit.size(1) == R, - "receiver_max_logit must have shape [BT, R]"); - TORCH_CHECK(receiver_sumexp.dim() == 2 && - receiver_sumexp.size(0) == BT && - receiver_sumexp.size(1) == R, - "receiver_sumexp must have shape [BT, R]"); - TORCH_CHECK(receiver_weighted_sum.dim() == 2 && - receiver_weighted_sum.size(0) == BT && - receiver_weighted_sum.size(1) == R, - "receiver_weighted_sum must have shape [BT, R]"); - - auto grad_input_k = at::zeros_like(input_k); - auto grad_input_v = at::zeros_like(input_v); - auto grad_recurrent_k = at::zeros_like(recurrent_k); - auto grad_recurrent_v = at::zeros_like(recurrent_v); - const int threads = kThreadsPerBlock; - const int64_t sender_total = static_cast(BT) * total_senders; - const int sender_blocks = - static_cast((sender_total + kWarpsPerBlock - 1) / kWarpsPerBlock); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_local_message_backward_sender_partitioned_kernel<<>>( - grad_msg.data_ptr(), - q.data_ptr(), - input_k.data_ptr(), - input_v.data_ptr(), - recurrent_k.data_ptr(), - recurrent_v.data_ptr(), - sender_receiver_idx.data_ptr(), - offset_distance.data_ptr(), - offset_delay.data_ptr(), - step_flat.data_ptr(), - receiver_max_logit.data_ptr(), - receiver_sumexp.data_ptr(), - receiver_weighted_sum.data_ptr(), - grad_input_k.data_ptr(), - grad_input_v.data_ptr(), - grad_recurrent_k.data_ptr(), - grad_recurrent_v.data_ptr(), - BT, - R, - input_senders, - recurrent_senders, - O, - d_k, - d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), - use_delay); - check_launch("fabric_local_message_backward_sender_partitioned_kernel"); - return {grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v}; -} - -std::vector fabric_local_message_backward_partitioned_fused_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor receiver_sender_idx, - at::Tensor offset_distance, - at::Tensor offset_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay) { - check_cuda_tensor(grad_msg, "grad_msg"); - check_cuda_tensor(q, "q"); - check_cuda_tensor(input_k, "input_k"); - check_cuda_tensor(input_v, "input_v"); - check_cuda_tensor(recurrent_k, "recurrent_k"); - check_cuda_tensor(recurrent_v, "recurrent_v"); - check_cuda_tensor(receiver_sender_idx, "receiver_sender_idx"); - check_cuda_tensor(offset_distance, "offset_distance"); - check_cuda_tensor(offset_delay, "offset_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(grad_msg.scalar_type() == at::kFloat, - "grad_msg must be float32"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(input_k.scalar_type() == at::kFloat, "input_k must be float32"); - TORCH_CHECK(input_v.scalar_type() == at::kFloat, "input_v must be float32"); - TORCH_CHECK(recurrent_k.scalar_type() == at::kFloat, "recurrent_k must be float32"); - TORCH_CHECK(recurrent_v.scalar_type() == at::kFloat, "recurrent_v must be float32"); - TORCH_CHECK(receiver_sender_idx.scalar_type() == at::kInt, - "receiver_sender_idx must be int32"); - TORCH_CHECK(offset_distance.scalar_type() == at::kFloat, - "offset_distance must be float32"); - TORCH_CHECK(offset_delay.scalar_type() == at::kInt, - "offset_delay must be int32"); - TORCH_CHECK(step_flat.scalar_type() == at::kLong, - "step_flat must be int64"); - - const int BT = static_cast(input_k.size(0)); - const int input_senders = static_cast(input_k.size(1)); - const int recurrent_senders = static_cast(recurrent_k.size(1)); - const int R = static_cast(q.size(0)); - const int O = static_cast(receiver_sender_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(input_v.size(2)); - TORCH_CHECK(O <= kMaxLocalOffsets, - "local-message kernel supports at most ", kMaxLocalOffsets, - " offsets"); - TORCH_CHECK(input_k.dim() == 3 && input_v.dim() == 3 && - recurrent_k.dim() == 3 && recurrent_v.dim() == 3, - "sender banks must have shape [BT, S, d]"); - TORCH_CHECK(input_k.size(0) == input_v.size(0) && - input_k.size(0) == recurrent_k.size(0) && - input_k.size(0) == recurrent_v.size(0), - "all sender banks must agree on BT"); - TORCH_CHECK(input_k.size(2) == q.size(1) && recurrent_k.size(2) == q.size(1), - "sender k banks and q must agree on d_k"); - TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), - "sender v banks must agree on d_v"); - TORCH_CHECK(grad_msg.dim() == 3 && grad_msg.size(0) == input_k.size(0) && - grad_msg.size(1) == q.size(0) && - grad_msg.size(2) == input_v.size(2), - "grad_msg must have shape [BT, R, d_v]"); - TORCH_CHECK(receiver_sender_idx.dim() == 2 && - receiver_sender_idx.size(0) == q.size(0), - "receiver_sender_idx must match [R, O]"); - TORCH_CHECK(offset_distance.size(0) == receiver_sender_idx.size(1), - "offset_distance must match O"); - TORCH_CHECK(offset_delay.size(0) == receiver_sender_idx.size(1), - "offset_delay must match O"); - TORCH_CHECK(step_flat.size(0) == input_k.size(0), - "step_flat length must equal BT"); - - auto grad_q = at::zeros_like(q); - auto grad_input_k = at::zeros_like(input_k); - auto grad_input_v = at::zeros_like(input_v); - auto grad_recurrent_k = at::zeros_like(recurrent_k); - auto grad_recurrent_v = at::zeros_like(recurrent_v); - const int threads = kThreadsPerBlock; - const int64_t receiver_total = static_cast(BT) * R; - const int groups_per_warp = kWarpSize / next_power_of_two(O); - const int groups_per_block = kWarpsPerBlock * groups_per_warp; - const int receiver_blocks = - static_cast((receiver_total + groups_per_block - 1) / groups_per_block); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_local_message_backward_partitioned_fused_kernel<<>>( - grad_msg.data_ptr(), - q.data_ptr(), - input_k.data_ptr(), - input_v.data_ptr(), - recurrent_k.data_ptr(), - recurrent_v.data_ptr(), - receiver_sender_idx.data_ptr(), - offset_distance.data_ptr(), - offset_delay.data_ptr(), - step_flat.data_ptr(), - grad_q.data_ptr(), - grad_input_k.data_ptr(), - grad_input_v.data_ptr(), - grad_recurrent_k.data_ptr(), - grad_recurrent_v.data_ptr(), - BT, - R, - input_senders, - recurrent_senders, - O, - d_k, - d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), - use_delay); - check_launch("fabric_local_message_backward_partitioned_fused_kernel"); - return {grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v}; -} diff --git a/src/cortical/fabric/backend/cuda/message_passing/message_backend_contract.cuh b/src/cortical/fabric/backend/cuda/message_passing/message_backend_contract.cuh deleted file mode 100644 index d86fa2a0..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/message_backend_contract.cuh +++ /dev/null @@ -1,57 +0,0 @@ -#pragma once - -#include "cortical/fabric/backend/cuda/execution/common.cuh" - -namespace fabric { - -// MessageBackend concept: -// -// __device__ static void aggregate_receiver_step_warp( -// int b, -// int t, -// int receiver, -// bool reset_row, -// const MessageTopology& topo, -// const TensorTable& input_ports, -// const TensorTable& public_prev, -// const TensorTable& msg_params, -// float* msg_out, -// int msg_dim, -// int lane); -// -// __device__ static void edge_contrib_step( -// int b, -// int t, -// int edge_idx, -// bool reset_row, -// const MessageTopology& topo, -// const TensorTable& input_ports, -// const TensorTable& public_prev, -// const TensorTable& msg_params, -// float* contrib_out, -// int msg_dim, -// int* receiver_out); -// -// __device__ static float edge_logit_step( -// int b, -// int t, -// int edge_idx, -// bool reset_row, -// const MessageTopology& topo, -// const TensorTable& input_ports, -// const TensorTable& public_prev, -// const TensorTable& msg_params, -// int* receiver_out); -// -// __device__ static void edge_value_step( -// int b, -// int t, -// int edge_idx, -// bool reset_row, -// const MessageTopology& topo, -// const TensorTable& input_ports, -// const TensorTable& public_prev, -// float* value_out, -// int msg_dim); - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/message_passing/registry.py b/src/cortical/fabric/backend/cuda/message_passing/registry.py deleted file mode 100644 index b53bdf21..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/registry.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass - -import torch - -MessageBackend = Callable[..., torch.Tensor] - - -@dataclass(frozen=True) -class MessageBackendSpec: - name: str - backend: MessageBackend - - -_MESSAGE_BACKENDS: dict[str, MessageBackendSpec] = {} - - -def register_message_backend(name: str, backend: MessageBackend) -> None: - _MESSAGE_BACKENDS[name] = MessageBackendSpec(name=name, backend=backend) - - -def get_message_backend(name: str) -> MessageBackendSpec: - try: - return _MESSAGE_BACKENDS[name] - except KeyError as exc: - raise ValueError(f"Unsupported Fabric message backend {name}") from exc - - -__all__ = [ - "get_message_backend", - "MessageBackendSpec", - "register_message_backend", -] diff --git a/src/cortical/fabric/backend/cuda/message_passing/sparse_message_backend.cuh b/src/cortical/fabric/backend/cuda/message_passing/sparse_message_backend.cuh deleted file mode 100644 index 05f0fb95..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/sparse_message_backend.cuh +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "cortical/fabric/backend/cuda/message_passing/local_message_backend.cuh" - -namespace fabric { - -using SparseMessageBackend = LocalMessageBackend; - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/message_passing/sparse_message_binding.cpp b/src/cortical/fabric/backend/cuda/message_passing/sparse_message_binding.cpp deleted file mode 100644 index df6f4b1a..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/sparse_message_binding.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include - -std::vector fabric_sparse_message_forward_cuda( - at::Tensor q, at::Tensor k_all, at::Tensor v_all, at::Tensor neighbor_idx, - at::Tensor neighbor_valid, at::Tensor edge_distance, at::Tensor edge_delay, - at::Tensor step_flat, double distance_scale, bool use_delay); -std::vector fabric_sparse_message_forward_partitioned_cuda( - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor neighbor_idx, - at::Tensor neighbor_valid, - at::Tensor edge_distance, - at::Tensor edge_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay); - -std::vector fabric_sparse_message_backward_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor neighbor_idx, at::Tensor neighbor_valid, at::Tensor edge_distance, - at::Tensor edge_delay, at::Tensor step_flat, double distance_scale, - bool use_delay); -std::vector fabric_sparse_message_backward_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor neighbor_idx, - at::Tensor neighbor_valid, - at::Tensor edge_distance, - at::Tensor edge_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay); -std::vector fabric_sparse_message_backward_receiver_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor neighbor_idx, at::Tensor neighbor_valid, at::Tensor edge_distance, - at::Tensor edge_delay, at::Tensor step_flat, double distance_scale, - bool use_delay); -std::vector fabric_sparse_message_backward_receiver_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor neighbor_idx, - at::Tensor neighbor_valid, - at::Tensor edge_distance, - at::Tensor edge_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay); -std::vector fabric_sparse_message_backward_sender_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor neighbor_idx, at::Tensor neighbor_valid, at::Tensor edge_distance, - at::Tensor edge_delay, at::Tensor step_flat, double distance_scale, - bool use_delay); -std::vector fabric_sparse_message_backward_sender_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor neighbor_idx, - at::Tensor neighbor_valid, - at::Tensor edge_distance, - at::Tensor edge_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &fabric_sparse_message_forward_cuda, - "Fabric sparse message forward (CUDA)"); - m.def("forward_partitioned", &fabric_sparse_message_forward_partitioned_cuda, - "Fabric sparse message forward for partitioned sender banks (CUDA)"); - m.def("backward", &fabric_sparse_message_backward_cuda, - "Fabric sparse message backward (CUDA)"); - m.def("backward_partitioned", &fabric_sparse_message_backward_partitioned_cuda, - "Fabric sparse message backward for partitioned sender banks (CUDA)"); - m.def("backward_receiver", &fabric_sparse_message_backward_receiver_cuda, - "Fabric sparse message backward receiver phase (CUDA)"); - m.def("backward_receiver_partitioned", - &fabric_sparse_message_backward_receiver_partitioned_cuda, - "Fabric sparse message backward receiver phase for partitioned sender banks (CUDA)"); - m.def("backward_sender", &fabric_sparse_message_backward_sender_cuda, - "Fabric sparse message backward sender phase (CUDA)"); - m.def("backward_sender_partitioned", - &fabric_sparse_message_backward_sender_partitioned_cuda, - "Fabric sparse message backward sender phase for partitioned sender banks (CUDA)"); -} diff --git a/src/cortical/fabric/backend/cuda/message_passing/sparse_message_cuda.py b/src/cortical/fabric/backend/cuda/message_passing/sparse_message_cuda.py deleted file mode 100644 index 416bce45..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/sparse_message_cuda.py +++ /dev/null @@ -1,481 +0,0 @@ -from __future__ import annotations - -import os - -import torch -from torch.autograd import Function - -from cortical.native.extension_loader import safe_load_extension - -_mod_path = os.path.dirname(__file__) -_ext = None - - -def _merge_partitioned_sender_banks( - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - num_input_senders = int(input_k.shape[1]) - num_recurrent_senders = int(recurrent_k.shape[1]) - if num_input_senders == 0: - return recurrent_k, recurrent_v - if num_recurrent_senders == 0: - return input_k, input_v - k_all = torch.nn.functional.pad(input_k, (0, 0, 0, num_recurrent_senders)) + torch.nn.functional.pad( - recurrent_k, - (0, 0, num_input_senders, 0), - ) - v_all = torch.nn.functional.pad(input_v, (0, 0, 0, num_recurrent_senders)) + torch.nn.functional.pad( - recurrent_v, - (0, 0, num_input_senders, 0), - ) - return k_all, v_all - - -def _load_ext(): - global _ext - if _ext is not None: - return _ext - _ext = safe_load_extension( - name="fabric_sparse_message_cuda", - sources=[ - os.path.join(_mod_path, "sparse_message_binding.cpp"), - os.path.join(_mod_path, "sparse_message_kernels.cu"), - ], - extra_cflags=["-O3"], - extra_cuda_cflags=["-O3", "-Xptxas", "-O3"], - verbose=False, - ) - return _ext - - -class _FabricSparseMessageCUDA(Function): - @staticmethod - def forward( - q: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - step_flat: torch.Tensor, - distance_scale: float, - use_delay: bool, - ) -> torch.Tensor: - (msg,) = _load_ext().forward( - q.contiguous(), - k_all.contiguous(), - v_all.contiguous(), - neighbor_idx.contiguous(), - neighbor_valid.contiguous(), - edge_distance.contiguous(), - edge_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return msg - - @staticmethod - def setup_context(ctx, inputs, output): - ( - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale, - use_delay, - ) = inputs - ctx.save_for_backward(q, k_all, v_all, neighbor_idx, neighbor_valid, edge_distance, edge_delay, step_flat) - ctx.distance_scale = float(distance_scale) - ctx.use_delay = bool(use_delay) - - @staticmethod - def backward(ctx, grad_msg: torch.Tensor): - q, k_all, v_all, neighbor_idx, neighbor_valid, edge_distance, edge_delay, step_flat = ctx.saved_tensors - with torch.profiler.record_function("fabric.backward.sparse_message_superop"): - grad_q, grad_k, grad_v = _load_ext().backward( - grad_msg.contiguous(), - q.contiguous(), - k_all.contiguous(), - v_all.contiguous(), - neighbor_idx.contiguous(), - neighbor_valid.contiguous(), - edge_distance.contiguous(), - edge_delay.contiguous(), - step_flat.contiguous(), - ctx.distance_scale, - ctx.use_delay, - ) - return grad_q, grad_k, grad_v, None, None, None, None, None, None, None - - -class _FabricSparseMessagePartitionedCUDA(Function): - @staticmethod - def forward( - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - step_flat: torch.Tensor, - distance_scale: float, - use_delay: bool, - ) -> torch.Tensor: - (msg,) = _load_ext().forward_partitioned( - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - neighbor_idx.contiguous(), - neighbor_valid.contiguous(), - edge_distance.contiguous(), - edge_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return msg - - @staticmethod - def setup_context(ctx, inputs, output): - del output - ( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale, - use_delay, - ) = inputs - ctx.save_for_backward( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - ) - ctx.distance_scale = float(distance_scale) - ctx.use_delay = bool(use_delay) - - @staticmethod - def backward(ctx, grad_msg: torch.Tensor): - ( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - ) = ctx.saved_tensors - with torch.profiler.record_function("fabric.backward.sparse_message_superop"): - ( - grad_q, - grad_input_k, - grad_input_v, - grad_recurrent_k, - grad_recurrent_v, - ) = _load_ext().backward_partitioned( - grad_msg.contiguous(), - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - neighbor_idx.contiguous(), - neighbor_valid.contiguous(), - edge_distance.contiguous(), - edge_delay.contiguous(), - step_flat.contiguous(), - float(ctx.distance_scale), - bool(ctx.use_delay), - ) - return ( - grad_q, - grad_input_k, - grad_input_v, - grad_recurrent_k, - grad_recurrent_v, - None, - None, - None, - None, - None, - None, - None, - ) - - -def fabric_sparse_message_cuda( - q: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> torch.Tensor: - return _FabricSparseMessageCUDA.apply( - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale, - use_delay, - ) - - -def fabric_sparse_message_partitioned_cuda( - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> torch.Tensor: - if not q.is_cuda: - k_all, v_all = _merge_partitioned_sender_banks(input_k, input_v, recurrent_k, recurrent_v) - return fabric_sparse_message_cuda( - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=distance_scale, - use_delay=use_delay, - ) - return _FabricSparseMessagePartitionedCUDA.apply( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale, - use_delay, - ) - - -def fabric_sparse_message_backward_receiver_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> torch.Tensor: - (grad_q,) = _load_ext().backward_receiver( - grad_msg.contiguous(), - q.contiguous(), - k_all.contiguous(), - v_all.contiguous(), - neighbor_idx.contiguous(), - neighbor_valid.contiguous(), - edge_distance.contiguous(), - edge_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return grad_q - - -def fabric_sparse_message_partitioned_backward_receiver_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> torch.Tensor: - if not q.is_cuda: - k_all, v_all = _merge_partitioned_sender_banks(input_k, input_v, recurrent_k, recurrent_v) - return fabric_sparse_message_backward_receiver_cuda( - grad_msg, - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=distance_scale, - use_delay=use_delay, - ) - (grad_q,) = _load_ext().backward_receiver_partitioned( - grad_msg.contiguous(), - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - neighbor_idx.contiguous(), - neighbor_valid.contiguous(), - edge_distance.contiguous(), - edge_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return grad_q - - -def fabric_sparse_message_backward_sender_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> tuple[torch.Tensor, torch.Tensor]: - grad_k, grad_v = _load_ext().backward_sender( - grad_msg.contiguous(), - q.contiguous(), - k_all.contiguous(), - v_all.contiguous(), - neighbor_idx.contiguous(), - neighbor_valid.contiguous(), - edge_distance.contiguous(), - edge_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return grad_k, grad_v - - -def fabric_sparse_message_partitioned_backward_sender_cuda( - grad_msg: torch.Tensor, - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - step_flat: torch.Tensor, - *, - distance_scale: float, - use_delay: bool, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - if not q.is_cuda: - k_all, v_all = _merge_partitioned_sender_banks(input_k, input_v, recurrent_k, recurrent_v) - grad_k_all, grad_v_all = fabric_sparse_message_backward_sender_cuda( - grad_msg, - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=distance_scale, - use_delay=use_delay, - ) - num_input_senders = int(input_k.shape[1]) - grad_input_k, grad_recurrent_k = grad_k_all.split( - (num_input_senders, grad_k_all.shape[1] - num_input_senders), dim=1 - ) - grad_input_v, grad_recurrent_v = grad_v_all.split( - (num_input_senders, grad_v_all.shape[1] - num_input_senders), dim=1 - ) - return grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v - grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v = _load_ext().backward_sender_partitioned( - grad_msg.contiguous(), - q.contiguous(), - input_k.contiguous(), - input_v.contiguous(), - recurrent_k.contiguous(), - recurrent_v.contiguous(), - neighbor_idx.contiguous(), - neighbor_valid.contiguous(), - edge_distance.contiguous(), - edge_delay.contiguous(), - step_flat.contiguous(), - float(distance_scale), - bool(use_delay), - ) - return grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v - - -__all__ = [ - "fabric_sparse_message_backward_receiver_cuda", - "fabric_sparse_message_backward_sender_cuda", - "fabric_sparse_message_cuda", - "fabric_sparse_message_partitioned_backward_receiver_cuda", - "fabric_sparse_message_partitioned_backward_sender_cuda", - "fabric_sparse_message_partitioned_cuda", -] diff --git a/src/cortical/fabric/backend/cuda/message_passing/sparse_message_kernels.cu b/src/cortical/fabric/backend/cuda/message_passing/sparse_message_kernels.cu deleted file mode 100644 index 28d810d4..00000000 --- a/src/cortical/fabric/backend/cuda/message_passing/sparse_message_kernels.cu +++ /dev/null @@ -1,983 +0,0 @@ -#include -#include -#include -#include - -#include -#include -#include - -namespace { - -constexpr int kMaxGridY = 65535; - -inline void check_cuda_tensor(const at::Tensor &tensor, const char *name) { - TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); - TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); -} - -inline void check_launch(const char *name) { - const cudaError_t err = cudaGetLastError(); - TORCH_CHECK(err == cudaSuccess, name, " launch failed: ", - cudaGetErrorString(err)); -} - -__global__ void merge_partitioned_sender_bank_kernel( - const float* __restrict__ input_bank, - const float* __restrict__ recurrent_bank, - float* __restrict__ merged_bank, - int BT, - int input_senders, - int recurrent_senders, - int dim) { - const int total_senders = input_senders + recurrent_senders; - const int64_t total = static_cast(BT) * total_senders * dim; - const int64_t stride = static_cast(blockDim.x) * gridDim.x; - for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - linear < total; - linear += stride) { - const int d = static_cast(linear % dim); - const int sender = static_cast((linear / dim) % total_senders); - const int bt = static_cast(linear / (static_cast(total_senders) * dim)); - if (sender < input_senders) { - merged_bank[linear] = - input_bank[((static_cast(bt) * input_senders + sender) * dim) + d]; - } else { - const int recurrent_sender = sender - input_senders; - merged_bank[linear] = - recurrent_bank[((static_cast(bt) * recurrent_senders + recurrent_sender) * dim) + d]; - } - } -} - -__global__ void fabric_sparse_message_forward_kernel( - const float *__restrict__ q, const float *__restrict__ k_all, - const float *__restrict__ v_all, const int64_t *__restrict__ neighbor_idx, - const bool *__restrict__ neighbor_valid, - const float *__restrict__ edge_distance, - const int64_t *__restrict__ edge_delay, const int64_t *__restrict__ step_flat, - float *__restrict__ out, int BT, int R, int S, int M, int d_k, int d_v, - float inv_sqrt_dk, float distance_scale, bool use_delay) { - const int bt = blockIdx.x; - const int recv = static_cast(blockIdx.z) * kMaxGridY + - static_cast(blockIdx.y); - const int tid = threadIdx.x; - if (bt >= BT || recv >= R) { - return; - } - - extern __shared__ float shared[]; - float *logits = shared; - float *weights = shared + M; - - const int neighbor_base = recv * M; - const int step_value = use_delay ? static_cast(step_flat[bt]) : 0; - if (tid == 0) { - float max_logit = -std::numeric_limits::infinity(); - for (int edge = 0; edge < M; ++edge) { - const int edge_idx = neighbor_base + edge; - const bool valid = neighbor_valid[edge_idx] && - (!use_delay || - static_cast(edge_delay[edge_idx]) <= step_value); - if (!valid) { - logits[edge] = -std::numeric_limits::infinity(); - weights[edge] = 0.0f; - continue; - } - const int send = static_cast(neighbor_idx[edge_idx]); - const int q_offset = recv * d_k; - const int k_offset = (bt * S + send) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_all[k_offset + d]; - } - const float logit = dot * inv_sqrt_dk - distance_scale * edge_distance[edge_idx]; - logits[edge] = logit; - max_logit = logit > max_logit ? logit : max_logit; - } - - if (!isfinite(max_logit)) { - for (int edge = 0; edge < M; ++edge) { - weights[edge] = 0.0f; - } - } else { - float sum = 0.0f; - for (int edge = 0; edge < M; ++edge) { - const float weight = isfinite(logits[edge]) ? expf(logits[edge] - max_logit) : 0.0f; - weights[edge] = weight; - sum += weight; - } - const float inv_sum = sum > 0.0f ? 1.0f / sum : 0.0f; - for (int edge = 0; edge < M; ++edge) { - weights[edge] *= inv_sum; - } - } - } - __syncthreads(); - - const int out_offset = (bt * R + recv) * d_v; - for (int d = tid; d < d_v; d += blockDim.x) { - float acc = 0.0f; - for (int edge = 0; edge < M; ++edge) { - const float weight = weights[edge]; - if (weight == 0.0f) { - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - acc += weight * v_all[(bt * S + send) * d_v + d]; - } - out[out_offset + d] = acc; - } -} - -__global__ void fabric_sparse_message_backward_kernel( - const float *__restrict__ grad_msg, const float *__restrict__ q, - const float *__restrict__ k_all, const float *__restrict__ v_all, - const int64_t *__restrict__ neighbor_idx, - const bool *__restrict__ neighbor_valid, - const float *__restrict__ edge_distance, - const int64_t *__restrict__ edge_delay, const int64_t *__restrict__ step_flat, - float *__restrict__ grad_q, float *__restrict__ grad_k, - float *__restrict__ grad_v, int BT, int R, int S, int M, int d_k, int d_v, - float inv_sqrt_dk, float distance_scale, bool use_delay) { - const int bt = blockIdx.x; - const int recv = static_cast(blockIdx.z) * kMaxGridY + - static_cast(blockIdx.y); - const int tid = threadIdx.x; - if (bt >= BT || recv >= R) { - return; - } - - extern __shared__ float shared[]; - float *logits = shared; - float *weights = shared + M; - float *dweight = shared + (2 * M); - float *dlogit = shared + (3 * M); - - const int neighbor_base = recv * M; - const int step_value = use_delay ? static_cast(step_flat[bt]) : 0; - if (tid == 0) { - float max_logit = -std::numeric_limits::infinity(); - for (int edge = 0; edge < M; ++edge) { - const int edge_idx = neighbor_base + edge; - const bool valid = neighbor_valid[edge_idx] && - (!use_delay || - static_cast(edge_delay[edge_idx]) <= step_value); - if (!valid) { - logits[edge] = -std::numeric_limits::infinity(); - weights[edge] = 0.0f; - dweight[edge] = 0.0f; - dlogit[edge] = 0.0f; - continue; - } - const int send = static_cast(neighbor_idx[edge_idx]); - const int q_offset = recv * d_k; - const int k_offset = (bt * S + send) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_all[k_offset + d]; - } - const float logit = dot * inv_sqrt_dk - distance_scale * edge_distance[edge_idx]; - logits[edge] = logit; - max_logit = logit > max_logit ? logit : max_logit; - } - - if (!isfinite(max_logit)) { - for (int edge = 0; edge < M; ++edge) { - weights[edge] = 0.0f; - dweight[edge] = 0.0f; - dlogit[edge] = 0.0f; - } - } else { - float sum = 0.0f; - for (int edge = 0; edge < M; ++edge) { - const float weight = isfinite(logits[edge]) ? expf(logits[edge] - max_logit) : 0.0f; - weights[edge] = weight; - sum += weight; - } - const float inv_sum = sum > 0.0f ? 1.0f / sum : 0.0f; - float weighted_sum = 0.0f; - const int grad_offset = (bt * R + recv) * d_v; - for (int edge = 0; edge < M; ++edge) { - weights[edge] *= inv_sum; - if (weights[edge] == 0.0f) { - dweight[edge] = 0.0f; - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - const int v_offset = (bt * S + send) * d_v; - float dot = 0.0f; - for (int d = 0; d < d_v; ++d) { - dot += grad_msg[grad_offset + d] * v_all[v_offset + d]; - } - dweight[edge] = dot; - weighted_sum += weights[edge] * dot; - } - for (int edge = 0; edge < M; ++edge) { - dlogit[edge] = weights[edge] * (dweight[edge] - weighted_sum); - } - } - } - __syncthreads(); - - const int grad_offset = (bt * R + recv) * d_v; - for (int d = tid; d < d_v; d += blockDim.x) { - const float grad_out = grad_msg[grad_offset + d]; - for (int edge = 0; edge < M; ++edge) { - const float weight = weights[edge]; - if (weight == 0.0f) { - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - atomicAdd(&grad_v[(bt * S + send) * d_v + d], weight * grad_out); - } - } - - for (int d = tid; d < d_k; d += blockDim.x) { - const float q_val = q[recv * d_k + d]; - float recv_grad = 0.0f; - for (int edge = 0; edge < M; ++edge) { - const float dl = dlogit[edge]; - if (dl == 0.0f) { - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - const float scaled = dl * inv_sqrt_dk; - atomicAdd(&grad_k[(bt * S + send) * d_k + d], scaled * q_val); - recv_grad += scaled * k_all[(bt * S + send) * d_k + d]; - } - atomicAdd(&grad_q[recv * d_k + d], recv_grad); - } -} - -__global__ void fabric_sparse_message_backward_receiver_kernel( - const float *__restrict__ grad_msg, const float *__restrict__ q, - const float *__restrict__ k_all, const float *__restrict__ v_all, - const int64_t *__restrict__ neighbor_idx, - const bool *__restrict__ neighbor_valid, - const float *__restrict__ edge_distance, - const int64_t *__restrict__ edge_delay, const int64_t *__restrict__ step_flat, - float *__restrict__ grad_q, int BT, int R, int S, int M, int d_k, int d_v, - float inv_sqrt_dk, float distance_scale, bool use_delay) { - const int bt = blockIdx.x; - const int recv = static_cast(blockIdx.z) * kMaxGridY + - static_cast(blockIdx.y); - const int tid = threadIdx.x; - if (bt >= BT || recv >= R) { - return; - } - - extern __shared__ float shared[]; - float *logits = shared; - float *weights = shared + M; - float *dweight = shared + (2 * M); - float *dlogit = shared + (3 * M); - - const int neighbor_base = recv * M; - const int step_value = use_delay ? static_cast(step_flat[bt]) : 0; - if (tid == 0) { - float max_logit = -std::numeric_limits::infinity(); - for (int edge = 0; edge < M; ++edge) { - const int edge_idx = neighbor_base + edge; - const bool valid = neighbor_valid[edge_idx] && - (!use_delay || - static_cast(edge_delay[edge_idx]) <= step_value); - if (!valid) { - logits[edge] = -std::numeric_limits::infinity(); - weights[edge] = 0.0f; - dweight[edge] = 0.0f; - dlogit[edge] = 0.0f; - continue; - } - const int send = static_cast(neighbor_idx[edge_idx]); - const int q_offset = recv * d_k; - const int k_offset = (bt * S + send) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_all[k_offset + d]; - } - const float logit = dot * inv_sqrt_dk - distance_scale * edge_distance[edge_idx]; - logits[edge] = logit; - max_logit = logit > max_logit ? logit : max_logit; - } - - if (!isfinite(max_logit)) { - for (int edge = 0; edge < M; ++edge) { - weights[edge] = 0.0f; - dweight[edge] = 0.0f; - dlogit[edge] = 0.0f; - } - } else { - float sum = 0.0f; - for (int edge = 0; edge < M; ++edge) { - const float weight = isfinite(logits[edge]) ? expf(logits[edge] - max_logit) : 0.0f; - weights[edge] = weight; - sum += weight; - } - const float inv_sum = sum > 0.0f ? 1.0f / sum : 0.0f; - float weighted_sum = 0.0f; - const int grad_offset = (bt * R + recv) * d_v; - for (int edge = 0; edge < M; ++edge) { - weights[edge] *= inv_sum; - if (weights[edge] == 0.0f) { - dweight[edge] = 0.0f; - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - const int v_offset = (bt * S + send) * d_v; - float dot = 0.0f; - for (int d = 0; d < d_v; ++d) { - dot += grad_msg[grad_offset + d] * v_all[v_offset + d]; - } - dweight[edge] = dot; - weighted_sum += weights[edge] * dot; - } - for (int edge = 0; edge < M; ++edge) { - dlogit[edge] = weights[edge] * (dweight[edge] - weighted_sum); - } - } - } - __syncthreads(); - - for (int d = tid; d < d_k; d += blockDim.x) { - float recv_grad = 0.0f; - for (int edge = 0; edge < M; ++edge) { - const float dl = dlogit[edge]; - if (dl == 0.0f) { - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - recv_grad += dl * inv_sqrt_dk * k_all[(bt * S + send) * d_k + d]; - } - atomicAdd(&grad_q[recv * d_k + d], recv_grad); - } -} - -__global__ void fabric_sparse_message_backward_sender_kernel( - const float *__restrict__ grad_msg, const float *__restrict__ q, - const float *__restrict__ k_all, const float *__restrict__ v_all, - const int64_t *__restrict__ neighbor_idx, - const bool *__restrict__ neighbor_valid, - const float *__restrict__ edge_distance, - const int64_t *__restrict__ edge_delay, const int64_t *__restrict__ step_flat, - float *__restrict__ grad_k, float *__restrict__ grad_v, int BT, int R, int S, - int M, int d_k, int d_v, float inv_sqrt_dk, float distance_scale, - bool use_delay) { - const int bt = blockIdx.x; - const int recv = static_cast(blockIdx.z) * kMaxGridY + - static_cast(blockIdx.y); - const int tid = threadIdx.x; - if (bt >= BT || recv >= R) { - return; - } - - extern __shared__ float shared[]; - float *logits = shared; - float *weights = shared + M; - float *dweight = shared + (2 * M); - float *dlogit = shared + (3 * M); - - const int neighbor_base = recv * M; - const int step_value = use_delay ? static_cast(step_flat[bt]) : 0; - if (tid == 0) { - float max_logit = -std::numeric_limits::infinity(); - for (int edge = 0; edge < M; ++edge) { - const int edge_idx = neighbor_base + edge; - const bool valid = neighbor_valid[edge_idx] && - (!use_delay || - static_cast(edge_delay[edge_idx]) <= step_value); - if (!valid) { - logits[edge] = -std::numeric_limits::infinity(); - weights[edge] = 0.0f; - dweight[edge] = 0.0f; - dlogit[edge] = 0.0f; - continue; - } - const int send = static_cast(neighbor_idx[edge_idx]); - const int q_offset = recv * d_k; - const int k_offset = (bt * S + send) * d_k; - float dot = 0.0f; - for (int d = 0; d < d_k; ++d) { - dot += q[q_offset + d] * k_all[k_offset + d]; - } - const float logit = dot * inv_sqrt_dk - distance_scale * edge_distance[edge_idx]; - logits[edge] = logit; - max_logit = logit > max_logit ? logit : max_logit; - } - - if (!isfinite(max_logit)) { - for (int edge = 0; edge < M; ++edge) { - weights[edge] = 0.0f; - dweight[edge] = 0.0f; - dlogit[edge] = 0.0f; - } - } else { - float sum = 0.0f; - for (int edge = 0; edge < M; ++edge) { - const float weight = isfinite(logits[edge]) ? expf(logits[edge] - max_logit) : 0.0f; - weights[edge] = weight; - sum += weight; - } - const float inv_sum = sum > 0.0f ? 1.0f / sum : 0.0f; - float weighted_sum = 0.0f; - const int grad_offset = (bt * R + recv) * d_v; - for (int edge = 0; edge < M; ++edge) { - weights[edge] *= inv_sum; - if (weights[edge] == 0.0f) { - dweight[edge] = 0.0f; - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - const int v_offset = (bt * S + send) * d_v; - float dot = 0.0f; - for (int d = 0; d < d_v; ++d) { - dot += grad_msg[grad_offset + d] * v_all[v_offset + d]; - } - dweight[edge] = dot; - weighted_sum += weights[edge] * dot; - } - for (int edge = 0; edge < M; ++edge) { - dlogit[edge] = weights[edge] * (dweight[edge] - weighted_sum); - } - } - } - __syncthreads(); - - const int grad_offset = (bt * R + recv) * d_v; - for (int d = tid; d < d_v; d += blockDim.x) { - const float grad_out = grad_msg[grad_offset + d]; - for (int edge = 0; edge < M; ++edge) { - const float weight = weights[edge]; - if (weight == 0.0f) { - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - atomicAdd(&grad_v[(bt * S + send) * d_v + d], weight * grad_out); - } - } - - for (int d = tid; d < d_k; d += blockDim.x) { - const float q_val = q[recv * d_k + d]; - for (int edge = 0; edge < M; ++edge) { - const float dl = dlogit[edge]; - if (dl == 0.0f) { - continue; - } - const int send = static_cast(neighbor_idx[neighbor_base + edge]); - atomicAdd(&grad_k[(bt * S + send) * d_k + d], dl * inv_sqrt_dk * q_val); - } - } -} - -} // namespace - -namespace { - -void check_sparse_message_backward_inputs( - const at::Tensor &grad_msg, const at::Tensor &q, const at::Tensor &k_all, - const at::Tensor &v_all, const at::Tensor &neighbor_idx, - const at::Tensor &neighbor_valid, const at::Tensor &edge_distance, - const at::Tensor &edge_delay, const at::Tensor &step_flat) { - check_cuda_tensor(grad_msg, "grad_msg"); - check_cuda_tensor(q, "q"); - check_cuda_tensor(k_all, "k_all"); - check_cuda_tensor(v_all, "v_all"); - check_cuda_tensor(neighbor_idx, "neighbor_idx"); - check_cuda_tensor(neighbor_valid, "neighbor_valid"); - check_cuda_tensor(edge_distance, "edge_distance"); - check_cuda_tensor(edge_delay, "edge_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(grad_msg.scalar_type() == at::kFloat, - "grad_msg must be float32"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(k_all.scalar_type() == at::kFloat, "k_all must be float32"); - TORCH_CHECK(v_all.scalar_type() == at::kFloat, "v_all must be float32"); - - TORCH_CHECK(q.dim() == 2, "q must have shape [R, d_k]"); - TORCH_CHECK(k_all.dim() == 3, "k_all must have shape [BT, S, d_k]"); - TORCH_CHECK(v_all.dim() == 3, "v_all must have shape [BT, S, d_v]"); - TORCH_CHECK(neighbor_idx.dim() == 2, - "neighbor_idx must have shape [R, M]"); - TORCH_CHECK(neighbor_valid.sizes() == neighbor_idx.sizes(), - "neighbor_valid must match neighbor_idx shape"); - TORCH_CHECK(edge_distance.sizes() == neighbor_idx.sizes(), - "edge_distance must match neighbor_idx shape"); - TORCH_CHECK(edge_delay.sizes() == neighbor_idx.sizes(), - "edge_delay must match neighbor_idx shape"); - TORCH_CHECK(step_flat.dim() == 1, "step_flat must have shape [BT]"); - TORCH_CHECK(k_all.size(0) == step_flat.size(0), - "step_flat length must equal BT"); - TORCH_CHECK(q.size(0) == neighbor_idx.size(0), - "q and neighbor_idx must agree on receiver count"); - TORCH_CHECK(k_all.size(1) == v_all.size(1), - "k_all and v_all must agree on sender count"); - TORCH_CHECK(q.size(1) == k_all.size(2), - "q and k_all must agree on d_k"); - TORCH_CHECK(grad_msg.dim() == 3, "grad_msg must have shape [BT, R, d_v]"); - TORCH_CHECK(grad_msg.size(0) == k_all.size(0), - "grad_msg and k_all must agree on BT"); - TORCH_CHECK(grad_msg.size(1) == q.size(0), - "grad_msg and q must agree on receiver count"); - TORCH_CHECK(grad_msg.size(2) == v_all.size(2), - "grad_msg and v_all must agree on d_v"); -} - -void check_sparse_message_partitioned_inputs( - const at::Tensor& q, - const at::Tensor& input_k, - const at::Tensor& input_v, - const at::Tensor& recurrent_k, - const at::Tensor& recurrent_v, - const at::Tensor& neighbor_idx, - const at::Tensor& neighbor_valid, - const at::Tensor& edge_distance, - const at::Tensor& edge_delay, - const at::Tensor& step_flat) { - check_cuda_tensor(q, "q"); - check_cuda_tensor(input_k, "input_k"); - check_cuda_tensor(input_v, "input_v"); - check_cuda_tensor(recurrent_k, "recurrent_k"); - check_cuda_tensor(recurrent_v, "recurrent_v"); - check_cuda_tensor(neighbor_idx, "neighbor_idx"); - check_cuda_tensor(neighbor_valid, "neighbor_valid"); - check_cuda_tensor(edge_distance, "edge_distance"); - check_cuda_tensor(edge_delay, "edge_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(input_k.scalar_type() == at::kFloat, "input_k must be float32"); - TORCH_CHECK(input_v.scalar_type() == at::kFloat, "input_v must be float32"); - TORCH_CHECK(recurrent_k.scalar_type() == at::kFloat, "recurrent_k must be float32"); - TORCH_CHECK(recurrent_v.scalar_type() == at::kFloat, "recurrent_v must be float32"); - TORCH_CHECK(q.dim() == 2, "q must have shape [R, d_k]"); - TORCH_CHECK(input_k.dim() == 3, "input_k must have shape [BT, Si, d_k]"); - TORCH_CHECK(input_v.dim() == 3, "input_v must have shape [BT, Si, d_v]"); - TORCH_CHECK(recurrent_k.dim() == 3, "recurrent_k must have shape [BT, Sr, d_k]"); - TORCH_CHECK(recurrent_v.dim() == 3, "recurrent_v must have shape [BT, Sr, d_v]"); - TORCH_CHECK(neighbor_idx.dim() == 2, "neighbor_idx must have shape [R, M]"); - TORCH_CHECK(neighbor_valid.sizes() == neighbor_idx.sizes(), - "neighbor_valid must match neighbor_idx shape"); - TORCH_CHECK(edge_distance.sizes() == neighbor_idx.sizes(), - "edge_distance must match neighbor_idx shape"); - TORCH_CHECK(edge_delay.sizes() == neighbor_idx.sizes(), - "edge_delay must match neighbor_idx shape"); - TORCH_CHECK(step_flat.dim() == 1, "step_flat must have shape [BT]"); - TORCH_CHECK(input_k.size(0) == input_v.size(0) && - input_k.size(0) == recurrent_k.size(0) && - input_k.size(0) == recurrent_v.size(0), - "partitioned sender banks must agree on BT"); - TORCH_CHECK(step_flat.size(0) == input_k.size(0), - "step_flat length must equal BT"); - TORCH_CHECK(input_k.size(2) == q.size(1) && recurrent_k.size(2) == q.size(1), - "partitioned K banks must match q d_k"); - TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), - "partitioned V banks must agree on d_v"); - TORCH_CHECK(q.size(0) == neighbor_idx.size(0), - "q and neighbor_idx must agree on receiver count"); -} - -std::pair merge_partitioned_sender_banks_cuda( - const at::Tensor& input_k, - const at::Tensor& input_v, - const at::Tensor& recurrent_k, - const at::Tensor& recurrent_v) { - const int BT = static_cast(input_k.size(0)); - const int input_senders = static_cast(input_k.size(1)); - const int recurrent_senders = static_cast(recurrent_k.size(1)); - const int d_k = static_cast(input_k.size(2)); - const int d_v = static_cast(input_v.size(2)); - auto k_all = at::empty({BT, input_senders + recurrent_senders, d_k}, input_k.options()); - auto v_all = at::empty({BT, input_senders + recurrent_senders, d_v}, input_v.options()); - const int threads = 256; - const int total_k = std::max(1, BT * (input_senders + recurrent_senders) * d_k); - const int total_v = std::max(1, BT * (input_senders + recurrent_senders) * d_v); - const int blocks_k = std::max(1, std::min(65535, (total_k + threads - 1) / threads)); - const int blocks_v = std::max(1, std::min(65535, (total_v + threads - 1) / threads)); - auto stream = at::cuda::getCurrentCUDAStream(); - merge_partitioned_sender_bank_kernel<<>>( - input_k.data_ptr(), - recurrent_k.data_ptr(), - k_all.data_ptr(), - BT, - input_senders, - recurrent_senders, - d_k); - check_launch("merge_partitioned_sender_bank_kernel(k)"); - merge_partitioned_sender_bank_kernel<<>>( - input_v.data_ptr(), - recurrent_v.data_ptr(), - v_all.data_ptr(), - BT, - input_senders, - recurrent_senders, - d_v); - check_launch("merge_partitioned_sender_bank_kernel(v)"); - return {k_all, v_all}; -} - -dim3 sparse_message_grid(int BT, int R) { - const unsigned int recv_tiles = - static_cast((R + kMaxGridY - 1) / kMaxGridY); - TORCH_CHECK(recv_tiles <= 65535, - "receiver tiling exceeded CUDA grid.z limit"); - return dim3(BT, std::min(R, kMaxGridY), recv_tiles); -} - -} // namespace - -std::vector fabric_sparse_message_forward_cuda( - at::Tensor q, at::Tensor k_all, at::Tensor v_all, at::Tensor neighbor_idx, - at::Tensor neighbor_valid, at::Tensor edge_distance, at::Tensor edge_delay, - at::Tensor step_flat, double distance_scale, bool use_delay) { - check_cuda_tensor(q, "q"); - check_cuda_tensor(k_all, "k_all"); - check_cuda_tensor(v_all, "v_all"); - check_cuda_tensor(neighbor_idx, "neighbor_idx"); - check_cuda_tensor(neighbor_valid, "neighbor_valid"); - check_cuda_tensor(edge_distance, "edge_distance"); - check_cuda_tensor(edge_delay, "edge_delay"); - check_cuda_tensor(step_flat, "step_flat"); - TORCH_CHECK(q.scalar_type() == at::kFloat, "q must be float32"); - TORCH_CHECK(k_all.scalar_type() == at::kFloat, "k_all must be float32"); - TORCH_CHECK(v_all.scalar_type() == at::kFloat, "v_all must be float32"); - TORCH_CHECK(edge_distance.scalar_type() == at::kFloat, - "edge_distance must be float32"); - TORCH_CHECK(q.dim() == 2, "q must have shape [N, d_k]"); - TORCH_CHECK(k_all.dim() == 3, "k_all must have shape [BT, N, d_k]"); - TORCH_CHECK(v_all.dim() == 3, "v_all must have shape [BT, N, d_v]"); - TORCH_CHECK(neighbor_idx.dim() == 2, - "neighbor_idx must have shape [N, M]"); - TORCH_CHECK(neighbor_valid.sizes() == neighbor_idx.sizes(), - "neighbor_valid must match neighbor_idx shape"); - TORCH_CHECK(edge_distance.sizes() == neighbor_idx.sizes(), - "edge_distance must match neighbor_idx shape"); - TORCH_CHECK(edge_delay.sizes() == neighbor_idx.sizes(), - "edge_delay must match neighbor_idx shape"); - TORCH_CHECK(step_flat.dim() == 1, "step_flat must have shape [BT]"); - TORCH_CHECK(k_all.size(0) == step_flat.size(0), - "step_flat length must equal BT"); - TORCH_CHECK(q.size(0) == neighbor_idx.size(0), - "q and neighbor_idx must agree on receiver count"); - TORCH_CHECK(k_all.size(1) == v_all.size(1), - "k_all and v_all must agree on sender count"); - TORCH_CHECK(q.size(1) == k_all.size(2), - "q and k_all must agree on d_k"); - - const int BT = static_cast(k_all.size(0)); - const int R = static_cast(q.size(0)); - const int S = static_cast(k_all.size(1)); - const int M = static_cast(neighbor_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(v_all.size(2)); - auto out = at::zeros({BT, R, d_v}, v_all.options()); - - const unsigned int recv_tiles = - static_cast((R + kMaxGridY - 1) / kMaxGridY); - TORCH_CHECK(recv_tiles <= 65535, - "receiver tiling exceeded CUDA grid.z limit"); - const dim3 grid(BT, std::min(R, kMaxGridY), recv_tiles); - const int threads = 32; - const size_t shared_bytes = static_cast(2 * M) * sizeof(float); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_sparse_message_forward_kernel<<>>( - q.data_ptr(), k_all.data_ptr(), v_all.data_ptr(), - neighbor_idx.data_ptr(), neighbor_valid.data_ptr(), - edge_distance.data_ptr(), edge_delay.data_ptr(), - step_flat.data_ptr(), out.data_ptr(), BT, R, S, M, d_k, d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_sparse_message_forward_kernel"); - return {out}; -} - -std::vector fabric_sparse_message_forward_partitioned_cuda( - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor neighbor_idx, - at::Tensor neighbor_valid, - at::Tensor edge_distance, - at::Tensor edge_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay) { - check_sparse_message_partitioned_inputs( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat); - auto [k_all, v_all] = merge_partitioned_sender_banks_cuda( - input_k, - input_v, - recurrent_k, - recurrent_v); - return fabric_sparse_message_forward_cuda( - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale, - use_delay); -} - -std::vector fabric_sparse_message_backward_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor neighbor_idx, at::Tensor neighbor_valid, at::Tensor edge_distance, - at::Tensor edge_delay, at::Tensor step_flat, double distance_scale, - bool use_delay) { - check_sparse_message_backward_inputs( - grad_msg, q, k_all, v_all, neighbor_idx, neighbor_valid, - edge_distance, edge_delay, step_flat); - - const int BT = static_cast(k_all.size(0)); - const int R = static_cast(q.size(0)); - const int S = static_cast(k_all.size(1)); - const int M = static_cast(neighbor_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(v_all.size(2)); - auto grad_q = at::zeros_like(q); - auto grad_k = at::zeros_like(k_all); - auto grad_v = at::zeros_like(v_all); - - const dim3 grid = sparse_message_grid(BT, R); - const int threads = 32; - const size_t shared_bytes = static_cast(4 * M) * sizeof(float); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_sparse_message_backward_kernel<<>>( - grad_msg.data_ptr(), q.data_ptr(), k_all.data_ptr(), - v_all.data_ptr(), neighbor_idx.data_ptr(), - neighbor_valid.data_ptr(), edge_distance.data_ptr(), - edge_delay.data_ptr(), step_flat.data_ptr(), - grad_q.data_ptr(), grad_k.data_ptr(), - grad_v.data_ptr(), BT, R, S, M, d_k, d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_sparse_message_backward_kernel"); - return {grad_q, grad_k, grad_v}; -} - -std::vector fabric_sparse_message_backward_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor neighbor_idx, - at::Tensor neighbor_valid, - at::Tensor edge_distance, - at::Tensor edge_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay) { - check_sparse_message_partitioned_inputs( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat); - auto [k_all, v_all] = merge_partitioned_sender_banks_cuda( - input_k, - input_v, - recurrent_k, - recurrent_v); - auto grads = fabric_sparse_message_backward_cuda( - grad_msg, - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale, - use_delay); - const int input_senders = static_cast(input_k.size(1)); - auto grad_k_split = grads[1].split_with_sizes( - {input_senders, grads[1].size(1) - input_senders}, - 1); - auto grad_v_split = grads[2].split_with_sizes( - {input_senders, grads[2].size(1) - input_senders}, - 1); - return {grads[0], grad_k_split[0], grad_v_split[0], grad_k_split[1], grad_v_split[1]}; -} - -std::vector fabric_sparse_message_backward_receiver_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor neighbor_idx, at::Tensor neighbor_valid, at::Tensor edge_distance, - at::Tensor edge_delay, at::Tensor step_flat, double distance_scale, - bool use_delay) { - check_sparse_message_backward_inputs( - grad_msg, q, k_all, v_all, neighbor_idx, neighbor_valid, - edge_distance, edge_delay, step_flat); - - const int BT = static_cast(k_all.size(0)); - const int R = static_cast(q.size(0)); - const int S = static_cast(k_all.size(1)); - const int M = static_cast(neighbor_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(v_all.size(2)); - auto grad_q = at::zeros_like(q); - - const dim3 grid = sparse_message_grid(BT, R); - const int threads = 32; - const size_t shared_bytes = static_cast(4 * M) * sizeof(float); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_sparse_message_backward_receiver_kernel<<>>( - grad_msg.data_ptr(), q.data_ptr(), k_all.data_ptr(), - v_all.data_ptr(), neighbor_idx.data_ptr(), - neighbor_valid.data_ptr(), edge_distance.data_ptr(), - edge_delay.data_ptr(), step_flat.data_ptr(), - grad_q.data_ptr(), BT, R, S, M, d_k, d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_sparse_message_backward_receiver_kernel"); - return {grad_q}; -} - -std::vector fabric_sparse_message_backward_receiver_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor neighbor_idx, - at::Tensor neighbor_valid, - at::Tensor edge_distance, - at::Tensor edge_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay) { - check_sparse_message_partitioned_inputs( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat); - auto [k_all, v_all] = merge_partitioned_sender_banks_cuda( - input_k, - input_v, - recurrent_k, - recurrent_v); - return fabric_sparse_message_backward_receiver_cuda( - grad_msg, - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale, - use_delay); -} - -std::vector fabric_sparse_message_backward_sender_cuda( - at::Tensor grad_msg, at::Tensor q, at::Tensor k_all, at::Tensor v_all, - at::Tensor neighbor_idx, at::Tensor neighbor_valid, at::Tensor edge_distance, - at::Tensor edge_delay, at::Tensor step_flat, double distance_scale, - bool use_delay) { - check_sparse_message_backward_inputs( - grad_msg, q, k_all, v_all, neighbor_idx, neighbor_valid, - edge_distance, edge_delay, step_flat); - - const int BT = static_cast(k_all.size(0)); - const int R = static_cast(q.size(0)); - const int S = static_cast(k_all.size(1)); - const int M = static_cast(neighbor_idx.size(1)); - const int d_k = static_cast(q.size(1)); - const int d_v = static_cast(v_all.size(2)); - auto grad_k = at::zeros_like(k_all); - auto grad_v = at::zeros_like(v_all); - - const dim3 grid = sparse_message_grid(BT, R); - const int threads = 32; - const size_t shared_bytes = static_cast(4 * M) * sizeof(float); - auto stream = at::cuda::getCurrentCUDAStream(); - fabric_sparse_message_backward_sender_kernel<<>>( - grad_msg.data_ptr(), q.data_ptr(), k_all.data_ptr(), - v_all.data_ptr(), neighbor_idx.data_ptr(), - neighbor_valid.data_ptr(), edge_distance.data_ptr(), - edge_delay.data_ptr(), step_flat.data_ptr(), - grad_k.data_ptr(), grad_v.data_ptr(), BT, R, S, M, d_k, d_v, - 1.0f / std::sqrt(static_cast(d_k)), - static_cast(distance_scale), use_delay); - check_launch("fabric_sparse_message_backward_sender_kernel"); - return {grad_k, grad_v}; -} - -std::vector fabric_sparse_message_backward_sender_partitioned_cuda( - at::Tensor grad_msg, - at::Tensor q, - at::Tensor input_k, - at::Tensor input_v, - at::Tensor recurrent_k, - at::Tensor recurrent_v, - at::Tensor neighbor_idx, - at::Tensor neighbor_valid, - at::Tensor edge_distance, - at::Tensor edge_delay, - at::Tensor step_flat, - double distance_scale, - bool use_delay) { - check_sparse_message_partitioned_inputs( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat); - auto [k_all, v_all] = merge_partitioned_sender_banks_cuda( - input_k, - input_v, - recurrent_k, - recurrent_v); - auto grads = fabric_sparse_message_backward_sender_cuda( - grad_msg, - q, - k_all, - v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale, - use_delay); - const int input_senders = static_cast(input_k.size(1)); - auto grad_k_split = grads[0].split_with_sizes( - {input_senders, grads[0].size(1) - input_senders}, - 1); - auto grad_v_split = grads[1].split_with_sizes( - {input_senders, grads[1].size(1) - input_senders}, - 1); - return {grad_k_split[0], grad_v_split[0], grad_k_split[1], grad_v_split[1]}; -} diff --git a/src/cortical/fabric/backend/cuda/message_rules/dot_product.cuh b/src/cortical/fabric/backend/cuda/message_rules/dot_product.cuh deleted file mode 100644 index 7edf2205..00000000 --- a/src/cortical/fabric/backend/cuda/message_rules/dot_product.cuh +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once - -#include - -#include "cortical/fabric/backend/cuda/nn/ir.cuh" - -namespace fabric::cuda::message_rules { - -struct DotProduct { - static fabric::cuda::nn::MessageRuleIR message_rule_ir_host( - int64_t receiver_slot_dim, - int64_t sender_public_dim, - int64_t key_dim, - int64_t value_dim, - int64_t projected_message_dim) { - if (receiver_slot_dim <= 0 || sender_public_dim <= 0 || key_dim <= 0 || value_dim <= 0 || - projected_message_dim <= 0) { - throw std::invalid_argument("fabric.cuda.nn dot-product message rule dimensions must be positive"); - } - fabric::cuda::nn::MessageRuleBuilder builder("dot_product"); - const int receiver_slot = builder.receiver_source(fabric::cuda::nn::MessageSourceKind::ReceiverSlot); - const int sender_public = builder.sender_source( - fabric::cuda::nn::MessageSourceKind::SenderPublicPrev, - fabric::cuda::nn::ResetPolicy::ZeroSourceRows, - fabric::cuda::nn::ResetScope::BatchRow); - const int edge_distance = builder.edge_source(fabric::cuda::nn::MessageSourceKind::EdgeDistance); - const int q_weight = builder.parameter( - "q_weight", - fabric::cuda::nn::MessageParameterRole::Projection, - {receiver_slot_dim, key_dim}, - fabric::cuda::nn::MessageSharingScope::RuleGlobal); - const int k_weight = builder.parameter( - "k_weight", - fabric::cuda::nn::MessageParameterRole::Projection, - {sender_public_dim, key_dim}, - fabric::cuda::nn::MessageSharingScope::SenderGroupShared, - fabric::cuda::nn::MessageIndexMapKind::NamedGroup, - "message_tiles"); - const int v_weight = builder.parameter( - "v_weight", - fabric::cuda::nn::MessageParameterRole::Projection, - {sender_public_dim, value_dim}, - fabric::cuda::nn::MessageSharingScope::SenderGroupShared, - fabric::cuda::nn::MessageIndexMapKind::NamedGroup, - "message_tiles"); - const int out_weight = builder.parameter( - "out_weight", - fabric::cuda::nn::MessageParameterRole::Projection, - {value_dim, projected_message_dim}, - fabric::cuda::nn::MessageSharingScope::RuleGlobal); - const int q = builder.linear(receiver_slot, q_weight); - const int k = builder.linear(sender_public, k_weight); - const int v = builder.linear(sender_public, v_weight); - const int logits = builder.dot(q, k); - const int distance_biased_logits = builder.add(logits, edge_distance); - const int weights = builder.segment_softmax( - distance_biased_logits, - fabric::cuda::nn::MessageSegmentKind::ReceiverNeighborhood); - const int mixed = builder.segment_weighted_sum( - weights, - v, - fabric::cuda::nn::MessageSegmentKind::ReceiverNeighborhood); - const int projected = builder.linear(mixed, out_weight); - builder.emit_projected_message(projected); - return builder.build_message_rule(); - } -}; - -} // namespace fabric::cuda::message_rules diff --git a/src/cortical/fabric/backend/cuda/nn/ir.cuh b/src/cortical/fabric/backend/cuda/nn/ir.cuh index 1a1874ff..cdc1087b 100644 --- a/src/cortical/fabric/backend/cuda/nn/ir.cuh +++ b/src/cortical/fabric/backend/cuda/nn/ir.cuh @@ -175,6 +175,7 @@ enum class MessageSourceKind : int { ReceiverSlot = 0, ReceiverPublicPrev, ReceiverStatePrev, + SenderSlot, SenderPublicPrev, InputPublic, ReceiverCoord, @@ -222,6 +223,7 @@ enum class MessageOpKind : int { Bias, Add, Mul, + Concat, Fma, Dot, CosineKernel, @@ -240,6 +242,7 @@ enum class MessageOpKind : int { SegmentSoftmax, SegmentWeightedSum, EmitProjectedMessage, + Parameter, }; enum class MessageSegmentKind : int { @@ -260,11 +263,6 @@ enum class MessageTapeKind : int { SaveTopology, }; -enum class MessageRuleLoweringKind : int { - DotProductSegmentSoftmaxWeightedSum = 0, - Unsupported, -}; - enum class StateEpiloguePolicy : int { Separate = 0, FusedNoReductionSameChunk, @@ -446,7 +444,7 @@ struct LoweredMessageBucket { }; struct LoweredMessageRule { - MessageRuleLoweringKind lowering_kind; + int lowering_kind; LoweredMessageBucket bucket; }; @@ -475,6 +473,7 @@ struct MessageRuleNode { MessageSegmentKind segment_kind; MessageBackwardPolicy backward_policy; MessageTapeKind tape_kind; + std::vector parameter_indices; }; struct MessageRuleIR { @@ -879,6 +878,7 @@ inline bool is_supported_message_rule_source(MessageSourceKind source_kind) { case MessageSourceKind::ReceiverSlot: case MessageSourceKind::ReceiverPublicPrev: case MessageSourceKind::ReceiverStatePrev: + case MessageSourceKind::SenderSlot: case MessageSourceKind::SenderPublicPrev: case MessageSourceKind::InputPublic: case MessageSourceKind::ReceiverCoord: @@ -902,6 +902,7 @@ inline bool is_supported_message_rule_op(MessageOpKind op_kind) { case MessageOpKind::Bias: case MessageOpKind::Add: case MessageOpKind::Mul: + case MessageOpKind::Concat: case MessageOpKind::Fma: case MessageOpKind::Dot: case MessageOpKind::Exp: @@ -915,6 +916,7 @@ inline bool is_supported_message_rule_op(MessageOpKind op_kind) { case MessageOpKind::SegmentSoftmax: case MessageOpKind::SegmentWeightedSum: case MessageOpKind::EmitProjectedMessage: + case MessageOpKind::Parameter: return true; case MessageOpKind::CosineKernel: case MessageOpKind::Lookup: @@ -958,6 +960,20 @@ inline void validate_message_rule_parameter(const MessageRuleParameter& paramete } } +inline int message_rule_node_parameter_count(const MessageRuleNode& node) { + if (!node.parameter_indices.empty()) { + return static_cast(node.parameter_indices.size()); + } + return node.parameter_index >= 0 ? 1 : 0; +} + +inline int message_rule_node_parameter_at(const MessageRuleNode& node, int index) { + if (!node.parameter_indices.empty()) { + return node.parameter_indices[static_cast(index)]; + } + return node.parameter_index; +} + inline void validate_message_rule_node(const MessageRuleNode& node, int source_count, int parameter_count, int node_count) { if (!is_supported_message_rule_op(node.kind)) { throw std::invalid_argument("unsupported fabric.cuda.nn message rule op"); @@ -968,6 +984,16 @@ inline void validate_message_rule_node(const MessageRuleNode& node, int source_c if (node.parameter_index >= parameter_count) { throw std::invalid_argument("fabric.cuda.nn message rule node parameter index is out of range"); } + for (const int parameter_index : node.parameter_indices) { + if (parameter_index < 0 || parameter_index >= parameter_count) { + throw std::invalid_argument("fabric.cuda.nn message rule node parameter index list is out of range"); + } + } + if (!node.parameter_indices.empty() && node.parameter_index >= 0 && + node.parameter_index != node.parameter_indices.front()) { + throw std::invalid_argument( + "fabric.cuda.nn message rule node primary parameter must match the first parameter binding"); + } if (node.lhs >= node_count || node.rhs >= node_count) { throw std::invalid_argument("fabric.cuda.nn message rule node dependency is out of range"); } @@ -977,16 +1003,31 @@ inline void validate_message_rule_node(const MessageRuleNode& node, int source_c if (node.kind != MessageOpKind::Source && node.source_index >= 0) { throw std::invalid_argument("fabric.cuda.nn message non-source nodes cannot bind a source directly"); } - if (node.kind == MessageOpKind::Linear && (node.lhs < 0 || node.parameter_index < 0)) { + const int parameter_binding_count = message_rule_node_parameter_count(node); + if (node.kind == MessageOpKind::Parameter && parameter_binding_count != 1) { + throw std::invalid_argument("fabric.cuda.nn message parameter nodes require a parameter"); + } + if (node.kind == MessageOpKind::Parameter && (node.lhs >= 0 || node.rhs >= 0 || node.source_index >= 0)) { + throw std::invalid_argument("fabric.cuda.nn message parameter nodes cannot bind sources or dependencies"); + } + if (node.kind != MessageOpKind::Linear && node.kind != MessageOpKind::Bias && + node.kind != MessageOpKind::Parameter && parameter_binding_count > 0) { + throw std::invalid_argument("fabric.cuda.nn message node cannot bind a parameter directly"); + } + if (node.kind == MessageOpKind::Linear && (node.lhs < 0 || parameter_binding_count <= 0)) { throw std::invalid_argument("fabric.cuda.nn message rule linear nodes require input and parameter"); } + if (node.kind == MessageOpKind::Bias && (node.lhs < 0 || parameter_binding_count <= 0)) { + throw std::invalid_argument("fabric.cuda.nn message rule bias nodes require input and parameter"); + } if ((node.kind == MessageOpKind::Dot || node.kind == MessageOpKind::Add || - node.kind == MessageOpKind::Mul || node.kind == MessageOpKind::Fma || + node.kind == MessageOpKind::Mul || node.kind == MessageOpKind::Concat || node.kind == MessageOpKind::Fma || node.kind == MessageOpKind::SegmentWeightedSum) && (node.lhs < 0 || node.rhs < 0)) { throw std::invalid_argument("fabric.cuda.nn binary message rule nodes require two inputs"); } - if ((node.kind == MessageOpKind::SegmentSoftmax || node.kind == MessageOpKind::EmitProjectedMessage) && + if ((node.kind == MessageOpKind::SegmentSoftmax || node.kind == MessageOpKind::Normalize || + node.kind == MessageOpKind::EmitProjectedMessage) && node.lhs < 0) { throw std::invalid_argument("fabric.cuda.nn unary message rule nodes require an input"); } @@ -1018,91 +1059,128 @@ inline void validate_message_rule_ir(const MessageRuleIR& rule) { } } -inline bool message_rule_source_is( - const MessageRuleIR& rule, - int node_index, - MessageSourceKind source_kind, - ResetPolicy reset_policy = ResetPolicy::None, - ResetScope reset_scope = ResetScope::None) { - if (node_index < 0 || node_index >= static_cast(rule.nodes.size())) { +struct MessageRuleSourcePattern { + MessageSourceKind kind; + ResetPolicy reset_policy; + ResetScope reset_scope; +}; + +struct MessageRuleParameterPattern { + const char* name; + MessageParameterRole role; + MessageSharingScope sharing_scope; + bool require_name; +}; + +struct MessageRuleNodePattern { + MessageOpKind kind; + int lhs; + int rhs; + int parameter_index; + int source_index; + MessageSegmentKind segment_kind; + const int* parameter_indices; + int parameter_count; +}; + +struct MessageRuleLoweringPattern { + int lowering_kind; + const char* rule_type; + const MessageRuleSourcePattern* sources; + int source_count; + const MessageRuleParameterPattern* parameters; + int parameter_count; + const MessageRuleNodePattern* nodes; + int node_count; + int projected_message_node; +}; + +inline bool message_rule_source_matches_pattern( + const MessageRuleSource& source, + const MessageRuleSourcePattern& pattern) { + return source.kind == pattern.kind && source.reset_policy == pattern.reset_policy && + source.reset_scope == pattern.reset_scope; +} + +inline bool message_rule_parameter_matches_pattern( + const MessageRuleParameter& parameter, + const MessageRuleParameterPattern& pattern) { + if (parameter.role != pattern.role || parameter.sharing_scope != pattern.sharing_scope) { + return false; + } + if (!pattern.require_name) { + return true; + } + return pattern.name != nullptr && parameter.name == pattern.name; +} + +inline bool message_rule_node_matches_pattern( + const MessageRuleNode& node, + const MessageRuleNodePattern& pattern) { + if (node.kind != pattern.kind || node.lhs != pattern.lhs || node.rhs != pattern.rhs || + node.parameter_index != pattern.parameter_index || node.source_index != pattern.source_index || + node.segment_kind != pattern.segment_kind) { return false; } - const MessageRuleNode& node = rule.nodes[static_cast(node_index)]; - if (node.kind != MessageOpKind::Source || node.source_index < 0 || - node.source_index >= static_cast(rule.sources.size())) { + if (message_rule_node_parameter_count(node) != pattern.parameter_count) { return false; } - const MessageRuleSource& source = rule.sources[static_cast(node.source_index)]; - return source.kind == source_kind && source.reset_policy == reset_policy && source.reset_scope == reset_scope; + for (int index = 0; index < pattern.parameter_count; ++index) { + if (pattern.parameter_indices == nullptr || + message_rule_node_parameter_at(node, index) != pattern.parameter_indices[index]) { + return false; + } + } + return true; } -inline bool message_rule_parameter_is( +inline bool message_rule_matches_lowering_pattern( const MessageRuleIR& rule, - int parameter_index, - MessageParameterRole role, - MessageSharingScope sharing_scope) { - if (parameter_index < 0 || parameter_index >= static_cast(rule.parameters.size())) { + const MessageRuleLoweringPattern& pattern) { + if (static_cast(rule.sources.size()) != pattern.source_count || + static_cast(rule.parameters.size()) != pattern.parameter_count || + static_cast(rule.nodes.size()) != pattern.node_count || + rule.projected_message_node != pattern.projected_message_node) { return false; } - const MessageRuleParameter& parameter = rule.parameters[static_cast(parameter_index)]; - return parameter.role == role && parameter.sharing_scope == sharing_scope; + for (int index = 0; index < pattern.source_count; ++index) { + if (!message_rule_source_matches_pattern( + rule.sources[static_cast(index)], + pattern.sources[index])) { + return false; + } + } + for (int index = 0; index < pattern.parameter_count; ++index) { + if (!message_rule_parameter_matches_pattern( + rule.parameters[static_cast(index)], + pattern.parameters[index])) { + return false; + } + } + for (int index = 0; index < pattern.node_count; ++index) { + if (!message_rule_node_matches_pattern( + rule.nodes[static_cast(index)], + pattern.nodes[index])) { + return false; + } + } + return true; } -inline MessageRuleLoweringKind classify_message_rule_lowering(const MessageRuleIR& rule) { +#include "message_rule_lowering_catalog.cuh" + +static constexpr int kUnsupportedMessageRuleLowering = -1; + +inline int classify_message_rule_lowering(const MessageRuleIR& rule) { validate_message_rule_ir(rule); - if (rule.nodes.size() != 12 || rule.sources.size() != 3 || rule.parameters.size() != 4) { - return MessageRuleLoweringKind::Unsupported; - } - const MessageRuleNode& receiver = rule.nodes[0]; - const MessageRuleNode& sender = rule.nodes[1]; - const MessageRuleNode& distance = rule.nodes[2]; - const MessageRuleNode& q = rule.nodes[3]; - const MessageRuleNode& k = rule.nodes[4]; - const MessageRuleNode& v = rule.nodes[5]; - const MessageRuleNode& logits = rule.nodes[6]; - const MessageRuleNode& biased_logits = rule.nodes[7]; - const MessageRuleNode& weights = rule.nodes[8]; - const MessageRuleNode& mixed = rule.nodes[9]; - const MessageRuleNode& projected = rule.nodes[10]; - const MessageRuleNode& emit = rule.nodes[static_cast(rule.projected_message_node)]; - const bool matches = - message_rule_source_is(rule, 0, MessageSourceKind::ReceiverSlot) && - message_rule_source_is( - rule, - 1, - MessageSourceKind::SenderPublicPrev, - ResetPolicy::ZeroSourceRows, - ResetScope::BatchRow) && - message_rule_source_is(rule, 2, MessageSourceKind::EdgeDistance) && - receiver.kind == MessageOpKind::Source && sender.kind == MessageOpKind::Source && - distance.kind == MessageOpKind::Source && q.kind == MessageOpKind::Linear && q.lhs == 0 && - message_rule_parameter_is(rule, q.parameter_index, MessageParameterRole::Projection, MessageSharingScope::RuleGlobal) && - k.kind == MessageOpKind::Linear && k.lhs == 1 && - message_rule_parameter_is( - rule, - k.parameter_index, - MessageParameterRole::Projection, - MessageSharingScope::SenderGroupShared) && - v.kind == MessageOpKind::Linear && v.lhs == 1 && - message_rule_parameter_is( - rule, - v.parameter_index, - MessageParameterRole::Projection, - MessageSharingScope::SenderGroupShared) && - logits.kind == MessageOpKind::Dot && logits.lhs == 3 && logits.rhs == 4 && - biased_logits.kind == MessageOpKind::Add && biased_logits.lhs == 6 && biased_logits.rhs == 2 && - weights.kind == MessageOpKind::SegmentSoftmax && weights.lhs == 7 && - weights.segment_kind == MessageSegmentKind::ReceiverNeighborhood && - mixed.kind == MessageOpKind::SegmentWeightedSum && mixed.lhs == 8 && mixed.rhs == 5 && - mixed.segment_kind == MessageSegmentKind::ReceiverNeighborhood && - projected.kind == MessageOpKind::Linear && projected.lhs == 9 && - message_rule_parameter_is( - rule, - projected.parameter_index, - MessageParameterRole::Projection, - MessageSharingScope::RuleGlobal) && - emit.kind == MessageOpKind::EmitProjectedMessage && emit.lhs == 10; - return matches ? MessageRuleLoweringKind::DotProductSegmentSoftmaxWeightedSum : MessageRuleLoweringKind::Unsupported; + for (const MessageRuleLoweringPattern* pattern = registered_message_rule_lowering_patterns_begin(); + pattern != registered_message_rule_lowering_patterns_end(); + ++pattern) { + if (message_rule_matches_lowering_pattern(rule, *pattern)) { + return pattern->lowering_kind; + } + } + return kUnsupportedMessageRuleLowering; } inline LoweredMessageRule lower_message_rule_to_bucket( @@ -1114,8 +1192,8 @@ inline LoweredMessageRule lower_message_rule_to_bucket( int64_t V, bool same_sized_window, bool allow_grouped) { - const MessageRuleLoweringKind lowering_kind = classify_message_rule_lowering(rule); - if (lowering_kind != MessageRuleLoweringKind::DotProductSegmentSoftmaxWeightedSum) { + const int lowering_kind = classify_message_rule_lowering(rule); + if (lowering_kind == kUnsupportedMessageRuleLowering) { throw std::invalid_argument("unsupported fabric.cuda.nn message rule lowering"); } MessageBucketSignature signature = regular_local_receiver_owned_message_signature(M, degree_or_block, K, V); @@ -1336,10 +1414,40 @@ class MessageRuleBuilder { return add_node(MessageOpKind::Linear, lhs, -1, parameter_index); } + int linear(int lhs, std::vector parameter_indices) { + return add_node( + MessageOpKind::Linear, + lhs, + -1, + -1, + -1, + MessageSegmentKind::ReceiverNeighborhood, + MessageBackwardPolicy::ExplicitBackward, + MessageTapeKind::Recompute, + std::move(parameter_indices)); + } + int bias(int lhs, int parameter_index) { return add_node(MessageOpKind::Bias, lhs, -1, parameter_index); } + int bias(int lhs, std::vector parameter_indices) { + return add_node( + MessageOpKind::Bias, + lhs, + -1, + -1, + -1, + MessageSegmentKind::ReceiverNeighborhood, + MessageBackwardPolicy::ExplicitBackward, + MessageTapeKind::Recompute, + std::move(parameter_indices)); + } + + int parameter_value(int parameter_index) { + return add_node(MessageOpKind::Parameter, -1, -1, parameter_index); + } + int add(int lhs, int rhs) { return add_node(MessageOpKind::Add, lhs, rhs); } @@ -1348,6 +1456,10 @@ class MessageRuleBuilder { return add_node(MessageOpKind::Mul, lhs, rhs); } + int concat(int lhs, int rhs) { + return add_node(MessageOpKind::Concat, lhs, rhs); + } + int fma(int lhs, int rhs) { return add_node(MessageOpKind::Fma, lhs, rhs); } @@ -1371,6 +1483,10 @@ class MessageRuleBuilder { return add_node(MessageOpKind::SegmentWeightedSum, weights, values, -1, -1, segment_kind); } + int normalize(int lhs) { + return add_node(MessageOpKind::Normalize, lhs); + } + int emit_projected_message(int lhs) { const int node = add_node(MessageOpKind::EmitProjectedMessage, lhs); rule_.projected_message_node = node; @@ -1406,16 +1522,23 @@ class MessageRuleBuilder { int source_index = -1, MessageSegmentKind segment_kind = MessageSegmentKind::ReceiverNeighborhood, MessageBackwardPolicy backward_policy = MessageBackwardPolicy::ExplicitBackward, - MessageTapeKind tape_kind = MessageTapeKind::Recompute) { + MessageTapeKind tape_kind = MessageTapeKind::Recompute, + std::vector parameter_indices = {}) { + if (parameter_index >= 0 && parameter_indices.empty()) { + parameter_indices.push_back(parameter_index); + } + const int primary_parameter_index = + parameter_indices.empty() ? parameter_index : parameter_indices.front(); MessageRuleNode node{ kind, lhs, rhs, - parameter_index, + primary_parameter_index, source_index, segment_kind, backward_policy, tape_kind, + std::move(parameter_indices), }; validate_message_rule_node( node, diff --git a/src/cortical/fabric/backend/cuda/nn/message_rule_lowering_catalog.cuh b/src/cortical/fabric/backend/cuda/nn/message_rule_lowering_catalog.cuh new file mode 100644 index 00000000..2b341a4e --- /dev/null +++ b/src/cortical/fabric/backend/cuda/nn/message_rule_lowering_catalog.cuh @@ -0,0 +1,188 @@ +// Generated by cortical.fabric.backend.message_rules.message_rule_lowering_catalog_header_text. +// Do not edit by hand; update MessageRuleBackendSpec registrations instead. + +#pragma once + +static constexpr int kDotProductLoweringId = 0; + +static const MessageRuleSourcePattern kDotProductSources[] = { + {MessageSourceKind::ReceiverSlot, ResetPolicy::None, ResetScope::None}, + {MessageSourceKind::SenderPublicPrev, ResetPolicy::ZeroSourceRows, ResetScope::BatchRow}, + {MessageSourceKind::EdgeDistance, ResetPolicy::None, ResetScope::None}, +}; + +static const MessageRuleParameterPattern kDotProductParameters[] = { + {nullptr, MessageParameterRole::Projection, MessageSharingScope::RuleGlobal, false}, + {nullptr, MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, false}, + {nullptr, MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, false}, + {nullptr, MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, false}, +}; + +static const int kDotProductNode3ParameterIndices[] = {0}; +static const int kDotProductNode4ParameterIndices[] = {1}; +static const int kDotProductNode5ParameterIndices[] = {2}; +static const int kDotProductNode10ParameterIndices[] = {3}; + +static const MessageRuleNodePattern kDotProductNodes[] = { + {MessageOpKind::Source, -1, -1, -1, 0, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 2, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 0, -1, 0, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductNode3ParameterIndices, 1}, + {MessageOpKind::Linear, 1, -1, 1, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductNode4ParameterIndices, 1}, + {MessageOpKind::Linear, 1, -1, 2, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductNode5ParameterIndices, 1}, + {MessageOpKind::Dot, 3, 4, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Add, 6, 2, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::SegmentSoftmax, 7, -1, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::SegmentWeightedSum, 8, 5, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 9, -1, 3, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductNode10ParameterIndices, 1}, + {MessageOpKind::EmitProjectedMessage, 10, -1, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, +}; + +static constexpr int kDotProductFixedSlotContextNudgeLoweringId = 1; + +static const MessageRuleSourcePattern kDotProductFixedSlotContextNudgeSources[] = { + {MessageSourceKind::ReceiverSlot, ResetPolicy::None, ResetScope::None}, + {MessageSourceKind::ReceiverPublicPrev, ResetPolicy::ZeroSourceRows, ResetScope::BatchRow}, + {MessageSourceKind::SenderSlot, ResetPolicy::None, ResetScope::None}, + {MessageSourceKind::SenderPublicPrev, ResetPolicy::ZeroSourceRows, ResetScope::BatchRow}, + {MessageSourceKind::EdgeDistance, ResetPolicy::None, ResetScope::None}, +}; + +static const MessageRuleParameterPattern kDotProductFixedSlotContextNudgeParameters[] = { + {"message_query_slot_weight", MessageParameterRole::Projection, MessageSharingScope::RuleGlobal, true}, + {"message_query_nudge_scale", MessageParameterRole::RuleScalar, MessageSharingScope::FabricGlobal, true}, + {"message_sender_slot_key_weight", MessageParameterRole::Projection, MessageSharingScope::RuleGlobal, true}, + {"message_sender_context_key", MessageParameterRole::RuleTable, MessageSharingScope::SenderLocal, true}, + {"input_sender_value_weight", MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, true}, + {"input_group_value_weight", MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, true}, + {"recurrent_sender_value_weight", MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, true}, + {"message_output_weight", MessageParameterRole::Projection, MessageSharingScope::RuleGlobal, true}, +}; + +static const int kDotProductFixedSlotContextNudgeNode5ParameterIndices[] = {0}; +static const int kDotProductFixedSlotContextNudgeNode6ParameterIndices[] = {6}; +static const int kDotProductFixedSlotContextNudgeNode7ParameterIndices[] = {1}; +static const int kDotProductFixedSlotContextNudgeNode10ParameterIndices[] = {2}; +static const int kDotProductFixedSlotContextNudgeNode11ParameterIndices[] = {3}; +static const int kDotProductFixedSlotContextNudgeNode13ParameterIndices[] = {4, 5, 6}; +static const int kDotProductFixedSlotContextNudgeNode18ParameterIndices[] = {7}; + +static const MessageRuleNodePattern kDotProductFixedSlotContextNudgeNodes[] = { + {MessageOpKind::Source, -1, -1, -1, 0, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 2, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 3, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 4, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 0, -1, 0, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextNudgeNode5ParameterIndices, 1}, + {MessageOpKind::Linear, 1, -1, 6, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextNudgeNode6ParameterIndices, 1}, + {MessageOpKind::Parameter, -1, -1, 1, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextNudgeNode7ParameterIndices, 1}, + {MessageOpKind::Mul, 6, 7, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Concat, 5, 8, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 2, -1, 2, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextNudgeNode10ParameterIndices, 1}, + {MessageOpKind::Parameter, -1, -1, 3, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextNudgeNode11ParameterIndices, 1}, + {MessageOpKind::Concat, 10, 11, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 3, -1, 4, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextNudgeNode13ParameterIndices, 3}, + {MessageOpKind::Dot, 9, 12, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Add, 14, 4, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::SegmentSoftmax, 15, -1, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::SegmentWeightedSum, 16, 13, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 17, -1, 7, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextNudgeNode18ParameterIndices, 1}, + {MessageOpKind::Normalize, 18, -1, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::EmitProjectedMessage, 19, -1, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, +}; + +static constexpr int kDotProductFixedSlotContextGateLoweringId = 2; + +static const MessageRuleSourcePattern kDotProductFixedSlotContextGateSources[] = { + {MessageSourceKind::ReceiverSlot, ResetPolicy::None, ResetScope::None}, + {MessageSourceKind::ReceiverPublicPrev, ResetPolicy::ZeroSourceRows, ResetScope::BatchRow}, + {MessageSourceKind::SenderSlot, ResetPolicy::None, ResetScope::None}, + {MessageSourceKind::SenderPublicPrev, ResetPolicy::ZeroSourceRows, ResetScope::BatchRow}, + {MessageSourceKind::EdgeDistance, ResetPolicy::None, ResetScope::None}, +}; + +static const MessageRuleParameterPattern kDotProductFixedSlotContextGateParameters[] = { + {"message_query_slot_weight", MessageParameterRole::Projection, MessageSharingScope::RuleGlobal, true}, + {"message_query_context_gate", MessageParameterRole::RuleScalar, MessageSharingScope::FabricGlobal, true}, + {"message_sender_slot_key_weight", MessageParameterRole::Projection, MessageSharingScope::RuleGlobal, true}, + {"message_sender_context_key", MessageParameterRole::RuleTable, MessageSharingScope::SenderLocal, true}, + {"input_sender_value_weight", MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, true}, + {"input_group_value_weight", MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, true}, + {"recurrent_sender_value_weight", MessageParameterRole::Projection, MessageSharingScope::SenderGroupShared, true}, + {"message_output_weight", MessageParameterRole::Projection, MessageSharingScope::RuleGlobal, true}, +}; + +static const int kDotProductFixedSlotContextGateNode5ParameterIndices[] = {0}; +static const int kDotProductFixedSlotContextGateNode6ParameterIndices[] = {6}; +static const int kDotProductFixedSlotContextGateNode7ParameterIndices[] = {1}; +static const int kDotProductFixedSlotContextGateNode10ParameterIndices[] = {2}; +static const int kDotProductFixedSlotContextGateNode11ParameterIndices[] = {3}; +static const int kDotProductFixedSlotContextGateNode13ParameterIndices[] = {4, 5, 6}; +static const int kDotProductFixedSlotContextGateNode18ParameterIndices[] = {7}; + +static const MessageRuleNodePattern kDotProductFixedSlotContextGateNodes[] = { + {MessageOpKind::Source, -1, -1, -1, 0, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 2, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 3, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Source, -1, -1, -1, 4, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 0, -1, 0, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextGateNode5ParameterIndices, 1}, + {MessageOpKind::Linear, 1, -1, 6, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextGateNode6ParameterIndices, 1}, + {MessageOpKind::Parameter, -1, -1, 1, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextGateNode7ParameterIndices, 1}, + {MessageOpKind::Mul, 6, 7, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Concat, 5, 8, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 2, -1, 2, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextGateNode10ParameterIndices, 1}, + {MessageOpKind::Parameter, -1, -1, 3, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextGateNode11ParameterIndices, 1}, + {MessageOpKind::Concat, 10, 11, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 3, -1, 4, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextGateNode13ParameterIndices, 3}, + {MessageOpKind::Dot, 9, 12, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Add, 14, 4, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::SegmentSoftmax, 15, -1, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::SegmentWeightedSum, 16, 13, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::Linear, 17, -1, 7, -1, MessageSegmentKind::ReceiverNeighborhood, kDotProductFixedSlotContextGateNode18ParameterIndices, 1}, + {MessageOpKind::Normalize, 18, -1, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, + {MessageOpKind::EmitProjectedMessage, 19, -1, -1, -1, MessageSegmentKind::ReceiverNeighborhood, nullptr, 0}, +}; + +inline const MessageRuleLoweringPattern* registered_message_rule_lowering_patterns_begin() { + static const MessageRuleLoweringPattern kRegisteredMessageRuleLoweringPatterns[] = { + { + kDotProductLoweringId, + "dot_product", + kDotProductSources, + static_cast(sizeof(kDotProductSources) / sizeof(kDotProductSources[0])), + kDotProductParameters, + static_cast(sizeof(kDotProductParameters) / sizeof(kDotProductParameters[0])), + kDotProductNodes, + static_cast(sizeof(kDotProductNodes) / sizeof(kDotProductNodes[0])), + 11, + }, + { + kDotProductFixedSlotContextNudgeLoweringId, + "dot_product_fixed_slot_context_nudge", + kDotProductFixedSlotContextNudgeSources, + static_cast(sizeof(kDotProductFixedSlotContextNudgeSources) / sizeof(kDotProductFixedSlotContextNudgeSources[0])), + kDotProductFixedSlotContextNudgeParameters, + static_cast(sizeof(kDotProductFixedSlotContextNudgeParameters) / sizeof(kDotProductFixedSlotContextNudgeParameters[0])), + kDotProductFixedSlotContextNudgeNodes, + static_cast(sizeof(kDotProductFixedSlotContextNudgeNodes) / sizeof(kDotProductFixedSlotContextNudgeNodes[0])), + 20, + }, + { + kDotProductFixedSlotContextGateLoweringId, + "dot_product_fixed_slot_context_gate", + kDotProductFixedSlotContextGateSources, + static_cast(sizeof(kDotProductFixedSlotContextGateSources) / sizeof(kDotProductFixedSlotContextGateSources[0])), + kDotProductFixedSlotContextGateParameters, + static_cast(sizeof(kDotProductFixedSlotContextGateParameters) / sizeof(kDotProductFixedSlotContextGateParameters[0])), + kDotProductFixedSlotContextGateNodes, + static_cast(sizeof(kDotProductFixedSlotContextGateNodes) / sizeof(kDotProductFixedSlotContextGateNodes[0])), + 20, + }, + }; + return kRegisteredMessageRuleLoweringPatterns; +} + +inline const MessageRuleLoweringPattern* registered_message_rule_lowering_patterns_end() { + return registered_message_rule_lowering_patterns_begin() + 3; +} diff --git a/src/cortical/fabric/backend/cuda/ops/__init__.py b/src/cortical/fabric/backend/cuda/ops/__init__.py index cbb52bb5..a2d01400 100644 --- a/src/cortical/fabric/backend/cuda/ops/__init__.py +++ b/src/cortical/fabric/backend/cuda/ops/__init__.py @@ -31,6 +31,7 @@ receiver_major_affine_bias_small_batch_cuda, receiver_major_affine_bias_split_out_cuda, receiver_major_affine_cuda, + receiver_major_affine_input_backward_cuda, receiver_major_affine_out_cuda, receiver_major_affine_small_batch_backward_cuda, receiver_major_affine_small_batch_cuda, @@ -66,6 +67,7 @@ "receiver_major_affine_bias_cuda", "receiver_major_affine_bias_out_cuda", "receiver_major_affine_bias_split_out_cuda", + "receiver_major_affine_input_backward_cuda", "receiver_major_affine_bias_small_batch_backward_cuda", "receiver_major_affine_bias_small_batch_cuda", "receiver_major_affine_small_batch_backward_cuda", diff --git a/src/cortical/fabric/backend/cuda/ops/diagonal_recurrence.cuh b/src/cortical/fabric/backend/cuda/ops/diagonal_recurrence.cuh index 5b74d2b1..540900ed 100644 --- a/src/cortical/fabric/backend/cuda/ops/diagonal_recurrence.cuh +++ b/src/cortical/fabric/backend/cuda/ops/diagonal_recurrence.cuh @@ -6,7 +6,7 @@ #include #include -#include "cortical/fabric/backend/cuda/execution/common.cuh" +#include "cortical/fabric/backend/cuda/contracts/common.cuh" #include "cortical/fabric/backend/cuda/nn/ir.cuh" namespace fabric::cuda::ops { diff --git a/src/cortical/fabric/backend/cuda/ops/factorized_projection_grads_triton.py b/src/cortical/fabric/backend/cuda/ops/factorized_projection_grads_triton.py index 88bf8337..2d6acf75 100644 --- a/src/cortical/fabric/backend/cuda/ops/factorized_projection_grads_triton.py +++ b/src/cortical/fabric/backend/cuda/ops/factorized_projection_grads_triton.py @@ -437,7 +437,12 @@ def factorized_recurrent_input_projection_grads_cuda( ): raise RuntimeError("factorized recurrent input projection backward requires CUDA float32 tensors") if input_proj_weight_t.dim() != 3 or grad_fused_weight.dim() != 3 or value_to_cell_weight.dim() != 2: - raise RuntimeError("factorized recurrent input projection backward expects [R,H,P], [R,V,P], and [H,V]") + raise RuntimeError( + "factorized recurrent input projection backward expects [R,H,P], [R,V,P], and [H,V]; " + f"input_proj_weight_t={tuple(input_proj_weight_t.shape)}, " + f"grad_fused_weight={tuple(grad_fused_weight.shape)}, " + f"value_to_cell_weight={tuple(value_to_cell_weight.shape)}" + ) receivers, hidden, projected = (int(dim) for dim in input_proj_weight_t.shape) grad_receivers, value_dim, grad_projected = (int(dim) for dim in grad_fused_weight.shape) diff --git a/src/cortical/fabric/backend/cuda/ops/receiver_major_affine_triton.py b/src/cortical/fabric/backend/cuda/ops/receiver_major_affine_triton.py index 156727df..3b160c99 100644 --- a/src/cortical/fabric/backend/cuda/ops/receiver_major_affine_triton.py +++ b/src/cortical/fabric/backend/cuda/ops/receiver_major_affine_triton.py @@ -1198,6 +1198,54 @@ def receiver_major_affine_backward_cuda( return grad_input, grad_weight +def receiver_major_affine_input_backward_cuda( + input: torch.Tensor, + weight: torch.Tensor, + grad_output: torch.Tensor, + *, + block_b: int = 32, + block_k: int | None = None, + block_n: int | None = None, + num_warps: int = 4, +) -> torch.Tensor: + batch, receivers, input_dim, output_dim = _check_receiver_major_affine_backward(input, weight, grad_output) + if block_b not in {16, 32, 64}: + raise RuntimeError("block_b must be 16, 32, or 64") + resolved_block_k = block_k if block_k is not None else _block_power_of_two(input_dim, cap=64) + resolved_block_n = block_n if block_n is not None else _block_power_of_two(output_dim, cap=64) + if resolved_block_k not in {1, 2, 4, 8, 16, 32, 64}: + raise RuntimeError("block_k must be a power of two no larger than 64") + if resolved_block_n not in {1, 2, 4, 8, 16, 32, 64}: + raise RuntimeError("block_n must be a power of two no larger than 64") + + grad_input = torch.empty_like(input) + batch_blocks = triton.cdiv(batch, block_b) + grid_input = (batch_blocks * receivers, triton.cdiv(input_dim, resolved_block_k)) + _receiver_major_affine_backward_input_kernel[grid_input]( + grad_output, + weight, + grad_input, + batch, + receivers, + input_dim, + output_dim, + grad_output.stride(0), + grad_output.stride(1), + grad_output.stride(2), + weight.stride(0), + weight.stride(1), + weight.stride(2), + grad_input.stride(0), + grad_input.stride(1), + grad_input.stride(2), + BLOCK_B=int(block_b), + BLOCK_K=int(resolved_block_k), + BLOCK_N=int(resolved_block_n), + num_warps=int(num_warps), + ) + return grad_input + + def receiver_major_affine_bias_small_batch_backward_cuda( input: torch.Tensor, weight: torch.Tensor, @@ -1580,6 +1628,7 @@ def receiver_major_affine_small_batch_cuda( "receiver_major_affine_bias_bmm_cuda", "receiver_major_affine_bias_cuda", "receiver_major_affine_bias_out_cuda", + "receiver_major_affine_input_backward_cuda", "receiver_major_affine_small_batch_backward_cuda", "receiver_major_affine_small_batch_cuda", "receiver_major_affine_bias_small_batch_backward_cuda", diff --git a/src/cortical/fabric/backend/cuda/projection/__init__.py b/src/cortical/fabric/backend/cuda/projection/__init__.py index 35bd4999..bdec2fc8 100644 --- a/src/cortical/fabric/backend/cuda/projection/__init__.py +++ b/src/cortical/fabric/backend/cuda/projection/__init__.py @@ -1,14 +1,3 @@ -from cortical.fabric.backend.cuda.projection.grouped_projection_cuda import ( - fabric_grouped_projection_cuda, - fabric_grouped_projection_forward_cuda, -) -from cortical.fabric.backend.cuda.projection.registry import get_readout_backend, register_readout_backend +from __future__ import annotations -register_readout_backend("local_output_readout", fabric_grouped_projection_forward_cuda) - -__all__ = [ - "fabric_grouped_projection_cuda", - "fabric_grouped_projection_forward_cuda", - "get_readout_backend", - "register_readout_backend", -] +__all__: list[str] = [] diff --git a/src/cortical/fabric/backend/cuda/projection/grouped_projection_binding.cpp b/src/cortical/fabric/backend/cuda/projection/grouped_projection_binding.cpp deleted file mode 100644 index 39612130..00000000 --- a/src/cortical/fabric/backend/cuda/projection/grouped_projection_binding.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -std::vector fabric_grouped_projection_forward_cuda( - at::Tensor sender_cells, at::Tensor grouped_weight, int64_t group_size, - int64_t receiver_offset); - -std::vector fabric_grouped_projection_backward_input_cuda( - at::Tensor grad_output, at::Tensor grouped_weight, int64_t group_size, - int64_t receiver_offset); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &fabric_grouped_projection_forward_cuda, - "Fabric grouped projection forward (CUDA)"); - m.def("backward_input", &fabric_grouped_projection_backward_input_cuda, - "Fabric grouped projection backward-input (CUDA)"); -} diff --git a/src/cortical/fabric/backend/cuda/projection/grouped_projection_cuda.py b/src/cortical/fabric/backend/cuda/projection/grouped_projection_cuda.py deleted file mode 100644 index f2189ab6..00000000 --- a/src/cortical/fabric/backend/cuda/projection/grouped_projection_cuda.py +++ /dev/null @@ -1,262 +0,0 @@ -from __future__ import annotations - -import os - -import torch -import triton -import triton.language as tl -from torch.autograd import Function - -from cortical.native.extension_loader import safe_load_extension - -_mod_path = os.path.dirname(__file__) -_ext = None - - -def _load_ext(): - global _ext - if _ext is not None: - return _ext - _ext = safe_load_extension( - name="fabric_grouped_projection_cuda", - sources=[ - os.path.join(_mod_path, "grouped_projection_binding.cpp"), - os.path.join(_mod_path, "grouped_projection_kernels.cu"), - ], - extra_cflags=["-O3"], - extra_cuda_cflags=["-O3", "-Xptxas", "-O3"], - verbose=False, - ) - return _ext - - -@triton.jit -def _grouped_projection_backward_weight_kernel( - sender_cells, - grad_output, - grad_weight, - batch_size: tl.constexpr, - num_cells: tl.constexpr, - hidden_size: tl.constexpr, - out_dim: tl.constexpr, - group_size: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_M: tl.constexpr, -): - group_idx = tl.program_id(0) - h_block = tl.program_id(1) - m_block = tl.program_id(2) - offsets_k = tl.arange(0, BLOCK_K) - offsets_h = h_block * BLOCK_H + tl.arange(0, BLOCK_H) - offsets_m = m_block * BLOCK_M + tl.arange(0, BLOCK_M) - total_group_rows: tl.constexpr = batch_size * group_size - acc = tl.zeros((BLOCK_H, BLOCK_M), dtype=tl.float32) - - for k_base in range(0, total_group_rows, BLOCK_K): - group_row = k_base + offsets_k - batch_idx = group_row // group_size - local_sender_idx = group_row - batch_idx * group_size - cell_idx = group_idx * group_size + local_sender_idx - row_mask = group_row < total_group_rows - sender_offsets = (batch_idx[:, None] * num_cells + cell_idx[:, None]) * hidden_size + offsets_h[None, :] - grad_offsets = (batch_idx[:, None] * num_cells + cell_idx[:, None]) * out_dim + offsets_m[None, :] - x = tl.load( - sender_cells + sender_offsets, - mask=row_mask[:, None] & (offsets_h[None, :] < hidden_size), - other=0.0, - ) - grad = tl.load( - grad_output + grad_offsets, - mask=row_mask[:, None] & (offsets_m[None, :] < out_dim), - other=0.0, - ) - acc += tl.dot(tl.trans(x), grad, input_precision="ieee") - - weight_offsets = group_idx * hidden_size * out_dim + offsets_h[:, None] * out_dim + offsets_m[None, :] - tl.store( - grad_weight + weight_offsets, - acc, - mask=(offsets_h[:, None] < hidden_size) & (offsets_m[None, :] < out_dim), - ) - - -def _next_power_of_2(value: int) -> int: - return 1 << max(0, int(value) - 1).bit_length() - - -def _grouped_projection_backward_weight( - sender_cells: torch.Tensor, - grad_output: torch.Tensor, - *, - group_size: int, - receiver_offset: int = 0, - grouped_weight_shape: tuple[int, ...] | None = None, -) -> torch.Tensor: - if int(receiver_offset) != 0: - if grouped_weight_shape is None: - raise RuntimeError("receiver-offset grouped projection backward requires grouped_weight_shape") - batch_size, num_cells, hidden_size = sender_cells.shape - out_dim = grad_output.shape[-1] - grad_weight = sender_cells.new_zeros(grouped_weight_shape) - global_cells = torch.arange( - int(receiver_offset), - int(receiver_offset) + int(num_cells), - device=sender_cells.device, - dtype=torch.long, - ) - group_ids = torch.div(global_cells, int(group_size), rounding_mode="floor") - for group_id in torch.unique(group_ids, sorted=True): - mask = group_ids == group_id - group_input = sender_cells[:, mask, :].reshape(batch_size, -1, hidden_size) - group_grad = grad_output[:, mask, :].reshape(batch_size, -1, out_dim) - grad_weight[int(group_id.item())] = torch.einsum("bnh,bnm->hm", group_input, group_grad) - return grad_weight - if ( - not sender_cells.is_cuda - or not grad_output.is_cuda - or sender_cells.dtype != torch.float32 - or grad_output.dtype != torch.float32 - or sender_cells.dim() != 3 - or grad_output.dim() != 3 - or group_size <= 0 - ): - batch_size, num_cells, hidden_size = sender_cells.shape - out_dim = grad_output.shape[-1] - num_groups = num_cells // max(1, group_size) - sender_grouped = sender_cells.reshape(batch_size, num_groups, group_size, hidden_size) - grad_grouped = grad_output.reshape(batch_size, num_groups, group_size, out_dim) - return torch.einsum("bgsh,bgsm->ghm", sender_grouped, grad_grouped) - batch_size, num_cells, hidden_size = (int(dim) for dim in sender_cells.shape) - out_dim = int(grad_output.shape[-1]) - if int(grad_output.shape[0]) != batch_size or int(grad_output.shape[1]) != num_cells: - raise RuntimeError("grouped projection grad_output must match sender_cells batch/cell shape") - if num_cells % int(group_size) != 0: - raise RuntimeError("grouped projection group_size must divide sender cell count") - num_groups = num_cells // int(group_size) - sender_cells = sender_cells.contiguous() - grad_output = grad_output.contiguous() - grad_weight = torch.empty((num_groups, hidden_size, out_dim), device=sender_cells.device, dtype=sender_cells.dtype) - block_h = _next_power_of_2(min(hidden_size, 64)) - block_m = _next_power_of_2(min(out_dim, 64)) - block_k = max(16, _next_power_of_2(min(batch_size * int(group_size), 256))) - _grouped_projection_backward_weight_kernel[ - (num_groups, triton.cdiv(hidden_size, block_h), triton.cdiv(out_dim, block_m)) - ]( - sender_cells, - grad_output, - grad_weight, - batch_size, - num_cells, - hidden_size, - out_dim, - int(group_size), - BLOCK_K=block_k, - BLOCK_H=block_h, - BLOCK_M=block_m, - num_warps=4, - ) - return grad_weight - - -class _FabricGroupedProjectionCUDA(Function): - @staticmethod - def forward( - sender_cells: torch.Tensor, - grouped_weight: torch.Tensor, - group_size: int, - receiver_offset: int = 0, - ) -> torch.Tensor: - (projected,) = _load_ext().forward( - sender_cells.contiguous(), - grouped_weight.contiguous(), - int(group_size), - int(receiver_offset), - ) - return projected - - @staticmethod - def setup_context(ctx, inputs, output): - del output - sender_cells, grouped_weight, group_size, receiver_offset = inputs - ctx.save_for_backward(sender_cells, grouped_weight) - ctx.group_size = int(group_size) - ctx.receiver_offset = int(receiver_offset) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - sender_cells, grouped_weight = ctx.saved_tensors - group_size = int(ctx.group_size) - with torch.profiler.record_function("fabric.backward.grouped_projection"): - grad_input = _load_ext().backward_input( - grad_output.contiguous(), - grouped_weight.contiguous(), - group_size, - int(ctx.receiver_offset), - )[0] - grad_weight = _grouped_projection_backward_weight( - sender_cells, - grad_output, - group_size=group_size, - receiver_offset=int(ctx.receiver_offset), - grouped_weight_shape=tuple(grouped_weight.shape), - ) - return grad_input, grad_weight, None, None - - -def fabric_grouped_projection_cuda( - sender_cells: torch.Tensor, - grouped_weight: torch.Tensor, - *, - group_size: int, - receiver_offset: int = 0, -) -> torch.Tensor: - return _FabricGroupedProjectionCUDA.apply(sender_cells, grouped_weight, int(group_size), int(receiver_offset)) - - -def fabric_grouped_projection_backward_cuda( - sender_cells: torch.Tensor, - grouped_weight: torch.Tensor, - grad_output: torch.Tensor, - *, - group_size: int, - receiver_offset: int = 0, -) -> tuple[torch.Tensor, torch.Tensor]: - with torch.profiler.record_function("fabric.backward.grouped_projection"): - grad_input = _load_ext().backward_input( - grad_output.contiguous(), - grouped_weight.contiguous(), - int(group_size), - int(receiver_offset), - )[0] - grad_weight = _grouped_projection_backward_weight( - sender_cells, - grad_output, - group_size=group_size, - receiver_offset=int(receiver_offset), - grouped_weight_shape=tuple(grouped_weight.shape), - ) - return grad_input, grad_weight - - -def fabric_grouped_projection_forward_cuda( - sender_cells: torch.Tensor, - grouped_weight: torch.Tensor, - *, - group_size: int, - receiver_offset: int = 0, -) -> torch.Tensor: - (projected,) = _load_ext().forward( - sender_cells.contiguous(), - grouped_weight.contiguous(), - int(group_size), - int(receiver_offset), - ) - return projected - - -__all__ = [ - "fabric_grouped_projection_backward_cuda", - "fabric_grouped_projection_cuda", - "fabric_grouped_projection_forward_cuda", -] diff --git a/src/cortical/fabric/backend/cuda/projection/grouped_projection_kernels.cu b/src/cortical/fabric/backend/cuda/projection/grouped_projection_kernels.cu deleted file mode 100644 index 8430b597..00000000 --- a/src/cortical/fabric/backend/cuda/projection/grouped_projection_kernels.cu +++ /dev/null @@ -1,176 +0,0 @@ -#include -#include -#include -#include - -#include - -namespace { - -inline void check_cuda_tensor(const at::Tensor &tensor, const char *name) { - TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); - TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); -} - -inline void check_launch(const char *name) { - const cudaError_t err = cudaGetLastError(); - TORCH_CHECK(err == cudaSuccess, name, " launch failed: ", - cudaGetErrorString(err)); -} - -inline int lane_group_size(int width) { - int lane_group = 1; - const int target = std::max(1, std::min(width, 32)); - while (lane_group < target) { - lane_group <<= 1; - } - return std::min(lane_group, 32); -} - -__global__ void grouped_projection_forward_kernel( - const float *__restrict__ sender_cells, - const float *__restrict__ grouped_weight, float *__restrict__ projected, - int num_rows, - int num_cells, - int hidden_size, - int out_dim, - int group_size, - int receiver_offset) { - const int row_local = threadIdx.y; - const int lane = threadIdx.x; - const int row = blockIdx.x * blockDim.y + row_local; - if (row >= num_rows) { - return; - } - const int cell_idx = row % num_cells; - const int global_cell_idx = receiver_offset + cell_idx; - const int group_idx = global_cell_idx / group_size; - const float *x_row = sender_cells + static_cast(row) * hidden_size; - const float *w_group = - grouped_weight + - (static_cast(group_idx) * hidden_size * out_dim); - float *y_row = projected + static_cast(row) * out_dim; - - for (int out_idx = lane; out_idx < out_dim; out_idx += blockDim.x) { - float acc = 0.0f; - for (int hidden_idx = 0; hidden_idx < hidden_size; ++hidden_idx) { - acc += x_row[hidden_idx] * w_group[hidden_idx * out_dim + out_idx]; - } - y_row[out_idx] = acc; - } -} - -__global__ void grouped_projection_backward_input_kernel( - const float *__restrict__ grad_output, - const float *__restrict__ grouped_weight, - float *__restrict__ grad_input, int num_rows, int num_cells, - int hidden_size, int out_dim, int group_size, int receiver_offset) { - const int row_local = threadIdx.y; - const int lane = threadIdx.x; - const int row = blockIdx.x * blockDim.y + row_local; - if (row >= num_rows) { - return; - } - const int cell_idx = row % num_cells; - const int global_cell_idx = receiver_offset + cell_idx; - const int group_idx = global_cell_idx / group_size; - const float *grad_row = grad_output + static_cast(row) * out_dim; - const float *w_group = - grouped_weight + - (static_cast(group_idx) * hidden_size * out_dim); - float *grad_input_row = grad_input + static_cast(row) * hidden_size; - - for (int hidden_idx = lane; hidden_idx < hidden_size; hidden_idx += blockDim.x) { - float acc = 0.0f; - for (int out_idx = 0; out_idx < out_dim; ++out_idx) { - acc += grad_row[out_idx] * w_group[hidden_idx * out_dim + out_idx]; - } - grad_input_row[hidden_idx] = acc; - } -} - -} // namespace - -std::vector fabric_grouped_projection_forward_cuda( - at::Tensor sender_cells, at::Tensor grouped_weight, int64_t group_size, int64_t receiver_offset) { - check_cuda_tensor(sender_cells, "sender_cells"); - check_cuda_tensor(grouped_weight, "grouped_weight"); - TORCH_CHECK(sender_cells.scalar_type() == at::kFloat, - "sender_cells must be float32"); - TORCH_CHECK(grouped_weight.scalar_type() == at::kFloat, - "grouped_weight must be float32"); - TORCH_CHECK(sender_cells.dim() == 3, - "sender_cells must have shape [B,N,H]"); - TORCH_CHECK(grouped_weight.dim() == 3, - "grouped_weight must have shape [G,H,M]"); - TORCH_CHECK(group_size > 0, "group_size must be positive"); - TORCH_CHECK(receiver_offset >= 0, "receiver_offset must be non-negative"); - - const int batch_size = static_cast(sender_cells.size(0)); - const int num_cells = static_cast(sender_cells.size(1)); - const int hidden_size = static_cast(sender_cells.size(2)); - const int num_groups = static_cast(grouped_weight.size(0)); - const int out_dim = static_cast(grouped_weight.size(2)); - - TORCH_CHECK(grouped_weight.size(1) == hidden_size, - "grouped_weight hidden size must match sender_cells"); - TORCH_CHECK(receiver_offset + num_cells <= num_groups * group_size, - "grouped_weight/group_size must cover sender cell window"); - - auto projected = - at::zeros({batch_size, num_cells, out_dim}, sender_cells.options()); - const int num_rows = batch_size * num_cells; - const int lane_group = lane_group_size(out_dim); - const int rows_per_block = std::max(1, 256 / lane_group); - const dim3 block(lane_group, rows_per_block); - const int blocks = (num_rows + rows_per_block - 1) / rows_per_block; - auto stream = at::cuda::getCurrentCUDAStream(); - grouped_projection_forward_kernel<<>>( - sender_cells.data_ptr(), grouped_weight.data_ptr(), - projected.data_ptr(), num_rows, num_cells, hidden_size, out_dim, - static_cast(group_size), static_cast(receiver_offset)); - check_launch("grouped_projection_forward_kernel"); - return {projected}; -} - -std::vector fabric_grouped_projection_backward_input_cuda( - at::Tensor grad_output, at::Tensor grouped_weight, int64_t group_size, int64_t receiver_offset) { - check_cuda_tensor(grad_output, "grad_output"); - check_cuda_tensor(grouped_weight, "grouped_weight"); - TORCH_CHECK(grad_output.scalar_type() == at::kFloat, - "grad_output must be float32"); - TORCH_CHECK(grouped_weight.scalar_type() == at::kFloat, - "grouped_weight must be float32"); - TORCH_CHECK(grad_output.dim() == 3, - "grad_output must have shape [B,N,M]"); - TORCH_CHECK(grouped_weight.dim() == 3, - "grouped_weight must have shape [G,H,M]"); - TORCH_CHECK(group_size > 0, "group_size must be positive"); - TORCH_CHECK(receiver_offset >= 0, "receiver_offset must be non-negative"); - - const int batch_size = static_cast(grad_output.size(0)); - const int num_cells = static_cast(grad_output.size(1)); - const int out_dim = static_cast(grad_output.size(2)); - const int num_groups = static_cast(grouped_weight.size(0)); - const int hidden_size = static_cast(grouped_weight.size(1)); - - TORCH_CHECK(grouped_weight.size(2) == out_dim, - "grouped_weight output dim must match grad_output"); - TORCH_CHECK(receiver_offset + num_cells <= num_groups * group_size, - "grouped_weight/group_size must cover sender cell window"); - - auto grad_input = - at::zeros({batch_size, num_cells, hidden_size}, grad_output.options()); - const int num_rows = batch_size * num_cells; - const int lane_group = lane_group_size(hidden_size); - const int rows_per_block = std::max(1, 256 / lane_group); - const dim3 block(lane_group, rows_per_block); - const int blocks = (num_rows + rows_per_block - 1) / rows_per_block; - auto stream = at::cuda::getCurrentCUDAStream(); - grouped_projection_backward_input_kernel<<>>( - grad_output.data_ptr(), grouped_weight.data_ptr(), - grad_input.data_ptr(), num_rows, num_cells, hidden_size, out_dim, - static_cast(group_size), static_cast(receiver_offset)); - check_launch("grouped_projection_backward_input_kernel"); - return {grad_input}; -} diff --git a/src/cortical/fabric/backend/cuda/projection/input_projection_backends.cuh b/src/cortical/fabric/backend/cuda/projection/input_projection_backends.cuh index 6ea1eabb..7f576b55 100644 --- a/src/cortical/fabric/backend/cuda/projection/input_projection_backends.cuh +++ b/src/cortical/fabric/backend/cuda/projection/input_projection_backends.cuh @@ -1,12 +1,12 @@ #pragma once -#include "cortical/fabric/backend/cuda/execution/common.cuh" +#include "cortical/fabric/backend/cuda/contracts/common.cuh" namespace fabric { -__device__ inline int input_projection_dim(const TensorTable& params, int fallback_dim) { +__device__ inline int input_projection_dim(const TensorTable& params, int default_dim) { if (params.count <= 0) { - return fallback_dim; + return default_dim; } const auto weight = tensor_ref(params, 0); if (weight.ndim == 3) { @@ -15,7 +15,7 @@ __device__ inline int input_projection_dim(const TensorTable& params, int fallba if (weight.ndim == 2) { return static_cast(weight.size[0]); } - return fallback_dim; + return default_dim; } __device__ inline void project_message_to_cell_input( diff --git a/src/cortical/fabric/backend/cuda/projection/linear_readout_backend.cuh b/src/cortical/fabric/backend/cuda/projection/linear_readout_backend.cuh index a4e7cf8e..9d777f03 100644 --- a/src/cortical/fabric/backend/cuda/projection/linear_readout_backend.cuh +++ b/src/cortical/fabric/backend/cuda/projection/linear_readout_backend.cuh @@ -2,7 +2,7 @@ #include -#include "cortical/fabric/backend/cuda/execution/common.cuh" +#include "cortical/fabric/backend/cuda/contracts/common.cuh" namespace fabric { diff --git a/src/cortical/fabric/backend/cuda/projection/public_projection_backends.cuh b/src/cortical/fabric/backend/cuda/projection/public_projection_backends.cuh index bd32203d..d114893a 100644 --- a/src/cortical/fabric/backend/cuda/projection/public_projection_backends.cuh +++ b/src/cortical/fabric/backend/cuda/projection/public_projection_backends.cuh @@ -1,6 +1,6 @@ #pragma once -#include "cortical/fabric/backend/cuda/execution/common.cuh" +#include "cortical/fabric/backend/cuda/contracts/common.cuh" namespace fabric { @@ -49,7 +49,8 @@ __device__ inline float project_public_kv_from_raw_public( __device__ inline void publish_from_raw_public( PublicProjectionKind kind, int b, - int receiver, + int output_receiver, + int param_receiver, const float* raw_public, int raw_public_dim, const TensorTable& projection_params, @@ -62,13 +63,13 @@ __device__ inline void publish_from_raw_public( const int hidden_dim = static_cast(hidden_out.size[2]); if (kind == PublicProjectionKind::HiddenIdentity) { for (int h = 0; h < hidden_dim; ++h) { - hidden_out.at(b, receiver, h) = h < raw_public_dim ? raw_public[h] : 0.0f; + hidden_out.at(b, output_receiver, h) = h < raw_public_dim ? raw_public[h] : 0.0f; } const auto direct_weight = tensor_ref(projection_params, 0); const auto grouped_weight = tensor_ref(projection_params, 1); const auto group_size = tensor_ref(projection_params, 2).at(0); const bool use_grouped = projection_params.count > 1 && grouped_weight.ndim == 3 && grouped_weight.size[0] > 0 && group_size > 1; - const int receiver_group = use_grouped ? receiver / group_size : receiver; + const int receiver_group = use_grouped ? param_receiver / group_size : param_receiver; for (int d = 0; d < head_dim + value_dim; ++d) { float acc = 0.0f; if (use_grouped) { @@ -77,13 +78,13 @@ __device__ inline void publish_from_raw_public( } } else { for (int in_d = 0; in_d < raw_public_dim; ++in_d) { - acc += raw_public[in_d] * direct_weight.at(receiver, in_d, d); + acc += raw_public[in_d] * direct_weight.at(param_receiver, in_d, d); } } if (d < head_dim) { - recurrent_k_out.at(b, receiver, d) = acc; + recurrent_k_out.at(b, output_receiver, d) = acc; } else { - recurrent_v_out.at(b, receiver, d - head_dim) = acc; + recurrent_v_out.at(b, output_receiver, d - head_dim) = acc; } } return; @@ -94,21 +95,21 @@ __device__ inline void publish_from_raw_public( const auto kv_weight = tensor_ref(projection_params, 2); const auto kv_bias = tensor_ref(projection_params, 3); for (int h = 0; h < hidden_dim; ++h) { - float acc = hidden_bias.at(receiver, h); + float acc = hidden_bias.at(param_receiver, h); for (int in_d = 0; in_d < raw_public_dim; ++in_d) { - acc += raw_public[in_d] * hidden_weight.at(receiver, in_d, h); + acc += raw_public[in_d] * hidden_weight.at(param_receiver, in_d, h); } - hidden_out.at(b, receiver, h) = acc; + hidden_out.at(b, output_receiver, h) = acc; } for (int d = 0; d < head_dim + value_dim; ++d) { - float acc = kv_bias.at(receiver, d); + float acc = kv_bias.at(param_receiver, d); for (int in_d = 0; in_d < raw_public_dim; ++in_d) { - acc += raw_public[in_d] * kv_weight.at(receiver, in_d, d); + acc += raw_public[in_d] * kv_weight.at(param_receiver, in_d, d); } if (d < head_dim) { - recurrent_k_out.at(b, receiver, d) = acc; + recurrent_k_out.at(b, output_receiver, d) = acc; } else { - recurrent_v_out.at(b, receiver, d - head_dim) = acc; + recurrent_v_out.at(b, output_receiver, d - head_dim) = acc; } } } diff --git a/src/cortical/fabric/backend/cuda/projection/readout_backend_contract.cuh b/src/cortical/fabric/backend/cuda/projection/readout_backend_contract.cuh index d20554f1..abf0ca20 100644 --- a/src/cortical/fabric/backend/cuda/projection/readout_backend_contract.cuh +++ b/src/cortical/fabric/backend/cuda/projection/readout_backend_contract.cuh @@ -1,6 +1,6 @@ #pragma once -#include "cortical/fabric/backend/cuda/execution/common.cuh" +#include "cortical/fabric/backend/cuda/contracts/common.cuh" namespace fabric { diff --git a/src/cortical/fabric/backend/cuda/projection/registry.py b/src/cortical/fabric/backend/cuda/projection/registry.py deleted file mode 100644 index 8ae37682..00000000 --- a/src/cortical/fabric/backend/cuda/projection/registry.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass - -import torch - -ReadoutBackend = Callable[..., torch.Tensor] - - -@dataclass(frozen=True) -class ReadoutBackendSpec: - name: str - backend: ReadoutBackend - - -_READOUT_BACKENDS: dict[str, ReadoutBackendSpec] = {} - - -def register_readout_backend(name: str, backend: ReadoutBackend) -> None: - _READOUT_BACKENDS[name] = ReadoutBackendSpec(name=name, backend=backend) - - -def get_readout_backend(name: str) -> ReadoutBackendSpec: - try: - return _READOUT_BACKENDS[name] - except KeyError as exc: - raise ValueError(f"Unsupported Fabric readout backend {name}") from exc - - -__all__ = [ - "get_readout_backend", - "ReadoutBackendSpec", - "register_readout_backend", -] diff --git a/src/cortical/fabric/backend/cuda/projections.py b/src/cortical/fabric/backend/cuda/projections.py deleted file mode 100644 index 45d90197..00000000 --- a/src/cortical/fabric/backend/cuda/projections.py +++ /dev/null @@ -1,235 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -import torch - -from cortical.fabric.backend.cuda import ( - fabric_grouped_projection_cuda, - fabric_grouped_projection_forward_cuda, -) -from cortical.fabric.backend.cuda.ops import ( - dense_affine_cuda, - receiver_major_affine_small_batch_cuda, -) -from cortical.fabric.backend.cuda.projection.receiver_major_gates import ( - SMALL_BATCH_RECEIVER_MAJOR_NO_BIAS_PROJECTION_MIN_WORK, - SMALL_BATCH_RECEIVER_MAJOR_PROJECTION_MAX_BATCH, -) - - -def _dense_affine_forward_or_none( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - *, - group_size: int = 1, -) -> Optional[torch.Tensor]: - if ( - not torch.is_grad_enabled() - and input.is_cuda - and input.dtype == torch.float32 - and weight.is_cuda - and weight.dtype == torch.float32 - and (bias is None or (bias.is_cuda and bias.dtype == torch.float32)) - ): - return dense_affine_cuda(input, weight, bias, layout="receiver_major", group_size=group_size).output - return None - - -def _small_batch_receiver_major_projection_or_none( - input: torch.Tensor, - weight: torch.Tensor, -) -> Optional[torch.Tensor]: - if ( - torch.is_grad_enabled() - and input.is_cuda - and input.dtype == torch.float32 - and weight.is_cuda - and weight.dtype == torch.float32 - and input.dim() == 3 - and weight.dim() == 3 - and input.shape[1] == weight.shape[0] - and input.shape[2] == weight.shape[1] - and 0 < input.shape[0] <= SMALL_BATCH_RECEIVER_MAJOR_PROJECTION_MAX_BATCH - and input.shape[2] >= 32 - and weight.shape[2] >= 8 - and input.shape[1] * input.shape[2] * weight.shape[2] >= SMALL_BATCH_RECEIVER_MAJOR_NO_BIAS_PROJECTION_MIN_WORK - ): - return receiver_major_affine_small_batch_cuda(input, weight) - return None - - -def project_grouped_sender_cells( - sender_cells_step: torch.Tensor, - grouped_weight: torch.Tensor, - *, - group_size: int, - head_dim: int, - value_dim: int, -) -> torch.Tensor: - batch_size, num_cells, hidden_size = sender_cells_step.shape - num_groups = int(grouped_weight.shape[0]) - expected_cells = num_groups * group_size - if expected_cells != num_cells: - raise ValueError(f"Grouped sender projection expects {expected_cells} cells, got {num_cells}") - use_grouped_cuda = ( - sender_cells_step.is_cuda and sender_cells_step.dtype == torch.float32 and grouped_weight.dtype == torch.float32 - ) - if use_grouped_cuda: - if not torch.is_grad_enabled(): - return fabric_grouped_projection_forward_cuda( - sender_cells_step, - grouped_weight, - group_size=group_size, - ) - return fabric_grouped_projection_cuda( - sender_cells_step, - grouped_weight, - group_size=group_size, - ) - grouped_cells = ( - sender_cells_step.reshape(batch_size, num_groups, group_size, hidden_size) - .permute(1, 0, 2, 3) - .reshape(num_groups, batch_size * group_size, hidden_size) - ) - projected = torch.bmm(grouped_cells, grouped_weight) - return ( - projected.reshape( - num_groups, - batch_size, - group_size, - head_dim + value_dim, - ) - .permute(1, 0, 2, 3) - .reshape(batch_size, num_cells, head_dim + value_dim) - ) - - -def project_sender_kv_from_cells_step( - sender_cells_step: torch.Tensor, - *, - sender_input_to_kv_weight: Optional[torch.Tensor], - grouped_sender_input_to_kv_weight: Optional[torch.Tensor] = None, - sender_group_size: int = 1, - head_dim: int, - value_dim: int, - contiguous_kv: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - if grouped_sender_input_to_kv_weight is not None and sender_group_size > 1: - kv_all = project_grouped_sender_cells( - sender_cells_step, - grouped_sender_input_to_kv_weight, - group_size=sender_group_size, - head_dim=head_dim, - value_dim=value_dim, - ) - else: - assert sender_input_to_kv_weight is not None - dense = _dense_affine_forward_or_none(sender_cells_step, sender_input_to_kv_weight) - if dense is not None: - kv_all = dense - else: - small_batch = _small_batch_receiver_major_projection_or_none( - sender_cells_step, - sender_input_to_kv_weight, - ) - kv_all = ( - small_batch - if small_batch is not None - else torch.bmm(sender_cells_step.transpose(0, 1), sender_input_to_kv_weight).transpose(0, 1) - ) - k, v = kv_all.split((head_dim, value_dim), dim=-1) - if contiguous_kv: - return k.contiguous(), v.contiguous() - return k, v - - -def project_sender_kv_from_cells_sequence( - sender_cells_seq: torch.Tensor, - *, - sender_input_to_kv_weight: Optional[torch.Tensor], - grouped_sender_input_to_kv_weight: Optional[torch.Tensor] = None, - sender_group_size: int = 1, - head_dim: int, - value_dim: int, -) -> tuple[torch.Tensor, torch.Tensor]: - if sender_cells_seq.dim() != 4: - raise ValueError(f"Sequence sender projection expects [B,T,N,H], got {tuple(sender_cells_seq.shape)}") - batch_size, time_steps, num_cells, hidden_size = sender_cells_seq.shape - sender_flat = sender_cells_seq.reshape(batch_size * time_steps, num_cells, hidden_size) - k_flat, v_flat = project_sender_kv_from_cells_step( - sender_flat, - sender_input_to_kv_weight=sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=grouped_sender_input_to_kv_weight, - sender_group_size=sender_group_size, - head_dim=head_dim, - value_dim=value_dim, - contiguous_kv=False, - ) - return ( - k_flat.view(batch_size, time_steps, num_cells, head_dim), - v_flat.view(batch_size, time_steps, num_cells, value_dim), - ) - - -def project_recurrent_kv_from_preproj_step( - recurrent_preproj_step: torch.Tensor, - *, - recurrent_preproj_to_kv_weight: torch.Tensor, - recurrent_preproj_to_kv_bias: torch.Tensor, - head_dim: int, - value_dim: int, -) -> tuple[torch.Tensor, torch.Tensor]: - dense = _dense_affine_forward_or_none( - recurrent_preproj_step, - recurrent_preproj_to_kv_weight, - recurrent_preproj_to_kv_bias, - ) - if dense is not None: - kv_all = dense - else: - kv_all = torch.bmm( - recurrent_preproj_step.transpose(0, 1), - recurrent_preproj_to_kv_weight, - ).transpose(0, 1) - kv_all = kv_all + recurrent_preproj_to_kv_bias.view(1, -1, head_dim + value_dim) - return kv_all.split((head_dim, value_dim), dim=-1) - - -def project_recurrent_hidden_from_preproj_step( - recurrent_preproj_step: torch.Tensor, - *, - out_proj_weight_t: torch.Tensor, - out_proj_bias: torch.Tensor, -) -> torch.Tensor: - dense = _dense_affine_forward_or_none( - recurrent_preproj_step, - out_proj_weight_t, - out_proj_bias, - ) - if dense is not None: - return dense - projected = torch.bmm( - recurrent_preproj_step.transpose(0, 1), - out_proj_weight_t, - ).transpose(0, 1) - return projected + out_proj_bias.view(1, -1, out_proj_bias.shape[-1]) - - -def project_output_cells_step_raw( - output_msg: torch.Tensor, - *, - value_to_output_weight: torch.Tensor, - output_cell_bias: torch.Tensor, - hidden_size: int, -) -> torch.Tensor: - dense = _dense_affine_forward_or_none(output_msg, value_to_output_weight, output_cell_bias) - if dense is not None: - return dense - projected = torch.bmm(output_msg.transpose(0, 1), value_to_output_weight).transpose(0, 1) - return projected + output_cell_bias.view( - 1, - -1, - hidden_size, - ) diff --git a/src/cortical/fabric/backend/cuda/recurrence_executor.py b/src/cortical/fabric/backend/cuda/recurrence_executor.py deleted file mode 100644 index 7fa38635..00000000 --- a/src/cortical/fabric/backend/cuda/recurrence_executor.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -import math -from dataclasses import dataclass -from typing import Any, Callable - -import torch - -from cortical.fabric.backend.cuda.execution import ( - FabricExecutionRequest, - normalize_launch_request, - run_registered_execution, -) -from cortical.fabric.backend.planner import PlannedFabricExecution -from cortical.fabric.backend.reuse import ExecutionFamily, MathBackend, ReuseScope -from cortical.types import ResetMask - -_STAGEABLE_REUSE_SCOPES = frozenset( - { - ReuseScope.FABRIC_GLOBAL, - ReuseScope.GROUP_SHARED, - ReuseScope.RECEIVER_LOCAL, - ReuseScope.PORT_LOCAL, - } -) -_MIN_STAGED_LAUNCH_BLOCKS = 8 - - -@dataclass(frozen=True) -class BackendSequenceRunnerProvider: - supported_execution_families: tuple[ExecutionFamily, ...] - supported_math_backends: tuple[MathBackend, ...] - build_execution_request: Callable[..., FabricExecutionRequest] | None = None - - -def _select_execution_semantics( - *, - execution_family: ExecutionFamily, - time_steps: int, -) -> tuple[str, str]: - del time_steps - if execution_family == ExecutionFamily.SEQUENCE_MAJOR: - return "receiver_owned", "persistent_scan" - if execution_family == ExecutionFamily.RECEIVER_MAJOR: - return "receiver_owned", "persistent_scan" - if execution_family == ExecutionFamily.EDGE_MAJOR: - return "edge_owned", "persistent_scan" - raise RuntimeError(f"Unsupported Fabric execution family {execution_family.value}") - - -def normalize_resets( - resets: ResetMask | None, - *, - batch: int, - seq: int, - device: torch.device, -) -> torch.Tensor | None: - if resets is None: - return None - mask = torch.as_tensor(resets, device=device, dtype=torch.bool) - if mask.dim() == 1: - return mask.view(batch, 1).expand(batch, seq) - if mask.dim() == 2 and mask.shape == (batch, seq): - return mask - raise ValueError(f"Expected resets with shape [B] or [B,T], got {tuple(mask.shape)}") - - -def _launch_replication_factor(plan: PlannedFabricExecution) -> int: - if not plan.bucket_plans: - return 1 - return max(1, max(bucket_plan.replication_factor for bucket_plan in plan.bucket_plans)) - - -def _single_launch_plan(plan: PlannedFabricExecution): - if not plan.bucket_plans: - raise RuntimeError("Fabric CUDA launch requires at least one planned backend bucket") - first = plan.bucket_plans[0] - for bucket_plan in plan.bucket_plans[1:]: - if ( - bucket_plan.receiver_tile != first.receiver_tile - or bucket_plan.batch_tile != first.batch_tile - or bucket_plan.edge_tile != first.edge_tile - or bucket_plan.hidden_chunk != first.hidden_chunk - or bucket_plan.state_receiver_tile != first.state_receiver_tile - or bucket_plan.state_batch_tile != first.state_batch_tile - or bucket_plan.state_hidden_chunk != first.state_hidden_chunk - or bucket_plan.state_static_stage_mode != first.state_static_stage_mode - or bucket_plan.emit_receiver_tile != first.emit_receiver_tile - or bucket_plan.emit_batch_tile != first.emit_batch_tile - or bucket_plan.emit_hidden_chunk != first.emit_hidden_chunk - or bucket_plan.emit_static_stage_mode != first.emit_static_stage_mode - or bucket_plan.public_receiver_tile != first.public_receiver_tile - or bucket_plan.public_batch_tile != first.public_batch_tile - or bucket_plan.readout_mode != first.readout_mode - or bucket_plan.readout_port_tile != first.readout_port_tile - or bucket_plan.readout_output_chunk != first.readout_output_chunk - or bucket_plan.cell_static_stage_mode != first.cell_static_stage_mode - ): - raise RuntimeError("Fabric CUDA dispatcher does not support mixed launch tile plans in one request") - return first - - -def backend_surface_launch_policy( - runtime: Any, - *, - population_name: str, - planned_backend_execution: PlannedFabricExecution, -) -> tuple[bool, int]: - cell_spec = runtime._backend_population_specs[population_name] - execution_families = {bucket_plan.execution_family for bucket_plan in planned_backend_execution.bucket_plans} - stageable_scopes = tuple(scope for scope in cell_spec.reuse_scopes.values() if scope in _STAGEABLE_REUSE_SCOPES) - stage_receiver_static = bool(stageable_scopes) and bool( - execution_families & {ExecutionFamily.RECEIVER_MAJOR, ExecutionFamily.SEQUENCE_MAJOR} - ) - replication_factor = _launch_replication_factor(planned_backend_execution) if stage_receiver_static else 1 - return stage_receiver_static, replication_factor - - -def effective_staged_replication_factor( - *, - stage_receiver_static: bool, - replication_factor: int, - batch_size: int, - batch_rows_per_block: int, - min_launch_blocks: int = _MIN_STAGED_LAUNCH_BLOCKS, -) -> int: - if not stage_receiver_static: - return 1 - batch_tiles = max(1, math.ceil(max(1, batch_size) / max(1, batch_rows_per_block))) - return min(batch_tiles, max(replication_factor, min(min_launch_blocks, batch_tiles))) - - -def build_backend_sequence_runner( - *, - runtime: Any, - population_name: str, - provider: BackendSequenceRunnerProvider, -) -> tuple[ - tuple[ExecutionFamily, ...], - tuple[MathBackend, ...], - Callable[..., tuple[Any, ...]], -]: - def run_backend_sequence( - *, - planned_backend_execution: PlannedFabricExecution, - input_k_seq: torch.Tensor, - input_v_seq: torch.Tensor, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None = None, - initial_recurrent_v: torch.Tensor | None = None, - resets: ResetMask | None, - ) -> tuple[Any, ...]: - execution_families = {bucket_plan.execution_family for bucket_plan in planned_backend_execution.bucket_plans} - math_backends = {bucket_plan.math_backend for bucket_plan in planned_backend_execution.bucket_plans} - if len(execution_families) != 1 or len(math_backends) != 1: - raise RuntimeError( - "Supported Fabric backend received a mixed execution-family/math-backend plan, " - f"execution_families={sorted(family.value for family in execution_families)} " - f"math={sorted(backend.value for backend in math_backends)}" - ) - if provider.build_execution_request is None: - raise RuntimeError(f"Backend recurrence provider for {population_name} does not build execution requests") - stage_receiver_static, replication_factor = backend_surface_launch_policy( - runtime, - population_name=population_name, - planned_backend_execution=planned_backend_execution, - ) - batch_size, time_steps = input_k_seq.shape[:2] - resets_bt = normalize_resets(resets, batch=batch_size, seq=time_steps, device=input_k_seq.device) - if resets_bt is None: - with torch.profiler.record_function("fabric.glue.default_resets_zero"): - resets_u8 = torch.zeros(batch_size, time_steps, device=input_k_seq.device, dtype=torch.uint8) - else: - with torch.profiler.record_function("fabric.glue.reset_mask_to_u8"): - resets_u8 = resets_bt.to(dtype=torch.uint8) - reset_rows_present = False if resets_bt is None else bool(resets_bt.any().item()) - execution_family = next(iter(execution_families)) - math_backend = next(iter(math_backends)) - launch_plan = _single_launch_plan(planned_backend_execution) - batch_rows_per_block = launch_plan.state_batch_tile - effective_replication = effective_staged_replication_factor( - stage_receiver_static=stage_receiver_static, - replication_factor=replication_factor, - batch_size=batch_size, - batch_rows_per_block=batch_rows_per_block, - ) - spatial_ownership, temporal_execution = _select_execution_semantics( - execution_family=execution_family, - time_steps=time_steps, - ) - request = provider.build_execution_request( - input_k_seq=input_k_seq, - input_v_seq=input_v_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - resets_u8=resets_u8, - reset_rows_present=reset_rows_present, - stage_receiver_static=stage_receiver_static, - replication_factor=effective_replication, - receiver_tile=launch_plan.receiver_tile, - batch_tile=launch_plan.batch_tile, - edge_tile=launch_plan.edge_tile, - hidden_chunk=launch_plan.hidden_chunk, - state_receiver_tile=launch_plan.state_receiver_tile, - state_batch_tile=launch_plan.state_batch_tile, - state_hidden_chunk=launch_plan.state_hidden_chunk, - state_static_stage_mode=launch_plan.state_static_stage_mode if stage_receiver_static else "disabled", - emit_receiver_tile=launch_plan.emit_receiver_tile, - emit_batch_tile=launch_plan.emit_batch_tile, - emit_hidden_chunk=launch_plan.emit_hidden_chunk, - emit_static_stage_mode=launch_plan.emit_static_stage_mode if stage_receiver_static else "disabled", - public_receiver_tile=launch_plan.public_receiver_tile, - public_batch_tile=launch_plan.public_batch_tile, - readout_mode=launch_plan.readout_mode, - readout_port_tile=launch_plan.readout_port_tile, - readout_output_chunk=launch_plan.readout_output_chunk, - cell_static_stage_mode=launch_plan.cell_static_stage_mode if stage_receiver_static else "disabled", - ) - request = normalize_launch_request(request) - return run_registered_execution( - spatial_ownership=spatial_ownership, - temporal_execution=temporal_execution, - math_backend=math_backend, - request=request, - ) - - run_backend_sequence.fabric_graph_capture_safe = False - run_backend_sequence.fabric_backend_engine_kind = "generic_dispatcher" - return provider.supported_execution_families, provider.supported_math_backends, run_backend_sequence - - -__all__ = [ - "BackendSequenceRunnerProvider", - "backend_surface_launch_policy", - "build_backend_sequence_runner", - "effective_staged_replication_factor", - "normalize_resets", - "_single_launch_plan", -] diff --git a/src/cortical/fabric/backend/cuda/reference/slstm_parity_reference.py b/src/cortical/fabric/backend/cuda/reference/slstm_parity_reference.py deleted file mode 100644 index 1a298c4c..00000000 --- a/src/cortical/fabric/backend/cuda/reference/slstm_parity_reference.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -import torch -from cortical.fabric.backend.cuda.message_passing.local_message_cuda import _load_ext as _load_local_message_ext - - -def _local_message_partitioned_step_backward_manual( - *, - q: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - receiver_sender_idx: torch.Tensor, - sender_receiver_idx: torch.Tensor, - offset_distance: torch.Tensor, - offset_delay: torch.Tensor, - grad_msg: torch.Tensor, - msg: torch.Tensor, - num_input_senders: int, - distance_scale: float, - use_delay: bool, - reset_mask: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - del msg - batch_size = input_k.shape[0] - total_senders = input_k.shape[1] + recurrent_k.shape[1] - k_all = torch.cat((input_k, recurrent_k), dim=1) - v_all = torch.cat((input_v, recurrent_v), dim=1) - grad_q = torch.zeros_like(q) - grad_k_all = torch.zeros_like(k_all) - grad_v_all = torch.zeros_like(v_all) - step_flat = torch.ones(batch_size, device=q.device, dtype=torch.long) - ext = _load_local_message_ext() - - def run_subset( - subset_mask: torch.Tensor, - *, - subset_receiver_sender_idx: torch.Tensor, - subset_sender_receiver_idx: torch.Tensor, - recurrent_sender_bank_zero_mask: torch.Tensor | None = None, - recurrent_sender_grad_zero_mask: torch.Tensor | None = None, - ) -> None: - nonlocal grad_q, grad_k_all, grad_v_all - subset_idx = torch.nonzero(subset_mask, as_tuple=False).reshape(-1) - if subset_idx.numel() == 0: - return - subset_input_k = input_k.index_select(0, subset_idx).contiguous() - subset_input_v = input_v.index_select(0, subset_idx).contiguous() - subset_recurrent_k = recurrent_k.index_select(0, subset_idx).contiguous() - subset_recurrent_v = recurrent_v.index_select(0, subset_idx).contiguous() - if recurrent_sender_bank_zero_mask is not None: - subset_bank_zero_mask = recurrent_sender_bank_zero_mask.index_select(0, subset_idx) - if bool(subset_bank_zero_mask.any()): - subset_recurrent_k[subset_bank_zero_mask] = 0 - subset_recurrent_v[subset_bank_zero_mask] = 0 - subset_grad_q, subset_grad_k_all, subset_grad_v_all = ext.backward( - grad_msg.index_select(0, subset_idx).contiguous(), - q.contiguous(), - torch.cat((subset_input_k, subset_recurrent_k), dim=1), - torch.cat((subset_input_v, subset_recurrent_v), dim=1), - subset_receiver_sender_idx.contiguous(), - subset_sender_receiver_idx.contiguous(), - offset_distance.contiguous(), - offset_delay.contiguous(), - step_flat.index_select(0, subset_idx).contiguous(), - float(distance_scale), - bool(use_delay), - ) - if recurrent_sender_grad_zero_mask is not None: - subset_grad_zero_mask = recurrent_sender_grad_zero_mask.index_select(0, subset_idx) - if bool(subset_grad_zero_mask.any()): - subset_grad_k_all[subset_grad_zero_mask, num_input_senders:] = 0 - subset_grad_v_all[subset_grad_zero_mask, num_input_senders:] = 0 - grad_q += subset_grad_q - grad_k_all.index_copy_(0, subset_idx, subset_grad_k_all) - grad_v_all.index_copy_(0, subset_idx, subset_grad_v_all) - - if reset_mask is None or not bool(reset_mask.any()): - run_subset( - torch.ones(batch_size, device=q.device, dtype=torch.bool), - subset_receiver_sender_idx=receiver_sender_idx, - subset_sender_receiver_idx=sender_receiver_idx, - ) - else: - run_subset( - torch.ones(batch_size, device=q.device, dtype=torch.bool), - subset_receiver_sender_idx=receiver_sender_idx, - subset_sender_receiver_idx=sender_receiver_idx, - recurrent_sender_bank_zero_mask=reset_mask, - recurrent_sender_grad_zero_mask=reset_mask, - ) - - grad_input_k, grad_recurrent_k = grad_k_all.split((num_input_senders, total_senders - num_input_senders), dim=1) - grad_input_v, grad_recurrent_v = grad_v_all.split((num_input_senders, total_senders - num_input_senders), dim=1) - return grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v - - -__all__ = ["_local_message_partitioned_step_backward_manual"] diff --git a/src/cortical/fabric/backend/cuda/registry/__init__.py b/src/cortical/fabric/backend/cuda/registry/__init__.py deleted file mode 100644 index c9c2ef67..00000000 --- a/src/cortical/fabric/backend/cuda/registry/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__all__: list[str] = [] diff --git a/src/cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cpp b/src/cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cpp deleted file mode 100644 index 88a4bcc1..00000000 --- a/src/cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include "cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cuh" - -#include -#include - -namespace fabric { - -namespace { - -struct CellDispatchRegistryStorage { - std::array entries{}; - std::array registered{}; -}; - -CellDispatchRegistryStorage& registry_storage() { - static CellDispatchRegistryStorage storage; - return storage; -} - -} // namespace - -void register_cell_core_dispatch_entry(int cell_core_id, const CellCoreDispatchEntry& entry) { - auto& storage = registry_storage(); - TORCH_CHECK( - cell_core_id >= 0 && static_cast(cell_core_id) < storage.entries.size(), - "unsupported cell core id: ", - cell_core_id); - const auto index = static_cast(cell_core_id); - storage.entries[index] = entry; - storage.registered[index] = true; -} - -const CellCoreDispatchEntry& lookup_cell_core_dispatch_entry(int cell_core_id) { - auto& storage = registry_storage(); - TORCH_CHECK( - cell_core_id >= 0 && static_cast(cell_core_id) < storage.entries.size(), - "unsupported cell core id: ", - cell_core_id); - const auto index = static_cast(cell_core_id); - TORCH_CHECK(storage.registered[index], "unregistered cell core id: ", cell_core_id); - return storage.entries[index]; -} - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cuh b/src/cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cuh deleted file mode 100644 index ee136276..00000000 --- a/src/cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cuh +++ /dev/null @@ -1,108 +0,0 @@ -#pragma once - -#include - -#include - -#include "cortical/fabric/backend/cuda/contracts/cell.cuh" - -namespace fabric { - -using ReceiverStateEmitFn = void (*)( - const float*, - TensorTable, - TensorTable, - TensorTable, - TensorTable, - ExecutionPlan, - const at::Tensor&, - int, - int, - int, - int, - int, - int, - float*, - float*, - float*, - int, - cudaStream_t); - -using ReceiverStateUpdateFn = void (*)( - const float*, - TensorTable, - TensorTable, - TensorTable, - TensorTable, - ExecutionPlan, - const at::Tensor&, - bool, - int, - int, - int, - int, - float*, - int, - int, - int, - int, - cudaStream_t); - -using ReceiverStateUpdateEmitFn = void (*)( - const float*, - TensorTable, - TensorTable, - TensorTable, - TensorTable, - ExecutionPlan, - const at::Tensor&, - bool, - bool, - int, - int, - int, - int, - int, - float*, - int, - int, - int, - int, - cudaStream_t); - -using ReceiverReduceEmitFn = void (*)( - TensorTable, - TensorTable, - ExecutionPlan, - int, - int, - int, - float*, - float*, - float*, - int, - int, - cudaStream_t); - -using PhaseStaticBytesFn = int (*)(const std::vector&, int); -using CellTransitionIRFn = fabric::cuda::nn::CellTransitionIR (*)( - const std::vector&, - int, - int, - int); - -struct CellCoreDispatchEntry { - ReceiverStateEmitFn receiver_state_emit; - ReceiverStateUpdateFn receiver_state_update; - ReceiverStateUpdateEmitFn receiver_state_update_emit; - ReceiverReduceEmitFn receiver_reduce_emit; - PhaseStaticBytesFn state_static_bytes; - PhaseStaticBytesFn emit_static_bytes; - CellTransitionIRFn cell_transition_ir; - int reduction_stats_dim; -}; - -void register_cell_core_dispatch_entry(int cell_core_id, const CellCoreDispatchEntry& entry); -const CellCoreDispatchEntry& lookup_cell_core_dispatch_entry(int cell_core_id); - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/registry/cell_registration_helpers.cuh b/src/cortical/fabric/backend/cuda/registry/cell_registration_helpers.cuh deleted file mode 100644 index 2fd8bd69..00000000 --- a/src/cortical/fabric/backend/cuda/registry/cell_registration_helpers.cuh +++ /dev/null @@ -1,197 +0,0 @@ -#pragma once - -#include "cortical/fabric/backend/cuda/registry/cell_dispatch_registry.cuh" -#include "cortical/fabric/backend/cuda/execution/receiver_owned_stepwise.cuh" - -namespace fabric { - -template -void launch_receiver_state_emit_registered( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - int state_epilogue_policy, - float* raw_public, - float* partial_stats, - float* reduced_stats, - int num_hidden_chunks, - cudaStream_t stream) { - stepwise_detail::launch_receiver_state_emit_typed( - projected_message, - state_prev, - state_next, - cell_params, - aux, - plan, - resets_u8, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - state_epilogue_policy, - raw_public, - partial_stats, - reduced_stats, - num_hidden_chunks, - stream); -} - -template -void launch_receiver_state_update_registered( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - float* partial_stats, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - stepwise_detail::launch_receiver_state_update_typed( - projected_message, - state_prev, - state_next, - cell_params, - aux, - plan, - resets_u8, - state_prev_is_zero, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - partial_stats, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); -} - -template -void launch_receiver_state_update_emit_registered( - const float* projected_message, - TensorTable state_prev, - TensorTable state_next, - TensorTable cell_params, - TensorTable aux, - ExecutionPlan plan, - const at::Tensor& resets_u8, - bool state_prev_is_zero, - bool materialize_state_output, - int t, - int projected_message_dim, - int raw_public_dim, - int state_static_bytes, - int emit_static_bytes, - float* raw_public, - int num_hidden_chunks, - int receiver_offset, - int receiver_global_offset, - int receiver_count, - cudaStream_t stream) { - stepwise_detail::launch_receiver_state_update_emit_typed( - projected_message, - state_prev, - state_next, - cell_params, - aux, - plan, - resets_u8, - state_prev_is_zero, - materialize_state_output, - t, - projected_message_dim, - raw_public_dim, - state_static_bytes, - emit_static_bytes, - raw_public, - num_hidden_chunks, - receiver_offset, - receiver_global_offset, - receiver_count, - stream); -} - -template -void launch_receiver_reduce_emit_registered( - TensorTable state_next, - TensorTable cell_params, - ExecutionPlan plan, - int projected_message_dim, - int raw_public_dim, - int emit_static_bytes, - float* raw_public, - float* partial_stats, - float* reduced_stats, - int num_hidden_chunks, - int receiver_global_offset, - cudaStream_t stream) { - stepwise_detail::launch_receiver_reduce_emit_typed( - state_next, - cell_params, - plan, - projected_message_dim, - raw_public_dim, - emit_static_bytes, - raw_public, - partial_stats, - reduced_stats, - num_hidden_chunks, - receiver_global_offset, - stream); -} - -template -int cell_state_static_bytes_registered(const std::vector& params, int receivers) { - return CellCore::state_static_bytes_host(params, receivers); -} - -template -int cell_emit_static_bytes_registered(const std::vector& params, int receivers) { - return CellCore::emit_static_bytes_host(params, receivers); -} - -template -fabric::cuda::nn::CellTransitionIR cell_transition_ir_registered( - const std::vector& params, - int receivers, - int projected_message_dim, - int raw_public_dim) { - return CellCore::cell_transition_ir_host(params, receivers, projected_message_dim, raw_public_dim); -} - -template -constexpr CellCoreDispatchEntry make_cell_core_dispatch_entry() { - return CellCoreDispatchEntry{ - &launch_receiver_state_emit_registered, - &launch_receiver_state_update_registered, - &launch_receiver_state_update_emit_registered, - &launch_receiver_reduce_emit_registered, - &cell_state_static_bytes_registered, - &cell_emit_static_bytes_registered, - &cell_transition_ir_registered, - CellCore::kReductionStatsDim, - }; -} - -} // namespace fabric diff --git a/src/cortical/fabric/backend/cuda/runtime_ops.py b/src/cortical/fabric/backend/cuda/runtime_ops.py deleted file mode 100644 index 1e00e360..00000000 --- a/src/cortical/fabric/backend/cuda/runtime_ops.py +++ /dev/null @@ -1,368 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import torch - -from cortical.fabric.backend.cuda import ( - fabric_local_message_cuda, - fabric_local_message_partitioned_cuda, - fabric_sparse_message_cuda, - fabric_sparse_message_partitioned_cuda, -) -from cortical.fabric.backend.cuda.projections import ( - project_grouped_sender_cells, - project_output_cells_step_raw, - project_recurrent_hidden_from_preproj_step, - project_recurrent_kv_from_preproj_step, - project_sender_kv_from_cells_sequence, - project_sender_kv_from_cells_step, -) - - -def _flatten_step_idx( - step_idx: int | torch.Tensor, - *, - batch_size: int, - time_steps: int, - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - if isinstance(step_idx, int): - return torch.full((batch_size * time_steps,), step_idx, device=device, dtype=dtype) - step_tensor = torch.as_tensor(step_idx, device=device, dtype=dtype) - if step_tensor.dim() == 1 and step_tensor.shape[0] == batch_size: - return step_tensor.view(batch_size, 1).expand(batch_size, time_steps).reshape(batch_size * time_steps) - if step_tensor.dim() == 2 and step_tensor.shape == (batch_size, time_steps): - return step_tensor.reshape(batch_size * time_steps) - raise ValueError(f"step_idx tensor must have shape [B] or [B,T], got {tuple(step_tensor.shape)}") - - -def _step_flat( - runtime: Any, - step_idx: int | torch.Tensor, - *, - batch_size: int, - time_steps: int, - tensor: torch.Tensor, -) -> torch.Tensor: - if isinstance(step_idx, int): - return runtime._constant_step_flat( - step_idx, - batch_size=batch_size, - time_steps=time_steps, - device=tensor.device, - dtype=runtime.edge_delay.dtype, - ) - return _flatten_step_idx( - step_idx, - batch_size=batch_size, - time_steps=time_steps, - device=tensor.device, - dtype=runtime.edge_delay.dtype, - ) - - -def compute_messages( - runtime: Any, - z_prev: torch.Tensor, - *, - k_all: torch.Tensor, - v_all: torch.Tensor, - q: torch.Tensor, - step_idx: int | torch.Tensor, -) -> torch.Tensor: - if z_prev.dtype != torch.float32: - raise ValueError(f"Fabric CUDA message kernel requires float32 inputs, got {z_prev.dtype}") - batch_size, time_steps, num_cells, _ = z_prev.shape - step_flat = _step_flat(runtime, step_idx, batch_size=batch_size, time_steps=time_steps, tensor=z_prev) - if runtime._local_message_step_enabled: - sender_k_all = k_all.index_select(1, runtime.sender_cell_idx) - sender_v_all = v_all.index_select(1, runtime.sender_cell_idx) - msg_flat = fabric_local_message_cuda( - q, - sender_k_all, - sender_v_all, - runtime.full_local_sender_idx, - runtime.full_local_receiver_idx_by_sender, - runtime.local_distance, - runtime.local_delay, - step_flat, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=runtime._has_edge_delay, - ).view(batch_size, time_steps, num_cells, runtime.value_dim) - else: - with torch.profiler.record_function("fabric.glue.runtime_sparse_message_clone"): - sparse_k_all = k_all.clone() - sparse_v_all = v_all.clone() - msg_flat = fabric_sparse_message_cuda( - q, - sparse_k_all, - sparse_v_all, - runtime.neighbor_idx, - runtime.neighbor_valid, - runtime.edge_distance, - runtime.edge_delay, - step_flat, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=runtime._has_edge_delay, - ).view(batch_size, time_steps, num_cells, runtime.value_dim) - return runtime.msg_out(msg_flat) - - -def project_sender_kv_from_cells_step_backend( - runtime: Any, - sender_cells_step: torch.Tensor, - *, - sender_input_to_kv_weight: torch.Tensor | None, - grouped_sender_input_to_kv_weight: torch.Tensor | None = None, - sender_group_size: int = 1, - contiguous_kv: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - return project_sender_kv_from_cells_step( - sender_cells_step, - sender_input_to_kv_weight=sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=grouped_sender_input_to_kv_weight, - sender_group_size=sender_group_size, - head_dim=runtime.head_dim, - value_dim=runtime.value_dim, - contiguous_kv=contiguous_kv, - ) - - -def project_sender_kv_from_cells_sequence_backend( - runtime: Any, - sender_cells_seq: torch.Tensor, - *, - sender_input_to_kv_weight: torch.Tensor | None, - grouped_sender_input_to_kv_weight: torch.Tensor | None = None, - sender_group_size: int = 1, -) -> tuple[torch.Tensor, torch.Tensor]: - return project_sender_kv_from_cells_sequence( - sender_cells_seq, - sender_input_to_kv_weight=sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=grouped_sender_input_to_kv_weight, - sender_group_size=sender_group_size, - head_dim=runtime.head_dim, - value_dim=runtime.value_dim, - ) - - -def project_boundary_source_sequence_backend( - runtime: Any, - source_hidden_seq: torch.Tensor, - *, - input_projection_weight: torch.Tensor, - input_projection_bias: torch.Tensor | None, -) -> torch.Tensor: - projected = torch.nn.functional.linear(source_hidden_seq, input_projection_weight, input_projection_bias) - return projected.view( - int(source_hidden_seq.shape[0]), - int(source_hidden_seq.shape[1]), - int(runtime._num_input_cells), - int(runtime.hidden_size), - ) - - -def project_recurrent_kv_from_preproj_step_backend( - runtime: Any, - recurrent_preproj_step: torch.Tensor, - *, - recurrent_preproj_to_kv_weight: torch.Tensor, - recurrent_preproj_to_kv_bias: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - return project_recurrent_kv_from_preproj_step( - recurrent_preproj_step, - recurrent_preproj_to_kv_weight=recurrent_preproj_to_kv_weight, - recurrent_preproj_to_kv_bias=recurrent_preproj_to_kv_bias, - head_dim=runtime.head_dim, - value_dim=runtime.value_dim, - ) - - -def project_recurrent_hidden_from_preproj_step_backend( - _runtime: Any, - recurrent_preproj_step: torch.Tensor, - *, - out_proj_weight_t: torch.Tensor, - out_proj_bias: torch.Tensor, -) -> torch.Tensor: - return project_recurrent_hidden_from_preproj_step( - recurrent_preproj_step, - out_proj_weight_t=out_proj_weight_t, - out_proj_bias=out_proj_bias, - ) - - -def project_recurrent_message_to_cell_step_backend( - _runtime: Any, - recurrent_msg: torch.Tensor, - *, - value_to_cell_weight: torch.Tensor, - recurrent_cell_bias: torch.Tensor, - fused_recurrent_value_to_cell_weight: torch.Tensor | None = None, - fused_recurrent_cell_bias: torch.Tensor | None = None, - fused_recurrent_population_input: bool = False, -) -> torch.Tensor: - if fused_recurrent_population_input and fused_recurrent_value_to_cell_weight is not None: - if fused_recurrent_cell_bias is None: - raise RuntimeError("fused recurrent projection requires fused recurrent cell bias") - return ( - torch.bmm( - recurrent_msg.transpose(0, 1), - fused_recurrent_value_to_cell_weight, - ).transpose(0, 1) - + fused_recurrent_cell_bias - ) - return torch.nn.functional.linear(recurrent_msg, value_to_cell_weight) + recurrent_cell_bias - - -def project_grouped_sender_cells_backend( - runtime: Any, - sender_cells_step: torch.Tensor, - grouped_weight: torch.Tensor, - *, - group_size: int, -) -> torch.Tensor: - return project_grouped_sender_cells( - sender_cells_step, - grouped_weight, - group_size=group_size, - head_dim=runtime.head_dim, - value_dim=runtime.value_dim, - ) - - -def compute_messages_step_subset_raw_backend( - runtime: Any, - k_all: torch.Tensor, - v_all: torch.Tensor, - *, - q_subset: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - use_delay: bool, - step_idx: int | torch.Tensor, - local_sender_idx: torch.Tensor | None = None, - local_receiver_idx_by_sender: torch.Tensor | None = None, - owner_tag: str = "generic", -) -> torch.Tensor: - if k_all.dtype != torch.float32: - raise ValueError(f"Fabric CUDA message kernel requires float32 inputs, got {k_all.dtype}") - batch_size = int(k_all.shape[0]) - step_flat = _step_flat(runtime, step_idx, batch_size=batch_size, time_steps=1, tensor=k_all) - if ( - runtime._local_message_step_enabled - and local_sender_idx is not None - and local_receiver_idx_by_sender is not None - ): - return fabric_local_message_cuda( - q_subset, - k_all, - v_all, - local_sender_idx, - local_receiver_idx_by_sender, - runtime.local_distance, - runtime.local_delay, - step_flat, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=use_delay, - owner_tag=owner_tag, - ) - with torch.profiler.record_function("fabric.glue.runtime_sparse_message_clone"): - sparse_k_all = k_all.clone() - sparse_v_all = v_all.clone() - return fabric_sparse_message_cuda( - q_subset, - sparse_k_all, - sparse_v_all, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=use_delay, - ) - - -def compute_messages_step_subset_partitioned_raw_backend( - runtime: Any, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - *, - q_subset: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - use_delay: bool, - step_idx: int | torch.Tensor, - local_sender_idx: torch.Tensor, - local_receiver_idx_by_sender: torch.Tensor, - owner_tag: str = "generic", -) -> torch.Tensor: - step_flat = _step_flat(runtime, step_idx, batch_size=int(input_k.shape[0]), time_steps=1, tensor=input_k) - if runtime._local_message_step_enabled: - return fabric_local_message_partitioned_cuda( - q_subset, - input_k, - input_v, - recurrent_k, - recurrent_v, - local_sender_idx, - local_receiver_idx_by_sender, - runtime.local_distance, - runtime.local_delay, - step_flat, - num_input_senders=runtime._num_input_cells, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=use_delay, - owner_tag=owner_tag, - ) - return fabric_sparse_message_partitioned_cuda( - q_subset, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=use_delay, - ) - - -def project_output_cells_step_raw_backend( - runtime: Any, - output_msg: torch.Tensor, - *, - value_to_output_weight: torch.Tensor, -) -> torch.Tensor: - return project_output_cells_step_raw( - output_msg, - value_to_output_weight=value_to_output_weight, - output_cell_bias=runtime.output_cell_bias, - hidden_size=runtime.hidden_size, - ) - - -__all__ = [ - "compute_messages", - "compute_messages_step_subset_partitioned_raw_backend", - "compute_messages_step_subset_raw_backend", - "project_grouped_sender_cells_backend", - "project_output_cells_step_raw_backend", - "project_recurrent_message_to_cell_step_backend", - "project_recurrent_hidden_from_preproj_step_backend", - "project_recurrent_kv_from_preproj_step_backend", - "project_sender_kv_from_cells_sequence_backend", - "project_sender_kv_from_cells_step_backend", -] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/__init__.py b/src/cortical/fabric/backend/cuda/sequence_surface/__init__.py index a22822e2..fd48a4f2 100644 --- a/src/cortical/fabric/backend/cuda/sequence_surface/__init__.py +++ b/src/cortical/fabric/backend/cuda/sequence_surface/__init__.py @@ -1,3 +1,3 @@ -from cortical.fabric.backend.cuda.sequence_surface.surface import CudaSequenceSurfaceMixin +from cortical.fabric.backend.cuda.sequence_surface.runtime.surface import CudaSequenceSurfaceMixin __all__ = ["CudaSequenceSurfaceMixin"] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/backward.py b/src/cortical/fabric/backend/cuda/sequence_surface/backward.py deleted file mode 100644 index f8c0f50e..00000000 --- a/src/cortical/fabric/backend/cuda/sequence_surface/backward.py +++ /dev/null @@ -1,3847 +0,0 @@ -from __future__ import annotations - -import math -import os -from collections import defaultdict -from collections.abc import Mapping -from dataclasses import dataclass, replace -from typing import Any, Literal, cast - -import torch -from tensordict import TensorDict, TensorDictBase - -from cortical.fabric.backend.cuda.message_passing.local_message_cuda import ( - fabric_local_message_partitioned_backward_fused_cuda, - fabric_local_message_partitioned_backward_receiver_cuda, - fabric_local_message_partitioned_backward_sender_cuda, -) -from cortical.fabric.backend.cuda.message_passing.sparse_message_cuda import ( - fabric_sparse_message_partitioned_backward_receiver_cuda, - fabric_sparse_message_partitioned_backward_sender_cuda, -) -from cortical.fabric.backend.cuda.ops import receiver_major_affine_backward_cuda, reset_backend_tensors_rows_cuda -from cortical.fabric.backend.cuda.projection.grouped_projection_cuda import fabric_grouped_projection_backward_cuda -from cortical.fabric.backend.cuda.projection.receiver_major_gates import receiver_major_affine_backward_block_b -from cortical.fabric.backend.cuda.sequence_surface.policy import ( - ActiveOutputBackwardTileInputs, - active_output_backward_batch_tile_policy, - artifact_storage_policy, - checkpoint_stride_alignment_policy, - recompute_artifact_window_policy, - recompute_checkpoint_stride_policy, -) -from cortical.fabric.backend.cuda.sequence_surface.replay import CudaSequenceReplayMixin -from cortical.fabric.backend.cuda.sequence_surface.support import ( - _BACKWARD_ATTRIBUTION_MODE_ENV, - _RECOMPUTE_PAYLOAD_TOTAL_KEYS, - _accumulate_owned_tensor_grad, - _accumulate_tensor_grad, - _artifact_payload_bytes_by_family, - _BackendSequenceStepArtifacts, - _format_payload_bytes, - _param_grad_tuple, - _param_input_specs, - _ReceiverWindowSpec, - _sender_kv_projection_param_specs, - _slice_batch_tensor, - _slice_batch_tree, - _slice_forward_carry_checkpoints, - _state_public_param_specs, - _transition_supports_receiver_local_dependency_window, -) -from cortical.fabric.backend.planner import ( - PlannedFabricBackwardExecution, - PlannedFabricExecution, - cuda_nn_primitive_backward_behavior, -) -from cortical.fabric.backend.reuse import ExecutionFamily -from cortical.fabric.runtime.state import ( - flatten_backend_packed_state as _flatten_backend_packed_state, -) -from cortical.fabric.runtime.state import ( - unflatten_backend_packed_state as _unflatten_backend_packed_state, -) - - -def _align_matrix_grad_tail_to_target( - grad: torch.Tensor, - target_shape: tuple[int, ...], -) -> torch.Tensor: - """Align executor-returned matrix grad layout to the declared parameter layout.""" - if grad.dim() != len(target_shape) or len(target_shape) < 2: - return grad - if tuple(grad.shape) == target_shape: - return grad - grad_prefix = tuple(int(dim) for dim in grad.shape[:-2]) - target_prefix = tuple(int(dim) for dim in target_shape[:-2]) - grad_tail = tuple(int(dim) for dim in grad.shape[-2:]) - target_tail = tuple(int(dim) for dim in target_shape[-2:]) - if grad_prefix == target_prefix and grad_tail == (target_tail[1], target_tail[0]): - return grad.transpose(-1, -2).contiguous() - return grad - - -def _sender_kv_projection_backward_profile_name(owner: str) -> str: - if owner == "public_projection": - return "fabric.backward.public_projection" - if owner == "grouped_projection": - return "fabric.backward.grouped_projection" - return f"fabric.backward.{owner}" - - -@dataclass(frozen=True) -class SenderKVProjectionRawParamGrad: - role: Literal["input", "recurrent"] - grad_weight: torch.Tensor - group_ids: torch.Tensor - grouped: bool - - -class CudaSequenceBackwardMixin(CudaSequenceReplayMixin): - @staticmethod - def _run_named_autograd_phase( - *, - outputs: list[tuple[torch.Tensor, torch.Tensor | None]], - inputs: list[tuple[str, torch.Tensor]], - profile_name: str, - retain_graph: bool = False, - ) -> dict[str, torch.Tensor | None]: - active_outputs = [(tensor, grad) for tensor, grad in outputs if grad is not None and tensor.requires_grad] - if not active_outputs or not inputs: - return {name: None for name, _tensor in inputs} - with torch.profiler.record_function(profile_name): - grad_results = torch.autograd.grad( - outputs=tuple(tensor for tensor, _grad in active_outputs), - inputs=tuple(tensor for _name, tensor in inputs), - grad_outputs=tuple(cast(torch.Tensor, grad) for _tensor, grad in active_outputs), - allow_unused=True, - retain_graph=retain_graph, - ) - return {name: grad for (name, _tensor), grad in zip(inputs, grad_results, strict=True)} - - def _active_transition_op_names_for_population(self, population_name: str) -> tuple[str, ...]: - if not population_name: - raise RuntimeError("Fabric CUDA state/public backward requires a recurrent population") - population_spec = self._backend_population_specs.get(population_name) - if population_spec is None: - raise RuntimeError( - f"Fabric CUDA state/public backward missing backend population spec for {population_name!r}" - ) - op_names = tuple(op.name for op in population_spec.transition_ir.ops) - for op_name in op_names: - cuda_nn_primitive_backward_behavior(op_name) - return op_names - - def _active_transition_op_names(self) -> tuple[str, ...]: - population_name = self._full_recurrent_population_name - if population_name is None: - raise RuntimeError("Fabric CUDA state/public backward requires a recurrent population") - return self._active_transition_op_names_for_population(population_name) - - def _state_public_backward_profile_name_for_population(self, population_name: str) -> str: - op_names = set(self._active_transition_op_names_for_population(population_name)) - if op_names & {"diag_rtu", "diagonal_recurrence"}: - return "fabric.backward.diagonal_recurrence" - if op_names & {"gated_logspace_recurrence", "norm_or_identity", "reduction_boundary", "state_epilogue_policy"}: - return "fabric.backward.state_epilogue" - if op_names & {"linear", "matmul", "state_affine"}: - return "fabric.backward.receiver_affine" - raise RuntimeError( - "Fabric CUDA state/public backward has no registered semantic owner for transition IR ops: " - + (", ".join(sorted(op_names)) or "") - ) - - def _state_public_backward_profile_name(self) -> str: - population_name = self._full_recurrent_population_name - if population_name is None: - raise RuntimeError("Fabric CUDA state/public backward requires a recurrent population") - return self._state_public_backward_profile_name_for_population(population_name) - - def _transition_core_state_names_for_population(self, population_name: str) -> tuple[str, ...] | None: - population_spec = self._backend_population_specs.get(population_name) - if population_spec is None: - return None - trace_names = { - schema.name - for schema in population_spec.private_state_schema - if schema.semantic_kind == "eligibility_trace" - } - if not trace_names: - return None - core_names = tuple( - state_name for state_name in population_spec.transition_ir.state_inputs if state_name not in trace_names - ) - if not core_names or len(core_names) == len(population_spec.transition_ir.state_inputs): - return None - return core_names - - def _can_elide_transition_trace_state_next( - self, - *, - population_name: str, - active_receiver_window: _ReceiverWindowSpec | None, - ) -> bool: - if active_receiver_window is None or not active_receiver_window.active: - return False - population_spec = self._backend_population_specs.get(population_name) - if population_spec is None: - return False - op_names = {op.name for op in population_spec.transition_ir.ops} - if not (op_names & {"diag_rtu", "diagonal_recurrence"}): - return False - return self._transition_core_state_names_for_population(population_name) is not None - - @staticmethod - def _grad_tree_empty(value: Any) -> bool: - if value is None: - return True - if torch.is_tensor(value): - return int(value.numel()) == 0 - if isinstance(value, (dict, TensorDictBase)): - return all(CudaSequenceBackwardMixin._grad_tree_empty(item) for item in value.values()) - return False - - def _fixed_output_dependency_receiver_window( - self, - *, - reason: str, - ) -> _ReceiverWindowSpec | None: - recurrent_count = int(self._num_recurrent_cells) - if not bool(getattr(self, "_output_local_recurrent_window_contiguous", False)): - self._last_backend_backward_active_receiver_window_reason = f"{reason}:not_contiguous" - return None - start = int(getattr(self, "_output_local_recurrent_window_start", 0)) - count = int(getattr(self, "_output_local_recurrent_window_count", 0)) - if start < 0 or count <= 0 or start + count > recurrent_count or count >= recurrent_count: - self._last_backend_backward_active_receiver_window_reason = ( - f"{reason}:invalid_or_full[start={start},count={count},full={recurrent_count}]" - ) - return None - window = _ReceiverWindowSpec( - mode="streaming_output_active_region", - start=start, - count=count, - full_count=recurrent_count, - ) - self._last_backend_backward_active_receiver_window = ( - f"{window.mode}[offset={window.start},count={window.count},full={window.full_count}]" - ) - self._last_backend_backward_active_receiver_window_reason = reason - return window - - def _receiver_window_is_closed_under_recurrent_senders( - self, - window: _ReceiverWindowSpec, - ) -> bool: - num_input_senders = int(self._num_input_cells) - start = int(window.start) - stop = int(window.start + window.count) - - def table_closed(table: torch.Tensor, *, slice_receivers: bool) -> bool: - if int(table.numel()) == 0: - return True - selected = ( - table.narrow(0, start, int(window.count)) - if slice_receivers and int(table.shape[0]) == int(window.full_count) - else table - ) - recurrent_sender_ids = selected[selected >= num_input_senders] - num_input_senders - if int(recurrent_sender_ids.numel()) == 0: - return True - outside = (recurrent_sender_ids < start) | (recurrent_sender_ids >= stop) - return not bool(outside.any().item()) - - return table_closed(self.recurrent_local_sender_idx, slice_receivers=True) and table_closed( - self.output_local_sender_idx, - slice_receivers=False, - ) - - def _backward_active_receiver_window_for_output_only_step( - self, - *, - boundary_seq: torch.Tensor, - output_boundary: Literal["sequence", "terminal"], - uses_sparse_messages: bool, - initial_state_is_fresh: bool, - packed_state: Any, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - grad_recurrent_k: torch.Tensor | None, - grad_recurrent_v: torch.Tensor | None, - grad_input_k_last: torch.Tensor | None, - grad_input_v_last: torch.Tensor | None, - ) -> _ReceiverWindowSpec | None: - self._last_backend_backward_active_receiver_window = "full_surface" - self._last_backend_backward_active_receiver_window_reason = "not_output_boundary_only_streaming_window" - if ( - boundary_seq.device.type != "cuda" - or uses_sparse_messages - or not self._local_message_step_enabled - or self._has_edge_delay - or not self._grad_tree_empty(grad_next_packed_state) - or not self._grad_tree_empty(grad_recurrent_hidden) - or not self._grad_tree_empty(grad_recurrent_k) - or not self._grad_tree_empty(grad_recurrent_v) - or not self._grad_tree_empty(grad_input_k_last) - or not self._grad_tree_empty(grad_input_v_last) - ): - return None - population_name = self._full_recurrent_population_name - if population_name is None: - self._last_backend_backward_active_receiver_window_reason = ( - "streaming_output_active_region:missing_population" - ) - return None - population_spec = self._backend_population_specs.get(population_name) - if population_spec is None or not _transition_supports_receiver_local_dependency_window( - population_spec.transition_ir - ): - self._last_backend_backward_active_receiver_window_reason = ( - "streaming_output_active_region:unsupported_transition" - ) - return None - window = self._fixed_output_dependency_receiver_window( - reason=f"streaming_output_active_region:fixed_window;output_boundary={output_boundary};" - f"time_steps={int(boundary_seq.shape[1])}" - ) - if window is not None and window.active and not self._receiver_window_is_closed_under_recurrent_senders(window): - - def fresh_or_compact_synthetic_recurrent_boundary(value: torch.Tensor | None) -> bool: - if value is None: - return True - return torch.is_tensor(value) and value.dim() >= 2 and int(value.shape[1]) == int(window.count) - - fresh_zero_recurrent_boundary = ( - bool(initial_state_is_fresh) - and packed_state is None - and fresh_or_compact_synthetic_recurrent_boundary(initial_recurrent_k) - and fresh_or_compact_synthetic_recurrent_boundary(initial_recurrent_v) - ) - if fresh_zero_recurrent_boundary: - self._last_backend_backward_active_receiver_window_reason = ( - f"{self._last_backend_backward_active_receiver_window_reason};" - "nonclosed_recurrent_senders=fresh_zero_boundary" - ) - return window - self._last_backend_backward_active_receiver_window = "full_surface" - self._last_backend_backward_active_receiver_window_reason = ( - "streaming_output_active_region:demoted_nonclosed_recurrent_sender_graph" - ) - return None - return window - - def _infer_backward_compact_carry_receiver_window( - self, - *, - packed_state: Any, - initial_hidden: torch.Tensor, - active_receiver_window: _ReceiverWindowSpec | None, - ) -> _ReceiverWindowSpec | None: - if active_receiver_window is not None: - return active_receiver_window - inferred_window = self._fixed_output_dependency_receiver_window( - reason="streaming_output_active_region:checkpoint_compact_carry" - ) - if inferred_window is None or not inferred_window.active: - return active_receiver_window - compact_receiver_count: int | None = None - if torch.is_tensor(initial_hidden) and initial_hidden.dim() >= 2: - hidden_receivers = int(initial_hidden.shape[1]) - if hidden_receivers > 0: - compact_receiver_count = hidden_receivers - if compact_receiver_count is None and packed_state is not None: - _packed_state_keys, packed_state_tensors = _flatten_backend_packed_state(packed_state) - for tensor in packed_state_tensors: - if torch.is_tensor(tensor) and tensor.dim() >= 2 and int(tensor.shape[1]) > 0: - compact_receiver_count = int(tensor.shape[1]) - break - if compact_receiver_count == int(inferred_window.count): - self._last_backend_backward_active_receiver_window = ( - f"{inferred_window.mode}[offset={inferred_window.start}," - f"count={inferred_window.count},full={inferred_window.full_count}]" - ) - self._last_backend_backward_active_receiver_window_reason = ( - "streaming_output_active_region:checkpoint_compact_carry;compact_runtime_bank" - ) - return inferred_window - if not self._receiver_window_is_closed_under_recurrent_senders(inferred_window): - self._last_backend_backward_active_receiver_window = "full_surface" - self._last_backend_backward_active_receiver_window_reason = ( - "streaming_output_active_region:checkpoint_compact_carry_demoted_nonclosed_recurrent_sender_graph" - ) - return active_receiver_window - return active_receiver_window - - def _active_output_dependency_backward_batch_tile_len( - self, - *, - population_name: str, - boundary_seq: torch.Tensor, - active_receiver_window: _ReceiverWindowSpec | None, - output_boundary: Literal["sequence", "terminal"] = "sequence", - ) -> int: - batch_size = int(boundary_seq.shape[0]) - if ( - batch_size <= 1 - or boundary_seq.device.type != "cuda" - or active_receiver_window is None - or not active_receiver_window.active - ): - self._last_backend_backward_batch_tile_len = batch_size - self._last_backend_backward_batch_tile_reason = "batch_tiling=disabled" - return batch_size - dtype_bytes = int(torch.empty((), dtype=boundary_seq.dtype).element_size()) - state_leaf_count = len(self._cell_spec_for_population(population_name).state_schema.keys) - state_leaf_mode = "full_state" - if self._can_elide_transition_trace_state_next( - population_name=population_name, - active_receiver_window=active_receiver_window, - ): - core_state_names = self._transition_core_state_names_for_population(population_name) - if core_state_names: - state_leaf_count = len(core_state_names) - state_leaf_mode = "core_state_without_trace" - active_receivers = int(active_receiver_window.count) - recurrent_receivers = int(active_receiver_window.full_count) - recurrent_bank_receivers = ( - active_receivers - if self._receiver_window_compacts_recurrent_senders(active_receiver_window) - else recurrent_receivers - ) - decision = active_output_backward_batch_tile_policy( - inputs=ActiveOutputBackwardTileInputs( - batch_size=batch_size, - time_steps=int(boundary_seq.shape[1]), - dtype_bytes=dtype_bytes, - state_leaf_count=state_leaf_count, - state_leaf_mode=state_leaf_mode, - active_receivers=active_receivers, - recurrent_receivers=recurrent_receivers, - recurrent_bank_receivers=recurrent_bank_receivers, - hidden_size=int(self.hidden_size), - head_dim=int(self.head_dim), - value_dim=int(self.value_dim), - input_cells=int(self._num_input_cells), - output_cells=int(self._num_output_cells), - output_dim=int(boundary_seq.shape[-1]), - output_boundary=output_boundary, - ), - memory=self._cuda_memory_budget(boundary_seq.device), - ) - self._last_backend_backward_batch_tile_len = int(decision.value) - self._last_backend_backward_batch_tile_reason = decision.reason - return int(decision.value) - - def _sequence_backward_surface_per_batch_bytes( - self, - *, - population_name: str, - boundary_seq: torch.Tensor, - active_receiver_window: _ReceiverWindowSpec | None, - ) -> int: - dtype_bytes = int(torch.empty((), dtype=boundary_seq.dtype).element_size()) - state_leaf_count = len(self._cell_spec_for_population(population_name).state_schema.keys) - if self._can_elide_transition_trace_state_next( - population_name=population_name, - active_receiver_window=active_receiver_window, - ): - core_state_names = self._transition_core_state_names_for_population(population_name) - if core_state_names: - state_leaf_count = len(core_state_names) - receiver_count = ( - int(active_receiver_window.count) - if active_receiver_window is not None and active_receiver_window.active - else int(self._num_recurrent_cells) - ) - recurrent_bank_receivers = ( - receiver_count - if self._receiver_window_compacts_recurrent_senders(active_receiver_window) - else int(self._num_recurrent_cells) - ) - state_elements = int(receiver_count) * int(state_leaf_count) * int(self.hidden_size) - public_elements = int(receiver_count) * (int(self.hidden_size) + int(self.head_dim) + int(self.value_dim)) - recurrent_bank_elements = int(recurrent_bank_receivers) * (int(self.head_dim) + int(self.value_dim)) - input_elements = int(self._num_input_cells) * (int(self.head_dim) + int(self.value_dim)) - output_elements = int(self._num_output_cells) * int(boundary_seq.shape[-1]) - # Include both forward intermediates and adjoint surfaces for a streaming sequence-loss tile. This is a - # conservative flat-graph estimate; it keys on receiver counts and boundary widths, not cell names or shapes. - per_batch_elements = ( - state_elements + public_elements + recurrent_bank_elements + input_elements + output_elements - ) - return int(math.ceil(float(per_batch_elements) * float(dtype_bytes) * 2.5)) - - @staticmethod - def _slice_receiver_window_rows(tensor: torch.Tensor, window: _ReceiverWindowSpec | None) -> torch.Tensor: - if window is None or not window.active: - return tensor - return tensor.narrow(0, window.start, window.count).contiguous() - - @staticmethod - def _slice_receiver_window_batch_rows(value: Any, window: _ReceiverWindowSpec | None) -> Any: - if window is None or not window.active: - return value - if torch.is_tensor(value): - if value.dim() >= 2 and int(value.shape[1]) == window.full_count: - return value.narrow(1, window.start, window.count).contiguous() - return value - if isinstance(value, TensorDictBase): - return TensorDict( - { - key: CudaSequenceBackwardMixin._slice_receiver_window_batch_rows(item, window) - for key, item in value.items() - }, - batch_size=[int(value.batch_size[0]), window.count] if len(value.batch_size) >= 2 else value.batch_size, - device=value.device, - ) - if isinstance(value, dict): - return { - key: CudaSequenceBackwardMixin._slice_receiver_window_batch_rows(item, window) - for key, item in value.items() - } - return value - - @staticmethod - def _receiver_window_compacts_recurrent_senders(window: _ReceiverWindowSpec | None) -> bool: - return window is not None and window.active - - @staticmethod - def _slice_receiver_window_recurrent_bank( - tensor: torch.Tensor, - window: _ReceiverWindowSpec | None, - ) -> torch.Tensor: - if not CudaSequenceBackwardMixin._receiver_window_compacts_recurrent_senders(window): - return tensor - assert window is not None - if tensor.dim() >= 2 and int(tensor.shape[1]) == window.full_count: - return tensor.narrow(1, window.start, window.count).contiguous() - if tensor.dim() >= 2 and int(tensor.shape[1]) == window.count: - return tensor - raise RuntimeError( - "Fabric time-expanded receiver window expected a full or compact recurrent bank, " - f"got shape={tuple(tensor.shape)} window_count={window.count} full_count={window.full_count}" - ) - - @staticmethod - def _scatter_receiver_window_recurrent_bank_grad( - value: torch.Tensor | None, - window: _ReceiverWindowSpec | None, - *, - like: torch.Tensor | None, - ) -> torch.Tensor | None: - if not CudaSequenceBackwardMixin._receiver_window_compacts_recurrent_senders(window) or value is None: - return value - assert window is not None - if like is not None: - if like.dim() >= 2 and int(like.shape[1]) == window.count: - return value - if like.dim() >= 2 and int(like.shape[1]) == window.full_count: - full = torch.zeros_like(like) - full.narrow(1, window.start, window.count).copy_(value) - return full - full_shape = list(value.shape) - full_shape[1] = window.full_count - full = value.new_zeros(tuple(full_shape)) - full.narrow(1, window.start, window.count).copy_(value) - return full - - @staticmethod - def _slice_receiver_window_static_tensor( - tensor: torch.Tensor, - window: _ReceiverWindowSpec, - ) -> torch.Tensor: - if tensor.dim() == 1 and int(tensor.numel()) >= window.full_count: - if int(tensor.numel()) % window.full_count == 0: - receiver_view = tensor.reshape(window.full_count, -1) - return receiver_view.narrow(0, window.start, window.count).contiguous().reshape(-1) - if tensor.dim() >= 1 and int(tensor.shape[0]) == window.full_count: - return tensor.narrow(0, window.start, window.count).contiguous() - if tensor.dim() >= 2 and int(tensor.shape[0]) == 1 and int(tensor.shape[1]) == window.full_count: - return tensor.narrow(1, window.start, window.count).contiguous() - return tensor - - def _slice_receiver_window_static_tensors( - self, - static_tensors: Mapping[str, object], - window: _ReceiverWindowSpec | None, - ) -> dict[str, object]: - if window is None or not window.active: - return dict(static_tensors) - sliced: dict[str, object] = dict(static_tensors) - for key in ( - "recurrent_cell_bias", - "fused_recurrent_value_to_cell_weight", - "fused_recurrent_cell_bias", - "recurrent_sender_input_to_kv_weight", - ): - value = sliced.get(key) - if torch.is_tensor(value): - sliced[key] = self._slice_receiver_window_static_tensor(value, window) - - direct_recurrent_kv = sliced.get("recurrent_sender_input_to_kv_weight") - grouped_recurrent_kv = sliced.get("recurrent_group_input_to_kv_weight") - if torch.is_tensor(direct_recurrent_kv): - sliced["recurrent_group_input_to_kv_weight"] = None - elif torch.is_tensor(grouped_recurrent_kv): - expanded = grouped_recurrent_kv.repeat_interleave( - max(1, int(self._recurrent_sender_kv_group_size)), - dim=0, - ) - if int(expanded.shape[0]) >= window.full_count: - sliced["recurrent_sender_input_to_kv_weight"] = ( - expanded[: window.full_count] - .narrow( - 0, - window.start, - window.count, - ) - .contiguous() - ) - sliced["recurrent_group_input_to_kv_weight"] = None - - population_materialized = sliced.get("population_materialized") - if isinstance(population_materialized, dict): - sliced_population_materialized: dict[str, object | None] = {} - for population_name, params in population_materialized.items(): - if not isinstance(params, dict): - sliced_population_materialized[population_name] = params - continue - sliced_params: dict[str, object] = {} - for name, value in params.items(): - sliced_params[name] = ( - self._slice_receiver_window_static_tensor(value, window) if torch.is_tensor(value) else value - ) - sliced_population_materialized[population_name] = sliced_params - sliced["population_materialized"] = sliced_population_materialized - return sliced - - @staticmethod - def _tensor_cache_fingerprint(tensor: torch.Tensor) -> tuple[object, ...]: - return ( - str(tensor.device), - str(tensor.dtype), - int(tensor.data_ptr()), - tuple(int(dim) for dim in tensor.shape), - tuple(int(stride) for stride in tensor.stride()), - int(getattr(tensor, "_version", 0)), - ) - - @classmethod - def _static_tensor_cache_fingerprint(cls, value: object) -> tuple[object, ...]: - if torch.is_tensor(value): - return ("tensor", cls._tensor_cache_fingerprint(value)) - if isinstance(value, Mapping): - return ( - "mapping", - tuple( - (str(key), cls._static_tensor_cache_fingerprint(item)) - for key, item in sorted(value.items(), key=lambda pair: str(pair[0])) - ), - ) - return ("object", type(value).__name__, id(value)) - - def _cached_receiver_window_static_tensors( - self, - static_tensors: Mapping[str, object], - window: _ReceiverWindowSpec | None, - ) -> dict[str, object]: - if window is None or not window.active: - return dict(static_tensors) - cache = getattr(self, "_receiver_window_static_tensor_dict_cache", None) - if cache is None: - cache = {} - self._receiver_window_static_tensor_dict_cache = cache - key = ( - id(static_tensors), - window.mode, - int(window.start), - int(window.count), - int(window.full_count), - self._static_tensor_cache_fingerprint(static_tensors), - ) - cached = cache.get(key) - if cached is None: - cached = self._slice_receiver_window_static_tensors(static_tensors, window) - cache[key] = cached - return dict(cached) - - @staticmethod - def _remap_partitioned_sender_table_for_receiver_window( - table: torch.Tensor, - window: _ReceiverWindowSpec | None, - *, - num_input_senders: int, - slice_receivers: bool, - compact_recurrent_senders: bool, - ) -> torch.Tensor: - adjusted = table - if window is not None and window.active and slice_receivers: - adjusted = adjusted.narrow(0, window.start, window.count) - adjusted = adjusted.contiguous() - if window is None or not window.active or not compact_recurrent_senders: - return adjusted - remapped = adjusted.clone() - recurrent_mask = remapped >= int(num_input_senders) - recurrent_index = remapped - int(num_input_senders) - in_window = ( - recurrent_mask - & (recurrent_index >= int(window.start)) - & (recurrent_index < int(window.start + window.count)) - ) - remapped[recurrent_mask & ~in_window] = -1 - remapped[in_window] = int(num_input_senders) + recurrent_index[in_window] - int(window.start) - return remapped.contiguous() - - @staticmethod - def _sender_reverse_table_from_receiver_table(receiver_sender_idx: torch.Tensor, num_senders: int) -> torch.Tensor: - reverse = torch.full( - (int(num_senders), int(receiver_sender_idx.shape[1])), - -1, - device=receiver_sender_idx.device, - dtype=torch.int32, - ) - receiver_idx, offset_idx = torch.nonzero(receiver_sender_idx >= 0, as_tuple=True) - if receiver_idx.numel() == 0: - return reverse - sender_idx = receiver_sender_idx[receiver_idx, offset_idx].to(dtype=torch.long) - valid = (sender_idx >= 0) & (sender_idx < int(num_senders)) - if bool((reverse[sender_idx[valid], offset_idx[valid]] >= 0).any()): - raise RuntimeError("Fabric local sender reverse table expects unique receiver per sender/offset") - reverse[sender_idx[valid], offset_idx[valid]] = receiver_idx[valid].to(dtype=torch.int32) - return reverse - - def _cached_receiver_window_sender_table( - self, - *, - name: str, - table: torch.Tensor, - window: _ReceiverWindowSpec | None, - num_input_senders: int, - slice_receivers: bool, - compact_recurrent_senders: bool, - ) -> torch.Tensor: - if window is None or not window.active: - return self._remap_partitioned_sender_table_for_receiver_window( - table, - window, - num_input_senders=num_input_senders, - slice_receivers=slice_receivers, - compact_recurrent_senders=compact_recurrent_senders, - ) - cache = getattr(self, "_receiver_window_sender_table_cache", None) - if cache is None: - cache = {} - self._receiver_window_sender_table_cache = cache - key = ( - name, - str(table.device), - int(table.data_ptr()), - tuple(table.shape), - window.mode, - int(window.start), - int(window.count), - int(window.full_count), - int(num_input_senders), - bool(slice_receivers), - bool(compact_recurrent_senders), - ) - cached = cache.get(key) - if cached is None: - cached = self._remap_partitioned_sender_table_for_receiver_window( - table, - window, - num_input_senders=num_input_senders, - slice_receivers=slice_receivers, - compact_recurrent_senders=compact_recurrent_senders, - ) - cache[key] = cached - return cached - - def _cached_sender_reverse_table( - self, - *, - name: str, - receiver_sender_idx: torch.Tensor, - num_senders: int, - ) -> torch.Tensor: - cache = getattr(self, "_sender_reverse_table_cache", None) - if cache is None: - cache = {} - self._sender_reverse_table_cache = cache - key = ( - name, - str(receiver_sender_idx.device), - int(receiver_sender_idx.data_ptr()), - tuple(receiver_sender_idx.shape), - int(num_senders), - ) - cached = cache.get(key) - if cached is None: - cached = self._sender_reverse_table_from_receiver_table(receiver_sender_idx, num_senders) - cache[key] = cached - return cached - - def _cached_receiver_window_static_rows( - self, - *, - name: str, - tensor: torch.Tensor, - window: _ReceiverWindowSpec | None, - ) -> torch.Tensor: - if window is None or not window.active: - return tensor - cache = getattr(self, "_receiver_window_static_rows_cache", None) - if cache is None: - cache = {} - self._receiver_window_static_rows_cache = cache - key = ( - name, - str(tensor.device), - int(tensor.data_ptr()), - tuple(tensor.shape), - window.mode, - int(window.start), - int(window.count), - int(window.full_count), - ) - cached = cache.get(key) - if cached is None: - cached = tensor.narrow(0, window.start, window.count).contiguous() - cache[key] = cached - return cached - - def _scatter_receiver_window_batch_rows( - self, - value: Any, - window: _ReceiverWindowSpec | None, - *, - like: Any, - ) -> Any: - if window is None or not window.active or value is None: - return value - if torch.is_tensor(value): - if ( - value.dim() >= 2 - and int(value.shape[1]) == window.count - and torch.is_tensor(like) - and like.dim() >= 2 - and int(like.shape[1]) == window.full_count - ): - full = torch.zeros_like(cast(torch.Tensor, like)) - full.narrow(1, window.start, window.count).copy_(value) - return full - return value - if isinstance(value, (dict, TensorDictBase)) and isinstance(like, (dict, TensorDictBase)): - return { - key: self._scatter_receiver_window_batch_rows(value.get(key), window, like=like.get(key)) - for key in like.keys() - } - return value - - def _finalize_backward_param_grads( - self, - *, - trainable_params: tuple[torch.Tensor, ...], - grad_param_accum: list[torch.Tensor | None], - active_receiver_window: _ReceiverWindowSpec | None, - ) -> tuple[torch.Tensor | None, ...]: - final_grad_param_accum: list[torch.Tensor | None] = [] - for parameter, grad_param in zip(trainable_params, grad_param_accum, strict=True): - if grad_param is not None and tuple(grad_param.shape) != tuple(parameter.shape): - if int(grad_param.numel()) == int(parameter.numel()): - grad_param = grad_param.reshape(tuple(parameter.shape)) - if ( - active_receiver_window is not None - and active_receiver_window.active - and grad_param is not None - and grad_param.dim() >= 1 - and int(grad_param.shape[0]) == active_receiver_window.count - and len(parameter.shape) >= 1 - and int(parameter.shape[0]) == active_receiver_window.full_count - ): - compact_shape = (int(active_receiver_window.count), *tuple(parameter.shape[1:])) - if tuple(grad_param.shape) != compact_shape and int(grad_param.numel()) == math.prod(compact_shape): - grad_param = grad_param.reshape(compact_shape) - full_grad = torch.zeros_like(parameter) - full_grad.narrow(0, active_receiver_window.start, active_receiver_window.count).copy_(grad_param) - grad_param = full_grad - final_grad_param_accum.append(grad_param) - return tuple(final_grad_param_accum) - - def _state_public_explicit_param_grad_tuple( - self, - *, - population_name: str, - materialized_param_grads: Mapping[str, torch.Tensor], - static_source_grads: Mapping[str, torch.Tensor], - projection_param_grads: Mapping[str, torch.Tensor | None], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - trainable_param_shapes: tuple[tuple[int, ...], ...] | None = None, - active_receiver_window: _ReceiverWindowSpec | None = None, - ) -> tuple[torch.Tensor | None, ...]: - if not population_name: - raise RuntimeError("Fabric CUDA state/public backward requires a recurrent population") - if trainable_param_shapes is None: - trainable_param_shapes = tuple(tuple(param.shape) for param in trainable_params) - if len(trainable_param_shapes) != len(trainable_param_names): - raise RuntimeError("Fabric CUDA state/public backward requires aligned parameter names and shapes") - by_name = {name: index for index, name in enumerate(trainable_param_names)} - grad_accum: list[torch.Tensor | None] = [None] * len(trainable_param_names) - - def add(name: str, grad: torch.Tensor | None) -> None: - if grad is None: - return - index = by_name.get(name) - if index is None: - return - target_shape = trainable_param_shapes[index] - compact_receiver_grad = ( - active_receiver_window is not None - and active_receiver_window.active - and grad.dim() >= 1 - and int(grad.shape[0]) == active_receiver_window.count - and len(target_shape) >= 1 - and int(target_shape[0]) == active_receiver_window.full_count - ) - if tuple(grad.shape) != target_shape: - if compact_receiver_grad: - compact_shape = (int(active_receiver_window.count), *target_shape[1:]) - grad = _align_matrix_grad_tail_to_target(grad, compact_shape) - if tuple(grad.shape) != compact_shape and int(grad.numel()) == math.prod(compact_shape): - grad = grad.reshape(compact_shape) - elif int(grad.numel()) == math.prod(target_shape): - grad = _align_matrix_grad_tail_to_target(grad, target_shape) - grad = grad.reshape(target_shape) - if tuple(grad.shape) != target_shape and not compact_receiver_grad: - raise RuntimeError( - f"Fabric CUDA state/public backward produced grad for {name} with shape " - f"{tuple(grad.shape)}, expected {target_shape}" - ) - grad_accum[index] = _accumulate_tensor_grad(grad_accum[index], grad) - - def reduce_to_param(name: str, grad: torch.Tensor) -> torch.Tensor | None: - with ( - torch.profiler.record_function("fabric.backward.glue.param_grad_binding"), - self._backend_owner_timing("glue.param_grad_binding"), - ): - index = by_name.get(name) - if index is None: - return None - target_shape = trainable_param_shapes[index] - reduced = grad - while reduced.dim() > len(target_shape): - reduced = reduced.sum(dim=0) - if reduced.dim() != len(target_shape): - raise RuntimeError( - f"Fabric CUDA state/public backward cannot reduce grad for {name} from " - f"{tuple(grad.shape)} to {target_shape}" - ) - reduced = _align_matrix_grad_tail_to_target(reduced, target_shape) - for dim, target_dim in enumerate(target_shape): - if reduced.shape[dim] == target_dim: - continue - if target_dim == 1: - reduced = reduced.sum(dim=dim, keepdim=True) - continue - raise RuntimeError( - f"Fabric CUDA state/public backward cannot reduce grad for {name} from " - f"{tuple(grad.shape)} to {target_shape}" - ) - return reduced - - def add_recurrent_bias_grad(grad: torch.Tensor) -> None: - with ( - torch.profiler.record_function("fabric.backward.glue.param_grad_binding"), - self._backend_owner_timing("glue.param_grad_binding"), - ): - recurrent_bias_2d = grad.squeeze(0) if grad.dim() == 3 else grad - recurrent_cell_idx = self.recurrent_cell_idx - if active_receiver_window is not None and active_receiver_window.active: - recurrent_cell_idx = recurrent_cell_idx.narrow( - 0, - active_receiver_window.start, - active_receiver_window.count, - ) - elif int(recurrent_bias_2d.shape[0]) != int(recurrent_cell_idx.numel()): - population_cell_idx = self._population_indices(population_name) - if int(recurrent_bias_2d.shape[0]) == int(population_cell_idx.numel()): - recurrent_cell_idx = population_cell_idx - full_bias_grad = torch.zeros( - self.slot_embed.shape[0], - self.hidden_size, - device=recurrent_bias_2d.device, - dtype=recurrent_bias_2d.dtype, - ) - full_bias_grad.index_add_( - 0, - recurrent_cell_idx.to(device=recurrent_bias_2d.device), - recurrent_bias_2d, - ) - add("slot_embed", full_bias_grad.matmul(self.cell_bias_proj.weight.detach())) - add("cell_bias_proj.weight", full_bias_grad.t().matmul(self.slot_embed.detach())) - - population_prefix = f"population_modules.{population_name}." - for materialized_name, grad in materialized_param_grads.items(): - base_name = f"{population_prefix}{materialized_name}_base" - reduced_base_grad = reduce_to_param(base_name, grad) - if reduced_base_grad is not None: - add(base_name, reduced_base_grad) - add(f"{population_prefix}{materialized_name}_delta", grad) - - value_to_cell_grad = static_source_grads.get("value_to_cell_weight") - if value_to_cell_grad is not None: - add("msg_to_cell.weight", value_to_cell_grad.matmul(self.msg_out.weight.detach().t())) - add("msg_out.weight", self.msg_to_cell.weight.detach().t().matmul(value_to_cell_grad)) - - recurrent_cell_bias_grad = static_source_grads.get("recurrent_cell_bias") - if recurrent_cell_bias_grad is not None: - add_recurrent_bias_grad(recurrent_cell_bias_grad) - - for parameter_index in range(len(trainable_param_names)): - grad_accum[parameter_index] = _accumulate_tensor_grad( - grad_accum[parameter_index], - projection_param_grads.get(f"param_{parameter_index}"), - ) - return tuple(grad_accum) - - @staticmethod - def _concat_kv_grads( - grad_k: torch.Tensor | None, - grad_v: torch.Tensor | None, - *, - head_dim: int, - value_dim: int, - ) -> torch.Tensor | None: - if grad_k is None and grad_v is None: - return None - if grad_k is None: - assert grad_v is not None - grad_k = grad_v.new_zeros(*grad_v.shape[:-1], int(head_dim)) - if grad_v is None: - grad_v = grad_k.new_zeros(*grad_k.shape[:-1], int(value_dim)) - return torch.cat((grad_k, grad_v), dim=-1) - - def _named_projection_param_grad_tuple( - self, - *, - named_grads: Mapping[str, torch.Tensor | None], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - ) -> tuple[torch.Tensor | None, ...]: - return tuple(named_grads[name] if name in named_grads else None for name in trainable_param_names) - - def _direct_sender_kv_group_ids( - self, - *, - role: Literal["input", "recurrent"], - active_receiver_window: _ReceiverWindowSpec | None, - device: torch.device, - ) -> torch.Tensor: - if role == "input": - cell_idx = self.input_cell_idx - else: - cell_idx = self.recurrent_cell_idx - if active_receiver_window is not None and active_receiver_window.active: - cell_idx = cell_idx.narrow( - 0, - int(active_receiver_window.start), - int(active_receiver_window.count), - ) - return self.kv_group_id.index_select(0, cell_idx.to(device=self.kv_group_id.device)).to(device=device) - - def _grouped_sender_kv_group_ids( - self, - *, - role: Literal["input", "recurrent"], - device: torch.device, - ) -> torch.Tensor: - group_ids = self.input_sender_kv_group_ids if role == "input" else self.recurrent_sender_kv_group_ids - return group_ids.to(device=device) - - def _sender_kv_projection_param_grads_from_weight_grad( - self, - *, - grad_weight: torch.Tensor, - group_ids: torch.Tensor, - grouped: bool, - ) -> dict[str, torch.Tensor]: - if grad_weight.numel() == 0: - return {} - public_weight = self.public_proj.weight.detach() - kv_weight = torch.cat((self.k_weight.detach(), self.v_weight.detach()), dim=-1) - group_ids = group_ids.to(device=grad_weight.device, dtype=torch.long) - if grouped: - selected_kv = kv_weight.index_select(0, group_ids) - grad_public = torch.einsum("gdm,ghm->dh", selected_kv, grad_weight) - grad_selected_kv = torch.einsum("dh,ghm->gdm", public_weight, grad_weight) - else: - selected_kv = kv_weight.index_select(0, group_ids) - grad_public = torch.einsum("ndm,nhm->dh", selected_kv, grad_weight) - grad_selected_kv = torch.einsum("dh,nhm->ndm", public_weight, grad_weight) - grad_k_selected, grad_v_selected = grad_selected_kv.split((self.head_dim, self.value_dim), dim=-1) - grad_k = torch.zeros_like(self.k_weight) - grad_v = torch.zeros_like(self.v_weight) - grad_k.index_add_(0, group_ids, grad_k_selected) - grad_v.index_add_(0, group_ids, grad_v_selected) - return { - "public_proj.weight": grad_public, - "k_weight": grad_k, - "v_weight": grad_v, - } - - def _sender_kv_projection_named_param_grads_from_raw( - self, - raw_grad: SenderKVProjectionRawParamGrad, - ) -> dict[str, torch.Tensor]: - return self._sender_kv_projection_param_grads_from_weight_grad( - grad_weight=raw_grad.grad_weight, - group_ids=raw_grad.group_ids, - grouped=raw_grad.grouped, - ) - - def _sender_kv_projection_param_grad_tuple_from_raw_grads( - self, - raw_grads: tuple[SenderKVProjectionRawParamGrad | None, ...], - *, - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - ) -> tuple[torch.Tensor | None, ...]: - raw_buckets: dict[tuple[str, bool, tuple[int, ...]], SenderKVProjectionRawParamGrad] = {} - for raw_grad in raw_grads: - if raw_grad is None: - continue - key = ( - raw_grad.role, - bool(raw_grad.grouped), - tuple(int(dim) for dim in raw_grad.grad_weight.shape), - ) - existing = raw_buckets.get(key) - if existing is None: - raw_buckets[key] = raw_grad - else: - raw_buckets[key] = SenderKVProjectionRawParamGrad( - role=existing.role, - grad_weight=existing.grad_weight + raw_grad.grad_weight, - group_ids=existing.group_ids, - grouped=existing.grouped, - ) - named_accum: dict[str, torch.Tensor | None] = {} - for raw_grad in raw_buckets.values(): - named_grads = self._sender_kv_projection_named_param_grads_from_raw(raw_grad) - for name, grad in named_grads.items(): - named_accum[name] = _accumulate_tensor_grad(named_accum.get(name), grad) - return self._named_projection_param_grad_tuple( - named_grads=named_accum, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - - def _run_backend_sender_kv_projection_backward_raw_phase( - self, - *, - role: Literal["input", "recurrent"], - sender_cells: torch.Tensor, - grad_k: torch.Tensor | None, - grad_v: torch.Tensor | None, - sequence_static_tensors: Mapping[str, object], - active_receiver_window: _ReceiverWindowSpec | None = None, - boundary_requires_grad: bool = True, - owner: str = "public_projection", - ) -> tuple[torch.Tensor | None, SenderKVProjectionRawParamGrad | None]: - grad_output = self._concat_kv_grads( - grad_k, - grad_v, - head_dim=self.head_dim, - value_dim=self.value_dim, - ) - if grad_output is None: - return None, None - projection_static_tensors = ( - self._cached_receiver_window_static_tensors(sequence_static_tensors, active_receiver_window) - if role == "recurrent" - else sequence_static_tensors - ) - direct_key = "input_sender_input_to_kv_weight" if role == "input" else "recurrent_sender_input_to_kv_weight" - grouped_key = "input_group_input_to_kv_weight" if role == "input" else "recurrent_group_input_to_kv_weight" - direct_weight = cast(torch.Tensor | None, projection_static_tensors[direct_key]) - grouped_weight = cast(torch.Tensor | None, projection_static_tensors[grouped_key]) - group_size = int(self._input_sender_kv_group_size if role == "input" else self._recurrent_sender_kv_group_size) - profile_name = _sender_kv_projection_backward_profile_name(owner) - with self._backend_owner_timing(owner): - if grouped_weight is not None and group_size > 1: - with torch.profiler.record_function(profile_name): - grad_sender, grad_weight = fabric_grouped_projection_backward_cuda( - sender_cells.detach(), - grouped_weight.detach(), - grad_output.contiguous(), - group_size=group_size, - ) - raw_grad = SenderKVProjectionRawParamGrad( - role=role, - grad_weight=grad_weight, - group_ids=self._grouped_sender_kv_group_ids(role=role, device=grad_weight.device), - grouped=True, - ) - else: - if direct_weight is None: - raise RuntimeError(f"Fabric {owner} backward is missing sender KV projection weight") - with torch.profiler.record_function(profile_name): - grad_sender, grad_weight = receiver_major_affine_backward_cuda( - sender_cells.detach(), - direct_weight.detach(), - grad_output.contiguous(), - block_b=receiver_major_affine_backward_block_b( - batch_size=int(sender_cells.shape[0]), - output_dim=int(grad_output.shape[-1]), - ), - ) - raw_grad = SenderKVProjectionRawParamGrad( - role=role, - grad_weight=grad_weight, - group_ids=self._direct_sender_kv_group_ids( - role=role, - active_receiver_window=active_receiver_window, - device=grad_weight.device, - ), - grouped=False, - ) - if not boundary_requires_grad: - grad_sender = None - return grad_sender, raw_grad - - def _run_backend_sender_kv_projection_backward_phase( - self, - *, - role: Literal["input", "recurrent"], - sender_cells: torch.Tensor, - grad_k: torch.Tensor | None, - grad_v: torch.Tensor | None, - sequence_static_tensors: Mapping[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - active_receiver_window: _ReceiverWindowSpec | None = None, - boundary_requires_grad: bool = True, - owner: str = "public_projection", - ) -> tuple[torch.Tensor | None, tuple[torch.Tensor | None, ...]]: - grad_sender, raw_grad = self._run_backend_sender_kv_projection_backward_raw_phase( - role=role, - sender_cells=sender_cells, - grad_k=grad_k, - grad_v=grad_v, - sequence_static_tensors=sequence_static_tensors, - active_receiver_window=active_receiver_window, - boundary_requires_grad=boundary_requires_grad, - owner=owner, - ) - param_grads = self._sender_kv_projection_param_grad_tuple_from_raw_grads( - (raw_grad,), - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - return grad_sender, param_grads - - def _run_backend_message_receiver_backward_phase( - self, - *, - grad_msg: torch.Tensor | None, - q_subset: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - local_sender_idx: torch.Tensor, - use_sparse_messages: bool, - ) -> tuple[torch.Tensor | None, tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None]: - if grad_msg is None: - return None, None - with ( - torch.profiler.record_function("fabric.backward.message.receiver"), - self._backend_owner_timing("message.receiver"), - ): - step_flat = self._constant_step_flat( - 1, - batch_size=int(grad_msg.shape[0]), - time_steps=1, - device=grad_msg.device, - dtype=edge_delay.dtype, - ) - if not use_sparse_messages and self._local_message_step_enabled: - grad_q, receiver_max_logit, receiver_sumexp, receiver_weighted_sum = ( - fabric_local_message_partitioned_backward_receiver_cuda( - grad_msg, - q_subset, - input_k, - input_v, - recurrent_k, - recurrent_v, - local_sender_idx, - self.local_distance, - self.local_delay, - step_flat, - distance_scale=float(self.config.distance_logit_scale), - use_delay=bool(self._has_edge_delay), - ) - ) - return grad_q, (receiver_max_logit, receiver_sumexp, receiver_weighted_sum) - - grad_q = fabric_sparse_message_partitioned_backward_receiver_cuda( - grad_msg, - q_subset, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=float(self.config.distance_logit_scale), - use_delay=bool(self._has_edge_delay), - ) - return grad_q, None - - def _run_backend_message_sender_backward_phase( - self, - *, - grad_msg: torch.Tensor | None, - q_subset: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - local_receiver_idx_by_sender: torch.Tensor, - receiver_phase_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None, - use_sparse_messages: bool, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - if grad_msg is None: - return None, None, None, None - with ( - torch.profiler.record_function("fabric.backward.message.sender"), - self._backend_owner_timing("message.sender"), - ): - step_flat = self._constant_step_flat( - 1, - batch_size=int(grad_msg.shape[0]), - time_steps=1, - device=grad_msg.device, - dtype=edge_delay.dtype, - ) - if not use_sparse_messages and self._local_message_step_enabled: - assert receiver_phase_cache is not None - grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v = ( - fabric_local_message_partitioned_backward_sender_cuda( - grad_msg, - q_subset, - input_k, - input_v, - recurrent_k, - recurrent_v, - local_receiver_idx_by_sender, - self.local_distance, - self.local_delay, - step_flat, - receiver_phase_cache[0], - receiver_phase_cache[1], - receiver_phase_cache[2], - distance_scale=float(self.config.distance_logit_scale), - use_delay=bool(self._has_edge_delay), - ) - ) - return grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v - - ( - grad_input_k, - grad_input_v, - grad_recurrent_k, - grad_recurrent_v, - ) = fabric_sparse_message_partitioned_backward_sender_cuda( - grad_msg, - q_subset, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=float(self.config.distance_logit_scale), - use_delay=bool(self._has_edge_delay), - ) - return grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v - - def _run_backend_message_backward_phase( - self, - *, - grad_msg: torch.Tensor | None, - q_subset: torch.Tensor, - input_k: torch.Tensor, - input_v: torch.Tensor, - recurrent_k: torch.Tensor, - recurrent_v: torch.Tensor, - neighbor_idx: torch.Tensor, - neighbor_valid: torch.Tensor, - edge_distance: torch.Tensor, - edge_delay: torch.Tensor, - local_sender_idx: torch.Tensor, - local_receiver_idx_by_sender: torch.Tensor, - use_sparse_messages: bool, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - if grad_msg is None: - return None, None, None, None, None - if not use_sparse_messages and self._local_message_step_enabled: - with ( - torch.profiler.record_function("fabric.backward.message.fused_receiver_sender"), - self._backend_owner_timing("message.fused_receiver_sender"), - ): - step_flat = self._constant_step_flat( - 1, - batch_size=int(grad_msg.shape[0]), - time_steps=1, - device=grad_msg.device, - dtype=edge_delay.dtype, - ) - return fabric_local_message_partitioned_backward_fused_cuda( - grad_msg, - q_subset, - input_k, - input_v, - recurrent_k, - recurrent_v, - local_sender_idx, - self.local_distance, - self.local_delay, - step_flat, - distance_scale=float(self.config.distance_logit_scale), - use_delay=bool(self._has_edge_delay), - ) - grad_q, receiver_cache = self._run_backend_message_receiver_backward_phase( - grad_msg=grad_msg, - q_subset=q_subset, - input_k=input_k, - input_v=input_v, - recurrent_k=recurrent_k, - recurrent_v=recurrent_v, - neighbor_idx=neighbor_idx, - neighbor_valid=neighbor_valid, - edge_distance=edge_distance, - edge_delay=edge_delay, - local_sender_idx=local_sender_idx, - use_sparse_messages=use_sparse_messages, - ) - grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v = ( - self._run_backend_message_sender_backward_phase( - grad_msg=grad_msg, - q_subset=q_subset, - input_k=input_k, - input_v=input_v, - recurrent_k=recurrent_k, - recurrent_v=recurrent_v, - neighbor_idx=neighbor_idx, - neighbor_valid=neighbor_valid, - edge_distance=edge_distance, - edge_delay=edge_delay, - local_receiver_idx_by_sender=local_receiver_idx_by_sender, - receiver_phase_cache=receiver_cache, - use_sparse_messages=use_sparse_messages, - ) - ) - return grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v - - def _run_backend_output_projection_backward_phase( - self, - *, - output_msg: torch.Tensor, - grad_output_cells: torch.Tensor | None, - sequence_static_tensors: Mapping[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - ) -> tuple[torch.Tensor | None, tuple[torch.Tensor | None, ...]]: - if grad_output_cells is None: - return None, tuple(None for _ in trainable_params) - with self._backend_owner_timing("readout"): - with torch.profiler.record_function("fabric.backward.readout"): - value_to_output_weight = cast(torch.Tensor, sequence_static_tensors["value_to_output_weight"]) - expanded_grad_output = grad_output_cells - if ( - grad_output_cells.dim() == output_msg.dim() - and int(grad_output_cells.shape[1]) == 1 - and int(output_msg.shape[1]) > 1 - ): - if str(self.config.readout_pool) != "mean": - raise RuntimeError( - "Fabric readout backward received a pooled gradient for a non-mean readout boundary" - ) - expanded_grad_output = grad_output_cells.expand( - -1, - int(output_msg.shape[1]), - -1, - ) / max(1, int(output_msg.shape[1])) - grad_output_msg = torch.bmm( - expanded_grad_output.transpose(0, 1), - value_to_output_weight.transpose(1, 2), - ).transpose(0, 1) - grad_value_to_output_weight = torch.bmm( - output_msg.detach().transpose(0, 1).transpose(1, 2), - expanded_grad_output.contiguous().transpose(0, 1), - ) - named_grads = { - "msg_out.weight": torch.einsum( - "pdh,pvh->dv", - self.output_cell_weight.detach(), - grad_value_to_output_weight, - ), - "output_cell_weight": torch.einsum( - "dv,pvh->pdh", - self.msg_out.weight.detach(), - grad_value_to_output_weight, - ), - "output_cell_bias": expanded_grad_output.sum(dim=0), - } - return grad_output_msg, self._named_projection_param_grad_tuple( - named_grads=named_grads, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - - def _run_backend_query_param_backward_phase( - self, - *, - grad_recurrent_q: torch.Tensor | None, - grad_output_q: torch.Tensor | None, - sequence_static_tensors: Mapping[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - device: torch.device, - dtype: torch.dtype, - active_receiver_window: _ReceiverWindowSpec | None = None, - ) -> tuple[torch.Tensor | None, ...]: - if grad_recurrent_q is None and grad_output_q is None: - return tuple(None for _ in trainable_params) - with self._backend_owner_timing("message.query_param"): - with torch.profiler.record_function("fabric.backward.message.query_param"): - grad_q_full = torch.zeros( - int(self.coords.shape[0]), - self.head_dim, - device=device, - dtype=dtype, - ) - if grad_recurrent_q is not None: - recurrent_cell_idx = self.recurrent_cell_idx - if active_receiver_window is not None and active_receiver_window.active: - recurrent_cell_idx = recurrent_cell_idx.narrow( - 0, - int(active_receiver_window.start), - int(active_receiver_window.count), - ) - grad_q_full.index_add_( - 0, - recurrent_cell_idx.to(device=device), - grad_recurrent_q, - ) - if grad_output_q is not None: - grad_q_full.index_add_( - 0, - self.output_cell_idx.to(device=device), - grad_output_q, - ) - named_grads = { - "slot_embed": grad_q_full.matmul(self.q_proj.weight.detach()), - "q_proj.weight": grad_q_full.t().matmul(self.slot_embed.detach()), - } - return self._named_projection_param_grad_tuple( - named_grads=named_grads, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - - def _run_backend_state_public_backward_phase( - self, - *, - population_name: str | None = None, - recurrent_msg: torch.Tensor, - recurrent_hidden_tape: torch.Tensor | None, - packed_state_before: Any, - population_reset_step: torch.Tensor | None, - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - grad_recurrent_k: torch.Tensor | None, - grad_recurrent_v: torch.Tensor | None, - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - trainable_param_shapes: tuple[tuple[int, ...], ...] | None = None, - sequence_static_tensors: Mapping[str, object], - param_static_tensors: Mapping[str, object] | None = None, - transition_backward_tape: Any | None, - need_grad_packed_state_before: bool, - device: torch.device, - dtype: torch.dtype, - active_receiver_window: _ReceiverWindowSpec | None = None, - ) -> tuple[torch.Tensor | None, Any, tuple[torch.Tensor | None, ...]]: - population_name = self._full_recurrent_population_name if population_name is None else population_name - if population_name is None: - raise RuntimeError("Fabric CUDA state/public backward requires a recurrent population") - backward_attribution_mode = os.environ.get(_BACKWARD_ATTRIBUTION_MODE_ENV) - state_public_profile_name = self._state_public_backward_profile_name_for_population(population_name) - split_recurrent_kv_projection = backward_attribution_mode != "state_public_output_probe" and ( - grad_recurrent_k is not None or grad_recurrent_v is not None - ) - if backward_attribution_mode not in {"state_public_output_probe", "state_public_state_probe"}: - transition_static_tensors = self._cached_receiver_window_static_tensors( - sequence_static_tensors, - active_receiver_window, - ) - - if not transition_static_tensors: - transition_static_tensors = self._materialize_inference_static_tensors( - device=device, - dtype=dtype, - include_backend_cell_tensors=False, - ) - - projection_grad_result_map: dict[str, torch.Tensor | None] = {} - combined_grad_recurrent_hidden = grad_recurrent_hidden - if split_recurrent_kv_projection: - if recurrent_hidden_tape is None: - raise RuntimeError("Fabric CUDA recurrent K/V projection backward requires recurrent hidden tape") - projection_grad_recurrent_hidden, projection_param_grads = ( - self._run_backend_sender_kv_projection_backward_phase( - role="recurrent", - sender_cells=recurrent_hidden_tape, - grad_k=grad_recurrent_k, - grad_v=grad_recurrent_v, - sequence_static_tensors=sequence_static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - active_receiver_window=active_receiver_window, - owner="grouped_projection", - ) - ) - combined_grad_recurrent_hidden = _accumulate_tensor_grad( - combined_grad_recurrent_hidden, - projection_grad_recurrent_hidden, - ) - projection_grad_result_map = { - f"param_{index}": grad for index, grad in enumerate(projection_param_grads) - } - - with ( - torch.no_grad(), - torch.profiler.record_function(state_public_profile_name), - ): - transition_backward = self._lower_backend_population_transition_backward_shared( - population_name=population_name, - recurrent_msg=recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=population_reset_step, - static_tensors=transition_static_tensors, - grad_next_packed_state=grad_next_packed_state, - grad_recurrent_hidden=combined_grad_recurrent_hidden, - need_grad_packed_state_before=need_grad_packed_state_before, - forward_tape=transition_backward_tape, - ) - return ( - cast(torch.Tensor | None, transition_backward.grad_recurrent_msg), - transition_backward.grad_packed_state_before, - self._state_public_explicit_param_grad_tuple( - population_name=population_name, - materialized_param_grads=transition_backward.materialized_param_grads, - static_source_grads=transition_backward.static_source_grads, - projection_param_grads=projection_grad_result_map, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - trainable_param_shapes=trainable_param_shapes, - active_receiver_window=active_receiver_window, - ), - ) - - packed_state_keys, packed_state_tensors = _flatten_backend_packed_state(packed_state_before) - detached_recurrent_msg = recurrent_msg.detach().requires_grad_(True) - detached_packed_state_tensors = tuple( - tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach() - for tensor in packed_state_tensors - ) - detached_packed_state = _unflatten_backend_packed_state(packed_state_keys, detached_packed_state_tensors) - with torch.enable_grad(): - step_static_tensors = self._materialize_inference_static_tensors( - device=device, - dtype=dtype, - include_backend_cell_tensors=False, - ) - replay_next_packed_state, replay_recurrent_hidden, replay_recurrent_k, replay_recurrent_v = ( - self._lower_backend_population_transition_shared( - population_name=population_name, - recurrent_msg=detached_recurrent_msg, - packed_state_before=detached_packed_state, - population_reset_step=population_reset_step, - static_tensors=step_static_tensors, - materialize_recurrent_kv=not split_recurrent_kv_projection, - ) - ) - - recurrent_hidden_output_specs: list[tuple[torch.Tensor, torch.Tensor | None]] = [ - (replay_recurrent_hidden, grad_recurrent_hidden) - ] - recurrent_kv_output_specs: list[tuple[torch.Tensor, torch.Tensor | None]] = [ - (replay_recurrent_k, grad_recurrent_k), - (replay_recurrent_v, grad_recurrent_v), - ] - next_state_output_specs: list[tuple[torch.Tensor, torch.Tensor | None]] = [] - if packed_state_keys is None: - if torch.is_tensor(replay_next_packed_state): - next_state_output_specs.append( - (replay_next_packed_state, cast(torch.Tensor | None, grad_next_packed_state)) - ) - else: - assert isinstance(replay_next_packed_state, (dict, TensorDictBase)) - assert grad_next_packed_state is None or isinstance(grad_next_packed_state, (dict, TensorDictBase)) - for key in packed_state_keys: - next_state_output_specs.append( - ( - replay_next_packed_state[key], - None if grad_next_packed_state is None else grad_next_packed_state.get(key), - ) - ) - output_specs = recurrent_hidden_output_specs + recurrent_kv_output_specs + next_state_output_specs - - input_specs: list[tuple[str, torch.Tensor]] = [("recurrent_msg", detached_recurrent_msg)] - if packed_state_keys is None: - input_specs.append(("packed_state_0", cast(torch.Tensor, detached_packed_state))) - else: - assert isinstance(detached_packed_state, (dict, TensorDictBase)) - for index, key in enumerate(packed_state_keys): - input_specs.append((f"packed_state_{index}", detached_packed_state[key])) - input_specs.extend( - _param_input_specs( - _state_public_param_specs( - trainable_param_names, - trainable_params, - include_sender_kv_projection=not split_recurrent_kv_projection, - ) - ) - ) - if backward_attribution_mode == "state_public_output_probe": - grad_result_map: dict[str, torch.Tensor | None] = {name: None for name, _tensor in input_specs} - for profile_name, probe_output_specs in ( - (f"{state_public_profile_name}.recurrent_hidden", recurrent_hidden_output_specs), - ("fabric.backward.grouped_projection.recurrent_kv", recurrent_kv_output_specs), - (f"{state_public_profile_name}.next_state", next_state_output_specs), - ): - probe_grad_result_map = self._run_named_autograd_phase( - outputs=probe_output_specs, - inputs=input_specs, - profile_name=profile_name, - retain_graph=True, - ) - for name, new_grad in probe_grad_result_map.items(): - if new_grad is None: - continue - grad_result_map[name] = _accumulate_tensor_grad(grad_result_map[name], new_grad) - else: - projection_grad_result_map: dict[str, torch.Tensor | None] = {} - if split_recurrent_kv_projection: - detached_recurrent_hidden = replay_recurrent_hidden.detach().requires_grad_(True) - with ( - torch.enable_grad(), - torch.profiler.record_function("fabric.backward.grouped_projection.recompute"), - ): - detached_replay_recurrent_k, detached_replay_recurrent_v = self._project_sender_kv_from_cells_step( - detached_recurrent_hidden, - sender_input_to_kv_weight=cast( - torch.Tensor | None, - step_static_tensors["recurrent_sender_input_to_kv_weight"], - ), - grouped_sender_input_to_kv_weight=cast( - torch.Tensor | None, - step_static_tensors["recurrent_group_input_to_kv_weight"], - ), - sender_group_size=self._recurrent_sender_kv_group_size, - ) - projection_input_specs: list[tuple[str, torch.Tensor]] = [ - ("recurrent_hidden", detached_recurrent_hidden) - ] - projection_input_specs.extend( - _param_input_specs(_sender_kv_projection_param_specs(trainable_param_names, trainable_params)) - ) - with self._backend_owner_timing("grouped_projection"): - projection_grad_result_map = self._run_named_autograd_phase( - outputs=[ - (detached_replay_recurrent_k, grad_recurrent_k), - (detached_replay_recurrent_v, grad_recurrent_v), - ], - inputs=projection_input_specs, - profile_name="fabric.backward.grouped_projection", - ) - recurrent_hidden_output_specs = [ - ( - replay_recurrent_hidden, - _accumulate_tensor_grad( - grad_recurrent_hidden, - cast(torch.Tensor | None, projection_grad_result_map.get("recurrent_hidden")), - ), - ) - ] - output_specs = recurrent_hidden_output_specs + next_state_output_specs - if backward_attribution_mode == "state_public_state_probe": - grad_result_map = {name: None for name, _tensor in input_specs} - active_probe_specs = [ - (f"{state_public_profile_name}.recurrent_hidden", recurrent_hidden_output_specs), - (f"{state_public_profile_name}.next_state", next_state_output_specs), - ] - active_probe_specs = [(name, specs) for name, specs in active_probe_specs if specs] - for probe_index, (profile_name, probe_output_specs) in enumerate(active_probe_specs): - probe_grad_result_map = self._run_named_autograd_phase( - outputs=probe_output_specs, - inputs=input_specs, - profile_name=profile_name, - retain_graph=probe_index < len(active_probe_specs) - 1, - ) - for name, new_grad in probe_grad_result_map.items(): - if new_grad is None: - continue - grad_result_map[name] = _accumulate_tensor_grad(grad_result_map.get(name), new_grad) - else: - with self._backend_owner_timing(state_public_profile_name.removeprefix("fabric.backward.")): - grad_result_map = self._run_named_autograd_phase( - outputs=output_specs, - inputs=input_specs, - profile_name=state_public_profile_name, - ) - for parameter_index in range(len(trainable_params)): - parameter_key = f"param_{parameter_index}" - grad_result_map[parameter_key] = _accumulate_tensor_grad( - grad_result_map.get(parameter_key), - cast(torch.Tensor | None, projection_grad_result_map.get(parameter_key)), - ) - if packed_state_keys is None: - grad_packed_state_before = grad_result_map.get("packed_state_0") - else: - grad_packed_state_before = { - key: cast(torch.Tensor | None, grad_result_map.get(f"packed_state_{index}")) - for index, key in enumerate(packed_state_keys) - } - return ( - cast(torch.Tensor | None, grad_result_map.get("recurrent_msg")), - grad_packed_state_before, - _param_grad_tuple(len(trainable_params), grad_result_map), - ) - - def _run_backend_boundary_public_backward_phase( - self, - *, - boundary_step: torch.Tensor, - grad_input_k: torch.Tensor | None, - grad_input_v: torch.Tensor | None, - sequence_static_tensors: Mapping[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - device: torch.device, - dtype: torch.dtype, - boundary_requires_grad: bool = True, - ) -> tuple[torch.Tensor | None, tuple[torch.Tensor | None, ...]]: - return self._run_backend_sender_kv_projection_backward_phase( - role="input", - sender_cells=boundary_step, - grad_k=grad_input_k, - grad_v=grad_input_v, - sequence_static_tensors=sequence_static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - boundary_requires_grad=boundary_requires_grad, - owner="public_projection", - ) - - def _run_backend_initial_recurrent_backward_phase( - self, - *, - hidden_before: torch.Tensor | None, - initial_recurrent_k_before: torch.Tensor | None, - initial_recurrent_v_before: torch.Tensor | None, - population_reset_step: torch.Tensor | None, - grad_resolved_recurrent_k: torch.Tensor | None, - grad_resolved_recurrent_v: torch.Tensor | None, - sequence_static_tensors: Mapping[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - active_receiver_window: _ReceiverWindowSpec | None, - device: torch.device, - dtype: torch.dtype, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, tuple[torch.Tensor | None, ...]]: - grad_hidden_before, grad_initial_k, grad_initial_v, raw_grad = ( - self._run_backend_initial_recurrent_backward_raw_phase( - hidden_before=hidden_before, - initial_recurrent_k_before=initial_recurrent_k_before, - initial_recurrent_v_before=initial_recurrent_v_before, - population_reset_step=population_reset_step, - grad_resolved_recurrent_k=grad_resolved_recurrent_k, - grad_resolved_recurrent_v=grad_resolved_recurrent_v, - sequence_static_tensors=sequence_static_tensors, - active_receiver_window=active_receiver_window, - device=device, - dtype=dtype, - ) - ) - param_grads = self._sender_kv_projection_param_grad_tuple_from_raw_grads( - (raw_grad,), - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - return grad_hidden_before, grad_initial_k, grad_initial_v, param_grads - - def _run_backend_initial_recurrent_backward_raw_phase( - self, - *, - hidden_before: torch.Tensor | None, - initial_recurrent_k_before: torch.Tensor | None, - initial_recurrent_v_before: torch.Tensor | None, - population_reset_step: torch.Tensor | None, - grad_resolved_recurrent_k: torch.Tensor | None, - grad_resolved_recurrent_v: torch.Tensor | None, - sequence_static_tensors: Mapping[str, object], - active_receiver_window: _ReceiverWindowSpec | None, - device: torch.device, - dtype: torch.dtype, - ) -> tuple[ - torch.Tensor | None, - torch.Tensor | None, - torch.Tensor | None, - SenderKVProjectionRawParamGrad | None, - ]: - if grad_resolved_recurrent_k is None and grad_resolved_recurrent_v is None: - return None, None, None, None - if initial_recurrent_k_before is not None and initial_recurrent_v_before is not None: - if population_reset_step is None: - return ( - None, - grad_resolved_recurrent_k, - grad_resolved_recurrent_v, - None, - ) - batch_size = int(initial_recurrent_k_before.shape[0]) - reset_rows = torch.as_tensor(population_reset_step, device=device, dtype=torch.bool).view(batch_size) - grad_initial_recurrent_k_before = ( - None - if grad_resolved_recurrent_k is None - else reset_backend_tensors_rows_cuda((grad_resolved_recurrent_k,), reset_rows)[0] - ) - grad_initial_recurrent_v_before = ( - None - if grad_resolved_recurrent_v is None - else reset_backend_tensors_rows_cuda((grad_resolved_recurrent_v,), reset_rows)[0] - ) - return ( - None, - cast(torch.Tensor | None, grad_initial_recurrent_k_before), - cast(torch.Tensor | None, grad_initial_recurrent_v_before), - None, - ) - if hidden_before is None: - raise RuntimeError( - "Fabric recurrent-carry thin reverse requires hidden_before when recurrent K/V are not explicit" - ) - detached_hidden_before = hidden_before.detach().requires_grad_(True) - grad_hidden_before, raw_grad = self._run_backend_sender_kv_projection_backward_raw_phase( - role="recurrent", - sender_cells=detached_hidden_before, - grad_k=grad_resolved_recurrent_k, - grad_v=grad_resolved_recurrent_v, - sequence_static_tensors=sequence_static_tensors, - active_receiver_window=active_receiver_window, - owner="glue.initial_recurrent", - ) - if population_reset_step is not None and grad_hidden_before is not None: - reset_rows = torch.as_tensor(population_reset_step, device=device, dtype=torch.bool).view( - hidden_before.shape[0] - ) - grad_hidden_before = reset_backend_tensors_rows_cuda((grad_hidden_before,), reset_rows)[0] - return (grad_hidden_before, None, None, raw_grad) - - def _run_backend_sequence_surface_backward_batch_tiled( - self, - *, - batch_tile_len: int, - active_receiver_window: _ReceiverWindowSpec | None, - boundary_seq: torch.Tensor, - projected_boundary_source_seq: torch.Tensor | None, - projected_boundary_weight: torch.Tensor | None, - projected_boundary_bias: torch.Tensor | None, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - initial_state_is_fresh: bool, - population_resets: torch.Tensor | None, - planned_backend_execution: PlannedFabricExecution, - planned_backend_backward_execution: PlannedFabricBackwardExecution, - grad_output_seq: torch.Tensor | None, - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - grad_recurrent_k: torch.Tensor | None, - grad_recurrent_v: torch.Tensor | None, - grad_input_k_last: torch.Tensor | None, - grad_input_v_last: torch.Tensor | None, - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - replay_static_tensors: dict[str, object], - output_boundary: Literal["sequence", "terminal"] = "sequence", - forward_carry_checkpoints: Any | None = None, - ) -> tuple[dict[str, torch.Tensor | None], tuple[torch.Tensor | None, ...]]: - merged_grad_map: dict[str, torch.Tensor | None] = {"population_resets": None} - grad_param_accum: list[torch.Tensor | None] = [None] * len(trainable_params) - batch_size = int(boundary_seq.shape[0]) - - def merge_tile_grad_map( - tile_grad_map: Mapping[str, torch.Tensor | None], - *, - start: int, - end: int, - ) -> None: - for key, tile_grad in tile_grad_map.items(): - if key == "population_resets": - merged_grad_map[key] = None - continue - if tile_grad is None: - continue - if key in {"projected_boundary_weight", "projected_boundary_bias"}: - merged_grad_map[key] = _accumulate_owned_tensor_grad( - cast(torch.Tensor | None, merged_grad_map.get(key)), - tile_grad, - ) - continue - merged = merged_grad_map.get(key) - if merged is None: - merged_shape = (batch_size, *tuple(tile_grad.shape[1:])) - merged = tile_grad.new_zeros(merged_shape) - merged_grad_map[key] = merged - merged[start:end].copy_(tile_grad) - - for start in range(0, int(boundary_seq.shape[0]), int(batch_tile_len)): - end = min(start + int(batch_tile_len), int(boundary_seq.shape[0])) - with ( - torch.profiler.record_function("fabric.backward.batch_tile"), - self._backend_owner_timing("batch_tile"), - ): - tile_grad_map, tile_param_grads = self._run_backend_sequence_surface_backward_once( - boundary_seq=boundary_seq[start:end], - projected_boundary_source_seq=_slice_batch_tensor(projected_boundary_source_seq, start, end), - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - packed_state=_slice_batch_tree(packed_state, start, end), - initial_hidden=initial_hidden[start:end], - initial_recurrent_k=_slice_batch_tensor(initial_recurrent_k, start, end), - initial_recurrent_v=_slice_batch_tensor(initial_recurrent_v, start, end), - initial_state_is_fresh=initial_state_is_fresh, - population_resets=_slice_batch_tensor(population_resets, start, end), - planned_backend_execution=planned_backend_execution, - planned_backend_backward_execution=planned_backend_backward_execution, - grad_output_seq=_slice_batch_tensor(grad_output_seq, start, end), - grad_next_packed_state=_slice_batch_tree(grad_next_packed_state, start, end), - grad_recurrent_hidden=_slice_batch_tensor(grad_recurrent_hidden, start, end), - grad_recurrent_k=_slice_batch_tensor(grad_recurrent_k, start, end), - grad_recurrent_v=_slice_batch_tensor(grad_recurrent_v, start, end), - grad_input_k_last=_slice_batch_tensor(grad_input_k_last, start, end), - grad_input_v_last=_slice_batch_tensor(grad_input_v_last, start, end), - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - replay_static_tensors=replay_static_tensors, - output_boundary=output_boundary, - forward_carry_checkpoints=_slice_forward_carry_checkpoints( - forward_carry_checkpoints, - start, - end, - ), - allow_batch_tiling=False, - ) - merge_tile_grad_map(tile_grad_map, start=start, end=end) - for parameter_index, grad_param in enumerate(tile_param_grads): - grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( - grad_param_accum[parameter_index], - grad_param, - ) - del tile_grad_map, tile_param_grads - self._receiver_window_static_tensor_dict_cache = {} - return ( - merged_grad_map, - self._finalize_backward_param_grads( - trainable_params=trainable_params, - grad_param_accum=grad_param_accum, - active_receiver_window=active_receiver_window, - ), - ) - - def _append_backend_backward_runtime_metadata(self) -> None: - record = getattr(self, "_last_backend_execution", None) - if record is None: - return - additions: list[str] = [] - recompute_artifact_window_len = getattr(self, "_last_backend_recompute_artifact_window_len", None) - if recompute_artifact_window_len is not None: - additions.append(f"training_recompute_artifact_window:t={recompute_artifact_window_len}") - recompute_artifact_window_reason = getattr(self, "_last_backend_recompute_artifact_window_reason", None) - if recompute_artifact_window_reason: - additions.append(f"training_recompute_artifact_window_reason:{recompute_artifact_window_reason}") - recompute_checkpoint_stride = getattr(self, "_last_backend_recompute_checkpoint_stride", None) - if recompute_checkpoint_stride is not None: - additions.append(f"training_recompute_checkpoint_stride:t={recompute_checkpoint_stride}") - recompute_checkpoint_count = getattr(self, "_last_backend_recompute_checkpoint_count", None) - if recompute_checkpoint_count is not None: - additions.append(f"training_recompute_checkpoint_count:n={recompute_checkpoint_count}") - recompute_checkpoint_reason = getattr(self, "_last_backend_recompute_checkpoint_reason", None) - if recompute_checkpoint_reason: - additions.append(f"training_recompute_checkpoint_reason:{recompute_checkpoint_reason}") - recompute_checkpoint_source = getattr(self, "_last_backend_recompute_checkpoint_source", None) - if recompute_checkpoint_source: - additions.append(f"training_recompute_checkpoint_source:{recompute_checkpoint_source}") - recompute_checkpoint_hidden_carry = getattr( - self, - "_last_backend_recompute_checkpoint_hidden_carry_mode", - None, - ) - if recompute_checkpoint_hidden_carry: - additions.append(f"training_recompute_checkpoint_hidden_carry:{recompute_checkpoint_hidden_carry}") - recompute_checkpoint_artifact_cache = getattr( - self, - "_last_backend_recompute_checkpoint_artifact_cache_mode", - None, - ) - if recompute_checkpoint_artifact_cache: - additions.append(f"training_recompute_checkpoint_artifact_cache:{recompute_checkpoint_artifact_cache}") - recompute_predecessor_cache_mode = getattr(self, "_last_backend_recompute_predecessor_cache_mode", None) - if recompute_predecessor_cache_mode: - additions.append(f"training_recompute_predecessor_cache:{recompute_predecessor_cache_mode}") - recompute_transition_tape_mode = getattr(self, "_last_backend_recompute_transition_tape_mode", None) - if recompute_transition_tape_mode: - additions.append(f"training_recompute_transition_tape:{recompute_transition_tape_mode}") - recompute_transition_tape_reason = getattr(self, "_last_backend_recompute_transition_tape_reason", None) - if recompute_transition_tape_reason: - additions.append(f"training_recompute_transition_tape_reason:{recompute_transition_tape_reason}") - recompute_payload_max_bytes = getattr(self, "_last_backend_recompute_payload_max_bytes", None) - if isinstance(recompute_payload_max_bytes, Mapping) and recompute_payload_max_bytes: - additions.append( - f"training_recompute_payload_max_bytes:{_format_payload_bytes(recompute_payload_max_bytes)}" - ) - recompute_payload_window_len = getattr(self, "_last_backend_recompute_payload_max_window_len", None) - if recompute_payload_window_len is not None: - additions.append(f"training_recompute_payload_max_window:t={int(recompute_payload_window_len)}") - recompute_payload_mode = getattr(self, "_last_backend_recompute_payload_max_mode", None) - if recompute_payload_mode: - additions.append(f"training_recompute_payload_max_mode:{recompute_payload_mode}") - recompute_payload_samples = getattr(self, "_last_backend_recompute_payload_sample_count", None) - if recompute_payload_samples is not None: - additions.append(f"training_recompute_payload_samples:n={int(recompute_payload_samples)}") - public_kv_materialization_mode = getattr( - self, - "_last_backend_recompute_public_kv_materialization_mode", - None, - ) - if public_kv_materialization_mode: - additions.append(f"training_recompute_public_kv:{public_kv_materialization_mode}") - target_state_materialization_mode = getattr( - self, - "_last_backend_recompute_target_state_materialization_mode", - None, - ) - if target_state_materialization_mode: - additions.append(f"training_recompute_target_state:{target_state_materialization_mode}") - active_receiver_window = getattr(self, "_last_backend_backward_active_receiver_window", None) - if active_receiver_window: - additions.append(f"backward_active_receiver_window:{active_receiver_window}") - active_receiver_window_reason = getattr(self, "_last_backend_backward_active_receiver_window_reason", None) - if active_receiver_window_reason: - additions.append(f"backward_active_receiver_window_reason:{active_receiver_window_reason}") - backward_batch_tile_len = getattr(self, "_last_backend_backward_batch_tile_len", None) - if backward_batch_tile_len is not None: - additions.append(f"training_backward_batch_tile:b={backward_batch_tile_len}") - backward_batch_tile_reason = getattr(self, "_last_backend_backward_batch_tile_reason", None) - if backward_batch_tile_reason: - additions.append(f"training_backward_batch_tile_reason:{backward_batch_tile_reason}") - if not additions: - return - existing = tuple(record.backward_saved_launch_counts) - merged = existing + tuple(addition for addition in additions if addition not in existing) - if merged == existing: - return - self._last_backend_execution = replace(record, backward_saved_launch_counts=merged) - - def _run_backend_sequence_surface_backward_once( - self, - *, - boundary_seq: torch.Tensor, - projected_boundary_source_seq: torch.Tensor | None = None, - projected_boundary_weight: torch.Tensor | None = None, - projected_boundary_bias: torch.Tensor | None = None, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - initial_state_is_fresh: bool, - population_resets: torch.Tensor | None, - planned_backend_execution: PlannedFabricExecution, - planned_backend_backward_execution: PlannedFabricBackwardExecution, - grad_output_seq: torch.Tensor | None, - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - grad_recurrent_k: torch.Tensor | None, - grad_recurrent_v: torch.Tensor | None, - grad_input_k_last: torch.Tensor | None, - grad_input_v_last: torch.Tensor | None, - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - replay_static_tensors: dict[str, object], - output_boundary: Literal["sequence", "terminal"] = "sequence", - forward_carry_checkpoints: Any | None = None, - allow_batch_tiling: bool = True, - ) -> tuple[dict[str, torch.Tensor | None], tuple[torch.Tensor | None, ...]]: - self._receiver_window_static_tensor_dict_cache = {} - if output_boundary not in {"sequence", "terminal"}: - raise ValueError(f"Unsupported Fabric backward output boundary {output_boundary!r}") - projected_boundary_active = ( - projected_boundary_source_seq is not None - or projected_boundary_weight is not None - or projected_boundary_bias is not None - ) - if projected_boundary_active and (projected_boundary_source_seq is None or projected_boundary_weight is None): - raise RuntimeError("Projected Fabric boundary backward requires both source sequence and projection weight") - if projected_boundary_active: - assert projected_boundary_source_seq is not None and projected_boundary_weight is not None - if projected_boundary_source_seq.dim() != 3: - raise RuntimeError("Projected Fabric boundary source must be shaped [B,T,H]") - if tuple(projected_boundary_source_seq.shape[:2]) != tuple(boundary_seq.shape[:2]): - raise RuntimeError("Projected Fabric boundary source must match boundary batch/time axes") - if projected_boundary_weight.dim() != 2: - raise RuntimeError("Projected Fabric boundary weight must be rank-2") - projected_features = int(boundary_seq.shape[2]) * int(boundary_seq.shape[3]) - if tuple(projected_boundary_weight.shape) != ( - projected_features, - int(projected_boundary_source_seq.shape[-1]), - ): - raise RuntimeError( - "Projected Fabric boundary weight must have shape " - f"[{projected_features}, {int(projected_boundary_source_seq.shape[-1])}]" - ) - if projected_boundary_bias is not None and tuple(projected_boundary_bias.shape) != (projected_features,): - raise RuntimeError( - "Projected Fabric boundary bias must have shape " - f"[{projected_features}], got {tuple(projected_boundary_bias.shape)}" - ) - if ( - not planned_backend_backward_execution.receiver_bucket_plans - or not planned_backend_backward_execution.sender_bucket_plans - ): - raise RuntimeError("Supported Fabric training surface requires a planned backward execution") - if any( - bucket_plan.execution_family != ExecutionFamily.RECEIVER_MAJOR - for bucket_plan in planned_backend_backward_execution.receiver_bucket_plans - ): - raise RuntimeError("Supported Fabric backward surface requires receiver-major receiver-adjoint execution") - if any( - bucket_plan.execution_family != ExecutionFamily.EDGE_MAJOR - for bucket_plan in planned_backend_backward_execution.sender_bucket_plans - ): - raise RuntimeError("Supported Fabric backward surface requires edge-major sender/public accumulation") - - def _accumulate_grad(current: torch.Tensor | None, new_grad: torch.Tensor | None) -> torch.Tensor | None: - if new_grad is None: - return current - return new_grad if current is None else current + new_grad - - static_tensors = dict(replay_static_tensors) - if not static_tensors: - static_tensors = self._materialize_inference_static_tensors( - device=boundary_seq.device, - dtype=boundary_seq.dtype, - include_backend_cell_tensors=False, - ) - param_static_tensors = static_tensors - backend_population_name = self._select_output_cells_stream_backend_population( - k=1, - ) - if backend_population_name is None: - raise RuntimeError( - f"Supported Fabric {planned_backend_execution.surface_key} backward surface " - "requires the backend-owned CUDA sequence surface" - ) - uses_sparse_messages = any( - bucket_plan.execution_family == ExecutionFamily.EDGE_MAJOR - for bucket_plan in planned_backend_execution.bucket_plans - ) - active_receiver_window = self._backward_active_receiver_window_for_output_only_step( - boundary_seq=boundary_seq, - output_boundary=output_boundary, - uses_sparse_messages=uses_sparse_messages, - initial_state_is_fresh=initial_state_is_fresh, - packed_state=packed_state, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - grad_next_packed_state=grad_next_packed_state, - grad_recurrent_hidden=grad_recurrent_hidden, - grad_recurrent_k=grad_recurrent_k, - grad_recurrent_v=grad_recurrent_v, - grad_input_k_last=grad_input_k_last, - grad_input_v_last=grad_input_v_last, - ) - active_receiver_window = self._infer_backward_compact_carry_receiver_window( - packed_state=packed_state, - initial_hidden=initial_hidden, - active_receiver_window=active_receiver_window, - ) - elide_trace_state_next = self._can_elide_transition_trace_state_next( - population_name=backend_population_name, - active_receiver_window=active_receiver_window, - ) - if allow_batch_tiling: - sequence_surface_per_batch_bytes = self._sequence_backward_surface_per_batch_bytes( - population_name=backend_population_name, - boundary_seq=boundary_seq, - active_receiver_window=active_receiver_window, - ) - batch_tile_len = ( - self._active_output_dependency_backward_batch_tile_len( - population_name=backend_population_name, - boundary_seq=boundary_seq, - active_receiver_window=active_receiver_window, - output_boundary=output_boundary, - ) - if active_receiver_window is not None and active_receiver_window.active - else self._backend_backward_batch_tile_len( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - output_boundary=output_boundary, - sequence_surface_per_batch_bytes=sequence_surface_per_batch_bytes, - ) - ) - if 0 < batch_tile_len < int(boundary_seq.shape[0]): - return self._run_backend_sequence_surface_backward_batch_tiled( - batch_tile_len=batch_tile_len, - active_receiver_window=active_receiver_window, - boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=initial_state_is_fresh, - population_resets=population_resets, - planned_backend_execution=planned_backend_execution, - planned_backend_backward_execution=planned_backend_backward_execution, - grad_output_seq=grad_output_seq, - grad_next_packed_state=grad_next_packed_state, - grad_recurrent_hidden=grad_recurrent_hidden, - grad_recurrent_k=grad_recurrent_k, - grad_recurrent_v=grad_recurrent_v, - grad_input_k_last=grad_input_k_last, - grad_input_v_last=grad_input_v_last, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - replay_static_tensors=static_tensors, - output_boundary=output_boundary, - forward_carry_checkpoints=forward_carry_checkpoints, - ) - synthetic_initial_hidden = False - synthetic_initial_recurrent_k = False - synthetic_initial_recurrent_v = False - if packed_state is None: - if not initial_state_is_fresh: - raise RuntimeError("Fabric physical backward received no packed state for a non-fresh sequence") - receiver_count = ( - int(active_receiver_window.count) - if active_receiver_window is not None and active_receiver_window.active - else int(self._num_recurrent_cells) - ) - packed_state = self._init_backend_population_state_for_receivers( - backend_population_name, - batch=int(boundary_seq.shape[0]), - receivers=receiver_count, - device=boundary_seq.device, - dtype=boundary_seq.dtype, - state_names=self._transition_core_state_names_for_population(backend_population_name) - if elide_trace_state_next - else None, - ) - if int(initial_hidden.shape[1]) == 0: - synthetic_initial_hidden = True - initial_hidden = boundary_seq.new_zeros( - int(boundary_seq.shape[0]), - receiver_count, - self.hidden_size, - ) - recurrent_bank_receiver_count = ( - int(active_receiver_window.count) - if self._receiver_window_compacts_recurrent_senders(active_receiver_window) - and active_receiver_window is not None - else int(self._num_recurrent_cells) - ) - if initial_recurrent_k is None: - synthetic_initial_recurrent_k = True - initial_recurrent_k = boundary_seq.new_zeros( - int(boundary_seq.shape[0]), - recurrent_bank_receiver_count, - self.head_dim, - ) - if initial_recurrent_v is None: - synthetic_initial_recurrent_v = True - initial_recurrent_v = boundary_seq.new_zeros( - int(boundary_seq.shape[0]), - recurrent_bank_receiver_count, - self.value_dim, - ) - - external_packed_state_like = packed_state - external_initial_hidden_like = None if synthetic_initial_hidden else initial_hidden - external_initial_recurrent_k_like = None if synthetic_initial_recurrent_k else initial_recurrent_k - external_initial_recurrent_v_like = None if synthetic_initial_recurrent_v else initial_recurrent_v - if active_receiver_window is not None and active_receiver_window.active: - packed_state = self._slice_receiver_window_batch_rows(packed_state, active_receiver_window) - initial_hidden = self._slice_receiver_window_batch_rows(initial_hidden, active_receiver_window) - initial_recurrent_k = ( - self._slice_receiver_window_recurrent_bank( - initial_recurrent_k, - active_receiver_window, - ) - if initial_recurrent_k is not None - else None - ) - initial_recurrent_v = ( - self._slice_receiver_window_recurrent_bank( - initial_recurrent_v, - active_receiver_window, - ) - if initial_recurrent_v is not None - else None - ) - - packed_state_keys, packed_state_tensors = _flatten_backend_packed_state(packed_state) - packed_state_input_requires_grad = any( - tensor.is_floating_point() and tensor.requires_grad for tensor in packed_state_tensors - ) - artifact_mode = getattr(self, "_last_backend_tape_artifact_mode", None) or "store_step_artifacts" - tape_chunk_len = int( - getattr(self, "_last_backend_tape_chunk_len", boundary_seq.shape[1]) or boundary_seq.shape[1] - ) - if ( - artifact_mode != "recompute_step_artifacts" - and output_boundary == "sequence" - and int(boundary_seq.shape[1]) > int(tape_chunk_len) - ): - artifact_mode = "recompute_step_artifacts" - self._last_backend_tape_artifact_mode = artifact_mode - previous_reason = getattr(self, "_last_backend_tape_chunk_reason", None) - streaming_reason = ( - f"sequence_output_streaming_guard=active;" - f"time_steps={int(boundary_seq.shape[1])};tape_chunk_len={int(tape_chunk_len)}" - ) - self._last_backend_tape_chunk_reason = ( - f"{previous_reason};{streaming_reason}" if previous_reason else streaming_reason - ) - can_elide_checkpoint_hidden_carry = bool( - initial_recurrent_k is not None - and initial_recurrent_v is not None - and torch.is_tensor(initial_hidden) - and initial_hidden.dim() == 3 - ) - checkpoint_hidden_carry_mode = ( - "elided_explicit_recurrent_kv_carry" - if can_elide_checkpoint_hidden_carry - else "materialized_public_hidden_carry" - ) - self._last_backend_recompute_checkpoint_hidden_carry_mode = checkpoint_hidden_carry_mode - empty_checkpoint_hidden: torch.Tensor | None = None - - def checkpoint_hidden_carry( - running_hidden: torch.Tensor, - running_recurrent_k: torch.Tensor | None, - running_recurrent_v: torch.Tensor | None, - ) -> torch.Tensor: - nonlocal empty_checkpoint_hidden - if not can_elide_checkpoint_hidden_carry or running_recurrent_k is None or running_recurrent_v is None: - return running_hidden - if empty_checkpoint_hidden is None: - empty_checkpoint_hidden = running_hidden.new_empty( - int(running_hidden.shape[0]), - 0, - int(running_hidden.shape[-1]), - ) - return empty_checkpoint_hidden - - def effective_batch_split_active() -> bool: - return False - - def checkpoint_state( - running_packed_state: Any, - running_hidden: torch.Tensor, - running_recurrent_k: torch.Tensor | None, - running_recurrent_v: torch.Tensor | None, - ) -> tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - return ( - running_packed_state, - checkpoint_hidden_carry(running_hidden, running_recurrent_k, running_recurrent_v), - running_recurrent_k, - running_recurrent_v, - ) - - def forward_carry_checkpoint_map( - stride: int, - ) -> dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]] | None: - if forward_carry_checkpoints is None: - return None - steps = tuple(int(step) for step in getattr(forward_carry_checkpoints, "steps", ())) - state_tensors = tuple(getattr(forward_carry_checkpoints, "state_tensors", ())) - public_tensors = tuple(getattr(forward_carry_checkpoints, "public_tensors", ())) - source_stride = getattr(forward_carry_checkpoints, "stride", None) - if source_stride is not None and int(source_stride) > int(stride): - self._last_backend_recompute_checkpoint_source = ( - f"forward_carry_checkpoint_tape_miss:sparse_stride;" - f"checkpoint_stride={int(source_stride)};requested_stride={int(stride)}" - ) - return None - if len(public_tensors) < 1: - self._last_backend_recompute_checkpoint_source = "forward_carry_checkpoint_tape_miss:missing_public" - return None - rebuild_state = getattr(forward_carry_checkpoints, "rebuild_state", None) - if rebuild_state is None: - self._last_backend_recompute_checkpoint_source = ( - "forward_carry_checkpoint_tape_miss:missing_state_schema" - ) - return None - checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]] = { - 0: checkpoint_state( - packed_state, - initial_hidden, - initial_recurrent_k, - initial_recurrent_v, - ) - } - for checkpoint_index, checkpoint_step in enumerate(steps): - if checkpoint_step <= 0 or checkpoint_step >= int(boundary_seq.shape[1]): - continue - state_leaves = tuple( - tensor[checkpoint_index].to(device=boundary_seq.device, non_blocking=True) - if tensor.device != boundary_seq.device - else tensor[checkpoint_index] - for tensor in state_tensors - ) - public_checkpoint_tensors = tuple( - tensor[checkpoint_index].to(device=boundary_seq.device, non_blocking=True) - if tensor.device != boundary_seq.device - else tensor[checkpoint_index] - for tensor in public_tensors - ) - checkpoints[int(checkpoint_step)] = checkpoint_state( - rebuild_state(state_leaves), - public_checkpoint_tensors[0], - public_checkpoint_tensors[1] if len(public_checkpoint_tensors) > 1 else None, - public_checkpoint_tensors[2] if len(public_checkpoint_tensors) > 2 else None, - ) - actual_count = len(checkpoints) - 1 - if actual_count <= 0: - self._last_backend_recompute_checkpoint_source = "forward_carry_checkpoint_tape_miss:empty" - return None - self._last_backend_recompute_checkpoint_source = ( - f"forward_carry_checkpoint_tape:checkpoint_stride={getattr(forward_carry_checkpoints, 'stride', None)};" - f"requested_stride={int(stride)};count={actual_count}" - ) - return checkpoints - - def nearest_checkpoint_step( - checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]], - target_step: int, - ) -> int: - eligible_steps = [int(step) for step in checkpoints if int(step) <= int(target_step)] - if not eligible_steps: - return 0 - return max(eligible_steps) - - def estimate_receiver_window_tensor_bytes(tensor: torch.Tensor | None) -> int: - if tensor is None: - return 0 - if ( - active_receiver_window is not None - and active_receiver_window.active - and tensor.dim() >= 2 - and int(tensor.shape[1]) == int(active_receiver_window.full_count) - ): - shape = list(tensor.shape) - shape[1] = int(active_receiver_window.count) - return int(math.prod(shape)) * int(tensor.element_size()) - return int(tensor.numel()) * int(tensor.element_size()) - - def estimate_recurrent_bank_tensor_bytes(tensor: torch.Tensor | None) -> int: - if tensor is None: - return 0 - if ( - self._receiver_window_compacts_recurrent_senders(active_receiver_window) - and active_receiver_window is not None - and tensor.dim() >= 2 - and int(tensor.shape[1]) == int(active_receiver_window.full_count) - ): - shape = list(tensor.shape) - shape[1] = int(active_receiver_window.count) - return int(math.prod(shape)) * int(tensor.element_size()) - return int(tensor.numel()) * int(tensor.element_size()) - - def estimate_carry_checkpoint_bytes() -> int: - state_bytes = 0 - if packed_state is not None: - _state_keys, state_tensors = _flatten_backend_packed_state(packed_state) - state_bytes = sum(estimate_receiver_window_tensor_bytes(tensor) for tensor in state_tensors) - carry_bytes = state_bytes - if not can_elide_checkpoint_hidden_carry: - carry_bytes += estimate_receiver_window_tensor_bytes(initial_hidden) - carry_bytes += estimate_recurrent_bank_tensor_bytes(initial_recurrent_k) - carry_bytes += estimate_recurrent_bank_tensor_bytes(initial_recurrent_v) - return int(math.ceil(float(carry_bytes) * 1.15)) - - def estimate_artifact_window_step_bytes() -> int: - batch_size = int(boundary_seq.shape[0]) - dtype_bytes = int(torch.empty((), dtype=boundary_seq.dtype, device="cpu").element_size()) - active_receivers = ( - int(active_receiver_window.count) - if active_receiver_window is not None and active_receiver_window.active - else int(self._num_recurrent_cells) - ) - recurrent_bank_receivers = ( - active_receivers - if self._receiver_window_compacts_recurrent_senders(active_receiver_window) - else int(self._num_recurrent_cells) - ) - state_bytes = 0 - if packed_state is not None: - _state_keys, state_tensors = _flatten_backend_packed_state(packed_state) - state_bytes = sum(estimate_receiver_window_tensor_bytes(tensor) for tensor in state_tensors) - boundary_step_bytes = int(boundary_seq[:, 0].numel()) * dtype_bytes if boundary_seq.shape[1] > 0 else 0 - hidden_before_bytes = ( - 0 - if initial_recurrent_k is not None and initial_recurrent_v is not None - else int(batch_size * active_receivers * int(self.hidden_size) * dtype_bytes) - ) - input_kv_bytes = int( - batch_size * int(self._num_input_cells) * (int(self.head_dim) + int(self.value_dim)) * dtype_bytes - ) - recurrent_bank_bytes = int( - batch_size * recurrent_bank_receivers * (int(self.head_dim) + int(self.value_dim)) * dtype_bytes - ) - recurrent_msg_bytes = int(batch_size * active_receivers * int(self.hidden_size) * dtype_bytes) - recurrent_public_bytes = int(batch_size * active_receivers * int(self.hidden_size) * dtype_bytes) - recurrent_output_kv_bytes = ( - 0 - if output_boundary == "terminal" - else int(batch_size * active_receivers * (int(self.head_dim) + int(self.value_dim)) * dtype_bytes) - ) - output_bytes = 0 - if grad_output_seq is not None and output_boundary != "terminal": - output_bytes = int( - batch_size - * int(self._num_output_cells) - * (int(self.value_dim) + int(boundary_seq.shape[-1])) - * dtype_bytes - ) - window_step_bytes = ( - state_bytes - + boundary_step_bytes - + hidden_before_bytes - + input_kv_bytes - + recurrent_bank_bytes - + recurrent_msg_bytes - + recurrent_public_bytes - + recurrent_output_kv_bytes - + output_bytes - ) - return int(math.ceil(float(window_step_bytes) * 1.40)) - - def trim_non_output_recurrent_kv_artifact( - artifacts: _BackendSequenceStepArtifacts, - step_index: int, - ) -> None: - if not step_needs_output_artifacts(int(step_index)): - artifacts.recurrent_k = None - artifacts.recurrent_v = None - - def estimate_transition_tape_step_bytes() -> int: - batch_size = int(boundary_seq.shape[0]) - dtype_bytes = int(torch.empty((), dtype=boundary_seq.dtype, device="cpu").element_size()) - active_receivers = ( - int(active_receiver_window.count) - if active_receiver_window is not None and active_receiver_window.active - else int(self._num_recurrent_cells) - ) - local_hidden_surface_bytes = int(batch_size * active_receivers * int(self.hidden_size) * dtype_bytes) - return int(math.ceil(float(local_hidden_surface_bytes) * 4.0)) - - def estimate_transition_input_tape_step_bytes() -> int: - batch_size = int(boundary_seq.shape[0]) - dtype_bytes = int(torch.empty((), dtype=boundary_seq.dtype, device="cpu").element_size()) - active_receivers = ( - int(active_receiver_window.count) - if active_receiver_window is not None and active_receiver_window.active - else int(self._num_recurrent_cells) - ) - input_projection_bytes = int(batch_size * active_receivers * int(self.hidden_size) * dtype_bytes) - return int(math.ceil(float(input_projection_bytes) * 1.20)) - - if artifact_mode != "recompute_step_artifacts" and boundary_seq.device.type == "cuda": - storage_policy = artifact_storage_policy( - artifact_mode=artifact_mode, - time_steps=int(boundary_seq.shape[1]), - stored_artifact_step_bytes=estimate_artifact_window_step_bytes(), - memory=self._cuda_memory_budget(boundary_seq.device), - ) - if storage_policy.artifact_mode != artifact_mode: - artifact_mode = storage_policy.artifact_mode - self._last_backend_tape_artifact_mode = artifact_mode - previous_reason = getattr(self, "_last_backend_tape_chunk_reason", None) - guard_reason = storage_policy.reason_suffix or "artifact_storage_guard=active" - self._last_backend_tape_chunk_reason = ( - f"{previous_reason};{guard_reason}" if previous_reason else guard_reason - ) - - def step_needs_output_artifacts(step_index: int) -> bool: - if grad_output_seq is None: - return False - if output_boundary == "terminal": - return int(step_index) == int(boundary_seq.shape[1]) - 1 - return True - - def step_needs_recurrent_kv_artifacts(step_index: int, end_step_index: int) -> bool: - return int(step_index) + 1 < int(end_step_index) or step_needs_output_artifacts(step_index) - - def record_recompute_payload( - *, - mode: str, - artifacts: list[_BackendSequenceStepArtifacts], - ) -> None: - if not artifacts or getattr(self, "_active_backend_owner_timing", None) is None: - return - window_payload: dict[str, int] = defaultdict(int) - for artifact in artifacts: - for key, value in _artifact_payload_bytes_by_family(artifact).items(): - window_payload[key] += int(value) - window_payload["total"] = sum(int(window_payload.get(key, 0)) for key in _RECOMPUTE_PAYLOAD_TOTAL_KEYS) - current_max = getattr(self, "_last_backend_recompute_payload_max_bytes", None) - if not isinstance(current_max, dict): - current_max = {} - max_total = int(current_max.get("total", 0)) - if int(window_payload["total"]) >= max_total: - self._last_backend_recompute_payload_max_bytes = dict(window_payload) - self._last_backend_recompute_payload_max_window_len = int(len(artifacts)) - self._last_backend_recompute_payload_max_mode = str(mode) - self._last_backend_recompute_payload_sample_count = ( - int(getattr(self, "_last_backend_recompute_payload_sample_count", 0) or 0) + 1 - ) - - def compute_artifact_range(start_step_index: int, end_step_index: int) -> list[_BackendSequenceStepArtifacts]: - running_packed_state = packed_state - running_hidden = initial_hidden - running_recurrent_k = initial_recurrent_k - running_recurrent_v = initial_recurrent_v - artifact_range: list[_BackendSequenceStepArtifacts] = [] - with torch.no_grad(): - for step_index in range(end_step_index): - artifacts = self._compute_backend_sequence_surface_step_artifacts( - population_name=backend_population_name, - boundary_step=boundary_seq[:, step_index], - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - population_resets=population_resets[:, step_index] if population_resets is not None else None, - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] - ), - static_tensors=static_tensors, - materialize_output_artifacts=step_index >= start_step_index - and step_needs_output_artifacts(step_index), - materialize_transition_backward_tape=step_index >= start_step_index, - materialize_diagonal_preproj_tape=int(end_step_index) - int(start_step_index) > 1, - materialize_recurrent_kv=step_needs_recurrent_kv_artifacts(step_index, end_step_index), - materialize_next_state=step_index + 1 < end_step_index, - materialize_trace_state_next=not elide_trace_state_next, - active_receiver_window=active_receiver_window, - artifact_owner_scope="artifact.range", - ) - if step_index + 1 < end_step_index: - running_packed_state = artifacts.next_packed_state - running_hidden = artifacts.recurrent_hidden - running_recurrent_k = artifacts.recurrent_k - running_recurrent_v = artifacts.recurrent_v - if step_index >= start_step_index: - artifacts.next_packed_state = None - trim_non_output_recurrent_kv_artifact(artifacts, step_index) - artifact_range.append(artifacts) - if not artifact_range: - raise RuntimeError("Fabric physical backward recompute requested an empty sequence") - record_recompute_payload(mode="range", artifacts=artifact_range) - return artifact_range - - def recompute_artifact_window_len(stride: int) -> int: - time_steps = int(boundary_seq.shape[1]) - if time_steps <= 1 or boundary_seq.device.type != "cuda": - self._last_backend_recompute_artifact_window_len = time_steps - self._last_backend_recompute_artifact_window_reason = "artifact_window=disabled" - return time_steps - estimated_step_bytes = self._estimate_backend_tape_step_bytes( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - ) - policy = recompute_artifact_window_policy( - time_steps=time_steps, - stride=int(stride), - estimated_step_bytes=estimated_step_bytes, - artifact_window_step_bytes=estimate_artifact_window_step_bytes(), - effective_batch_split_active=effective_batch_split_active(), - memory=self._cuda_memory_budget(boundary_seq.device), - ) - self._last_backend_recompute_artifact_window_len = int(policy.window_len) - self._last_backend_recompute_artifact_window_reason = policy.reason - return int(policy.window_len) - - def recompute_checkpoint_stride() -> int: - time_steps = int(boundary_seq.shape[1]) - if time_steps <= 1 or boundary_seq.device.type != "cuda": - self._last_backend_recompute_checkpoint_stride = time_steps - self._last_backend_recompute_checkpoint_count = 0 - self._last_backend_recompute_checkpoint_reason = "checkpointing=disabled" - return time_steps - estimated_step_bytes = self._estimate_backend_tape_step_bytes( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - ) - policy = recompute_checkpoint_stride_policy( - time_steps=time_steps, - estimated_step_bytes=estimated_step_bytes, - effective_batch_split_active=effective_batch_split_active(), - memory=self._cuda_memory_budget(boundary_seq.device), - ) - self._last_backend_recompute_checkpoint_stride = int(policy.stride) - self._last_backend_recompute_checkpoint_count = int(policy.checkpoint_count) - self._last_backend_recompute_checkpoint_reason = policy.reason - return int(policy.stride) - - def align_checkpoint_stride_to_artifact_window(stride: int, window_len: int) -> int: - time_steps = int(boundary_seq.shape[1]) - policy = checkpoint_stride_alignment_policy( - time_steps=time_steps, - stride=int(stride), - window_len=int(window_len), - carry_checkpoint_bytes=estimate_carry_checkpoint_bytes(), - memory=self._cuda_memory_budget(boundary_seq.device), - ) - if policy.reason_suffix is None: - return int(policy.stride) - previous_reason = getattr(self, "_last_backend_recompute_checkpoint_reason", None) - if int(policy.stride) == int(stride): - self._last_backend_recompute_checkpoint_reason = ( - f"{previous_reason};{policy.reason_suffix}" if previous_reason else policy.reason_suffix - ) - return int(stride) - self._last_backend_recompute_checkpoint_reason = ( - f"{previous_reason};{policy.reason_suffix}" if previous_reason else policy.reason_suffix - ) - self._last_backend_recompute_checkpoint_stride = int(policy.stride) - if policy.checkpoint_count is not None: - self._last_backend_recompute_checkpoint_count = int(policy.checkpoint_count) - return int(policy.stride) - - def build_recompute_checkpoints( - stride: int, - ) -> dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]]: - time_steps = int(boundary_seq.shape[1]) - checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]] = { - 0: checkpoint_state( - packed_state, - initial_hidden, - initial_recurrent_k, - initial_recurrent_v, - ) - } - if stride >= time_steps: - return checkpoints - running_packed_state = packed_state - running_hidden = initial_hidden - running_recurrent_k = initial_recurrent_k - running_recurrent_v = initial_recurrent_v - with torch.no_grad(): - last_checkpoint_step = (max(0, time_steps - 1) // int(stride)) * int(stride) - for step_index in range(last_checkpoint_step): - artifacts = self._compute_backend_sequence_surface_step_artifacts( - population_name=backend_population_name, - boundary_step=boundary_seq[:, step_index], - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - population_resets=population_resets[:, step_index] if population_resets is not None else None, - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] - ), - static_tensors=static_tensors, - materialize_output_artifacts=bool( - checkpoint_artifact_window_cache_enabled - and step_index >= checkpoint_artifact_cache_start - and step_needs_output_artifacts(step_index) - ), - materialize_trace_state_next=not elide_trace_state_next, - active_receiver_window=active_receiver_window, - artifact_owner_scope="artifact.checkpoint", - ) - running_packed_state = artifacts.next_packed_state - running_hidden = artifacts.recurrent_hidden - running_recurrent_k = artifacts.recurrent_k - running_recurrent_v = artifacts.recurrent_v - checkpoint_step = step_index + 1 - if checkpoint_step < time_steps and checkpoint_step % stride == 0: - checkpoints[checkpoint_step] = checkpoint_state( - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - ) - if checkpoint_artifact_window_cache_enabled and step_index >= checkpoint_artifact_cache_start: - cached_artifact = artifacts - cached_artifact.next_packed_state = None - trim_non_output_recurrent_kv_artifact(cached_artifact, step_index) - window_start = (int(step_index) // int(recompute_window_len)) * int(recompute_window_len) - checkpoint_artifact_windows[window_start].append(cached_artifact) - else: - del artifacts - return checkpoints - - def compute_artifact_from_checkpoint( - *, - step_index: int, - stride: int, - checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]], - predecessor_artifacts: dict[int, _BackendSequenceStepArtifacts] | None = None, - predecessor_checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]] - | None = None, - predecessor_checkpoint_limit: int = 1, - ) -> _BackendSequenceStepArtifacts: - segment_checkpoint_step = (int(step_index) // int(stride)) * int(stride) - checkpoint_step = nearest_checkpoint_step(checkpoints, segment_checkpoint_step) - predecessor_checkpoint_limit = max(1, int(predecessor_checkpoint_limit)) - if predecessor_artifacts is not None: - for cached_step in tuple(predecessor_artifacts): - if int(cached_step) > int(step_index): - predecessor_artifacts.pop(cached_step, None) - if predecessor_artifacts is not None and int(step_index) in predecessor_artifacts: - return predecessor_artifacts.pop(int(step_index)) - if predecessor_checkpoints is not None: - for cached_step in tuple(predecessor_checkpoints): - if int(cached_step) > int(step_index): - predecessor_checkpoints.pop(cached_step, None) - if predecessor_checkpoints is not None and int(step_index) in predecessor_checkpoints: - checkpoint_step = int(step_index) - running_packed_state, running_hidden, running_recurrent_k, running_recurrent_v = ( - predecessor_checkpoints.pop(int(step_index)) - ) - else: - running_packed_state, running_hidden, running_recurrent_k, running_recurrent_v = checkpoints[ - checkpoint_step - ] - target_artifact: _BackendSequenceStepArtifacts | None = None - with torch.no_grad(): - for current_step in range(checkpoint_step, int(step_index) + 1): - materialize_predecessor_artifact = bool( - predecessor_artifacts is not None - and current_step == int(step_index) - 1 - and current_step >= checkpoint_step - ) - if ( - predecessor_checkpoints is not None - and not materialize_predecessor_artifact - and current_step < int(step_index) - and current_step > checkpoint_step - and int(step_index) - current_step <= predecessor_checkpoint_limit - and current_step >= checkpoint_step - ): - predecessor_checkpoints[current_step] = checkpoint_state( - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - ) - artifacts = self._compute_backend_sequence_surface_step_artifacts( - population_name=backend_population_name, - boundary_step=boundary_seq[:, current_step], - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - population_resets=population_resets[:, current_step] if population_resets is not None else None, - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] - ), - static_tensors=static_tensors, - materialize_output_artifacts=( - current_step == int(step_index) or materialize_predecessor_artifact - ) - and step_needs_output_artifacts(current_step), - materialize_transition_backward_tape=current_step == int(step_index) - or materialize_predecessor_artifact, - materialize_diagonal_preproj_tape=False, - materialize_recurrent_kv=step_needs_recurrent_kv_artifacts( - current_step, - int(step_index) + 1, - ), - materialize_next_state=current_step < int(step_index), - materialize_trace_state_next=not elide_trace_state_next, - active_receiver_window=active_receiver_window, - artifact_owner_scope="artifact.checkpoint_step", - ) - if current_step < int(step_index): - running_packed_state = artifacts.next_packed_state - running_hidden = artifacts.recurrent_hidden - running_recurrent_k = artifacts.recurrent_k - running_recurrent_v = artifacts.recurrent_v - if materialize_predecessor_artifact: - artifacts.next_packed_state = None - trim_non_output_recurrent_kv_artifact(artifacts, current_step) - if predecessor_artifacts is not None: - predecessor_artifacts[current_step] = artifacts - continue - if current_step == int(step_index): - artifacts.next_packed_state = None - trim_non_output_recurrent_kv_artifact(artifacts, current_step) - target_artifact = artifacts - else: - del artifacts - if target_artifact is None: - raise RuntimeError("Fabric physical backward checkpoint recompute missed the target step") - record_recompute_payload(mode="checkpoint_step", artifacts=[target_artifact]) - return target_artifact - - def compute_artifact_window_from_checkpoint( - *, - start_step_index: int, - end_step_index: int, - stride: int, - checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]], - predecessor_checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]] - | None = None, - predecessor_checkpoint_limit: int = 0, - ) -> list[_BackendSequenceStepArtifacts]: - checkpoint_step = 0 - segment_start = 0 - segment_end = int(boundary_seq.shape[1]) - if int(stride) < int(boundary_seq.shape[1]): - segment_start = (int(start_step_index) // int(stride)) * int(stride) - segment_end = min(int(boundary_seq.shape[1]), int(segment_start) + int(stride)) - checkpoint_step = nearest_checkpoint_step(checkpoints, int(segment_start)) - predecessor_checkpoint_limit = max(0, int(predecessor_checkpoint_limit)) - if predecessor_checkpoints is not None: - for cached_step in tuple(predecessor_checkpoints): - if ( - int(cached_step) > int(start_step_index) - or int(cached_step) < int(segment_start) - or int(cached_step) >= int(segment_end) - ): - predecessor_checkpoints.pop(cached_step, None) - if predecessor_checkpoints is not None and int(start_step_index) in predecessor_checkpoints: - checkpoint_step = int(start_step_index) - running_packed_state, running_hidden, running_recurrent_k, running_recurrent_v = ( - predecessor_checkpoints.pop(int(start_step_index)) - ) - else: - running_packed_state, running_hidden, running_recurrent_k, running_recurrent_v = checkpoints[ - checkpoint_step - ] - artifact_range: list[_BackendSequenceStepArtifacts] = [] - with torch.no_grad(): - for current_step in range(checkpoint_step, int(end_step_index)): - artifacts = self._compute_backend_sequence_surface_step_artifacts( - population_name=backend_population_name, - boundary_step=boundary_seq[:, current_step], - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - population_resets=population_resets[:, current_step] if population_resets is not None else None, - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] - ), - static_tensors=static_tensors, - materialize_output_artifacts=current_step >= int(start_step_index) - and step_needs_output_artifacts(current_step), - materialize_transition_backward_tape=window_transition_tape_mode != "disabled" - and current_step >= int(start_step_index), - materialize_diagonal_preproj_tape=window_transition_tape_mode == "full", - materialize_recurrent_kv=step_needs_recurrent_kv_artifacts(current_step, end_step_index), - materialize_next_state=current_step + 1 < int(end_step_index), - materialize_trace_state_next=not elide_trace_state_next, - active_receiver_window=active_receiver_window, - artifact_owner_scope="artifact.window", - ) - if current_step + 1 < int(end_step_index): - running_packed_state = artifacts.next_packed_state - running_hidden = artifacts.recurrent_hidden - running_recurrent_k = artifacts.recurrent_k - running_recurrent_v = artifacts.recurrent_v - candidate_checkpoint_step = int(current_step) + 1 - if ( - predecessor_checkpoints is not None - and predecessor_checkpoint_limit > 0 - and candidate_checkpoint_step > int(segment_start) - and candidate_checkpoint_step < int(start_step_index) - and (candidate_checkpoint_step - int(segment_start)) % int(recompute_window_len) == 0 - and int(start_step_index) - candidate_checkpoint_step - <= int(predecessor_checkpoint_limit) * int(recompute_window_len) - ): - predecessor_checkpoints[candidate_checkpoint_step] = checkpoint_state( - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - ) - if current_step >= int(start_step_index): - artifacts.next_packed_state = None - trim_non_output_recurrent_kv_artifact(artifacts, current_step) - artifact_range.append(artifacts) - else: - del artifacts - if not artifact_range: - raise RuntimeError("Fabric physical backward checkpoint window recompute produced no artifacts") - record_recompute_payload(mode="checkpoint_window", artifacts=artifact_range) - return artifact_range - - recompute_window_len = int(boundary_seq.shape[1]) - recompute_checkpoint_stride_len = recompute_window_len - recompute_checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]] = {} - checkpoint_artifact_windows: dict[int, list[_BackendSequenceStepArtifacts]] = defaultdict(list) - checkpoint_artifact_window_cache_enabled = False - checkpoint_artifact_cache_start = int(boundary_seq.shape[1]) - checkpoint_artifact_cache_reason = "disabled" - if artifact_mode == "recompute_step_artifacts": - recompute_checkpoint_stride_len = recompute_checkpoint_stride() - recompute_window_len = recompute_artifact_window_len(recompute_checkpoint_stride_len) - aligned_checkpoint_stride_len = align_checkpoint_stride_to_artifact_window( - recompute_checkpoint_stride_len, - recompute_window_len, - ) - if int(aligned_checkpoint_stride_len) != int(recompute_checkpoint_stride_len): - recompute_checkpoint_stride_len = int(aligned_checkpoint_stride_len) - recompute_window_len = recompute_artifact_window_len(recompute_checkpoint_stride_len) - if int(recompute_window_len) > 1 and int(recompute_checkpoint_stride_len) >= int(recompute_window_len): - _usable_bytes, total_bytes, free_bytes, _reusable_reserved_bytes = self._cuda_usable_memory_info( - boundary_seq.device - ) - runtime_reserve_bytes = max(4 << 30, int(total_bytes * 0.04)) - artifact_window_bytes = max(1, estimate_artifact_window_step_bytes() * int(recompute_window_len)) - artifact_cache_budget = max( - 0, - min(int(total_bytes * 0.20), int(free_bytes) - int(runtime_reserve_bytes)), - ) - max_cached_windows = min(2, int(artifact_cache_budget // artifact_window_bytes)) - last_checkpoint_step = ( - max(0, int(boundary_seq.shape[1]) - 1) // int(recompute_checkpoint_stride_len) - ) * int(recompute_checkpoint_stride_len) - checkpoint_artifact_window_cache_enabled = bool(max_cached_windows > 0 and last_checkpoint_step > 0) - checkpoint_artifact_cache_start = max( - 0, - int(last_checkpoint_step) - int(max_cached_windows) * int(recompute_window_len), - ) - checkpoint_artifact_cache_reason = ( - f"window_artifact_cache={'enabled' if checkpoint_artifact_window_cache_enabled else 'disabled'};" - f"artifact_window_bytes={int(artifact_window_bytes)};free_bytes={int(free_bytes)};" - f"reserve_bytes={int(runtime_reserve_bytes)};budget_bytes={int(artifact_cache_budget)};" - f"max_cached_windows={int(max_cached_windows)};cache_start={int(checkpoint_artifact_cache_start)};" - f"last_checkpoint_step={int(last_checkpoint_step)}" - ) - forward_checkpoints = forward_carry_checkpoint_map(recompute_checkpoint_stride_len) - if forward_checkpoints is not None: - recompute_checkpoints = forward_checkpoints - else: - with ( - torch.profiler.record_function("fabric.backward.tape_recompute_checkpoints"), - self._backend_owner_timing("tape_recompute_checkpoints"), - ): - recompute_checkpoints = build_recompute_checkpoints(recompute_checkpoint_stride_len) - previous_checkpoint_source = getattr(self, "_last_backend_recompute_checkpoint_source", None) - self._last_backend_recompute_checkpoint_source = ( - f"{previous_checkpoint_source};fallback=replayed_transition_checkpoints" - if previous_checkpoint_source - else "replayed_transition_checkpoints" - ) - predecessor_artifacts: dict[int, _BackendSequenceStepArtifacts] = {} - predecessor_checkpoints: dict[int, tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]] = {} - predecessor_cache_candidate = bool( - artifact_mode == "recompute_step_artifacts" - and recompute_window_len <= 1 - and recompute_checkpoint_stride_len > 1 - ) - predecessor_artifact_cache_enabled = False - predecessor_carry_cache_enabled = False - window_carry_cache_enabled = False - window_transition_tape_mode = "disabled" - predecessor_checkpoint_limit = 1 - window_carry_checkpoint_limit = 0 - predecessor_cache_reason = "disabled" - predecessor_cache_context = "disabled" - if predecessor_cache_candidate: - estimated_step_bytes = self._estimate_backend_tape_step_bytes( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - ) - _usable_bytes, total_bytes, free_bytes, _reusable_reserved_bytes = self._cuda_usable_memory_info( - boundary_seq.device - ) - runtime_reserve_bytes = max(4 << 30, int(total_bytes * 0.04)) - artifact_cache_budget = max(0, int(free_bytes) - int(runtime_reserve_bytes)) - predecessor_cache_context = ( - "effective_batch_split" if effective_batch_split_active() else "streaming_window" - ) - if effective_batch_split_active(): - predecessor_artifact_cache_enabled = bool(artifact_cache_budget >= int(estimated_step_bytes)) - predecessor_carry_cache_enabled = not predecessor_artifact_cache_enabled - else: - # Outside an outer effective-batch split, avoid caching full step artifacts: that can increase the - # live transition/message surfaces enough to erase the checkpointing benefit. A single predecessor - # carry checkpoint is the generic rolling-tape reuse point and is enabled only when raw free memory - # still covers the estimated carry state plus scheduler headroom after the runtime reserve, and when - # one carry state is not itself a large fraction of device memory. - carry_checkpoint_bytes = estimate_carry_checkpoint_bytes() - predecessor_carry_cache_enabled = bool( - artifact_cache_budget >= int(carry_checkpoint_bytes) * 2 - and int(carry_checkpoint_bytes) <= int(total_bytes * 0.05) - ) - if predecessor_carry_cache_enabled: - predecessor_checkpoint_limit = max( - 1, - min( - int(recompute_checkpoint_stride_len) - 1, - int(artifact_cache_budget // max(1, int(carry_checkpoint_bytes))) - 1, - ), - ) - carry_checkpoint_bytes = estimate_carry_checkpoint_bytes() - predecessor_cache_reason = ( - f"estimated_step_bytes={int(estimated_step_bytes)};free_bytes={int(free_bytes)};" - f"reserve_bytes={int(runtime_reserve_bytes)};artifact_cache_budget={artifact_cache_budget};" - f"carry_checkpoint_bytes={int(carry_checkpoint_bytes)};" - f"streaming_carry_required_bytes={int(carry_checkpoint_bytes) * 2};" - f"streaming_carry_max_checkpoint_bytes={int(total_bytes * 0.05)};" - f"predecessor_checkpoint_limit={int(predecessor_checkpoint_limit)}" - ) - elif ( - artifact_mode == "recompute_step_artifacts" - and recompute_window_len > 1 - and recompute_checkpoint_stride_len > recompute_window_len - ): - _usable_bytes, total_bytes, free_bytes, _reusable_reserved_bytes = self._cuda_usable_memory_info( - boundary_seq.device - ) - runtime_reserve_bytes = max(4 << 30, int(total_bytes * 0.04)) - carry_cache_budget = max(0, int(free_bytes) - int(runtime_reserve_bytes)) - carry_checkpoint_bytes = estimate_carry_checkpoint_bytes() - max_window_carry_checkpoints = max( - 0, - int(math.ceil(float(recompute_checkpoint_stride_len) / float(recompute_window_len))) - 1, - ) - budgeted_window_carry_checkpoints = max( - 0, - int(carry_cache_budget // max(1, int(carry_checkpoint_bytes))), - ) - window_carry_checkpoint_limit = min( - int(max_window_carry_checkpoints), - int(budgeted_window_carry_checkpoints), - ) - window_carry_cache_enabled = bool( - window_carry_checkpoint_limit > 0 and int(carry_checkpoint_bytes) <= int(total_bytes * 0.05) - ) - predecessor_cache_context = "windowed_artifact" - predecessor_cache_reason = ( - f"free_bytes={int(free_bytes)};reserve_bytes={int(runtime_reserve_bytes)};" - f"carry_cache_budget={carry_cache_budget};" - f"carry_checkpoint_bytes={int(carry_checkpoint_bytes)};" - f"window_carry_max_checkpoint_bytes={int(total_bytes * 0.05)};" - f"max_window_carry_checkpoints={int(max_window_carry_checkpoints)};" - f"budgeted_window_carry_checkpoints={int(budgeted_window_carry_checkpoints)};" - f"window_carry_checkpoint_limit={int(window_carry_checkpoint_limit)}" - ) - if artifact_mode == "recompute_step_artifacts" and recompute_window_len > 1: - usable_bytes, total_bytes, free_bytes, reusable_reserved_bytes = self._cuda_usable_memory_info( - boundary_seq.device - ) - runtime_reserve_bytes = max(4 << 30, int(total_bytes * 0.04)) - budget_usable_bytes = ( - int(usable_bytes) if int(reusable_reserved_bytes) > int(runtime_reserve_bytes) else int(free_bytes) - ) - transition_tape_budget = max( - 0, - min(int(total_bytes * 0.08), int(budget_usable_bytes) - int(runtime_reserve_bytes)), - ) - transition_tape_step_bytes = estimate_transition_tape_step_bytes() - transition_tape_window_bytes = int(transition_tape_step_bytes) * int(recompute_window_len) - transition_input_tape_step_bytes = estimate_transition_input_tape_step_bytes() - transition_input_tape_window_bytes = int(transition_input_tape_step_bytes) * int(recompute_window_len) - if transition_tape_window_bytes > 0 and transition_tape_window_bytes <= transition_tape_budget: - window_transition_tape_mode = "full" - elif ( - transition_input_tape_window_bytes > 0 and transition_input_tape_window_bytes <= transition_tape_budget - ): - window_transition_tape_mode = "input_projection" - self._last_backend_recompute_transition_tape_reason = ( - f"transition_tape_step_bytes={int(transition_tape_step_bytes)};" - f"transition_tape_window_bytes={int(transition_tape_window_bytes)};" - f"transition_input_tape_step_bytes={int(transition_input_tape_step_bytes)};" - f"transition_input_tape_window_bytes={int(transition_input_tape_window_bytes)};" - f"free_bytes={int(free_bytes)};reserve_bytes={int(runtime_reserve_bytes)};" - f"budget_bytes={int(transition_tape_budget)};" - f"selected={window_transition_tape_mode}" - ) - else: - self._last_backend_recompute_transition_tape_reason = "transition_tape=disabled" - self._last_backend_recompute_predecessor_cache_mode = ( - f"enabled:{predecessor_cache_context}:artifact;{predecessor_cache_reason}" - if predecessor_artifact_cache_enabled - else f"enabled:{predecessor_cache_context}:carry;{predecessor_cache_reason}" - if predecessor_carry_cache_enabled - else f"enabled:{predecessor_cache_context}:window_carry;{predecessor_cache_reason}" - if window_carry_cache_enabled - else "disabled" - ) - if checkpoint_artifact_cache_reason != "disabled": - self._last_backend_recompute_checkpoint_artifact_cache_mode = checkpoint_artifact_cache_reason - self._last_backend_recompute_transition_tape_mode = ( - "target_step" - if artifact_mode == "recompute_step_artifacts" and recompute_window_len <= 1 - else "windowed_input_projection" - if artifact_mode == "recompute_step_artifacts" and window_transition_tape_mode == "input_projection" - else "windowed_artifacts" - if artifact_mode == "recompute_step_artifacts" and window_transition_tape_mode == "full" - else "disabled_windowed_artifacts" - if artifact_mode == "recompute_step_artifacts" - else "disabled" - ) - self._last_backend_recompute_public_kv_materialization_mode = ( - "carry_or_output_boundary_only" if artifact_mode == "recompute_step_artifacts" else "stored_artifacts" - ) - self._last_backend_recompute_target_state_materialization_mode = ( - "emit_only_no_next_state" - if artifact_mode == "recompute_step_artifacts" - else "stored_artifacts_final_step_emit_only" - ) - artifact_window_start = -1 - artifact_window_end = -1 - artifact_window: list[_BackendSequenceStepArtifacts | None] = [] - - def recomputed_artifacts_for_step(step_index: int) -> _BackendSequenceStepArtifacts: - nonlocal artifact_window_start, artifact_window_end, artifact_window - if recompute_window_len <= 1: - if recompute_checkpoint_stride_len < int(boundary_seq.shape[1]): - return compute_artifact_from_checkpoint( - step_index=int(step_index), - stride=recompute_checkpoint_stride_len, - checkpoints=recompute_checkpoints, - predecessor_artifacts=predecessor_artifacts if predecessor_artifact_cache_enabled else None, - predecessor_checkpoints=predecessor_checkpoints if predecessor_carry_cache_enabled else None, - predecessor_checkpoint_limit=predecessor_checkpoint_limit, - ) - return compute_artifact_range(int(step_index), int(step_index) + 1)[0] - segment_start = ( - 0 - if recompute_checkpoint_stride_len >= int(boundary_seq.shape[1]) - else (int(step_index) // recompute_checkpoint_stride_len) * recompute_checkpoint_stride_len - ) - segment_end = min(int(boundary_seq.shape[1]), segment_start + recompute_checkpoint_stride_len) - window_start = ( - segment_start + ((int(step_index) - segment_start) // recompute_window_len) * recompute_window_len - ) - window_end = min(segment_end, window_start + recompute_window_len) - if not artifact_window or window_start != artifact_window_start or window_end != artifact_window_end: - cached_window = checkpoint_artifact_windows.pop(window_start, None) - if cached_window is not None and len(cached_window) == int(window_end) - int(window_start): - artifact_window = cached_window - record_recompute_payload(mode="checkpoint_window_cache", artifacts=artifact_window) - elif recompute_checkpoints: - artifact_window = compute_artifact_window_from_checkpoint( - start_step_index=window_start, - end_step_index=window_end, - stride=recompute_checkpoint_stride_len, - checkpoints=recompute_checkpoints, - predecessor_checkpoints=predecessor_checkpoints if window_carry_cache_enabled else None, - predecessor_checkpoint_limit=window_carry_checkpoint_limit, - ) - else: - artifact_window = compute_artifact_range(window_start, window_end) - artifact_window_start = window_start - artifact_window_end = window_end - artifact_offset = int(step_index) - int(artifact_window_start) - artifacts = artifact_window[artifact_offset] - if artifacts is None: - raise RuntimeError("Fabric physical backward recompute window artifact was already consumed") - return artifacts - - step_artifacts: list[_BackendSequenceStepArtifacts] | None = None - if artifact_mode == "recompute_step_artifacts": - step_artifacts = None - else: - step_artifacts = [] - running_packed_state = packed_state - running_hidden = initial_hidden - running_recurrent_k = initial_recurrent_k - running_recurrent_v = initial_recurrent_v - with torch.no_grad(): - time_steps = int(boundary_seq.shape[1]) - for step_index in range(time_steps): - artifacts = self._compute_backend_sequence_surface_step_artifacts( - population_name=backend_population_name, - boundary_step=boundary_seq[:, step_index], - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - population_resets=population_resets[:, step_index] if population_resets is not None else None, - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] - ), - static_tensors=static_tensors, - materialize_output_artifacts=step_needs_output_artifacts(step_index), - materialize_recurrent_kv=step_needs_recurrent_kv_artifacts(step_index, time_steps), - materialize_next_state=step_index + 1 < time_steps, - materialize_trace_state_next=not elide_trace_state_next, - active_receiver_window=active_receiver_window, - artifact_owner_scope="artifact.store", - ) - if step_index + 1 < time_steps: - running_packed_state = artifacts.next_packed_state - running_hidden = artifacts.recurrent_hidden - running_recurrent_k = artifacts.recurrent_k - running_recurrent_v = artifacts.recurrent_v - trim_non_output_recurrent_kv_artifact(artifacts, step_index) - step_artifacts.append(artifacts) - - running_grad_packed_state = grad_next_packed_state - running_grad_recurrent_hidden = grad_recurrent_hidden - running_grad_recurrent_k = grad_recurrent_k - running_grad_recurrent_v = grad_recurrent_v - grad_boundary_steps: list[torch.Tensor | None] | None = ( - None if projected_boundary_active else [None] * int(boundary_seq.shape[1]) - ) - grad_projected_boundary_source_seq = ( - torch.zeros_like(projected_boundary_source_seq) - if projected_boundary_active - and projected_boundary_source_seq is not None - and projected_boundary_source_seq.requires_grad - else None - ) - grad_projected_boundary_weight = ( - torch.zeros_like(projected_boundary_weight) - if projected_boundary_active - and projected_boundary_weight is not None - and projected_boundary_weight.requires_grad - else None - ) - grad_projected_boundary_bias = ( - torch.zeros_like(projected_boundary_bias) - if projected_boundary_active - and projected_boundary_bias is not None - and projected_boundary_bias.requires_grad - else None - ) - saw_projected_boundary_grad = False - grad_initial_hidden_total: torch.Tensor | None = None - grad_initial_recurrent_k_total: torch.Tensor | None = None - grad_initial_recurrent_v_total: torch.Tensor | None = None - grad_param_accum: list[torch.Tensor | None] = [None] * len(trainable_params) - - for step_index in reversed(range(int(boundary_seq.shape[1]))): - if step_artifacts is None: - with self._backend_owner_timing("tape_recompute_artifacts"): - artifacts = recomputed_artifacts_for_step(step_index) - else: - artifacts = step_artifacts[step_index] - if grad_output_seq is None: - grad_output_step = None - elif output_boundary == "terminal": - grad_output_step = grad_output_seq[:, -1] if step_index == int(boundary_seq.shape[1]) - 1 else None - else: - grad_output_step = grad_output_seq[:, step_index] - artifact_receiver_window = artifacts.active_receiver_window - compact_recurrent_senders = self._receiver_window_compacts_recurrent_senders(artifact_receiver_window) - recurrent_q_for_artifact = self._slice_receiver_window_rows( - cast(torch.Tensor, static_tensors["recurrent_q"]), - artifact_receiver_window, - ) - recurrent_local_sender_idx = self._cached_receiver_window_sender_table( - name="backward_recurrent", - table=self.recurrent_local_sender_idx, - window=artifact_receiver_window, - num_input_senders=int(self._num_input_cells), - slice_receivers=True, - compact_recurrent_senders=compact_recurrent_senders, - ) - recurrent_sender_count = ( - int(self._num_input_cells) + int(artifact_receiver_window.count) - if compact_recurrent_senders and artifact_receiver_window is not None - else int(self.sender_cell_idx.numel()) - ) - recurrent_local_receiver_idx_by_sender = ( - self.recurrent_local_receiver_idx_by_sender - if artifact_receiver_window is None or not artifact_receiver_window.active - else self._cached_sender_reverse_table( - name="backward_recurrent", - receiver_sender_idx=recurrent_local_sender_idx, - num_senders=recurrent_sender_count, - ) - ) - output_local_sender_idx = self._cached_receiver_window_sender_table( - name="backward_output", - table=self.output_local_sender_idx, - window=artifact_receiver_window, - num_input_senders=int(self._num_input_cells), - slice_receivers=False, - compact_recurrent_senders=True, - ) - output_local_receiver_idx_by_sender = ( - self.output_local_receiver_idx_by_sender - if artifact_receiver_window is None or not artifact_receiver_window.active - else self._cached_sender_reverse_table( - name="backward_output", - receiver_sender_idx=output_local_sender_idx, - num_senders=int(self._num_input_cells) + int(artifact_receiver_window.count), - ) - ) - - if grad_output_step is not None: - if artifacts.output_msg is None: - raise RuntimeError("Fabric physical backward output-gradient step is missing output message") - grad_output_msg, output_projection_param_grads = self._run_backend_output_projection_backward_phase( - output_msg=artifacts.output_msg, - grad_output_cells=grad_output_step, - sequence_static_tensors=param_static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - for parameter_index, grad_param in enumerate(output_projection_param_grads): - grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( - grad_param_accum[parameter_index], - grad_param, - ) - - ( - grad_output_q, - grad_input_k_from_output, - grad_input_v_from_output, - grad_recurrent_k_from_output, - grad_recurrent_v_from_output, - ) = self._run_backend_message_backward_phase( - grad_msg=grad_output_msg, - q_subset=cast(torch.Tensor, static_tensors["output_q"]), - input_k=artifacts.input_k, - input_v=artifacts.input_v, - recurrent_k=artifacts.recurrent_k, - recurrent_v=artifacts.recurrent_v, - neighbor_idx=self.output_neighbor_idx, - neighbor_valid=self.output_neighbor_valid, - edge_distance=self.output_edge_distance, - edge_delay=self.output_edge_delay, - local_sender_idx=output_local_sender_idx, - local_receiver_idx_by_sender=output_local_receiver_idx_by_sender, - use_sparse_messages=uses_sparse_messages, - ) - else: - grad_output_q = None - grad_input_k_from_output = None - grad_input_v_from_output = None - grad_recurrent_k_from_output = None - grad_recurrent_v_from_output = None - - grad_recurrent_msg, grad_packed_state_before, receiver_phase_param_grads = ( - self._run_backend_state_public_backward_phase( - population_name=backend_population_name, - recurrent_msg=artifacts.recurrent_msg, - recurrent_hidden_tape=artifacts.recurrent_hidden, - packed_state_before=self._slice_receiver_window_batch_rows( - artifacts.packed_state_before, - artifact_receiver_window, - ), - population_reset_step=artifacts.population_reset_step, - grad_next_packed_state=running_grad_packed_state, - grad_recurrent_hidden=running_grad_recurrent_hidden, - grad_recurrent_k=_accumulate_grad(running_grad_recurrent_k, grad_recurrent_k_from_output), - grad_recurrent_v=_accumulate_grad(running_grad_recurrent_v, grad_recurrent_v_from_output), - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - sequence_static_tensors=static_tensors, - param_static_tensors=param_static_tensors, - transition_backward_tape=artifacts.transition_backward_tape, - need_grad_packed_state_before=step_index > 0 or packed_state_input_requires_grad, - device=boundary_seq.device, - dtype=boundary_seq.dtype, - active_receiver_window=artifact_receiver_window, - ) - ) - for parameter_index, grad_param in enumerate(receiver_phase_param_grads): - grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( - grad_param_accum[parameter_index], - grad_param, - ) - - ( - grad_recurrent_q, - grad_input_k_from_recurrent, - grad_input_v_from_recurrent, - grad_resolved_recurrent_k_before, - grad_resolved_recurrent_v_before, - ) = self._run_backend_message_backward_phase( - grad_msg=grad_recurrent_msg, - q_subset=recurrent_q_for_artifact, - input_k=artifacts.input_k, - input_v=artifacts.input_v, - recurrent_k=artifacts.recurrent_k_bank, - recurrent_v=artifacts.recurrent_v_bank, - neighbor_idx=self.recurrent_neighbor_idx - if artifact_receiver_window is None or not artifact_receiver_window.active - else self._cached_receiver_window_static_rows( - name="backward_recurrent_neighbor_idx", - tensor=self.recurrent_neighbor_idx, - window=artifact_receiver_window, - ), - neighbor_valid=self.recurrent_neighbor_valid - if artifact_receiver_window is None or not artifact_receiver_window.active - else self._cached_receiver_window_static_rows( - name="backward_recurrent_neighbor_valid", - tensor=self.recurrent_neighbor_valid, - window=artifact_receiver_window, - ), - edge_distance=self.recurrent_edge_distance - if artifact_receiver_window is None or not artifact_receiver_window.active - else self._cached_receiver_window_static_rows( - name="backward_recurrent_edge_distance", - tensor=self.recurrent_edge_distance, - window=artifact_receiver_window, - ), - edge_delay=self.recurrent_edge_delay - if artifact_receiver_window is None or not artifact_receiver_window.active - else self._cached_receiver_window_static_rows( - name="backward_recurrent_edge_delay", - tensor=self.recurrent_edge_delay, - window=artifact_receiver_window, - ), - local_sender_idx=recurrent_local_sender_idx, - local_receiver_idx_by_sender=recurrent_local_receiver_idx_by_sender, - use_sparse_messages=uses_sparse_messages, - ) - q_param_grads = self._run_backend_query_param_backward_phase( - grad_recurrent_q=grad_recurrent_q, - grad_output_q=grad_output_q, - sequence_static_tensors=param_static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - device=boundary_seq.device, - dtype=boundary_seq.dtype, - active_receiver_window=artifact_receiver_window, - ) - for parameter_index, grad_param in enumerate(q_param_grads): - grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( - grad_param_accum[parameter_index], - grad_param, - ) - - total_grad_input_k = _accumulate_grad(grad_input_k_from_output, grad_input_k_from_recurrent) - total_grad_input_v = _accumulate_grad(grad_input_v_from_output, grad_input_v_from_recurrent) - if step_index == int(boundary_seq.shape[1]) - 1: - total_grad_input_k = _accumulate_grad(total_grad_input_k, grad_input_k_last) - total_grad_input_v = _accumulate_grad(total_grad_input_v, grad_input_v_last) - grad_boundary_step, boundary_param_grads = self._run_backend_boundary_public_backward_phase( - boundary_step=artifacts.boundary_step, - grad_input_k=total_grad_input_k, - grad_input_v=total_grad_input_v, - sequence_static_tensors=param_static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - device=boundary_seq.device, - dtype=boundary_seq.dtype, - ) - if projected_boundary_active and grad_boundary_step is not None: - assert projected_boundary_source_seq is not None and projected_boundary_weight is not None - grad_boundary_flat = grad_boundary_step.reshape(int(grad_boundary_step.shape[0]), -1).contiguous() - with ( - torch.profiler.record_function("fabric.backward.input_projection"), - self._backend_owner_timing("input_projection"), - ): - if grad_projected_boundary_source_seq is not None: - grad_projected_boundary_source_seq[:, step_index].copy_( - grad_boundary_flat.matmul(projected_boundary_weight) - ) - if grad_projected_boundary_weight is not None: - grad_projected_boundary_weight.add_( - grad_boundary_flat.transpose(0, 1).matmul(projected_boundary_source_seq[:, step_index]) - ) - if grad_projected_boundary_bias is not None: - grad_projected_boundary_bias.add_(grad_boundary_flat.sum(dim=0)) - saw_projected_boundary_grad = True - elif grad_boundary_steps is not None: - grad_boundary_steps[step_index] = grad_boundary_step - for parameter_index, grad_param in enumerate(boundary_param_grads): - grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( - grad_param_accum[parameter_index], - grad_param, - ) - - ( - grad_hidden_before, - grad_initial_recurrent_k_before, - grad_initial_recurrent_v_before, - initial_recurrent_param_grads, - ) = self._run_backend_initial_recurrent_backward_phase( - hidden_before=artifacts.hidden_before, - initial_recurrent_k_before=artifacts.recurrent_k_before, - initial_recurrent_v_before=artifacts.recurrent_v_before, - population_reset_step=artifacts.population_reset_step, - grad_resolved_recurrent_k=grad_resolved_recurrent_k_before, - grad_resolved_recurrent_v=grad_resolved_recurrent_v_before, - sequence_static_tensors=param_static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - active_receiver_window=artifacts.active_receiver_window, - device=boundary_seq.device, - dtype=boundary_seq.dtype, - ) - for parameter_index, grad_param in enumerate(initial_recurrent_param_grads): - grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( - grad_param_accum[parameter_index], - grad_param, - ) - - running_grad_packed_state = grad_packed_state_before - if step_index == 0: - grad_initial_hidden_total = _accumulate_grad(grad_initial_hidden_total, grad_hidden_before) - grad_initial_recurrent_k_total = _accumulate_grad( - grad_initial_recurrent_k_total, - grad_initial_recurrent_k_before, - ) - grad_initial_recurrent_v_total = _accumulate_grad( - grad_initial_recurrent_v_total, - grad_initial_recurrent_v_before, - ) - else: - running_grad_recurrent_hidden = grad_hidden_before - running_grad_recurrent_k = grad_initial_recurrent_k_before - running_grad_recurrent_v = grad_initial_recurrent_v_before - if step_artifacts is None: - if recompute_window_len > 1 and artifact_window_start <= int(step_index) < artifact_window_end: - artifact_window[int(step_index) - artifact_window_start] = None - del artifacts - - final_grad_packed_state = self._scatter_receiver_window_batch_rows( - running_grad_packed_state, - active_receiver_window, - like=external_packed_state_like, - ) - final_grad_initial_hidden = ( - None - if synthetic_initial_hidden - else self._scatter_receiver_window_batch_rows( - grad_initial_hidden_total, - active_receiver_window, - like=external_initial_hidden_like, - ) - ) - final_grad_initial_recurrent_k = ( - None - if synthetic_initial_recurrent_k - else self._scatter_receiver_window_recurrent_bank_grad( - grad_initial_recurrent_k_total, - active_receiver_window, - like=external_initial_recurrent_k_like, - ) - ) - final_grad_initial_recurrent_v = ( - None - if synthetic_initial_recurrent_v - else self._scatter_receiver_window_recurrent_bank_grad( - grad_initial_recurrent_v_total, - active_receiver_window, - like=external_initial_recurrent_v_like, - ) - ) - - grad_sequence_inputs: dict[str, torch.Tensor | None] = { - "initial_hidden": final_grad_initial_hidden, - "population_resets": None, - "initial_recurrent_k": final_grad_initial_recurrent_k, - "initial_recurrent_v": final_grad_initial_recurrent_v, - } - if projected_boundary_active: - grad_sequence_inputs["projected_boundary_source_seq"] = ( - grad_projected_boundary_source_seq if saw_projected_boundary_grad else None - ) - grad_sequence_inputs["projected_boundary_weight"] = ( - grad_projected_boundary_weight if saw_projected_boundary_grad else None - ) - grad_sequence_inputs["projected_boundary_bias"] = ( - grad_projected_boundary_bias if saw_projected_boundary_grad else None - ) - else: - assert grad_boundary_steps is not None - grad_sequence_inputs["boundary_seq"] = ( - torch.stack( - [ - cast(torch.Tensor, grad_step) - if grad_step is not None - else torch.zeros_like(boundary_seq[:, step_index]) - for step_index, grad_step in enumerate(grad_boundary_steps) - ], - dim=1, - ) - if any(grad_step is not None for grad_step in grad_boundary_steps) - else None - ) - if packed_state_keys is None: - grad_sequence_inputs["packed_state_0"] = cast(torch.Tensor | None, final_grad_packed_state) - else: - assert isinstance(final_grad_packed_state, (dict, TensorDictBase)) - for index, key in enumerate(packed_state_keys): - grad_sequence_inputs[f"packed_state_{index}"] = cast( - torch.Tensor | None, final_grad_packed_state.get(key) - ) - - self._receiver_window_static_tensor_dict_cache = {} - return ( - grad_sequence_inputs, - self._finalize_backward_param_grads( - trainable_params=trainable_params, - grad_param_accum=grad_param_accum, - active_receiver_window=active_receiver_window, - ), - ) diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/__init__.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/__init__.py new file mode 100644 index 00000000..69b403a1 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/__init__.py @@ -0,0 +1 @@ +"""Compiler tables and temporal sequence-surface planning helpers.""" diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/allocation_audit.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/allocation_audit.py new file mode 100644 index 00000000..baa314a0 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/allocation_audit.py @@ -0,0 +1,339 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + + +TemporalAllocationOwner = Literal[ + "primitive_output", + "primitive_workspace", + "planned_runtime_buffer", + "metadata_row", + "illegal_scheduler_allocation", +] + + +@dataclass(frozen=True) +class TemporalRegisteredProgramAllocationRule: + relative_path: str + pattern: str + owner: TemporalAllocationOwner + reason: str + + +@dataclass(frozen=True) +class TemporalRegisteredProgramAllocationSite: + relative_path: str + line_number: int + line_text: str + owner: TemporalAllocationOwner | None + reason: str = "" + + @property + def summary(self) -> str: + owner = "unclassified" if self.owner is None else self.owner + return f"{self.relative_path}:{int(self.line_number)}:{owner}:{self.line_text.strip()}" + + +@dataclass(frozen=True) +class TemporalRegisteredProgramAllocationAudit: + sites: tuple[TemporalRegisteredProgramAllocationSite, ...] + + @property + def unclassified(self) -> tuple[TemporalRegisteredProgramAllocationSite, ...]: + return tuple(site for site in self.sites if site.owner is None) + + @property + def illegal(self) -> tuple[TemporalRegisteredProgramAllocationSite, ...]: + return tuple(site for site in self.sites if site.owner == "illegal_scheduler_allocation") + + @property + def summaries(self) -> tuple[str, ...]: + return tuple(site.summary for site in self.sites) + + @property + def review_summary(self) -> tuple[str, ...]: + by_owner: dict[str, int] = {} + for site in self.sites: + owner = "unclassified" if site.owner is None else str(site.owner) + by_owner[owner] = by_owner.get(owner, 0) + 1 + owner_counts = ",".join(f"{owner}:{count}" for owner, count in sorted(by_owner.items())) + return ( + "registered_program_allocation_audit=compiler_owned", + f"site_count={len(self.sites)}", + f"unclassified_count={len(self.unclassified)}", + f"illegal_count={len(self.illegal)}", + f"owner_counts={owner_counts}", + ) + + +_REGISTERED_PROGRAM_ROOT = Path("src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program") + +_ALLOCATION_PATTERN = re.compile( + r"\b(?:at::(?:empty|zeros|full|zeros_like|empty_like)|[A-Za-z0-9_]+\.new_(?:empty|zeros))\s*\(" +) + + +def registered_program_allocation_rules() -> tuple[TemporalRegisteredProgramAllocationRule, ...]: + return _REGISTERED_PROGRAM_ALLOCATION_RULES + + +def audit_registered_program_allocations(repo_root: Path) -> TemporalRegisteredProgramAllocationAudit: + repo_root = Path(repo_root) + sites: list[TemporalRegisteredProgramAllocationSite] = [] + rule_cache = tuple( + ( + rule, + re.compile(rule.pattern, flags=re.MULTILINE), + ) + for rule in registered_program_allocation_rules() + ) + for source_path in sorted((repo_root / _REGISTERED_PROGRAM_ROOT).glob("*.cuh")): + relative_path = source_path.relative_to(repo_root).as_posix() + source_text = source_path.read_text(encoding="utf-8") + lines = source_text.splitlines() + for line_index, line in enumerate(lines): + if _ALLOCATION_PATTERN.search(line) is None: + continue + context = "\n".join(lines[line_index : min(len(lines), line_index + 4)]) + matched_rule = _allocation_rule_for_context( + rule_cache, + relative_path=relative_path, + context=context, + ) + sites.append( + TemporalRegisteredProgramAllocationSite( + relative_path=relative_path, + line_number=line_index + 1, + line_text=line, + owner=None if matched_rule is None else matched_rule.owner, + reason="" if matched_rule is None else matched_rule.reason, + ) + ) + return TemporalRegisteredProgramAllocationAudit(sites=tuple(sites)) + + +def assert_registered_program_allocations_are_classified(repo_root: Path) -> TemporalRegisteredProgramAllocationAudit: + audit = audit_registered_program_allocations(repo_root) + if audit.unclassified or audit.illegal: + details = "\n".join(site.summary for site in (*audit.unclassified, *audit.illegal)) + raise RuntimeError( + "Registered temporal program has unclassified or illegal allocation sites. " + "Classify each site as primitive_output, planned_runtime_buffer, metadata_row, " + "or illegal_scheduler_allocation before extending the compiler path.\n" + details + ) + return audit + + +def _allocation_rule_for_context( + rule_cache: tuple[tuple[TemporalRegisteredProgramAllocationRule, re.Pattern[str]], ...], + *, + relative_path: str, + context: str, +) -> TemporalRegisteredProgramAllocationRule | None: + for rule, pattern in rule_cache: + if rule.relative_path != relative_path: + continue + if pattern.search(context): + return rule + return None + + +def _primitive_output_file(relative_name: str) -> TemporalRegisteredProgramAllocationRule: + return TemporalRegisteredProgramAllocationRule( + relative_path=(_REGISTERED_PROGRAM_ROOT / relative_name).as_posix(), + pattern=r"\b(?:at::(?:empty|zeros|full|zeros_like|empty_like)|[A-Za-z0-9_]+\.new_(?:empty|zeros))\s*\(", + owner="primitive_output", + reason="allocations in this registered primitive implementation are returned primitive outputs or adjoint outputs", + ) + + +def _metadata_rule(relative_name: str, pattern: str, reason: str) -> TemporalRegisteredProgramAllocationRule: + return TemporalRegisteredProgramAllocationRule( + relative_path=(_REGISTERED_PROGRAM_ROOT / relative_name).as_posix(), + pattern=pattern, + owner="metadata_row", + reason=reason, + ) + + +def _primitive_output_rule(relative_name: str, pattern: str, reason: str) -> TemporalRegisteredProgramAllocationRule: + return TemporalRegisteredProgramAllocationRule( + relative_path=(_REGISTERED_PROGRAM_ROOT / relative_name).as_posix(), + pattern=pattern, + owner="primitive_output", + reason=reason, + ) + + +_REGISTERED_PROGRAM_ALLOCATION_RULES = ( + _primitive_output_file("operator_exports.cuh"), + _primitive_output_file("parameter_reducer_program.cuh"), + _primitive_output_file("reverse_artifacts_and_resets.cuh"), + _primitive_output_file("transition_forward_program.cuh"), + _primitive_output_file("transition_math_helpers.cuh"), + _primitive_output_file("transition_primitive_forward_ops.cuh"), + _primitive_output_file("transition_reverse_handlers.cuh"), + _primitive_output_file("transition_reverse_program.cuh"), + _metadata_rule( + "backward_program.cuh", + r"grad_output_window\.new_empty\(\{0\}\)", + "zero-size sentinel tensor used to represent absent carry/seed groups", + ), + _metadata_rule( + "backward_program.cuh", + r"empty = tensor\.new_empty\(\{0\}\)", + "zero-size sentinel tensor used when pruning transition output slots by compiler liveness rows", + ), + _metadata_rule( + "backward_program.cuh", + r"at::zeros\(\{1\}, empty\.options\(\)\.dtype\(at::kLong\)\)", + "empty transition seed-row metadata placeholder", + ), + _metadata_rule( + "backward_program.cuh", + r"at::empty\(\{old_rows\.size\(0\) \+ 1, 3\}", + "transition public-y seed metadata row extension", + ), + _metadata_rule( + "backward_program.cuh", + r"at::empty\(\{static_cast\(group_rows\.size\(\) / 3\), 3\}", + "transition next-seed metadata row materialization", + ), + _metadata_rule( + "backward_program.cuh", + r"at::empty\(\s*\{static_cast\(rows\.size\(\) / 5\), 5\}", + "native registered-backward memory telemetry metadata rows", + ), + _metadata_rule( + "backward_program.cuh", + r"new_empty\(\{0, 0, 0\}\)", + "zero-size rank-3 sentinel tensor for reduced fixed-slot key gradients consumed by value-only projection strategies", + ), + _metadata_rule( + "backward_surface_steps.cuh", + r"recurrent_hidden_backend_order\.new_empty\(\{0\}\)", + "zero-size sentinel tensor for absent initial recurrent K/V gradient", + ), + _metadata_rule( + "backward_surface_steps.cuh", + r"reference\.new_empty\(\{0\}\)", + "zero-size sentinel tensor for absent fixed-slot context value projection weight", + ), + TemporalRegisteredProgramAllocationRule( + relative_path=(_REGISTERED_PROGRAM_ROOT / "backward_surface_steps.cuh").as_posix(), + pattern=r"at::zeros\(\{value_weight\.size\(0\), value_weight\.size\(1\), head_dim\}", + owner="primitive_workspace", + reason="fixed-slot context local zero-key prefix workspace owned by the registered reverse message primitive", + ), + TemporalRegisteredProgramAllocationRule( + relative_path=(_REGISTERED_PROGRAM_ROOT / "backward_surface_steps.cuh").as_posix(), + pattern=r"at::zeros\(\s*\{grad_cells_out\.size\(0\), input_count \+ recurrent_count \+ route_output_count", + owner="primitive_workspace", + reason="route-split readout backward workspace owned by the registered reverse readout primitive", + ), + _metadata_rule( + "backward_surface_steps.cuh", + r"boundary_step\.new_empty\(\{0\}\)", + "zero-size sentinel tensor for absent fixed-slot context grouped input value weight", + ), + _metadata_rule( + "backward_surface_steps.cuh", + r"at::empty\(\{1\}, at::TensorOptions\(\)\.dtype\(at::kLong\)\.device\(at::kCPU\)\)", + "CPU grouped-flag metadata row returned with recurrent K/V parameter gradients", + ), + _metadata_rule( + "backward_surface_steps.cuh", + r"at::empty_like\(accumulated\)", + "CPU scalar grouped-flag metadata row produced while merging compiler-routed reverse span outputs", + ), + _metadata_rule( + "executor_span_decode.cuh", + r"at::empty\(\{executor_rows\.size\(0\), kFusedProgramSpanColumns\}", + "decoded span metadata rows", + ), + _metadata_rule( + "executor_span_decode.cuh", + r"at::empty\(\{0, kNativeStrategyRowColumns\}", + "empty native strategy metadata row table", + ), + _metadata_rule( + "forward_program.cuh", + r"boundary_seq\.new_empty\(\{0\}\)", + "zero-size sentinel tensor for optional recurrent K/V projection inputs", + ), + _metadata_rule( + "forward_program.cuh", + r"input_v_step\.new_empty\(\{0\}\)", + "zero-size sentinel tensor for compiler-selected message-to-transition streaming rows", + ), + _metadata_rule( + "forward_program.cuh", + r"at::empty\(\s*\{static_cast\(rows\.size\(\) / 5\), 5\}", + "native registered-forward memory telemetry metadata rows", + ), + _primitive_output_rule( + "forward_program.cuh", + r"at::empty\(\{time_steps, batch_size, sender_count, 2 \* key_part_dim\}", + "fixed-slot context sender-key bank emitted by the registered message primitive", + ), + _primitive_output_rule( + "forward_program.cuh", + r"at::empty\(\{T, B, sender_count, value_dim\}", + "fixed-slot context boundary value projection emitted by the registered message primitive", + ), + _primitive_output_rule( + "forward_program.cuh", + r"at::empty\(\{B, sender_count, value_dim\}", + "fixed-slot context recurrent value projection emitted by the registered message primitive", + ), + _metadata_rule( + "forward_program.cuh", + r"at::empty\(\s*\{row_count, kReverseArtifactBindingRowColumns\}", + "reverse artifact binding metadata rows returned by the forward program", + ), + TemporalRegisteredProgramAllocationRule( + relative_path=(_REGISTERED_PROGRAM_ROOT / "memory_runtime_buffers.cuh").as_posix(), + pattern=r"at::empty\(expected_shape, tensor\.options\(\)\)", + owner="planned_runtime_buffer", + reason="deferred local transition output materialized from compiler runtime-buffer rows at first use", + ), + _metadata_rule( + "program_spans_and_handlers.cuh", + r"at::empty\(\{7\}", + "fused-program validation summary metadata", + ), + _metadata_rule( + "program_spans_and_handlers.cuh", + r"at::empty\(\{0, kNativeStrategyRowColumns\}", + "empty native strategy metadata row table", + ), + _metadata_rule( + "program_spans_and_handlers.cuh", + r"at::empty\(\{0, kNativeCallableBindingSchemaRowColumns\}", + "empty native callable binding metadata row table", + ), + _metadata_rule( + "program_spans_and_handlers.cuh", + r"at::empty\(\{0, kNativeCallableOutputRowColumns\}", + "empty native callable output metadata row table", + ), + _metadata_rule( + "program_tensor_access.cuh", + r"at::empty\(\{0\}, tensor\.options\(\)\)", + "zero-size sentinel tensor used to compact final program tensor return slots by compiler liveness rows", + ), +) + + +__all__ = [ + "TemporalAllocationOwner", + "TemporalRegisteredProgramAllocationAudit", + "TemporalRegisteredProgramAllocationRule", + "TemporalRegisteredProgramAllocationSite", + "assert_registered_program_allocations_are_classified", + "audit_registered_program_allocations", + "registered_program_allocation_rules", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/backward_plan.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/backward_plan.py new file mode 100644 index 00000000..4e93b6b2 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/backward_plan.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import torch + +from .executor_bindings import ( + TemporalExecutorBindingPlan, + build_temporal_reverse_executor_binding_plan, +) +from .strategy_selection import ( + TemporalStrategySelectionReport, + build_temporal_strategy_selection_report, +) +from .tables import ( + TemporalPrimitiveTablePlan, + temporal_reverse_executor_rows, + temporal_reverse_executor_rows_tensor, +) + + +@dataclass(frozen=True) +class TemporalBackwardExecutablePlan: + reverse_executor_rows: torch.Tensor + executor_binding_rows: torch.Tensor + executor_summaries: tuple[str, ...] + executor_binding_summaries: tuple[str, ...] + executor_binding_blockers: tuple[str, ...] + strategy_ids: tuple[str, ...] + strategy_candidate_summaries: tuple[str, ...] + strategy_legality_status: Literal["legal", "blocked"] + strategy_legality_reasons: tuple[str, ...] + tape_policy: str + state_gradient_policy: str + parameter_gradient_policy: str + runtime_entrypoint: str = "registered_reverse_executor_bindings" + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "backward_executable_plan=compiler_owned", + f"runtime_entrypoint={self.runtime_entrypoint}", + f"executor_row_count={int(self.reverse_executor_rows.shape[0])}", + f"executor_binding_row_count={int(self.executor_binding_rows.shape[0])}", + "strategy_ids=" + ",".join(self.strategy_ids), + f"strategy_legality_status={self.strategy_legality_status}", + *self.strategy_legality_reasons, + f"tape_policy={self.tape_policy}", + f"state_gradient_policy={self.state_gradient_policy}", + f"parameter_gradient_policy={self.parameter_gradient_policy}", + *self.executor_binding_blockers, + *self.executor_summaries, + *self.executor_binding_summaries, + ) + + +def build_temporal_backward_executable_plan( + table: TemporalPrimitiveTablePlan, + *, + reverse_binding_plan: TemporalExecutorBindingPlan | None = None, + strategy_report: TemporalStrategySelectionReport | None = None, +) -> TemporalBackwardExecutablePlan: + executor_rows = temporal_reverse_executor_rows(table) + reverse_binding_plan = ( + build_temporal_reverse_executor_binding_plan(table) if reverse_binding_plan is None else reverse_binding_plan + ) + strategy_report = ( + build_temporal_strategy_selection_report( + table, + reverse_binding_plan=reverse_binding_plan, + directions=("reverse",), + ) + if strategy_report is None + else strategy_report + ) + reverse_candidates = tuple( + candidate for candidate in strategy_report.candidates if candidate.direction == "reverse" + ) + strategy_ids = tuple( + candidate.strategy_id for candidate in reverse_candidates if candidate.match_status == "matched" + ) + strategy_legality_reasons = tuple( + reason + for candidate in reverse_candidates + if candidate.match_status == "matched" + for reason in candidate.legality_reasons + ) + strategy_legality_status: Literal["legal", "blocked"] = ( + "blocked" + if reverse_binding_plan.has_blockers + or not strategy_ids + or any( + candidate.legality_status == "blocked" + for candidate in reverse_candidates + if candidate.match_status == "matched" + ) + else "legal" + ) + return TemporalBackwardExecutablePlan( + reverse_executor_rows=temporal_reverse_executor_rows_tensor(table), + executor_binding_rows=reverse_binding_plan.rows, + executor_summaries=tuple(row.summary for row in executor_rows), + executor_binding_summaries=reverse_binding_plan.summaries, + executor_binding_blockers=reverse_binding_plan.blocker_summaries, + strategy_ids=strategy_ids, + strategy_candidate_summaries=tuple(candidate.summary for candidate in reverse_candidates), + strategy_legality_status=strategy_legality_status, + strategy_legality_reasons=strategy_legality_reasons, + tape_policy="compiler_declared_tape_or_recompute", + state_gradient_policy="compiler_binding_owned_state_gradients", + parameter_gradient_policy="compiler_binding_owned_parameter_reductions", + ) + + +__all__ = [ + "TemporalBackwardExecutablePlan", + "build_temporal_backward_executable_plan", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal_buckets.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/buckets.py similarity index 83% rename from src/cortical/fabric/backend/cuda/sequence_surface/temporal_buckets.py rename to src/cortical/fabric/backend/cuda/sequence_surface/compiler/buckets.py index d3c23e70..182203df 100644 --- a/src/cortical/fabric/backend/cuda/sequence_surface/temporal_buckets.py +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/buckets.py @@ -5,7 +5,8 @@ import torch -from cortical.fabric.backend.cuda import transition_execution +from cortical.fabric.backend.cuda.transition_execution.projection import factorized_recurrent_input_prepack +from cortical.fabric.backend.flat_bucket_identity import transition_flat_bucket_identity POPULATION_STATIC_TENSORS_KEY = "_flat_bucket_population_static_tensors" TRAINABLE_ITEMS_KEY = "_flat_bucket_trainable_items" @@ -32,26 +33,47 @@ @dataclass(frozen=True) -class TemporalPopulationBucket: - name: str +class TemporalFlatBucket: + binding_name: str + binding_slot: int + flat_bucket_identity: tuple[str, ...] backend_start: int backend_stop: int recurrent_indices: torch.Tensor static_tensors: dict[str, object] + @property + def name(self) -> str: + return self.binding_name + @property def count(self) -> int: return int(self.backend_stop) - int(self.backend_start) + @property + def binding_identity(self) -> tuple[str, ...]: + return ( + "flat_bucket_binding", + f"binding_slot={int(self.binding_slot)}", + ) + @dataclass(frozen=True) class TemporalBucketPlan: - population_buckets: tuple[TemporalPopulationBucket, ...] + flat_buckets: tuple[TemporalFlatBucket, ...] executor: str = "temporal_bucket_sequence" + @property + def population_buckets(self) -> tuple[TemporalFlatBucket, ...]: + return self.flat_buckets + @property def population_names(self) -> tuple[str, ...]: - return tuple(bucket.name for bucket in self.population_buckets) + return tuple(bucket.binding_name for bucket in self.flat_buckets) + + @property + def flat_bucket_identities(self) -> tuple[tuple[str, ...], ...]: + return tuple(bucket.flat_bucket_identity for bucket in self.flat_buckets) @dataclass(frozen=True) @@ -97,15 +119,18 @@ def temporal_bucket_plan(runtime: Any, static_tensors: dict[str, object]) -> Tem if isinstance(cached, TemporalBucketPlan): return cached recurrent_count = int(runtime.recurrent_cell_idx.numel()) - buckets: list[TemporalPopulationBucket] = [] + buckets: list[TemporalFlatBucket] = [] for name in active_population_names(runtime): start, stop = runtime._population_backend_recurrent_slice(name) if start == stop: continue recurrent_indices = runtime._population_recurrent_indices(name) + binding_slot = int(runtime._population_name_to_idx[name]) buckets.append( - TemporalPopulationBucket( - name=name, + TemporalFlatBucket( + binding_name=name, + binding_slot=binding_slot, + flat_bucket_identity=_runtime_flat_bucket_identity(runtime, name, binding_slot), backend_start=int(start), backend_stop=int(stop), recurrent_indices=recurrent_indices, @@ -117,16 +142,40 @@ def temporal_bucket_plan(runtime: Any, static_tensors: dict[str, object]) -> Tem ), ) ) - plan = TemporalBucketPlan(population_buckets=tuple(buckets)) + plan = TemporalBucketPlan(flat_buckets=tuple(buckets)) static_tensors[BACKEND_ORDER_BUCKET_PLAN_KEY] = plan return plan +def backend_order_flat_buckets( + runtime: Any, + static_tensors: dict[str, object], +) -> tuple[TemporalFlatBucket, ...]: + return temporal_bucket_plan(runtime, static_tensors).flat_buckets + + def backend_order_population_buckets( runtime: Any, static_tensors: dict[str, object], -) -> tuple[TemporalPopulationBucket, ...]: - return temporal_bucket_plan(runtime, static_tensors).population_buckets +) -> tuple[TemporalFlatBucket, ...]: + return backend_order_flat_buckets(runtime, static_tensors) + + +def _runtime_flat_bucket_identity(runtime: Any, binding_name: str, binding_slot: int) -> tuple[str, ...]: + del binding_slot + cell_spec = runtime._cell_spec_for_population(binding_name) + population_spec = runtime._backend_population_specs.get(binding_name) + return transition_flat_bucket_identity( + state_schema_keys=cell_spec.state_schema.keys, + public_schema_kind=cell_spec.public_schema.kind, + parameter_schema_keys=cell_spec.parameter_schema.keys, + input_projection_schema_keys=cell_spec.input_projection_schema.keys, + public_projection_schema_keys=cell_spec.public_projection_schema.keys, + transition_ir=None if population_spec is None else population_spec.transition_ir, + transition_parameter_bindings=None + if population_spec is None + else population_spec.transition_parameter_bindings, + ) def temporal_backward_owner_plan( @@ -311,14 +360,15 @@ def population_static_tensors( input_proj_weight_t = population_params.get("input_proj_weight_t") if torch.is_tensor(input_proj_weight_t): out["input_proj_weight_t"] = input_proj_weight_t - recurrent_weight = out.get("value_to_cell_weight") + recurrent_weight_source = str(out.get("recurrent_message_to_cell_weight_source", "value_to_cell_weight")) + recurrent_weight = out.get(recurrent_weight_source) recurrent_bias = out.get("recurrent_cell_bias") if ( torch.is_tensor(recurrent_weight) and torch.is_tensor(recurrent_bias) and torch.is_tensor(input_proj_weight_t) ): - fused_weight, fused_bias = transition_execution.factorized_recurrent_input_prepack( + fused_weight, fused_bias = factorized_recurrent_input_prepack( value_to_cell_weight=recurrent_weight, recurrent_cell_bias=recurrent_bias, input_proj_weight_t=input_proj_weight_t, @@ -393,8 +443,9 @@ def flat_bucket_trainable_items( "TemporalBucketPlan", "TemporalBackwardOwner", "TemporalBackwardOwnerPlan", - "TemporalPopulationBucket", + "TemporalFlatBucket", "active_population_names", + "backend_order_flat_buckets", "backend_order_population_buckets", "flat_bucket_trainable_items", "population_static_tensors", diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py new file mode 100644 index 00000000..d3f5b252 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py @@ -0,0 +1,729 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import torch + +from cortical.fabric.backend.cuda.transition_execution.registry import ( + transition_primitive_executor_record_for_lowered_primitive, +) + +from .tables import ( + TemporalForwardExecutorRow, + TemporalPrimitiveTablePlan, + TemporalReverseExecutorRow, + TemporalTensorBindingRow, + temporal_forward_executor_rows, + temporal_reverse_executor_rows, +) + + +TemporalExecutorDirection = Literal["forward", "reverse"] + +_DIRECTION_OPCODE = {"forward": 1, "reverse": 2} +_BINDING_KIND_OPCODE = {"input": 0, "parameter": 1, "output": 2} +_TRANSITION_PARAM_REDUCER_KIND_OPCODE = { + "materialized": 0, + "input_projection_weight": 1, + "input_projection_bias": 2, +} + + +@dataclass(frozen=True) +class TemporalExecutorTensorBinding: + direction: TemporalExecutorDirection + executor_row_index: int + executor_id: int + executor_name: str + surface: str + bucket_ordinal: int + receiver_start: int + receiver_count: int + local_binding_index: int + primitive_row_index: int + primitive: str + binding_index: int + binding_kind: str + logical_name: str + source_bindings: tuple[str, ...] + + @property + def summary(self) -> str: + sources = ",".join(self.source_bindings) if self.source_bindings else "-" + return ( + f"direction={self.direction}" + f",executor_row={int(self.executor_row_index)}" + f",executor_id={int(self.executor_id)}" + f",executor={self.executor_name}" + f",surface={self.surface}" + f",bucket={int(self.bucket_ordinal)}" + f",local_binding={int(self.local_binding_index)}" + f",primitive_row={int(self.primitive_row_index)}" + f",primitive={self.primitive}" + f",binding={int(self.binding_index)}" + f",kind={self.binding_kind}" + f",logical={self.logical_name}" + f",sources={sources}" + ) + + +@dataclass(frozen=True) +class TemporalExecutorBindingBlocker: + direction: TemporalExecutorDirection + executor_row_index: int + executor_name: str + surface: str + bucket_ordinal: int + code: Literal["MISSING_REQUIRED_BINDING", "UNKNOWN_BINDING_KIND"] + reason: str + + @property + def summary(self) -> str: + return ( + f"direction={self.direction}" + f",executor_row={int(self.executor_row_index)}" + f",executor={self.executor_name}" + f",surface={self.surface}" + f",bucket={int(self.bucket_ordinal)}" + f",code={self.code}" + f",reason={self.reason}" + ) + + +@dataclass(frozen=True) +class TemporalExecutorBindingPlan: + direction: TemporalExecutorDirection + bindings: tuple[TemporalExecutorTensorBinding, ...] + blockers: tuple[TemporalExecutorBindingBlocker, ...] = () + + @property + def rows(self) -> torch.Tensor: + rows = [ + [ + _DIRECTION_OPCODE[self.direction], + int(binding.executor_row_index), + int(binding.executor_id), + int(binding.primitive_row_index), + int(binding.binding_index), + int(binding.bucket_ordinal), + _BINDING_KIND_OPCODE[str(binding.binding_kind)], + int(binding.local_binding_index), + ] + for binding in self.bindings + ] + if not rows: + return torch.empty((0, 8), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + @property + def summaries(self) -> tuple[str, ...]: + return tuple(binding.summary for binding in self.bindings) + + @property + def blocker_summaries(self) -> tuple[str, ...]: + return tuple(blocker.summary for blocker in self.blockers) + + @property + def has_blockers(self) -> bool: + return bool(self.blockers) + + +@dataclass(frozen=True) +class TemporalTransitionParamGradBinding: + executor_row_index: int + executor_id: int + executor_name: str + bucket_ordinal: int + transition_primitive_row_index: int + reduction_primitive_row_index: int + parameter_name: str + grad_logical_name: str + grad_binding_index: int + parameter_binding_index: int + reducer_kind: str + source_bindings: tuple[str, ...] + selected_static_source: str = "" + + @property + def summary(self) -> str: + sources = ",".join(self.source_bindings) if self.source_bindings else "-" + selected_source = self.selected_static_source or "-" + return ( + f"executor_row={int(self.executor_row_index)}" + f",executor_id={int(self.executor_id)}" + f",executor={self.executor_name}" + f",bucket={int(self.bucket_ordinal)}" + f",transition_primitive_row={int(self.transition_primitive_row_index)}" + f",reduction_primitive_row={int(self.reduction_primitive_row_index)}" + f",parameter={self.parameter_name}" + f",grad_logical={self.grad_logical_name}" + f",grad_binding={int(self.grad_binding_index)}" + f",parameter_binding={int(self.parameter_binding_index)}" + f",reducer={self.reducer_kind}" + f",sources={sources}" + f",selected_static_source={selected_source}" + ) + + +@dataclass(frozen=True) +class TemporalTransitionParamGradBindingPlan: + bindings: tuple[TemporalTransitionParamGradBinding, ...] + + @property + def rows(self) -> torch.Tensor: + rows = [ + [ + int(binding.executor_row_index), + int(binding.executor_id), + int(binding.bucket_ordinal), + int(binding.transition_primitive_row_index), + int(binding.reduction_primitive_row_index), + int(binding.grad_binding_index), + int(binding.parameter_binding_index), + int(_TRANSITION_PARAM_REDUCER_KIND_OPCODE[str(binding.reducer_kind)]), + ] + for binding in self.bindings + ] + if not rows: + return torch.empty((0, 8), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + @property + def summaries(self) -> tuple[str, ...]: + return tuple(binding.summary for binding in self.bindings) + + +def build_temporal_forward_executor_binding_plan( + table: TemporalPrimitiveTablePlan, +) -> TemporalExecutorBindingPlan: + return _build_temporal_executor_binding_plan( + table, + direction="forward", + executor_rows=temporal_forward_executor_rows(table), + ) + + +def build_temporal_reverse_executor_binding_plan( + table: TemporalPrimitiveTablePlan, +) -> TemporalExecutorBindingPlan: + return _build_temporal_executor_binding_plan( + table, + direction="reverse", + executor_rows=temporal_reverse_executor_rows(table), + ) + + +def build_temporal_transition_param_grad_binding_plan( + table: TemporalPrimitiveTablePlan, + reverse_binding_plan: TemporalExecutorBindingPlan | None = None, +) -> TemporalTransitionParamGradBindingPlan: + reverse_binding_plan = ( + build_temporal_reverse_executor_binding_plan(table) if reverse_binding_plan is None else reverse_binding_plan + ) + bindings: list[TemporalTransitionParamGradBinding] = [] + for executor_row_index, executor_row in enumerate(temporal_reverse_executor_rows(table)): + if executor_row.surface != "transition": + continue + primitive_row = table.primitive_rows[int(executor_row.primitive_row_start)] + primitive = str(primitive_row.primitive) + primitive_record = transition_primitive_executor_record_for_lowered_primitive(primitive) + grad_contract = _transition_param_grad_contract(primitive_row, primitive_record) + for grad_logical_name, parameter_name, reducer_kind in grad_contract: + grad_binding = _find_executor_binding( + reverse_binding_plan, + executor_row_index=int(executor_row_index), + binding_kind="output", + logical_name=grad_logical_name, + ) + parameter_binding = _find_executor_binding( + reverse_binding_plan, + executor_row_index=int(executor_row_index), + binding_kind="parameter", + logical_name=parameter_name, + ) + reduction_row_index = _find_transition_parameter_reduction_row_index( + table, + bucket_ordinal=int(executor_row.bucket_ordinal), + parameter_name=parameter_name, + ) + if grad_binding is None or parameter_binding is None or reduction_row_index is None: + continue + bindings.append( + TemporalTransitionParamGradBinding( + executor_row_index=int(executor_row_index), + executor_id=int(executor_row.executor_id), + executor_name=executor_row.executor_name, + bucket_ordinal=int(executor_row.bucket_ordinal), + transition_primitive_row_index=int(executor_row.primitive_row_start), + reduction_primitive_row_index=int(reduction_row_index), + parameter_name=parameter_name, + grad_logical_name=grad_logical_name, + grad_binding_index=int(grad_binding.binding_index), + parameter_binding_index=int(parameter_binding.binding_index), + reducer_kind=reducer_kind, + source_bindings=tuple(parameter_binding.source_bindings), + selected_static_source=_selected_transition_static_source( + table, + parameter_binding=parameter_binding, + parameter_name=parameter_name, + reducer_kind=reducer_kind, + ), + ) + ) + return TemporalTransitionParamGradBindingPlan(bindings=tuple(bindings)) + + +def _transition_param_grad_contract( + primitive_row: object, + primitive_record: object | None, +) -> tuple[tuple[str, str, str], ...]: + if primitive_record is None: + return () + generic_parameter_bindings = tuple(getattr(primitive_record, "parameter_bindings", ())) + actual_parameters = tuple(str(parameter) for parameter in getattr(primitive_row, "parameter_inputs", ())) + if ( + generic_parameter_bindings + and set(generic_parameter_bindings) <= {"weight", "bias", "eps"} + and actual_parameters + ): + contracts: list[tuple[str, str, str]] = [] + for generic_name, actual_name in zip(generic_parameter_bindings, actual_parameters, strict=False): + if generic_name == "eps": + continue + contracts.append( + ( + f"grad_{actual_name}", + actual_name, + _transition_generic_param_reducer_kind(generic_name, actual_name), + ) + ) + return tuple(contracts) + return tuple(getattr(primitive_record, "param_grad_outputs", ())) + + +def _transition_generic_param_reducer_kind(generic_name: str, actual_name: str) -> str: + if generic_name == "weight" and actual_name in {"value_to_state_weight", "input_proj_weight"}: + return "input_projection_weight" + if generic_name == "bias" and actual_name in {"recurrent_bias", "recurrent_cell_bias"}: + return "input_projection_bias" + return "materialized" + + +def _selected_transition_static_source( + _table: TemporalPrimitiveTablePlan, + *, + parameter_binding: TemporalExecutorTensorBinding, + parameter_name: str, + reducer_kind: str, +) -> str: + if reducer_kind != "input_projection_weight": + return "" + if parameter_name not in {"value_to_state_weight", "input_proj_weight"}: + return "" + sources = set(str(source) for source in parameter_binding.source_bindings) + if ( + "static_tensor:fused_recurrent_value_to_cell_weight" in sources + and "static_tensor:message_to_cell_weight" in sources + ): + return "message_to_cell_weight" + has_value_to_cell_source = any(source.endswith(":value_to_cell_weight") for source in sources) + if has_value_to_cell_source: + return "" + return "message_to_cell_weight" if "static_tensor:message_to_cell_weight" in sources else "" + + +def _build_temporal_executor_binding_plan( + table: TemporalPrimitiveTablePlan, + *, + direction: TemporalExecutorDirection, + executor_rows: tuple[TemporalForwardExecutorRow | TemporalReverseExecutorRow, ...], +) -> TemporalExecutorBindingPlan: + bindings: list[TemporalExecutorTensorBinding] = [] + blockers: list[TemporalExecutorBindingBlocker] = [] + next_synthetic_binding_index = 1 + max( + (int(binding.binding_index) for binding in table.tensor_bindings), + default=-1, + ) + reverse_grad_binding_indices: dict[tuple[int, str], int] = {} + for executor_row_index, executor_row in enumerate(executor_rows): + if direction == "reverse" and executor_row.surface == "transition": + local_bindings, next_synthetic_binding_index = _transition_reverse_executor_contract_bindings( + table, + executor_row, + executor_row_index=int(executor_row_index), + next_synthetic_binding_index=int(next_synthetic_binding_index), + grad_binding_indices=reverse_grad_binding_indices, + ) + else: + local_bindings = _executor_bindings_from_tensor_binding_rows( + direction=direction, + executor_row=executor_row, + executor_row_index=int(executor_row_index), + rows=_tensor_bindings_for_executor_row(table, executor_row), + ) + present_parameters = {binding.logical_name for binding in local_bindings if binding.binding_kind == "parameter"} + for parameter in tuple(executor_row.parameter_bindings): + if str(parameter) not in present_parameters: + blockers.append( + TemporalExecutorBindingBlocker( + direction=direction, + executor_row_index=int(executor_row_index), + executor_name=executor_row.executor_name, + surface=executor_row.surface, + bucket_ordinal=int(executor_row.bucket_ordinal), + code="MISSING_REQUIRED_BINDING", + reason=f"executor parameter {parameter!r} has no compiler tensor binding row", + ) + ) + for binding in local_bindings: + if binding.binding_kind not in _BINDING_KIND_OPCODE: + blockers.append( + TemporalExecutorBindingBlocker( + direction=direction, + executor_row_index=int(executor_row_index), + executor_name=executor_row.executor_name, + surface=executor_row.surface, + bucket_ordinal=int(executor_row.bucket_ordinal), + code="UNKNOWN_BINDING_KIND", + reason=f"binding {binding.binding_index} uses unknown kind {binding.binding_kind!r}", + ) + ) + continue + bindings.append(binding) + return TemporalExecutorBindingPlan( + direction=direction, + bindings=tuple(bindings), + blockers=tuple(blockers), + ) + + +def _find_executor_binding( + plan: TemporalExecutorBindingPlan, + *, + executor_row_index: int, + binding_kind: str, + logical_name: str, +) -> TemporalExecutorTensorBinding | None: + matches = tuple( + binding + for binding in plan.bindings + if int(binding.executor_row_index) == int(executor_row_index) + and binding.binding_kind == str(binding_kind) + and binding.logical_name == str(logical_name) + ) + if not matches: + return None + return matches[0] + + +def _find_transition_parameter_reduction_row_index( + table: TemporalPrimitiveTablePlan, + *, + bucket_ordinal: int, + parameter_name: str, +) -> int | None: + for row_index, row in enumerate(table.primitive_rows): + if "surface=parameter_reduction" not in row.flat_bucket_identity: + continue + bucket_attr = next((value for key, value in row.attributes if key == "bucket_ordinal"), None) + parameter_attr = next((value for key, value in row.attributes if key == "parameter"), None) + if bucket_attr == str(int(bucket_ordinal)) and parameter_attr == str(parameter_name): + return int(row_index) + return None + + +def _executor_bindings_from_tensor_binding_rows( + *, + direction: TemporalExecutorDirection, + executor_row: TemporalForwardExecutorRow | TemporalReverseExecutorRow, + executor_row_index: int, + rows: tuple[TemporalTensorBindingRow, ...], +) -> tuple[TemporalExecutorTensorBinding, ...]: + return tuple( + TemporalExecutorTensorBinding( + direction=direction, + executor_row_index=int(executor_row_index), + executor_id=int(executor_row.executor_id), + executor_name=executor_row.executor_name, + surface=executor_row.surface, + bucket_ordinal=int(executor_row.bucket_ordinal), + receiver_start=int(executor_row.receiver_start), + receiver_count=int(executor_row.receiver_count), + local_binding_index=int(local_binding_index), + primitive_row_index=int(binding.row_index), + primitive=binding.primitive, + binding_index=int(binding.binding_index), + binding_kind=binding.binding_kind, + logical_name=binding.logical_name, + source_bindings=tuple(binding.source_bindings), + ) + for local_binding_index, binding in enumerate(rows) + ) + + +def _transition_reverse_executor_contract_bindings( + table: TemporalPrimitiveTablePlan, + executor_row: TemporalForwardExecutorRow | TemporalReverseExecutorRow, + *, + executor_row_index: int, + next_synthetic_binding_index: int, + grad_binding_indices: dict[tuple[int, str], int], +) -> tuple[tuple[TemporalExecutorTensorBinding, ...], int]: + primitive_row = table.primitive_rows[int(executor_row.primitive_row_start)] + primitive = str(primitive_row.primitive) + primitive_record = transition_primitive_executor_record_for_lowered_primitive(primitive) + if ( + primitive_record is None + or not primitive_record.reverse_input_bindings + or not primitive_record.reverse_output_bindings + ): + return ( + _executor_bindings_from_tensor_binding_rows( + direction="reverse", + executor_row=executor_row, + executor_row_index=int(executor_row_index), + rows=_tensor_bindings_for_executor_row(table, executor_row), + ), + int(next_synthetic_binding_index), + ) + input_names = primitive_record.reverse_input_bindings + parameter_names = primitive_record.parameter_bindings + output_names = primitive_record.reverse_output_bindings + parameter_names_for_schema = parameter_names + actual_parameter_names = tuple(primitive_row.parameter_inputs) + if parameter_names and set(parameter_names) <= {"weight", "bias", "eps"} and actual_parameter_names: + parameter_names = actual_parameter_names + + def make_binding( + *, + local_binding_index: int, + binding_kind: str, + logical_name: str, + binding_index: int, + source_bindings: tuple[str, ...] = (), + ) -> TemporalExecutorTensorBinding: + return TemporalExecutorTensorBinding( + direction="reverse", + executor_row_index=int(executor_row_index), + executor_id=int(executor_row.executor_id), + executor_name=executor_row.executor_name, + surface=executor_row.surface, + bucket_ordinal=int(executor_row.bucket_ordinal), + receiver_start=int(executor_row.receiver_start), + receiver_count=int(executor_row.receiver_count), + local_binding_index=int(local_binding_index), + primitive_row_index=int(executor_row.primitive_row_start), + primitive=primitive, + binding_index=int(binding_index), + binding_kind=binding_kind, + logical_name=logical_name, + source_bindings=source_bindings, + ) + + synthetic_binding_index = int(next_synthetic_binding_index) + + def transition_binding( + logical_name: str, + *, + binding_kinds: tuple[str, ...], + ) -> TemporalTensorBindingRow | None: + return _find_transition_tensor_binding( + table, + bucket_ordinal=int(executor_row.bucket_ordinal), + logical_name=logical_name, + binding_kinds=binding_kinds, + ) + + def grad_logical_name(name: str) -> str: + return "grad_" + str(name).removeprefix("next_") + + def grad_value_logical_name(name: str) -> str: + return "grad_" + str(name) + + def grad_binding( + logical_name: str, + *, + source_bindings: tuple[str, ...] = (), + ) -> tuple[int, tuple[str, ...]]: + nonlocal synthetic_binding_index + key = (int(executor_row.bucket_ordinal), str(logical_name)) + binding_index = grad_binding_indices.get(key) + if binding_index is None: + binding_index = synthetic_binding_index + synthetic_binding_index += 1 + grad_binding_indices[key] = int(binding_index) + return int(binding_index), source_bindings + + def actual_input_name(schema_name: str) -> str: + if schema_name == "input" and primitive_row.inputs: + return str(primitive_row.inputs[0]) + if schema_name == "output" and primitive_row.outputs: + return str(primitive_row.outputs[0]) + if schema_name == "grad_output" and primitive_row.outputs: + output_name = str(primitive_row.outputs[0]) + if output_name.startswith("next_"): + return "grad_next_" + output_name.removeprefix("next_") + return grad_logical_name(output_name) + return str(schema_name) + + def actual_output_name(schema_name: str, local_index: int) -> str: + if schema_name == "grad_input" and primitive_row.inputs: + return grad_value_logical_name(str(primitive_row.inputs[0])) + if schema_name.startswith("grad_"): + parameter_name = schema_name.removeprefix("grad_") + try: + parameter_index = parameter_names_for_schema.index(parameter_name) + except ValueError: + return str(schema_name) + if parameter_index < len(parameter_names): + return grad_logical_name(parameter_names[parameter_index]) + return str(schema_name) + + def seed_source_for_grad(logical_name: str) -> tuple[str, ...]: + if logical_name == "grad_public_y": + return ("reverse_seed:grad_public_y",) + if logical_name.startswith("grad_next_"): + return (f"reverse_seed:{logical_name}",) + return (f"reverse_internal:{logical_name}",) + + result: list[TemporalExecutorTensorBinding] = [] + for local_binding_index, logical_name in enumerate(input_names): + actual_logical_name = actual_input_name(str(logical_name)) + if actual_logical_name.startswith("grad_"): + binding_index, source_bindings = grad_binding( + actual_logical_name, + source_bindings=seed_source_for_grad(actual_logical_name), + ) + else: + row = transition_binding( + actual_logical_name, + binding_kinds=("input", "output"), + ) + if row is None: + binding_index = synthetic_binding_index + synthetic_binding_index += 1 + source_bindings = (f"reverse_runtime:{actual_logical_name}",) + else: + binding_index = int(row.binding_index) + source_bindings = tuple(row.source_bindings) + result.append( + make_binding( + local_binding_index=int(local_binding_index), + binding_kind="input", + logical_name=actual_logical_name, + binding_index=int(binding_index), + source_bindings=source_bindings, + ) + ) + for local_binding_index, logical_name in enumerate(parameter_names): + row = transition_binding( + str(logical_name), + binding_kinds=("parameter",), + ) + if row is None: + continue + binding_index = int(row.binding_index) + source_bindings = tuple(row.source_bindings) + result.append( + make_binding( + local_binding_index=int(local_binding_index), + binding_kind="parameter", + logical_name=logical_name, + binding_index=int(binding_index), + source_bindings=source_bindings, + ) + ) + for local_binding_index, logical_name in enumerate(output_names): + actual_logical_name = actual_output_name(str(logical_name), int(local_binding_index)) + if actual_logical_name.startswith("grad_"): + binding_index, source_bindings = grad_binding(actual_logical_name) + else: + binding_index = synthetic_binding_index + synthetic_binding_index += 1 + source_bindings = () + result.append( + make_binding( + local_binding_index=int(local_binding_index), + binding_kind="output", + logical_name=actual_logical_name, + binding_index=int(binding_index), + source_bindings=source_bindings, + ) + ) + return tuple(result), int(synthetic_binding_index) + + +def _find_transition_tensor_binding( + table: TemporalPrimitiveTablePlan, + *, + bucket_ordinal: int, + logical_name: str, + binding_kinds: tuple[str, ...], +) -> TemporalTensorBindingRow | None: + matches = tuple( + binding + for binding in table.tensor_bindings + if binding.surface == "transition" + and int(binding.bucket_ordinal) == int(bucket_ordinal) + and binding.logical_name == str(logical_name) + and binding.binding_kind in binding_kinds + ) + if not matches: + return None + return sorted( + matches, + key=lambda binding: ( + 0 if binding.binding_kind == "output" else 1, + int(binding.row_index), + int(binding.binding_index), + ), + )[0] + + +def _tensor_bindings_for_executor_row( + table: TemporalPrimitiveTablePlan, + executor_row: TemporalForwardExecutorRow | TemporalReverseExecutorRow, +) -> tuple[TemporalTensorBindingRow, ...]: + row_start = int(executor_row.primitive_row_start) + row_stop = row_start + int(executor_row.primitive_row_count) + parameter_names = {str(parameter) for parameter in executor_row.parameter_bindings} + selected: list[TemporalTensorBindingRow] = [] + seen: set[int] = set() + for binding in table.tensor_bindings: + in_row_span = row_start <= int(binding.row_index) < row_stop + is_executor_parameter = ( + binding.binding_kind == "parameter" + and binding.surface == executor_row.surface + and int(binding.bucket_ordinal) == int(executor_row.bucket_ordinal) + and binding.logical_name in parameter_names + ) + if not (in_row_span or is_executor_parameter): + continue + if int(binding.binding_index) in seen: + continue + seen.add(int(binding.binding_index)) + selected.append(binding) + return tuple( + sorted( + selected, + key=lambda binding: ( + 0 if binding.binding_kind == "input" else 1 if binding.binding_kind == "parameter" else 2, + int(binding.row_index), + int(binding.binding_index), + ), + ) + ) + + +__all__ = [ + "TemporalExecutorBindingBlocker", + "TemporalExecutorBindingPlan", + "TemporalExecutorDirection", + "TemporalExecutorTensorBinding", + "TemporalTransitionParamGradBinding", + "TemporalTransitionParamGradBindingPlan", + "build_temporal_forward_executor_binding_plan", + "build_temporal_reverse_executor_binding_plan", + "build_temporal_transition_param_grad_binding_plan", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py new file mode 100644 index 00000000..9a626029 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py @@ -0,0 +1,1040 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from cortical.fabric.backend.message_rules import ( + MessageRuleNativeExecutorSpec, + build_message_rule_backend_spec, + compile_message_rule, + ordered_message_rule_backend_spec_types, +) +from cortical.fabric.backend.readout_rules import ( + ReadoutRuleNativeExecutorSpec, + build_readout_rule_backend_spec, + compile_readout_rule, + default_readout_rule_ir, + readout_rule_native_executor, + registered_readout_rule_backend_spec_lowering_kinds, +) +from cortical.fabric.backend.cuda.transition_execution.registry import ( + TransitionExecutorStrategySpec, + registered_transition_forward_strategy_specs, + registered_transition_reverse_strategy_specs, +) + +from .row_groups import ( + TEMPORAL_MESSAGE_BUCKET_ORDINAL, + TEMPORAL_READOUT_BUCKET_ORDINAL, + TemporalRowGroupSchema, + canonical_temporal_row_group, + pattern_temporal_row_group, + surface_for_temporal_row, +) + + +@dataclass(frozen=True) +class TemporalPrimitiveRowPattern: + primitive: str + parameter_inputs: tuple[str, ...] = () + attribute_constraints: tuple[tuple[str, str], ...] = () + + @property + def signature(self) -> tuple[str, tuple[str, ...], tuple[tuple[str, str], ...]]: + return self.primitive, self.parameter_inputs, self.attribute_constraints + + +@dataclass(frozen=True) +class TemporalProgramAccessPattern: + access_name: str + logical_name: str + binding_kind: str = "parameter" + required: bool = True + access_opcode: int = 0 + + @property + def stable_access_opcode(self) -> int: + if int(self.access_opcode) <= 0: + raise RuntimeError(f"Registered temporal program access has no opcode: {self.access_name!r}") + return int(self.access_opcode) + + +TemporalMessageParamGradSource = Literal["recurrent_query_grad", "boundary_extra_output"] + + +@dataclass(frozen=True) +class TemporalMessageParamGradOutputPattern: + logical_name: str + source: TemporalMessageParamGradSource + source_index: int = 0 + + @property + def summary(self) -> str: + return f"{self.logical_name}:{self.source}[{int(self.source_index)}]" + + +@dataclass(frozen=True) +class TemporalForwardExecutorPattern: + executor_id: int + executor_name: str + surface: str + bucket_ordinal: int | None + row_pattern: tuple[TemporalPrimitiveRowPattern, ...] + implementation_contract: str + strategy_id: str = "" + handler_kind_opcode: int | None = None + handler_capabilities: tuple[str, ...] = () + handler_effects: tuple[str, ...] = () + owner: str = "fabric_cuda_sequence_surface" + strategy_version: int = 1 + row_schema_version: int = 1 + tensor_binding_schema_version: int = 1 + metadata_schema_version: int = 1 + cuda_kernel_abi_version: int = 1 + legality_predicate: str = "match_structural_row_signature" + cost_model: str = "registered_executor_static_priority" + runtime_entrypoint: str = "registered_temporal_fused_forward_program_cuda" + native_callable: str = "" + cxx_entrypoints: tuple[str, ...] = () + cxx_entrypoint_phases: tuple[str, ...] = () + required_effects: tuple[str, ...] = () + match_effects: tuple[str, ...] = () + required_layouts: tuple[str, ...] = ("contiguous",) + supported_dtypes: tuple[str, ...] = ("float32",) + supported_devices: tuple[str, ...] = ("cuda",) + workspace: str = "planner_assigned" + aliasing: str = "planner_owned_no_secret_retention" + saved_tensor_policy: str = "planner_owned_store_or_recompute" + gradient_accumulation: str = "compiler_binding_owned" + determinism: str = "parity_guarded" + tolerance_class: str = "fabric_cuda_training_parity" + demotion_policy: str = "typed_reject_or_fail_closed" + audit_metadata_schema: str = "temporal_strategy_metadata_v1" + verified_rewrite_required: bool = False + program_accesses: tuple[TemporalProgramAccessPattern, ...] = () + state_carry_rules: tuple[tuple[str, str], ...] = () + parameter_reducer_kind: str = "" + message_param_grad_outputs: tuple[TemporalMessageParamGradOutputPattern, ...] = () + + @property + def row_signature(self) -> tuple[tuple[str, tuple[str, ...], tuple[tuple[str, str], ...]], ...]: + return tuple(row.signature for row in self.row_pattern) + + @property + def row_group_schema(self) -> TemporalRowGroupSchema: + return pattern_temporal_row_group( + surface=self.surface, + bucket_ordinal=self.bucket_ordinal, + rows=self.row_pattern, + match_effects=self.match_effects, + ) + + @property + def stable_strategy_id(self) -> str: + return self.strategy_id or self.executor_name + + @property + def stable_native_callable_id(self) -> str: + return self.native_callable or self.implementation_contract + + @property + def stable_handler_kind_opcode(self) -> int: + return int(self.executor_id if self.handler_kind_opcode is None else self.handler_kind_opcode) + + @property + def stable_handler_capabilities(self) -> tuple[str, ...]: + if self.handler_capabilities: + return self.handler_capabilities + if self.surface == "message": + return ("message_carrier",) + if self.surface == "readout": + return ("readout",) + if self.surface == "transition": + return ("transition",) + return () + + @property + def stable_handler_effects(self) -> tuple[str, ...]: + if self.handler_effects: + return self.handler_effects + if self.surface == "message": + return "state_read", "parameter_read", "message_emit" + if self.surface == "readout": + return "state_read", "parameter_read", "output_emit" + if self.surface == "transition": + return "state_read", "message_read", "state_write", "tape_policy" + return () + + @property + def summary(self) -> str: + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + primitives = "+".join(row.primitive for row in self.row_pattern) + return ( + f"executor={self.executor_name}" + f",strategy_id={self.stable_strategy_id}" + f",strategy_version={int(self.strategy_version)}" + f",id={int(self.executor_id)}" + f",surface={self.surface}" + f",bucket={bucket}" + f",rows={primitives}" + f",legality={self.legality_predicate}" + f",cost_model={self.cost_model}" + f",runtime={self.runtime_entrypoint}" + f",native_callable={self.stable_native_callable_id}" + f",cxx_entrypoints={'+'.join(self.cxx_entrypoints) if self.cxx_entrypoints else '-'}" + f",cxx_entrypoint_phases={'+'.join(self.cxx_entrypoint_phases) if self.cxx_entrypoint_phases else '-'}" + f",handler_kind={int(self.stable_handler_kind_opcode)}" + f",handler_capabilities={'+'.join(self.stable_handler_capabilities) if self.stable_handler_capabilities else '-'}" + f",handler_effects={'+'.join(self.stable_handler_effects) if self.stable_handler_effects else '-'}" + f",abi=rows{int(self.row_schema_version)}" + f"/tensor{int(self.tensor_binding_schema_version)}" + f"/metadata{int(self.metadata_schema_version)}" + f"/cuda{int(self.cuda_kernel_abi_version)}" + f",effects={'+'.join(self.required_effects) if self.required_effects else '-'}" + f",program_accesses={'+'.join(access.access_name for access in self.program_accesses) if self.program_accesses else '-'}" + f",state_carry_rules={len(self.state_carry_rules)}" + f",parameter_reducer={self.parameter_reducer_kind or '-'}" + f",message_param_grad_outputs={'+'.join(output.summary for output in self.message_param_grad_outputs) if self.message_param_grad_outputs else '-'}" + f",rewrite_required={int(bool(self.verified_rewrite_required))}" + f",implementation_contract={self.implementation_contract}" + ) + + +@dataclass(frozen=True) +class TemporalReverseExecutorPattern: + executor_id: int + executor_name: str + surface: str + bucket_ordinal: int | None + row_pattern: tuple[TemporalPrimitiveRowPattern, ...] + implementation_contract: str + strategy_id: str = "" + handler_kind_opcode: int | None = None + handler_capabilities: tuple[str, ...] = () + handler_effects: tuple[str, ...] = () + owner: str = "fabric_cuda_sequence_surface" + strategy_version: int = 1 + row_schema_version: int = 1 + tensor_binding_schema_version: int = 1 + metadata_schema_version: int = 1 + cuda_kernel_abi_version: int = 1 + legality_predicate: str = "match_structural_row_signature" + cost_model: str = "registered_executor_static_priority" + runtime_entrypoint: str = "registered_reverse_executor_bindings" + native_callable: str = "" + cxx_entrypoints: tuple[str, ...] = () + cxx_entrypoint_phases: tuple[str, ...] = () + required_effects: tuple[str, ...] = () + match_effects: tuple[str, ...] = () + required_layouts: tuple[str, ...] = ("contiguous",) + supported_dtypes: tuple[str, ...] = ("float32",) + supported_devices: tuple[str, ...] = ("cuda",) + workspace: str = "planner_assigned" + aliasing: str = "planner_owned_no_secret_retention" + saved_tensor_policy: str = "planner_owned_store_or_recompute" + gradient_accumulation: str = "compiler_binding_owned" + determinism: str = "parity_guarded" + tolerance_class: str = "fabric_cuda_training_parity" + demotion_policy: str = "typed_reject_or_fail_closed" + audit_metadata_schema: str = "temporal_strategy_metadata_v1" + verified_rewrite_required: bool = False + program_accesses: tuple[TemporalProgramAccessPattern, ...] = () + state_carry_rules: tuple[tuple[str, str], ...] = () + parameter_reducer_kind: str = "" + message_param_grad_outputs: tuple[TemporalMessageParamGradOutputPattern, ...] = () + + @property + def row_signature(self) -> tuple[tuple[str, tuple[str, ...], tuple[tuple[str, str], ...]], ...]: + return tuple(row.signature for row in self.row_pattern) + + @property + def row_group_schema(self) -> TemporalRowGroupSchema: + return pattern_temporal_row_group( + surface=self.surface, + bucket_ordinal=self.bucket_ordinal, + rows=self.row_pattern, + match_effects=self.match_effects, + ) + + @property + def stable_strategy_id(self) -> str: + return self.strategy_id or self.executor_name + + @property + def stable_native_callable_id(self) -> str: + return self.native_callable or self.implementation_contract + + @property + def stable_handler_kind_opcode(self) -> int: + return int(self.executor_id if self.handler_kind_opcode is None else self.handler_kind_opcode) + + @property + def stable_handler_capabilities(self) -> tuple[str, ...]: + if self.handler_capabilities: + return self.handler_capabilities + if self.surface == "message": + return ("message_carrier",) + if self.surface == "readout": + return ("readout",) + if self.surface == "transition": + return ("transition",) + return () + + @property + def stable_handler_effects(self) -> tuple[str, ...]: + if self.handler_effects: + return self.handler_effects + if self.surface in {"message", "readout"}: + return "grad_read", "parameter_grad_emit" + if self.surface == "transition": + return "grad_read", "state_write", "parameter_grad_emit", "tape_policy" + return () + + @property + def summary(self) -> str: + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + primitives = "+".join(row.primitive for row in self.row_pattern) + return ( + f"reverse_executor={self.executor_name}" + f",strategy_id={self.stable_strategy_id}" + f",strategy_version={int(self.strategy_version)}" + f",id={int(self.executor_id)}" + f",surface={self.surface}" + f",bucket={bucket}" + f",rows={primitives}" + f",legality={self.legality_predicate}" + f",cost_model={self.cost_model}" + f",runtime={self.runtime_entrypoint}" + f",native_callable={self.stable_native_callable_id}" + f",cxx_entrypoints={'+'.join(self.cxx_entrypoints) if self.cxx_entrypoints else '-'}" + f",cxx_entrypoint_phases={'+'.join(self.cxx_entrypoint_phases) if self.cxx_entrypoint_phases else '-'}" + f",handler_kind={int(self.stable_handler_kind_opcode)}" + f",handler_capabilities={'+'.join(self.stable_handler_capabilities) if self.stable_handler_capabilities else '-'}" + f",handler_effects={'+'.join(self.stable_handler_effects) if self.stable_handler_effects else '-'}" + f",abi=rows{int(self.row_schema_version)}" + f"/tensor{int(self.tensor_binding_schema_version)}" + f"/metadata{int(self.metadata_schema_version)}" + f"/cuda{int(self.cuda_kernel_abi_version)}" + f",effects={'+'.join(self.required_effects) if self.required_effects else '-'}" + f",program_accesses={'+'.join(access.access_name for access in self.program_accesses) if self.program_accesses else '-'}" + f",state_carry_rules={len(self.state_carry_rules)}" + f",parameter_reducer={self.parameter_reducer_kind or '-'}" + f",message_param_grad_outputs={'+'.join(output.summary for output in self.message_param_grad_outputs) if self.message_param_grad_outputs else '-'}" + f",rewrite_required={int(bool(self.verified_rewrite_required))}" + f",implementation_contract={self.implementation_contract}" + ) + + +TemporalExecutorDirection = Literal["forward", "reverse"] + + +@dataclass(frozen=True) +class TemporalExecutorStrategyRegistry: + forward_strategies: tuple[TemporalForwardExecutorPattern, ...] + reverse_strategies: tuple[TemporalReverseExecutorPattern, ...] + + def __post_init__(self) -> None: + self._validate_direction("forward", self.forward_strategies) + self._validate_direction("reverse", self.reverse_strategies) + self._validate_access_opcodes() + + def forward_patterns(self) -> tuple[TemporalForwardExecutorPattern, ...]: + return self.forward_strategies + + def reverse_patterns(self) -> tuple[TemporalReverseExecutorPattern, ...]: + return self.reverse_strategies + + def all_patterns(self) -> tuple[TemporalForwardExecutorPattern | TemporalReverseExecutorPattern, ...]: + return (*self.forward_strategies, *self.reverse_strategies) + + def match_forward( + self, + *, + surface: str, + bucket_ordinal: int, + rows: tuple[object, ...], + ) -> TemporalForwardExecutorPattern | None: + return self._match( + direction="forward", + surface=surface, + bucket_ordinal=bucket_ordinal, + rows=rows, + ) + + def match_reverse( + self, + *, + surface: str, + bucket_ordinal: int, + rows: tuple[object, ...], + ) -> TemporalReverseExecutorPattern | None: + return self._match( + direction="reverse", + surface=surface, + bucket_ordinal=bucket_ordinal, + rows=rows, + ) + + def forward_pattern_for_executor( + self, + *, + surface: str, + executor_name: str, + ) -> TemporalForwardExecutorPattern: + matches = tuple( + pattern + for pattern in self.forward_strategies + if pattern.surface == str(surface) and pattern.executor_name == str(executor_name) + ) + if len(matches) != 1: + raise RuntimeError( + "Registered temporal forward strategy registry has no unique executor strategy: " + f"surface={surface!r}; executor={executor_name!r}; matches={len(matches)}" + ) + return matches[0] + + def reverse_pattern_for_executor( + self, + *, + surface: str, + executor_name: str, + ) -> TemporalReverseExecutorPattern: + matches = tuple( + pattern + for pattern in self.reverse_strategies + if pattern.surface == str(surface) and pattern.executor_name == str(executor_name) + ) + if len(matches) != 1: + raise RuntimeError( + "Registered temporal reverse strategy registry has no unique executor strategy: " + f"surface={surface!r}; executor={executor_name!r}; matches={len(matches)}" + ) + return matches[0] + + def strategy_summaries(self) -> tuple[str, ...]: + return tuple(pattern.summary for pattern in self.all_patterns()) + + def _match( + self, + *, + direction: TemporalExecutorDirection, + surface: str, + bucket_ordinal: int, + rows: tuple[object, ...], + ) -> TemporalForwardExecutorPattern | TemporalReverseExecutorPattern | None: + candidate = canonical_temporal_row_group( + surface=surface, + bucket_ordinal=int(bucket_ordinal), + rows=rows, + ) + patterns = self.forward_strategies if direction == "forward" else self.reverse_strategies + for pattern in patterns: + if pattern.surface != surface: + continue + if pattern.bucket_ordinal is not None and int(pattern.bucket_ordinal) != int(bucket_ordinal): + continue + if pattern.row_group_schema.matches(candidate): + return pattern + return None + + @staticmethod + def _validate_direction( + direction: TemporalExecutorDirection, + patterns: tuple[TemporalForwardExecutorPattern | TemporalReverseExecutorPattern, ...], + ) -> None: + seen_strategy_ids: set[str] = set() + seen_executor_keys: set[ + tuple[str, str, int, tuple[tuple[str, tuple[str, ...], tuple[tuple[str, str], ...]], ...]] + ] = set() + for pattern in patterns: + strategy_id = pattern.stable_strategy_id + if not strategy_id.startswith(f"{direction}."): + raise RuntimeError( + "Registered temporal executor strategy id must include its direction prefix: " + f"direction={direction}; strategy_id={strategy_id!r}" + ) + if strategy_id in seen_strategy_ids: + raise RuntimeError( + "Registered temporal executor strategy ids must be unique: " + f"direction={direction}; strategy_id={strategy_id!r}" + ) + seen_strategy_ids.add(strategy_id) + executor_key = (pattern.surface, pattern.executor_name, int(pattern.executor_id), pattern.row_signature) + if executor_key in seen_executor_keys: + raise RuntimeError( + "Registered temporal executor strategy rows must have unique executor keys: " + f"direction={direction}; surface={pattern.surface}; executor={pattern.executor_name}; " + f"executor_id={int(pattern.executor_id)}" + ) + seen_executor_keys.add(executor_key) + if not pattern.row_signature: + raise RuntimeError( + "Registered temporal executor strategy must declare a row signature: " + f"direction={direction}; strategy_id={strategy_id!r}" + ) + if not pattern.implementation_contract.startswith("registered_"): + raise RuntimeError( + "Registered temporal executor strategy must use a compiler-owned implementation contract: " + f"direction={direction}; strategy_id={strategy_id!r}; " + f"contract={pattern.implementation_contract!r}" + ) + if not pattern.stable_native_callable_id.startswith(f"native.{direction}."): + raise RuntimeError( + "Registered temporal executor strategy must declare a direction-scoped native callable id: " + f"direction={direction}; strategy_id={strategy_id!r}; " + f"native_callable={pattern.stable_native_callable_id!r}" + ) + for access in pattern.program_accesses: + _ = access.stable_access_opcode + if pattern.surface in {"message", "readout"} and not pattern.cxx_entrypoints: + raise RuntimeError( + "Registered temporal message/readout strategies must declare native C++ entrypoints: " + f"direction={direction}; strategy_id={strategy_id!r}" + ) + if pattern.cxx_entrypoint_phases and len(pattern.cxx_entrypoint_phases) != len(pattern.cxx_entrypoints): + raise RuntimeError( + "Registered temporal strategy C++ entrypoint phases must match entrypoint arity: " + f"direction={direction}; strategy_id={strategy_id!r}; " + f"phases={pattern.cxx_entrypoint_phases!r}; entrypoints={pattern.cxx_entrypoints!r}" + ) + if len(set(pattern.cxx_entrypoint_phases)) != len(pattern.cxx_entrypoint_phases): + raise RuntimeError( + "Registered temporal strategy C++ entrypoint phases must be unique: " + f"direction={direction}; strategy_id={strategy_id!r}; phases={pattern.cxx_entrypoint_phases!r}" + ) + if pattern.surface == "message": + expected_phase_contracts = _allowed_message_cxx_entrypoint_phases(direction) + if tuple(pattern.cxx_entrypoint_phases) not in expected_phase_contracts: + raise RuntimeError( + "Registered temporal message strategies must declare named C++ entrypoint phases: " + f"direction={direction}; strategy_id={strategy_id!r}; " + f"expected={expected_phase_contracts!r}; actual={pattern.cxx_entrypoint_phases!r}" + ) + if pattern.surface == "readout": + expected_phases = _required_readout_cxx_entrypoint_phases(direction) + if tuple(pattern.cxx_entrypoint_phases) != expected_phases: + raise RuntimeError( + "Registered temporal readout strategies must declare named C++ entrypoint phases: " + f"direction={direction}; strategy_id={strategy_id!r}; " + f"expected={expected_phases!r}; actual={pattern.cxx_entrypoint_phases!r}" + ) + if direction == "reverse" and pattern.surface == "transition" and not pattern.cxx_entrypoints: + raise RuntimeError( + "Registered temporal reverse transition strategies must declare native C++ entrypoints: " + f"direction={direction}; strategy_id={strategy_id!r}" + ) + if pattern.parameter_reducer_kind and pattern.surface != "message": + raise RuntimeError( + "Registered temporal parameter reducer strategy metadata is only supported on message strategies: " + f"direction={direction}; strategy_id={strategy_id!r}; surface={pattern.surface!r}" + ) + if pattern.message_param_grad_outputs and not pattern.parameter_reducer_kind: + raise RuntimeError( + "Registered temporal message strategy declares parameter-gradient outputs without a reducer kind: " + f"direction={direction}; strategy_id={strategy_id!r}" + ) + + def _validate_access_opcodes(self) -> None: + access_opcodes: dict[str, int] = {} + opcode_names: dict[int, str] = {} + for pattern in self.all_patterns(): + for access in pattern.program_accesses: + opcode = int(access.stable_access_opcode) + name = str(access.access_name) + existing_opcode = access_opcodes.get(name) + if existing_opcode is not None and int(existing_opcode) != opcode: + raise RuntimeError( + "Registered temporal program access must use one stable opcode: " + f"access={name!r}; opcodes={(int(existing_opcode), opcode)!r}" + ) + existing_name = opcode_names.get(opcode) + if existing_name is not None and existing_name != name: + raise RuntimeError( + "Registered temporal program access opcodes must be unique: " + f"opcode={opcode}; names={(existing_name, name)!r}" + ) + access_opcodes[name] = opcode + opcode_names[opcode] = name + + +def _message_rule_primitive_row_pattern(rule_type: str) -> tuple[TemporalPrimitiveRowPattern, ...]: + spec = build_message_rule_backend_spec( + rule_type=str(rule_type), + kv_group_count=1, + cell_count=2, + ) + compiled = compile_message_rule(spec.to_ir()) + return tuple( + TemporalPrimitiveRowPattern( + op.primitive, + tuple(op.parameter_bindings), + ) + for op in compiled.primitive_ops + ) + + +def _required_message_cxx_entrypoint_phases(direction: str) -> tuple[str, ...]: + if direction == "forward": + return "bind", "recurrent_kv", "message" + if direction == "reverse": + return ( + "recurrent_kv_backward", + "recurrent_message_backward", + "initial_recurrent_kv_backward", + "boundary_kv_backward", + "recurrent_kv_forward_recompute", + ) + raise RuntimeError(f"Unknown registered temporal message executor direction {direction!r}") + + +def _allowed_message_cxx_entrypoint_phases(direction: str) -> tuple[tuple[str, ...], ...]: + required = _required_message_cxx_entrypoint_phases(direction) + if direction == "forward": + return ( + required, + (*required, "keyless_readout_message"), + (*required, "keyless_readout_message", "direct_keyless_readout_message"), + ( + *required, + "keyless_readout_message", + "direct_keyless_readout_message", + "stream_readout_message", + ), + ( + *required, + "keyless_readout_message", + "direct_keyless_readout_message", + "stream_readout_message", + "stream_transition_input", + ), + ) + return (required,) + + +def _required_readout_cxx_entrypoint_phases(direction: str) -> tuple[str, ...]: + if direction == "forward": + return "bind", "message", "projection", "projection_into" + if direction == "reverse": + return "readout_backward", "output_message_backward" + raise RuntimeError(f"Unknown registered temporal readout executor direction {direction!r}") + + +def _message_rule_program_accesses(rule_type: str) -> tuple[TemporalProgramAccessPattern, ...]: + spec = build_message_rule_backend_spec( + rule_type=str(rule_type), + kv_group_count=1, + cell_count=2, + ) + accesses: list[TemporalProgramAccessPattern] = [] + for tensor_spec in spec.static_tensors: + access_name = str(tensor_spec.program_access_name) + access_opcode = int(tensor_spec.program_access_opcode) + if not access_name: + continue + accesses.append( + TemporalProgramAccessPattern( + access_name, + str(tensor_spec.name), + access_opcode=access_opcode, + ) + ) + return tuple(accesses) + + +def _message_rule_parameter_reducer_kind(rule_type: str) -> str: + return str( + build_message_rule_backend_spec( + rule_type=str(rule_type), + kv_group_count=1, + cell_count=2, + ).parameter_reducer_kind + ) + + +def _message_rule_param_grad_outputs(rule_type: str) -> tuple[TemporalMessageParamGradOutputPattern, ...]: + spec = build_message_rule_backend_spec( + rule_type=str(rule_type), + kv_group_count=1, + cell_count=2, + ) + return tuple( + TemporalMessageParamGradOutputPattern( + output.logical_name, + output.source, + int(output.source_index), + ) + for output in spec.param_grad_outputs + ) + + +def _message_rule_native_executor( + rule_type: str, + direction: TemporalExecutorDirection, +) -> MessageRuleNativeExecutorSpec: + spec = build_message_rule_backend_spec( + rule_type=str(rule_type), + kv_group_count=1, + cell_count=2, + ) + matches = tuple(executor for executor in spec.native_executors if executor.direction == direction) + if len(matches) != 1: + raise RuntimeError( + "Registered message rule must declare one native executor for each supported direction: " + f"rule_type={rule_type!r}; direction={direction!r}; matches={len(matches)}" + ) + return matches[0] + + +def _message_rule_has_native_executor( + rule_type: str, + direction: TemporalExecutorDirection, +) -> bool: + spec = build_message_rule_backend_spec( + rule_type=str(rule_type), + kv_group_count=1, + cell_count=2, + ) + return any(executor.direction == direction for executor in spec.native_executors) + + +def _message_rule_forward_executor_pattern(rule_type: str) -> TemporalForwardExecutorPattern: + native_executor = _message_rule_native_executor(rule_type, "forward") + return TemporalForwardExecutorPattern( + executor_id=int(native_executor.executor_id), + executor_name=str(native_executor.executor_name), + surface="message", + bucket_ordinal=TEMPORAL_MESSAGE_BUCKET_ORDINAL, + row_pattern=_message_rule_primitive_row_pattern(rule_type), + implementation_contract=str(native_executor.implementation_contract), + strategy_id=str(native_executor.strategy_id), + native_callable=str(native_executor.native_callable), + cxx_entrypoints=tuple(str(item) for item in native_executor.cxx_entrypoints), + cxx_entrypoint_phases=tuple(str(item) for item in native_executor.cxx_entrypoint_phases), + strategy_version=int(native_executor.strategy_version), + required_effects=("state_read", "parameter_read", "message_emit"), + match_effects=("state_read", "parameter_read", "message_emit"), + program_accesses=_message_rule_program_accesses(rule_type), + ) + + +def _message_rule_reverse_executor_pattern(rule_type: str) -> TemporalReverseExecutorPattern: + native_executor = _message_rule_native_executor(rule_type, "reverse") + return TemporalReverseExecutorPattern( + executor_id=int(native_executor.executor_id), + executor_name=str(native_executor.executor_name), + surface="message", + bucket_ordinal=TEMPORAL_MESSAGE_BUCKET_ORDINAL, + row_pattern=_message_rule_primitive_row_pattern(rule_type), + implementation_contract=str(native_executor.implementation_contract), + strategy_id=str(native_executor.strategy_id), + native_callable=str(native_executor.native_callable), + cxx_entrypoints=tuple(str(item) for item in native_executor.cxx_entrypoints), + cxx_entrypoint_phases=tuple(str(item) for item in native_executor.cxx_entrypoint_phases), + strategy_version=int(native_executor.strategy_version), + required_effects=("grad_read", "message_grad_emit", "parameter_grad_emit"), + match_effects=("state_read", "parameter_read", "message_emit"), + program_accesses=_message_rule_program_accesses(rule_type), + parameter_reducer_kind=_message_rule_parameter_reducer_kind(rule_type), + message_param_grad_outputs=_message_rule_param_grad_outputs(rule_type), + ) + + +def _registered_message_rule_forward_executor_patterns() -> tuple[TemporalForwardExecutorPattern, ...]: + return tuple( + _message_rule_forward_executor_pattern(rule_type) + for rule_type in ordered_message_rule_backend_spec_types() + if _message_rule_has_native_executor(rule_type, "forward") + ) + + +def _registered_message_rule_reverse_executor_patterns() -> tuple[TemporalReverseExecutorPattern, ...]: + return tuple( + _message_rule_reverse_executor_pattern(rule_type) + for rule_type in ordered_message_rule_backend_spec_types() + if _message_rule_has_native_executor(rule_type, "reverse") + ) + + +def _readout_pool_for_lowering_kind(lowering_kind: str) -> str: + suffix = "_readout_project" + if not str(lowering_kind).endswith(suffix): + raise RuntimeError(f"Registered readout lowering kind has unsupported shape: {lowering_kind!r}") + return str(lowering_kind)[: -len(suffix)] + + +def _readout_rule_primitive_row_pattern(lowering_kind: str) -> tuple[TemporalPrimitiveRowPattern, ...]: + compiled = compile_readout_rule( + default_readout_rule_ir( + readout_pool=_readout_pool_for_lowering_kind(lowering_kind), + readout_slots=1, + ) + ) + return tuple( + TemporalPrimitiveRowPattern( + op.primitive, + tuple(op.parameter_inputs), + tuple( + (str(key), str(value)) + for key, value in op.attributes + if str(key) in {"lowering_kind", "pool", "output_boundary"} + ), + ) + for op in compiled.primitive_ops + ) + + +def _readout_rule_program_accesses(lowering_kind: str) -> tuple[TemporalProgramAccessPattern, ...]: + spec = build_readout_rule_backend_spec(lowering_kind=lowering_kind) + return tuple( + TemporalProgramAccessPattern( + access.access_name, + access.logical_name, + access.binding_kind, + access.required, + access.access_opcode, + ) + for access in ( + _readout_rule_program_access_from_static_tensor(static_tensor) + for static_tensor in spec.static_tensors + if static_tensor.program_access_name + ) + ) + + +def _readout_rule_program_access_from_static_tensor( + static_tensor: object, +) -> TemporalProgramAccessPattern: + return TemporalProgramAccessPattern( + str(getattr(static_tensor, "program_access_name")), + str(getattr(static_tensor, "name")), + access_opcode=int(getattr(static_tensor, "program_access_opcode")), + ) + + +def _readout_rule_native_executor_for_lowering( + lowering_kind: str, + direction: TemporalExecutorDirection, +) -> ReadoutRuleNativeExecutorSpec: + return readout_rule_native_executor( + lowering_kind=lowering_kind, + direction=direction, + ) + + +def _readout_rule_has_native_executor( + lowering_kind: str, + direction: TemporalExecutorDirection, +) -> bool: + spec = build_readout_rule_backend_spec(lowering_kind=lowering_kind) + return any(executor.direction == direction for executor in spec.native_executors) + + +def _readout_rule_forward_executor_pattern(lowering_kind: str) -> TemporalForwardExecutorPattern: + native_executor = _readout_rule_native_executor_for_lowering(lowering_kind, "forward") + return TemporalForwardExecutorPattern( + executor_id=int(native_executor.executor_id), + executor_name=str(native_executor.executor_name), + surface="readout", + bucket_ordinal=TEMPORAL_READOUT_BUCKET_ORDINAL, + row_pattern=_readout_rule_primitive_row_pattern(lowering_kind), + implementation_contract=str(native_executor.implementation_contract), + strategy_id=str(native_executor.strategy_id), + native_callable=str(native_executor.native_callable), + cxx_entrypoints=tuple(str(item) for item in native_executor.cxx_entrypoints), + cxx_entrypoint_phases=tuple(str(item) for item in native_executor.cxx_entrypoint_phases), + strategy_version=int(native_executor.strategy_version), + required_effects=("state_read", "parameter_read", "output_emit", "materialization_boundary"), + match_effects=("state_read", "parameter_read", "output_emit", "materialization_boundary"), + program_accesses=_readout_rule_program_accesses(lowering_kind), + ) + + +def _readout_rule_reverse_executor_pattern(lowering_kind: str) -> TemporalReverseExecutorPattern: + native_executor = _readout_rule_native_executor_for_lowering(lowering_kind, "reverse") + return TemporalReverseExecutorPattern( + executor_id=int(native_executor.executor_id), + executor_name=str(native_executor.executor_name), + surface="readout", + bucket_ordinal=TEMPORAL_READOUT_BUCKET_ORDINAL, + row_pattern=_readout_rule_primitive_row_pattern(lowering_kind), + implementation_contract=str(native_executor.implementation_contract), + strategy_id=str(native_executor.strategy_id), + native_callable=str(native_executor.native_callable), + cxx_entrypoints=tuple(str(item) for item in native_executor.cxx_entrypoints), + cxx_entrypoint_phases=tuple(str(item) for item in native_executor.cxx_entrypoint_phases), + strategy_version=int(native_executor.strategy_version), + required_effects=("grad_read", "message_grad_emit", "parameter_grad_emit"), + match_effects=("state_read", "parameter_read", "output_emit", "materialization_boundary"), + program_accesses=_readout_rule_program_accesses(lowering_kind), + ) + + +def _registered_readout_rule_forward_executor_patterns() -> tuple[TemporalForwardExecutorPattern, ...]: + return tuple( + _readout_rule_forward_executor_pattern(lowering_kind) + for lowering_kind in registered_readout_rule_backend_spec_lowering_kinds() + if _readout_rule_has_native_executor(lowering_kind, "forward") + ) + + +def _registered_readout_rule_reverse_executor_patterns() -> tuple[TemporalReverseExecutorPattern, ...]: + return tuple( + _readout_rule_reverse_executor_pattern(lowering_kind) + for lowering_kind in registered_readout_rule_backend_spec_lowering_kinds() + if _readout_rule_has_native_executor(lowering_kind, "reverse") + ) + + +def _transition_program_access_patterns( + spec: TransitionExecutorStrategySpec, +) -> tuple[TemporalProgramAccessPattern, ...]: + return tuple( + TemporalProgramAccessPattern( + access.access_name, + access.logical_name, + access.binding_kind, + access.required, + access.access_opcode, + ) + for access in spec.program_accesses + ) + + +def _transition_row_patterns(spec: TransitionExecutorStrategySpec) -> tuple[TemporalPrimitiveRowPattern, ...]: + return tuple(TemporalPrimitiveRowPattern(row.primitive, tuple(row.parameter_inputs)) for row in spec.row_pattern) + + +def _transition_forward_executor_pattern(spec: TransitionExecutorStrategySpec) -> TemporalForwardExecutorPattern: + return TemporalForwardExecutorPattern( + executor_id=int(spec.executor_id), + executor_name=str(spec.executor_name), + surface="transition", + bucket_ordinal=None, + row_pattern=_transition_row_patterns(spec), + implementation_contract=str(spec.implementation_contract), + strategy_id=str(spec.strategy_id), + native_callable=str(spec.native_callable), + cxx_entrypoints=tuple(str(entrypoint) for entrypoint in spec.cxx_entrypoints), + strategy_version=int(spec.strategy_version), + required_effects=tuple(str(effect) for effect in spec.required_effects), + match_effects=tuple(str(effect) for effect in spec.match_effects), + program_accesses=_transition_program_access_patterns(spec), + state_carry_rules=tuple((str(before), str(after)) for before, after in spec.state_carry_rules), + ) + + +def _transition_reverse_executor_pattern(spec: TransitionExecutorStrategySpec) -> TemporalReverseExecutorPattern: + return TemporalReverseExecutorPattern( + executor_id=int(spec.executor_id), + executor_name=str(spec.executor_name), + surface="transition", + bucket_ordinal=None, + row_pattern=_transition_row_patterns(spec), + implementation_contract=str(spec.implementation_contract), + strategy_id=str(spec.strategy_id), + native_callable=str(spec.native_callable), + cxx_entrypoints=tuple(str(entrypoint) for entrypoint in spec.cxx_entrypoints), + strategy_version=int(spec.strategy_version), + required_effects=tuple(str(effect) for effect in spec.required_effects), + match_effects=tuple(str(effect) for effect in spec.match_effects), + program_accesses=_transition_program_access_patterns(spec), + state_carry_rules=tuple((str(before), str(after)) for before, after in spec.state_carry_rules), + ) + + +def _registered_transition_forward_executor_patterns() -> tuple[TemporalForwardExecutorPattern, ...]: + return tuple(_transition_forward_executor_pattern(spec) for spec in registered_transition_forward_strategy_specs()) + + +def _registered_transition_reverse_executor_patterns() -> tuple[TemporalReverseExecutorPattern, ...]: + return tuple(_transition_reverse_executor_pattern(spec) for spec in registered_transition_reverse_strategy_specs()) + + +_REGISTERED_TEMPORAL_FORWARD_EXECUTOR_STRATEGIES = ( + *_registered_message_rule_forward_executor_patterns(), + *_registered_readout_rule_forward_executor_patterns(), + *_registered_transition_forward_executor_patterns(), +) + +_REGISTERED_TEMPORAL_REVERSE_EXECUTOR_STRATEGIES = ( + *_registered_message_rule_reverse_executor_patterns(), + *_registered_readout_rule_reverse_executor_patterns(), + *_registered_transition_reverse_executor_patterns(), +) + +_TEMPORAL_EXECUTOR_STRATEGY_REGISTRY = TemporalExecutorStrategyRegistry( + forward_strategies=_REGISTERED_TEMPORAL_FORWARD_EXECUTOR_STRATEGIES, + reverse_strategies=_REGISTERED_TEMPORAL_REVERSE_EXECUTOR_STRATEGIES, +) + + +def temporal_executor_strategy_registry() -> TemporalExecutorStrategyRegistry: + return _TEMPORAL_EXECUTOR_STRATEGY_REGISTRY + + +def temporal_forward_executor_patterns() -> tuple[TemporalForwardExecutorPattern, ...]: + return _TEMPORAL_EXECUTOR_STRATEGY_REGISTRY.forward_patterns() + + +def temporal_reverse_executor_patterns() -> tuple[TemporalReverseExecutorPattern, ...]: + return _TEMPORAL_EXECUTOR_STRATEGY_REGISTRY.reverse_patterns() + + +def temporal_row_signature( + rows: tuple[object, ...], +) -> tuple[tuple[str, tuple[str, ...], tuple[tuple[str, str], ...]], ...]: + return canonical_temporal_row_group(surface="*", bucket_ordinal=None, rows=rows).row_signature + + +def match_temporal_forward_executor_pattern( + *, + surface: str, + bucket_ordinal: int, + rows: tuple[object, ...], +) -> TemporalForwardExecutorPattern | None: + return _TEMPORAL_EXECUTOR_STRATEGY_REGISTRY.match_forward( + surface=surface, + bucket_ordinal=bucket_ordinal, + rows=rows, + ) + + +def match_temporal_reverse_executor_pattern( + *, + surface: str, + bucket_ordinal: int, + rows: tuple[object, ...], +) -> TemporalReverseExecutorPattern | None: + return _TEMPORAL_EXECUTOR_STRATEGY_REGISTRY.match_reverse( + surface=surface, + bucket_ordinal=bucket_ordinal, + rows=rows, + ) + + +def temporal_forward_executor_pattern_summaries() -> tuple[str, ...]: + return tuple(pattern.summary for pattern in temporal_forward_executor_patterns()) + + +def temporal_reverse_executor_pattern_summaries() -> tuple[str, ...]: + return tuple(pattern.summary for pattern in temporal_reverse_executor_patterns()) + + +__all__ = [ + "TemporalForwardExecutorPattern", + "TemporalMessageParamGradOutputPattern", + "TemporalPrimitiveRowPattern", + "TemporalProgramAccessPattern", + "TemporalReverseExecutorPattern", + "TemporalExecutorStrategyRegistry", + "TemporalRowGroupSchema", + "match_temporal_forward_executor_pattern", + "match_temporal_reverse_executor_pattern", + "surface_for_temporal_row", + "temporal_executor_strategy_registry", + "temporal_forward_executor_pattern_summaries", + "temporal_forward_executor_patterns", + "temporal_reverse_executor_pattern_summaries", + "temporal_reverse_executor_patterns", + "temporal_row_signature", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_plan.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_plan.py new file mode 100644 index 00000000..be4a8ecf --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_plan.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +import torch + +from .executor_bindings import ( + TemporalExecutorBindingPlan, + build_temporal_forward_executor_binding_plan, +) +from .strategy_selection import ( + TemporalStrategySelectionReport, + build_temporal_strategy_selection_report, +) +from .tables import ( + TemporalPrimitiveTablePlan, + temporal_forward_executor_rows, + temporal_forward_executor_rows_tensor, +) + + +@dataclass(frozen=True) +class TemporalForwardExecutablePlan: + forward_executor_rows: torch.Tensor + executor_binding_rows: torch.Tensor + executor_summaries: tuple[str, ...] + executor_binding_summaries: tuple[str, ...] + executor_binding_blockers: tuple[str, ...] + strategy_ids: tuple[str, ...] + strategy_candidate_summaries: tuple[str, ...] + strategy_legality_status: Literal["legal", "blocked"] + strategy_legality_reasons: tuple[str, ...] + runtime_entrypoint: str = "registered_temporal_fused_forward_program_cuda" + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "forward_executable_plan=compiler_owned", + f"runtime_entrypoint={self.runtime_entrypoint}", + f"executor_row_count={int(self.forward_executor_rows.shape[0])}", + f"executor_binding_row_count={int(self.executor_binding_rows.shape[0])}", + "strategy_ids=" + ",".join(self.strategy_ids), + f"strategy_legality_status={self.strategy_legality_status}", + *self.strategy_legality_reasons, + *self.executor_binding_blockers, + *self.executor_summaries, + *self.executor_binding_summaries, + ) + + +def build_temporal_forward_executable_plan( + table: TemporalPrimitiveTablePlan, + *, + forward_binding_plan: TemporalExecutorBindingPlan | None = None, + strategy_report: TemporalStrategySelectionReport | None = None, +) -> TemporalForwardExecutablePlan: + executor_rows = temporal_forward_executor_rows(table) + forward_binding_plan = ( + build_temporal_forward_executor_binding_plan(table) if forward_binding_plan is None else forward_binding_plan + ) + strategy_report = ( + build_temporal_strategy_selection_report( + table, + forward_binding_plan=forward_binding_plan, + directions=("forward",), + ) + if strategy_report is None + else strategy_report + ) + forward_candidates = tuple( + candidate for candidate in strategy_report.candidates if candidate.direction == "forward" + ) + strategy_ids = tuple( + candidate.strategy_id for candidate in forward_candidates if candidate.match_status == "matched" + ) + strategy_legality_reasons = tuple( + reason + for candidate in forward_candidates + if candidate.match_status == "matched" + for reason in candidate.legality_reasons + ) + strategy_legality_status: Literal["legal", "blocked"] = ( + "blocked" + if forward_binding_plan.has_blockers + or not strategy_ids + or any( + candidate.legality_status == "blocked" + for candidate in forward_candidates + if candidate.match_status == "matched" + ) + else "legal" + ) + return TemporalForwardExecutablePlan( + forward_executor_rows=temporal_forward_executor_rows_tensor(table), + executor_binding_rows=forward_binding_plan.rows, + executor_summaries=tuple(row.summary for row in executor_rows), + executor_binding_summaries=forward_binding_plan.summaries, + executor_binding_blockers=forward_binding_plan.blocker_summaries, + strategy_ids=strategy_ids, + strategy_candidate_summaries=tuple(candidate.summary for candidate in forward_candidates), + strategy_legality_status=strategy_legality_status, + strategy_legality_reasons=strategy_legality_reasons, + ) + + +__all__ = [ + "TemporalForwardExecutablePlan", + "build_temporal_forward_executable_plan", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py new file mode 100644 index 00000000..d8c103f1 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from .executor_patterns import ( + TemporalProgramAccessPattern, + TemporalForwardExecutorPattern, + TemporalReverseExecutorPattern, + temporal_executor_strategy_registry, +) + + +def temporal_program_access_opcode(access_name: str) -> int: + opcodes = { + int(access.stable_access_opcode) + for pattern in temporal_executor_strategy_registry().all_patterns() + for access in pattern.program_accesses + if str(access.access_name) == str(access_name) + } + if len(opcodes) != 1: + raise RuntimeError( + "Registered temporal program access has no unique strategy-owned opcode: " + f"access={access_name!r}; opcodes={tuple(sorted(opcodes))!r}" + ) + return next(iter(opcodes)) + + +def temporal_forward_program_access_rows_tensor( + *, + message_handles: tuple[Any, ...], + readout_handles: tuple[Any, ...], + transition_handles: tuple[Any, ...], +) -> torch.Tensor: + rows: list[list[int]] = [] + ordered_handles = ( + *_surface_handles(message_handles, surface="message"), + *_surface_handles(readout_handles, surface="readout"), + *sorted(transition_handles, key=lambda item: (int(item.bucket_ordinal), int(item.row_index))), + ) + for handle in ordered_handles: + pattern = _forward_pattern_for_handle(handle) + handle_rows = _program_access_rows_for_handle(handle, pattern.program_accesses) + if str(handle.surface) == "transition": + _extend_transition_program_access_rows(handle_rows, handle) + rows.extend(handle_rows) + if not rows: + return torch.empty((0, 6), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_forward_transition_state_carry_rows_tensor( + *, + transition_handles: tuple[Any, ...], +) -> torch.Tensor: + rows: list[list[int]] = [] + seen_rows: set[tuple[int, int, int]] = set() + handles_by_bucket: dict[int, list[Any]] = {} + for handle in sorted(transition_handles, key=lambda item: (int(item.bucket_ordinal), int(item.row_index))): + handles_by_bucket.setdefault(int(handle.bucket_ordinal), []).append(handle) + + def append_state_carry_row( + *, + bucket_ordinal: int, + input_binding: Any, + output_binding: Any, + ) -> None: + row_key = ( + int(bucket_ordinal), + int(input_binding.binding_index), + int(output_binding.binding_index), + ) + if row_key in seen_rows: + return + seen_rows.add(row_key) + rows.append([row_key[0], row_key[1], row_key[2]]) + + for bucket_ordinal, bucket_handles in handles_by_bucket.items(): + input_by_logical = { + str(binding.logical_name): binding + for handle in bucket_handles + for binding in handle.bindings + if binding.binding_kind == "input" + } + output_by_logical = { + str(binding.logical_name): binding + for handle in bucket_handles + for binding in handle.bindings + if binding.binding_kind == "output" + } + produced_logicals = set(output_by_logical) + for input_name, input_binding in input_by_logical.items(): + if input_name == "aggregated_message" or input_name in produced_logicals: + continue + output_binding = output_by_logical.get(f"next_{input_name}") + if output_binding is None: + continue + append_state_carry_row( + bucket_ordinal=int(bucket_ordinal), + input_binding=input_binding, + output_binding=output_binding, + ) + for handle in bucket_handles: + pattern = _forward_pattern_for_handle(handle) + for input_name, output_name in pattern.state_carry_rules: + input_binding = input_by_logical.get(str(input_name)) + output_binding = output_by_logical.get(str(output_name)) + if output_binding is None: + if input_binding is not None: + raise RuntimeError( + "Registered forward program state carry rule has no output binding: " + f"executor={handle.executor_name}; bucket={int(handle.bucket_ordinal)}; " + f"input={input_name!r}; output={output_name!r}" + ) + continue + if input_binding is None: + raise RuntimeError( + "Registered forward program state carry rule has no input binding: " + f"executor={handle.executor_name}; bucket={int(handle.bucket_ordinal)}; " + f"input={input_name!r}; output={output_name!r}" + ) + append_state_carry_row( + bucket_ordinal=int(bucket_ordinal), + input_binding=input_binding, + output_binding=output_binding, + ) + if not rows: + return torch.empty((0, 3), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_program_access_rows_tensor( + *, + message_handles: tuple[Any, ...], + readout_handles: tuple[Any, ...], +) -> torch.Tensor: + rows: list[list[int]] = [] + for handle in ( + *_surface_handles(message_handles, surface="message"), + *_surface_handles(readout_handles, surface="readout"), + ): + pattern = _reverse_pattern_for_handle(handle) + rows.extend(_program_access_rows_for_handle(handle, pattern.program_accesses)) + if not rows: + return torch.empty((0, 6), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def _surface_handles(handles: tuple[Any, ...], *, surface: str) -> tuple[Any, ...]: + ordered = tuple(sorted(handles, key=lambda item: int(item.row_index))) + if not ordered: + raise RuntimeError(f"Registered temporal program access requires at least one {surface} executor row") + unexpected = tuple(handle for handle in ordered if str(handle.surface) != str(surface)) + if unexpected: + raise RuntimeError( + "Registered temporal program access received executor rows for the wrong surface: " + f"expected={surface!r}; actual={tuple(str(handle.surface) for handle in unexpected)!r}" + ) + return ordered + + +def _program_access_rows_for_handle( + handle: Any, + accesses: tuple[TemporalProgramAccessPattern, ...], +) -> list[list[int]]: + rows: list[list[int]] = [] + for access_slot, access in enumerate(accesses): + binding = _first_binding( + handle, + logical_name=str(access.logical_name), + binding_kind=str(access.binding_kind), + required=bool(access.required), + ) + if binding is None: + continue + rows.append( + [ + int(access_slot), + int(handle.row_index), + int(handle.bucket_ordinal), + int(binding.binding_index), + 1 if bool(access.required) else 0, + int(access.stable_access_opcode), + ] + ) + return rows + + +def _extend_transition_program_access_rows(rows: list[list[int]], handle: Any) -> None: + existing_opcodes = {int(row[5]) for row in rows} + _append_transition_program_access_row( + rows, + handle, + access_slot=0, + access_name="transition_aggregated_message_input", + logical_name="aggregated_message", + binding_kind="input", + existing_opcodes=existing_opcodes, + ) + _append_transition_program_access_row( + rows, + handle, + access_slot=1, + access_name="transition_public_state_output", + logical_name="public_y", + binding_kind="output", + existing_opcodes=existing_opcodes, + ) + + +def _append_transition_program_access_row( + rows: list[list[int]], + handle: Any, + *, + access_slot: int, + access_name: str, + logical_name: str, + binding_kind: str, + existing_opcodes: set[int], +) -> None: + access_opcode = temporal_program_access_opcode(access_name) + if int(access_opcode) in existing_opcodes: + return + binding = _first_binding( + handle, + logical_name=logical_name, + binding_kind=binding_kind, + required=False, + ) + if binding is None: + return + rows.append( + [ + int(access_slot), + int(handle.row_index), + int(handle.bucket_ordinal), + int(binding.binding_index), + 1, + int(access_opcode), + ] + ) + existing_opcodes.add(int(access_opcode)) + + +def _first_binding( + handle: Any, + *, + logical_name: str, + binding_kind: str, + required: bool = True, +) -> Any | None: + matches = tuple( + binding + for binding in handle.bindings + if binding.binding_kind == str(binding_kind) and binding.logical_name == str(logical_name) + ) + if not matches: + if not required: + return None + raise RuntimeError( + "Registered temporal program access has no compiler binding: " + f"surface={handle.surface}; bucket={int(handle.bucket_ordinal)}; " + f"executor={handle.executor_name}; kind={binding_kind}; logical={logical_name!r}" + ) + return sorted( + matches, + key=lambda binding: (int(binding.primitive_row_index), int(binding.binding_index)), + )[0] + + +def _forward_pattern_for_handle(handle: Any) -> TemporalForwardExecutorPattern: + return temporal_executor_strategy_registry().forward_pattern_for_executor( + surface=str(handle.surface), + executor_name=str(handle.executor_name), + ) + + +def _reverse_pattern_for_handle(handle: Any) -> TemporalReverseExecutorPattern: + return temporal_executor_strategy_registry().reverse_pattern_for_executor( + surface=str(handle.surface), + executor_name=str(handle.executor_name), + ) + + +__all__ = [ + "temporal_program_access_opcode", + "temporal_forward_program_access_rows_tensor", + "temporal_forward_transition_state_carry_rows_tensor", + "temporal_reverse_program_access_rows_tensor", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py new file mode 100644 index 00000000..b11d0c88 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py @@ -0,0 +1,2336 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import torch + +from .row_groups import surface_for_temporal_row, temporal_effects_for_row +from .primitive_registry import temporal_surface_opcode +from .reverse_artifacts import temporal_reverse_artifact_role_names +from .tables import TemporalPrimitiveTablePlan, TemporalTensorTableSlot + + +TemporalMemoryOwner = Literal["compiler_tensor_role_table", "compiler_primitive_row", "compiler_memory_policy"] +TemporalRuntimeArtifactMode = Literal["none", "store_step_artifacts", "recompute_step_artifacts"] +TemporalPhysicalStrategyKind = Literal["stage_materialized", "streaming_step_producer_consumer"] +TemporalPhysicalStrategyStatus = Literal["active", "candidate", "blocked"] +TemporalPhysicalStrategyResetPolicy = Literal["absent", "present", "unknown"] +TemporalPhysicalStrategyOutputBoundary = Literal["terminal", "sequence"] + +_REQUIRED_MEMORY_POLICY_EFFECTS = ( + "local_seed_policy", + "metadata_policy", + "primitive_output_policy", + "tape_policy", + "alias_policy", + "recompute_window_policy", + "materialization_policy", + "cuda_graph_constraint", +) + + +@dataclass(frozen=True) +class TemporalMemoryPlanEntry: + row_index: int | None + bucket_ordinal: int + surface: str + tensor_role: str + tensor_class: str + layout: str + lifetime: str + workspace_class: str + alias_set: str + recompute_policy: str + effect: str + owner: TemporalMemoryOwner + + @property + def summary(self) -> str: + row = "*" if self.row_index is None else str(int(self.row_index)) + return ( + f"row={row}" + f",surface={self.surface}" + f",bucket={int(self.bucket_ordinal)}" + f",role={self.tensor_role}" + f",tensor={self.tensor_class}" + f",layout={self.layout}" + f",lifetime={self.lifetime}" + f",workspace={self.workspace_class}" + f",alias={self.alias_set}" + f",recompute={self.recompute_policy}" + f",effect={self.effect}" + f",owner={self.owner}" + ) + + +@dataclass(frozen=True) +class TemporalMemoryLivenessPlan: + entries: tuple[TemporalMemoryPlanEntry, ...] + workspace_policy: str + layout_policy: str + alias_policy: str + peak_workspace_estimate_bytes: int | None = None + + @property + def summaries(self) -> tuple[str, ...]: + return tuple(entry.summary for entry in self.entries) + + @property + def fingerprint(self) -> tuple[str, ...]: + return (*self.review_summary, *self.summaries) + + @property + def review_summary(self) -> tuple[str, ...]: + peak = "*" if self.peak_workspace_estimate_bytes is None else str(int(self.peak_workspace_estimate_bytes)) + return ( + "memory_liveness_plan=compiler_owned", + f"entry_count={len(self.entries)}", + f"workspace_policy={self.workspace_policy}", + f"layout_policy={self.layout_policy}", + f"alias_policy={self.alias_policy}", + f"peak_workspace_estimate_bytes={peak}", + ) + + +@dataclass(frozen=True) +class TemporalMemoryRuntimePolicy: + effect_policies: tuple[tuple[str, str], ...] + memory_row_indices: tuple[int, ...] + source: str + effect_row_indices: tuple[tuple[str, int], ...] = () + + @property + def effect_names(self) -> tuple[str, ...]: + return tuple(effect for effect, _policy in self.effect_policies) + + @property + def complete(self) -> bool: + effects = set(self.effect_names) + return all(effect in effects for effect in _REQUIRED_MEMORY_POLICY_EFFECTS) + + @property + def alias_allocation_enabled(self) -> bool: + return self.policy_for("alias_policy") == "scheduler_alias_policy" + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "memory_runtime_policy=compiler_executable", + f"policy_complete={int(self.complete)}", + f"policy_effects={_tuple_summary(self.effect_names)}", + f"memory_rows={_int_tuple_summary(self.memory_row_indices)}", + "policy_rows_by_effect=" + _effect_row_tuple_summary(self.effect_row_indices), + f"source={self.source}", + *(f"policy:{effect}={policy}" for effect, policy in self.effect_policies), + ) + + def policy_for(self, effect: str) -> str: + for policy_effect, policy in self.effect_policies: + if policy_effect == str(effect): + return policy + return "" + + def memory_row_for(self, effect: str) -> int: + for policy_effect, row_index in self.effect_row_indices: + if policy_effect == str(effect): + return int(row_index) + return -1 + + def require_complete(self) -> None: + missing = tuple(effect for effect in _REQUIRED_MEMORY_POLICY_EFFECTS if not self.policy_for(effect)) + if missing: + raise RuntimeError( + "Temporal runtime buffer planning requires compiler-owned memory policy rows; " + f"missing={_tuple_summary(missing)}" + ) + + +@dataclass(frozen=True) +class TemporalMemoryRuntimeSchedulePlan: + mode: TemporalRuntimeArtifactMode + physical_time_steps: int + checkpoint_stride: int + recompute_window_len: int + checkpoint_steps: tuple[int, ...] + backward_windows: tuple[tuple[int, int], ...] + store_step_artifacts: bool + checkpoint_owner: str + reverse_artifact_kind: str + output_materialization: str + output_physical_steps: tuple[int, ...] + scheduler_owner: str + primitive_output_policy: str + tape_policy: str + alias_policy: str + recompute_window_policy: str + materialization_policy: str + cuda_graph_constraint: str + local_seed_policy: str + metadata_policy: str + runtime_policy: TemporalMemoryRuntimePolicy + + @property + def fingerprint(self) -> tuple[str, ...]: + return self.review_summary + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "memory_runtime_schedule_plan=compiler_executable", + f"mode={self.mode}", + f"physical_time_steps={int(self.physical_time_steps)}", + f"checkpoint_stride={int(self.checkpoint_stride)}", + f"recompute_window_len={int(self.recompute_window_len)}", + f"checkpoint_steps={_int_tuple_summary(self.checkpoint_steps)}", + f"backward_windows={_window_tuple_summary(self.backward_windows)}", + f"store_step_artifacts={int(bool(self.store_step_artifacts))}", + f"checkpoint_owner={self.checkpoint_owner}", + f"reverse_artifact_kind={self.reverse_artifact_kind}", + f"output_materialization={self.output_materialization}", + f"output_physical_steps={_int_tuple_summary(self.output_physical_steps)}", + f"scheduler_owner={self.scheduler_owner}", + f"local_seed_policy={self.local_seed_policy}", + f"metadata_policy={self.metadata_policy}", + f"primitive_output_policy={self.primitive_output_policy}", + f"tape_policy={self.tape_policy}", + f"alias_policy={self.alias_policy}", + f"recompute_window_policy={self.recompute_window_policy}", + f"materialization_policy={self.materialization_policy}", + f"cuda_graph_constraint={self.cuda_graph_constraint}", + *self.runtime_policy.review_summary, + ) + + +@dataclass(frozen=True) +class TemporalMemoryRuntimeArtifactPlan: + mode: TemporalRuntimeArtifactMode + checkpoint_stride: int + recompute_window_len: int + checkpoint_steps: tuple[int, ...] + backward_windows: tuple[tuple[int, int], ...] + store_step_artifacts: bool + checkpoint_owner: str + reason: str + workspace_aliases: tuple[str, ...] + runtime_policy: TemporalMemoryRuntimePolicy + runtime_schedule_plan: TemporalMemoryRuntimeSchedulePlan + reverse_artifact_roles: tuple[str, ...] = () + + @property + def fingerprint(self) -> tuple[str, ...]: + return self.review_summary + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "memory_runtime_artifact_plan=compiler_executable", + f"mode={self.mode}", + f"checkpoint_stride={int(self.checkpoint_stride)}", + f"recompute_window_len={int(self.recompute_window_len)}", + f"checkpoint_steps={_int_tuple_summary(self.checkpoint_steps)}", + f"backward_windows={_window_tuple_summary(self.backward_windows)}", + f"store_step_artifacts={int(bool(self.store_step_artifacts))}", + f"checkpoint_owner={self.checkpoint_owner}", + f"workspace_aliases={_tuple_summary(self.workspace_aliases)}", + f"reverse_artifact_roles={_tuple_summary(self.reverse_artifact_roles)}", + *self.runtime_schedule_plan.review_summary, + *self.runtime_policy.review_summary, + self.reason, + ) + + +@dataclass(frozen=True) +class TemporalPhysicalStrategyRow: + row_index: int + strategy: TemporalPhysicalStrategyKind + status: TemporalPhysicalStrategyStatus + executable: bool + physical_time_steps: int + inner_steps: int + output_boundary: TemporalPhysicalStrategyOutputBoundary + reset_policy: TemporalPhysicalStrategyResetPolicy + required_surface_mask: int + consumed_table_mask: int + blocker: str = "" + schema_version: int = 1 + + @property + def summary(self) -> str: + return ( + f"physical_strategy_row={int(self.row_index)}" + f",strategy={self.strategy}" + f",status={self.status}" + f",executable={int(bool(self.executable))}" + f",physical_time_steps={int(self.physical_time_steps)}" + f",inner_steps={int(self.inner_steps)}" + f",output_boundary={self.output_boundary}" + f",reset_policy={self.reset_policy}" + f",required_surface_mask={int(self.required_surface_mask)}" + f",consumed_table_mask={int(self.consumed_table_mask)}" + f",blocker={self.blocker or '-'}" + f",schema_version={int(self.schema_version)}" + ) + + +@dataclass(frozen=True) +class TemporalPhysicalStrategyPlan: + rows: tuple[TemporalPhysicalStrategyRow, ...] + selected_strategy: str + streaming_strategy_status: str + reason: str + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "physical_strategy_plan=compiler_executable", + f"row_count={len(self.rows)}", + f"selected_strategy={self.selected_strategy}", + f"streaming_strategy_status={self.streaming_strategy_status}", + f"reason={self.reason}", + *(row.summary for row in self.rows), + ) + + @property + def fingerprint(self) -> tuple[str, ...]: + return self.review_summary + + +@dataclass(frozen=True) +class TemporalRuntimeBufferSpec: + name: str + tensor_role: str + shape: tuple[int, ...] + dtype: str + device: str + workspace_class: str + alias_set: str + init: Literal["empty", "zeros"] + owner: str + memory_row_index: int | None = None + surface: str = "" + bucket_ordinal: int = 0 + effect: str = "" + runtime_role: str = "workspace" + logical_index: int = 0 + allocation: Literal["eager", "deferred_local"] = "eager" + + @property + def summary(self) -> str: + memory_row = "*" if self.memory_row_index is None else str(int(self.memory_row_index)) + return ( + f"buffer={self.name}" + f",memory_row={memory_row}" + f",role={self.tensor_role}" + f",shape={_int_tuple_summary(self.shape)}" + f",dtype={self.dtype}" + f",device={self.device}" + f",workspace={self.workspace_class}" + f",alias={self.alias_set}" + f",surface={self.surface or '-'}" + f",bucket={int(self.bucket_ordinal)}" + f",effect={self.effect or '-'}" + f",runtime_role={self.runtime_role}" + f",logical_index={int(self.logical_index)}" + f",init={self.init}" + f",allocation={self.allocation}" + f",owner={self.owner}" + ) + + @property + def numel(self) -> int: + elements = 1 + for dim in self.shape: + elements *= max(1, int(dim)) + return int(elements) + + @property + def estimated_bytes(self) -> int: + return int(self.numel) * _runtime_dtype_element_size(self.dtype) + + +@dataclass(frozen=True) +class TemporalTransitionForwardRuntimeBufferRequest: + primitive_row_index: int + bucket_ordinal: int + logical_name: str + shape: tuple[int, ...] + runtime_role: str = "transition_forward_linear_output" + logical_index: int | None = None + alias_runtime_role: str = "" + + +@dataclass(frozen=True) +class TemporalTransitionReverseRuntimeBufferRequest: + bucket_ordinal: int + logical_name: str + shape: tuple[int, ...] + runtime_role: str + effect: str + logical_index: int | None = None + + +@dataclass(frozen=True) +class TemporalRuntimeBufferPlan: + specs: tuple[TemporalRuntimeBufferSpec, ...] + runtime_policy: TemporalMemoryRuntimePolicy | None = None + runtime_schedule_fingerprint: tuple[str, ...] = () + runtime_schedule_rows: torch.Tensor | None = None + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "memory_runtime_buffer_plan=compiler_executable", + f"buffer_count={len(self.specs)}", + f"planned_buffer_bytes={int(self.planned_buffer_bytes)}", + f"estimated_allocated_buffer_bytes={int(self.estimated_allocated_buffer_bytes)}", + "bytes_by_workspace=" + _bytes_by_summary(self.bytes_by_workspace_class), + "bytes_by_runtime_role=" + _bytes_by_summary(self.bytes_by_runtime_role), + f"runtime_schedule_attached={int(bool(self.runtime_schedule_fingerprint))}", + *( + (f"runtime_schedule_rows={_schedule_rows_summary(self.runtime_schedule_rows)}",) + if self.runtime_schedule_rows is not None + else () + ), + *((self.runtime_policy.review_summary) if self.runtime_policy is not None else ()), + *(spec.summary for spec in self.specs), + ) + + @property + def fingerprint(self) -> tuple[str, ...]: + return self.review_summary + + @property + def planned_buffer_bytes(self) -> int: + return sum(int(spec.estimated_bytes) for spec in self.specs) + + @property + def estimated_allocated_buffer_bytes(self) -> int: + seen_aliases: set[tuple[str, tuple[int, ...], str, str, str]] = set() + allocated = 0 + for spec in self.specs: + if spec.allocation == "deferred_local": + continue + alias_key = _runtime_buffer_alias_allocation_key(self.runtime_policy, spec) + if alias_key is not None: + if alias_key in seen_aliases: + continue + seen_aliases.add(alias_key) + allocated += int(spec.estimated_bytes) + return int(allocated) + + @property + def bytes_by_workspace_class(self) -> tuple[tuple[str, int], ...]: + totals: dict[str, int] = {} + for spec in self.specs: + totals[spec.workspace_class] = totals.get(spec.workspace_class, 0) + int(spec.estimated_bytes) + return tuple(sorted(totals.items())) + + @property + def bytes_by_runtime_role(self) -> tuple[tuple[str, int], ...]: + totals: dict[str, int] = {} + for spec in self.specs: + totals[spec.runtime_role] = totals.get(spec.runtime_role, 0) + int(spec.estimated_bytes) + return tuple(sorted(totals.items())) + + +_MEMORY_TENSOR_CLASS_OPCODE = { + "backward_input_grad": 1, + "materialized_output_boundary": 2, + "message_activation": 3, + "output_activation": 4, + "parameter_binding": 5, + "parameter_grad_accumulator": 6, + "private_state": 7, + "public_state": 8, + "state_activation": 9, + "state_carry": 10, + "transition_tape": 11, + "unknown_tensor_class": 12, + "runtime_policy": 13, + "metadata_descriptor": 14, + "primitive_output": 15, +} +_MEMORY_LIFETIME_OPCODE = { + "backward_window": 1, + "compiled_program": 2, + "output_window": 3, + "physical_step": 4, + "primitive_scope": 5, + "temporal_carry": 6, +} +_MEMORY_WORKSPACE_CLASS_OPCODE = { + "message_workspace": 1, + "output_workspace": 2, + "parameter_table": 3, + "primitive_workspace": 4, + "reduction_workspace": 5, + "state_carry_workspace": 6, + "tensor_table_workspace": 7, + "transition_workspace": 8, + "policy_table": 9, +} +_MEMORY_EFFECT_OPCODE = { + "grad_read": 1, + "materialization_boundary": 2, + "message_emit": 3, + "message_read": 4, + "output_emit": 5, + "parameter_grad_emit": 6, + "parameter_read": 7, + "state_emit": 8, + "state_read": 9, + "state_read_write": 10, + "state_write": 11, + "tape_policy": 12, + "tensor_role": 13, + "local_seed_policy": 14, + "metadata_policy": 15, + "primitive_output_policy": 16, + "alias_policy": 17, + "recompute_window_policy": 18, + "materialization_policy": 19, + "cuda_graph_constraint": 20, +} +_MEMORY_RECOMPUTE_POLICY_OPCODE = { + "accumulate_window_then_bind": 1, + "carry_forward_or_checkpoint": 2, + "checkpoint_boundary_policy": 3, + "materialize_when_requested": 4, + "not_recomputed": 5, + "planner_policy": 6, + "recompute_or_store_by_scheduler": 7, + "scheduler_tape_policy": 8, + "policy_not_recomputed": 9, + "scheduler_alias_policy": 10, + "scheduler_recompute_window_policy": 11, + "scheduler_materialization_policy": 12, + "cuda_graph_guard_policy": 13, +} +_MEMORY_OWNER_OPCODE = { + "compiler_primitive_row": 1, + "compiler_tensor_role_table": 2, + "compiler_memory_policy": 3, +} +_RUNTIME_BUFFER_ROLE_OPCODE = { + "workspace": 0, + "output_seq": 1, + "grad_boundary_seq": 2, + "forward_cells_prev_artifact": 3, + "forward_recurrent_hidden_after": 4, + "reverse_grad_carry_cells": 5, + "reverse_grad_cells_work": 6, + "transition_forward_linear_output": 7, + "transition_forward_matmul_output": 8, + "transition_forward_state_output": 9, + "transition_forward_norm_output": 10, + "transition_forward_diag_output": 11, + "transition_forward_unary_output": 18, + "transition_reverse_recurrent_msg_span": 19, + "transition_reverse_state_before_zero": 20, + "forward_recurrent_msg": 12, + "forward_output_msg": 13, + "forward_output_cells": 14, + "reverse_grad_recurrent_msg": 15, + "forward_message_step_flat": 16, + "reverse_message_step_flat": 17, +} +_RUNTIME_PUBLIC_STATE_ALIAS_PREFIX = "runtime_alias.transition_public_state." +_RUNTIME_SCHEDULE_ROLE_OPCODE = { + "local_seed_policy": 1, + "metadata_policy": 2, + "primitive_output_policy": 3, + "tape_policy": 4, + "alias_policy": 5, + "recompute_window_policy": 6, + "materialization_policy": 7, + "cuda_graph_constraint": 8, + "checkpoint_stride": 20, + "recompute_window_len": 21, + "checkpoint_step": 22, + "backward_window": 23, + "output_physical_step": 24, + "store_step_artifacts": 25, + "physical_time_steps": 26, +} +_PHYSICAL_STRATEGY_OPCODE = { + "stage_materialized": 1, + "streaming_step_producer_consumer": 2, +} +_PHYSICAL_STRATEGY_STATUS_OPCODE = { + "active": 1, + "candidate": 2, + "blocked": 3, +} +_PHYSICAL_STRATEGY_OUTPUT_BOUNDARY_OPCODE = { + "terminal": 1, + "sequence": 2, +} +_PHYSICAL_STRATEGY_RESET_POLICY_OPCODE = { + "absent": 1, + "present": 2, + "unknown": 3, +} +_PHYSICAL_STRATEGY_BLOCKER_OPCODE = { + "": 0, + "pending_registered_streaming_step_program_body": 1, +} +_PHYSICAL_STRATEGY_SURFACE_MASK = { + "message": 1, + "transition": 2, + "readout": 4, + "artifacts": 8, + "reducers": 16, +} +_PHYSICAL_STRATEGY_TABLE_MASK = { + "primitive_rows": 1, + "executor_rows": 2, + "binding_rows": 4, + "memory_liveness_rows": 8, + "artifact_route_rows": 16, + "output_route_rows": 32, + "runtime_schedule_rows": 64, +} + +_TORCH_DTYPE_BY_SPEC = { + "torch.float32": torch.float32, + "torch.float": torch.float32, + "float32": torch.float32, + "torch.int64": torch.int64, + "torch.long": torch.int64, + "int64": torch.int64, +} +_REVERSE_PROGRAM_ARTIFACT_ROLES = temporal_reverse_artifact_role_names() + + +def temporal_runtime_buffer_role_opcode(runtime_role: str) -> int: + try: + return _RUNTIME_BUFFER_ROLE_OPCODE[str(runtime_role)] + except KeyError as error: + raise RuntimeError(f"Unknown temporal runtime buffer role {runtime_role!r}") from error + + +def temporal_memory_runtime_schedule_rows_tensor( + schedule_plan: TemporalMemoryRuntimeSchedulePlan, +) -> torch.Tensor: + schedule_plan.runtime_policy.require_complete() + rows: list[list[int]] = [] + + def append(role: str, memory_row_index: int, value0: int, value1: int = 0) -> None: + rows.append( + [ + len(rows), + _memory_opcode(_RUNTIME_SCHEDULE_ROLE_OPCODE, role, "runtime_schedule_role"), + int(memory_row_index), + int(value0), + int(value1), + 1, + ] + ) + + for effect in _REQUIRED_MEMORY_POLICY_EFFECTS: + policy = schedule_plan.runtime_policy.policy_for(effect) + append( + effect, + schedule_plan.runtime_policy.memory_row_for(effect), + _memory_opcode(_MEMORY_RECOMPUTE_POLICY_OPCODE, policy, "runtime_schedule_policy"), + ) + append("physical_time_steps", -1, int(schedule_plan.physical_time_steps)) + append("checkpoint_stride", -1, int(schedule_plan.checkpoint_stride)) + append("recompute_window_len", -1, int(schedule_plan.recompute_window_len)) + append("store_step_artifacts", -1, 1 if schedule_plan.store_step_artifacts else 0) + for step in schedule_plan.checkpoint_steps: + append("checkpoint_step", -1, int(step)) + for start, end in schedule_plan.backward_windows: + append("backward_window", -1, int(start), int(end)) + for step in schedule_plan.output_physical_steps: + append("output_physical_step", -1, int(step)) + return torch.tensor(rows, dtype=torch.long) + + +def build_temporal_physical_strategy_plan( + schedule_plan: TemporalMemoryRuntimeSchedulePlan, + *, + inner_steps: int, + output_boundary: TemporalPhysicalStrategyOutputBoundary, + reset_policy: TemporalPhysicalStrategyResetPolicy, + streaming_step_body_available: bool = False, +) -> TemporalPhysicalStrategyPlan: + physical_time_steps = max(1, int(schedule_plan.physical_time_steps)) + inner_steps = max(1, int(inner_steps)) + required_surface_mask = _physical_strategy_mask( + _PHYSICAL_STRATEGY_SURFACE_MASK, + ("message", "transition", "readout", "artifacts", "reducers"), + "physical_strategy_surface", + ) + consumed_table_mask = _physical_strategy_mask( + _PHYSICAL_STRATEGY_TABLE_MASK, + ( + "primitive_rows", + "executor_rows", + "binding_rows", + "memory_liveness_rows", + "artifact_route_rows", + "output_route_rows", + "runtime_schedule_rows", + ), + "physical_strategy_table", + ) + streaming_selected = bool(streaming_step_body_available) + streaming_blocker = "" if streaming_selected else "pending_registered_streaming_step_program_body" + streaming_status: TemporalPhysicalStrategyStatus = "active" if streaming_selected else "blocked" + stage_status: TemporalPhysicalStrategyStatus = "candidate" if streaming_selected else "active" + rows = ( + TemporalPhysicalStrategyRow( + row_index=0, + strategy="stage_materialized", + status=stage_status, + executable=not streaming_selected, + physical_time_steps=physical_time_steps, + inner_steps=inner_steps, + output_boundary=output_boundary, + reset_policy=reset_policy, + required_surface_mask=required_surface_mask, + consumed_table_mask=consumed_table_mask, + ), + TemporalPhysicalStrategyRow( + row_index=1, + strategy="streaming_step_producer_consumer", + status=streaming_status, + executable=streaming_selected, + physical_time_steps=physical_time_steps, + inner_steps=inner_steps, + output_boundary=output_boundary, + reset_policy=reset_policy, + required_surface_mask=required_surface_mask, + consumed_table_mask=consumed_table_mask, + blocker=streaming_blocker, + ), + ) + selected_strategy = "streaming_step_producer_consumer" if streaming_selected else "stage_materialized" + return TemporalPhysicalStrategyPlan( + rows=rows, + selected_strategy=selected_strategy, + streaming_strategy_status=streaming_status, + reason=( + f"active_strategy={selected_strategy};" + + ( + "streaming_step_strategy=registered_program_body;" + if streaming_selected + else "streaming_step_strategy=compiler_product_pending_registered_program_body;" + ) + + "semantics=primitive_rows_stable" + ), + ) + + +def temporal_physical_strategy_rows_tensor(strategy_plan: TemporalPhysicalStrategyPlan) -> torch.Tensor: + rows = [ + [ + int(row.row_index), + int(row.schema_version), + _memory_opcode(_PHYSICAL_STRATEGY_OPCODE, row.strategy, "physical_strategy"), + _memory_opcode(_PHYSICAL_STRATEGY_STATUS_OPCODE, row.status, "physical_strategy_status"), + 1 if row.executable else 0, + int(row.physical_time_steps), + int(row.inner_steps), + _memory_opcode( + _PHYSICAL_STRATEGY_OUTPUT_BOUNDARY_OPCODE, + row.output_boundary, + "physical_strategy_output_boundary", + ), + _memory_opcode( + _PHYSICAL_STRATEGY_RESET_POLICY_OPCODE, + row.reset_policy, + "physical_strategy_reset_policy", + ), + int(row.required_surface_mask), + int(row.consumed_table_mask), + _memory_opcode(_PHYSICAL_STRATEGY_BLOCKER_OPCODE, row.blocker, "physical_strategy_blocker"), + ] + for row in strategy_plan.rows + ] + if not rows: + return torch.empty((0, 12), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_memory_liveness_rows_tensor(memory_plan: TemporalMemoryLivenessPlan) -> torch.Tensor: + rows: list[list[int]] = [] + for entry_index, entry in enumerate(memory_plan.entries): + rows.append( + [ + int(entry_index), + -1 if entry.row_index is None else int(entry.row_index), + int(entry.bucket_ordinal), + int(temporal_surface_opcode(entry.surface)), + _memory_opcode(_MEMORY_TENSOR_CLASS_OPCODE, entry.tensor_class, "tensor_class"), + _memory_opcode(_MEMORY_LIFETIME_OPCODE, entry.lifetime, "lifetime"), + _memory_opcode(_MEMORY_WORKSPACE_CLASS_OPCODE, entry.workspace_class, "workspace_class"), + _memory_opcode(_MEMORY_EFFECT_OPCODE, entry.effect, "effect"), + _memory_opcode(_MEMORY_RECOMPUTE_POLICY_OPCODE, entry.recompute_policy, "recompute_policy"), + _memory_opcode(_MEMORY_OWNER_OPCODE, entry.owner, "owner"), + ] + ) + if not rows: + return torch.empty((0, 10), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_memory_runtime_policy(memory_plan: TemporalMemoryLivenessPlan) -> TemporalMemoryRuntimePolicy: + policy_rows = tuple( + (entry_index, entry) for entry_index, entry in enumerate(memory_plan.entries) if _memory_entry_is_policy(entry) + ) + if not policy_rows: + return TemporalMemoryRuntimePolicy( + effect_policies=(), + memory_row_indices=(), + source="missing_compiler_memory_policy_rows", + ) + effect_to_policy: dict[str, str] = {} + row_indices: list[int] = [] + for entry_index, entry in policy_rows: + if entry.row_index is not None: + raise RuntimeError( + "Temporal memory policy rows must be program-level rows, not primitive-owned rows: " + f"memory_row={entry_index}; primitive_row={entry.row_index}" + ) + if entry.owner != "compiler_memory_policy" or entry.workspace_class != "policy_table": + raise RuntimeError( + "Temporal memory policy rows must be owned by the compiler policy table: " + f"memory_row={entry_index}; owner={entry.owner}; workspace={entry.workspace_class}" + ) + if entry.effect in effect_to_policy: + raise RuntimeError(f"Temporal memory policy row effect is duplicated: effect={entry.effect}") + effect_to_policy[entry.effect] = entry.recompute_policy + row_indices.append(int(entry_index)) + ordered_effects = tuple( + (effect, effect_to_policy[effect]) + for effect in (*_REQUIRED_MEMORY_POLICY_EFFECTS, *tuple(sorted(effect_to_policy))) + if effect in effect_to_policy + ) + return TemporalMemoryRuntimePolicy( + effect_policies=tuple(dict(ordered_effects).items()), + memory_row_indices=tuple(row_indices), + effect_row_indices=tuple( + (effect, int(row_index)) + for row_index, effect in sorted((row_index, entry.effect) for row_index, entry in policy_rows) + ), + source="compiler_memory_liveness_plan", + ) + + +def build_temporal_memory_liveness_plan(table: TemporalPrimitiveTablePlan) -> TemporalMemoryLivenessPlan: + entries: list[TemporalMemoryPlanEntry] = [] + entries.extend(_slot_entry(slot) for slot in table.tensor_slots) + for row_index, row in enumerate(table.primitive_rows): + surface = surface_for_temporal_row(row) + entries.extend( + _primitive_effect_entry( + row_index=row_index, + surface=surface, + bucket_ordinal=int(row.bucket_ordinal), + effect=effect, + ) + for effect in _temporal_memory_effects_for_row(row, surface=surface) + ) + entries.extend(_transition_reverse_dynamic_buffer_entries(table)) + entries.extend(_memory_policy_entries(table)) + return TemporalMemoryLivenessPlan( + entries=tuple(entries), + workspace_policy="planner_assigns_workspace_lifetimes_and_alias_sets", + layout_policy="compiler_declares_layouts_before_strategy_selection", + alias_policy="executors_request_workspace_classes_planner_assigns_alias_sets", + peak_workspace_estimate_bytes=None, + ) + + +def _temporal_memory_effects_for_row(row: object, *, surface: str) -> tuple[str, ...]: + effects = list(temporal_effects_for_row(row, surface=surface)) + if surface in {"message", "readout"}: + effects.extend(("grad_read", "parameter_grad_emit")) + elif surface == "transition": + effects.extend(("grad_read", "parameter_grad_emit")) + return tuple(dict.fromkeys(effects)) + + +def build_temporal_memory_runtime_artifact_plan( + memory_plan: TemporalMemoryLivenessPlan, + *, + physical_time_steps: int, + collect_artifacts: bool, + scheduler_plan: Any | None, +) -> TemporalMemoryRuntimeArtifactPlan: + physical_time_steps = max(1, int(physical_time_steps)) + runtime_schedule_plan = build_temporal_memory_runtime_schedule_plan( + memory_plan, + physical_time_steps=physical_time_steps, + collect_artifacts=collect_artifacts, + scheduler_plan=scheduler_plan, + ) + runtime_policy = runtime_schedule_plan.runtime_policy + workspace_aliases = tuple( + entry.alias_set + for entry in memory_plan.entries + if entry.lifetime in {"temporal_carry", "backward_window", "output_window"} + ) + if not collect_artifacts: + return TemporalMemoryRuntimeArtifactPlan( + mode="none", + checkpoint_stride=runtime_schedule_plan.checkpoint_stride, + recompute_window_len=runtime_schedule_plan.recompute_window_len, + checkpoint_steps=(), + backward_windows=(), + store_step_artifacts=False, + checkpoint_owner=runtime_schedule_plan.checkpoint_owner, + reason=( + "artifact_policy=none;" + f"reverse_artifacts={runtime_schedule_plan.reverse_artifact_kind};" + "source=compiler_memory_liveness_plan;" + f"{_runtime_policy_reason(runtime_policy)}" + ), + workspace_aliases=workspace_aliases, + runtime_policy=runtime_policy, + runtime_schedule_plan=runtime_schedule_plan, + reverse_artifact_roles=(), + ) + if runtime_schedule_plan.mode == "store_step_artifacts": + return TemporalMemoryRuntimeArtifactPlan( + mode="store_step_artifacts", + checkpoint_stride=runtime_schedule_plan.checkpoint_stride, + recompute_window_len=runtime_schedule_plan.recompute_window_len, + checkpoint_steps=runtime_schedule_plan.checkpoint_steps, + backward_windows=runtime_schedule_plan.backward_windows, + store_step_artifacts=True, + checkpoint_owner=runtime_schedule_plan.checkpoint_owner, + reason=( + "artifact_policy=store_step_artifacts;" + f"reverse_artifacts={runtime_schedule_plan.reverse_artifact_kind};" + "source=compiler_memory_liveness_plan;" + f"{_runtime_policy_reason(runtime_policy)}" + ), + workspace_aliases=workspace_aliases, + runtime_policy=runtime_policy, + runtime_schedule_plan=runtime_schedule_plan, + reverse_artifact_roles=_REVERSE_PROGRAM_ARTIFACT_ROLES, + ) + return TemporalMemoryRuntimeArtifactPlan( + mode="recompute_step_artifacts", + checkpoint_stride=runtime_schedule_plan.checkpoint_stride, + recompute_window_len=runtime_schedule_plan.recompute_window_len, + checkpoint_steps=runtime_schedule_plan.checkpoint_steps, + backward_windows=runtime_schedule_plan.backward_windows, + store_step_artifacts=False, + checkpoint_owner=runtime_schedule_plan.checkpoint_owner, + reason=( + "artifact_policy=checkpoint_recompute;" + f"reverse_artifacts={runtime_schedule_plan.reverse_artifact_kind};" + "source=compiler_memory_liveness_plan;" + f"{_runtime_policy_reason(runtime_policy)}" + ), + workspace_aliases=workspace_aliases, + runtime_policy=runtime_policy, + runtime_schedule_plan=runtime_schedule_plan, + reverse_artifact_roles=_REVERSE_PROGRAM_ARTIFACT_ROLES, + ) + + +def build_temporal_memory_runtime_schedule_plan( + memory_plan: TemporalMemoryLivenessPlan, + *, + physical_time_steps: int, + collect_artifacts: bool, + scheduler_plan: Any | None, +) -> TemporalMemoryRuntimeSchedulePlan: + physical_time_steps = max(1, int(physical_time_steps)) + runtime_policy = temporal_memory_runtime_policy(memory_plan) + runtime_policy.require_complete() + reverse_artifact_kind = _scheduler_materialization_value( + scheduler_plan, + "reverse_artifact_kind", + "store_step_artifacts" if collect_artifacts else "none", + ) + checkpoint_steps = _scheduler_int_value( + scheduler_plan, + materialization_name="checkpoint_steps", + checkpoint_name="checkpoint_steps", + ) + recompute_window_steps = _scheduler_int_value( + scheduler_plan, + materialization_name="recompute_window_steps", + checkpoint_name="backward_window_steps", + ) + checkpoint_stride = _bounded_window_value(checkpoint_steps, physical_time_steps) + recompute_window_len = _bounded_window_value(recompute_window_steps, physical_time_steps) + if not collect_artifacts: + mode: TemporalRuntimeArtifactMode = "none" + checkpoint_stride = physical_time_steps + recompute_window_len = physical_time_steps + planned_checkpoint_steps: tuple[int, ...] = () + planned_backward_windows: tuple[tuple[int, int], ...] = () + checkpoint_owner = "compiler_memory_schedule_inference_no_backward_artifacts" + elif reverse_artifact_kind in {"forward_reverse_tables", "store_step_artifacts", "stored_step_artifacts"}: + mode = "store_step_artifacts" + checkpoint_stride = physical_time_steps + planned_checkpoint_steps = (0,) + planned_backward_windows = _planned_backward_windows( + mode=mode, + physical_time_steps=physical_time_steps, + checkpoint_stride=checkpoint_stride, + recompute_window_len=recompute_window_len, + ) + checkpoint_owner = "compiler_memory_schedule_store_artifacts" + else: + mode = "recompute_step_artifacts" + planned_checkpoint_steps = _planned_checkpoint_steps( + checkpoint_stride=checkpoint_stride, + physical_time_steps=physical_time_steps, + ) + planned_backward_windows = _planned_backward_windows( + mode=mode, + physical_time_steps=physical_time_steps, + checkpoint_stride=checkpoint_stride, + recompute_window_len=recompute_window_len, + ) + checkpoint_owner = "compiler_memory_schedule_checkpoint_recompute" + return TemporalMemoryRuntimeSchedulePlan( + mode=mode, + physical_time_steps=physical_time_steps, + checkpoint_stride=checkpoint_stride, + recompute_window_len=recompute_window_len, + checkpoint_steps=planned_checkpoint_steps, + backward_windows=planned_backward_windows, + store_step_artifacts=mode == "store_step_artifacts", + checkpoint_owner=checkpoint_owner, + reverse_artifact_kind=reverse_artifact_kind, + output_materialization=_scheduler_materialization_value( + scheduler_plan, + "output_materialization", + "outputs_only", + ), + output_physical_steps=_scheduler_output_physical_steps(scheduler_plan), + scheduler_owner=str(getattr(scheduler_plan, "owner", "missing_scheduler_plan")), + primitive_output_policy=runtime_policy.policy_for("primitive_output_policy"), + tape_policy=runtime_policy.policy_for("tape_policy"), + alias_policy=runtime_policy.policy_for("alias_policy"), + recompute_window_policy=runtime_policy.policy_for("recompute_window_policy"), + materialization_policy=runtime_policy.policy_for("materialization_policy"), + cuda_graph_constraint=runtime_policy.policy_for("cuda_graph_constraint"), + local_seed_policy=runtime_policy.policy_for("local_seed_policy"), + metadata_policy=runtime_policy.policy_for("metadata_policy"), + runtime_policy=runtime_policy, + ) + + +def build_temporal_runtime_buffer_plan( + memory_plan: TemporalMemoryLivenessPlan, + *, + output_seq_shape: tuple[int, ...] | None = None, + grad_boundary_seq_shape: tuple[int, ...] | None = None, + forward_message_step_flat_shape: tuple[int, ...] | None = None, + reverse_message_step_flat_shape: tuple[int, ...] | None = None, + physical_time_steps: int | None = None, + cells_prev_shape: tuple[int, ...] | None = None, + recurrent_hidden_shape: tuple[int, ...] | None = None, + grad_carry_cells_shape: tuple[int, ...] | None = None, + reverse_grad_cells_work_shape: tuple[int, ...] | None = None, + forward_recurrent_msg_shape: tuple[int, ...] | None = None, + forward_output_msg_shape: tuple[int, ...] | None = None, + forward_output_cells_shape: tuple[int, ...] | None = None, + reverse_grad_recurrent_msg_shape: tuple[int, ...] | None = None, + transition_forward_outputs: tuple[TemporalTransitionForwardRuntimeBufferRequest, ...] = (), + transition_reverse_dynamic_buffers: tuple[TemporalTransitionReverseRuntimeBufferRequest, ...] = (), + runtime_schedule_plan: TemporalMemoryRuntimeSchedulePlan | None = None, + dtype: str, + device: str, + include_workspace_rows: bool = False, + enable_public_state_runtime_alias: bool = False, + defer_forward_step_buffers: bool = False, + defer_local_transition_outputs: bool = False, +) -> TemporalRuntimeBufferPlan: + runtime_policy = temporal_memory_runtime_policy(memory_plan) + runtime_policy.require_complete() + runtime_schedule_rows = None + runtime_schedule_fingerprint: tuple[str, ...] = () + if runtime_schedule_plan is not None: + _require_runtime_schedule_matches_policy(runtime_policy, runtime_schedule_plan) + runtime_schedule_rows = temporal_memory_runtime_schedule_rows_tensor(runtime_schedule_plan) + runtime_schedule_fingerprint = runtime_schedule_plan.fingerprint + specs: list[TemporalRuntimeBufferSpec] = [] + recurrent_hidden_shape_tuple = ( + None if recurrent_hidden_shape is None else tuple(int(item) for item in recurrent_hidden_shape) + ) + recurrent_hidden_step_count = ( + 0 + if recurrent_hidden_shape_tuple is None + else _runtime_physical_step_count( + physical_time_steps, + role="forward_recurrent_hidden_after", + ) + ) + public_state_runtime_alias_set = ( + _runtime_public_state_alias_set(recurrent_hidden_shape_tuple) + if bool(enable_public_state_runtime_alias) + and any( + _transition_forward_request_aliases_recurrent_hidden( + request, + recurrent_hidden_shape=recurrent_hidden_shape_tuple, + recurrent_hidden_step_count=recurrent_hidden_step_count, + ) + for request in transition_forward_outputs + ) + else None + ) + if output_seq_shape is not None: + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="output_workspace", + preferred_effect="output_emit", + ) + specs.append( + TemporalRuntimeBufferSpec( + name="output_seq", + tensor_role="output_activation", + shape=tuple(int(item) for item in output_seq_shape), + dtype=str(dtype), + device=str(device), + workspace_class="output_workspace", + alias_set=memory_entry.alias_set, + init="empty", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(memory_entry.bucket_ordinal), + effect=memory_entry.effect, + runtime_role="output_seq", + ) + ) + if grad_boundary_seq_shape is not None: + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="reduction_workspace", + preferred_effect="grad_read", + ) + specs.append( + TemporalRuntimeBufferSpec( + name="grad_boundary_seq", + tensor_role="grad_boundary_accumulator", + shape=tuple(int(item) for item in grad_boundary_seq_shape), + dtype=str(dtype), + device=str(device), + workspace_class="reduction_workspace", + alias_set=memory_entry.alias_set, + init="zeros", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(memory_entry.bucket_ordinal), + effect=memory_entry.effect, + runtime_role="grad_boundary_seq", + ) + ) + for role, shape, surface, bucket_ordinal in ( + ( + "forward_message_step_flat", + forward_message_step_flat_shape, + "message", + -1, + ), + ( + "reverse_message_step_flat", + reverse_message_step_flat_shape, + "message", + -1, + ), + ): + if shape is None: + continue + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="message_workspace", + preferred_effect="message_read", + ) + specs.append( + TemporalRuntimeBufferSpec( + name=role, + tensor_role=role, + shape=tuple(int(item) for item in shape), + dtype="torch.int64", + device=str(device), + workspace_class="message_workspace", + alias_set=f"{surface}.{role}", + init="empty", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface or surface, + bucket_ordinal=( + int(memory_entry.bucket_ordinal) if int(memory_entry.bucket_ordinal) >= 0 else int(bucket_ordinal) + ), + effect=memory_entry.effect, + runtime_role=role, + ) + ) + if cells_prev_shape is not None: + step_count = _runtime_physical_step_count( + physical_time_steps, + role="forward_cells_prev_artifact", + ) + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="state_carry_workspace", + preferred_effect="state_read", + ) + for step_index in range(step_count): + specs.append( + TemporalRuntimeBufferSpec( + name=f"forward_cells_prev_artifact_step_{int(step_index)}", + tensor_role="forward_cells_prev_artifact", + shape=tuple(int(item) for item in cells_prev_shape), + dtype=str(dtype), + device=str(device), + workspace_class="state_carry_workspace", + alias_set=memory_entry.alias_set, + init="zeros", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(memory_entry.bucket_ordinal), + effect=memory_entry.effect, + runtime_role="forward_cells_prev_artifact", + logical_index=int(step_index), + ) + ) + if recurrent_hidden_shape_tuple is not None: + step_count = int(recurrent_hidden_step_count) + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="state_carry_workspace", + preferred_effect="state_write", + ) + for step_index in range(step_count): + specs.append( + TemporalRuntimeBufferSpec( + name=f"forward_recurrent_hidden_after_step_{int(step_index)}", + tensor_role="forward_recurrent_hidden_after", + shape=recurrent_hidden_shape_tuple, + dtype=str(dtype), + device=str(device), + workspace_class="state_carry_workspace", + alias_set=public_state_runtime_alias_set or memory_entry.alias_set, + init="empty", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(memory_entry.bucket_ordinal), + effect=memory_entry.effect, + runtime_role="forward_recurrent_hidden_after", + logical_index=int(step_index), + allocation=( + "deferred_local" + if bool(defer_forward_step_buffers) + and runtime_schedule_plan is not None + and runtime_schedule_plan.mode == "none" + and not _runtime_public_state_alias_enabled(public_state_runtime_alias_set or "") + else "eager" + ), + ) + ) + reverse_grad_cells_work_shape = ( + grad_carry_cells_shape if reverse_grad_cells_work_shape is None else reverse_grad_cells_work_shape + ) + if reverse_grad_cells_work_shape is not None: + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="reduction_workspace", + preferred_effect="grad_read", + ) + specs.append( + TemporalRuntimeBufferSpec( + name="reverse_grad_cells_work", + tensor_role="reverse_grad_cells_work", + shape=tuple(int(item) for item in reverse_grad_cells_work_shape), + dtype=str(dtype), + device=str(device), + workspace_class="reduction_workspace", + alias_set=memory_entry.alias_set, + init="zeros", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(memory_entry.bucket_ordinal), + effect=memory_entry.effect, + runtime_role="reverse_grad_cells_work", + logical_index=0, + ) + ) + if grad_carry_cells_shape is not None: + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="reduction_workspace", + preferred_effect="grad_read", + ) + specs.append( + TemporalRuntimeBufferSpec( + name="reverse_grad_carry_cells", + tensor_role="reverse_grad_carry_cells", + shape=tuple(int(item) for item in grad_carry_cells_shape), + dtype=str(dtype), + device=str(device), + workspace_class="reduction_workspace", + alias_set=memory_entry.alias_set, + init="zeros", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(memory_entry.bucket_ordinal), + effect=memory_entry.effect, + runtime_role="reverse_grad_carry_cells", + logical_index=0, + ) + ) + for role, shape, workspace_class, preferred_effect in ( + ( + "forward_recurrent_msg", + forward_recurrent_msg_shape, + "message_workspace", + "message_emit", + ), + ( + "forward_output_msg", + forward_output_msg_shape, + "message_workspace", + "message_emit", + ), + ( + "forward_output_cells", + forward_output_cells_shape, + "output_workspace", + "output_emit", + ), + ): + if shape is None: + continue + step_count = _runtime_physical_step_count( + physical_time_steps, + role=role, + ) + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class=workspace_class, + preferred_effect=preferred_effect, + ) + for step_index in range(step_count): + specs.append( + TemporalRuntimeBufferSpec( + name=f"{role}_step_{int(step_index)}", + tensor_role=role, + shape=tuple(int(item) for item in shape), + dtype=str(dtype), + device=str(device), + workspace_class=workspace_class, + alias_set=memory_entry.alias_set, + init="empty", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(memory_entry.bucket_ordinal), + effect=memory_entry.effect, + runtime_role=role, + logical_index=int(step_index), + allocation=( + "deferred_local" + if bool(defer_forward_step_buffers) + and runtime_schedule_plan is not None + and runtime_schedule_plan.mode == "none" + else "eager" + ), + ) + ) + if reverse_grad_recurrent_msg_shape is not None: + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="reduction_workspace", + preferred_effect="grad_read", + ) + specs.append( + TemporalRuntimeBufferSpec( + name="reverse_grad_recurrent_msg", + tensor_role="reverse_grad_recurrent_msg", + shape=tuple(int(item) for item in reverse_grad_recurrent_msg_shape), + dtype=str(dtype), + device=str(device), + workspace_class="reduction_workspace", + alias_set=memory_entry.alias_set, + init="zeros", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(memory_entry.bucket_ordinal), + effect=memory_entry.effect, + runtime_role="reverse_grad_recurrent_msg", + logical_index=0, + ) + ) + for request in transition_forward_outputs: + memory_row_index, memory_entry = _runtime_memory_entry_for( + memory_plan, + workspace_class="transition_workspace", + preferred_effect="tape_policy", + row_index=int(request.primitive_row_index), + ) + logical_name = str(request.logical_name) + runtime_role = str(request.runtime_role) + request_shape = tuple(int(item) for item in request.shape) + alias_set = memory_entry.alias_set + if bool(enable_public_state_runtime_alias) and _transition_forward_request_aliases_recurrent_hidden( + request, + recurrent_hidden_shape=recurrent_hidden_shape_tuple, + recurrent_hidden_step_count=recurrent_hidden_step_count, + ): + alias_set = _runtime_public_state_alias_set(request_shape) + allocation: Literal["eager", "deferred_local"] = "eager" + if ( + bool(defer_local_transition_outputs) + and runtime_schedule_plan is not None + and runtime_schedule_plan.mode == "none" + and not _runtime_public_state_alias_enabled(alias_set) + and str(request.runtime_role).startswith("transition_forward_") + ): + allocation = "deferred_local" + specs.append( + TemporalRuntimeBufferSpec( + name=f"{runtime_role}_row_{int(request.primitive_row_index)}_{logical_name}", + tensor_role=runtime_role, + shape=request_shape, + dtype=str(dtype), + device=str(device), + workspace_class="transition_workspace", + alias_set=alias_set, + init="empty", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(request.bucket_ordinal), + effect=memory_entry.effect, + runtime_role=runtime_role, + logical_index=( + int(request.primitive_row_index) if request.logical_index is None else int(request.logical_index) + ), + allocation=allocation, + ) + ) + for request in transition_reverse_dynamic_buffers: + runtime_role = str(request.runtime_role) + logical_name = str(request.logical_name) + request_shape = tuple(int(item) for item in request.shape) + bucket_matches = tuple( + (entry_index, entry) + for entry_index, entry in enumerate(memory_plan.entries) + if entry.workspace_class == "transition_workspace" + and entry.effect == str(request.effect) + and entry.tensor_role == runtime_role + and int(entry.bucket_ordinal) == int(request.bucket_ordinal) + ) + if not bucket_matches: + raise RuntimeError( + "Temporal runtime buffer allocation requires a compiler memory-plan row for transition reverse " + f"dynamic buffer: bucket={int(request.bucket_ordinal)}; effect={request.effect}; " + f"role={runtime_role}" + ) + memory_row_index, memory_entry = bucket_matches[0] + specs.append( + TemporalRuntimeBufferSpec( + name=f"{runtime_role}_bucket_{int(request.bucket_ordinal)}_{logical_name}", + tensor_role=runtime_role, + shape=request_shape, + dtype=str(dtype), + device=str(device), + workspace_class="transition_workspace", + alias_set=memory_entry.alias_set, + init="empty", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=memory_entry.surface, + bucket_ordinal=int(request.bucket_ordinal), + effect=memory_entry.effect, + runtime_role=runtime_role, + logical_index=( + int(request.bucket_ordinal) if request.logical_index is None else int(request.logical_index) + ), + ) + ) + if include_workspace_rows: + existing_memory_rows = {int(spec.memory_row_index) for spec in specs if spec.memory_row_index is not None} + for memory_row_index, entry in enumerate(memory_plan.entries): + if int(memory_row_index) in existing_memory_rows: + continue + if not _memory_entry_requires_runtime_buffer(entry): + continue + specs.append( + TemporalRuntimeBufferSpec( + name=f"workspace_row_{int(memory_row_index)}_{entry.effect}", + tensor_role=entry.tensor_role, + shape=(1,), + dtype=str(dtype), + device=str(device), + workspace_class=entry.workspace_class, + alias_set=entry.alias_set, + init="empty", + owner="compiler_memory_liveness_plan", + memory_row_index=int(memory_row_index), + surface=entry.surface, + bucket_ordinal=int(entry.bucket_ordinal), + effect=entry.effect, + runtime_role="workspace", + ) + ) + buffer_plan = TemporalRuntimeBufferPlan( + specs=tuple(specs), + runtime_policy=runtime_policy, + runtime_schedule_fingerprint=runtime_schedule_fingerprint, + runtime_schedule_rows=runtime_schedule_rows, + ) + validate_temporal_runtime_buffer_plan( + memory_plan, + buffer_plan, + require_runtime_schedule=runtime_schedule_plan is not None, + require_workspace_coverage=include_workspace_rows, + ) + return buffer_plan + + +def validate_temporal_runtime_buffer_plan( + memory_plan: TemporalMemoryLivenessPlan, + buffer_plan: TemporalRuntimeBufferPlan, + *, + require_runtime_schedule: bool = False, + require_workspace_coverage: bool = False, +) -> None: + runtime_policy = temporal_memory_runtime_policy(memory_plan) + runtime_policy.require_complete() + if buffer_plan.runtime_policy != runtime_policy: + raise RuntimeError("Temporal runtime buffer plan policy does not match compiler memory liveness policy") + if require_runtime_schedule: + if not buffer_plan.runtime_schedule_fingerprint: + raise RuntimeError("Temporal runtime buffer plan is missing compiler runtime schedule fingerprint") + if buffer_plan.runtime_schedule_rows is None: + raise RuntimeError("Temporal runtime buffer plan is missing compiler runtime schedule rows") + memory_entry_count = len(memory_plan.entries) + seen_names: set[str] = set() + seen_runtime_roles: set[tuple[str, int]] = set() + covered_memory_rows: set[int] = set() + runtime_alias_groups: dict[str, list[TemporalRuntimeBufferSpec]] = {} + for spec in buffer_plan.specs: + if spec.name in seen_names: + raise RuntimeError(f"Temporal runtime buffer plan has duplicate buffer name {spec.name!r}") + seen_names.add(spec.name) + if spec.owner != "compiler_memory_liveness_plan": + raise RuntimeError( + "Temporal runtime buffer spec must be owned by the compiler memory liveness plan: " + f"buffer={spec.name}; owner={spec.owner}" + ) + if spec.allocation not in {"eager", "deferred_local"}: + raise RuntimeError( + "Temporal runtime buffer spec has an unsupported allocation mode: " + f"buffer={spec.name}; allocation={spec.allocation}" + ) + if spec.allocation == "deferred_local": + if spec.init != "empty": + raise RuntimeError( + f"Deferred local runtime buffers must use empty init: buffer={spec.name}; init={spec.init}" + ) + if not _runtime_buffer_role_allows_deferred_local(spec.runtime_role): + raise RuntimeError( + "Deferred local runtime buffers are only legal for compiler-routed step-local outputs: " + f"buffer={spec.name}; runtime_role={spec.runtime_role}" + ) + if _runtime_public_state_alias_enabled(spec.alias_set): + raise RuntimeError( + "Deferred local runtime buffers must not overlap public-state alias groups: " + f"buffer={spec.name}; alias={spec.alias_set}" + ) + if spec.memory_row_index is None: + raise RuntimeError(f"Temporal runtime buffer spec must reference a compiler memory row: buffer={spec.name}") + memory_row_index = int(spec.memory_row_index) + if memory_row_index < 0 or memory_row_index >= memory_entry_count: + raise RuntimeError( + "Temporal runtime buffer spec references an invalid compiler memory row: " + f"buffer={spec.name}; memory_row={memory_row_index}; row_count={memory_entry_count}" + ) + memory_entry = memory_plan.entries[memory_row_index] + if not _memory_entry_requires_runtime_buffer(memory_entry): + raise RuntimeError( + "Temporal runtime buffer spec references a non-runtime memory row: " + f"buffer={spec.name}; memory_row={memory_row_index}; workspace={memory_entry.workspace_class}; " + f"effect={memory_entry.effect}" + ) + if spec.workspace_class != memory_entry.workspace_class: + raise RuntimeError( + "Temporal runtime buffer spec workspace does not match compiler memory row: " + f"buffer={spec.name}; spec={spec.workspace_class}; memory={memory_entry.workspace_class}" + ) + if spec.effect and spec.effect != memory_entry.effect: + raise RuntimeError( + "Temporal runtime buffer spec effect does not match compiler memory row: " + f"buffer={spec.name}; spec={spec.effect}; memory={memory_entry.effect}" + ) + if spec.surface and spec.surface != memory_entry.surface: + raise RuntimeError( + "Temporal runtime buffer spec surface does not match compiler memory row: " + f"buffer={spec.name}; spec={spec.surface}; memory={memory_entry.surface}" + ) + if spec.runtime_role != "workspace": + role_key = (spec.runtime_role, int(spec.logical_index)) + if role_key in seen_runtime_roles: + raise RuntimeError( + "Temporal runtime buffer plan has duplicate runtime role/logical index: " + f"role={spec.runtime_role}; logical_index={int(spec.logical_index)}" + ) + seen_runtime_roles.add(role_key) + if _runtime_public_state_alias_enabled(spec.alias_set): + runtime_alias_groups.setdefault(spec.alias_set, []).append(spec) + covered_memory_rows.add(memory_row_index) + _validate_runtime_public_state_alias_groups(runtime_alias_groups) + if require_workspace_coverage: + required_rows = { + int(memory_row_index) + for memory_row_index, memory_entry in enumerate(memory_plan.entries) + if _memory_entry_requires_runtime_buffer(memory_entry) + } + missing = tuple(sorted(required_rows - covered_memory_rows)) + if missing: + raise RuntimeError( + "Temporal runtime buffer plan does not cover all executable compiler memory rows: " + f"missing={_int_tuple_summary(missing)}" + ) + + +def temporal_runtime_buffer_rows_tensor(buffer_plan: TemporalRuntimeBufferPlan) -> torch.Tensor: + rows: list[list[int]] = [] + alias_to_index: dict[str, int] = {} + + def alias_index(alias_set: str) -> int: + index = alias_to_index.get(alias_set) + if index is None: + index = len(alias_to_index) + alias_to_index[alias_set] = int(index) + return int(index) + + for buffer_index, spec in enumerate(buffer_plan.specs): + rows.append( + [ + int(buffer_index), + -1 if spec.memory_row_index is None else int(spec.memory_row_index), + _memory_opcode(_MEMORY_WORKSPACE_CLASS_OPCODE, spec.workspace_class, "workspace_class"), + 0 if not spec.surface else int(temporal_surface_opcode(spec.surface)), + int(spec.bucket_ordinal), + 0 if not spec.effect else _memory_opcode(_MEMORY_EFFECT_OPCODE, spec.effect, "effect"), + alias_index(spec.alias_set), + 1 if spec.init == "zeros" else 0, + _memory_opcode(_RUNTIME_BUFFER_ROLE_OPCODE, spec.runtime_role, "runtime_buffer_role"), + int(spec.logical_index), + ] + ) + if not rows: + return torch.empty((0, 10), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_runtime_buffer_spec( + buffer_plan: TemporalRuntimeBufferPlan, + *, + name: str, +) -> TemporalRuntimeBufferSpec: + matches = tuple(spec for spec in buffer_plan.specs if spec.name == str(name)) + if len(matches) != 1: + raise RuntimeError( + "Temporal runtime buffer plan must provide a unique compiler-owned buffer spec: " + f"name={name!r}; count={len(matches)}" + ) + return matches[0] + + +def allocate_temporal_runtime_buffer( + reference: torch.Tensor, + spec: TemporalRuntimeBufferSpec, +) -> torch.Tensor: + if not torch.is_tensor(reference): + raise RuntimeError("Temporal runtime buffer allocation requires a tensor reference") + expected_device = str(reference.device) + if spec.device != expected_device: + raise RuntimeError( + "Temporal runtime buffer spec device does not match allocation reference: " + f"spec={spec.device}; reference={expected_device}; buffer={spec.name}" + ) + dtype = _TORCH_DTYPE_BY_SPEC.get(str(spec.dtype)) + if dtype is None: + raise RuntimeError(f"Temporal runtime buffer spec has unsupported dtype {spec.dtype!r}: buffer={spec.name}") + shape = tuple(int(item) for item in spec.shape) + if spec.allocation == "deferred_local": + if spec.init != "empty": + raise RuntimeError( + f"Deferred local temporal runtime buffers must use empty init: buffer={spec.name}; init={spec.init}" + ) + return torch.empty((0,), device=reference.device, dtype=dtype) + if spec.init == "zeros": + return torch.zeros(shape, device=reference.device, dtype=dtype) + if spec.init == "empty": + return torch.empty(shape, device=reference.device, dtype=dtype) + raise RuntimeError(f"Unsupported temporal runtime buffer init policy {spec.init!r}") + + +def allocate_temporal_runtime_buffers( + reference: torch.Tensor, + buffer_plan: TemporalRuntimeBufferPlan, +) -> tuple[torch.Tensor, ...]: + alias_allocations: dict[tuple[str, tuple[int, ...], str, str, str], torch.Tensor] = {} + buffers: list[torch.Tensor] = [] + for spec in buffer_plan.specs: + if spec.allocation == "deferred_local": + buffers.append(allocate_temporal_runtime_buffer(reference, spec)) + continue + alias_key = _runtime_buffer_alias_allocation_key(buffer_plan.runtime_policy, spec) + if alias_key is not None and alias_key in alias_allocations: + buffers.append(alias_allocations[alias_key]) + continue + buffer = allocate_temporal_runtime_buffer(reference, spec) + if alias_key is not None: + alias_allocations[alias_key] = buffer + buffers.append(buffer) + return tuple(buffers) + + +def _slot_entry(slot: TemporalTensorTableSlot) -> TemporalMemoryPlanEntry: + return TemporalMemoryPlanEntry( + row_index=None, + bucket_ordinal=int(slot.bucket_ordinal), + surface="transition", + tensor_role=f"{slot.table_role}:{slot.key}", + tensor_class=_slot_tensor_class(slot), + layout="contiguous", + lifetime=_slot_lifetime(slot), + workspace_class=_slot_workspace_class(slot), + alias_set=f"bucket{int(slot.bucket_ordinal)}.{slot.table_role}.{slot.slot}", + recompute_policy=_slot_recompute_policy(slot), + effect=_slot_effect(slot), + owner="compiler_tensor_role_table", + ) + + +def _primitive_effect_entry( + *, + row_index: int, + surface: str, + bucket_ordinal: int, + effect: str, +) -> TemporalMemoryPlanEntry: + return TemporalMemoryPlanEntry( + row_index=int(row_index), + bucket_ordinal=int(bucket_ordinal), + surface=surface, + tensor_role=_effect_tensor_role(effect), + tensor_class=_effect_tensor_class(effect), + layout="contiguous", + lifetime=_effect_lifetime(effect), + workspace_class=_effect_workspace_class(effect), + alias_set=f"{surface}.{effect}.row{int(row_index)}", + recompute_policy=_effect_recompute_policy(effect), + effect=effect, + owner="compiler_primitive_row", + ) + + +def _transition_reverse_dynamic_buffer_entries( + table: TemporalPrimitiveTablePlan, +) -> tuple[TemporalMemoryPlanEntry, ...]: + transition_buckets = tuple( + dict.fromkeys( + int(row.bucket_ordinal) for row in table.primitive_rows if surface_for_temporal_row(row) == "transition" + ) + ) + entries: list[TemporalMemoryPlanEntry] = [] + for bucket_ordinal in transition_buckets: + entries.append( + TemporalMemoryPlanEntry( + row_index=None, + bucket_ordinal=int(bucket_ordinal), + surface="transition", + tensor_role="transition_reverse_recurrent_msg_span", + tensor_class="message_activation", + layout="contiguous", + lifetime="primitive_scope", + workspace_class="transition_workspace", + alias_set=f"transition.bucket{int(bucket_ordinal)}.reverse_dynamic.recurrent_msg_span", + recompute_policy="not_recomputed", + effect="message_read", + owner="compiler_tensor_role_table", + ) + ) + entries.append( + TemporalMemoryPlanEntry( + row_index=None, + bucket_ordinal=int(bucket_ordinal), + surface="transition", + tensor_role="transition_reverse_state_before_zero", + tensor_class="private_state", + layout="contiguous", + lifetime="primitive_scope", + workspace_class="transition_workspace", + alias_set=f"transition.bucket{int(bucket_ordinal)}.reverse_dynamic.state_before_zero", + recompute_policy="not_recomputed", + effect="state_read", + owner="compiler_tensor_role_table", + ) + ) + return tuple(entries) + + +def _memory_policy_entries(table: TemporalPrimitiveTablePlan) -> tuple[TemporalMemoryPlanEntry, ...]: + del table + policies = ( + ( + "local_seed_row", + "metadata_descriptor", + "local_seed_policy", + "policy_not_recomputed", + "local_seed_row_policy", + ), + ( + "metadata_row", + "metadata_descriptor", + "metadata_policy", + "policy_not_recomputed", + "metadata_row_policy", + ), + ( + "primitive_output", + "primitive_output", + "primitive_output_policy", + "recompute_or_store_by_scheduler", + "primitive_output_policy", + ), + ( + "tape", + "transition_tape", + "tape_policy", + "scheduler_tape_policy", + "transition_tape_policy", + ), + ( + "alias", + "runtime_policy", + "alias_policy", + "scheduler_alias_policy", + "alias_set_policy", + ), + ( + "recompute_window", + "runtime_policy", + "recompute_window_policy", + "scheduler_recompute_window_policy", + "recompute_window_policy", + ), + ( + "materialization", + "runtime_policy", + "materialization_policy", + "scheduler_materialization_policy", + "materialization_policy", + ), + ( + "cuda_graph", + "runtime_policy", + "cuda_graph_constraint", + "cuda_graph_guard_policy", + "cuda_graph_capture_constraint", + ), + ) + return tuple( + TemporalMemoryPlanEntry( + row_index=None, + bucket_ordinal=0, + surface="runtime_policy", + tensor_role=tensor_role, + tensor_class=tensor_class, + layout="metadata", + lifetime="compiled_program", + workspace_class="policy_table", + alias_set=f"runtime_policy.{alias}", + recompute_policy=recompute_policy, + effect=effect, + owner="compiler_memory_policy", + ) + for tensor_role, tensor_class, effect, recompute_policy, alias in policies + ) + + +def _slot_tensor_class(slot: TemporalTensorTableSlot) -> str: + if slot.table_role == "private_state": + return "private_state" + if slot.table_role == "public_state": + return "public_state" + if slot.table_role == "transition_params": + return "parameter_binding" + return str(slot.semantic_kind) + + +def _slot_lifetime(slot: TemporalTensorTableSlot) -> str: + if slot.table_role in {"private_state", "public_state"}: + return "temporal_carry" + if slot.table_role == "transition_params": + return "compiled_program" + return "primitive_scope" + + +def _slot_workspace_class(slot: TemporalTensorTableSlot) -> str: + if slot.table_role in {"private_state", "public_state"}: + return "state_carry_workspace" + if slot.table_role == "transition_params": + return "parameter_table" + return "tensor_table_workspace" + + +def _slot_recompute_policy(slot: TemporalTensorTableSlot) -> str: + if slot.table_role in {"private_state", "public_state"}: + return "carry_forward_or_checkpoint" + if slot.table_role == "transition_params": + return "not_recomputed" + return "planner_policy" + + +def _slot_effect(slot: TemporalTensorTableSlot) -> str: + if slot.table_role == "private_state": + return "state_read_write" + if slot.table_role == "public_state": + return "state_emit" + if slot.table_role == "transition_params": + return "parameter_read" + return "tensor_role" + + +def _effect_tensor_role(effect: str) -> str: + if effect == "state_read": + return "state_input" + if effect == "state_write": + return "state_output" + if effect == "message_read": + return "message_input" + if effect == "message_emit": + return "message_output" + if effect == "output_emit": + return "output_activation" + if effect == "materialization_boundary": + return "output_boundary" + if effect == "tape_policy": + return "transition_tape" + if effect == "grad_read": + return "grad_input" + if effect == "parameter_grad_emit": + return "parameter_grad_output" + if effect == "parameter_read": + return "parameter_input" + return effect + + +def _effect_tensor_class(effect: str) -> str: + if effect == "state_read": + return "state_activation" + if effect == "state_write": + return "state_carry" + if effect == "message_read": + return "message_activation" + if effect == "message_emit": + return "message_activation" + if effect == "output_emit": + return "output_activation" + if effect == "materialization_boundary": + return "materialized_output_boundary" + if effect == "tape_policy": + return "transition_tape" + if effect == "grad_read": + return "backward_input_grad" + if effect == "parameter_grad_emit": + return "parameter_grad_accumulator" + if effect == "parameter_read": + return "parameter_binding" + return "unknown_tensor_class" + + +def _effect_lifetime(effect: str) -> str: + if effect in {"message_emit", "message_read", "output_emit"}: + return "physical_step" + if effect in {"state_read", "state_write"}: + return "temporal_carry" + if effect == "materialization_boundary": + return "output_window" + if effect == "tape_policy": + return "backward_window" + if effect in {"grad_read", "parameter_grad_emit"}: + return "backward_window" + if effect == "parameter_read": + return "compiled_program" + return "primitive_scope" + + +def _effect_workspace_class(effect: str) -> str: + if effect in {"message_emit", "message_read"}: + return "message_workspace" + if effect in {"state_read", "state_write"}: + return "state_carry_workspace" + if effect in {"output_emit", "materialization_boundary"}: + return "output_workspace" + if effect == "tape_policy": + return "transition_workspace" + if effect in {"grad_read", "parameter_grad_emit"}: + return "reduction_workspace" + if effect == "parameter_read": + return "parameter_table" + return "primitive_workspace" + + +def _effect_recompute_policy(effect: str) -> str: + if effect in {"message_emit", "message_read", "output_emit"}: + return "recompute_or_store_by_scheduler" + if effect in {"state_read", "state_write"}: + return "checkpoint_boundary_policy" + if effect == "materialization_boundary": + return "materialize_when_requested" + if effect == "tape_policy": + return "scheduler_tape_policy" + if effect in {"grad_read", "parameter_grad_emit"}: + return "accumulate_window_then_bind" + if effect == "parameter_read": + return "not_recomputed" + return "planner_policy" + + +def _scheduler_materialization_value(scheduler_plan: Any | None, name: str, default: str) -> str: + materialization = None if scheduler_plan is None else getattr(scheduler_plan, "materialization", None) + return str(default if materialization is None else getattr(materialization, name, default)) + + +def _scheduler_int_value( + scheduler_plan: Any | None, + *, + materialization_name: str, + checkpoint_name: str, +) -> int | None: + if scheduler_plan is None: + return None + materialization = getattr(scheduler_plan, "materialization", None) + checkpoint = getattr(scheduler_plan, "checkpoint", None) + value = None if materialization is None else getattr(materialization, materialization_name, None) + if value is None and checkpoint is not None: + value = getattr(checkpoint, checkpoint_name, None) + return None if value is None else max(1, int(value)) + + +def _scheduler_output_physical_steps(scheduler_plan: Any | None) -> tuple[int, ...]: + output_emissions = None if scheduler_plan is None else getattr(scheduler_plan, "output_emissions", None) + physical_steps = () if output_emissions is None else getattr(output_emissions, "physical_steps", ()) + return tuple(int(step) for step in physical_steps) + + +def _bounded_window_value(value: int | None, physical_time_steps: int) -> int: + if value is None: + return max(1, int(physical_time_steps)) + return max(1, min(int(physical_time_steps), int(value))) + + +def _planned_checkpoint_steps(*, checkpoint_stride: int, physical_time_steps: int) -> tuple[int, ...]: + stride = max(1, int(checkpoint_stride)) + return tuple(range(0, max(1, int(physical_time_steps)), stride)) + + +def _planned_backward_windows( + *, + mode: TemporalRuntimeArtifactMode, + physical_time_steps: int, + checkpoint_stride: int, + recompute_window_len: int, +) -> tuple[tuple[int, int], ...]: + if mode == "none": + return () + segment_stride = int(physical_time_steps) if mode == "store_step_artifacts" else max(1, int(checkpoint_stride)) + tile = max(1, int(recompute_window_len)) + windows: list[tuple[int, int]] = [] + segment_start = 0 + while segment_start < int(physical_time_steps): + segment_end = min(int(physical_time_steps), segment_start + segment_stride) + window_start = segment_start + while window_start < segment_end: + window_end = min(segment_end, window_start + tile) + windows.append((int(window_start), int(window_end))) + window_start = window_end + segment_start = segment_end + return tuple(windows) + + +def _runtime_alias_for( + memory_plan: TemporalMemoryLivenessPlan, + *, + workspace_class: str, + preferred_effect: str, +) -> str: + preferred = tuple( + entry.alias_set + for entry in memory_plan.entries + if entry.workspace_class == workspace_class and entry.effect == preferred_effect + ) + if preferred: + return preferred[0] + candidates = tuple(entry.alias_set for entry in memory_plan.entries if entry.workspace_class == workspace_class) + if candidates: + return candidates[0] + raise RuntimeError( + "Temporal runtime buffer allocation requires a compiler memory-plan workspace entry: " + f"workspace_class={workspace_class}; preferred_effect={preferred_effect}" + ) + + +def _runtime_memory_entry_for( + memory_plan: TemporalMemoryLivenessPlan, + *, + workspace_class: str, + preferred_effect: str, + row_index: int | None = None, +) -> tuple[int, TemporalMemoryPlanEntry]: + preferred = tuple( + (entry_index, entry) + for entry_index, entry in enumerate(memory_plan.entries) + if entry.workspace_class == workspace_class + and entry.effect == preferred_effect + and (row_index is None or entry.row_index == int(row_index)) + ) + if preferred: + return preferred[0] + candidates = tuple( + (entry_index, entry) + for entry_index, entry in enumerate(memory_plan.entries) + if entry.workspace_class == workspace_class and (row_index is None or entry.row_index == int(row_index)) + ) + if candidates: + return candidates[0] + raise RuntimeError( + "Temporal runtime buffer allocation requires a compiler memory-plan workspace entry: " + f"workspace_class={workspace_class}; preferred_effect={preferred_effect}" + ) + + +def _memory_entry_requires_runtime_buffer(entry: TemporalMemoryPlanEntry) -> bool: + if entry.workspace_class == "parameter_table": + return False + if entry.lifetime == "compiled_program" and entry.effect == "parameter_read": + return False + return entry.workspace_class in { + "message_workspace", + "output_workspace", + "primitive_workspace", + "reduction_workspace", + "state_carry_workspace", + "tensor_table_workspace", + "transition_workspace", + } + + +def _memory_entry_is_policy(entry: TemporalMemoryPlanEntry) -> bool: + return ( + entry.owner == "compiler_memory_policy" + or entry.workspace_class == "policy_table" + or entry.surface == "runtime_policy" + ) + + +def _runtime_policy_reason(policy: TemporalMemoryRuntimePolicy) -> str: + fields = tuple(f"{effect}={policy.policy_for(effect)}" for effect in _REQUIRED_MEMORY_POLICY_EFFECTS) + return "memory_runtime_policy=" + ",".join(fields) + + +def _runtime_buffer_alias_allocation_key( + policy: TemporalMemoryRuntimePolicy | None, + spec: TemporalRuntimeBufferSpec, +) -> tuple[str, tuple[int, ...], str, str, str] | None: + if policy is None or not policy.alias_allocation_enabled: + return None + if spec.init != "empty": + return None + if spec.runtime_role != "workspace" and not _runtime_public_state_alias_enabled(spec.alias_set): + return None + return ( + spec.alias_set, + tuple(int(item) for item in spec.shape), + str(spec.dtype), + str(spec.device), + str(spec.init), + ) + + +def _runtime_public_state_alias_set(shape: tuple[int, ...] | None) -> str: + if shape is None: + return "" + return _RUNTIME_PUBLIC_STATE_ALIAS_PREFIX + _int_tuple_summary(shape) + + +def _runtime_public_state_alias_enabled(alias_set: str) -> bool: + return str(alias_set).startswith(_RUNTIME_PUBLIC_STATE_ALIAS_PREFIX) + + +def _runtime_buffer_role_allows_deferred_local(runtime_role: str) -> bool: + return str(runtime_role) in { + "forward_recurrent_hidden_after", + "forward_recurrent_msg", + "forward_output_msg", + "forward_output_cells", + "transition_forward_linear_output", + "transition_forward_matmul_output", + "transition_forward_state_output", + "transition_forward_norm_output", + "transition_forward_diag_output", + "transition_forward_unary_output", + } + + +def _transition_forward_request_aliases_recurrent_hidden( + request: TemporalTransitionForwardRuntimeBufferRequest, + *, + recurrent_hidden_shape: tuple[int, ...] | None, + recurrent_hidden_step_count: int, +) -> bool: + if request.alias_runtime_role != "forward_recurrent_hidden_after": + return False + if recurrent_hidden_shape is None or int(recurrent_hidden_step_count) != 1: + return False + return tuple(int(item) for item in request.shape) == tuple(int(item) for item in recurrent_hidden_shape) + + +def _validate_runtime_public_state_alias_groups( + groups: dict[str, list[TemporalRuntimeBufferSpec]], +) -> None: + for alias_set, specs in groups.items(): + roles = {str(spec.runtime_role) for spec in specs} + if "forward_recurrent_hidden_after" not in roles: + raise RuntimeError( + "Temporal public-state runtime alias must include the recurrent hidden carry buffer: " + f"alias={alias_set}; roles={_tuple_summary(tuple(sorted(roles)))}" + ) + if sum(1 for spec in specs if spec.runtime_role == "forward_recurrent_hidden_after") != 1: + raise RuntimeError( + "Temporal public-state runtime alias must have exactly one recurrent hidden carry buffer: " + f"alias={alias_set}" + ) + if not any(str(spec.runtime_role).startswith("transition_forward_") for spec in specs): + raise RuntimeError( + "Temporal public-state runtime alias must include a transition forward output buffer: " + f"alias={alias_set}; roles={_tuple_summary(tuple(sorted(roles)))}" + ) + shapes = {tuple(int(item) for item in spec.shape) for spec in specs} + dtypes = {str(spec.dtype) for spec in specs} + devices = {str(spec.device) for spec in specs} + inits = {str(spec.init) for spec in specs} + if len(shapes) != 1 or len(dtypes) != 1 or len(devices) != 1 or inits != {"empty"}: + raise RuntimeError( + "Temporal public-state runtime alias requires exact shape, dtype, device, and empty-init match: " + f"alias={alias_set}" + ) + + +def _require_runtime_schedule_matches_policy( + runtime_policy: TemporalMemoryRuntimePolicy, + runtime_schedule_plan: TemporalMemoryRuntimeSchedulePlan, +) -> None: + for effect in _REQUIRED_MEMORY_POLICY_EFFECTS: + policy = runtime_policy.policy_for(effect) + schedule_policy = getattr(runtime_schedule_plan, effect, "") + if policy != schedule_policy: + raise RuntimeError( + "Temporal runtime buffer planning received a schedule plan whose policy rows do not match " + f"the compiler memory policy: effect={effect}; policy={policy}; schedule={schedule_policy}" + ) + + +def _schedule_rows_summary(rows: torch.Tensor | None) -> str: + if rows is None: + return "none" + if not torch.is_tensor(rows): + return "invalid" + return "x".join(str(int(item)) for item in rows.shape) + + +def _runtime_physical_step_count(value: int | None, *, role: str) -> int: + if value is None: + raise RuntimeError(f"Temporal runtime buffer role {role!r} requires physical_time_steps") + return max(1, int(value)) + + +def _int_tuple_summary(values: tuple[int, ...]) -> str: + return "none" if not values else ",".join(str(int(value)) for value in values) + + +def _window_tuple_summary(values: tuple[tuple[int, int], ...]) -> str: + return "none" if not values else ";".join(f"{int(start)}:{int(end)}" for start, end in values) + + +def _tuple_summary(values: tuple[str, ...]) -> str: + return "none" if not values else ",".join(values) + + +def _effect_row_tuple_summary(values: tuple[tuple[str, int], ...]) -> str: + return "none" if not values else ",".join(f"{effect}:{int(row_index)}" for effect, row_index in values) + + +def _bytes_by_summary(values: tuple[tuple[str, int], ...]) -> str: + return "none" if not values else ",".join(f"{name}:{int(value)}" for name, value in values) + + +def _runtime_dtype_element_size(dtype: str) -> int: + torch_dtype = _TORCH_DTYPE_BY_SPEC.get(str(dtype)) + if torch_dtype is None: + return 0 + return int(torch.empty((), dtype=torch_dtype).element_size()) + + +def _memory_opcode(mapping: dict[str, int], value: str, field: str) -> int: + try: + return int(mapping[str(value)]) + except KeyError as error: + raise RuntimeError( + f"Temporal memory liveness row has unregistered {field} value {value!r}; " + "register it before lowering the memory plan into fused-program rows" + ) from error + + +def _physical_strategy_mask(mapping: dict[str, int], values: tuple[str, ...], field: str) -> int: + mask = 0 + for value in values: + mask |= _memory_opcode(mapping, value, field) + return int(mask) + + +__all__ = [ + "allocate_temporal_runtime_buffers", + "allocate_temporal_runtime_buffer", + "build_temporal_physical_strategy_plan", + "TemporalMemoryRuntimeArtifactPlan", + "TemporalMemoryRuntimePolicy", + "TemporalMemoryRuntimeSchedulePlan", + "TemporalMemoryLivenessPlan", + "TemporalMemoryPlanEntry", + "TemporalPhysicalStrategyPlan", + "TemporalPhysicalStrategyRow", + "TemporalRuntimeBufferPlan", + "TemporalRuntimeBufferSpec", + "TemporalTransitionForwardRuntimeBufferRequest", + "TemporalTransitionReverseRuntimeBufferRequest", + "build_temporal_memory_liveness_plan", + "build_temporal_memory_runtime_artifact_plan", + "build_temporal_memory_runtime_schedule_plan", + "build_temporal_runtime_buffer_plan", + "temporal_memory_liveness_rows_tensor", + "temporal_memory_runtime_policy", + "temporal_memory_runtime_schedule_rows_tensor", + "temporal_physical_strategy_rows_tensor", + "temporal_runtime_buffer_role_opcode", + "temporal_runtime_buffer_rows_tensor", + "temporal_runtime_buffer_spec", + "validate_temporal_runtime_buffer_plan", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py new file mode 100644 index 00000000..9be6f2d3 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py @@ -0,0 +1,1400 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Literal, cast + +import torch + +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_patterns import ( + temporal_executor_strategy_registry, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.memory_plan import ( + temporal_runtime_buffer_role_opcode, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.primitive_registry import ( + temporal_primitive_opcode, + temporal_surface_opcode, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.reducer_patterns import ( + temporal_parameter_reducer_pattern, + temporal_parameter_reducer_patterns, + temporal_transition_trainable_reducer_pattern, + temporal_transition_trainable_reducer_patterns, +) +from cortical.fabric.backend.cuda.transition_execution.registry import ( + registered_transition_primitive_executor_records, + transition_primitive_executor_record_for_lowered_primitive, +) + + +TemporalNativeCallableCategory = Literal[ + "executor_strategy", + "transition_primitive_forward", + "transition_primitive_backward", + "parameter_reducer", + "transition_trainable_reducer", +] +TemporalNativeCallableDirection = Literal["none", "forward", "reverse"] +TemporalNativeCallableOutputShapeKind = Literal["hidden", "gate_logits", "diagonal_preproj"] +TemporalNativeCallableLogicalIndexSource = Literal["primitive_row", "binding_index"] +TemporalNativeCallableOutputInit = Literal["empty", "zeros"] +TemporalNativeCallableBindingKind = Literal["input", "parameter", "output"] +TemporalTransitionReverseSeedKind = Literal["public_grad", "state_grad"] + +_NATIVE_CALLABLE_CATEGORY_OPCODE = { + "executor_strategy": 1, + "transition_primitive_forward": 2, + "transition_primitive_backward": 3, + "parameter_reducer": 4, + "transition_trainable_reducer": 5, +} +_NATIVE_CALLABLE_DIRECTION_OPCODE = { + "none": 0, + "forward": 1, + "reverse": 2, +} +_NATIVE_CALLABLE_SURFACE_OPCODE = { + "none": 0, + "message": temporal_surface_opcode("message"), + "transition": temporal_surface_opcode("transition"), + "readout": temporal_surface_opcode("readout"), + "parameter_reduction": temporal_surface_opcode("parameter_reduction"), +} +_NATIVE_CALLABLE_OUTPUT_SHAPE_KIND_OPCODE = { + "hidden": 1, + "gate_logits": 2, + "diagonal_preproj": 3, +} +_NATIVE_CALLABLE_LOGICAL_INDEX_SOURCE_OPCODE = { + "primitive_row": 1, + "binding_index": 2, +} +_NATIVE_CALLABLE_OUTPUT_INIT_OPCODE = { + "empty": 0, + "zeros": 1, +} +_NATIVE_CALLABLE_BINDING_KIND_OPCODE = { + "input": 1, + "parameter": 2, + "output": 3, +} +_TRANSITION_REVERSE_SEED_KIND_OPCODE = { + "public_grad": 1, + "state_grad": 2, +} +_OPTIONAL_TRANSITION_REVERSE_PARAMETERS = frozenset({"outnorm_eps", "eps", "activation_id"}) + + +@dataclass(frozen=True) +class TemporalNativeCallableDefinition: + callable_id: str + category: TemporalNativeCallableCategory + direction: TemporalNativeCallableDirection + surface: str + primitive: str + implementation_symbol: str + version: int = 1 + display_name: str = "" + cxx_entrypoints: tuple[str, ...] = () + cxx_entrypoint_phases: tuple[str, ...] = () + + @property + def callable_hash(self) -> int: + return temporal_strategy_id_hash(self.callable_id) + + @property + def implementation_symbol_hash(self) -> int: + return temporal_strategy_id_hash(self.implementation_symbol) + + @property + def summary(self) -> str: + return ( + f"callable={self.callable_id}" + f",category={self.category}" + f",direction={self.direction}" + f",surface={self.surface}" + f",primitive={self.primitive or '-'}" + f",symbol={self.implementation_symbol}" + f",version={int(self.version)}" + f",cxx={'+'.join(self.cxx_entrypoints) if self.cxx_entrypoints else '-'}" + f",cxx_phases={'+'.join(self.cxx_entrypoint_phases) if self.cxx_entrypoint_phases else '-'}" + ) + + +@dataclass(frozen=True) +class TemporalNativeCallableOutputDefinition: + callable_id: str + primitive: str + output_name: str + output_index: int + runtime_role: str + shape_kind: TemporalNativeCallableOutputShapeKind + logical_index_source: TemporalNativeCallableLogicalIndexSource + init: TemporalNativeCallableOutputInit = "empty" + direction: TemporalNativeCallableDirection = "forward" + surface: str = "transition" + version: int = 1 + + @property + def callable_hash(self) -> int: + return temporal_strategy_id_hash(self.callable_id) + + @property + def output_name_hash(self) -> int: + return temporal_strategy_id_hash(self.output_name) + + @property + def summary(self) -> str: + return ( + f"callable={self.callable_id}" + f",direction={self.direction}" + f",surface={self.surface}" + f",primitive={self.primitive}" + f",output={self.output_name}" + f",output_index={int(self.output_index)}" + f",runtime_role={self.runtime_role}" + f",shape={self.shape_kind}" + f",logical_index_source={self.logical_index_source}" + f",init={self.init}" + f",version={int(self.version)}" + ) + + def shape(self, *, batch_size: int, receiver_count: int, hidden_size: int) -> tuple[int, ...]: + if self.shape_kind == "gate_logits": + return (int(batch_size), int(receiver_count), 4, int(hidden_size)) + if self.shape_kind == "diagonal_preproj": + return (int(batch_size), int(receiver_count), 2 * int(hidden_size)) + return (int(batch_size), int(receiver_count), int(hidden_size)) + + def logical_index(self, *, primitive_row_index: int, binding_index: int) -> int: + if self.logical_index_source == "binding_index": + return int(binding_index) + return int(primitive_row_index) + + +@dataclass(frozen=True) +class TemporalNativeCallableBindingSchemaDefinition: + callable_id: str + primitive: str + binding_kind: TemporalNativeCallableBindingKind + logical_name: str + local_binding_index: int + required: bool = True + direction: TemporalNativeCallableDirection = "forward" + surface: str = "transition" + version: int = 1 + + @property + def callable_hash(self) -> int: + return temporal_strategy_id_hash(self.callable_id) + + @property + def logical_name_hash(self) -> int: + return temporal_strategy_id_hash(self.logical_name) + + @property + def summary(self) -> str: + return ( + f"callable={self.callable_id}" + f",direction={self.direction}" + f",surface={self.surface}" + f",primitive={self.primitive}" + f",binding_kind={self.binding_kind}" + f",logical_name={self.logical_name}" + f",local_binding_index={int(self.local_binding_index)}" + f",required={int(self.required)}" + f",version={int(self.version)}" + ) + + +@dataclass(frozen=True) +class TemporalTransitionReverseSeedRoleDefinition: + role_name: str + role_id: int + seed_kind: TemporalTransitionReverseSeedKind + version: int = 1 + + @property + def role_name_hash(self) -> int: + return temporal_strategy_id_hash(self.role_name) + + @property + def summary(self) -> str: + return ( + f"reverse_seed_role={self.role_name}" + f",role_id={int(self.role_id)}" + f",kind={self.seed_kind}" + f",version={int(self.version)}" + ) + + +def temporal_strategy_id_hash(strategy_id: str) -> int: + value = str(strategy_id).encode("utf-8") + checksum = 2166136261 + for byte in value: + checksum ^= int(byte) + checksum = (checksum * 16777619) & 0xFFFFFFFF + checksum &= 0x7FFFFFFF + return int(checksum if checksum > 0 else 1) + + +def parameter_reducer_native_callable_id(reducer_kind: str) -> str: + return temporal_parameter_reducer_pattern(reducer_kind).native_callable + + +def transition_trainable_reducer_native_callable_id(reducer_kind: str) -> str: + return temporal_transition_trainable_reducer_pattern(reducer_kind).native_callable + + +def temporal_native_callable_definitions() -> tuple[TemporalNativeCallableDefinition, ...]: + definitions: list[TemporalNativeCallableDefinition] = [] + registry = temporal_executor_strategy_registry() + for direction, patterns in ( + ("forward", registry.forward_patterns()), + ("reverse", registry.reverse_patterns()), + ): + definitions.extend( + TemporalNativeCallableDefinition( + callable_id=pattern.stable_native_callable_id, + category="executor_strategy", + direction=direction, + surface=pattern.surface, + primitive=pattern.row_pattern[0].primitive if pattern.row_pattern else "", + implementation_symbol=pattern.implementation_contract, + version=int(pattern.strategy_version), + display_name=_strategy_display_name(pattern.stable_strategy_id, int(pattern.strategy_version)), + cxx_entrypoints=_strategy_cxx_entrypoints(pattern), + cxx_entrypoint_phases=_strategy_cxx_entrypoint_phases(pattern), + ) + for pattern in patterns + ) + for record in registered_transition_primitive_executor_records(): + if record.program_forward_status == "callable" and record.program_forward_symbol: + definitions.append( + TemporalNativeCallableDefinition( + callable_id=record.program_forward_symbol, + category="transition_primitive_forward", + direction="forward", + surface="transition", + primitive=record.primitive, + implementation_symbol=record.program_forward_symbol, + display_name=f"transition.{record.primitive}.forward", + cxx_entrypoints=(record.program_forward_cxx_entrypoint,), + ) + ) + if record.program_backward_status == "callable" and record.program_backward_symbol: + definitions.append( + TemporalNativeCallableDefinition( + callable_id=record.program_backward_symbol, + category="transition_primitive_backward", + direction="reverse", + surface="transition", + primitive=record.primitive, + implementation_symbol=record.program_backward_symbol, + display_name=f"transition.{record.primitive}.backward", + ) + ) + definitions.extend( + TemporalNativeCallableDefinition( + callable_id=pattern.native_callable, + category="parameter_reducer", + direction="reverse", + surface="parameter_reduction", + primitive="", + implementation_symbol=pattern.implementation_symbol, + display_name=pattern.reducer_kind, + cxx_entrypoints=pattern.stable_cxx_entrypoints, + ) + for pattern in temporal_parameter_reducer_patterns() + ) + definitions.extend( + TemporalNativeCallableDefinition( + callable_id=pattern.native_callable, + category="transition_trainable_reducer", + direction="reverse", + surface="parameter_reduction", + primitive="", + implementation_symbol=pattern.implementation_symbol, + display_name=pattern.reducer_kind, + cxx_entrypoints=pattern.stable_cxx_entrypoints, + ) + for pattern in temporal_transition_trainable_reducer_patterns() + ) + unique: dict[str, TemporalNativeCallableDefinition] = {} + for definition in definitions: + existing = unique.get(definition.callable_id) + if existing is None: + unique[definition.callable_id] = definition + continue + if ( + existing.category != definition.category + or existing.direction != definition.direction + or existing.surface != definition.surface + or existing.version != definition.version + or existing.cxx_entrypoints != definition.cxx_entrypoints + or existing.cxx_entrypoint_phases != definition.cxx_entrypoint_phases + ): + raise RuntimeError( + "Temporal native callable id is shared by incompatible executor definitions: " + f"callable={definition.callable_id!r}" + ) + unique[definition.callable_id] = replace( + existing, + primitive=existing.primitive or definition.primitive, + implementation_symbol=existing.callable_id, + display_name=existing.callable_id, + ) + return tuple(unique.values()) + + +def temporal_native_callable_catalog_rows_tensor() -> torch.Tensor: + rows = [ + [ + index, + int(definition.callable_hash), + int(_NATIVE_CALLABLE_CATEGORY_OPCODE[definition.category]), + int(_NATIVE_CALLABLE_DIRECTION_OPCODE[definition.direction]), + int(_NATIVE_CALLABLE_SURFACE_OPCODE.get(definition.surface, 0)), + _native_callable_primitive_opcode(definition.primitive), + int(definition.implementation_symbol_hash), + int(definition.version), + ] + for index, definition in enumerate(temporal_native_callable_definitions()) + ] + if not rows: + return torch.empty((0, 8), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_native_callable_output_definitions() -> tuple[TemporalNativeCallableOutputDefinition, ...]: + definitions: list[TemporalNativeCallableOutputDefinition] = [] + for record in registered_transition_primitive_executor_records(): + if record.program_forward_status != "callable" or not record.program_forward_symbol: + continue + output_contracts = tuple(record.program_forward_output_contracts) + if not output_contracts: + raise RuntimeError( + "Registered transition forward primitive has no compiler-owned output contract: " + f"primitive={record.primitive!r}; callable={record.program_forward_symbol!r}" + ) + definitions.extend( + TemporalNativeCallableOutputDefinition( + callable_id=record.program_forward_symbol, + primitive=record.primitive, + output_name=str(output_name), + output_index=int(output_index), + runtime_role=str(runtime_role), + shape_kind=cast(TemporalNativeCallableOutputShapeKind, str(shape_kind)), + logical_index_source=cast(TemporalNativeCallableLogicalIndexSource, str(logical_index_source)), + ) + for output_index, (output_name, runtime_role, shape_kind, logical_index_source) in enumerate( + output_contracts + ) + ) + return tuple(definitions) + + +def temporal_native_callable_output_rows_tensor() -> torch.Tensor: + rows = [ + [ + index, + int(definition.callable_hash), + int(_NATIVE_CALLABLE_DIRECTION_OPCODE[definition.direction]), + int(_NATIVE_CALLABLE_SURFACE_OPCODE.get(definition.surface, 0)), + _native_callable_primitive_opcode(definition.primitive), + int(definition.output_name_hash), + int(definition.output_index), + temporal_runtime_buffer_role_opcode(definition.runtime_role), + int(_NATIVE_CALLABLE_OUTPUT_SHAPE_KIND_OPCODE[definition.shape_kind]), + int(_NATIVE_CALLABLE_LOGICAL_INDEX_SOURCE_OPCODE[definition.logical_index_source]), + int(_NATIVE_CALLABLE_OUTPUT_INIT_OPCODE[definition.init]), + int(definition.version), + ] + for index, definition in enumerate(temporal_native_callable_output_definitions()) + ] + if not rows: + return torch.empty((0, 12), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_native_callable_binding_schema_definitions() -> tuple[ + TemporalNativeCallableBindingSchemaDefinition, + ..., +]: + definitions: list[TemporalNativeCallableBindingSchemaDefinition] = [] + registry = temporal_executor_strategy_registry() + for direction, patterns in ( + ("forward", registry.forward_patterns()), + ("reverse", registry.reverse_patterns()), + ): + for pattern in patterns: + primitive = pattern.row_pattern[0].primitive if pattern.row_pattern else "" + definitions.extend( + TemporalNativeCallableBindingSchemaDefinition( + callable_id=pattern.stable_native_callable_id, + primitive=primitive, + binding_kind=cast(TemporalNativeCallableBindingKind, access.binding_kind), + logical_name=access.access_name, + local_binding_index=local_index, + required=bool(access.required), + direction=cast(TemporalNativeCallableDirection, direction), + surface=pattern.surface, + version=int(pattern.tensor_binding_schema_version), + ) + for local_index, access in enumerate(pattern.program_accesses) + ) + for record in registered_transition_primitive_executor_records(): + if record.program_forward_status == "callable" and record.program_forward_symbol: + callable_id = record.program_forward_symbol + forward_inputs = tuple(record.program_forward_input_bindings) + forward_parameters = tuple(record.program_forward_parameter_bindings) + forward_outputs = tuple(record.program_forward_output_bindings) + if not forward_inputs and not forward_outputs: + raise RuntimeError( + "Registered transition forward primitive has no compiler-owned callable binding schema: " + f"primitive={record.primitive!r}; callable={callable_id!r}" + ) + definitions.extend( + TemporalNativeCallableBindingSchemaDefinition( + callable_id=callable_id, + primitive=record.primitive, + binding_kind="input", + logical_name=logical_name, + local_binding_index=local_index, + ) + for local_index, logical_name in enumerate(forward_inputs) + ) + definitions.extend( + TemporalNativeCallableBindingSchemaDefinition( + callable_id=callable_id, + primitive=record.primitive, + binding_kind="parameter", + logical_name=logical_name, + local_binding_index=local_index, + required=required, + ) + for local_index, (logical_name, required) in enumerate(forward_parameters) + ) + definitions.extend( + TemporalNativeCallableBindingSchemaDefinition( + callable_id=callable_id, + primitive=record.primitive, + binding_kind="output", + logical_name=logical_name, + local_binding_index=local_index, + required=required, + ) + for local_index, (logical_name, required) in enumerate(forward_outputs) + ) + if record.program_backward_status == "callable" and record.program_backward_symbol: + callable_id = str(record.program_reverse_native_callable) + if not callable_id: + continue + definitions.extend( + TemporalNativeCallableBindingSchemaDefinition( + callable_id=callable_id, + primitive=record.primitive, + binding_kind="input", + logical_name=logical_name, + local_binding_index=local_index, + direction="reverse", + ) + for local_index, logical_name in enumerate(record.reverse_input_bindings) + ) + definitions.extend( + TemporalNativeCallableBindingSchemaDefinition( + callable_id=callable_id, + primitive=record.primitive, + binding_kind="parameter", + logical_name=logical_name, + local_binding_index=local_index, + direction="reverse", + required=logical_name not in _OPTIONAL_TRANSITION_REVERSE_PARAMETERS, + ) + for local_index, logical_name in enumerate(record.parameter_bindings) + ) + definitions.extend( + TemporalNativeCallableBindingSchemaDefinition( + callable_id=callable_id, + primitive=record.primitive, + binding_kind="output", + logical_name=logical_name, + local_binding_index=local_index, + direction="reverse", + ) + for local_index, logical_name in enumerate(record.reverse_output_bindings) + ) + unique: dict[ + tuple[str, str, str, str, str, int, int], + TemporalNativeCallableBindingSchemaDefinition, + ] = {} + for definition in definitions: + unique.setdefault( + ( + definition.callable_id, + definition.direction, + definition.surface, + definition.binding_kind, + definition.logical_name, + int(definition.local_binding_index), + int(definition.version), + ), + definition, + ) + return tuple(unique.values()) + + +def temporal_native_callable_binding_schema_rows_tensor() -> torch.Tensor: + rows = [ + [ + index, + int(definition.callable_hash), + int(_NATIVE_CALLABLE_DIRECTION_OPCODE[definition.direction]), + int(_NATIVE_CALLABLE_SURFACE_OPCODE.get(definition.surface, 0)), + _native_callable_primitive_opcode(definition.primitive), + int(_NATIVE_CALLABLE_BINDING_KIND_OPCODE[definition.binding_kind]), + int(definition.logical_name_hash), + int(definition.local_binding_index), + int(definition.required), + int(definition.version), + ] + for index, definition in enumerate(temporal_native_callable_binding_schema_definitions()) + ] + if not rows: + return torch.empty((0, 10), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_native_callable_binding_schema_summaries() -> tuple[str, ...]: + return tuple(definition.summary for definition in temporal_native_callable_binding_schema_definitions()) + + +def temporal_native_callable_binding_schema_fingerprint() -> int: + checksum = 2166136261 + for definition in temporal_native_callable_binding_schema_definitions(): + for field in ( + definition.callable_id, + definition.direction, + definition.surface, + definition.primitive, + definition.binding_kind, + definition.logical_name, + str(int(definition.local_binding_index)), + str(int(definition.required)), + str(int(definition.version)), + ): + for byte in str(field).encode("utf-8"): + checksum ^= int(byte) + checksum = (checksum * 16777619) & 0xFFFFFFFF + checksum &= 0x7FFFFFFF + return int(checksum if checksum > 0 else 1) + + +_STATIC_TRANSITION_REVERSE_SEED_ROLE_DEFINITIONS = ( + TemporalTransitionReverseSeedRoleDefinition("grad_public_y", 1, "public_grad"), + TemporalTransitionReverseSeedRoleDefinition("grad_next_y", 2, "state_grad"), + TemporalTransitionReverseSeedRoleDefinition("grad_next_c", 3, "state_grad"), + TemporalTransitionReverseSeedRoleDefinition("grad_next_n", 4, "state_grad"), + TemporalTransitionReverseSeedRoleDefinition("grad_next_m", 5, "state_grad"), + TemporalTransitionReverseSeedRoleDefinition("grad_next_hc1", 6, "state_grad"), + TemporalTransitionReverseSeedRoleDefinition("grad_next_hc2", 7, "state_grad"), +) + + +def temporal_transition_reverse_seed_role_definitions( + extra_role_names: tuple[str, ...] = (), +) -> tuple[TemporalTransitionReverseSeedRoleDefinition, ...]: + dynamic_definitions = tuple( + _dynamic_transition_reverse_seed_role_definition(role_name) + for role_name in tuple(dict.fromkeys(str(role_name) for role_name in extra_role_names)) + if _static_transition_reverse_seed_role_definition(role_name) is None + ) + definitions = _STATIC_TRANSITION_REVERSE_SEED_ROLE_DEFINITIONS + dynamic_definitions + role_ids = [int(definition.role_id) for definition in definitions] + if len(role_ids) != len(set(role_ids)): + raise RuntimeError( + "Registered transition reverse seed role IDs must be unique: " + f"roles={tuple(definition.summary for definition in definitions)!r}" + ) + return definitions + + +def _static_transition_reverse_seed_role_definition( + role_name: str, +) -> TemporalTransitionReverseSeedRoleDefinition | None: + for definition in _STATIC_TRANSITION_REVERSE_SEED_ROLE_DEFINITIONS: + if definition.role_name == str(role_name): + return definition + return None + + +def _dynamic_transition_reverse_seed_role_definition( + role_name: str, +) -> TemporalTransitionReverseSeedRoleDefinition: + role_name = str(role_name) + if not role_name.startswith("grad_next_"): + raise RuntimeError(f"Unknown registered transition reverse seed role {role_name!r}") + return TemporalTransitionReverseSeedRoleDefinition( + role_name, + _dynamic_transition_reverse_seed_role_id(role_name), + "state_grad", + ) + + +def _dynamic_transition_reverse_seed_role_id(role_name: str) -> int: + return 1_000_000_000 + temporal_strategy_id_hash(f"transition_reverse_seed_role:{role_name}") + + +def temporal_transition_reverse_seed_role_id(role_name: str) -> int: + static_definition = _static_transition_reverse_seed_role_definition(role_name) + if static_definition is not None: + return int(static_definition.role_id) + return int(_dynamic_transition_reverse_seed_role_definition(role_name).role_id) + + +def temporal_transition_reverse_seed_role_rows_tensor( + extra_role_names: tuple[str, ...] = (), +) -> torch.Tensor: + return _transition_reverse_seed_role_rows_tensor(extra_role_names) + + +def _transition_reverse_seed_role_rows_tensor( + extra_role_names: tuple[str, ...], +) -> torch.Tensor: + rows = [ + [ + int(definition.role_id), + int(definition.role_name_hash), + int(_TRANSITION_REVERSE_SEED_KIND_OPCODE[definition.seed_kind]), + int(definition.version), + ] + for definition in temporal_transition_reverse_seed_role_definitions(extra_role_names) + ] + if not rows: + return torch.empty((0, 4), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_transition_reverse_seed_role_summaries( + extra_role_names: tuple[str, ...] = (), +) -> tuple[str, ...]: + return tuple( + definition.summary for definition in temporal_transition_reverse_seed_role_definitions(extra_role_names) + ) + + +def temporal_native_callable_output_summaries() -> tuple[str, ...]: + return tuple(definition.summary for definition in temporal_native_callable_output_definitions()) + + +def temporal_native_callable_output_contract_fingerprint() -> int: + checksum = 2166136261 + for definition in temporal_native_callable_output_definitions(): + for field in ( + definition.callable_id, + definition.direction, + definition.surface, + definition.primitive, + definition.output_name, + str(int(definition.output_index)), + definition.runtime_role, + definition.shape_kind, + definition.logical_index_source, + definition.init, + str(int(definition.version)), + ): + for byte in str(field).encode("utf-8"): + checksum ^= int(byte) + checksum = (checksum * 16777619) & 0xFFFFFFFF + checksum &= 0x7FFFFFFF + return int(checksum if checksum > 0 else 1) + + +def temporal_native_callable_transition_forward_output_definition( + *, + primitive: str, + output_name: str, + output_index: int | None = None, +) -> TemporalNativeCallableOutputDefinition: + record = transition_primitive_executor_record_for_lowered_primitive(str(primitive)) + if record is None or record.program_forward_status != "callable" or not record.program_forward_symbol: + raise RuntimeError(f"Transition primitive {primitive!r} has no registered forward native callable") + matches = tuple( + definition + for definition in temporal_native_callable_output_definitions() + if definition.callable_id == record.program_forward_symbol and definition.output_name == str(output_name) + ) + if len(matches) == 1: + return matches[0] + if output_index is not None: + matches = tuple( + definition + for definition in temporal_native_callable_output_definitions() + if definition.callable_id == record.program_forward_symbol + and int(definition.output_index) == int(output_index) + ) + if len(matches) != 1: + raise RuntimeError( + "Transition forward native callable has no unique compiler-owned output contract: " + f"primitive={primitive!r}; callable={record.program_forward_symbol!r}; output={output_name!r}; " + f"count={len(matches)}" + ) + return matches[0] + + +def temporal_native_callable_summaries() -> tuple[str, ...]: + return tuple(definition.summary for definition in temporal_native_callable_definitions()) + + +def temporal_native_callable_catalog_fingerprint() -> int: + checksum = 2166136261 + for definition in temporal_native_callable_definitions(): + for field in ( + definition.callable_id, + definition.category, + definition.direction, + definition.surface, + definition.primitive, + definition.implementation_symbol, + definition.display_name, + str(int(definition.version)), + *definition.cxx_entrypoints, + *definition.cxx_entrypoint_phases, + ): + for byte in str(field).encode("utf-8"): + checksum ^= int(byte) + checksum = (checksum * 16777619) & 0xFFFFFFFF + checksum &= 0x7FFFFFFF + return int(checksum if checksum > 0 else 1) + + +def temporal_native_callable_generated_header_text() -> str: + lines = [ + "// Generated by compiler/native_callables.py.", + "// Do not edit this catalog by hand; update the compiler registry and regenerate.", + f"// Catalog fingerprint: {temporal_native_callable_catalog_fingerprint()}", + "", + ] + _emit_forward_transition_catalog(lines) + _emit_forward_message_catalog(lines) + _emit_forward_readout_catalog(lines) + _emit_reverse_message_catalog(lines) + _emit_reverse_readout_catalog(lines) + _emit_reverse_transition_catalog(lines) + _emit_parameter_reducer_catalog(lines) + _emit_transition_trainable_reducer_catalog(lines) + return "\n".join(lines).rstrip() + "\n" + + +def validate_temporal_native_callable_generated_header(header_text: str) -> None: + expected = temporal_native_callable_generated_header_text() + if str(header_text) != expected: + raise RuntimeError( + "Registered temporal native callable header does not match the compiler-owned catalog. " + "Regenerate src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/" + "flat_bucket_registered_native_callables.cuh from temporal_native_callable_generated_header_text()." + ) + + +def validate_temporal_native_callable_catalog_coverage( + *, + native_callable_catalog_rows: torch.Tensor, + native_strategy_rows: torch.Tensor, + transition_primitive_callable_rows: torch.Tensor, +) -> None: + _require_catalog_rows(native_callable_catalog_rows) + catalog_hashes = {int(row[1]) for row in native_callable_catalog_rows.tolist()} + _require_hashes_in_catalog( + tuple(int(row[16]) for row in _rows(native_strategy_rows, columns=17)), + catalog_hashes, + "native_strategy_rows", + ) + transition_hashes: list[int] = [] + for row in _rows(transition_primitive_callable_rows, columns=6): + if int(row[4]): + transition_hashes.append(int(row[1])) + if int(row[5]): + transition_hashes.append(int(row[2])) + _require_hashes_in_catalog(tuple(transition_hashes), catalog_hashes, "transition_primitive_callable_rows") + + +def validate_temporal_native_callable_output_contract_coverage( + *, + native_callable_catalog_rows: torch.Tensor, + native_callable_output_rows: torch.Tensor, +) -> None: + _require_catalog_rows(native_callable_catalog_rows) + output_rows = _rows(native_callable_output_rows, columns=12) + catalog_hashes = {int(row[1]) for row in native_callable_catalog_rows.tolist()} + _require_hashes_in_catalog( + tuple(int(row[1]) for row in output_rows), + catalog_hashes, + "native_callable_output_rows", + ) + + +def validate_temporal_native_callable_binding_schema_coverage( + *, + native_callable_catalog_rows: torch.Tensor, + native_callable_binding_schema_rows: torch.Tensor, +) -> None: + _require_catalog_rows(native_callable_catalog_rows) + binding_rows = _rows(native_callable_binding_schema_rows, columns=10) + catalog_hashes = {int(row[1]) for row in native_callable_catalog_rows.tolist()} + _require_hashes_in_catalog( + tuple(int(row[1]) for row in binding_rows), + catalog_hashes, + "native_callable_binding_schema_rows", + ) + valid_directions = set(_NATIVE_CALLABLE_DIRECTION_OPCODE.values()) + valid_surfaces = set(_NATIVE_CALLABLE_SURFACE_OPCODE.values()) + valid_binding_kinds = set(_NATIVE_CALLABLE_BINDING_KIND_OPCODE.values()) + seen: set[tuple[int, int, int, int]] = set() + for row in binding_rows: + if int(row[2]) not in valid_directions or int(row[3]) not in valid_surfaces: + raise RuntimeError("Temporal native callable binding schema contains invalid direction or surface opcode") + if int(row[5]) not in valid_binding_kinds: + raise RuntimeError("Temporal native callable binding schema contains invalid binding kind opcode") + if int(row[6]) <= 0 or int(row[7]) < 0 or int(row[9]) != 1: + raise RuntimeError("Temporal native callable binding schema contains invalid logical name/index/version") + key = (int(row[1]), int(row[5]), int(row[6]), int(row[7])) + if key in seen: + raise RuntimeError("Temporal native callable binding schema contains duplicate callable binding rows") + seen.add(key) + + +def _native_callable_primitive_opcode(primitive: str) -> int: + if not primitive: + return 0 + return temporal_primitive_opcode(primitive) + + +def _strategy_display_name(strategy_id: str, version: int) -> str: + suffix = f".v{int(version)}" + return str(strategy_id)[: -len(suffix)] if str(strategy_id).endswith(suffix) else str(strategy_id) + + +def _strategy_cxx_entrypoints(pattern: object) -> tuple[str, ...]: + return tuple(str(entrypoint) for entrypoint in getattr(pattern, "cxx_entrypoints", ())) + + +def _strategy_cxx_entrypoint_phases(pattern: object) -> tuple[str, ...]: + return tuple(str(phase) for phase in getattr(pattern, "cxx_entrypoint_phases", ())) + + +def _display_name(definition: TemporalNativeCallableDefinition) -> str: + return definition.display_name or definition.callable_id + + +def _definitions( + *, + category: TemporalNativeCallableCategory, + direction: TemporalNativeCallableDirection | None = None, + surface: str | None = None, +) -> tuple[TemporalNativeCallableDefinition, ...]: + return tuple( + definition + for definition in temporal_native_callable_definitions() + if definition.category == category + and (direction is None or definition.direction == direction) + and (surface is None or definition.surface == surface) + ) + + +def _emit_forward_transition_catalog(lines: list[str]) -> None: + definitions = _definitions( + category="transition_primitive_forward", + direction="forward", + surface="transition", + ) + lines.extend( + ( + "#if defined(REGISTERED_TEMPORAL_NATIVE_FORWARD_TRANSITION_CATALOG)", + "inline const RegisteredTransitionForwardPrimitiveExecutor* registered_native_transition_forward_primitive_catalog_begin() {", + " static const RegisteredTransitionForwardPrimitiveExecutor kRegisteredNativeTransitionForwardPrimitiveCatalog[] = {", + ) + ) + for definition in definitions: + run_symbol = _single_cxx_entrypoint(definition) + lines.extend( + ( + " {", + f' registered_temporal_stable_id_hash_constexpr("{definition.callable_id}"),', + f' "{_display_name(definition)}",', + f" {run_symbol},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredNativeTransitionForwardPrimitiveCatalog;", + "}", + "", + "inline const RegisteredTransitionForwardPrimitiveExecutor* registered_native_transition_forward_primitive_catalog_end() {", + f" return registered_native_transition_forward_primitive_catalog_begin() + {len(definitions)};", + "}", + "#undef REGISTERED_TEMPORAL_NATIVE_FORWARD_TRANSITION_CATALOG", + "#endif", + "", + ) + ) + + +def _emit_forward_message_catalog(lines: list[str]) -> None: + definitions = _definitions(category="executor_strategy", direction="forward", surface="message") + lines.extend( + ( + "#if defined(REGISTERED_TEMPORAL_NATIVE_FORWARD_MESSAGE_CATALOG)", + "inline const RegisteredForwardMessageCarrierStrategy* registered_native_forward_message_catalog_begin() {", + " static const RegisteredForwardMessageCarrierStrategy kRegisteredNativeForwardMessageCatalog[] = {", + ) + ) + for definition in definitions: + phase_entrypoints = dict(zip(definition.cxx_entrypoint_phases, definition.cxx_entrypoints, strict=True)) + required_phases = ("bind", "recurrent_kv", "message") + missing = tuple(phase for phase in required_phases if phase not in phase_entrypoints) + unsupported = tuple( + phase + for phase in phase_entrypoints + if phase + not in ( + *required_phases, + "keyless_readout_message", + "direct_keyless_readout_message", + "stream_readout_message", + "stream_transition_input", + ) + ) + if missing or unsupported: + raise RuntimeError( + "Registered temporal forward message native callable has invalid C++ entrypoint phases: " + f"callable={definition.callable_id!r}; missing={missing!r}; unsupported={unsupported!r}; " + f"phases={definition.cxx_entrypoint_phases!r}" + ) + bind = phase_entrypoints["bind"] + recurrent_kv = phase_entrypoints["recurrent_kv"] + message = phase_entrypoints["message"] + keyless_readout_message = phase_entrypoints.get("keyless_readout_message", "nullptr") + direct_keyless_readout_message = phase_entrypoints.get("direct_keyless_readout_message", "nullptr") + stream_readout_message = phase_entrypoints.get("stream_readout_message", "nullptr") + stream_transition_input = phase_entrypoints.get("stream_transition_input", "nullptr") + lines.extend( + ( + " {", + f' registered_temporal_stable_id_hash_constexpr("{definition.callable_id}"),', + f' "{_display_name(definition)}",', + f" {bind},", + f" {recurrent_kv},", + f" {message},", + f" {keyless_readout_message},", + f" {direct_keyless_readout_message},", + f" {stream_readout_message},", + f" {stream_transition_input},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredNativeForwardMessageCatalog;", + "}", + "", + "inline const RegisteredForwardMessageCarrierStrategy* registered_native_forward_message_catalog_end() {", + f" return registered_native_forward_message_catalog_begin() + {len(definitions)};", + "}", + "#undef REGISTERED_TEMPORAL_NATIVE_FORWARD_MESSAGE_CATALOG", + "#endif", + "", + ) + ) + + +def _emit_forward_readout_catalog(lines: list[str]) -> None: + definitions = _definitions(category="executor_strategy", direction="forward", surface="readout") + lines.extend( + ( + "#if defined(REGISTERED_TEMPORAL_NATIVE_FORWARD_READOUT_CATALOG)", + "inline const RegisteredForwardReadoutStrategy* registered_native_forward_readout_catalog_begin() {", + " static const RegisteredForwardReadoutStrategy kRegisteredNativeForwardReadoutCatalog[] = {", + ) + ) + for definition in definitions: + phase_entrypoints = _cxx_entrypoints_by_phase( + definition, + ("bind", "message", "projection", "projection_into"), + ) + bind = phase_entrypoints["bind"] + message = phase_entrypoints["message"] + project = phase_entrypoints["projection"] + project_into = phase_entrypoints["projection_into"] + lines.extend( + ( + " {", + f' registered_temporal_stable_id_hash_constexpr("{definition.callable_id}"),', + f' "{_display_name(definition)}",', + f" {bind},", + f" {message},", + f" {project},", + f" {project_into},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredNativeForwardReadoutCatalog;", + "}", + "", + "inline const RegisteredForwardReadoutStrategy* registered_native_forward_readout_catalog_end() {", + f" return registered_native_forward_readout_catalog_begin() + {len(definitions)};", + "}", + "#undef REGISTERED_TEMPORAL_NATIVE_FORWARD_READOUT_CATALOG", + "#endif", + "", + ) + ) + + +def _emit_reverse_message_catalog(lines: list[str]) -> None: + definitions = _definitions(category="executor_strategy", direction="reverse", surface="message") + lines.extend( + ( + "#if defined(REGISTERED_TEMPORAL_NATIVE_REVERSE_MESSAGE_CATALOG)", + "inline const RegisteredReverseMessageStrategy* registered_native_reverse_message_catalog_begin() {", + " static const RegisteredReverseMessageStrategy kRegisteredNativeReverseMessageCatalog[] = {", + ) + ) + for definition in definitions: + phase_entrypoints = _cxx_entrypoints_by_phase( + definition, + ( + "recurrent_kv_backward", + "recurrent_message_backward", + "initial_recurrent_kv_backward", + "boundary_kv_backward", + "recurrent_kv_forward_recompute", + ), + ) + recurrent_kv = phase_entrypoints["recurrent_kv_backward"] + message = phase_entrypoints["recurrent_message_backward"] + initial_recurrent_kv = phase_entrypoints["initial_recurrent_kv_backward"] + boundary_kv = phase_entrypoints["boundary_kv_backward"] + recurrent_kv_forward_recompute = phase_entrypoints["recurrent_kv_forward_recompute"] + lines.extend( + ( + " {", + f' registered_temporal_stable_id_hash_constexpr("{definition.callable_id}"),', + f' "{_display_name(definition)}",', + f" {recurrent_kv},", + f" {message},", + f" {initial_recurrent_kv},", + f" {boundary_kv},", + f" {recurrent_kv_forward_recompute},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredNativeReverseMessageCatalog;", + "}", + "", + "inline const RegisteredReverseMessageStrategy* registered_native_reverse_message_catalog_end() {", + f" return registered_native_reverse_message_catalog_begin() + {len(definitions)};", + "}", + "#undef REGISTERED_TEMPORAL_NATIVE_REVERSE_MESSAGE_CATALOG", + "#endif", + "", + ) + ) + + +def _emit_reverse_readout_catalog(lines: list[str]) -> None: + definitions = _definitions(category="executor_strategy", direction="reverse", surface="readout") + lines.extend( + ( + "#if defined(REGISTERED_TEMPORAL_NATIVE_REVERSE_READOUT_CATALOG)", + "inline const RegisteredReverseReadoutStrategy* registered_native_reverse_readout_catalog_begin() {", + " static const RegisteredReverseReadoutStrategy kRegisteredNativeReverseReadoutCatalog[] = {", + ) + ) + for definition in definitions: + phase_entrypoints = _cxx_entrypoints_by_phase( + definition, + ("readout_backward", "output_message_backward"), + ) + readout = phase_entrypoints["readout_backward"] + output_message = phase_entrypoints["output_message_backward"] + lines.extend( + ( + " {", + f' registered_temporal_stable_id_hash_constexpr("{definition.callable_id}"),', + f' "{_display_name(definition)}",', + f" {readout},", + f" {output_message},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredNativeReverseReadoutCatalog;", + "}", + "", + "inline const RegisteredReverseReadoutStrategy* registered_native_reverse_readout_catalog_end() {", + f" return registered_native_reverse_readout_catalog_begin() + {len(definitions)};", + "}", + "#undef REGISTERED_TEMPORAL_NATIVE_REVERSE_READOUT_CATALOG", + "#endif", + "", + ) + ) + + +def _emit_reverse_transition_catalog(lines: list[str]) -> None: + definitions = _reverse_transition_strategy_definitions() + lines.extend( + ( + "#if defined(REGISTERED_TEMPORAL_NATIVE_REVERSE_TRANSITION_CATALOG)", + "inline const RegisteredTransitionReversePrimitiveExecutor* registered_native_transition_reverse_primitive_catalog_begin() {", + " static const RegisteredTransitionReversePrimitiveExecutor kRegisteredNativeTransitionReversePrimitiveCatalog[] = {", + ) + ) + for ( + definition, + primitive_backward_callable, + input_count, + min_param_count, + max_param_count, + output_count, + ) in definitions: + run_symbol = _single_cxx_entrypoint(definition) + lines.extend( + ( + " {", + f' registered_temporal_stable_id_hash_constexpr("{definition.callable_id}"),', + f' registered_temporal_stable_id_hash_constexpr("{primitive_backward_callable}"),', + f" {input_count},", + f" {min_param_count},", + f" {max_param_count},", + f" {output_count},", + f' "{_display_name(definition)}",', + f" {run_symbol},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredNativeTransitionReversePrimitiveCatalog;", + "}", + "", + "inline const RegisteredTransitionReversePrimitiveExecutor* registered_native_transition_reverse_primitive_catalog_end() {", + f" return registered_native_transition_reverse_primitive_catalog_begin() + {len(definitions)};", + "}", + "#undef REGISTERED_TEMPORAL_NATIVE_REVERSE_TRANSITION_CATALOG", + "#endif", + "", + ) + ) + + +def _emit_parameter_reducer_catalog(lines: list[str]) -> None: + definitions = _definitions(category="parameter_reducer", direction="reverse", surface="parameter_reduction") + lines.extend( + ( + "#if defined(REGISTERED_TEMPORAL_NATIVE_PARAMETER_REDUCER_CATALOG)", + "inline const RegisteredParameterReducerHandler* registered_native_parameter_reducer_catalog_begin() {", + " static const RegisteredParameterReducerHandler kRegisteredNativeParameterReducerCatalog[] = {", + ) + ) + for definition in definitions: + run_symbol = _single_cxx_entrypoint(definition) + lines.extend( + ( + " {", + f' registered_temporal_stable_id_hash_constexpr("{definition.callable_id}"),', + f' "{_display_name(definition)}",', + f" {run_symbol},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredNativeParameterReducerCatalog;", + "}", + "", + "inline const RegisteredParameterReducerHandler* registered_native_parameter_reducer_catalog_end() {", + f" return registered_native_parameter_reducer_catalog_begin() + {len(definitions)};", + "}", + "#undef REGISTERED_TEMPORAL_NATIVE_PARAMETER_REDUCER_CATALOG", + "#endif", + "", + ) + ) + + +def _emit_transition_trainable_reducer_catalog(lines: list[str]) -> None: + definitions = _definitions( + category="transition_trainable_reducer", + direction="reverse", + surface="parameter_reduction", + ) + lines.extend( + ( + "#if defined(REGISTERED_TEMPORAL_NATIVE_TRANSITION_TRAINABLE_REDUCER_CATALOG)", + "inline const RegisteredTransitionTrainableReducerHandler* registered_native_transition_trainable_reducer_catalog_begin() {", + " static const RegisteredTransitionTrainableReducerHandler kRegisteredNativeTransitionTrainableReducerCatalog[] = {", + ) + ) + for definition in definitions: + run_symbol = _single_cxx_entrypoint(definition) + lines.extend( + ( + " {", + f' registered_temporal_stable_id_hash_constexpr("{definition.callable_id}"),', + f' "{_display_name(definition)}",', + f" {run_symbol},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredNativeTransitionTrainableReducerCatalog;", + "}", + "", + "inline const RegisteredTransitionTrainableReducerHandler* registered_native_transition_trainable_reducer_catalog_end() {", + f" return registered_native_transition_trainable_reducer_catalog_begin() + {len(definitions)};", + "}", + "#undef REGISTERED_TEMPORAL_NATIVE_TRANSITION_TRAINABLE_REDUCER_CATALOG", + "#endif", + "", + ) + ) + + +def _reverse_transition_strategy_definitions() -> tuple[ + tuple[TemporalNativeCallableDefinition, str, int, int, int, int], + ..., +]: + rows: list[tuple[TemporalNativeCallableDefinition, str, int, int, int, int]] = [] + seen: set[tuple[str, str]] = set() + for definition in _definitions(category="executor_strategy", direction="reverse", surface="transition"): + record = transition_primitive_executor_record_for_lowered_primitive(definition.primitive) + if record is None or record.program_backward_status != "callable" or not record.program_backward_symbol: + continue + key = (definition.callable_id, record.program_backward_symbol) + if key in seen: + continue + seen.add(key) + max_param_count = len(record.parameter_bindings) + min_param_count = max_param_count - sum( + 1 for parameter in record.parameter_bindings if parameter in _OPTIONAL_TRANSITION_REVERSE_PARAMETERS + ) + rows.append( + ( + definition, + record.program_backward_symbol, + len(record.reverse_input_bindings), + min_param_count, + max_param_count, + len(record.reverse_output_bindings), + ) + ) + return tuple(rows) + + +def _cxx_entrypoints(definition: TemporalNativeCallableDefinition, *, count: int) -> tuple[str, ...]: + entrypoints = tuple(definition.cxx_entrypoints) + if len(entrypoints) != int(count): + raise RuntimeError( + "Registered temporal native callable has wrong C++ entrypoint arity: " + f"callable={definition.callable_id!r}; expected={int(count)}; actual={len(entrypoints)}" + ) + return entrypoints + + +def _cxx_entrypoints_by_phase( + definition: TemporalNativeCallableDefinition, + expected_phases: tuple[str, ...], +) -> dict[str, str]: + entrypoints = tuple(definition.cxx_entrypoints) + phases = tuple(definition.cxx_entrypoint_phases) + if len(entrypoints) != len(expected_phases) or len(phases) != len(expected_phases): + raise RuntimeError( + "Registered temporal native callable has wrong named C++ entrypoint contract: " + f"callable={definition.callable_id!r}; expected={expected_phases!r}; " + f"phases={phases!r}; entrypoints={entrypoints!r}" + ) + if set(phases) != set(expected_phases) or len(set(phases)) != len(phases): + raise RuntimeError( + "Registered temporal native callable has invalid C++ entrypoint phases: " + f"callable={definition.callable_id!r}; expected={expected_phases!r}; phases={phases!r}" + ) + return dict(zip(phases, entrypoints, strict=True)) + + +def _single_cxx_entrypoint(definition: TemporalNativeCallableDefinition) -> str: + return _cxx_entrypoints(definition, count=1)[0] + + +def _rows(tensor: torch.Tensor, *, columns: int) -> tuple[tuple[int, ...], ...]: + if ( + tensor.device.type != "cpu" + or tensor.dtype != torch.long + or tensor.dim() != 2 + or int(tensor.shape[1]) != columns + ): + raise RuntimeError(f"Temporal native callable validation requires CPU int64 rows with shape [N,{columns}]") + return tuple(tuple(int(item) for item in row) for row in tensor.tolist()) + + +def _require_catalog_rows(native_callable_catalog_rows: torch.Tensor) -> None: + _rows(native_callable_catalog_rows, columns=8) + hashes = tuple(int(row[1]) for row in native_callable_catalog_rows.tolist()) + if len(set(hashes)) != len(hashes): + raise RuntimeError("Temporal native callable catalog contains duplicate callable hashes") + + +def _require_hashes_in_catalog(hashes: tuple[int, ...], catalog_hashes: set[int], subject: str) -> None: + missing = tuple(sorted({int(item) for item in hashes if int(item) > 0 and int(item) not in catalog_hashes})) + if missing: + raise RuntimeError( + f"Temporal native callable catalog does not cover {subject}: " + ",".join(str(item) for item in missing) + ) + + +__all__ = [ + "TemporalNativeCallableDefinition", + "TemporalNativeCallableBindingSchemaDefinition", + "TemporalNativeCallableOutputDefinition", + "TemporalTransitionReverseSeedRoleDefinition", + "parameter_reducer_native_callable_id", + "temporal_native_callable_catalog_fingerprint", + "temporal_native_callable_catalog_rows_tensor", + "temporal_native_callable_binding_schema_definitions", + "temporal_native_callable_binding_schema_fingerprint", + "temporal_native_callable_binding_schema_rows_tensor", + "temporal_native_callable_binding_schema_summaries", + "temporal_native_callable_definitions", + "temporal_native_callable_generated_header_text", + "temporal_native_callable_output_contract_fingerprint", + "temporal_native_callable_output_definitions", + "temporal_native_callable_output_rows_tensor", + "temporal_native_callable_output_summaries", + "temporal_native_callable_summaries", + "temporal_native_callable_transition_forward_output_definition", + "temporal_strategy_id_hash", + "temporal_transition_reverse_seed_role_definitions", + "temporal_transition_reverse_seed_role_id", + "temporal_transition_reverse_seed_role_rows_tensor", + "temporal_transition_reverse_seed_role_summaries", + "transition_trainable_reducer_native_callable_id", + "validate_temporal_native_callable_generated_header", + "validate_temporal_native_callable_binding_schema_coverage", + "validate_temporal_native_callable_catalog_coverage", + "validate_temporal_native_callable_output_contract_coverage", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py new file mode 100644 index 00000000..4e0dc3b6 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py @@ -0,0 +1,536 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .executor_patterns import ( + surface_for_temporal_row, + temporal_executor_strategy_registry, +) +from .row_groups import ( + TEMPORAL_MESSAGE_BUCKET_ORDINAL, + TEMPORAL_PARAMETER_REDUCTION_BUCKET_ORDINAL, + TEMPORAL_READOUT_BUCKET_ORDINAL, +) +from .tables import TemporalPrimitiveTablePlan + + +_TRANSITION_AFFINE_PRIMITIVES = frozenset({"linear", "matmul"}) +_TRANSITION_NORMALIZATION_PRIMITIVES = frozenset({"norm", "layer_norm", "rms_norm", "identity", "norm_or_identity"}) + + +@dataclass(frozen=True) +class TemporalPrimitiveExecutorContract: + row_index: int | None + surface: str + primitive: str + bucket_ordinal: int | None + forward_executor: str + backward_executor: str + status: str + reason: str + required_tensor_roles: tuple[str, ...] = () + parameter_bindings: tuple[str, ...] = () + inputs: tuple[str, ...] = () + outputs: tuple[str, ...] = () + + @property + def summary(self) -> str: + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + row = "*" if self.row_index is None else str(int(self.row_index)) + params = ",".join(self.parameter_bindings) if self.parameter_bindings else "-" + return ( + f"row={row}" + f",surface={self.surface}" + f",primitive={self.primitive}" + f",bucket={bucket}" + f",forward={self.forward_executor}" + f",backward={self.backward_executor}" + f",status={self.status}" + f",reason={self.reason}" + f",params={params}" + ) + + @property + def blocker(self) -> str | None: + if self.status == "implemented": + return None + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + return f"primitive={self.primitive},bucket={bucket},reason={self.reason}" + + +@dataclass(frozen=True) +class TemporalFusedPrimitiveExecutorGroup: + group_name: str + surface: str + bucket_ordinal: int | None + row_indices: tuple[int, ...] + primitives: tuple[str, ...] + forward_executor: str + backward_executor: str + status: str + reason: str + parameter_bindings: tuple[str, ...] = () + + @property + def summary(self) -> str: + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + params = ",".join(self.parameter_bindings) if self.parameter_bindings else "-" + return ( + f"fusion_group={self.group_name}" + f",surface={self.surface}" + f",bucket={bucket}" + f",rows={','.join(str(row) for row in self.row_indices)}" + f",primitives={'+'.join(self.primitives)}" + f",forward={self.forward_executor}" + f",backward={self.backward_executor}" + f",status={self.status}" + f",reason={self.reason}" + f",params={params}" + ) + + @property + def blocker(self) -> str | None: + if self.status == "implemented": + return None + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + return f"fusion_group={self.group_name},surface={self.surface},bucket={bucket},reason={self.reason}" + + +@dataclass(frozen=True) +class TemporalPrimitiveExecutorPlan: + contracts: tuple[TemporalPrimitiveExecutorContract, ...] + fusion_groups: tuple[TemporalFusedPrimitiveExecutorGroup, ...] = () + + @property + def summaries(self) -> tuple[str, ...]: + return tuple(contract.summary for contract in self.contracts) + tuple( + group.summary for group in self.fusion_groups + ) + + @property + def blockers(self) -> tuple[str, ...]: + return tuple( + dict.fromkeys( + ( + *(blocker for contract in self.contracts if (blocker := contract.blocker) is not None), + *(blocker for group in self.fusion_groups if (blocker := group.blocker) is not None), + ) + ) + ) + + @property + def has_blockers(self) -> bool: + return bool(self.blockers) + + +def build_temporal_primitive_executor_plan( + table: TemporalPrimitiveTablePlan, +) -> TemporalPrimitiveExecutorPlan: + fusion_groups = _fusion_groups_for_table(table) + group_by_row_index = {int(row_index): group for group in fusion_groups for row_index in group.row_indices} + contracts: list[TemporalPrimitiveExecutorContract] = [] + for row_index, row in enumerate(table.primitive_rows): + contracts.append( + _contract_for_row( + row, + row_index=int(row_index), + fusion_group=group_by_row_index.get(int(row_index)), + ) + ) + contracts.extend(_missing_surface_contracts(table)) + return TemporalPrimitiveExecutorPlan(contracts=tuple(contracts), fusion_groups=fusion_groups) + + +def _fusion_groups_for_table(table: TemporalPrimitiveTablePlan) -> tuple[TemporalFusedPrimitiveExecutorGroup, ...]: + groups: list[TemporalFusedPrimitiveExecutorGroup] = [] + message = [ + (row_index, row) + for row_index, row in enumerate(table.primitive_rows) + if int(row.bucket_ordinal) == TEMPORAL_MESSAGE_BUCKET_ORDINAL and "surface=message" in row.flat_bucket_identity + ] + if message: + groups.append( + _fusion_group( + group_name="compiled_message_rule", + surface="message", + bucket_ordinal=TEMPORAL_MESSAGE_BUCKET_ORDINAL, + rows=tuple(message), + ) + ) + readout = [ + (row_index, row) + for row_index, row in enumerate(table.primitive_rows) + if int(row.bucket_ordinal) == TEMPORAL_READOUT_BUCKET_ORDINAL + and ("surface=readout" in row.flat_bucket_identity or "surface=readout_boundary" in row.flat_bucket_identity) + ] + if readout: + groups.append( + _fusion_group( + group_name="compiled_readout_boundary", + surface="readout", + bucket_ordinal=TEMPORAL_READOUT_BUCKET_ORDINAL, + rows=tuple(readout), + ) + ) + parameter_rows_by_bucket: dict[str, list[tuple[int, object]]] = {} + for row_index, row in enumerate(table.primitive_rows): + if ( + int(row.bucket_ordinal) != TEMPORAL_PARAMETER_REDUCTION_BUCKET_ORDINAL + or "surface=parameter_reduction" not in row.flat_bucket_identity + ): + continue + bucket_attr = next((value for key, value in row.attributes if key == "bucket_ordinal"), "*") + parameter_rows_by_bucket.setdefault(str(bucket_attr), []).append((row_index, row)) + for bucket_key, rows in sorted(parameter_rows_by_bucket.items(), key=lambda item: item[0]): + bucket_ordinal = None if bucket_key == "*" else int(bucket_key) + groups.append( + _fusion_group( + group_name="compiled_parameter_reduction", + surface="parameter_reduction", + bucket_ordinal=bucket_ordinal, + rows=tuple(rows), + ) + ) + transition_by_bucket: dict[int, list[tuple[int, object]]] = {} + for row_index, row in enumerate(table.primitive_rows): + if int(row.bucket_ordinal) < 0 or "surface=transition" not in row.flat_bucket_identity: + continue + transition_by_bucket.setdefault(int(row.bucket_ordinal), []).append((row_index, row)) + for bucket_ordinal, rows in sorted(transition_by_bucket.items()): + groups.append( + _fusion_group( + group_name="compiled_transition_block", + surface="transition", + bucket_ordinal=bucket_ordinal, + rows=tuple(rows), + ) + ) + return tuple(groups) + + +def _fusion_group( + *, + group_name: str, + surface: str, + bucket_ordinal: int | None, + rows: tuple[tuple[int, object], ...], +) -> TemporalFusedPrimitiveExecutorGroup: + row_values = tuple(row for _row_index, row in rows) + primitives = tuple(str(getattr(row, "primitive", "")) for _row_index, row in rows) + parameter_bindings = tuple( + dict.fromkeys( + str(parameter) + for _row_index, row in rows + for parameter in tuple(getattr(row, "parameter_inputs", ()) or ()) + ) + ) + bucket_value = -3 if bucket_ordinal is None and surface == "parameter_reduction" else bucket_ordinal + registry = temporal_executor_strategy_registry() + forward_pattern = ( + None + if bucket_value is None + else registry.match_forward( + surface=surface, + bucket_ordinal=int(bucket_value), + rows=row_values, + ) + ) + reverse_pattern = ( + None + if bucket_value is None + else _match_reverse_pattern_for_group( + registry, + surface=surface, + bucket_ordinal=int(bucket_value), + rows=row_values, + ) + ) + if surface == "parameter_reduction": + status = "implemented" + reason = "registered_reverse_executor_binds_parameter_reductions" + forward_executor = "not_applicable" + backward_executor = "registered_parameter_reduction_executor" + elif forward_pattern is not None and reverse_pattern is not None: + status = "implemented" + reason = "registered_executor_binding_group_implemented" + forward_executor = forward_pattern.executor_name + backward_executor = reverse_pattern.executor_name + else: + status = "missing_executor" + if forward_pattern is None and reverse_pattern is None: + reason = "fused_block_not_dispatched_by_primitive_executor" + elif forward_pattern is None: + reason = "registered_forward_executor_required" + else: + reason = "registered_backward_executor_required" + forward_executor = "unregistered" if forward_pattern is None else forward_pattern.executor_name + backward_executor = "unregistered" if reverse_pattern is None else reverse_pattern.executor_name + return TemporalFusedPrimitiveExecutorGroup( + group_name=group_name, + surface=surface, + bucket_ordinal=bucket_ordinal, + row_indices=tuple(int(row_index) for row_index, _row in rows), + primitives=primitives, + parameter_bindings=parameter_bindings, + forward_executor=forward_executor, + backward_executor=backward_executor, + status=status, + reason=reason, + ) + + +def _match_reverse_pattern_for_group( + registry: object, + *, + surface: str, + bucket_ordinal: int, + rows: tuple[object, ...], +) -> object | None: + match_reverse = getattr(registry, "match_reverse") + direct = match_reverse(surface=surface, bucket_ordinal=int(bucket_ordinal), rows=rows) + if direct is not None or surface != "transition": + return direct + for row in rows: + row_match = match_reverse(surface=surface, bucket_ordinal=int(bucket_ordinal), rows=(row,)) + if row_match is not None: + return row_match + return None + + +def _contract_for_row( + row: object, + *, + row_index: int, + fusion_group: TemporalFusedPrimitiveExecutorGroup | None, +) -> TemporalPrimitiveExecutorContract: + primitive = str(getattr(row, "primitive", "")) + bucket_ordinal = int(getattr(row, "bucket_ordinal", 0)) + surface = _surface_for_row(row) + inputs = tuple(str(item) for item in getattr(row, "inputs", ()) or ()) + outputs = tuple(str(item) for item in getattr(row, "outputs", ()) or ()) + parameter_bindings = tuple(str(item) for item in getattr(row, "parameter_inputs", ()) or ()) + if fusion_group is not None: + return TemporalPrimitiveExecutorContract( + row_index=row_index, + surface=surface, + primitive=primitive, + bucket_ordinal=bucket_ordinal, + forward_executor=fusion_group.forward_executor, + backward_executor=fusion_group.backward_executor, + status=fusion_group.status, + reason=( + "registered_executor_group_owns_primitive_row" + if fusion_group.status == "implemented" + else fusion_group.reason + ), + required_tensor_roles=_required_tensor_roles_for_surface(surface, primitive), + parameter_bindings=parameter_bindings, + inputs=inputs, + outputs=outputs, + ) + if bucket_ordinal == TEMPORAL_MESSAGE_BUCKET_ORDINAL: + return TemporalPrimitiveExecutorContract( + row_index=row_index, + surface=surface, + primitive=primitive, + bucket_ordinal=bucket_ordinal, + forward_executor="registered_message_executor_required", + backward_executor="registered_message_backward_executor_required", + status="missing_executor", + reason="message_primitive_row_not_covered_by_registered_executor_group", + required_tensor_roles=_required_tensor_roles_for_surface(surface, primitive), + parameter_bindings=parameter_bindings, + inputs=inputs, + outputs=outputs, + ) + if bucket_ordinal == TEMPORAL_READOUT_BUCKET_ORDINAL: + return TemporalPrimitiveExecutorContract( + row_index=row_index, + surface=surface, + primitive=primitive, + bucket_ordinal=bucket_ordinal, + forward_executor="registered_readout_executor_required", + backward_executor="registered_boundary_backward_executor_required", + status="missing_executor", + reason="readout_primitive_row_not_covered_by_registered_executor_group", + required_tensor_roles=_required_tensor_roles_for_surface(surface, primitive), + parameter_bindings=parameter_bindings, + inputs=inputs, + outputs=outputs, + ) + if bucket_ordinal == TEMPORAL_PARAMETER_REDUCTION_BUCKET_ORDINAL: + return TemporalPrimitiveExecutorContract( + row_index=row_index, + surface=surface, + primitive=primitive, + bucket_ordinal=bucket_ordinal, + forward_executor="not_applicable", + backward_executor="registered_parameter_binding_executor_required", + status="missing_executor", + reason="parameter_reduction_row_not_covered_by_registered_executor_group", + required_tensor_roles=_required_tensor_roles_for_surface(surface, primitive), + parameter_bindings=parameter_bindings, + inputs=inputs, + outputs=outputs, + ) + if primitive in _TRANSITION_AFFINE_PRIMITIVES: + return TemporalPrimitiveExecutorContract( + row_index=row_index, + surface=surface, + primitive=primitive, + bucket_ordinal=bucket_ordinal, + forward_executor="registered_transition_affine_executor", + backward_executor="registered_transition_affine_backward_executor", + status="implemented", + reason="registered_transition_executor_dispatches_affine_row", + required_tensor_roles=("input", "weight", "bias", "output"), + parameter_bindings=parameter_bindings, + inputs=inputs, + outputs=outputs, + ) + if primitive in _registered_transition_composite_primitives(): + return TemporalPrimitiveExecutorContract( + row_index=row_index, + surface=surface, + primitive=primitive, + bucket_ordinal=bucket_ordinal, + forward_executor="declared_composite_transition_executor", + backward_executor="declared_composite_transition_backward_executor", + status="implemented", + reason="registered_transition_executor_dispatches_declared_composite_row", + required_tensor_roles=("private_state", "transition_params", "projected_message", "output"), + parameter_bindings=parameter_bindings, + inputs=inputs, + outputs=outputs, + ) + if primitive in _TRANSITION_NORMALIZATION_PRIMITIVES: + return TemporalPrimitiveExecutorContract( + row_index=row_index, + surface=surface, + primitive=primitive, + bucket_ordinal=bucket_ordinal, + forward_executor="registered_transition_normalization_executor", + backward_executor="registered_transition_normalization_backward_executor", + status="implemented", + reason="registered_transition_executor_dispatches_normalization_row", + required_tensor_roles=("input", "weight", "bias", "output"), + parameter_bindings=parameter_bindings, + inputs=inputs, + outputs=outputs, + ) + return TemporalPrimitiveExecutorContract( + row_index=row_index, + surface=surface, + primitive=primitive, + bucket_ordinal=bucket_ordinal, + forward_executor="unregistered", + backward_executor="unregistered", + status="missing_executor", + reason="unregistered_temporal_primitive_executor", + parameter_bindings=parameter_bindings, + inputs=inputs, + outputs=outputs, + ) + + +def _required_tensor_roles_for_surface(surface: str, primitive: str) -> tuple[str, ...]: + if surface == "message": + return ("message_source", "message_parameter", "message_output") + if surface == "readout": + return ("public_state", "readout_parameter", "output") + if surface == "parameter_reduction": + return ("primitive_tape", "grad_output", "parameter_grad") + if surface == "transition" and primitive in _TRANSITION_AFFINE_PRIMITIVES: + return ("input", "weight", "bias", "output") + if surface == "transition" and primitive in _registered_transition_composite_primitives(): + return ("private_state", "transition_params", "projected_message", "output") + if surface == "transition" and primitive in _TRANSITION_NORMALIZATION_PRIMITIVES: + return ("input", "weight", "bias", "output") + return () + + +def _surface_for_row(row: object) -> str: + return surface_for_temporal_row(row) + + +def _registered_transition_composite_primitives() -> frozenset[str]: + registered_transition_primitives = frozenset( + row.primitive + for pattern in temporal_executor_strategy_registry().all_patterns() + if pattern.surface == "transition" + for row in pattern.row_pattern + ) + return registered_transition_primitives - _TRANSITION_AFFINE_PRIMITIVES - _TRANSITION_NORMALIZATION_PRIMITIVES + + +def _missing_surface_contracts( + table: TemporalPrimitiveTablePlan, +) -> tuple[TemporalPrimitiveExecutorContract, ...]: + bucket_count = max(1, int(table.bucket_count)) + has_message_rows = any("surface=message" in row.flat_bucket_identity for row in table.primitive_rows) + has_readout_rows = any( + "surface=readout" in row.flat_bucket_identity or "surface=readout_boundary" in row.flat_bucket_identity + for row in table.primitive_rows + ) + has_parameter_reduction_rows = any( + "surface=parameter_reduction" in row.flat_bucket_identity for row in table.primitive_rows + ) + contracts: list[TemporalPrimitiveExecutorContract] = [] + if not has_message_rows: + contracts.append( + TemporalPrimitiveExecutorContract( + row_index=None, + surface="message", + primitive="message", + bucket_ordinal=None, + forward_executor="registered_message_executor_required", + backward_executor="registered_message_backward_executor_required", + status="missing_executor", + reason="message_primitive_rows_missing_from_temporal_table", + required_tensor_roles=("sender_public", "receiver_query", "edge_table", "projected_message"), + ) + ) + if not has_readout_rows: + contracts.append( + TemporalPrimitiveExecutorContract( + row_index=None, + surface="readout", + primitive="readout_boundary", + bucket_ordinal=None, + forward_executor="registered_readout_executor_required", + backward_executor="registered_boundary_backward_executor_required", + status="missing_executor", + reason="readout_boundary_rows_missing_from_temporal_table", + required_tensor_roles=( + "public_state", + "output_q", + "value_to_output_weight", + "output_cell_bias", + "output", + ), + ) + ) + if not has_parameter_reduction_rows: + contracts.append( + TemporalPrimitiveExecutorContract( + row_index=None, + surface="parameter_reduction", + primitive="parameter_reduction", + bucket_ordinal=None, + forward_executor="not_applicable", + backward_executor="registered_parameter_binding_executor_required", + status="missing_executor", + reason="parameter_reduction_rows_missing_from_temporal_table", + required_tensor_roles=("primitive_tape", "grad_output", "parameter_grad"), + parameter_bindings=tuple(f"bucket={bucket}" for bucket in range(bucket_count)), + ) + ) + return tuple(contracts) + + +__all__ = [ + "TemporalFusedPrimitiveExecutorGroup", + "TemporalPrimitiveExecutorContract", + "TemporalPrimitiveExecutorPlan", + "build_temporal_primitive_executor_plan", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py new file mode 100644 index 00000000..54965930 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TemporalPrimitiveDefinition: + name: str + opcode: int + transition_tape_kind: str | None = None + full_tape_extra_state_factor: int = 0 + + +@dataclass(frozen=True) +class TemporalTableEnumDefinition: + name: str + opcode: int + + +_TEMPORAL_PRIMITIVE_DEFINITIONS = ( + TemporalPrimitiveDefinition("gated_logspace_recurrence", 1, "gated_logspace", 8), + TemporalPrimitiveDefinition("diag_rtu", 2, "diagonal_recurrence", 1), + TemporalPrimitiveDefinition("diagonal_recurrence", 2, "diagonal_recurrence", 1), + TemporalPrimitiveDefinition("linear", 10), + TemporalPrimitiveDefinition("matmul", 11), + TemporalPrimitiveDefinition("add", 12), + TemporalPrimitiveDefinition("attention_logits", 13), + TemporalPrimitiveDefinition("segment_softmax", 14), + TemporalPrimitiveDefinition("weighted_sum", 15), + TemporalPrimitiveDefinition("tanh", 16), + TemporalPrimitiveDefinition("mul", 18), + TemporalPrimitiveDefinition("concat", 19), + TemporalPrimitiveDefinition("readout_project", 20), + TemporalPrimitiveDefinition("reduction_boundary", 21), + TemporalPrimitiveDefinition("norm_or_identity", 30), + TemporalPrimitiveDefinition("norm", 31), + TemporalPrimitiveDefinition("layer_norm", 32), + TemporalPrimitiveDefinition("rms_norm", 33), + TemporalPrimitiveDefinition("identity", 34), + TemporalPrimitiveDefinition("normalize", 35), +) +_TEMPORAL_SURFACE_DEFINITIONS = ( + TemporalTableEnumDefinition("message", 1), + TemporalTableEnumDefinition("readout", 2), + TemporalTableEnumDefinition("readout_boundary", 3), + TemporalTableEnumDefinition("transition", 4), + TemporalTableEnumDefinition("parameter_reduction", 5), + TemporalTableEnumDefinition("runtime_policy", 6), +) +_TEMPORAL_BINDING_KIND_DEFINITIONS = ( + TemporalTableEnumDefinition("input", 0), + TemporalTableEnumDefinition("parameter", 1), + TemporalTableEnumDefinition("output", 2), +) + + +def temporal_primitive_definitions() -> tuple[TemporalPrimitiveDefinition, ...]: + return _TEMPORAL_PRIMITIVE_DEFINITIONS + + +def temporal_surface_definitions() -> tuple[TemporalTableEnumDefinition, ...]: + return _TEMPORAL_SURFACE_DEFINITIONS + + +def temporal_binding_kind_definitions() -> tuple[TemporalTableEnumDefinition, ...]: + return _TEMPORAL_BINDING_KIND_DEFINITIONS + + +def temporal_primitive_opcode(primitive: str) -> int: + definition = _primitive_definition(primitive) + if definition is None: + raise RuntimeError( + f"Temporal primitive registry has no opcode for primitive {primitive!r}; " + "register the primitive before lowering it into temporal rows" + ) + return int(definition.opcode) + + +def temporal_primitive_name_for_opcode(opcode: int) -> str: + matches = tuple( + definition.name for definition in _TEMPORAL_PRIMITIVE_DEFINITIONS if definition.opcode == int(opcode) + ) + if not matches: + raise RuntimeError( + f"Temporal primitive registry has no primitive name for opcode {int(opcode)}; " + "register the primitive opcode before decoding fused program rows" + ) + return matches[0] + + +def temporal_surface_opcode(surface: str) -> int: + definition = _enum_definition(_TEMPORAL_SURFACE_DEFINITIONS, surface) + if definition is None: + raise RuntimeError( + f"Temporal primitive registry has no surface opcode for {surface!r}; " + "register the scheduling surface before lowering tensor bindings" + ) + return int(definition.opcode) + + +def temporal_binding_kind_opcode(binding_kind: str) -> int: + definition = _enum_definition(_TEMPORAL_BINDING_KIND_DEFINITIONS, binding_kind) + if definition is None: + raise RuntimeError( + f"Temporal primitive registry has no binding-kind opcode for {binding_kind!r}; " + "register the binding kind before lowering executor bindings" + ) + return int(definition.opcode) + + +def temporal_transition_tape_kind(primitive: str) -> str | None: + definition = _primitive_definition(primitive) + return None if definition is None else definition.transition_tape_kind + + +def temporal_full_tape_extra_state_factor(primitive: str) -> int: + definition = _primitive_definition(primitive) + return 0 if definition is None else int(definition.full_tape_extra_state_factor) + + +def _primitive_definition(primitive: str) -> TemporalPrimitiveDefinition | None: + for definition in _TEMPORAL_PRIMITIVE_DEFINITIONS: + if definition.name == str(primitive): + return definition + return None + + +def _enum_definition( + definitions: tuple[TemporalTableEnumDefinition, ...], + name: str, +) -> TemporalTableEnumDefinition | None: + for definition in definitions: + if definition.name == str(name): + return definition + return None + + +__all__ = [ + "TemporalPrimitiveDefinition", + "TemporalTableEnumDefinition", + "temporal_binding_kind_definitions", + "temporal_binding_kind_opcode", + "temporal_full_tape_extra_state_factor", + "temporal_primitive_name_for_opcode", + "temporal_primitive_definitions", + "temporal_primitive_opcode", + "temporal_surface_definitions", + "temporal_surface_opcode", + "temporal_transition_tape_kind", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py new file mode 100644 index 00000000..4b6da77d --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py @@ -0,0 +1,2668 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import torch + +from cortical.fabric.backend.cuda.transition_execution.registry import ( + registered_transition_primitive_executor_records, + transition_program_layer_blocker_codes, + transition_program_layer_missing_symbols, +) + +from cortical.fabric.backend.cuda.sequence_surface.compiler.backward_plan import ( + TemporalBackwardExecutablePlan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.forward_plan import ( + TemporalForwardExecutablePlan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.forward_program import ( + temporal_program_access_opcode, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.memory_plan import ( + TemporalMemoryLivenessPlan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.native_callables import ( + temporal_native_callable_binding_schema_rows_tensor, + temporal_native_callable_catalog_rows_tensor, + temporal_native_callable_output_rows_tensor, + temporal_strategy_id_hash, + temporal_transition_reverse_seed_role_rows_tensor, + validate_temporal_native_callable_binding_schema_coverage, + validate_temporal_native_callable_catalog_coverage, + validate_temporal_native_callable_output_contract_coverage, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_patterns import ( + TemporalForwardExecutorPattern, + TemporalReverseExecutorPattern, + temporal_executor_strategy_registry, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.primitive_registry import ( + temporal_primitive_opcode, + temporal_primitive_name_for_opcode, + temporal_surface_opcode, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.reverse_artifacts import ( + temporal_reverse_artifact_role_id, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.tables import ( + TemporalForwardExecutorRow, + TemporalPrimitiveTablePlan, + TemporalReverseExecutorRow, + temporal_forward_executor_rows, + temporal_reverse_executor_rows, +) + + +TemporalFusedCudaProgramStatus = Literal["legal", "blocked"] +TemporalRegisteredProgramExecutorStatus = Literal["active"] +TemporalReverseProgramStageKind = Literal[ + "output_grad_window", + "readout_message_kv_step", + "transition_step", + "recurrent_message_boundary_initial_kv_step", + "parameter_reducer_step", +] +TemporalReverseSpanOutputGroup = Literal["front", "boundary"] +TemporalReverseOutputRouteKind = Literal[ + "boundary_grad", + "carry_grad", + "readout_parameter_grad", + "query_parameter_grad", + "sender_kv_parameter_grad", + "transition_boundary", + "message_strategy_parameter_grad", + "output_grad", +] +TemporalForwardArtifactSurface = Literal["message", "transition", "readout", "runtime_policy"] +TemporalForwardArtifactMergeKind = Literal["identity_singleton", "concat_or_error", "sum_or_error"] +TemporalForwardOutputRouteKind = Literal[ + "readout_output_cells", + "readout_output_select", + "readout_output_concat", + "readout_output_sum", +] +TemporalReadoutMessageProducerConsumerStrategy = Literal[ + "materialized_recurrent_kv_after", + "stream_readout_from_message_projection", +] +TemporalReadoutMessageProducerConsumerStatus = Literal["active", "candidate", "blocked"] +TemporalMessageTransitionProducerConsumerStrategy = Literal[ + "materialized_recurrent_message", + "stream_message_to_transition_input", +] +TemporalMessageTransitionProducerConsumerStatus = Literal["active", "candidate", "blocked"] + +_REVERSE_STAGE_KIND_OPCODE = { + "output_grad_window": 1, + "readout_message_kv_step": 2, + "transition_step": 3, + "recurrent_message_boundary_initial_kv_step": 4, + "parameter_reducer_step": 5, +} +_REVERSE_STAGE_SURFACE_OPCODE = { + "none": 0, + "message": 1, + "transition": 2, + "readout": 3, + "parameter_reduction": 4, +} +_HANDLER_CAPABILITY_FLAG = { + "message_carrier": 1, + "readout": 2, + "transition": 4, +} +_HANDLER_EFFECT_FLAG = { + "state_read": 1, + "parameter_read": 2, + "message_emit": 4, + "message_read": 8, + "output_emit": 16, + "state_write": 32, + "tape_policy": 64, + "grad_read": 128, + "parameter_grad_emit": 256, +} +_DIRECTION_OPCODE = { + "forward": 1, + "reverse": 2, +} +_REVERSE_SPAN_OUTPUT_GROUP_OPCODE: dict[TemporalReverseSpanOutputGroup, int] = { + "front": 1, + "boundary": 2, +} +_REVERSE_OUTPUT_ROUTE_KIND_OPCODE: dict[TemporalReverseOutputRouteKind, int] = { + "boundary_grad": 1, + "carry_grad": 2, + "readout_parameter_grad": 3, + "query_parameter_grad": 4, + "sender_kv_parameter_grad": 5, + "transition_boundary": 6, + "message_strategy_parameter_grad": 7, + "output_grad": 8, +} +_FORWARD_ARTIFACT_MERGE_KIND_OPCODE: dict[TemporalForwardArtifactMergeKind, int] = { + "identity_singleton": 1, + "concat_or_error": 2, + "sum_or_error": 3, +} +_FORWARD_OUTPUT_ROUTE_KIND_OPCODE: dict[TemporalForwardOutputRouteKind, int] = { + "readout_output_cells": 1, + "readout_output_select": 2, + "readout_output_concat": 3, + "readout_output_sum": 4, +} +_READOUT_MESSAGE_PRODUCER_CONSUMER_STRATEGY_OPCODE: dict[ + TemporalReadoutMessageProducerConsumerStrategy, + int, +] = { + "materialized_recurrent_kv_after": 1, + "stream_readout_from_message_projection": 2, +} +_READOUT_MESSAGE_PRODUCER_CONSUMER_STATUS_OPCODE: dict[ + TemporalReadoutMessageProducerConsumerStatus, + int, +] = { + "active": 1, + "candidate": 2, + "blocked": 3, +} +_READOUT_MESSAGE_PRODUCER_CONSUMER_BLOCKER_OPCODE = { + "": 0, + "pending_registered_readout_streaming_body": 1, + "missing_message_or_readout_executor": 2, + "missing_forward_output_route": 3, + "missing_required_streaming_bindings": 4, + "cost_rejected_current_code_regression": 5, +} +_READOUT_MESSAGE_PRODUCER_CONSUMER_ROLE_MASK = { + "input_k": 1, + "input_v": 2, + "recurrent_k_after": 4, + "recurrent_v_after": 8, + "recurrent_hidden": 16, + "recurrent_kv_weight": 32, + "readout_output_query": 64, + "output_route_rows": 128, + "memory_liveness_rows": 256, +} +_MESSAGE_TRANSITION_PRODUCER_CONSUMER_STRATEGY_OPCODE: dict[ + TemporalMessageTransitionProducerConsumerStrategy, + int, +] = { + "materialized_recurrent_message": 1, + "stream_message_to_transition_input": 2, +} +_MESSAGE_TRANSITION_PRODUCER_CONSUMER_STATUS_OPCODE: dict[ + TemporalMessageTransitionProducerConsumerStatus, + int, +] = { + "active": 1, + "candidate": 2, + "blocked": 3, +} +_MESSAGE_TRANSITION_PRODUCER_CONSUMER_BLOCKER_OPCODE = { + "": 0, + "missing_message_or_transition_executor": 1, + "multiple_transition_consumers_need_merge_rows": 2, + "receiver_count_mismatch": 3, + "pending_direct_chunk_body": 4, + "cost_rejected_current_code_regression": 5, +} +_MESSAGE_TRANSITION_PRODUCER_CONSUMER_ROLE_MASK = { + "input_k": 1, + "input_v": 2, + "recurrent_hidden": 4, + "recurrent_msg": 8, + "transition_aggregate_binding": 16, + "memory_liveness_rows": 32, + "physical_strategy_rows": 64, +} +_REVERSE_FRONT_OUTPUT_NAMES = ( + "grad_boundary_direct", + "grad_recurrent_hidden_backend_direct", + "grad_value_to_output_weight", + "grad_output_cell_bias", + "grad_output_q", + "grad_input_k_from_output", + "grad_input_v_from_output", + "grad_recurrent_hidden_from_kv_graph_order", + "grad_recurrent_kv_weight_graph_order", +) +_REVERSE_LOCAL_ONLY_FRONT_OUTPUT_NAMES = frozenset( + { + "grad_recurrent_hidden_backend_direct", + "grad_input_k_from_output", + "grad_input_v_from_output", + "grad_recurrent_hidden_from_kv_graph_order", + } +) +_BASE_REVERSE_OUTPUT_ROUTES: tuple[ + tuple[TemporalReverseOutputRouteKind, str, TemporalReverseSpanOutputGroup, str], + ..., +] = ( + ("boundary_grad", "direct_boundary", "front", "grad_boundary_direct"), + ("carry_grad", "direct_recurrent_hidden_backend", "front", "grad_recurrent_hidden_backend_direct"), + ("readout_parameter_grad", "value_to_output_weight", "front", "grad_value_to_output_weight"), + ("readout_parameter_grad", "output_cell_bias", "front", "grad_output_cell_bias"), + ("query_parameter_grad", "output_query", "front", "grad_output_q"), + ("output_grad", "input_k_from_output", "front", "grad_input_k_from_output"), + ("output_grad", "input_v_from_output", "front", "grad_input_v_from_output"), + ("output_grad", "recurrent_hidden_from_kv", "front", "grad_recurrent_hidden_from_kv_graph_order"), + ("sender_kv_parameter_grad", "recurrent_output_kv_weight", "front", "grad_recurrent_kv_weight_graph_order"), + ("transition_boundary", "recurrent_query", "boundary", "grad_recurrent_q_backend"), + ("transition_boundary", "boundary_projection", "boundary", "grad_boundary_from_projection_raw"), + ("sender_kv_parameter_grad", "boundary_input_kv_weight", "boundary", "grad_input_kv_weight"), + ("sender_kv_parameter_grad", "boundary_input_kv_grouped_flag", "boundary", "input_kv_grouped_flag"), + ("carry_grad", "hidden_graph_order_before", "boundary", "grad_hidden_graph_order"), + ( + "sender_kv_parameter_grad", + "initial_recurrent_kv_weight", + "boundary", + "grad_initial_recurrent_kv_weight_graph_order", + ), +) +_REVERSE_BOUNDARY_OUTPUT_NAMES = ( + "grad_recurrent_q_backend", + "grad_boundary_from_projection_raw", + "grad_input_kv_weight", + "input_kv_grouped_flag", + "grad_hidden_graph_order", + "grad_initial_recurrent_kv_weight_graph_order", +) +_FORWARD_ARTIFACT_ROLES_BY_SURFACE: tuple[tuple[TemporalForwardArtifactSurface, tuple[str, ...]], ...] = ( + ( + "runtime_policy", + ( + "boundary_step", + "cells_prev", + ), + ), + ( + "message", + ( + "input_k", + "input_v", + "recurrent_k_before", + "recurrent_v_before", + "recurrent_k", + "recurrent_v", + "recurrent_hidden_before_backend_order", + "recurrent_hidden_backend_order", + "recurrent_msg_backend_order", + ), + ), + ( + "transition", + ("transition_state_before",), + ), + ( + "readout", + ( + "output_msg", + "output_cells", + ), + ), +) +_RECOMPUTABLE_FORWARD_ARTIFACT_ROLES_BY_SURFACE: dict[TemporalForwardArtifactSurface, frozenset[str]] = { + "message": frozenset({"recurrent_k", "recurrent_v", "recurrent_k_before", "recurrent_v_before"}), +} + + +def _forward_artifact_route_required( + *, + surface: TemporalForwardArtifactSurface, + artifact_role: str, +) -> bool: + return str(artifact_role) not in _RECOMPUTABLE_FORWARD_ARTIFACT_ROLES_BY_SURFACE.get(surface, frozenset()) + + +@dataclass(frozen=True) +class TemporalForwardArtifactRouteRow: + row_index: int + surface: TemporalForwardArtifactSurface + executor_row_index: int + executor_id: int + bucket_ordinal: int + artifact_role: str + logical_name: str + required: bool = True + schema_version: int = 1 + + @property + def role_id(self) -> int: + return temporal_reverse_artifact_role_id(self.artifact_role) + + @property + def logical_id(self) -> int: + return temporal_strategy_id_hash(self.logical_name) + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + int(temporal_surface_opcode(self.surface)), + int(self.executor_row_index), + int(self.executor_id), + int(self.bucket_ordinal), + int(self.role_id), + int(self.logical_id), + int(self.required), + int(self.schema_version), + 0, + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},surface={self.surface},executor_row={int(self.executor_row_index)}," + f"executor_id={int(self.executor_id)},bucket={int(self.bucket_ordinal)}," + f"artifact={self.artifact_role},logical={self.logical_name},required={int(self.required)}" + ) + + +@dataclass(frozen=True) +class TemporalForwardArtifactMergeRow: + row_index: int + surface: TemporalForwardArtifactSurface + bucket_ordinal: int + artifact_role: str + merge_kind: TemporalForwardArtifactMergeKind + output_route: str + producer_route_row_index: int + producer_executor_row_index: int + producer_executor_id: int + required: bool = True + schema_version: int = 1 + + @property + def role_id(self) -> int: + return temporal_reverse_artifact_role_id(self.artifact_role) + + @property + def output_route_id(self) -> int: + return temporal_strategy_id_hash(self.output_route) + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + int(temporal_surface_opcode(self.surface)), + int(self.bucket_ordinal), + int(self.role_id), + int(_FORWARD_ARTIFACT_MERGE_KIND_OPCODE[self.merge_kind]), + int(self.output_route_id), + int(self.producer_route_row_index), + int(self.producer_executor_row_index), + int(self.producer_executor_id), + int(self.required), + int(self.schema_version), + 0, + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},surface={self.surface},bucket={int(self.bucket_ordinal)}," + f"artifact={self.artifact_role},merge={self.merge_kind},output_route={self.output_route}," + f"producer_route_row={int(self.producer_route_row_index)}," + f"producer_executor_row={int(self.producer_executor_row_index)}," + f"producer_executor_id={int(self.producer_executor_id)},required={int(self.required)}" + ) + + +@dataclass(frozen=True) +class TemporalForwardOutputRouteRow: + row_index: int + route_kind: TemporalForwardOutputRouteKind + surface: Literal["readout"] + executor_row_index: int + executor_id: int + bucket_ordinal: int + output_role: str + output_offset: int = 0 + required: bool = True + schema_version: int = 1 + + @property + def output_role_id(self) -> int: + return temporal_strategy_id_hash(self.output_role) + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + int(_FORWARD_OUTPUT_ROUTE_KIND_OPCODE[self.route_kind]), + int(temporal_surface_opcode(self.surface)), + int(self.executor_row_index), + int(self.executor_id), + int(self.bucket_ordinal), + int(self.output_role_id), + int(self.required), + int(self.schema_version), + int(self.output_offset), + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},kind={self.route_kind},surface={self.surface}," + f"executor_row={int(self.executor_row_index)},executor_id={int(self.executor_id)}," + f"bucket={int(self.bucket_ordinal)},output={self.output_role}," + f"output_offset={int(self.output_offset)},required={int(self.required)}" + ) + + +@dataclass(frozen=True) +class TemporalReadoutMessageProducerConsumerRow: + row_index: int + strategy: TemporalReadoutMessageProducerConsumerStrategy + status: TemporalReadoutMessageProducerConsumerStatus + executable: bool + producer_executor_row_index: int + producer_executor_id: int + producer_bucket_ordinal: int + consumer_executor_row_index: int + consumer_executor_id: int + consumer_bucket_ordinal: int + output_route_row_index: int + required_role_mask: int + blocker: str = "" + schema_version: int = 1 + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + int(self.schema_version), + int(_READOUT_MESSAGE_PRODUCER_CONSUMER_STRATEGY_OPCODE[self.strategy]), + int(_READOUT_MESSAGE_PRODUCER_CONSUMER_STATUS_OPCODE[self.status]), + int(bool(self.executable)), + int(temporal_surface_opcode("message")), + int(self.producer_executor_row_index), + int(self.producer_executor_id), + int(self.producer_bucket_ordinal), + int(temporal_surface_opcode("readout")), + int(self.consumer_executor_row_index), + int(self.consumer_executor_id), + int(self.consumer_bucket_ordinal), + int(self.output_route_row_index), + int(self.required_role_mask), + int(_READOUT_MESSAGE_PRODUCER_CONSUMER_BLOCKER_OPCODE[self.blocker]), + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},strategy={self.strategy},status={self.status}," + f"executable={int(bool(self.executable))}," + f"producer=message:{int(self.producer_executor_row_index)}:{int(self.producer_executor_id)}:" + f"{int(self.producer_bucket_ordinal)}," + f"consumer=readout:{int(self.consumer_executor_row_index)}:{int(self.consumer_executor_id)}:" + f"{int(self.consumer_bucket_ordinal)}," + f"output_route_row={int(self.output_route_row_index)}," + f"required_role_mask={int(self.required_role_mask)}," + f"blocker={self.blocker or '-'}" + ) + + +@dataclass(frozen=True) +class TemporalReadoutMessageProducerConsumerPlan: + rows: tuple[TemporalReadoutMessageProducerConsumerRow, ...] + selected_strategy: str + streaming_status: str + reason: str + + @property + def summaries(self) -> tuple[str, ...]: + return tuple(row.summary for row in self.rows) + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "readout_message_producer_consumer_plan=compiler_owned", + f"row_count={len(self.rows)}", + f"selected_strategy={self.selected_strategy}", + f"streaming_status={self.streaming_status}", + f"reason={self.reason}", + *self.summaries, + ) + + @property + def fingerprint(self) -> tuple[str, ...]: + return self.review_summary + + +@dataclass(frozen=True) +class TemporalMessageTransitionProducerConsumerRow: + row_index: int + strategy: TemporalMessageTransitionProducerConsumerStrategy + status: TemporalMessageTransitionProducerConsumerStatus + executable: bool + producer_executor_row_index: int + producer_executor_id: int + producer_bucket_ordinal: int + consumer_executor_row_index: int + consumer_executor_id: int + consumer_bucket_ordinal: int + aggregate_access_opcode: int + required_role_mask: int + blocker: str = "" + schema_version: int = 1 + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + int(self.schema_version), + int(_MESSAGE_TRANSITION_PRODUCER_CONSUMER_STRATEGY_OPCODE[self.strategy]), + int(_MESSAGE_TRANSITION_PRODUCER_CONSUMER_STATUS_OPCODE[self.status]), + int(bool(self.executable)), + int(temporal_surface_opcode("message")), + int(self.producer_executor_row_index), + int(self.producer_executor_id), + int(self.producer_bucket_ordinal), + int(temporal_surface_opcode("transition")), + int(self.consumer_executor_row_index), + int(self.consumer_executor_id), + int(self.consumer_bucket_ordinal), + int(self.aggregate_access_opcode), + int(self.required_role_mask), + int(_MESSAGE_TRANSITION_PRODUCER_CONSUMER_BLOCKER_OPCODE[self.blocker]), + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},strategy={self.strategy},status={self.status}," + f"executable={int(bool(self.executable))}," + f"producer=message:{int(self.producer_executor_row_index)}:{int(self.producer_executor_id)}:" + f"{int(self.producer_bucket_ordinal)}," + f"consumer=transition:{int(self.consumer_executor_row_index)}:{int(self.consumer_executor_id)}:" + f"{int(self.consumer_bucket_ordinal)}," + f"aggregate_access_opcode={int(self.aggregate_access_opcode)}," + f"required_role_mask={int(self.required_role_mask)}," + f"blocker={self.blocker or '-'}" + ) + + +@dataclass(frozen=True) +class TemporalMessageTransitionProducerConsumerPlan: + rows: tuple[TemporalMessageTransitionProducerConsumerRow, ...] + selected_strategy: str + streaming_status: str + reason: str + + @property + def summaries(self) -> tuple[str, ...]: + return tuple(row.summary for row in self.rows) + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "message_transition_producer_consumer_plan=compiler_owned", + f"row_count={len(self.rows)}", + f"selected_strategy={self.selected_strategy}", + f"streaming_status={self.streaming_status}", + f"reason={self.reason}", + *self.summaries, + ) + + @property + def fingerprint(self) -> tuple[str, ...]: + return self.review_summary + + +@dataclass(frozen=True) +class TemporalReverseArtifactConsumerRouteRow: + row_index: int + surface: str + reverse_executor_row_index: int + reverse_executor_id: int + bucket_ordinal: int + artifact_role: str + forward_artifact_route_row_index: int + forward_executor_row_index: int + forward_executor_id: int + required: bool = True + schema_version: int = 1 + + @property + def artifact_role_id(self) -> int: + return temporal_reverse_artifact_role_id(self.artifact_role) + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + int(temporal_surface_opcode(self.surface)), + int(self.reverse_executor_row_index), + int(self.reverse_executor_id), + int(self.bucket_ordinal), + int(self.artifact_role_id), + int(self.forward_artifact_route_row_index), + int(self.forward_executor_row_index), + int(self.forward_executor_id), + int(self.required), + int(self.schema_version), + 0, + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},surface={self.surface}," + f"reverse_executor_row={int(self.reverse_executor_row_index)}," + f"reverse_executor_id={int(self.reverse_executor_id)},bucket={int(self.bucket_ordinal)}," + f"artifact={self.artifact_role},forward_route_row={int(self.forward_artifact_route_row_index)}," + f"forward_executor_row={int(self.forward_executor_row_index)}," + f"forward_executor_id={int(self.forward_executor_id)},required={int(self.required)}" + ) + + +@dataclass(frozen=True) +class TemporalReverseParameterReducerRouteRow: + row_index: int + route_kind: TemporalReverseOutputRouteKind + target_role: str + source_group: TemporalReverseSpanOutputGroup + source_logical_name: str + surface: str + executor_row_index: int + executor_id: int + bucket_ordinal: int + required: bool = True + schema_version: int = 1 + + @property + def target_role_id(self) -> int: + return temporal_strategy_id_hash(self.target_role) + + @property + def source_role_id(self) -> int: + return temporal_reverse_span_output_role_id(self.source_logical_name) + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + int(_REVERSE_OUTPUT_ROUTE_KIND_OPCODE[self.route_kind]), + int(self.target_role_id), + int(_REVERSE_SPAN_OUTPUT_GROUP_OPCODE[self.source_group]), + int(self.source_role_id), + int(temporal_surface_opcode(self.surface)), + int(self.executor_row_index), + int(self.executor_id), + int(self.bucket_ordinal), + int(self.required), + int(self.schema_version), + 0, + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},kind={self.route_kind},target={self.target_role}," + f"source_group={self.source_group},source={self.source_logical_name},surface={self.surface}," + f"executor_row={int(self.executor_row_index)},executor_id={int(self.executor_id)}," + f"bucket={int(self.bucket_ordinal)},required={int(self.required)}" + ) + + +@dataclass(frozen=True) +class TemporalReverseProgramStageRow: + stage_index: int + stage_kind: TemporalReverseProgramStageKind + surface: str + executor_row_index: int + executor_id: int + primitive_row_start: int + primitive_row_count: int + bucket_ordinal: int + dependency_mask: int + memory_scope: str + + @property + def summary(self) -> str: + return ( + f"stage={int(self.stage_index)}" + f",kind={self.stage_kind}" + f",surface={self.surface}" + f",executor_row={int(self.executor_row_index)}" + f",executor_id={int(self.executor_id)}" + f",primitive_start={int(self.primitive_row_start)}" + f",primitive_count={int(self.primitive_row_count)}" + f",bucket={int(self.bucket_ordinal)}" + f",deps={int(self.dependency_mask)}" + f",memory={self.memory_scope}" + ) + + +@dataclass(frozen=True) +class TemporalReverseProgramStagePlan: + stages: tuple[TemporalReverseProgramStageRow, ...] + + @property + def rows(self) -> torch.Tensor: + rows = [ + [ + int(stage.stage_index), + _REVERSE_STAGE_KIND_OPCODE[stage.stage_kind], + _REVERSE_STAGE_SURFACE_OPCODE.get(stage.surface, 0), + int(stage.executor_row_index), + int(stage.executor_id), + int(stage.primitive_row_start), + int(stage.primitive_row_count), + int(stage.bucket_ordinal), + int(stage.dependency_mask), + 0, + ] + for stage in self.stages + ] + if not rows: + return torch.empty((0, 10), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + @property + def summaries(self) -> tuple[str, ...]: + return tuple(stage.summary for stage in self.stages) + + +def reverse_program_stage_opcode(stage_kind: TemporalReverseProgramStageKind) -> int: + return int(_REVERSE_STAGE_KIND_OPCODE[stage_kind]) + + +@dataclass(frozen=True) +class TemporalReverseSpanOutputRow: + row_index: int + group: TemporalReverseSpanOutputGroup + logical_name: str + local_slot: int + required: bool = True + schema_version: int = 1 + + @property + def role_id(self) -> int: + return temporal_strategy_id_hash(self.logical_name) + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + _REVERSE_SPAN_OUTPUT_GROUP_OPCODE[self.group], + int(self.role_id), + int(self.local_slot), + int(self.required), + int(self.schema_version), + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},group={self.group},logical={self.logical_name}," + f"role={int(self.role_id)},slot={int(self.local_slot)},required={int(self.required)}" + ) + + +def temporal_reverse_span_output_role_id(logical_name: str) -> int: + return int(temporal_strategy_id_hash(str(logical_name))) + + +def temporal_reverse_span_output_group_opcode(group: TemporalReverseSpanOutputGroup) -> int: + return int(_REVERSE_SPAN_OUTPUT_GROUP_OPCODE[group]) + + +def temporal_reverse_span_output_rows() -> tuple[TemporalReverseSpanOutputRow, ...]: + boundary_output_names = _REVERSE_BOUNDARY_OUTPUT_NAMES + _reverse_message_boundary_extra_output_names() + rows: list[TemporalReverseSpanOutputRow] = [] + for group, names in ( + ("front", _REVERSE_FRONT_OUTPUT_NAMES), + ("boundary", boundary_output_names), + ): + for local_slot, logical_name in enumerate(names): + rows.append( + TemporalReverseSpanOutputRow( + row_index=len(rows), + group=group, + logical_name=logical_name, + local_slot=int(local_slot), + required=logical_name not in _REVERSE_LOCAL_ONLY_FRONT_OUTPUT_NAMES, + ) + ) + return tuple(rows) + + +def _reverse_message_boundary_extra_output_names() -> tuple[str, ...]: + names_by_source_index: dict[int, str] = {} + for pattern in temporal_executor_strategy_registry().reverse_patterns(): + if pattern.surface != "message": + continue + for output in pattern.message_param_grad_outputs: + if output.source != "boundary_extra_output": + continue + source_index = int(output.source_index) + existing = names_by_source_index.get(source_index) + if existing is not None and existing != output.logical_name: + raise RuntimeError( + "Temporal reverse span output rows found incompatible message boundary outputs " + f"for source_index={source_index}: {existing!r} vs {output.logical_name!r}" + ) + names_by_source_index[source_index] = output.logical_name + return tuple(name for _index, name in sorted(names_by_source_index.items())) + + +def temporal_reverse_span_output_rows_tensor() -> torch.Tensor: + rows = [row.row for row in temporal_reverse_span_output_rows()] + if not rows: + return torch.empty((0, 6), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_span_output_summaries() -> tuple[str, ...]: + return tuple(row.summary for row in temporal_reverse_span_output_rows()) + + +@dataclass(frozen=True) +class TemporalReverseOutputRouteRow: + row_index: int + route_kind: TemporalReverseOutputRouteKind + target_role: str + source_group: TemporalReverseSpanOutputGroup + source_logical_name: str + required: bool = True + schema_version: int = 1 + + @property + def target_role_id(self) -> int: + return temporal_strategy_id_hash(self.target_role) + + @property + def source_role_id(self) -> int: + return temporal_reverse_span_output_role_id(self.source_logical_name) + + @property + def row(self) -> list[int]: + return [ + int(self.row_index), + int(_REVERSE_OUTPUT_ROUTE_KIND_OPCODE[self.route_kind]), + int(self.target_role_id), + int(_REVERSE_SPAN_OUTPUT_GROUP_OPCODE[self.source_group]), + int(self.source_role_id), + int(self.required), + int(self.schema_version), + 0, + ] + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)},kind={self.route_kind},target={self.target_role}," + f"source_group={self.source_group},source={self.source_logical_name}," + f"required={int(self.required)}" + ) + + +def temporal_reverse_output_route_kind_opcode(route_kind: str) -> int: + for registered_kind, opcode in _REVERSE_OUTPUT_ROUTE_KIND_OPCODE.items(): + if registered_kind == str(route_kind): + return int(opcode) + raise RuntimeError(f"Unregistered temporal reverse output route kind {route_kind!r}") + + +def temporal_reverse_output_route_target_id(target_role: str) -> int: + return int(temporal_strategy_id_hash(str(target_role))) + + +def temporal_reverse_output_route_rows() -> tuple[TemporalReverseOutputRouteRow, ...]: + routes: list[TemporalReverseOutputRouteRow] = [ + TemporalReverseOutputRouteRow( + row_index=index, + route_kind=route_kind, + target_role=target_role, + source_group=source_group, + source_logical_name=source_logical_name, + ) + for index, (route_kind, target_role, source_group, source_logical_name) in enumerate( + _BASE_REVERSE_OUTPUT_ROUTES + ) + ] + for logical_name in _reverse_message_boundary_extra_output_names(): + routes.append( + TemporalReverseOutputRouteRow( + row_index=len(routes), + route_kind="message_strategy_parameter_grad", + target_role=logical_name, + source_group="boundary", + source_logical_name=logical_name, + ) + ) + return tuple(routes) + + +def temporal_reverse_output_route_rows_tensor() -> torch.Tensor: + rows = [row.row for row in temporal_reverse_output_route_rows()] + if not rows: + return torch.empty((0, 8), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_output_route_summaries() -> tuple[str, ...]: + return tuple(row.summary for row in temporal_reverse_output_route_rows()) + + +def build_temporal_forward_artifact_route_rows( + table: TemporalPrimitiveTablePlan, +) -> tuple[TemporalForwardArtifactRouteRow, ...]: + forward_rows = tuple(temporal_forward_executor_rows(table)) + rows: list[TemporalForwardArtifactRouteRow] = [] + + def add_for_executor( + *, + surface: TemporalForwardArtifactSurface, + executor_row_index: int, + executor_id: int, + bucket_ordinal: int, + artifact_roles: tuple[str, ...], + ) -> None: + for artifact_role in artifact_roles: + rows.append( + TemporalForwardArtifactRouteRow( + row_index=len(rows), + surface=surface, + executor_row_index=int(executor_row_index), + executor_id=int(executor_id), + bucket_ordinal=int(bucket_ordinal), + artifact_role=artifact_role, + logical_name=f"{surface}.{artifact_role}", + required=_forward_artifact_route_required(surface=surface, artifact_role=artifact_role), + ) + ) + + for surface, artifact_roles in _FORWARD_ARTIFACT_ROLES_BY_SURFACE: + if surface == "runtime_policy": + add_for_executor( + surface=surface, + executor_row_index=-1, + executor_id=0, + bucket_ordinal=-1, + artifact_roles=artifact_roles, + ) + continue + for executor_row_index, executor_row in enumerate(forward_rows): + if str(executor_row.surface) != surface: + continue + add_for_executor( + surface=surface, + executor_row_index=int(executor_row_index), + executor_id=int(executor_row.executor_id), + bucket_ordinal=int(executor_row.bucket_ordinal), + artifact_roles=artifact_roles, + ) + return tuple(rows) + + +def temporal_forward_artifact_route_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = [row.row for row in build_temporal_forward_artifact_route_rows(table)] + if not rows: + return torch.empty((0, 10), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_forward_artifact_route_summaries(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(row.summary for row in build_temporal_forward_artifact_route_rows(table)) + + +def build_temporal_forward_artifact_merge_rows( + table: TemporalPrimitiveTablePlan, +) -> tuple[TemporalForwardArtifactMergeRow, ...]: + route_rows = build_temporal_forward_artifact_route_rows(table) + grouped: dict[tuple[TemporalForwardArtifactSurface, int, str], list[TemporalForwardArtifactRouteRow]] = {} + for route_row in route_rows: + grouped.setdefault( + ( + route_row.surface, + int(route_row.bucket_ordinal), + route_row.artifact_role, + ), + [], + ).append(route_row) + + rows: list[TemporalForwardArtifactMergeRow] = [] + for (surface, bucket_ordinal, artifact_role), producer_rows in sorted( + grouped.items(), + key=lambda item: ( + int(temporal_surface_opcode(item[0][0])), + int(item[0][1]), + int(temporal_reverse_artifact_role_id(item[0][2])), + ), + ): + if len(producer_rows) == 1: + producer = producer_rows[0] + merge_kind: TemporalForwardArtifactMergeKind = "identity_singleton" + producer_route_row_index = int(producer.row_index) + producer_executor_row_index = int(producer.executor_row_index) + producer_executor_id = int(producer.executor_id) + else: + merge_kind = "concat_or_error" + producer_route_row_index = -1 + producer_executor_row_index = -1 + producer_executor_id = -1 + rows.append( + TemporalForwardArtifactMergeRow( + row_index=len(rows), + surface=surface, + bucket_ordinal=int(bucket_ordinal), + artifact_role=artifact_role, + merge_kind=merge_kind, + output_route=f"{surface}.{int(bucket_ordinal)}.{artifact_role}", + producer_route_row_index=producer_route_row_index, + producer_executor_row_index=producer_executor_row_index, + producer_executor_id=producer_executor_id, + required=any(route.required for route in producer_rows), + ) + ) + return tuple(rows) + + +def temporal_forward_artifact_merge_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = [row.row for row in build_temporal_forward_artifact_merge_rows(table)] + if not rows: + return torch.empty((0, 12), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_forward_artifact_merge_summaries(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(row.summary for row in build_temporal_forward_artifact_merge_rows(table)) + + +def build_temporal_forward_output_route_rows( + table: TemporalPrimitiveTablePlan, +) -> tuple[TemporalForwardOutputRouteRow, ...]: + readout_executor_rows = tuple( + (executor_row_index, executor_row) + for executor_row_index, executor_row in enumerate(temporal_forward_executor_rows(table)) + if str(executor_row.surface) == "readout" + ) + route_kind: TemporalForwardOutputRouteKind = ( + "readout_output_cells" if len(readout_executor_rows) <= 1 else "readout_output_concat" + ) + rows: list[TemporalForwardOutputRouteRow] = [] + output_offset = 0 + for executor_row_index, executor_row in readout_executor_rows: + route_output_offset = int(output_offset) if route_kind == "readout_output_concat" else 0 + rows.append( + TemporalForwardOutputRouteRow( + row_index=len(rows), + route_kind=route_kind, + surface="readout", + executor_row_index=int(executor_row_index), + executor_id=int(executor_row.executor_id), + bucket_ordinal=int(executor_row.bucket_ordinal), + output_role="output_cells", + output_offset=route_output_offset, + ) + ) + if route_kind == "readout_output_concat": + output_offset += int(executor_row.receiver_count) + return tuple(rows) + + +def temporal_forward_output_route_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = [row.row for row in build_temporal_forward_output_route_rows(table)] + if not rows: + return torch.empty((0, 10), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_forward_output_route_kind_opcode(kind: TemporalForwardOutputRouteKind) -> int: + try: + return int(_FORWARD_OUTPUT_ROUTE_KIND_OPCODE[kind]) + except KeyError as error: + raise RuntimeError(f"Unknown temporal forward output route kind {kind!r}") from error + + +def temporal_forward_output_route_summaries(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(row.summary for row in build_temporal_forward_output_route_rows(table)) + + +def build_temporal_readout_message_producer_consumer_plan( + table: TemporalPrimitiveTablePlan, + *, + streaming_readout_body_available: bool = False, + streaming_readout_body_profitable: bool = True, +) -> TemporalReadoutMessageProducerConsumerPlan: + message_executor_rows = tuple( + (executor_row_index, executor_row) + for executor_row_index, executor_row in enumerate(temporal_forward_executor_rows(table)) + if str(executor_row.surface) == "message" + ) + readout_executor_rows = tuple( + (executor_row_index, executor_row) + for executor_row_index, executor_row in enumerate(temporal_forward_executor_rows(table)) + if str(executor_row.surface) == "readout" + ) + output_route_by_executor = { + ( + int(route.executor_row_index), + int(route.executor_id), + int(route.bucket_ordinal), + ): int(route.row_index) + for route in build_temporal_forward_output_route_rows(table) + } + materialized_role_mask = _readout_message_required_role_mask( + ( + "input_k", + "input_v", + "recurrent_k_after", + "recurrent_v_after", + "readout_output_query", + "output_route_rows", + "memory_liveness_rows", + ) + ) + streaming_role_mask = _readout_message_required_role_mask( + ( + "input_k", + "input_v", + "recurrent_hidden", + "recurrent_kv_weight", + "readout_output_query", + "output_route_rows", + "memory_liveness_rows", + ) + ) + missing_blocker = ( + "missing_message_or_readout_executor" + if not message_executor_rows or not readout_executor_rows + else "missing_forward_output_route" + ) + rows: list[TemporalReadoutMessageProducerConsumerRow] = [] + for message_executor_row_index, message_executor_row in message_executor_rows: + for readout_executor_row_index, readout_executor_row in readout_executor_rows: + output_route_row = output_route_by_executor.get( + ( + int(readout_executor_row_index), + int(readout_executor_row.executor_id), + int(readout_executor_row.bucket_ordinal), + ), + -1, + ) + has_route = int(output_route_row) >= 0 + streaming_supported = _readout_message_streaming_bindings_supported( + message_executor_row, + readout_executor_row, + ) + streaming_selected = bool( + streaming_readout_body_available + and streaming_readout_body_profitable + and has_route + and streaming_supported + ) + streaming_blocker = ( + "" + if streaming_selected + else ( + "missing_required_streaming_bindings" + if has_route and not streaming_supported + else ( + "cost_rejected_current_code_regression" + if has_route and streaming_supported and streaming_readout_body_available + else ("pending_registered_readout_streaming_body" if has_route else missing_blocker) + ) + ) + ) + rows.append( + TemporalReadoutMessageProducerConsumerRow( + row_index=len(rows), + strategy="materialized_recurrent_kv_after", + status="candidate" if streaming_selected else "active", + executable=bool(has_route and not streaming_selected), + producer_executor_row_index=int(message_executor_row_index), + producer_executor_id=int(message_executor_row.executor_id), + producer_bucket_ordinal=int(message_executor_row.bucket_ordinal), + consumer_executor_row_index=int(readout_executor_row_index), + consumer_executor_id=int(readout_executor_row.executor_id), + consumer_bucket_ordinal=int(readout_executor_row.bucket_ordinal), + output_route_row_index=int(output_route_row), + required_role_mask=materialized_role_mask, + blocker="" if has_route else missing_blocker, + ) + ) + rows.append( + TemporalReadoutMessageProducerConsumerRow( + row_index=len(rows), + strategy="stream_readout_from_message_projection", + status="active" if streaming_selected else "blocked", + executable=bool(streaming_selected), + producer_executor_row_index=int(message_executor_row_index), + producer_executor_id=int(message_executor_row.executor_id), + producer_bucket_ordinal=int(message_executor_row.bucket_ordinal), + consumer_executor_row_index=int(readout_executor_row_index), + consumer_executor_id=int(readout_executor_row.executor_id), + consumer_bucket_ordinal=int(readout_executor_row.bucket_ordinal), + output_route_row_index=int(output_route_row), + required_role_mask=streaming_role_mask, + blocker=streaming_blocker, + ) + ) + if not rows: + rows = [ + TemporalReadoutMessageProducerConsumerRow( + row_index=0, + strategy="materialized_recurrent_kv_after", + status="blocked", + executable=False, + producer_executor_row_index=-1, + producer_executor_id=-1, + producer_bucket_ordinal=-1, + consumer_executor_row_index=-1, + consumer_executor_id=-1, + consumer_bucket_ordinal=-1, + output_route_row_index=-1, + required_role_mask=materialized_role_mask, + blocker=missing_blocker, + ), + TemporalReadoutMessageProducerConsumerRow( + row_index=1, + strategy="stream_readout_from_message_projection", + status="blocked", + executable=False, + producer_executor_row_index=-1, + producer_executor_id=-1, + producer_bucket_ordinal=-1, + consumer_executor_row_index=-1, + consumer_executor_id=-1, + consumer_bucket_ordinal=-1, + output_route_row_index=-1, + required_role_mask=streaming_role_mask, + blocker=missing_blocker, + ), + ] + selected_strategy = ( + "stream_readout_from_message_projection" + if any(row.strategy == "stream_readout_from_message_projection" and row.status == "active" for row in rows) + else "materialized_recurrent_kv_after" + ) + streaming_status = "active" if selected_strategy == "stream_readout_from_message_projection" else "blocked" + return TemporalReadoutMessageProducerConsumerPlan( + rows=tuple(rows), + selected_strategy=selected_strategy, + streaming_status=streaming_status, + reason=( + f"active_strategy={selected_strategy};" + + ( + "streaming_readout_strategy=registered_program_body;" + if selected_strategy == "stream_readout_from_message_projection" + else ( + "streaming_readout_strategy=cost_rejected_current_code_regression;" + if any( + row.strategy == "stream_readout_from_message_projection" + and row.blocker == "cost_rejected_current_code_regression" + for row in rows + ) + else "streaming_readout_strategy=compiler_product_pending_registered_program_body;" + ) + ) + + "semantics=primitive_rows_stable" + ), + ) + + +def temporal_readout_message_producer_consumer_rows_tensor( + plan: TemporalReadoutMessageProducerConsumerPlan, +) -> torch.Tensor: + rows = [row.row for row in plan.rows] + if not rows: + return torch.empty((0, 16), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_readout_message_producer_consumer_summaries( + plan: TemporalReadoutMessageProducerConsumerPlan, +) -> tuple[str, ...]: + return plan.summaries + + +def build_temporal_message_transition_producer_consumer_plan( + table: TemporalPrimitiveTablePlan, + *, + streaming_transition_body_available: bool = False, + streaming_transition_body_profitable: bool = True, +) -> TemporalMessageTransitionProducerConsumerPlan: + message_executor_rows = tuple( + (executor_row_index, executor_row) + for executor_row_index, executor_row in enumerate(temporal_forward_executor_rows(table)) + if str(executor_row.surface) == "message" + ) + transition_executor_rows = tuple( + (executor_row_index, executor_row) + for executor_row_index, executor_row in enumerate(temporal_forward_executor_rows(table)) + if str(executor_row.surface) == "transition" + ) + materialized_role_mask = _message_transition_required_role_mask( + ( + "input_k", + "input_v", + "recurrent_hidden", + "recurrent_msg", + "transition_aggregate_binding", + "memory_liveness_rows", + "physical_strategy_rows", + ) + ) + streaming_role_mask = materialized_role_mask + aggregate_access_opcode = temporal_program_access_opcode("transition_aggregated_message_input") + missing_blocker = ( + "missing_message_or_transition_executor" + if not message_executor_rows or not transition_executor_rows + else "multiple_transition_consumers_need_merge_rows" + ) + rows: list[TemporalMessageTransitionProducerConsumerRow] = [] + singleton_route = len(message_executor_rows) == 1 and len(transition_executor_rows) == 1 + for message_executor_row_index, message_executor_row in message_executor_rows: + for transition_executor_row_index, transition_executor_row in transition_executor_rows: + receiver_count_matches = int(message_executor_row.receiver_count) == int( + transition_executor_row.receiver_count + ) + streaming_supported = bool(singleton_route and receiver_count_matches) + streaming_selected = bool( + streaming_transition_body_available and streaming_transition_body_profitable and streaming_supported + ) + if streaming_selected: + streaming_blocker = "" + elif not singleton_route: + streaming_blocker = "multiple_transition_consumers_need_merge_rows" + elif not receiver_count_matches: + streaming_blocker = "receiver_count_mismatch" + elif streaming_transition_body_available: + streaming_blocker = "cost_rejected_current_code_regression" + else: + streaming_blocker = "pending_direct_chunk_body" + rows.append( + TemporalMessageTransitionProducerConsumerRow( + row_index=len(rows), + strategy="materialized_recurrent_message", + status="candidate" if streaming_selected else "active", + executable=not streaming_selected, + producer_executor_row_index=int(message_executor_row_index), + producer_executor_id=int(message_executor_row.executor_id), + producer_bucket_ordinal=int(message_executor_row.bucket_ordinal), + consumer_executor_row_index=int(transition_executor_row_index), + consumer_executor_id=int(transition_executor_row.executor_id), + consumer_bucket_ordinal=int(transition_executor_row.bucket_ordinal), + aggregate_access_opcode=aggregate_access_opcode, + required_role_mask=materialized_role_mask, + blocker="", + ) + ) + rows.append( + TemporalMessageTransitionProducerConsumerRow( + row_index=len(rows), + strategy="stream_message_to_transition_input", + status="active" if streaming_selected else "blocked", + executable=bool(streaming_selected), + producer_executor_row_index=int(message_executor_row_index), + producer_executor_id=int(message_executor_row.executor_id), + producer_bucket_ordinal=int(message_executor_row.bucket_ordinal), + consumer_executor_row_index=int(transition_executor_row_index), + consumer_executor_id=int(transition_executor_row.executor_id), + consumer_bucket_ordinal=int(transition_executor_row.bucket_ordinal), + aggregate_access_opcode=aggregate_access_opcode, + required_role_mask=streaming_role_mask, + blocker=streaming_blocker, + ) + ) + if not rows: + rows = [ + TemporalMessageTransitionProducerConsumerRow( + row_index=0, + strategy="materialized_recurrent_message", + status="blocked", + executable=False, + producer_executor_row_index=-1, + producer_executor_id=-1, + producer_bucket_ordinal=-1, + consumer_executor_row_index=-1, + consumer_executor_id=-1, + consumer_bucket_ordinal=-1, + aggregate_access_opcode=aggregate_access_opcode, + required_role_mask=materialized_role_mask, + blocker=missing_blocker, + ), + TemporalMessageTransitionProducerConsumerRow( + row_index=1, + strategy="stream_message_to_transition_input", + status="blocked", + executable=False, + producer_executor_row_index=-1, + producer_executor_id=-1, + producer_bucket_ordinal=-1, + consumer_executor_row_index=-1, + consumer_executor_id=-1, + consumer_bucket_ordinal=-1, + aggregate_access_opcode=aggregate_access_opcode, + required_role_mask=streaming_role_mask, + blocker=missing_blocker, + ), + ] + selected_strategy = ( + "stream_message_to_transition_input" + if any(row.strategy == "stream_message_to_transition_input" and row.status == "active" for row in rows) + else "materialized_recurrent_message" + ) + streaming_status = "active" if selected_strategy == "stream_message_to_transition_input" else "blocked" + return TemporalMessageTransitionProducerConsumerPlan( + rows=tuple(rows), + selected_strategy=selected_strategy, + streaming_status=streaming_status, + reason=( + f"active_strategy={selected_strategy};" + + ( + "message_transition_strategy=direct_binding_to_transition_input;" + if selected_strategy == "stream_message_to_transition_input" + else ( + "message_transition_strategy=cost_rejected_current_code_regression;" + if any( + row.strategy == "stream_message_to_transition_input" + and row.blocker == "cost_rejected_current_code_regression" + for row in rows + ) + else "message_transition_strategy=materialized_or_pending_direct_chunk_body;" + ) + ) + + "semantics=primitive_rows_stable" + ), + ) + + +def temporal_message_transition_producer_consumer_rows_tensor( + plan: TemporalMessageTransitionProducerConsumerPlan, +) -> torch.Tensor: + rows = [row.row for row in plan.rows] + if not rows: + return torch.empty((0, 16), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_message_transition_producer_consumer_summaries( + plan: TemporalMessageTransitionProducerConsumerPlan, +) -> tuple[str, ...]: + return plan.summaries + + +def build_temporal_reverse_artifact_consumer_route_rows( + table: TemporalPrimitiveTablePlan, +) -> tuple[TemporalReverseArtifactConsumerRouteRow, ...]: + forward_routes = build_temporal_forward_artifact_route_rows(table) + reverse_rows = tuple(temporal_reverse_executor_rows(table)) + rows: list[TemporalReverseArtifactConsumerRouteRow] = [] + for reverse_executor_row_index, reverse_executor_row in enumerate(reverse_rows): + surface = str(reverse_executor_row.surface) + if surface not in {"message", "transition", "readout"}: + continue + for forward_route in forward_routes: + if str(forward_route.surface) != surface or int(forward_route.bucket_ordinal) != int( + reverse_executor_row.bucket_ordinal + ): + continue + rows.append( + TemporalReverseArtifactConsumerRouteRow( + row_index=len(rows), + surface=surface, + reverse_executor_row_index=int(reverse_executor_row_index), + reverse_executor_id=int(reverse_executor_row.executor_id), + bucket_ordinal=int(reverse_executor_row.bucket_ordinal), + artifact_role=forward_route.artifact_role, + forward_artifact_route_row_index=int(forward_route.row_index), + forward_executor_row_index=int(forward_route.executor_row_index), + forward_executor_id=int(forward_route.executor_id), + required=bool(forward_route.required), + ) + ) + return tuple(rows) + + +def temporal_reverse_artifact_consumer_route_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = [row.row for row in build_temporal_reverse_artifact_consumer_route_rows(table)] + if not rows: + return torch.empty((0, 12), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_artifact_consumer_route_summaries(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(row.summary for row in build_temporal_reverse_artifact_consumer_route_rows(table)) + + +def build_temporal_reverse_parameter_reducer_route_rows( + table: TemporalPrimitiveTablePlan, +) -> tuple[TemporalReverseParameterReducerRouteRow, ...]: + reverse_rows = tuple(temporal_reverse_executor_rows(table)) + output_routes = temporal_reverse_output_route_rows() + surface_by_route: dict[tuple[TemporalReverseOutputRouteKind, str], str] = { + ("readout_parameter_grad", "value_to_output_weight"): "readout", + ("readout_parameter_grad", "output_cell_bias"): "readout", + ("query_parameter_grad", "output_query"): "readout", + ("transition_boundary", "recurrent_query"): "message", + ("sender_kv_parameter_grad", "recurrent_output_kv_weight"): "message", + ("sender_kv_parameter_grad", "boundary_input_kv_weight"): "message", + ("sender_kv_parameter_grad", "boundary_input_kv_grouped_flag"): "message", + ("sender_kv_parameter_grad", "initial_recurrent_kv_weight"): "message", + } + for logical_name in _reverse_message_boundary_extra_output_names(): + surface_by_route[("message_strategy_parameter_grad", logical_name)] = "message" + rows: list[TemporalReverseParameterReducerRouteRow] = [] + for route in output_routes: + surface = surface_by_route.get((route.route_kind, route.target_role)) + if surface is None: + continue + matches = tuple( + (executor_row_index, executor_row) + for executor_row_index, executor_row in enumerate(reverse_rows) + if str(executor_row.surface) == surface + ) + for executor_row_index, executor_row in matches: + rows.append( + TemporalReverseParameterReducerRouteRow( + row_index=len(rows), + route_kind=route.route_kind, + target_role=route.target_role, + source_group=route.source_group, + source_logical_name=route.source_logical_name, + surface=surface, + executor_row_index=int(executor_row_index), + executor_id=int(executor_row.executor_id), + bucket_ordinal=int(executor_row.bucket_ordinal), + required=bool(route.required), + ) + ) + return tuple(rows) + + +def temporal_reverse_parameter_reducer_route_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = [row.row for row in build_temporal_reverse_parameter_reducer_route_rows(table)] + if not rows: + return torch.empty((0, 12), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_parameter_reducer_route_summaries(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(row.summary for row in build_temporal_reverse_parameter_reducer_route_rows(table)) + + +def build_temporal_reverse_program_stage_plan( + table: TemporalPrimitiveTablePlan, + backward_plan: TemporalBackwardExecutablePlan, +) -> TemporalReverseProgramStagePlan: + rows = temporal_reverse_executor_rows(table) + _require_reverse_executor_row_count(rows, backward_plan) + readout_rows = _reverse_executor_rows_for_surface(rows, surface="readout") + message_rows = _reverse_executor_rows_for_surface(rows, surface="message") + transition_rows = tuple((index, row) for index, row in enumerate(rows) if row.surface == "transition") + stages: list[TemporalReverseProgramStageRow] = [] + + def add( + stage_kind: TemporalReverseProgramStageKind, + *, + row_index: int, + row: TemporalReverseExecutorRow, + dependency_mask: int, + memory_scope: str, + surface: str | None = None, + ) -> None: + stages.append( + TemporalReverseProgramStageRow( + stage_index=len(stages), + stage_kind=stage_kind, + surface=str(row.surface if surface is None else surface), + executor_row_index=int(row_index), + executor_id=int(row.executor_id), + primitive_row_start=int(row.primitive_row_start), + primitive_row_count=int(row.primitive_row_count), + bucket_ordinal=int(row.bucket_ordinal), + dependency_mask=int(dependency_mask), + memory_scope=memory_scope, + ) + ) + + for readout_index, readout_row in readout_rows: + add( + "output_grad_window", + row_index=readout_index, + row=readout_row, + dependency_mask=0, + memory_scope="output_workspace", + ) + for readout_index, readout_row in readout_rows: + add( + "readout_message_kv_step", + row_index=readout_index, + row=readout_row, + dependency_mask=1 << 0, + memory_scope="message_workspace", + ) + for transition_index, transition_row in transition_rows: + add( + "transition_step", + row_index=transition_index, + row=transition_row, + dependency_mask=1 << 1, + memory_scope="transition_workspace", + ) + for message_index, message_row in message_rows: + add( + "recurrent_message_boundary_initial_kv_step", + row_index=message_index, + row=message_row, + dependency_mask=(1 << 1) | (1 << 2), + memory_scope="message_workspace", + ) + for reducer_row_index, reducer_row in enumerate(rows): + add( + "parameter_reducer_step", + row_index=reducer_row_index, + row=reducer_row, + dependency_mask=(1 << 1) | (1 << 2) | (1 << 3), + memory_scope="reduction_workspace", + surface="parameter_reduction", + ) + return TemporalReverseProgramStagePlan(stages=tuple(stages)) + + +@dataclass(frozen=True) +class TemporalFusedCudaLaunchContract: + schema_version: int + forward_entrypoint: str + backward_entrypoint: str + required_tables: tuple[str, ...] + primitive_row_count: int + forward_executor_row_count: int + reverse_executor_row_count: int + forward_handler_row_count: int + reverse_handler_row_count: int + native_strategy_row_count: int + native_callable_catalog_row_count: int + native_callable_binding_schema_row_count: int + native_callable_output_row_count: int + transition_reverse_seed_role_row_count: int + reverse_output_route_row_count: int + forward_artifact_route_row_count: int + forward_artifact_merge_row_count: int + forward_output_route_row_count: int + readout_message_producer_consumer_row_count: int + message_transition_producer_consumer_row_count: int + reverse_artifact_consumer_route_row_count: int + reverse_parameter_reducer_route_row_count: int + forward_binding_row_count: int + reverse_binding_row_count: int + memory_entry_count: int + memory_liveness_row_count: int + workspace_policy: str + layout_policy: str + alias_policy: str + demotion_policy: str + unsupported_policy: str + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "fused_cuda_launch_contract=compiler_owned", + f"schema_version={int(self.schema_version)}", + f"forward_entrypoint={self.forward_entrypoint}", + f"backward_entrypoint={self.backward_entrypoint}", + "required_tables=" + _tuple_summary(self.required_tables), + f"primitive_row_count={int(self.primitive_row_count)}", + f"forward_executor_row_count={int(self.forward_executor_row_count)}", + f"reverse_executor_row_count={int(self.reverse_executor_row_count)}", + f"forward_handler_row_count={int(self.forward_handler_row_count)}", + f"reverse_handler_row_count={int(self.reverse_handler_row_count)}", + f"native_strategy_row_count={int(self.native_strategy_row_count)}", + f"native_callable_catalog_row_count={int(self.native_callable_catalog_row_count)}", + f"native_callable_binding_schema_row_count={int(self.native_callable_binding_schema_row_count)}", + f"native_callable_output_row_count={int(self.native_callable_output_row_count)}", + f"transition_reverse_seed_role_row_count={int(self.transition_reverse_seed_role_row_count)}", + f"reverse_output_route_row_count={int(self.reverse_output_route_row_count)}", + f"forward_artifact_route_row_count={int(self.forward_artifact_route_row_count)}", + f"forward_artifact_merge_row_count={int(self.forward_artifact_merge_row_count)}", + f"forward_output_route_row_count={int(self.forward_output_route_row_count)}", + f"readout_message_producer_consumer_row_count={int(self.readout_message_producer_consumer_row_count)}", + "message_transition_producer_consumer_row_count=" + f"{int(self.message_transition_producer_consumer_row_count)}", + f"reverse_artifact_consumer_route_row_count={int(self.reverse_artifact_consumer_route_row_count)}", + f"reverse_parameter_reducer_route_row_count={int(self.reverse_parameter_reducer_route_row_count)}", + f"forward_binding_row_count={int(self.forward_binding_row_count)}", + f"reverse_binding_row_count={int(self.reverse_binding_row_count)}", + f"memory_entry_count={int(self.memory_entry_count)}", + f"memory_liveness_row_count={int(self.memory_liveness_row_count)}", + f"workspace_policy={self.workspace_policy}", + f"layout_policy={self.layout_policy}", + f"alias_policy={self.alias_policy}", + f"demotion_policy={self.demotion_policy}", + f"unsupported_policy={self.unsupported_policy}", + ) + + +@dataclass(frozen=True) +class TemporalFusedCudaProgramPlan: + status: TemporalFusedCudaProgramStatus + forward_entrypoint: str + backward_entrypoint: str + forward_strategy_ids: tuple[str, ...] + backward_strategy_ids: tuple[str, ...] + blocker_code: str + blocker_reason: str + launch_contract: TemporalFusedCudaLaunchContract + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "fused_cuda_program_plan=compiler_owned", + f"status={self.status}", + f"forward_entrypoint={self.forward_entrypoint}", + f"backward_entrypoint={self.backward_entrypoint}", + "forward_strategy_ids=" + _tuple_summary(self.forward_strategy_ids), + "backward_strategy_ids=" + _tuple_summary(self.backward_strategy_ids), + f"blocker_code={self.blocker_code or '-'}", + f"blocker_reason={self.blocker_reason or '-'}", + *self.launch_contract.review_summary, + ) + + +@dataclass(frozen=True) +class TemporalRegisteredProgramExecutorPlan: + status: TemporalRegisteredProgramExecutorStatus + forward_entrypoint: str + backward_entrypoint: str + demotion_policy: str + fused_cuda_status: TemporalFusedCudaProgramStatus + fused_cuda_blocker_code: str + reason: str + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "registered_program_executor_plan=compiler_owned", + f"status={self.status}", + f"forward_entrypoint={self.forward_entrypoint}", + f"backward_entrypoint={self.backward_entrypoint}", + f"demotion_policy={self.demotion_policy}", + f"fused_cuda_status={self.fused_cuda_status}", + f"fused_cuda_blocker_code={self.fused_cuda_blocker_code or '-'}", + f"reason={self.reason}", + ) + + +def build_temporal_fused_cuda_program_plan( + *, + primitive_rows: torch.Tensor, + forward_plan: TemporalForwardExecutablePlan, + backward_plan: TemporalBackwardExecutablePlan, + memory_plan: TemporalMemoryLivenessPlan, + memory_liveness_rows: torch.Tensor, + forward_handler_rows: torch.Tensor | None = None, + reverse_handler_rows: torch.Tensor | None = None, + native_strategy_rows: torch.Tensor | None = None, + native_callable_binding_schema_rows: torch.Tensor | None = None, + native_callable_output_rows: torch.Tensor | None = None, + transition_reverse_seed_role_rows: torch.Tensor | None = None, + transition_primitive_callable_rows: torch.Tensor | None = None, + reverse_output_route_rows: torch.Tensor | None = None, + forward_artifact_route_rows: torch.Tensor | None = None, + forward_artifact_merge_rows: torch.Tensor | None = None, + forward_output_route_rows: torch.Tensor | None = None, + readout_message_producer_consumer_rows: torch.Tensor | None = None, + message_transition_producer_consumer_rows: torch.Tensor | None = None, + reverse_artifact_consumer_route_rows: torch.Tensor | None = None, + reverse_parameter_reducer_route_rows: torch.Tensor | None = None, +) -> TemporalFusedCudaProgramPlan: + if forward_handler_rows is None or reverse_handler_rows is None: + raise RuntimeError( + "Temporal fused CUDA program planning requires compiler-owned executor handler rows; " + "build them from primitive/executor rows before selecting fused CUDA launch" + ) + native_strategy_rows = ( + temporal_native_executor_strategy_rows_tensor() if native_strategy_rows is None else native_strategy_rows + ) + transition_primitive_callable_rows = ( + temporal_transition_primitive_native_callable_rows_tensor() + if transition_primitive_callable_rows is None + else transition_primitive_callable_rows + ) + native_callable_catalog_rows = temporal_native_callable_catalog_rows_tensor() + native_callable_binding_schema_rows = ( + temporal_native_callable_binding_schema_rows_tensor() + if native_callable_binding_schema_rows is None + else native_callable_binding_schema_rows + ) + native_callable_output_rows = ( + temporal_native_callable_output_rows_tensor() + if native_callable_output_rows is None + else native_callable_output_rows + ) + transition_reverse_seed_role_rows = ( + temporal_transition_reverse_seed_role_rows_tensor() + if transition_reverse_seed_role_rows is None + else transition_reverse_seed_role_rows + ) + reverse_output_route_rows = ( + temporal_reverse_output_route_rows_tensor() if reverse_output_route_rows is None else reverse_output_route_rows + ) + if ( + forward_artifact_route_rows is None + or forward_artifact_merge_rows is None + or forward_output_route_rows is None + or readout_message_producer_consumer_rows is None + or message_transition_producer_consumer_rows is None + or reverse_artifact_consumer_route_rows is None + or reverse_parameter_reducer_route_rows is None + ): + raise RuntimeError( + "Temporal fused CUDA program planning requires compiler-owned artifact route, artifact merge, " + "output route, readout/message producer-consumer, message/transition producer-consumer, " + "reverse artifact consumer route, and reducer route rows" + ) + validate_temporal_native_callable_catalog_coverage( + native_callable_catalog_rows=native_callable_catalog_rows, + native_strategy_rows=native_strategy_rows, + transition_primitive_callable_rows=transition_primitive_callable_rows, + ) + validate_temporal_native_callable_output_contract_coverage( + native_callable_catalog_rows=native_callable_catalog_rows, + native_callable_output_rows=native_callable_output_rows, + ) + validate_temporal_native_callable_binding_schema_coverage( + native_callable_catalog_rows=native_callable_catalog_rows, + native_callable_binding_schema_rows=native_callable_binding_schema_rows, + ) + launch_contract = _build_temporal_fused_cuda_launch_contract( + primitive_rows=primitive_rows, + forward_plan=forward_plan, + backward_plan=backward_plan, + memory_plan=memory_plan, + memory_liveness_rows=memory_liveness_rows, + forward_handler_rows=forward_handler_rows, + reverse_handler_rows=reverse_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_catalog_rows=native_callable_catalog_rows, + native_callable_binding_schema_rows=native_callable_binding_schema_rows, + native_callable_output_rows=native_callable_output_rows, + transition_reverse_seed_role_rows=transition_reverse_seed_role_rows, + reverse_output_route_rows=reverse_output_route_rows, + forward_artifact_route_rows=forward_artifact_route_rows, + forward_artifact_merge_rows=forward_artifact_merge_rows, + forward_output_route_rows=forward_output_route_rows, + readout_message_producer_consumer_rows=readout_message_producer_consumer_rows, + message_transition_producer_consumer_rows=message_transition_producer_consumer_rows, + reverse_artifact_consumer_route_rows=reverse_artifact_consumer_route_rows, + reverse_parameter_reducer_route_rows=reverse_parameter_reducer_route_rows, + ) + if forward_plan.strategy_legality_status != "legal" or backward_plan.strategy_legality_status != "legal": + return TemporalFusedCudaProgramPlan( + status="blocked", + forward_entrypoint=launch_contract.forward_entrypoint, + backward_entrypoint=launch_contract.backward_entrypoint, + forward_strategy_ids=forward_plan.strategy_ids, + backward_strategy_ids=backward_plan.strategy_ids, + blocker_code="STRATEGY_LEGALITY_BLOCKED", + blocker_reason=("fused_program_requires_legal_forward_and_backward_registered_strategy_sets"), + launch_contract=launch_contract, + ) + aggregate_merge_blocker = _forward_artifact_aggregate_merge_blocker(forward_artifact_merge_rows) + if aggregate_merge_blocker is not None: + return TemporalFusedCudaProgramPlan( + status="blocked", + forward_entrypoint=launch_contract.forward_entrypoint, + backward_entrypoint=launch_contract.backward_entrypoint, + forward_strategy_ids=forward_plan.strategy_ids, + backward_strategy_ids=backward_plan.strategy_ids, + blocker_code="FORWARD_ARTIFACT_MERGE_UNSUPPORTED", + blocker_reason=aggregate_merge_blocker, + launch_contract=launch_contract, + ) + forward_output_route_blocker = _forward_output_route_blocker(forward_output_route_rows) + if forward_output_route_blocker is not None: + return TemporalFusedCudaProgramPlan( + status="blocked", + forward_entrypoint=launch_contract.forward_entrypoint, + backward_entrypoint=launch_contract.backward_entrypoint, + forward_strategy_ids=forward_plan.strategy_ids, + backward_strategy_ids=backward_plan.strategy_ids, + blocker_code="FORWARD_OUTPUT_ROUTE_UNSUPPORTED", + blocker_reason=forward_output_route_blocker, + launch_contract=launch_contract, + ) + transition_primitives = _transition_primitive_names_from_rows(primitive_rows) + transition_blockers = transition_program_layer_blocker_codes(transition_primitives) + transition_missing_symbols = transition_program_layer_missing_symbols(transition_primitives) + if transition_blockers: + return TemporalFusedCudaProgramPlan( + status="blocked", + forward_entrypoint=launch_contract.forward_entrypoint, + backward_entrypoint=launch_contract.backward_entrypoint, + forward_strategy_ids=forward_plan.strategy_ids, + backward_strategy_ids=backward_plan.strategy_ids, + blocker_code=transition_blockers[0], + blocker_reason=( + "registered_fused_program_requires_transition_primitives_callable_from_program_layer_cuda_body;" + "transition_primitives=" + + _tuple_summary(transition_primitives) + + ";missing_symbols=" + + _tuple_summary(transition_missing_symbols) + ), + launch_contract=launch_contract, + ) + return TemporalFusedCudaProgramPlan( + status="legal", + forward_entrypoint=launch_contract.forward_entrypoint, + backward_entrypoint=launch_contract.backward_entrypoint, + forward_strategy_ids=forward_plan.strategy_ids, + backward_strategy_ids=backward_plan.strategy_ids, + blocker_code="", + blocker_reason=( + "registered_fused_program_has_sequence_forward_span_dispatch_body;" + "registered_fused_program_has_reverse_span_dispatch_body;" + "transition_primitives=" + _tuple_summary(transition_primitives) + ), + launch_contract=launch_contract, + ) + + +def build_temporal_registered_program_executor_plan( + fused_cuda_program_plan: TemporalFusedCudaProgramPlan, +) -> TemporalRegisteredProgramExecutorPlan: + reason = ( + "registered_fused_cuda_program_blocked_fail_closed" + if fused_cuda_program_plan.status == "blocked" + else "registered_fused_cuda_program_matches_launch_contract" + ) + return TemporalRegisteredProgramExecutorPlan( + status="active", + forward_entrypoint=fused_cuda_program_plan.forward_entrypoint, + backward_entrypoint=fused_cuda_program_plan.backward_entrypoint, + demotion_policy="fail_closed_registered_fused_program_only", + fused_cuda_status=fused_cuda_program_plan.status, + fused_cuda_blocker_code=fused_cuda_program_plan.blocker_code, + reason=reason, + ) + + +def _require_reverse_executor_row_count( + rows: tuple[TemporalReverseExecutorRow, ...], + backward_plan: TemporalBackwardExecutablePlan, +) -> None: + if int(backward_plan.reverse_executor_rows.shape[0]) != len(rows): + raise RuntimeError( + "Temporal reverse program stage plan requires aligned reverse executor rows: " + f"table={len(rows)}; plan={int(backward_plan.reverse_executor_rows.shape[0])}" + ) + + +def _reverse_executor_rows_for_surface( + rows: tuple[TemporalReverseExecutorRow, ...], + *, + surface: str, +) -> tuple[tuple[int, TemporalReverseExecutorRow], ...]: + matches = tuple((index, row) for index, row in enumerate(rows) if row.surface == str(surface)) + if not matches: + raise RuntimeError( + "Temporal reverse program stage plan requires reverse executor rows for surface " + f"{surface!r}; found={len(matches)}" + ) + return matches + + +def _tuple_summary(values: tuple[str, ...]) -> str: + return "none" if not values else ",".join(values) + + +def _transition_primitive_names_from_rows(primitive_rows: torch.Tensor) -> tuple[str, ...]: + if primitive_rows.device.type != "cpu" or primitive_rows.dtype != torch.long or primitive_rows.dim() != 2: + raise RuntimeError("Temporal fused CUDA program requires CPU int64 primitive rows") + names: list[str] = [] + for row in primitive_rows.tolist(): + if int(row[3]) < 0: + continue + names.append(temporal_primitive_name_for_opcode(int(row[0]))) + return tuple(dict.fromkeys(names)) + + +def _forward_artifact_aggregate_merge_blocker(forward_artifact_merge_rows: torch.Tensor) -> str | None: + if ( + forward_artifact_merge_rows.device.type != "cpu" + or forward_artifact_merge_rows.dtype != torch.long + or forward_artifact_merge_rows.dim() != 2 + or int(forward_artifact_merge_rows.shape[1]) != 12 + ): + raise RuntimeError("Temporal fused CUDA program requires CPU int64 forward artifact merge rows") + supported_merge_opcodes = { + int(_FORWARD_ARTIFACT_MERGE_KIND_OPCODE["identity_singleton"]), + int(_FORWARD_ARTIFACT_MERGE_KIND_OPCODE["concat_or_error"]), + int(_FORWARD_ARTIFACT_MERGE_KIND_OPCODE["sum_or_error"]), + } + unsupported_rows = tuple( + int(row[0]) for row in forward_artifact_merge_rows.tolist() if int(row[4]) not in supported_merge_opcodes + ) + if not unsupported_rows: + return None + return "registered_fused_program_requires_supported_forward_artifact_merge_rows;unsupported_merge_rows=" + ",".join( + str(row) for row in unsupported_rows + ) + + +def _forward_output_route_blocker(forward_output_route_rows: torch.Tensor) -> str | None: + if ( + forward_output_route_rows.device.type != "cpu" + or forward_output_route_rows.dtype != torch.long + or forward_output_route_rows.dim() != 2 + or int(forward_output_route_rows.shape[1]) != 10 + ): + raise RuntimeError("Temporal fused CUDA program requires CPU int64 forward output route rows") + route_rows = forward_output_route_rows.tolist() + if not route_rows: + return "registered_fused_program_requires_at_least_one_executable_forward_output_route" + supported_route_opcodes = set(_FORWARD_OUTPUT_ROUTE_KIND_OPCODE.values()) + unsupported_rows = tuple(int(row[0]) for row in route_rows if int(row[1]) not in supported_route_opcodes) + if unsupported_rows: + return ( + "registered_fused_program_requires_supported_forward_output_route_rows;" + "unsupported_output_route_rows=" + ",".join(str(row) for row in unsupported_rows) + ) + negative_offset_rows = tuple(int(row[0]) for row in route_rows if int(row[9]) < 0) + if negative_offset_rows: + return ( + "registered_fused_program_requires_valid_forward_output_route_offsets;" + "negative_output_offset_rows=" + ",".join(str(row) for row in negative_offset_rows) + ) + if len(route_rows) == 1: + route_kind = int(route_rows[0][1]) + if route_kind != int(_FORWARD_OUTPUT_ROUTE_KIND_OPCODE["readout_output_concat"]) and int(route_rows[0][9]) != 0: + return ( + "registered_fused_program_requires_zero_offset_for_non_concat_output_routes;" + f"output_route_row={int(route_rows[0][0])};output_offset={int(route_rows[0][9])}" + ) + return None + route_kinds = {int(row[1]) for row in route_rows} + multi_route_kinds = { + int(_FORWARD_OUTPUT_ROUTE_KIND_OPCODE["readout_output_concat"]), + int(_FORWARD_OUTPUT_ROUTE_KIND_OPCODE["readout_output_sum"]), + } + if len(route_kinds) == 1 and next(iter(route_kinds)) in multi_route_kinds: + if next(iter(route_kinds)) == int(_FORWARD_OUTPUT_ROUTE_KIND_OPCODE["readout_output_sum"]): + nonzero_offset_rows = tuple(int(row[0]) for row in route_rows if int(row[9]) != 0) + if nonzero_offset_rows: + return ( + "registered_fused_program_requires_zero_offset_for_sum_output_routes;" + "nonzero_output_offset_rows=" + ",".join(str(row) for row in nonzero_offset_rows) + ) + return None + return ( + "registered_fused_program_requires_explicit_multi_output_route_merge_kind;" + f"route_count={int(forward_output_route_rows.shape[0])};" + "route_kinds=" + ",".join(str(kind) for kind in sorted(route_kinds)) + ) + + +def _readout_message_required_role_mask(roles: tuple[str, ...]) -> int: + mask = 0 + for role in roles: + try: + mask |= int(_READOUT_MESSAGE_PRODUCER_CONSUMER_ROLE_MASK[str(role)]) + except KeyError as error: + raise RuntimeError(f"Unknown readout/message producer-consumer role {role!r}") from error + return int(mask) + + +def _message_transition_required_role_mask(roles: tuple[str, ...]) -> int: + mask = 0 + for role in roles: + try: + mask |= int(_MESSAGE_TRANSITION_PRODUCER_CONSUMER_ROLE_MASK[str(role)]) + except KeyError as error: + raise RuntimeError(f"Unknown message/transition producer-consumer role {role!r}") from error + return int(mask) + + +def _readout_message_streaming_bindings_supported( + message_executor_row: Any, + readout_executor_row: Any, +) -> bool: + message_bindings = set(str(binding) for binding in getattr(message_executor_row, "parameter_bindings", ())) + readout_bindings = set(str(binding) for binding in getattr(readout_executor_row, "parameter_bindings", ())) + return { + "message_sender_slot_key_weight", + "message_sender_context_key", + "recurrent_sender_value_weight", + }.issubset(message_bindings) and { + "output_q", + "value_to_output_weight", + "output_cell_bias", + }.issubset(readout_bindings) + + +def _build_temporal_fused_cuda_launch_contract( + *, + primitive_rows: torch.Tensor, + forward_plan: TemporalForwardExecutablePlan, + backward_plan: TemporalBackwardExecutablePlan, + memory_plan: TemporalMemoryLivenessPlan, + memory_liveness_rows: torch.Tensor, + forward_handler_rows: torch.Tensor, + reverse_handler_rows: torch.Tensor, + native_strategy_rows: torch.Tensor, + native_callable_catalog_rows: torch.Tensor, + native_callable_binding_schema_rows: torch.Tensor, + native_callable_output_rows: torch.Tensor, + transition_reverse_seed_role_rows: torch.Tensor, + reverse_output_route_rows: torch.Tensor, + forward_artifact_route_rows: torch.Tensor, + forward_artifact_merge_rows: torch.Tensor, + forward_output_route_rows: torch.Tensor, + readout_message_producer_consumer_rows: torch.Tensor, + message_transition_producer_consumer_rows: torch.Tensor, + reverse_artifact_consumer_route_rows: torch.Tensor, + reverse_parameter_reducer_route_rows: torch.Tensor, +) -> TemporalFusedCudaLaunchContract: + if primitive_rows.device.type != "cpu" or primitive_rows.dtype != torch.long or primitive_rows.dim() != 2: + raise RuntimeError("Temporal fused CUDA launch contract requires CPU int64 primitive rows") + if ( + memory_liveness_rows.device.type != "cpu" + or memory_liveness_rows.dtype != torch.long + or memory_liveness_rows.dim() != 2 + ): + raise RuntimeError("Temporal fused CUDA launch contract requires CPU int64 memory liveness rows") + for name, rows in ( + ("forward_handler_rows", forward_handler_rows), + ("reverse_handler_rows", reverse_handler_rows), + ): + if rows.device.type != "cpu" or rows.dtype != torch.long or rows.dim() != 2 or int(rows.shape[1]) != 11: + raise RuntimeError(f"Temporal fused CUDA launch contract requires CPU int64 {name} with shape [N,11]") + if ( + native_strategy_rows.device.type != "cpu" + or native_strategy_rows.dtype != torch.long + or native_strategy_rows.dim() != 2 + or int(native_strategy_rows.shape[1]) != 17 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 native_strategy_rows with shape [N,17]" + ) + if ( + native_callable_catalog_rows.device.type != "cpu" + or native_callable_catalog_rows.dtype != torch.long + or native_callable_catalog_rows.dim() != 2 + or int(native_callable_catalog_rows.shape[1]) != 8 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 native_callable_catalog_rows with shape [N,8]" + ) + if ( + native_callable_binding_schema_rows.device.type != "cpu" + or native_callable_binding_schema_rows.dtype != torch.long + or native_callable_binding_schema_rows.dim() != 2 + or int(native_callable_binding_schema_rows.shape[1]) != 10 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 native_callable_binding_schema_rows " + "with shape [N,10]" + ) + if ( + native_callable_output_rows.device.type != "cpu" + or native_callable_output_rows.dtype != torch.long + or native_callable_output_rows.dim() != 2 + or int(native_callable_output_rows.shape[1]) != 12 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 native_callable_output_rows with shape [N,12]" + ) + if ( + transition_reverse_seed_role_rows.device.type != "cpu" + or transition_reverse_seed_role_rows.dtype != torch.long + or transition_reverse_seed_role_rows.dim() != 2 + or int(transition_reverse_seed_role_rows.shape[1]) != 4 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 transition_reverse_seed_role_rows with shape [N,4]" + ) + if ( + reverse_output_route_rows.device.type != "cpu" + or reverse_output_route_rows.dtype != torch.long + or reverse_output_route_rows.dim() != 2 + or int(reverse_output_route_rows.shape[1]) != 8 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 reverse_output_route_rows with shape [N,8]" + ) + if ( + forward_artifact_route_rows.device.type != "cpu" + or forward_artifact_route_rows.dtype != torch.long + or forward_artifact_route_rows.dim() != 2 + or int(forward_artifact_route_rows.shape[1]) != 10 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 forward_artifact_route_rows with shape [N,10]" + ) + if ( + forward_artifact_merge_rows.device.type != "cpu" + or forward_artifact_merge_rows.dtype != torch.long + or forward_artifact_merge_rows.dim() != 2 + or int(forward_artifact_merge_rows.shape[1]) != 12 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 forward_artifact_merge_rows with shape [N,12]" + ) + if ( + forward_output_route_rows.device.type != "cpu" + or forward_output_route_rows.dtype != torch.long + or forward_output_route_rows.dim() != 2 + or int(forward_output_route_rows.shape[1]) != 10 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 forward_output_route_rows with shape [N,10]" + ) + if ( + reverse_parameter_reducer_route_rows.device.type != "cpu" + or reverse_parameter_reducer_route_rows.dtype != torch.long + or reverse_parameter_reducer_route_rows.dim() != 2 + or int(reverse_parameter_reducer_route_rows.shape[1]) != 12 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 reverse_parameter_reducer_route_rows " + "with shape [N,12]" + ) + if ( + readout_message_producer_consumer_rows.device.type != "cpu" + or readout_message_producer_consumer_rows.dtype != torch.long + or readout_message_producer_consumer_rows.dim() != 2 + or int(readout_message_producer_consumer_rows.shape[1]) != 16 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 readout_message_producer_consumer_rows " + "with shape [N,16]" + ) + if ( + message_transition_producer_consumer_rows.device.type != "cpu" + or message_transition_producer_consumer_rows.dtype != torch.long + or message_transition_producer_consumer_rows.dim() != 2 + or int(message_transition_producer_consumer_rows.shape[1]) != 16 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 message_transition_producer_consumer_rows " + "with shape [N,16]" + ) + if ( + reverse_artifact_consumer_route_rows.device.type != "cpu" + or reverse_artifact_consumer_route_rows.dtype != torch.long + or reverse_artifact_consumer_route_rows.dim() != 2 + or int(reverse_artifact_consumer_route_rows.shape[1]) != 12 + ): + raise RuntimeError( + "Temporal fused CUDA launch contract requires CPU int64 reverse_artifact_consumer_route_rows " + "with shape [N,12]" + ) + return TemporalFusedCudaLaunchContract( + schema_version=1, + forward_entrypoint="registered_temporal_fused_forward_program_cuda", + backward_entrypoint="registered_temporal_fused_backward_program_cuda", + required_tables=( + "primitive_rows", + "forward_executor_rows", + "reverse_executor_rows", + "forward_handler_rows", + "reverse_handler_rows", + "native_strategy_rows", + "native_callable_catalog_rows", + "native_callable_binding_schema_rows", + "native_callable_output_rows", + "transition_reverse_seed_role_rows", + "reverse_output_route_rows", + "forward_artifact_route_rows", + "forward_artifact_merge_rows", + "forward_output_route_rows", + "readout_message_producer_consumer_rows", + "message_transition_producer_consumer_rows", + "reverse_artifact_consumer_route_rows", + "reverse_parameter_reducer_route_rows", + "forward_executor_binding_rows", + "reverse_executor_binding_rows", + "memory_liveness_plan", + "memory_liveness_rows", + "physical_strategy_rows", + "forward_program_runtime_rows", + "reverse_program_runtime_rows", + ), + primitive_row_count=int(primitive_rows.shape[0]), + forward_executor_row_count=int(forward_plan.forward_executor_rows.shape[0]), + reverse_executor_row_count=int(backward_plan.reverse_executor_rows.shape[0]), + forward_handler_row_count=int(forward_handler_rows.shape[0]), + reverse_handler_row_count=int(reverse_handler_rows.shape[0]), + native_strategy_row_count=int(native_strategy_rows.shape[0]), + native_callable_catalog_row_count=int(native_callable_catalog_rows.shape[0]), + native_callable_binding_schema_row_count=int(native_callable_binding_schema_rows.shape[0]), + native_callable_output_row_count=int(native_callable_output_rows.shape[0]), + transition_reverse_seed_role_row_count=int(transition_reverse_seed_role_rows.shape[0]), + reverse_output_route_row_count=int(reverse_output_route_rows.shape[0]), + forward_artifact_route_row_count=int(forward_artifact_route_rows.shape[0]), + forward_artifact_merge_row_count=int(forward_artifact_merge_rows.shape[0]), + forward_output_route_row_count=int(forward_output_route_rows.shape[0]), + readout_message_producer_consumer_row_count=int(readout_message_producer_consumer_rows.shape[0]), + message_transition_producer_consumer_row_count=int(message_transition_producer_consumer_rows.shape[0]), + reverse_artifact_consumer_route_row_count=int(reverse_artifact_consumer_route_rows.shape[0]), + reverse_parameter_reducer_route_row_count=int(reverse_parameter_reducer_route_rows.shape[0]), + forward_binding_row_count=int(forward_plan.executor_binding_rows.shape[0]), + reverse_binding_row_count=int(backward_plan.executor_binding_rows.shape[0]), + memory_entry_count=len(memory_plan.entries), + memory_liveness_row_count=int(memory_liveness_rows.shape[0]), + workspace_policy=memory_plan.workspace_policy, + layout_policy=memory_plan.layout_policy, + alias_policy=memory_plan.alias_policy, + demotion_policy="fail_closed_no_unregistered_program_demotion", + unsupported_policy="typed_strategy_and_binding_rejection", + ) + + +def temporal_native_executor_strategy_rows_tensor() -> torch.Tensor: + rows: list[list[int]] = [] + registry = temporal_executor_strategy_registry() + rows.extend(_native_strategy_row(direction="forward", pattern=pattern) for pattern in registry.forward_patterns()) + rows.extend(_native_strategy_row(direction="reverse", pattern=pattern) for pattern in registry.reverse_patterns()) + unique_rows = tuple(dict.fromkeys(tuple(row) for row in rows)) + if not unique_rows: + return torch.empty((0, 17), dtype=torch.long) + return torch.tensor([list(row) for row in unique_rows], dtype=torch.long) + + +def temporal_transition_primitive_native_callable_rows_tensor() -> torch.Tensor: + rows: list[list[int]] = [] + for record in registered_transition_primitive_executor_records(): + if record.program_layer_status != "callable": + continue + rows.append( + [ + int(temporal_primitive_opcode(record.primitive)), + int(temporal_strategy_id_hash(record.program_forward_symbol)) if record.program_forward_symbol else 0, + int(temporal_strategy_id_hash(record.program_backward_symbol)) if record.program_backward_symbol else 0, + int(record.program_layer_status == "callable"), + int(record.program_forward_status == "callable"), + int(record.program_backward_status == "callable"), + ] + ) + unique_rows = tuple(dict.fromkeys(tuple(row) for row in rows)) + if not unique_rows: + return torch.empty((0, 6), dtype=torch.long) + return torch.tensor([list(row) for row in unique_rows], dtype=torch.long) + + +def _native_strategy_row( + *, + direction: Literal["forward", "reverse"], + pattern: TemporalForwardExecutorPattern | TemporalReverseExecutorPattern, +) -> list[int]: + if not pattern.row_pattern: + raise RuntimeError( + "Registered temporal native strategy row requires a primitive row schema: " + f"direction={direction}; strategy={pattern.stable_strategy_id}" + ) + primitive = pattern.row_pattern[0].primitive + return [ + int(_DIRECTION_OPCODE[direction]), + int(temporal_surface_opcode(pattern.surface)), + int(pattern.executor_id), + int(pattern.stable_handler_kind_opcode), + int(temporal_primitive_opcode(primitive)), + int(len(pattern.row_pattern)), + int(_capability_mask(pattern.stable_handler_capabilities)), + int(_effect_mask(pattern.stable_handler_effects)), + int(pattern.row_schema_version), + int(pattern.tensor_binding_schema_version), + int(pattern.metadata_schema_version), + int(pattern.cuda_kernel_abi_version), + int(temporal_strategy_id_hash(pattern.stable_strategy_id)), + int(len(pattern.program_accesses)), + int(len(pattern.state_carry_rules)), + int(bool(pattern.verified_rewrite_required)), + int(temporal_strategy_id_hash(pattern.stable_native_callable_id)), + ] + + +def temporal_forward_executor_handler_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = tuple( + dict.fromkeys(tuple(_forward_handler_row(table, row)) for row in temporal_forward_executor_rows(table)) + ) + if not rows: + return torch.empty((0, 11), dtype=torch.long) + return torch.tensor([list(row) for row in rows], dtype=torch.long) + + +def temporal_reverse_executor_handler_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = tuple( + dict.fromkeys(tuple(_reverse_handler_row(table, row)) for row in temporal_reverse_executor_rows(table)) + ) + if not rows: + return torch.empty((0, 11), dtype=torch.long) + return torch.tensor([list(row) for row in rows], dtype=torch.long) + + +def _forward_handler_row( + table: TemporalPrimitiveTablePlan, + row: TemporalForwardExecutorRow, +) -> list[int]: + strategy = _forward_handler_strategy(table, row) + return _handler_row( + table, + executor_id=int(row.executor_id), + surface=str(row.surface), + primitive_row_start=int(row.primitive_row_start), + primitive_row_count=int(row.primitive_row_count), + handler_kind=int(strategy.stable_handler_kind_opcode), + capability_flags=_capability_mask(strategy.stable_handler_capabilities), + required_effects=strategy.stable_handler_effects, + strategy_hash=temporal_strategy_id_hash(strategy.stable_strategy_id), + program_access_count=len(strategy.program_accesses), + state_carry_rule_count=len(strategy.state_carry_rules), + verified_rewrite_required=bool(strategy.verified_rewrite_required), + ) + + +def _reverse_handler_row( + table: TemporalPrimitiveTablePlan, + row: TemporalReverseExecutorRow, +) -> list[int]: + strategy = _reverse_handler_strategy(table, row) + return _handler_row( + table, + executor_id=int(row.executor_id), + surface=str(row.surface), + primitive_row_start=int(row.primitive_row_start), + primitive_row_count=int(row.primitive_row_count), + handler_kind=int(strategy.stable_handler_kind_opcode), + capability_flags=_capability_mask(strategy.stable_handler_capabilities), + required_effects=strategy.stable_handler_effects, + strategy_hash=temporal_strategy_id_hash(strategy.stable_strategy_id), + program_access_count=len(strategy.program_accesses), + state_carry_rule_count=len(strategy.state_carry_rules), + verified_rewrite_required=bool(strategy.verified_rewrite_required), + ) + + +def _forward_handler_strategy( + table: TemporalPrimitiveTablePlan, + row: TemporalForwardExecutorRow, +) -> TemporalForwardExecutorPattern: + strategy = temporal_executor_strategy_registry().match_forward( + surface=str(row.surface), + bucket_ordinal=int(row.bucket_ordinal), + rows=_executor_primitive_rows(table, int(row.primitive_row_start), int(row.primitive_row_count)), + ) + if ( + strategy is None + or strategy.executor_name != row.executor_name + or int(strategy.executor_id) != int(row.executor_id) + ): + raise RuntimeError( + "Temporal fused CUDA forward executor has no registered handler strategy: " + f"executor={row.executor_name!r}; executor_id={int(row.executor_id)}; " + f"surface={row.surface!r}; bucket={int(row.bucket_ordinal)}" + ) + return strategy + + +def _reverse_handler_strategy( + table: TemporalPrimitiveTablePlan, + row: TemporalReverseExecutorRow, +) -> TemporalReverseExecutorPattern: + strategy = temporal_executor_strategy_registry().match_reverse( + surface=str(row.surface), + bucket_ordinal=int(row.bucket_ordinal), + rows=_executor_primitive_rows(table, int(row.primitive_row_start), int(row.primitive_row_count)), + ) + if ( + strategy is None + or strategy.executor_name != row.executor_name + or int(strategy.executor_id) != int(row.executor_id) + ): + raise RuntimeError( + "Temporal fused CUDA reverse executor has no registered handler strategy: " + f"executor={row.executor_name!r}; executor_id={int(row.executor_id)}; " + f"surface={row.surface!r}; bucket={int(row.bucket_ordinal)}" + ) + return strategy + + +def _executor_primitive_rows( + table: TemporalPrimitiveTablePlan, + primitive_row_start: int, + primitive_row_count: int, +) -> tuple[object, ...]: + if primitive_row_start < 0 or primitive_row_count <= 0: + return () + primitive_row_end = int(primitive_row_start) + int(primitive_row_count) + return tuple(table.primitive_rows[int(primitive_row_start) : primitive_row_end]) + + +def _handler_row( + table: TemporalPrimitiveTablePlan, + *, + executor_id: int, + surface: str, + primitive_row_start: int, + primitive_row_count: int, + handler_kind: int, + capability_flags: int, + required_effects: tuple[str, ...], + strategy_hash: int, + program_access_count: int, + state_carry_rule_count: int, + verified_rewrite_required: bool, +) -> list[int]: + if primitive_row_start < 0 or primitive_row_start >= len(table.primitive_rows): + raise RuntimeError( + "Temporal fused CUDA executor handler references no primitive row: " + f"executor_id={int(executor_id)}; surface={surface!r}; row={int(primitive_row_start)}" + ) + primitive = table.primitive_rows[int(primitive_row_start)].primitive + return [ + int(executor_id), + int(temporal_surface_opcode(surface)), + int(handler_kind), + int(temporal_primitive_opcode(primitive)), + int(primitive_row_count), + int(capability_flags), + int(_effect_mask(required_effects)), + int(strategy_hash), + int(program_access_count), + int(state_carry_rule_count), + int(bool(verified_rewrite_required)), + ] + + +def _effect_mask(effects: tuple[str, ...]) -> int: + mask = 0 + for effect in effects: + flag = _HANDLER_EFFECT_FLAG.get(effect) + if flag is None: + raise RuntimeError(f"Temporal fused CUDA handler has no effect flag for {effect!r}") + mask |= int(flag) + return mask + + +def _capability_mask(capabilities: tuple[str, ...]) -> int: + mask = 0 + for capability in capabilities: + flag = _HANDLER_CAPABILITY_FLAG.get(capability) + if flag is None: + raise RuntimeError(f"Temporal fused CUDA handler has no capability flag for {capability!r}") + mask |= int(flag) + if mask <= 0: + raise RuntimeError("Temporal fused CUDA handler strategy declares no runtime capability") + return mask + + +__all__ = [ + "TemporalReverseProgramStageKind", + "TemporalForwardArtifactMergeKind", + "TemporalForwardArtifactMergeRow", + "TemporalForwardArtifactRouteRow", + "TemporalForwardOutputRouteKind", + "TemporalForwardOutputRouteRow", + "TemporalMessageTransitionProducerConsumerPlan", + "TemporalMessageTransitionProducerConsumerRow", + "TemporalMessageTransitionProducerConsumerStatus", + "TemporalMessageTransitionProducerConsumerStrategy", + "TemporalReadoutMessageProducerConsumerPlan", + "TemporalReadoutMessageProducerConsumerRow", + "TemporalReadoutMessageProducerConsumerStatus", + "TemporalReadoutMessageProducerConsumerStrategy", + "TemporalReverseArtifactConsumerRouteRow", + "TemporalReverseProgramStagePlan", + "TemporalReverseProgramStageRow", + "TemporalReverseParameterReducerRouteRow", + "TemporalReverseSpanOutputGroup", + "TemporalReverseSpanOutputRow", + "TemporalReverseOutputRouteKind", + "TemporalReverseOutputRouteRow", + "TemporalFusedCudaLaunchContract", + "TemporalFusedCudaProgramPlan", + "TemporalRegisteredProgramExecutorPlan", + "build_temporal_forward_artifact_merge_rows", + "build_temporal_forward_artifact_route_rows", + "build_temporal_forward_output_route_rows", + "build_temporal_message_transition_producer_consumer_plan", + "build_temporal_readout_message_producer_consumer_plan", + "build_temporal_reverse_artifact_consumer_route_rows", + "build_temporal_reverse_parameter_reducer_route_rows", + "build_temporal_reverse_program_stage_plan", + "build_temporal_fused_cuda_program_plan", + "build_temporal_registered_program_executor_plan", + "reverse_program_stage_opcode", + "temporal_forward_artifact_merge_rows_tensor", + "temporal_forward_artifact_merge_summaries", + "temporal_forward_artifact_route_rows_tensor", + "temporal_forward_artifact_route_summaries", + "temporal_forward_output_route_rows_tensor", + "temporal_forward_output_route_summaries", + "temporal_message_transition_producer_consumer_rows_tensor", + "temporal_message_transition_producer_consumer_summaries", + "temporal_readout_message_producer_consumer_rows_tensor", + "temporal_readout_message_producer_consumer_summaries", + "temporal_forward_executor_handler_rows_tensor", + "temporal_native_callable_binding_schema_rows_tensor", + "temporal_native_callable_catalog_rows_tensor", + "temporal_native_callable_output_rows_tensor", + "temporal_native_executor_strategy_rows_tensor", + "temporal_reverse_artifact_consumer_route_rows_tensor", + "temporal_reverse_artifact_consumer_route_summaries", + "temporal_reverse_parameter_reducer_route_rows_tensor", + "temporal_reverse_parameter_reducer_route_summaries", + "temporal_reverse_span_output_group_opcode", + "temporal_reverse_span_output_role_id", + "temporal_reverse_span_output_rows", + "temporal_reverse_span_output_rows_tensor", + "temporal_reverse_span_output_summaries", + "temporal_reverse_output_route_kind_opcode", + "temporal_reverse_output_route_target_id", + "temporal_reverse_output_route_rows", + "temporal_reverse_output_route_rows_tensor", + "temporal_reverse_output_route_summaries", + "temporal_reverse_executor_handler_rows_tensor", + "temporal_strategy_id_hash", + "temporal_transition_primitive_native_callable_rows_tensor", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_runtime.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_runtime.py new file mode 100644 index 00000000..2ee2158b --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_runtime.py @@ -0,0 +1,693 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import torch + + +TemporalForwardProgramRuntimeRole = Literal[ + "recurrent_local_sender_idx", + "output_local_sender_idx", + "local_distance", + "local_delay", + "inner_steps", + "output_boundary_terminal", + "distance_scale", + "head_dim", + "value_dim", + "use_delay", +] +TemporalReverseProgramRuntimeRole = Literal[ + "graph_to_backend_order", + "backend_to_graph_inverse_order", + "output_local_sender_idx", + "local_distance", + "local_delay", + "output_neighbor_idx", + "output_neighbor_valid", + "output_edge_distance", + "output_edge_delay", + "recurrent_local_sender_idx", + "recurrent_neighbor_idx", + "recurrent_neighbor_valid", + "recurrent_edge_distance", + "recurrent_edge_delay", + "message_step_indices", + "input_count", + "recurrent_count", + "distance_scale", + "use_sparse_messages", + "use_delay", + "group_size", + "head_dim", + "value_dim", + "return_boundary_grad", +] +TemporalProgramRuntimeSupportDirection = Literal["forward", "reverse"] +TemporalProgramRuntimeRequirement = Literal[ + "output_contract", + "readout_pool", + "final_state_materialization", + "boundary_device_dtype", + "local_message_step", + "artifact_storage", + "window_length", + "grad_output_window", + "grad_carry_device_dtype", + "grad_carry_materialization_policy", + "reverse_artifact_roles", +] + +_FORWARD_RUNTIME_ROLE_OPCODE: dict[TemporalForwardProgramRuntimeRole, int] = { + "recurrent_local_sender_idx": 1, + "output_local_sender_idx": 2, + "local_distance": 3, + "local_delay": 4, + "inner_steps": 5, + "output_boundary_terminal": 6, + "distance_scale": 7, + "head_dim": 8, + "value_dim": 9, + "use_delay": 10, +} +_REVERSE_RUNTIME_ROLE_OPCODE: dict[TemporalReverseProgramRuntimeRole, int] = { + "graph_to_backend_order": 1, + "backend_to_graph_inverse_order": 2, + "output_local_sender_idx": 3, + "local_distance": 4, + "local_delay": 5, + "output_neighbor_idx": 6, + "output_neighbor_valid": 7, + "output_edge_distance": 8, + "output_edge_delay": 9, + "recurrent_local_sender_idx": 10, + "recurrent_neighbor_idx": 11, + "recurrent_neighbor_valid": 12, + "recurrent_edge_distance": 13, + "recurrent_edge_delay": 14, + "message_step_indices": 15, + "input_count": 16, + "recurrent_count": 17, + "distance_scale": 18, + "use_sparse_messages": 19, + "use_delay": 20, + "group_size": 21, + "head_dim": 22, + "value_dim": 23, + "return_boundary_grad": 24, +} +_DTYPE_OPCODE = { + torch.float16: 1, + torch.bfloat16: 2, + torch.float32: 3, + torch.float64: 4, + torch.int32: 5, + torch.int64: 6, + torch.bool: 7, +} +_DEVICE_OPCODE = { + "cpu": 1, + "cuda": 2, +} +_SUPPORT_DIRECTION_OPCODE: dict[TemporalProgramRuntimeSupportDirection, int] = { + "forward": 1, + "reverse": 2, +} +_RUNTIME_REQUIREMENT_OPCODE: dict[TemporalProgramRuntimeRequirement, int] = { + "output_contract": 1, + "readout_pool": 2, + "final_state_materialization": 3, + "boundary_device_dtype": 4, + "local_message_step": 5, + "artifact_storage": 6, + "window_length": 7, + "grad_output_window": 8, + "grad_carry_device_dtype": 9, + "reverse_artifact_roles": 10, + "grad_carry_materialization_policy": 11, +} + + +@dataclass(frozen=True) +class TemporalForwardProgramRuntimeFact: + role: TemporalForwardProgramRuntimeRole + tensor: torch.Tensor + + @property + def role_opcode(self) -> int: + return int(_FORWARD_RUNTIME_ROLE_OPCODE[self.role]) + + @property + def dtype_opcode(self) -> int: + dtype_opcode = _DTYPE_OPCODE.get(self.tensor.dtype) + if dtype_opcode is None: + raise RuntimeError( + "Registered forward program runtime fact has unsupported dtype: " + f"role={self.role}; dtype={self.tensor.dtype}" + ) + return int(dtype_opcode) + + @property + def device_opcode(self) -> int: + device_opcode = _DEVICE_OPCODE.get(str(self.tensor.device.type)) + if device_opcode is None: + raise RuntimeError( + "Registered forward program runtime fact has unsupported device: " + f"role={self.role}; device={self.tensor.device}" + ) + return int(device_opcode) + + @property + def summary(self) -> str: + return ( + f"runtime_fact={self.role}" + f",role_opcode={self.role_opcode}" + f",shape={_shape_summary(tuple(int(item) for item in self.tensor.shape))}" + f",dtype={self.tensor.dtype}" + f",device={self.tensor.device.type}" + ) + + +@dataclass(frozen=True) +class TemporalReverseProgramRuntimeFact: + role: TemporalReverseProgramRuntimeRole + tensor: torch.Tensor + + @property + def role_opcode(self) -> int: + return int(_REVERSE_RUNTIME_ROLE_OPCODE[self.role]) + + @property + def dtype_opcode(self) -> int: + dtype_opcode = _DTYPE_OPCODE.get(self.tensor.dtype) + if dtype_opcode is None: + raise RuntimeError( + "Registered reverse program runtime fact has unsupported dtype: " + f"role={self.role}; dtype={self.tensor.dtype}" + ) + return int(dtype_opcode) + + @property + def device_opcode(self) -> int: + device_opcode = _DEVICE_OPCODE.get(str(self.tensor.device.type)) + if device_opcode is None: + raise RuntimeError( + "Registered reverse program runtime fact has unsupported device: " + f"role={self.role}; device={self.tensor.device}" + ) + return int(device_opcode) + + @property + def summary(self) -> str: + return ( + f"runtime_fact={self.role}" + f",role_opcode={self.role_opcode}" + f",shape={_shape_summary(tuple(int(item) for item in self.tensor.shape))}" + f",dtype={self.tensor.dtype}" + f",device={self.tensor.device.type}" + ) + + +@dataclass(frozen=True) +class TemporalForwardProgramRuntimePlan: + facts: tuple[TemporalForwardProgramRuntimeFact, ...] + + @property + def tensors(self) -> tuple[torch.Tensor, ...]: + return tuple(fact.tensor for fact in self.facts) + + @property + def rows(self) -> torch.Tensor: + rows = [ + [ + int(fact.role_opcode), + int(tensor_index), + int(fact.dtype_opcode), + int(fact.tensor.dim()), + int(fact.device_opcode), + 1, + ] + for tensor_index, fact in enumerate(self.facts) + ] + return torch.tensor(rows, dtype=torch.long) + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "forward_program_runtime_plan=compiler_owned", + f"runtime_fact_count={len(self.facts)}", + *(fact.summary for fact in self.facts), + ) + + +@dataclass(frozen=True) +class TemporalReverseProgramRuntimePlan: + facts: tuple[TemporalReverseProgramRuntimeFact, ...] + + @property + def tensors(self) -> tuple[torch.Tensor, ...]: + return tuple(fact.tensor for fact in self.facts) + + @property + def rows(self) -> torch.Tensor: + rows = [ + [ + int(fact.role_opcode), + int(tensor_index), + int(fact.dtype_opcode), + int(fact.tensor.dim()), + int(fact.device_opcode), + 1, + ] + for tensor_index, fact in enumerate(self.facts) + ] + return torch.tensor(rows, dtype=torch.long) + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "reverse_program_runtime_plan=compiler_owned", + f"runtime_fact_count={len(self.facts)}", + *(fact.summary for fact in self.facts), + ) + + +@dataclass(frozen=True) +class TemporalProgramRuntimeSupportCheck: + direction: TemporalProgramRuntimeSupportDirection + requirement: TemporalProgramRuntimeRequirement + legal: bool + reason: str + + @property + def row(self) -> list[int]: + return [ + int(_SUPPORT_DIRECTION_OPCODE[self.direction]), + int(_RUNTIME_REQUIREMENT_OPCODE[self.requirement]), + int(bool(self.legal)), + int(_stable_reason_hash(self.reason)), + 1, + ] + + @property + def summary(self) -> str: + return ( + f"direction={self.direction},requirement={self.requirement}," + f"legal={int(bool(self.legal))},reason={self.reason or '-'}" + ) + + +@dataclass(frozen=True) +class TemporalProgramRuntimeSupportPlan: + direction: TemporalProgramRuntimeSupportDirection + checks: tuple[TemporalProgramRuntimeSupportCheck, ...] + + @property + def rows(self) -> torch.Tensor: + if not self.checks: + return torch.empty((0, 5), dtype=torch.long) + return torch.tensor([check.row for check in self.checks], dtype=torch.long) + + @property + def rejection_reason(self) -> str | None: + for check in self.checks: + if not bool(check.legal): + return check.reason + return None + + @property + def review_summary(self) -> tuple[str, ...]: + status = "legal" if self.rejection_reason is None else "blocked" + return ( + f"{self.direction}_program_runtime_support=compiler_owned", + f"status={status}", + f"runtime_requirement_count={len(self.checks)}", + *(check.summary for check in self.checks), + ) + + +def temporal_forward_program_runtime_role_opcode(role: TemporalForwardProgramRuntimeRole) -> int: + return int(_FORWARD_RUNTIME_ROLE_OPCODE[role]) + + +def temporal_reverse_program_runtime_role_opcode(role: TemporalReverseProgramRuntimeRole) -> int: + return int(_REVERSE_RUNTIME_ROLE_OPCODE[role]) + + +def build_temporal_forward_program_runtime_plan( + runtime: Any, + *, + boundary_seq: torch.Tensor, + inner_steps: int, + output_boundary_terminal: bool, +) -> TemporalForwardProgramRuntimePlan: + scalar_device = torch.device("cpu") + facts = ( + TemporalForwardProgramRuntimeFact( + "recurrent_local_sender_idx", + runtime.recurrent_local_sender_idx_flat_bucket_carry_order, + ), + TemporalForwardProgramRuntimeFact( + "output_local_sender_idx", + runtime.output_local_sender_idx_flat_bucket_carry_order, + ), + TemporalForwardProgramRuntimeFact("local_distance", runtime.local_distance), + TemporalForwardProgramRuntimeFact("local_delay", runtime.local_delay), + TemporalForwardProgramRuntimeFact( + "inner_steps", + torch.tensor([int(inner_steps)], dtype=torch.long, device=scalar_device), + ), + TemporalForwardProgramRuntimeFact( + "output_boundary_terminal", + torch.tensor([1 if bool(output_boundary_terminal) else 0], dtype=torch.long, device=scalar_device), + ), + TemporalForwardProgramRuntimeFact( + "distance_scale", + torch.tensor( + [float(runtime.config.message.distance_logit_scale)], + dtype=torch.float64, + device=scalar_device, + ), + ), + TemporalForwardProgramRuntimeFact( + "head_dim", + torch.tensor([int(runtime.head_dim)], dtype=torch.long, device=scalar_device), + ), + TemporalForwardProgramRuntimeFact( + "value_dim", + torch.tensor([int(runtime.value_dim)], dtype=torch.long, device=scalar_device), + ), + TemporalForwardProgramRuntimeFact( + "use_delay", + torch.tensor([1 if bool(runtime._has_edge_delay) else 0], dtype=torch.long, device=scalar_device), + ), + ) + _validate_forward_runtime_fact_devices(boundary_seq, facts) + return TemporalForwardProgramRuntimePlan(facts=facts) + + +def build_temporal_reverse_program_runtime_plan( + runtime: Any, + *, + reference_boundary: torch.Tensor, + message_step_indices: torch.Tensor, + return_boundary_grad: bool, + use_sparse_messages: bool, +) -> TemporalReverseProgramRuntimePlan: + scalar_device = torch.device("cpu") + facts = ( + TemporalReverseProgramRuntimeFact("graph_to_backend_order", runtime.population_backend_recurrent_order), + TemporalReverseProgramRuntimeFact( + "backend_to_graph_inverse_order", + runtime.population_backend_recurrent_inverse_order.to( + device=reference_boundary.device, + dtype=torch.long, + ), + ), + TemporalReverseProgramRuntimeFact( + "output_local_sender_idx", + runtime.output_local_sender_idx_flat_bucket_carry_order, + ), + TemporalReverseProgramRuntimeFact("local_distance", runtime.local_distance), + TemporalReverseProgramRuntimeFact("local_delay", runtime.local_delay), + TemporalReverseProgramRuntimeFact("output_neighbor_idx", runtime.output_neighbor_idx), + TemporalReverseProgramRuntimeFact("output_neighbor_valid", runtime.output_neighbor_valid), + TemporalReverseProgramRuntimeFact("output_edge_distance", runtime.output_edge_distance), + TemporalReverseProgramRuntimeFact("output_edge_delay", runtime.output_edge_delay), + TemporalReverseProgramRuntimeFact( + "recurrent_local_sender_idx", + runtime.recurrent_local_sender_idx_flat_bucket_carry_order, + ), + TemporalReverseProgramRuntimeFact( + "recurrent_neighbor_idx", + runtime.recurrent_neighbor_idx_flat_bucket_carry_order, + ), + TemporalReverseProgramRuntimeFact( + "recurrent_neighbor_valid", + runtime.recurrent_neighbor_valid_backend_order, + ), + TemporalReverseProgramRuntimeFact( + "recurrent_edge_distance", + runtime.recurrent_edge_distance_backend_order, + ), + TemporalReverseProgramRuntimeFact("recurrent_edge_delay", runtime.recurrent_edge_delay_backend_order), + TemporalReverseProgramRuntimeFact("message_step_indices", message_step_indices.contiguous()), + TemporalReverseProgramRuntimeFact( + "input_count", + torch.tensor([int(reference_boundary.shape[1])], dtype=torch.long, device=scalar_device), + ), + TemporalReverseProgramRuntimeFact( + "recurrent_count", + torch.tensor([int(runtime.recurrent_cell_idx.numel())], dtype=torch.long, device=scalar_device), + ), + TemporalReverseProgramRuntimeFact( + "distance_scale", + torch.tensor( + [float(runtime.config.message.distance_logit_scale)], + dtype=torch.float64, + device=scalar_device, + ), + ), + TemporalReverseProgramRuntimeFact( + "use_sparse_messages", + torch.tensor([1 if bool(use_sparse_messages) else 0], dtype=torch.long, device=scalar_device), + ), + TemporalReverseProgramRuntimeFact( + "use_delay", + torch.tensor([1 if bool(runtime._has_edge_delay) else 0], dtype=torch.long, device=scalar_device), + ), + TemporalReverseProgramRuntimeFact( + "group_size", + torch.tensor([int(runtime._input_sender_kv_group_size)], dtype=torch.long, device=scalar_device), + ), + TemporalReverseProgramRuntimeFact( + "head_dim", + torch.tensor([int(runtime.head_dim)], dtype=torch.long, device=scalar_device), + ), + TemporalReverseProgramRuntimeFact( + "value_dim", + torch.tensor([int(runtime.value_dim)], dtype=torch.long, device=scalar_device), + ), + TemporalReverseProgramRuntimeFact( + "return_boundary_grad", + torch.tensor([1 if bool(return_boundary_grad) else 0], dtype=torch.long, device=scalar_device), + ), + ) + _validate_reverse_runtime_fact_devices(reference_boundary, facts) + return TemporalReverseProgramRuntimePlan(facts=facts) + + +def build_temporal_forward_program_runtime_support_plan( + runtime: Any, + *, + boundary_seq: torch.Tensor, + output_contract: str, + readout_pool: str, + materialize_final_state: bool, + collect_artifacts: bool, + memory_artifact_plan: Any, +) -> TemporalProgramRuntimeSupportPlan: + checks = ( + _support_check( + "forward", + "output_contract", + output_contract in {"output_cells", "pooled_output_cells"}, + f"unsupported_output_contract:{output_contract}", + ), + _support_check( + "forward", + "readout_pool", + output_contract != "pooled_output_cells" or readout_pool in {"flatten", "mean"}, + f"unsupported_output_contract:{output_contract};readout_pool={readout_pool}", + ), + _support_check( + "forward", + "final_state_materialization", + not (output_contract == "pooled_output_cells" and bool(materialize_final_state)), + "unsupported_pooled_output_cells_with_materialized_final_state", + ), + _support_check( + "forward", + "boundary_device_dtype", + boundary_seq.device.type == "cuda" and boundary_seq.dtype == torch.float32, + "unsupported_boundary_device_or_dtype", + ), + _support_check( + "forward", + "local_message_step", + bool(getattr(runtime, "_local_message_step_enabled", False)), + "unsupported_local_message_step_disabled", + ), + _support_check( + "forward", + "artifact_storage", + ( + not bool(collect_artifacts) + or ( + str(getattr(memory_artifact_plan, "mode", "")) == "store_step_artifacts" + and bool(getattr(memory_artifact_plan, "store_step_artifacts", False)) + ) + ), + "unsupported_artifact_storage_policy", + ), + ) + return TemporalProgramRuntimeSupportPlan(direction="forward", checks=checks) + + +def build_temporal_reverse_program_runtime_support_plan( + *, + reference_boundary: torch.Tensor, + grad_output_window: torch.Tensor | None, + grad_carry_cells: torch.Tensor | None, + materialize_grad_carry_cells: bool, + local_time_steps: int, + output_contract: str, + readout_pool: str, + reverse_artifact_roles: tuple[str, ...], +) -> TemporalProgramRuntimeSupportPlan: + checks = ( + _support_check( + "reverse", + "window_length", + int(local_time_steps) > 0, + "empty_tensor_store_window", + ), + _support_check( + "reverse", + "output_contract", + output_contract in {"output_cells", "pooled_output_cells"}, + f"unsupported_output_contract:{output_contract}", + ), + _support_check( + "reverse", + "readout_pool", + output_contract != "pooled_output_cells" or readout_pool in {"flatten", "mean"}, + f"unsupported_output_contract:{output_contract};readout_pool={readout_pool}", + ), + _support_check( + "reverse", + "grad_output_window", + torch.is_tensor(grad_output_window) and int(grad_output_window.shape[1]) == int(local_time_steps), + "missing_or_mismatched_grad_output_window", + ), + _support_check( + "reverse", + "grad_carry_device_dtype", + not torch.is_tensor(grad_carry_cells) + or (grad_carry_cells.device.type == "cuda" and grad_carry_cells.dtype == torch.float32), + "unsupported_carry_cell_grad_device_or_dtype", + ), + _support_check( + "reverse", + "grad_carry_materialization_policy", + True, + f"materialize_grad_carry_cells={int(bool(materialize_grad_carry_cells))}", + ), + _support_check( + "reverse", + "reverse_artifact_roles", + bool(reverse_artifact_roles), + "missing_reverse_artifact_roles", + ), + _support_check( + "reverse", + "boundary_device_dtype", + reference_boundary.device.type == "cuda" and reference_boundary.dtype == torch.float32, + "unsupported_boundary_device_or_dtype", + ), + ) + return TemporalProgramRuntimeSupportPlan(direction="reverse", checks=checks) + + +def _validate_forward_runtime_fact_devices( + boundary_seq: torch.Tensor, + facts: tuple[TemporalForwardProgramRuntimeFact, ...], +) -> None: + device_roles = { + "recurrent_local_sender_idx", + "output_local_sender_idx", + "local_distance", + "local_delay", + } + for fact in facts: + if fact.role in device_roles and fact.tensor.device != boundary_seq.device: + raise RuntimeError( + "Registered forward program runtime fact must live beside boundary_seq: " + f"role={fact.role}; fact_device={fact.tensor.device}; boundary_device={boundary_seq.device}" + ) + + +def _validate_reverse_runtime_fact_devices( + reference_boundary: torch.Tensor, + facts: tuple[TemporalReverseProgramRuntimeFact, ...], +) -> None: + device_roles = { + "backend_to_graph_inverse_order", + "output_local_sender_idx", + "local_distance", + "local_delay", + "output_neighbor_idx", + "output_neighbor_valid", + "output_edge_distance", + "output_edge_delay", + "recurrent_local_sender_idx", + "recurrent_neighbor_idx", + "recurrent_neighbor_valid", + "recurrent_edge_distance", + "recurrent_edge_delay", + } + cpu_roles = {"message_step_indices"} + for fact in facts: + if fact.role in device_roles and fact.tensor.device != reference_boundary.device: + raise RuntimeError( + "Registered reverse program runtime fact must live beside reference boundary: " + f"role={fact.role}; fact_device={fact.tensor.device}; " + f"boundary_device={reference_boundary.device}" + ) + if fact.role in cpu_roles and fact.tensor.device.type != "cpu": + raise RuntimeError( + "Registered reverse program runtime fact must live on CPU: " + f"role={fact.role}; fact_device={fact.tensor.device}" + ) + + +def _shape_summary(shape: tuple[int, ...]) -> str: + return "scalar" if not shape else "x".join(str(int(item)) for item in shape) + + +def _support_check( + direction: TemporalProgramRuntimeSupportDirection, + requirement: TemporalProgramRuntimeRequirement, + legal: bool, + reason: str, +) -> TemporalProgramRuntimeSupportCheck: + return TemporalProgramRuntimeSupportCheck( + direction=direction, + requirement=requirement, + legal=bool(legal), + reason=reason if requirement == "grad_carry_materialization_policy" else "" if bool(legal) else reason, + ) + + +def _stable_reason_hash(reason: str) -> int: + value = 0 + for byte in str(reason).encode("utf-8"): + value = ((int(value) * 131) + int(byte)) & 0x7FFFFFFF + return int(value) + + +__all__ = [ + "TemporalForwardProgramRuntimeFact", + "TemporalForwardProgramRuntimePlan", + "TemporalProgramRuntimeSupportCheck", + "TemporalProgramRuntimeSupportPlan", + "TemporalReverseProgramRuntimeFact", + "TemporalReverseProgramRuntimePlan", + "build_temporal_forward_program_runtime_support_plan", + "build_temporal_forward_program_runtime_plan", + "build_temporal_reverse_program_runtime_support_plan", + "build_temporal_reverse_program_runtime_plan", + "temporal_forward_program_runtime_role_opcode", + "temporal_reverse_program_runtime_role_opcode", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py new file mode 100644 index 00000000..21cfc77c --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from cortical.fabric.backend.message_rules import ( + MessageRuleParameterReducerSpec, + build_message_rule_backend_spec, + registered_message_rule_backend_spec_types, +) + + +@dataclass(frozen=True) +class TemporalParameterReducerPattern: + reducer_kind: str + reducer_kind_opcode: int + native_callable: str + implementation_symbol: str + count_target: str = "none" + count_mode: str = "none" + active_trainable_roles: tuple[str, ...] = () + required_static_logical_groups: tuple[tuple[str, ...], ...] = () + grad_output_roles: tuple[str, ...] = () + strategy_opcode: int = 0 + strategy_version: int = 1 + cxx_entrypoints: tuple[str, ...] = () + + @property + def stable_strategy_opcode(self) -> int: + return int(self.strategy_opcode or self.reducer_kind_opcode) + + @property + def stable_cxx_entrypoints(self) -> tuple[str, ...]: + return self.cxx_entrypoints or (self.implementation_symbol,) + + +@dataclass(frozen=True) +class TemporalTransitionTrainableReducerPattern: + reducer_kind: str + reducer_kind_opcode: int + native_callable: str + implementation_symbol: str + strategy_version: int = 1 + cxx_entrypoints: tuple[str, ...] = () + + @property + def stable_cxx_entrypoints(self) -> tuple[str, ...]: + return self.cxx_entrypoints or (self.implementation_symbol,) + + +_PARAMETER_REDUCER_COUNT_TARGET_OPCODE = { + "none": 0, + "sender": 1, + "readout": 2, + "recurrent_query": 3, + "output_query": 4, + "message_strategy": 5, +} +_PARAMETER_REDUCER_COUNT_MODE_OPCODE = { + "none": 0, + "tensor_count": 1, + "row": 2, +} +_PARAMETER_TRAINABLE_ROLE_TO_NAME = { + "public_proj_weight": "public_proj.weight", + "k_weight": "k_weight", + "v_weight": "v_weight", + "q_proj_weight": "q_proj.weight", + "slot_embed": "slot_embed", + "msg_out_weight": "msg_out.weight", + "output_cell_weight": "output_cell_weight", + "output_cell_bias": "output_cell_bias", + "message_query_slot_proj_weight": "message_query_slot_proj.weight", + "message_sender_slot_key_proj_weight": "message_sender_slot_key_proj.weight", + "message_query_nudge_scale": "message_query_nudge_scale", + "message_sender_context_key": "message_sender_context_key", + "message_query_context_gate": "message_query_context_gate", +} +_PARAMETER_TRAINABLE_ROLE_OPCODE = {role: index + 1 for index, role in enumerate(_PARAMETER_TRAINABLE_ROLE_TO_NAME)} +_PARAMETER_RUNTIME_ROLE_OPCODE = { + "population_backend_recurrent_inverse_order": 1, + "recurrent_cell_idx": 2, + "output_cell_idx": 3, + "input_cell_idx": 4, +} +_MESSAGE_STRATEGY_GRAD_OUTPUT_ROLE_OPCODE = { + "grad_query_slot_backend": 1, + "grad_input_key_bank": 2, + "grad_recurrent_key_bank": 3, + "grad_query_context_scalar": 4, + "grad_output_weight": 5, +} + +_BASE_PARAMETER_REDUCER_PATTERNS = ( + TemporalParameterReducerPattern( + reducer_kind="readout_output", + reducer_kind_opcode=1, + native_callable="native.reverse.parameter_reduction.readout_output.v1", + implementation_symbol="run_registered_readout_output_parameter_reducer_strategy", + count_target="readout", + count_mode="tensor_count", + active_trainable_roles=("msg_out_weight", "output_cell_weight", "output_cell_bias"), + ), + TemporalParameterReducerPattern( + reducer_kind="sender_kv_projection", + reducer_kind_opcode=2, + native_callable="native.reverse.parameter_reduction.sender_kv_projection.v1", + implementation_symbol="run_registered_sender_kv_parameter_reducer_strategy", + count_target="sender", + count_mode="tensor_count", + active_trainable_roles=("public_proj_weight", "k_weight", "v_weight"), + ), + TemporalParameterReducerPattern( + reducer_kind="recurrent_query", + reducer_kind_opcode=3, + native_callable="native.reverse.parameter_reduction.recurrent_query.v1", + implementation_symbol="run_registered_recurrent_query_parameter_reducer_strategy", + count_target="recurrent_query", + count_mode="row", + active_trainable_roles=("slot_embed", "q_proj_weight"), + ), + TemporalParameterReducerPattern( + reducer_kind="transition", + reducer_kind_opcode=4, + native_callable="native.reverse.parameter_reduction.transition.v1", + implementation_symbol="run_registered_transition_parameter_reducer_strategy", + ), + TemporalParameterReducerPattern( + reducer_kind="output_query", + reducer_kind_opcode=5, + native_callable="native.reverse.parameter_reduction.output_query.v1", + implementation_symbol="run_registered_output_query_parameter_reducer_strategy", + count_target="output_query", + count_mode="row", + active_trainable_roles=("slot_embed", "q_proj_weight"), + ), +) + + +def _message_rule_parameter_reducer_patterns() -> tuple[TemporalParameterReducerPattern, ...]: + reducers_by_kind: dict[str, MessageRuleParameterReducerSpec] = {} + for rule_type in registered_message_rule_backend_spec_types(): + spec = build_message_rule_backend_spec( + rule_type=rule_type, + kv_group_count=1, + cell_count=2, + ) + if spec.parameter_reducer is None: + continue + existing = reducers_by_kind.get(spec.parameter_reducer.reducer_kind) + if existing is not None and existing != spec.parameter_reducer: + raise RuntimeError( + "Message-rule parameter reducer implementations must be unique by reducer kind: " + f"reducer_kind={spec.parameter_reducer.reducer_kind!r}" + ) + reducers_by_kind[spec.parameter_reducer.reducer_kind] = spec.parameter_reducer + return tuple( + TemporalParameterReducerPattern( + reducer_kind=reducer.reducer_kind, + reducer_kind_opcode=int(reducer.reducer_kind_opcode), + native_callable=reducer.native_callable, + implementation_symbol=reducer.implementation_symbol, + count_target=reducer.count_target, + count_mode=reducer.count_mode, + active_trainable_roles=tuple(reducer.active_trainable_roles), + required_static_logical_groups=tuple(reducer.required_static_logical_groups), + grad_output_roles=tuple(reducer.grad_output_roles), + strategy_opcode=int(reducer.strategy_opcode), + strategy_version=int(reducer.strategy_version), + cxx_entrypoints=tuple(reducer.cxx_entrypoints), + ) + for reducer in reducers_by_kind.values() + ) + + +_PARAMETER_REDUCER_PATTERNS = _BASE_PARAMETER_REDUCER_PATTERNS + _message_rule_parameter_reducer_patterns() + + +_TRANSITION_TRAINABLE_REDUCER_PATTERNS = ( + TemporalTransitionTrainableReducerPattern( + reducer_kind="materialized_base", + reducer_kind_opcode=1, + native_callable="native.reverse.parameter_reduction.transition.materialized_base.v1", + implementation_symbol="run_registered_transition_materialized_base_reducer", + ), + TemporalTransitionTrainableReducerPattern( + reducer_kind="materialized_delta", + reducer_kind_opcode=2, + native_callable="native.reverse.parameter_reduction.transition.materialized_delta.v1", + implementation_symbol="run_registered_transition_materialized_delta_reducer", + ), + TemporalTransitionTrainableReducerPattern( + reducer_kind="value_to_cell_msg_to_cell", + reducer_kind_opcode=3, + native_callable="native.reverse.parameter_reduction.transition.value_to_cell_msg_to_cell.v1", + implementation_symbol="run_registered_transition_value_to_cell_msg_to_cell_reducer", + ), + TemporalTransitionTrainableReducerPattern( + reducer_kind="value_to_cell_msg_out", + reducer_kind_opcode=4, + native_callable="native.reverse.parameter_reduction.transition.value_to_cell_msg_out.v1", + implementation_symbol="run_registered_transition_value_to_cell_msg_out_reducer", + ), + TemporalTransitionTrainableReducerPattern( + reducer_kind="recurrent_bias_slot_embed", + reducer_kind_opcode=5, + native_callable="native.reverse.parameter_reduction.transition.recurrent_bias_slot_embed.v1", + implementation_symbol="run_registered_transition_recurrent_bias_slot_embed_reducer", + ), + TemporalTransitionTrainableReducerPattern( + reducer_kind="recurrent_bias_cell_bias_proj", + reducer_kind_opcode=6, + native_callable="native.reverse.parameter_reduction.transition.recurrent_bias_cell_bias_proj.v1", + implementation_symbol="run_registered_transition_recurrent_bias_cell_bias_proj_reducer", + ), +) + + +def temporal_parameter_reducer_patterns() -> tuple[TemporalParameterReducerPattern, ...]: + return _PARAMETER_REDUCER_PATTERNS + + +def temporal_transition_trainable_reducer_patterns() -> tuple[TemporalTransitionTrainableReducerPattern, ...]: + return _TRANSITION_TRAINABLE_REDUCER_PATTERNS + + +def temporal_parameter_reducer_pattern(reducer_kind: str) -> TemporalParameterReducerPattern: + for pattern in _PARAMETER_REDUCER_PATTERNS: + if pattern.reducer_kind == str(reducer_kind): + return pattern + raise RuntimeError(f"Unregistered temporal parameter reducer kind {reducer_kind!r}") + + +def temporal_parameter_reducer_pattern_for_opcode(reducer_kind_opcode: int) -> TemporalParameterReducerPattern: + for pattern in _PARAMETER_REDUCER_PATTERNS: + if int(pattern.reducer_kind_opcode) == int(reducer_kind_opcode): + return pattern + raise RuntimeError(f"Unregistered temporal parameter reducer opcode {int(reducer_kind_opcode)}") + + +def temporal_parameter_reducer_kind_opcode(reducer_kind: str) -> int: + return int(temporal_parameter_reducer_pattern(reducer_kind).reducer_kind_opcode) + + +def temporal_parameter_reducer_strategy_opcode(reducer_kind: str) -> int: + return int(temporal_parameter_reducer_pattern(reducer_kind).stable_strategy_opcode) + + +def temporal_parameter_reducer_count_target_opcode(count_target: str) -> int: + try: + return int(_PARAMETER_REDUCER_COUNT_TARGET_OPCODE[str(count_target)]) + except KeyError as exc: + raise RuntimeError(f"Unregistered parameter reducer count target {count_target!r}") from exc + + +def temporal_parameter_reducer_count_mode_opcode(count_mode: str) -> int: + try: + return int(_PARAMETER_REDUCER_COUNT_MODE_OPCODE[str(count_mode)]) + except KeyError as exc: + raise RuntimeError(f"Unregistered parameter reducer count mode {count_mode!r}") from exc + + +def temporal_parameter_trainable_roles() -> tuple[str, ...]: + return tuple(_PARAMETER_TRAINABLE_ROLE_TO_NAME) + + +def temporal_parameter_trainable_role_name(role: str) -> str: + try: + return _PARAMETER_TRAINABLE_ROLE_TO_NAME[str(role)] + except KeyError as exc: + raise RuntimeError(f"Unregistered parameter trainable role {role!r}") from exc + + +def temporal_parameter_trainable_role_opcode(role: str) -> int: + try: + return int(_PARAMETER_TRAINABLE_ROLE_OPCODE[str(role)]) + except KeyError as exc: + raise RuntimeError(f"Unregistered parameter trainable role {role!r}") from exc + + +def temporal_parameter_runtime_role_opcode(role: str) -> int: + try: + return int(_PARAMETER_RUNTIME_ROLE_OPCODE[str(role)]) + except KeyError as exc: + raise RuntimeError(f"Unregistered parameter runtime role {role!r}") from exc + + +def temporal_message_strategy_grad_output_role_opcode(role: str) -> int: + try: + return int(_MESSAGE_STRATEGY_GRAD_OUTPUT_ROLE_OPCODE[str(role)]) + except KeyError as exc: + raise RuntimeError(f"Unregistered message strategy grad output role {role!r}") from exc + + +def temporal_transition_trainable_reducer_pattern( + reducer_kind: str, +) -> TemporalTransitionTrainableReducerPattern: + for pattern in _TRANSITION_TRAINABLE_REDUCER_PATTERNS: + if pattern.reducer_kind == str(reducer_kind): + return pattern + raise RuntimeError(f"Unregistered temporal transition trainable reducer kind {reducer_kind!r}") + + +def temporal_transition_trainable_reducer_kind_opcode(reducer_kind: str) -> int: + return int(temporal_transition_trainable_reducer_pattern(reducer_kind).reducer_kind_opcode) + + +__all__ = [ + "TemporalParameterReducerPattern", + "TemporalTransitionTrainableReducerPattern", + "temporal_message_strategy_grad_output_role_opcode", + "temporal_parameter_reducer_count_mode_opcode", + "temporal_parameter_reducer_count_target_opcode", + "temporal_parameter_reducer_kind_opcode", + "temporal_parameter_reducer_pattern", + "temporal_parameter_reducer_pattern_for_opcode", + "temporal_parameter_reducer_patterns", + "temporal_parameter_reducer_strategy_opcode", + "temporal_parameter_runtime_role_opcode", + "temporal_parameter_trainable_role_name", + "temporal_parameter_trainable_role_opcode", + "temporal_parameter_trainable_roles", + "temporal_transition_trainable_reducer_kind_opcode", + "temporal_transition_trainable_reducer_pattern", + "temporal_transition_trainable_reducer_patterns", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reset_plan.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reset_plan.py new file mode 100644 index 00000000..5e5bd4a0 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reset_plan.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import Literal + +import torch + + +TemporalForwardResetKind = Literal["message", "transition"] +TemporalReverseResetKind = Literal["message", "transition"] + +_FORWARD_RESET_KIND_IDS: dict[TemporalForwardResetKind, int] = { + "message": 1, + "transition": 2, +} +_REVERSE_RESET_KIND_IDS: dict[TemporalReverseResetKind, int] = { + "message": 1, + "transition": 2, +} +_FORWARD_RESET_POLICY_ZERO_SOURCE_ROWS = 1 +_FORWARD_RESET_SCOPE_BATCH_OUTER_STEP = 1 +_REVERSE_RESET_POLICY_ZERO_SOURCE_ROWS = 1 +_REVERSE_RESET_SCOPE_BATCH_ROW = 1 + + +def temporal_forward_reset_kind_id(kind: TemporalForwardResetKind) -> int: + return _FORWARD_RESET_KIND_IDS[kind] + + +def temporal_reverse_reset_kind_id(kind: TemporalReverseResetKind) -> int: + return _REVERSE_RESET_KIND_IDS[kind] + + +def temporal_forward_reset_tensor_table( + *, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + tensors: list[torch.Tensor] = [] + rows: list[list[int]] = [] + + def append(kind: TemporalForwardResetKind, reset_seq: torch.Tensor | None) -> None: + if reset_seq is None: + return + if reset_seq.dim() != 2: + raise RuntimeError(f"Forward temporal reset tensor for {kind!r} must have shape [B,T]") + rows.append( + [ + temporal_forward_reset_kind_id(kind), + len(tensors), + _FORWARD_RESET_POLICY_ZERO_SOURCE_ROWS, + _FORWARD_RESET_SCOPE_BATCH_OUTER_STEP, + ] + ) + tensors.append(reset_seq.to(device=reset_seq.device, dtype=torch.bool).contiguous()) + + append("message", population_resets) + append("transition", transition_resets) + if not rows: + return tuple(tensors), torch.empty((0, 4), dtype=torch.long) + return tuple(tensors), torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_reset_tensor_table( + *, + message_reset_step: torch.Tensor | None, + transition_reset_step: torch.Tensor | None, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + tensors: list[torch.Tensor] = [] + rows: list[list[int]] = [] + + def append(kind: TemporalReverseResetKind, reset_step: torch.Tensor | None) -> None: + if reset_step is None: + return + rows.append( + [ + temporal_reverse_reset_kind_id(kind), + len(tensors), + _REVERSE_RESET_POLICY_ZERO_SOURCE_ROWS, + _REVERSE_RESET_SCOPE_BATCH_ROW, + ] + ) + tensors.append(reset_step.to(device=reset_step.device, dtype=torch.bool).view(-1).contiguous()) + + append("message", message_reset_step) + append("transition", transition_reset_step) + if not rows: + return tuple(tensors), torch.empty((0, 4), dtype=torch.long) + return tuple(tensors), torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_transition_state_reset_rows_tensor( + *, + group_logical_slots: tuple[dict[str, int], ...], +) -> torch.Tensor: + rows: list[list[int]] = [] + for group_index, logical_to_slot in enumerate(group_logical_slots): + for state_name in ("grad_y", "grad_c", "grad_n", "grad_m", "grad_hc1", "grad_hc2"): + slot = logical_to_slot.get(state_name) + if slot is not None: + rows.append([int(group_index), int(slot)]) + if not rows: + return torch.empty((0, 2), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +__all__ = [ + "TemporalForwardResetKind", + "TemporalReverseResetKind", + "temporal_forward_reset_kind_id", + "temporal_forward_reset_tensor_table", + "temporal_reverse_reset_kind_id", + "temporal_reverse_reset_tensor_table", + "temporal_reverse_transition_state_reset_rows_tensor", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py new file mode 100644 index 00000000..a13d3b27 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class TemporalReverseArtifactRole: + role_id: int + name: str + tensor_required: bool + + @property + def summary(self) -> str: + return f"reverse_artifact_role={int(self.role_id)},name={self.name},tensor={int(self.tensor_required)}" + + +@dataclass(frozen=True) +class TemporalReverseArtifactAccess: + access_id: int + name: str + role_name: str + required: bool + + @property + def summary(self) -> str: + return ( + f"reverse_artifact_access={int(self.access_id)},name={self.name}," + f"role={self.role_name},required={int(self.required)}" + ) + + +_REVERSE_ARTIFACT_ROLES = ( + TemporalReverseArtifactRole(1, "boundary_step", True), + TemporalReverseArtifactRole(2, "cells_prev", True), + TemporalReverseArtifactRole(3, "input_k", True), + TemporalReverseArtifactRole(4, "input_v", True), + TemporalReverseArtifactRole(5, "recurrent_k_before", True), + TemporalReverseArtifactRole(6, "recurrent_v_before", True), + TemporalReverseArtifactRole(7, "recurrent_k", True), + TemporalReverseArtifactRole(8, "recurrent_v", True), + TemporalReverseArtifactRole(9, "recurrent_hidden_before_backend_order", True), + TemporalReverseArtifactRole(10, "recurrent_hidden_backend_order", True), + TemporalReverseArtifactRole(11, "recurrent_msg_backend_order", True), + TemporalReverseArtifactRole(12, "output_msg", True), + TemporalReverseArtifactRole(13, "output_cells", True), + TemporalReverseArtifactRole(14, "transition_state_before", True), +) +_ROLE_BY_NAME = {role.name: role for role in _REVERSE_ARTIFACT_ROLES} +_REVERSE_ARTIFACT_ACCESSES = ( + TemporalReverseArtifactAccess(1, "boundary_step", "boundary_step", True), + TemporalReverseArtifactAccess(2, "cells_prev", "cells_prev", False), + TemporalReverseArtifactAccess(3, "input_k", "input_k", True), + TemporalReverseArtifactAccess(4, "input_v", "input_v", True), + TemporalReverseArtifactAccess(5, "recurrent_k_before", "recurrent_k_before", True), + TemporalReverseArtifactAccess(6, "recurrent_v_before", "recurrent_v_before", True), + TemporalReverseArtifactAccess(7, "recurrent_k", "recurrent_k", False), + TemporalReverseArtifactAccess(8, "recurrent_v", "recurrent_v", False), + TemporalReverseArtifactAccess( + 9, + "recurrent_hidden_before_backend_order", + "recurrent_hidden_before_backend_order", + True, + ), + TemporalReverseArtifactAccess( + 10, + "recurrent_hidden_backend_order", + "recurrent_hidden_backend_order", + True, + ), + TemporalReverseArtifactAccess(11, "recurrent_msg_backend_order", "recurrent_msg_backend_order", True), + TemporalReverseArtifactAccess(12, "output_msg", "output_msg", True), + TemporalReverseArtifactAccess(13, "output_cells", "output_cells", True), + TemporalReverseArtifactAccess(14, "transition_state_before", "transition_state_before", True), +) +_ACCESS_BY_NAME = {access.name: access for access in _REVERSE_ARTIFACT_ACCESSES} +_TRANSITION_STATE_ARTIFACT_FLAG_STRIDE = 1_000_000 + + +def temporal_reverse_artifact_roles() -> tuple[TemporalReverseArtifactRole, ...]: + return _REVERSE_ARTIFACT_ROLES + + +def temporal_reverse_artifact_accesses() -> tuple[TemporalReverseArtifactAccess, ...]: + return _REVERSE_ARTIFACT_ACCESSES + + +def temporal_reverse_artifact_access_names() -> tuple[str, ...]: + return tuple(access.name for access in _REVERSE_ARTIFACT_ACCESSES) + + +def temporal_reverse_artifact_role_names() -> tuple[str, ...]: + return tuple(role.name for role in _REVERSE_ARTIFACT_ROLES) + + +def temporal_reverse_tensor_artifact_role_names() -> tuple[str, ...]: + return tuple(role.name for role in _REVERSE_ARTIFACT_ROLES if role.tensor_required) + + +def temporal_reverse_artifact_role_id(name: str) -> int: + try: + return int(_ROLE_BY_NAME[str(name)].role_id) + except KeyError as error: + raise RuntimeError(f"Unknown temporal reverse artifact role {name!r}") from error + + +def temporal_reverse_artifact_role_is_tensor(name: str) -> bool: + try: + return bool(_ROLE_BY_NAME[str(name)].tensor_required) + except KeyError as error: + raise RuntimeError(f"Unknown temporal reverse artifact role {name!r}") from error + + +def temporal_reverse_artifact_access_id(name: str) -> int: + try: + return int(_ACCESS_BY_NAME[str(name)].access_id) + except KeyError as error: + raise RuntimeError(f"Unknown temporal reverse artifact access {name!r}") from error + + +def temporal_reverse_artifact_access_role_name(name: str) -> str: + try: + return str(_ACCESS_BY_NAME[str(name)].role_name) + except KeyError as error: + raise RuntimeError(f"Unknown temporal reverse artifact access {name!r}") from error + + +def temporal_reverse_artifact_access_is_required(name: str) -> bool: + try: + return bool(_ACCESS_BY_NAME[str(name)].required) + except KeyError as error: + raise RuntimeError(f"Unknown temporal reverse artifact access {name!r}") from error + + +def temporal_reverse_artifact_role_rows_tensor(roles: tuple[str, ...]) -> torch.Tensor: + rows = [ + [ + index, + temporal_reverse_artifact_role_id(role), + int(temporal_reverse_artifact_role_is_tensor(role)), + ] + for index, role in enumerate(roles) + ] + if not rows: + return torch.empty((0, 3), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_artifact_access_rows_tensor( + roles: tuple[str, ...], + accesses: tuple[str, ...] | None = None, +) -> torch.Tensor: + available_roles = set(roles) + access_names = ( + temporal_reverse_artifact_access_names() if accesses is None else tuple(str(name) for name in accesses) + ) + rows: list[list[int]] = [] + for access_name in access_names: + access = _ACCESS_BY_NAME.get(access_name) + if access is None: + raise RuntimeError(f"Unknown temporal reverse artifact access {access_name!r}") + if access.role_name not in available_roles: + if access.required: + raise RuntimeError( + "Temporal reverse artifact access requires a missing role: " + f"access={access.name!r}; role={access.role_name!r}" + ) + continue + rows.append( + [ + int(access.access_id), + temporal_reverse_artifact_role_id(access.role_name), + int(access.required), + ] + ) + if not rows: + return torch.empty((0, 3), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def encode_temporal_reverse_transition_state_artifact_flags( + *, + bucket_ordinal: int, + binding_index: int, +) -> int: + if int(bucket_ordinal) < 0 or int(binding_index) < 0: + raise RuntimeError("Transition state-before reverse artifact flags require non-negative ids") + if int(binding_index) >= _TRANSITION_STATE_ARTIFACT_FLAG_STRIDE: + raise RuntimeError( + "Transition state-before reverse artifact binding index exceeds flag encoding stride: " + f"binding={int(binding_index)}; stride={int(_TRANSITION_STATE_ARTIFACT_FLAG_STRIDE)}" + ) + return int(bucket_ordinal) * int(_TRANSITION_STATE_ARTIFACT_FLAG_STRIDE) + int(binding_index) + + +def decode_temporal_reverse_transition_state_artifact_flags(flags: int) -> tuple[int, int]: + if int(flags) < 0: + raise RuntimeError("Transition state-before reverse artifact flags must be non-negative") + return ( + int(flags) // int(_TRANSITION_STATE_ARTIFACT_FLAG_STRIDE), + int(flags) % int(_TRANSITION_STATE_ARTIFACT_FLAG_STRIDE), + ) + + +__all__ = [ + "TemporalReverseArtifactAccess", + "TemporalReverseArtifactRole", + "decode_temporal_reverse_transition_state_artifact_flags", + "encode_temporal_reverse_transition_state_artifact_flags", + "temporal_reverse_artifact_access_id", + "temporal_reverse_artifact_access_is_required", + "temporal_reverse_artifact_access_names", + "temporal_reverse_artifact_access_role_name", + "temporal_reverse_artifact_access_rows_tensor", + "temporal_reverse_artifact_accesses", + "temporal_reverse_artifact_role_id", + "temporal_reverse_artifact_role_is_tensor", + "temporal_reverse_artifact_role_names", + "temporal_reverse_artifact_role_rows_tensor", + "temporal_reverse_artifact_roles", + "temporal_reverse_tensor_artifact_role_names", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/row_groups.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/row_groups.py new file mode 100644 index 00000000..6fd7c2b9 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/row_groups.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +TEMPORAL_MESSAGE_BUCKET_ORDINAL = -1 +TEMPORAL_READOUT_BUCKET_ORDINAL = -2 +TEMPORAL_PARAMETER_REDUCTION_BUCKET_ORDINAL = -3 + + +@dataclass(frozen=True) +class TemporalRowSchema: + primitive: str + parameter_inputs: tuple[str, ...] = () + inputs: tuple[str, ...] = () + outputs: tuple[str, ...] = () + attributes: tuple[tuple[str, str], ...] = () + + @property + def semantic_signature(self) -> tuple[str, tuple[str, ...], tuple[tuple[str, str], ...]]: + return self.primitive, self.parameter_inputs, self.attributes + + @property + def summary(self) -> str: + params = ",".join(self.parameter_inputs) if self.parameter_inputs else "-" + inputs = ",".join(self.inputs) if self.inputs else "-" + outputs = ",".join(self.outputs) if self.outputs else "-" + attrs = ",".join(f"{key}={value}" for key, value in self.attributes) if self.attributes else "-" + return f"primitive={self.primitive},params={params},inputs={inputs},outputs={outputs},attrs={attrs}" + + +@dataclass(frozen=True) +class TemporalRowGroupSchema: + surface: str + bucket_ordinal: int | None + rows: tuple[TemporalRowSchema, ...] + effects: tuple[str, ...] = () + + @property + def row_signature(self) -> tuple[tuple[str, tuple[str, ...], tuple[tuple[str, str], ...]], ...]: + return tuple(row.semantic_signature for row in self.rows) + + @property + def primitive_signature(self) -> tuple[str, ...]: + return tuple(row.primitive for row in self.rows) + + @property + def summary(self) -> str: + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + return ( + f"surface={self.surface}" + f",bucket={bucket}" + f",rows={'+'.join(self.primitive_signature)}" + f",effects={'+'.join(self.effects) if self.effects else '-'}" + ) + + def matches(self, candidate: TemporalRowGroupSchema) -> bool: + if self.surface != candidate.surface: + return False + if self.bucket_ordinal is not None and int(self.bucket_ordinal) != int(candidate.bucket_ordinal or 0): + return False + if len(self.rows) != len(candidate.rows): + return False + for expected, actual in zip(self.rows, candidate.rows, strict=True): + if expected.primitive != actual.primitive: + return False + if expected.parameter_inputs != ("*",) and expected.parameter_inputs != actual.parameter_inputs: + return False + if expected.attributes: + actual_attributes = set(actual.attributes) + if not set(expected.attributes).issubset(actual_attributes): + return False + return set(self.effects).issubset(set(candidate.effects)) + + +def surface_for_temporal_row(row: object) -> str: + for item in tuple(str(item) for item in getattr(row, "flat_bucket_identity", ()) or ()): + if item.startswith("surface="): + return item.removeprefix("surface=") + return "transition" if int(getattr(row, "bucket_ordinal", 0)) >= 0 else "unknown" + + +def temporal_effects_for_row(row: object, *, surface: str | None = None) -> tuple[str, ...]: + row_surface = surface_for_temporal_row(row) if surface is None else surface + primitive = str(getattr(row, "primitive", "")) + if row_surface == "message": + return "state_read", "parameter_read", "message_emit" + if row_surface == "readout": + return "state_read", "parameter_read", "output_emit" + if row_surface == "readout_boundary": + return "materialization_boundary", "output_emit" + if row_surface == "parameter_reduction": + return "grad_read", "parameter_grad_emit" + if row_surface == "transition": + return "state_read", "message_read", "state_write", "tape_policy" + return ("unknown_effect:" + primitive,) + + +def canonical_temporal_row_group( + *, + surface: str, + bucket_ordinal: int | None, + rows: tuple[object, ...], +) -> TemporalRowGroupSchema: + schemas = tuple(_row_schema(row) for row in rows) + effects = tuple(dict.fromkeys(effect for row in rows for effect in temporal_effects_for_row(row))) + return TemporalRowGroupSchema( + surface=surface, + bucket_ordinal=bucket_ordinal, + rows=schemas, + effects=effects, + ) + + +def pattern_temporal_row_group( + *, + surface: str, + bucket_ordinal: int | None, + rows: tuple[Any, ...], + match_effects: tuple[str, ...] = (), +) -> TemporalRowGroupSchema: + return TemporalRowGroupSchema( + surface=surface, + bucket_ordinal=bucket_ordinal, + rows=tuple( + TemporalRowSchema( + primitive=str(getattr(row, "primitive", "")), + parameter_inputs=tuple(str(parameter) for parameter in getattr(row, "parameter_inputs", ()) or ()), + attributes=tuple( + (str(key), str(value)) + for key, value in getattr(row, "attribute_constraints", ()) or getattr(row, "attributes", ()) or () + ), + ) + for row in rows + ), + effects=match_effects, + ) + + +def _row_schema(row: object) -> TemporalRowSchema: + return TemporalRowSchema( + primitive=str(getattr(row, "primitive", "")), + parameter_inputs=tuple(str(parameter) for parameter in getattr(row, "parameter_inputs", ()) or ()), + inputs=tuple(str(item) for item in getattr(row, "inputs", ()) or ()), + outputs=tuple(str(item) for item in getattr(row, "outputs", ()) or ()), + attributes=tuple((str(key), str(value)) for key, value in getattr(row, "attributes", ()) or ()), + ) + + +__all__ = [ + "TEMPORAL_MESSAGE_BUCKET_ORDINAL", + "TEMPORAL_PARAMETER_REDUCTION_BUCKET_ORDINAL", + "TEMPORAL_READOUT_BUCKET_ORDINAL", + "TemporalRowGroupSchema", + "TemporalRowSchema", + "canonical_temporal_row_group", + "pattern_temporal_row_group", + "surface_for_temporal_row", + "temporal_effects_for_row", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py new file mode 100644 index 00000000..65611406 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +from typing import Any + +from .backward_plan import build_temporal_backward_executable_plan +from .executor_bindings import ( + build_temporal_forward_executor_binding_plan, + build_temporal_reverse_executor_binding_plan, + build_temporal_transition_param_grad_binding_plan, +) +from .forward_plan import build_temporal_forward_executable_plan +from .memory_plan import build_temporal_memory_liveness_plan, temporal_memory_liveness_rows_tensor +from .program_execution import ( + build_temporal_fused_cuda_program_plan, + build_temporal_message_transition_producer_consumer_plan, + build_temporal_readout_message_producer_consumer_plan, + build_temporal_registered_program_executor_plan, + build_temporal_reverse_program_stage_plan, + temporal_forward_artifact_merge_rows_tensor, + temporal_forward_artifact_route_rows_tensor, + temporal_forward_executor_handler_rows_tensor, + temporal_forward_output_route_rows_tensor, + temporal_native_callable_catalog_rows_tensor, + temporal_native_callable_output_rows_tensor, + temporal_native_executor_strategy_rows_tensor, + temporal_message_transition_producer_consumer_rows_tensor, + temporal_readout_message_producer_consumer_rows_tensor, + temporal_reverse_artifact_consumer_route_rows_tensor, + temporal_reverse_executor_handler_rows_tensor, + temporal_reverse_parameter_reducer_route_rows_tensor, + temporal_transition_primitive_native_callable_rows_tensor, +) +from .native_callables import temporal_native_callable_output_summaries, temporal_native_callable_summaries +from .primitive_dispatch import build_temporal_primitive_executor_plan +from .strategy_selection import build_temporal_strategy_selection_report +from .tables import ( + TemporalPrimitiveTablePlan, + temporal_reverse_executor_summaries, + temporal_primitive_rows_tensor, + temporal_tensor_binding_summaries, + temporal_table_transition_recurrent_bucket_kinds, + validate_temporal_supported_scan_binding_projection, +) +from .verification import ( + temporal_strategy_rejection_codes, + verify_temporal_primitive_table, +) + + +def record_temporal_primitive_table_runtime_metadata( + runtime: Any, + table: TemporalPrimitiveTablePlan, + *, + scheduler_plan: Any | None = None, +) -> None: + if scheduler_plan is not None: + runtime._last_flat_bucket_temporal_scheduler_plan = scheduler_plan.review_summary + executor_plan = build_temporal_primitive_executor_plan(table) + forward_binding_plan = build_temporal_forward_executor_binding_plan(table) + reverse_binding_plan = build_temporal_reverse_executor_binding_plan(table) + transition_param_grad_binding_plan = build_temporal_transition_param_grad_binding_plan( + table, + reverse_binding_plan=reverse_binding_plan, + ) + verification_report = verify_temporal_primitive_table( + table, + executor_plan=executor_plan, + forward_binding_plan=forward_binding_plan, + reverse_binding_plan=reverse_binding_plan, + ) + strategy_report = build_temporal_strategy_selection_report( + table, + forward_binding_plan=forward_binding_plan, + reverse_binding_plan=reverse_binding_plan, + ) + memory_plan = build_temporal_memory_liveness_plan(table) + memory_liveness_rows = temporal_memory_liveness_rows_tensor(memory_plan) + forward_executable_plan = build_temporal_forward_executable_plan( + table, + forward_binding_plan=forward_binding_plan, + strategy_report=strategy_report, + ) + backward_executable_plan = build_temporal_backward_executable_plan( + table, + reverse_binding_plan=reverse_binding_plan, + strategy_report=strategy_report, + ) + reverse_program_stage_plan = build_temporal_reverse_program_stage_plan(table, backward_executable_plan) + forward_handler_rows = temporal_forward_executor_handler_rows_tensor(table) + reverse_handler_rows = temporal_reverse_executor_handler_rows_tensor(table) + native_strategy_rows = temporal_native_executor_strategy_rows_tensor() + native_callable_catalog_rows = temporal_native_callable_catalog_rows_tensor() + native_callable_output_rows = temporal_native_callable_output_rows_tensor() + transition_primitive_callable_rows = temporal_transition_primitive_native_callable_rows_tensor() + fused_cuda_program_plan = build_temporal_fused_cuda_program_plan( + primitive_rows=temporal_primitive_rows_tensor(table), + forward_plan=forward_executable_plan, + backward_plan=backward_executable_plan, + memory_plan=memory_plan, + memory_liveness_rows=memory_liveness_rows, + forward_handler_rows=forward_handler_rows, + reverse_handler_rows=reverse_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_output_rows=native_callable_output_rows, + transition_primitive_callable_rows=transition_primitive_callable_rows, + forward_artifact_route_rows=temporal_forward_artifact_route_rows_tensor(table), + forward_artifact_merge_rows=temporal_forward_artifact_merge_rows_tensor(table), + forward_output_route_rows=temporal_forward_output_route_rows_tensor(table), + readout_message_producer_consumer_rows=temporal_readout_message_producer_consumer_rows_tensor( + build_temporal_readout_message_producer_consumer_plan(table) + ), + message_transition_producer_consumer_rows=temporal_message_transition_producer_consumer_rows_tensor( + build_temporal_message_transition_producer_consumer_plan(table) + ), + reverse_artifact_consumer_route_rows=temporal_reverse_artifact_consumer_route_rows_tensor(table), + reverse_parameter_reducer_route_rows=temporal_reverse_parameter_reducer_route_rows_tensor(table), + ) + program_executor_plan = build_temporal_registered_program_executor_plan(fused_cuda_program_plan) + runtime._last_flat_bucket_temporal_table_review = table.review_summary + runtime._last_flat_bucket_temporal_registered_transition_bucket_kinds = tuple( + f"bucket={bucket_ordinal},kind={kind}" + for bucket_ordinal, kind in sorted(temporal_table_transition_recurrent_bucket_kinds(table).items()) + ) + runtime._last_flat_bucket_temporal_primitive_names = table.primitive_names + runtime._last_flat_bucket_temporal_primitive_families = table.primitive_families + runtime._last_flat_bucket_temporal_primitive_row_count = len(table.primitive_rows) + runtime._last_flat_bucket_temporal_tensor_binding_row_count = len(table.tensor_bindings) + runtime._last_flat_bucket_temporal_tensor_binding_summaries = temporal_tensor_binding_summaries(table) + runtime._last_flat_bucket_temporal_scan_binding_projection = validate_temporal_supported_scan_binding_projection( + table + ) + runtime._last_flat_bucket_temporal_reverse_executor_summaries = temporal_reverse_executor_summaries(table) + runtime._last_flat_bucket_temporal_primitive_executor_contracts = executor_plan.summaries + runtime._last_flat_bucket_temporal_primitive_executor_blockers = executor_plan.blockers + runtime._last_flat_bucket_temporal_compiler_pass_pipeline = verification_report.pass_pipeline + runtime._last_flat_bucket_temporal_compiler_schema_versions = verification_report.schema_versions + runtime._last_flat_bucket_temporal_strategy_rejection_codes = temporal_strategy_rejection_codes() + runtime._last_flat_bucket_temporal_verifier_status = verification_report.status + runtime._last_flat_bucket_temporal_verifier_issues = verification_report.issue_summaries + runtime._last_flat_bucket_temporal_effect_summaries = verification_report.effect_summaries + runtime._last_flat_bucket_temporal_planner_explain = verification_report.explain_summary + runtime._last_flat_bucket_temporal_strategy_candidates = strategy_report.candidate_summaries + runtime._last_flat_bucket_temporal_legal_strategy_candidates = strategy_report.legal_summaries + runtime._last_flat_bucket_temporal_blocked_strategy_candidates = strategy_report.blocked_summaries + runtime._last_flat_bucket_temporal_forward_executor_binding_rows = forward_binding_plan.rows + runtime._last_flat_bucket_temporal_forward_executor_handler_rows = forward_handler_rows + runtime._last_flat_bucket_temporal_native_strategy_rows = native_strategy_rows + runtime._last_flat_bucket_temporal_native_callable_catalog_rows = native_callable_catalog_rows + runtime._last_flat_bucket_temporal_native_callable_catalog_summaries = temporal_native_callable_summaries() + runtime._last_flat_bucket_temporal_native_callable_output_rows = native_callable_output_rows + runtime._last_flat_bucket_temporal_native_callable_output_summaries = temporal_native_callable_output_summaries() + runtime._last_flat_bucket_temporal_forward_executor_binding_summaries = forward_binding_plan.summaries + runtime._last_flat_bucket_temporal_forward_executor_binding_blockers = forward_binding_plan.blocker_summaries + runtime._last_flat_bucket_temporal_reverse_executor_binding_rows = reverse_binding_plan.rows + runtime._last_flat_bucket_temporal_reverse_executor_handler_rows = reverse_handler_rows + runtime._last_flat_bucket_temporal_reverse_executor_binding_summaries = reverse_binding_plan.summaries + runtime._last_flat_bucket_temporal_reverse_executor_binding_blockers = reverse_binding_plan.blocker_summaries + runtime._last_flat_bucket_temporal_transition_param_grad_binding_rows = transition_param_grad_binding_plan.rows + runtime._last_flat_bucket_temporal_transition_param_grad_binding_summaries = ( + transition_param_grad_binding_plan.summaries + ) + runtime._last_flat_bucket_temporal_memory_plan_review = memory_plan.review_summary + runtime._last_flat_bucket_temporal_memory_plan_summaries = memory_plan.summaries + runtime._last_flat_bucket_temporal_memory_liveness_rows = memory_liveness_rows + runtime._last_flat_bucket_temporal_reverse_program_stage_rows = reverse_program_stage_plan.rows + runtime._last_flat_bucket_temporal_reverse_program_stage_summaries = reverse_program_stage_plan.summaries + runtime._last_flat_bucket_temporal_workspace_policy = memory_plan.workspace_policy + runtime._last_flat_bucket_temporal_layout_policy = memory_plan.layout_policy + runtime._last_flat_bucket_temporal_alias_policy = memory_plan.alias_policy + runtime._last_flat_bucket_temporal_memory_peak_estimate_bytes = memory_plan.peak_workspace_estimate_bytes + runtime._last_flat_bucket_temporal_forward_executable_plan = forward_executable_plan.review_summary + runtime._last_flat_bucket_temporal_forward_strategy_ids = forward_executable_plan.strategy_ids + runtime._last_flat_bucket_temporal_forward_runtime_entrypoint = forward_executable_plan.runtime_entrypoint + runtime._last_flat_bucket_temporal_forward_strategy_legality_status = ( + forward_executable_plan.strategy_legality_status + ) + runtime._last_flat_bucket_temporal_forward_strategy_legality_reasons = ( + forward_executable_plan.strategy_legality_reasons + ) + runtime._last_flat_bucket_temporal_backward_executable_plan = backward_executable_plan.review_summary + runtime._last_flat_bucket_temporal_backward_strategy_ids = backward_executable_plan.strategy_ids + runtime._last_flat_bucket_temporal_backward_runtime_entrypoint = backward_executable_plan.runtime_entrypoint + runtime._last_flat_bucket_temporal_backward_strategy_legality_status = ( + backward_executable_plan.strategy_legality_status + ) + runtime._last_flat_bucket_temporal_backward_strategy_legality_reasons = ( + backward_executable_plan.strategy_legality_reasons + ) + runtime._last_flat_bucket_temporal_fused_cuda_program_plan = fused_cuda_program_plan.review_summary + runtime._last_flat_bucket_temporal_fused_cuda_launch_contract = ( + fused_cuda_program_plan.launch_contract.review_summary + ) + runtime._last_flat_bucket_temporal_fused_cuda_program_status = fused_cuda_program_plan.status + runtime._last_flat_bucket_temporal_fused_cuda_program_blocker = ( + fused_cuda_program_plan.blocker_code, + fused_cuda_program_plan.blocker_reason, + ) + runtime._last_flat_bucket_temporal_registered_program_executor_plan = program_executor_plan.review_summary + runtime._last_flat_bucket_temporal_registered_program_executor_status = program_executor_plan.status + runtime._last_flat_bucket_temporal_registered_program_executor_demotion_policy = ( + program_executor_plan.demotion_policy + ) + + +__all__ = ["record_temporal_primitive_table_runtime_metadata"] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/scan_schedule.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/scan_schedule.py new file mode 100644 index 00000000..c70cebb7 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/scan_schedule.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterator + + +@dataclass(frozen=True) +class TemporalPhysicalScanStep: + physical_step: int + outer_step: int + inner_step: int + emit_output: bool + apply_boundary_reset: bool + apply_transition_reset: bool + + +@dataclass(frozen=True) +class TemporalPhysicalScanSchedule: + outer_time_steps: int + inner_steps: int + + @property + def physical_time_steps(self) -> int: + return int(self.outer_time_steps) * int(self.inner_steps) + + @property + def emission_steps(self) -> tuple[int, ...]: + return tuple(range(int(self.inner_steps) - 1, self.physical_time_steps, int(self.inner_steps))) + + @property + def steps(self) -> tuple[TemporalPhysicalScanStep, ...]: + return tuple(self.iter_steps()) + + def step_at(self, physical_step: int) -> TemporalPhysicalScanStep: + physical_step = int(physical_step) + if physical_step < 0 or physical_step >= self.physical_time_steps: + raise IndexError("Temporal physical scan step is outside the schedule") + return scalar_temporal_scan_step(physical_step=physical_step, inner_steps=int(self.inner_steps)) + + def iter_steps(self, *, start: int = 0, end: int | None = None) -> Iterator[TemporalPhysicalScanStep]: + start = int(start) + stop = self.physical_time_steps if end is None else int(end) + if start < 0 or stop < start or stop > self.physical_time_steps: + raise IndexError("Temporal physical scan step range is outside the schedule") + for physical_step in range(start, stop): + yield scalar_temporal_scan_step(physical_step=physical_step, inner_steps=int(self.inner_steps)) + + +def scalar_temporal_scan_step(*, physical_step: int, inner_steps: int) -> TemporalPhysicalScanStep: + physical_step = int(physical_step) + inner_steps = int(inner_steps) + if physical_step < 0: + raise ValueError("Temporal physical scan step must be non-negative") + if inner_steps <= 0: + raise ValueError("Temporal physical scan requires positive inner_steps") + outer_step, inner_step = divmod(physical_step, inner_steps) + return TemporalPhysicalScanStep( + physical_step=physical_step, + outer_step=outer_step, + inner_step=inner_step, + emit_output=inner_step == inner_steps - 1, + apply_boundary_reset=inner_step == 0, + apply_transition_reset=True, + ) + + +def emitted_output_index_for_scan_step( + scan_step: TemporalPhysicalScanStep, + *, + outer_time_steps: int, + emitted_time_steps: int, +) -> int | None: + if not scan_step.emit_output: + return None + outer_time_steps = int(outer_time_steps) + emitted_time_steps = int(emitted_time_steps) + if emitted_time_steps == outer_time_steps: + return scan_step.outer_step + if emitted_time_steps == 1 and scan_step.outer_step == outer_time_steps - 1: + return 0 + return None + + +def build_scalar_temporal_scan_schedule(*, outer_time_steps: int, inner_steps: int) -> TemporalPhysicalScanSchedule: + outer_time_steps = int(outer_time_steps) + inner_steps = int(inner_steps) + if outer_time_steps <= 0: + raise ValueError("Temporal physical scan requires positive outer_time_steps") + if inner_steps <= 0: + raise ValueError("Temporal physical scan requires positive inner_steps") + return TemporalPhysicalScanSchedule( + outer_time_steps=outer_time_steps, + inner_steps=inner_steps, + ) + + +__all__ = [ + "TemporalPhysicalScanSchedule", + "TemporalPhysicalScanStep", + "build_scalar_temporal_scan_schedule", + "emitted_output_index_for_scan_step", + "scalar_temporal_scan_step", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/strategy_selection.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/strategy_selection.py new file mode 100644 index 00000000..22a96418 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/strategy_selection.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from .executor_bindings import ( + TemporalExecutorBindingPlan, + build_temporal_forward_executor_binding_plan, + build_temporal_reverse_executor_binding_plan, +) +from .executor_patterns import ( + TemporalForwardExecutorPattern, + TemporalReverseExecutorPattern, + temporal_executor_strategy_registry, +) +from .row_groups import ( + TemporalRowGroupSchema, + canonical_temporal_row_group, +) +from .tables import TemporalPrimitiveTablePlan, temporal_forward_executor_rows, temporal_reverse_executor_rows +from .verification import TemporalStrategyRejectionCode + + +@dataclass(frozen=True) +class TemporalStrategyCandidate: + direction: Literal["forward", "reverse"] + strategy_id: str + executor_name: str + surface: str + bucket_ordinal: int | None + row_group_summary: str + match_status: Literal["matched", "not_matched"] + legality_status: Literal["legal", "blocked"] + rejection_code: TemporalStrategyRejectionCode | None + rejection_reason: str + legality_reasons: tuple[str, ...] + binding_row_count: int + binding_blocker_count: int + cost_model: str + estimated_cost_rank: int | None + runtime_entrypoint: str + + @property + def summary(self) -> str: + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + rank = "*" if self.estimated_cost_rank is None else str(int(self.estimated_cost_rank)) + code = "-" if self.rejection_code is None else self.rejection_code + reason = "-" if not self.rejection_reason else self.rejection_reason + legality_reasons = "|".join(self.legality_reasons) if self.legality_reasons else "-" + return ( + f"direction={self.direction}" + f",strategy_id={self.strategy_id}" + f",executor={self.executor_name}" + f",surface={self.surface}" + f",bucket={bucket}" + f",match={self.match_status}" + f",legality={self.legality_status}" + f",rejection={code}" + f",reason={reason}" + f",legality_reasons={legality_reasons}" + f",binding_rows={int(self.binding_row_count)}" + f",binding_blockers={int(self.binding_blocker_count)}" + f",cost_model={self.cost_model}" + f",rank={rank}" + f",runtime={self.runtime_entrypoint}" + f",row_group={self.row_group_summary}" + ) + + +@dataclass(frozen=True) +class TemporalStrategySelectionReport: + candidates: tuple[TemporalStrategyCandidate, ...] + + @property + def candidate_summaries(self) -> tuple[str, ...]: + return tuple(candidate.summary for candidate in self.candidates) + + @property + def legal_summaries(self) -> tuple[str, ...]: + return tuple(candidate.summary for candidate in self.candidates if candidate.legality_status == "legal") + + @property + def blocked_summaries(self) -> tuple[str, ...]: + return tuple( + candidate.summary + for candidate in self.candidates + if candidate.match_status == "matched" and candidate.legality_status == "blocked" + ) + + +def build_temporal_strategy_selection_report( + table: TemporalPrimitiveTablePlan, + *, + forward_binding_plan: TemporalExecutorBindingPlan | None = None, + reverse_binding_plan: TemporalExecutorBindingPlan | None = None, + directions: tuple[Literal["forward", "reverse"], ...] = ("forward", "reverse"), +) -> TemporalStrategySelectionReport: + requested_directions = frozenset(directions) + forward_groups = _forward_row_groups(table) if "forward" in requested_directions else () + reverse_groups = _reverse_row_groups(table) if "reverse" in requested_directions else () + binding_counts = {} + if "forward" in requested_directions: + forward_binding_plan = ( + build_temporal_forward_executor_binding_plan(table) + if forward_binding_plan is None + else forward_binding_plan + ) + binding_counts.update(_binding_counts_by_group(forward_binding_plan)) + if "reverse" in requested_directions: + reverse_binding_plan = ( + build_temporal_reverse_executor_binding_plan(table) + if reverse_binding_plan is None + else reverse_binding_plan + ) + binding_counts.update(_binding_counts_by_group(reverse_binding_plan)) + candidates: list[TemporalStrategyCandidate] = [] + strategy_registry = temporal_executor_strategy_registry() + if "forward" in requested_directions: + for pattern in strategy_registry.forward_patterns(): + candidates.extend(_candidates_for_forward_pattern(pattern, forward_groups, binding_counts)) + if "reverse" in requested_directions: + for pattern in strategy_registry.reverse_patterns(): + candidates.extend(_candidates_for_reverse_pattern(pattern, reverse_groups, binding_counts)) + return TemporalStrategySelectionReport(candidates=tuple(candidates)) + + +def _candidates_for_forward_pattern( + pattern: TemporalForwardExecutorPattern, + groups: tuple[TemporalRowGroupSchema, ...], + binding_counts: dict[tuple[Literal["forward", "reverse"], str, int], tuple[int, int]], +) -> tuple[TemporalStrategyCandidate, ...]: + return _candidates_for_pattern( + direction="forward", + pattern=pattern, + groups=groups, + binding_counts=binding_counts, + ) + + +def _candidates_for_reverse_pattern( + pattern: TemporalReverseExecutorPattern, + groups: tuple[TemporalRowGroupSchema, ...], + binding_counts: dict[tuple[Literal["forward", "reverse"], str, int], tuple[int, int]], +) -> tuple[TemporalStrategyCandidate, ...]: + return _candidates_for_pattern( + direction="reverse", + pattern=pattern, + groups=groups, + binding_counts=binding_counts, + ) + + +def _candidates_for_pattern( + *, + direction: Literal["forward", "reverse"], + pattern: TemporalForwardExecutorPattern | TemporalReverseExecutorPattern, + groups: tuple[TemporalRowGroupSchema, ...], + binding_counts: dict[tuple[Literal["forward", "reverse"], str, int], tuple[int, int]], +) -> tuple[TemporalStrategyCandidate, ...]: + matched_groups = tuple(group for group in groups if pattern.row_group_schema.matches(group)) + if not matched_groups: + reason = "strategy_row_group_schema_did_not_match_compiled_rows" + binding_row_count, binding_blocker_count = _binding_stats_for_pattern( + direction=direction, + pattern=pattern, + binding_counts=binding_counts, + ) + return ( + TemporalStrategyCandidate( + direction=direction, + strategy_id=pattern.stable_strategy_id, + executor_name=pattern.executor_name, + surface=pattern.surface, + bucket_ordinal=pattern.bucket_ordinal, + row_group_summary=pattern.row_group_schema.summary, + match_status="not_matched", + legality_status="blocked", + rejection_code="UNSUPPORTED_PATTERN", + rejection_reason=reason, + legality_reasons=( + f"strategy_legality_blocker=UNSUPPORTED_PATTERN" + f",strategy_id={pattern.stable_strategy_id}" + f",executor={pattern.executor_name}" + f",surface={pattern.surface}" + ",bucket=*" + f",reason={reason}", + ), + binding_row_count=binding_row_count, + binding_blocker_count=binding_blocker_count, + cost_model=pattern.cost_model, + estimated_cost_rank=None, + runtime_entrypoint=pattern.runtime_entrypoint, + ), + ) + return tuple( + _matched_candidate_for_pattern( + direction=direction, + pattern=pattern, + matched_group=matched_group, + binding_counts=binding_counts, + ) + for matched_group in matched_groups + ) + + +def _matched_candidate_for_pattern( + *, + direction: Literal["forward", "reverse"], + pattern: TemporalForwardExecutorPattern | TemporalReverseExecutorPattern, + matched_group: TemporalRowGroupSchema, + binding_counts: dict[tuple[Literal["forward", "reverse"], str, int], tuple[int, int]], +) -> TemporalStrategyCandidate: + blockers = _legality_reasons_for_pattern(pattern=pattern, matched_group=matched_group) + binding_row_count, binding_blocker_count = _binding_stats_for_group( + direction=direction, + surface=matched_group.surface, + bucket_ordinal=matched_group.bucket_ordinal, + binding_counts=binding_counts, + ) + if blockers: + rejection_code = ( + "UNVERIFIED_REWRITE" + if pattern.verified_rewrite_required + else "ABI_VERSION_MISMATCH" + if any("ABI_VERSION_MISMATCH" in blocker for blocker in blockers) + else "UNSUPPORTED_PATTERN" + ) + return TemporalStrategyCandidate( + direction=direction, + strategy_id=pattern.stable_strategy_id, + executor_name=pattern.executor_name, + surface=pattern.surface, + bucket_ordinal=matched_group.bucket_ordinal, + row_group_summary=matched_group.summary, + match_status="matched", + legality_status="blocked", + rejection_code=rejection_code, + rejection_reason=( + "strategy_matches_but_requires_verified_rewrite_before_cost_selection" + if pattern.verified_rewrite_required + else "strategy_matches_but_failed_legality_filter" + ), + legality_reasons=blockers, + binding_row_count=binding_row_count, + binding_blocker_count=binding_blocker_count, + cost_model=pattern.cost_model, + estimated_cost_rank=None, + runtime_entrypoint=pattern.runtime_entrypoint, + ) + return TemporalStrategyCandidate( + direction=direction, + strategy_id=pattern.stable_strategy_id, + executor_name=pattern.executor_name, + surface=pattern.surface, + bucket_ordinal=matched_group.bucket_ordinal, + row_group_summary=matched_group.summary, + match_status="matched", + legality_status="legal", + rejection_code=None, + rejection_reason="", + legality_reasons=(), + binding_row_count=binding_row_count, + binding_blocker_count=binding_blocker_count, + cost_model=pattern.cost_model, + estimated_cost_rank=0, + runtime_entrypoint=pattern.runtime_entrypoint, + ) + + +def _legality_reasons_for_pattern( + *, + pattern: TemporalForwardExecutorPattern | TemporalReverseExecutorPattern, + matched_group: TemporalRowGroupSchema, +) -> tuple[str, ...]: + reasons: list[str] = [] + bucket = "*" if matched_group.bucket_ordinal is None else str(int(matched_group.bucket_ordinal)) + if pattern.row_schema_version != 1 or pattern.tensor_binding_schema_version != 1: + reasons.append( + f"strategy_legality_blocker=ABI_VERSION_MISMATCH" + f",strategy_id={pattern.stable_strategy_id}" + f",executor={pattern.executor_name}" + f",surface={pattern.surface}" + f",bucket={bucket}" + ",reason=row_or_tensor_binding_schema_version_mismatch" + ) + if pattern.metadata_schema_version != 1 or pattern.cuda_kernel_abi_version != 1: + reasons.append( + f"strategy_legality_blocker=ABI_VERSION_MISMATCH" + f",strategy_id={pattern.stable_strategy_id}" + f",executor={pattern.executor_name}" + f",surface={pattern.surface}" + f",bucket={bucket}" + ",reason=metadata_or_cuda_kernel_abi_version_mismatch" + ) + missing_match_effects = tuple(effect for effect in pattern.match_effects if effect not in matched_group.effects) + if missing_match_effects: + reasons.append( + f"strategy_legality_blocker=UNSUPPORTED_PATTERN" + f",strategy_id={pattern.stable_strategy_id}" + f",executor={pattern.executor_name}" + f",surface={pattern.surface}" + f",bucket={bucket}" + f",reason=missing_required_effects:{','.join(missing_match_effects)}" + ) + if pattern.verified_rewrite_required: + reasons.append( + f"strategy_legality_blocker=UNVERIFIED_REWRITE" + f",strategy_id={pattern.stable_strategy_id}" + f",executor={pattern.executor_name}" + f",surface={pattern.surface}" + f",bucket={bucket}" + f",reason={pattern.implementation_contract}" + ) + return tuple(reasons) + + +def _binding_counts_by_group( + binding_plan: TemporalExecutorBindingPlan, +) -> dict[tuple[Literal["forward", "reverse"], str, int], tuple[int, int]]: + binding_counts: dict[tuple[Literal["forward", "reverse"], str, int], int] = {} + blocker_counts: dict[tuple[Literal["forward", "reverse"], str, int], int] = {} + for binding in binding_plan.bindings: + key = (binding.direction, binding.surface, int(binding.bucket_ordinal)) + binding_counts[key] = binding_counts.get(key, 0) + 1 + for blocker in binding_plan.blockers: + key = (blocker.direction, blocker.surface, int(blocker.bucket_ordinal)) + blocker_counts[key] = blocker_counts.get(key, 0) + 1 + return { + key: (binding_counts.get(key, 0), blocker_counts.get(key, 0)) + for key in set(binding_counts) | set(blocker_counts) + } + + +def _binding_stats_for_pattern( + *, + direction: Literal["forward", "reverse"], + pattern: TemporalForwardExecutorPattern | TemporalReverseExecutorPattern, + binding_counts: dict[tuple[Literal["forward", "reverse"], str, int], tuple[int, int]], +) -> tuple[int, int]: + if pattern.bucket_ordinal is None: + counts = [ + stats + for (candidate_direction, surface, _bucket), stats in binding_counts.items() + if candidate_direction == direction and surface == pattern.surface + ] + return sum(count for count, _blockers in counts), sum(blockers for _count, blockers in counts) + return _binding_stats_for_group( + direction=direction, + surface=pattern.surface, + bucket_ordinal=int(pattern.bucket_ordinal), + binding_counts=binding_counts, + ) + + +def _binding_stats_for_group( + *, + direction: Literal["forward", "reverse"], + surface: str, + bucket_ordinal: int | None, + binding_counts: dict[tuple[Literal["forward", "reverse"], str, int], tuple[int, int]], +) -> tuple[int, int]: + if bucket_ordinal is None: + return 0, 0 + return binding_counts.get((direction, surface, int(bucket_ordinal)), (0, 0)) + + +def _forward_row_groups(table: TemporalPrimitiveTablePlan) -> tuple[TemporalRowGroupSchema, ...]: + return tuple(_executor_row_group(table, row) for row in temporal_forward_executor_rows(table)) + + +def _reverse_row_groups(table: TemporalPrimitiveTablePlan) -> tuple[TemporalRowGroupSchema, ...]: + return tuple(_executor_row_group(table, row) for row in temporal_reverse_executor_rows(table)) + + +def _executor_row_group( + table: TemporalPrimitiveTablePlan, + row: object, +) -> TemporalRowGroupSchema: + primitive_row_start = int(getattr(row, "primitive_row_start")) + primitive_row_count = int(getattr(row, "primitive_row_count")) + primitive_rows = tuple(table.primitive_rows[primitive_row_start : primitive_row_start + primitive_row_count]) + return canonical_temporal_row_group( + surface=str(getattr(row, "surface")), + bucket_ordinal=int(getattr(row, "bucket_ordinal")), + rows=primitive_rows, + ) + + +__all__ = [ + "TemporalStrategyCandidate", + "TemporalStrategySelectionReport", + "build_temporal_strategy_selection_report", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py new file mode 100644 index 00000000..50d19188 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py @@ -0,0 +1,1219 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +import torch + +from cortical.fabric.backend.cuda.transition_execution.registry import ( + transition_primitive_executor_record_for_lowered_primitive, +) +from cortical.fabric.backend.planner import cuda_nn_primitive_backward_behavior + +from .buckets import backend_order_flat_buckets +from .executor_patterns import ( + temporal_executor_strategy_registry, + surface_for_temporal_row, +) +from .row_groups import ( + TEMPORAL_MESSAGE_BUCKET_ORDINAL, + TEMPORAL_PARAMETER_REDUCTION_BUCKET_ORDINAL, + TEMPORAL_READOUT_BUCKET_ORDINAL, +) +from .primitive_registry import ( + temporal_binding_kind_opcode, + temporal_full_tape_extra_state_factor, + temporal_primitive_opcode, + temporal_surface_opcode, + temporal_transition_tape_kind, +) + +_MESSAGE_RUNTIME_STATIC_TENSOR_KEYS = { + "recurrent_q_weight": ("recurrent_q_backend_order",), + "input_sender_kv_weight": ("input_sender_input_to_kv_weight",), + "input_group_kv_weight": ("input_group_input_to_kv_weight",), + "recurrent_sender_kv_weight": ("recurrent_sender_input_to_kv_weight_backend_order",), + "message_query_slot_weight": ("message_query_slot_weight",), + "message_query_nudge_scale": ("message_query_nudge_scale",), + "message_query_context_gate": ("message_query_context_gate",), + "message_sender_slot_key_weight": ("message_sender_slot_key_weight",), + "message_sender_context_key": ("message_sender_context_key",), + "input_sender_value_weight": ("input_sender_value_weight",), + "input_group_value_weight": ("input_group_value_weight",), + "recurrent_sender_value_weight": ("recurrent_sender_value_weight",), + "message_output_weight": ("message_output_weight",), +} +_READOUT_RUNTIME_STATIC_TENSOR_KEYS = { + "output_q": ("output_q",), + "value_to_output_weight": ("value_to_output_weight",), +} +_READOUT_RUNTIME_ATTR_KEYS = { + "output_cell_bias": ("output_cell_bias",), +} + + +@dataclass(frozen=True) +class TemporalTensorTableSlot: + bucket_ordinal: int + table_role: str + slot: int + key: str + semantic_kind: str + + +@dataclass(frozen=True) +class TemporalPrimitiveRow: + bucket_ordinal: int + receiver_start: int + receiver_stop: int + primitive: str + primitive_family: str + backward_behavior: str + inputs: tuple[str, ...] + outputs: tuple[str, ...] + attributes: tuple[tuple[str, str], ...] + flat_bucket_identity: tuple[str, ...] + parameter_inputs: tuple[str, ...] = () + + @property + def receiver_count(self) -> int: + return int(self.receiver_stop) - int(self.receiver_start) + + +@dataclass(frozen=True) +class TemporalTensorBindingRow: + binding_index: int + row_index: int + bucket_ordinal: int + surface: str + primitive: str + binding_kind: str + logical_name: str + source_bindings: tuple[str, ...] + flat_bucket_identity: tuple[str, ...] + + @property + def summary(self) -> str: + sources = ",".join(self.source_bindings) if self.source_bindings else "-" + return ( + f"binding={int(self.binding_index)}" + f",row={int(self.row_index)}" + f",surface={self.surface}" + f",primitive={self.primitive}" + f",bucket={int(self.bucket_ordinal)}" + f",kind={self.binding_kind}" + f",logical={self.logical_name}" + f",sources={sources}" + ) + + +@dataclass(frozen=True) +class TemporalForwardExecutorRow: + executor_id: int + executor_name: str + surface: str + bucket_ordinal: int + primitive_row_start: int + primitive_row_count: int + receiver_start: int + receiver_count: int + parameter_bindings: tuple[str, ...] + + @property + def scan_projection_summary(self) -> str: + if self.surface == "transition": + return f"transition_bucket={int(self.bucket_ordinal)}:{self.executor_name.removesuffix('_transition')}" + return f"{self.surface}={self.executor_name}" + + @property + def summary(self) -> str: + params = ",".join(self.parameter_bindings) if self.parameter_bindings else "-" + return ( + f"executor={self.executor_name}" + f",surface={self.surface}" + f",bucket={int(self.bucket_ordinal)}" + f",rows={int(self.primitive_row_start)}:{int(self.primitive_row_count)}" + f",receivers={int(self.receiver_start)}:{int(self.receiver_count)}" + f",params={params}" + ) + + +@dataclass(frozen=True) +class TemporalReverseExecutorRow: + executor_id: int + executor_name: str + surface: str + bucket_ordinal: int + primitive_row_start: int + primitive_row_count: int + receiver_start: int + receiver_count: int + parameter_bindings: tuple[str, ...] + + @property + def summary(self) -> str: + params = ",".join(self.parameter_bindings) if self.parameter_bindings else "-" + return ( + f"reverse_executor={self.executor_name}" + f",surface={self.surface}" + f",bucket={int(self.bucket_ordinal)}" + f",rows={int(self.primitive_row_start)}:{int(self.primitive_row_count)}" + f",receivers={int(self.receiver_start)}:{int(self.receiver_count)}" + f",params={params}" + ) + + +@dataclass(frozen=True) +class TemporalPrimitiveTablePlan: + bucket_count: int + tensor_slots: tuple[TemporalTensorTableSlot, ...] + primitive_rows: tuple[TemporalPrimitiveRow, ...] + tensor_bindings: tuple[TemporalTensorBindingRow, ...] = () + + @property + def primitive_names(self) -> tuple[str, ...]: + return tuple(dict.fromkeys(row.primitive for row in self.primitive_rows)) + + @property + def primitive_families(self) -> tuple[str, ...]: + return tuple(dict.fromkeys(row.primitive_family for row in self.primitive_rows)) + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "temporal_table_abi=flat_bucket_tensor_tables", + f"bucket_count={int(self.bucket_count)}", + f"primitive_row_count={len(self.primitive_rows)}", + f"tensor_binding_row_count={len(self.tensor_bindings)}", + "primitive_families=" + ",".join(self.primitive_families), + ) + + @property + def fingerprint(self) -> tuple[str, ...]: + return ( + f"bucket_count={int(self.bucket_count)}", + *( + "row=" + f"{row_index}:" + f"bucket={int(row.bucket_ordinal)}:" + f"receivers={int(row.receiver_start)}-{int(row.receiver_stop)}:" + f"primitive={row.primitive}:" + f"family={row.primitive_family}:" + f"backward={row.backward_behavior}:" + f"inputs={','.join(row.inputs)}:" + f"outputs={','.join(row.outputs)}:" + f"params={','.join(row.parameter_inputs)}:" + f"attrs={','.join(f'{key}={value}' for key, value in row.attributes)}" + for row_index, row in enumerate(self.primitive_rows) + ), + *( + "binding=" + f"{int(binding.binding_index)}:" + f"row={int(binding.row_index)}:" + f"bucket={int(binding.bucket_ordinal)}:" + f"surface={binding.surface}:" + f"primitive={binding.primitive}:" + f"kind={binding.binding_kind}:" + f"logical={binding.logical_name}:" + f"sources={','.join(binding.source_bindings)}" + for binding in self.tensor_bindings + ), + ) + + def review_disallowed_terms(self, terms: Sequence[str]) -> tuple[str, ...]: + haystack = "\n".join( + ( + *(slot.table_role for slot in self.tensor_slots), + *(slot.key for slot in self.tensor_slots), + *(slot.semantic_kind for slot in self.tensor_slots), + *(row.primitive for row in self.primitive_rows), + *(row.primitive_family for row in self.primitive_rows), + *(row.backward_behavior for row in self.primitive_rows), + *(item for row in self.primitive_rows for item in row.inputs), + *(item for row in self.primitive_rows for item in row.outputs), + *(item for row in self.primitive_rows for item in row.parameter_inputs), + *(item for row in self.primitive_rows for item in row.flat_bucket_identity), + *(binding.surface for binding in self.tensor_bindings), + *(binding.primitive for binding in self.tensor_bindings), + *(binding.binding_kind for binding in self.tensor_bindings), + *(binding.logical_name for binding in self.tensor_bindings), + *(item for binding in self.tensor_bindings for item in binding.source_bindings), + *(item for binding in self.tensor_bindings for item in binding.flat_bucket_identity), + ) + ) + return tuple(term for term in terms if term and term in haystack) + + +def build_temporal_primitive_table_plan( + runtime: Any, + static_tensors: dict[str, object], +) -> TemporalPrimitiveTablePlan: + tensor_slots: list[TemporalTensorTableSlot] = [] + primitive_rows: list[TemporalPrimitiveRow] = [] + buckets = tuple(backend_order_flat_buckets(runtime, static_tensors)) + _append_message_primitive_rows( + primitive_rows, + runtime=runtime, + receiver_count=sum(max(0, int(bucket.backend_stop) - int(bucket.backend_start)) for bucket in buckets), + ) + _append_readout_primitive_rows(primitive_rows, runtime=runtime) + _append_parameter_reduction_primitive_rows(primitive_rows, runtime=runtime, buckets=buckets) + for bucket_ordinal, bucket in enumerate(buckets): + transition_program = _transition_program_for_bucket(runtime, bucket) + _append_schema_slots( + tensor_slots, + bucket_ordinal=bucket_ordinal, + table_role="private_state", + schema=transition_program.private_state_schema, + ) + _append_schema_slots( + tensor_slots, + bucket_ordinal=bucket_ordinal, + table_role="public_state", + schema=transition_program.public_interface_schema, + ) + _append_schema_slots( + tensor_slots, + bucket_ordinal=bucket_ordinal, + table_role="transition_params", + schema=transition_program.parameter_schema, + ) + for op in transition_program.primitive_ops: + primitive = str(op.primitive) + behavior = cuda_nn_primitive_backward_behavior(primitive) + parameter_inputs = tuple(str(item) for item in getattr(op, "parameter_inputs", ())) + if primitive == "norm_or_identity": + parameter_inputs = (*parameter_inputs, "outnorm_eps") + elif primitive in {"diag_rtu", "diagonal_recurrence"}: + parameter_inputs = (*parameter_inputs, "activation_id") + primitive_rows.append( + TemporalPrimitiveRow( + bucket_ordinal=bucket_ordinal, + receiver_start=int(bucket.backend_start), + receiver_stop=int(bucket.backend_stop), + primitive=primitive, + primitive_family=str(behavior.family), + backward_behavior=str(behavior.behavior), + inputs=tuple(str(item) for item in op.inputs), + outputs=tuple(str(item) for item in op.outputs), + attributes=( + *tuple((str(key), str(value)) for key, value in op.attributes), + ("compiled_transition_lowering_kind", str(transition_program.lowering_kind)), + ("compiled_transition_binding_slot", str(int(transition_program.binding_slot))), + ), + flat_bucket_identity=( + "surface=transition", + *tuple(str(item) for item in bucket.flat_bucket_identity), + ), + parameter_inputs=tuple(dict.fromkeys(parameter_inputs)), + ) + ) + tensor_bindings = _build_tensor_binding_rows( + runtime=runtime, + rows=tuple(primitive_rows), + buckets=buckets, + ) + return TemporalPrimitiveTablePlan( + bucket_count=len(buckets), + tensor_slots=tuple(tensor_slots), + primitive_rows=tuple(primitive_rows), + tensor_bindings=tensor_bindings, + ) + + +def _transition_program_for_bucket(runtime: Any, bucket: Any) -> Any: + backend_ir = getattr(runtime, "backend_ir", None) + if backend_ir is None: + raise RuntimeError("Fabric runtime is missing backend_ir; transition programs must compile before execution") + binding_slot = int(getattr(bucket, "binding_slot", -1)) + transition_program = backend_ir.transition_program_for_binding_slot(binding_slot) + if int(getattr(transition_program, "binding_slot", -2)) != binding_slot: + raise RuntimeError( + "Fabric backend IR transition program binding slot mismatch: " + f"bucket={binding_slot}, program={getattr(transition_program, 'binding_slot', None)}" + ) + primitive_ops = tuple(getattr(transition_program, "primitive_ops", ()) or ()) + if not primitive_ops: + raise RuntimeError( + f"Fabric transition program for binding_slot={binding_slot} has no primitive rows; " + "cell transitions must lower before temporal table construction" + ) + return transition_program + + +def _flat_bucket_name(bucket: Any) -> str: + return str(getattr(bucket, "name", getattr(bucket, "binding_name", ""))) + + +def _append_message_primitive_rows( + rows: list[TemporalPrimitiveRow], + *, + runtime: Any, + receiver_count: int, +) -> None: + message_program = getattr(getattr(runtime, "backend_ir", None), "message_program", None) + if message_program is None: + raise RuntimeError( + "Fabric backend IR is missing compiled message_program; " + "message rules must compile before temporal primitive table construction" + ) + primitive_ops = tuple(getattr(message_program, "primitive_ops", ()) or ()) + if not primitive_ops: + return + rule_name = str(getattr(message_program, "rule_name", "message_rule")) + lowering_kind = str(getattr(message_program, "lowering_kind", "unsupported")) + for op in primitive_ops: + primitive = str(getattr(op, "primitive", "")) + behavior = cuda_nn_primitive_backward_behavior(primitive) + rows.append( + TemporalPrimitiveRow( + bucket_ordinal=TEMPORAL_MESSAGE_BUCKET_ORDINAL, + receiver_start=0, + receiver_stop=max(0, int(receiver_count)), + primitive=primitive, + primitive_family=f"message/{behavior.family}", + backward_behavior=str(behavior.behavior), + inputs=tuple(str(item) for item in getattr(op, "inputs", ())), + outputs=tuple(str(item) for item in getattr(op, "outputs", ())), + attributes=( + *tuple((str(key), str(value)) for key, value in getattr(op, "attributes", ())), + ("compiled_rule", rule_name), + ("compiled_lowering_kind", lowering_kind), + ), + flat_bucket_identity=( + "surface=message", + f"rule={rule_name}", + f"node={int(getattr(op, 'node_index', -1))}", + ), + parameter_inputs=tuple(str(item) for item in getattr(op, "parameter_bindings", ())), + ) + ) + + +def _append_readout_primitive_rows( + rows: list[TemporalPrimitiveRow], + *, + runtime: Any, +) -> None: + output_count = int(getattr(getattr(runtime, "backend_ir", None), "num_output_ports", 0)) + if output_count <= 0: + return + readout_program = getattr(getattr(runtime, "backend_ir", None), "readout_program", None) + if readout_program is None: + raise RuntimeError( + "Fabric backend IR is missing compiled readout_program; " + "readout rules must compile before temporal primitive table construction" + ) + for op in tuple(getattr(readout_program, "primitive_ops", ()) or ()): + primitive = str(getattr(op, "primitive", "")) + behavior = cuda_nn_primitive_backward_behavior(primitive) + surface = "readout_boundary" if primitive == "reduction_boundary" else "readout" + rows.append( + TemporalPrimitiveRow( + bucket_ordinal=TEMPORAL_READOUT_BUCKET_ORDINAL, + receiver_start=0, + receiver_stop=output_count, + primitive=primitive, + primitive_family=f"{surface}/{behavior.family}", + backward_behavior=str(behavior.behavior), + inputs=tuple(str(item) for item in getattr(op, "inputs", ())), + outputs=tuple(str(item) for item in getattr(op, "outputs", ())), + attributes=( + *tuple((str(key), str(value)) for key, value in getattr(op, "attributes", ())), + ("compiled_readout_rule", str(getattr(readout_program, "rule_name", "readout"))), + ("compiled_readout_lowering_kind", str(getattr(readout_program, "lowering_kind", ""))), + ), + flat_bucket_identity=( + f"surface={surface}", + f"rule={getattr(readout_program, 'rule_name', 'readout')}", + f"pool={getattr(readout_program, 'pool', 'unknown')}", + f"node={int(getattr(op, 'node_index', -1))}", + ), + parameter_inputs=tuple(str(item) for item in getattr(op, "parameter_inputs", ())), + ) + ) + + +def _append_parameter_reduction_primitive_rows( + rows: list[TemporalPrimitiveRow], + *, + runtime: Any, + buckets: tuple[Any, ...], +) -> None: + behavior = cuda_nn_primitive_backward_behavior("reduction_boundary") + for bucket_ordinal, bucket in enumerate(buckets): + transition_program = _transition_program_for_bucket(runtime, bucket) + for parameter in transition_program.parameter_bindings: + parameter_name = str(getattr(parameter, "parameter", "parameter")) + rows.append( + TemporalPrimitiveRow( + bucket_ordinal=TEMPORAL_PARAMETER_REDUCTION_BUCKET_ORDINAL, + receiver_start=int(bucket.backend_start), + receiver_stop=int(bucket.backend_stop), + primitive="reduction_boundary", + primitive_family=f"parameter/{behavior.family}", + backward_behavior=str(behavior.behavior), + inputs=(f"grad:{parameter_name}",), + outputs=(f"param_grad:{parameter_name}",), + attributes=(("parameter", parameter_name), ("bucket_ordinal", str(bucket_ordinal))), + flat_bucket_identity=( + "surface=parameter_reduction", + *tuple(str(item) for item in bucket.flat_bucket_identity), + f"parameter={parameter_name}", + ), + parameter_inputs=(parameter_name,), + ) + ) + + +def temporal_table_transition_recurrent_bucket_kinds( + table: TemporalPrimitiveTablePlan, +) -> dict[int, str]: + by_bucket: dict[int, str] = {} + for row in table.primitive_rows: + if temporal_transition_tape_kind(row.primitive) is None: + continue + bucket_ordinal = int(row.bucket_ordinal) + previous = by_bucket.get(bucket_ordinal) + if previous is not None and previous != row.primitive: + by_bucket[bucket_ordinal] = "ambiguous" + continue + by_bucket[bucket_ordinal] = row.primitive + return by_bucket + + +def temporal_table_transition_kind_labels(table: TemporalPrimitiveTablePlan) -> frozenset[str]: + return frozenset( + tape_kind + for row in table.primitive_rows + if (tape_kind := temporal_transition_tape_kind(row.primitive)) is not None + ) + + +def temporal_table_full_tape_extra_state_factors(table: TemporalPrimitiveTablePlan) -> dict[int, int]: + factors: dict[int, int] = {} + for row in table.primitive_rows: + factor = temporal_full_tape_extra_state_factor(row.primitive) + if factor <= 0: + continue + bucket_ordinal = int(row.bucket_ordinal) + factors[bucket_ordinal] = max(int(factor), factors.get(bucket_ordinal, 0)) + return factors + + +def temporal_primitive_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows: list[list[int]] = [] + for row in table.primitive_rows: + opcode = temporal_primitive_opcode(row.primitive) + rows.append( + [ + int(opcode), + int(row.receiver_start), + int(row.receiver_count), + int(row.bucket_ordinal), + ] + ) + if not rows: + return torch.empty((0, 4), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_executor_rows(table: TemporalPrimitiveTablePlan) -> tuple[TemporalReverseExecutorRow, ...]: + rows: list[TemporalReverseExecutorRow] = [] + message_rows = tuple( + (row_index, row) + for row_index, row in enumerate(table.primitive_rows) + if int(row.bucket_ordinal) == TEMPORAL_MESSAGE_BUCKET_ORDINAL and surface_for_temporal_row(row) == "message" + ) + if message_rows: + message_values = tuple(row for _row_index, row in message_rows) + message_pattern = temporal_executor_strategy_registry().match_reverse( + surface="message", + bucket_ordinal=TEMPORAL_MESSAGE_BUCKET_ORDINAL, + rows=message_values, + ) + if message_pattern is None: + raise RuntimeError( + "Fabric temporal reverse table has no primitive executor for message binding rows " + f"{tuple((row.primitive, row.parameter_inputs) for row in message_values)!r}; " + "register a reverse message primitive executor for these compiler rows" + ) + receiver_count = max((int(row.receiver_count) for _row_index, row in message_rows), default=0) + parameter_bindings = tuple( + dict.fromkeys( + str(parameter) for _row_index, row in message_rows for parameter in tuple(row.parameter_inputs) + ) + ) + rows.append( + TemporalReverseExecutorRow( + executor_id=message_pattern.executor_id, + executor_name=message_pattern.executor_name, + surface="message", + bucket_ordinal=TEMPORAL_MESSAGE_BUCKET_ORDINAL, + primitive_row_start=int(message_rows[0][0]), + primitive_row_count=len(message_rows), + receiver_start=0, + receiver_count=int(receiver_count), + parameter_bindings=parameter_bindings, + ) + ) + readout_rows = tuple( + (row_index, row) + for row_index, row in enumerate(table.primitive_rows) + if int(row.bucket_ordinal) == TEMPORAL_READOUT_BUCKET_ORDINAL + and surface_for_temporal_row(row) in {"readout", "readout_boundary"} + ) + if readout_rows: + readout_values = tuple(row for _row_index, row in readout_rows) + readout_pattern = temporal_executor_strategy_registry().match_reverse( + surface="readout", + bucket_ordinal=TEMPORAL_READOUT_BUCKET_ORDINAL, + rows=readout_values, + ) + if readout_pattern is None: + raise RuntimeError( + "Fabric temporal reverse table has no primitive executor for readout binding rows " + f"{tuple((row.primitive, row.parameter_inputs) for row in readout_values)!r}; " + "add a reverse readout primitive executor instead of borrowing the forward readout executor" + ) + _require_parameter_binding_rows(table, readout_values) + receiver_count = max((int(row.receiver_count) for _row_index, row in readout_rows), default=0) + parameter_bindings = tuple( + dict.fromkeys( + str(parameter) for _row_index, row in readout_rows for parameter in tuple(row.parameter_inputs) + ) + ) + rows.append( + TemporalReverseExecutorRow( + executor_id=readout_pattern.executor_id, + executor_name=readout_pattern.executor_name, + surface="readout", + bucket_ordinal=TEMPORAL_READOUT_BUCKET_ORDINAL, + primitive_row_start=int(readout_rows[0][0]), + primitive_row_count=len(readout_rows), + receiver_start=0, + receiver_count=int(receiver_count), + parameter_bindings=parameter_bindings, + ) + ) + seen_transition_rows: set[tuple[str, int, int]] = set() + covered_transition_rows: set[int] = set() + transition_parameters_by_bucket: dict[int, tuple[str, ...]] = {} + for bucket_ordinal in range(int(table.bucket_count)): + transition_parameters_by_bucket[bucket_ordinal] = tuple( + dict.fromkeys( + str(parameter) + for row in table.primitive_rows + if int(row.bucket_ordinal) == bucket_ordinal and surface_for_temporal_row(row) == "transition" + for parameter in tuple(row.parameter_inputs) + ) + ) + + def append_transition_reverse_row( + row_index: int, + row: TemporalPrimitiveRow, + ) -> None: + nonlocal rows + pattern = temporal_executor_strategy_registry().match_reverse( + surface="transition", + bucket_ordinal=int(row.bucket_ordinal), + rows=(row,), + ) + if pattern is None: + return + key = (pattern.executor_name, int(row.bucket_ordinal), int(row_index)) + if key in seen_transition_rows: + return + seen_transition_rows.add(key) + parameter_bindings = _transition_reverse_row_parameter_bindings( + row, + transition_parameters_by_bucket.get(int(row.bucket_ordinal), ()), + ) + rows.append( + TemporalReverseExecutorRow( + executor_id=int(pattern.executor_id), + executor_name=pattern.executor_name, + surface="transition", + bucket_ordinal=int(row.bucket_ordinal), + primitive_row_start=int(row_index), + primitive_row_count=1, + receiver_start=int(row.receiver_start), + receiver_count=int(row.receiver_count), + parameter_bindings=parameter_bindings, + ) + ) + covered_transition_rows.update( + _transition_rows_covered_by_reverse_row( + table, + bucket_ordinal=int(row.bucket_ordinal), + reverse_primitive_row_index=int(row_index), + ) + ) + + for row_index, row in enumerate(table.primitive_rows): + if surface_for_temporal_row(row) != "transition": + continue + primitive_record = transition_primitive_executor_record_for_lowered_primitive(row.primitive) + if primitive_record is None or _transition_reverse_uses_generic_parameter_schema(primitive_record): + continue + append_transition_reverse_row(int(row_index), row) + for row_index, row in reversed(tuple(enumerate(table.primitive_rows))): + if surface_for_temporal_row(row) != "transition": + continue + if int(row_index) in covered_transition_rows: + continue + append_transition_reverse_row(int(row_index), row) + _require_transition_reverse_coverage(table, rows) + return tuple(rows) + + +def _transition_reverse_row_parameter_bindings( + row: TemporalPrimitiveRow, + bucket_parameter_bindings: tuple[str, ...], +) -> tuple[str, ...]: + primitive_record = transition_primitive_executor_record_for_lowered_primitive(row.primitive) + if primitive_record is None: + return tuple(row.parameter_inputs) + if not primitive_record.parameter_bindings: + return () + if _transition_reverse_uses_generic_parameter_schema(primitive_record): + return tuple(row.parameter_inputs) + return bucket_parameter_bindings + + +def _transition_reverse_uses_generic_parameter_schema(primitive_record: object) -> bool: + parameter_bindings = tuple(getattr(primitive_record, "parameter_bindings", ())) + return bool(parameter_bindings) and set(parameter_bindings) <= {"weight", "bias", "eps"} + + +def _transition_rows_covered_by_reverse_row( + table: TemporalPrimitiveTablePlan, + *, + bucket_ordinal: int, + reverse_primitive_row_index: int, +) -> set[int]: + covered = {int(reverse_primitive_row_index)} + primitive_row = table.primitive_rows[int(reverse_primitive_row_index)] + primitive_record = transition_primitive_executor_record_for_lowered_primitive(primitive_row.primitive) + if primitive_record is None: + return covered + covered_inputs = set(primitive_record.reverse_input_bindings) + covered_parameters = set(primitive_record.parameter_bindings) + for candidate_index, candidate_row in enumerate(table.primitive_rows): + if ( + int(candidate_row.bucket_ordinal) != int(bucket_ordinal) + or surface_for_temporal_row(candidate_row) != "transition" + ): + continue + if set(candidate_row.outputs) & covered_inputs: + covered.add(int(candidate_index)) + if set(candidate_row.parameter_inputs) & covered_parameters: + covered.add(int(candidate_index)) + return covered + + +def _require_transition_reverse_coverage( + table: TemporalPrimitiveTablePlan, + reverse_rows: list[TemporalReverseExecutorRow], +) -> None: + covered_row_indices: set[int] = set() + transition_rows_by_bucket: dict[int, tuple[tuple[int, TemporalPrimitiveRow], ...]] = {} + for row_index, row in enumerate(table.primitive_rows): + if surface_for_temporal_row(row) != "transition": + continue + transition_rows_by_bucket.setdefault(int(row.bucket_ordinal), ()) + transition_rows_by_bucket[int(row.bucket_ordinal)] = ( + *transition_rows_by_bucket[int(row.bucket_ordinal)], + (int(row_index), row), + ) + for reverse_row in reverse_rows: + if reverse_row.surface != "transition": + continue + primitive_row = table.primitive_rows[int(reverse_row.primitive_row_start)] + primitive_record = transition_primitive_executor_record_for_lowered_primitive(primitive_row.primitive) + covered_row_indices.add(int(reverse_row.primitive_row_start)) + if primitive_record is None: + continue + covered_inputs = set(primitive_record.reverse_input_bindings) + covered_parameters = set(primitive_record.parameter_bindings) + for candidate_index, candidate_row in transition_rows_by_bucket.get(int(reverse_row.bucket_ordinal), ()): + if set(candidate_row.outputs) & covered_inputs: + covered_row_indices.add(int(candidate_index)) + if set(candidate_row.parameter_inputs) & covered_parameters: + covered_row_indices.add(int(candidate_index)) + missing = tuple( + (row_index, row) + for row_index, row in enumerate(table.primitive_rows) + if surface_for_temporal_row(row) == "transition" and int(row_index) not in covered_row_indices + ) + if missing: + raise RuntimeError( + "Fabric temporal reverse table has transition primitive rows without compiler-owned adjoint coverage: " + + ", ".join( + f"row={int(row_index)} primitive={row.primitive} bucket={int(row.bucket_ordinal)}" + for row_index, row in missing + ) + + "; register reverse primitive executors or a composite reverse executor that declares the covered rows" + ) + + +def temporal_reverse_executor_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = [ + [ + int(row.executor_id), + int(row.primitive_row_start), + int(row.primitive_row_count), + int(row.bucket_ordinal), + int(row.receiver_start), + int(row.receiver_count), + ] + for row in temporal_reverse_executor_rows(table) + ] + if not rows: + return torch.empty((0, 6), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_reverse_executor_summaries(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(row.summary for row in temporal_reverse_executor_rows(table)) + + +def temporal_tensor_binding_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows: list[list[int]] = [] + for binding in table.tensor_bindings: + surface_opcode = temporal_surface_opcode(binding.surface) + primitive_opcode = temporal_primitive_opcode(binding.primitive) + binding_kind_opcode = temporal_binding_kind_opcode(binding.binding_kind) + rows.append( + [ + int(binding.binding_index), + int(binding.row_index), + int(binding.bucket_ordinal), + int(surface_opcode), + int(primitive_opcode), + int(binding_kind_opcode), + len(binding.source_bindings), + ] + ) + if not rows: + return torch.empty((0, 7), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def validate_temporal_supported_scan_binding_projection(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(row.scan_projection_summary for row in temporal_forward_executor_rows(table)) + + +def temporal_forward_executor_rows(table: TemporalPrimitiveTablePlan) -> tuple[TemporalForwardExecutorRow, ...]: + rows: list[TemporalForwardExecutorRow] = [] + rows.append(_message_forward_executor_row(table)) + rows.append(_readout_forward_executor_row(table)) + rows.extend(_transition_forward_executor_rows(table)) + return tuple(rows) + + +def temporal_forward_executor_rows_tensor(table: TemporalPrimitiveTablePlan) -> torch.Tensor: + rows = [ + [ + int(row.executor_id), + int(row.primitive_row_start), + int(row.primitive_row_count), + int(row.bucket_ordinal), + int(row.receiver_start), + int(row.receiver_count), + ] + for row in temporal_forward_executor_rows(table) + ] + if not rows: + return torch.empty((0, 6), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def temporal_forward_executor_summaries(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(row.summary for row in temporal_forward_executor_rows(table)) + + +def temporal_tensor_binding_summaries(table: TemporalPrimitiveTablePlan) -> tuple[str, ...]: + return tuple(binding.summary for binding in table.tensor_bindings) + + +def _build_tensor_binding_rows( + *, + runtime: Any, + rows: tuple[TemporalPrimitiveRow, ...], + buckets: tuple[Any, ...], +) -> tuple[TemporalTensorBindingRow, ...]: + transition_sources_by_bucket = _transition_parameter_sources_by_bucket(runtime, buckets) + message_sources = _message_parameter_sources(runtime) + readout_sources = _readout_parameter_sources(runtime) + bindings: list[TemporalTensorBindingRow] = [] + for row_index, row in enumerate(rows): + surface = surface_for_temporal_row(row) + parameter_sources = _parameter_sources_for_row( + row, + surface=surface, + transition_sources_by_bucket=transition_sources_by_bucket, + message_sources=message_sources, + readout_sources=readout_sources, + ) + parameter_names = set(row.parameter_inputs) + for name in row.inputs: + kind = "parameter" if name in parameter_names else "input" + sources = parameter_sources.get(name, ()) + if kind == "parameter" and not sources: + raise RuntimeError( + "Fabric temporal tensor binding row is missing compiled parameter binding " + f"for {surface}:{row.primitive}:{name}" + ) + bindings.append( + TemporalTensorBindingRow( + binding_index=len(bindings), + row_index=int(row_index), + bucket_ordinal=int(row.bucket_ordinal), + surface=surface, + primitive=str(row.primitive), + binding_kind=kind, + logical_name=str(name), + source_bindings=tuple(sources), + flat_bucket_identity=tuple(row.flat_bucket_identity), + ) + ) + for name in row.parameter_inputs: + if name in row.inputs: + continue + sources = parameter_sources.get(name, ()) + if not sources: + raise RuntimeError( + "Fabric temporal tensor binding row is missing compiled parameter binding " + f"for {surface}:{row.primitive}:{name}" + ) + bindings.append( + TemporalTensorBindingRow( + binding_index=len(bindings), + row_index=int(row_index), + bucket_ordinal=int(row.bucket_ordinal), + surface=surface, + primitive=str(row.primitive), + binding_kind="parameter", + logical_name=str(name), + source_bindings=tuple(sources), + flat_bucket_identity=tuple(row.flat_bucket_identity), + ) + ) + for name in row.outputs: + bindings.append( + TemporalTensorBindingRow( + binding_index=len(bindings), + row_index=int(row_index), + bucket_ordinal=int(row.bucket_ordinal), + surface=surface, + primitive=str(row.primitive), + binding_kind="output", + logical_name=str(name), + source_bindings=(), + flat_bucket_identity=tuple(row.flat_bucket_identity), + ) + ) + return tuple(bindings) + + +def _message_forward_executor_row(table: TemporalPrimitiveTablePlan) -> TemporalForwardExecutorRow: + rows = [ + (row_index, row) + for row_index, row in enumerate(table.primitive_rows) + if int(row.bucket_ordinal) == TEMPORAL_MESSAGE_BUCKET_ORDINAL and surface_for_temporal_row(row) == "message" + ] + if not rows: + raise RuntimeError("Fabric temporal executor binding plan has no compiled message rows") + row_values = tuple(row for _row_index, row in rows) + pattern = temporal_executor_strategy_registry().match_forward( + surface="message", bucket_ordinal=TEMPORAL_MESSAGE_BUCKET_ORDINAL, rows=row_values + ) + if pattern is None: + raise RuntimeError( + "Fabric temporal executor binding plan has no primitive executor for message binding rows " + f"{tuple((row.primitive, row.parameter_inputs) for row in row_values)!r}; " + "register a message primitive executor for these compiler rows" + ) + _require_parameter_binding_rows(table, row_values) + return _executor_row_from_primitive_rows( + executor_name=pattern.executor_name, + surface="message", + bucket_ordinal=TEMPORAL_MESSAGE_BUCKET_ORDINAL, + rows=tuple(rows), + ) + + +def _readout_forward_executor_row(table: TemporalPrimitiveTablePlan) -> TemporalForwardExecutorRow: + rows = [ + (row_index, row) + for row_index, row in enumerate(table.primitive_rows) + if int(row.bucket_ordinal) == TEMPORAL_READOUT_BUCKET_ORDINAL + and surface_for_temporal_row(row) in {"readout", "readout_boundary"} + ] + if not rows: + raise RuntimeError("Fabric temporal executor binding plan has no compiled readout rows") + row_values = tuple(row for _row_index, row in rows) + pattern = temporal_executor_strategy_registry().match_forward( + surface="readout", bucket_ordinal=TEMPORAL_READOUT_BUCKET_ORDINAL, rows=row_values + ) + if pattern is None: + raise RuntimeError( + "Fabric temporal executor binding plan has no primitive executor for readout binding rows " + f"{tuple((row.primitive, row.parameter_inputs) for row in row_values)!r}; " + "register a readout primitive executor for these compiler rows" + ) + _require_parameter_binding_rows(table, row_values) + return _executor_row_from_primitive_rows( + executor_name=pattern.executor_name, + surface="readout", + bucket_ordinal=TEMPORAL_READOUT_BUCKET_ORDINAL, + rows=tuple(rows), + ) + + +def _transition_forward_executor_rows(table: TemporalPrimitiveTablePlan) -> tuple[TemporalForwardExecutorRow, ...]: + executor_rows: list[TemporalForwardExecutorRow] = [] + for bucket_ordinal in range(int(table.bucket_count)): + rows = [ + (row_index, row) + for row_index, row in enumerate(table.primitive_rows) + if int(row.bucket_ordinal) == bucket_ordinal and surface_for_temporal_row(row) == "transition" + ] + row_values = tuple(row for _row_index, row in rows) + pattern = temporal_executor_strategy_registry().match_forward( + surface="transition", + bucket_ordinal=bucket_ordinal, + rows=row_values, + ) + if pattern is not None: + _require_parameter_binding_rows(table, row_values) + executor_rows.append( + _executor_row_from_primitive_rows( + executor_name=pattern.executor_name, + surface="transition", + bucket_ordinal=bucket_ordinal, + rows=tuple(rows), + ) + ) + continue + for row_index, row in rows: + primitive_pattern = temporal_executor_strategy_registry().match_forward( + surface="transition", + bucket_ordinal=bucket_ordinal, + rows=(row,), + ) + if primitive_pattern is None: + raise RuntimeError( + "Fabric temporal executor binding plan has no primitive executor for transition binding row " + f"bucket={bucket_ordinal} row={(row.primitive, row.parameter_inputs)!r}; " + "register a transition primitive executor for this compiler row" + ) + _require_parameter_binding_rows(table, (row,)) + executor_rows.append( + _executor_row_from_primitive_rows( + executor_name=primitive_pattern.executor_name, + surface="transition", + bucket_ordinal=bucket_ordinal, + rows=((row_index, row),), + ) + ) + return tuple(executor_rows) + + +def _executor_row_from_primitive_rows( + *, + executor_name: str, + surface: str, + bucket_ordinal: int, + rows: tuple[tuple[int, TemporalPrimitiveRow], ...], +) -> TemporalForwardExecutorRow: + if not rows: + raise RuntimeError(f"Fabric temporal forward executor {executor_name} has no primitive rows") + pattern = temporal_executor_strategy_registry().match_forward( + surface=surface, + bucket_ordinal=int(bucket_ordinal), + rows=tuple(row for _row_index, row in rows), + ) + if pattern is None or pattern.executor_name != executor_name: + raise RuntimeError( + f"Fabric temporal forward executor {executor_name} is not registered for surface={surface}, " + f"bucket={int(bucket_ordinal)}" + ) + row_indices = tuple(int(row_index) for row_index, _row in rows) + expected_indices = tuple(range(row_indices[0], row_indices[0] + len(row_indices))) + if row_indices != expected_indices: + raise RuntimeError( + f"Fabric temporal forward executor {executor_name} requires contiguous primitive rows; got {row_indices!r}" + ) + receiver_start = int(rows[0][1].receiver_start) + receiver_count = int(rows[0][1].receiver_count) + for _row_index, row in rows: + if int(row.receiver_start) != receiver_start or int(row.receiver_count) != receiver_count: + raise RuntimeError(f"Fabric temporal forward executor {executor_name} has inconsistent receiver ranges") + parameters = tuple(dict.fromkeys(str(parameter) for _row_index, row in rows for parameter in row.parameter_inputs)) + return TemporalForwardExecutorRow( + executor_id=int(pattern.executor_id), + executor_name=executor_name, + surface=surface, + bucket_ordinal=int(bucket_ordinal), + primitive_row_start=int(row_indices[0]), + primitive_row_count=len(row_indices), + receiver_start=receiver_start, + receiver_count=receiver_count, + parameter_bindings=parameters, + ) + + +def _require_parameter_binding_rows( + table: TemporalPrimitiveTablePlan, + rows: Sequence[TemporalPrimitiveRow], +) -> None: + available = { + (int(binding.row_index), str(binding.logical_name)) + for binding in table.tensor_bindings + if binding.binding_kind == "parameter" and binding.source_bindings + } + for row in rows: + row_index = table.primitive_rows.index(row) + for parameter in row.parameter_inputs: + if (int(row_index), str(parameter)) not in available: + raise RuntimeError( + "Fabric temporal executor binding plan is missing compiler-owned tensor binding " + f"for row={row_index}, parameter={parameter!r}" + ) + + +def _transition_parameter_sources_by_bucket( + runtime: Any, + buckets: tuple[Any, ...], +) -> dict[int, dict[str, tuple[str, ...]]]: + sources_by_bucket: dict[int, dict[str, tuple[str, ...]]] = {} + for bucket_ordinal, bucket in enumerate(buckets): + program = _transition_program_for_bucket(runtime, bucket) + parameter_sources: dict[str, tuple[str, ...]] = {} + for item in getattr(program, "parameter_bindings", ()) or (): + parameter = str(getattr(item, "parameter", "")) + if not parameter: + continue + parameter_sources[parameter] = tuple( + f"{getattr(binding, 'kind', '')}:{getattr(binding, 'source', '')}" + for binding in tuple(getattr(item, "bindings", ()) or ()) + ) + population_materialized = bucket.static_tensors.get("population_materialized") + population_params = ( + population_materialized.get(_flat_bucket_name(bucket)) + if isinstance(population_materialized, dict) + else None + ) + if isinstance(population_params, dict): + for scalar_name in ("outnorm_eps", "activation_id"): + if torch.is_tensor(population_params.get(scalar_name)): + parameter_sources.setdefault(scalar_name, (f"cell_param:{scalar_name}",)) + sources_by_bucket[int(bucket_ordinal)] = parameter_sources + return sources_by_bucket + + +def _message_parameter_sources(runtime: Any) -> dict[str, tuple[str, ...]]: + message_rule = getattr(getattr(runtime, "backend_ir", None), "message_rule", None) + out: dict[str, tuple[str, ...]] = {} + for parameter in tuple(getattr(message_rule, "parameters", ()) or ()): + name = str(getattr(parameter, "name", "")) + if not name: + continue + declaration_sources = ( + "message_parameter:" + f"role={getattr(parameter, 'role', '')}:" + f"sharing={getattr(parameter, 'sharing_scope', '')}:" + f"groups={int(getattr(parameter, 'group_count', 1))}", + ) + static_sources = tuple(f"static_tensor:{key}" for key in _MESSAGE_RUNTIME_STATIC_TENSOR_KEYS.get(name, ())) + out[name] = declaration_sources + static_sources + return out + + +def _readout_parameter_sources(runtime: Any) -> dict[str, tuple[str, ...]]: + readout_program = getattr(getattr(runtime, "backend_ir", None), "readout_program", None) + if readout_program is None: + return {} + source = ( + "readout_parameter:" + f"rule={getattr(readout_program, 'rule_name', '')}:" + f"pool={getattr(readout_program, 'pool', '')}:" + f"slots={int(getattr(readout_program, 'readout_slots', 0))}" + ) + return { + str(parameter): ( + source, + *(f"static_tensor:{key}" for key in _READOUT_RUNTIME_STATIC_TENSOR_KEYS.get(str(parameter), ())), + *(f"runtime_attr:{key}" for key in _READOUT_RUNTIME_ATTR_KEYS.get(str(parameter), ())), + ) + for op in tuple(getattr(readout_program, "primitive_ops", ()) or ()) + for parameter in tuple(getattr(op, "parameter_inputs", ()) or ()) + } + + +def _parameter_sources_for_row( + row: TemporalPrimitiveRow, + *, + surface: str, + transition_sources_by_bucket: dict[int, dict[str, tuple[str, ...]]], + message_sources: dict[str, tuple[str, ...]], + readout_sources: dict[str, tuple[str, ...]], +) -> dict[str, tuple[str, ...]]: + if surface == "transition": + return transition_sources_by_bucket.get(int(row.bucket_ordinal), {}) + if surface == "parameter_reduction": + bucket_attr = next((value for key, value in row.attributes if key == "bucket_ordinal"), None) + if bucket_attr is None: + return {} + return transition_sources_by_bucket.get(int(bucket_attr), {}) + if surface == "message": + return message_sources + if surface in {"readout", "readout_boundary"}: + return readout_sources + return {} + + +def _append_schema_slots( + slots: list[TemporalTensorTableSlot], + *, + bucket_ordinal: int, + table_role: str, + schema: tuple[Any, ...], +) -> None: + for slot, item in enumerate(schema): + slots.append( + TemporalTensorTableSlot( + bucket_ordinal=int(bucket_ordinal), + table_role=table_role, + slot=int(slot), + key=str(getattr(item, "name", "")), + semantic_kind=str(getattr(item, "semantic_kind", "generic")), + ) + ) + + +__all__ = [ + "TemporalPrimitiveRow", + "TemporalPrimitiveTablePlan", + "TemporalForwardExecutorRow", + "TemporalReverseExecutorRow", + "TemporalTensorBindingRow", + "TemporalTensorTableSlot", + "build_temporal_primitive_table_plan", + "temporal_forward_executor_rows", + "temporal_forward_executor_rows_tensor", + "temporal_forward_executor_summaries", + "temporal_primitive_rows_tensor", + "temporal_reverse_executor_rows", + "temporal_reverse_executor_rows_tensor", + "temporal_reverse_executor_summaries", + "temporal_tensor_binding_rows_tensor", + "temporal_tensor_binding_summaries", + "temporal_table_full_tape_extra_state_factors", + "temporal_table_transition_recurrent_bucket_kinds", + "temporal_table_transition_kind_labels", + "validate_temporal_supported_scan_binding_projection", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/compiler/verification.py b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/verification.py new file mode 100644 index 00000000..5fcbc4a6 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/compiler/verification.py @@ -0,0 +1,425 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from .executor_bindings import ( + TemporalExecutorBindingPlan, + build_temporal_forward_executor_binding_plan, + build_temporal_reverse_executor_binding_plan, +) +from .executor_patterns import surface_for_temporal_row +from .primitive_dispatch import ( + TemporalPrimitiveExecutorPlan, + build_temporal_primitive_executor_plan, +) +from .tables import TemporalPrimitiveTablePlan + + +TemporalCompilerPassName = Literal[ + "semantic_ir", + "normalize_graph_message_cell", + "shape_dtype_device_analysis", + "primitive_row_ir", + "canonicalize_row_groups", + "assign_tensor_roles_layouts", + "plan_active_regions", + "plan_boundary_carry_tape", + "plan_memory_liveness_workspace", + "plan_forward_physical", + "plan_backward_physical", + "match_strategies", + "filter_strategy_legality", + "select_strategy_cost", + "build_executable_launch_plan", + "emit_audit_metadata", +] + +TemporalStrategyRejectionCode = Literal[ + "UNSUPPORTED_PATTERN", + "UNSUPPORTED_DTYPE", + "UNSUPPORTED_LAYOUT", + "INSUFFICIENT_WORKSPACE", + "RESET_POLICY_MISMATCH", + "TAPE_POLICY_MISMATCH", + "DEVICE_CAPABILITY_MISMATCH", + "SHAPE_OUT_OF_RANGE", + "MISSING_REQUIRED_BINDING", + "MISSING_BACKWARD_COVERAGE", + "HIDDEN_FALLBACK_ROUTE", + "ABI_VERSION_MISMATCH", + "UNVERIFIED_REWRITE", +] + + +_TEMPORAL_COMPILER_PASS_PIPELINE: tuple[TemporalCompilerPassName, ...] = ( + "semantic_ir", + "normalize_graph_message_cell", + "shape_dtype_device_analysis", + "primitive_row_ir", + "canonicalize_row_groups", + "assign_tensor_roles_layouts", + "plan_active_regions", + "plan_boundary_carry_tape", + "plan_memory_liveness_workspace", + "plan_forward_physical", + "plan_backward_physical", + "match_strategies", + "filter_strategy_legality", + "select_strategy_cost", + "build_executable_launch_plan", + "emit_audit_metadata", +) + +_TEMPORAL_STRATEGY_REJECTION_CODES: tuple[TemporalStrategyRejectionCode, ...] = ( + "UNSUPPORTED_PATTERN", + "UNSUPPORTED_DTYPE", + "UNSUPPORTED_LAYOUT", + "INSUFFICIENT_WORKSPACE", + "RESET_POLICY_MISMATCH", + "TAPE_POLICY_MISMATCH", + "DEVICE_CAPABILITY_MISMATCH", + "SHAPE_OUT_OF_RANGE", + "MISSING_REQUIRED_BINDING", + "MISSING_BACKWARD_COVERAGE", + "HIDDEN_FALLBACK_ROUTE", + "ABI_VERSION_MISMATCH", + "UNVERIFIED_REWRITE", +) + +_TEMPORAL_COMPILER_SCHEMA_VERSIONS = ( + "PrimitiveRowABI=1", + "TensorRoleTableABI=1", + "ExecutorPlanABI=1", + "BackwardTapeABI=1", + "MetadataSchemaABI=1", +) + + +@dataclass(frozen=True) +class TemporalEffectAnnotation: + row_index: int + surface: str + bucket_ordinal: int + effect: str + target: str + + @property + def summary(self) -> str: + return ( + f"row={int(self.row_index)}" + f",surface={self.surface}" + f",bucket={int(self.bucket_ordinal)}" + f",effect={self.effect}" + f",target={self.target}" + ) + + +@dataclass(frozen=True) +class TemporalVerifierIssue: + code: TemporalStrategyRejectionCode + pass_name: TemporalCompilerPassName + severity: Literal["error", "blocker"] + subject: str + reason: str + row_index: int | None = None + surface: str | None = None + bucket_ordinal: int | None = None + + @property + def summary(self) -> str: + row = "*" if self.row_index is None else str(int(self.row_index)) + bucket = "*" if self.bucket_ordinal is None else str(int(self.bucket_ordinal)) + surface = "-" if self.surface is None else self.surface + return ( + f"severity={self.severity}" + f",code={self.code}" + f",pass={self.pass_name}" + f",row={row}" + f",surface={surface}" + f",bucket={bucket}" + f",subject={self.subject}" + f",reason={self.reason}" + ) + + +@dataclass(frozen=True) +class TemporalCompilerVerificationReport: + status: Literal["ok", "blocked", "error"] + pass_pipeline: tuple[TemporalCompilerPassName, ...] + schema_versions: tuple[str, ...] + effects: tuple[TemporalEffectAnnotation, ...] + issues: tuple[TemporalVerifierIssue, ...] + + @property + def effect_summaries(self) -> tuple[str, ...]: + return tuple(effect.summary for effect in self.effects) + + @property + def issue_summaries(self) -> tuple[str, ...]: + return tuple(issue.summary for issue in self.issues) + + @property + def explain_summary(self) -> tuple[str, ...]: + return ( + f"verification_status={self.status}", + "compiler_pass_pipeline=" + "->".join(self.pass_pipeline), + "schema_versions=" + ",".join(self.schema_versions), + f"effect_count={len(self.effects)}", + f"issue_count={len(self.issues)}", + *self.issue_summaries, + ) + + +def temporal_compiler_pass_pipeline() -> tuple[TemporalCompilerPassName, ...]: + return _TEMPORAL_COMPILER_PASS_PIPELINE + + +def temporal_strategy_rejection_codes() -> tuple[TemporalStrategyRejectionCode, ...]: + return _TEMPORAL_STRATEGY_REJECTION_CODES + + +def temporal_compiler_schema_versions() -> tuple[str, ...]: + return _TEMPORAL_COMPILER_SCHEMA_VERSIONS + + +def verify_temporal_primitive_table( + table: TemporalPrimitiveTablePlan, + *, + executor_plan: TemporalPrimitiveExecutorPlan | None = None, + forward_binding_plan: TemporalExecutorBindingPlan | None = None, + reverse_binding_plan: TemporalExecutorBindingPlan | None = None, +) -> TemporalCompilerVerificationReport: + executor_plan = build_temporal_primitive_executor_plan(table) if executor_plan is None else executor_plan + forward_binding_plan = ( + build_temporal_forward_executor_binding_plan(table) if forward_binding_plan is None else forward_binding_plan + ) + reverse_binding_plan = ( + build_temporal_reverse_executor_binding_plan(table) if reverse_binding_plan is None else reverse_binding_plan + ) + issues: list[TemporalVerifierIssue] = [] + effects: list[TemporalEffectAnnotation] = [] + + if int(table.bucket_count) <= 0: + issues.append( + TemporalVerifierIssue( + code="UNSUPPORTED_PATTERN", + pass_name="primitive_row_ir", + severity="error", + subject="bucket_count", + reason="temporal table must contain at least one transition bucket", + ) + ) + if not table.primitive_rows: + issues.append( + TemporalVerifierIssue( + code="UNSUPPORTED_PATTERN", + pass_name="primitive_row_ir", + severity="error", + subject="primitive_rows", + reason="temporal table has no primitive rows", + ) + ) + + parameter_bindings = { + (int(binding.row_index), str(binding.logical_name)) + for binding in table.tensor_bindings + if str(binding.binding_kind) == "parameter" + } + allowed_surfaces = {"message", "readout", "readout_boundary", "transition", "parameter_reduction"} + for row_index, row in enumerate(table.primitive_rows): + surface = surface_for_temporal_row(row) + bucket_ordinal = int(row.bucket_ordinal) + if not str(row.primitive): + issues.append( + _row_issue( + "UNSUPPORTED_PATTERN", + "primitive_row_ir", + "error", + "primitive", + "primitive row has no primitive name", + row_index, + surface, + bucket_ordinal, + ) + ) + if surface not in allowed_surfaces: + issues.append( + _row_issue( + "UNSUPPORTED_PATTERN", + "assign_tensor_roles_layouts", + "error", + "surface", + f"unsupported temporal surface {surface!r}", + row_index, + surface, + bucket_ordinal, + ) + ) + if int(row.receiver_count) < 0: + issues.append( + _row_issue( + "SHAPE_OUT_OF_RANGE", + "shape_dtype_device_analysis", + "error", + "receiver_count", + "receiver count must be non-negative", + row_index, + surface, + bucket_ordinal, + ) + ) + for parameter in tuple(row.parameter_inputs): + if (int(row_index), str(parameter)) not in parameter_bindings: + issues.append( + _row_issue( + "MISSING_REQUIRED_BINDING", + "assign_tensor_roles_layouts", + "error", + str(parameter), + "parameter input has no tensor binding row", + row_index, + surface, + bucket_ordinal, + ) + ) + effects.extend(_effects_for_row(row, row_index=row_index, surface=surface)) + + for contract in executor_plan.contracts: + if contract.status == "implemented": + continue + code: TemporalStrategyRejectionCode = ( + "MISSING_BACKWARD_COVERAGE" if contract.backward_executor in {"", "unregistered"} else "UNSUPPORTED_PATTERN" + ) + issues.append( + TemporalVerifierIssue( + code=code, + pass_name="filter_strategy_legality", + severity="blocker", + subject=f"{contract.surface}:{contract.primitive}", + reason=contract.reason, + row_index=contract.row_index, + surface=contract.surface, + bucket_ordinal=contract.bucket_ordinal, + ) + ) + for group in executor_plan.fusion_groups: + if group.status == "implemented": + continue + issues.append( + TemporalVerifierIssue( + code="UNSUPPORTED_PATTERN", + pass_name="filter_strategy_legality", + severity="blocker", + subject=f"fusion:{group.group_name}", + reason=group.reason, + surface=group.surface, + bucket_ordinal=group.bucket_ordinal, + ) + ) + for binding_plan in (forward_binding_plan, reverse_binding_plan): + for blocker in binding_plan.blockers: + issues.append( + TemporalVerifierIssue( + code=blocker.code, + pass_name="assign_tensor_roles_layouts", + severity="error", + subject=f"{blocker.direction}:{blocker.executor_name}", + reason=blocker.reason, + surface=blocker.surface, + bucket_ordinal=blocker.bucket_ordinal, + ) + ) + + status: Literal["ok", "blocked", "error"] + if any(issue.severity == "error" for issue in issues): + status = "error" + elif issues: + status = "blocked" + else: + status = "ok" + return TemporalCompilerVerificationReport( + status=status, + pass_pipeline=_TEMPORAL_COMPILER_PASS_PIPELINE, + schema_versions=_TEMPORAL_COMPILER_SCHEMA_VERSIONS, + effects=tuple(effects), + issues=tuple(issues), + ) + + +def _row_issue( + code: TemporalStrategyRejectionCode, + pass_name: TemporalCompilerPassName, + severity: Literal["error", "blocker"], + subject: str, + reason: str, + row_index: int, + surface: str, + bucket_ordinal: int, +) -> TemporalVerifierIssue: + return TemporalVerifierIssue( + code=code, + pass_name=pass_name, + severity=severity, + subject=subject, + reason=reason, + row_index=int(row_index), + surface=surface, + bucket_ordinal=int(bucket_ordinal), + ) + + +def _effects_for_row( + row: object, + *, + row_index: int, + surface: str, +) -> tuple[TemporalEffectAnnotation, ...]: + bucket_ordinal = int(getattr(row, "bucket_ordinal", 0)) + primitive = str(getattr(row, "primitive", "")) + if surface == "message": + effects = ( + ("state_read", "sender_public_state"), + ("parameter_read", ",".join(tuple(getattr(row, "parameter_inputs", ()) or ())) or "-"), + ("message_emit", primitive), + ) + elif surface == "readout": + effects = ( + ("state_read", "output_public_state"), + ("parameter_read", ",".join(tuple(getattr(row, "parameter_inputs", ()) or ())) or "-"), + ("output_emit", primitive), + ) + elif surface == "readout_boundary": + effects = (("materialization_boundary", primitive), ("output_emit", primitive)) + elif surface == "parameter_reduction": + effects = (("grad_read", primitive), ("parameter_grad_emit", ",".join(tuple(row.parameter_inputs)) or "-")) + elif surface == "transition": + effects = ( + ("state_read", "private_state"), + ("message_read", "projected_message"), + ("state_write", "private_state"), + ("tape_policy", primitive), + ) + else: + effects = (("unknown_effect", primitive),) + return tuple( + TemporalEffectAnnotation( + row_index=int(row_index), + surface=surface, + bucket_ordinal=bucket_ordinal, + effect=effect, + target=target, + ) + for effect, target in effects + ) + + +__all__ = [ + "TemporalCompilerVerificationReport", + "TemporalEffectAnnotation", + "TemporalVerifierIssue", + "temporal_compiler_pass_pipeline", + "temporal_compiler_schema_versions", + "temporal_strategy_rejection_codes", + "verify_temporal_primitive_table", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/__init__.py b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/__init__.py new file mode 100644 index 00000000..64bae397 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/__init__.py @@ -0,0 +1 @@ +"""Flat-bucket CUDA bindings, kernels, and layout helpers.""" diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh new file mode 100644 index 00000000..94d50e0b --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh @@ -0,0 +1,331 @@ +// Generated by compiler/native_callables.py. +// Do not edit this catalog by hand; update the compiler registry and regenerate. +// Catalog fingerprint: 1556325898 + +#if defined(REGISTERED_TEMPORAL_NATIVE_FORWARD_TRANSITION_CATALOG) +inline const RegisteredTransitionForwardPrimitiveExecutor* registered_native_transition_forward_primitive_catalog_begin() { + static const RegisteredTransitionForwardPrimitiveExecutor kRegisteredNativeTransitionForwardPrimitiveCatalog[] = { + { + registered_temporal_stable_id_hash_constexpr("program_transition_linear_forward"), + "transition.linear.forward", + run_registered_transition_linear_forward_primitive, + }, + { + registered_temporal_stable_id_hash_constexpr("program_transition_recurrent_matmul_forward"), + "transition.matmul.forward", + run_registered_transition_matmul_forward_primitive, + }, + { + registered_temporal_stable_id_hash_constexpr("program_transition_gated_logspace_recurrence_forward"), + "transition.gated_logspace_recurrence.forward", + run_registered_transition_gated_logspace_forward_primitive, + }, + { + registered_temporal_stable_id_hash_constexpr("program_transition_norm_or_identity_forward"), + "transition.norm_or_identity.forward", + run_registered_transition_norm_or_identity_forward_primitive, + }, + { + registered_temporal_stable_id_hash_constexpr("program_transition_diag_rtu_forward"), + "transition.diag_rtu.forward", + run_registered_transition_diag_rtu_forward_primitive, + }, + { + registered_temporal_stable_id_hash_constexpr("program_transition_tanh_forward"), + "transition.tanh.forward", + run_registered_transition_tanh_forward_primitive, + }, + }; + return kRegisteredNativeTransitionForwardPrimitiveCatalog; +} + +inline const RegisteredTransitionForwardPrimitiveExecutor* registered_native_transition_forward_primitive_catalog_end() { + return registered_native_transition_forward_primitive_catalog_begin() + 6; +} +#undef REGISTERED_TEMPORAL_NATIVE_FORWARD_TRANSITION_CATALOG +#endif + +#if defined(REGISTERED_TEMPORAL_NATIVE_FORWARD_MESSAGE_CATALOG) +inline const RegisteredForwardMessageCarrierStrategy* registered_native_forward_message_catalog_begin() { + static const RegisteredForwardMessageCarrierStrategy kRegisteredNativeForwardMessageCatalog[] = { + { + registered_temporal_stable_id_hash_constexpr("native.forward.msg_attention_project.v1"), + "forward.message.neighborhood_attention_project", + bind_neighborhood_attention_project_message_handler, + run_neighborhood_attention_project_recurrent_kv, + run_neighborhood_attention_project_message, + nullptr, + nullptr, + nullptr, + nullptr, + }, + { + registered_temporal_stable_id_hash_constexpr("native.forward.msg_fixed_slot_context_nudge.v1"), + "forward.message.fixed_slot_context_nudge", + bind_fixed_slot_context_message_handler, + run_fixed_slot_context_recurrent_kv, + run_fixed_slot_context_message, + run_fixed_slot_context_keyless_readout_message, + run_fixed_slot_context_direct_keyless_readout_message, + run_fixed_slot_context_stream_readout_message, + run_fixed_slot_context_stream_transition_input, + }, + { + registered_temporal_stable_id_hash_constexpr("native.forward.msg_fixed_slot_context_gate.v1"), + "forward.message.fixed_slot_context_gate", + bind_fixed_slot_context_message_handler, + run_fixed_slot_context_recurrent_kv, + run_fixed_slot_context_message, + run_fixed_slot_context_keyless_readout_message, + run_fixed_slot_context_direct_keyless_readout_message, + run_fixed_slot_context_stream_readout_message, + run_fixed_slot_context_stream_transition_input, + }, + }; + return kRegisteredNativeForwardMessageCatalog; +} + +inline const RegisteredForwardMessageCarrierStrategy* registered_native_forward_message_catalog_end() { + return registered_native_forward_message_catalog_begin() + 3; +} +#undef REGISTERED_TEMPORAL_NATIVE_FORWARD_MESSAGE_CATALOG +#endif + +#if defined(REGISTERED_TEMPORAL_NATIVE_FORWARD_READOUT_CATALOG) +inline const RegisteredForwardReadoutStrategy* registered_native_forward_readout_catalog_begin() { + static const RegisteredForwardReadoutStrategy kRegisteredNativeForwardReadoutCatalog[] = { + { + registered_temporal_stable_id_hash_constexpr("native.forward.output_projection_reduction_boundary.v1"), + "native.forward.output_projection_reduction_boundary.v1", + bind_projection_reduction_boundary_readout_handler, + run_projection_reduction_boundary_readout_message, + run_projection_reduction_boundary_readout_projection, + run_projection_reduction_boundary_readout_projection_into, + }, + }; + return kRegisteredNativeForwardReadoutCatalog; +} + +inline const RegisteredForwardReadoutStrategy* registered_native_forward_readout_catalog_end() { + return registered_native_forward_readout_catalog_begin() + 1; +} +#undef REGISTERED_TEMPORAL_NATIVE_FORWARD_READOUT_CATALOG +#endif + +#if defined(REGISTERED_TEMPORAL_NATIVE_REVERSE_MESSAGE_CATALOG) +inline const RegisteredReverseMessageStrategy* registered_native_reverse_message_catalog_begin() { + static const RegisteredReverseMessageStrategy kRegisteredNativeReverseMessageCatalog[] = { + { + registered_temporal_stable_id_hash_constexpr("native.reverse.msg_attention_project.v1"), + "reverse.message.neighborhood_attention_project", + run_neighborhood_attention_project_recurrent_kv_backward, + run_neighborhood_attention_project_recurrent_message_backward, + run_neighborhood_attention_project_initial_recurrent_kv_backward, + run_neighborhood_attention_project_boundary_kv_backward, + run_neighborhood_attention_project_recurrent_kv_forward_recompute, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.msg_fixed_slot_context_nudge.v1"), + "reverse.message.fixed_slot_context_nudge", + run_fixed_slot_context_recurrent_kv_backward, + run_fixed_slot_context_recurrent_message_backward, + run_fixed_slot_context_initial_recurrent_kv_backward, + run_fixed_slot_context_boundary_kv_backward, + run_fixed_slot_context_recurrent_kv_forward_recompute, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.msg_fixed_slot_context_gate.v1"), + "reverse.message.fixed_slot_context_gate", + run_fixed_slot_context_recurrent_kv_backward, + run_fixed_slot_context_recurrent_message_backward, + run_fixed_slot_context_initial_recurrent_kv_backward, + run_fixed_slot_context_boundary_kv_backward, + run_fixed_slot_context_recurrent_kv_forward_recompute, + }, + }; + return kRegisteredNativeReverseMessageCatalog; +} + +inline const RegisteredReverseMessageStrategy* registered_native_reverse_message_catalog_end() { + return registered_native_reverse_message_catalog_begin() + 3; +} +#undef REGISTERED_TEMPORAL_NATIVE_REVERSE_MESSAGE_CATALOG +#endif + +#if defined(REGISTERED_TEMPORAL_NATIVE_REVERSE_READOUT_CATALOG) +inline const RegisteredReverseReadoutStrategy* registered_native_reverse_readout_catalog_begin() { + static const RegisteredReverseReadoutStrategy kRegisteredNativeReverseReadoutCatalog[] = { + { + registered_temporal_stable_id_hash_constexpr("native.reverse.output_projection_reduction_boundary.v1"), + "native.reverse.output_projection_reduction_boundary.v1", + run_projection_reduction_boundary_readout_backward, + run_projection_reduction_boundary_output_message_backward, + }, + }; + return kRegisteredNativeReverseReadoutCatalog; +} + +inline const RegisteredReverseReadoutStrategy* registered_native_reverse_readout_catalog_end() { + return registered_native_reverse_readout_catalog_begin() + 1; +} +#undef REGISTERED_TEMPORAL_NATIVE_REVERSE_READOUT_CATALOG +#endif + +#if defined(REGISTERED_TEMPORAL_NATIVE_REVERSE_TRANSITION_CATALOG) +inline const RegisteredTransitionReversePrimitiveExecutor* registered_native_transition_reverse_primitive_catalog_begin() { + static const RegisteredTransitionReversePrimitiveExecutor kRegisteredNativeTransitionReversePrimitiveCatalog[] = { + { + registered_temporal_stable_id_hash_constexpr("native.reverse.transition_gated_logspace.v1"), + registered_temporal_stable_id_hash_constexpr("program_transition_gated_logspace_recurrence_backward"), + 14, + 6, + 7, + 11, + "reverse.transition.gated_logspace", + run_registered_gated_logspace_reverse_transition_handler, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.transition_diag_rtu.v1"), + registered_temporal_stable_id_hash_constexpr("program_transition_diag_rtu_backward"), + 9, + 9, + 11, + 12, + "native.reverse.transition_diag_rtu.v1", + run_registered_diag_rtu_reverse_transition_handler, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.transition_linear_primitive.v1"), + registered_temporal_stable_id_hash_constexpr("program_transition_linear_backward"), + 2, + 2, + 2, + 3, + "reverse.transition.linear_primitive", + run_registered_linear_reverse_transition_handler, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.transition_matmul_primitive.v1"), + registered_temporal_stable_id_hash_constexpr("program_transition_recurrent_matmul_backward"), + 2, + 1, + 1, + 2, + "reverse.transition.matmul_primitive", + run_registered_matmul_reverse_transition_handler, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.transition_norm_or_identity_primitive.v1"), + registered_temporal_stable_id_hash_constexpr("program_transition_norm_or_identity_backward"), + 2, + 1, + 2, + 2, + "reverse.transition.norm_or_identity_primitive", + run_registered_norm_or_identity_reverse_transition_handler, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.transition_tanh.v1"), + registered_temporal_stable_id_hash_constexpr("program_transition_tanh_backward"), + 2, + 0, + 0, + 1, + "reverse.transition.tanh", + run_registered_tanh_reverse_transition_handler, + }, + }; + return kRegisteredNativeTransitionReversePrimitiveCatalog; +} + +inline const RegisteredTransitionReversePrimitiveExecutor* registered_native_transition_reverse_primitive_catalog_end() { + return registered_native_transition_reverse_primitive_catalog_begin() + 6; +} +#undef REGISTERED_TEMPORAL_NATIVE_REVERSE_TRANSITION_CATALOG +#endif + +#if defined(REGISTERED_TEMPORAL_NATIVE_PARAMETER_REDUCER_CATALOG) +inline const RegisteredParameterReducerHandler* registered_native_parameter_reducer_catalog_begin() { + static const RegisteredParameterReducerHandler kRegisteredNativeParameterReducerCatalog[] = { + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.readout_output.v1"), + "readout_output", + run_registered_readout_output_parameter_reducer_strategy, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.sender_kv_projection.v1"), + "sender_kv_projection", + run_registered_sender_kv_parameter_reducer_strategy, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.recurrent_query.v1"), + "recurrent_query", + run_registered_recurrent_query_parameter_reducer_strategy, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.transition.v1"), + "transition", + run_registered_transition_parameter_reducer_strategy, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.output_query.v1"), + "output_query", + run_registered_output_query_parameter_reducer_strategy, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.fixed_slot_context_message.v1"), + "fixed_slot_context_message", + run_registered_fixed_slot_context_message_parameter_reducer_strategy, + }, + }; + return kRegisteredNativeParameterReducerCatalog; +} + +inline const RegisteredParameterReducerHandler* registered_native_parameter_reducer_catalog_end() { + return registered_native_parameter_reducer_catalog_begin() + 6; +} +#undef REGISTERED_TEMPORAL_NATIVE_PARAMETER_REDUCER_CATALOG +#endif + +#if defined(REGISTERED_TEMPORAL_NATIVE_TRANSITION_TRAINABLE_REDUCER_CATALOG) +inline const RegisteredTransitionTrainableReducerHandler* registered_native_transition_trainable_reducer_catalog_begin() { + static const RegisteredTransitionTrainableReducerHandler kRegisteredNativeTransitionTrainableReducerCatalog[] = { + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.transition.materialized_base.v1"), + "materialized_base", + run_registered_transition_materialized_base_reducer, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.transition.materialized_delta.v1"), + "materialized_delta", + run_registered_transition_materialized_delta_reducer, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.transition.value_to_cell_msg_to_cell.v1"), + "value_to_cell_msg_to_cell", + run_registered_transition_value_to_cell_msg_to_cell_reducer, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.transition.value_to_cell_msg_out.v1"), + "value_to_cell_msg_out", + run_registered_transition_value_to_cell_msg_out_reducer, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.transition.recurrent_bias_slot_embed.v1"), + "recurrent_bias_slot_embed", + run_registered_transition_recurrent_bias_slot_embed_reducer, + }, + { + registered_temporal_stable_id_hash_constexpr("native.reverse.parameter_reduction.transition.recurrent_bias_cell_bias_proj.v1"), + "recurrent_bias_cell_bias_proj", + run_registered_transition_recurrent_bias_cell_bias_proj_reducer, + }, + }; + return kRegisteredNativeTransitionTrainableReducerCatalog; +} + +inline const RegisteredTransitionTrainableReducerHandler* registered_native_transition_trainable_reducer_catalog_end() { + return registered_native_transition_trainable_reducer_catalog_begin() + 6; +} +#undef REGISTERED_TEMPORAL_NATIVE_TRANSITION_TRAINABLE_REDUCER_CATALOG +#endif diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp new file mode 100644 index 00000000..54a61e39 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp @@ -0,0 +1,217 @@ +#include + +#include + +#include + +std::vector flat_bucket_registered_temporal_fused_forward_program_validate_cuda( + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version); + +std::vector flat_bucket_registered_temporal_fused_backward_program_validate_cuda( + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version); + +std::vector flat_bucket_registered_temporal_fused_forward_program_cuda( + const at::Tensor& boundary_seq, + const at::Tensor& recurrent_hidden_initial_backend_order, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& forward_transition_state_carry_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& forward_artifact_merge_rows, + const at::Tensor& forward_output_route_rows, + const at::Tensor& readout_message_producer_consumer_rows, + const at::Tensor& message_transition_producer_consumer_rows, + std::vector forward_reset_tensors, + const at::Tensor& forward_reset_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& memory_runtime_schedule_rows, + const at::Tensor& physical_strategy_rows, + std::vector runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + std::vector forward_program_runtime_tensors, + const at::Tensor& forward_program_runtime_rows, + bool return_final_program_tensors, + bool return_reverse_artifacts, + int64_t schema_version); + +std::vector flat_bucket_registered_temporal_fused_forward_transition_program_cuda( + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + std::vector runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& forward_transition_state_carry_rows, + bool release_dead_input_bindings, + bool allow_terminal_local_state_outputs, + int64_t schema_version); + +std::vector flat_bucket_registered_temporal_fused_reverse_transition_program_cuda( + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& primitive_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version, + bool return_state_grads); + +std::vector flat_bucket_registered_temporal_parameter_reducer_program_cuda( + const at::Tensor& parameter_reducer_rows, + const at::Tensor& parameter_reducer_strategy_rows, + const at::Tensor& parameter_reducer_trainable_role_rows, + const at::Tensor& parameter_reducer_runtime_metadata_rows, + const at::Tensor& transition_source_rows, + const at::Tensor& transition_trainable_rows, + std::vector sender_grad_weight_tensors, + std::vector sender_group_id_tensors, + const at::Tensor& sender_grouped_flags, + std::vector readout_grad_value_to_output_weight_tensors, + std::vector readout_grad_output_cell_bias_tensors, + std::vector recurrent_query_grad_tensors, + std::vector output_query_grad_tensors, + std::vector message_strategy_grad_tensors, + const at::Tensor& message_strategy_grad_rows, + std::vector transition_source_tensors, + std::vector transition_source_recurrent_cell_idx_tensors, + std::vector parameter_output_tensors, + std::vector trainable_param_tensors, + std::vector runtime_metadata_tensors, + int64_t coord_count, + int64_t head_dim, + int64_t value_dim, + int64_t schema_version); + +std::vector> flat_bucket_registered_temporal_fused_backward_program_cuda( + const at::Tensor& grad_output_window, + const at::Tensor& grad_carry_cells, + const at::Tensor& reverse_program_stage_rows, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& forward_artifact_merge_rows, + const at::Tensor& forward_output_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + std::vector> reverse_reset_tensor_groups, + std::vector reverse_reset_row_groups, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& reverse_span_output_rows, + const at::Tensor& transition_reverse_seed_role_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& memory_runtime_schedule_rows, + const at::Tensor& physical_strategy_rows, + std::vector runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + std::vector reverse_program_runtime_tensors, + const at::Tensor& reverse_program_runtime_rows, + std::vector> transition_program_tensor_groups, + std::vector transition_program_tensor_binding_row_groups, + std::vector transition_forward_executor_row_groups, + std::vector transition_reverse_executor_row_groups, + std::vector transition_forward_executor_binding_row_groups, + std::vector transition_reverse_executor_binding_row_groups, + std::vector transition_memory_liveness_row_groups, + std::vector> transition_seed_tensor_groups, + std::vector transition_seed_row_groups, + std::vector transition_dynamic_binding_row_groups, + std::vector transition_output_keep_slot_row_groups, + std::vector transition_parameter_tensors, + const at::Tensor& transition_parameter_rows, + const at::Tensor& transition_recurrent_msg_output_rows, + const at::Tensor& transition_public_y_seed_rows, + const at::Tensor& transition_state_reset_rows, + const at::Tensor& transition_next_seed_output_rows, + bool return_window_start_transition_state_grads, + int64_t schema_version); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "fused_forward_program_validate", + &flat_bucket_registered_temporal_fused_forward_program_validate_cuda, + "Registered Fabric temporal fused forward program row validation over compiler products (CUDA)"); + m.def( + "fused_backward_program_validate", + &flat_bucket_registered_temporal_fused_backward_program_validate_cuda, + "Registered Fabric temporal fused backward program row validation over compiler products (CUDA)"); + m.def( + "fused_forward_program_execute", + &flat_bucket_registered_temporal_fused_forward_program_cuda, + "Registered Fabric temporal fused forward program execution over compiler products (CUDA)"); + m.def( + "fused_forward_transition_program_execute", + &flat_bucket_registered_temporal_fused_forward_transition_program_cuda, + "Registered Fabric temporal fused transition program span execution over compiler tensor bindings (CUDA)"); + m.def( + "fused_reverse_transition_program_execute", + &flat_bucket_registered_temporal_fused_reverse_transition_program_cuda, + "Registered Fabric temporal fused reverse transition program span execution over compiler tensor bindings (CUDA)"); + m.def( + "parameter_reducer_program_execute", + &flat_bucket_registered_temporal_parameter_reducer_program_cuda, + "Registered Fabric temporal parameter reducer row program over compiler reducer rows (CUDA)"); + m.def( + "fused_backward_program_execute", + &flat_bucket_registered_temporal_fused_backward_program_cuda, + "Registered Fabric temporal fused backward program span over compiler artifacts, tensor tables, reset rows, transition dynamic binding rows, transition seed rows, and selected executor rows (CUDA)"); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py new file mode 100644 index 00000000..ead169f7 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py @@ -0,0 +1,616 @@ +from __future__ import annotations + +import os + +import torch + +from cortical.native.extension_loader import safe_load_extension + +_MOD_PATH = os.path.dirname(__file__) +_SRC_ROOT = os.path.normpath(os.path.join(_MOD_PATH, "..", "..", "..", "..", "..", "..")) +_EXT = None +_LAST_FUSED_BACKWARD_PROGRAM_STAGE_MEMORY_ROWS: torch.Tensor | None = None +_REVERSE_PROGRAM_STAGE_OPCODE = { + "output_grad_window": 1, + "readout_message_kv_step": 2, + "transition_step": 3, + "recurrent_message_boundary_initial_kv_step": 4, + "parameter_reducer_step": 5, +} + + +def _require_cpu_long_rows(tensor: torch.Tensor, *, name: str, columns: int) -> torch.Tensor: + if ( + tensor.device.type != "cpu" + or tensor.dtype != torch.long + or tensor.dim() != 2 + or int(tensor.shape[1]) != columns + ): + raise RuntimeError( + f"{name} must be a CPU int64 tensor with shape [N,{int(columns)}] before registered CUDA launch" + ) + return tensor.contiguous() + + +def _require_native_strategy_rows(tensor: torch.Tensor) -> torch.Tensor: + rows = _require_cpu_long_rows(tensor, name="native_strategy_rows", columns=17) + if int(rows.shape[0]) <= 0: + raise RuntimeError("native_strategy_rows must contain at least one compiler-owned native strategy row") + return rows + + +def _require_transition_primitive_callable_rows(tensor: torch.Tensor) -> torch.Tensor: + rows = _require_cpu_long_rows(tensor, name="transition_primitive_callable_rows", columns=6) + if int(rows.shape[0]) <= 0: + raise RuntimeError( + "transition_primitive_callable_rows must contain compiler-owned transition primitive callable rows" + ) + return rows + + +def _require_native_callable_output_rows(tensor: torch.Tensor) -> torch.Tensor: + rows = _require_cpu_long_rows(tensor, name="native_callable_output_rows", columns=12) + if int(rows.shape[0]) <= 0: + raise RuntimeError("native_callable_output_rows must contain compiler-owned native callable output rows") + return rows + + +def _require_reverse_span_output_rows(tensor: torch.Tensor) -> torch.Tensor: + rows = _require_cpu_long_rows(tensor, name="reverse_span_output_rows", columns=6) + if int(rows.shape[0]) <= 0: + raise RuntimeError("reverse_span_output_rows must contain compiler-owned reverse span output rows") + return rows + + +def _require_native_callable_binding_schema_rows(tensor: torch.Tensor) -> torch.Tensor: + rows = _require_cpu_long_rows(tensor, name="native_callable_binding_schema_rows", columns=10) + if int(rows.shape[0]) <= 0: + raise RuntimeError( + "native_callable_binding_schema_rows must contain compiler-owned native callable binding rows" + ) + return rows + + +def _require_physical_strategy_rows(tensor: torch.Tensor) -> torch.Tensor: + rows = _require_cpu_long_rows(tensor, name="physical_strategy_rows", columns=12) + if int(rows.shape[0]) <= 0: + raise RuntimeError("physical_strategy_rows must contain compiler-owned physical strategy rows") + return rows + + +def _require_reverse_program_stage( + reverse_program_stage_rows: torch.Tensor, + *, + stage_name: str, +) -> torch.Tensor: + rows = _require_cpu_long_rows(reverse_program_stage_rows, name="reverse_program_stage_rows", columns=10) + stage_opcode = _REVERSE_PROGRAM_STAGE_OPCODE[stage_name] + if not any(int(row[1]) == int(stage_opcode) for row in rows.tolist()): + raise RuntimeError( + "Registered temporal reverse program stage dispatch requires a compiler stage row: " + f"stage={stage_name!r}; opcode={int(stage_opcode)}" + ) + return rows + + +def _is_fused_backward_program_stage_memory_rows(tensor: torch.Tensor) -> bool: + return ( + tensor.device.type == "cpu" + and tensor.dtype == torch.long + and tensor.dim() == 2 + and int(tensor.shape[1]) == 5 + and int(tensor.shape[0]) > 0 + ) + + +def registered_temporal_fused_backward_program_stage_memory_rows() -> torch.Tensor | None: + return _LAST_FUSED_BACKWARD_PROGRAM_STAGE_MEMORY_ROWS + + +def _load_ext(): + global _EXT + if _EXT is not None: + return _EXT + _EXT = safe_load_extension( + name="fabric_flat_bucket_registered_program_cuda", + sources=[ + os.path.join(_MOD_PATH, "flat_bucket_registered_program_binding.cpp"), + os.path.join(_MOD_PATH, "flat_bucket_registered_program_kernels.cu"), + os.path.join(_SRC_ROOT, "cortical/fabric/backend/cuda/ops/dense_affine_kernels.cu"), + ], + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3", "-Xptxas", "-O3"], + extra_include_paths=[_SRC_ROOT], + extra_ldflags=["-lcublas"], + verbose=False, + ) + return _EXT + + +def registered_temporal_fused_forward_program_validate_cuda( + *, + primitive_rows: torch.Tensor, + forward_executor_rows: torch.Tensor, + reverse_executor_rows: torch.Tensor, + forward_handler_rows: torch.Tensor, + reverse_handler_rows: torch.Tensor, + native_strategy_rows: torch.Tensor, + native_callable_binding_schema_rows: torch.Tensor, + native_callable_output_rows: torch.Tensor, + forward_executor_binding_rows: torch.Tensor, + reverse_executor_binding_rows: torch.Tensor, + memory_liveness_rows: torch.Tensor, + schema_version: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + summary, forward_spans, reverse_spans = _load_ext().fused_forward_program_validate( + _require_cpu_long_rows(primitive_rows, name="primitive_rows", columns=4), + _require_cpu_long_rows(forward_executor_rows, name="forward_executor_rows", columns=6), + _require_cpu_long_rows(reverse_executor_rows, name="reverse_executor_rows", columns=6), + _require_cpu_long_rows(forward_handler_rows, name="forward_handler_rows", columns=11), + _require_cpu_long_rows(reverse_handler_rows, name="reverse_handler_rows", columns=11), + _require_native_strategy_rows(native_strategy_rows), + _require_native_callable_binding_schema_rows(native_callable_binding_schema_rows), + _require_native_callable_output_rows(native_callable_output_rows), + _require_cpu_long_rows(forward_executor_binding_rows, name="forward_executor_binding_rows", columns=8), + _require_cpu_long_rows(reverse_executor_binding_rows, name="reverse_executor_binding_rows", columns=8), + _require_cpu_long_rows(memory_liveness_rows, name="memory_liveness_rows", columns=10), + int(schema_version), + ) + return summary, forward_spans, reverse_spans + + +def registered_temporal_fused_backward_program_validate_cuda( + *, + primitive_rows: torch.Tensor, + forward_executor_rows: torch.Tensor, + reverse_executor_rows: torch.Tensor, + forward_handler_rows: torch.Tensor, + reverse_handler_rows: torch.Tensor, + native_strategy_rows: torch.Tensor, + native_callable_binding_schema_rows: torch.Tensor, + native_callable_output_rows: torch.Tensor, + forward_executor_binding_rows: torch.Tensor, + reverse_executor_binding_rows: torch.Tensor, + memory_liveness_rows: torch.Tensor, + schema_version: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + summary, forward_spans, reverse_spans = _load_ext().fused_backward_program_validate( + _require_cpu_long_rows(primitive_rows, name="primitive_rows", columns=4), + _require_cpu_long_rows(forward_executor_rows, name="forward_executor_rows", columns=6), + _require_cpu_long_rows(reverse_executor_rows, name="reverse_executor_rows", columns=6), + _require_cpu_long_rows(forward_handler_rows, name="forward_handler_rows", columns=11), + _require_cpu_long_rows(reverse_handler_rows, name="reverse_handler_rows", columns=11), + _require_native_strategy_rows(native_strategy_rows), + _require_native_callable_binding_schema_rows(native_callable_binding_schema_rows), + _require_native_callable_output_rows(native_callable_output_rows), + _require_cpu_long_rows(forward_executor_binding_rows, name="forward_executor_binding_rows", columns=8), + _require_cpu_long_rows(reverse_executor_binding_rows, name="reverse_executor_binding_rows", columns=8), + _require_cpu_long_rows(memory_liveness_rows, name="memory_liveness_rows", columns=10), + int(schema_version), + ) + return summary, forward_spans, reverse_spans + + +def registered_temporal_fused_forward_program_cuda( + *, + boundary_seq: torch.Tensor, + recurrent_hidden_initial_backend_order: torch.Tensor, + program_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + program_tensor_binding_rows: torch.Tensor, + forward_program_access_rows: torch.Tensor, + forward_transition_state_carry_rows: torch.Tensor, + forward_artifact_route_rows: torch.Tensor, + forward_artifact_merge_rows: torch.Tensor, + forward_output_route_rows: torch.Tensor, + readout_message_producer_consumer_rows: torch.Tensor, + message_transition_producer_consumer_rows: torch.Tensor, + forward_reset_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + forward_reset_rows: torch.Tensor, + primitive_rows: torch.Tensor, + forward_executor_rows: torch.Tensor, + reverse_executor_rows: torch.Tensor, + forward_handler_rows: torch.Tensor, + reverse_handler_rows: torch.Tensor, + native_strategy_rows: torch.Tensor, + native_callable_binding_schema_rows: torch.Tensor, + native_callable_output_rows: torch.Tensor, + transition_primitive_callable_rows: torch.Tensor, + forward_executor_binding_rows: torch.Tensor, + reverse_executor_binding_rows: torch.Tensor, + memory_liveness_rows: torch.Tensor, + memory_runtime_schedule_rows: torch.Tensor, + physical_strategy_rows: torch.Tensor, + runtime_buffer_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + runtime_buffer_rows: torch.Tensor, + forward_program_runtime_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + forward_program_runtime_rows: torch.Tensor, + return_final_program_tensors: bool = True, + return_reverse_artifacts: bool = False, + schema_version: int = 1, +) -> tuple[torch.Tensor, ...]: + outputs = _load_ext().fused_forward_program_execute( + boundary_seq.contiguous(), + recurrent_hidden_initial_backend_order.contiguous(), + [tensor.contiguous() for tensor in program_tensors], + _require_cpu_long_rows(program_tensor_binding_rows, name="program_tensor_binding_rows", columns=4), + _require_cpu_long_rows(forward_program_access_rows, name="forward_program_access_rows", columns=6), + _require_cpu_long_rows( + forward_transition_state_carry_rows, + name="forward_transition_state_carry_rows", + columns=3, + ), + _require_cpu_long_rows(forward_artifact_route_rows, name="forward_artifact_route_rows", columns=10), + _require_cpu_long_rows(forward_artifact_merge_rows, name="forward_artifact_merge_rows", columns=12), + _require_cpu_long_rows(forward_output_route_rows, name="forward_output_route_rows", columns=10), + _require_cpu_long_rows( + readout_message_producer_consumer_rows, + name="readout_message_producer_consumer_rows", + columns=16, + ), + _require_cpu_long_rows( + message_transition_producer_consumer_rows, + name="message_transition_producer_consumer_rows", + columns=16, + ), + [tensor.contiguous() for tensor in forward_reset_tensors], + _require_cpu_long_rows(forward_reset_rows, name="forward_reset_rows", columns=4), + _require_cpu_long_rows(primitive_rows, name="primitive_rows", columns=4), + _require_cpu_long_rows(forward_executor_rows, name="forward_executor_rows", columns=6), + _require_cpu_long_rows(reverse_executor_rows, name="reverse_executor_rows", columns=6), + _require_cpu_long_rows(forward_handler_rows, name="forward_handler_rows", columns=11), + _require_cpu_long_rows(reverse_handler_rows, name="reverse_handler_rows", columns=11), + _require_native_strategy_rows(native_strategy_rows), + _require_native_callable_binding_schema_rows(native_callable_binding_schema_rows), + _require_native_callable_output_rows(native_callable_output_rows), + _require_transition_primitive_callable_rows(transition_primitive_callable_rows), + _require_cpu_long_rows(forward_executor_binding_rows, name="forward_executor_binding_rows", columns=8), + _require_cpu_long_rows(reverse_executor_binding_rows, name="reverse_executor_binding_rows", columns=8), + _require_cpu_long_rows(memory_liveness_rows, name="memory_liveness_rows", columns=10), + _require_cpu_long_rows(memory_runtime_schedule_rows, name="memory_runtime_schedule_rows", columns=6), + _require_physical_strategy_rows(physical_strategy_rows), + [tensor.contiguous() for tensor in runtime_buffer_tensors], + _require_cpu_long_rows(runtime_buffer_rows, name="runtime_buffer_rows", columns=10), + [tensor.contiguous() for tensor in forward_program_runtime_tensors], + _require_cpu_long_rows(forward_program_runtime_rows, name="forward_program_runtime_rows", columns=6), + bool(return_final_program_tensors), + bool(return_reverse_artifacts), + int(schema_version), + ) + return tuple(outputs) + + +def registered_temporal_fused_forward_transition_program_cuda( + *, + program_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + program_tensor_binding_rows: torch.Tensor, + runtime_buffer_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + runtime_buffer_rows: torch.Tensor, + primitive_rows: torch.Tensor, + forward_executor_rows: torch.Tensor, + forward_handler_rows: torch.Tensor, + native_strategy_rows: torch.Tensor, + native_callable_binding_schema_rows: torch.Tensor, + native_callable_output_rows: torch.Tensor, + transition_primitive_callable_rows: torch.Tensor, + forward_executor_binding_rows: torch.Tensor, + memory_liveness_rows: torch.Tensor, + forward_transition_state_carry_rows: torch.Tensor | None = None, + release_dead_input_bindings: bool = False, + allow_terminal_local_state_outputs: bool = False, + schema_version: int = 1, +) -> tuple[torch.Tensor, ...]: + if forward_transition_state_carry_rows is None: + forward_transition_state_carry_rows = torch.empty((0, 3), dtype=torch.long) + outputs = _load_ext().fused_forward_transition_program_execute( + [tensor.contiguous() for tensor in program_tensors], + _require_cpu_long_rows(program_tensor_binding_rows, name="program_tensor_binding_rows", columns=4), + [tensor.contiguous() for tensor in runtime_buffer_tensors], + _require_cpu_long_rows(runtime_buffer_rows, name="runtime_buffer_rows", columns=10), + _require_cpu_long_rows(primitive_rows, name="primitive_rows", columns=4), + _require_cpu_long_rows(forward_executor_rows, name="forward_executor_rows", columns=6), + _require_cpu_long_rows(forward_handler_rows, name="forward_handler_rows", columns=11), + _require_native_strategy_rows(native_strategy_rows), + _require_native_callable_binding_schema_rows(native_callable_binding_schema_rows), + _require_native_callable_output_rows(native_callable_output_rows), + _require_transition_primitive_callable_rows(transition_primitive_callable_rows), + _require_cpu_long_rows(forward_executor_binding_rows, name="forward_executor_binding_rows", columns=8), + _require_cpu_long_rows(memory_liveness_rows, name="memory_liveness_rows", columns=10), + _require_cpu_long_rows( + forward_transition_state_carry_rows, + name="forward_transition_state_carry_rows", + columns=3, + ), + bool(release_dead_input_bindings), + bool(allow_terminal_local_state_outputs), + int(schema_version), + ) + return tuple(outputs) + + +def registered_temporal_fused_reverse_transition_program_cuda( + *, + program_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + program_tensor_binding_rows: torch.Tensor, + primitive_rows: torch.Tensor, + reverse_executor_rows: torch.Tensor, + reverse_handler_rows: torch.Tensor, + native_strategy_rows: torch.Tensor, + native_callable_binding_schema_rows: torch.Tensor, + native_callable_output_rows: torch.Tensor, + transition_primitive_callable_rows: torch.Tensor, + reverse_executor_binding_rows: torch.Tensor, + memory_liveness_rows: torch.Tensor, + schema_version: int = 1, + return_state_grads: bool = True, +) -> tuple[torch.Tensor, ...]: + outputs = _load_ext().fused_reverse_transition_program_execute( + [tensor.contiguous() for tensor in program_tensors], + _require_cpu_long_rows(program_tensor_binding_rows, name="program_tensor_binding_rows", columns=4), + _require_cpu_long_rows(primitive_rows, name="primitive_rows", columns=4), + _require_cpu_long_rows(reverse_executor_rows, name="reverse_executor_rows", columns=6), + _require_cpu_long_rows(reverse_handler_rows, name="reverse_handler_rows", columns=11), + _require_native_strategy_rows(native_strategy_rows), + _require_native_callable_binding_schema_rows(native_callable_binding_schema_rows), + _require_native_callable_output_rows(native_callable_output_rows), + _require_transition_primitive_callable_rows(transition_primitive_callable_rows), + _require_cpu_long_rows(reverse_executor_binding_rows, name="reverse_executor_binding_rows", columns=8), + _require_cpu_long_rows(memory_liveness_rows, name="memory_liveness_rows", columns=10), + int(schema_version), + bool(return_state_grads), + ) + return tuple(outputs) + + +def registered_temporal_parameter_reducer_program_cuda( + *, + parameter_reducer_rows: torch.Tensor, + parameter_reducer_strategy_rows: torch.Tensor, + parameter_reducer_trainable_role_rows: torch.Tensor, + parameter_reducer_runtime_metadata_rows: torch.Tensor, + transition_source_rows: torch.Tensor, + transition_trainable_rows: torch.Tensor, + sender_grad_weight_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + sender_group_id_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + sender_grouped_flags: torch.Tensor, + readout_grad_value_to_output_weight_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + readout_grad_output_cell_bias_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + recurrent_query_grad_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + output_query_grad_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + message_strategy_grad_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + message_strategy_grad_rows: torch.Tensor, + transition_source_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_source_recurrent_cell_idx_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + parameter_output_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + trainable_param_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + runtime_metadata_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + coord_count: int, + head_dim: int, + value_dim: int, + schema_version: int = 1, +) -> tuple[torch.Tensor, ...]: + outputs = _load_ext().parameter_reducer_program_execute( + _require_cpu_long_rows(parameter_reducer_rows, name="parameter_reducer_rows", columns=8), + _require_cpu_long_rows( + parameter_reducer_strategy_rows, + name="parameter_reducer_strategy_rows", + columns=9, + ), + _require_cpu_long_rows( + parameter_reducer_trainable_role_rows, + name="parameter_reducer_trainable_role_rows", + columns=6, + ), + _require_cpu_long_rows( + parameter_reducer_runtime_metadata_rows, + name="parameter_reducer_runtime_metadata_rows", + columns=4, + ), + _require_cpu_long_rows(transition_source_rows, name="transition_source_rows", columns=8), + _require_cpu_long_rows(transition_trainable_rows, name="transition_trainable_rows", columns=9), + [tensor.contiguous() for tensor in sender_grad_weight_tensors], + [tensor.contiguous() for tensor in sender_group_id_tensors], + sender_grouped_flags.to(device="cpu", dtype=torch.long).contiguous(), + [tensor.contiguous() for tensor in readout_grad_value_to_output_weight_tensors], + [tensor.contiguous() for tensor in readout_grad_output_cell_bias_tensors], + [tensor.contiguous() for tensor in recurrent_query_grad_tensors], + [tensor.contiguous() for tensor in output_query_grad_tensors], + [tensor.contiguous() for tensor in message_strategy_grad_tensors], + _require_cpu_long_rows(message_strategy_grad_rows, name="message_strategy_grad_rows", columns=5), + [tensor.contiguous() for tensor in transition_source_tensors], + [ + tensor.to(device=tensor.device, dtype=torch.long).contiguous() + for tensor in transition_source_recurrent_cell_idx_tensors + ], + [tensor.contiguous() for tensor in parameter_output_tensors], + [tensor.contiguous() for tensor in trainable_param_tensors], + [tensor.contiguous() for tensor in runtime_metadata_tensors], + int(coord_count), + int(head_dim), + int(value_dim), + int(schema_version), + ) + return tuple(outputs) + + +def registered_temporal_fused_backward_program_cuda( + *, + reverse_program_stage_rows: torch.Tensor, + grad_output_window: torch.Tensor, + grad_carry_cells: torch.Tensor | None, + reverse_artifact_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + reverse_artifact_binding_rows: torch.Tensor, + reverse_artifact_role_rows: torch.Tensor, + reverse_artifact_access_rows: torch.Tensor, + forward_artifact_route_rows: torch.Tensor, + forward_artifact_merge_rows: torch.Tensor, + forward_output_route_rows: torch.Tensor, + reverse_artifact_consumer_route_rows: torch.Tensor, + reverse_reset_tensor_groups: tuple[tuple[torch.Tensor, ...], ...] | list[tuple[torch.Tensor, ...]], + reverse_reset_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + program_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + program_tensor_binding_rows: torch.Tensor, + reverse_program_access_rows: torch.Tensor, + primitive_rows: torch.Tensor, + forward_executor_rows: torch.Tensor, + reverse_executor_rows: torch.Tensor, + forward_handler_rows: torch.Tensor, + reverse_handler_rows: torch.Tensor, + native_strategy_rows: torch.Tensor, + native_callable_binding_schema_rows: torch.Tensor, + native_callable_output_rows: torch.Tensor, + reverse_span_output_rows: torch.Tensor, + transition_reverse_seed_role_rows: torch.Tensor, + transition_primitive_callable_rows: torch.Tensor, + forward_executor_binding_rows: torch.Tensor, + reverse_executor_binding_rows: torch.Tensor, + memory_liveness_rows: torch.Tensor, + memory_runtime_schedule_rows: torch.Tensor, + physical_strategy_rows: torch.Tensor, + runtime_buffer_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + runtime_buffer_rows: torch.Tensor, + reverse_program_runtime_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + reverse_program_runtime_rows: torch.Tensor, + transition_program_tensor_groups: tuple[tuple[torch.Tensor, ...], ...] | list[tuple[torch.Tensor, ...]], + transition_program_tensor_binding_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_forward_executor_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_reverse_executor_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_forward_executor_binding_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_reverse_executor_binding_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_memory_liveness_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_seed_tensor_groups: tuple[tuple[torch.Tensor, ...], ...] | list[tuple[torch.Tensor, ...]], + transition_seed_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_dynamic_binding_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_output_keep_slot_row_groups: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_parameter_tensors: tuple[torch.Tensor, ...] | list[torch.Tensor], + transition_parameter_rows: torch.Tensor, + transition_recurrent_msg_output_rows: torch.Tensor, + transition_public_y_seed_rows: torch.Tensor, + transition_state_reset_rows: torch.Tensor, + transition_next_seed_output_rows: torch.Tensor, + return_window_start_transition_state_grads: bool = True, + schema_version: int = 1, +) -> tuple[tuple[torch.Tensor, ...], ...]: + global _LAST_FUSED_BACKWARD_PROGRAM_STAGE_MEMORY_ROWS + _LAST_FUSED_BACKWARD_PROGRAM_STAGE_MEMORY_ROWS = None + _require_reverse_program_stage(reverse_program_stage_rows, stage_name="transition_step") + _require_reverse_program_stage( + reverse_program_stage_rows, + stage_name="recurrent_message_boundary_initial_kv_step", + ) + carry_tensor = grad_output_window.new_empty(0) if grad_carry_cells is None else grad_carry_cells.contiguous() + outputs = _load_ext().fused_backward_program_execute( + grad_output_window.contiguous(), + carry_tensor, + _require_cpu_long_rows(reverse_program_stage_rows, name="reverse_program_stage_rows", columns=10), + [tensor.contiguous() for tensor in reverse_artifact_tensors], + _require_cpu_long_rows(reverse_artifact_binding_rows, name="reverse_artifact_binding_rows", columns=5), + _require_cpu_long_rows(reverse_artifact_role_rows, name="reverse_artifact_role_rows", columns=3), + _require_cpu_long_rows(reverse_artifact_access_rows, name="reverse_artifact_access_rows", columns=3), + _require_cpu_long_rows(forward_artifact_route_rows, name="forward_artifact_route_rows", columns=10), + _require_cpu_long_rows(forward_artifact_merge_rows, name="forward_artifact_merge_rows", columns=12), + _require_cpu_long_rows(forward_output_route_rows, name="forward_output_route_rows", columns=10), + _require_cpu_long_rows( + reverse_artifact_consumer_route_rows, + name="reverse_artifact_consumer_route_rows", + columns=12, + ), + [ + [tensor.to(device=tensor.device, dtype=torch.bool).contiguous() for tensor in group] + for group in reverse_reset_tensor_groups + ], + [_require_cpu_long_rows(rows, name="reverse_reset_rows", columns=4) for rows in reverse_reset_row_groups], + [tensor.contiguous() for tensor in program_tensors], + _require_cpu_long_rows(program_tensor_binding_rows, name="program_tensor_binding_rows", columns=4), + _require_cpu_long_rows(reverse_program_access_rows, name="reverse_program_access_rows", columns=6), + _require_cpu_long_rows(primitive_rows, name="primitive_rows", columns=4), + _require_cpu_long_rows(forward_executor_rows, name="forward_executor_rows", columns=6), + _require_cpu_long_rows(reverse_executor_rows, name="reverse_executor_rows", columns=6), + _require_cpu_long_rows(forward_handler_rows, name="forward_handler_rows", columns=11), + _require_cpu_long_rows(reverse_handler_rows, name="reverse_handler_rows", columns=11), + _require_native_strategy_rows(native_strategy_rows), + _require_native_callable_binding_schema_rows(native_callable_binding_schema_rows), + _require_native_callable_output_rows(native_callable_output_rows), + _require_reverse_span_output_rows(reverse_span_output_rows), + _require_cpu_long_rows( + transition_reverse_seed_role_rows, + name="transition_reverse_seed_role_rows", + columns=4, + ), + _require_transition_primitive_callable_rows(transition_primitive_callable_rows), + _require_cpu_long_rows(forward_executor_binding_rows, name="forward_executor_binding_rows", columns=8), + _require_cpu_long_rows(reverse_executor_binding_rows, name="reverse_executor_binding_rows", columns=8), + _require_cpu_long_rows(memory_liveness_rows, name="memory_liveness_rows", columns=10), + _require_cpu_long_rows(memory_runtime_schedule_rows, name="memory_runtime_schedule_rows", columns=6), + _require_physical_strategy_rows(physical_strategy_rows), + [tensor.contiguous() for tensor in runtime_buffer_tensors], + _require_cpu_long_rows(runtime_buffer_rows, name="runtime_buffer_rows", columns=10), + [tensor.contiguous() for tensor in reverse_program_runtime_tensors], + _require_cpu_long_rows(reverse_program_runtime_rows, name="reverse_program_runtime_rows", columns=6), + [[tensor.contiguous() for tensor in group] for group in transition_program_tensor_groups], + [ + _require_cpu_long_rows(rows, name="transition_program_tensor_binding_rows", columns=4) + for rows in transition_program_tensor_binding_row_groups + ], + [ + _require_cpu_long_rows(rows, name="transition_forward_executor_rows", columns=6) + for rows in transition_forward_executor_row_groups + ], + [ + _require_cpu_long_rows(rows, name="transition_reverse_executor_rows", columns=6) + for rows in transition_reverse_executor_row_groups + ], + [ + _require_cpu_long_rows(rows, name="transition_forward_executor_binding_rows", columns=8) + for rows in transition_forward_executor_binding_row_groups + ], + [ + _require_cpu_long_rows(rows, name="transition_reverse_executor_binding_rows", columns=8) + for rows in transition_reverse_executor_binding_row_groups + ], + [ + _require_cpu_long_rows(rows, name="transition_memory_liveness_rows", columns=10) + for rows in transition_memory_liveness_row_groups + ], + [[tensor.contiguous() for tensor in group] for group in transition_seed_tensor_groups], + [_require_cpu_long_rows(rows, name="transition_seed_rows", columns=3) for rows in transition_seed_row_groups], + [ + _require_cpu_long_rows(rows, name="transition_dynamic_binding_rows", columns=5) + for rows in transition_dynamic_binding_row_groups + ], + [ + _require_cpu_long_rows(rows, name="transition_output_keep_slot_rows", columns=1) + for rows in transition_output_keep_slot_row_groups + ], + [tensor.contiguous() for tensor in transition_parameter_tensors], + _require_cpu_long_rows(transition_parameter_rows, name="transition_parameter_rows", columns=3), + _require_cpu_long_rows( + transition_recurrent_msg_output_rows, + name="transition_recurrent_msg_output_rows", + columns=4, + ), + _require_cpu_long_rows(transition_public_y_seed_rows, name="transition_public_y_seed_rows", columns=4), + _require_cpu_long_rows(transition_state_reset_rows, name="transition_state_reset_rows", columns=2), + _require_cpu_long_rows(transition_next_seed_output_rows, name="transition_next_seed_output_rows", columns=4), + bool(return_window_start_transition_state_grads), + int(schema_version), + ) + grouped_outputs = tuple(tuple(group) for group in outputs) + if not grouped_outputs: + raise RuntimeError("Fused backward program returned no output groups") + last_group = grouped_outputs[-1] + if len(last_group) == 1 and _is_fused_backward_program_stage_memory_rows(last_group[0]): + _LAST_FUSED_BACKWARD_PROGRAM_STAGE_MEMORY_ROWS = last_group[0].contiguous() + grouped_outputs = grouped_outputs[:-1] + if not grouped_outputs: + raise RuntimeError("Fused backward program returned no semantic output groups") + return grouped_outputs + + +__all__ = [ + "registered_temporal_fused_backward_program_cuda", + "registered_temporal_fused_backward_program_stage_memory_rows", + "registered_temporal_fused_backward_program_validate_cuda", + "registered_temporal_fused_forward_transition_program_cuda", + "registered_temporal_fused_forward_program_cuda", + "registered_temporal_fused_forward_program_validate_cuda", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_kernels.cu b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_kernels.cu new file mode 100644 index 00000000..33a48df5 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_kernels.cu @@ -0,0 +1,12 @@ +// Semantic split for the registered temporal CUDA program. +#include "registered_program/common.cuh" +#include "registered_program/operator_declarations.cuh" +#include "registered_program/transition_forward_program.cuh" +#include "registered_program/forward_program.cuh" +#include "registered_program/backward_surface_steps.cuh" +#include "registered_program/transition_primitive_forward_ops.cuh" +#include "registered_program/transition_reverse_handlers.cuh" +#include "registered_program/transition_reverse_program.cuh" +#include "registered_program/parameter_reducer_program.cuh" +#include "registered_program/backward_program.cuh" +#include "registered_program/operator_exports.cuh" diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_buckets.py b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_buckets.py new file mode 100644 index 00000000..327587e0 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_buckets.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import torch +from tensordict import TensorDict + +from cortical.fabric.backend.cuda.transition_execution.types import TransitionInputProjectionParamGradStep + + +@dataclass(frozen=True) +class BackendOrderTransitionPopulationParamGrads: + materialized_param_grads: dict[str, torch.Tensor] + static_source_grads: dict[str, torch.Tensor] + input_projection_param_grad_steps: tuple[ + TransitionInputProjectionParamGradStep, + ..., + ] = () + + +@dataclass(frozen=True) +class BackendOrderTransitionParamGrads: + by_population: dict[str, BackendOrderTransitionPopulationParamGrads] + + +def _population_grad_state_to_backend_grad_state( + runtime: Any, + population_name: str, + population_grad_state: Mapping[str, torch.Tensor | None] | None, +) -> dict[str, torch.Tensor | None] | None: + if population_grad_state is None: + return None + state_names = tuple(runtime._compiled_transition_state_names_for_population(population_name)) + return { + state_name: (grad.permute(1, 0, 2).contiguous() if torch.is_tensor(grad) else None) + for state_name in state_names + for grad in (population_grad_state.get(state_name),) + } + + +def _partial_backend_grad_state_to_population_state(backend_state: Mapping[str, object]) -> TensorDict: + leaves: dict[str, torch.Tensor] = {} + batch_size: list[int] | None = None + for state_name, grad in backend_state.items(): + if not torch.is_tensor(grad) or grad.dim() < 3: + continue + leaves[state_name] = grad.permute(1, 0, 2) + if batch_size is None: + batch_size = [int(grad.shape[1]), int(grad.shape[0])] + return TensorDict(leaves, batch_size=batch_size or []) + + +__all__ = [ + "BackendOrderTransitionParamGrads", + "BackendOrderTransitionPopulationParamGrads", + "_partial_backend_grad_state_to_population_state", + "_population_grad_state_to_backend_grad_state", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/artifact_routes.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/artifact_routes.cuh new file mode 100644 index 00000000..073988dd --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/artifact_routes.cuh @@ -0,0 +1,470 @@ +#pragma once + +inline void validate_forward_artifact_route_rows( + const at::Tensor& forward_artifact_route_rows, + int64_t schema_version) { + check_cpu_long_rank2( + forward_artifact_route_rows, + "forward_artifact_route_rows", + kForwardArtifactRouteRowColumns); + const int64_t* rows = forward_artifact_route_rows.data_ptr(); + bool saw_message_artifact = false; + bool saw_readout_artifact = false; + bool saw_transition_artifact = false; + for (int64_t row_index = 0; row_index < forward_artifact_route_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kForwardArtifactRouteRowColumns; + TORCH_CHECK(row[0] == row_index, "forward_artifact_route_rows must be densely indexed"); + TORCH_CHECK(row[1] > 0, "forward_artifact_route_rows row ", row_index, " has invalid surface opcode"); + TORCH_CHECK(row[2] >= -1, "forward_artifact_route_rows row ", row_index, " has invalid executor row index"); + TORCH_CHECK(row[3] >= 0, "forward_artifact_route_rows row ", row_index, " has invalid executor id"); + TORCH_CHECK(row[5] > 0, "forward_artifact_route_rows row ", row_index, " has invalid artifact role"); + TORCH_CHECK(row[6] > 0, "forward_artifact_route_rows row ", row_index, " has invalid logical route id"); + TORCH_CHECK( + row[7] == 0 || row[7] == 1, + "forward_artifact_route_rows row ", + row_index, + " has invalid required flag"); + TORCH_CHECK( + row[8] == schema_version, + "forward_artifact_route_rows row ", + row_index, + " has unsupported schema version"); + saw_message_artifact = saw_message_artifact || row[1] == kMessageSurfaceOpcode; + saw_readout_artifact = saw_readout_artifact || row[1] == kReadoutSurfaceOpcode; + saw_transition_artifact = saw_transition_artifact || row[1] == kTransitionSurfaceOpcode; + } + TORCH_CHECK(saw_message_artifact, "forward_artifact_route_rows are missing message artifact ownership"); + TORCH_CHECK(saw_readout_artifact, "forward_artifact_route_rows are missing readout artifact ownership"); + TORCH_CHECK(saw_transition_artifact, "forward_artifact_route_rows are missing transition artifact ownership"); +} + +inline int64_t forward_artifact_route_row_for( + const at::Tensor& forward_artifact_route_rows, + int64_t surface_opcode, + int64_t executor_row_index, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t role_id, + int64_t schema_version, + const char* subject) { + validate_forward_artifact_route_rows(forward_artifact_route_rows, schema_version); + const int64_t* rows = forward_artifact_route_rows.data_ptr(); + int64_t selected = -1; + for (int64_t row_index = 0; row_index < forward_artifact_route_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kForwardArtifactRouteRowColumns; + if ( + row[1] != surface_opcode || + row[2] != executor_row_index || + row[3] != executor_id || + row[4] != bucket_ordinal || + row[5] != role_id) { + continue; + } + TORCH_CHECK(selected < 0, subject, " has duplicate compiler artifact route rows"); + selected = row[0]; + } + TORCH_CHECK( + selected >= 0, + subject, + " has no compiler artifact route row for surface=", + surface_opcode, + "; executor_row=", + executor_row_index, + "; executor_id=", + executor_id, + "; bucket=", + bucket_ordinal, + "; role=", + role_id); + return selected; +} + +inline int64_t forward_artifact_route_row_for_surface_bucket_role( + const at::Tensor& forward_artifact_route_rows, + int64_t surface_opcode, + int64_t bucket_ordinal, + int64_t role_id, + int64_t schema_version, + const char* subject) { + validate_forward_artifact_route_rows(forward_artifact_route_rows, schema_version); + const int64_t* rows = forward_artifact_route_rows.data_ptr(); + int64_t selected = -1; + for (int64_t row_index = 0; row_index < forward_artifact_route_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kForwardArtifactRouteRowColumns; + if (row[1] != surface_opcode || row[4] != bucket_ordinal || row[5] != role_id) { + continue; + } + TORCH_CHECK( + selected < 0, + subject, + " has multiple compiler artifact producer routes for surface=", + surface_opcode, + "; bucket=", + bucket_ordinal, + "; role=", + role_id); + selected = row[0]; + } + TORCH_CHECK( + selected >= 0, + subject, + " has no compiler artifact producer route for surface=", + surface_opcode, + "; bucket=", + bucket_ordinal, + "; role=", + role_id); + return selected; +} + +inline void validate_forward_artifact_merge_rows( + const at::Tensor& forward_artifact_merge_rows, + const at::Tensor& forward_artifact_route_rows, + int64_t schema_version) { + validate_forward_artifact_route_rows(forward_artifact_route_rows, schema_version); + check_cpu_long_rank2( + forward_artifact_merge_rows, + "forward_artifact_merge_rows", + kForwardArtifactMergeRowColumns); + const int64_t* merge_rows = forward_artifact_merge_rows.data_ptr(); + const int64_t* route_rows = forward_artifact_route_rows.data_ptr(); + bool saw_message_artifact = false; + bool saw_readout_artifact = false; + bool saw_transition_artifact = false; + for (int64_t row_index = 0; row_index < forward_artifact_merge_rows.size(0); ++row_index) { + const int64_t* row = merge_rows + row_index * kForwardArtifactMergeRowColumns; + const int64_t surface_opcode = row[1]; + const int64_t role_id = row[3]; + const int64_t merge_kind = row[4]; + const int64_t producer_route_row = row[6]; + TORCH_CHECK(row[0] == row_index, "forward_artifact_merge_rows must be densely indexed"); + TORCH_CHECK(surface_opcode > 0, "forward_artifact_merge_rows row ", row_index, " has invalid surface opcode"); + TORCH_CHECK(role_id > 0, "forward_artifact_merge_rows row ", row_index, " has invalid artifact role"); + TORCH_CHECK(row[5] > 0, "forward_artifact_merge_rows row ", row_index, " has invalid output route id"); + TORCH_CHECK( + merge_kind == kForwardArtifactMergeIdentitySingleton || + merge_kind == kForwardArtifactMergeConcatOrError || + merge_kind == kForwardArtifactMergeSumOrError, + "forward_artifact_merge_rows row ", + row_index, + " has unsupported merge kind"); + TORCH_CHECK( + row[9] == 0 || row[9] == 1, + "forward_artifact_merge_rows row ", + row_index, + " has invalid required flag"); + TORCH_CHECK( + row[10] == schema_version, + "forward_artifact_merge_rows row ", + row_index, + " has unsupported schema version"); + if (merge_kind == kForwardArtifactMergeIdentitySingleton) { + TORCH_CHECK( + producer_route_row >= 0 && producer_route_row < forward_artifact_route_rows.size(0), + "forward_artifact_merge_rows row ", + row_index, + " has invalid producer route row"); + const int64_t* route_row = route_rows + producer_route_row * kForwardArtifactRouteRowColumns; + TORCH_CHECK( + route_row[1] == surface_opcode && route_row[4] == row[2] && route_row[5] == role_id, + "forward_artifact_merge_rows row ", + row_index, + " producer route does not match surface/bucket/role"); + TORCH_CHECK( + route_row[2] == row[7] && route_row[3] == row[8], + "forward_artifact_merge_rows row ", + row_index, + " producer executor fields do not match route row"); + } else { + TORCH_CHECK( + producer_route_row < 0 && row[7] < 0 && row[8] < 0, + "forward_artifact_merge_rows row ", + row_index, + " aggregate merge rows must not name a singleton producer"); + } + saw_message_artifact = saw_message_artifact || surface_opcode == kMessageSurfaceOpcode; + saw_readout_artifact = saw_readout_artifact || surface_opcode == kReadoutSurfaceOpcode; + saw_transition_artifact = saw_transition_artifact || surface_opcode == kTransitionSurfaceOpcode; + } + TORCH_CHECK(saw_message_artifact, "forward_artifact_merge_rows are missing message artifact ownership"); + TORCH_CHECK(saw_readout_artifact, "forward_artifact_merge_rows are missing readout artifact ownership"); + TORCH_CHECK(saw_transition_artifact, "forward_artifact_merge_rows are missing transition artifact ownership"); +} + +inline int64_t forward_artifact_merge_row_for_surface_bucket_role( + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& forward_artifact_merge_rows, + int64_t surface_opcode, + int64_t bucket_ordinal, + int64_t role_id, + int64_t schema_version, + const char* subject) { + validate_forward_artifact_merge_rows(forward_artifact_merge_rows, forward_artifact_route_rows, schema_version); + const int64_t* rows = forward_artifact_merge_rows.data_ptr(); + int64_t selected_row = -1; + for (int64_t row_index = 0; row_index < forward_artifact_merge_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kForwardArtifactMergeRowColumns; + if (row[1] != surface_opcode || row[2] != bucket_ordinal || row[3] != role_id) { + continue; + } + TORCH_CHECK( + selected_row < 0, + subject, + " has duplicate compiler artifact merge rows for surface=", + surface_opcode, + "; bucket=", + bucket_ordinal, + "; role=", + role_id); + selected_row = row_index; + } + TORCH_CHECK( + selected_row >= 0, + subject, + " has no compiler artifact merge row for surface=", + surface_opcode, + "; bucket=", + bucket_ordinal, + "; role=", + role_id); + return selected_row; +} + +inline void validate_forward_output_route_rows( + const at::Tensor& forward_output_route_rows, + int64_t schema_version) { + check_cpu_long_rank2( + forward_output_route_rows, + "forward_output_route_rows", + kForwardOutputRouteRowColumns); + const int64_t* rows = forward_output_route_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_output_route_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kForwardOutputRouteRowColumns; + TORCH_CHECK(row[0] == row_index, "forward_output_route_rows must be densely indexed"); + TORCH_CHECK( + row[1] == kForwardOutputRouteReadoutOutputCells || + row[1] == kForwardOutputRouteReadoutOutputSelect || + row[1] == kForwardOutputRouteReadoutOutputConcat || + row[1] == kForwardOutputRouteReadoutOutputSum, + "forward_output_route_rows row ", + row_index, + " has unsupported route kind"); + TORCH_CHECK( + row[2] == kReadoutSurfaceOpcode, + "forward_output_route_rows row ", + row_index, + " must be readout-owned"); + TORCH_CHECK(row[3] >= 0, "forward_output_route_rows row ", row_index, " has invalid executor row"); + TORCH_CHECK(row[4] > 0, "forward_output_route_rows row ", row_index, " has invalid executor id"); + TORCH_CHECK( + row[5] >= 0 || row[5] == kTemporalReadoutBucketOrdinal, + "forward_output_route_rows row ", + row_index, + " has invalid bucket ordinal"); + TORCH_CHECK(row[6] > 0, "forward_output_route_rows row ", row_index, " has invalid output role"); + TORCH_CHECK( + row[7] == 0 || row[7] == 1, + "forward_output_route_rows row ", + row_index, + " has invalid required flag"); + TORCH_CHECK( + row[8] == schema_version, + "forward_output_route_rows row ", + row_index, + " has unsupported schema version"); + TORCH_CHECK(row[9] >= 0, "forward_output_route_rows row ", row_index, " has invalid output offset"); + } + TORCH_CHECK( + forward_output_route_rows.size(0) > 0, + "forward_output_route_rows must contain at least one executable output route"); + if (forward_output_route_rows.size(0) > 1) { + const int64_t route_kind = rows[1]; + TORCH_CHECK( + route_kind == kForwardOutputRouteReadoutOutputConcat || + route_kind == kForwardOutputRouteReadoutOutputSum, + "forward_output_route_rows with multiple producers must use concat or sum output route semantics"); + for (int64_t row_index = 1; row_index < forward_output_route_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kForwardOutputRouteRowColumns; + TORCH_CHECK( + row[1] == route_kind, + "forward_output_route_rows with multiple producers must use one route kind"); + } + } +} + +inline int64_t forward_output_route_row_for_readout_executor( + const at::Tensor& forward_output_route_rows, + int64_t executor_row_index, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t schema_version, + const char* subject) { + validate_forward_output_route_rows(forward_output_route_rows, schema_version); + const int64_t* rows = forward_output_route_rows.data_ptr(); + int64_t selected = -1; + for (int64_t row_index = 0; row_index < forward_output_route_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kForwardOutputRouteRowColumns; + if (row[3] != executor_row_index || row[4] != executor_id || row[5] != bucket_ordinal) { + continue; + } + TORCH_CHECK(selected < 0, subject, " has duplicate compiler output route rows for readout executor"); + selected = row_index; + } + TORCH_CHECK( + selected >= 0, + subject, + " has no compiler output route row for readout executor: executor_row=", + executor_row_index, + "; executor_id=", + executor_id, + "; bucket=", + bucket_ordinal); + return selected; +} + +inline void validate_reverse_artifact_consumer_route_rows( + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_route_rows, + int64_t schema_version) { + validate_forward_artifact_route_rows(forward_artifact_route_rows, schema_version); + check_cpu_long_rank2( + reverse_artifact_consumer_route_rows, + "reverse_artifact_consumer_route_rows", + kReverseArtifactConsumerRouteRowColumns); + const int64_t* consumer_rows = reverse_artifact_consumer_route_rows.data_ptr(); + const int64_t* forward_rows = forward_artifact_route_rows.data_ptr(); + for (int64_t row_index = 0; row_index < reverse_artifact_consumer_route_rows.size(0); ++row_index) { + const int64_t* row = consumer_rows + row_index * kReverseArtifactConsumerRouteRowColumns; + const int64_t surface_opcode = row[1]; + const int64_t reverse_executor_row = row[2]; + const int64_t reverse_executor_id = row[3]; + const int64_t bucket_ordinal = row[4]; + const int64_t role_id = row[5]; + const int64_t forward_route_row = row[6]; + TORCH_CHECK(row[0] == row_index, "reverse_artifact_consumer_route_rows must be densely indexed"); + TORCH_CHECK( + surface_opcode == kMessageSurfaceOpcode || + surface_opcode == kReadoutSurfaceOpcode || + surface_opcode == kTransitionSurfaceOpcode, + "reverse_artifact_consumer_route_rows row ", + row_index, + " has unsupported surface opcode"); + TORCH_CHECK( + reverse_executor_row >= 0, + "reverse_artifact_consumer_route_rows row ", + row_index, + " has invalid reverse executor row"); + TORCH_CHECK( + reverse_executor_id > 0, + "reverse_artifact_consumer_route_rows row ", + row_index, + " has invalid reverse executor id"); + TORCH_CHECK( + bucket_ordinal >= 0 || + bucket_ordinal == kTemporalMessageBucketOrdinal || + bucket_ordinal == kTemporalReadoutBucketOrdinal, + "reverse_artifact_consumer_route_rows row ", + row_index, + " has invalid bucket ordinal"); + TORCH_CHECK(role_id > 0, "reverse_artifact_consumer_route_rows row ", row_index, " has invalid artifact role"); + TORCH_CHECK( + forward_route_row >= 0 && forward_route_row < forward_artifact_route_rows.size(0), + "reverse_artifact_consumer_route_rows row ", + row_index, + " has invalid forward artifact route row"); + const int64_t* forward_row = forward_rows + forward_route_row * kForwardArtifactRouteRowColumns; + TORCH_CHECK( + forward_row[1] == surface_opcode && forward_row[4] == bucket_ordinal && forward_row[5] == role_id, + "reverse_artifact_consumer_route_rows row ", + row_index, + " forward route does not match surface/bucket/role"); + TORCH_CHECK( + forward_row[2] == row[7] && forward_row[3] == row[8], + "reverse_artifact_consumer_route_rows row ", + row_index, + " forward executor fields do not match route row"); + TORCH_CHECK( + row[9] == 0 || row[9] == 1, + "reverse_artifact_consumer_route_rows row ", + row_index, + " has invalid required flag"); + TORCH_CHECK( + row[10] == schema_version, + "reverse_artifact_consumer_route_rows row ", + row_index, + " has unsupported schema version"); + } +} + +inline int64_t try_reverse_artifact_consumer_forward_route_row_for( + const at::Tensor& reverse_artifact_consumer_route_rows, + int64_t surface_opcode, + int64_t reverse_executor_row_index, + int64_t reverse_executor_id, + int64_t bucket_ordinal, + int64_t role_id, + const char* subject); + +inline int64_t reverse_artifact_consumer_forward_route_row_for( + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_route_rows, + int64_t surface_opcode, + int64_t reverse_executor_row_index, + int64_t reverse_executor_id, + int64_t bucket_ordinal, + int64_t role_id, + int64_t schema_version, + const char* subject) { + validate_reverse_artifact_consumer_route_rows( + reverse_artifact_consumer_route_rows, + forward_artifact_route_rows, + schema_version); + const int64_t selected_forward_route_row = try_reverse_artifact_consumer_forward_route_row_for( + reverse_artifact_consumer_route_rows, + surface_opcode, + reverse_executor_row_index, + reverse_executor_id, + bucket_ordinal, + role_id, + subject); + TORCH_CHECK( + selected_forward_route_row >= 0, + subject, + " has no reverse artifact consumer route row for surface=", + surface_opcode, + "; reverse_executor_row=", + reverse_executor_row_index, + "; reverse_executor_id=", + reverse_executor_id, + "; bucket=", + bucket_ordinal, + "; role=", + role_id); + return selected_forward_route_row; +} + +inline int64_t try_reverse_artifact_consumer_forward_route_row_for( + const at::Tensor& reverse_artifact_consumer_route_rows, + int64_t surface_opcode, + int64_t reverse_executor_row_index, + int64_t reverse_executor_id, + int64_t bucket_ordinal, + int64_t role_id, + const char* subject) { + const int64_t* rows = reverse_artifact_consumer_route_rows.data_ptr(); + int64_t selected_forward_route_row = -1; + for (int64_t row_index = 0; row_index < reverse_artifact_consumer_route_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseArtifactConsumerRouteRowColumns; + if ( + row[1] != surface_opcode || + row[2] != reverse_executor_row_index || + row[3] != reverse_executor_id || + row[4] != bucket_ordinal || + row[5] != role_id) { + continue; + } + TORCH_CHECK(selected_forward_route_row < 0, subject, " has duplicate reverse artifact consumer route rows"); + selected_forward_route_row = row[6]; + } + return selected_forward_route_row; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_program.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_program.cuh new file mode 100644 index 00000000..1e04663a --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_program.cuh @@ -0,0 +1,1646 @@ +#pragma once + +#include +#include + +constexpr int64_t kRegisteredBackwardMemoryStageEntry = 1; +constexpr int64_t kRegisteredBackwardMemoryStageAfterGradCellsSeed = 2; +constexpr int64_t kRegisteredBackwardMemoryStageAfterReadout = 3; +constexpr int64_t kRegisteredBackwardMemoryStageAfterOutputMessage = 4; +constexpr int64_t kRegisteredBackwardMemoryStageAfterRecurrentKv = 5; +constexpr int64_t kRegisteredBackwardMemoryStageAfterFrontOutputs = 6; +constexpr int64_t kRegisteredBackwardMemoryStageAfterTransition = 7; +constexpr int64_t kRegisteredBackwardMemoryStageAfterRecurrentMsgBuffer = 8; +constexpr int64_t kRegisteredBackwardMemoryStageAfterRecurrentMessage = 9; +constexpr int64_t kRegisteredBackwardMemoryStageAfterBoundaryKv = 10; +constexpr int64_t kRegisteredBackwardMemoryStageAfterInitialRecurrentKv = 11; +constexpr int64_t kRegisteredBackwardMemoryStageAfterBoundaryOutputs = 12; +constexpr int64_t kRegisteredBackwardMemoryStageAfterStepReturn = 13; +constexpr int64_t kRegisteredBackwardMemoryStageAfterSeedUpdate = 14; +constexpr int64_t kRegisteredBackwardMemoryStageAfterCarryUpdate = 15; +constexpr int64_t kRegisteredBackwardMemoryStageAfterStableAppend = 16; +constexpr int64_t kRegisteredBackwardMemoryStageReturn = 17; +constexpr int64_t kRegisteredBackwardMemoryStageAfterTransitionKeepSlots = 18; + +inline void append_registered_backward_memory_stage_row( + std::vector* rows, + const at::Tensor& reference, + int64_t local_step, + int64_t stage_id) { + if (rows == nullptr || !reference.defined() || !reference.is_cuda()) { + return; + } + const auto device_index = static_cast(reference.get_device()); + const auto stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_index); + const size_t aggregate = static_cast(c10::CachingAllocator::StatType::AGGREGATE); + rows->push_back(local_step); + rows->push_back(stage_id); + rows->push_back(stats.allocated_bytes[aggregate].current); + rows->push_back(stats.reserved_bytes[aggregate].current); + rows->push_back(stats.allocated_bytes[aggregate].peak); +} + +inline at::Tensor registered_backward_memory_stage_rows_tensor( + const std::vector& rows) { + at::Tensor tensor = at::empty( + {static_cast(rows.size() / 5), 5}, + at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + if (!rows.empty()) { + std::memcpy( + tensor.data_ptr(), + rows.data(), + rows.size() * sizeof(int64_t)); + } + return tensor; +} + +inline void validate_registered_reverse_span_output_rows( + const at::Tensor& reverse_span_output_rows, + int64_t schema_version) { + check_cpu_long_rank2( + reverse_span_output_rows, + "reverse_span_output_rows", + kReverseSpanOutputRowColumns); + const int64_t* rows = reverse_span_output_rows.data_ptr(); + for (int64_t row_index = 0; row_index < reverse_span_output_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseSpanOutputRowColumns; + TORCH_CHECK(row[0] == row_index, "reverse_span_output_rows must be densely indexed"); + TORCH_CHECK( + row[1] == kReverseSpanOutputFrontGroup || row[1] == kReverseSpanOutputBoundaryGroup, + "reverse_span_output_rows row ", + row_index, + " has invalid output group"); + TORCH_CHECK(row[2] > 0, "reverse_span_output_rows row ", row_index, " has no role id"); + TORCH_CHECK(row[3] >= 0, "reverse_span_output_rows row ", row_index, " has invalid local slot"); + TORCH_CHECK( + row[4] == 0 || row[4] == 1, + "reverse_span_output_rows row ", + row_index, + " has invalid required flag"); + TORCH_CHECK( + row[5] == schema_version, + "reverse_span_output_rows row ", + row_index, + " has unsupported schema version"); + } +} + +inline void validate_registered_reverse_program_stage_rows( + const at::Tensor& reverse_program_stage_rows, + int64_t schema_version) { + check_cpu_long_rank2(reverse_program_stage_rows, "reverse_program_stage_rows", 10); + const int64_t* rows = reverse_program_stage_rows.data_ptr(); + bool saw_transition_stage = false; + bool saw_boundary_stage = false; + bool saw_parameter_reducer_stage = false; + for (int64_t row_index = 0; row_index < reverse_program_stage_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + TORCH_CHECK(row[0] == row_index, "reverse_program_stage_rows must be densely indexed"); + TORCH_CHECK(row[1] >= 1 && row[1] <= 5, "reverse_program_stage_rows row ", row_index, " has invalid stage kind"); + TORCH_CHECK(row[2] >= 0, "reverse_program_stage_rows row ", row_index, " has invalid surface opcode"); + TORCH_CHECK(row[3] >= 0, "reverse_program_stage_rows row ", row_index, " has invalid executor row index"); + TORCH_CHECK(row[4] > 0, "reverse_program_stage_rows row ", row_index, " has invalid executor id"); + TORCH_CHECK(row[5] >= 0, "reverse_program_stage_rows row ", row_index, " has invalid primitive start"); + TORCH_CHECK(row[6] > 0, "reverse_program_stage_rows row ", row_index, " has invalid primitive count"); + TORCH_CHECK(row[7] >= -3, "reverse_program_stage_rows row ", row_index, " has invalid bucket ordinal"); + TORCH_CHECK(row[8] >= 0, "reverse_program_stage_rows row ", row_index, " has invalid dependency mask"); + TORCH_CHECK( + row[9] == 0 || row[9] == schema_version, + "reverse_program_stage_rows row ", + row_index, + " has unsupported schema version"); + saw_transition_stage = saw_transition_stage || row[1] == 3; + saw_boundary_stage = saw_boundary_stage || row[1] == 4; + saw_parameter_reducer_stage = saw_parameter_reducer_stage || row[1] == 5; + } + TORCH_CHECK(saw_transition_stage, "reverse_program_stage_rows are missing transition stage ownership"); + TORCH_CHECK(saw_boundary_stage, "reverse_program_stage_rows are missing message-boundary stage ownership"); + TORCH_CHECK(saw_parameter_reducer_stage, "reverse_program_stage_rows are missing parameter-reducer stage ownership"); +} + +inline std::vector declared_reverse_span_output_group( + const at::Tensor& reverse_span_output_rows, + int64_t group_opcode, + const std::vector>& candidates, + const at::Tensor& empty, + int64_t schema_version, + const char* subject) { + validate_registered_reverse_span_output_rows(reverse_span_output_rows, schema_version); + const int64_t* rows = reverse_span_output_rows.data_ptr(); + int64_t max_slot = -1; + int64_t group_row_count = 0; + std::set seen_slots; + std::set seen_roles; + for (int64_t row_index = 0; row_index < reverse_span_output_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseSpanOutputRowColumns; + if (row[1] != group_opcode) { + continue; + } + ++group_row_count; + TORCH_CHECK(seen_slots.insert(row[3]).second, subject, " has duplicate compiler output slot"); + TORCH_CHECK(seen_roles.insert(row[2]).second, subject, " has duplicate compiler output role"); + max_slot = std::max(max_slot, row[3]); + } + TORCH_CHECK(group_row_count > 0, subject, " has no compiler-declared output rows"); + std::vector outputs(static_cast(max_slot + 1), empty); + for (int64_t row_index = 0; row_index < reverse_span_output_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseSpanOutputRowColumns; + if (row[1] != group_opcode) { + continue; + } + const int64_t role_id = row[2]; + const int64_t slot = row[3]; + const bool required = row[4] != 0; + const at::Tensor* selected = nullptr; + for (const auto& candidate : candidates) { + if (candidate.first == role_id) { + TORCH_CHECK(selected == nullptr, subject, " has duplicate candidate tensor for compiler output role"); + selected = &candidate.second; + } + } + if (selected == nullptr || !selected->defined()) { + TORCH_CHECK(!required, subject, " is missing required compiler-declared output role ", role_id); + outputs[static_cast(slot)] = empty; + continue; + } + outputs[static_cast(slot)] = *selected; + } + return outputs; +} + +inline const at::Tensor& reverse_span_output_tensor_for_role( + const std::vector& outputs, + const at::Tensor& reverse_span_output_rows, + int64_t group_opcode, + int64_t role_id, + int64_t schema_version, + const char* subject) { + validate_registered_reverse_span_output_rows(reverse_span_output_rows, schema_version); + const int64_t* rows = reverse_span_output_rows.data_ptr(); + const int64_t* selected = nullptr; + for (int64_t row_index = 0; row_index < reverse_span_output_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseSpanOutputRowColumns; + if (row[1] == group_opcode && row[2] == role_id) { + TORCH_CHECK(selected == nullptr, subject, " has duplicate compiler output role"); + selected = row; + } + } + TORCH_CHECK(selected != nullptr, subject, " has no compiler output row for role ", role_id); + TORCH_CHECK( + selected[3] >= 0 && selected[3] < static_cast(outputs.size()), + subject, + " compiler output row points outside output group"); + return outputs[static_cast(selected[3])]; +} + +inline at::Tensor reverse_output_or_empty( + const std::vector& outputs, + size_t slot, + const at::Tensor& empty) { + return slot < outputs.size() ? outputs[slot] : empty; +} + +inline at::Tensor stable_reverse_program_output_tensor(const at::Tensor& tensor, bool clone_for_stable_return) { + if (!tensor.defined()) { + return tensor; + } + if (!clone_for_stable_return) { + return tensor.contiguous(); + } + return tensor.contiguous().clone(); +} + +inline std::vector stable_reverse_program_output_group( + const std::vector& outputs, + bool clone_for_stable_return) { + std::vector stable; + stable.reserve(outputs.size()); + for (const at::Tensor& output : outputs) { + stable.push_back(stable_reverse_program_output_tensor(output, clone_for_stable_return)); + } + return stable; +} + +inline void append_stable_reverse_program_output_groups( + std::vector>& destination, + const std::vector>& source, + bool clone_for_stable_return) { + destination.reserve(destination.size() + source.size()); + for (const std::vector& output_group : source) { + destination.push_back(stable_reverse_program_output_group(output_group, clone_for_stable_return)); + } +} + +std::vector> registered_temporal_fused_reverse_program_step_impl( + const at::Tensor& grad_output_window, + const at::Tensor& grad_carry_cells, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_merge_rows, + const at::Tensor& forward_output_route_rows, + std::vector reverse_reset_tensors, + const at::Tensor& reverse_reset_rows, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& reverse_span_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& transition_reverse_seed_role_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& graph_to_backend_order, + const at::Tensor& backend_to_graph_inverse_order, + const at::Tensor& output_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& output_neighbor_idx, + const at::Tensor& output_neighbor_valid, + const at::Tensor& output_edge_distance, + const at::Tensor& output_edge_delay, + std::vector> transition_program_tensor_groups, + std::vector transition_program_tensor_binding_row_groups, + std::vector transition_forward_executor_row_groups, + std::vector transition_reverse_executor_row_groups, + std::vector transition_forward_executor_binding_row_groups, + std::vector transition_reverse_executor_binding_row_groups, + std::vector transition_memory_liveness_row_groups, + std::vector> transition_seed_tensor_groups, + std::vector transition_seed_row_groups, + std::vector transition_dynamic_binding_row_groups, + std::vector transition_output_keep_slot_row_groups, + std::vector transition_parameter_tensors, + const at::Tensor& transition_parameter_rows, + const at::Tensor& transition_recurrent_msg_output_rows, + const at::Tensor& transition_public_y_seed_rows, + const at::Tensor& transition_state_reset_rows, + const at::Tensor& recurrent_local_sender_idx, + const at::Tensor& recurrent_neighbor_idx, + const at::Tensor& recurrent_neighbor_valid, + const at::Tensor& recurrent_edge_distance, + const at::Tensor& recurrent_edge_delay, + int64_t local_step, + int64_t message_step_index, + int64_t input_count, + int64_t recurrent_count, + double distance_scale, + bool use_sparse_messages, + bool use_delay, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + bool return_boundary_grad, + bool return_transition_state_grads, + bool return_initial_recurrent_hidden_grad, + std::vector* memory_stage_rows, + int64_t schema_version) { + check_cuda_float_rank4(grad_output_window, "fused reverse full step grad_output_window"); + const int64_t B = grad_output_window.size(0); + const int64_t local_time_steps = grad_output_window.size(1); + const int64_t output_count = grad_output_window.size(2); + const int64_t hidden = grad_output_window.size(3); + TORCH_CHECK(local_time_steps > 0, "fused reverse full step requires a non-empty gradient window"); + TORCH_CHECK(local_step >= 0 && local_step < local_time_steps, "fused reverse full step local_step is out of range"); + + std::vector decoded = validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); + validate_registered_fused_backward_output_grad_reverse_span(decoded[2], output_count); + const std::vector tensor_required = + validate_temporal_reverse_artifact_role_rows(reverse_artifact_role_rows); + validate_temporal_reverse_artifact_access_rows(reverse_artifact_access_rows, tensor_required); + validate_temporal_reverse_artifact_binding_rows( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + tensor_required, + local_time_steps); + validate_forward_artifact_route_rows(forward_artifact_route_rows, schema_version); + validate_forward_artifact_merge_rows(forward_artifact_merge_rows, forward_artifact_route_rows, schema_version); + validate_reverse_artifact_consumer_route_rows( + reverse_artifact_consumer_route_rows, + forward_artifact_route_rows, + schema_version); + validate_temporal_reverse_reset_rows(reverse_reset_tensors, reverse_reset_rows, B); + check_reverse_program_access_rows(reverse_program_access_rows); + at::Tensor message_reset = reverse_reset_tensor_for_kind( + reverse_reset_tensors, + reverse_reset_rows, + kReverseResetMessage); + at::Tensor transition_reset = reverse_reset_tensor_for_kind( + reverse_reset_tensors, + reverse_reset_rows, + kReverseResetTransition); + + const int64_t total_cells = input_count + recurrent_count + output_count; + const int64_t output_start = input_count + recurrent_count; + TORCH_CHECK( + output_start >= 0 && output_start + output_count <= total_cells, + "fused reverse full step output slice exceeds full cell bank"); + + at::Tensor grad_cells_out = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleReverseGradCellsWork, + 0, + {B, total_cells, hidden}, + "fused reverse full step grad cells workspace"); + grad_cells_out.zero_(); + grad_cells_out.slice(1, output_start, output_start + output_count) + .copy_(grad_output_window.select(1, local_step)); + if (grad_carry_cells.defined() && grad_carry_cells.numel() > 0) { + check_cuda_float_bank(grad_carry_cells, "fused reverse full step grad_carry_cells"); + TORCH_CHECK(grad_carry_cells.sizes() == grad_cells_out.sizes(), "fused reverse full step carry shape mismatch"); + grad_cells_out.add_(grad_carry_cells); + } + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterGradCellsSeed); + + std::vector> readout_span_outputs; + std::vector> output_message_span_outputs; + std::vector> recurrent_kv_span_outputs; + std::vector> recurrent_message_span_outputs; + std::vector> boundary_kv_span_outputs; + std::vector> initial_recurrent_kv_span_outputs; + + std::vector readout = registered_temporal_backward_readout_step_impl( + grad_cells_out, + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_role_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + forward_artifact_merge_rows, + forward_output_route_rows, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + graph_to_backend_order, + local_step, + input_count, + recurrent_count, + schema_version, + &readout_span_outputs); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterReadout); + std::vector output_message = + registered_temporal_backward_output_message_step_impl( + readout[2], + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_role_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + forward_artifact_merge_rows, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + output_local_sender_idx, + local_distance, + local_delay, + output_neighbor_idx, + output_neighbor_valid, + output_edge_distance, + output_edge_delay, + local_step, + message_step_index, + distance_scale, + use_sparse_messages, + use_delay, + schema_version, + &output_message_span_outputs); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterOutputMessage); + std::vector recurrent_kv = + registered_temporal_backward_recurrent_kv_projection_step_impl( + output_message[3], + output_message[4], + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_role_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + forward_artifact_merge_rows, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + backend_to_graph_inverse_order, + local_step, + head_dim, + value_dim, + schema_version, + &recurrent_kv_span_outputs); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterRecurrentKv); + const at::Tensor empty = grad_output_window.new_empty({0}); + std::vector front_outputs = declared_reverse_span_output_group( + reverse_span_output_rows, + kReverseSpanOutputFrontGroup, + { + {registered_temporal_stable_id_hash_constexpr("grad_boundary_direct"), readout[0]}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_hidden_backend_direct"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_value_to_output_weight"), readout[3]}, + {registered_temporal_stable_id_hash_constexpr("grad_output_cell_bias"), readout[4]}, + {registered_temporal_stable_id_hash_constexpr("grad_output_q"), output_message[0]}, + {registered_temporal_stable_id_hash_constexpr("grad_input_k_from_output"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_input_v_from_output"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_hidden_from_kv_graph_order"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_kv_weight_graph_order"), recurrent_kv[1]}, + }, + empty, + schema_version, + "fused reverse full step front outputs"); + at::Tensor grad_recurrent_hidden_backend_direct = readout[1]; + at::Tensor grad_recurrent_hidden_from_kv_graph_order = recurrent_kv[0]; + at::Tensor grad_input_k_from_output = output_message[1]; + at::Tensor grad_input_v_from_output = output_message[2]; + if (readout_span_outputs.size() <= 1) { + readout_span_outputs.clear(); + output_message_span_outputs.clear(); + } + if (recurrent_kv_span_outputs.size() <= 1) { + recurrent_kv_span_outputs.clear(); + } + readout.clear(); + output_message.clear(); + recurrent_kv.clear(); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterFrontOutputs); + at::Tensor grad_recurrent_hidden_backend; + if (grad_recurrent_hidden_backend_direct.defined() && grad_recurrent_hidden_backend_direct.numel() > 0) { + grad_recurrent_hidden_backend = grad_recurrent_hidden_backend_direct; + } + if (grad_recurrent_hidden_from_kv_graph_order.defined() && grad_recurrent_hidden_from_kv_graph_order.numel() > 0) { + at::Tensor order = graph_to_backend_order.to(grad_recurrent_hidden_from_kv_graph_order.device(), at::kLong).contiguous(); + at::Tensor grad_from_kv_backend = grad_recurrent_hidden_from_kv_graph_order.index_select(1, order); + if (grad_recurrent_hidden_backend.defined() && grad_recurrent_hidden_backend.numel() > 0) { + grad_recurrent_hidden_backend = grad_recurrent_hidden_backend + grad_from_kv_backend; + } else { + grad_recurrent_hidden_backend = grad_from_kv_backend; + } + } + + bool has_state_seed = false; + for (const at::Tensor& seed_rows : transition_seed_row_groups) { + check_cpu_long_rank2(seed_rows, "transition_seed_rows", 3); + has_state_seed = has_state_seed || seed_rows.size(0) > 0; + } + std::vector> outputs; + outputs.push_back(front_outputs); + auto append_readout_front_span_groups = [&]() { + TORCH_CHECK( + readout_span_outputs.size() == output_message_span_outputs.size(), + "fused reverse full step readout span output groups must align"); + if (readout_span_outputs.size() <= 1) { + return; + } + for (size_t span_index = 0; span_index < readout_span_outputs.size(); ++span_index) { + const std::vector& readout_span = readout_span_outputs[span_index]; + const std::vector& output_message_span = output_message_span_outputs[span_index]; + outputs.push_back(declared_reverse_span_output_group( + reverse_span_output_rows, + kReverseSpanOutputFrontGroup, + { + {registered_temporal_stable_id_hash_constexpr("grad_boundary_direct"), reverse_output_or_empty(readout_span, 0, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_hidden_backend_direct"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_value_to_output_weight"), reverse_output_or_empty(readout_span, 3, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_output_cell_bias"), reverse_output_or_empty(readout_span, 4, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_output_q"), reverse_output_or_empty(output_message_span, 0, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_input_k_from_output"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_input_v_from_output"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_hidden_from_kv_graph_order"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_kv_weight_graph_order"), empty}, + }, + empty, + schema_version, + "fused reverse full step readout per-span front outputs")); + } + }; + auto append_message_front_span_groups = [&]() { + if (recurrent_kv_span_outputs.size() <= 1) { + return; + } + for (const std::vector& message_kv_span : recurrent_kv_span_outputs) { + outputs.push_back(declared_reverse_span_output_group( + reverse_span_output_rows, + kReverseSpanOutputFrontGroup, + { + {registered_temporal_stable_id_hash_constexpr("grad_boundary_direct"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_hidden_backend_direct"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_value_to_output_weight"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_output_cell_bias"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_output_q"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_input_k_from_output"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_input_v_from_output"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_hidden_from_kv_graph_order"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_kv_weight_graph_order"), reverse_output_or_empty(message_kv_span, 1, empty)}, + }, + empty, + schema_version, + "fused reverse full step message per-span front outputs")); + } + }; + auto append_message_boundary_span_groups = [&]() { + if (recurrent_kv_span_outputs.size() <= 1) { + return; + } + TORCH_CHECK( + recurrent_message_span_outputs.empty() || + (recurrent_message_span_outputs.size() == recurrent_kv_span_outputs.size() && + boundary_kv_span_outputs.size() == recurrent_kv_span_outputs.size() && + initial_recurrent_kv_span_outputs.size() == recurrent_kv_span_outputs.size()), + "fused reverse full step message span output groups must align"); + for (size_t span_index = 0; span_index < recurrent_kv_span_outputs.size(); ++span_index) { + const std::vector empty_span; + const std::vector& recurrent_message_span = + recurrent_message_span_outputs.empty() ? empty_span : recurrent_message_span_outputs[span_index]; + const std::vector& boundary_kv_span = + boundary_kv_span_outputs.empty() ? empty_span : boundary_kv_span_outputs[span_index]; + const std::vector& initial_kv_span = + initial_recurrent_kv_span_outputs.empty() ? empty_span : initial_recurrent_kv_span_outputs[span_index]; + at::Tensor span_grad_input_k = reverse_output_or_empty(recurrent_message_span, 1, empty); + if (span_grad_input_k.defined() && span_grad_input_k.numel() > 0 && + grad_input_k_from_output.defined() && grad_input_k_from_output.numel() > 0) { + if (span_grad_input_k.dim() == 2) { + TORCH_CHECK( + grad_input_k_from_output.dim() == 3 && + grad_input_k_from_output.size(1) == span_grad_input_k.size(0) && + grad_input_k_from_output.size(2) == span_grad_input_k.size(1), + "fused reverse full step per-span reduced input K gradient shapes do not match"); + span_grad_input_k = span_grad_input_k + grad_input_k_from_output.sum(0).contiguous(); + } else { + span_grad_input_k = span_grad_input_k + grad_input_k_from_output; + } + } else if (grad_input_k_from_output.defined() && grad_input_k_from_output.numel() > 0) { + span_grad_input_k = grad_input_k_from_output; + } + at::Tensor span_initial_hidden = reverse_output_or_empty(initial_kv_span, 0, empty); + if (span_initial_hidden.defined() && span_initial_hidden.numel() > 0) { + span_initial_hidden = zero_batch_rows_for_reset( + span_initial_hidden, + message_reset, + "fused reverse full step per-span initial recurrent hidden grad"); + } + outputs.push_back(declared_reverse_span_output_group( + reverse_span_output_rows, + kReverseSpanOutputBoundaryGroup, + { + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_q_backend"), reverse_output_or_empty(recurrent_message_span, 0, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_boundary_from_projection_raw"), reverse_output_or_empty(boundary_kv_span, 0, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_input_kv_weight"), reverse_output_or_empty(boundary_kv_span, 1, empty)}, + {registered_temporal_stable_id_hash_constexpr("input_kv_grouped_flag"), reverse_output_or_empty(boundary_kv_span, 2, at::zeros({1}, empty.options().dtype(at::kLong)))}, + {registered_temporal_stable_id_hash_constexpr("grad_hidden_graph_order"), span_initial_hidden}, + {registered_temporal_stable_id_hash_constexpr("grad_initial_recurrent_kv_weight_graph_order"), reverse_output_or_empty(initial_kv_span, 1, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_query_context_scalar"), reverse_output_or_empty(recurrent_message_span, 5, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_output_weight"), reverse_output_or_empty(recurrent_message_span, 6, empty)}, + {registered_temporal_stable_id_hash_constexpr("grad_input_key_bank"), span_grad_input_k}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_key_bank"), reverse_output_or_empty(recurrent_message_span, 3, empty)}, + }, + empty, + schema_version, + "fused reverse full step message per-span boundary outputs")); + } + }; + if ((!grad_recurrent_hidden_backend.defined() || grad_recurrent_hidden_backend.numel() == 0) && !has_state_seed) { + outputs.push_back(declared_reverse_span_output_group( + reverse_span_output_rows, + kReverseSpanOutputBoundaryGroup, + { + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_q_backend"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_boundary_from_projection_raw"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_input_kv_weight"), empty}, + {registered_temporal_stable_id_hash_constexpr("input_kv_grouped_flag"), at::zeros({1}, empty.options().dtype(at::kLong))}, + {registered_temporal_stable_id_hash_constexpr("grad_hidden_graph_order"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_initial_recurrent_kv_weight_graph_order"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_query_context_scalar"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_output_weight"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_input_key_bank"), empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_key_bank"), empty}, + }, + empty, + schema_version, + "fused reverse full step boundary outputs")); + append_readout_front_span_groups(); + append_message_front_span_groups(); + append_message_boundary_span_groups(); + return outputs; + } + + check_cpu_long_rank2(transition_public_y_seed_rows, "transition_public_y_seed_rows", 4); + TORCH_CHECK( + transition_public_y_seed_rows.size(0) == static_cast(transition_seed_tensor_groups.size()), + "transition public-y seed rows must align with transition seed groups"); + TORCH_CHECK( + transition_seed_tensor_groups.size() == transition_seed_row_groups.size(), + "transition seed tensor/row groups must align"); + + auto filter_transition_outputs_by_keep_slots = [&]( + std::vector>& transition_outputs) { + TORCH_CHECK( + transition_output_keep_slot_row_groups.size() == transition_outputs.size(), + "fused reverse full step transition output keep-slot row groups must align"); + for (size_t group_index = 0; group_index < transition_outputs.size(); ++group_index) { + std::vector& group = transition_outputs[group_index]; + check_cpu_long_rank2( + transition_output_keep_slot_row_groups[group_index], + "transition_output_keep_slot_rows", + 1); + at::Tensor empty; + for (const at::Tensor& tensor : group) { + if (tensor.defined()) { + empty = tensor.new_empty({0}); + break; + } + } + std::vector kept(group.size(), empty); + const int64_t* keep_rows = transition_output_keep_slot_row_groups[group_index].data_ptr(); + for (int64_t row_index = 0; row_index < transition_output_keep_slot_row_groups[group_index].size(0); ++row_index) { + const int64_t slot = keep_rows[row_index]; + TORCH_CHECK( + 0 <= slot && slot < static_cast(group.size()), + "transition output keep-slot row references an invalid output slot"); + kept[static_cast(slot)] = group[static_cast(slot)]; + } + group.swap(kept); + } + }; + + if (grad_recurrent_hidden_backend.defined() && grad_recurrent_hidden_backend.numel() > 0) { + check_cuda_float_bank(grad_recurrent_hidden_backend, "fused reverse full step grad_recurrent_hidden_backend"); + check_cpu_long_rank2( + transition_reverse_seed_role_rows, + "transition_reverse_seed_role_rows", + kTransitionReverseSeedRoleRowColumns); + const int64_t* seed_role_rows = transition_reverse_seed_role_rows.data_ptr(); + auto seed_role_is_registered = [&](int64_t role_id) { + for (int64_t role_index = 0; role_index < transition_reverse_seed_role_rows.size(0); ++role_index) { + const int64_t* role = seed_role_rows + role_index * kTransitionReverseSeedRoleRowColumns; + if (role[0] == role_id) { + return true; + } + } + return false; + }; + const int64_t* public_rows = transition_public_y_seed_rows.data_ptr(); + for (size_t group_index = 0; group_index < transition_seed_tensor_groups.size(); ++group_index) { + const int64_t* public_row = public_rows + static_cast(group_index) * 4; + const int64_t bucket_ordinal = public_row[0]; + const int64_t bucket_start = public_row[1]; + const int64_t bucket_stop = public_row[2]; + const int64_t public_seed_role = public_row[3]; + TORCH_CHECK( + seed_role_is_registered(public_seed_role), + "transition public-y seed row has unregistered seed role"); + TORCH_CHECK( + 0 <= bucket_start && bucket_start < bucket_stop && + bucket_stop <= grad_recurrent_hidden_backend.size(1), + "transition public-y seed row has invalid recurrent backend range"); + at::Tensor public_seed = + grad_recurrent_hidden_backend.slice(1, bucket_start, bucket_stop).contiguous(); + const at::Tensor& old_rows = transition_seed_row_groups[group_index]; + at::Tensor new_rows = at::empty({old_rows.size(0) + 1, 3}, old_rows.options()); + if (old_rows.size(0) > 0) { + new_rows.slice(0, 0, old_rows.size(0)).copy_(old_rows); + } + auto new_rows_acc = new_rows.accessor(); + new_rows_acc[old_rows.size(0)][0] = public_seed_role; + new_rows_acc[old_rows.size(0)][1] = static_cast(transition_seed_tensor_groups[group_index].size()); + new_rows_acc[old_rows.size(0)][2] = bucket_ordinal; + transition_seed_tensor_groups[group_index].push_back(public_seed); + transition_seed_row_groups[group_index] = new_rows; + } + } + + const size_t group_count = transition_program_tensor_groups.size(); + TORCH_CHECK(group_count > 0, "fused reverse full step requires at least one transition executor group"); + TORCH_CHECK( + transition_program_tensor_binding_row_groups.size() == group_count && + transition_forward_executor_row_groups.size() == group_count && + transition_reverse_executor_row_groups.size() == group_count && + transition_forward_executor_binding_row_groups.size() == group_count && + transition_reverse_executor_binding_row_groups.size() == group_count && + transition_memory_liveness_row_groups.size() == group_count && + transition_seed_tensor_groups.size() == group_count && + transition_seed_row_groups.size() == group_count && + transition_dynamic_binding_row_groups.size() == group_count && + transition_output_keep_slot_row_groups.size() == group_count, + "fused reverse full step transition group table sizes must match"); + std::vector> transition_and_boundary_outputs; + transition_and_boundary_outputs.reserve(group_count + 1); + for (size_t group_index = 0; group_index < group_count; ++group_index) { + transition_and_boundary_outputs.push_back(registered_temporal_backward_transition_group_impl( + transition_program_tensor_groups[group_index], + transition_program_tensor_binding_row_groups[group_index], + primitive_rows, + transition_forward_executor_row_groups[group_index], + transition_reverse_executor_row_groups[group_index], + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + transition_primitive_callable_rows, + transition_forward_executor_binding_row_groups[group_index], + transition_reverse_executor_binding_row_groups[group_index], + transition_memory_liveness_row_groups[group_index], + runtime_buffer_tensors, + runtime_buffer_rows, + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_role_rows, + reverse_artifact_access_rows, + transition_seed_tensor_groups[group_index], + transition_reverse_seed_role_rows, + transition_seed_row_groups[group_index], + transition_dynamic_binding_row_groups[group_index], + transition_parameter_tensors, + transition_parameter_rows, + local_step, + memory_stage_rows, + schema_version, + return_transition_state_grads)); + } + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterTransition); + apply_transition_state_reset_outputs( + transition_and_boundary_outputs, + transition_state_reset_rows, + transition_reset); + filter_transition_outputs_by_keep_slots(transition_and_boundary_outputs); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterTransitionKeepSlots); + check_cpu_long_rank2(transition_recurrent_msg_output_rows, "transition_recurrent_msg_output_rows", 4); + std::vector recurrent_msg_output_row_by_group(transition_and_boundary_outputs.size(), nullptr); + TORCH_CHECK(recurrent_count > 0, "fused reverse full step requires positive recurrent_count"); + TORCH_CHECK(value_dim > 0, "fused reverse full step requires positive value_dim"); + const int64_t* output_rows = transition_recurrent_msg_output_rows.data_ptr(); + for (int64_t row_index = 0; row_index < transition_recurrent_msg_output_rows.size(0); ++row_index) { + const int64_t* row = output_rows + row_index * 4; + const int64_t group_index = row[0]; + TORCH_CHECK( + group_index >= 0 && group_index < static_cast(transition_and_boundary_outputs.size()), + "fused reverse full step has invalid compiler recurrent-message output group"); + TORCH_CHECK( + recurrent_msg_output_row_by_group[static_cast(group_index)] == nullptr, + "fused reverse full step has duplicate compiler recurrent-message output rows for group ", + group_index); + recurrent_msg_output_row_by_group[static_cast(group_index)] = row; + } + for (size_t group_index = 0; group_index < recurrent_msg_output_row_by_group.size(); ++group_index) { + TORCH_CHECK( + recurrent_msg_output_row_by_group[group_index] != nullptr, + "fused reverse full step is missing compiler recurrent-message output row for group ", + static_cast(group_index)); + } + at::Tensor grad_recurrent_msg; + const bool has_recurrent_msg_runtime_buffer = registered_runtime_buffer_has_role( + runtime_buffer_rows, + kRuntimeBufferRoleReverseGradRecurrentMsg, + 0); + if (has_recurrent_msg_runtime_buffer) { + grad_recurrent_msg = registered_runtime_buffer_for_role_any_shape( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleReverseGradRecurrentMsg, + 0, + "fused reverse full step grad recurrent message"); + TORCH_CHECK( + grad_recurrent_msg.dim() == 3 && grad_recurrent_msg.size(0) == B && + grad_recurrent_msg.size(1) == recurrent_count && grad_recurrent_msg.size(2) > 0, + "fused reverse full step grad recurrent message compiler runtime buffer has invalid shape"); + grad_recurrent_msg.zero_(); + } else { + TORCH_CHECK( + transition_and_boundary_outputs.size() == 1 && transition_recurrent_msg_output_rows.size(0) == 1, + "fused reverse full step can elide reverse_grad_recurrent_msg only for a singleton full transition output"); + } + int64_t recurrent_message_dim = -1; + for (size_t group_index = 0; group_index < transition_and_boundary_outputs.size(); ++group_index) { + const int64_t* row = recurrent_msg_output_row_by_group[group_index]; + const int64_t slot = row[1]; + const int64_t bucket_start = row[2]; + const int64_t bucket_stop = row[3]; + TORCH_CHECK( + slot >= 0 && slot < static_cast(transition_and_boundary_outputs[group_index].size()), + "fused reverse full step has invalid transition output slot"); + TORCH_CHECK( + 0 <= bucket_start && bucket_start < bucket_stop && bucket_stop <= recurrent_count, + "fused reverse full step has invalid transition bucket range"); + const at::Tensor& bucket_grad = transition_and_boundary_outputs[group_index][static_cast(slot)]; + check_cuda_float_bank(bucket_grad, "fused reverse full step grad_aggregated_message"); + if (recurrent_message_dim < 0) { + recurrent_message_dim = bucket_grad.size(2); + } + TORCH_CHECK( + bucket_grad.size(2) == recurrent_message_dim, + "fused reverse full step grad_aggregated_message output dimensions must match"); + TORCH_CHECK( + bucket_grad.size(1) == bucket_stop - bucket_start && bucket_grad.size(2) == recurrent_message_dim, + "fused reverse full step grad_aggregated_message shape does not match bucket range"); + TORCH_CHECK( + bucket_grad.size(0) == B, + "fused reverse full step grad_aggregated_message batch mismatch"); + if (has_recurrent_msg_runtime_buffer) { + grad_recurrent_msg.slice(1, bucket_start, bucket_stop).copy_(bucket_grad); + } else { + TORCH_CHECK( + bucket_start == 0 && bucket_stop == recurrent_count, + "fused reverse full step direct recurrent-message gradient must cover the full recurrent bank"); + grad_recurrent_msg = bucket_grad; + } + } + TORCH_CHECK( + grad_recurrent_msg.defined() && recurrent_message_dim > 0, + "fused reverse full step did not materialize a recurrent-message gradient input"); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterRecurrentMsgBuffer); + std::vector>* recurrent_message_span_output_sink = + recurrent_kv_span_outputs.empty() ? nullptr : &recurrent_message_span_outputs; + std::vector>* boundary_kv_span_output_sink = + recurrent_kv_span_outputs.empty() ? nullptr : &boundary_kv_span_outputs; + std::vector>* initial_recurrent_kv_span_output_sink = + recurrent_kv_span_outputs.empty() ? nullptr : &initial_recurrent_kv_span_outputs; + std::vector recurrent_message = registered_temporal_backward_recurrent_message_step_impl( + grad_recurrent_msg, + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_role_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + forward_artifact_merge_rows, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + recurrent_local_sender_idx, + local_distance, + local_delay, + recurrent_neighbor_idx, + recurrent_neighbor_valid, + recurrent_edge_distance, + recurrent_edge_delay, + local_step, + message_step_index, + distance_scale, + use_sparse_messages, + use_delay, + schema_version, + recurrent_message_span_output_sink); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterRecurrentMessage); + for (size_t group_index = 0; group_index < recurrent_msg_output_row_by_group.size(); ++group_index) { + const int64_t* row = recurrent_msg_output_row_by_group[group_index]; + const int64_t slot = row[1]; + if (0 <= slot && slot < static_cast(transition_and_boundary_outputs[group_index].size())) { + transition_and_boundary_outputs[group_index][static_cast(slot)] = empty; + } + } + grad_recurrent_msg = at::Tensor(); + const bool materialize_message_key_bank_outputs = recurrent_message.size() > 5; + const bool reduced_message_key_bank_outputs = + materialize_message_key_bank_outputs && + recurrent_message.size() > 3 && + recurrent_message[1].defined() && + recurrent_message[3].defined() && + recurrent_message[1].dim() == 2 && + recurrent_message[3].dim() == 2; + at::Tensor grad_input_k = recurrent_message[1]; + at::Tensor grad_input_key_bank_for_reducer = + materialize_message_key_bank_outputs ? recurrent_message[1] : empty; + at::Tensor grad_input_v = recurrent_message[2]; + if (grad_input_k_from_output.defined() && grad_input_k_from_output.numel() > 0) { + if (reduced_message_key_bank_outputs) { + TORCH_CHECK( + grad_input_k_from_output.dim() == 3 && + grad_input_k_from_output.size(1) == grad_input_key_bank_for_reducer.size(0) && + grad_input_k_from_output.size(2) == grad_input_key_bank_for_reducer.size(1), + "fused reverse full step reduced input K gradient shapes do not match"); + grad_input_key_bank_for_reducer = + grad_input_key_bank_for_reducer + grad_input_k_from_output.sum(0).contiguous(); + grad_input_k = grad_input_k_from_output; + } else { + TORCH_CHECK( + grad_input_k.defined() && grad_input_k.sizes() == grad_input_k_from_output.sizes(), + "fused reverse full step input K gradient shapes do not match"); + grad_input_k.add_(grad_input_k_from_output); + grad_input_key_bank_for_reducer = grad_input_k; + } + } + if (grad_input_v_from_output.defined() && grad_input_v_from_output.numel() > 0) { + TORCH_CHECK( + grad_input_v.defined() && grad_input_v.sizes() == grad_input_v_from_output.sizes(), + "fused reverse full step input V gradient shapes do not match"); + grad_input_v.add_(grad_input_v_from_output); + } + if (reduced_message_key_bank_outputs && (!grad_input_k.defined() || grad_input_k.dim() != 3)) { + grad_input_k = grad_input_v.new_empty({0, 0, 0}); + } + std::vector boundary_kv = registered_temporal_backward_boundary_kv_projection_step_impl( + grad_input_k, + grad_input_v, + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_role_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + forward_artifact_merge_rows, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + local_step, + group_size, + head_dim, + value_dim, + return_boundary_grad, + schema_version, + boundary_kv_span_output_sink); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterBoundaryKv); + if (recurrent_message.size() > 1 && (!materialize_message_key_bank_outputs || reduced_message_key_bank_outputs)) { + recurrent_message[1] = empty; + grad_input_k = at::Tensor(); + } + if (recurrent_message.size() > 2) { + recurrent_message[2] = empty; + } + grad_input_v = at::Tensor(); + at::Tensor grad_recurrent_k_for_initial = recurrent_message[3]; + if (reduced_message_key_bank_outputs) { + grad_recurrent_k_for_initial = recurrent_message[4].new_empty({0, 0, 0}); + } + std::vector initial_recurrent_kv = + registered_temporal_backward_initial_recurrent_kv_projection_step_impl( + grad_recurrent_k_for_initial, + recurrent_message[4], + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_role_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + forward_artifact_merge_rows, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + backend_to_graph_inverse_order, + local_step, + head_dim, + value_dim, + return_initial_recurrent_hidden_grad, + schema_version, + initial_recurrent_kv_span_output_sink); + if (recurrent_message.size() > 4) { + recurrent_message[4] = empty; + } + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterInitialRecurrentKv); + at::Tensor grad_hidden_graph_order = zero_batch_rows_for_reset( + initial_recurrent_kv[0], + message_reset, + "fused reverse full step initial recurrent hidden grad"); + transition_and_boundary_outputs.push_back(declared_reverse_span_output_group( + reverse_span_output_rows, + kReverseSpanOutputBoundaryGroup, + { + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_q_backend"), recurrent_message[0]}, + {registered_temporal_stable_id_hash_constexpr("grad_boundary_from_projection_raw"), boundary_kv[0]}, + {registered_temporal_stable_id_hash_constexpr("grad_input_kv_weight"), boundary_kv[1]}, + {registered_temporal_stable_id_hash_constexpr("input_kv_grouped_flag"), boundary_kv[2]}, + {registered_temporal_stable_id_hash_constexpr("grad_hidden_graph_order"), grad_hidden_graph_order}, + {registered_temporal_stable_id_hash_constexpr("grad_initial_recurrent_kv_weight_graph_order"), initial_recurrent_kv[1]}, + {registered_temporal_stable_id_hash_constexpr("grad_query_context_scalar"), recurrent_message.size() > 5 ? recurrent_message[5] : empty}, + {registered_temporal_stable_id_hash_constexpr("grad_output_weight"), recurrent_message.size() > 6 ? recurrent_message[6] : empty}, + {registered_temporal_stable_id_hash_constexpr("grad_input_key_bank"), + materialize_message_key_bank_outputs ? grad_input_key_bank_for_reducer : empty}, + {registered_temporal_stable_id_hash_constexpr("grad_recurrent_key_bank"), + materialize_message_key_bank_outputs ? recurrent_message[3] : empty}, + }, + empty, + schema_version, + "fused reverse full step boundary outputs")); + if (recurrent_kv_span_outputs.empty()) { + recurrent_message_span_outputs.clear(); + boundary_kv_span_outputs.clear(); + initial_recurrent_kv_span_outputs.clear(); + } + recurrent_message.clear(); + boundary_kv.clear(); + initial_recurrent_kv.clear(); + append_registered_backward_memory_stage_row( + memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterBoundaryOutputs); + outputs.insert(outputs.end(), transition_and_boundary_outputs.begin(), transition_and_boundary_outputs.end()); + append_readout_front_span_groups(); + append_message_front_span_groups(); + append_message_boundary_span_groups(); + return outputs; +} + +inline std::pair>, std::vector> +transition_seed_groups_from_reverse_step_outputs( + const std::vector>& transition_outputs, + const at::Tensor& transition_reverse_seed_role_rows, + const at::Tensor& transition_next_seed_output_rows) { + check_cpu_long_rank2( + transition_reverse_seed_role_rows, + "transition_reverse_seed_role_rows", + kTransitionReverseSeedRoleRowColumns); + const int64_t* seed_role_rows = transition_reverse_seed_role_rows.data_ptr(); + auto seed_role_is_registered = [&](int64_t role_id) { + for (int64_t role_index = 0; role_index < transition_reverse_seed_role_rows.size(0); ++role_index) { + const int64_t* role = seed_role_rows + role_index * kTransitionReverseSeedRoleRowColumns; + if (role[0] == role_id) { + return true; + } + } + return false; + }; + check_cpu_long_rank2(transition_next_seed_output_rows, "transition_next_seed_output_rows", 4); + const int64_t group_count = static_cast(transition_outputs.size()); + const int64_t* rows = transition_next_seed_output_rows.data_ptr(); + std::vector> seed_tensor_groups; + std::vector seed_row_groups; + seed_tensor_groups.reserve(transition_outputs.size()); + seed_row_groups.reserve(transition_outputs.size()); + for (int64_t group_index = 0; group_index < group_count; ++group_index) { + std::vector group_tensors; + std::vector group_rows; + for (int64_t row_index = 0; row_index < transition_next_seed_output_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 4; + if (row[0] != group_index) { + continue; + } + const int64_t seed_role = row[1]; + const int64_t output_slot = row[2]; + const int64_t bucket_ordinal = row[3]; + TORCH_CHECK( + seed_role_is_registered(seed_role), + "transition_next_seed_output_rows row ", + row_index, + " has invalid seed role"); + TORCH_CHECK( + output_slot >= 0 && output_slot < static_cast(transition_outputs[static_cast(group_index)].size()), + "transition_next_seed_output_rows row ", + row_index, + " has invalid transition output slot"); + TORCH_CHECK(bucket_ordinal >= 0, "transition_next_seed_output_rows row ", row_index, " has invalid bucket"); + const at::Tensor& seed = transition_outputs[static_cast(group_index)][static_cast(output_slot)]; + if (!seed.defined() || seed.numel() == 0) { + continue; + } + group_rows.push_back(seed_role); + group_rows.push_back(static_cast(group_tensors.size())); + group_rows.push_back(bucket_ordinal); + group_tensors.push_back(stable_reverse_program_output_tensor(seed, true)); + } + at::Tensor group_row_tensor = + at::empty({static_cast(group_rows.size() / 3), 3}, transition_next_seed_output_rows.options()); + if (!group_rows.empty()) { + std::memcpy( + group_row_tensor.data_ptr(), + group_rows.data(), + static_cast(group_rows.size()) * sizeof(int64_t)); + } + seed_tensor_groups.push_back(group_tensors); + seed_row_groups.push_back(group_row_tensor); + } + return {seed_tensor_groups, seed_row_groups}; +} + +std::vector> flat_bucket_registered_temporal_fused_backward_program_cuda( + const at::Tensor& grad_output_window, + const at::Tensor& grad_carry_cells, + const at::Tensor& reverse_program_stage_rows, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& forward_artifact_merge_rows, + const at::Tensor& forward_output_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + std::vector> reverse_reset_tensor_groups, + std::vector reverse_reset_row_groups, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& reverse_span_output_rows, + const at::Tensor& transition_reverse_seed_role_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& memory_runtime_schedule_rows, + const at::Tensor& physical_strategy_rows, + std::vector runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + std::vector reverse_program_runtime_tensors, + const at::Tensor& reverse_program_runtime_rows, + std::vector> transition_program_tensor_groups, + std::vector transition_program_tensor_binding_row_groups, + std::vector transition_forward_executor_row_groups, + std::vector transition_reverse_executor_row_groups, + std::vector transition_forward_executor_binding_row_groups, + std::vector transition_reverse_executor_binding_row_groups, + std::vector transition_memory_liveness_row_groups, + std::vector> transition_seed_tensor_groups, + std::vector transition_seed_row_groups, + std::vector transition_dynamic_binding_row_groups, + std::vector transition_output_keep_slot_row_groups, + std::vector transition_parameter_tensors, + const at::Tensor& transition_parameter_rows, + const at::Tensor& transition_recurrent_msg_output_rows, + const at::Tensor& transition_public_y_seed_rows, + const at::Tensor& transition_state_reset_rows, + const at::Tensor& transition_next_seed_output_rows, + bool return_window_start_transition_state_grads, + int64_t schema_version) { + check_cuda_float_rank4(grad_output_window, "fused backward program grad_output_window"); + const int64_t B = grad_output_window.size(0); + const int64_t local_time_steps = grad_output_window.size(1); + const int64_t output_count = grad_output_window.size(2); + const int64_t hidden = grad_output_window.size(3); + TORCH_CHECK(local_time_steps > 0, "fused backward program requires a non-empty gradient window"); + TORCH_CHECK( + static_cast(reverse_reset_tensor_groups.size()) == local_time_steps && + static_cast(reverse_reset_row_groups.size()) == local_time_steps, + "fused backward program reverse reset groups must cover every local step"); + check_reverse_program_runtime_rows(reverse_program_runtime_tensors, reverse_program_runtime_rows); + validate_forward_artifact_route_rows(forward_artifact_route_rows, schema_version); + validate_forward_artifact_merge_rows(forward_artifact_merge_rows, forward_artifact_route_rows, schema_version); + validate_forward_output_route_rows(forward_output_route_rows, schema_version); + validate_reverse_artifact_consumer_route_rows( + reverse_artifact_consumer_route_rows, + forward_artifact_route_rows, + schema_version); + const at::Tensor graph_to_backend_order = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeGraphToBackendOrder, + "registered fused reverse graph_to_backend_order"); + const at::Tensor backend_to_graph_inverse_order = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeBackendToGraphInverseOrder, + "registered fused reverse backend_to_graph_inverse_order"); + const at::Tensor output_local_sender_idx = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeOutputLocalSenderIdx, + "registered fused reverse output_local_sender_idx"); + const at::Tensor local_distance = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeLocalDistance, + "registered fused reverse local_distance"); + const at::Tensor local_delay = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeLocalDelay, + "registered fused reverse local_delay"); + const at::Tensor output_neighbor_idx = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeOutputNeighborIdx, + "registered fused reverse output_neighbor_idx"); + const at::Tensor output_neighbor_valid = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeOutputNeighborValid, + "registered fused reverse output_neighbor_valid"); + const at::Tensor output_edge_distance = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeOutputEdgeDistance, + "registered fused reverse output_edge_distance"); + const at::Tensor output_edge_delay = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeOutputEdgeDelay, + "registered fused reverse output_edge_delay"); + const at::Tensor recurrent_local_sender_idx = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeRecurrentLocalSenderIdx, + "registered fused reverse recurrent_local_sender_idx"); + const at::Tensor recurrent_neighbor_idx = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeRecurrentNeighborIdx, + "registered fused reverse recurrent_neighbor_idx"); + const at::Tensor recurrent_neighbor_valid = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeRecurrentNeighborValid, + "registered fused reverse recurrent_neighbor_valid"); + const at::Tensor recurrent_edge_distance = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeRecurrentEdgeDistance, + "registered fused reverse recurrent_edge_distance"); + const at::Tensor recurrent_edge_delay = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeRecurrentEdgeDelay, + "registered fused reverse recurrent_edge_delay"); + const at::Tensor message_step_indices = reverse_program_runtime_tensor_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeMessageStepIndices, + "registered fused reverse message_step_indices"); + const int64_t input_count = reverse_program_runtime_int_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeInputCount, + "registered fused reverse input_count"); + const int64_t recurrent_count = reverse_program_runtime_int_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeRecurrentCount, + "registered fused reverse recurrent_count"); + const double distance_scale = reverse_program_runtime_double_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeDistanceScale, + "registered fused reverse distance_scale"); + const bool use_sparse_messages = reverse_program_runtime_int_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeUseSparseMessages, + "registered fused reverse use_sparse_messages") != 0; + const bool use_delay = reverse_program_runtime_int_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeUseDelay, + "registered fused reverse use_delay") != 0; + const int64_t group_size = reverse_program_runtime_int_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeGroupSize, + "registered fused reverse group_size"); + const int64_t head_dim = reverse_program_runtime_int_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeHeadDim, + "registered fused reverse head_dim"); + const int64_t value_dim = reverse_program_runtime_int_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeValueDim, + "registered fused reverse value_dim"); + const bool return_boundary_grad = reverse_program_runtime_int_for_role( + reverse_program_runtime_tensors, + reverse_program_runtime_rows, + kReverseRuntimeReturnBoundaryGrad, + "registered fused reverse return_boundary_grad") != 0; + TORCH_CHECK( + !message_step_indices.is_cuda() && message_step_indices.is_contiguous() && + message_step_indices.scalar_type() == at::kLong && message_step_indices.dim() == 1 && + message_step_indices.size(0) == local_time_steps, + "fused backward program message_step_indices must be a CPU int64 vector with one entry per local step"); + validate_registered_runtime_buffer_rows( + memory_liveness_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + "registered fused backward program"); + validate_registered_memory_runtime_schedule_rows( + memory_liveness_rows, + memory_runtime_schedule_rows, + "registered fused backward program"); + validate_registered_physical_strategy_rows( + physical_strategy_rows, + memory_runtime_schedule_rows, + "registered fused backward program"); + validate_registered_native_callable_output_rows(native_callable_output_rows, schema_version); + validate_registered_reverse_span_output_rows(reverse_span_output_rows, schema_version); + validate_registered_reverse_program_stage_rows(reverse_program_stage_rows, schema_version); + TORCH_CHECK( + native_callable_output_rows.size(0) > 0, + "registered fused backward program requires compiler-owned native callable output rows"); + + std::vector> current_transition_seed_tensor_groups = transition_seed_tensor_groups; + std::vector current_transition_seed_row_groups = transition_seed_row_groups; + at::Tensor current_grad_carry_cells = + (grad_carry_cells.defined() && grad_carry_cells.numel() > 0) ? grad_carry_cells.contiguous() + : grad_output_window.new_empty({0}); + std::vector> span_outputs; + std::vector memory_stage_rows; + const size_t transition_group_count = transition_program_tensor_groups.size(); + const int64_t* message_steps = message_step_indices.data_ptr(); + const bool clone_span_outputs_for_stable_return = local_time_steps > 1; + const bool materialize_grad_carry_cells = registered_runtime_buffer_has_role( + runtime_buffer_rows, + kRuntimeBufferRoleReverseGradCarryCells, + 0); + append_registered_backward_memory_stage_row( + &memory_stage_rows, + grad_output_window, + -1, + kRegisteredBackwardMemoryStageEntry); + + for (int64_t local_step = local_time_steps - 1; local_step >= 0; --local_step) { + std::vector> step_outputs = + registered_temporal_fused_reverse_program_step_impl( + grad_output_window, + current_grad_carry_cells, + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_role_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + forward_artifact_merge_rows, + forward_output_route_rows, + reverse_reset_tensor_groups[static_cast(local_step)], + reverse_reset_row_groups[static_cast(local_step)], + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + reverse_span_output_rows, + transition_primitive_callable_rows, + transition_reverse_seed_role_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + graph_to_backend_order, + backend_to_graph_inverse_order, + output_local_sender_idx, + local_distance, + local_delay, + output_neighbor_idx, + output_neighbor_valid, + output_edge_distance, + output_edge_delay, + transition_program_tensor_groups, + transition_program_tensor_binding_row_groups, + transition_forward_executor_row_groups, + transition_reverse_executor_row_groups, + transition_forward_executor_binding_row_groups, + transition_reverse_executor_binding_row_groups, + transition_memory_liveness_row_groups, + current_transition_seed_tensor_groups, + current_transition_seed_row_groups, + transition_dynamic_binding_row_groups, + transition_output_keep_slot_row_groups, + transition_parameter_tensors, + transition_parameter_rows, + transition_recurrent_msg_output_rows, + transition_public_y_seed_rows, + transition_state_reset_rows, + recurrent_local_sender_idx, + recurrent_neighbor_idx, + recurrent_neighbor_valid, + recurrent_edge_distance, + recurrent_edge_delay, + local_step, + message_steps[local_step], + input_count, + recurrent_count, + distance_scale, + use_sparse_messages, + use_delay, + group_size, + head_dim, + value_dim, + return_boundary_grad, + local_step > 0 || return_window_start_transition_state_grads, + materialize_grad_carry_cells, + &memory_stage_rows, + schema_version); + append_registered_backward_memory_stage_row( + &memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterStepReturn); + TORCH_CHECK( + step_outputs.size() >= transition_group_count + 2, + "fused backward program full-step output group count mismatch"); + std::vector> transition_outputs; + transition_outputs.reserve(transition_group_count); + for (size_t group_index = 0; group_index < transition_group_count; ++group_index) { + transition_outputs.push_back(step_outputs[group_index + 1]); + } + if (local_step > 0) { + auto next_seed_groups = + transition_seed_groups_from_reverse_step_outputs( + transition_outputs, + transition_reverse_seed_role_rows, + transition_next_seed_output_rows); + current_transition_seed_tensor_groups = next_seed_groups.first; + current_transition_seed_row_groups = next_seed_groups.second; + } + append_registered_backward_memory_stage_row( + &memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterSeedUpdate); + + const std::vector& boundary_group = step_outputs[transition_group_count + 1]; + const int64_t total_cells = input_count + recurrent_count + output_count; + if (materialize_grad_carry_cells) { + current_grad_carry_cells = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleReverseGradCarryCells, + 0, + {B, total_cells, hidden}, + "fused backward program grad carry cells"); + current_grad_carry_cells.zero_(); + } else { + TORCH_CHECK( + local_step == 0, + "fused backward program missing compiler grad-carry cells buffer before an earlier local step"); + current_grad_carry_cells = at::Tensor(); + } + const at::Tensor& grad_hidden_graph = reverse_span_output_tensor_for_role( + boundary_group, + reverse_span_output_rows, + kReverseSpanOutputBoundaryGroup, + registered_temporal_stable_id_hash_constexpr("grad_hidden_graph_order"), + schema_version, + "fused backward program boundary output group"); + if (materialize_grad_carry_cells && grad_hidden_graph.defined() && grad_hidden_graph.numel() > 0 && + recurrent_count > 0) { + check_cuda_float_bank(grad_hidden_graph, "fused backward program recurrent hidden carry"); + TORCH_CHECK( + grad_hidden_graph.size(0) == current_grad_carry_cells.size(0) && + grad_hidden_graph.size(1) == recurrent_count && + grad_hidden_graph.size(2) == current_grad_carry_cells.size(2), + "fused backward program recurrent hidden carry shape mismatch"); + current_grad_carry_cells.slice(1, input_count, input_count + recurrent_count).copy_(grad_hidden_graph); + } + append_registered_backward_memory_stage_row( + &memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterCarryUpdate); + append_stable_reverse_program_output_groups( + span_outputs, + step_outputs, + clone_span_outputs_for_stable_return); + append_registered_backward_memory_stage_row( + &memory_stage_rows, + grad_output_window, + local_step, + kRegisteredBackwardMemoryStageAfterStableAppend); + } + append_registered_backward_memory_stage_row( + &memory_stage_rows, + grad_output_window, + -1, + kRegisteredBackwardMemoryStageReturn); + if (!memory_stage_rows.empty()) { + span_outputs.push_back({registered_backward_memory_stage_rows_tensor(memory_stage_rows)}); + } + return span_outputs; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_surface_steps.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_surface_steps.cuh new file mode 100644 index 00000000..1be923a6 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_surface_steps.cuh @@ -0,0 +1,1287 @@ +#pragma once + +std::vector flat_bucket_registered_temporal_fused_backward_program_validate_cuda( + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version) { + return validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); +} + +using ReverseReadoutProjectFn = std::vector (*)( + const at::Tensor& grad_cells_out, + const at::Tensor& output_msg, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& graph_to_backend_order, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t input_count, + int64_t recurrent_count); + +using ReverseReadoutMessageFn = std::vector (*)( + const at::Tensor& grad_output_msg, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& output_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& output_neighbor_idx, + const at::Tensor& output_neighbor_valid, + const at::Tensor& output_edge_distance, + const at::Tensor& output_edge_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + double distance_scale, + bool use_sparse_messages, + bool use_delay); + +using ReverseMessageRecurrentKvFn = std::vector (*)( + const at::Tensor& grad_recurrent_k, + const at::Tensor& grad_recurrent_v, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& backend_to_graph_inverse_order, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t head_dim, + int64_t value_dim, + bool return_input_grad); + +using ReverseMessageForwardRecurrentKvFn = std::vector (*)( + const at::Tensor& input_k_reference, + const at::Tensor& recurrent_hidden_backend_order, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t head_dim, + int64_t value_dim); + +using ReverseMessageCarrierFn = std::vector (*)( + const at::Tensor& grad_recurrent_msg, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k_before, + const at::Tensor& recurrent_v_before, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& recurrent_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& recurrent_neighbor_idx, + const at::Tensor& recurrent_neighbor_valid, + const at::Tensor& recurrent_edge_distance, + const at::Tensor& recurrent_edge_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + double distance_scale, + bool use_sparse_messages, + bool use_delay); + +using ReverseMessageBoundaryKvFn = std::vector (*)( + const at::Tensor& grad_input_k, + const at::Tensor& grad_input_v, + const at::Tensor& boundary_step, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + bool return_boundary_grad); + +struct RegisteredReverseReadoutStrategy { + int64_t native_callable_hash; + const char* name; + ReverseReadoutProjectFn readout; + ReverseReadoutMessageFn output_message; +}; + +struct RegisteredReverseMessageStrategy { + int64_t native_callable_hash; + const char* name; + ReverseMessageRecurrentKvFn recurrent_kv; + ReverseMessageCarrierFn recurrent_message; + ReverseMessageRecurrentKvFn initial_recurrent_kv; + ReverseMessageBoundaryKvFn boundary_kv; + ReverseMessageForwardRecurrentKvFn recurrent_kv_forward_recompute; +}; + +#include "native_callables/message_reverse_strategies.cuh" +#include "native_callables/readout_reverse_strategies.cuh" + +#define REGISTERED_TEMPORAL_NATIVE_REVERSE_MESSAGE_CATALOG +#include "../flat_bucket_registered_native_callables.cuh" + +#define REGISTERED_TEMPORAL_NATIVE_REVERSE_READOUT_CATALOG +#include "../flat_bucket_registered_native_callables.cuh" + +inline const RegisteredReverseMessageStrategy& registered_reverse_message_strategy_for_native_row( + const RegisteredNativeStrategyRow& strategy) { + for (const RegisteredReverseMessageStrategy* item = registered_native_reverse_message_catalog_begin(); + item != registered_native_reverse_message_catalog_end(); + ++item) { + if (item->native_callable_hash == strategy.native_callable_hash) { + return *item; + } + } + TORCH_CHECK( + false, + "registered reverse message strategy has no native callable for compiler-emitted strategy row: native_callable_hash=", + strategy.native_callable_hash); + return *registered_native_reverse_message_catalog_begin(); +} + +inline const RegisteredReverseReadoutStrategy& registered_reverse_readout_strategy_for_native_row( + const RegisteredNativeStrategyRow& strategy) { + for (const RegisteredReverseReadoutStrategy* item = registered_native_reverse_readout_catalog_begin(); + item != registered_native_reverse_readout_catalog_end(); + ++item) { + if (item->native_callable_hash == strategy.native_callable_hash) { + return *item; + } + } + TORCH_CHECK( + false, + "registered reverse readout strategy has no native callable for compiler-emitted strategy row: native_callable_hash=", + strategy.native_callable_hash); + return *registered_native_reverse_readout_catalog_begin(); +} + +struct RegisteredReverseSpanStrategy { + RegisteredFusedProgramSpan span; + RegisteredNativeStrategyRow native_strategy; +}; + +inline std::vector registered_reverse_span_strategies_by_capability( + const at::Tensor& reverse_spans, + const at::Tensor& native_strategy_rows, + int64_t surface_opcode, + int64_t capability_flag, + const char* subject) { + const std::vector span_indices = registered_reverse_handler_span_indices_by_capability( + reverse_spans, + surface_opcode, + capability_flag, + subject); + std::vector spans; + spans.reserve(span_indices.size()); + for (const int64_t span_index : span_indices) { + RegisteredFusedProgramSpan span = registered_fused_program_span_at(reverse_spans, span_index); + spans.push_back({ + span, + registered_native_strategy_row_for_span( + native_strategy_rows, + kReverseDirectionOpcode, + span, + subject), + }); + } + return spans; +} + +inline bool registered_reverse_output_is_empty(const at::Tensor& tensor) { + return !tensor.defined() || tensor.numel() == 0; +} + +inline at::Tensor combine_registered_reverse_output_tensor( + const at::Tensor& accumulated, + const at::Tensor& update, + const char* subject, + int64_t slot) { + if (registered_reverse_output_is_empty(accumulated)) { + return update; + } + if (registered_reverse_output_is_empty(update)) { + return accumulated; + } + TORCH_CHECK( + accumulated.sizes() == update.sizes(), + subject, + " produced incompatible compiler-routed output shapes at slot ", + slot, + ": accumulated=", + accumulated.sizes(), + "; update=", + update.sizes()); + TORCH_CHECK( + accumulated.scalar_type() == update.scalar_type(), + subject, + " produced incompatible compiler-routed output dtypes at slot ", + slot); + if ( + accumulated.scalar_type() == at::kLong && + accumulated.device().is_cpu() && + update.device().is_cpu() && + accumulated.numel() == 1) { + at::Tensor merged = at::empty_like(accumulated); + merged.data_ptr()[0] = std::max( + accumulated.reshape({-1}).select(0, 0).item(), + update.reshape({-1}).select(0, 0).item()); + return merged; + } + TORCH_CHECK( + accumulated.device() == update.device(), + subject, + " produced incompatible compiler-routed output devices at slot ", + slot); + return accumulated + update; +} + +inline std::vector combine_registered_reverse_span_outputs( + std::vector accumulated, + const std::vector& update, + const char* subject) { + if (accumulated.empty()) { + return update; + } + TORCH_CHECK( + accumulated.size() == update.size(), + subject, + " produced incompatible compiler-routed output arity: accumulated=", + static_cast(accumulated.size()), + "; update=", + static_cast(update.size())); + for (size_t slot = 0; slot < accumulated.size(); ++slot) { + accumulated[slot] = combine_registered_reverse_output_tensor( + accumulated[slot], + update[slot], + subject, + static_cast(slot)); + } + return accumulated; +} + +static std::vector registered_temporal_backward_readout_step_impl( + const at::Tensor& grad_cells_out, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_merge_rows, + const at::Tensor& forward_output_route_rows, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& graph_to_backend_order, + int64_t local_step, + int64_t input_count, + int64_t recurrent_count, + int64_t schema_version, + std::vector>* span_output_groups = nullptr) { + check_cuda_float_bank(grad_cells_out, "grad_cells_out"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_reverse_program_access_rows(reverse_program_access_rows); + const int64_t window_len = reverse_artifact_binding_window_len(reverse_artifact_binding_rows); + TORCH_CHECK(local_step >= 0 && local_step < window_len, "fused backward readout step has invalid local_step"); + const std::vector tensor_required = + validate_temporal_reverse_artifact_role_rows(reverse_artifact_role_rows); + validate_temporal_reverse_artifact_access_rows(reverse_artifact_access_rows, tensor_required); + validate_temporal_reverse_artifact_binding_rows( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + tensor_required, + window_len); + std::vector decoded = validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); + + const int64_t output_count = grad_cells_out.size(1) - input_count - recurrent_count; + TORCH_CHECK(output_count > 0, "fused backward readout program requires positive output_count"); + validate_registered_fused_backward_output_grad_reverse_span(decoded[2], output_count); + validate_forward_output_route_rows(forward_output_route_rows, schema_version); + + const at::Tensor& reverse_spans = decoded[2]; + const int64_t* forward_artifact_rows = forward_artifact_route_rows.data_ptr(); + const int64_t* output_route_rows = forward_output_route_rows.data_ptr(); + const int64_t output_route_kind = output_route_rows[1]; + bool direct_carry_claimed = false; + std::vector outputs; + for (const RegisteredReverseSpanStrategy& readout_span : registered_reverse_span_strategies_by_capability( + reverse_spans, + native_strategy_rows, + kReadoutSurfaceOpcode, + kReverseHandlerReadoutFlag, + "fused backward readout program")) { + const RegisteredReverseReadoutStrategy& strategy = + registered_reverse_readout_strategy_for_native_row(readout_span.native_strategy); + at::Tensor output_msg = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kReadoutSurfaceOpcode, + readout_span.span.executor_row_index, + readout_span.span.executor_id, + readout_span.span.bucket_ordinal, + kReverseArtifactAccessOutputMsg, + local_step, + schema_version, + "fused backward readout program output_msg"); + check_cuda_float_bank(output_msg, "fused backward readout output_msg"); + const int64_t output_msg_role_id = reverse_artifact_role_for_access( + reverse_artifact_access_rows, + kReverseArtifactAccessOutputMsg, + "fused backward readout program output route"); + const int64_t forward_artifact_route_row = reverse_artifact_consumer_forward_route_row_for( + reverse_artifact_consumer_route_rows, + forward_artifact_route_rows, + kReadoutSurfaceOpcode, + readout_span.span.executor_row_index, + readout_span.span.executor_id, + readout_span.span.bucket_ordinal, + output_msg_role_id, + schema_version, + "fused backward readout program output route"); + const int64_t* forward_route = + forward_artifact_rows + forward_artifact_route_row * kForwardArtifactRouteRowColumns; + const int64_t output_route_row = forward_output_route_row_for_readout_executor( + forward_output_route_rows, + forward_route[2], + forward_route[3], + forward_route[4], + schema_version, + "fused backward readout program"); + const int64_t* output_route = output_route_rows + output_route_row * kForwardOutputRouteRowColumns; + const int64_t route_output_count = output_msg.size(1); + at::Tensor span_grad_cells_out = at::zeros( + {grad_cells_out.size(0), input_count + recurrent_count + route_output_count, grad_cells_out.size(2)}, + grad_cells_out.options()); + if (!direct_carry_claimed) { + if (input_count + recurrent_count > 0) { + span_grad_cells_out.slice(1, 0, input_count + recurrent_count) + .copy_(grad_cells_out.slice(1, 0, input_count + recurrent_count)); + } + direct_carry_claimed = true; + } + const int64_t aggregate_output_start = input_count + recurrent_count; + const int64_t route_output_offset = output_route[9]; + if (output_route_kind == kForwardOutputRouteReadoutOutputConcat) { + TORCH_CHECK( + route_output_offset + route_output_count <= output_count, + "fused backward readout concat route exceeds aggregate output grad"); + span_grad_cells_out.slice(1, input_count + recurrent_count, input_count + recurrent_count + route_output_count) + .copy_(grad_cells_out.slice( + 1, + aggregate_output_start + route_output_offset, + aggregate_output_start + route_output_offset + route_output_count)); + } else { + TORCH_CHECK( + route_output_offset == 0, + "fused backward readout non-concat route must use zero compiler output offset"); + TORCH_CHECK( + route_output_count == output_count, + "fused backward readout non-concat route output count must match aggregate output grad"); + span_grad_cells_out.slice(1, input_count + recurrent_count, input_count + recurrent_count + route_output_count) + .copy_(grad_cells_out.slice(1, aggregate_output_start, aggregate_output_start + output_count)); + } + std::vector span_outputs = + strategy.readout( + span_grad_cells_out, + output_msg, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + readout_span.native_strategy, + graph_to_backend_order, + reverse_executor_rows, + reverse_executor_binding_rows, + readout_span.span, + input_count, + recurrent_count); + if (span_output_groups != nullptr) { + span_output_groups->push_back(span_outputs); + } + outputs = combine_registered_reverse_span_outputs( + std::move(outputs), + span_outputs, + "fused backward readout program"); + } + return outputs; +} + +static std::vector registered_temporal_backward_output_message_step_impl( + const at::Tensor& grad_output_msg, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_merge_rows, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& output_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& output_neighbor_idx, + const at::Tensor& output_neighbor_valid, + const at::Tensor& output_edge_distance, + const at::Tensor& output_edge_delay, + int64_t local_step, + int64_t message_step_index, + double distance_scale, + bool use_sparse_messages, + bool use_delay, + int64_t schema_version, + std::vector>* span_output_groups = nullptr) { + check_cuda_float_bank(grad_output_msg, "grad_output_msg"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_reverse_program_access_rows(reverse_program_access_rows); + const int64_t window_len = reverse_artifact_binding_window_len(reverse_artifact_binding_rows); + TORCH_CHECK(local_step >= 0 && local_step < window_len, "fused backward output-message step has invalid local_step"); + TORCH_CHECK(message_step_index >= 0, "fused backward output-message step requires non-negative message_step_index"); + const std::vector tensor_required = + validate_temporal_reverse_artifact_role_rows(reverse_artifact_role_rows); + validate_temporal_reverse_artifact_access_rows(reverse_artifact_access_rows, tensor_required); + validate_temporal_reverse_artifact_binding_rows( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + tensor_required, + window_len); + std::vector decoded = validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); + + const at::Tensor& reverse_spans = decoded[2]; + const std::vector message_artifact_spans = + registered_reverse_span_strategies_by_capability( + reverse_spans, + native_strategy_rows, + kMessageSurfaceOpcode, + kReverseHandlerMessageCarrierFlag, + "fused backward output-message program message artifacts"); + validate_forward_artifact_merge_rows(forward_artifact_merge_rows, forward_artifact_route_rows, schema_version); + validate_reverse_artifact_consumer_route_rows( + reverse_artifact_consumer_route_rows, + forward_artifact_route_rows, + schema_version); + const int64_t message_input_k_role = reverse_artifact_role_for_access( + reverse_artifact_access_rows, + kReverseArtifactAccessInputK, + "fused backward output-message input_k route"); + const RegisteredReverseSpanStrategy* message_artifact_strategy_ptr = nullptr; + for (const RegisteredReverseSpanStrategy& message_span : message_artifact_spans) { + const int64_t forward_route_row = try_reverse_artifact_consumer_forward_route_row_for( + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_span.span.executor_row_index, + message_span.span.executor_id, + message_span.span.bucket_ordinal, + message_input_k_role, + "fused backward output-message program message artifacts"); + if (forward_route_row >= 0) { + TORCH_CHECK( + message_artifact_strategy_ptr == nullptr, + "fused backward output-message program has duplicate reverse message spans for compiler consumer route"); + message_artifact_strategy_ptr = &message_span; + } + } + TORCH_CHECK( + message_artifact_strategy_ptr != nullptr, + "fused backward output-message program has no reverse message span for selected compiler consumer route"); + const RegisteredReverseSpanStrategy& message_artifact_strategy = *message_artifact_strategy_ptr; + const RegisteredFusedProgramSpan& message_artifact_span = message_artifact_strategy.span; + at::Tensor input_k = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_artifact_span.executor_row_index, + message_artifact_span.executor_id, + message_artifact_span.bucket_ordinal, + kReverseArtifactAccessInputK, + local_step, + schema_version, + "fused backward output-message input_k"); + at::Tensor input_v = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_artifact_span.executor_row_index, + message_artifact_span.executor_id, + message_artifact_span.bucket_ordinal, + kReverseArtifactAccessInputV, + local_step, + schema_version, + "fused backward output-message input_v"); + at::Tensor recurrent_k = try_reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_artifact_span.executor_row_index, + message_artifact_span.executor_id, + message_artifact_span.bucket_ordinal, + kReverseArtifactAccessRecurrentK, + local_step, + "fused backward output-message recurrent_k"); + at::Tensor recurrent_v = try_reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_artifact_span.executor_row_index, + message_artifact_span.executor_id, + message_artifact_span.bucket_ordinal, + kReverseArtifactAccessRecurrentV, + local_step, + "fused backward output-message recurrent_v"); + if (!recurrent_k.defined() || !recurrent_v.defined()) { + at::Tensor recurrent_hidden_backend_order = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_artifact_span.executor_row_index, + message_artifact_span.executor_id, + message_artifact_span.bucket_ordinal, + kReverseArtifactAccessRecurrentHiddenBackendOrder, + local_step, + schema_version, + "fused backward output-message recurrent hidden for K/V recompute"); + const RegisteredReverseMessageStrategy& message_strategy = + registered_reverse_message_strategy_for_native_row(message_artifact_strategy.native_strategy); + std::vector recurrent_kv = message_strategy.recurrent_kv_forward_recompute( + input_k, + recurrent_hidden_backend_order, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + message_artifact_strategy.native_strategy, + forward_executor_rows, + forward_executor_binding_rows, + message_artifact_span, + input_k.size(2), + grad_output_msg.size(2)); + TORCH_CHECK( + recurrent_kv.size() == 2, + "fused backward output-message recurrent K/V recompute returned an invalid output count"); + recurrent_k = recurrent_kv[0]; + recurrent_v = recurrent_kv[1]; + } + at::Tensor step_flat = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleReverseMessageStepFlat, + 0, + {grad_output_msg.size(0)}, + "fused backward output-message step workspace"); + step_flat.fill_(message_step_index); + std::vector outputs; + for (const RegisteredReverseSpanStrategy& readout_span : registered_reverse_span_strategies_by_capability( + reverse_spans, + native_strategy_rows, + kReadoutSurfaceOpcode, + kReverseHandlerReadoutFlag, + "fused backward output-message program")) { + const RegisteredReverseReadoutStrategy& strategy = + registered_reverse_readout_strategy_for_native_row(readout_span.native_strategy); + std::vector span_outputs = + strategy.output_message( + grad_output_msg, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + readout_span.native_strategy, + input_k, + input_v, + recurrent_k, + recurrent_v, + output_local_sender_idx, + local_distance, + local_delay, + output_neighbor_idx, + output_neighbor_valid, + output_edge_distance, + output_edge_delay, + step_flat, + reverse_executor_rows, + reverse_executor_binding_rows, + readout_span.span, + distance_scale, + use_sparse_messages, + use_delay); + if (span_output_groups != nullptr) { + span_output_groups->push_back(span_outputs); + } + outputs = combine_registered_reverse_span_outputs( + std::move(outputs), + span_outputs, + "fused backward output-message program"); + } + return outputs; +} + +static std::vector registered_temporal_backward_recurrent_kv_projection_step_impl( + const at::Tensor& grad_recurrent_k, + const at::Tensor& grad_recurrent_v, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_merge_rows, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& backend_to_graph_inverse_order, + int64_t local_step, + int64_t head_dim, + int64_t value_dim, + int64_t schema_version, + std::vector>* span_output_groups = nullptr) { + check_cuda_float_bank(grad_recurrent_k, "grad_recurrent_k"); + check_cuda_float_bank(grad_recurrent_v, "grad_recurrent_v"); + check_cuda_long_rank1(backend_to_graph_inverse_order, "backend_to_graph_inverse_order"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_reverse_program_access_rows(reverse_program_access_rows); + TORCH_CHECK(head_dim > 0 && value_dim > 0, "fused recurrent K/V projection requires positive K/V dimensions"); + const int64_t window_len = reverse_artifact_binding_window_len(reverse_artifact_binding_rows); + TORCH_CHECK(local_step >= 0 && local_step < window_len, "fused recurrent K/V projection step has invalid local_step"); + const std::vector tensor_required = + validate_temporal_reverse_artifact_role_rows(reverse_artifact_role_rows); + validate_temporal_reverse_artifact_access_rows(reverse_artifact_access_rows, tensor_required); + validate_temporal_reverse_artifact_binding_rows( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + tensor_required, + window_len); + std::vector decoded = validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); + + const at::Tensor& reverse_spans = decoded[2]; + std::vector outputs; + for (const RegisteredReverseSpanStrategy& message_span : registered_reverse_span_strategies_by_capability( + reverse_spans, + native_strategy_rows, + kMessageSurfaceOpcode, + kReverseHandlerMessageCarrierFlag, + "fused recurrent K/V projection program")) { + const RegisteredReverseMessageStrategy& strategy = + registered_reverse_message_strategy_for_native_row(message_span.native_strategy); + at::Tensor recurrent_hidden_backend_order = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_span.span.executor_row_index, + message_span.span.executor_id, + message_span.span.bucket_ordinal, + kReverseArtifactAccessRecurrentHiddenBackendOrder, + local_step, + schema_version, + "fused recurrent K/V projection recurrent_hidden_backend_order"); + check_cuda_float_bank( + recurrent_hidden_backend_order, + "fused recurrent K/V projection recurrent_hidden_backend_order"); + std::vector span_outputs = + strategy.recurrent_kv( + grad_recurrent_k, + grad_recurrent_v, + recurrent_hidden_backend_order, + backend_to_graph_inverse_order, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + message_span.native_strategy, + reverse_executor_rows, + reverse_executor_binding_rows, + message_span.span, + head_dim, + value_dim, + true); + if (span_output_groups != nullptr) { + span_output_groups->push_back(span_outputs); + } + outputs = combine_registered_reverse_span_outputs( + std::move(outputs), + span_outputs, + "fused recurrent K/V projection program"); + } + return outputs; +} + +static std::vector registered_temporal_backward_recurrent_message_step_impl( + const at::Tensor& grad_recurrent_msg, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_merge_rows, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& recurrent_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& recurrent_neighbor_idx, + const at::Tensor& recurrent_neighbor_valid, + const at::Tensor& recurrent_edge_distance, + const at::Tensor& recurrent_edge_delay, + int64_t local_step, + int64_t message_step_index, + double distance_scale, + bool use_sparse_messages, + bool use_delay, + int64_t schema_version, + std::vector>* span_output_groups = nullptr) { + check_cuda_float_bank(grad_recurrent_msg, "grad_recurrent_msg"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_reverse_program_access_rows(reverse_program_access_rows); + const int64_t window_len = reverse_artifact_binding_window_len(reverse_artifact_binding_rows); + TORCH_CHECK(local_step >= 0 && local_step < window_len, "fused backward recurrent-message step has invalid local_step"); + TORCH_CHECK(message_step_index >= 0, "fused backward recurrent-message step requires non-negative message_step_index"); + const std::vector tensor_required = + validate_temporal_reverse_artifact_role_rows(reverse_artifact_role_rows); + validate_temporal_reverse_artifact_access_rows(reverse_artifact_access_rows, tensor_required); + validate_temporal_reverse_artifact_binding_rows( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + tensor_required, + window_len); + std::vector decoded = validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); + + const at::Tensor& reverse_spans = decoded[2]; + at::Tensor step_flat = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleReverseMessageStepFlat, + 0, + {grad_recurrent_msg.size(0)}, + "fused backward recurrent-message step workspace"); + step_flat.fill_(message_step_index); + std::vector outputs; + for (const RegisteredReverseSpanStrategy& message_span : registered_reverse_span_strategies_by_capability( + reverse_spans, + native_strategy_rows, + kMessageSurfaceOpcode, + kReverseHandlerMessageCarrierFlag, + "fused recurrent-message program")) { + const RegisteredReverseMessageStrategy& strategy = + registered_reverse_message_strategy_for_native_row(message_span.native_strategy); + at::Tensor input_k = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_span.span.executor_row_index, + message_span.span.executor_id, + message_span.span.bucket_ordinal, + kReverseArtifactAccessInputK, + local_step, + schema_version, + "fused recurrent-message input_k"); + at::Tensor input_v = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_span.span.executor_row_index, + message_span.span.executor_id, + message_span.span.bucket_ordinal, + kReverseArtifactAccessInputV, + local_step, + schema_version, + "fused recurrent-message input_v"); + at::Tensor recurrent_k_before = try_reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_span.span.executor_row_index, + message_span.span.executor_id, + message_span.span.bucket_ordinal, + kReverseArtifactAccessRecurrentKBefore, + local_step, + "fused recurrent-message recurrent_k_before"); + at::Tensor recurrent_v_before = try_reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_span.span.executor_row_index, + message_span.span.executor_id, + message_span.span.bucket_ordinal, + kReverseArtifactAccessRecurrentVBefore, + local_step, + "fused recurrent-message recurrent_v_before"); + if (!recurrent_k_before.defined() || !recurrent_v_before.defined()) { + at::Tensor recurrent_hidden_before_backend_order = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_span.span.executor_row_index, + message_span.span.executor_id, + message_span.span.bucket_ordinal, + kReverseArtifactAccessRecurrentHiddenBeforeBackendOrder, + local_step, + schema_version, + "fused recurrent-message recurrent hidden-before for K/V recompute"); + std::vector recurrent_kv = strategy.recurrent_kv_forward_recompute( + input_k, + recurrent_hidden_before_backend_order, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + message_span.native_strategy, + forward_executor_rows, + forward_executor_binding_rows, + message_span.span, + input_k.size(2), + input_v.size(2)); + TORCH_CHECK( + recurrent_kv.size() == 2, + "fused recurrent-message recurrent K/V-before recompute returned an invalid output count"); + recurrent_k_before = recurrent_kv[0]; + recurrent_v_before = recurrent_kv[1]; + } + std::vector span_outputs = + strategy.recurrent_message( + grad_recurrent_msg, + input_k, + input_v, + recurrent_k_before, + recurrent_v_before, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + message_span.native_strategy, + recurrent_local_sender_idx, + local_distance, + local_delay, + recurrent_neighbor_idx, + recurrent_neighbor_valid, + recurrent_edge_distance, + recurrent_edge_delay, + step_flat, + reverse_executor_rows, + reverse_executor_binding_rows, + message_span.span, + distance_scale, + use_sparse_messages, + use_delay); + if (span_output_groups != nullptr) { + span_output_groups->push_back(span_outputs); + } + outputs = combine_registered_reverse_span_outputs( + std::move(outputs), + span_outputs, + "fused recurrent-message program"); + } + return outputs; +} + +static std::vector registered_temporal_backward_initial_recurrent_kv_projection_step_impl( + const at::Tensor& grad_recurrent_k_before, + const at::Tensor& grad_recurrent_v_before, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_merge_rows, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& backend_to_graph_inverse_order, + int64_t local_step, + int64_t head_dim, + int64_t value_dim, + bool return_hidden_grad, + int64_t schema_version, + std::vector>* span_output_groups = nullptr) { + check_cuda_float_bank(grad_recurrent_k_before, "grad_recurrent_k_before"); + check_cuda_float_bank(grad_recurrent_v_before, "grad_recurrent_v_before"); + check_cuda_long_rank1(backend_to_graph_inverse_order, "backend_to_graph_inverse_order"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_reverse_program_access_rows(reverse_program_access_rows); + TORCH_CHECK(head_dim > 0 && value_dim > 0, "fused initial recurrent K/V projection requires positive K/V dimensions"); + const int64_t window_len = reverse_artifact_binding_window_len(reverse_artifact_binding_rows); + TORCH_CHECK( + local_step >= 0 && local_step < window_len, + "fused initial recurrent K/V projection step has invalid local_step"); + const std::vector tensor_required = + validate_temporal_reverse_artifact_role_rows(reverse_artifact_role_rows); + validate_temporal_reverse_artifact_access_rows(reverse_artifact_access_rows, tensor_required); + validate_temporal_reverse_artifact_binding_rows( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + tensor_required, + window_len); + std::vector decoded = validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); + + const at::Tensor& reverse_spans = decoded[2]; + std::vector outputs; + for (const RegisteredReverseSpanStrategy& message_span : registered_reverse_span_strategies_by_capability( + reverse_spans, + native_strategy_rows, + kMessageSurfaceOpcode, + kReverseHandlerMessageCarrierFlag, + "fused initial recurrent K/V projection")) { + const RegisteredReverseMessageStrategy& strategy = + registered_reverse_message_strategy_for_native_row(message_span.native_strategy); + at::Tensor recurrent_hidden_before_backend_order = reverse_artifact_tensor_for_routed_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + forward_artifact_route_rows, + reverse_artifact_consumer_route_rows, + kMessageSurfaceOpcode, + message_span.span.executor_row_index, + message_span.span.executor_id, + message_span.span.bucket_ordinal, + kReverseArtifactAccessRecurrentHiddenBeforeBackendOrder, + local_step, + schema_version, + "fused initial recurrent K/V projection recurrent_hidden_before_backend_order"); + check_cuda_float_bank( + recurrent_hidden_before_backend_order, + "fused initial recurrent K/V projection recurrent_hidden_before_backend_order"); + std::vector span_outputs = + strategy.initial_recurrent_kv( + grad_recurrent_k_before, + grad_recurrent_v_before, + recurrent_hidden_before_backend_order, + backend_to_graph_inverse_order, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + message_span.native_strategy, + reverse_executor_rows, + reverse_executor_binding_rows, + message_span.span, + head_dim, + value_dim, + return_hidden_grad); + if (span_output_groups != nullptr) { + span_output_groups->push_back(span_outputs); + } + outputs = combine_registered_reverse_span_outputs( + std::move(outputs), + span_outputs, + "fused initial recurrent K/V projection"); + } + return outputs; +} + +static std::vector registered_temporal_backward_boundary_kv_projection_step_impl( + const at::Tensor& grad_input_k, + const at::Tensor& grad_input_v, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + const at::Tensor& forward_artifact_merge_rows, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t local_step, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + bool return_boundary_grad, + int64_t schema_version, + std::vector>* span_output_groups = nullptr) { + check_cuda_float_bank(grad_input_k, "grad_input_k"); + check_cuda_float_bank(grad_input_v, "grad_input_v"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_reverse_program_access_rows(reverse_program_access_rows); + TORCH_CHECK(group_size > 0, "fused boundary K/V projection requires positive group_size"); + TORCH_CHECK(head_dim > 0 && value_dim > 0, "fused boundary K/V projection requires positive K/V dimensions"); + const int64_t window_len = reverse_artifact_binding_window_len(reverse_artifact_binding_rows); + TORCH_CHECK(local_step >= 0 && local_step < window_len, "fused boundary K/V projection step has invalid local_step"); + const std::vector tensor_required = + validate_temporal_reverse_artifact_role_rows(reverse_artifact_role_rows); + validate_temporal_reverse_artifact_access_rows(reverse_artifact_access_rows, tensor_required); + validate_forward_artifact_route_rows(forward_artifact_route_rows, schema_version); + validate_temporal_reverse_artifact_binding_rows( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + tensor_required, + window_len); + std::vector decoded = validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); + + const at::Tensor& reverse_spans = decoded[2]; + at::Tensor boundary_step = reverse_artifact_tensor_for_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + kReverseArtifactAccessBoundaryStep, + local_step, + "fused boundary K/V projection boundary_step"); + check_cuda_float_bank(boundary_step, "fused boundary K/V projection boundary_step"); + std::vector outputs; + for (const RegisteredReverseSpanStrategy& message_span : registered_reverse_span_strategies_by_capability( + reverse_spans, + native_strategy_rows, + kMessageSurfaceOpcode, + kReverseHandlerMessageCarrierFlag, + "fused boundary K/V projection")) { + const RegisteredReverseMessageStrategy& strategy = + registered_reverse_message_strategy_for_native_row(message_span.native_strategy); + std::vector backward = strategy.boundary_kv( + grad_input_k, + grad_input_v, + boundary_step, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + message_span.native_strategy, + reverse_executor_rows, + reverse_executor_binding_rows, + message_span.span, + group_size, + head_dim, + value_dim, + return_boundary_grad); + TORCH_CHECK( + backward.size() >= 2, + "fused boundary K/V projection strategy returned too few compiler outputs"); + at::Tensor grouped_flag = at::empty({1}, at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + grouped_flag.data_ptr()[0] = backward.size() > 2 && backward[2].numel() > 0 + ? backward[2].reshape({-1}).select(0, 0).item() + : 0; + std::vector span_outputs = {backward[0], backward[1], grouped_flag}; + if (span_output_groups != nullptr) { + span_output_groups->push_back(span_outputs); + } + outputs = combine_registered_reverse_span_outputs( + std::move(outputs), + span_outputs, + "fused boundary K/V projection"); + } + return outputs; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/common.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/common.cuh new file mode 100644 index 00000000..b8839126 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/common.cuh @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +#include "cortical/fabric/backend/cuda/ops/dense_affine.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +#include "constants_and_checks.cuh" +#include "artifact_routes.cuh" +#include "executor_span_decode.cuh" +#include "memory_runtime_buffers.cuh" +#include "native_callable_bindings.cuh" +#include "program_spans_and_handlers.cuh" +#include "reverse_artifacts_and_resets.cuh" +#include "program_tensor_access.cuh" +#include "transition_device_kernels.cuh" +#include "transition_math_helpers.cuh" +#include "layout_kernels.cuh" + +} // namespace diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/constants_and_checks.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/constants_and_checks.cuh new file mode 100644 index 00000000..784fc4d1 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/constants_and_checks.cuh @@ -0,0 +1,384 @@ +#pragma once + +constexpr int kThreadsPerBlock = 256; +constexpr int64_t kForwardDirectionOpcode = 1; +constexpr int64_t kReverseDirectionOpcode = 2; +constexpr int64_t kInputBindingKindOpcode = 0; +constexpr int64_t kParameterBindingKindOpcode = 1; +constexpr int64_t kOutputBindingKindOpcode = 2; +constexpr int64_t kMessageSurfaceOpcode = 1; +constexpr int64_t kReadoutSurfaceOpcode = 2; +constexpr int64_t kTransitionSurfaceOpcode = 4; +constexpr int64_t kRuntimePolicySurfaceOpcode = 6; +constexpr int64_t kTemporalMessageBucketOrdinal = -1; +constexpr int64_t kTemporalReadoutBucketOrdinal = -2; +constexpr int64_t kTemporalParameterReductionBucketOrdinal = -3; +constexpr int64_t kProgramAccessTransitionAggregatedMessageInput = 8; +constexpr int64_t kProgramAccessTransitionPublicStateOutput = 9; +constexpr int64_t kPrimitiveGatedLogspaceRecurrenceOpcode = 1; +constexpr int64_t kPrimitiveDiagRtuOpcode = 2; +constexpr int64_t kPrimitiveLinearOpcode = 10; +constexpr int64_t kPrimitiveMatmulOpcode = 11; +constexpr int64_t kPrimitiveTanhOpcode = 16; +constexpr int64_t kPrimitiveReadoutProjectOpcode = 20; +constexpr int64_t kPrimitiveNormOrIdentityOpcode = 30; +constexpr int64_t kReverseArtifactBoundaryStep = 1; +constexpr int64_t kReverseArtifactCellsPrev = 2; +constexpr int64_t kReverseArtifactInputK = 3; +constexpr int64_t kReverseArtifactInputV = 4; +constexpr int64_t kReverseArtifactRecurrentKBefore = 5; +constexpr int64_t kReverseArtifactRecurrentVBefore = 6; +constexpr int64_t kReverseArtifactRecurrentK = 7; +constexpr int64_t kReverseArtifactRecurrentV = 8; +constexpr int64_t kReverseArtifactRecurrentHiddenBeforeBackendOrder = 9; +constexpr int64_t kReverseArtifactRecurrentHiddenBackendOrder = 10; +constexpr int64_t kReverseArtifactRecurrentMsgBackendOrder = 11; +constexpr int64_t kReverseArtifactOutputMsg = 12; +constexpr int64_t kReverseArtifactOutputCells = 13; +constexpr int64_t kReverseArtifactTransitionStateBefore = 14; +constexpr int64_t kReverseArtifactMaxRole = 14; +constexpr int64_t kReverseArtifactAccessBoundaryStep = 1; +constexpr int64_t kReverseArtifactAccessCellsPrev = 2; +constexpr int64_t kReverseArtifactAccessInputK = 3; +constexpr int64_t kReverseArtifactAccessInputV = 4; +constexpr int64_t kReverseArtifactAccessRecurrentKBefore = 5; +constexpr int64_t kReverseArtifactAccessRecurrentVBefore = 6; +constexpr int64_t kReverseArtifactAccessRecurrentK = 7; +constexpr int64_t kReverseArtifactAccessRecurrentV = 8; +constexpr int64_t kReverseArtifactAccessRecurrentHiddenBeforeBackendOrder = 9; +constexpr int64_t kReverseArtifactAccessRecurrentHiddenBackendOrder = 10; +constexpr int64_t kReverseArtifactAccessRecurrentMsgBackendOrder = 11; +constexpr int64_t kReverseArtifactAccessOutputMsg = 12; +constexpr int64_t kReverseArtifactAccessOutputCells = 13; +constexpr int64_t kReverseArtifactAccessTransitionStateBefore = 14; +constexpr int64_t kReverseArtifactMaxAccess = 14; +constexpr int64_t kTransitionStateArtifactFlagStride = 1000000; +constexpr int64_t kReverseResetMessage = 1; +constexpr int64_t kReverseResetTransition = 2; +constexpr int64_t kReverseResetPolicyZeroSourceRows = 1; +constexpr int64_t kReverseResetScopeBatchRow = 1; +constexpr int64_t kForwardResetMessage = 1; +constexpr int64_t kForwardResetTransition = 2; +constexpr int64_t kForwardResetPolicyZeroSourceRows = 1; +constexpr int64_t kForwardResetScopeBatchOuterStep = 1; +constexpr int64_t kFusedHandlerRowColumns = 11; +constexpr int64_t kNativeStrategyRowColumns = 17; +constexpr int64_t kNativeCallableBindingSchemaRowColumns = 10; +constexpr int64_t kNativeCallableOutputRowColumns = 12; +constexpr int64_t kTransitionReverseSeedRoleRowColumns = 4; +constexpr int64_t kTransitionPrimitiveCallableRowColumns = 6; +constexpr int64_t kReverseSpanOutputRowColumns = 6; +constexpr int64_t kReverseArtifactBindingRowColumns = 5; +constexpr int64_t kForwardArtifactRouteRowColumns = 10; +constexpr int64_t kForwardArtifactMergeRowColumns = 12; +constexpr int64_t kForwardOutputRouteRowColumns = 10; +constexpr int64_t kReadoutMessageProducerConsumerRowColumns = 16; +constexpr int64_t kMessageTransitionProducerConsumerRowColumns = 16; +constexpr int64_t kReverseArtifactConsumerRouteRowColumns = 12; +constexpr int64_t kReverseParameterReducerRouteRowColumns = 12; +constexpr int64_t kPhysicalStrategyRowColumns = 12; +constexpr int64_t kForwardArtifactMergeIdentitySingleton = 1; +constexpr int64_t kForwardArtifactMergeConcatOrError = 2; +constexpr int64_t kForwardArtifactMergeSumOrError = 3; +constexpr int64_t kForwardOutputRouteReadoutOutputCells = 1; +constexpr int64_t kForwardOutputRouteReadoutOutputSelect = 2; +constexpr int64_t kForwardOutputRouteReadoutOutputConcat = 3; +constexpr int64_t kForwardOutputRouteReadoutOutputSum = 4; +constexpr int64_t kPhysicalStrategyStageMaterialized = 1; +constexpr int64_t kPhysicalStrategyStreamingStepProducerConsumer = 2; +constexpr int64_t kPhysicalStrategyStatusActive = 1; +constexpr int64_t kPhysicalStrategyStatusCandidate = 2; +constexpr int64_t kPhysicalStrategyStatusBlocked = 3; +constexpr int64_t kPhysicalStrategyOutputBoundaryTerminal = 1; +constexpr int64_t kPhysicalStrategyOutputBoundarySequence = 2; +constexpr int64_t kPhysicalStrategyResetAbsent = 1; +constexpr int64_t kPhysicalStrategyResetPresent = 2; +constexpr int64_t kPhysicalStrategyResetUnknown = 3; +constexpr int64_t kPhysicalStrategyBlockerNone = 0; +constexpr int64_t kPhysicalStrategyBlockerPendingProgramKernel = 1; +constexpr int64_t kReadoutMessageProducerConsumerMaterializedKvAfter = 1; +constexpr int64_t kReadoutMessageProducerConsumerStreamFromMessageProjection = 2; +constexpr int64_t kReadoutMessageProducerConsumerStatusActive = 1; +constexpr int64_t kReadoutMessageProducerConsumerStatusCandidate = 2; +constexpr int64_t kReadoutMessageProducerConsumerStatusBlocked = 3; +constexpr int64_t kReadoutMessageProducerConsumerBlockerNone = 0; +constexpr int64_t kReadoutMessageProducerConsumerBlockerPendingProgramBody = 1; +constexpr int64_t kReadoutMessageProducerConsumerBlockerMissingExecutor = 2; +constexpr int64_t kReadoutMessageProducerConsumerBlockerMissingOutputRoute = 3; +constexpr int64_t kReadoutMessageProducerConsumerBlockerMissingStreamingBindings = 4; +constexpr int64_t kReadoutMessageProducerConsumerBlockerCostRejected = 5; +constexpr int64_t kReadoutMessageProducerConsumerRoleInputK = 1; +constexpr int64_t kReadoutMessageProducerConsumerRoleInputV = 2; +constexpr int64_t kReadoutMessageProducerConsumerRoleRecurrentKAfter = 4; +constexpr int64_t kReadoutMessageProducerConsumerRoleRecurrentVAfter = 8; +constexpr int64_t kReadoutMessageProducerConsumerRoleRecurrentHidden = 16; +constexpr int64_t kReadoutMessageProducerConsumerRoleRecurrentKvWeight = 32; +constexpr int64_t kReadoutMessageProducerConsumerRoleReadoutOutputQuery = 64; +constexpr int64_t kReadoutMessageProducerConsumerRoleOutputRouteRows = 128; +constexpr int64_t kReadoutMessageProducerConsumerRoleMemoryLivenessRows = 256; +constexpr int64_t kMessageTransitionProducerConsumerMaterializedRecurrentMessage = 1; +constexpr int64_t kMessageTransitionProducerConsumerStreamToTransitionInput = 2; +constexpr int64_t kMessageTransitionProducerConsumerStatusActive = 1; +constexpr int64_t kMessageTransitionProducerConsumerStatusCandidate = 2; +constexpr int64_t kMessageTransitionProducerConsumerStatusBlocked = 3; +constexpr int64_t kMessageTransitionProducerConsumerBlockerNone = 0; +constexpr int64_t kMessageTransitionProducerConsumerBlockerMissingExecutor = 1; +constexpr int64_t kMessageTransitionProducerConsumerBlockerMultipleConsumersNeedMergeRows = 2; +constexpr int64_t kMessageTransitionProducerConsumerBlockerReceiverCountMismatch = 3; +constexpr int64_t kMessageTransitionProducerConsumerBlockerPendingDirectChunkBody = 4; +constexpr int64_t kMessageTransitionProducerConsumerBlockerCostRejected = 5; +constexpr int64_t kMessageTransitionProducerConsumerRoleInputK = 1; +constexpr int64_t kMessageTransitionProducerConsumerRoleInputV = 2; +constexpr int64_t kMessageTransitionProducerConsumerRoleRecurrentHidden = 4; +constexpr int64_t kMessageTransitionProducerConsumerRoleRecurrentMsg = 8; +constexpr int64_t kMessageTransitionProducerConsumerRoleTransitionAggregateBinding = 16; +constexpr int64_t kMessageTransitionProducerConsumerRoleMemoryLivenessRows = 32; +constexpr int64_t kMessageTransitionProducerConsumerRolePhysicalStrategyRows = 64; +constexpr int64_t kPhysicalStrategySurfaceMessage = 1; +constexpr int64_t kPhysicalStrategySurfaceTransition = 2; +constexpr int64_t kPhysicalStrategySurfaceReadout = 4; +constexpr int64_t kPhysicalStrategySurfaceArtifacts = 8; +constexpr int64_t kPhysicalStrategySurfaceReducers = 16; +constexpr int64_t kPhysicalStrategyTablePrimitiveRows = 1; +constexpr int64_t kPhysicalStrategyTableExecutorRows = 2; +constexpr int64_t kPhysicalStrategyTableBindingRows = 4; +constexpr int64_t kPhysicalStrategyTableMemoryLivenessRows = 8; +constexpr int64_t kPhysicalStrategyTableArtifactRouteRows = 16; +constexpr int64_t kPhysicalStrategyTableOutputRouteRows = 32; +constexpr int64_t kPhysicalStrategyTableRuntimeScheduleRows = 64; +constexpr int64_t kReverseSpanOutputFrontGroup = 1; +constexpr int64_t kReverseSpanOutputBoundaryGroup = 2; +constexpr int64_t kFusedProgramSpanColumns = 21; +constexpr int64_t kForwardHandlerMessageCarrierFlag = 1; +constexpr int64_t kForwardHandlerReadoutFlag = 2; +constexpr int64_t kForwardHandlerTransitionFlag = 4; +constexpr int64_t kReverseHandlerMessageCarrierFlag = 1; +constexpr int64_t kReverseHandlerReadoutFlag = 2; +constexpr int64_t kReverseHandlerTransitionFlag = 4; +constexpr int64_t kHandlerEffectStateRead = 1; +constexpr int64_t kHandlerEffectParameterRead = 2; +constexpr int64_t kHandlerEffectMessageEmit = 4; +constexpr int64_t kHandlerEffectMessageRead = 8; +constexpr int64_t kHandlerEffectOutputEmit = 16; +constexpr int64_t kHandlerEffectStateWrite = 32; +constexpr int64_t kHandlerEffectTapePolicy = 64; +constexpr int64_t kHandlerEffectGradRead = 128; +constexpr int64_t kHandlerEffectParameterGradEmit = 256; + +constexpr int64_t registered_temporal_stable_id_hash_constexpr(const char* value) { + uint32_t checksum = 2166136261u; + for (int64_t index = 0; value[index] != '\0'; ++index) { + checksum ^= static_cast(value[index]); + checksum = static_cast((static_cast(checksum) * 16777619ull) & 0xFFFFFFFFull); + } + const int64_t result = static_cast(checksum & 0x7FFFFFFFu); + return result > 0 ? result : 1; +} +constexpr int64_t kTransitionDynamicSourceReverseArtifact = 1; +constexpr int64_t kTransitionDynamicSourceStateBeforeArtifact = 2; +constexpr int64_t kTransitionDynamicSourceSeedOrZeros = 3; +constexpr int64_t kParameterReducerReadoutOutput = 1; +constexpr int64_t kParameterReducerSenderKvProjection = 2; +constexpr int64_t kParameterReducerRecurrentQuery = 3; +constexpr int64_t kParameterReducerTransition = 4; +constexpr int64_t kParameterReducerOutputQuery = 5; +constexpr int64_t kParameterReducerFixedSlotContextMessage = 6; +constexpr int64_t kParameterReducerCountNone = 0; +constexpr int64_t kParameterReducerCountSender = 1; +constexpr int64_t kParameterReducerCountReadout = 2; +constexpr int64_t kParameterReducerCountRecurrentQuery = 3; +constexpr int64_t kParameterReducerCountOutputQuery = 4; +constexpr int64_t kParameterReducerCountMessageStrategy = 5; +constexpr int64_t kParameterReducerCountModeNone = 0; +constexpr int64_t kParameterReducerCountModeTensorCount = 1; +constexpr int64_t kParameterReducerCountModeRow = 2; +constexpr int64_t kParameterTrainableRolePublicProjWeight = 1; +constexpr int64_t kParameterTrainableRoleKWeight = 2; +constexpr int64_t kParameterTrainableRoleVWeight = 3; +constexpr int64_t kParameterTrainableRoleQProjWeight = 4; +constexpr int64_t kParameterTrainableRoleSlotEmbed = 5; +constexpr int64_t kParameterTrainableRoleMsgOutWeight = 6; +constexpr int64_t kParameterTrainableRoleOutputCellWeight = 7; +constexpr int64_t kParameterTrainableRoleOutputCellBias = 8; +constexpr int64_t kParameterTrainableRoleMessageQuerySlotProjWeight = 9; +constexpr int64_t kParameterTrainableRoleMessageSenderSlotKeyProjWeight = 10; +constexpr int64_t kParameterTrainableRoleMessageQueryNudgeScale = 11; +constexpr int64_t kParameterTrainableRoleMessageSenderContextKey = 12; +constexpr int64_t kParameterTrainableRoleMessageQueryContextGate = 13; +constexpr int64_t kParameterRuntimeRoleBackendRecurrentInverseOrder = 1; +constexpr int64_t kParameterRuntimeRoleRecurrentCellIdx = 2; +constexpr int64_t kParameterRuntimeRoleOutputCellIdx = 3; +constexpr int64_t kParameterRuntimeRoleInputCellIdx = 4; +constexpr int64_t kMessageStrategyGradQuerySlotBackend = 1; +constexpr int64_t kMessageStrategyGradInputKeyBank = 2; +constexpr int64_t kMessageStrategyGradRecurrentKeyBank = 3; +constexpr int64_t kMessageStrategyGradQueryContextScalar = 4; +constexpr int64_t kMessageStrategyGradOutputWeight = 5; +constexpr int64_t kTransitionSourceMaterialized = 1; +constexpr int64_t kTransitionSourceStaticSource = 2; +constexpr int64_t kMemoryWorkspaceMessage = 1; +constexpr int64_t kMemoryWorkspaceOutput = 2; +constexpr int64_t kMemoryWorkspaceParameterTable = 3; +constexpr int64_t kMemoryWorkspacePrimitive = 4; +constexpr int64_t kMemoryWorkspaceReduction = 5; +constexpr int64_t kMemoryWorkspaceStateCarry = 6; +constexpr int64_t kMemoryWorkspaceTensorTable = 7; +constexpr int64_t kMemoryWorkspaceTransition = 8; +constexpr int64_t kMemoryWorkspacePolicyTable = 9; +constexpr int64_t kMemoryEffectGradRead = 1; +constexpr int64_t kMemoryEffectMessageEmit = 3; +constexpr int64_t kMemoryEffectMessageRead = 4; +constexpr int64_t kMemoryEffectOutputEmit = 5; +constexpr int64_t kMemoryEffectParameterGradEmit = 6; +constexpr int64_t kMemoryEffectParameterRead = 7; +constexpr int64_t kMemoryEffectStateRead = 9; +constexpr int64_t kMemoryEffectStateWrite = 11; +constexpr int64_t kMemoryEffectTapePolicy = 12; +constexpr int64_t kMemoryEffectLocalSeedPolicy = 14; +constexpr int64_t kMemoryEffectMetadataPolicy = 15; +constexpr int64_t kMemoryEffectPrimitiveOutputPolicy = 16; +constexpr int64_t kMemoryEffectAliasPolicy = 17; +constexpr int64_t kMemoryEffectRecomputeWindowPolicy = 18; +constexpr int64_t kMemoryEffectMaterializationPolicy = 19; +constexpr int64_t kMemoryEffectCudaGraphConstraint = 20; +constexpr int64_t kMemoryRecomputePolicyCudaGraphGuardPolicy = 13; +constexpr int64_t kMemoryOwnerCompilerPrimitiveRow = 1; +constexpr int64_t kMemoryOwnerCompilerTensorRoleTable = 2; +constexpr int64_t kMemoryOwnerCompilerMemoryPolicy = 3; +constexpr int64_t kRuntimeBufferRoleWorkspace = 0; +constexpr int64_t kRuntimeBufferRoleOutputSeq = 1; +constexpr int64_t kRuntimeBufferRoleGradBoundarySeq = 2; +constexpr int64_t kRuntimeBufferRoleForwardCellsPrevArtifact = 3; +constexpr int64_t kRuntimeBufferRoleForwardRecurrentHiddenAfter = 4; +constexpr int64_t kRuntimeBufferRoleReverseGradCarryCells = 5; +constexpr int64_t kRuntimeBufferRoleReverseGradCellsWork = 6; +constexpr int64_t kRuntimeBufferRoleTransitionForwardLinearOutput = 7; +constexpr int64_t kRuntimeBufferRoleTransitionForwardMatmulOutput = 8; +constexpr int64_t kRuntimeBufferRoleTransitionForwardStateOutput = 9; +constexpr int64_t kRuntimeBufferRoleTransitionForwardNormOutput = 10; +constexpr int64_t kRuntimeBufferRoleTransitionForwardDiagOutput = 11; +constexpr int64_t kRuntimeBufferRoleForwardRecurrentMsg = 12; +constexpr int64_t kRuntimeBufferRoleForwardOutputMsg = 13; +constexpr int64_t kRuntimeBufferRoleForwardOutputCells = 14; +constexpr int64_t kRuntimeBufferRoleReverseGradRecurrentMsg = 15; +constexpr int64_t kRuntimeBufferRoleForwardMessageStepFlat = 16; +constexpr int64_t kRuntimeBufferRoleReverseMessageStepFlat = 17; +constexpr int64_t kRuntimeBufferRoleTransitionForwardUnaryOutput = 18; +constexpr int64_t kRuntimeBufferRoleTransitionReverseRecurrentMsgSpan = 19; +constexpr int64_t kRuntimeBufferRoleTransitionReverseStateBeforeZero = 20; +constexpr int64_t kRuntimeScheduleRoleLocalSeedPolicy = 1; +constexpr int64_t kRuntimeScheduleRoleMetadataPolicy = 2; +constexpr int64_t kRuntimeScheduleRolePrimitiveOutputPolicy = 3; +constexpr int64_t kRuntimeScheduleRoleTapePolicy = 4; +constexpr int64_t kRuntimeScheduleRoleAliasPolicy = 5; +constexpr int64_t kRuntimeScheduleRoleRecomputeWindowPolicy = 6; +constexpr int64_t kRuntimeScheduleRoleMaterializationPolicy = 7; +constexpr int64_t kRuntimeScheduleRoleCudaGraphConstraint = 8; +constexpr int64_t kRuntimeScheduleRoleCheckpointStride = 20; +constexpr int64_t kRuntimeScheduleRoleRecomputeWindowLen = 21; +constexpr int64_t kRuntimeScheduleRoleCheckpointStep = 22; +constexpr int64_t kRuntimeScheduleRoleBackwardWindow = 23; +constexpr int64_t kRuntimeScheduleRoleOutputPhysicalStep = 24; +constexpr int64_t kRuntimeScheduleRoleStoreStepArtifacts = 25; +constexpr int64_t kRuntimeScheduleRolePhysicalTimeSteps = 26; +constexpr int64_t kNativeCallableOutputShapeHidden = 1; +constexpr int64_t kNativeCallableOutputShapeGateLogits = 2; +constexpr int64_t kNativeCallableOutputShapeDiagonalPreproj = 3; +constexpr int64_t kNativeCallableOutputLogicalPrimitiveRow = 1; +constexpr int64_t kNativeCallableOutputLogicalBindingIndex = 2; +constexpr int64_t kNativeCallableOutputInitEmpty = 0; +constexpr int64_t kNativeCallableOutputInitZeros = 1; +constexpr int64_t kNativeCallableBindingInput = 1; +constexpr int64_t kNativeCallableBindingParameter = 2; +constexpr int64_t kNativeCallableBindingOutput = 3; +constexpr int kMaxRegisteredAttentionOffsets = 64; + +inline void check_cuda_float_bank(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kFloat, name, " must be float32"); + TORCH_CHECK(tensor.dim() == 3, name, " must be rank-3"); +} + +inline void check_cuda_float_sequence_bank(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kFloat, name, " must be float32"); + TORCH_CHECK(tensor.dim() == 4, name, " must be rank-4 [B,T,C,H]"); +} + +inline void check_cuda_float_rank4(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kFloat, name, " must be float32"); + TORCH_CHECK(tensor.dim() == 4, name, " must be rank-4"); +} + +inline void check_cuda_float_rank2(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kFloat, name, " must be float32"); + TORCH_CHECK(tensor.dim() == 2, name, " must be rank-2"); +} + +inline void check_cuda_int_rank2(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kInt, name, " must be int32"); + TORCH_CHECK(tensor.dim() == 2, name, " must be rank-2"); +} + +inline void check_cuda_long_rank2(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kLong, name, " must be int64"); + TORCH_CHECK(tensor.dim() == 2, name, " must be rank-2"); +} + +inline void check_cuda_bool_rank2(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kBool, name, " must be bool"); + TORCH_CHECK(tensor.dim() == 2, name, " must be rank-2"); +} + +inline void check_cuda_bool_rank1(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kBool, name, " must be bool"); + TORCH_CHECK(tensor.dim() == 1, name, " must be rank-1"); +} + +inline void check_cuda_float_rank1(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kFloat, name, " must be float32"); + TORCH_CHECK(tensor.dim() == 1, name, " must be rank-1"); +} + +inline void check_cuda_int_rank1(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kInt, name, " must be int32"); + TORCH_CHECK(tensor.dim() == 1, name, " must be rank-1"); +} + +inline void check_cuda_long_rank1(const at::Tensor& tensor, const char* name) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kLong, name, " must be int64"); + TORCH_CHECK(tensor.dim() == 1, name, " must be rank-1"); +} + +inline void check_cpu_long_rank2(const at::Tensor& tensor, const char* name, int64_t columns) { + TORCH_CHECK(!tensor.is_cuda(), name, " must be a CPU tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kLong, name, " must be int64"); + TORCH_CHECK(tensor.dim() == 2, name, " must be rank-2"); + TORCH_CHECK(tensor.size(1) == columns, name, " must have ", columns, " columns"); +} + +inline void check_launch(const char* name) { + const cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, name, " launch failed: ", cudaGetErrorString(err)); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/executor_span_decode.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/executor_span_decode.cuh new file mode 100644 index 00000000..337295b7 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/executor_span_decode.cuh @@ -0,0 +1,552 @@ +#pragma once + +inline void validate_registered_readout_executor_rows( + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t output_count) { + check_cpu_long_rank2(forward_executor_rows, "forward_executor_rows", 6); + check_cpu_long_rank2(forward_executor_binding_rows, "forward_executor_binding_rows", 8); + const int64_t* executor_rows = forward_executor_rows.data_ptr(); + const int64_t row_count = forward_executor_rows.size(0); + bool saw_readout_row = false; + for (int64_t row = 0; row < row_count; ++row) { + const int64_t* item = executor_rows + row * 6; + if (item[0] == executor_id && item[3] == bucket_ordinal) { + TORCH_CHECK(item[5] == output_count, "registered readout executor receiver count mismatch"); + saw_readout_row = true; + } + } + TORCH_CHECK( + saw_readout_row, + "registered temporal forward epilogue requires a compiler-selected readout executor row"); + + const int64_t* binding_rows = forward_executor_binding_rows.data_ptr(); + const int64_t binding_count = forward_executor_binding_rows.size(0); + bool saw_readout_parameter_binding = false; + for (int64_t row = 0; row < binding_count; ++row) { + const int64_t* item = binding_rows + row * 8; + if (item[0] == kForwardDirectionOpcode && item[2] == executor_id && + item[5] == bucket_ordinal && item[6] == kParameterBindingKindOpcode) { + saw_readout_parameter_binding = true; + break; + } + } + TORCH_CHECK( + saw_readout_parameter_binding, + "registered temporal forward epilogue requires compiler-owned readout parameter binding rows"); +} + +inline void validate_registered_reverse_readout_executor_rows( + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t output_count) { + check_cpu_long_rank2(reverse_executor_rows, "reverse_executor_rows", 6); + check_cpu_long_rank2(reverse_executor_binding_rows, "reverse_executor_binding_rows", 8); + const int64_t* executor_rows = reverse_executor_rows.data_ptr(); + const int64_t row_count = reverse_executor_rows.size(0); + bool saw_readout_row = false; + for (int64_t row = 0; row < row_count; ++row) { + const int64_t* item = executor_rows + row * 6; + if (item[0] == executor_id && item[3] == bucket_ordinal) { + TORCH_CHECK(item[5] == output_count, "registered reverse readout executor receiver count mismatch"); + saw_readout_row = true; + } + } + TORCH_CHECK( + saw_readout_row, + "registered temporal reverse epilogue requires a compiler-selected readout backward executor row"); + + const int64_t* binding_rows = reverse_executor_binding_rows.data_ptr(); + const int64_t binding_count = reverse_executor_binding_rows.size(0); + bool saw_readout_parameter_binding = false; + for (int64_t row = 0; row < binding_count; ++row) { + const int64_t* item = binding_rows + row * 8; + if (item[0] == kReverseDirectionOpcode && item[2] == executor_id && + item[5] == bucket_ordinal && item[6] == kParameterBindingKindOpcode) { + saw_readout_parameter_binding = true; + break; + } + } + TORCH_CHECK( + saw_readout_parameter_binding, + "registered temporal reverse epilogue requires compiler-owned readout parameter binding rows"); +} + +inline void validate_registered_partitioned_attention_executor_rows( + const at::Tensor& executor_rows_tensor, + const at::Tensor& executor_binding_rows_tensor, + int64_t direction_opcode, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t receiver_count) { + check_cpu_long_rank2(executor_rows_tensor, "executor_rows", 6); + check_cpu_long_rank2(executor_binding_rows_tensor, "executor_binding_rows", 8); + const int64_t* executor_rows = executor_rows_tensor.data_ptr(); + const int64_t row_count = executor_rows_tensor.size(0); + bool saw_executor_row = false; + for (int64_t row = 0; row < row_count; ++row) { + const int64_t* item = executor_rows + row * 6; + if (item[0] == executor_id && item[3] == bucket_ordinal) { + TORCH_CHECK( + item[5] == receiver_count, + "registered partitioned attention executor receiver count mismatch"); + saw_executor_row = true; + } + } + TORCH_CHECK( + saw_executor_row, + "registered partitioned attention requires a compiler-selected executor row"); + + const int64_t* binding_rows = executor_binding_rows_tensor.data_ptr(); + const int64_t binding_count = executor_binding_rows_tensor.size(0); + bool saw_parameter_binding = false; + for (int64_t row = 0; row < binding_count; ++row) { + const int64_t* item = binding_rows + row * 8; + if (item[0] == direction_opcode && item[2] == executor_id && item[5] == bucket_ordinal && + item[6] == kParameterBindingKindOpcode) { + saw_parameter_binding = true; + break; + } + } + TORCH_CHECK( + saw_parameter_binding, + "registered partitioned attention requires compiler-owned parameter binding rows"); +} + +inline void validate_registered_executor_binding_rows( + const at::Tensor& executor_rows_tensor, + const at::Tensor& executor_binding_rows_tensor, + int64_t direction_opcode, + int64_t executor_id, + int64_t bucket_ordinal, + const char* subject, + bool require_parameter_binding = true) { + check_cpu_long_rank2(executor_rows_tensor, "executor_rows", 6); + check_cpu_long_rank2(executor_binding_rows_tensor, "executor_binding_rows", 8); + const int64_t* executor_rows = executor_rows_tensor.data_ptr(); + const int64_t row_count = executor_rows_tensor.size(0); + bool saw_executor_row = false; + for (int64_t row = 0; row < row_count; ++row) { + const int64_t* item = executor_rows + row * 6; + if (item[0] == executor_id && item[3] == bucket_ordinal) { + saw_executor_row = true; + break; + } + } + TORCH_CHECK(saw_executor_row, subject, " requires a compiler-selected executor row"); + const int64_t* binding_rows = executor_binding_rows_tensor.data_ptr(); + const int64_t binding_count = executor_binding_rows_tensor.size(0); + bool saw_parameter_binding = false; + for (int64_t row = 0; row < binding_count; ++row) { + const int64_t* item = binding_rows + row * 8; + if (item[0] == direction_opcode && item[2] == executor_id && item[5] == bucket_ordinal && + item[6] == kParameterBindingKindOpcode) { + saw_parameter_binding = true; + break; + } + } + TORCH_CHECK( + !require_parameter_binding || saw_parameter_binding, + subject, + " requires compiler-owned parameter binding rows"); +} + +inline bool fused_program_has_executor_row( + const at::Tensor& executor_rows_tensor, + int64_t executor_row_index, + int64_t executor_id, + int64_t bucket_ordinal) { + if (executor_row_index < 0 || executor_row_index >= executor_rows_tensor.size(0)) { + return false; + } + const int64_t* row = executor_rows_tensor.data_ptr() + executor_row_index * 6; + return row[0] == executor_id && row[3] == bucket_ordinal; +} + +inline void validate_registered_fused_program_executor_rows( + const at::Tensor& primitive_rows, + const at::Tensor& executor_rows, + const char* name) { + check_cpu_long_rank2(executor_rows, name, 6); + const int64_t primitive_count = primitive_rows.size(0); + const int64_t* rows = executor_rows.data_ptr(); + for (int64_t row_index = 0; row_index < executor_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + const int64_t executor_id = row[0]; + const int64_t primitive_row_start = row[1]; + const int64_t primitive_row_count = row[2]; + const int64_t bucket_ordinal = row[3]; + const int64_t receiver_start = row[4]; + const int64_t receiver_count = row[5]; + TORCH_CHECK(executor_id > 0, name, " row ", row_index, " has an invalid executor id"); + TORCH_CHECK( + bucket_ordinal >= 0 || bucket_ordinal == kTemporalMessageBucketOrdinal || + bucket_ordinal == kTemporalReadoutBucketOrdinal, + name, + " row ", + row_index, + " has an invalid bucket ordinal"); + TORCH_CHECK(receiver_start >= 0, name, " row ", row_index, " has an invalid receiver start"); + TORCH_CHECK(receiver_count >= 0, name, " row ", row_index, " has an invalid receiver count"); + TORCH_CHECK(primitive_row_start >= 0, name, " row ", row_index, " has an invalid primitive row start"); + TORCH_CHECK(primitive_row_count > 0, name, " row ", row_index, " has no primitive row coverage"); + TORCH_CHECK( + primitive_row_start + primitive_row_count <= primitive_count, + name, + " row ", + row_index, + " references primitive rows outside the compiler primitive table"); + const int64_t* primitive_table = primitive_rows.data_ptr(); + for (int64_t primitive_offset = 0; primitive_offset < primitive_row_count; ++primitive_offset) { + const int64_t primitive_index = primitive_row_start + primitive_offset; + const int64_t* primitive = primitive_table + primitive_index * 4; + TORCH_CHECK(primitive[0] > 0, name, " row ", row_index, " references an invalid primitive opcode"); + TORCH_CHECK( + primitive[3] == bucket_ordinal, + name, + " row ", + row_index, + " primitive bucket does not match executor bucket"); + } + } +} + +inline void validate_registered_fused_program_binding_rows( + const at::Tensor& primitive_rows, + const at::Tensor& executor_rows, + const at::Tensor& binding_rows, + int64_t expected_direction_opcode, + const char* name) { + check_cpu_long_rank2(binding_rows, name, 8); + const int64_t primitive_count = primitive_rows.size(0); + const int64_t* rows = binding_rows.data_ptr(); + TORCH_CHECK(binding_rows.size(0) > 0, name, " must include compiler-owned tensor binding rows"); + for (int64_t row_index = 0; row_index < binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 8; + const int64_t direction_opcode = row[0]; + const int64_t executor_row_index = row[1]; + const int64_t executor_id = row[2]; + const int64_t primitive_row_index = row[3]; + const int64_t binding_index = row[4]; + const int64_t bucket_ordinal = row[5]; + const int64_t binding_kind = row[6]; + const int64_t local_binding_index = row[7]; + TORCH_CHECK( + direction_opcode == expected_direction_opcode, + name, + " row ", + row_index, + " has a direction opcode that does not match its fused program direction"); + TORCH_CHECK( + fused_program_has_executor_row(executor_rows, executor_row_index, executor_id, bucket_ordinal), + name, + " row ", + row_index, + " references no matching compiler executor row"); + TORCH_CHECK( + primitive_row_index >= 0 && primitive_row_index < primitive_count, + name, + " row ", + row_index, + " references no compiler primitive row"); + TORCH_CHECK(binding_index >= 0, name, " row ", row_index, " has an invalid binding index"); + TORCH_CHECK( + binding_kind >= 0 && binding_kind <= 2, + name, + " row ", + row_index, + " has an unregistered binding-kind opcode"); + TORCH_CHECK(local_binding_index >= 0, name, " row ", row_index, " has an invalid local binding index"); + } +} + +inline int64_t surface_opcode_for_executor_bucket(int64_t bucket_ordinal) { + if (bucket_ordinal == kTemporalMessageBucketOrdinal) { + return kMessageSurfaceOpcode; + } + if (bucket_ordinal == kTemporalReadoutBucketOrdinal) { + return kReadoutSurfaceOpcode; + } + TORCH_CHECK(bucket_ordinal >= 0, "registered fused program cannot infer surface for bucket ", bucket_ordinal); + return kTransitionSurfaceOpcode; +} + +inline int64_t count_fused_memory_rows_for_executor( + const at::Tensor& memory_liveness_rows, + int64_t primitive_row_start, + int64_t primitive_row_count, + int64_t bucket_ordinal) { + const int64_t* memory_rows = memory_liveness_rows.data_ptr(); + int64_t count = 0; + for (int64_t row_index = 0; row_index < memory_liveness_rows.size(0); ++row_index) { + const int64_t* row = memory_rows + row_index * 10; + const int64_t memory_primitive_row = row[1]; + const int64_t memory_bucket = row[2]; + if (memory_primitive_row >= primitive_row_start && + memory_primitive_row < primitive_row_start + primitive_row_count) { + ++count; + continue; + } + if (memory_primitive_row == -1 && memory_bucket == bucket_ordinal) { + ++count; + } + } + return count; +} + +inline at::Tensor decode_registered_fused_program_spans( + const at::Tensor& primitive_rows, + const at::Tensor& executor_rows, + const at::Tensor& handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t direction_opcode, + int64_t schema_version, + bool require_native_strategy_rows, + const char* name) { + check_cpu_long_rank2( + handler_rows, + direction_opcode == kForwardDirectionOpcode ? "forward_handler_rows" : "reverse_handler_rows", + kFusedHandlerRowColumns); + if (require_native_strategy_rows) { + check_cpu_long_rank2(native_strategy_rows, "native_strategy_rows", kNativeStrategyRowColumns); + } + auto spans = at::empty({executor_rows.size(0), kFusedProgramSpanColumns}, executor_rows.options()); + const int64_t* executor_table = executor_rows.data_ptr(); + const int64_t* handler_table = handler_rows.data_ptr(); + const int64_t* strategy_table = native_strategy_rows.data_ptr(); + const int64_t* binding_table = executor_binding_rows.data_ptr(); + int64_t* span_table = spans.data_ptr(); + for (int64_t executor_row_index = 0; executor_row_index < executor_rows.size(0); ++executor_row_index) { + const int64_t* executor = executor_table + executor_row_index * 6; + const int64_t executor_id = executor[0]; + const int64_t primitive_row_start = executor[1]; + const int64_t primitive_row_count = executor[2]; + const int64_t bucket_ordinal = executor[3]; + const int64_t receiver_start = executor[4]; + const int64_t receiver_count = executor[5]; + const int64_t surface_opcode = surface_opcode_for_executor_bucket(bucket_ordinal); + + const int64_t* selected_handler = nullptr; + for (int64_t handler_row_index = 0; handler_row_index < handler_rows.size(0); ++handler_row_index) { + const int64_t* handler = handler_table + handler_row_index * kFusedHandlerRowColumns; + TORCH_CHECK(handler[0] > 0, name, " handler row ", handler_row_index, " has an invalid executor id"); + TORCH_CHECK(handler[1] > 0, name, " handler row ", handler_row_index, " has an invalid surface opcode"); + TORCH_CHECK(handler[2] > 0, name, " handler row ", handler_row_index, " has an invalid handler kind"); + TORCH_CHECK(handler[3] > 0, name, " handler row ", handler_row_index, " has an invalid primitive opcode"); + TORCH_CHECK(handler[4] > 0, name, " handler row ", handler_row_index, " has an invalid primitive row count"); + TORCH_CHECK(handler[5] > 0, name, " handler row ", handler_row_index, " has no capability flags"); + TORCH_CHECK(handler[6] > 0, name, " handler row ", handler_row_index, " has no required effect mask"); + TORCH_CHECK(handler[7] > 0, name, " handler row ", handler_row_index, " has no strategy identity hash"); + TORCH_CHECK(handler[8] >= 0, name, " handler row ", handler_row_index, " has an invalid access count"); + TORCH_CHECK(handler[9] >= 0, name, " handler row ", handler_row_index, " has an invalid carry-rule count"); + TORCH_CHECK( + handler[10] == 0 || handler[10] == 1, + name, + " handler row ", + handler_row_index, + " has an invalid rewrite flag"); + if (handler[0] != executor_id || handler[1] != surface_opcode) { + continue; + } + TORCH_CHECK(selected_handler == nullptr, name, " has duplicate compiler handler rows for executor_id=", executor_id, + ",surface_opcode=", surface_opcode); + selected_handler = handler; + } + TORCH_CHECK( + selected_handler != nullptr, + name, + " has no compiler handler row for executor_id=", + executor_id, + ",surface_opcode=", + surface_opcode, + ",bucket=", + bucket_ordinal); + TORCH_CHECK( + selected_handler[4] == primitive_row_count, + name, + " handler row primitive count does not match executor row for executor_id=", + executor_id, + ",surface_opcode=", + surface_opcode); + const int64_t primitive_opcode = (primitive_rows.data_ptr() + primitive_row_start * 4)[0]; + TORCH_CHECK( + selected_handler[3] == primitive_opcode, + name, + " handler row primitive opcode does not match executor primitive row for executor_id=", + executor_id, + ",surface_opcode=", + surface_opcode, + ",handler_primitive_opcode=", + selected_handler[3], + ",primitive_row_opcode=", + primitive_opcode); + if (require_native_strategy_rows) { + const int64_t* selected_strategy = nullptr; + for (int64_t strategy_row_index = 0; strategy_row_index < native_strategy_rows.size(0); ++strategy_row_index) { + const int64_t* strategy = strategy_table + strategy_row_index * kNativeStrategyRowColumns; + TORCH_CHECK(strategy[0] == kForwardDirectionOpcode || strategy[0] == kReverseDirectionOpcode, + name, " native strategy row ", strategy_row_index, " has invalid direction opcode"); + TORCH_CHECK(strategy[1] > 0, name, " native strategy row ", strategy_row_index, " has invalid surface opcode"); + TORCH_CHECK(strategy[2] > 0, name, " native strategy row ", strategy_row_index, " has invalid executor id"); + TORCH_CHECK(strategy[3] > 0, name, " native strategy row ", strategy_row_index, " has invalid handler kind"); + TORCH_CHECK(strategy[4] > 0, name, " native strategy row ", strategy_row_index, " has invalid primitive opcode"); + TORCH_CHECK(strategy[5] > 0, name, " native strategy row ", strategy_row_index, " has invalid row signature count"); + TORCH_CHECK(strategy[6] > 0, name, " native strategy row ", strategy_row_index, " has no capability flags"); + TORCH_CHECK(strategy[7] > 0, name, " native strategy row ", strategy_row_index, " has no effect flags"); + TORCH_CHECK(strategy[8] == 1 && strategy[9] == 1 && strategy[10] == 1 && strategy[11] == schema_version, + name, " native strategy row ", strategy_row_index, " has unsupported schema versions"); + TORCH_CHECK( + strategy[12] > 0, + name, + " native strategy row ", + strategy_row_index, + " has no strategy identity hash"); + TORCH_CHECK( + strategy[13] >= 0, + name, + " native strategy row ", + strategy_row_index, + " has an invalid access count"); + TORCH_CHECK( + strategy[14] >= 0, + name, + " native strategy row ", + strategy_row_index, + " has an invalid carry-rule count"); + TORCH_CHECK( + strategy[15] == 0 || strategy[15] == 1, + name, + " native strategy row ", + strategy_row_index, + " has an invalid rewrite flag"); + TORCH_CHECK( + strategy[16] > 0, + name, + " native strategy row ", + strategy_row_index, + " has no native callable hash"); + if (strategy[0] != direction_opcode || strategy[1] != surface_opcode || strategy[2] != executor_id || + strategy[3] != selected_handler[2] || strategy[4] != selected_handler[3] || + strategy[12] != selected_handler[7]) { + continue; + } + TORCH_CHECK(selected_strategy == nullptr, name, " has duplicate native strategy rows for executor_id=", executor_id, + ",surface_opcode=", surface_opcode, ",handler_kind=", selected_handler[2], + ",primitive_opcode=", selected_handler[3]); + selected_strategy = strategy; + } + TORCH_CHECK( + selected_strategy != nullptr, + name, + " has no native strategy row for executor_id=", + executor_id, + ",surface_opcode=", + surface_opcode, + ",handler_kind=", + selected_handler[2], + ",primitive_opcode=", + selected_handler[3]); + TORCH_CHECK( + selected_strategy[5] == selected_handler[4], + name, + " native strategy row signature count does not match handler primitive count for executor_id=", + executor_id); + TORCH_CHECK( + selected_strategy[6] == selected_handler[5] && selected_strategy[7] == selected_handler[6], + name, + " native strategy row capability/effect masks do not match compiler handler row for executor_id=", + executor_id); + TORCH_CHECK( + selected_strategy[12] == selected_handler[7] && + selected_strategy[13] == selected_handler[8] && + selected_strategy[14] == selected_handler[9] && + selected_strategy[15] == selected_handler[10], + name, + " native strategy row contract does not match compiler handler row for executor_id=", + executor_id); + } + + int64_t binding_start = -1; + int64_t binding_count = 0; + int64_t previous_binding_row = -1; + for (int64_t binding_row_index = 0; binding_row_index < executor_binding_rows.size(0); ++binding_row_index) { + const int64_t* binding = binding_table + binding_row_index * 8; + if (binding[0] != direction_opcode || binding[1] != executor_row_index || + binding[2] != executor_id || binding[5] != bucket_ordinal) { + continue; + } + if (binding_start < 0) { + binding_start = binding_row_index; + } else { + TORCH_CHECK( + binding_row_index == previous_binding_row + 1, + name, + " executor row ", + executor_row_index, + " has non-contiguous compiler binding rows"); + } + previous_binding_row = binding_row_index; + ++binding_count; + } + TORCH_CHECK(binding_count > 0, name, " executor row ", executor_row_index, " has no binding rows"); + + const int64_t memory_count = count_fused_memory_rows_for_executor( + memory_liveness_rows, + primitive_row_start, + primitive_row_count, + bucket_ordinal); + TORCH_CHECK(memory_count > 0, name, " executor row ", executor_row_index, " has no memory-plan rows"); + + int64_t* span = span_table + executor_row_index * kFusedProgramSpanColumns; + span[0] = direction_opcode; + span[1] = executor_row_index; + span[2] = executor_id; + span[3] = surface_opcode; + span[4] = bucket_ordinal; + span[5] = primitive_row_start; + span[6] = primitive_row_count; + span[7] = receiver_start; + span[8] = receiver_count; + span[9] = binding_start; + span[10] = binding_count; + span[11] = memory_count; + span[12] = selected_handler[2]; + span[13] = selected_handler[3]; + span[14] = selected_handler[4]; + span[15] = selected_handler[5]; + span[16] = selected_handler[6]; + span[17] = selected_handler[7]; + span[18] = selected_handler[8]; + span[19] = selected_handler[9]; + span[20] = selected_handler[10]; + } + return spans; +} + +inline at::Tensor decode_registered_fused_program_spans( + const at::Tensor& primitive_rows, + const at::Tensor& executor_rows, + const at::Tensor& handler_rows, + const at::Tensor& executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t direction_opcode, + const char* name) { + at::Tensor empty_strategy_rows = at::empty({0, kNativeStrategyRowColumns}, handler_rows.options()); + return decode_registered_fused_program_spans( + primitive_rows, + executor_rows, + handler_rows, + empty_strategy_rows, + executor_binding_rows, + memory_liveness_rows, + direction_opcode, + 1, + false, + name); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/forward_program.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/forward_program.cuh new file mode 100644 index 00000000..57e6beb6 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/forward_program.cuh @@ -0,0 +1,2468 @@ +#pragma once + +#include + +namespace { + +struct RegisteredForwardStrategyTensor { + const char* access_name; + at::Tensor tensor; +}; + +using RegisteredForwardStrategyTensorList = std::vector; + +struct RegisteredForwardMessageExecutorState { + RegisteredFusedProgramSpan span; + RegisteredNativeStrategyRow native_strategy; + RegisteredForwardStrategyTensorList tensors; + RegisteredForwardStrategyTensorList cached_tensors; + int64_t input_group_size; + int64_t message_output_dim; +}; + +struct RegisteredForwardReadoutExecutorState { + RegisteredFusedProgramSpan span; + RegisteredNativeStrategyRow native_strategy; + RegisteredForwardStrategyTensorList tensors; +}; + +struct RegisteredForwardMessageStepState { + const RegisteredForwardMessageExecutorState* executor; + at::Tensor input_k_step; + at::Tensor input_v_step; + at::Tensor recurrent_k_before; + at::Tensor recurrent_v_before; + at::Tensor recurrent_msg; + at::Tensor recurrent_k_after; + at::Tensor recurrent_v_after; +}; + +constexpr int64_t kRegisteredForwardMemoryStageEntry = 201; +constexpr int64_t kRegisteredForwardMemoryStageAfterInputKv = 202; +constexpr int64_t kRegisteredForwardMemoryStageAfterRecurrentKvBefore = 203; +constexpr int64_t kRegisteredForwardMemoryStageAfterRecurrentMessage = 204; +constexpr int64_t kRegisteredForwardMemoryStageAfterTransition = 205; +constexpr int64_t kRegisteredForwardMemoryStageAfterRecurrentKvAfter = 206; +constexpr int64_t kRegisteredForwardMemoryStageAfterReadoutMessage = 207; +constexpr int64_t kRegisteredForwardMemoryStageAfterReadoutProjection = 208; +constexpr int64_t kRegisteredForwardMemoryStageAfterOutputRoute = 209; +constexpr int64_t kRegisteredForwardMemoryStageAfterTensorCompaction = 210; +constexpr int64_t kRegisteredForwardMemoryStageReturn = 211; +constexpr int64_t kRegisteredForwardMemoryStageMessageAfterOutputWeight = 212; +constexpr int64_t kRegisteredForwardMemoryStageMessageAfterWeightedValue = 213; +constexpr int64_t kRegisteredForwardMemoryStageMessageAfterProjected = 214; +constexpr int64_t kRegisteredForwardMemoryStageMessageAfterNormalize = 215; +constexpr int64_t kRegisteredForwardMemoryStageMessageBeforeWeightedValueAlloc = 216; +constexpr int64_t kRegisteredForwardMemoryStageMessageAfterWeightedValueAlloc = 217; +constexpr int64_t kRegisteredForwardMemoryStageMessageBeforeOutputWeight = 218; +constexpr int64_t kRegisteredForwardMemoryStageMessageBeforeProjectedGemm = 219; +constexpr int64_t kRegisteredForwardMemoryStageMessageAfterProjectedGemm = 220; +constexpr int64_t kRegisteredForwardMemoryStageMessageAfterProjectedContiguous = 221; +constexpr int64_t kRegisteredForwardMemoryStageMessageBeforeNormalize = 222; +constexpr int64_t kRegisteredForwardMemoryStageAfterStreamingMessageRelease = 223; + +inline void append_registered_forward_memory_stage_row( + std::vector* rows, + const at::Tensor& reference, + int64_t local_step, + int64_t stage_id) { + if (rows == nullptr || !reference.defined() || !reference.is_cuda()) { + return; + } + const auto device_index = static_cast(reference.get_device()); + const auto stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_index); + const size_t aggregate = static_cast(c10::CachingAllocator::StatType::AGGREGATE); + rows->push_back(local_step); + rows->push_back(stage_id); + rows->push_back(stats.allocated_bytes[aggregate].current); + rows->push_back(stats.reserved_bytes[aggregate].current); + rows->push_back(stats.allocated_bytes[aggregate].peak); +} + +inline at::Tensor registered_forward_memory_stage_rows_tensor( + const std::vector& rows) { + at::Tensor tensor = at::empty( + {static_cast(rows.size() / 5), 5}, + at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + if (!rows.empty()) { + std::memcpy( + tensor.data_ptr(), + rows.data(), + rows.size() * sizeof(int64_t)); + } + return tensor; +} + +using ForwardMessageBindFn = RegisteredForwardMessageExecutorState (*)( + const RegisteredFusedProgramSpan& span, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& boundary_seq, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t input_count, + int64_t head_dim, + int64_t value_dim); + +using ForwardMessageRecurrentKvFn = std::vector (*)( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& recurrent_hidden, + const at::Tensor& empty, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t head_dim, + int64_t value_dim, + bool materialize_key_bank); + +using ForwardMessageCarrierFn = at::Tensor (*)( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const at::Tensor& recurrent_msg_output_override, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay, + std::vector* memory_stage_rows, + int64_t memory_stage_local_step); + +using ForwardMessageStreamTransitionInputFn = at::Tensor (*)( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const RegisteredTransitionInputProjectionTarget& transition_target, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay, + std::vector* memory_stage_rows, + int64_t memory_stage_local_step); + +using ForwardReadoutBindFn = RegisteredForwardReadoutExecutorState (*)( + const RegisteredFusedProgramSpan& span, + const RegisteredNativeStrategyRow& native_strategy, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows); + +using ForwardReadoutMessageFn = at::Tensor (*)( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay); + +using ForwardReadoutProjectFn = at::Tensor (*)( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& output_msg, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index); + +using ForwardReadoutProjectIntoFn = at::Tensor (*)( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& output_msg, + const at::Tensor& output_cells, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows); + +using ForwardMessageKeylessReadoutFn = at::Tensor (*)( + const RegisteredForwardMessageExecutorState& message_executor, + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay); + +using ForwardMessageDirectKeylessReadoutFn = at::Tensor (*)( + const RegisteredForwardMessageExecutorState& message_executor, + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay); + +using ForwardMessageStreamReadoutFn = at::Tensor (*)( + const RegisteredForwardMessageExecutorState& message_executor, + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay); + +struct RegisteredForwardMessageCarrierStrategy { + int64_t native_callable_hash; + const char* name; + ForwardMessageBindFn bind; + ForwardMessageRecurrentKvFn recurrent_kv; + ForwardMessageCarrierFn message; + ForwardMessageKeylessReadoutFn keyless_readout_message; + ForwardMessageDirectKeylessReadoutFn direct_keyless_readout_message; + ForwardMessageStreamReadoutFn stream_readout_message; + ForwardMessageStreamTransitionInputFn stream_transition_input; +}; + +struct RegisteredForwardReadoutStrategy { + int64_t native_callable_hash; + const char* name; + ForwardReadoutBindFn bind; + ForwardReadoutMessageFn message; + ForwardReadoutProjectFn project; + ForwardReadoutProjectIntoFn project_into; +}; + +inline bool registered_forward_strategy_callable_matches_native_row( + int64_t native_callable_hash, + const RegisteredNativeStrategyRow& strategy) { + return strategy.direction_opcode == kForwardDirectionOpcode && + native_callable_hash == strategy.native_callable_hash; +} + +inline at::Tensor registered_forward_strategy_tensor( + const RegisteredForwardStrategyTensorList& tensors, + const char* access_name, + const char* owner) { + for (const RegisteredForwardStrategyTensor& tensor : tensors) { + if (std::string(tensor.access_name) == std::string(access_name)) { + return tensor.tensor; + } + } + TORCH_CHECK(false, owner, " did not bind required compiler program access: ", access_name); + return at::Tensor(); +} + +inline at::Tensor registered_forward_message_tensor( + const RegisteredForwardMessageExecutorState& state, + const char* access_name) { + return registered_forward_strategy_tensor(state.tensors, access_name, "registered forward message strategy"); +} + +inline at::Tensor registered_forward_message_cache_tensor( + const RegisteredForwardMessageExecutorState& state, + const char* access_name) { + return registered_forward_strategy_tensor(state.cached_tensors, access_name, "registered forward message cache"); +} + +inline bool registered_reverse_transition_input_binding_consumes_forward_binding( + const at::Tensor& reverse_executor_binding_rows, + int64_t bucket_ordinal, + int64_t binding_index) { + check_cpu_long_rank2(reverse_executor_binding_rows, "reverse_executor_binding_rows", 8); + const int64_t* rows = reverse_executor_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < reverse_executor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 8; + if ( + row[0] == kReverseDirectionOpcode && + row[4] == binding_index && + row[5] == bucket_ordinal && + row[6] == kInputBindingKindOpcode) { + return true; + } + } + return false; +} + +inline at::Tensor registered_forward_readout_tensor( + const RegisteredForwardReadoutExecutorState& state, + const char* access_name) { + return registered_forward_strategy_tensor(state.tensors, access_name, "registered forward readout strategy"); +} + +inline at::Tensor run_registered_forward_readout_projection_into( + const at::Tensor& output_msg, + const at::Tensor& value_to_output_weight, + const at::Tensor& output_cell_bias, + const at::Tensor& output_cells, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal); + +#include "native_callables/message_forward_strategies.cuh" +#include "native_callables/readout_forward_strategies.cuh" + +#define REGISTERED_TEMPORAL_NATIVE_FORWARD_MESSAGE_CATALOG +#include "../flat_bucket_registered_native_callables.cuh" + +inline const RegisteredForwardMessageCarrierStrategy& registered_forward_message_carrier_strategy_for_native_row( + const RegisteredNativeStrategyRow& native_strategy) { + for (const RegisteredForwardMessageCarrierStrategy* strategy = registered_native_forward_message_catalog_begin(); + strategy != registered_native_forward_message_catalog_end(); + ++strategy) { + if (registered_forward_strategy_callable_matches_native_row(strategy->native_callable_hash, native_strategy)) { + return *strategy; + } + } + TORCH_CHECK( + false, + "registered fused forward message carrier has no native callable for compiler-emitted strategy row: handler_kind=", + native_strategy.handler_kind, + ", executor_id=", + native_strategy.executor_id, + ", primitive_opcode=", + native_strategy.primitive_opcode, + ", strategy_hash=", + native_strategy.strategy_id_hash, + ", access_count=", + native_strategy.program_access_count, + ", carry_count=", + native_strategy.state_carry_rule_count, + ", native_callable_hash=", + native_strategy.native_callable_hash); + return *registered_native_forward_message_catalog_begin(); +} + +inline const RegisteredForwardMessageCarrierStrategy& registered_forward_message_carrier_strategy_for_span( + const at::Tensor& native_strategy_rows, + const RegisteredFusedProgramSpan& span) { + const RegisteredNativeStrategyRow native_strategy = registered_native_strategy_row_for_span( + native_strategy_rows, + kForwardDirectionOpcode, + span, + "registered fused forward message carrier"); + return registered_forward_message_carrier_strategy_for_native_row(native_strategy); +} + +#define REGISTERED_TEMPORAL_NATIVE_FORWARD_READOUT_CATALOG +#include "../flat_bucket_registered_native_callables.cuh" + +inline const RegisteredForwardReadoutStrategy& registered_forward_readout_strategy_for_native_row( + const RegisteredNativeStrategyRow& native_strategy) { + for (const RegisteredForwardReadoutStrategy* strategy = registered_native_forward_readout_catalog_begin(); + strategy != registered_native_forward_readout_catalog_end(); + ++strategy) { + if (registered_forward_strategy_callable_matches_native_row(strategy->native_callable_hash, native_strategy)) { + return *strategy; + } + } + TORCH_CHECK( + false, + "registered fused forward readout has no native callable for compiler-emitted strategy row: handler_kind=", + native_strategy.handler_kind, + ", executor_id=", + native_strategy.executor_id, + ", primitive_opcode=", + native_strategy.primitive_opcode, + ", strategy_hash=", + native_strategy.strategy_id_hash, + ", access_count=", + native_strategy.program_access_count, + ", carry_count=", + native_strategy.state_carry_rule_count, + ", native_callable_hash=", + native_strategy.native_callable_hash); + return *registered_native_forward_readout_catalog_begin(); +} + +inline const RegisteredForwardReadoutStrategy& registered_forward_readout_strategy_for_span( + const at::Tensor& native_strategy_rows, + const RegisteredFusedProgramSpan& span) { + const RegisteredNativeStrategyRow native_strategy = registered_native_strategy_row_for_span( + native_strategy_rows, + kForwardDirectionOpcode, + span, + "registered fused forward readout"); + return registered_forward_readout_strategy_for_native_row(native_strategy); +} + +inline std::vector bind_registered_forward_message_executor_handlers( + const at::Tensor& forward_spans, + const at::Tensor& native_strategy_rows, + const at::Tensor& boundary_seq, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t input_count, + int64_t head_dim, + int64_t value_dim) { + const std::vector span_indices = registered_forward_handler_span_indices_by_capability( + forward_spans, + kMessageSurfaceOpcode, + kForwardHandlerMessageCarrierFlag, + "registered fused forward temporal message carrier"); + std::vector executors; + executors.reserve(span_indices.size()); + for (const int64_t span_index : span_indices) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, span_index); + const RegisteredForwardMessageCarrierStrategy& strategy = + registered_forward_message_carrier_strategy_for_span(native_strategy_rows, span); + const RegisteredNativeStrategyRow native_strategy = registered_native_strategy_row_for_span( + native_strategy_rows, + kForwardDirectionOpcode, + span, + strategy.name); + executors.push_back(strategy.bind( + span, + native_strategy, + boundary_seq, + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + forward_executor_rows, + forward_executor_binding_rows, + input_count, + head_dim, + value_dim)); + } + return executors; +} + +inline std::vector bind_registered_forward_readout_executor_handlers( + const at::Tensor& forward_spans, + const at::Tensor& native_strategy_rows, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows) { + const std::vector span_indices = registered_forward_handler_span_indices_by_capability( + forward_spans, + kReadoutSurfaceOpcode, + kForwardHandlerReadoutFlag, + "registered fused forward temporal readout"); + std::vector executors; + executors.reserve(span_indices.size()); + for (const int64_t span_index : span_indices) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, span_index); + const RegisteredForwardReadoutStrategy& strategy = + registered_forward_readout_strategy_for_span(native_strategy_rows, span); + const RegisteredNativeStrategyRow native_strategy = registered_native_strategy_row_for_span( + native_strategy_rows, + kForwardDirectionOpcode, + span, + strategy.name); + executors.push_back(strategy.bind( + span, + native_strategy, + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows)); + } + return executors; +} + +inline std::vector run_registered_forward_message_recurrent_kv_handler( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& recurrent_hidden, + const at::Tensor& empty, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t head_dim, + int64_t value_dim, + bool materialize_key_bank) { + const RegisteredForwardMessageCarrierStrategy& strategy = + registered_forward_message_carrier_strategy_for_native_row( + message_executor.native_strategy); + return strategy.recurrent_kv( + message_executor, + recurrent_hidden, + empty, + forward_executor_rows, + forward_executor_binding_rows, + head_dim, + value_dim, + materialize_key_bank); +} + +inline at::Tensor run_registered_forward_message_carrier_handler( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const at::Tensor& recurrent_msg_output_override, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay, + std::vector* memory_stage_rows, + int64_t memory_stage_local_step) { + const RegisteredForwardMessageCarrierStrategy& strategy = + registered_forward_message_carrier_strategy_for_native_row( + message_executor.native_strategy); + return strategy.message( + message_executor, + input_k_step, + input_v_step, + recurrent_hidden, + recurrent_k, + recurrent_v, + receiver_sender_idx, + local_distance, + local_delay, + step_flat, + recurrent_msg_output_override, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + runtime_buffer_logical_index, + distance_scale, + use_delay, + memory_stage_rows, + memory_stage_local_step); +} + +inline at::Tensor run_registered_forward_message_stream_transition_input_handler( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const RegisteredTransitionInputProjectionTarget& transition_target, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay, + std::vector* memory_stage_rows, + int64_t memory_stage_local_step) { + const RegisteredForwardMessageCarrierStrategy& strategy = + registered_forward_message_carrier_strategy_for_native_row( + message_executor.native_strategy); + TORCH_CHECK( + strategy.stream_transition_input != nullptr, + "registered message/transition producer-consumer row selected direct transition input, " + "but the message native strategy has no stream_transition_input implementation"); + return strategy.stream_transition_input( + message_executor, + input_k_step, + input_v_step, + recurrent_hidden, + recurrent_k, + recurrent_v, + receiver_sender_idx, + local_distance, + local_delay, + step_flat, + transition_target, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + runtime_buffer_logical_index, + distance_scale, + use_delay, + memory_stage_rows, + memory_stage_local_step); +} + +inline at::Tensor run_registered_forward_readout_message_handler( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay) { + const RegisteredForwardReadoutStrategy& strategy = + registered_forward_readout_strategy_for_native_row( + readout_executor.native_strategy); + return strategy.message( + readout_executor, + input_k_step, + input_v_step, + recurrent_k, + recurrent_v, + receiver_sender_idx, + local_distance, + local_delay, + step_flat, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + runtime_buffer_logical_index, + distance_scale, + use_delay); +} + +inline at::Tensor run_registered_forward_message_stream_readout_handler( + const RegisteredForwardMessageExecutorState& message_executor, + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay) { + const RegisteredForwardMessageCarrierStrategy& strategy = + registered_forward_message_carrier_strategy_for_native_row( + message_executor.native_strategy); + TORCH_CHECK( + strategy.stream_readout_message != nullptr, + "registered readout/message producer-consumer row selected streaming readout, " + "but the message native strategy has no stream_readout_message implementation"); + return strategy.stream_readout_message( + message_executor, + readout_executor, + input_k_step, + input_v_step, + recurrent_hidden, + receiver_sender_idx, + local_distance, + local_delay, + step_flat, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + runtime_buffer_logical_index, + distance_scale, + use_delay); +} + +inline at::Tensor run_registered_forward_readout_projection_handler( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& output_msg, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index) { + const RegisteredForwardReadoutStrategy& strategy = + registered_forward_readout_strategy_for_native_row( + readout_executor.native_strategy); + return strategy.project( + readout_executor, + output_msg, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + runtime_buffer_logical_index); +} + +inline at::Tensor run_registered_forward_readout_projection_into_handler( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& output_msg, + const at::Tensor& output_cells, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows) { + const RegisteredForwardReadoutStrategy& strategy = + registered_forward_readout_strategy_for_native_row( + readout_executor.native_strategy); + TORCH_CHECK( + strategy.project_into != nullptr, + "registered forward readout strategy selected output-route streaming projection " + "but has no projection_into implementation"); + return strategy.project_into( + readout_executor, + output_msg, + output_cells, + forward_executor_rows, + forward_executor_binding_rows); +} + +inline void validate_registered_forward_streaming_step_physical_contract( + const at::Tensor& physical_strategy_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + bool return_final_program_tensors, + bool return_reverse_artifacts, + int64_t local_time_steps) { + if (!registered_physical_strategy_active_is_streaming( + physical_strategy_rows, + "registered fused forward program")) { + return; + } + TORCH_CHECK( + !return_reverse_artifacts, + "registered fused forward streaming-step strategy cannot retain reverse artifacts"); + TORCH_CHECK( + !return_final_program_tensors, + "registered fused forward streaming-step strategy cannot retain final program tensors"); + TORCH_CHECK(local_time_steps > 0, "registered fused forward streaming-step strategy requires physical steps"); + TORCH_CHECK( + registered_runtime_buffer_has_deferred_local_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardRecurrentMsg, + 0), + "registered fused forward streaming-step strategy requires deferred local recurrent-message buffer ownership"); + TORCH_CHECK( + registered_runtime_buffer_has_deferred_local_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardOutputMsg, + 0), + "registered fused forward streaming-step strategy requires deferred local output-message buffer ownership"); + TORCH_CHECK( + registered_runtime_buffer_has_deferred_local_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardOutputCells, + 0), + "registered fused forward streaming-step strategy requires deferred local output-cells buffer ownership"); + TORCH_CHECK( + registered_runtime_buffer_has_deferred_local_transition_forward_output( + runtime_buffer_tensors, + runtime_buffer_rows), + "registered fused forward streaming-step strategy requires deferred local transition-output ownership"); +} + +inline void validate_readout_message_producer_consumer_rows( + const at::Tensor& readout_message_producer_consumer_rows, + const at::Tensor& forward_output_route_rows, + int64_t schema_version, + const char* name) { + check_cpu_long_rank2( + readout_message_producer_consumer_rows, + "readout_message_producer_consumer_rows", + kReadoutMessageProducerConsumerRowColumns); + check_cpu_long_rank2(forward_output_route_rows, "forward_output_route_rows", kForwardOutputRouteRowColumns); + TORCH_CHECK( + readout_message_producer_consumer_rows.size(0) > 0, + name, + " requires compiler-owned readout_message_producer_consumer_rows"); + const int64_t* rows = readout_message_producer_consumer_rows.data_ptr(); + const int64_t* output_routes = forward_output_route_rows.data_ptr(); + bool saw_active_materialized = false; + bool saw_active_streaming = false; + bool saw_streaming_candidate = false; + for (int64_t row_index = 0; row_index < readout_message_producer_consumer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReadoutMessageProducerConsumerRowColumns; + const int64_t strategy_opcode = row[2]; + const int64_t status_opcode = row[3]; + const int64_t executable = row[4]; + const int64_t producer_surface_opcode = row[5]; + const int64_t consumer_surface_opcode = row[9]; + const int64_t consumer_executor_row = row[10]; + const int64_t consumer_executor_id = row[11]; + const int64_t consumer_bucket = row[12]; + const int64_t output_route_row = row[13]; + const int64_t required_role_mask = row[14]; + const int64_t blocker_opcode = row[15]; + TORCH_CHECK(row[0] == row_index, name, " readout_message_producer_consumer_rows must be densely indexed"); + TORCH_CHECK(row[1] == schema_version, name, " readout_message_producer_consumer_rows schema version mismatch"); + TORCH_CHECK( + strategy_opcode == kReadoutMessageProducerConsumerMaterializedKvAfter || + strategy_opcode == kReadoutMessageProducerConsumerStreamFromMessageProjection, + name, + " readout_message_producer_consumer_rows row ", + row_index, + " has unknown strategy opcode ", + strategy_opcode); + TORCH_CHECK( + status_opcode == kReadoutMessageProducerConsumerStatusActive || + status_opcode == kReadoutMessageProducerConsumerStatusCandidate || + status_opcode == kReadoutMessageProducerConsumerStatusBlocked, + name, + " readout_message_producer_consumer_rows row ", + row_index, + " has unknown status opcode ", + status_opcode); + TORCH_CHECK( + executable == 0 || executable == 1, + name, + " readout_message_producer_consumer_rows executable flag must be 0/1"); + TORCH_CHECK( + producer_surface_opcode == kMessageSurfaceOpcode, + name, + " readout_message_producer_consumer_rows producer surface must be message"); + TORCH_CHECK( + consumer_surface_opcode == kReadoutSurfaceOpcode, + name, + " readout_message_producer_consumer_rows consumer surface must be readout"); + TORCH_CHECK( + 0 <= output_route_row && output_route_row < forward_output_route_rows.size(0), + name, + " readout_message_producer_consumer_rows row ", + row_index, + " references invalid forward output route row"); + const int64_t* output_route = output_routes + output_route_row * kForwardOutputRouteRowColumns; + TORCH_CHECK( + output_route[2] == kReadoutSurfaceOpcode && output_route[3] == consumer_executor_row && + output_route[4] == consumer_executor_id && output_route[5] == consumer_bucket, + name, + " readout_message_producer_consumer_rows row ", + row_index, + " output route does not match readout consumer executor"); + TORCH_CHECK( + (required_role_mask & kReadoutMessageProducerConsumerRoleReadoutOutputQuery) != 0 && + (required_role_mask & kReadoutMessageProducerConsumerRoleOutputRouteRows) != 0 && + (required_role_mask & kReadoutMessageProducerConsumerRoleMemoryLivenessRows) != 0, + name, + " readout_message_producer_consumer_rows row ", + row_index, + " is missing required compiler route/liveness roles"); + if (strategy_opcode == kReadoutMessageProducerConsumerMaterializedKvAfter) { + TORCH_CHECK( + (required_role_mask & kReadoutMessageProducerConsumerRoleRecurrentKAfter) != 0 && + (required_role_mask & kReadoutMessageProducerConsumerRoleRecurrentVAfter) != 0, + name, + " materialized readout/message strategy must consume recurrent K/V-after roles"); + } else { + saw_streaming_candidate = true; + TORCH_CHECK( + (required_role_mask & kReadoutMessageProducerConsumerRoleRecurrentHidden) != 0 && + (required_role_mask & kReadoutMessageProducerConsumerRoleRecurrentKvWeight) != 0, + name, + " streaming readout/message strategy must consume recurrent hidden and recurrent K/V weight roles"); + } + TORCH_CHECK( + blocker_opcode == kReadoutMessageProducerConsumerBlockerNone || + blocker_opcode == kReadoutMessageProducerConsumerBlockerPendingProgramBody || + blocker_opcode == kReadoutMessageProducerConsumerBlockerMissingExecutor || + blocker_opcode == kReadoutMessageProducerConsumerBlockerMissingOutputRoute || + blocker_opcode == kReadoutMessageProducerConsumerBlockerMissingStreamingBindings || + blocker_opcode == kReadoutMessageProducerConsumerBlockerCostRejected, + name, + " readout_message_producer_consumer_rows row ", + row_index, + " has invalid blocker opcode"); + if (status_opcode == kReadoutMessageProducerConsumerStatusActive) { + TORCH_CHECK(executable == 1, name, " active readout_message_producer_consumer row must be executable"); + TORCH_CHECK(blocker_opcode == kReadoutMessageProducerConsumerBlockerNone, name, " active readout_message_producer_consumer row is blocked"); + if (strategy_opcode == kReadoutMessageProducerConsumerMaterializedKvAfter) { + saw_active_materialized = true; + } else if (strategy_opcode == kReadoutMessageProducerConsumerStreamFromMessageProjection) { + saw_active_streaming = true; + } + } else { + TORCH_CHECK(executable == 0, name, " non-active readout_message_producer_consumer row must not be executable"); + } + } + TORCH_CHECK( + saw_active_materialized || saw_active_streaming, + name, + " readout_message_producer_consumer_rows missing active executable readout/message strategy row"); + TORCH_CHECK( + saw_streaming_candidate, + name, + " readout_message_producer_consumer_rows missing stream_readout_from_message_projection row"); +} + +inline const int64_t* active_readout_message_producer_consumer_row_for_route( + const at::Tensor& readout_message_producer_consumer_rows, + const RegisteredForwardReadoutExecutorState& readout_executor, + const char* name) { + check_cpu_long_rank2( + readout_message_producer_consumer_rows, + "readout_message_producer_consumer_rows", + kReadoutMessageProducerConsumerRowColumns); + const int64_t* rows = readout_message_producer_consumer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < readout_message_producer_consumer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReadoutMessageProducerConsumerRowColumns; + if ( + row[3] == kReadoutMessageProducerConsumerStatusActive && + row[4] == 1 && + row[9] == kReadoutSurfaceOpcode && + row[10] == readout_executor.span.executor_row_index && + row[11] == readout_executor.span.executor_id && + row[12] == readout_executor.span.bucket_ordinal) { + return row; + } + } + TORCH_CHECK(false, name, " has no active readout/message producer-consumer row for readout route"); + return nullptr; +} + +inline void validate_message_transition_producer_consumer_rows( + const at::Tensor& message_transition_producer_consumer_rows, + int64_t schema_version, + const char* name) { + check_cpu_long_rank2( + message_transition_producer_consumer_rows, + "message_transition_producer_consumer_rows", + kMessageTransitionProducerConsumerRowColumns); + TORCH_CHECK( + message_transition_producer_consumer_rows.size(0) > 0, + name, + " requires compiler-owned message_transition_producer_consumer_rows"); + const int64_t* rows = message_transition_producer_consumer_rows.data_ptr(); + bool saw_active_materialized = false; + bool saw_active_streaming = false; + bool saw_streaming_candidate = false; + for (int64_t row_index = 0; row_index < message_transition_producer_consumer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kMessageTransitionProducerConsumerRowColumns; + const int64_t strategy_opcode = row[2]; + const int64_t status_opcode = row[3]; + const int64_t executable = row[4]; + const int64_t producer_surface_opcode = row[5]; + const int64_t consumer_surface_opcode = row[9]; + const int64_t aggregate_access_opcode = row[13]; + const int64_t required_role_mask = row[14]; + const int64_t blocker_opcode = row[15]; + TORCH_CHECK(row[0] == row_index, name, " message_transition_producer_consumer_rows must be densely indexed"); + TORCH_CHECK(row[1] == schema_version, name, " message_transition_producer_consumer_rows schema version mismatch"); + TORCH_CHECK( + strategy_opcode == kMessageTransitionProducerConsumerMaterializedRecurrentMessage || + strategy_opcode == kMessageTransitionProducerConsumerStreamToTransitionInput, + name, + " message_transition_producer_consumer_rows row ", + row_index, + " has unknown strategy opcode ", + strategy_opcode); + TORCH_CHECK( + status_opcode == kMessageTransitionProducerConsumerStatusActive || + status_opcode == kMessageTransitionProducerConsumerStatusCandidate || + status_opcode == kMessageTransitionProducerConsumerStatusBlocked, + name, + " message_transition_producer_consumer_rows row ", + row_index, + " has unknown status opcode ", + status_opcode); + TORCH_CHECK( + executable == 0 || executable == 1, + name, + " message_transition_producer_consumer_rows executable flag must be 0/1"); + TORCH_CHECK( + producer_surface_opcode == kMessageSurfaceOpcode, + name, + " message_transition_producer_consumer_rows producer surface must be message"); + TORCH_CHECK( + consumer_surface_opcode == kTransitionSurfaceOpcode, + name, + " message_transition_producer_consumer_rows consumer surface must be transition"); + TORCH_CHECK( + aggregate_access_opcode == kProgramAccessTransitionAggregatedMessageInput, + name, + " message_transition_producer_consumer_rows must target transition aggregate input access rows"); + TORCH_CHECK( + (required_role_mask & kMessageTransitionProducerConsumerRoleRecurrentMsg) != 0 && + (required_role_mask & kMessageTransitionProducerConsumerRoleTransitionAggregateBinding) != 0 && + (required_role_mask & kMessageTransitionProducerConsumerRoleMemoryLivenessRows) != 0 && + (required_role_mask & kMessageTransitionProducerConsumerRolePhysicalStrategyRows) != 0, + name, + " message_transition_producer_consumer_rows row ", + row_index, + " is missing required compiler route/liveness roles"); + TORCH_CHECK( + blocker_opcode == kMessageTransitionProducerConsumerBlockerNone || + blocker_opcode == kMessageTransitionProducerConsumerBlockerMissingExecutor || + blocker_opcode == kMessageTransitionProducerConsumerBlockerMultipleConsumersNeedMergeRows || + blocker_opcode == kMessageTransitionProducerConsumerBlockerReceiverCountMismatch || + blocker_opcode == kMessageTransitionProducerConsumerBlockerPendingDirectChunkBody || + blocker_opcode == kMessageTransitionProducerConsumerBlockerCostRejected, + name, + " message_transition_producer_consumer_rows row ", + row_index, + " has invalid blocker opcode"); + if (strategy_opcode == kMessageTransitionProducerConsumerStreamToTransitionInput) { + saw_streaming_candidate = true; + } + if (status_opcode == kMessageTransitionProducerConsumerStatusActive) { + TORCH_CHECK(executable == 1, name, " active message_transition_producer_consumer row must be executable"); + TORCH_CHECK( + blocker_opcode == kMessageTransitionProducerConsumerBlockerNone, + name, + " active message_transition_producer_consumer row is blocked"); + if (strategy_opcode == kMessageTransitionProducerConsumerMaterializedRecurrentMessage) { + saw_active_materialized = true; + } else if (strategy_opcode == kMessageTransitionProducerConsumerStreamToTransitionInput) { + saw_active_streaming = true; + } + } else { + TORCH_CHECK(executable == 0, name, " non-active message_transition_producer_consumer row must not be executable"); + } + } + TORCH_CHECK( + saw_active_materialized || saw_active_streaming, + name, + " message_transition_producer_consumer_rows missing active executable message/transition strategy row"); + TORCH_CHECK( + saw_streaming_candidate, + name, + " message_transition_producer_consumer_rows missing stream_message_to_transition_input row"); +} + +inline const int64_t* active_message_transition_producer_consumer_row_for_transition( + const at::Tensor& message_transition_producer_consumer_rows, + const RegisteredFusedProgramSpan& transition_span, + const char* name) { + check_cpu_long_rank2( + message_transition_producer_consumer_rows, + "message_transition_producer_consumer_rows", + kMessageTransitionProducerConsumerRowColumns); + const int64_t* rows = message_transition_producer_consumer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < message_transition_producer_consumer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kMessageTransitionProducerConsumerRowColumns; + if ( + row[3] == kMessageTransitionProducerConsumerStatusActive && + row[4] == 1 && + row[9] == kTransitionSurfaceOpcode && + row[10] == transition_span.executor_row_index && + row[11] == transition_span.executor_id && + row[12] == transition_span.bucket_ordinal) { + return row; + } + } + TORCH_CHECK(false, name, " has no active message/transition producer-consumer row for transition route"); + return nullptr; +} + +inline bool registered_message_transition_streaming_row_targets_message( + const at::Tensor& message_transition_producer_consumer_rows, + const RegisteredForwardMessageExecutorState& message_executor) { + check_cpu_long_rank2( + message_transition_producer_consumer_rows, + "message_transition_producer_consumer_rows", + kMessageTransitionProducerConsumerRowColumns); + const int64_t* rows = message_transition_producer_consumer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < message_transition_producer_consumer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kMessageTransitionProducerConsumerRowColumns; + if ( + row[2] == kMessageTransitionProducerConsumerStreamToTransitionInput && + row[3] == kMessageTransitionProducerConsumerStatusActive && + row[4] == 1 && + row[5] == kMessageSurfaceOpcode && + row[6] == message_executor.span.executor_row_index && + row[7] == message_executor.span.executor_id && + row[8] == message_executor.span.bucket_ordinal) { + return true; + } + } + return false; +} + +inline bool registered_message_transition_direct_input_supported_for_message( + const at::Tensor& message_transition_producer_consumer_rows, + const at::Tensor& forward_spans, + const at::Tensor& primitive_rows, + const RegisteredForwardMessageExecutorState& message_executor) { + check_cpu_long_rank2( + message_transition_producer_consumer_rows, + "message_transition_producer_consumer_rows", + kMessageTransitionProducerConsumerRowColumns); + const int64_t* rows = message_transition_producer_consumer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < message_transition_producer_consumer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kMessageTransitionProducerConsumerRowColumns; + if ( + row[2] != kMessageTransitionProducerConsumerStreamToTransitionInput || + row[3] != kMessageTransitionProducerConsumerStatusActive || + row[4] != 1 || + row[5] != kMessageSurfaceOpcode || + row[6] != message_executor.span.executor_row_index || + row[7] != message_executor.span.executor_id || + row[8] != message_executor.span.bucket_ordinal) { + continue; + } + for (int64_t span_index = 0; span_index < forward_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, span_index); + if ( + span.surface_opcode != kTransitionSurfaceOpcode || + span.executor_row_index != row[10] || + span.executor_id != row[11] || + span.bucket_ordinal != row[12]) { + continue; + } + const int64_t* primitives = primitive_rows.data_ptr(); + for (int64_t local_primitive = 0; local_primitive < span.primitive_row_count; ++local_primitive) { + const int64_t primitive_row_index = span.primitive_row_start + local_primitive; + if (primitives[primitive_row_index * 4] == kPrimitiveGatedLogspaceRecurrenceOpcode) { + return true; + } + } + return false; + } + } + return false; +} + +inline const RegisteredForwardMessageStepState& message_step_state_for_message_transition_row( + const std::vector& message_step_states, + const int64_t* producer_consumer_row, + const char* name) { + TORCH_CHECK(producer_consumer_row != nullptr, name, " missing message/transition producer-consumer row"); + for (const RegisteredForwardMessageStepState& state : message_step_states) { + const RegisteredForwardMessageExecutorState& executor = *state.executor; + if ( + producer_consumer_row[5] == kMessageSurfaceOpcode && + producer_consumer_row[6] == executor.span.executor_row_index && + producer_consumer_row[7] == executor.span.executor_id && + producer_consumer_row[8] == executor.span.bucket_ordinal) { + return state; + } + } + TORCH_CHECK(false, name, " has no message producer state matching active message/transition row"); + return message_step_states.front(); +} + +inline const RegisteredForwardMessageStepState& message_step_state_for_producer_consumer_row( + const std::vector& message_step_states, + const int64_t* producer_consumer_row, + const char* name) { + TORCH_CHECK(producer_consumer_row != nullptr, name, " missing readout/message producer-consumer row"); + for (const RegisteredForwardMessageStepState& state : message_step_states) { + const RegisteredForwardMessageExecutorState& executor = *state.executor; + if ( + producer_consumer_row[5] == kMessageSurfaceOpcode && + producer_consumer_row[6] == executor.span.executor_row_index && + producer_consumer_row[7] == executor.span.executor_id && + producer_consumer_row[8] == executor.span.bucket_ordinal) { + return state; + } + } + TORCH_CHECK(false, name, " has no message producer state matching active readout/message row"); + return message_step_states.front(); +} + +} // namespace + +std::vector flat_bucket_registered_temporal_fused_forward_program_cuda( + const at::Tensor& boundary_seq, + const at::Tensor& recurrent_hidden_initial_backend_order, + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& forward_transition_state_carry_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& forward_artifact_merge_rows, + const at::Tensor& forward_output_route_rows, + const at::Tensor& readout_message_producer_consumer_rows, + const at::Tensor& message_transition_producer_consumer_rows, + std::vector forward_reset_tensors, + const at::Tensor& forward_reset_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& memory_runtime_schedule_rows, + const at::Tensor& physical_strategy_rows, + std::vector runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + std::vector forward_program_runtime_tensors, + const at::Tensor& forward_program_runtime_rows, + bool return_final_program_tensors, + bool return_reverse_artifacts, + int64_t schema_version) { + check_cuda_float_sequence_bank(boundary_seq, "boundary_seq"); + check_cuda_float_bank(recurrent_hidden_initial_backend_order, "recurrent_hidden_initial_backend_order"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_forward_program_access_rows(forward_program_access_rows); + check_forward_transition_state_carry_rows(forward_transition_state_carry_rows); + validate_forward_artifact_route_rows(forward_artifact_route_rows, schema_version); + validate_forward_artifact_merge_rows(forward_artifact_merge_rows, forward_artifact_route_rows, schema_version); + validate_forward_output_route_rows(forward_output_route_rows, schema_version); + validate_readout_message_producer_consumer_rows( + readout_message_producer_consumer_rows, + forward_output_route_rows, + schema_version, + "registered fused forward program"); + validate_message_transition_producer_consumer_rows( + message_transition_producer_consumer_rows, + schema_version, + "registered fused forward program"); + check_forward_program_runtime_rows(forward_program_runtime_tensors, forward_program_runtime_rows); + const at::Tensor recurrent_local_sender_idx = forward_program_runtime_tensor_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeRecurrentLocalSenderIdx, + "registered fused forward recurrent_local_sender_idx"); + const at::Tensor output_local_sender_idx = forward_program_runtime_tensor_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeOutputLocalSenderIdx, + "registered fused forward output_local_sender_idx"); + const at::Tensor local_distance = forward_program_runtime_tensor_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeLocalDistance, + "registered fused forward local_distance"); + const at::Tensor local_delay = forward_program_runtime_tensor_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeLocalDelay, + "registered fused forward local_delay"); + const int64_t inner_steps = forward_program_runtime_int_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeInnerSteps, + "registered fused forward inner_steps"); + const bool output_boundary_terminal = forward_program_runtime_int_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeOutputBoundaryTerminal, + "registered fused forward output_boundary_terminal") != 0; + const double distance_scale = forward_program_runtime_double_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeDistanceScale, + "registered fused forward distance_scale"); + const int64_t head_dim = forward_program_runtime_int_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeHeadDim, + "registered fused forward head_dim"); + const int64_t value_dim = forward_program_runtime_int_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeValueDim, + "registered fused forward value_dim"); + const bool use_delay = forward_program_runtime_int_for_role( + forward_program_runtime_tensors, + forward_program_runtime_rows, + kForwardRuntimeUseDelay, + "registered fused forward use_delay") != 0; + check_cuda_int_rank2(recurrent_local_sender_idx, "recurrent_local_sender_idx"); + check_cuda_int_rank2(output_local_sender_idx, "output_local_sender_idx"); + check_cuda_float_rank1(local_distance, "local_distance"); + check_cuda_int_rank1(local_delay, "local_delay"); + TORCH_CHECK(inner_steps > 0, "registered fused forward program requires positive inner_steps"); + std::vector decoded = validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); + validate_registered_runtime_buffer_rows( + memory_liveness_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + "registered fused forward program"); + validate_registered_memory_runtime_schedule_rows( + memory_liveness_rows, + memory_runtime_schedule_rows, + "registered fused forward program"); + validate_registered_physical_strategy_rows( + physical_strategy_rows, + memory_runtime_schedule_rows, + "registered fused forward program"); + const at::Tensor& forward_spans = decoded[1]; + validate_registered_fused_forward_program_dispatch(forward_spans, forward_program_access_rows); + validate_forward_transition_state_carry_contract(forward_spans, forward_transition_state_carry_rows); + const int64_t B = boundary_seq.size(0); + const int64_t outer_steps = boundary_seq.size(1); + const int64_t local_time_steps = outer_steps * inner_steps; + const int64_t input_count = boundary_seq.size(2); + const int64_t hidden = boundary_seq.size(3); + check_forward_reset_rows(forward_reset_tensors, forward_reset_rows, B, outer_steps); + TORCH_CHECK( + recurrent_hidden_initial_backend_order.size(0) == B && + recurrent_hidden_initial_backend_order.size(2) == hidden, + "recurrent_hidden_initial_backend_order must match boundary batch/hidden dimensions"); + const bool streaming_step_strategy = registered_physical_strategy_active_is_streaming( + physical_strategy_rows, + "registered fused forward program"); + validate_registered_forward_streaming_step_physical_contract( + physical_strategy_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + return_final_program_tensors, + return_reverse_artifacts, + local_time_steps); + + const std::vector message_executors = + bind_registered_forward_message_executor_handlers( + forward_spans, + native_strategy_rows, + boundary_seq, + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + forward_executor_rows, + forward_executor_binding_rows, + input_count, + head_dim, + value_dim); + const std::vector readout_executors = + bind_registered_forward_readout_executor_handlers( + forward_spans, + native_strategy_rows, + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows); + const int64_t* forward_route_rows = forward_artifact_route_rows.data_ptr(); + const int64_t* output_route_rows = forward_output_route_rows.data_ptr(); + std::vector output_route_readouts; + std::vector output_route_counts; + std::vector output_route_offsets; + output_route_readouts.reserve(static_cast(forward_output_route_rows.size(0))); + output_route_counts.reserve(static_cast(forward_output_route_rows.size(0))); + output_route_offsets.reserve(static_cast(forward_output_route_rows.size(0))); + for (int64_t route_index = 0; route_index < forward_output_route_rows.size(0); ++route_index) { + const int64_t* route = output_route_rows + route_index * kForwardOutputRouteRowColumns; + const RegisteredForwardReadoutExecutorState* selected_readout = nullptr; + for (const RegisteredForwardReadoutExecutorState& candidate : readout_executors) { + if ( + candidate.span.executor_row_index == route[3] && + candidate.span.executor_id == route[4] && + candidate.span.bucket_ordinal == route[5]) { + TORCH_CHECK( + selected_readout == nullptr, + "registered fused forward program has duplicate readout executors for compiler output route"); + selected_readout = &candidate; + } + } + TORCH_CHECK( + selected_readout != nullptr, + "registered fused forward program has no readout executor for compiler output route row ", + route_index); + const at::Tensor readout_output_q = + registered_forward_readout_tensor(*selected_readout, "readout_output_query"); + output_route_readouts.push_back(selected_readout); + output_route_counts.push_back(readout_output_q.size(0)); + output_route_offsets.push_back(route[9]); + } + const int64_t output_route_kind = output_route_rows[1]; + int64_t output_count = output_route_counts[0]; + if (output_route_kind == kForwardOutputRouteReadoutOutputConcat) { + output_count = 0; + for (int64_t route_index = 0; route_index < forward_output_route_rows.size(0); ++route_index) { + const int64_t* route = output_route_rows + route_index * kForwardOutputRouteRowColumns; + const int64_t route_output_count = output_route_counts[static_cast(route_index)]; + TORCH_CHECK( + route[9] == output_count, + "registered fused forward concat output route has invalid compiler output offset: row=", + route_index, + "; expected=", + output_count, + "; actual=", + route[9]); + TORCH_CHECK(route_output_count > 0, "registered fused forward concat output route has empty readout output"); + output_count += route_output_count; + } + } else { + for (int64_t route_index = 0; route_index < forward_output_route_rows.size(0); ++route_index) { + const int64_t* route = output_route_rows + route_index * kForwardOutputRouteRowColumns; + const int64_t route_output_count = output_route_counts[static_cast(route_index)]; + TORCH_CHECK( + route[9] == 0, + "registered fused forward non-concat output route must use zero compiler output offset"); + TORCH_CHECK( + route_output_count == output_count, + "registered fused forward output route requires matching readout output counts for select/sum semantics"); + } + } + const int64_t output_steps = output_boundary_terminal ? 1 : outer_steps; + const int64_t recurrent_count = recurrent_hidden_initial_backend_order.size(1); + const int64_t full_cell_count = input_count + recurrent_count + output_count; + at::Tensor output_seq = registered_runtime_buffer_for_role_any_shape( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleOutputSeq, + 0, + "registered fused forward program output_seq"); + TORCH_CHECK( + output_seq.dim() == 4 && + output_seq.size(0) == B && + output_seq.size(1) == output_steps && + output_seq.size(3) == hidden && + (output_seq.size(2) == output_count || output_seq.size(2) == 1), + "registered fused forward program output_seq shape does not match compiler output contract: output_seq=", + output_seq.sizes(), + "; output_count=", + output_count, + "; hidden=", + hidden); + bool has_forward_reset_tensor = false; + for (const at::Tensor& reset_tensor : forward_reset_tensors) { + if (reset_tensor.defined() && reset_tensor.numel() > 0) { + has_forward_reset_tensor = true; + break; + } + } + at::Tensor recurrent_hidden = recurrent_hidden_initial_backend_order.contiguous(); + int64_t output_index = 0; + const at::Tensor empty = boundary_seq.new_empty({0}); + const at::Tensor message_reset_seq = + forward_reset_tensor_for_kind(forward_reset_tensors, forward_reset_rows, kForwardResetMessage); + const at::Tensor transition_reset_seq = + forward_reset_tensor_for_kind(forward_reset_tensors, forward_reset_rows, kForwardResetTransition); + std::vector memory_stage_rows; + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + -1, + kRegisteredForwardMemoryStageEntry); + std::vector reverse_artifact_tensors; + std::vector reverse_artifact_binding_values; + auto reverse_artifact_tensor_for_storage = [&]( + int64_t role_id, + const at::Tensor& tensor) -> at::Tensor { + const at::Tensor contiguous = tensor.is_contiguous() ? tensor : tensor.contiguous(); + if (role_id == kReverseArtifactTransitionStateBefore) { + return contiguous.clone(); + } + return contiguous; + }; + auto append_reverse_artifact = [&]( + int64_t surface_opcode, + int64_t executor_row_index, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t role_id, + int64_t local_step, + int64_t flags, + const at::Tensor& tensor, + const char* subject) { + if (!return_reverse_artifacts) { + return; + } + TORCH_CHECK(tensor.defined(), "registered fused forward program cannot store undefined reverse artifact"); + const int64_t route_row = forward_artifact_route_row_for( + forward_artifact_route_rows, + surface_opcode, + executor_row_index, + executor_id, + bucket_ordinal, + role_id, + schema_version, + subject); + const int64_t* route = forward_route_rows + route_row * kForwardArtifactRouteRowColumns; + if (route[7] == 0) { + return; + } + const int64_t tensor_index = static_cast(reverse_artifact_tensors.size()); + reverse_artifact_binding_values.push_back(role_id); + reverse_artifact_binding_values.push_back(tensor_index); + reverse_artifact_binding_values.push_back(local_step); + reverse_artifact_binding_values.push_back(flags); + reverse_artifact_binding_values.push_back(route_row); + reverse_artifact_tensors.push_back(reverse_artifact_tensor_for_storage(role_id, tensor)); + }; + auto message_step_tensor_for_role = [&]( + const std::vector& message_step_states, + int64_t role_id, + const char* subject) -> at::Tensor { + const int64_t merge_row_index = forward_artifact_merge_row_for_surface_bucket_role( + forward_artifact_route_rows, + forward_artifact_merge_rows, + kMessageSurfaceOpcode, + kTemporalMessageBucketOrdinal, + role_id, + schema_version, + subject); + const int64_t* merge_rows = forward_artifact_merge_rows.data_ptr(); + const int64_t* merge_row = merge_rows + merge_row_index * kForwardArtifactMergeRowColumns; + const int64_t merge_kind = merge_row[4]; + const int64_t producer_route_row = merge_row[6]; + std::vector producer_tensors; + auto tensor_for_state = [&](const RegisteredForwardMessageStepState& state) -> at::Tensor { + if (role_id == kReverseArtifactInputK) { + return state.input_k_step; + } + if (role_id == kReverseArtifactInputV) { + return state.input_v_step; + } + if (role_id == kReverseArtifactRecurrentKBefore) { + return state.recurrent_k_before; + } + if (role_id == kReverseArtifactRecurrentVBefore) { + return state.recurrent_v_before; + } + if (role_id == kReverseArtifactRecurrentMsgBackendOrder) { + return state.recurrent_msg; + } + if (role_id == kReverseArtifactRecurrentK) { + return state.recurrent_k_after; + } + if (role_id == kReverseArtifactRecurrentV) { + return state.recurrent_v_after; + } + TORCH_CHECK(false, subject, " requested unsupported compiler message artifact role ", role_id); + return at::Tensor(); + }; + for (const RegisteredForwardMessageStepState& state : message_step_states) { + const RegisteredFusedProgramSpan& span = state.executor->span; + bool selected = false; + if (merge_kind == kForwardArtifactMergeIdentitySingleton) { + const int64_t* route = forward_route_rows + producer_route_row * kForwardArtifactRouteRowColumns; + selected = + span.executor_row_index == route[2] && + span.executor_id == route[3] && + span.bucket_ordinal == route[4]; + } else { + for (int64_t route_row_index = 0; route_row_index < forward_artifact_route_rows.size(0); ++route_row_index) { + const int64_t* route = forward_route_rows + route_row_index * kForwardArtifactRouteRowColumns; + if ( + route[1] == kMessageSurfaceOpcode && + route[4] == kTemporalMessageBucketOrdinal && + route[5] == role_id && + span.executor_row_index == route[2] && + span.executor_id == route[3] && + span.bucket_ordinal == route[4]) { + selected = true; + break; + } + } + } + if (selected) { + producer_tensors.push_back(tensor_for_state(state)); + } + } + TORCH_CHECK( + !producer_tensors.empty(), + subject, + " has no produced message artifact for compiler merge row ", + merge_row_index, + "; role=", + role_id); + if (merge_kind == kForwardArtifactMergeIdentitySingleton) { + TORCH_CHECK( + producer_tensors.size() == 1, + subject, + " identity_singleton merge selected ", + producer_tensors.size(), + " producers for role ", + role_id); + return producer_tensors[0]; + } + const at::Tensor& reference = producer_tensors[0]; + for (size_t producer_index = 1; producer_index < producer_tensors.size(); ++producer_index) { + const at::Tensor& tensor = producer_tensors[producer_index]; + TORCH_CHECK( + tensor.defined() && tensor.device() == reference.device() && tensor.scalar_type() == reference.scalar_type(), + subject, + " aggregate artifact producers must have matching device and dtype"); + TORCH_CHECK( + tensor.dim() == reference.dim(), + subject, + " aggregate artifact producers must have matching rank"); + for (int64_t dim = 0; dim < reference.dim(); ++dim) { + if (merge_kind == kForwardArtifactMergeConcatOrError && dim == reference.dim() - 1) { + continue; + } + TORCH_CHECK( + tensor.size(dim) == reference.size(dim), + subject, + " aggregate artifact producer shape mismatch at dim ", + dim, + ": reference=", + reference.sizes(), + "; tensor=", + tensor.sizes()); + } + } + if (merge_kind == kForwardArtifactMergeConcatOrError) { + return at::cat(producer_tensors, reference.dim() - 1).contiguous(); + } + if (merge_kind == kForwardArtifactMergeSumOrError) { + at::Tensor merged = reference.clone(); + for (size_t producer_index = 1; producer_index < producer_tensors.size(); ++producer_index) { + merged.add_(producer_tensors[producer_index]); + } + return merged.contiguous(); + } + TORCH_CHECK( + false, + subject, + " selected unsupported compiler artifact merge kind ", + merge_kind, + " for role ", + role_id); + return at::Tensor(); + }; + + for (int64_t physical_step = 0; physical_step < local_time_steps; ++physical_step) { + const int64_t outer_step = physical_step / inner_steps; + const int64_t inner_step = physical_step - outer_step * inner_steps; + const int64_t message_step = inner_step + 1; + const at::Tensor message_reset_for_outer = forward_reset_step_tensor(message_reset_seq, outer_step); + const at::Tensor message_reset = inner_step == 0 ? message_reset_for_outer : at::Tensor(); + at::Tensor transition_reset = forward_reset_step_tensor(transition_reset_seq, outer_step); + if (!transition_reset.defined() || transition_reset.numel() == 0) { + transition_reset = message_reset_for_outer; + } + if (message_reset.defined() && message_reset.numel() > 0) { + recurrent_hidden = zero_batch_rows_for_reset( + recurrent_hidden, + message_reset, + "registered fused forward message recurrent hidden reset"); + } + at::Tensor boundary_step = boundary_seq.select(1, outer_step).contiguous(); + at::Tensor recurrent_hidden_before = recurrent_hidden.contiguous(); + if (return_reverse_artifacts) { + append_reverse_artifact( + kRuntimePolicySurfaceOpcode, + -1, + 0, + -1, + kReverseArtifactBoundaryStep, + physical_step, + 0, + boundary_step, + "registered fused forward boundary_step artifact"); + if (registered_runtime_buffer_has_role( + runtime_buffer_rows, + kRuntimeBufferRoleForwardCellsPrevArtifact, + physical_step)) { + at::Tensor cells_prev_artifact = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardCellsPrevArtifact, + physical_step, + {B, full_cell_count, hidden}, + "registered fused forward program cells_prev artifact"); + cells_prev_artifact.zero_(); + append_reverse_artifact( + kRuntimePolicySurfaceOpcode, + -1, + 0, + -1, + kReverseArtifactCellsPrev, + physical_step, + 0, + cells_prev_artifact, + "registered fused forward cells_prev artifact"); + } + } + at::Tensor step_flat = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardMessageStepFlat, + 0, + {B}, + "registered fused forward program message-step workspace"); + step_flat.fill_(message_step); + std::vector message_step_states; + message_step_states.reserve(message_executors.size()); + at::Tensor streaming_recurrent_hidden_after_alias; + for (const RegisteredForwardMessageExecutorState& message_executor : message_executors) { + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactRecurrentHiddenBeforeBackendOrder, + physical_step, + 0, + recurrent_hidden_before, + "registered fused forward recurrent_hidden_before artifact"); + at::Tensor input_k_seq = registered_forward_message_cache_tensor(message_executor, "input_k_seq"); + at::Tensor input_v_seq = registered_forward_message_cache_tensor(message_executor, "input_v_seq"); + at::Tensor input_k_step = input_k_seq.select(0, outer_step).contiguous(); + at::Tensor input_v_step = input_v_seq.select(0, outer_step).contiguous(); + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterInputKv); + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactInputK, + physical_step, + 0, + input_k_step, + "registered fused forward input_k artifact"); + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactInputV, + physical_step, + 0, + input_v_step, + "registered fused forward input_v artifact"); + std::vector recurrent_kv = run_registered_forward_message_recurrent_kv_handler( + message_executor, + recurrent_hidden_before, + empty, + forward_executor_rows, + forward_executor_binding_rows, + head_dim, + value_dim, + return_reverse_artifacts); + at::Tensor recurrent_k_before = recurrent_kv[0]; + at::Tensor recurrent_v_before = recurrent_kv[1]; + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterRecurrentKvBefore); + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactRecurrentKBefore, + physical_step, + 0, + recurrent_k_before, + "registered fused forward recurrent_k_before artifact"); + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactRecurrentVBefore, + physical_step, + 0, + recurrent_v_before, + "registered fused forward recurrent_v_before artifact"); + at::Tensor recurrent_msg_output_override; + const RegisteredForwardMessageCarrierStrategy& message_strategy = + registered_forward_message_carrier_strategy_for_native_row(message_executor.native_strategy); + const bool message_transition_streaming_row_allowed = + streaming_step_strategy && + !return_reverse_artifacts && + !return_final_program_tensors && + registered_message_transition_streaming_row_targets_message( + message_transition_producer_consumer_rows, + message_executor); + const bool message_transition_direct_input_allowed = + message_transition_streaming_row_allowed && + message_strategy.stream_transition_input != nullptr && + registered_message_transition_direct_input_supported_for_message( + message_transition_producer_consumer_rows, + forward_spans, + primitive_rows, + message_executor); + const bool message_transition_streaming_alias_allowed = + message_transition_streaming_row_allowed && + !message_transition_direct_input_allowed && + recurrent_count > 0 && + recurrent_count == message_executor.span.receiver_count && + message_executor.span.receiver_start == 0 && + message_executor.message_output_dim == hidden; + if (message_transition_streaming_alias_allowed) { + if (!streaming_recurrent_hidden_after_alias.defined()) { + streaming_recurrent_hidden_after_alias = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardRecurrentHiddenAfter, + physical_step, + {B, recurrent_count, hidden}, + "registered fused forward streaming recurrent message/hidden-after alias"); + } + recurrent_msg_output_override = streaming_recurrent_hidden_after_alias; + } + at::Tensor recurrent_msg = message_transition_direct_input_allowed + ? input_v_step.new_empty({0}) + : run_registered_forward_message_carrier_handler( + message_executor, + input_k_step, + input_v_step, + recurrent_hidden_before, + recurrent_k_before, + recurrent_v_before, + recurrent_local_sender_idx, + local_distance, + local_delay, + step_flat, + recurrent_msg_output_override, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + physical_step, + distance_scale, + use_delay, + &memory_stage_rows, + physical_step); + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterRecurrentMessage); + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactRecurrentMsgBackendOrder, + physical_step, + 0, + recurrent_msg, + "registered fused forward recurrent_msg artifact"); + if (!return_reverse_artifacts) { + recurrent_k_before = empty_program_tensor_like(recurrent_k_before); + recurrent_v_before = empty_program_tensor_like(recurrent_v_before); + } + message_step_states.push_back(RegisteredForwardMessageStepState{ + &message_executor, + input_k_step, + input_v_step, + recurrent_k_before, + recurrent_v_before, + recurrent_msg, + at::Tensor(), + at::Tensor(), + }); + } + at::Tensor recurrent_msg = message_step_tensor_for_role( + message_step_states, + kReverseArtifactRecurrentMsgBackendOrder, + "registered fused forward transition recurrent message input"); + + std::set transition_state_before_artifact_buckets; + for (int64_t span_index = 0; span_index < forward_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, span_index); + const RegisteredForwardExecutorHandler& handler = registered_forward_executor_handler_for_span(span); + if (!handler.runs_transition_program) { + continue; + } + const int64_t bucket_start = span.receiver_start; + const int64_t bucket_count = span.receiver_count; + const int64_t aggregate_input_binding = temporal_program_access_binding_by_opcode( + forward_program_access_rows, + kProgramAccessTransitionAggregatedMessageInput, + span.executor_row_index, + span.bucket_ordinal, + false, + "forward", + "registered fused forward transition aggregate input"); + if (aggregate_input_binding >= 0) { + const int64_t* producer_consumer_row = active_message_transition_producer_consumer_row_for_transition( + message_transition_producer_consumer_rows, + span, + "registered fused forward transition aggregate input"); + if (producer_consumer_row[2] == kMessageTransitionProducerConsumerStreamToTransitionInput) { + const RegisteredForwardMessageStepState& producer_state = message_step_state_for_message_transition_row( + message_step_states, + producer_consumer_row, + "registered fused forward transition aggregate input"); + if (!producer_state.recurrent_msg.defined() || producer_state.recurrent_msg.numel() == 0) { + TORCH_CHECK( + bucket_start == 0 && bucket_count == producer_state.executor->span.receiver_count, + "registered fused forward message_transition_producer_consumer_rows selected direct transition input " + "for a non-singleton transition span; explicit merge/chunk rows are required"); + at::Tensor transition_input_output_override; + if (bucket_count == recurrent_count) { + if (!streaming_recurrent_hidden_after_alias.defined()) { + streaming_recurrent_hidden_after_alias = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardRecurrentHiddenAfter, + physical_step, + {B, recurrent_count, hidden}, + "registered fused forward streamed transition input/public-state alias"); + } + transition_input_output_override = streaming_recurrent_hidden_after_alias; + } + RegisteredTransitionInputProjectionTarget transition_target = + registered_transition_input_projection_target_for_span( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + primitive_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + transition_primitive_callable_rows, + forward_executor_binding_rows, + span, + producer_state.input_v_step.size(0), + producer_state.executor->message_output_dim, + "registered fused forward message-to-transition input projection", + transition_input_output_override); + at::Tensor transition_input = run_registered_forward_message_stream_transition_input_handler( + *producer_state.executor, + producer_state.input_k_step, + producer_state.input_v_step, + recurrent_hidden_before, + producer_state.recurrent_k_before, + producer_state.recurrent_v_before, + recurrent_local_sender_idx, + local_distance, + local_delay, + step_flat, + transition_target, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + physical_step, + distance_scale, + use_delay, + &memory_stage_rows, + physical_step); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + transition_target.output_binding, + transition_input, + "registered fused forward message-to-transition input projection"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + aggregate_input_binding, + empty_program_tensor_like(transition_input), + "registered fused forward streamed transition aggregate sentinel"); + } else { + TORCH_CHECK( + bucket_start == 0 && bucket_count == producer_state.recurrent_msg.size(1), + "registered fused forward message_transition_producer_consumer_rows selected direct transition input " + "for a non-singleton transition span; explicit merge/chunk rows are required"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + aggregate_input_binding, + producer_state.recurrent_msg, + "registered fused forward transition aggregate input"); + } + } else { + at::Tensor aggregate_input = recurrent_msg.slice(1, bucket_start, bucket_start + bucket_count).contiguous(); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + aggregate_input_binding, + aggregate_input, + "registered fused forward transition aggregate input"); + } + } + zero_forward_transition_state_inputs_for_reset( + program_tensors, + program_tensor_binding_rows, + forward_transition_state_carry_rows, + span.bucket_ordinal, + transition_reset); + if ( + return_reverse_artifacts && + transition_state_before_artifact_buckets.insert(span.bucket_ordinal).second) { + const int64_t* carry_rows = forward_transition_state_carry_rows.data_ptr(); + for (int64_t carry_index = 0; carry_index < forward_transition_state_carry_rows.size(0); ++carry_index) { + const int64_t* carry_row = carry_rows + carry_index * 3; + if (carry_row[0] != span.bucket_ordinal) { + continue; + } + const int64_t input_binding = carry_row[1]; + const bool reverse_consumes_input = + registered_reverse_transition_input_binding_consumes_forward_binding( + reverse_executor_binding_rows, span.bucket_ordinal, input_binding); + if (!reverse_consumes_input) { + continue; + } + append_reverse_artifact( + kTransitionSurfaceOpcode, + span.executor_row_index, + span.executor_id, + span.bucket_ordinal, + kReverseArtifactTransitionStateBefore, + physical_step, + span.bucket_ordinal * kTransitionStateArtifactFlagStride + input_binding, + program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding, + "registered fused forward transition state-before artifact"), + "registered fused forward transition_state_before artifact"); + } + } + } + + const bool terminal_local_transition_state = + streaming_step_strategy && + !return_reverse_artifacts && + !return_final_program_tensors && + physical_step + 1 == local_time_steps; + program_tensors = flat_bucket_registered_temporal_fused_forward_transition_program_cuda( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + primitive_rows, + forward_executor_rows, + forward_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + transition_primitive_callable_rows, + forward_executor_binding_rows, + memory_liveness_rows, + forward_transition_state_carry_rows, + !return_reverse_artifacts && !return_final_program_tensors, + terminal_local_transition_state, + schema_version); + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterTransition); + + at::Tensor next_recurrent_hidden = + streaming_recurrent_hidden_after_alias.defined() + ? streaming_recurrent_hidden_after_alias + : recurrent_count > 0 + ? registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardRecurrentHiddenAfter, + physical_step, + {B, recurrent_count, hidden}, + "registered fused forward program recurrent hidden after") + : recurrent_hidden; + std::set transition_public_output_buckets; + std::set transition_required_output_buckets; + for (int64_t span_index = 0; span_index < forward_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, span_index); + const RegisteredForwardExecutorHandler& handler = registered_forward_executor_handler_for_span(span); + if (!handler.runs_transition_program) { + continue; + } + transition_required_output_buckets.insert(span.bucket_ordinal); + const int64_t bucket_start = span.receiver_start; + const int64_t bucket_count = span.receiver_count; + const int64_t public_output_binding = temporal_program_access_binding_by_opcode( + forward_program_access_rows, + kProgramAccessTransitionPublicStateOutput, + span.executor_row_index, + span.bucket_ordinal, + false, + "forward", + "registered fused forward transition public output"); + if (public_output_binding >= 0) { + at::Tensor public_y = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + public_output_binding, + "registered fused forward transition public output"); + at::Tensor public_state_target = next_recurrent_hidden.slice(1, bucket_start, bucket_start + bucket_count); + if (public_state_target.data_ptr() != public_y.data_ptr()) { + public_state_target.copy_(public_y); + } + transition_public_output_buckets.insert(span.bucket_ordinal); + } + if (!terminal_local_transition_state) { + apply_forward_transition_state_carry_rows( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span.bucket_ordinal); + } + } + for (const int64_t bucket_ordinal : transition_required_output_buckets) { + TORCH_CHECK( + transition_public_output_buckets.count(bucket_ordinal) > 0, + "registered fused forward transition program produced no public state output for bucket ", + bucket_ordinal); + } + clear_forward_transition_output_binding_slots( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows); + if (streaming_step_strategy && !return_reverse_artifacts) { + for (RegisteredForwardMessageStepState& state : message_step_states) { + state.recurrent_msg = empty; + } + recurrent_msg = empty; + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterStreamingMessageRelease); + } + recurrent_hidden = next_recurrent_hidden; + bool recurrent_kv_after_materialized = false; + bool recurrent_k_after_materialized = false; + auto materialize_recurrent_kv_after = [&](bool materialize_key_bank) { + if (recurrent_kv_after_materialized && (!materialize_key_bank || recurrent_k_after_materialized)) { + return; + } + for (RegisteredForwardMessageStepState& state : message_step_states) { + const RegisteredForwardMessageExecutorState& message_executor = *state.executor; + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactRecurrentHiddenBackendOrder, + physical_step, + 0, + recurrent_hidden, + "registered fused forward recurrent_hidden artifact"); + std::vector recurrent_kv = run_registered_forward_message_recurrent_kv_handler( + message_executor, + recurrent_hidden, + empty, + forward_executor_rows, + forward_executor_binding_rows, + head_dim, + value_dim, + materialize_key_bank); + state.recurrent_k_after = recurrent_kv[0]; + state.recurrent_v_after = recurrent_kv[1]; + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactRecurrentK, + physical_step, + 0, + state.recurrent_k_after, + "registered fused forward recurrent_k artifact"); + append_reverse_artifact( + kMessageSurfaceOpcode, + message_executor.span.executor_row_index, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + kReverseArtifactRecurrentV, + physical_step, + 0, + state.recurrent_v_after, + "registered fused forward recurrent_v artifact"); + } + recurrent_kv_after_materialized = true; + recurrent_k_after_materialized = materialize_key_bank; + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterRecurrentKvAfter); + }; + if (return_reverse_artifacts) { + materialize_recurrent_kv_after(true); + } + + const bool emit_output = inner_step + 1 == inner_steps; + if (return_reverse_artifacts || (emit_output && (!output_boundary_terminal || outer_step + 1 == outer_steps))) { + at::Tensor output_step; + if (emit_output && (!output_boundary_terminal || outer_step + 1 == outer_steps)) { + output_step = output_seq.select(1, output_index); + } + const bool output_route_can_stream_into_output_seq = + streaming_step_strategy && + !has_forward_reset_tensor && + !return_reverse_artifacts && + output_step.defined() && + output_step.dim() == 3 && + output_step.size(0) == B && + output_step.size(1) == output_count && + output_step.size(2) == hidden && + ( + output_route_kind == kForwardOutputRouteReadoutOutputConcat || + forward_output_route_rows.size(0) == 1); + std::vector output_cell_routes; + output_cell_routes.reserve( + output_route_can_stream_into_output_seq ? 0 : output_route_readouts.size()); + std::set> emitted_readout_artifact_routes; + for (size_t route_index = 0; route_index < output_route_readouts.size(); ++route_index) { + const RegisteredForwardReadoutExecutorState* readout_executor_ptr = output_route_readouts[route_index]; + const RegisteredForwardReadoutExecutorState& readout_executor = *readout_executor_ptr; + const int64_t* producer_consumer_row = active_readout_message_producer_consumer_row_for_route( + readout_message_producer_consumer_rows, + readout_executor, + "registered fused forward readout output route"); + const RegisteredForwardMessageStepState& producer_state = message_step_state_for_producer_consumer_row( + message_step_states, + producer_consumer_row, + "registered fused forward readout output route"); + const at::Tensor input_k_step = producer_state.input_k_step; + const at::Tensor input_v_step = producer_state.input_v_step; + at::Tensor output_msg; + at::Tensor output_cells; + if (producer_consumer_row[2] == kReadoutMessageProducerConsumerStreamFromMessageProjection) { + output_msg = run_registered_forward_message_stream_readout_handler( + *producer_state.executor, + readout_executor, + input_k_step, + input_v_step, + recurrent_hidden, + output_local_sender_idx, + local_distance, + local_delay, + step_flat, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + physical_step, + distance_scale, + use_delay); + TORCH_CHECK( + output_msg.defined(), + "registered fused forward streaming readout selected an unsupported message/readout row"); + } else { + materialize_recurrent_kv_after(true); + at::Tensor recurrent_k = producer_state.recurrent_k_after; + at::Tensor recurrent_v = producer_state.recurrent_v_after; + output_msg = run_registered_forward_readout_message_handler( + readout_executor, + input_k_step, + input_v_step, + recurrent_k, + recurrent_v, + output_local_sender_idx, + local_distance, + local_delay, + step_flat, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + physical_step, + distance_scale, + use_delay); + } + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterReadoutMessage); + if (output_route_can_stream_into_output_seq) { + const int64_t route_output_count = output_route_counts[route_index]; + const int64_t route_output_offset = output_route_offsets[route_index]; + TORCH_CHECK( + route_output_offset >= 0 && route_output_count >= 0 && + route_output_offset + route_output_count <= output_step.size(1), + "registered fused forward output-route streaming projection has invalid compiler output span"); + at::Tensor output_route_target = output_step.narrow(1, route_output_offset, route_output_count); + output_cells = run_registered_forward_readout_projection_into_handler( + readout_executor, + output_msg, + output_route_target, + forward_executor_rows, + forward_executor_binding_rows); + } else { + output_cells = run_registered_forward_readout_projection_handler( + readout_executor, + output_msg, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + physical_step); + output_cell_routes.push_back(output_cells); + } + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterReadoutProjection); + const auto readout_key = std::make_tuple( + readout_executor.span.executor_row_index, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal); + if (emitted_readout_artifact_routes.insert(readout_key).second) { + append_reverse_artifact( + kReadoutSurfaceOpcode, + readout_executor.span.executor_row_index, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal, + kReverseArtifactOutputMsg, + physical_step, + 0, + output_msg, + "registered fused forward output_msg artifact"); + append_reverse_artifact( + kReadoutSurfaceOpcode, + readout_executor.span.executor_row_index, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal, + kReverseArtifactOutputCells, + physical_step, + 0, + output_cells, + "registered fused forward output_cells artifact"); + } + } + if (output_route_can_stream_into_output_seq) { + ++output_index; + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterOutputRoute); + continue; + } + TORCH_CHECK(!output_cell_routes.empty(), "registered fused forward program produced no output route tensors"); + at::Tensor output_cells = output_cell_routes[0]; + if (output_route_kind == kForwardOutputRouteReadoutOutputConcat) { + for (size_t route_index = 1; route_index < output_cell_routes.size(); ++route_index) { + const at::Tensor& candidate = output_cell_routes[route_index]; + TORCH_CHECK( + candidate.dim() == output_cells.dim() && + candidate.size(0) == output_cells.size(0) && + candidate.size(2) == output_cells.size(2), + "registered fused forward concat output route has incompatible output cell shapes"); + } + output_cells = at::cat(output_cell_routes, 1).contiguous(); + } else if (output_route_kind == kForwardOutputRouteReadoutOutputSum) { + output_cells = output_cell_routes[0].clone(); + for (size_t route_index = 1; route_index < output_cell_routes.size(); ++route_index) { + TORCH_CHECK( + output_cell_routes[route_index].sizes() == output_cells.sizes(), + "registered fused forward sum output route has incompatible output cell shapes"); + output_cells.add_(output_cell_routes[route_index]); + } + output_cells = output_cells.contiguous(); + } else { + TORCH_CHECK( + output_cell_routes.size() == 1, + "registered fused forward select/output-cells route selected multiple readout producers"); + } + if (emit_output && (!output_boundary_terminal || outer_step + 1 == outer_steps)) { + if (output_step.sizes() == output_cells.sizes()) { + output_step.copy_(output_cells); + } else if ( + output_step.dim() == 3 && + output_step.size(0) == output_cells.size(0) && + output_step.size(1) == 1 && + output_step.size(2) == output_cells.size(2)) { + output_step.copy_(output_cells.mean(1, true)); + } else { + TORCH_CHECK( + false, + "registered fused forward program output contract shape mismatch: output_step=", + output_step.sizes(), + "; output_cells=", + output_cells.sizes()); + } + ++output_index; + } + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + physical_step, + kRegisteredForwardMemoryStageAfterOutputRoute); + } + } + TORCH_CHECK(output_index == output_steps, "registered fused forward program emitted unexpected output count"); + compact_forward_program_tensor_table_for_return( + program_tensors, + program_tensor_binding_rows, + forward_transition_state_carry_rows, + return_final_program_tensors); + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + -1, + kRegisteredForwardMemoryStageAfterTensorCompaction); + std::vector outputs; + outputs.reserve(static_cast(program_tensors.size()) + 3); + outputs.push_back(output_seq); + outputs.push_back( + (return_final_program_tensors || return_reverse_artifacts) + ? recurrent_hidden + : empty_program_tensor_like(recurrent_hidden)); + outputs.insert(outputs.end(), program_tensors.begin(), program_tensors.end()); + append_registered_forward_memory_stage_row( + &memory_stage_rows, + boundary_seq, + -1, + kRegisteredForwardMemoryStageReturn); + outputs.push_back(registered_forward_memory_stage_rows_tensor(memory_stage_rows)); + if (return_reverse_artifacts) { + const int64_t row_count = + static_cast(reverse_artifact_binding_values.size() / kReverseArtifactBindingRowColumns); + at::Tensor binding_rows = at::empty( + {row_count, kReverseArtifactBindingRowColumns}, + at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + if (row_count > 0) { + std::copy( + reverse_artifact_binding_values.begin(), + reverse_artifact_binding_values.end(), + binding_rows.data_ptr()); + } + outputs.push_back(binding_rows); + outputs.insert(outputs.end(), reverse_artifact_tensors.begin(), reverse_artifact_tensors.end()); + } + return outputs; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/layout_kernels.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/layout_kernels.cuh new file mode 100644 index 00000000..44d2f358 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/layout_kernels.cuh @@ -0,0 +1,330 @@ +#pragma once + +__global__ void registered_backward_sender_kv_sender_kernel( + const float* __restrict__ grad_k, + const float* __restrict__ grad_v, + const float* __restrict__ direct_weight, + const float* __restrict__ grouped_weight, + float* __restrict__ grad_sender, + int64_t total_elements, + int sender_count, + int hidden_dim, + int head_dim, + int value_dim, + int kv_dim, + int group_size, + bool use_grouped_weight, + bool has_grad_k, + bool has_grad_v) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden_dim); + const int sender = static_cast((linear / hidden_dim) % sender_count); + const int b = static_cast(linear / (static_cast(hidden_dim) * sender_count)); + const float* weight = nullptr; + if (use_grouped_weight) { + const int group = sender / group_size; + weight = grouped_weight + (static_cast(group) * hidden_dim * kv_dim); + } else { + weight = direct_weight + (static_cast(sender) * hidden_dim * kv_dim); + } + float acc = 0.0f; + if (has_grad_k) { + for (int d = 0; d < head_dim; ++d) { + acc += grad_k[(static_cast(b) * sender_count + sender) * head_dim + d] * + weight[static_cast(h) * kv_dim + d]; + } + } + if (has_grad_v) { + for (int d = 0; d < value_dim; ++d) { + acc += grad_v[(static_cast(b) * sender_count + sender) * value_dim + d] * + weight[static_cast(h) * kv_dim + head_dim + d]; + } + } + grad_sender[linear] = acc; + } +} + +__global__ void registered_backward_sender_kv_weight_kernel( + const float* __restrict__ sender_cells, + const float* __restrict__ grad_k, + const float* __restrict__ grad_v, + float* __restrict__ grad_weight, + int64_t total_elements, + int batch_size, + int sender_count, + int hidden_dim, + int head_dim, + int value_dim, + int kv_dim, + int group_size, + bool use_grouped_weight, + bool has_grad_k, + bool has_grad_v) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int d = static_cast(linear % kv_dim); + const int h = static_cast((linear / kv_dim) % hidden_dim); + const int owner = static_cast(linear / (static_cast(kv_dim) * hidden_dim)); + const int sender_begin = use_grouped_weight ? owner * group_size : owner; + const int sender_end = use_grouped_weight ? sender_begin + group_size : owner + 1; + float acc = 0.0f; + for (int sender = sender_begin; sender < sender_end; ++sender) { + for (int b = 0; b < batch_size; ++b) { + const float source = sender_cells[(static_cast(b) * sender_count + sender) * hidden_dim + h]; + float grad = 0.0f; + if (d < head_dim) { + if (has_grad_k) { + grad = grad_k[(static_cast(b) * sender_count + sender) * head_dim + d]; + } + } else if (has_grad_v) { + grad = grad_v[(static_cast(b) * sender_count + sender) * value_dim + d - head_dim]; + } + acc += source * grad; + } + } + grad_weight[linear] = acc; + } +} + +__global__ void registered_backward_sender_value_sender_kernel( + const float* __restrict__ grad_v, + const float* __restrict__ direct_weight, + const float* __restrict__ grouped_weight, + float* __restrict__ grad_sender, + int64_t total_elements, + int sender_count, + int hidden_dim, + int value_dim, + int group_size, + bool use_grouped_weight) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden_dim); + const int sender = static_cast((linear / hidden_dim) % sender_count); + const int b = static_cast(linear / (static_cast(hidden_dim) * sender_count)); + const float* weight = nullptr; + if (use_grouped_weight) { + const int group = sender / group_size; + weight = grouped_weight + (static_cast(group) * hidden_dim * value_dim); + } else { + weight = direct_weight + (static_cast(sender) * hidden_dim * value_dim); + } + float acc = 0.0f; + for (int d = 0; d < value_dim; ++d) { + acc += grad_v[(static_cast(b) * sender_count + sender) * value_dim + d] * + weight[static_cast(h) * value_dim + d]; + } + grad_sender[linear] = acc; + } +} + +__global__ void registered_backward_sender_value_weight_kernel( + const float* __restrict__ sender_cells, + const float* __restrict__ grad_v, + float* __restrict__ grad_weight, + int64_t total_elements, + int batch_size, + int sender_count, + int hidden_dim, + int value_dim, + int group_size, + bool use_grouped_weight) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int d = static_cast(linear % value_dim); + const int h = static_cast((linear / value_dim) % hidden_dim); + const int owner = static_cast(linear / (static_cast(value_dim) * hidden_dim)); + const int sender_begin = use_grouped_weight ? owner * group_size : owner; + const int sender_end = use_grouped_weight ? sender_begin + group_size : owner + 1; + float acc = 0.0f; + for (int sender = sender_begin; sender < sender_end; ++sender) { + for (int b = 0; b < batch_size; ++b) { + const float source = sender_cells[(static_cast(b) * sender_count + sender) * hidden_dim + h]; + const float grad = grad_v[(static_cast(b) * sender_count + sender) * value_dim + d]; + acc += source * grad; + } + } + grad_weight[linear] = acc; + } +} + +template +void launch_registered_forward_readout_layout_epilogue( + const at::Tensor& boundary, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& output_q, + const at::Tensor& output_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& value_to_output_weight, + const at::Tensor& output_cell_bias, + const at::Tensor& backend_to_graph_inverse_order, + at::Tensor& output_cells, + at::Tensor& recurrent_hidden_graph_order, + at::Tensor& cells_out, + float distance_scale) { + const int B = static_cast(boundary.size(0)); + const int input_count = static_cast(boundary.size(1)); + const int recurrent_count = static_cast(recurrent_hidden_backend_order.size(1)); + const int output_count = static_cast(output_q.size(0)); + const int degree = static_cast(output_local_sender_idx.size(1)); + const int head_dim = static_cast(output_q.size(1)); + const int key_dim = static_cast(input_k.size(2)); + const int value_dim = static_cast(input_v.size(2)); + const int hidden_dim = static_cast(boundary.size(2)); + const int64_t total_elements = + static_cast(B) * (input_count + recurrent_count + output_count) * hidden_dim; + if (total_elements == 0) { + return; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_readout_layout_epilogue_kernel<<>>( + boundary.data_ptr(), + recurrent_hidden_backend_order.data_ptr(), + input_k.data_ptr(), + input_v.data_ptr(), + recurrent_k.data_ptr(), + recurrent_v.data_ptr(), + output_q.data_ptr(), + output_local_sender_idx.data_ptr(), + local_distance.data_ptr(), + value_to_output_weight.data_ptr(), + output_cell_bias.data_ptr(), + backend_to_graph_inverse_order.data_ptr(), + output_cells.data_ptr(), + recurrent_hidden_graph_order.data_ptr(), + cells_out.data_ptr(), + total_elements, + input_count, + recurrent_count, + output_count, + degree, + head_dim, + key_dim, + value_dim, + hidden_dim, + distance_scale); + check_launch("registered_forward_readout_layout_epilogue_kernel"); +} + +template +void launch_registered_forward_cells_layout( + const at::Tensor& boundary, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& output_cells, + const at::Tensor& backend_to_graph_inverse_order, + at::Tensor& recurrent_hidden_graph_order, + at::Tensor& cells_out) { + const int B = static_cast(boundary.size(0)); + const int input_count = static_cast(boundary.size(1)); + const int recurrent_count = static_cast(recurrent_hidden_backend_order.size(1)); + const int output_count = static_cast(output_cells.size(1)); + const int hidden_dim = static_cast(boundary.size(2)); + const int64_t total_elements = + static_cast(B) * (input_count + recurrent_count + output_count) * hidden_dim; + if (total_elements == 0) { + return; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_cells_layout_kernel<<>>( + boundary.data_ptr(), + recurrent_hidden_backend_order.data_ptr(), + output_cells.data_ptr(), + backend_to_graph_inverse_order.data_ptr(), + recurrent_hidden_graph_order.data_ptr(), + cells_out.data_ptr(), + total_elements, + input_count, + recurrent_count, + output_count, + hidden_dim); + check_launch("registered_forward_cells_layout_kernel"); +} + +template +void launch_registered_backward_layout_split( + const at::Tensor& grad_cells_out, + const at::Tensor& graph_to_backend_order, + at::Tensor& grad_boundary, + at::Tensor& grad_recurrent_hidden_backend, + int input_count, + int recurrent_count, + int output_count, + int hidden_dim) { + const int B = static_cast(grad_cells_out.size(0)); + const int64_t total_elements = static_cast(B) * (input_count + recurrent_count) * hidden_dim; + if (total_elements == 0) { + return; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_backward_layout_split_kernel<<>>( + grad_cells_out.data_ptr(), + graph_to_backend_order.data_ptr(), + grad_boundary.data_ptr(), + grad_recurrent_hidden_backend.data_ptr(), + total_elements, + input_count, + recurrent_count, + output_count, + hidden_dim); + check_launch("registered_backward_layout_split_kernel"); +} + +void launch_registered_backward_readout_projection( + const at::Tensor& grad_cells_out, + const at::Tensor& output_msg, + const at::Tensor& value_to_output_weight, + at::Tensor& grad_output_msg, + at::Tensor& grad_value_to_output_weight, + at::Tensor& grad_output_cell_bias, + int input_count, + int recurrent_count) { + const int B = static_cast(grad_cells_out.size(0)); + const int output_count = static_cast(output_msg.size(1)); + const int value_dim = static_cast(output_msg.size(2)); + const int hidden_dim = static_cast(grad_cells_out.size(2)); + const int64_t max_elements = std::max({ + static_cast(B) * output_count * value_dim, + static_cast(output_count) * value_dim * hidden_dim, + static_cast(output_count) * hidden_dim, + }); + if (max_elements == 0) { + return; + } + const int blocks = static_cast(std::min( + 4096, + (max_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_backward_readout_projection_kernel<<>>( + grad_cells_out.data_ptr(), + output_msg.data_ptr(), + value_to_output_weight.data_ptr(), + grad_output_msg.data_ptr(), + grad_value_to_output_weight.data_ptr(), + grad_output_cell_bias.data_ptr(), + max_elements, + B, + input_count, + recurrent_count, + output_count, + value_dim, + hidden_dim); + check_launch("registered_backward_readout_projection_kernel"); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/memory_runtime_buffers.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/memory_runtime_buffers.cuh new file mode 100644 index 00000000..eae5ca43 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/memory_runtime_buffers.cuh @@ -0,0 +1,769 @@ +#pragma once + +inline void validate_registered_fused_program_memory_rows( + const at::Tensor& primitive_rows, + const at::Tensor& memory_liveness_rows) { + check_cpu_long_rank2(memory_liveness_rows, "memory_liveness_rows", 10); + const int64_t primitive_count = primitive_rows.size(0); + const int64_t* rows = memory_liveness_rows.data_ptr(); + bool saw_workspace = false; + bool saw_policy_row = false; + bool saw_local_seed_policy = false; + bool saw_metadata_policy = false; + bool saw_primitive_output_policy = false; + bool saw_tape_policy = false; + bool saw_alias_policy = false; + bool saw_recompute_window_policy = false; + bool saw_materialization_policy = false; + bool saw_cuda_graph_constraint = false; + for (int64_t row_index = 0; row_index < memory_liveness_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + const int64_t entry_index = row[0]; + const int64_t primitive_row_index = row[1]; + const int64_t bucket_ordinal = row[2]; + const int64_t surface_opcode = row[3]; + const int64_t tensor_class_opcode = row[4]; + const int64_t lifetime_opcode = row[5]; + const int64_t workspace_opcode = row[6]; + const int64_t effect_opcode = row[7]; + const int64_t recompute_opcode = row[8]; + const int64_t owner_opcode = row[9]; + TORCH_CHECK(entry_index == row_index, "memory_liveness_rows must be densely indexed"); + TORCH_CHECK( + primitive_row_index == -1 || (primitive_row_index >= 0 && primitive_row_index < primitive_count), + "memory_liveness_rows row ", + row_index, + " references no compiler primitive row"); + TORCH_CHECK( + bucket_ordinal >= 0 || bucket_ordinal == kTemporalMessageBucketOrdinal || + bucket_ordinal == kTemporalReadoutBucketOrdinal || + bucket_ordinal == kTemporalParameterReductionBucketOrdinal, + "memory_liveness_rows row ", + row_index, + " has an invalid bucket ordinal"); + TORCH_CHECK(surface_opcode > 0, "memory_liveness_rows row ", row_index, " has an invalid surface opcode"); + TORCH_CHECK( + tensor_class_opcode > 0, + "memory_liveness_rows row ", + row_index, + " has an invalid tensor-class opcode"); + TORCH_CHECK(lifetime_opcode > 0, "memory_liveness_rows row ", row_index, " has an invalid lifetime opcode"); + TORCH_CHECK(workspace_opcode > 0, "memory_liveness_rows row ", row_index, " has an invalid workspace opcode"); + TORCH_CHECK(effect_opcode > 0, "memory_liveness_rows row ", row_index, " has an invalid effect opcode"); + TORCH_CHECK(recompute_opcode > 0, "memory_liveness_rows row ", row_index, " has an invalid recompute opcode"); + TORCH_CHECK(owner_opcode > 0, "memory_liveness_rows row ", row_index, " has an invalid owner opcode"); + const bool is_policy_row = surface_opcode == kRuntimePolicySurfaceOpcode || + workspace_opcode == kMemoryWorkspacePolicyTable || owner_opcode == kMemoryOwnerCompilerMemoryPolicy; + if (is_policy_row) { + TORCH_CHECK( + primitive_row_index == -1, + "runtime policy memory_liveness_rows row ", + row_index, + " must not claim a primitive row"); + TORCH_CHECK( + surface_opcode == kRuntimePolicySurfaceOpcode, + "runtime policy memory_liveness_rows row ", + row_index, + " must use the runtime_policy surface opcode"); + TORCH_CHECK( + workspace_opcode == kMemoryWorkspacePolicyTable, + "runtime policy memory_liveness_rows row ", + row_index, + " must use the policy_table workspace opcode"); + TORCH_CHECK( + owner_opcode == kMemoryOwnerCompilerMemoryPolicy, + "runtime policy memory_liveness_rows row ", + row_index, + " must be owned by compiler_memory_policy"); + TORCH_CHECK( + effect_opcode == kMemoryEffectLocalSeedPolicy || effect_opcode == kMemoryEffectMetadataPolicy || + effect_opcode == kMemoryEffectPrimitiveOutputPolicy || effect_opcode == kMemoryEffectTapePolicy || + effect_opcode == kMemoryEffectAliasPolicy || effect_opcode == kMemoryEffectRecomputeWindowPolicy || + effect_opcode == kMemoryEffectMaterializationPolicy || effect_opcode == kMemoryEffectCudaGraphConstraint, + "runtime policy memory_liveness_rows row ", + row_index, + " has an unsupported policy effect opcode ", + effect_opcode); + saw_policy_row = true; + saw_local_seed_policy = saw_local_seed_policy || effect_opcode == kMemoryEffectLocalSeedPolicy; + saw_metadata_policy = saw_metadata_policy || effect_opcode == kMemoryEffectMetadataPolicy; + saw_primitive_output_policy = saw_primitive_output_policy || effect_opcode == kMemoryEffectPrimitiveOutputPolicy; + saw_tape_policy = saw_tape_policy || effect_opcode == kMemoryEffectTapePolicy; + saw_alias_policy = saw_alias_policy || effect_opcode == kMemoryEffectAliasPolicy; + saw_recompute_window_policy = saw_recompute_window_policy || effect_opcode == kMemoryEffectRecomputeWindowPolicy; + saw_materialization_policy = saw_materialization_policy || effect_opcode == kMemoryEffectMaterializationPolicy; + saw_cuda_graph_constraint = saw_cuda_graph_constraint || effect_opcode == kMemoryEffectCudaGraphConstraint; + } + saw_workspace = true; + } + TORCH_CHECK(saw_workspace, "fused temporal program requires compiler-owned memory liveness rows"); + if (saw_policy_row) { + TORCH_CHECK(saw_local_seed_policy, "runtime policy memory_liveness_rows are missing local_seed_policy"); + TORCH_CHECK(saw_metadata_policy, "runtime policy memory_liveness_rows are missing metadata_policy"); + TORCH_CHECK(saw_primitive_output_policy, "runtime policy memory_liveness_rows are missing primitive_output_policy"); + TORCH_CHECK(saw_tape_policy, "runtime policy memory_liveness_rows are missing tape_policy"); + TORCH_CHECK(saw_alias_policy, "runtime policy memory_liveness_rows are missing alias_policy"); + TORCH_CHECK(saw_recompute_window_policy, "runtime policy memory_liveness_rows are missing recompute_window_policy"); + TORCH_CHECK(saw_materialization_policy, "runtime policy memory_liveness_rows are missing materialization_policy"); + TORCH_CHECK(saw_cuda_graph_constraint, "runtime policy memory_liveness_rows are missing cuda_graph_constraint"); + } +} + +inline bool registered_memory_row_requires_runtime_buffer(const int64_t* row) { + const int64_t surface_opcode = row[3]; + const int64_t workspace_opcode = row[6]; + const int64_t effect_opcode = row[7]; + const int64_t owner_opcode = row[9]; + if (surface_opcode == kRuntimePolicySurfaceOpcode || workspace_opcode == kMemoryWorkspacePolicyTable || + owner_opcode == kMemoryOwnerCompilerMemoryPolicy) { + return false; + } + if (workspace_opcode == kMemoryWorkspaceParameterTable) { + return false; + } + if (effect_opcode == kMemoryEffectParameterRead) { + return false; + } + return workspace_opcode == kMemoryWorkspaceMessage || workspace_opcode == kMemoryWorkspaceOutput || + workspace_opcode == kMemoryWorkspaceTransition || workspace_opcode == kMemoryWorkspaceReduction || + workspace_opcode == kMemoryWorkspaceStateCarry || workspace_opcode == kMemoryWorkspaceTensorTable || + workspace_opcode == kMemoryWorkspacePrimitive; +} + +inline int64_t runtime_schedule_policy_effect_for_role(int64_t role_opcode) { + if (role_opcode == kRuntimeScheduleRoleLocalSeedPolicy) { + return kMemoryEffectLocalSeedPolicy; + } + if (role_opcode == kRuntimeScheduleRoleMetadataPolicy) { + return kMemoryEffectMetadataPolicy; + } + if (role_opcode == kRuntimeScheduleRolePrimitiveOutputPolicy) { + return kMemoryEffectPrimitiveOutputPolicy; + } + if (role_opcode == kRuntimeScheduleRoleTapePolicy) { + return kMemoryEffectTapePolicy; + } + if (role_opcode == kRuntimeScheduleRoleAliasPolicy) { + return kMemoryEffectAliasPolicy; + } + if (role_opcode == kRuntimeScheduleRoleRecomputeWindowPolicy) { + return kMemoryEffectRecomputeWindowPolicy; + } + if (role_opcode == kRuntimeScheduleRoleMaterializationPolicy) { + return kMemoryEffectMaterializationPolicy; + } + if (role_opcode == kRuntimeScheduleRoleCudaGraphConstraint) { + return kMemoryEffectCudaGraphConstraint; + } + return 0; +} + +inline void require_registered_runtime_schedule_policy_row( + const at::Tensor& memory_liveness_rows, + const int64_t schedule_row_index, + const int64_t role_opcode, + const int64_t memory_row_index, + const int64_t policy_opcode) { + const int64_t effect_opcode = runtime_schedule_policy_effect_for_role(role_opcode); + TORCH_CHECK( + effect_opcode > 0, + "memory_runtime_schedule_rows row ", + schedule_row_index, + " has non-policy role in policy validation: ", + role_opcode); + TORCH_CHECK( + 0 <= memory_row_index && memory_row_index < memory_liveness_rows.size(0), + "memory_runtime_schedule_rows row ", + schedule_row_index, + " references invalid policy memory row ", + memory_row_index); + const int64_t* memory = memory_liveness_rows.data_ptr() + memory_row_index * 10; + TORCH_CHECK( + memory[3] == kRuntimePolicySurfaceOpcode && memory[6] == kMemoryWorkspacePolicyTable && + memory[9] == kMemoryOwnerCompilerMemoryPolicy, + "memory_runtime_schedule_rows row ", + schedule_row_index, + " must reference a compiler-owned runtime policy memory row"); + TORCH_CHECK( + memory[7] == effect_opcode, + "memory_runtime_schedule_rows row ", + schedule_row_index, + " references a policy row with the wrong effect"); + TORCH_CHECK( + memory[8] == policy_opcode, + "memory_runtime_schedule_rows row ", + schedule_row_index, + " policy opcode does not match memory_liveness_rows recompute policy"); +} + +inline bool runtime_schedule_role_is_scalar(int64_t role_opcode) { + return role_opcode == kRuntimeScheduleRolePhysicalTimeSteps || role_opcode == kRuntimeScheduleRoleCheckpointStride || + role_opcode == kRuntimeScheduleRoleRecomputeWindowLen || role_opcode == kRuntimeScheduleRoleCheckpointStep || + role_opcode == kRuntimeScheduleRoleBackwardWindow || role_opcode == kRuntimeScheduleRoleOutputPhysicalStep || + role_opcode == kRuntimeScheduleRoleStoreStepArtifacts; +} + +inline int64_t registered_runtime_schedule_value_for_role( + const at::Tensor& memory_runtime_schedule_rows, + const int64_t role_opcode, + const int64_t value_column, + const char* name) { + check_cpu_long_rank2(memory_runtime_schedule_rows, "memory_runtime_schedule_rows", 6); + TORCH_CHECK( + 0 <= value_column && value_column < 6, + name, + " requested invalid memory_runtime_schedule_rows value column ", + value_column); + const int64_t* rows = memory_runtime_schedule_rows.data_ptr(); + for (int64_t row_index = 0; row_index < memory_runtime_schedule_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + if (row[1] == role_opcode) { + return row[value_column]; + } + } + TORCH_CHECK(false, name, " memory_runtime_schedule_rows missing required role ", role_opcode); +} + +inline void enforce_registered_cuda_graph_launch_guard( + const at::Tensor& memory_runtime_schedule_rows, + const char* name) { + const int64_t cuda_graph_policy = registered_runtime_schedule_value_for_role( + memory_runtime_schedule_rows, + kRuntimeScheduleRoleCudaGraphConstraint, + 3, + name); + TORCH_CHECK( + cuda_graph_policy == kMemoryRecomputePolicyCudaGraphGuardPolicy, + name, + " requires compiler-owned cuda_graph_guard_policy before fused CUDA launch"); + cudaStreamCaptureStatus capture_status = cudaStreamCaptureStatusNone; + const auto stream = at::cuda::getCurrentCUDAStream(); + const cudaError_t capture_error = cudaStreamIsCapturing(stream.stream(), &capture_status); + TORCH_CHECK( + capture_error == cudaSuccess, + name, + " failed to query CUDA graph capture status: ", + cudaGetErrorString(capture_error)); + TORCH_CHECK( + capture_status != cudaStreamCaptureStatusInvalidated, + name, + " cannot launch with an invalidated CUDA graph capture stream"); +} + +inline void validate_registered_memory_runtime_schedule_rows( + const at::Tensor& memory_liveness_rows, + const at::Tensor& memory_runtime_schedule_rows, + const char* name) { + check_cpu_long_rank2(memory_liveness_rows, "memory_liveness_rows", 10); + check_cpu_long_rank2(memory_runtime_schedule_rows, "memory_runtime_schedule_rows", 6); + const int64_t* rows = memory_runtime_schedule_rows.data_ptr(); + bool saw_local_seed = false; + bool saw_metadata = false; + bool saw_primitive_output = false; + bool saw_tape = false; + bool saw_alias = false; + bool saw_recompute = false; + bool saw_materialization = false; + bool saw_cuda_graph = false; + bool saw_physical_steps = false; + bool saw_checkpoint_stride = false; + bool saw_recompute_window_len = false; + for (int64_t row_index = 0; row_index < memory_runtime_schedule_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + const int64_t role_opcode = row[1]; + TORCH_CHECK(row[0] == row_index, name, " memory_runtime_schedule_rows must be densely indexed"); + TORCH_CHECK(row[5] == 1, name, " memory_runtime_schedule_rows row ", row_index, " must be required"); + const int64_t policy_effect = runtime_schedule_policy_effect_for_role(role_opcode); + if (policy_effect > 0) { + require_registered_runtime_schedule_policy_row(memory_liveness_rows, row_index, role_opcode, row[2], row[3]); + saw_local_seed = saw_local_seed || role_opcode == kRuntimeScheduleRoleLocalSeedPolicy; + saw_metadata = saw_metadata || role_opcode == kRuntimeScheduleRoleMetadataPolicy; + saw_primitive_output = saw_primitive_output || role_opcode == kRuntimeScheduleRolePrimitiveOutputPolicy; + saw_tape = saw_tape || role_opcode == kRuntimeScheduleRoleTapePolicy; + saw_alias = saw_alias || role_opcode == kRuntimeScheduleRoleAliasPolicy; + saw_recompute = saw_recompute || role_opcode == kRuntimeScheduleRoleRecomputeWindowPolicy; + saw_materialization = saw_materialization || role_opcode == kRuntimeScheduleRoleMaterializationPolicy; + saw_cuda_graph = saw_cuda_graph || role_opcode == kRuntimeScheduleRoleCudaGraphConstraint; + continue; + } + TORCH_CHECK( + runtime_schedule_role_is_scalar(role_opcode), + name, + " memory_runtime_schedule_rows row ", + row_index, + " has unknown role opcode ", + role_opcode); + TORCH_CHECK(row[2] == -1, name, " scalar memory_runtime_schedule_rows must not reference a memory row"); + TORCH_CHECK(row[3] >= 0, name, " scalar memory_runtime_schedule_rows must have a non-negative value"); + if (role_opcode == kRuntimeScheduleRolePhysicalTimeSteps) { + saw_physical_steps = true; + TORCH_CHECK(row[3] > 0, name, " physical_time_steps schedule row must be positive"); + } else if (role_opcode == kRuntimeScheduleRoleCheckpointStride) { + saw_checkpoint_stride = true; + TORCH_CHECK(row[3] > 0, name, " checkpoint_stride schedule row must be positive"); + } else if (role_opcode == kRuntimeScheduleRoleRecomputeWindowLen) { + saw_recompute_window_len = true; + TORCH_CHECK(row[3] > 0, name, " recompute_window_len schedule row must be positive"); + } else if (role_opcode == kRuntimeScheduleRoleBackwardWindow) { + TORCH_CHECK(row[4] > row[3], name, " backward_window schedule row must have end > start"); + } + } + TORCH_CHECK(saw_local_seed, name, " memory_runtime_schedule_rows missing local_seed_policy"); + TORCH_CHECK(saw_metadata, name, " memory_runtime_schedule_rows missing metadata_policy"); + TORCH_CHECK(saw_primitive_output, name, " memory_runtime_schedule_rows missing primitive_output_policy"); + TORCH_CHECK(saw_tape, name, " memory_runtime_schedule_rows missing tape_policy"); + TORCH_CHECK(saw_alias, name, " memory_runtime_schedule_rows missing alias_policy"); + TORCH_CHECK(saw_recompute, name, " memory_runtime_schedule_rows missing recompute_window_policy"); + TORCH_CHECK(saw_materialization, name, " memory_runtime_schedule_rows missing materialization_policy"); + TORCH_CHECK(saw_cuda_graph, name, " memory_runtime_schedule_rows missing cuda_graph_constraint"); + TORCH_CHECK(saw_physical_steps, name, " memory_runtime_schedule_rows missing physical_time_steps"); + TORCH_CHECK(saw_checkpoint_stride, name, " memory_runtime_schedule_rows missing checkpoint_stride"); + TORCH_CHECK(saw_recompute_window_len, name, " memory_runtime_schedule_rows missing recompute_window_len"); + enforce_registered_cuda_graph_launch_guard(memory_runtime_schedule_rows, name); +} + +inline void validate_registered_physical_strategy_rows( + const at::Tensor& physical_strategy_rows, + const at::Tensor& memory_runtime_schedule_rows, + const char* name) { + check_cpu_long_rank2(physical_strategy_rows, "physical_strategy_rows", kPhysicalStrategyRowColumns); + TORCH_CHECK(physical_strategy_rows.size(0) > 0, name, " requires compiler-owned physical_strategy_rows"); + const int64_t scheduled_physical_steps = registered_runtime_schedule_value_for_role( + memory_runtime_schedule_rows, + kRuntimeScheduleRolePhysicalTimeSteps, + 3, + name); + const int64_t required_surface_mask = + kPhysicalStrategySurfaceMessage | kPhysicalStrategySurfaceTransition | kPhysicalStrategySurfaceReadout | + kPhysicalStrategySurfaceArtifacts | kPhysicalStrategySurfaceReducers; + const int64_t required_table_mask = + kPhysicalStrategyTablePrimitiveRows | kPhysicalStrategyTableExecutorRows | + kPhysicalStrategyTableBindingRows | kPhysicalStrategyTableMemoryLivenessRows | + kPhysicalStrategyTableArtifactRouteRows | kPhysicalStrategyTableOutputRouteRows | + kPhysicalStrategyTableRuntimeScheduleRows; + const int64_t* rows = physical_strategy_rows.data_ptr(); + bool saw_active = false; + bool saw_streaming_strategy = false; + for (int64_t row_index = 0; row_index < physical_strategy_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kPhysicalStrategyRowColumns; + const int64_t strategy_opcode = row[2]; + const int64_t status_opcode = row[3]; + const int64_t executable = row[4]; + const int64_t physical_time_steps = row[5]; + const int64_t inner_steps = row[6]; + const int64_t output_boundary_opcode = row[7]; + const int64_t reset_policy_opcode = row[8]; + const int64_t surface_mask = row[9]; + const int64_t table_mask = row[10]; + const int64_t blocker_opcode = row[11]; + TORCH_CHECK(row[0] == row_index, name, " physical_strategy_rows must be densely indexed"); + TORCH_CHECK(row[1] == 1, name, " physical_strategy_rows schema version mismatch"); + TORCH_CHECK( + strategy_opcode == kPhysicalStrategyStageMaterialized || + strategy_opcode == kPhysicalStrategyStreamingStepProducerConsumer, + name, + " physical_strategy_rows row ", + row_index, + " has unknown strategy opcode ", + strategy_opcode); + TORCH_CHECK( + status_opcode == kPhysicalStrategyStatusActive || status_opcode == kPhysicalStrategyStatusCandidate || + status_opcode == kPhysicalStrategyStatusBlocked, + name, + " physical_strategy_rows row ", + row_index, + " has unknown status opcode ", + status_opcode); + TORCH_CHECK(executable == 0 || executable == 1, name, " physical_strategy_rows executable flag must be 0/1"); + TORCH_CHECK( + physical_time_steps == scheduled_physical_steps, + name, + " physical_strategy_rows physical_time_steps do not match memory_runtime_schedule_rows"); + TORCH_CHECK(inner_steps > 0, name, " physical_strategy_rows inner_steps must be positive"); + TORCH_CHECK( + output_boundary_opcode == kPhysicalStrategyOutputBoundaryTerminal || + output_boundary_opcode == kPhysicalStrategyOutputBoundarySequence, + name, + " physical_strategy_rows row ", + row_index, + " has invalid output boundary opcode ", + output_boundary_opcode); + TORCH_CHECK( + reset_policy_opcode == kPhysicalStrategyResetAbsent || + reset_policy_opcode == kPhysicalStrategyResetPresent || + reset_policy_opcode == kPhysicalStrategyResetUnknown, + name, + " physical_strategy_rows row ", + row_index, + " has invalid reset policy opcode ", + reset_policy_opcode); + TORCH_CHECK( + (surface_mask & required_surface_mask) == required_surface_mask, + name, + " physical_strategy_rows row ", + row_index, + " does not declare all required producer-consumer surfaces"); + TORCH_CHECK( + (table_mask & required_table_mask) == required_table_mask, + name, + " physical_strategy_rows row ", + row_index, + " does not consume all required compiler tables"); + TORCH_CHECK( + blocker_opcode == kPhysicalStrategyBlockerNone || + blocker_opcode == kPhysicalStrategyBlockerPendingProgramKernel, + name, + " physical_strategy_rows row ", + row_index, + " has invalid blocker opcode ", + blocker_opcode); + if (status_opcode == kPhysicalStrategyStatusActive) { + TORCH_CHECK(!saw_active, name, " physical_strategy_rows has multiple active strategies"); + TORCH_CHECK(executable == 1, name, " active physical_strategy_rows entry must be executable"); + TORCH_CHECK(blocker_opcode == kPhysicalStrategyBlockerNone, name, " active physical_strategy_rows entry is blocked"); + saw_active = true; + } else { + TORCH_CHECK(executable == 0, name, " non-active physical_strategy_rows entry must not be executable"); + } + if (strategy_opcode == kPhysicalStrategyStreamingStepProducerConsumer) { + saw_streaming_strategy = true; + } + } + TORCH_CHECK(saw_active, name, " physical_strategy_rows missing active executable strategy"); + TORCH_CHECK(saw_streaming_strategy, name, " physical_strategy_rows missing streaming-step producer-consumer strategy row"); +} + +inline int64_t registered_active_physical_strategy_opcode( + const at::Tensor& physical_strategy_rows, + const char* name) { + check_cpu_long_rank2(physical_strategy_rows, "physical_strategy_rows", kPhysicalStrategyRowColumns); + const int64_t* rows = physical_strategy_rows.data_ptr(); + int64_t active_strategy = 0; + for (int64_t row_index = 0; row_index < physical_strategy_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kPhysicalStrategyRowColumns; + if (row[3] != kPhysicalStrategyStatusActive) { + continue; + } + TORCH_CHECK(active_strategy == 0, name, " physical_strategy_rows has multiple active strategies"); + TORCH_CHECK(row[4] == 1, name, " active physical_strategy_rows entry must be executable"); + TORCH_CHECK(row[11] == kPhysicalStrategyBlockerNone, name, " active physical_strategy_rows entry is blocked"); + active_strategy = row[2]; + } + TORCH_CHECK(active_strategy != 0, name, " physical_strategy_rows missing active executable strategy"); + return active_strategy; +} + +inline bool registered_physical_strategy_active_is_streaming( + const at::Tensor& physical_strategy_rows, + const char* name) { + return registered_active_physical_strategy_opcode(physical_strategy_rows, name) == + kPhysicalStrategyStreamingStepProducerConsumer; +} + +inline bool registered_runtime_buffer_role_allows_deferred_local(int64_t runtime_role_opcode) { + return runtime_role_opcode == kRuntimeBufferRoleForwardRecurrentHiddenAfter || + runtime_role_opcode == kRuntimeBufferRoleForwardRecurrentMsg || + runtime_role_opcode == kRuntimeBufferRoleForwardOutputMsg || + runtime_role_opcode == kRuntimeBufferRoleForwardOutputCells || + runtime_role_opcode == kRuntimeBufferRoleTransitionForwardLinearOutput || + runtime_role_opcode == kRuntimeBufferRoleTransitionForwardMatmulOutput || + runtime_role_opcode == kRuntimeBufferRoleTransitionForwardStateOutput || + runtime_role_opcode == kRuntimeBufferRoleTransitionForwardNormOutput || + runtime_role_opcode == kRuntimeBufferRoleTransitionForwardDiagOutput || + runtime_role_opcode == kRuntimeBufferRoleTransitionForwardUnaryOutput; +} + +inline bool registered_runtime_buffer_is_deferred_local_placeholder(const at::Tensor& tensor) { + return tensor.defined() && tensor.numel() == 0 && tensor.dim() == 1 && tensor.size(0) == 0; +} + +inline bool registered_runtime_buffer_has_deferred_local_role( + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + int64_t runtime_role_opcode, + int64_t logical_index) { + check_cpu_long_rank2(runtime_buffer_rows, "runtime_buffer_rows", 10); + const int64_t* rows = runtime_buffer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < runtime_buffer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + if (row[8] != runtime_role_opcode || row[9] != logical_index) { + continue; + } + const int64_t tensor_index = row[0]; + TORCH_CHECK( + 0 <= tensor_index && tensor_index < static_cast(runtime_buffer_tensors.size()), + "runtime buffer row references invalid tensor index while checking deferred local role"); + return registered_runtime_buffer_is_deferred_local_placeholder( + runtime_buffer_tensors[static_cast(tensor_index)]); + } + return false; +} + +inline bool registered_runtime_buffer_has_deferred_local_transition_forward_output( + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows) { + check_cpu_long_rank2(runtime_buffer_rows, "runtime_buffer_rows", 10); + const int64_t* rows = runtime_buffer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < runtime_buffer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + const int64_t role = row[8]; + if ( + role != kRuntimeBufferRoleTransitionForwardLinearOutput && + role != kRuntimeBufferRoleTransitionForwardMatmulOutput && + role != kRuntimeBufferRoleTransitionForwardStateOutput && + role != kRuntimeBufferRoleTransitionForwardNormOutput && + role != kRuntimeBufferRoleTransitionForwardDiagOutput && + role != kRuntimeBufferRoleTransitionForwardUnaryOutput) { + continue; + } + const int64_t tensor_index = row[0]; + TORCH_CHECK( + 0 <= tensor_index && tensor_index < static_cast(runtime_buffer_tensors.size()), + "runtime buffer row references invalid tensor index while checking deferred transition output"); + if (registered_runtime_buffer_is_deferred_local_placeholder( + runtime_buffer_tensors[static_cast(tensor_index)])) { + return true; + } + } + return false; +} + +inline void validate_registered_runtime_buffer_rows( + const at::Tensor& memory_liveness_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const char* name) { + check_cpu_long_rank2(memory_liveness_rows, "memory_liveness_rows", 10); + check_cpu_long_rank2(runtime_buffer_rows, "runtime_buffer_rows", 10); + TORCH_CHECK( + runtime_buffer_rows.size(0) == static_cast(runtime_buffer_tensors.size()), + name, + " runtime buffer row count must match runtime buffer tensor table"); + const int64_t* memory_rows = memory_liveness_rows.data_ptr(); + const int64_t* buffer_rows = runtime_buffer_rows.data_ptr(); + std::vector covered_memory_rows(static_cast(memory_liveness_rows.size(0)), false); + for (int64_t buffer_row_index = 0; buffer_row_index < runtime_buffer_rows.size(0); ++buffer_row_index) { + const int64_t* buffer = buffer_rows + buffer_row_index * 10; + TORCH_CHECK(buffer[0] == buffer_row_index, name, " runtime buffer rows must be densely indexed"); + const int64_t memory_row_index = buffer[1]; + TORCH_CHECK( + 0 <= memory_row_index && memory_row_index < memory_liveness_rows.size(0), + name, + " runtime buffer row references invalid memory row ", + memory_row_index); + const int64_t* memory = memory_rows + memory_row_index * 10; + TORCH_CHECK( + registered_memory_row_requires_runtime_buffer(memory), + name, + " runtime buffer row references a memory row that does not require allocation: ", + memory_row_index); + TORCH_CHECK(buffer[2] == memory[6], name, " runtime buffer workspace does not match memory row"); + TORCH_CHECK(buffer[3] == memory[3], name, " runtime buffer surface does not match memory row"); + TORCH_CHECK(buffer[4] == memory[2], name, " runtime buffer bucket does not match memory row"); + TORCH_CHECK(buffer[5] == memory[7], name, " runtime buffer effect does not match memory row"); + TORCH_CHECK(buffer[6] >= 0, name, " runtime buffer row has invalid alias index"); + TORCH_CHECK(buffer[7] == 0 || buffer[7] == 1, name, " runtime buffer row has invalid init flag"); + TORCH_CHECK(buffer[8] >= 0, name, " runtime buffer row has invalid runtime role"); + TORCH_CHECK(buffer[9] >= 0, name, " runtime buffer row has invalid logical index"); + const at::Tensor& tensor = runtime_buffer_tensors[static_cast(buffer_row_index)]; + TORCH_CHECK(tensor.defined(), name, " runtime buffer tensor is undefined at row ", buffer_row_index); + TORCH_CHECK(tensor.is_cuda(), name, " runtime buffer tensor must be CUDA at row ", buffer_row_index); + TORCH_CHECK(tensor.is_contiguous(), name, " runtime buffer tensor must be contiguous at row ", buffer_row_index); + if (buffer[8] == kRuntimeBufferRoleForwardMessageStepFlat || + buffer[8] == kRuntimeBufferRoleReverseMessageStepFlat) { + TORCH_CHECK( + tensor.scalar_type() == at::kLong, + name, + " message-step runtime buffer tensor must be int64 at row ", + buffer_row_index); + } else { + TORCH_CHECK( + tensor.scalar_type() == at::kFloat, + name, + " runtime buffer tensor must be float32 at row ", + buffer_row_index); + } + if (registered_runtime_buffer_is_deferred_local_placeholder(tensor)) { + TORCH_CHECK( + registered_runtime_buffer_role_allows_deferred_local(buffer[8]), + name, + " deferred local runtime buffer placeholder is only legal for compiler-routed step-local outputs at row ", + buffer_row_index); + TORCH_CHECK( + buffer[7] == 0, + name, + " deferred local runtime buffer placeholder must use empty init at row ", + buffer_row_index); + } else { + TORCH_CHECK(tensor.numel() > 0, name, " runtime buffer tensor must allocate storage at row ", buffer_row_index); + } + covered_memory_rows[static_cast(memory_row_index)] = true; + } + for (int64_t memory_row_index = 0; memory_row_index < memory_liveness_rows.size(0); ++memory_row_index) { + const int64_t* memory = memory_rows + memory_row_index * 10; + if (!registered_memory_row_requires_runtime_buffer(memory)) { + continue; + } + TORCH_CHECK( + covered_memory_rows[static_cast(memory_row_index)], + name, + " has no runtime buffer for compiler memory row ", + memory_row_index); + } +} + +inline bool registered_tensor_matches_shape(const at::Tensor& tensor, const std::vector& expected_shape) { + if (tensor.dim() != static_cast(expected_shape.size())) { + return false; + } + for (int64_t dim = 0; dim < tensor.dim(); ++dim) { + if (tensor.size(dim) != expected_shape[static_cast(dim)]) { + return false; + } + } + return true; +} + +inline at::Tensor registered_materialize_deferred_local_runtime_buffer( + const at::Tensor& tensor, + int64_t runtime_role_opcode, + const std::vector& expected_shape, + const char* name) { + if (!registered_runtime_buffer_is_deferred_local_placeholder(tensor)) { + TORCH_CHECK( + registered_tensor_matches_shape(tensor, expected_shape), + name, + " compiler runtime buffer shape mismatch for role=", + runtime_role_opcode); + return tensor; + } + TORCH_CHECK( + registered_runtime_buffer_role_allows_deferred_local(runtime_role_opcode), + name, + " deferred local runtime buffer requested for a non-step-local output role ", + runtime_role_opcode); + return at::empty(expected_shape, tensor.options()); +} + +inline at::Tensor registered_runtime_buffer_for_workspace_effect( + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + int64_t workspace_opcode, + int64_t effect_opcode, + const std::vector& expected_shape, + const char* name) { + check_cpu_long_rank2(runtime_buffer_rows, "runtime_buffer_rows", 10); + const int64_t* rows = runtime_buffer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < runtime_buffer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + if (row[2] != workspace_opcode || row[5] != effect_opcode) { + continue; + } + const int64_t tensor_index = row[0]; + TORCH_CHECK( + 0 <= tensor_index && tensor_index < static_cast(runtime_buffer_tensors.size()), + name, + " runtime buffer row references invalid tensor index ", + tensor_index); + const at::Tensor& tensor = runtime_buffer_tensors[static_cast(tensor_index)]; + if (!registered_tensor_matches_shape(tensor, expected_shape)) { + continue; + } + return tensor; + } + TORCH_CHECK( + false, + name, + " has no compiler runtime buffer for workspace=", + workspace_opcode, + ", effect=", + effect_opcode, + ", expected_rank=", + static_cast(expected_shape.size())); + return runtime_buffer_tensors.empty() ? at::Tensor() : runtime_buffer_tensors[0]; +} + +inline at::Tensor registered_runtime_buffer_for_role( + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + int64_t runtime_role_opcode, + int64_t logical_index, + const std::vector& expected_shape, + const char* name) { + check_cpu_long_rank2(runtime_buffer_rows, "runtime_buffer_rows", 10); + const int64_t* rows = runtime_buffer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < runtime_buffer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + if (row[8] != runtime_role_opcode || row[9] != logical_index) { + continue; + } + const int64_t tensor_index = row[0]; + TORCH_CHECK( + 0 <= tensor_index && tensor_index < static_cast(runtime_buffer_tensors.size()), + name, + " runtime buffer row references invalid tensor index ", + tensor_index); + const at::Tensor& tensor = runtime_buffer_tensors[static_cast(tensor_index)]; + return registered_materialize_deferred_local_runtime_buffer( + tensor, + runtime_role_opcode, + expected_shape, + name); + } + TORCH_CHECK( + false, + name, + " has no compiler runtime buffer for role=", + runtime_role_opcode, + ", logical_index=", + logical_index, + ", expected_rank=", + static_cast(expected_shape.size())); + return runtime_buffer_tensors.empty() ? at::Tensor() : runtime_buffer_tensors[0]; +} + +inline bool registered_runtime_buffer_has_role( + const at::Tensor& runtime_buffer_rows, + int64_t runtime_role_opcode, + int64_t logical_index) { + check_cpu_long_rank2(runtime_buffer_rows, "runtime_buffer_rows", 10); + const int64_t* rows = runtime_buffer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < runtime_buffer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + if (row[8] == runtime_role_opcode && row[9] == logical_index) { + return true; + } + } + return false; +} + +inline at::Tensor registered_runtime_buffer_for_role_any_shape( + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + int64_t runtime_role_opcode, + int64_t logical_index, + const char* name) { + check_cpu_long_rank2(runtime_buffer_rows, "runtime_buffer_rows", 10); + const int64_t* rows = runtime_buffer_rows.data_ptr(); + for (int64_t row_index = 0; row_index < runtime_buffer_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + if (row[8] != runtime_role_opcode || row[9] != logical_index) { + continue; + } + const int64_t tensor_index = row[0]; + TORCH_CHECK( + 0 <= tensor_index && tensor_index < static_cast(runtime_buffer_tensors.size()), + name, + " runtime buffer row references invalid tensor index ", + tensor_index); + return runtime_buffer_tensors[static_cast(tensor_index)]; + } + TORCH_CHECK( + false, + name, + " has no compiler runtime buffer for role=", + runtime_role_opcode, + ", logical_index=", + logical_index); + return runtime_buffer_tensors.empty() ? at::Tensor() : runtime_buffer_tensors[0]; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callable_bindings.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callable_bindings.cuh new file mode 100644 index 00000000..35d10264 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callable_bindings.cuh @@ -0,0 +1,587 @@ +#pragma once + +struct RegisteredNativeCallableOutputContract { + int64_t runtime_role; + int64_t shape_kind; + int64_t logical_index_source; + int64_t init_kind; +}; + +inline at::Tensor program_tensor_for_binding( + const std::vector& tensor_table, + const at::Tensor& tensor_binding_rows, + int64_t binding_index, + const char* subject); + +inline void validate_registered_native_callable_output_rows( + const at::Tensor& native_callable_output_rows, + int64_t schema_version) { + check_cpu_long_rank2( + native_callable_output_rows, + "native_callable_output_rows", + kNativeCallableOutputRowColumns); + const int64_t* rows = native_callable_output_rows.data_ptr(); + for (int64_t row_index = 0; row_index < native_callable_output_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kNativeCallableOutputRowColumns; + TORCH_CHECK(row[0] == row_index, "native_callable_output_rows must be densely indexed"); + TORCH_CHECK(row[1] > 0, "native_callable_output_rows row ", row_index, " has no callable hash"); + TORCH_CHECK( + row[2] == kForwardDirectionOpcode || row[2] == kReverseDirectionOpcode, + "native_callable_output_rows row ", + row_index, + " has invalid direction opcode"); + TORCH_CHECK(row[3] > 0, "native_callable_output_rows row ", row_index, " has invalid surface opcode"); + TORCH_CHECK(row[4] > 0, "native_callable_output_rows row ", row_index, " has invalid primitive opcode"); + TORCH_CHECK(row[5] > 0, "native_callable_output_rows row ", row_index, " has no output name hash"); + TORCH_CHECK(row[6] >= 0, "native_callable_output_rows row ", row_index, " has invalid output index"); + TORCH_CHECK(row[7] > 0, "native_callable_output_rows row ", row_index, " has invalid runtime role"); + TORCH_CHECK( + row[8] == kNativeCallableOutputShapeHidden || + row[8] == kNativeCallableOutputShapeGateLogits || + row[8] == kNativeCallableOutputShapeDiagonalPreproj, + "native_callable_output_rows row ", + row_index, + " has invalid shape kind"); + TORCH_CHECK( + row[9] == kNativeCallableOutputLogicalPrimitiveRow || + row[9] == kNativeCallableOutputLogicalBindingIndex, + "native_callable_output_rows row ", + row_index, + " has invalid logical-index source"); + TORCH_CHECK( + row[10] == kNativeCallableOutputInitEmpty || row[10] == kNativeCallableOutputInitZeros, + "native_callable_output_rows row ", + row_index, + " has invalid init kind"); + TORCH_CHECK( + row[11] == schema_version, + "native_callable_output_rows row ", + row_index, + " has unsupported schema version"); + } +} + +inline void validate_registered_native_callable_binding_schema_rows( + const at::Tensor& native_callable_binding_schema_rows, + int64_t schema_version) { + check_cpu_long_rank2( + native_callable_binding_schema_rows, + "native_callable_binding_schema_rows", + kNativeCallableBindingSchemaRowColumns); + const int64_t* rows = native_callable_binding_schema_rows.data_ptr(); + for (int64_t row_index = 0; row_index < native_callable_binding_schema_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kNativeCallableBindingSchemaRowColumns; + TORCH_CHECK(row[0] == row_index, "native_callable_binding_schema_rows must be densely indexed"); + TORCH_CHECK(row[1] > 0, "native_callable_binding_schema_rows row ", row_index, " has no callable hash"); + TORCH_CHECK( + row[2] == kForwardDirectionOpcode || row[2] == kReverseDirectionOpcode, + "native_callable_binding_schema_rows row ", + row_index, + " has invalid direction opcode"); + TORCH_CHECK(row[3] > 0, "native_callable_binding_schema_rows row ", row_index, " has invalid surface opcode"); + TORCH_CHECK(row[4] > 0, "native_callable_binding_schema_rows row ", row_index, " has invalid primitive opcode"); + TORCH_CHECK( + row[5] == kNativeCallableBindingInput || + row[5] == kNativeCallableBindingParameter || + row[5] == kNativeCallableBindingOutput, + "native_callable_binding_schema_rows row ", + row_index, + " has invalid binding kind"); + TORCH_CHECK(row[6] > 0, "native_callable_binding_schema_rows row ", row_index, " has no logical name hash"); + TORCH_CHECK(row[7] >= 0, "native_callable_binding_schema_rows row ", row_index, " has invalid local index"); + TORCH_CHECK( + row[8] == 0 || row[8] == 1, + "native_callable_binding_schema_rows row ", + row_index, + " has invalid required flag"); + TORCH_CHECK( + row[9] == schema_version, + "native_callable_binding_schema_rows row ", + row_index, + " has unsupported schema version"); + } +} + +struct RegisteredNativeCallableBindingVectorContract { + int64_t min_count; + int64_t max_count; +}; + +inline bool registered_native_callable_binding_schema_row_matches( + const int64_t* row, + int64_t native_callable_hash, + int64_t direction_opcode, + int64_t surface_opcode, + int64_t primitive_opcode, + int64_t binding_kind) { + return row[1] == native_callable_hash && row[2] == direction_opcode && row[3] == surface_opcode && + row[4] == primitive_opcode && row[5] == binding_kind; +} + +inline RegisteredNativeCallableBindingVectorContract registered_native_callable_binding_vector_contract_for( + const at::Tensor& native_callable_binding_schema_rows, + int64_t native_callable_hash, + int64_t direction_opcode, + int64_t surface_opcode, + int64_t primitive_opcode, + int64_t binding_kind, + bool allow_optional, + int64_t schema_version, + const char* subject) { + validate_registered_native_callable_binding_schema_rows(native_callable_binding_schema_rows, schema_version); + const int64_t* rows = native_callable_binding_schema_rows.data_ptr(); + int64_t row_count = 0; + int64_t required_count = 0; + int64_t max_local_index = -1; + for (int64_t row_index = 0; row_index < native_callable_binding_schema_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kNativeCallableBindingSchemaRowColumns; + if (!registered_native_callable_binding_schema_row_matches( + row, + native_callable_hash, + direction_opcode, + surface_opcode, + primitive_opcode, + binding_kind)) { + continue; + } + TORCH_CHECK( + allow_optional || row[8] == 1, + subject, + " compiler native callable schema marks a non-parameter binding optional"); + ++row_count; + required_count += row[8] == 1 ? 1 : 0; + if (row[7] > max_local_index) { + max_local_index = row[7]; + } + } + if (row_count == 0) { + return RegisteredNativeCallableBindingVectorContract{0, 0}; + } + TORCH_CHECK( + max_local_index + 1 == row_count, + subject, + " compiler native callable binding schema local indices are not dense for callable=", + native_callable_hash, + ",primitive=", + primitive_opcode, + ",kind=", + binding_kind); + for (int64_t local_index = 0; local_index <= max_local_index; ++local_index) { + int64_t rows_at_local_index = 0; + for (int64_t row_index = 0; row_index < native_callable_binding_schema_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kNativeCallableBindingSchemaRowColumns; + if (!registered_native_callable_binding_schema_row_matches( + row, + native_callable_hash, + direction_opcode, + surface_opcode, + primitive_opcode, + binding_kind)) { + continue; + } + rows_at_local_index += row[7] == local_index ? 1 : 0; + } + TORCH_CHECK( + rows_at_local_index == 1, + subject, + " compiler native callable binding schema has duplicate or missing local index ", + local_index, + " for callable=", + native_callable_hash, + ",primitive=", + primitive_opcode, + ",kind=", + binding_kind); + } + return RegisteredNativeCallableBindingVectorContract{ + allow_optional ? required_count : row_count, + row_count, + }; +} + +inline RegisteredNativeCallableBindingVectorContract require_native_callable_binding_vector_contract( + const std::vector& bindings, + const at::Tensor& native_callable_binding_schema_rows, + int64_t native_callable_hash, + int64_t direction_opcode, + int64_t surface_opcode, + int64_t primitive_opcode, + int64_t binding_kind, + bool allow_optional, + int64_t schema_version, + const char* subject) { + const RegisteredNativeCallableBindingVectorContract contract = + registered_native_callable_binding_vector_contract_for( + native_callable_binding_schema_rows, + native_callable_hash, + direction_opcode, + surface_opcode, + primitive_opcode, + binding_kind, + allow_optional, + schema_version, + subject); + const int64_t actual_count = static_cast(bindings.size()); + TORCH_CHECK( + actual_count >= contract.min_count && actual_count <= contract.max_count, + subject, + " received binding count outside compiler native callable schema: expected=[", + contract.min_count, + ",", + contract.max_count, + "], actual=", + actual_count, + ", callable=", + native_callable_hash, + ", primitive=", + primitive_opcode, + ", kind=", + binding_kind); + const int64_t* rows = native_callable_binding_schema_rows.data_ptr(); + for (int64_t row_index = 0; row_index < native_callable_binding_schema_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kNativeCallableBindingSchemaRowColumns; + if (!registered_native_callable_binding_schema_row_matches( + row, + native_callable_hash, + direction_opcode, + surface_opcode, + primitive_opcode, + binding_kind)) { + continue; + } + TORCH_CHECK( + row[8] == 0 || row[7] < actual_count, + subject, + " is missing required compiler native callable binding at local index ", + row[7], + " for callable=", + native_callable_hash, + ", primitive=", + primitive_opcode, + ", kind=", + binding_kind); + } + return contract; +} + +inline int64_t native_callable_local_binding_index_for( + const at::Tensor& native_callable_binding_schema_rows, + int64_t native_callable_hash, + int64_t direction_opcode, + int64_t surface_opcode, + int64_t primitive_opcode, + int64_t binding_kind, + int64_t logical_name_hash, + bool required, + int64_t schema_version, + const char* subject) { + validate_registered_native_callable_binding_schema_rows(native_callable_binding_schema_rows, schema_version); + const int64_t* rows = native_callable_binding_schema_rows.data_ptr(); + int64_t result = -1; + for (int64_t row_index = 0; row_index < native_callable_binding_schema_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kNativeCallableBindingSchemaRowColumns; + if (row[1] != native_callable_hash || row[2] != direction_opcode || row[3] != surface_opcode || + row[4] != primitive_opcode || row[5] != binding_kind || row[6] != logical_name_hash) { + continue; + } + TORCH_CHECK( + result < 0, + subject, + " has duplicate compiler native callable binding schema rows for logical binding hash ", + logical_name_hash); + result = row[7]; + if (required) { + TORCH_CHECK(row[8] == 1, subject, " binding schema marks required logical binding optional"); + } + } + if (required) { + TORCH_CHECK( + result >= 0, + subject, + " has no compiler native callable binding schema row for logical binding hash ", + logical_name_hash); + } + return result; +} + +inline int64_t native_callable_program_binding_for( + const std::vector& bindings, + const at::Tensor& native_callable_binding_schema_rows, + int64_t native_callable_hash, + int64_t direction_opcode, + int64_t primitive_opcode, + int64_t binding_kind, + const char* logical_name, + bool required, + int64_t schema_version, + const char* subject) { + const int64_t local_index = native_callable_local_binding_index_for( + native_callable_binding_schema_rows, + native_callable_hash, + direction_opcode, + kTransitionSurfaceOpcode, + primitive_opcode, + binding_kind, + registered_temporal_stable_id_hash_constexpr(logical_name), + required, + schema_version, + subject); + if (local_index < 0) { + return -1; + } + if (!required && local_index >= static_cast(bindings.size())) { + return -1; + } + TORCH_CHECK( + local_index < static_cast(bindings.size()), + subject, + " is missing program binding for logical binding ", + logical_name, + " at local index ", + local_index); + return bindings[static_cast(local_index)]; +} + +inline at::Tensor native_callable_tensor_for_binding( + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& bindings, + const at::Tensor& native_callable_binding_schema_rows, + int64_t native_callable_hash, + int64_t direction_opcode, + int64_t primitive_opcode, + int64_t binding_kind, + const char* logical_name, + bool required, + int64_t schema_version, + const char* subject) { + const int64_t binding_index = native_callable_program_binding_for( + bindings, + native_callable_binding_schema_rows, + native_callable_hash, + direction_opcode, + primitive_opcode, + binding_kind, + logical_name, + required, + schema_version, + subject); + if (binding_index < 0) { + return at::Tensor(); + } + return program_tensor_for_binding(program_tensors, program_tensor_binding_rows, binding_index, subject); +} + +inline RegisteredNativeCallableOutputContract registered_native_callable_output_contract_for( + const at::Tensor& native_callable_output_rows, + int64_t native_callable_hash, + int64_t direction_opcode, + int64_t surface_opcode, + int64_t primitive_opcode, + int64_t output_index, + const std::vector& expected_shape, + int64_t schema_version, + const char* subject) { + validate_registered_native_callable_output_rows(native_callable_output_rows, schema_version); + const int64_t* rows = native_callable_output_rows.data_ptr(); + const int64_t* selected = nullptr; + for (int64_t row_index = 0; row_index < native_callable_output_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kNativeCallableOutputRowColumns; + if (row[1] != native_callable_hash || row[2] != direction_opcode || row[3] != surface_opcode || + row[4] != primitive_opcode) { + continue; + } + const bool shape_matches = + (row[8] == kNativeCallableOutputShapeGateLogits && expected_shape.size() == 4 && expected_shape[2] == 4) || + (row[8] == kNativeCallableOutputShapeDiagonalPreproj && expected_shape.size() == 3) || + (row[8] == kNativeCallableOutputShapeHidden && expected_shape.size() == 3); + if (!shape_matches) { + continue; + } + if (row[6] == output_index) { + TORCH_CHECK( + selected == nullptr, + subject, + " has duplicate compiler native callable output rows for callable=", + native_callable_hash, + ",primitive=", + primitive_opcode, + ",output_index=", + output_index); + selected = row; + continue; + } + } + const int64_t* compatible = nullptr; + if (selected == nullptr) { + for (int64_t row_index = 0; row_index < native_callable_output_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kNativeCallableOutputRowColumns; + if (row[1] != native_callable_hash || row[2] != direction_opcode || row[3] != surface_opcode || + row[4] != primitive_opcode) { + continue; + } + const bool shape_matches = + (row[8] == kNativeCallableOutputShapeGateLogits && expected_shape.size() == 4 && expected_shape[2] == 4) || + (row[8] == kNativeCallableOutputShapeDiagonalPreproj && expected_shape.size() == 3) || + (row[8] == kNativeCallableOutputShapeHidden && expected_shape.size() == 3); + if (!shape_matches) { + continue; + } + if (compatible == nullptr) { + compatible = row; + continue; + } + TORCH_CHECK( + compatible[7] == row[7] && compatible[8] == row[8] && compatible[9] == row[9] && + compatible[10] == row[10], + subject, + " has ambiguous compiler native callable output rows for callable=", + native_callable_hash, + ",primitive=", + primitive_opcode, + ",shape_kind=", + row[8]); + } + } + if (selected == nullptr) { + selected = compatible; + } + TORCH_CHECK( + selected != nullptr, + subject, + " has no compiler native callable output row for callable=", + native_callable_hash, + ",primitive=", + primitive_opcode, + ",output_index=", + output_index); + return RegisteredNativeCallableOutputContract{ + selected[7], + selected[8], + selected[9], + selected[10], + }; +} + +inline int64_t registered_native_callable_output_logical_index( + const RegisteredNativeCallableOutputContract& contract, + int64_t primitive_row_index, + int64_t output_binding_index) { + if (contract.logical_index_source == kNativeCallableOutputLogicalBindingIndex) { + return output_binding_index; + } + TORCH_CHECK( + contract.logical_index_source == kNativeCallableOutputLogicalPrimitiveRow, + "native callable output contract has invalid logical-index source"); + return primitive_row_index; +} + +inline void validate_registered_native_callable_output_shape_contract( + const RegisteredNativeCallableOutputContract& contract, + const std::vector& expected_shape, + const char* subject) { + if (contract.shape_kind == kNativeCallableOutputShapeGateLogits) { + TORCH_CHECK( + expected_shape.size() == 4 && expected_shape[2] == 4, + subject, + " expected gate-logit output shape [B,R,4,H] from compiler output contract"); + return; + } + if (contract.shape_kind == kNativeCallableOutputShapeDiagonalPreproj) { + TORCH_CHECK( + expected_shape.size() == 3 && expected_shape[2] % 2 == 0, + subject, + " expected diagonal preprojection output shape [B,R,2H] from compiler output contract"); + return; + } + TORCH_CHECK( + contract.shape_kind == kNativeCallableOutputShapeHidden, + subject, + " has invalid native callable output shape kind"); + TORCH_CHECK( + expected_shape.size() == 3, + subject, + " expected hidden output shape [B,R,H] from compiler output contract"); +} + +inline at::Tensor registered_runtime_buffer_for_native_callable_output( + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_output_rows, + int64_t native_callable_hash, + int64_t primitive_opcode, + int64_t primitive_row_index, + int64_t output_index, + int64_t output_binding_index, + int64_t schema_version, + const std::vector& expected_shape, + const char* subject) { + const RegisteredNativeCallableOutputContract contract = registered_native_callable_output_contract_for( + native_callable_output_rows, + native_callable_hash, + kForwardDirectionOpcode, + kTransitionSurfaceOpcode, + primitive_opcode, + output_index, + expected_shape, + schema_version, + subject); + validate_registered_native_callable_output_shape_contract(contract, expected_shape, subject); + at::Tensor output = registered_runtime_buffer_for_role_any_shape( + runtime_buffer_tensors, + runtime_buffer_rows, + contract.runtime_role, + registered_native_callable_output_logical_index(contract, primitive_row_index, output_binding_index), + subject); + output = registered_materialize_deferred_local_runtime_buffer( + output, + contract.runtime_role, + expected_shape, + subject); + if (contract.init_kind == kNativeCallableOutputInitZeros) { + output.zero_(); + } + return output; +} + +inline at::Tensor registered_runtime_buffer_for_native_callable_output_allow_deferred( + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_output_rows, + int64_t native_callable_hash, + int64_t primitive_opcode, + int64_t primitive_row_index, + int64_t output_index, + int64_t output_binding_index, + int64_t schema_version, + const std::vector& expected_shape, + const char* subject) { + const RegisteredNativeCallableOutputContract contract = registered_native_callable_output_contract_for( + native_callable_output_rows, + native_callable_hash, + kForwardDirectionOpcode, + kTransitionSurfaceOpcode, + primitive_opcode, + output_index, + expected_shape, + schema_version, + subject); + validate_registered_native_callable_output_shape_contract(contract, expected_shape, subject); + at::Tensor output = registered_runtime_buffer_for_role_any_shape( + runtime_buffer_tensors, + runtime_buffer_rows, + contract.runtime_role, + registered_native_callable_output_logical_index(contract, primitive_row_index, output_binding_index), + subject); + if (!registered_runtime_buffer_is_deferred_local_placeholder(output)) { + output = registered_materialize_deferred_local_runtime_buffer( + output, + contract.runtime_role, + expected_shape, + subject); + if (contract.init_kind == kNativeCallableOutputInitZeros) { + output.zero_(); + } + } + return output; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_forward_strategies.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_forward_strategies.cuh new file mode 100644 index 00000000..cfe9281a --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_forward_strategies.cuh @@ -0,0 +1,1594 @@ +#pragma once + +inline RegisteredForwardMessageExecutorState bind_neighborhood_attention_project_message_handler( + const RegisteredFusedProgramSpan& span, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& boundary_seq, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t input_count, + int64_t head_dim, + int64_t value_dim) { + const at::Tensor recurrent_q = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_recurrent_query", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fused forward recurrent query"); + const at::Tensor input_direct_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_input_direct_kv_weight", + true, + true, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fused forward input direct K/V weight"); + const at::Tensor input_group_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_input_group_kv_weight", + true, + true, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fused forward input group K/V weight"); + const at::Tensor recurrent_kv_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_recurrent_kv_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fused forward recurrent K/V weight"); + const int64_t input_group_size = + input_group_weight.defined() && input_group_weight.numel() > 0 ? input_count / input_group_weight.size(0) : 1; + std::vector input_kv = flat_bucket_registered_forward_sender_kv_sequence_cuda( + boundary_seq, + input_direct_weight, + input_group_weight, + input_group_size, + head_dim, + value_dim, + forward_executor_rows, + forward_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); + return RegisteredForwardMessageExecutorState{ + span, + native_strategy, + { + {"message_recurrent_query", recurrent_q}, + {"message_input_direct_kv_weight", input_direct_weight}, + {"message_input_group_kv_weight", input_group_weight}, + {"message_recurrent_kv_weight", recurrent_kv_weight}, + }, + { + {"input_k_seq", input_kv[0]}, + {"input_v_seq", input_kv[1]}, + }, + input_group_size, + value_dim, + }; +} + +inline std::vector run_neighborhood_attention_project_recurrent_kv( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& recurrent_hidden, + const at::Tensor& empty, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t head_dim, + int64_t value_dim, + bool materialize_key_bank) { + (void)materialize_key_bank; + const at::Tensor recurrent_kv_weight = + registered_forward_message_tensor(message_executor, "message_recurrent_kv_weight"); + return flat_bucket_registered_forward_sender_kv_step_cuda( + recurrent_hidden, + recurrent_kv_weight, + empty, + 1, + head_dim, + value_dim, + forward_executor_rows, + forward_executor_binding_rows, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal); +} + +inline at::Tensor run_registered_forward_partitioned_attention_into( + const at::Tensor& q, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const at::Tensor& out, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay, + const char* name) { + check_cuda_float_rank2(q, "registered forward attention q"); + check_cuda_float_bank(input_k_step, "registered forward attention input_k"); + check_cuda_float_bank(input_v_step, "registered forward attention input_v"); + check_cuda_float_bank(recurrent_k, "registered forward attention recurrent_k"); + check_cuda_float_bank(recurrent_v, "registered forward attention recurrent_v"); + check_cuda_float_bank(out, name); + check_cuda_int_rank2(receiver_sender_idx, "registered forward attention sender index"); + check_cuda_float_rank1(local_distance, "registered forward attention distance"); + check_cuda_int_rank1(local_delay, "registered forward attention delay"); + check_cuda_long_rank1(step_flat, "registered forward attention step"); + const int B = static_cast(input_k_step.size(0)); + const int input_senders = static_cast(input_k_step.size(1)); + const int recurrent_senders = static_cast(recurrent_k.size(1)); + const int receiver_count = static_cast(q.size(0)); + const int degree = static_cast(receiver_sender_idx.size(1)); + const int head_dim = static_cast(q.size(1)); + const int key_dim = static_cast(input_k_step.size(2)); + const int value_dim = static_cast(input_v_step.size(2)); + validate_registered_partitioned_attention_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + receiver_count); + TORCH_CHECK(degree <= kMaxRegisteredAttentionOffsets, name, " degree exceeds registered attention limit"); + TORCH_CHECK( + registered_tensor_matches_shape(out, {B, receiver_count, value_dim}), + name, + " compiler runtime buffer shape mismatch"); + TORCH_CHECK( + input_v_step.size(0) == B && recurrent_k.size(0) == B && recurrent_v.size(0) == B, + name, + " bank batch mismatch"); + TORCH_CHECK(recurrent_k.size(2) == key_dim && key_dim >= head_dim, name, " key bank width mismatch"); + TORCH_CHECK(input_v_step.size(2) == recurrent_v.size(2), name, " V dimension mismatch"); + TORCH_CHECK(receiver_sender_idx.size(0) == receiver_count, name, " receiver row mismatch"); + TORCH_CHECK(local_distance.size(0) == degree && local_delay.size(0) == degree, name, " offset metadata mismatch"); + TORCH_CHECK(step_flat.size(0) == B, name, " step tensor batch mismatch"); + const int64_t total_elements = static_cast(B) * receiver_count * value_dim; + if (total_elements == 0) { + return out; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_partitioned_attention_kernel<<>>( + q.data_ptr(), + input_k_step.data_ptr(), + input_v_step.data_ptr(), + recurrent_k.data_ptr(), + recurrent_v.data_ptr(), + receiver_sender_idx.data_ptr(), + local_distance.data_ptr(), + local_delay.data_ptr(), + step_flat.data_ptr(), + out.data_ptr(), + total_elements, + B, + receiver_count, + input_senders, + recurrent_senders, + degree, + head_dim, + key_dim, + value_dim, + 1.0f / std::sqrt(static_cast(head_dim > 0 ? head_dim : 1)), + static_cast(distance_scale), + use_delay); + check_launch("registered_forward_partitioned_attention_kernel"); + return out; +} + +inline at::Tensor run_neighborhood_attention_project_message( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const at::Tensor& recurrent_msg_output_override, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay, + std::vector* memory_stage_rows, + int64_t memory_stage_local_step) { + (void)memory_stage_rows; + (void)memory_stage_local_step; + (void)recurrent_hidden; + const at::Tensor recurrent_q = registered_forward_message_tensor(message_executor, "message_recurrent_query"); + const std::vector recurrent_msg_shape = { + input_k_step.size(0), + recurrent_q.size(0), + input_v_step.size(2), + }; + at::Tensor recurrent_msg = recurrent_msg_output_override.defined() && recurrent_msg_output_override.numel() > 0 + ? recurrent_msg_output_override + : registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardRecurrentMsg, + runtime_buffer_logical_index, + recurrent_msg_shape, + "registered fused forward recurrent message"); + TORCH_CHECK( + registered_tensor_matches_shape(recurrent_msg, recurrent_msg_shape), + "registered fused forward recurrent message output override shape mismatch"); + return run_registered_forward_partitioned_attention_into( + recurrent_q, + input_k_step, + input_v_step, + recurrent_k, + recurrent_v, + receiver_sender_idx, + local_distance, + local_delay, + step_flat, + recurrent_msg, + forward_executor_rows, + forward_executor_binding_rows, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal, + distance_scale, + use_delay, + "registered fused forward recurrent message"); +} + +inline at::Tensor run_registered_forward_fixed_slot_context_key_bank( + const at::Tensor& sender_slot_key, + const at::Tensor& sender_context_key, + int64_t batch_size, + int64_t time_steps, + const at::TensorOptions& options, + const char* subject) { + check_cuda_float_rank2(sender_slot_key, subject); + check_cuda_float_rank2(sender_context_key, subject); + TORCH_CHECK(sender_slot_key.sizes() == sender_context_key.sizes(), subject, " key source shape mismatch"); + const int sender_count = static_cast(sender_slot_key.size(0)); + const int key_part_dim = static_cast(sender_slot_key.size(1)); + TORCH_CHECK(key_part_dim > 0, subject, " key dimension must be positive"); + auto sender_k = at::empty({time_steps, batch_size, sender_count, 2 * key_part_dim}, options); + const int64_t total_elements = static_cast(time_steps) * batch_size * sender_count * 2 * key_part_dim; + if (total_elements == 0) { + return sender_k; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_fixed_slot_context_key_sequence_kernel<<>>( + sender_slot_key.data_ptr(), + sender_context_key.data_ptr(), + sender_k.data_ptr(), + total_elements, + static_cast(batch_size), + static_cast(time_steps), + sender_count, + key_part_dim); + check_launch("registered_forward_fixed_slot_context_key_sequence_kernel"); + return sender_k; +} + +inline at::Tensor run_registered_forward_sender_value_sequence( + const at::Tensor& sender_cells_seq, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + int64_t group_size, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + TORCH_CHECK(sender_cells_seq.is_cuda() && sender_cells_seq.is_contiguous(), "sender_cells_seq must be CUDA contiguous"); + TORCH_CHECK(sender_cells_seq.scalar_type() == at::kFloat && sender_cells_seq.dim() == 4, "sender_cells_seq must be float32 [B,T,N,H]"); + const int B = static_cast(sender_cells_seq.size(0)); + const int T = static_cast(sender_cells_seq.size(1)); + const int sender_count = static_cast(sender_cells_seq.size(2)); + const int hidden_dim = static_cast(sender_cells_seq.size(3)); + const bool has_grouped = grouped_weight.defined() && grouped_weight.numel() > 0; + const bool has_direct = direct_weight.defined() && direct_weight.numel() > 0; + TORCH_CHECK(has_direct || has_grouped, "registered sender value projection requires direct or grouped weight"); + int value_dim = 0; + if (has_grouped) { + TORCH_CHECK(group_size > 0, "group_size must be positive for grouped sender value projection"); + TORCH_CHECK(sender_count % static_cast(group_size) == 0, "group_size must divide sender count"); + TORCH_CHECK( + grouped_weight.is_cuda() && grouped_weight.is_contiguous() && grouped_weight.scalar_type() == at::kFloat && + grouped_weight.dim() == 3, + "grouped value weight must be float32 [G,H,V]"); + value_dim = static_cast(grouped_weight.size(2)); + TORCH_CHECK( + grouped_weight.size(0) == sender_count / static_cast(group_size) && + grouped_weight.size(1) == hidden_dim, + "grouped value weight shape mismatch"); + } else { + TORCH_CHECK( + direct_weight.is_cuda() && direct_weight.is_contiguous() && direct_weight.scalar_type() == at::kFloat && + direct_weight.dim() == 3, + "direct value weight must be float32 [N,H,V]"); + value_dim = static_cast(direct_weight.size(2)); + TORCH_CHECK( + direct_weight.size(0) == sender_count && direct_weight.size(1) == hidden_dim, + "direct value weight shape mismatch"); + } + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "registered sender value sequence projection"); + auto sender_v = at::empty({T, B, sender_count, value_dim}, sender_cells_seq.options()); + const int64_t total_elements = static_cast(B) * T * sender_count * value_dim; + if (total_elements == 0) { + return sender_v; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_sender_value_sequence_kernel<<>>( + sender_cells_seq.data_ptr(), + has_direct ? direct_weight.data_ptr() : nullptr, + has_grouped ? grouped_weight.data_ptr() : nullptr, + sender_v.data_ptr(), + total_elements, + B, + T, + sender_count, + hidden_dim, + value_dim, + static_cast(group_size), + has_grouped); + check_launch("registered_forward_sender_value_sequence_kernel"); + return sender_v; +} + +inline at::Tensor run_registered_forward_sender_value_step( + const at::Tensor& sender_cells, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + int64_t group_size, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + check_cuda_float_bank(sender_cells, "registered sender value cells"); + const int B = static_cast(sender_cells.size(0)); + const int sender_count = static_cast(sender_cells.size(1)); + const int hidden_dim = static_cast(sender_cells.size(2)); + const bool has_grouped = grouped_weight.defined() && grouped_weight.numel() > 0; + const bool has_direct = direct_weight.defined() && direct_weight.numel() > 0; + TORCH_CHECK(has_direct || has_grouped, "registered sender value step projection requires direct or grouped weight"); + int value_dim = 0; + if (has_grouped) { + TORCH_CHECK(group_size > 0, "group_size must be positive for grouped sender value step projection"); + TORCH_CHECK(sender_count % static_cast(group_size) == 0, "group_size must divide sender count"); + TORCH_CHECK( + grouped_weight.is_cuda() && grouped_weight.is_contiguous() && grouped_weight.scalar_type() == at::kFloat && + grouped_weight.dim() == 3, + "grouped value step weight must be float32 [G,H,V]"); + value_dim = static_cast(grouped_weight.size(2)); + TORCH_CHECK( + grouped_weight.size(0) == sender_count / static_cast(group_size) && + grouped_weight.size(1) == hidden_dim, + "grouped value step weight shape mismatch"); + } else { + TORCH_CHECK( + direct_weight.is_cuda() && direct_weight.is_contiguous() && direct_weight.scalar_type() == at::kFloat && + direct_weight.dim() == 3, + "direct value step weight must be float32 [N,H,V]"); + value_dim = static_cast(direct_weight.size(2)); + TORCH_CHECK( + direct_weight.size(0) == sender_count && direct_weight.size(1) == hidden_dim, + "direct value step weight shape mismatch"); + } + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "registered sender value step projection"); + auto sender_v = at::empty({B, sender_count, value_dim}, sender_cells.options()); + const int64_t total_elements = static_cast(B) * sender_count * value_dim; + if (total_elements == 0) { + return sender_v; + } + const at::Tensor& weight = has_grouped ? grouped_weight : direct_weight; + const at::Tensor empty_bias = sender_cells.new_empty({0}); + fabric::cuda::ops::dense_affine_out_cuda( + sender_cells, + weight, + empty_bias, + sender_v, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + group_size, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + return sender_v; +} + +inline RegisteredForwardMessageExecutorState bind_fixed_slot_context_message_handler( + const RegisteredFusedProgramSpan& span, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& boundary_seq, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t input_count, + int64_t head_dim, + int64_t value_dim) { + const at::Tensor query_slot_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_query_slot_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fixed-slot context query slot weight"); + const at::Tensor query_context_scalar = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_query_context_scalar", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fixed-slot context query scalar"); + const at::Tensor sender_slot_key = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_sender_slot_key_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fixed-slot context sender slot key"); + const at::Tensor sender_context_key = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_sender_context_key", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fixed-slot context sender context key"); + const at::Tensor input_value_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_input_value_weight", + true, + true, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fixed-slot context input value weight"); + const at::Tensor input_group_value_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_input_group_value_weight", + true, + true, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fixed-slot context grouped input value weight"); + const at::Tensor recurrent_value_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_recurrent_value_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fixed-slot context recurrent value weight"); + const at::Tensor message_output_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_output_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fixed-slot context output weight"); + const int64_t input_group_size = input_group_value_weight.defined() && input_group_value_weight.numel() > 0 + ? input_count / input_group_value_weight.size(0) + : 1; + const at::Tensor input_k_seq = run_registered_forward_fixed_slot_context_key_bank( + sender_slot_key.slice(0, 0, input_count), + sender_context_key.slice(0, 0, input_count), + boundary_seq.size(0), + boundary_seq.size(1), + boundary_seq.options(), + "registered fixed-slot context input key"); + const at::Tensor input_v_seq = run_registered_forward_sender_value_sequence( + boundary_seq, + input_value_weight, + input_group_value_weight, + input_group_size, + forward_executor_rows, + forward_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); + return RegisteredForwardMessageExecutorState{ + span, + native_strategy, + { + {"message_query_slot_weight", query_slot_weight}, + {"message_query_context_scalar", query_context_scalar}, + {"message_sender_slot_key_weight", sender_slot_key}, + {"message_sender_context_key", sender_context_key}, + {"message_input_value_weight", input_value_weight}, + {"message_input_group_value_weight", input_group_value_weight}, + {"message_recurrent_value_weight", recurrent_value_weight}, + {"message_output_weight", message_output_weight}, + }, + { + {"input_k_seq", input_k_seq}, + {"input_v_seq", input_v_seq}, + }, + input_group_size, + message_output_weight.size(0), + }; +} + +inline std::vector run_fixed_slot_context_recurrent_kv( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& recurrent_hidden, + const at::Tensor& empty, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t head_dim, + int64_t value_dim, + bool materialize_key_bank) { + (void)head_dim; + (void)value_dim; + const at::Tensor input_k_seq = registered_forward_message_cache_tensor(message_executor, "input_k_seq"); + const at::Tensor sender_slot_key = + registered_forward_message_tensor(message_executor, "message_sender_slot_key_weight"); + const at::Tensor sender_context_key = + registered_forward_message_tensor(message_executor, "message_sender_context_key"); + const at::Tensor recurrent_value_weight = + registered_forward_message_tensor(message_executor, "message_recurrent_value_weight"); + at::Tensor recurrent_k = materialize_key_bank + ? run_registered_forward_fixed_slot_context_key_bank( + sender_slot_key.slice( + 0, + input_k_seq.size(2), + input_k_seq.size(2) + recurrent_hidden.size(1)), + sender_context_key.slice( + 0, + input_k_seq.size(2), + input_k_seq.size(2) + recurrent_hidden.size(1)), + recurrent_hidden.size(0), + 1, + recurrent_hidden.options(), + "registered fixed-slot context recurrent key").select(0, 0).contiguous() + : at::Tensor(); + at::Tensor recurrent_v = materialize_key_bank + ? run_registered_forward_sender_value_step( + recurrent_hidden, + recurrent_value_weight, + empty, + 1, + forward_executor_rows, + forward_executor_binding_rows, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal) + : at::Tensor(); + return {recurrent_k, recurrent_v}; +} + +inline at::Tensor run_fixed_slot_context_message( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const at::Tensor& recurrent_msg_output_override, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay, + std::vector* memory_stage_rows, + int64_t memory_stage_local_step) { + const at::Tensor recurrent_q = registered_forward_message_tensor(message_executor, "message_query_slot_weight"); + const at::Tensor query_context_scalar = + registered_forward_message_tensor(message_executor, "message_query_context_scalar"); + const at::Tensor sender_slot_key = + registered_forward_message_tensor(message_executor, "message_sender_slot_key_weight"); + const at::Tensor sender_context_key = + registered_forward_message_tensor(message_executor, "message_sender_context_key"); + const at::Tensor message_output_weight = + registered_forward_message_tensor(message_executor, "message_output_weight"); + const at::Tensor recurrent_value_weight = + registered_forward_message_tensor(message_executor, "message_recurrent_value_weight"); + const std::vector recurrent_msg_shape = { + input_v_step.size(0), + recurrent_q.size(0), + message_executor.message_output_dim, + }; + at::Tensor recurrent_msg = recurrent_msg_output_override.defined() && recurrent_msg_output_override.numel() > 0 + ? recurrent_msg_output_override + : registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardRecurrentMsg, + runtime_buffer_logical_index, + recurrent_msg_shape, + "registered fixed-slot context recurrent message"); + TORCH_CHECK( + registered_tensor_matches_shape(recurrent_msg, recurrent_msg_shape), + "registered fixed-slot context recurrent message output override shape mismatch"); + check_cuda_float_rank2(recurrent_q, "registered fixed-slot context query slot"); + check_cuda_float_rank1(query_context_scalar, "registered fixed-slot context query scalar"); + check_cuda_float_rank2(sender_slot_key, "registered fixed-slot context sender slot key"); + check_cuda_float_rank2(sender_context_key, "registered fixed-slot context context key"); + check_cuda_float_rank2(message_output_weight, "registered fixed-slot context output weight"); + check_cuda_float_bank(input_k_step, "registered fixed-slot context input key"); + check_cuda_float_bank(input_v_step, "registered fixed-slot context input value"); + if (recurrent_k.defined()) { + check_cuda_float_bank(recurrent_k, "registered fixed-slot context recurrent key"); + } + const bool recurrent_value_deferred = !recurrent_v.defined(); + if (recurrent_value_deferred) { + check_cuda_float_bank(recurrent_hidden, "registered fixed-slot context recurrent hidden value source"); + check_cuda_float_bank(recurrent_value_weight, "registered fixed-slot context recurrent value weight"); + } else { + check_cuda_float_bank(recurrent_v, "registered fixed-slot context recurrent value"); + } + check_cuda_int_rank2(receiver_sender_idx, "registered fixed-slot context sender index"); + check_cuda_float_rank1(local_distance, "registered fixed-slot context distance"); + check_cuda_int_rank1(local_delay, "registered fixed-slot context delay"); + check_cuda_long_rank1(step_flat, "registered fixed-slot context step"); + const int B = static_cast(input_v_step.size(0)); + const int receiver_count = static_cast(recurrent_q.size(0)); + const int input_senders = static_cast(input_v_step.size(1)); + const int recurrent_senders = static_cast( + recurrent_value_deferred ? recurrent_hidden.size(1) : recurrent_v.size(1)); + const int degree = static_cast(receiver_sender_idx.size(1)); + const int head = static_cast(recurrent_q.size(1)); + const int value = static_cast(input_v_step.size(2)); + const int message = static_cast(message_executor.message_output_dim); + TORCH_CHECK(degree <= kMaxRegisteredAttentionOffsets, "fixed-slot context degree exceeds registered attention limit"); + if (recurrent_value_deferred) { + TORCH_CHECK( + recurrent_hidden.size(0) == B && recurrent_hidden.size(1) == recurrent_senders, + "fixed-slot context deferred recurrent value source shape mismatch"); + TORCH_CHECK( + recurrent_value_weight.size(0) == recurrent_senders && + recurrent_value_weight.size(1) == recurrent_hidden.size(2) && + recurrent_value_weight.size(2) == value, + "fixed-slot context deferred recurrent value weight shape mismatch"); + } else { + TORCH_CHECK(recurrent_v.size(2) == value, "fixed-slot context recurrent value dimension mismatch"); + } + TORCH_CHECK(value >= head, "fixed-slot context query scalar requires value_dim >= head_dim"); + TORCH_CHECK( + sender_slot_key.size(0) == input_senders + recurrent_senders && + sender_context_key.size(0) == input_senders + recurrent_senders, + "fixed-slot context sender key table shape mismatch"); + TORCH_CHECK( + message_output_weight.size(0) == message && + message_output_weight.size(1) == value, + "fixed-slot context output weight shape mismatch"); + const int64_t row_count = static_cast(B) * receiver_count; + if (row_count == 0 || message == 0) { + return recurrent_msg; + } + const auto stream = at::cuda::getCurrentCUDAStream(); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeOutputWeight); + const at::Tensor output_weight_t = message_output_weight.transpose(0, 1).contiguous(); + append_registered_forward_memory_stage_row( + memory_stage_rows, + output_weight_t, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterOutputWeight); + auto run_message_chunk = [&]( + const at::Tensor& input_v_chunk, + const at::Tensor& recurrent_v_chunk, + const at::Tensor& step_flat_chunk, + const at::Tensor& recurrent_msg_chunk) { + const int chunk_batch = static_cast(input_v_chunk.size(0)); + const int64_t chunk_row_count = static_cast(chunk_batch) * receiver_count; + if (chunk_row_count == 0) { + return; + } + const int blocks = static_cast(std::min( + 4096, + chunk_row_count)); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeWeightedValueAlloc); + at::Tensor weighted_value = at::empty({chunk_batch, receiver_count, value}, recurrent_msg.options()); + append_registered_forward_memory_stage_row( + memory_stage_rows, + weighted_value, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterWeightedValueAlloc); + if (degree <= 32) { + constexpr int kWarpsPerBlock = kThreadsPerBlock / 32; + const int warp_blocks = static_cast(std::min( + 4096, + (chunk_row_count + kWarpsPerBlock - 1) / kWarpsPerBlock)); + registered_forward_fixed_slot_context_weighted_value_warp_kernel<<>>( + recurrent_q.data_ptr(), + query_context_scalar.data_ptr(), + sender_slot_key.data_ptr(), + sender_context_key.data_ptr(), + input_v_chunk.data_ptr(), + recurrent_v_chunk.data_ptr(), + receiver_sender_idx.data_ptr(), + local_distance.data_ptr(), + local_delay.data_ptr(), + step_flat_chunk.data_ptr(), + weighted_value.data_ptr(), + chunk_row_count, + chunk_batch, + receiver_count, + input_senders, + recurrent_senders, + degree, + head, + value, + 1.0f / std::sqrt(static_cast(2 * (head > 0 ? head : 1))), + static_cast(distance_scale), + use_delay); + check_launch("registered_forward_fixed_slot_context_weighted_value_warp_kernel"); + } else { + registered_forward_fixed_slot_context_weighted_value_kernel<<>>( + recurrent_q.data_ptr(), + query_context_scalar.data_ptr(), + sender_slot_key.data_ptr(), + sender_context_key.data_ptr(), + input_v_chunk.data_ptr(), + recurrent_v_chunk.data_ptr(), + receiver_sender_idx.data_ptr(), + local_distance.data_ptr(), + local_delay.data_ptr(), + step_flat_chunk.data_ptr(), + weighted_value.data_ptr(), + chunk_row_count, + chunk_batch, + receiver_count, + input_senders, + recurrent_senders, + degree, + head, + value, + 1.0f / std::sqrt(static_cast(2 * (head > 0 ? head : 1))), + static_cast(distance_scale), + use_delay); + check_launch("registered_forward_fixed_slot_context_weighted_value_kernel"); + } + append_registered_forward_memory_stage_row( + memory_stage_rows, + weighted_value, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeOutputWeight); + append_registered_forward_memory_stage_row( + memory_stage_rows, + output_weight_t, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterOutputWeight); + const at::Tensor weighted_value_rows = weighted_value.reshape({chunk_row_count, value}); + at::Tensor recurrent_msg_rows = recurrent_msg_chunk.reshape({chunk_row_count, message}); + append_registered_forward_memory_stage_row( + memory_stage_rows, + weighted_value, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterWeightedValue); + constexpr int64_t kProjectedMessageChunkBytes = 1024LL * 1024LL * 1024LL; + const int64_t projected_row_bytes = static_cast(message) * static_cast(sizeof(float)); + const int64_t projected_rows_per_chunk = projected_row_bytes <= 0 + ? chunk_row_count + : std::max(1, kProjectedMessageChunkBytes / projected_row_bytes); + const bool can_project_into_recurrent_msg = weighted_value.data_ptr() != recurrent_msg_chunk.data_ptr(); + if (can_project_into_recurrent_msg && chunk_row_count <= projected_rows_per_chunk) { + append_registered_forward_memory_stage_row( + memory_stage_rows, + weighted_value, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeProjectedGemm); + const at::Tensor empty_bias = weighted_value.new_empty({0}); + fabric::cuda::ops::dense_affine_out_cuda( + weighted_value, + output_weight_t, + empty_bias, + recurrent_msg_chunk, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjectedGemm); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjectedContiguous); + TORCH_CHECK( + registered_tensor_matches_shape(recurrent_msg_rows, {chunk_row_count, message}), + "fixed-slot context projected message GEMM shape mismatch"); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjected); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeNormalize); + registered_forward_fixed_slot_context_normalize_rows_inplace_kernel<<>>( + recurrent_msg_rows.data_ptr(), + chunk_row_count, + message, + 1.0e-5f); + check_launch("registered_forward_fixed_slot_context_normalize_rows_inplace_kernel"); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterNormalize); + } else if (chunk_row_count <= projected_rows_per_chunk) { + append_registered_forward_memory_stage_row( + memory_stage_rows, + weighted_value, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeProjectedGemm); + at::Tensor projected = at::matmul( + weighted_value_rows, + output_weight_t); + append_registered_forward_memory_stage_row( + memory_stage_rows, + projected, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjectedGemm); + projected = projected.contiguous(); + append_registered_forward_memory_stage_row( + memory_stage_rows, + projected, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjectedContiguous); + TORCH_CHECK( + registered_tensor_matches_shape(projected, {chunk_row_count, message}), + "fixed-slot context projected message GEMM shape mismatch"); + append_registered_forward_memory_stage_row( + memory_stage_rows, + projected, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjected); + append_registered_forward_memory_stage_row( + memory_stage_rows, + projected, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeNormalize); + registered_forward_fixed_slot_context_normalize_rows_kernel<<>>( + projected.data_ptr(), + recurrent_msg_rows.data_ptr(), + chunk_row_count, + message, + 1.0e-5f); + check_launch("registered_forward_fixed_slot_context_normalize_rows_kernel"); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterNormalize); + } else { + for (int64_t row_start = 0; row_start < chunk_row_count; row_start += projected_rows_per_chunk) { + const int64_t chunk_rows = std::min(projected_rows_per_chunk, chunk_row_count - row_start); + append_registered_forward_memory_stage_row( + memory_stage_rows, + weighted_value, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeProjectedGemm); + const int chunk_blocks = static_cast(std::min( + 4096, + chunk_rows)); + at::Tensor recurrent_msg_rows_chunk = recurrent_msg_rows.narrow(0, row_start, chunk_rows); + if (can_project_into_recurrent_msg) { + at::Tensor projected = at::matmul( + weighted_value_rows.narrow(0, row_start, chunk_rows), + output_weight_t); + recurrent_msg_rows_chunk.copy_(projected); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjectedGemm); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjectedContiguous); + TORCH_CHECK( + registered_tensor_matches_shape(recurrent_msg_rows_chunk, {chunk_rows, message}), + "fixed-slot context projected message chunk GEMM shape mismatch"); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjected); + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeNormalize); + registered_forward_fixed_slot_context_normalize_rows_inplace_kernel<<>>( + recurrent_msg_rows_chunk.data_ptr(), + chunk_rows, + message, + 1.0e-5f); + check_launch("registered_forward_fixed_slot_context_normalize_rows_inplace_kernel_chunk"); + } else { + at::Tensor projected = at::matmul( + weighted_value_rows.narrow(0, row_start, chunk_rows), + output_weight_t); + append_registered_forward_memory_stage_row( + memory_stage_rows, + projected, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjectedGemm); + projected = projected.contiguous(); + append_registered_forward_memory_stage_row( + memory_stage_rows, + projected, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjectedContiguous); + TORCH_CHECK( + registered_tensor_matches_shape(projected, {chunk_rows, message}), + "fixed-slot context projected message chunk GEMM shape mismatch"); + append_registered_forward_memory_stage_row( + memory_stage_rows, + projected, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterProjected); + append_registered_forward_memory_stage_row( + memory_stage_rows, + projected, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageBeforeNormalize); + registered_forward_fixed_slot_context_normalize_rows_kernel<<>>( + projected.data_ptr(), + recurrent_msg_rows_chunk.data_ptr(), + chunk_rows, + message, + 1.0e-5f); + check_launch("registered_forward_fixed_slot_context_normalize_rows_kernel_chunk"); + } + append_registered_forward_memory_stage_row( + memory_stage_rows, + recurrent_msg_chunk, + memory_stage_local_step, + kRegisteredForwardMemoryStageMessageAfterNormalize); + } + } + }; + constexpr int64_t kMessageProducerConsumerChunkBytes = 768LL * 1024LL * 1024LL; + const auto bounded_message_batch_chunk = [&](int64_t extra_bytes_per_batch) { + const int64_t weighted_value_bytes_per_batch = std::max( + 1, + static_cast(receiver_count) * value * static_cast(sizeof(float))); + const int64_t bytes_per_batch = std::max( + 1, + weighted_value_bytes_per_batch + std::max(0, extra_bytes_per_batch)); + return std::max( + 1, + std::min(B, kMessageProducerConsumerChunkBytes / bytes_per_batch)); + }; + if (recurrent_value_deferred) { + constexpr int64_t kDeferredRecurrentValueChunkBytes = 512LL * 1024LL * 1024LL; + const int64_t value_bank_bytes_per_batch = std::max( + 1, + static_cast(recurrent_senders) * value * static_cast(sizeof(float))); + const int64_t weighted_value_bytes_per_batch = std::max( + 1, + static_cast(receiver_count) * value * static_cast(sizeof(float))); + const int64_t bytes_per_batch = std::max( + 1, + value_bank_bytes_per_batch + weighted_value_bytes_per_batch); + const int64_t batch_chunk = std::max( + 1, + std::min(B, kDeferredRecurrentValueChunkBytes / bytes_per_batch)); + const at::Tensor empty_weight = recurrent_hidden.new_empty({0}); + for (int64_t batch_start = 0; batch_start < B; batch_start += batch_chunk) { + const int64_t current_batch = std::min(batch_chunk, B - batch_start); + at::Tensor recurrent_hidden_chunk = recurrent_hidden.narrow(0, batch_start, current_batch); + at::Tensor recurrent_v_chunk = run_registered_forward_sender_value_step( + recurrent_hidden_chunk, + recurrent_value_weight, + empty_weight, + 1, + forward_executor_rows, + forward_executor_binding_rows, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal); + run_message_chunk( + input_v_step.narrow(0, batch_start, current_batch), + recurrent_v_chunk, + step_flat.narrow(0, batch_start, current_batch), + recurrent_msg.narrow(0, batch_start, current_batch)); + } + return recurrent_msg; + } + if (value == message) { + const int64_t batch_chunk = bounded_message_batch_chunk(0); + for (int64_t batch_start = 0; batch_start < B; batch_start += batch_chunk) { + const int64_t current_batch = std::min(batch_chunk, B - batch_start); + run_message_chunk( + input_v_step.narrow(0, batch_start, current_batch), + recurrent_v.narrow(0, batch_start, current_batch), + step_flat.narrow(0, batch_start, current_batch), + recurrent_msg.narrow(0, batch_start, current_batch)); + } + return recurrent_msg; + } + run_message_chunk(input_v_step, recurrent_v, step_flat, recurrent_msg); + return recurrent_msg; +} + +inline at::Tensor run_fixed_slot_context_stream_transition_input( + const RegisteredForwardMessageExecutorState& message_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const RegisteredTransitionInputProjectionTarget& transition_target, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay, + std::vector* memory_stage_rows, + int64_t memory_stage_local_step) { + TORCH_CHECK( + transition_target.defined, + "registered fixed-slot context stream_transition_input requires a compiler transition input target"); + check_cuda_float_bank(input_v_step, "registered fixed-slot stream_transition input value"); + check_cuda_float_bank(recurrent_hidden, "registered fixed-slot stream_transition recurrent hidden"); + check_cuda_float_bank(transition_target.output, "registered fixed-slot stream_transition transition output"); + TORCH_CHECK( + input_v_step.size(0) == transition_target.output.size(0), + "registered fixed-slot stream_transition batch mismatch"); + TORCH_CHECK( + message_executor.span.receiver_count == transition_target.receiver_count && + transition_target.output.size(1) == transition_target.receiver_count, + "registered fixed-slot stream_transition receiver span mismatch"); + TORCH_CHECK( + message_executor.message_output_dim == transition_target.message_dim, + "registered fixed-slot stream_transition message width mismatch"); + TORCH_CHECK( + transition_target.output.size(2) == transition_target.hidden, + "registered fixed-slot stream_transition output hidden mismatch"); + const int64_t B = input_v_step.size(0); + const int64_t receivers = transition_target.receiver_count; + const int64_t message_dim = transition_target.message_dim; + const int64_t hidden = transition_target.hidden; + const int64_t value = input_v_step.size(2); + const int64_t message_bytes_per_batch = std::max( + 1, + receivers * message_dim * static_cast(sizeof(float))); + const int64_t transition_bytes_per_batch = std::max( + 1, + receivers * hidden * static_cast(sizeof(float))); + const int64_t value_bank_bytes_per_batch = recurrent_v.defined() && recurrent_v.numel() > 0 + ? std::max(1, recurrent_v.size(1) * value * static_cast(sizeof(float))) + : std::max(1, recurrent_hidden.size(1) * value * static_cast(sizeof(float))); + constexpr int64_t kStreamTransitionChunkBytes = 512LL * 1024LL * 1024LL; + const int64_t bytes_per_batch = std::max( + 1, + message_bytes_per_batch + transition_bytes_per_batch + value_bank_bytes_per_batch); + const int64_t batch_chunk = std::max( + 1, + std::min(B, kStreamTransitionChunkBytes / bytes_per_batch)); + for (int64_t batch_start = 0; batch_start < B; batch_start += batch_chunk) { + const int64_t current_batch = std::min(batch_chunk, B - batch_start); + at::Tensor message_chunk = at::empty({current_batch, receivers, message_dim}, transition_target.output.options()); + at::Tensor recurrent_hidden_chunk = recurrent_hidden.narrow(0, batch_start, current_batch); + at::Tensor recurrent_k_chunk = recurrent_k.defined() && recurrent_k.numel() > 0 + ? recurrent_k.narrow(0, batch_start, current_batch) + : recurrent_k; + at::Tensor recurrent_v_chunk = recurrent_v.defined() && recurrent_v.numel() > 0 + ? recurrent_v.narrow(0, batch_start, current_batch) + : recurrent_v; + at::Tensor projected_message_chunk = run_fixed_slot_context_message( + message_executor, + input_k_step.narrow(0, batch_start, current_batch), + input_v_step.narrow(0, batch_start, current_batch), + recurrent_hidden_chunk, + recurrent_k_chunk, + recurrent_v_chunk, + receiver_sender_idx, + local_distance, + local_delay, + step_flat.narrow(0, batch_start, current_batch), + message_chunk, + runtime_buffer_tensors, + runtime_buffer_rows, + forward_executor_rows, + forward_executor_binding_rows, + runtime_buffer_logical_index, + distance_scale, + use_delay, + memory_stage_rows, + memory_stage_local_step); + at::Tensor transition_output_chunk = transition_target.output.narrow(0, batch_start, current_batch); + fabric::cuda::ops::dense_affine_out_cuda( + projected_message_chunk, + transition_target.input_weight, + transition_target.input_bias, + transition_output_chunk, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + } + return transition_target.output; +} + +inline at::Tensor run_fixed_slot_context_keyless_readout_message( + const RegisteredForwardMessageExecutorState& message_executor, + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay) { + (void)input_k_step; + const at::Tensor output_q = registered_forward_readout_tensor(readout_executor, "readout_output_query"); + const at::Tensor sender_slot_key = + registered_forward_message_tensor(message_executor, "message_sender_slot_key_weight"); + check_cuda_float_rank2(output_q, "registered fixed-slot keyless readout output query"); + check_cuda_float_rank2(sender_slot_key, "registered fixed-slot keyless readout sender key"); + check_cuda_float_bank(input_v_step, "registered fixed-slot keyless readout input value"); + check_cuda_float_bank(recurrent_v, "registered fixed-slot keyless readout recurrent value"); + check_cuda_int_rank2(receiver_sender_idx, "registered fixed-slot keyless readout sender index"); + check_cuda_float_rank1(local_distance, "registered fixed-slot keyless readout distance"); + check_cuda_int_rank1(local_delay, "registered fixed-slot keyless readout delay"); + check_cuda_long_rank1(step_flat, "registered fixed-slot keyless readout step"); + const int B = static_cast(input_v_step.size(0)); + const int output_count = static_cast(output_q.size(0)); + const int input_senders = static_cast(input_v_step.size(1)); + const int recurrent_senders = static_cast(recurrent_v.size(1)); + const int degree = static_cast(receiver_sender_idx.size(1)); + const int head = static_cast(output_q.size(1)); + const int value = static_cast(input_v_step.size(2)); + if (degree > 32) { + return at::Tensor(); + } + validate_registered_partitioned_attention_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal, + output_count); + TORCH_CHECK(recurrent_v.size(0) == B, "fixed-slot keyless readout recurrent value batch mismatch"); + TORCH_CHECK(recurrent_v.size(2) == value, "fixed-slot keyless readout value dimension mismatch"); + TORCH_CHECK( + sender_slot_key.size(0) == input_senders + recurrent_senders && sender_slot_key.size(1) >= head, + "fixed-slot keyless readout sender key shape mismatch"); + TORCH_CHECK(receiver_sender_idx.size(0) == output_count, "fixed-slot keyless readout receiver row mismatch"); + TORCH_CHECK(local_distance.size(0) == degree && local_delay.size(0) == degree, "fixed-slot keyless readout offset mismatch"); + at::Tensor output_msg = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardOutputMsg, + runtime_buffer_logical_index, + { + input_v_step.size(0), + output_q.size(0), + input_v_step.size(2), + }, + "registered fixed-slot keyless forward output message"); + const int64_t row_count = static_cast(B) * output_count; + if (row_count == 0) { + return output_msg; + } + constexpr int kWarpsPerBlock = kThreadsPerBlock / 32; + const int blocks = static_cast(std::min( + 4096, + (row_count + kWarpsPerBlock - 1) / kWarpsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_fixed_slot_context_keyless_readout_warp_kernel<<>>( + output_q.data_ptr(), + sender_slot_key.data_ptr(), + input_v_step.data_ptr(), + recurrent_v.data_ptr(), + receiver_sender_idx.data_ptr(), + local_distance.data_ptr(), + local_delay.data_ptr(), + step_flat.data_ptr(), + output_msg.data_ptr(), + row_count, + output_count, + input_senders, + recurrent_senders, + degree, + head, + value, + 1.0f / std::sqrt(static_cast(head > 0 ? head : 1)), + static_cast(distance_scale), + use_delay); + check_launch("registered_forward_fixed_slot_context_keyless_readout_warp_kernel"); + return output_msg; +} + +inline at::Tensor run_fixed_slot_context_direct_keyless_readout_message( + const RegisteredForwardMessageExecutorState& message_executor, + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay) { + (void)input_k_step; + const at::Tensor output_q = registered_forward_readout_tensor(readout_executor, "readout_output_query"); + const at::Tensor sender_slot_key = + registered_forward_message_tensor(message_executor, "message_sender_slot_key_weight"); + const at::Tensor recurrent_value_weight = + registered_forward_message_tensor(message_executor, "message_recurrent_value_weight"); + check_cuda_float_rank2(output_q, "registered fixed-slot direct keyless readout output query"); + check_cuda_float_rank2(sender_slot_key, "registered fixed-slot direct keyless readout sender key"); + check_cuda_float_bank(input_v_step, "registered fixed-slot direct keyless readout input value"); + check_cuda_float_bank(recurrent_hidden, "registered fixed-slot direct keyless readout recurrent hidden"); + check_cuda_int_rank2(receiver_sender_idx, "registered fixed-slot direct keyless readout sender index"); + check_cuda_float_rank1(local_distance, "registered fixed-slot direct keyless readout distance"); + check_cuda_int_rank1(local_delay, "registered fixed-slot direct keyless readout delay"); + check_cuda_long_rank1(step_flat, "registered fixed-slot direct keyless readout step"); + if ( + !recurrent_value_weight.defined() || recurrent_value_weight.numel() == 0 || + !recurrent_value_weight.is_cuda() || !recurrent_value_weight.is_contiguous() || + recurrent_value_weight.scalar_type() != at::kFloat || recurrent_value_weight.dim() != 3) { + return at::Tensor(); + } + const int B = static_cast(input_v_step.size(0)); + const int output_count = static_cast(output_q.size(0)); + const int input_senders = static_cast(input_v_step.size(1)); + const int recurrent_senders = static_cast(recurrent_hidden.size(1)); + const int degree = static_cast(receiver_sender_idx.size(1)); + const int head = static_cast(output_q.size(1)); + const int hidden = static_cast(recurrent_hidden.size(2)); + const int value = static_cast(input_v_step.size(2)); + TORCH_CHECK(degree <= 32, "fixed-slot streaming readout degree exceeds streaming warp strategy limit"); + validate_registered_partitioned_attention_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal, + output_count); + TORCH_CHECK( + recurrent_value_weight.size(0) == recurrent_senders && + recurrent_value_weight.size(1) == hidden && + recurrent_value_weight.size(2) == value, + "fixed-slot direct keyless readout recurrent value weight shape mismatch"); + TORCH_CHECK( + sender_slot_key.size(0) == input_senders + recurrent_senders && sender_slot_key.size(1) >= head, + "fixed-slot direct keyless readout sender key shape mismatch"); + TORCH_CHECK(receiver_sender_idx.size(0) == output_count, "fixed-slot direct keyless readout receiver row mismatch"); + TORCH_CHECK( + local_distance.size(0) == degree && local_delay.size(0) == degree, + "fixed-slot direct keyless readout offset mismatch"); + at::Tensor output_msg = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardOutputMsg, + runtime_buffer_logical_index, + { + input_v_step.size(0), + output_q.size(0), + input_v_step.size(2), + }, + "registered fixed-slot direct keyless forward output message"); + const int64_t row_count = static_cast(B) * output_count; + if (row_count == 0) { + return output_msg; + } + constexpr int kWarpsPerBlock = kThreadsPerBlock / 32; + const int blocks = static_cast(std::min( + 4096, + (row_count + kWarpsPerBlock - 1) / kWarpsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_fixed_slot_context_direct_keyless_readout_warp_kernel<<>>( + output_q.data_ptr(), + sender_slot_key.data_ptr(), + input_v_step.data_ptr(), + recurrent_hidden.data_ptr(), + recurrent_value_weight.data_ptr(), + receiver_sender_idx.data_ptr(), + local_distance.data_ptr(), + local_delay.data_ptr(), + step_flat.data_ptr(), + output_msg.data_ptr(), + row_count, + output_count, + input_senders, + recurrent_senders, + degree, + head, + hidden, + value, + 1.0f / std::sqrt(static_cast(head > 0 ? head : 1)), + static_cast(distance_scale), + use_delay); + check_launch("registered_forward_fixed_slot_context_direct_keyless_readout_warp_kernel"); + return output_msg; +} + +inline at::Tensor run_fixed_slot_context_stream_readout_message( + const RegisteredForwardMessageExecutorState& message_executor, + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_hidden, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay) { + const at::Tensor output_q = registered_forward_readout_tensor(readout_executor, "readout_output_query"); + const at::Tensor sender_slot_key = + registered_forward_message_tensor(message_executor, "message_sender_slot_key_weight"); + const at::Tensor sender_context_key = + registered_forward_message_tensor(message_executor, "message_sender_context_key"); + const at::Tensor recurrent_value_weight = + registered_forward_message_tensor(message_executor, "message_recurrent_value_weight"); + check_cuda_float_rank2(output_q, "registered fixed-slot streaming readout output query"); + check_cuda_float_rank2(sender_slot_key, "registered fixed-slot streaming readout sender slot key"); + check_cuda_float_rank2(sender_context_key, "registered fixed-slot streaming readout sender context key"); + check_cuda_float_bank(input_k_step, "registered fixed-slot streaming readout input key"); + check_cuda_float_bank(input_v_step, "registered fixed-slot streaming readout input value"); + check_cuda_float_bank(recurrent_hidden, "registered fixed-slot streaming readout recurrent hidden"); + check_cuda_float_bank(recurrent_value_weight, "registered fixed-slot streaming readout recurrent value weight"); + check_cuda_int_rank2(receiver_sender_idx, "registered fixed-slot streaming readout sender index"); + check_cuda_float_rank1(local_distance, "registered fixed-slot streaming readout distance"); + check_cuda_int_rank1(local_delay, "registered fixed-slot streaming readout delay"); + check_cuda_long_rank1(step_flat, "registered fixed-slot streaming readout step"); + const int B = static_cast(input_v_step.size(0)); + const int output_count = static_cast(output_q.size(0)); + const int input_senders = static_cast(input_v_step.size(1)); + const int recurrent_senders = static_cast(recurrent_hidden.size(1)); + const int degree = static_cast(receiver_sender_idx.size(1)); + const int head = static_cast(output_q.size(1)); + const int key_dim = static_cast(input_k_step.size(2)); + const int sender_key_dim = static_cast(sender_slot_key.size(1)); + const int hidden = static_cast(recurrent_hidden.size(2)); + const int value = static_cast(input_v_step.size(2)); + if (degree > 32) { + return at::Tensor(); + } + validate_registered_partitioned_attention_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal, + output_count); + TORCH_CHECK( + input_k_step.size(0) == B && input_k_step.size(1) == input_senders, + "fixed-slot streaming readout input key shape mismatch"); + TORCH_CHECK( + recurrent_value_weight.size(0) == recurrent_senders && + recurrent_value_weight.size(1) == hidden && + recurrent_value_weight.size(2) == value, + "fixed-slot streaming readout recurrent value weight shape mismatch"); + TORCH_CHECK( + sender_context_key.sizes() == sender_slot_key.sizes() && + sender_slot_key.size(0) == input_senders + recurrent_senders, + "fixed-slot streaming readout sender key table shape mismatch"); + TORCH_CHECK( + key_dim >= head && 2 * sender_key_dim >= head, + "fixed-slot streaming readout query/key width mismatch"); + TORCH_CHECK(receiver_sender_idx.size(0) == output_count, "fixed-slot streaming readout receiver row mismatch"); + TORCH_CHECK( + local_distance.size(0) == degree && local_delay.size(0) == degree, + "fixed-slot streaming readout offset mismatch"); + TORCH_CHECK(step_flat.size(0) == B, "fixed-slot streaming readout step batch mismatch"); + at::Tensor output_msg = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardOutputMsg, + runtime_buffer_logical_index, + { + input_v_step.size(0), + output_q.size(0), + input_v_step.size(2), + }, + "registered fixed-slot streaming forward output message"); + const int64_t row_count = static_cast(B) * output_count; + if (row_count == 0) { + return output_msg; + } + const at::Tensor empty_weight = recurrent_hidden.new_empty({0}); + constexpr int64_t kStreamingReadoutValueChunkBytes = 512LL * 1024LL * 1024LL; + const int64_t value_bank_bytes_per_batch = std::max( + 1, + static_cast(recurrent_senders) * value * static_cast(sizeof(float))); + const int64_t output_msg_bytes_per_batch = std::max( + 1, + static_cast(output_count) * value * static_cast(sizeof(float))); + const int64_t bytes_per_batch = std::max( + 1, + value_bank_bytes_per_batch + output_msg_bytes_per_batch); + const int64_t batch_chunk = std::max( + 1, + std::min(B, kStreamingReadoutValueChunkBytes / bytes_per_batch)); + for (int64_t batch_start = 0; batch_start < B; batch_start += batch_chunk) { + const int64_t current_batch = std::min(batch_chunk, B - batch_start); + at::Tensor recurrent_hidden_chunk = recurrent_hidden.narrow(0, batch_start, current_batch); + at::Tensor recurrent_k_chunk = run_registered_forward_fixed_slot_context_key_bank( + sender_slot_key.slice(0, input_senders, input_senders + recurrent_senders), + sender_context_key.slice(0, input_senders, input_senders + recurrent_senders), + current_batch, + 1, + recurrent_hidden.options(), + "registered fixed-slot streaming readout recurrent key chunk").select(0, 0).contiguous(); + at::Tensor recurrent_v_chunk = run_registered_forward_sender_value_step( + recurrent_hidden_chunk, + recurrent_value_weight, + empty_weight, + 1, + forward_executor_rows, + forward_executor_binding_rows, + message_executor.span.executor_id, + message_executor.span.bucket_ordinal); + at::Tensor input_k_chunk = input_k_step.narrow(0, batch_start, current_batch).contiguous(); + at::Tensor input_v_chunk = input_v_step.narrow(0, batch_start, current_batch).contiguous(); + at::Tensor step_flat_chunk = step_flat.narrow(0, batch_start, current_batch).contiguous(); + at::Tensor output_msg_chunk = output_msg.narrow(0, batch_start, current_batch); + run_registered_forward_partitioned_attention_into( + output_q, + input_k_chunk, + input_v_chunk, + recurrent_k_chunk, + recurrent_v_chunk, + receiver_sender_idx, + local_distance, + local_delay, + step_flat_chunk, + output_msg_chunk, + forward_executor_rows, + forward_executor_binding_rows, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal, + distance_scale, + use_delay, + "registered fixed-slot streaming readout output message chunk"); + } + return output_msg; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_reverse_strategies.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_reverse_strategies.cuh new file mode 100644 index 00000000..7f295449 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_reverse_strategies.cuh @@ -0,0 +1,599 @@ +#pragma once + +inline std::vector run_neighborhood_attention_project_recurrent_kv_backward( + const at::Tensor& grad_recurrent_k, + const at::Tensor& grad_recurrent_v, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& backend_to_graph_inverse_order, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t head_dim, + int64_t value_dim, + bool return_input_grad) { + at::Tensor recurrent_kv_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_recurrent_kv_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fused recurrent K/V projection recurrent weight"); + std::vector backward = flat_bucket_registered_backward_sender_kv_projection_cuda( + recurrent_hidden_backend_order, + recurrent_kv_weight, + recurrent_hidden_backend_order.new_empty({0}), + grad_recurrent_k, + grad_recurrent_v, + 1, + head_dim, + value_dim, + return_input_grad, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); + at::Tensor grad_hidden_graph_order = return_input_grad + ? backward[0].index_select(1, backend_to_graph_inverse_order).contiguous() + : recurrent_hidden_backend_order.new_empty({0}); + at::Tensor grad_weight_graph_order = backward[1].index_select(0, backend_to_graph_inverse_order).contiguous(); + return {grad_hidden_graph_order, grad_weight_graph_order}; +} + +inline std::vector run_neighborhood_attention_project_initial_recurrent_kv_backward( + const at::Tensor& grad_recurrent_k, + const at::Tensor& grad_recurrent_v, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& backend_to_graph_inverse_order, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t head_dim, + int64_t value_dim, + bool return_input_grad) { + return run_neighborhood_attention_project_recurrent_kv_backward( + grad_recurrent_k, + grad_recurrent_v, + recurrent_hidden_backend_order, + backend_to_graph_inverse_order, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + reverse_executor_rows, + reverse_executor_binding_rows, + span, + head_dim, + value_dim, + return_input_grad); +} + +inline std::vector run_neighborhood_attention_project_recurrent_kv_forward_recompute( + const at::Tensor& input_k_reference, + const at::Tensor& recurrent_hidden_backend_order, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t head_dim, + int64_t value_dim) { + static_cast(input_k_reference); + at::Tensor recurrent_kv_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_recurrent_kv_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fused output-message recurrent K/V forward recompute recurrent weight"); + return flat_bucket_registered_forward_sender_kv_step_cuda( + recurrent_hidden_backend_order, + recurrent_kv_weight, + recurrent_hidden_backend_order.new_empty({0}), + 1, + head_dim, + value_dim, + forward_executor_rows, + forward_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); +} + +inline std::vector run_neighborhood_attention_project_recurrent_message_backward( + const at::Tensor& grad_recurrent_msg, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k_before, + const at::Tensor& recurrent_v_before, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& recurrent_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& recurrent_neighbor_idx, + const at::Tensor& recurrent_neighbor_valid, + const at::Tensor& recurrent_edge_distance, + const at::Tensor& recurrent_edge_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + double distance_scale, + bool use_sparse_messages, + bool use_delay) { + at::Tensor recurrent_q = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_recurrent_query", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fused recurrent-message recurrent_q"); + if (use_sparse_messages) { + return flat_bucket_registered_backward_sparse_attention_cuda( + grad_recurrent_msg, + recurrent_q, + input_k, + input_v, + recurrent_k_before, + recurrent_v_before, + recurrent_neighbor_idx, + recurrent_neighbor_valid, + recurrent_edge_distance, + recurrent_edge_delay, + step_flat, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + distance_scale, + use_delay); + } + return flat_bucket_registered_backward_partitioned_attention_cuda( + grad_recurrent_msg, + recurrent_q, + input_k, + input_v, + recurrent_k_before, + recurrent_v_before, + recurrent_local_sender_idx, + local_distance, + local_delay, + step_flat, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + distance_scale, + use_delay); +} + +inline std::vector run_neighborhood_attention_project_boundary_kv_backward( + const at::Tensor& grad_input_k, + const at::Tensor& grad_input_v, + const at::Tensor& boundary_step, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + bool return_boundary_grad) { + at::Tensor input_sender_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_input_direct_kv_weight", + true, + true, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fused boundary K/V projection input direct weight"); + at::Tensor input_group_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_input_group_kv_weight", + true, + true, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fused boundary K/V projection input grouped weight"); + std::vector backward = flat_bucket_registered_backward_sender_kv_projection_cuda( + boundary_step, + input_sender_weight, + input_group_weight, + grad_input_k, + grad_input_v, + group_size, + head_dim, + value_dim, + return_boundary_grad, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); + const bool grouped = input_group_weight.defined() && input_group_weight.numel() > 0 && group_size > 1; + at::Tensor grouped_flag = at::empty({1}, at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + grouped_flag.data_ptr()[0] = grouped ? 1 : 0; + return {backward[0], backward[1], grouped_flag}; +} + +inline std::vector run_fixed_slot_context_recurrent_kv_backward( + const at::Tensor& grad_recurrent_k, + const at::Tensor& grad_recurrent_v, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& backend_to_graph_inverse_order, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t head_dim, + int64_t value_dim, + bool return_input_grad) { + static_cast(grad_recurrent_k); + static_cast(head_dim); + at::Tensor recurrent_value_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_recurrent_value_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context recurrent value weight"); + std::vector backward = flat_bucket_registered_backward_sender_value_projection_cuda( + recurrent_hidden_backend_order, + recurrent_value_weight, + recurrent_hidden_backend_order.new_empty({0}), + grad_recurrent_v, + 1, + value_dim, + return_input_grad, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); + at::Tensor grad_hidden_graph_order = return_input_grad + ? backward[0].index_select(1, backend_to_graph_inverse_order).contiguous() + : recurrent_hidden_backend_order.new_empty({0}); + at::Tensor grad_weight_graph_order = backward[1].index_select(0, backend_to_graph_inverse_order).contiguous(); + return {grad_hidden_graph_order, grad_weight_graph_order}; +} + +inline std::vector run_fixed_slot_context_initial_recurrent_kv_backward( + const at::Tensor& grad_recurrent_k, + const at::Tensor& grad_recurrent_v, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& backend_to_graph_inverse_order, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t head_dim, + int64_t value_dim, + bool return_input_grad) { + return run_fixed_slot_context_recurrent_kv_backward( + grad_recurrent_k, + grad_recurrent_v, + recurrent_hidden_backend_order, + backend_to_graph_inverse_order, + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + reverse_executor_rows, + reverse_executor_binding_rows, + span, + head_dim, + value_dim, + return_input_grad); +} + +inline std::vector run_fixed_slot_context_recurrent_kv_forward_recompute( + const at::Tensor& input_k_reference, + const at::Tensor& recurrent_hidden_backend_order, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t head_dim, + int64_t value_dim) { + static_cast(value_dim); + at::Tensor sender_slot_key = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_sender_slot_key_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context recurrent K/V forward recompute sender slot key"); + at::Tensor sender_context_key = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_sender_context_key", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context recurrent K/V forward recompute sender context key"); + at::Tensor recurrent_value_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_recurrent_value_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context recurrent K/V forward recompute recurrent value weight"); + const int64_t input_sender_count = input_k_reference.size(1); + at::Tensor recurrent_k = run_registered_forward_fixed_slot_context_key_bank( + sender_slot_key.slice( + 0, + input_sender_count, + input_sender_count + recurrent_hidden_backend_order.size(1)), + sender_context_key.slice( + 0, + input_sender_count, + input_sender_count + recurrent_hidden_backend_order.size(1)), + recurrent_hidden_backend_order.size(0), + 1, + recurrent_hidden_backend_order.options(), + "registered fixed-slot context recurrent key recompute").select(0, 0).contiguous(); + at::Tensor recurrent_v = run_registered_forward_sender_value_step( + recurrent_hidden_backend_order, + recurrent_value_weight, + recurrent_hidden_backend_order.new_empty({0}), + 1, + forward_executor_rows, + forward_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); + return {recurrent_k, recurrent_v}; +} + +inline std::vector run_fixed_slot_context_recurrent_message_backward( + const at::Tensor& grad_recurrent_msg, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k_before, + const at::Tensor& recurrent_v_before, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& recurrent_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& recurrent_neighbor_idx, + const at::Tensor& recurrent_neighbor_valid, + const at::Tensor& recurrent_edge_distance, + const at::Tensor& recurrent_edge_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + double distance_scale, + bool use_sparse_messages, + bool use_delay) { + TORCH_CHECK(!use_sparse_messages, "fixed-slot context reverse message requires compiler partitioned sender rows"); + static_cast(recurrent_neighbor_idx); + static_cast(recurrent_neighbor_valid); + static_cast(recurrent_edge_distance); + static_cast(recurrent_edge_delay); + at::Tensor query_slot = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_query_slot_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context query slot weight"); + at::Tensor query_context_scalar = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_query_context_scalar", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context query scalar"); + at::Tensor output_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_output_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context output weight"); + std::vector backward = flat_bucket_registered_backward_fixed_slot_context_message_cuda( + grad_recurrent_msg, + query_slot, + query_context_scalar, + input_k, + input_v, + recurrent_k_before, + recurrent_v_before, + output_weight, + recurrent_local_sender_idx, + local_distance, + local_delay, + step_flat, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + distance_scale, + use_delay); + TORCH_CHECK( + backward.size() == 7, + "fixed-slot context reverse message must return the declared compiler output group"); + backward[1] = backward[1].sum(0).contiguous(); + backward[3] = backward[3].sum(0).contiguous(); + return backward; +} + +inline std::vector run_fixed_slot_context_boundary_kv_backward( + const at::Tensor& grad_input_k, + const at::Tensor& grad_input_v, + const at::Tensor& boundary_step, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + bool return_boundary_grad) { + static_cast(grad_input_k); + at::Tensor input_value_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_input_value_weight", + true, + true, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context input value weight"); + at::Tensor input_group_value_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "message_input_group_value_weight", + true, + true, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fixed-slot context grouped input value weight"); + const bool grouped = input_group_value_weight.defined() && input_group_value_weight.numel() > 0 && group_size > 1; + static_cast(head_dim); + std::vector backward = flat_bucket_registered_backward_sender_value_projection_cuda( + boundary_step, + input_value_weight, + input_group_value_weight, + grad_input_v, + group_size, + value_dim, + return_boundary_grad, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); + at::Tensor grouped_flag = at::empty({1}, at::TensorOptions().dtype(at::kLong).device(at::kCPU)); + grouped_flag.data_ptr()[0] = grouped ? 1 : 0; + return {backward[0], backward[1], grouped_flag}; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/readout_forward_strategies.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/readout_forward_strategies.cuh new file mode 100644 index 00000000..d102f379 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/readout_forward_strategies.cuh @@ -0,0 +1,237 @@ +#pragma once + +inline at::Tensor run_registered_forward_readout_projection_into( + const at::Tensor& output_msg, + const at::Tensor& value_to_output_weight, + const at::Tensor& output_cell_bias, + const at::Tensor& output_cells, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + check_cuda_float_bank(output_msg, "registered forward readout output_msg"); + TORCH_CHECK(output_cells.is_cuda(), "registered forward readout output_cells must be a CUDA tensor"); + TORCH_CHECK(output_cells.scalar_type() == at::kFloat, "registered forward readout output_cells must be float32"); + TORCH_CHECK(output_cells.dim() == 3, "registered forward readout output_cells must be rank-3"); + TORCH_CHECK( + value_to_output_weight.is_cuda() && value_to_output_weight.is_contiguous() && + value_to_output_weight.scalar_type() == at::kFloat && value_to_output_weight.dim() == 3, + "registered forward readout value_to_output_weight must be float32 [O,V,H]"); + check_cuda_float_rank2(output_cell_bias, "registered forward readout output_cell_bias"); + const int B = static_cast(output_msg.size(0)); + const int output_count = static_cast(output_msg.size(1)); + const int value_dim = static_cast(output_msg.size(2)); + const int hidden_dim = static_cast(value_to_output_weight.size(2)); + validate_registered_readout_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + executor_id, + bucket_ordinal, + output_count); + TORCH_CHECK( + registered_tensor_matches_shape(output_cells, {B, output_count, hidden_dim}), + "registered forward readout output_cells compiler runtime buffer shape mismatch"); + TORCH_CHECK( + value_to_output_weight.size(0) == output_count && value_to_output_weight.size(1) == value_dim, + "registered forward readout value_to_output_weight shape mismatch"); + TORCH_CHECK( + output_cell_bias.size(0) == output_count && output_cell_bias.size(1) == hidden_dim, + "registered forward readout output_cell_bias shape mismatch"); + const int64_t total_elements = static_cast(B) * output_count * hidden_dim; + if (total_elements == 0) { + return output_cells; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + if (output_cells.is_contiguous()) { + registered_forward_readout_projection_kernel<<>>( + output_msg.data_ptr(), + value_to_output_weight.data_ptr(), + output_cell_bias.data_ptr(), + output_cells.data_ptr(), + total_elements, + output_count, + value_dim, + hidden_dim); + check_launch("registered_forward_readout_projection_kernel"); + } else { + registered_forward_readout_projection_strided_kernel<<>>( + output_msg.data_ptr(), + value_to_output_weight.data_ptr(), + output_cell_bias.data_ptr(), + output_cells.data_ptr(), + total_elements, + output_count, + value_dim, + hidden_dim, + output_cells.stride(0), + output_cells.stride(1), + output_cells.stride(2)); + check_launch("registered_forward_readout_projection_strided_kernel"); + } + return output_cells; +} + +inline RegisteredForwardReadoutExecutorState bind_projection_reduction_boundary_readout_handler( + const RegisteredFusedProgramSpan& span, + const RegisteredNativeStrategyRow& native_strategy, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows) { + const at::Tensor output_q = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "readout_output_query", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fused forward output query"); + const at::Tensor value_to_output_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "readout_value_to_output_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fused forward value_to_output_weight"); + const at::Tensor output_cell_bias = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + forward_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "readout_output_cell_bias", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "forward", + "registered fused forward output_cell_bias"); + return RegisteredForwardReadoutExecutorState{ + span, + native_strategy, + { + {"readout_output_query", output_q}, + {"readout_value_to_output_weight", value_to_output_weight}, + {"readout_output_cell_bias", output_cell_bias}, + }, + }; +} + +inline at::Tensor run_projection_reduction_boundary_readout_message( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& input_k_step, + const at::Tensor& input_v_step, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& step_flat, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index, + double distance_scale, + bool use_delay) { + const at::Tensor output_q = registered_forward_readout_tensor(readout_executor, "readout_output_query"); + at::Tensor output_msg = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardOutputMsg, + runtime_buffer_logical_index, + { + input_k_step.size(0), + output_q.size(0), + input_v_step.size(2), + }, + "registered fused forward output message"); + return run_registered_forward_partitioned_attention_into( + output_q, + input_k_step, + input_v_step, + recurrent_k, + recurrent_v, + receiver_sender_idx, + local_distance, + local_delay, + step_flat, + output_msg, + forward_executor_rows, + forward_executor_binding_rows, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal, + distance_scale, + use_delay, + "registered fused forward output message"); +} + +inline at::Tensor run_projection_reduction_boundary_readout_projection( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& output_msg, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t runtime_buffer_logical_index) { + const at::Tensor value_to_output_weight = + registered_forward_readout_tensor(readout_executor, "readout_value_to_output_weight"); + const at::Tensor output_cell_bias = registered_forward_readout_tensor(readout_executor, "readout_output_cell_bias"); + at::Tensor output_cells = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleForwardOutputCells, + runtime_buffer_logical_index, + { + output_msg.size(0), + value_to_output_weight.size(0), + value_to_output_weight.size(2), + }, + "registered fused forward output cells"); + return run_registered_forward_readout_projection_into( + output_msg, + value_to_output_weight, + output_cell_bias, + output_cells, + forward_executor_rows, + forward_executor_binding_rows, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal); +} + +inline at::Tensor run_projection_reduction_boundary_readout_projection_into( + const RegisteredForwardReadoutExecutorState& readout_executor, + const at::Tensor& output_msg, + const at::Tensor& output_cells, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows) { + const at::Tensor value_to_output_weight = + registered_forward_readout_tensor(readout_executor, "readout_value_to_output_weight"); + const at::Tensor output_cell_bias = registered_forward_readout_tensor(readout_executor, "readout_output_cell_bias"); + return run_registered_forward_readout_projection_into( + output_msg, + value_to_output_weight, + output_cell_bias, + output_cells, + forward_executor_rows, + forward_executor_binding_rows, + readout_executor.span.executor_id, + readout_executor.span.bucket_ordinal); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/readout_reverse_strategies.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/readout_reverse_strategies.cuh new file mode 100644 index 00000000..3f52fd88 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/readout_reverse_strategies.cuh @@ -0,0 +1,121 @@ +#pragma once + +inline std::vector run_projection_reduction_boundary_readout_backward( + const at::Tensor& grad_cells_out, + const at::Tensor& output_msg, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& graph_to_backend_order, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t input_count, + int64_t recurrent_count) { + at::Tensor value_to_output_weight = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "readout_value_to_output_weight", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fused backward readout value_to_output_weight"); + return flat_bucket_registered_backward_readout_layout_projection_cuda( + grad_cells_out, + output_msg, + value_to_output_weight, + graph_to_backend_order, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + input_count, + recurrent_count); +} + +inline std::vector run_projection_reduction_boundary_output_message_backward( + const at::Tensor& grad_output_msg, + const std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_program_access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& output_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& local_delay, + const at::Tensor& output_neighbor_idx, + const at::Tensor& output_neighbor_valid, + const at::Tensor& output_edge_distance, + const at::Tensor& output_edge_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + double distance_scale, + bool use_sparse_messages, + bool use_delay) { + at::Tensor output_q = program_tensor_for_native_strategy_access( + program_tensors, + program_tensor_binding_rows, + reverse_program_access_rows, + native_callable_binding_schema_rows, + native_strategy, + kNativeCallableBindingParameter, + "readout_output_query", + true, + false, + span.executor_row_index, + span.bucket_ordinal, + "reverse", + "fused backward output-message output_q"); + if (use_sparse_messages) { + return flat_bucket_registered_backward_sparse_attention_cuda( + grad_output_msg, + output_q, + input_k, + input_v, + recurrent_k, + recurrent_v, + output_neighbor_idx, + output_neighbor_valid, + output_edge_distance, + output_edge_delay, + step_flat, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + distance_scale, + use_delay); + } + return flat_bucket_registered_backward_partitioned_attention_cuda( + grad_output_msg, + output_q, + input_k, + input_v, + recurrent_k, + recurrent_v, + output_local_sender_idx, + local_distance, + local_delay, + step_flat, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + distance_scale, + use_delay); +} + diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/operator_declarations.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/operator_declarations.cuh new file mode 100644 index 00000000..b4865fe5 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/operator_declarations.cuh @@ -0,0 +1,148 @@ +#pragma once + +std::vector flat_bucket_registered_forward_sender_kv_sequence_cuda( + const at::Tensor& sender_cells_seq, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal); + +std::vector flat_bucket_registered_forward_sender_kv_step_cuda( + const at::Tensor& sender_cells, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal); + +std::vector flat_bucket_registered_backward_sender_kv_projection_cuda( + const at::Tensor& sender_cells, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + const at::Tensor& grad_k, + const at::Tensor& grad_v, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + bool return_input_grad, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal); + +std::vector flat_bucket_registered_backward_sender_value_projection_cuda( + const at::Tensor& sender_cells, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + const at::Tensor& grad_v, + int64_t group_size, + int64_t value_dim, + bool return_input_grad, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal); + +std::vector flat_bucket_registered_forward_partitioned_attention_cuda( + const at::Tensor& q, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& offset_distance, + const at::Tensor& offset_delay, + const at::Tensor& step_flat, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay); + +std::vector flat_bucket_registered_forward_readout_projection_cuda( + const at::Tensor& output_msg, + const at::Tensor& value_to_output_weight, + const at::Tensor& output_cell_bias, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal); + +std::vector flat_bucket_registered_backward_partitioned_attention_cuda( + const at::Tensor& grad_msg, + const at::Tensor& q, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& offset_distance, + const at::Tensor& offset_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay); + +std::vector flat_bucket_registered_backward_sparse_attention_cuda( + const at::Tensor& grad_msg, + const at::Tensor& q, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& neighbor_idx, + const at::Tensor& neighbor_valid, + const at::Tensor& edge_distance, + const at::Tensor& edge_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay); + +std::vector flat_bucket_registered_backward_fixed_slot_context_message_cuda( + const at::Tensor& grad_msg, + const at::Tensor& query_slot, + const at::Tensor& query_context_scalar, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& output_weight, + const at::Tensor& receiver_sender_idx, + const at::Tensor& offset_distance, + const at::Tensor& offset_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay); + +std::vector flat_bucket_registered_backward_readout_layout_projection_cuda( + const at::Tensor& grad_cells_out, + const at::Tensor& output_msg, + const at::Tensor& value_to_output_weight, + const at::Tensor& graph_to_backend_order, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t input_count, + int64_t recurrent_count); diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/operator_exports.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/operator_exports.cuh new file mode 100644 index 00000000..dfb6c654 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/operator_exports.cuh @@ -0,0 +1,1223 @@ +#pragma once + +std::vector flat_bucket_registered_forward_partitioned_attention_cuda( + const at::Tensor& q, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& offset_distance, + const at::Tensor& offset_delay, + const at::Tensor& step_flat, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay) { + check_cuda_float_rank2(q, "q"); + check_cuda_float_bank(input_k, "input_k"); + check_cuda_float_bank(input_v, "input_v"); + check_cuda_float_bank(recurrent_k, "recurrent_k"); + check_cuda_float_bank(recurrent_v, "recurrent_v"); + check_cuda_int_rank2(receiver_sender_idx, "receiver_sender_idx"); + check_cuda_float_rank1(offset_distance, "offset_distance"); + check_cuda_int_rank1(offset_delay, "offset_delay"); + check_cuda_long_rank1(step_flat, "step_flat"); + + const int B = static_cast(input_k.size(0)); + const int input_senders = static_cast(input_k.size(1)); + const int recurrent_senders = static_cast(recurrent_k.size(1)); + const int receiver_count = static_cast(q.size(0)); + const int degree = static_cast(receiver_sender_idx.size(1)); + const int head_dim = static_cast(q.size(1)); + const int key_dim = static_cast(input_k.size(2)); + const int value_dim = static_cast(input_v.size(2)); + validate_registered_partitioned_attention_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + receiver_count); + TORCH_CHECK( + degree <= kMaxRegisteredAttentionOffsets, + "registered partitioned attention supports at most ", + kMaxRegisteredAttentionOffsets, + " offsets"); + TORCH_CHECK(input_v.size(0) == B && recurrent_k.size(0) == B && recurrent_v.size(0) == B, "bank batch mismatch"); + TORCH_CHECK(recurrent_k.size(2) == key_dim && key_dim >= head_dim, "K/head dimension mismatch"); + TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), "V dimension mismatch"); + TORCH_CHECK(receiver_sender_idx.size(0) == receiver_count, "receiver_sender_idx receiver count mismatch"); + TORCH_CHECK(offset_distance.size(0) == degree && offset_delay.size(0) == degree, "offset metadata degree mismatch"); + TORCH_CHECK(step_flat.size(0) == B, "step_flat length must match batch/time dimension"); + + auto out = at::empty({B, receiver_count, value_dim}, input_v.options()); + const int64_t total_elements = static_cast(B) * receiver_count * value_dim; + if (total_elements == 0) { + return {out}; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_partitioned_attention_kernel<<>>( + q.data_ptr(), + input_k.data_ptr(), + input_v.data_ptr(), + recurrent_k.data_ptr(), + recurrent_v.data_ptr(), + receiver_sender_idx.data_ptr(), + offset_distance.data_ptr(), + offset_delay.data_ptr(), + step_flat.data_ptr(), + out.data_ptr(), + total_elements, + B, + receiver_count, + input_senders, + recurrent_senders, + degree, + head_dim, + key_dim, + value_dim, + 1.0f / std::sqrt(static_cast(head_dim > 0 ? head_dim : 1)), + static_cast(distance_scale), + use_delay); + check_launch("registered_forward_partitioned_attention_kernel"); + return {out}; +} + +std::vector flat_bucket_registered_backward_partitioned_attention_cuda( + const at::Tensor& grad_msg, + const at::Tensor& q, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& receiver_sender_idx, + const at::Tensor& offset_distance, + const at::Tensor& offset_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay) { + check_cuda_float_bank(grad_msg, "grad_msg"); + check_cuda_float_rank2(q, "q"); + check_cuda_float_bank(input_k, "input_k"); + check_cuda_float_bank(input_v, "input_v"); + check_cuda_float_bank(recurrent_k, "recurrent_k"); + check_cuda_float_bank(recurrent_v, "recurrent_v"); + check_cuda_int_rank2(receiver_sender_idx, "receiver_sender_idx"); + check_cuda_float_rank1(offset_distance, "offset_distance"); + check_cuda_int_rank1(offset_delay, "offset_delay"); + check_cuda_long_rank1(step_flat, "step_flat"); + + const int B = static_cast(input_k.size(0)); + const int input_senders = static_cast(input_k.size(1)); + const int recurrent_senders = static_cast(recurrent_k.size(1)); + const int receiver_count = static_cast(q.size(0)); + const int degree = static_cast(receiver_sender_idx.size(1)); + const int head_dim = static_cast(q.size(1)); + const int key_dim = static_cast(input_k.size(2)); + const int value_dim = static_cast(input_v.size(2)); + validate_registered_partitioned_attention_executor_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + receiver_count); + TORCH_CHECK( + degree <= kMaxRegisteredAttentionOffsets, + "registered partitioned attention backward supports at most ", + kMaxRegisteredAttentionOffsets, + " offsets"); + TORCH_CHECK(input_v.size(0) == B && recurrent_k.size(0) == B && recurrent_v.size(0) == B, "bank batch mismatch"); + TORCH_CHECK(recurrent_k.size(2) == key_dim && key_dim >= head_dim, "K/head dimension mismatch"); + TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), "V dimension mismatch"); + TORCH_CHECK( + grad_msg.size(0) == B && grad_msg.size(1) == receiver_count && grad_msg.size(2) == value_dim, + "grad_msg shape mismatch"); + TORCH_CHECK(receiver_sender_idx.size(0) == receiver_count, "receiver_sender_idx receiver count mismatch"); + TORCH_CHECK(offset_distance.size(0) == degree && offset_delay.size(0) == degree, "offset metadata degree mismatch"); + TORCH_CHECK(step_flat.size(0) == B, "step_flat length must match batch/time dimension"); + + auto grad_q = at::zeros_like(q); + auto grad_input_k = at::zeros_like(input_k); + auto grad_input_v = at::zeros_like(input_v); + auto grad_recurrent_k = at::zeros_like(recurrent_k); + auto grad_recurrent_v = at::zeros_like(recurrent_v); + const int64_t receiver_total = static_cast(B) * receiver_count; + if (receiver_total == 0) { + return {grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v}; + } + const int blocks = static_cast(std::min( + 4096, + (receiver_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_backward_partitioned_attention_kernel<<>>( + grad_msg.data_ptr(), + q.data_ptr(), + input_k.data_ptr(), + input_v.data_ptr(), + recurrent_k.data_ptr(), + recurrent_v.data_ptr(), + receiver_sender_idx.data_ptr(), + offset_distance.data_ptr(), + offset_delay.data_ptr(), + step_flat.data_ptr(), + grad_q.data_ptr(), + grad_input_k.data_ptr(), + grad_input_v.data_ptr(), + grad_recurrent_k.data_ptr(), + grad_recurrent_v.data_ptr(), + receiver_total, + receiver_count, + input_senders, + recurrent_senders, + degree, + head_dim, + key_dim, + value_dim, + 1.0f / std::sqrt(static_cast(head_dim > 0 ? head_dim : 1)), + static_cast(distance_scale), + use_delay); + check_launch("registered_backward_partitioned_attention_kernel"); + return {grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v}; +} + +std::vector flat_bucket_registered_forward_sparse_attention_cuda( + const at::Tensor& q, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& neighbor_idx, + const at::Tensor& neighbor_valid, + const at::Tensor& edge_distance, + const at::Tensor& edge_delay, + const at::Tensor& step_flat, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay) { + check_cuda_float_rank2(q, "q"); + check_cuda_float_bank(input_k, "input_k"); + check_cuda_float_bank(input_v, "input_v"); + check_cuda_float_bank(recurrent_k, "recurrent_k"); + check_cuda_float_bank(recurrent_v, "recurrent_v"); + check_cuda_long_rank2(neighbor_idx, "neighbor_idx"); + check_cuda_bool_rank2(neighbor_valid, "neighbor_valid"); + check_cuda_long_rank2(edge_delay, "edge_delay"); + TORCH_CHECK(edge_distance.is_cuda() && edge_distance.is_contiguous(), "edge_distance must be a CUDA tensor"); + TORCH_CHECK(edge_distance.scalar_type() == at::kFloat && edge_distance.dim() == 2, "edge_distance must be [R,M]"); + check_cuda_long_rank1(step_flat, "step_flat"); + + const int B = static_cast(input_k.size(0)); + const int input_senders = static_cast(input_k.size(1)); + const int recurrent_senders = static_cast(recurrent_k.size(1)); + const int receiver_count = static_cast(q.size(0)); + const int degree = static_cast(neighbor_idx.size(1)); + const int head_dim = static_cast(q.size(1)); + const int key_dim = static_cast(input_k.size(2)); + const int value_dim = static_cast(input_v.size(2)); + validate_registered_partitioned_attention_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + receiver_count); + TORCH_CHECK(input_v.size(0) == B && recurrent_k.size(0) == B && recurrent_v.size(0) == B, "bank batch mismatch"); + TORCH_CHECK(recurrent_k.size(2) == key_dim && key_dim >= head_dim, "K/head dimension mismatch"); + TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), "V dimension mismatch"); + TORCH_CHECK(neighbor_idx.size(0) == receiver_count, "neighbor_idx receiver count mismatch"); + TORCH_CHECK(neighbor_valid.sizes() == neighbor_idx.sizes(), "neighbor_valid shape mismatch"); + TORCH_CHECK(edge_distance.sizes() == neighbor_idx.sizes(), "edge_distance shape mismatch"); + TORCH_CHECK(edge_delay.sizes() == neighbor_idx.sizes(), "edge_delay shape mismatch"); + TORCH_CHECK(step_flat.size(0) == B, "step_flat length must match batch/time dimension"); + + auto out = at::empty({B, receiver_count, value_dim}, input_v.options()); + const int64_t total_elements = static_cast(B) * receiver_count * value_dim; + if (total_elements == 0) { + return {out}; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_sparse_attention_kernel<<>>( + q.data_ptr(), + input_k.data_ptr(), + input_v.data_ptr(), + recurrent_k.data_ptr(), + recurrent_v.data_ptr(), + neighbor_idx.data_ptr(), + neighbor_valid.data_ptr(), + edge_distance.data_ptr(), + edge_delay.data_ptr(), + step_flat.data_ptr(), + out.data_ptr(), + total_elements, + receiver_count, + input_senders, + recurrent_senders, + degree, + head_dim, + key_dim, + value_dim, + 1.0f / std::sqrt(static_cast(head_dim > 0 ? head_dim : 1)), + static_cast(distance_scale), + use_delay); + check_launch("registered_forward_sparse_attention_kernel"); + return {out}; +} + +std::vector flat_bucket_registered_forward_readout_projection_cuda( + const at::Tensor& output_msg, + const at::Tensor& value_to_output_weight, + const at::Tensor& output_cell_bias, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + check_cuda_float_bank(output_msg, "output_msg"); + TORCH_CHECK( + value_to_output_weight.is_cuda() && value_to_output_weight.is_contiguous() && + value_to_output_weight.scalar_type() == at::kFloat && value_to_output_weight.dim() == 3, + "value_to_output_weight must be float32 [O,V,H]"); + check_cuda_float_rank2(output_cell_bias, "output_cell_bias"); + const int B = static_cast(output_msg.size(0)); + const int output_count = static_cast(output_msg.size(1)); + const int value_dim = static_cast(output_msg.size(2)); + const int hidden_dim = static_cast(value_to_output_weight.size(2)); + validate_registered_readout_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + executor_id, + bucket_ordinal, + output_count); + TORCH_CHECK( + value_to_output_weight.size(0) == output_count && value_to_output_weight.size(1) == value_dim, + "value_to_output_weight shape mismatch"); + TORCH_CHECK( + output_cell_bias.size(0) == output_count && output_cell_bias.size(1) == hidden_dim, + "output_cell_bias shape mismatch"); + + auto output_cells = at::empty({B, output_count, hidden_dim}, output_msg.options()); + const int64_t total_elements = static_cast(B) * output_count * hidden_dim; + if (total_elements == 0) { + return {output_cells}; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_readout_projection_kernel<<>>( + output_msg.data_ptr(), + value_to_output_weight.data_ptr(), + output_cell_bias.data_ptr(), + output_cells.data_ptr(), + total_elements, + output_count, + value_dim, + hidden_dim); + check_launch("registered_forward_readout_projection_kernel"); + return {output_cells}; +} + +std::vector flat_bucket_registered_forward_cells_layout_cuda( + const at::Tensor& boundary, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& output_cells, + const at::Tensor& backend_to_graph_inverse_order, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + check_cuda_float_bank(boundary, "boundary"); + check_cuda_float_bank(recurrent_hidden_backend_order, "recurrent_hidden_backend_order"); + check_cuda_float_bank(output_cells, "output_cells"); + TORCH_CHECK( + backend_to_graph_inverse_order.is_cuda() && backend_to_graph_inverse_order.is_contiguous() && + backend_to_graph_inverse_order.dim() == 1, + "backend_to_graph_inverse_order must be a contiguous CUDA rank-1 tensor"); + TORCH_CHECK( + backend_to_graph_inverse_order.scalar_type() == at::kLong || + backend_to_graph_inverse_order.scalar_type() == at::kInt, + "backend_to_graph_inverse_order must be int64 or int32"); + const int B = static_cast(boundary.size(0)); + const int input_count = static_cast(boundary.size(1)); + const int recurrent_count = static_cast(recurrent_hidden_backend_order.size(1)); + const int output_count = static_cast(output_cells.size(1)); + const int hidden_dim = static_cast(boundary.size(2)); + validate_registered_readout_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + executor_id, + bucket_ordinal, + output_count); + TORCH_CHECK(recurrent_hidden_backend_order.size(0) == B && output_cells.size(0) == B, "layout batch mismatch"); + TORCH_CHECK( + recurrent_hidden_backend_order.size(2) == hidden_dim && output_cells.size(2) == hidden_dim, + "layout hidden dimension mismatch"); + TORCH_CHECK(backend_to_graph_inverse_order.size(0) == recurrent_count, "layout recurrent order size mismatch"); + auto recurrent_hidden_graph_order = at::empty_like(recurrent_hidden_backend_order); + auto cells_out = at::empty({B, input_count + recurrent_count + output_count, hidden_dim}, boundary.options()); + if (backend_to_graph_inverse_order.scalar_type() == at::kLong) { + launch_registered_forward_cells_layout( + boundary, + recurrent_hidden_backend_order, + output_cells, + backend_to_graph_inverse_order, + recurrent_hidden_graph_order, + cells_out); + } else { + launch_registered_forward_cells_layout( + boundary, + recurrent_hidden_backend_order, + output_cells, + backend_to_graph_inverse_order, + recurrent_hidden_graph_order, + cells_out); + } + return {recurrent_hidden_graph_order, cells_out}; +} + +std::vector flat_bucket_registered_forward_sender_kv_sequence_cuda( + const at::Tensor& sender_cells_seq, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + TORCH_CHECK(sender_cells_seq.is_cuda() && sender_cells_seq.is_contiguous(), "sender_cells_seq must be CUDA contiguous"); + TORCH_CHECK(sender_cells_seq.scalar_type() == at::kFloat && sender_cells_seq.dim() == 4, "sender_cells_seq must be float32 [B,T,N,H]"); + const int B = static_cast(sender_cells_seq.size(0)); + const int T = static_cast(sender_cells_seq.size(1)); + const int sender_count = static_cast(sender_cells_seq.size(2)); + const int hidden_dim = static_cast(sender_cells_seq.size(3)); + const int head = static_cast(head_dim); + const int value = static_cast(value_dim); + const int kv_dim = head + value; + const bool has_grouped = grouped_weight.numel() > 0; + const bool has_direct = direct_weight.numel() > 0; + TORCH_CHECK(head > 0 && value > 0, "K/V dimensions must be positive"); + TORCH_CHECK(has_direct || has_grouped, "registered sender K/V projection requires direct or grouped weight"); + if (has_grouped) { + TORCH_CHECK(group_size > 0, "group_size must be positive for grouped sender K/V projection"); + TORCH_CHECK(sender_count % static_cast(group_size) == 0, "group_size must divide sender count"); + TORCH_CHECK( + grouped_weight.is_cuda() && grouped_weight.is_contiguous() && grouped_weight.scalar_type() == at::kFloat && + grouped_weight.dim() == 3, + "grouped_weight must be float32 [G,H,K+V]"); + TORCH_CHECK( + grouped_weight.size(0) == sender_count / static_cast(group_size) && + grouped_weight.size(1) == hidden_dim && grouped_weight.size(2) == kv_dim, + "grouped_weight shape mismatch"); + } else { + TORCH_CHECK( + direct_weight.is_cuda() && direct_weight.is_contiguous() && direct_weight.scalar_type() == at::kFloat && + direct_weight.dim() == 3, + "direct_weight must be float32 [N,H,K+V]"); + TORCH_CHECK( + direct_weight.size(0) == sender_count && direct_weight.size(1) == hidden_dim && direct_weight.size(2) == kv_dim, + "direct_weight shape mismatch"); + } + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "registered sender K/V sequence projection"); + auto sender_k = at::empty({T, B, sender_count, head}, sender_cells_seq.options()); + auto sender_v = at::empty({T, B, sender_count, value}, sender_cells_seq.options()); + const int64_t total_elements = static_cast(B) * T * sender_count * kv_dim; + if (total_elements == 0) { + return {sender_k, sender_v}; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_sender_kv_sequence_kernel<<>>( + sender_cells_seq.data_ptr(), + has_direct ? direct_weight.data_ptr() : nullptr, + has_grouped ? grouped_weight.data_ptr() : nullptr, + sender_k.data_ptr(), + sender_v.data_ptr(), + total_elements, + B, + T, + sender_count, + hidden_dim, + head, + value, + static_cast(group_size), + has_grouped); + check_launch("registered_forward_sender_kv_sequence_kernel"); + return {sender_k, sender_v}; +} + +std::vector flat_bucket_registered_forward_sender_kv_step_cuda( + const at::Tensor& sender_cells, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + check_cuda_float_bank(sender_cells, "sender_cells"); + const int B = static_cast(sender_cells.size(0)); + const int sender_count = static_cast(sender_cells.size(1)); + const int hidden_dim = static_cast(sender_cells.size(2)); + const int head = static_cast(head_dim); + const int value = static_cast(value_dim); + const int kv_dim = head + value; + const bool has_grouped = grouped_weight.numel() > 0; + const bool has_direct = direct_weight.numel() > 0; + TORCH_CHECK(head > 0 && value > 0, "K/V dimensions must be positive"); + TORCH_CHECK(has_direct || has_grouped, "registered sender K/V step projection requires direct or grouped weight"); + if (has_grouped) { + TORCH_CHECK(group_size > 0, "group_size must be positive for grouped sender K/V step projection"); + TORCH_CHECK(sender_count % static_cast(group_size) == 0, "group_size must divide sender count"); + TORCH_CHECK( + grouped_weight.is_cuda() && grouped_weight.is_contiguous() && grouped_weight.scalar_type() == at::kFloat && + grouped_weight.dim() == 3, + "grouped_weight must be float32 [G,H,K+V]"); + TORCH_CHECK( + grouped_weight.size(0) == sender_count / static_cast(group_size) && + grouped_weight.size(1) == hidden_dim && grouped_weight.size(2) == kv_dim, + "grouped_weight shape mismatch"); + } else { + TORCH_CHECK( + direct_weight.is_cuda() && direct_weight.is_contiguous() && direct_weight.scalar_type() == at::kFloat && + direct_weight.dim() == 3, + "direct_weight must be float32 [N,H,K+V]"); + TORCH_CHECK( + direct_weight.size(0) == sender_count && direct_weight.size(1) == hidden_dim && direct_weight.size(2) == kv_dim, + "direct_weight shape mismatch"); + } + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "registered sender K/V step projection"); + auto sender_k = at::empty({B, sender_count, head}, sender_cells.options()); + auto sender_v = at::empty({B, sender_count, value}, sender_cells.options()); + const int64_t total_elements = static_cast(B) * sender_count * kv_dim; + if (total_elements == 0) { + return {sender_k, sender_v}; + } + const int blocks = static_cast(std::min( + 4096, + (total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_forward_sender_kv_step_kernel<<>>( + sender_cells.data_ptr(), + has_direct ? direct_weight.data_ptr() : nullptr, + has_grouped ? grouped_weight.data_ptr() : nullptr, + sender_k.data_ptr(), + sender_v.data_ptr(), + total_elements, + sender_count, + hidden_dim, + head, + value, + static_cast(group_size), + has_grouped); + check_launch("registered_forward_sender_kv_step_kernel"); + return {sender_k, sender_v}; +} + +std::vector flat_bucket_registered_backward_sender_kv_projection_cuda( + const at::Tensor& sender_cells, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + const at::Tensor& grad_k, + const at::Tensor& grad_v, + int64_t group_size, + int64_t head_dim, + int64_t value_dim, + bool return_input_grad, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + check_cuda_float_bank(sender_cells, "sender_cells"); + const int B = static_cast(sender_cells.size(0)); + const int sender_count = static_cast(sender_cells.size(1)); + const int hidden_dim = static_cast(sender_cells.size(2)); + const int head = static_cast(head_dim); + const int value = static_cast(value_dim); + const int kv_dim = head + value; + const bool has_grad_k = grad_k.numel() > 0; + const bool has_grad_v = grad_v.numel() > 0; + const bool has_grouped = grouped_weight.numel() > 0; + const bool has_direct = direct_weight.numel() > 0; + TORCH_CHECK(head > 0 && value > 0, "K/V dimensions must be positive"); + TORCH_CHECK(has_grad_k || has_grad_v, "registered sender K/V backward requires a K or V gradient"); + TORCH_CHECK(has_direct || has_grouped, "registered sender K/V backward requires direct or grouped weight"); + if (has_grad_k) { + check_cuda_float_bank(grad_k, "grad_k"); + TORCH_CHECK( + grad_k.size(0) == B && grad_k.size(1) == sender_count && grad_k.size(2) == head, + "grad_k shape mismatch"); + } + if (has_grad_v) { + check_cuda_float_bank(grad_v, "grad_v"); + TORCH_CHECK( + grad_v.size(0) == B && grad_v.size(1) == sender_count && grad_v.size(2) == value, + "grad_v shape mismatch"); + } + if (has_grouped) { + TORCH_CHECK(group_size > 0, "group_size must be positive for grouped sender K/V backward"); + TORCH_CHECK(sender_count % static_cast(group_size) == 0, "group_size must divide sender count"); + TORCH_CHECK( + grouped_weight.is_cuda() && grouped_weight.is_contiguous() && grouped_weight.scalar_type() == at::kFloat && + grouped_weight.dim() == 3, + "grouped_weight must be float32 [G,H,K+V]"); + TORCH_CHECK( + grouped_weight.size(0) == sender_count / static_cast(group_size) && + grouped_weight.size(1) == hidden_dim && grouped_weight.size(2) == kv_dim, + "grouped_weight shape mismatch"); + } else { + TORCH_CHECK( + direct_weight.is_cuda() && direct_weight.is_contiguous() && direct_weight.scalar_type() == at::kFloat && + direct_weight.dim() == 3, + "direct_weight must be float32 [N,H,K+V]"); + TORCH_CHECK( + direct_weight.size(0) == sender_count && direct_weight.size(1) == hidden_dim && direct_weight.size(2) == kv_dim, + "direct_weight shape mismatch"); + } + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "registered sender K/V projection backward"); + + auto grad_sender = return_input_grad ? at::empty_like(sender_cells) : sender_cells.new_empty({0}); + const at::Tensor& active_weight = has_grouped ? grouped_weight : direct_weight; + auto grad_weight = at::empty_like(active_weight); + const auto stream = at::cuda::getCurrentCUDAStream(); + if (return_input_grad) { + const int64_t sender_elements = static_cast(B) * sender_count * hidden_dim; + if (sender_elements > 0) { + const int blocks = static_cast(std::min( + 4096, + (sender_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + registered_backward_sender_kv_sender_kernel<<>>( + has_grad_k ? grad_k.data_ptr() : nullptr, + has_grad_v ? grad_v.data_ptr() : nullptr, + has_direct ? direct_weight.data_ptr() : nullptr, + has_grouped ? grouped_weight.data_ptr() : nullptr, + grad_sender.data_ptr(), + sender_elements, + sender_count, + hidden_dim, + head, + value, + kv_dim, + static_cast(group_size), + has_grouped, + has_grad_k, + has_grad_v); + check_launch("registered_backward_sender_kv_sender_kernel"); + } + } + const int weight_count = has_grouped ? sender_count / static_cast(group_size) : sender_count; + const int64_t weight_elements = static_cast(weight_count) * hidden_dim * kv_dim; + if (weight_elements > 0) { + const int blocks = static_cast(std::min( + 4096, + (weight_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + registered_backward_sender_kv_weight_kernel<<>>( + sender_cells.data_ptr(), + has_grad_k ? grad_k.data_ptr() : nullptr, + has_grad_v ? grad_v.data_ptr() : nullptr, + grad_weight.data_ptr(), + weight_elements, + B, + sender_count, + hidden_dim, + head, + value, + kv_dim, + static_cast(group_size), + has_grouped, + has_grad_k, + has_grad_v); + check_launch("registered_backward_sender_kv_weight_kernel"); + } + return {grad_sender, grad_weight}; +} + +std::vector flat_bucket_registered_backward_sender_value_projection_cuda( + const at::Tensor& sender_cells, + const at::Tensor& direct_weight, + const at::Tensor& grouped_weight, + const at::Tensor& grad_v, + int64_t group_size, + int64_t value_dim, + bool return_input_grad, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + check_cuda_float_bank(sender_cells, "sender_value_cells"); + check_cuda_float_bank(grad_v, "sender_value_grad_v"); + const int B = static_cast(sender_cells.size(0)); + const int sender_count = static_cast(sender_cells.size(1)); + const int hidden_dim = static_cast(sender_cells.size(2)); + const int value = static_cast(value_dim); + const bool has_grouped = grouped_weight.numel() > 0; + const bool has_direct = direct_weight.numel() > 0; + TORCH_CHECK(value > 0, "registered sender value backward requires a positive value dimension"); + TORCH_CHECK(has_direct || has_grouped, "registered sender value backward requires direct or grouped value weight"); + TORCH_CHECK( + grad_v.size(0) == B && grad_v.size(1) == sender_count && grad_v.size(2) == value, + "sender value grad_v shape mismatch"); + if (has_grouped) { + TORCH_CHECK(group_size > 0, "group_size must be positive for grouped sender value backward"); + TORCH_CHECK(sender_count % static_cast(group_size) == 0, "group_size must divide sender count"); + TORCH_CHECK( + grouped_weight.is_cuda() && grouped_weight.is_contiguous() && grouped_weight.scalar_type() == at::kFloat && + grouped_weight.dim() == 3, + "grouped value weight must be float32 [G,H,V]"); + TORCH_CHECK( + grouped_weight.size(0) == sender_count / static_cast(group_size) && + grouped_weight.size(1) == hidden_dim && grouped_weight.size(2) == value, + "grouped sender value weight shape mismatch"); + } else { + TORCH_CHECK( + direct_weight.is_cuda() && direct_weight.is_contiguous() && direct_weight.scalar_type() == at::kFloat && + direct_weight.dim() == 3, + "direct value weight must be float32 [N,H,V]"); + TORCH_CHECK( + direct_weight.size(0) == sender_count && direct_weight.size(1) == hidden_dim && direct_weight.size(2) == value, + "direct sender value weight shape mismatch"); + } + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "registered sender value projection backward"); + + auto grad_sender = return_input_grad ? at::empty_like(sender_cells) : sender_cells.new_empty({0}); + const at::Tensor& active_weight = has_grouped ? grouped_weight : direct_weight; + auto grad_weight = at::empty_like(active_weight); + const auto stream = at::cuda::getCurrentCUDAStream(); + if (return_input_grad) { + const int64_t sender_elements = static_cast(B) * sender_count * hidden_dim; + if (sender_elements > 0) { + const int blocks = static_cast(std::min( + 4096, + (sender_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + registered_backward_sender_value_sender_kernel<<>>( + grad_v.data_ptr(), + has_direct ? direct_weight.data_ptr() : nullptr, + has_grouped ? grouped_weight.data_ptr() : nullptr, + grad_sender.data_ptr(), + sender_elements, + sender_count, + hidden_dim, + value, + static_cast(group_size), + has_grouped); + check_launch("registered_backward_sender_value_sender_kernel"); + } + } + const int weight_count = has_grouped ? sender_count / static_cast(group_size) : sender_count; + const int64_t weight_elements = static_cast(weight_count) * hidden_dim * value; + if (weight_elements > 0) { + const int blocks = static_cast(std::min( + 4096, + (weight_elements + kThreadsPerBlock - 1) / kThreadsPerBlock)); + registered_backward_sender_value_weight_kernel<<>>( + sender_cells.data_ptr(), + grad_v.data_ptr(), + grad_weight.data_ptr(), + weight_elements, + B, + sender_count, + hidden_dim, + value, + static_cast(group_size), + has_grouped); + check_launch("registered_backward_sender_value_weight_kernel"); + } + return {grad_sender, grad_weight}; +} + +std::vector flat_bucket_registered_backward_sparse_attention_cuda( + const at::Tensor& grad_msg, + const at::Tensor& q, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& neighbor_idx, + const at::Tensor& neighbor_valid, + const at::Tensor& edge_distance, + const at::Tensor& edge_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay) { + check_cuda_float_bank(grad_msg, "grad_msg"); + check_cuda_float_rank2(q, "q"); + check_cuda_float_bank(input_k, "input_k"); + check_cuda_float_bank(input_v, "input_v"); + check_cuda_float_bank(recurrent_k, "recurrent_k"); + check_cuda_float_bank(recurrent_v, "recurrent_v"); + check_cuda_long_rank2(neighbor_idx, "neighbor_idx"); + check_cuda_bool_rank2(neighbor_valid, "neighbor_valid"); + check_cuda_long_rank2(edge_delay, "edge_delay"); + TORCH_CHECK(edge_distance.is_cuda() && edge_distance.is_contiguous(), "edge_distance must be a CUDA tensor"); + TORCH_CHECK(edge_distance.scalar_type() == at::kFloat && edge_distance.dim() == 2, "edge_distance must be [R,M]"); + check_cuda_long_rank1(step_flat, "step_flat"); + + const int B = static_cast(input_k.size(0)); + const int input_senders = static_cast(input_k.size(1)); + const int recurrent_senders = static_cast(recurrent_k.size(1)); + const int receiver_count = static_cast(q.size(0)); + const int degree = static_cast(neighbor_idx.size(1)); + const int head_dim = static_cast(q.size(1)); + const int key_dim = static_cast(input_k.size(2)); + const int value_dim = static_cast(input_v.size(2)); + validate_registered_partitioned_attention_executor_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + receiver_count); + TORCH_CHECK(input_v.size(0) == B && recurrent_k.size(0) == B && recurrent_v.size(0) == B, "bank batch mismatch"); + TORCH_CHECK(recurrent_k.size(2) == key_dim && key_dim >= head_dim, "K/head dimension mismatch"); + TORCH_CHECK(input_v.size(2) == recurrent_v.size(2), "V dimension mismatch"); + TORCH_CHECK( + grad_msg.size(0) == B && grad_msg.size(1) == receiver_count && grad_msg.size(2) == value_dim, + "grad_msg shape mismatch"); + TORCH_CHECK(neighbor_idx.size(0) == receiver_count, "neighbor_idx receiver count mismatch"); + TORCH_CHECK(neighbor_valid.sizes() == neighbor_idx.sizes(), "neighbor_valid shape mismatch"); + TORCH_CHECK(edge_distance.sizes() == neighbor_idx.sizes(), "edge_distance shape mismatch"); + TORCH_CHECK(edge_delay.sizes() == neighbor_idx.sizes(), "edge_delay shape mismatch"); + TORCH_CHECK(step_flat.size(0) == B, "step_flat length must match batch/time dimension"); + + auto grad_q = at::zeros_like(q); + auto grad_input_k = at::zeros_like(input_k); + auto grad_input_v = at::zeros_like(input_v); + auto grad_recurrent_k = at::zeros_like(recurrent_k); + auto grad_recurrent_v = at::zeros_like(recurrent_v); + const int64_t receiver_total = static_cast(B) * receiver_count; + if (receiver_total == 0) { + return {grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v}; + } + const int blocks = static_cast(std::min( + 4096, + (receiver_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_backward_sparse_attention_kernel<<>>( + grad_msg.data_ptr(), + q.data_ptr(), + input_k.data_ptr(), + input_v.data_ptr(), + recurrent_k.data_ptr(), + recurrent_v.data_ptr(), + neighbor_idx.data_ptr(), + neighbor_valid.data_ptr(), + edge_distance.data_ptr(), + edge_delay.data_ptr(), + step_flat.data_ptr(), + grad_q.data_ptr(), + grad_input_k.data_ptr(), + grad_input_v.data_ptr(), + grad_recurrent_k.data_ptr(), + grad_recurrent_v.data_ptr(), + receiver_total, + receiver_count, + input_senders, + recurrent_senders, + degree, + head_dim, + key_dim, + value_dim, + 1.0f / std::sqrt(static_cast(head_dim > 0 ? head_dim : 1)), + static_cast(distance_scale), + use_delay); + check_launch("registered_backward_sparse_attention_kernel"); + return {grad_q, grad_input_k, grad_input_v, grad_recurrent_k, grad_recurrent_v}; +} + +std::vector flat_bucket_registered_backward_fixed_slot_context_message_cuda( + const at::Tensor& grad_msg, + const at::Tensor& query_slot, + const at::Tensor& query_context_scalar, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& output_weight, + const at::Tensor& receiver_sender_idx, + const at::Tensor& offset_distance, + const at::Tensor& offset_delay, + const at::Tensor& step_flat, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + bool use_delay) { + check_cuda_float_bank(grad_msg, "fixed-slot context grad_msg"); + check_cuda_float_rank2(query_slot, "fixed-slot context query_slot"); + check_cuda_float_rank1(query_context_scalar, "fixed-slot context query_context_scalar"); + check_cuda_float_bank(input_k, "fixed-slot context input_k"); + check_cuda_float_bank(input_v, "fixed-slot context input_v"); + check_cuda_float_bank(recurrent_k, "fixed-slot context recurrent_k"); + check_cuda_float_bank(recurrent_v, "fixed-slot context recurrent_v"); + check_cuda_float_rank2(output_weight, "fixed-slot context output_weight"); + check_cuda_int_rank2(receiver_sender_idx, "fixed-slot context receiver_sender_idx"); + check_cuda_float_rank1(offset_distance, "fixed-slot context offset_distance"); + check_cuda_int_rank1(offset_delay, "fixed-slot context offset_delay"); + check_cuda_long_rank1(step_flat, "fixed-slot context step_flat"); + TORCH_CHECK(query_context_scalar.numel() == 1, "fixed-slot context query_context_scalar must be scalar"); + const int B = static_cast(input_v.size(0)); + const int input_senders = static_cast(input_v.size(1)); + const int recurrent_senders = static_cast(recurrent_v.size(1)); + const int receiver_count = static_cast(query_slot.size(0)); + const int degree = static_cast(receiver_sender_idx.size(1)); + const int head_dim = static_cast(query_slot.size(1)); + const int value_dim = static_cast(input_v.size(2)); + const int message_dim = static_cast(grad_msg.size(2)); + validate_registered_partitioned_attention_executor_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + receiver_count); + TORCH_CHECK( + degree <= kMaxRegisteredAttentionOffsets, + "registered fixed-slot context backward supports at most ", + kMaxRegisteredAttentionOffsets, + " offsets"); + TORCH_CHECK(recurrent_senders == receiver_count, "fixed-slot context recurrent sender/receiver count mismatch"); + TORCH_CHECK(input_k.size(0) == B && recurrent_k.size(0) == B, "fixed-slot context key batch mismatch"); + TORCH_CHECK(input_k.size(1) == input_senders && recurrent_k.size(1) == recurrent_senders, "fixed-slot context key sender mismatch"); + TORCH_CHECK(input_k.size(2) == 2 * head_dim && recurrent_k.size(2) == 2 * head_dim, "fixed-slot context key dim mismatch"); + TORCH_CHECK(recurrent_v.size(0) == B && recurrent_v.size(2) == value_dim, "fixed-slot context recurrent value mismatch"); + TORCH_CHECK(value_dim >= head_dim, "fixed-slot context backward requires value_dim >= head_dim"); + TORCH_CHECK(grad_msg.size(0) == B && grad_msg.size(1) == receiver_count, "fixed-slot context grad_msg shape mismatch"); + TORCH_CHECK(output_weight.size(0) == message_dim && output_weight.size(1) == value_dim, "fixed-slot context output weight mismatch"); + TORCH_CHECK(offset_distance.size(0) == degree && offset_delay.size(0) == degree, "fixed-slot context offset metadata mismatch"); + TORCH_CHECK(step_flat.size(0) == B, "fixed-slot context step length mismatch"); + + at::Tensor grad_query_slot = at::zeros_like(query_slot); + at::Tensor grad_input_k = at::zeros_like(input_k); + at::Tensor grad_input_v = at::zeros_like(input_v); + at::Tensor grad_recurrent_k = at::zeros_like(recurrent_k); + at::Tensor grad_recurrent_v = at::zeros_like(recurrent_v); + at::Tensor grad_query_context_scalar = at::zeros_like(query_context_scalar); + at::Tensor grad_output_weight = at::zeros_like(output_weight); + const int64_t receiver_total = static_cast(B) * receiver_count; + if (receiver_total == 0) { + return { + grad_query_slot, + grad_input_k, + grad_input_v, + grad_recurrent_k, + grad_recurrent_v, + grad_query_context_scalar, + grad_output_weight, + }; + } + const int blocks = static_cast(std::min( + 4096, + (receiver_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + registered_backward_fixed_slot_context_message_kernel<<>>( + grad_msg.data_ptr(), + query_slot.data_ptr(), + query_context_scalar.data_ptr(), + input_k.data_ptr(), + input_v.data_ptr(), + recurrent_k.data_ptr(), + recurrent_v.data_ptr(), + output_weight.data_ptr(), + receiver_sender_idx.data_ptr(), + offset_distance.data_ptr(), + offset_delay.data_ptr(), + step_flat.data_ptr(), + grad_query_slot.data_ptr(), + grad_input_k.data_ptr(), + grad_input_v.data_ptr(), + grad_recurrent_k.data_ptr(), + grad_recurrent_v.data_ptr(), + grad_query_context_scalar.data_ptr(), + grad_output_weight.data_ptr(), + receiver_total, + receiver_count, + input_senders, + recurrent_senders, + degree, + head_dim, + value_dim, + message_dim, + 1.0f / std::sqrt(static_cast(2 * (head_dim > 0 ? head_dim : 1))), + static_cast(distance_scale), + use_delay); + check_launch("registered_backward_fixed_slot_context_message_kernel"); + return { + grad_query_slot, + grad_input_k, + grad_input_v, + grad_recurrent_k, + grad_recurrent_v, + grad_query_context_scalar, + grad_output_weight, + }; +} + +std::vector flat_bucket_registered_forward_readout_layout_epilogue_cuda( + const at::Tensor& boundary, + const at::Tensor& recurrent_hidden_backend_order, + const at::Tensor& input_k, + const at::Tensor& input_v, + const at::Tensor& recurrent_k, + const at::Tensor& recurrent_v, + const at::Tensor& output_q, + const at::Tensor& output_local_sender_idx, + const at::Tensor& local_distance, + const at::Tensor& value_to_output_weight, + const at::Tensor& output_cell_bias, + const at::Tensor& backend_to_graph_inverse_order, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double distance_scale, + int64_t num_input_senders) { + check_cuda_float_bank(boundary, "boundary"); + check_cuda_float_bank(recurrent_hidden_backend_order, "recurrent_hidden_backend_order"); + check_cuda_float_bank(input_k, "input_k"); + check_cuda_float_bank(input_v, "input_v"); + check_cuda_float_bank(recurrent_k, "recurrent_k"); + check_cuda_float_bank(recurrent_v, "recurrent_v"); + check_cuda_float_rank2(output_q, "output_q"); + TORCH_CHECK( + output_local_sender_idx.is_cuda() && output_local_sender_idx.is_contiguous(), + "output_local_sender_idx must be a contiguous CUDA tensor"); + TORCH_CHECK( + output_local_sender_idx.scalar_type() == at::kInt && output_local_sender_idx.dim() == 2, + "output_local_sender_idx must be int32 [O,degree]"); + TORCH_CHECK( + local_distance.is_cuda() && local_distance.is_contiguous() && local_distance.scalar_type() == at::kFloat && + local_distance.dim() == 1, + "local_distance must be float32 [degree]"); + TORCH_CHECK( + value_to_output_weight.is_cuda() && value_to_output_weight.is_contiguous() && + value_to_output_weight.scalar_type() == at::kFloat && value_to_output_weight.dim() == 3, + "value_to_output_weight must be float32 [O,V,H]"); + check_cuda_float_rank2(output_cell_bias, "output_cell_bias"); + TORCH_CHECK( + backend_to_graph_inverse_order.is_cuda() && backend_to_graph_inverse_order.is_contiguous(), + "backend_to_graph_inverse_order must be a contiguous CUDA tensor"); + TORCH_CHECK( + backend_to_graph_inverse_order.scalar_type() == at::kLong || + backend_to_graph_inverse_order.scalar_type() == at::kInt, + "backend_to_graph_inverse_order must be int64 or int32"); + TORCH_CHECK(backend_to_graph_inverse_order.dim() == 1, "backend_to_graph_inverse_order must be rank-1"); + + const int B = static_cast(boundary.size(0)); + const int input_count = static_cast(boundary.size(1)); + const int recurrent_count = static_cast(recurrent_hidden_backend_order.size(1)); + const int hidden_dim = static_cast(boundary.size(2)); + const int head_dim = static_cast(output_q.size(1)); + const int key_dim = static_cast(input_k.size(2)); + const int value_dim = static_cast(input_v.size(2)); + const int output_count = static_cast(output_q.size(0)); + const int degree = static_cast(output_local_sender_idx.size(1)); + validate_registered_readout_executor_rows( + forward_executor_rows, + forward_executor_binding_rows, + executor_id, + bucket_ordinal, + output_count); + TORCH_CHECK(num_input_senders == input_count, "num_input_senders must match input_k sender count"); + TORCH_CHECK(input_k.size(0) == B && input_k.size(1) == input_count, "input_k shape mismatch"); + TORCH_CHECK(input_v.size(0) == B && input_v.size(1) == input_count, "input_v shape mismatch"); + TORCH_CHECK(recurrent_k.size(0) == B && recurrent_k.size(1) == recurrent_count, "recurrent_k shape mismatch"); + TORCH_CHECK(recurrent_v.size(0) == B && recurrent_v.size(1) == recurrent_count, "recurrent_v shape mismatch"); + TORCH_CHECK(input_v.size(2) == value_dim && recurrent_v.size(2) == value_dim, "value dimension mismatch"); + TORCH_CHECK(recurrent_k.size(2) == key_dim && key_dim >= head_dim, "recurrent_k key dimension mismatch"); + TORCH_CHECK(output_local_sender_idx.size(0) == output_count, "output sender row mismatch"); + TORCH_CHECK(local_distance.size(0) == degree, "local_distance degree mismatch"); + TORCH_CHECK( + value_to_output_weight.size(0) == output_count && value_to_output_weight.size(1) == value_dim && + value_to_output_weight.size(2) == hidden_dim, + "value_to_output_weight shape mismatch"); + TORCH_CHECK(output_cell_bias.size(0) == output_count && output_cell_bias.size(1) == hidden_dim, "bias shape mismatch"); + TORCH_CHECK( + recurrent_hidden_backend_order.size(0) == B && recurrent_hidden_backend_order.size(2) == hidden_dim, + "recurrent hidden shape mismatch"); + TORCH_CHECK( + backend_to_graph_inverse_order.size(0) == recurrent_count, + "backend_to_graph_inverse_order length mismatch"); + + auto output_cells = at::empty({B, output_count, hidden_dim}, boundary.options()); + auto recurrent_graph_order = at::empty_like(recurrent_hidden_backend_order); + auto cells_out = at::empty({B, input_count + recurrent_count + output_count, hidden_dim}, boundary.options()); + if (backend_to_graph_inverse_order.scalar_type() == at::kLong) { + launch_registered_forward_readout_layout_epilogue( + boundary, + recurrent_hidden_backend_order, + input_k, + input_v, + recurrent_k, + recurrent_v, + output_q, + output_local_sender_idx, + local_distance, + value_to_output_weight, + output_cell_bias, + backend_to_graph_inverse_order, + output_cells, + recurrent_graph_order, + cells_out, + static_cast(distance_scale)); + } else { + launch_registered_forward_readout_layout_epilogue( + boundary, + recurrent_hidden_backend_order, + input_k, + input_v, + recurrent_k, + recurrent_v, + output_q, + output_local_sender_idx, + local_distance, + value_to_output_weight, + output_cell_bias, + backend_to_graph_inverse_order, + output_cells, + recurrent_graph_order, + cells_out, + static_cast(distance_scale)); + } + return {output_cells, recurrent_graph_order, cells_out}; +} + +std::vector flat_bucket_registered_backward_readout_layout_projection_cuda( + const at::Tensor& grad_cells_out, + const at::Tensor& output_msg, + const at::Tensor& value_to_output_weight, + const at::Tensor& graph_to_backend_order, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t input_count, + int64_t recurrent_count) { + check_cuda_float_bank(grad_cells_out, "grad_cells_out"); + check_cuda_float_bank(output_msg, "output_msg"); + TORCH_CHECK( + value_to_output_weight.is_cuda() && value_to_output_weight.is_contiguous() && + value_to_output_weight.scalar_type() == at::kFloat && value_to_output_weight.dim() == 3, + "value_to_output_weight must be float32 [O,V,H]"); + TORCH_CHECK( + graph_to_backend_order.is_cuda() && graph_to_backend_order.is_contiguous(), + "graph_to_backend_order must be a contiguous CUDA tensor"); + TORCH_CHECK( + graph_to_backend_order.scalar_type() == at::kLong || graph_to_backend_order.scalar_type() == at::kInt, + "graph_to_backend_order must be int64 or int32"); + TORCH_CHECK(graph_to_backend_order.dim() == 1, "graph_to_backend_order must be rank-1"); + const int B = static_cast(grad_cells_out.size(0)); + const int output_count = static_cast(output_msg.size(1)); + const int value_dim = static_cast(output_msg.size(2)); + const int hidden_dim = static_cast(grad_cells_out.size(2)); + const int input_cells = static_cast(input_count); + const int recurrent_cells = static_cast(recurrent_count); + validate_registered_reverse_readout_executor_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + executor_id, + bucket_ordinal, + output_count); + TORCH_CHECK(input_cells >= 0 && recurrent_cells >= 0, "input/recurrent counts must be non-negative"); + TORCH_CHECK( + grad_cells_out.size(1) == input_cells + recurrent_cells + output_count, + "grad_cells_out cell dimension does not match input/recurrent/output counts"); + TORCH_CHECK(output_msg.size(0) == B, "output_msg batch mismatch"); + TORCH_CHECK( + value_to_output_weight.size(0) == output_count && value_to_output_weight.size(1) == value_dim && + value_to_output_weight.size(2) == hidden_dim, + "value_to_output_weight shape mismatch"); + TORCH_CHECK(graph_to_backend_order.size(0) == recurrent_cells, "graph_to_backend_order length mismatch"); + + auto grad_boundary = at::empty({B, input_cells, hidden_dim}, grad_cells_out.options()); + auto grad_recurrent_hidden_backend = at::empty({B, recurrent_cells, hidden_dim}, grad_cells_out.options()); + auto grad_output_msg = at::empty_like(output_msg); + auto grad_value_to_output_weight = at::empty_like(value_to_output_weight); + auto grad_output_cell_bias = at::empty({output_count, hidden_dim}, grad_cells_out.options()); + if (graph_to_backend_order.scalar_type() == at::kLong) { + launch_registered_backward_layout_split( + grad_cells_out, + graph_to_backend_order, + grad_boundary, + grad_recurrent_hidden_backend, + input_cells, + recurrent_cells, + output_count, + hidden_dim); + } else { + launch_registered_backward_layout_split( + grad_cells_out, + graph_to_backend_order, + grad_boundary, + grad_recurrent_hidden_backend, + input_cells, + recurrent_cells, + output_count, + hidden_dim); + } + launch_registered_backward_readout_projection( + grad_cells_out, + output_msg, + value_to_output_weight, + grad_output_msg, + grad_value_to_output_weight, + grad_output_cell_bias, + input_cells, + recurrent_cells); + return { + grad_boundary, + grad_recurrent_hidden_backend, + grad_output_msg, + grad_value_to_output_weight, + grad_output_cell_bias, + }; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/parameter_reducer_program.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/parameter_reducer_program.cuh new file mode 100644 index 00000000..85c691d3 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/parameter_reducer_program.cuh @@ -0,0 +1,1198 @@ +#pragma once + +struct RegisteredParameterReducerExpectedCount { + int64_t target; + int64_t count; +}; + +struct RegisteredParameterReducerExpectedCounts { + std::vector entries; +}; + +struct RegisteredParameterReducerTrainableRole { + int64_t role; + int64_t parameter_index; +}; + +struct RegisteredParameterReducerRuntimeRole { + int64_t role; + int64_t tensor_index; +}; + +struct RegisteredParameterReducerRoleTable { + std::vector trainable_roles; + std::vector runtime_roles; +}; + +struct RegisteredParameterReducerRuntimeContext { + std::vector& param_outputs; + const RegisteredParameterReducerRoleTable& role_table; + const at::Tensor& transition_source_rows; + const at::Tensor& transition_trainable_rows; + const std::vector& sender_grad_weight_tensors; + const std::vector& sender_group_id_tensors; + const at::Tensor& sender_grouped_flags; + const std::vector& readout_grad_value_to_output_weight_tensors; + const std::vector& readout_grad_output_cell_bias_tensors; + const std::vector& recurrent_query_grad_tensors; + const std::vector& output_query_grad_tensors; + const std::vector& message_strategy_grad_tensors; + const at::Tensor& message_strategy_grad_rows; + const std::vector& transition_source_tensors; + const std::vector& transition_source_recurrent_cell_idx_tensors; + const std::vector& trainable_param_tensors; + const std::vector& runtime_metadata_tensors; + int64_t coord_count; + int64_t head_dim; + int64_t value_dim; +}; + +using RegisteredParameterReducerRunFn = void (*)(RegisteredParameterReducerRuntimeContext&); + +void run_registered_sender_kv_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context); +void run_registered_recurrent_query_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context); +void run_registered_output_query_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context); +void run_registered_readout_output_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context); +void run_registered_transition_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context); +void run_registered_fixed_slot_context_message_parameter_reducer_strategy( + RegisteredParameterReducerRuntimeContext& context); + +struct RegisteredParameterReducerHandler { + int64_t native_callable_hash; + const char* name; + RegisteredParameterReducerRunFn run; +}; + +struct RegisteredParameterReducerStrategy { + int64_t reducer_kind; + int64_t strategy_kind; + int64_t expected_count_target; + int64_t expected_count_mode; + int64_t flags; + const RegisteredParameterReducerHandler* handler; +}; + +#define REGISTERED_TEMPORAL_NATIVE_PARAMETER_REDUCER_CATALOG +#include "../flat_bucket_registered_native_callables.cuh" + +const RegisteredParameterReducerHandler& registered_parameter_reducer_handler_for_native_callable( + int64_t native_callable_hash) { + for (const RegisteredParameterReducerHandler* handler = registered_native_parameter_reducer_catalog_begin(); + handler != registered_native_parameter_reducer_catalog_end(); + ++handler) { + if (handler->native_callable_hash == native_callable_hash) { + return *handler; + } + } + TORCH_CHECK(false, "unknown parameter reducer native callable hash: ", native_callable_hash); + return *registered_native_parameter_reducer_catalog_begin(); +} + +RegisteredParameterReducerRoleTable decode_registered_parameter_reducer_role_table( + const at::Tensor& trainable_role_rows, + const at::Tensor& runtime_metadata_rows, + const std::vector& trainable_param_tensors, + const std::vector& runtime_metadata_tensors) { + RegisteredParameterReducerRoleTable table; + const int64_t* trainable_rows = trainable_role_rows.data_ptr(); + table.trainable_roles.reserve(static_cast(trainable_role_rows.size(0))); + for (int64_t row = 0; row < trainable_role_rows.size(0); ++row) { + const int64_t* item = trainable_rows + row * 6; + TORCH_CHECK(item[0] == row, "parameter reducer trainable role row index mismatch"); + TORCH_CHECK(item[1] > 0, "parameter reducer trainable role row requires a role opcode"); + TORCH_CHECK( + 0 <= item[2] && item[2] < static_cast(trainable_param_tensors.size()), + "parameter reducer trainable role references tensor outside trainable tensor table"); + for (const RegisteredParameterReducerTrainableRole& existing : table.trainable_roles) { + TORCH_CHECK( + existing.role != item[1], + "parameter reducer trainable role rows contain duplicate role: ", + item[1]); + } + table.trainable_roles.push_back(RegisteredParameterReducerTrainableRole{item[1], item[2]}); + } + + const int64_t* runtime_rows = runtime_metadata_rows.data_ptr(); + table.runtime_roles.reserve(static_cast(runtime_metadata_rows.size(0))); + for (int64_t row = 0; row < runtime_metadata_rows.size(0); ++row) { + const int64_t* item = runtime_rows + row * 4; + TORCH_CHECK(item[0] == row, "parameter reducer runtime metadata row index mismatch"); + TORCH_CHECK(item[1] > 0, "parameter reducer runtime metadata row requires a role opcode"); + TORCH_CHECK( + 0 <= item[2] && item[2] < static_cast(runtime_metadata_tensors.size()), + "parameter reducer runtime metadata row references tensor outside runtime metadata tensor table"); + for (const RegisteredParameterReducerRuntimeRole& existing : table.runtime_roles) { + TORCH_CHECK( + existing.role != item[1], + "parameter reducer runtime metadata rows contain duplicate role: ", + item[1]); + } + table.runtime_roles.push_back(RegisteredParameterReducerRuntimeRole{item[1], item[2]}); + } + return table; +} + +int64_t parameter_index_for_trainable_role( + const RegisteredParameterReducerRoleTable& table, + int64_t role, + const char* role_name) { + for (const RegisteredParameterReducerTrainableRole& item : table.trainable_roles) { + if (item.role == role) { + return item.parameter_index; + } + } + TORCH_CHECK(false, "parameter reducer missing compiler trainable role: ", role_name); + return -1; +} + +bool has_trainable_role( + const RegisteredParameterReducerRoleTable& table, + int64_t role) { + for (const RegisteredParameterReducerTrainableRole& item : table.trainable_roles) { + if (item.role == role) { + return true; + } + } + return false; +} + +const at::Tensor& trainable_tensor_for_role( + const RegisteredParameterReducerRuntimeContext& context, + int64_t role, + const char* role_name) { + const int64_t parameter_index = parameter_index_for_trainable_role(context.role_table, role, role_name); + return context.trainable_param_tensors[static_cast(parameter_index)]; +} + +at::Tensor& output_tensor_for_role( + RegisteredParameterReducerRuntimeContext& context, + int64_t role, + const char* role_name) { + const int64_t parameter_index = parameter_index_for_trainable_role(context.role_table, role, role_name); + at::Tensor& output = context.param_outputs[static_cast(parameter_index)]; + const at::Tensor& target = context.trainable_param_tensors[static_cast(parameter_index)]; + TORCH_CHECK( + output.defined() && output.numel() > 0, + "parameter reducer output for role ", + role_name, + " was not allocated by the compiler output tensor table"); + TORCH_CHECK( + output.is_cuda() && output.is_contiguous() && output.scalar_type() == at::kFloat, + "parameter reducer output tensor must be contiguous CUDA float32 for role ", + role_name); + TORCH_CHECK(output.sizes() == target.sizes(), "parameter reducer output shape mismatch for role ", role_name); + return output; +} + +const at::Tensor& runtime_tensor_for_role( + const RegisteredParameterReducerRuntimeContext& context, + int64_t role, + const char* role_name) { + for (const RegisteredParameterReducerRuntimeRole& item : context.role_table.runtime_roles) { + if (item.role == role) { + return context.runtime_metadata_tensors[static_cast(item.tensor_index)]; + } + } + TORCH_CHECK(false, "parameter reducer missing compiler runtime metadata role: ", role_name); + return context.runtime_metadata_tensors[0]; +} + +void accumulate_output_for_role( + RegisteredParameterReducerRuntimeContext& context, + int64_t role, + const char* role_name, + const at::Tensor& grad) { + at::Tensor& output = output_tensor_for_role(context, role, role_name); + output.add_(grad); +} + +void accumulate_output_for_first_available_role( + RegisteredParameterReducerRuntimeContext& context, + int64_t primary_role, + const char* primary_role_name, + int64_t alternate_role, + const char* alternate_role_name, + const at::Tensor& grad) { + if (has_trainable_role(context.role_table, primary_role)) { + accumulate_output_for_role(context, primary_role, primary_role_name, grad); + return; + } + accumulate_output_for_role(context, alternate_role, alternate_role_name, grad); +} + +void add_registered_parameter_reducer_expected_count( + RegisteredParameterReducerExpectedCounts& counts, + int64_t target, + int64_t increment) { + TORCH_CHECK(target > 0, "parameter reducer expected count target must be positive"); + for (RegisteredParameterReducerExpectedCount& entry : counts.entries) { + if (entry.target == target) { + entry.count += increment; + return; + } + } + counts.entries.push_back(RegisteredParameterReducerExpectedCount{target, increment}); +} + +int64_t registered_parameter_reducer_expected_count( + const RegisteredParameterReducerExpectedCounts& counts, + int64_t target) { + for (const RegisteredParameterReducerExpectedCount& entry : counts.entries) { + if (entry.target == target) { + return entry.count; + } + } + return 0; +} + +void accumulate_registered_parameter_reducer_expected_count( + RegisteredParameterReducerExpectedCounts& counts, + const RegisteredParameterReducerStrategy& strategy, + int64_t tensor_count) { + if (strategy.expected_count_target == kParameterReducerCountNone) { + TORCH_CHECK( + strategy.expected_count_mode == kParameterReducerCountModeNone, + "parameter reducer count target none requires count mode none"); + return; + } + const int64_t increment = strategy.expected_count_mode == kParameterReducerCountModeTensorCount + ? tensor_count + : strategy.expected_count_mode == kParameterReducerCountModeRow + ? 1 + : -1; + TORCH_CHECK(increment >= 0, "unknown parameter reducer count mode: ", strategy.expected_count_mode); + add_registered_parameter_reducer_expected_count(counts, strategy.expected_count_target, increment); +} + +std::vector decode_registered_parameter_reducer_strategy_rows( + const at::Tensor& parameter_reducer_strategy_rows) { + std::vector strategies; + strategies.reserve(static_cast(parameter_reducer_strategy_rows.size(0))); + const int64_t* rows = parameter_reducer_strategy_rows.data_ptr(); + for (int64_t row = 0; row < parameter_reducer_strategy_rows.size(0); ++row) { + const int64_t* item = rows + row * 9; + TORCH_CHECK(item[0] == row, "parameter reducer strategy row index mismatch"); + TORCH_CHECK(item[1] > 0, "parameter reducer strategy row requires reducer kind"); + TORCH_CHECK(item[2] > 0, "parameter reducer strategy row requires strategy kind"); + TORCH_CHECK(item[3] >= 0, "parameter reducer strategy row requires non-negative count target"); + TORCH_CHECK(item[4] >= 0, "parameter reducer strategy row requires non-negative count mode"); + TORCH_CHECK(item[8] > 0, "parameter reducer strategy row requires native callable hash"); + const RegisteredParameterReducerHandler& handler = + registered_parameter_reducer_handler_for_native_callable(item[8]); + for (const RegisteredParameterReducerStrategy& strategy : strategies) { + TORCH_CHECK( + strategy.reducer_kind != item[1], + "parameter reducer strategy rows contain duplicate reducer kind: ", + item[1]); + } + strategies.push_back(RegisteredParameterReducerStrategy{ + item[1], + item[2], + item[3], + item[4], + item[5], + &handler, + }); + } + return strategies; +} + +const RegisteredParameterReducerStrategy& registered_parameter_reducer_strategy_for_kind( + const std::vector& strategies, + int64_t reducer_kind) { + for (const RegisteredParameterReducerStrategy& strategy : strategies) { + if (strategy.reducer_kind == reducer_kind) { + return strategy; + } + } + TORCH_CHECK(false, "parameter reducer row references unregistered strategy kind: ", reducer_kind); + return strategies[0]; +} + +std::vector decode_registered_parameter_reducer_handlers( + const at::Tensor& parameter_reducer_rows, + const at::Tensor& parameter_reducer_strategy_rows, + RegisteredParameterReducerExpectedCounts& counts) { + const std::vector registered_strategies = + decode_registered_parameter_reducer_strategy_rows(parameter_reducer_strategy_rows); + std::vector handlers; + handlers.reserve(static_cast(registered_strategies.size())); + std::vector seen_reducer_kinds; + seen_reducer_kinds.reserve(static_cast(parameter_reducer_rows.size(0))); + const int64_t* rows = parameter_reducer_rows.data_ptr(); + for (int64_t row = 0; row < parameter_reducer_rows.size(0); ++row) { + const int64_t* item = rows + row * 8; + TORCH_CHECK(item[0] == row, "parameter reducer row index mismatch"); + TORCH_CHECK(item[6] >= 0, "parameter reducer tensor_count must be non-negative"); + const RegisteredParameterReducerStrategy& strategy = + registered_parameter_reducer_strategy_for_kind(registered_strategies, item[1]); + accumulate_registered_parameter_reducer_expected_count(counts, strategy, item[6]); + if (std::find(seen_reducer_kinds.begin(), seen_reducer_kinds.end(), strategy.reducer_kind) == + seen_reducer_kinds.end()) { + handlers.push_back(strategy); + seen_reducer_kinds.push_back(strategy.reducer_kind); + } + } + return handlers; +} + +void run_registered_sender_kv_parameter_reducer_handler( + RegisteredParameterReducerRuntimeContext& context) { + const at::Tensor& public_proj_weight = + trainable_tensor_for_role(context, kParameterTrainableRolePublicProjWeight, "public_proj_weight"); + const at::Tensor& k_weight = trainable_tensor_for_role(context, kParameterTrainableRoleKWeight, "k_weight"); + const at::Tensor& v_weight = trainable_tensor_for_role(context, kParameterTrainableRoleVWeight, "v_weight"); + const int64_t* grouped_flags = context.sender_grouped_flags.data_ptr(); + for (size_t index = 0; index < context.sender_grad_weight_tensors.size(); ++index) { + const at::Tensor& grad_weight = context.sender_grad_weight_tensors[index]; + const at::Tensor& group_ids = context.sender_group_id_tensors[index]; + if (grad_weight.numel() == 0) { + continue; + } + TORCH_CHECK( + grad_weight.is_cuda() && grad_weight.is_contiguous() && grad_weight.scalar_type() == at::kFloat && + grad_weight.dim() == 3, + "sender K/V grad_weight must be contiguous CUDA float32 [N,H,M]"); + TORCH_CHECK( + group_ids.is_contiguous() && group_ids.scalar_type() == at::kLong && group_ids.dim() == 1, + "sender K/V group_ids must be contiguous int64 rank-1"); + TORCH_CHECK(group_ids.size(0) == grad_weight.size(0), "sender K/V group_ids length mismatch"); + const bool value_only_grad = grad_weight.size(2) == context.value_dim; + const bool merged_kv_grad = grad_weight.size(2) == context.head_dim + context.value_dim; + TORCH_CHECK( + value_only_grad || merged_kv_grad, + "sender K/V grad merged dimension mismatch"); + TORCH_CHECK(grad_weight.size(1) == public_proj_weight.size(1), "sender K/V grad hidden width mismatch"); + const at::Tensor expanded_public = + public_proj_weight.unsqueeze(0).expand({grad_weight.size(0), -1, -1}); + at::Tensor& grad_public = + output_tensor_for_role(context, kParameterTrainableRolePublicProjWeight, "public_proj_weight"); + at::Tensor& grad_v = output_tensor_for_role(context, kParameterTrainableRoleVWeight, "v_weight"); + if (merged_kv_grad) { + const at::Tensor selected_k = k_weight.index_select(0, group_ids); + const at::Tensor selected_v = v_weight.index_select(0, group_ids); + const at::Tensor grad_k_weight = grad_weight.slice(2, 0, context.head_dim); + const at::Tensor grad_v_weight = grad_weight.slice(2, context.head_dim, context.head_dim + context.value_dim); + at::Tensor& grad_k = output_tensor_for_role(context, kParameterTrainableRoleKWeight, "k_weight"); + grad_public.add_(at::sum(at::bmm(selected_k, grad_k_weight.transpose(1, 2)), {0})); + grad_public.add_(at::sum(at::bmm(selected_v, grad_v_weight.transpose(1, 2)), {0})); + grad_k.index_add_(0, group_ids, at::bmm(expanded_public, grad_k_weight)); + grad_v.index_add_(0, group_ids, at::bmm(expanded_public, grad_v_weight)); + } else { + const at::Tensor selected_v = v_weight.index_select(0, group_ids); + grad_public.add_(at::sum(at::bmm(selected_v, grad_weight.transpose(1, 2)), {0})); + grad_v.index_add_(0, group_ids, at::bmm(expanded_public, grad_weight)); + } + TORCH_CHECK( + grouped_flags[index] == 0 || grouped_flags[index] == 1, + "sender K/V grouped flag must be 0 or 1"); + } +} + +void accumulate_registered_query_parameter_grads( + RegisteredParameterReducerRuntimeContext& context, + const at::Tensor& grad_q_full, + const at::Tensor& q_proj_weight, + const at::Tensor& slot_embed) { + at::Tensor& grad_slot_embed = + output_tensor_for_role(context, kParameterTrainableRoleSlotEmbed, "slot_embed"); + at::Tensor& grad_q_proj = + output_tensor_for_role(context, kParameterTrainableRoleQProjWeight, "q_proj_weight"); + grad_slot_embed.add_(grad_q_full.matmul(q_proj_weight)); + grad_q_proj.add_(grad_q_full.t().matmul(slot_embed)); +} + +void accumulate_registered_slot_linear_parameter_grads( + RegisteredParameterReducerRuntimeContext& context, + const at::Tensor& grad_full, + int64_t weight_role, + const char* weight_role_name) { + const at::Tensor& projection_weight = trainable_tensor_for_role(context, weight_role, weight_role_name); + const at::Tensor& slot_embed = + trainable_tensor_for_role(context, kParameterTrainableRoleSlotEmbed, "slot_embed"); + at::Tensor& grad_slot_embed = + output_tensor_for_role(context, kParameterTrainableRoleSlotEmbed, "slot_embed"); + at::Tensor& grad_projection = output_tensor_for_role(context, weight_role, weight_role_name); + TORCH_CHECK( + grad_full.dim() == 2 && grad_full.size(0) == slot_embed.size(0) && + grad_full.size(1) == projection_weight.size(0), + "slot-linear parameter reducer grad shape mismatch for ", + weight_role_name); + grad_slot_embed.add_(grad_full.matmul(projection_weight)); + grad_projection.add_(grad_full.t().matmul(slot_embed)); +} + +void run_registered_recurrent_query_parameter_reducer_handler( + RegisteredParameterReducerRuntimeContext& context) { + if (context.recurrent_query_grad_tensors.empty()) { + return; + } + const at::Tensor& q_proj_weight = + trainable_tensor_for_role(context, kParameterTrainableRoleQProjWeight, "q_proj_weight"); + const at::Tensor& slot_embed = + trainable_tensor_for_role(context, kParameterTrainableRoleSlotEmbed, "slot_embed"); + const at::Tensor& population_backend_recurrent_inverse_order = runtime_tensor_for_role( + context, + kParameterRuntimeRoleBackendRecurrentInverseOrder, + "population_backend_recurrent_inverse_order"); + const at::Tensor& recurrent_cell_idx = + runtime_tensor_for_role(context, kParameterRuntimeRoleRecurrentCellIdx, "recurrent_cell_idx"); + at::Tensor grad_q_full = at::zeros({context.coord_count, context.head_dim}, slot_embed.options()); + for (const at::Tensor& grad_recurrent_q_backend : context.recurrent_query_grad_tensors) { + check_cuda_float_rank2(grad_recurrent_q_backend, "grad_recurrent_q_backend"); + TORCH_CHECK( + grad_recurrent_q_backend.size(1) == context.head_dim, + "grad_recurrent_q_backend head dimension mismatch"); + at::Tensor inverse_order = + population_backend_recurrent_inverse_order.to(grad_recurrent_q_backend.device(), at::kLong); + at::Tensor recurrent_index = recurrent_cell_idx.to(grad_recurrent_q_backend.device(), at::kLong); + TORCH_CHECK( + grad_recurrent_q_backend.size(0) == inverse_order.size(0), + "grad_recurrent_q_backend recurrent count mismatch"); + at::Tensor grad_recurrent_q = grad_recurrent_q_backend.index_select(0, inverse_order); + grad_q_full.index_add_(0, recurrent_index, grad_recurrent_q); + } + accumulate_registered_query_parameter_grads(context, grad_q_full, q_proj_weight, slot_embed); +} + +void run_registered_output_query_parameter_reducer_handler( + RegisteredParameterReducerRuntimeContext& context) { + if (context.output_query_grad_tensors.empty()) { + return; + } + const at::Tensor& q_proj_weight = + trainable_tensor_for_role(context, kParameterTrainableRoleQProjWeight, "q_proj_weight"); + const at::Tensor& slot_embed = + trainable_tensor_for_role(context, kParameterTrainableRoleSlotEmbed, "slot_embed"); + const at::Tensor& output_cell_idx = + runtime_tensor_for_role(context, kParameterRuntimeRoleOutputCellIdx, "output_cell_idx"); + at::Tensor grad_q_full = at::zeros({context.coord_count, context.head_dim}, slot_embed.options()); + for (const at::Tensor& grad_output_q : context.output_query_grad_tensors) { + check_cuda_float_rank2(grad_output_q, "grad_output_q"); + TORCH_CHECK(grad_output_q.size(1) == context.head_dim, "grad_output_q head dimension mismatch"); + at::Tensor output_index = output_cell_idx.to(grad_output_q.device(), at::kLong); + TORCH_CHECK(grad_output_q.size(0) == output_index.size(0), "grad_output_q output count mismatch"); + grad_q_full.index_add_(0, output_index, grad_output_q); + } + accumulate_registered_query_parameter_grads(context, grad_q_full, q_proj_weight, slot_embed); +} + +std::vector message_strategy_grad_tensors_for_role( + const RegisteredParameterReducerRuntimeContext& context, + int64_t reducer_kind, + int64_t role) { + check_cpu_long_rank2(context.message_strategy_grad_rows, "message_strategy_grad_rows", 5); + std::vector tensors; + const int64_t* rows = context.message_strategy_grad_rows.data_ptr(); + for (int64_t row = 0; row < context.message_strategy_grad_rows.size(0); ++row) { + const int64_t* item = rows + row * 5; + TORCH_CHECK(item[0] == row, "message strategy grad row index mismatch"); + TORCH_CHECK(item[1] > 0, "message strategy grad row requires reducer kind"); + TORCH_CHECK(item[2] > 0, "message strategy grad row requires output role"); + TORCH_CHECK( + 0 <= item[3] && item[3] < static_cast(context.message_strategy_grad_tensors.size()), + "message strategy grad row tensor index is outside tensor table"); + if (item[1] == reducer_kind && item[2] == role) { + tensors.push_back(context.message_strategy_grad_tensors[static_cast(item[3])]); + } + } + return tensors; +} + +void run_registered_fixed_slot_context_message_parameter_reducer_handler( + RegisteredParameterReducerRuntimeContext& context) { + const std::vector query_slot_grad_tensors = message_strategy_grad_tensors_for_role( + context, + kParameterReducerFixedSlotContextMessage, + kMessageStrategyGradQuerySlotBackend); + const std::vector input_key_grad_tensors = message_strategy_grad_tensors_for_role( + context, + kParameterReducerFixedSlotContextMessage, + kMessageStrategyGradInputKeyBank); + const std::vector recurrent_key_grad_tensors = message_strategy_grad_tensors_for_role( + context, + kParameterReducerFixedSlotContextMessage, + kMessageStrategyGradRecurrentKeyBank); + const std::vector query_context_scalar_grad_tensors = message_strategy_grad_tensors_for_role( + context, + kParameterReducerFixedSlotContextMessage, + kMessageStrategyGradQueryContextScalar); + const std::vector output_weight_grad_tensors = message_strategy_grad_tensors_for_role( + context, + kParameterReducerFixedSlotContextMessage, + kMessageStrategyGradOutputWeight); + if (query_slot_grad_tensors.empty() && input_key_grad_tensors.empty() && recurrent_key_grad_tensors.empty() && + query_context_scalar_grad_tensors.empty() && output_weight_grad_tensors.empty()) { + return; + } + const at::Tensor& slot_embed = + trainable_tensor_for_role(context, kParameterTrainableRoleSlotEmbed, "slot_embed"); + const at::Tensor& context_key_param = + trainable_tensor_for_role(context, kParameterTrainableRoleMessageSenderContextKey, "message_sender_context_key"); + const at::Tensor& population_backend_recurrent_inverse_order = runtime_tensor_for_role( + context, + kParameterRuntimeRoleBackendRecurrentInverseOrder, + "population_backend_recurrent_inverse_order"); + const at::Tensor& recurrent_cell_idx = + runtime_tensor_for_role(context, kParameterRuntimeRoleRecurrentCellIdx, "recurrent_cell_idx"); + const at::Tensor& input_cell_idx = + runtime_tensor_for_role(context, kParameterRuntimeRoleInputCellIdx, "input_cell_idx"); + at::Tensor grad_query_slot_full = at::zeros({context.coord_count, context.head_dim}, slot_embed.options()); + for (const at::Tensor& grad_query_slot_backend : query_slot_grad_tensors) { + check_cuda_float_rank2(grad_query_slot_backend, "fixed-slot context grad_query_slot_backend"); + TORCH_CHECK( + grad_query_slot_backend.size(1) == context.head_dim, + "fixed-slot context grad_query_slot_backend head dimension mismatch"); + at::Tensor inverse_order = + population_backend_recurrent_inverse_order.to(grad_query_slot_backend.device(), at::kLong); + at::Tensor recurrent_index = recurrent_cell_idx.to(grad_query_slot_backend.device(), at::kLong); + TORCH_CHECK( + grad_query_slot_backend.size(0) == inverse_order.size(0), + "fixed-slot context grad_query_slot_backend recurrent count mismatch"); + at::Tensor grad_query_slot = grad_query_slot_backend.index_select(0, inverse_order); + grad_query_slot_full.index_add_(0, recurrent_index, grad_query_slot); + } + if (grad_query_slot_full.numel() > 0) { + accumulate_registered_slot_linear_parameter_grads( + context, + grad_query_slot_full, + kParameterTrainableRoleMessageQuerySlotProjWeight, + "message_query_slot_proj_weight"); + } + + at::Tensor grad_sender_slot_key_full = at::zeros({context.coord_count, context.head_dim}, slot_embed.options()); + at::Tensor grad_sender_context_key_full = at::zeros_like(context_key_param); + auto accumulate_key_bank = [&](const at::Tensor& grad_key_bank, const at::Tensor& cell_index) { + TORCH_CHECK( + grad_key_bank.is_cuda() && grad_key_bank.is_contiguous() && grad_key_bank.scalar_type() == at::kFloat && + (grad_key_bank.dim() == 2 || grad_key_bank.dim() == 3), + "fixed-slot context key bank grad must be contiguous CUDA float32 [N,2H] or [B,N,2H]"); + TORCH_CHECK( + grad_key_bank.size(grad_key_bank.dim() - 1) == 2 * context.head_dim, + "fixed-slot context key bank grad last dimension must be 2*head_dim"); + at::Tensor index = cell_index.to(grad_key_bank.device(), at::kLong).contiguous(); + at::Tensor reduced = + grad_key_bank.dim() == 3 ? grad_key_bank.sum(0).contiguous() : grad_key_bank.contiguous(); + TORCH_CHECK( + reduced.size(0) == index.size(0), + "fixed-slot context key bank grad sender count does not match runtime cell index"); + grad_sender_slot_key_full.index_add_(0, index, reduced.slice(1, 0, context.head_dim)); + grad_sender_context_key_full.index_add_(0, index, reduced.slice(1, context.head_dim, 2 * context.head_dim)); + }; + at::Tensor input_index = input_cell_idx.to(slot_embed.device(), at::kLong).contiguous(); + at::Tensor recurrent_index = recurrent_cell_idx.to(slot_embed.device(), at::kLong).contiguous(); + for (const at::Tensor& grad_input_key_bank : input_key_grad_tensors) { + accumulate_key_bank(grad_input_key_bank, input_index); + } + for (const at::Tensor& grad_recurrent_key_bank_backend : recurrent_key_grad_tensors) { + TORCH_CHECK( + grad_recurrent_key_bank_backend.is_cuda() && grad_recurrent_key_bank_backend.is_contiguous() && + grad_recurrent_key_bank_backend.scalar_type() == at::kFloat && + (grad_recurrent_key_bank_backend.dim() == 2 || grad_recurrent_key_bank_backend.dim() == 3), + "fixed-slot context recurrent key bank grad must be contiguous CUDA float32 [R,2H] or [B,R,2H]"); + at::Tensor inverse_order = + population_backend_recurrent_inverse_order.to(grad_recurrent_key_bank_backend.device(), at::kLong); + TORCH_CHECK( + grad_recurrent_key_bank_backend.size(grad_recurrent_key_bank_backend.dim() - 2) == inverse_order.size(0), + "fixed-slot context recurrent key bank grad recurrent count mismatch"); + at::Tensor reduced_backend = grad_recurrent_key_bank_backend.dim() == 3 + ? grad_recurrent_key_bank_backend.sum(0).contiguous() + : grad_recurrent_key_bank_backend.contiguous(); + at::Tensor reduced_graph = reduced_backend.index_select(0, inverse_order).contiguous(); + accumulate_key_bank(reduced_graph, recurrent_index); + } + if (!input_key_grad_tensors.empty() || !recurrent_key_grad_tensors.empty()) { + accumulate_registered_slot_linear_parameter_grads( + context, + grad_sender_slot_key_full, + kParameterTrainableRoleMessageSenderSlotKeyProjWeight, + "message_sender_slot_key_proj_weight"); + accumulate_output_for_role( + context, + kParameterTrainableRoleMessageSenderContextKey, + "message_sender_context_key", + grad_sender_context_key_full); + } + for (const at::Tensor& grad_query_context_scalar : query_context_scalar_grad_tensors) { + check_cuda_float_rank1(grad_query_context_scalar, "fixed-slot context grad_query_context_scalar"); + accumulate_output_for_first_available_role( + context, + kParameterTrainableRoleMessageQueryNudgeScale, + "message_query_nudge_scale", + kParameterTrainableRoleMessageQueryContextGate, + "message_query_context_gate", + grad_query_context_scalar); + } + for (const at::Tensor& grad_output_weight : output_weight_grad_tensors) { + check_cuda_float_rank2(grad_output_weight, "fixed-slot context grad_output_weight"); + accumulate_output_for_role( + context, + kParameterTrainableRoleMsgOutWeight, + "msg_out_weight", + grad_output_weight); + } +} + +void run_registered_readout_output_parameter_reducer_handler( + RegisteredParameterReducerRuntimeContext& context) { + const at::Tensor& msg_out_weight = + trainable_tensor_for_role(context, kParameterTrainableRoleMsgOutWeight, "msg_out_weight"); + const at::Tensor& output_cell_weight = + trainable_tensor_for_role(context, kParameterTrainableRoleOutputCellWeight, "output_cell_weight"); + const at::Tensor& output_cell_bias = + trainable_tensor_for_role(context, kParameterTrainableRoleOutputCellBias, "output_cell_bias"); + for (const at::Tensor& grad_value_to_output_weight : context.readout_grad_value_to_output_weight_tensors) { + TORCH_CHECK( + grad_value_to_output_weight.is_cuda() && grad_value_to_output_weight.is_contiguous() && + grad_value_to_output_weight.scalar_type() == at::kFloat && + grad_value_to_output_weight.dim() == 3, + "grad_value_to_output_weight must be contiguous CUDA float32 [P,V,H]"); + TORCH_CHECK( + grad_value_to_output_weight.size(0) == output_cell_weight.size(0) && + grad_value_to_output_weight.size(1) == msg_out_weight.size(1) && + grad_value_to_output_weight.size(2) == output_cell_weight.size(2), + "grad_value_to_output_weight shape mismatch"); + const at::Tensor grad_msg_out_step = + at::sum(at::bmm(output_cell_weight, grad_value_to_output_weight.transpose(1, 2)), {0}); + const at::Tensor grad_output_cell_weight_step = + at::bmm(msg_out_weight.unsqueeze(0).expand({grad_value_to_output_weight.size(0), -1, -1}), + grad_value_to_output_weight); + accumulate_output_for_role(context, kParameterTrainableRoleMsgOutWeight, "msg_out_weight", grad_msg_out_step); + accumulate_output_for_role( + context, + kParameterTrainableRoleOutputCellWeight, + "output_cell_weight", + grad_output_cell_weight_step); + } + for (const at::Tensor& grad_bias : context.readout_grad_output_cell_bias_tensors) { + check_cuda_float_rank2(grad_bias, "grad_output_cell_bias"); + TORCH_CHECK( + grad_bias.size(0) == output_cell_bias.size(0) && grad_bias.size(1) == output_cell_bias.size(1), + "grad_output_cell_bias shape mismatch"); + accumulate_output_for_role(context, kParameterTrainableRoleOutputCellBias, "output_cell_bias", grad_bias); + } +} + +using RegisteredTransitionTrainableReducerRunFn = at::Tensor (*)( + const at::Tensor& source_grad, + const at::Tensor& target_param, + int64_t aux_index, + int64_t source_row, + const std::vector& trainable_param_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + int64_t coord_count); + +struct RegisteredTransitionTrainableReducerHandler { + int64_t native_callable_hash; + const char* name; + RegisteredTransitionTrainableReducerRunFn run; +}; + +at::Tensor align_or_scatter_registered_transition_grad_to_target( + const at::Tensor& source_grad, + const at::Tensor& target_param, + const at::Tensor& source_row_indices, + bool reduce_leading_dims) { + if (source_row_indices.defined() && source_row_indices.numel() > 0) { + at::Tensor row_indices = source_row_indices.to(source_grad.device(), at::kLong).contiguous(); + at::Tensor reduced = source_grad; + if (reduce_leading_dims && target_param.dim() > 0 && target_param.size(0) != 1 && + reduced.dim() >= target_param.dim() && reduced.dim() > 0 && reduced.size(0) == row_indices.size(0)) { + while (reduced.dim() > target_param.dim()) { + reduced = at::sum(reduced, {1}); + } + } + if (reduced.dim() == target_param.dim() && reduced.dim() > 0 && reduced.size(0) == row_indices.size(0)) { + bool trailing_shape_matches = true; + for (int64_t dim = 1; dim < reduced.dim(); ++dim) { + trailing_shape_matches = trailing_shape_matches && reduced.size(dim) == target_param.size(dim); + } + const bool row_indices_in_target = + row_indices.numel() > 0 && + at::min(row_indices).item() >= 0 && + at::max(row_indices).item() < target_param.size(0); + if (trailing_shape_matches && row_indices_in_target) { + at::Tensor scattered = at::zeros_like(target_param); + scattered.index_add_(0, row_indices, reduced.contiguous()); + return scattered; + } + } + } + return align_registered_transition_grad_to_target(source_grad, target_param, reduce_leading_dims); +} + +at::Tensor run_registered_transition_materialized_base_reducer( + const at::Tensor& source_grad, + const at::Tensor& target_param, + int64_t aux_index, + int64_t source_row, + const std::vector& trainable_param_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + int64_t coord_count) { + (void)aux_index; + (void)source_row; + (void)trainable_param_tensors; + (void)transition_source_recurrent_cell_idx_tensors; + (void)coord_count; + return align_registered_transition_grad_to_target(source_grad, target_param, true); +} + +at::Tensor run_registered_transition_materialized_delta_reducer( + const at::Tensor& source_grad, + const at::Tensor& target_param, + int64_t aux_index, + int64_t source_row, + const std::vector& trainable_param_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + int64_t coord_count) { + (void)aux_index; + (void)trainable_param_tensors; + (void)coord_count; + const at::Tensor& source_row_indices = + transition_source_recurrent_cell_idx_tensors[static_cast(source_row)]; + return align_or_scatter_registered_transition_grad_to_target(source_grad, target_param, source_row_indices, false); +} + +at::Tensor run_registered_transition_value_to_cell_msg_to_cell_reducer( + const at::Tensor& source_grad, + const at::Tensor& target_param, + int64_t aux_index, + int64_t source_row, + const std::vector& trainable_param_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + int64_t coord_count) { + (void)source_row; + (void)transition_source_recurrent_cell_idx_tensors; + (void)coord_count; + TORCH_CHECK(aux_index >= 0, "value_to_cell msg_to_cell reducer requires aux msg_out.weight"); + const at::Tensor& msg_out = trainable_param_tensors[static_cast(aux_index)]; + check_cuda_float_rank2(source_grad, "value_to_cell source grad"); + check_cuda_float_rank2(msg_out, "value_to_cell aux msg_out.weight"); + const at::Tensor& source_row_indices = + transition_source_recurrent_cell_idx_tensors[static_cast(source_row)]; + return align_or_scatter_registered_transition_grad_to_target( + source_grad.matmul(msg_out.t()), + target_param, + source_row_indices, + true); +} + +at::Tensor run_registered_transition_value_to_cell_msg_out_reducer( + const at::Tensor& source_grad, + const at::Tensor& target_param, + int64_t aux_index, + int64_t source_row, + const std::vector& trainable_param_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + int64_t coord_count) { + (void)source_row; + (void)transition_source_recurrent_cell_idx_tensors; + (void)coord_count; + TORCH_CHECK(aux_index >= 0, "value_to_cell msg_out reducer requires aux msg_to_cell.weight"); + const at::Tensor& msg_to_cell = trainable_param_tensors[static_cast(aux_index)]; + check_cuda_float_rank2(source_grad, "value_to_cell source grad"); + check_cuda_float_rank2(msg_to_cell, "value_to_cell aux msg_to_cell.weight"); + const at::Tensor& source_row_indices = + transition_source_recurrent_cell_idx_tensors[static_cast(source_row)]; + return align_or_scatter_registered_transition_grad_to_target( + msg_to_cell.t().matmul(source_grad), + target_param, + source_row_indices, + true); +} + +at::Tensor run_registered_transition_recurrent_bias_slot_embed_reducer( + const at::Tensor& source_grad, + const at::Tensor& target_param, + int64_t aux_index, + int64_t source_row, + const std::vector& trainable_param_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + int64_t coord_count) { + TORCH_CHECK(aux_index >= 0, "recurrent bias slot_embed reducer requires aux cell_bias_proj.weight"); + const at::Tensor& cell_bias_proj = trainable_param_tensors[static_cast(aux_index)]; + check_cuda_float_rank2(cell_bias_proj, "recurrent bias aux cell_bias_proj.weight"); + at::Tensor recurrent_index = + transition_source_recurrent_cell_idx_tensors[static_cast(source_row)] + .to(source_grad.device(), at::kLong) + .contiguous(); + at::Tensor full_bias_grad = transition_recurrent_bias_full_grad(source_grad, recurrent_index, coord_count); + return align_registered_transition_grad_to_target(full_bias_grad.matmul(cell_bias_proj), target_param, true); +} + +at::Tensor run_registered_transition_recurrent_bias_cell_bias_proj_reducer( + const at::Tensor& source_grad, + const at::Tensor& target_param, + int64_t aux_index, + int64_t source_row, + const std::vector& trainable_param_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + int64_t coord_count) { + TORCH_CHECK(aux_index >= 0, "recurrent bias cell_bias_proj reducer requires aux slot_embed"); + const at::Tensor& slot_embed_param = trainable_param_tensors[static_cast(aux_index)]; + check_cuda_float_rank2(slot_embed_param, "recurrent bias aux slot_embed"); + at::Tensor recurrent_index = + transition_source_recurrent_cell_idx_tensors[static_cast(source_row)] + .to(source_grad.device(), at::kLong) + .contiguous(); + at::Tensor full_bias_grad = transition_recurrent_bias_full_grad(source_grad, recurrent_index, coord_count); + return align_registered_transition_grad_to_target(full_bias_grad.t().matmul(slot_embed_param), target_param, true); +} + +#define REGISTERED_TEMPORAL_NATIVE_TRANSITION_TRAINABLE_REDUCER_CATALOG +#include "../flat_bucket_registered_native_callables.cuh" + +const RegisteredTransitionTrainableReducerHandler& registered_transition_trainable_reducer_handler_for_native_callable( + int64_t native_callable_hash) { + for (const RegisteredTransitionTrainableReducerHandler* handler = + registered_native_transition_trainable_reducer_catalog_begin(); + handler != registered_native_transition_trainable_reducer_catalog_end(); + ++handler) { + if (handler->native_callable_hash == native_callable_hash) { + return *handler; + } + } + TORCH_CHECK(false, "unknown transition trainable reducer native callable hash: ", native_callable_hash); + return *registered_native_transition_trainable_reducer_catalog_begin(); +} + +at::Tensor run_registered_transition_trainable_reducer_handler( + const RegisteredTransitionTrainableReducerHandler& handler, + const at::Tensor& source_grad, + const at::Tensor& target_param, + int64_t aux_index, + int64_t source_row, + const std::vector& trainable_param_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + int64_t coord_count) { + try { + return handler.run( + source_grad, + target_param, + aux_index, + source_row, + trainable_param_tensors, + transition_source_recurrent_cell_idx_tensors, + coord_count); + } catch (const c10::Error& error) { + TORCH_CHECK( + false, + "transition trainable reducer failed: handler=", + handler.name, + "; source_row=", + source_row, + "; source_shape=", + source_grad.sizes(), + "; target_shape=", + target_param.sizes(), + "; reason=", + error.what_without_backtrace()); + } +} + +void validate_registered_transition_trainable_reducer_rows( + const at::Tensor& transition_trainable_rows, + const std::vector& trainable_param_tensors) { + const int64_t* transition_trainable_row_ptr = transition_trainable_rows.data_ptr(); + for (int64_t row = 0; row < transition_trainable_rows.size(0); ++row) { + const int64_t* item = transition_trainable_row_ptr + row * 9; + TORCH_CHECK(item[0] == row, "transition trainable row index mismatch"); + TORCH_CHECK( + 0 <= item[4] && item[4] < static_cast(trainable_param_tensors.size()), + "transition trainable row target index is outside trainable tensor table"); + TORCH_CHECK( + item[5] == -1 || + (0 <= item[5] && item[5] < static_cast(trainable_param_tensors.size())), + "transition trainable row aux index is outside trainable tensor table"); + TORCH_CHECK(item[8] > 0, "transition trainable row requires native callable hash"); + (void)registered_transition_trainable_reducer_handler_for_native_callable(item[8]); + } +} + +void validate_registered_parameter_output_tensors( + const std::vector& trainable_param_tensors, + const std::vector& param_outputs) { + TORCH_CHECK( + param_outputs.size() == trainable_param_tensors.size(), + "parameter reducer output tensor table must align with trainable tensor table"); + for (size_t index = 0; index < trainable_param_tensors.size(); ++index) { + const at::Tensor& trainable_param = trainable_param_tensors[index]; + TORCH_CHECK( + trainable_param.is_cuda() && trainable_param.is_contiguous() && + trainable_param.scalar_type() == at::kFloat, + "transition trainable parameter table tensors must be contiguous CUDA float32"); + const at::Tensor& output = param_outputs[index]; + TORCH_CHECK( + output.defined() && output.is_cuda() && output.is_contiguous() && output.scalar_type() == at::kFloat, + "parameter reducer output tensor table entries must be defined contiguous CUDA float32 tensors"); + TORCH_CHECK( + output.numel() == 0 || output.sizes() == trainable_param.sizes(), + "parameter reducer output tensor must be zero-size sentinel or match trainable shape at index ", + static_cast(index)); + } +} + +void run_registered_transition_parameter_reducer_handler( + std::vector& param_outputs, + const at::Tensor& transition_source_rows, + const at::Tensor& transition_trainable_rows, + const std::vector& transition_source_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + const std::vector& trainable_param_tensors, + int64_t coord_count) { + const int64_t* transition_source_row_ptr = transition_source_rows.data_ptr(); + std::vector transition_reduced_sources; + transition_reduced_sources.reserve(static_cast(transition_source_rows.size(0))); + for (int64_t row = 0; row < transition_source_rows.size(0); ++row) { + const int64_t* item = transition_source_row_ptr + row * 8; + transition_reduced_sources.push_back( + reduce_registered_transition_source_tensors(transition_source_tensors, item[4], item[5])); + } + const int64_t* transition_trainable_row_ptr = transition_trainable_rows.data_ptr(); + for (int64_t row = 0; row < transition_trainable_rows.size(0); ++row) { + const int64_t* item = transition_trainable_row_ptr + row * 9; + const int64_t request_index = item[1]; + const int64_t source_name_index = item[2]; + const int64_t target_index = item[4]; + const int64_t aux_index = item[5]; + const int64_t native_callable_hash = item[8]; + const int64_t source_row = + find_registered_transition_source_row(transition_source_rows, request_index, source_name_index); + TORCH_CHECK(source_row >= 0, "transition trainable row references a missing source row"); + const RegisteredTransitionTrainableReducerHandler& handler = + registered_transition_trainable_reducer_handler_for_native_callable(native_callable_hash); + const at::Tensor& source_grad = transition_reduced_sources[static_cast(source_row)]; + const at::Tensor& target_param = trainable_param_tensors[static_cast(target_index)]; + at::Tensor grad_param = run_registered_transition_trainable_reducer_handler( + handler, + source_grad, + target_param, + aux_index, + source_row, + trainable_param_tensors, + transition_source_recurrent_cell_idx_tensors, + coord_count); + at::Tensor& output = param_outputs[static_cast(target_index)]; + TORCH_CHECK( + output.defined() && output.numel() > 0 && output.sizes() == target_param.sizes(), + "transition parameter reducer target output was not allocated by the compiler output tensor table"); + output.add_(grad_param); + } +} + +void run_registered_sender_kv_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context) { + run_registered_sender_kv_parameter_reducer_handler(context); +} + +void run_registered_recurrent_query_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context) { + run_registered_recurrent_query_parameter_reducer_handler(context); +} + +void run_registered_output_query_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context) { + run_registered_output_query_parameter_reducer_handler(context); +} + +void run_registered_readout_output_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context) { + run_registered_readout_output_parameter_reducer_handler(context); +} + +void run_registered_transition_parameter_reducer_strategy(RegisteredParameterReducerRuntimeContext& context) { + run_registered_transition_parameter_reducer_handler( + context.param_outputs, + context.transition_source_rows, + context.transition_trainable_rows, + context.transition_source_tensors, + context.transition_source_recurrent_cell_idx_tensors, + context.trainable_param_tensors, + context.coord_count); +} + +void run_registered_fixed_slot_context_message_parameter_reducer_strategy( + RegisteredParameterReducerRuntimeContext& context) { + run_registered_fixed_slot_context_message_parameter_reducer_handler(context); +} + +void run_registered_parameter_reducer_handler( + const RegisteredParameterReducerStrategy& strategy, + std::vector& param_outputs, + const RegisteredParameterReducerRoleTable& role_table, + const at::Tensor& transition_source_rows, + const at::Tensor& transition_trainable_rows, + const std::vector& sender_grad_weight_tensors, + const std::vector& sender_group_id_tensors, + const at::Tensor& sender_grouped_flags, + const std::vector& readout_grad_value_to_output_weight_tensors, + const std::vector& readout_grad_output_cell_bias_tensors, + const std::vector& recurrent_query_grad_tensors, + const std::vector& output_query_grad_tensors, + const std::vector& message_strategy_grad_tensors, + const at::Tensor& message_strategy_grad_rows, + const std::vector& transition_source_tensors, + const std::vector& transition_source_recurrent_cell_idx_tensors, + const std::vector& trainable_param_tensors, + const std::vector& runtime_metadata_tensors, + int64_t coord_count, + int64_t head_dim, + int64_t value_dim) { + RegisteredParameterReducerRuntimeContext context{ + param_outputs, + role_table, + transition_source_rows, + transition_trainable_rows, + sender_grad_weight_tensors, + sender_group_id_tensors, + sender_grouped_flags, + readout_grad_value_to_output_weight_tensors, + readout_grad_output_cell_bias_tensors, + recurrent_query_grad_tensors, + output_query_grad_tensors, + message_strategy_grad_tensors, + message_strategy_grad_rows, + transition_source_tensors, + transition_source_recurrent_cell_idx_tensors, + trainable_param_tensors, + runtime_metadata_tensors, + coord_count, + head_dim, + value_dim, + }; + TORCH_CHECK( + strategy.handler != nullptr && strategy.handler->run != nullptr, + "registered parameter reducer strategy has no run function: ", + strategy.strategy_kind); + strategy.handler->run(context); +} + +std::vector flat_bucket_registered_temporal_parameter_reducer_program_cuda( + const at::Tensor& parameter_reducer_rows, + const at::Tensor& parameter_reducer_strategy_rows, + const at::Tensor& parameter_reducer_trainable_role_rows, + const at::Tensor& parameter_reducer_runtime_metadata_rows, + const at::Tensor& transition_source_rows, + const at::Tensor& transition_trainable_rows, + std::vector sender_grad_weight_tensors, + std::vector sender_group_id_tensors, + const at::Tensor& sender_grouped_flags, + std::vector readout_grad_value_to_output_weight_tensors, + std::vector readout_grad_output_cell_bias_tensors, + std::vector recurrent_query_grad_tensors, + std::vector output_query_grad_tensors, + std::vector message_strategy_grad_tensors, + const at::Tensor& message_strategy_grad_rows, + std::vector transition_source_tensors, + std::vector transition_source_recurrent_cell_idx_tensors, + std::vector parameter_output_tensors, + std::vector trainable_param_tensors, + std::vector runtime_metadata_tensors, + int64_t coord_count, + int64_t head_dim, + int64_t value_dim, + int64_t schema_version) { + TORCH_CHECK(schema_version == 1, "registered parameter reducer program schema version mismatch"); + check_cpu_long_rank2(parameter_reducer_rows, "parameter_reducer_rows", 8); + check_cpu_long_rank2(parameter_reducer_strategy_rows, "parameter_reducer_strategy_rows", 9); + check_cpu_long_rank2(parameter_reducer_trainable_role_rows, "parameter_reducer_trainable_role_rows", 6); + check_cpu_long_rank2(parameter_reducer_runtime_metadata_rows, "parameter_reducer_runtime_metadata_rows", 4); + check_cpu_long_rank2(transition_source_rows, "transition_source_rows", 8); + check_cpu_long_rank2(transition_trainable_rows, "transition_trainable_rows", 9); + check_cpu_long_rank2(message_strategy_grad_rows, "message_strategy_grad_rows", 5); + TORCH_CHECK( + sender_grouped_flags.device().is_cpu() && sender_grouped_flags.scalar_type() == at::kLong && + sender_grouped_flags.dim() == 1, + "sender_grouped_flags must be CPU int64 rank-1"); + + RegisteredParameterReducerExpectedCounts expected_counts; + const std::vector reducer_handlers = + decode_registered_parameter_reducer_handlers( + parameter_reducer_rows, + parameter_reducer_strategy_rows, + expected_counts); + const RegisteredParameterReducerRoleTable role_table = decode_registered_parameter_reducer_role_table( + parameter_reducer_trainable_role_rows, + parameter_reducer_runtime_metadata_rows, + trainable_param_tensors, + runtime_metadata_tensors); + validate_registered_parameter_output_tensors(trainable_param_tensors, parameter_output_tensors); + std::vector param_outputs = parameter_output_tensors; + TORCH_CHECK( + registered_parameter_reducer_expected_count(expected_counts, kParameterReducerCountSender) == + static_cast(sender_grad_weight_tensors.size()), + "sender K/V reducer row tensor_count does not match provided gradient tensors"); + TORCH_CHECK( + sender_grad_weight_tensors.size() == sender_group_id_tensors.size() && + sender_grad_weight_tensors.size() == static_cast(sender_grouped_flags.size(0)), + "sender K/V reducer tensor tables must have aligned weights, group ids, and grouped flags"); + TORCH_CHECK( + registered_parameter_reducer_expected_count(expected_counts, kParameterReducerCountReadout) == + static_cast( + readout_grad_value_to_output_weight_tensors.size() + + readout_grad_output_cell_bias_tensors.size()), + "readout reducer row tensor_count does not match provided gradient tensors"); + TORCH_CHECK( + registered_parameter_reducer_expected_count(expected_counts, kParameterReducerCountRecurrentQuery) == + static_cast(recurrent_query_grad_tensors.size()), + "recurrent query reducer rows do not match provided gradient tensors"); + TORCH_CHECK( + registered_parameter_reducer_expected_count(expected_counts, kParameterReducerCountOutputQuery) == + static_cast(output_query_grad_tensors.size()), + "output query reducer rows do not match provided gradient tensors"); + TORCH_CHECK( + registered_parameter_reducer_expected_count(expected_counts, kParameterReducerCountMessageStrategy) == + static_cast(message_strategy_grad_tensors.size()), + "message strategy reducer row tensor_count does not match provided gradient tensors"); + TORCH_CHECK( + transition_source_recurrent_cell_idx_tensors.size() == + static_cast(transition_source_rows.size(0)), + "transition source rows must have one recurrent index tensor per source row"); + int64_t transition_source_tensor_total = 0; + const int64_t* transition_source_row_ptr = transition_source_rows.data_ptr(); + for (int64_t row = 0; row < transition_source_rows.size(0); ++row) { + const int64_t* item = transition_source_row_ptr + row * 8; + TORCH_CHECK(item[0] == row, "transition source row index mismatch"); + TORCH_CHECK( + item[3] == kTransitionSourceMaterialized || item[3] == kTransitionSourceStaticSource, + "unknown transition source row kind: ", + item[3]); + TORCH_CHECK(item[4] >= 0 && item[5] > 0, "transition source rows require non-empty tensor spans"); + TORCH_CHECK( + item[4] == transition_source_tensor_total, + "transition source tensor spans must be contiguous and compiler ordered"); + transition_source_tensor_total += item[5]; + } + TORCH_CHECK( + transition_source_tensor_total == static_cast(transition_source_tensors.size()), + "transition source row tensor spans do not match provided source tensors"); + validate_registered_transition_trainable_reducer_rows(transition_trainable_rows, trainable_param_tensors); + + for (const RegisteredParameterReducerStrategy& handler : reducer_handlers) { + run_registered_parameter_reducer_handler( + handler, + param_outputs, + role_table, + transition_source_rows, + transition_trainable_rows, + sender_grad_weight_tensors, + sender_group_id_tensors, + sender_grouped_flags, + readout_grad_value_to_output_weight_tensors, + readout_grad_output_cell_bias_tensors, + recurrent_query_grad_tensors, + output_query_grad_tensors, + message_strategy_grad_tensors, + message_strategy_grad_rows, + transition_source_tensors, + transition_source_recurrent_cell_idx_tensors, + trainable_param_tensors, + runtime_metadata_tensors, + coord_count, + head_dim, + value_dim); + } + + std::vector outputs; + outputs.reserve(param_outputs.size()); + for (const at::Tensor& grad : param_outputs) { + outputs.push_back(grad); + } + return outputs; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/program_spans_and_handlers.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/program_spans_and_handlers.cuh new file mode 100644 index 00000000..29e88bee --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/program_spans_and_handlers.cuh @@ -0,0 +1,688 @@ +#pragma once + +struct RegisteredFusedProgramSpan { + int64_t direction_opcode; + int64_t executor_row_index; + int64_t executor_id; + int64_t surface_opcode; + int64_t bucket_ordinal; + int64_t primitive_row_start; + int64_t primitive_row_count; + int64_t receiver_start; + int64_t receiver_count; + int64_t binding_start; + int64_t binding_count; + int64_t memory_count; + int64_t handler_kind; + int64_t handler_primitive_opcode; + int64_t handler_primitive_row_count; + int64_t handler_flags; + int64_t required_effect_mask; + int64_t strategy_id_hash; + int64_t program_access_count; + int64_t state_carry_rule_count; + int64_t verified_rewrite_required; +}; + +struct RegisteredNativeStrategyRow { + int64_t direction_opcode; + int64_t surface_opcode; + int64_t executor_id; + int64_t handler_kind; + int64_t primitive_opcode; + int64_t primitive_row_count; + int64_t capability_flags; + int64_t effect_mask; + int64_t row_schema_version; + int64_t tensor_binding_schema_version; + int64_t metadata_schema_version; + int64_t cuda_kernel_abi_version; + int64_t strategy_id_hash; + int64_t program_access_count; + int64_t state_carry_rule_count; + int64_t verified_rewrite_required; + int64_t native_callable_hash; +}; + +struct RegisteredForwardExecutorHandler { + int64_t handler_kind; + int64_t executor_id; + int64_t surface_opcode; + const char* name; + bool provides_temporal_message_carrier; + bool provides_temporal_readout; + bool runs_transition_program; +}; + +struct RegisteredReverseExecutorHandler { + int64_t handler_kind; + int64_t executor_id; + int64_t surface_opcode; + const char* name; + int64_t primitive_opcode; + int64_t primitive_row_count; + bool consumes_temporal_message_carrier; + bool consumes_temporal_readout; + bool runs_transition_adjoint; +}; + +inline RegisteredFusedProgramSpan registered_fused_program_span_at( + const at::Tensor& spans, + int64_t span_index) { + const int64_t* row = spans.data_ptr() + span_index * kFusedProgramSpanColumns; + return RegisteredFusedProgramSpan{ + row[0], + row[1], + row[2], + row[3], + row[4], + row[5], + row[6], + row[7], + row[8], + row[9], + row[10], + row[11], + row[12], + row[13], + row[14], + row[15], + row[16], + row[17], + row[18], + row[19], + row[20], + }; +} + +inline RegisteredNativeStrategyRow registered_native_strategy_row_at( + const at::Tensor& native_strategy_rows, + int64_t row_index) { + const int64_t* row = native_strategy_rows.data_ptr() + row_index * kNativeStrategyRowColumns; + return RegisteredNativeStrategyRow{ + row[0], + row[1], + row[2], + row[3], + row[4], + row[5], + row[6], + row[7], + row[8], + row[9], + row[10], + row[11], + row[12], + row[13], + row[14], + row[15], + row[16], + }; +} + +inline bool registered_native_strategy_row_matches_span( + const RegisteredNativeStrategyRow& native_strategy, + int64_t direction_opcode, + const RegisteredFusedProgramSpan& span) { + return native_strategy.direction_opcode == direction_opcode && + native_strategy.surface_opcode == span.surface_opcode && + native_strategy.executor_id == span.executor_id && + native_strategy.handler_kind == span.handler_kind && + native_strategy.primitive_opcode == span.handler_primitive_opcode && + native_strategy.primitive_row_count == span.handler_primitive_row_count && + native_strategy.capability_flags == span.handler_flags && + native_strategy.effect_mask == span.required_effect_mask && + native_strategy.strategy_id_hash == span.strategy_id_hash && + native_strategy.program_access_count == span.program_access_count && + native_strategy.state_carry_rule_count == span.state_carry_rule_count && + native_strategy.verified_rewrite_required == span.verified_rewrite_required; +} + +inline RegisteredNativeStrategyRow registered_native_strategy_row_for_span( + const at::Tensor& native_strategy_rows, + int64_t direction_opcode, + const RegisteredFusedProgramSpan& span, + const char* name) { + check_cpu_long_rank2(native_strategy_rows, "native_strategy_rows", kNativeStrategyRowColumns); + RegisteredNativeStrategyRow selected{}; + bool found = false; + for (int64_t row_index = 0; row_index < native_strategy_rows.size(0); ++row_index) { + const RegisteredNativeStrategyRow strategy = registered_native_strategy_row_at(native_strategy_rows, row_index); + if (!registered_native_strategy_row_matches_span(strategy, direction_opcode, span)) { + continue; + } + TORCH_CHECK( + !found, + name, + " has duplicate compiler native strategy rows for executor_id=", + span.executor_id, + ",surface_opcode=", + span.surface_opcode, + ",strategy_hash=", + span.strategy_id_hash); + selected = strategy; + found = true; + } + TORCH_CHECK( + found, + name, + " has no compiler native strategy row for executor_id=", + span.executor_id, + ",surface_opcode=", + span.surface_opcode, + ",handler_kind=", + span.handler_kind, + ",primitive_opcode=", + span.handler_primitive_opcode, + ",strategy_hash=", + span.strategy_id_hash); + return selected; +} + +struct RegisteredFusedProgramMemoryFacts { + int64_t row_count = 0; + bool message_workspace = false; + bool output_workspace = false; + bool transition_workspace = false; + bool grad_read = false; + bool message_emit = false; + bool message_read = false; + bool output_emit = false; + bool parameter_grad_emit = false; + bool parameter_read = false; + bool state_read = false; + bool state_write = false; + bool tape_policy = false; +}; + +inline bool registered_memory_row_belongs_to_span( + const int64_t* row, + const RegisteredFusedProgramSpan& span) { + const int64_t primitive_row_index = row[1]; + const int64_t bucket_ordinal = row[2]; + if (primitive_row_index >= span.primitive_row_start && + primitive_row_index < span.primitive_row_start + span.primitive_row_count) { + return true; + } + return primitive_row_index == -1 && bucket_ordinal == span.bucket_ordinal; +} + +inline RegisteredFusedProgramMemoryFacts registered_fused_program_memory_facts_for_span( + const at::Tensor& memory_liveness_rows, + const RegisteredFusedProgramSpan& span) { + check_cpu_long_rank2(memory_liveness_rows, "memory_liveness_rows", 10); + RegisteredFusedProgramMemoryFacts facts; + const int64_t* rows = memory_liveness_rows.data_ptr(); + for (int64_t row_index = 0; row_index < memory_liveness_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 10; + if (!registered_memory_row_belongs_to_span(row, span)) { + continue; + } + ++facts.row_count; + const int64_t workspace_opcode = row[6]; + const int64_t effect_opcode = row[7]; + facts.message_workspace = facts.message_workspace || workspace_opcode == kMemoryWorkspaceMessage; + facts.output_workspace = facts.output_workspace || workspace_opcode == kMemoryWorkspaceOutput; + facts.transition_workspace = facts.transition_workspace || workspace_opcode == kMemoryWorkspaceTransition; + facts.grad_read = facts.grad_read || effect_opcode == kMemoryEffectGradRead; + facts.message_emit = facts.message_emit || effect_opcode == kMemoryEffectMessageEmit; + facts.message_read = facts.message_read || effect_opcode == kMemoryEffectMessageRead; + facts.output_emit = facts.output_emit || effect_opcode == kMemoryEffectOutputEmit; + facts.parameter_grad_emit = facts.parameter_grad_emit || effect_opcode == kMemoryEffectParameterGradEmit; + facts.parameter_read = facts.parameter_read || effect_opcode == kMemoryEffectParameterRead; + facts.state_read = facts.state_read || effect_opcode == kMemoryEffectStateRead; + facts.state_write = facts.state_write || effect_opcode == kMemoryEffectStateWrite; + facts.tape_policy = facts.tape_policy || effect_opcode == kMemoryEffectTapePolicy; + } + TORCH_CHECK( + facts.row_count == span.memory_count, + "registered fused program span memory count mismatch for executor row ", + span.executor_row_index); + return facts; +} + +inline void require_registered_memory_fact(bool present, const char* subject, const char* requirement) { + TORCH_CHECK(present, subject, " is missing compiler memory-liveness requirement ", requirement); +} + +inline void require_registered_handler_effect_contract( + const RegisteredFusedProgramMemoryFacts& facts, + int64_t required_effect_mask, + const char* subject) { + if ((required_effect_mask & kHandlerEffectStateRead) != 0) { + require_registered_memory_fact(facts.state_read, subject, "handler_effect:state_read"); + } + if ((required_effect_mask & kHandlerEffectParameterRead) != 0) { + require_registered_memory_fact(facts.parameter_read, subject, "handler_effect:parameter_read"); + } + if ((required_effect_mask & kHandlerEffectMessageEmit) != 0) { + require_registered_memory_fact(facts.message_emit, subject, "handler_effect:message_emit"); + } + if ((required_effect_mask & kHandlerEffectMessageRead) != 0) { + require_registered_memory_fact(facts.message_read, subject, "handler_effect:message_read"); + } + if ((required_effect_mask & kHandlerEffectOutputEmit) != 0) { + require_registered_memory_fact(facts.output_emit, subject, "handler_effect:output_emit"); + } + if ((required_effect_mask & kHandlerEffectStateWrite) != 0) { + require_registered_memory_fact(facts.state_write, subject, "handler_effect:state_write"); + } + if ((required_effect_mask & kHandlerEffectTapePolicy) != 0) { + require_registered_memory_fact(facts.tape_policy, subject, "handler_effect:tape_policy"); + } + if ((required_effect_mask & kHandlerEffectGradRead) != 0) { + require_registered_memory_fact(facts.grad_read, subject, "handler_effect:grad_read"); + } + if ((required_effect_mask & kHandlerEffectParameterGradEmit) != 0) { + require_registered_memory_fact(facts.parameter_grad_emit, subject, "handler_effect:parameter_grad_emit"); + } +} + +inline void require_registered_surface_memory_contract( + const at::Tensor& memory_liveness_rows, + const RegisteredFusedProgramSpan& span, + const char* subject) { + const RegisteredFusedProgramMemoryFacts facts = + registered_fused_program_memory_facts_for_span(memory_liveness_rows, span); + require_registered_handler_effect_contract(facts, span.required_effect_mask, subject); + if (span.surface_opcode == kMessageSurfaceOpcode) { + require_registered_memory_fact(facts.state_read, subject, "state_read"); + require_registered_memory_fact(facts.parameter_read, subject, "parameter_read"); + require_registered_memory_fact(facts.message_emit, subject, "message_emit"); + require_registered_memory_fact(facts.message_workspace, subject, "message_workspace"); + return; + } + if (span.surface_opcode == kReadoutSurfaceOpcode) { + require_registered_memory_fact(facts.state_read, subject, "state_read"); + require_registered_memory_fact(facts.parameter_read, subject, "parameter_read"); + require_registered_memory_fact(facts.output_emit, subject, "output_emit"); + require_registered_memory_fact(facts.output_workspace, subject, "output_workspace"); + return; + } + if (span.surface_opcode == kTransitionSurfaceOpcode) { + require_registered_memory_fact(facts.state_read, subject, "state_read"); + require_registered_memory_fact(facts.message_read, subject, "message_read"); + require_registered_memory_fact(facts.state_write, subject, "state_write"); + require_registered_memory_fact(facts.tape_policy, subject, "tape_policy"); + require_registered_memory_fact(facts.transition_workspace, subject, "transition_workspace"); + return; + } + TORCH_CHECK(false, subject, " has unsupported compiler memory surface opcode ", span.surface_opcode); +} + +inline int64_t registered_forward_required_capability_for_surface(int64_t surface_opcode) { + if (surface_opcode == kMessageSurfaceOpcode) { + return kForwardHandlerMessageCarrierFlag; + } + if (surface_opcode == kReadoutSurfaceOpcode) { + return kForwardHandlerReadoutFlag; + } + if (surface_opcode == kTransitionSurfaceOpcode) { + return kForwardHandlerTransitionFlag; + } + TORCH_CHECK(false, "registered fused forward program has unsupported handler surface opcode ", surface_opcode); + return 0; +} + +inline int64_t registered_reverse_required_capability_for_surface(int64_t surface_opcode) { + if (surface_opcode == kMessageSurfaceOpcode) { + return kReverseHandlerMessageCarrierFlag; + } + if (surface_opcode == kReadoutSurfaceOpcode) { + return kReverseHandlerReadoutFlag; + } + if (surface_opcode == kTransitionSurfaceOpcode) { + return kReverseHandlerTransitionFlag; + } + TORCH_CHECK(false, "registered fused reverse program has unsupported handler surface opcode ", surface_opcode); + return 0; +} + +inline RegisteredForwardExecutorHandler registered_forward_executor_handler_for_span( + const RegisteredFusedProgramSpan& span) { + TORCH_CHECK( + span.direction_opcode == kForwardDirectionOpcode, + "registered fused forward program handler span has invalid direction ", + span.direction_opcode); + TORCH_CHECK(span.handler_kind > 0, "registered fused forward program handler kind must be compiler-owned"); + TORCH_CHECK( + span.handler_primitive_opcode >= 0, + "registered fused forward program handler primitive opcode is invalid"); + TORCH_CHECK( + span.handler_primitive_row_count == 0 || span.handler_primitive_row_count == span.primitive_row_count, + "registered fused forward program handler primitive row count does not match span"); + TORCH_CHECK(span.strategy_id_hash > 0, "registered fused forward program span has no strategy identity hash"); + TORCH_CHECK(span.program_access_count >= 0, "registered fused forward program span has invalid access count"); + TORCH_CHECK( + span.state_carry_rule_count >= 0, + "registered fused forward program span has invalid carry-rule count"); + TORCH_CHECK( + span.verified_rewrite_required == 0 || span.verified_rewrite_required == 1, + "registered fused forward program span has invalid rewrite flag"); + const int64_t required_capability = registered_forward_required_capability_for_surface(span.surface_opcode); + TORCH_CHECK( + (span.handler_flags & required_capability) == required_capability, + "registered fused forward program handler row is missing capability flags from compiler row: handler_kind=", + span.handler_kind, + ", executor_id=", + span.executor_id, + ", surface_opcode=", + span.surface_opcode, + ", bucket=", + span.bucket_ordinal); + return RegisteredForwardExecutorHandler{ + span.handler_kind, + span.executor_id, + span.surface_opcode, + "registered.forward.compiler_handler", + (span.handler_flags & kForwardHandlerMessageCarrierFlag) != 0, + (span.handler_flags & kForwardHandlerReadoutFlag) != 0, + (span.handler_flags & kForwardHandlerTransitionFlag) != 0, + }; +} + +inline RegisteredReverseExecutorHandler registered_reverse_executor_handler_for_span( + const RegisteredFusedProgramSpan& span) { + TORCH_CHECK( + span.direction_opcode == kReverseDirectionOpcode, + "registered fused reverse program handler span has invalid direction ", + span.direction_opcode); + TORCH_CHECK(span.handler_kind > 0, "registered fused reverse program handler kind must be compiler-owned"); + TORCH_CHECK( + span.handler_primitive_opcode >= 0, + "registered fused reverse program handler primitive opcode is invalid"); + TORCH_CHECK( + span.handler_primitive_row_count == 0 || span.handler_primitive_row_count == span.primitive_row_count, + "registered fused reverse program handler primitive row count does not match span"); + TORCH_CHECK(span.strategy_id_hash > 0, "registered fused reverse program span has no strategy identity hash"); + TORCH_CHECK(span.program_access_count >= 0, "registered fused reverse program span has invalid access count"); + TORCH_CHECK( + span.state_carry_rule_count >= 0, + "registered fused reverse program span has invalid carry-rule count"); + TORCH_CHECK( + span.verified_rewrite_required == 0 || span.verified_rewrite_required == 1, + "registered fused reverse program span has invalid rewrite flag"); + const int64_t required_capability = registered_reverse_required_capability_for_surface(span.surface_opcode); + TORCH_CHECK( + (span.handler_flags & required_capability) == required_capability, + "registered fused reverse program handler row is missing capability flags from compiler row: handler_kind=", + span.handler_kind, + ", executor_id=", + span.executor_id, + ", surface_opcode=", + span.surface_opcode, + ", bucket=", + span.bucket_ordinal); + return RegisteredReverseExecutorHandler{ + span.handler_kind, + span.executor_id, + span.surface_opcode, + "registered.reverse.compiler_handler", + span.handler_primitive_opcode, + span.handler_primitive_row_count, + (span.handler_flags & kReverseHandlerMessageCarrierFlag) != 0, + (span.handler_flags & kReverseHandlerReadoutFlag) != 0, + (span.handler_flags & kReverseHandlerTransitionFlag) != 0, + }; +} + +inline void validate_registered_fused_forward_span_memory( + const at::Tensor& forward_spans, + const at::Tensor& memory_liveness_rows) { + check_cpu_long_rank2(forward_spans, "forward_spans", kFusedProgramSpanColumns); + for (int64_t span_index = 0; span_index < forward_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, span_index); + TORCH_CHECK(span.direction_opcode == kForwardDirectionOpcode, "forward_spans contains a non-forward span"); + const RegisteredForwardExecutorHandler& handler = registered_forward_executor_handler_for_span(span); + require_registered_surface_memory_contract(memory_liveness_rows, span, handler.name); + } +} + +inline void validate_registered_fused_reverse_span_memory( + const at::Tensor& reverse_spans, + const at::Tensor& memory_liveness_rows) { + check_cpu_long_rank2(reverse_spans, "reverse_spans", kFusedProgramSpanColumns); + for (int64_t span_index = 0; span_index < reverse_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(reverse_spans, span_index); + TORCH_CHECK(span.direction_opcode == kReverseDirectionOpcode, "reverse_spans contains a non-reverse span"); + const RegisteredReverseExecutorHandler& handler = registered_reverse_executor_handler_for_span(span); + require_registered_surface_memory_contract(memory_liveness_rows, span, handler.name); + } +} + +inline std::vector registered_forward_handler_span_indices_by_capability( + const at::Tensor& forward_spans, + int64_t surface_opcode, + int64_t capability_flag, + const char* subject) { + check_cpu_long_rank2(forward_spans, "forward_spans", kFusedProgramSpanColumns); + std::vector span_indices; + for (int64_t span_index = 0; span_index < forward_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, span_index); + const RegisteredForwardExecutorHandler& handler = registered_forward_executor_handler_for_span(span); + if (span.surface_opcode != surface_opcode || (span.handler_flags & capability_flag) == 0) { + continue; + } + span_indices.push_back(span_index); + } + TORCH_CHECK(!span_indices.empty(), subject, " requires registered forward handler capability"); + return span_indices; +} + +inline std::vector registered_reverse_handler_span_indices_by_capability( + const at::Tensor& reverse_spans, + int64_t surface_opcode, + int64_t capability_flag, + const char* subject) { + check_cpu_long_rank2(reverse_spans, "reverse_spans", kFusedProgramSpanColumns); + std::vector span_indices; + for (int64_t span_index = 0; span_index < reverse_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(reverse_spans, span_index); + TORCH_CHECK(span.direction_opcode == kReverseDirectionOpcode, "reverse_spans contains a non-reverse span"); + const RegisteredReverseExecutorHandler& handler = registered_reverse_executor_handler_for_span(span); + if (span.surface_opcode != surface_opcode || (span.handler_flags & capability_flag) == 0) { + continue; + } + span_indices.push_back(span_index); + } + TORCH_CHECK(!span_indices.empty(), subject, " requires registered reverse handler capability"); + return span_indices; +} + +inline std::vector validate_registered_temporal_fused_program( + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version, + bool require_native_callable_output_rows) { + TORCH_CHECK(schema_version == 1, "registered temporal fused program schema version mismatch"); + check_cpu_long_rank2(primitive_rows, "primitive_rows", 4); + TORCH_CHECK(primitive_rows.size(0) > 0, "registered temporal fused program requires primitive rows"); + validate_registered_fused_program_executor_rows(primitive_rows, forward_executor_rows, "forward_executor_rows"); + validate_registered_fused_program_executor_rows(primitive_rows, reverse_executor_rows, "reverse_executor_rows"); + check_cpu_long_rank2(forward_handler_rows, "forward_handler_rows", kFusedHandlerRowColumns); + check_cpu_long_rank2(reverse_handler_rows, "reverse_handler_rows", kFusedHandlerRowColumns); + check_cpu_long_rank2(native_strategy_rows, "native_strategy_rows", kNativeStrategyRowColumns); + validate_registered_native_callable_binding_schema_rows(native_callable_binding_schema_rows, schema_version); + validate_registered_native_callable_output_rows(native_callable_output_rows, schema_version); + if (require_native_callable_output_rows) { + TORCH_CHECK( + native_callable_binding_schema_rows.size(0) > 0, + "registered temporal fused program requires compiler-owned native callable binding schema rows"); + TORCH_CHECK( + native_callable_output_rows.size(0) > 0, + "registered temporal fused program requires compiler-owned native callable output rows"); + } + TORCH_CHECK(forward_executor_rows.size(0) > 0, "fused forward program requires executor rows"); + TORCH_CHECK(reverse_executor_rows.size(0) > 0, "fused reverse program requires executor rows"); + validate_registered_fused_program_binding_rows( + primitive_rows, + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + "forward_executor_binding_rows"); + validate_registered_fused_program_binding_rows( + primitive_rows, + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + "reverse_executor_binding_rows"); + validate_registered_fused_program_memory_rows(primitive_rows, memory_liveness_rows); + at::Tensor forward_spans = decode_registered_fused_program_spans( + primitive_rows, + forward_executor_rows, + forward_handler_rows, + native_strategy_rows, + forward_executor_binding_rows, + memory_liveness_rows, + kForwardDirectionOpcode, + schema_version, + native_strategy_rows.size(0) > 0, + "fused_forward_program"); + at::Tensor reverse_spans = decode_registered_fused_program_spans( + primitive_rows, + reverse_executor_rows, + reverse_handler_rows, + native_strategy_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + kReverseDirectionOpcode, + schema_version, + native_strategy_rows.size(0) > 0, + "fused_backward_program"); + validate_registered_fused_forward_span_memory(forward_spans, memory_liveness_rows); + validate_registered_fused_reverse_span_memory(reverse_spans, memory_liveness_rows); + + at::Tensor summary = at::empty({7}, primitive_rows.options()); + int64_t* out = summary.data_ptr(); + out[0] = schema_version; + out[1] = primitive_rows.size(0); + out[2] = forward_executor_rows.size(0); + out[3] = reverse_executor_rows.size(0); + out[4] = forward_executor_binding_rows.size(0); + out[5] = reverse_executor_binding_rows.size(0); + out[6] = memory_liveness_rows.size(0); + return {summary, forward_spans, reverse_spans}; +} + +inline std::vector validate_registered_temporal_fused_program( + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version) { + at::Tensor empty_strategy_rows = at::empty({0, kNativeStrategyRowColumns}, forward_handler_rows.options()); + return validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + empty_strategy_rows, + at::empty({0, kNativeCallableBindingSchemaRowColumns}, forward_handler_rows.options()), + at::empty({0, kNativeCallableOutputRowColumns}, forward_handler_rows.options()), + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version, + false); +} + +inline std::vector validate_registered_temporal_fused_program( + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version) { + return validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version, + false); +} + +inline int64_t count_temporal_program_access_rows_for_executor( + const at::Tensor& access_rows, + int64_t executor_row_index, + int64_t bucket_ordinal) { + const int64_t* rows = access_rows.data_ptr(); + int64_t count = 0; + for (int64_t row_index = 0; row_index < access_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + if (row[1] == executor_row_index && row[2] == bucket_ordinal) { + ++count; + } + } + return count; +} + +inline void validate_registered_fused_forward_program_dispatch( + const at::Tensor& forward_spans, + const at::Tensor& forward_program_access_rows) { + check_cpu_long_rank2(forward_spans, "forward_spans", kFusedProgramSpanColumns); + int64_t temporal_message_carrier_count = 0; + int64_t temporal_readout_count = 0; + int64_t transition_handler_count = 0; + for (int64_t row_index = 0; row_index < forward_spans.size(0); ++row_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, row_index); + TORCH_CHECK(span.direction_opcode == kForwardDirectionOpcode, "forward_spans contains a non-forward span"); + const RegisteredForwardExecutorHandler& handler = registered_forward_executor_handler_for_span(span); + temporal_message_carrier_count += handler.provides_temporal_message_carrier ? 1 : 0; + temporal_readout_count += handler.provides_temporal_readout ? 1 : 0; + transition_handler_count += handler.runs_transition_program ? 1 : 0; + const int64_t access_row_count = count_temporal_program_access_rows_for_executor( + forward_program_access_rows, + span.executor_row_index, + span.bucket_ordinal); + const bool access_count_matches = handler.runs_transition_program + ? access_row_count >= span.program_access_count + : access_row_count == span.program_access_count; + TORCH_CHECK( + access_count_matches, + "forward_spans row ", + row_index, + " program access row count does not match compiler strategy contract for ", + handler.name, + ": expected=", + span.program_access_count, + ", actual=", + access_row_count); + } + TORCH_CHECK( + temporal_message_carrier_count > 0, + "fused forward program requires at least one compiler-owned temporal message-carrier handler"); + TORCH_CHECK( + temporal_readout_count > 0, + "fused forward program requires at least one compiler-owned temporal readout handler"); + TORCH_CHECK( + transition_handler_count > 0, + "fused forward program requires at least one compiler-owned transition handler"); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/program_tensor_access.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/program_tensor_access.cuh new file mode 100644 index 00000000..cf1e5229 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/program_tensor_access.cuh @@ -0,0 +1,928 @@ +#pragma once + +inline void validate_registered_fused_backward_output_grad_reverse_span( + const at::Tensor& reverse_spans, + int64_t output_count) { + check_cpu_long_rank2(reverse_spans, "reverse_spans", kFusedProgramSpanColumns); + const int64_t* spans = reverse_spans.data_ptr(); + bool saw_readout_span = false; + for (int64_t span_index = 0; span_index < reverse_spans.size(0); ++span_index) { + const int64_t* span = spans + span_index * kFusedProgramSpanColumns; + TORCH_CHECK(span[0] == kReverseDirectionOpcode, "reverse_spans contains a non-reverse span"); + if (span[3] != kReadoutSurfaceOpcode) { + continue; + } + TORCH_CHECK( + span[8] == output_count, + "fused backward output-gradient program readout receiver count mismatch"); + saw_readout_span = true; + } + TORCH_CHECK( + saw_readout_span, + "fused backward output-gradient program requires a compiler-selected readout reverse span"); +} + +inline void check_program_tensor_binding_rows(const at::Tensor& tensor_binding_rows) { + check_cpu_long_rank2(tensor_binding_rows, "program_tensor_binding_rows", 4); + const int64_t* rows = tensor_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < tensor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 4; + TORCH_CHECK(row[0] >= 0, "program_tensor_binding_rows row ", row_index, " has invalid binding index"); + TORCH_CHECK(row[1] >= 0, "program_tensor_binding_rows row ", row_index, " has invalid tensor table index"); + TORCH_CHECK(row[2] >= 0, "program_tensor_binding_rows row ", row_index, " has invalid primitive row index"); + TORCH_CHECK(row[3] >= 0 && row[3] <= 2, "program_tensor_binding_rows row ", row_index, " has invalid binding kind"); + } +} + +inline int64_t program_tensor_index_for_binding( + const at::Tensor& tensor_binding_rows, + int64_t binding_index) { + const int64_t* rows = tensor_binding_rows.data_ptr(); + int64_t result = -1; + for (int64_t row_index = 0; row_index < tensor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 4; + if (row[0] != binding_index) { + continue; + } + TORCH_CHECK(result < 0 || result == row[1], "program tensor binding index maps to multiple tensor slots"); + result = row[1]; + } + TORCH_CHECK(result >= 0, "program tensor table has no slot for compiler binding index ", binding_index); + return result; +} + +inline at::Tensor program_tensor_for_binding( + const std::vector& tensor_table, + const at::Tensor& tensor_binding_rows, + int64_t binding_index, + const char* subject) { + const int64_t tensor_index = program_tensor_index_for_binding(tensor_binding_rows, binding_index); + TORCH_CHECK( + tensor_index < static_cast(tensor_table.size()), + subject, + " references tensor table index outside program_tensors"); + const at::Tensor& tensor = tensor_table[static_cast(tensor_index)]; + TORCH_CHECK(tensor.defined() && tensor.numel() > 0, subject, " references an empty program tensor"); + return tensor; +} + +inline at::Tensor program_tensor_for_binding_allow_empty( + const std::vector& tensor_table, + const at::Tensor& tensor_binding_rows, + int64_t binding_index, + const char* subject) { + const int64_t tensor_index = program_tensor_index_for_binding(tensor_binding_rows, binding_index); + TORCH_CHECK( + tensor_index < static_cast(tensor_table.size()), + subject, + " references tensor table index outside program_tensors"); + const at::Tensor& tensor = tensor_table[static_cast(tensor_index)]; + TORCH_CHECK(tensor.defined(), subject, " references an undefined program tensor"); + return tensor; +} + +inline double program_scalar_double_for_binding( + const std::vector& tensor_table, + const at::Tensor& tensor_binding_rows, + int64_t binding_index, + const char* subject) { + at::Tensor tensor = program_tensor_for_binding(tensor_table, tensor_binding_rows, binding_index, subject); + TORCH_CHECK(tensor.numel() == 1, subject, " must bind exactly one scalar value"); + return tensor.reshape({-1}).select(0, 0).item(); +} + +inline int64_t program_scalar_int_for_binding( + const std::vector& tensor_table, + const at::Tensor& tensor_binding_rows, + int64_t binding_index, + const char* subject) { + at::Tensor tensor = program_tensor_for_binding(tensor_table, tensor_binding_rows, binding_index, subject); + TORCH_CHECK(tensor.numel() == 1, subject, " must bind exactly one scalar value"); + return tensor.reshape({-1}).select(0, 0).item(); +} + +constexpr int64_t kForwardRuntimeRecurrentLocalSenderIdx = 1; +constexpr int64_t kForwardRuntimeOutputLocalSenderIdx = 2; +constexpr int64_t kForwardRuntimeLocalDistance = 3; +constexpr int64_t kForwardRuntimeLocalDelay = 4; +constexpr int64_t kForwardRuntimeInnerSteps = 5; +constexpr int64_t kForwardRuntimeOutputBoundaryTerminal = 6; +constexpr int64_t kForwardRuntimeDistanceScale = 7; +constexpr int64_t kForwardRuntimeHeadDim = 8; +constexpr int64_t kForwardRuntimeValueDim = 9; +constexpr int64_t kForwardRuntimeUseDelay = 10; +constexpr int64_t kReverseRuntimeGraphToBackendOrder = 1; +constexpr int64_t kReverseRuntimeBackendToGraphInverseOrder = 2; +constexpr int64_t kReverseRuntimeOutputLocalSenderIdx = 3; +constexpr int64_t kReverseRuntimeLocalDistance = 4; +constexpr int64_t kReverseRuntimeLocalDelay = 5; +constexpr int64_t kReverseRuntimeOutputNeighborIdx = 6; +constexpr int64_t kReverseRuntimeOutputNeighborValid = 7; +constexpr int64_t kReverseRuntimeOutputEdgeDistance = 8; +constexpr int64_t kReverseRuntimeOutputEdgeDelay = 9; +constexpr int64_t kReverseRuntimeRecurrentLocalSenderIdx = 10; +constexpr int64_t kReverseRuntimeRecurrentNeighborIdx = 11; +constexpr int64_t kReverseRuntimeRecurrentNeighborValid = 12; +constexpr int64_t kReverseRuntimeRecurrentEdgeDistance = 13; +constexpr int64_t kReverseRuntimeRecurrentEdgeDelay = 14; +constexpr int64_t kReverseRuntimeMessageStepIndices = 15; +constexpr int64_t kReverseRuntimeInputCount = 16; +constexpr int64_t kReverseRuntimeRecurrentCount = 17; +constexpr int64_t kReverseRuntimeDistanceScale = 18; +constexpr int64_t kReverseRuntimeUseSparseMessages = 19; +constexpr int64_t kReverseRuntimeUseDelay = 20; +constexpr int64_t kReverseRuntimeGroupSize = 21; +constexpr int64_t kReverseRuntimeHeadDim = 22; +constexpr int64_t kReverseRuntimeValueDim = 23; +constexpr int64_t kReverseRuntimeReturnBoundaryGrad = 24; + +inline void check_forward_program_runtime_rows( + const std::vector& runtime_tensors, + const at::Tensor& runtime_rows) { + check_cpu_long_rank2(runtime_rows, "forward_program_runtime_rows", 6); + const int64_t* rows = runtime_rows.data_ptr(); + for (int64_t row_index = 0; row_index < runtime_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + TORCH_CHECK(row[0] > 0, "forward_program_runtime_rows row ", row_index, " has invalid role opcode"); + TORCH_CHECK(row[1] >= 0, "forward_program_runtime_rows row ", row_index, " has invalid tensor index"); + TORCH_CHECK( + row[1] < static_cast(runtime_tensors.size()), + "forward_program_runtime_rows row ", + row_index, + " references a tensor outside forward_program_runtime_tensors"); + TORCH_CHECK(row[2] > 0, "forward_program_runtime_rows row ", row_index, " has invalid dtype opcode"); + TORCH_CHECK(row[3] >= 0, "forward_program_runtime_rows row ", row_index, " has invalid rank"); + TORCH_CHECK(row[4] > 0, "forward_program_runtime_rows row ", row_index, " has invalid device opcode"); + TORCH_CHECK(row[5] == 1, "forward_program_runtime_rows row ", row_index, " must be required"); + const at::Tensor& tensor = runtime_tensors[static_cast(row[1])]; + TORCH_CHECK(tensor.defined(), "forward_program_runtime_tensors contains an undefined tensor"); + TORCH_CHECK( + tensor.dim() == row[3], + "forward_program_runtime_rows row ", + row_index, + " rank does not match runtime tensor"); + } +} + +inline int64_t forward_program_runtime_tensor_index_for_role( + const at::Tensor& runtime_rows, + int64_t role_opcode) { + const int64_t* rows = runtime_rows.data_ptr(); + int64_t result = -1; + for (int64_t row_index = 0; row_index < runtime_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + if (row[0] != role_opcode) { + continue; + } + TORCH_CHECK(result < 0 || result == row[1], "forward program runtime role maps to multiple tensor slots"); + result = row[1]; + } + TORCH_CHECK(result >= 0, "forward program runtime table has no required role opcode ", role_opcode); + return result; +} + +inline at::Tensor forward_program_runtime_tensor_for_role( + const std::vector& runtime_tensors, + const at::Tensor& runtime_rows, + int64_t role_opcode, + const char* subject) { + const int64_t tensor_index = forward_program_runtime_tensor_index_for_role(runtime_rows, role_opcode); + TORCH_CHECK( + tensor_index < static_cast(runtime_tensors.size()), + subject, + " references runtime tensor outside forward_program_runtime_tensors"); + const at::Tensor& tensor = runtime_tensors[static_cast(tensor_index)]; + TORCH_CHECK(tensor.defined(), subject, " references an undefined runtime tensor"); + return tensor; +} + +inline int64_t forward_program_runtime_int_for_role( + const std::vector& runtime_tensors, + const at::Tensor& runtime_rows, + int64_t role_opcode, + const char* subject) { + at::Tensor tensor = forward_program_runtime_tensor_for_role( + runtime_tensors, runtime_rows, role_opcode, subject); + TORCH_CHECK(tensor.numel() == 1, subject, " must bind exactly one scalar value"); + return tensor.reshape({-1}).select(0, 0).item(); +} + +inline double forward_program_runtime_double_for_role( + const std::vector& runtime_tensors, + const at::Tensor& runtime_rows, + int64_t role_opcode, + const char* subject) { + at::Tensor tensor = forward_program_runtime_tensor_for_role( + runtime_tensors, runtime_rows, role_opcode, subject); + TORCH_CHECK(tensor.numel() == 1, subject, " must bind exactly one scalar value"); + return tensor.reshape({-1}).select(0, 0).item(); +} + +inline void check_reverse_program_runtime_rows( + const std::vector& runtime_tensors, + const at::Tensor& runtime_rows) { + check_cpu_long_rank2(runtime_rows, "reverse_program_runtime_rows", 6); + const int64_t* rows = runtime_rows.data_ptr(); + for (int64_t row_index = 0; row_index < runtime_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + TORCH_CHECK(row[0] > 0, "reverse_program_runtime_rows row ", row_index, " has invalid role opcode"); + TORCH_CHECK(row[1] >= 0, "reverse_program_runtime_rows row ", row_index, " has invalid tensor index"); + TORCH_CHECK( + row[1] < static_cast(runtime_tensors.size()), + "reverse_program_runtime_rows row ", + row_index, + " references a tensor outside reverse_program_runtime_tensors"); + TORCH_CHECK(row[2] > 0, "reverse_program_runtime_rows row ", row_index, " has invalid dtype opcode"); + TORCH_CHECK(row[3] >= 0, "reverse_program_runtime_rows row ", row_index, " has invalid rank"); + TORCH_CHECK(row[4] > 0, "reverse_program_runtime_rows row ", row_index, " has invalid device opcode"); + TORCH_CHECK(row[5] == 1, "reverse_program_runtime_rows row ", row_index, " must be required"); + const at::Tensor& tensor = runtime_tensors[static_cast(row[1])]; + TORCH_CHECK(tensor.defined(), "reverse_program_runtime_tensors contains an undefined tensor"); + TORCH_CHECK( + tensor.dim() == row[3], + "reverse_program_runtime_rows row ", + row_index, + " rank does not match runtime tensor"); + } +} + +inline int64_t reverse_program_runtime_tensor_index_for_role( + const at::Tensor& runtime_rows, + int64_t role_opcode) { + const int64_t* rows = runtime_rows.data_ptr(); + int64_t result = -1; + for (int64_t row_index = 0; row_index < runtime_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + if (row[0] != role_opcode) { + continue; + } + TORCH_CHECK(result < 0 || result == row[1], "reverse program runtime role maps to multiple tensor slots"); + result = row[1]; + } + TORCH_CHECK(result >= 0, "reverse program runtime table has no required role opcode ", role_opcode); + return result; +} + +inline at::Tensor reverse_program_runtime_tensor_for_role( + const std::vector& runtime_tensors, + const at::Tensor& runtime_rows, + int64_t role_opcode, + const char* subject) { + const int64_t tensor_index = reverse_program_runtime_tensor_index_for_role(runtime_rows, role_opcode); + TORCH_CHECK( + tensor_index < static_cast(runtime_tensors.size()), + subject, + " references runtime tensor outside reverse_program_runtime_tensors"); + const at::Tensor& tensor = runtime_tensors[static_cast(tensor_index)]; + TORCH_CHECK(tensor.defined(), subject, " references an undefined runtime tensor"); + return tensor; +} + +inline int64_t reverse_program_runtime_int_for_role( + const std::vector& runtime_tensors, + const at::Tensor& runtime_rows, + int64_t role_opcode, + const char* subject) { + at::Tensor tensor = reverse_program_runtime_tensor_for_role( + runtime_tensors, runtime_rows, role_opcode, subject); + TORCH_CHECK(tensor.numel() == 1, subject, " must bind exactly one scalar value"); + return tensor.reshape({-1}).select(0, 0).item(); +} + +inline double reverse_program_runtime_double_for_role( + const std::vector& runtime_tensors, + const at::Tensor& runtime_rows, + int64_t role_opcode, + const char* subject) { + at::Tensor tensor = reverse_program_runtime_tensor_for_role( + runtime_tensors, runtime_rows, role_opcode, subject); + TORCH_CHECK(tensor.numel() == 1, subject, " must bind exactly one scalar value"); + return tensor.reshape({-1}).select(0, 0).item(); +} + +inline void set_program_tensor_for_binding( + std::vector& tensor_table, + const at::Tensor& tensor_binding_rows, + int64_t binding_index, + const at::Tensor& value, + const char* subject) { + const int64_t tensor_index = program_tensor_index_for_binding(tensor_binding_rows, binding_index); + TORCH_CHECK( + tensor_index < static_cast(tensor_table.size()), + subject, + " output references tensor table index outside program_tensors"); + tensor_table[static_cast(tensor_index)] = value; +} + +inline at::Tensor empty_program_tensor_like(const at::Tensor& tensor) { + if (!tensor.defined()) { + return at::Tensor(); + } + return at::empty({0}, tensor.options()); +} + +inline void zero_forward_transition_state_inputs_for_reset( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_transition_state_carry_rows, + int64_t bucket_ordinal, + const at::Tensor& transition_reset) { + check_cpu_long_rank2(forward_transition_state_carry_rows, "forward_transition_state_carry_rows", 3); + if (!transition_reset.defined() || transition_reset.numel() == 0) { + return; + } + const int64_t* rows = forward_transition_state_carry_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_transition_state_carry_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[0] != bucket_ordinal) { + continue; + } + const int64_t input_binding = row[1]; + const int64_t tensor_index = program_tensor_index_for_binding(program_tensor_binding_rows, input_binding); + TORCH_CHECK( + tensor_index >= 0 && tensor_index < static_cast(program_tensors.size()), + "forward transition reset references tensor table index outside program_tensors"); + at::Tensor& tensor = program_tensors[static_cast(tensor_index)]; + if (tensor.defined() && tensor.numel() > 0) { + tensor = zero_batch_rows_for_reset(tensor, transition_reset, "forward transition state reset input"); + } + } +} + +inline void clear_forward_transition_output_binding_slots( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_executor_binding_rows) { + check_cpu_long_rank2(forward_executor_binding_rows, "forward_executor_binding_rows", 8); + const int64_t* rows = forward_executor_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_executor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 8; + // Binding rows do not carry a surface column. Message/readout spans use + // negative compiler bucket ordinals; transition spans use concrete bucket + // ordinals from the lowered flat-bucket program. + if ( + row[0] != kForwardDirectionOpcode || + row[5] < 0 || + row[6] != kOutputBindingKindOpcode) { + continue; + } + const int64_t tensor_index = program_tensor_index_for_binding(program_tensor_binding_rows, row[4]); + TORCH_CHECK( + tensor_index >= 0 && tensor_index < static_cast(program_tensors.size()), + "forward transition output binding references tensor table index outside program_tensors"); + at::Tensor& tensor = program_tensors[static_cast(tensor_index)]; + tensor = empty_program_tensor_like(tensor); + } +} + +inline bool forward_transition_binding_has_future_input_use( + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t binding_index, + int64_t current_primitive_row_index) { + check_cpu_long_rank2(forward_executor_binding_rows, "forward_executor_binding_rows", 8); + const int64_t primitive_end = span.primitive_row_start + span.primitive_row_count; + const int64_t* rows = forward_executor_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_executor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 8; + if ( + row[0] != kForwardDirectionOpcode || + row[1] != span.executor_row_index || + row[2] != span.executor_id || + row[5] != span.bucket_ordinal || + row[6] != kInputBindingKindOpcode || + row[4] != binding_index) { + continue; + } + if (row[3] > current_primitive_row_index && row[3] < primitive_end) { + return true; + } + } + return false; +} + +inline bool forward_transition_binding_is_current_output( + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t binding_index, + int64_t current_primitive_row_index) { + check_cpu_long_rank2(forward_executor_binding_rows, "forward_executor_binding_rows", 8); + const int64_t* rows = forward_executor_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_executor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 8; + if ( + row[0] == kForwardDirectionOpcode && + row[1] == span.executor_row_index && + row[2] == span.executor_id && + row[3] == current_primitive_row_index && + row[4] == binding_index && + row[5] == span.bucket_ordinal && + row[6] == kOutputBindingKindOpcode) { + return true; + } + } + return false; +} + +inline bool forward_transition_binding_is_any_output( + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t binding_index) { + check_cpu_long_rank2(forward_executor_binding_rows, "forward_executor_binding_rows", 8); + const int64_t* rows = forward_executor_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_executor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 8; + if ( + row[0] == kForwardDirectionOpcode && + row[1] == span.executor_row_index && + row[2] == span.executor_id && + row[4] == binding_index && + row[5] == span.bucket_ordinal && + row[6] == kOutputBindingKindOpcode) { + return true; + } + } + return false; +} + +inline void check_forward_transition_state_carry_rows(const at::Tensor& carry_rows); + +inline bool forward_executor_output_binding_is_active( + const at::Tensor& forward_executor_binding_rows, + int64_t binding_index) { + check_cpu_long_rank2(forward_executor_binding_rows, "forward_executor_binding_rows", 8); + const int64_t* rows = forward_executor_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_executor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 8; + if ( + row[0] == kForwardDirectionOpcode && + row[4] == binding_index && + row[6] == kOutputBindingKindOpcode) { + return true; + } + } + return false; +} + +inline bool forward_transition_binding_is_active_state_carry_source( + const at::Tensor& forward_transition_state_carry_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t bucket_ordinal, + int64_t binding_index) { + check_forward_transition_state_carry_rows(forward_transition_state_carry_rows); + const int64_t* rows = forward_transition_state_carry_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_transition_state_carry_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[0] != bucket_ordinal || row[2] != binding_index) { + continue; + } + if (forward_executor_output_binding_is_active(forward_executor_binding_rows, row[2])) { + return true; + } + } + return false; +} + +inline bool forward_transition_binding_slot_aliases_active_state_carry_source( + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_transition_state_carry_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t bucket_ordinal, + int64_t binding_index) { + check_forward_transition_state_carry_rows(forward_transition_state_carry_rows); + const int64_t tensor_index = program_tensor_index_for_binding(program_tensor_binding_rows, binding_index); + TORCH_CHECK( + tensor_index >= 0, + "forward transition input binding has no compiler tensor-table slot"); + const int64_t* rows = forward_transition_state_carry_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_transition_state_carry_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[0] != bucket_ordinal) { + continue; + } + if (!forward_executor_output_binding_is_active(forward_executor_binding_rows, row[2])) { + continue; + } + const int64_t source_tensor_index = program_tensor_index_for_binding(program_tensor_binding_rows, row[2]); + TORCH_CHECK( + source_tensor_index >= 0, + "forward transition state-carry source binding has no compiler tensor-table slot"); + if (source_tensor_index == tensor_index) { + return true; + } + } + return false; +} + +inline bool forward_transition_has_active_state_carry_sources( + const at::Tensor& forward_transition_state_carry_rows, + const at::Tensor& forward_executor_binding_rows) { + check_forward_transition_state_carry_rows(forward_transition_state_carry_rows); + const int64_t* rows = forward_transition_state_carry_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_transition_state_carry_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (forward_executor_output_binding_is_active(forward_executor_binding_rows, row[2])) { + return true; + } + } + return false; +} + +inline void clear_forward_transition_dead_input_binding_slots_after_primitive( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& forward_transition_state_carry_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + const std::vector& input_bindings) { + for (const int64_t binding_index : input_bindings) { + if (forward_transition_binding_is_current_output( + forward_executor_binding_rows, + span, + binding_index, + primitive_row_index)) { + continue; + } + if (forward_transition_binding_is_active_state_carry_source( + forward_transition_state_carry_rows, + forward_executor_binding_rows, + span.bucket_ordinal, + binding_index)) { + continue; + } + if (forward_transition_binding_slot_aliases_active_state_carry_source( + program_tensor_binding_rows, + forward_transition_state_carry_rows, + forward_executor_binding_rows, + span.bucket_ordinal, + binding_index)) { + continue; + } + if (forward_transition_binding_has_future_input_use( + forward_executor_binding_rows, + span, + binding_index, + primitive_row_index)) { + continue; + } + const int64_t tensor_index = program_tensor_index_for_binding(program_tensor_binding_rows, binding_index); + TORCH_CHECK( + tensor_index >= 0 && tensor_index < static_cast(program_tensors.size()), + "forward transition input binding references tensor table index outside program_tensors"); + at::Tensor& tensor = program_tensors[static_cast(tensor_index)]; + tensor = empty_program_tensor_like(tensor); + } +} + +inline void check_temporal_program_access_rows(const at::Tensor& access_rows, const char* row_name) { + check_cpu_long_rank2(access_rows, row_name, 6); + const int64_t* rows = access_rows.data_ptr(); + for (int64_t row_index = 0; row_index < access_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + TORCH_CHECK(row[0] >= 0, row_name, " row ", row_index, " has invalid executor-local access slot"); + TORCH_CHECK(row[1] >= 0, row_name, " row ", row_index, " has invalid executor row index"); + TORCH_CHECK( + row[2] >= 0 || row[2] == kTemporalMessageBucketOrdinal || row[2] == kTemporalReadoutBucketOrdinal, + row_name, + " row ", + row_index, + " has invalid bucket ordinal"); + TORCH_CHECK(row[3] >= 0, row_name, " row ", row_index, " has invalid binding index"); + TORCH_CHECK(row[4] == 0 || row[4] == 1, row_name, " row ", row_index, " has invalid required flag"); + TORCH_CHECK(row[5] > 0, row_name, " row ", row_index, " has invalid compiler access opcode"); + } +} + +inline void check_forward_program_access_rows(const at::Tensor& access_rows) { + check_temporal_program_access_rows(access_rows, "forward_program_access_rows"); +} + +inline void check_reverse_program_access_rows(const at::Tensor& access_rows) { + check_temporal_program_access_rows(access_rows, "reverse_program_access_rows"); +} + +inline void check_forward_transition_state_carry_rows(const at::Tensor& carry_rows) { + check_cpu_long_rank2(carry_rows, "forward_transition_state_carry_rows", 3); + const int64_t* rows = carry_rows.data_ptr(); + for (int64_t row_index = 0; row_index < carry_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + TORCH_CHECK(row[0] >= 0, "forward_transition_state_carry_rows row ", row_index, " has invalid bucket ordinal"); + TORCH_CHECK(row[1] >= 0, "forward_transition_state_carry_rows row ", row_index, " has invalid input binding"); + TORCH_CHECK(row[2] >= 0, "forward_transition_state_carry_rows row ", row_index, " has invalid output binding"); + } +} + +inline void compact_forward_program_tensor_table_for_return( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_transition_state_carry_rows, + bool return_final_program_tensors) { + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_forward_transition_state_carry_rows(forward_transition_state_carry_rows); + std::vector retained_tensor_indices; + if (return_final_program_tensors) { + const int64_t* rows = forward_transition_state_carry_rows.data_ptr(); + for (int64_t row_index = 0; row_index < forward_transition_state_carry_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + const int64_t tensor_index = program_tensor_index_for_binding(program_tensor_binding_rows, row[1]); + TORCH_CHECK( + tensor_index >= 0 && tensor_index < static_cast(program_tensors.size()), + "forward final program tensor compaction references tensor table index outside program_tensors"); + if (std::find(retained_tensor_indices.begin(), retained_tensor_indices.end(), tensor_index) == + retained_tensor_indices.end()) { + retained_tensor_indices.push_back(tensor_index); + } + } + } + for (int64_t tensor_index = 0; tensor_index < static_cast(program_tensors.size()); ++tensor_index) { + const bool keep_tensor = + return_final_program_tensors && + std::find(retained_tensor_indices.begin(), retained_tensor_indices.end(), tensor_index) != + retained_tensor_indices.end(); + if (!keep_tensor) { + at::Tensor& tensor = program_tensors[static_cast(tensor_index)]; + tensor = empty_program_tensor_like(tensor); + } + } +} + +inline int64_t count_forward_transition_state_carry_rows_for_bucket( + const at::Tensor& carry_rows, + int64_t bucket_ordinal) { + const int64_t* rows = carry_rows.data_ptr(); + int64_t count = 0; + for (int64_t row_index = 0; row_index < carry_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[0] == bucket_ordinal) { + ++count; + } + } + return count; +} + +inline void validate_forward_transition_state_carry_contract( + const at::Tensor& forward_spans, + const at::Tensor& forward_transition_state_carry_rows) { + check_cpu_long_rank2(forward_spans, "forward_spans", kFusedProgramSpanColumns); + check_forward_transition_state_carry_rows(forward_transition_state_carry_rows); + for (int64_t span_index = 0; span_index < forward_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(forward_spans, span_index); + if (span.direction_opcode != kForwardDirectionOpcode || span.surface_opcode != kTransitionSurfaceOpcode) { + continue; + } + const int64_t carry_count = + count_forward_transition_state_carry_rows_for_bucket(forward_transition_state_carry_rows, span.bucket_ordinal); + if (span.state_carry_rule_count == 0) { + continue; + } + TORCH_CHECK( + carry_count <= span.state_carry_rule_count, + "forward transition state-carry row count exceeds compiler strategy contract for bucket ", + span.bucket_ordinal, + ": maximum=", + span.state_carry_rule_count, + ", actual=", + carry_count); + } +} + +inline int64_t temporal_program_access_binding_by_local_index( + const at::Tensor& access_rows, + int64_t local_binding_index, + int64_t executor_row_index, + int64_t bucket_ordinal, + const char* direction, + const char* subject) { + const int64_t* rows = access_rows.data_ptr(); + int64_t binding_index = -1; + for (int64_t row_index = 0; row_index < access_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + if (row[0] != local_binding_index || row[1] != executor_row_index || row[2] != bucket_ordinal) { + continue; + } + TORCH_CHECK(binding_index < 0 || binding_index == row[3], subject, " has duplicate access binding rows"); + binding_index = row[3]; + } + TORCH_CHECK(binding_index >= 0, subject, " has no compiler-owned ", direction, " native access binding"); + return binding_index; +} + +inline int64_t temporal_program_access_binding_by_opcode( + const at::Tensor& access_rows, + int64_t access_opcode, + int64_t executor_row_index, + int64_t bucket_ordinal, + bool required, + const char* direction, + const char* subject) { + const int64_t* rows = access_rows.data_ptr(); + int64_t binding_index = -1; + for (int64_t row_index = 0; row_index < access_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 6; + if (row[1] != executor_row_index || row[2] != bucket_ordinal || row[5] != access_opcode) { + continue; + } + TORCH_CHECK(binding_index < 0 || binding_index == row[3], subject, " has duplicate compiler access rows"); + binding_index = row[3]; + } + TORCH_CHECK( + binding_index >= 0 || !required, + subject, + " has no compiler-owned ", + direction, + " program access binding for opcode ", + access_opcode); + return binding_index; +} + +inline int64_t program_binding_for_native_strategy_access( + const at::Tensor& access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + int64_t binding_kind, + const char* logical_name, + bool required, + int64_t executor_row_index, + int64_t bucket_ordinal, + const char* direction, + const char* subject) { + const int64_t local_binding_index = native_callable_local_binding_index_for( + native_callable_binding_schema_rows, + native_strategy.native_callable_hash, + native_strategy.direction_opcode, + native_strategy.surface_opcode, + native_strategy.primitive_opcode, + binding_kind, + registered_temporal_stable_id_hash_constexpr(logical_name), + required, + native_strategy.tensor_binding_schema_version, + subject); + if (local_binding_index < 0) { + return -1; + } + return temporal_program_access_binding_by_local_index( + access_rows, + local_binding_index, + executor_row_index, + bucket_ordinal, + direction, + subject); +} + +inline at::Tensor program_tensor_for_native_strategy_access( + const std::vector& tensor_table, + const at::Tensor& tensor_binding_rows, + const at::Tensor& access_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredNativeStrategyRow& native_strategy, + int64_t binding_kind, + const char* logical_name, + bool required, + bool allow_empty, + int64_t executor_row_index, + int64_t bucket_ordinal, + const char* direction, + const char* subject) { + const int64_t binding_index = program_binding_for_native_strategy_access( + access_rows, + native_callable_binding_schema_rows, + native_strategy, + binding_kind, + logical_name, + required, + executor_row_index, + bucket_ordinal, + direction, + subject); + if (binding_index < 0) { + return at::Tensor(); + } + return allow_empty + ? program_tensor_for_binding_allow_empty(tensor_table, tensor_binding_rows, binding_index, subject) + : program_tensor_for_binding(tensor_table, tensor_binding_rows, binding_index, subject); +} + +inline void apply_forward_transition_state_carry_rows( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& carry_rows, + int64_t bucket_ordinal) { + const int64_t* rows = carry_rows.data_ptr(); + for (int64_t row_index = 0; row_index < carry_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[0] != bucket_ordinal) { + continue; + } + if (!forward_executor_output_binding_is_active(forward_executor_binding_rows, row[2])) { + continue; + } + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + row[1], + program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + row[2], + "registered fused forward transition state carry source"), + "registered fused forward transition state carry target"); + } +} + +inline std::vector fused_binding_indices_for_primitive( + const at::Tensor& executor_binding_rows, + int64_t primitive_row_index, + int64_t direction_opcode, + int64_t binding_kind) { + const int64_t* rows = executor_binding_rows.data_ptr(); + std::vector> values; + for (int64_t row_index = 0; row_index < executor_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 8; + if (row[0] != direction_opcode || row[3] != primitive_row_index || row[6] != binding_kind) { + continue; + } + values.push_back({row[7], row[4]}); + } + std::sort(values.begin(), values.end(), [](const auto& a, const auto& b) { return a.first < b.first; }); + std::vector bindings; + bindings.reserve(values.size()); + for (const auto& item : values) { + bindings.push_back(item.second); + } + return bindings; +} + +inline std::vector fused_input_bindings_for_primitive( + const at::Tensor& executor_binding_rows, + int64_t primitive_row_index) { + return fused_binding_indices_for_primitive( + executor_binding_rows, + primitive_row_index, + kForwardDirectionOpcode, + 0); +} + +inline std::vector fused_parameter_bindings_for_primitive( + const at::Tensor& executor_binding_rows, + int64_t primitive_row_index) { + return fused_binding_indices_for_primitive( + executor_binding_rows, + primitive_row_index, + kForwardDirectionOpcode, + kParameterBindingKindOpcode); +} + +inline std::vector fused_output_bindings_for_primitive( + const at::Tensor& executor_binding_rows, + int64_t primitive_row_index) { + return fused_binding_indices_for_primitive( + executor_binding_rows, + primitive_row_index, + kForwardDirectionOpcode, + 2); +} + +inline std::vector fused_reverse_input_bindings_for_primitive( + const at::Tensor& executor_binding_rows, + int64_t primitive_row_index) { + return fused_binding_indices_for_primitive( + executor_binding_rows, + primitive_row_index, + kReverseDirectionOpcode, + 0); +} + +inline std::vector fused_reverse_parameter_bindings_for_primitive( + const at::Tensor& executor_binding_rows, + int64_t primitive_row_index) { + return fused_binding_indices_for_primitive( + executor_binding_rows, + primitive_row_index, + kReverseDirectionOpcode, + kParameterBindingKindOpcode); +} + +inline std::vector fused_reverse_output_bindings_for_primitive( + const at::Tensor& executor_binding_rows, + int64_t primitive_row_index) { + return fused_binding_indices_for_primitive( + executor_binding_rows, + primitive_row_index, + kReverseDirectionOpcode, + 2); +} + +inline void require_binding_count( + const std::vector& bindings, + int64_t expected, + const char* subject) { + TORCH_CHECK( + static_cast(bindings.size()) == expected, + subject, + " expected ", + expected, + " compiler tensor bindings, got ", + bindings.size()); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/reverse_artifacts_and_resets.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/reverse_artifacts_and_resets.cuh new file mode 100644 index 00000000..ab1dd0f0 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/reverse_artifacts_and_resets.cuh @@ -0,0 +1,754 @@ +#pragma once + +inline bool reverse_artifact_role_is_tensor(int64_t role_id) { + switch (role_id) { + case kReverseArtifactBoundaryStep: + case kReverseArtifactCellsPrev: + case kReverseArtifactInputK: + case kReverseArtifactInputV: + case kReverseArtifactRecurrentKBefore: + case kReverseArtifactRecurrentVBefore: + case kReverseArtifactRecurrentK: + case kReverseArtifactRecurrentV: + case kReverseArtifactRecurrentHiddenBeforeBackendOrder: + case kReverseArtifactRecurrentHiddenBackendOrder: + case kReverseArtifactRecurrentMsgBackendOrder: + case kReverseArtifactOutputMsg: + case kReverseArtifactOutputCells: + case kReverseArtifactTransitionStateBefore: + return true; + default: + TORCH_CHECK(false, "unknown temporal reverse artifact role id ", role_id); + } +} + +inline std::vector validate_temporal_reverse_artifact_role_rows( + const at::Tensor& reverse_artifact_role_rows) { + check_cpu_long_rank2(reverse_artifact_role_rows, "reverse_artifact_role_rows", 3); + std::vector tensor_required(kReverseArtifactMaxRole + 1, -1); + const int64_t* rows = reverse_artifact_role_rows.data_ptr(); + for (int64_t row_index = 0; row_index < reverse_artifact_role_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + const int64_t role_index = row[0]; + const int64_t role_id = row[1]; + const int64_t required = row[2]; + TORCH_CHECK(role_index == row_index, "reverse_artifact_role_rows row index mismatch at row ", row_index); + TORCH_CHECK( + role_id > 0 && role_id <= kReverseArtifactMaxRole, + "reverse_artifact_role_rows row ", + row_index, + " has invalid role id ", + role_id); + TORCH_CHECK( + required == 0 || required == 1, + "reverse_artifact_role_rows row ", + row_index, + " has invalid tensor-required flag"); + TORCH_CHECK( + required == static_cast(reverse_artifact_role_is_tensor(role_id)), + "reverse_artifact_role_rows row ", + row_index, + " has mismatched tensor-required flag"); + TORCH_CHECK(tensor_required[role_id] < 0, "duplicate temporal reverse artifact role id ", role_id); + tensor_required[role_id] = required; + } + return tensor_required; +} + +inline void validate_temporal_reverse_artifact_binding_rows( + const std::vector& reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const std::vector& tensor_required, + int64_t local_time_steps) { + check_cpu_long_rank2( + reverse_artifact_binding_rows, + "reverse_artifact_binding_rows", + kReverseArtifactBindingRowColumns); + TORCH_CHECK(local_time_steps > 0, "reverse artifact binding validation requires a non-empty window"); + const int64_t* rows = reverse_artifact_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < reverse_artifact_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseArtifactBindingRowColumns; + const int64_t role_id = row[0]; + const int64_t tensor_index = row[1]; + const int64_t local_step = row[2]; + const int64_t flags = row[3]; + const int64_t route_row = row[4]; + TORCH_CHECK( + role_id > 0 && role_id <= kReverseArtifactMaxRole && tensor_required[role_id] >= 0, + "reverse_artifact_binding_rows row ", + row_index, + " references an unregistered reverse artifact role"); + TORCH_CHECK( + tensor_required[role_id] == 1, + "reverse_artifact_binding_rows row ", + row_index, + " binds non-tensor reverse artifact role"); + TORCH_CHECK( + tensor_index >= 0 && tensor_index < static_cast(reverse_artifact_tensors.size()), + "reverse_artifact_binding_rows row ", + row_index, + " has invalid tensor index"); + TORCH_CHECK( + local_step >= 0 && local_step < local_time_steps, + "reverse_artifact_binding_rows row ", + row_index, + " has invalid local step"); + TORCH_CHECK(flags >= 0, "reverse_artifact_binding_rows row ", row_index, " has invalid flags"); + TORCH_CHECK(route_row >= 0, "reverse_artifact_binding_rows row ", row_index, " has invalid artifact route row"); + TORCH_CHECK( + reverse_artifact_tensors[static_cast(tensor_index)].defined(), + "reverse artifact tensor table contains an undefined tensor at index ", + tensor_index); + } +} + +inline bool reverse_artifact_access_is_known(int64_t access_id) { + return access_id > 0 && access_id <= kReverseArtifactMaxAccess; +} + +inline std::vector validate_temporal_reverse_artifact_access_rows( + const at::Tensor& reverse_artifact_access_rows, + const std::vector& tensor_required) { + check_cpu_long_rank2(reverse_artifact_access_rows, "reverse_artifact_access_rows", 3); + std::vector role_for_access(kReverseArtifactMaxAccess + 1, -1); + const int64_t* rows = reverse_artifact_access_rows.data_ptr(); + for (int64_t row_index = 0; row_index < reverse_artifact_access_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + const int64_t access_id = row[0]; + const int64_t role_id = row[1]; + const int64_t required = row[2]; + TORCH_CHECK( + reverse_artifact_access_is_known(access_id), + "reverse_artifact_access_rows row ", + row_index, + " has invalid access id ", + access_id); + TORCH_CHECK( + role_id > 0 && role_id <= kReverseArtifactMaxRole && tensor_required[role_id] == 1, + "reverse_artifact_access_rows row ", + row_index, + " references an unregistered tensor reverse artifact role"); + TORCH_CHECK( + required == 0 || required == 1, + "reverse_artifact_access_rows row ", + row_index, + " has invalid required flag"); + TORCH_CHECK(role_for_access[access_id] < 0, "duplicate temporal reverse artifact access id ", access_id); + role_for_access[access_id] = role_id; + } + return role_for_access; +} + +inline int64_t reverse_artifact_binding_window_len(const at::Tensor& reverse_artifact_binding_rows) { + check_cpu_long_rank2( + reverse_artifact_binding_rows, + "reverse_artifact_binding_rows", + kReverseArtifactBindingRowColumns); + const int64_t* rows = reverse_artifact_binding_rows.data_ptr(); + int64_t max_step = -1; + for (int64_t row_index = 0; row_index < reverse_artifact_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseArtifactBindingRowColumns; + max_step = std::max(max_step, row[2]); + } + return max_step + 1; +} + +inline int64_t reverse_artifact_role_for_access( + const at::Tensor& reverse_artifact_access_rows, + int64_t access_id, + const char* subject) { + check_cpu_long_rank2(reverse_artifact_access_rows, "reverse_artifact_access_rows", 3); + const int64_t* rows = reverse_artifact_access_rows.data_ptr(); + int64_t role_id = -1; + for (int64_t row_index = 0; row_index < reverse_artifact_access_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[0] != access_id) { + continue; + } + TORCH_CHECK(role_id < 0, subject, " has duplicate reverse artifact access id ", access_id); + role_id = row[1]; + } + TORCH_CHECK(role_id > 0, subject, " is missing reverse artifact access id ", access_id); + return role_id; +} + +inline int64_t try_reverse_artifact_role_for_access( + const at::Tensor& reverse_artifact_access_rows, + int64_t access_id) { + check_cpu_long_rank2(reverse_artifact_access_rows, "reverse_artifact_access_rows", 3); + const int64_t* rows = reverse_artifact_access_rows.data_ptr(); + int64_t role_id = -1; + for (int64_t row_index = 0; row_index < reverse_artifact_access_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[0] != access_id) { + continue; + } + TORCH_CHECK(role_id < 0, "duplicate reverse artifact access id ", access_id); + role_id = row[1]; + } + return role_id; +} + +inline at::Tensor reverse_artifact_tensor_for_access_step( + const std::vector& reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_access_rows, + int64_t access_id, + int64_t local_step, + const char* subject) { + const int64_t role_id = reverse_artifact_role_for_access(reverse_artifact_access_rows, access_id, subject); + const int64_t* rows = reverse_artifact_binding_rows.data_ptr(); + int64_t tensor_index = -1; + for (int64_t row_index = 0; row_index < reverse_artifact_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseArtifactBindingRowColumns; + if (row[0] != role_id || row[2] != local_step) { + continue; + } + TORCH_CHECK( + tensor_index < 0, + subject, + " has duplicate reverse artifact binding for role ", + role_id, + " from access ", + access_id, + " at local step ", + local_step); + tensor_index = row[1]; + } + TORCH_CHECK( + tensor_index >= 0, + subject, + " is missing reverse artifact binding for role ", + role_id, + " from access ", + access_id, + " at local step ", + local_step); + TORCH_CHECK( + tensor_index < static_cast(reverse_artifact_tensors.size()), + subject, + " reverse artifact tensor index is out of range"); + return reverse_artifact_tensors[static_cast(tensor_index)]; +} + +inline at::Tensor reverse_artifact_tensor_for_routed_access_step( + const std::vector& reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + int64_t surface_opcode, + int64_t executor_row_index, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t access_id, + int64_t local_step, + int64_t schema_version, + const char* subject) { + const int64_t role_id = reverse_artifact_role_for_access(reverse_artifact_access_rows, access_id, subject); + const int64_t route_row = reverse_artifact_consumer_forward_route_row_for( + reverse_artifact_consumer_route_rows, + forward_artifact_route_rows, + surface_opcode, + executor_row_index, + executor_id, + bucket_ordinal, + role_id, + schema_version, + subject); + const int64_t* rows = reverse_artifact_binding_rows.data_ptr(); + int64_t tensor_index = -1; + for (int64_t row_index = 0; row_index < reverse_artifact_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseArtifactBindingRowColumns; + if (row[0] != role_id || row[2] != local_step || row[4] != route_row) { + continue; + } + TORCH_CHECK( + tensor_index < 0, + subject, + " has duplicate reverse artifact binding for role ", + role_id, + " from access ", + access_id, + " at local step ", + local_step, + " and compiler route row ", + route_row); + tensor_index = row[1]; + } + TORCH_CHECK( + tensor_index >= 0, + subject, + " is missing reverse artifact binding for role ", + role_id, + " from access ", + access_id, + " at local step ", + local_step, + " and compiler route row ", + route_row); + TORCH_CHECK( + tensor_index < static_cast(reverse_artifact_tensors.size()), + subject, + " reverse artifact tensor index is out of range"); + return reverse_artifact_tensors[static_cast(tensor_index)]; +} + +inline at::Tensor try_reverse_artifact_tensor_for_routed_access_step( + const std::vector& reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_access_rows, + const at::Tensor& forward_artifact_route_rows, + const at::Tensor& reverse_artifact_consumer_route_rows, + int64_t surface_opcode, + int64_t executor_row_index, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t access_id, + int64_t local_step, + const char* subject) { + const int64_t role_id = try_reverse_artifact_role_for_access(reverse_artifact_access_rows, access_id); + if (role_id < 0) { + return at::Tensor(); + } + const int64_t route_row = try_reverse_artifact_consumer_forward_route_row_for( + reverse_artifact_consumer_route_rows, + surface_opcode, + executor_row_index, + executor_id, + bucket_ordinal, + role_id, + subject); + if (route_row < 0) { + return at::Tensor(); + } + const int64_t* rows = reverse_artifact_binding_rows.data_ptr(); + int64_t tensor_index = -1; + for (int64_t row_index = 0; row_index < reverse_artifact_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseArtifactBindingRowColumns; + if (row[0] != role_id || row[2] != local_step || row[4] != route_row) { + continue; + } + TORCH_CHECK( + tensor_index < 0, + subject, + " has duplicate optional reverse artifact binding for role ", + role_id, + " at local step ", + local_step, + " and compiler route row ", + route_row); + tensor_index = row[1]; + } + if (tensor_index < 0) { + return at::Tensor(); + } + TORCH_CHECK( + tensor_index < static_cast(reverse_artifact_tensors.size()), + subject, + " optional reverse artifact tensor index is out of range"); + return reverse_artifact_tensors[static_cast(tensor_index)]; +} + +inline void validate_temporal_reverse_reset_rows( + const std::vector& reverse_reset_tensors, + const at::Tensor& reverse_reset_rows, + int64_t batch_size) { + check_cpu_long_rank2(reverse_reset_rows, "reverse_reset_rows", 4); + const int64_t* rows = reverse_reset_rows.data_ptr(); + bool saw_message = false; + bool saw_transition = false; + for (int64_t row_index = 0; row_index < reverse_reset_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 4; + const int64_t reset_kind = row[0]; + const int64_t tensor_index = row[1]; + const int64_t policy = row[2]; + const int64_t scope = row[3]; + TORCH_CHECK( + reset_kind == kReverseResetMessage || reset_kind == kReverseResetTransition, + "reverse_reset_rows row ", + row_index, + " has invalid reset kind ", + reset_kind); + TORCH_CHECK( + tensor_index >= 0 && tensor_index < static_cast(reverse_reset_tensors.size()), + "reverse_reset_rows row ", + row_index, + " has invalid tensor index"); + TORCH_CHECK( + policy == kReverseResetPolicyZeroSourceRows, + "reverse_reset_rows row ", + row_index, + " has unsupported reset policy"); + TORCH_CHECK( + scope == kReverseResetScopeBatchRow, + "reverse_reset_rows row ", + row_index, + " has unsupported reset scope"); + if (reset_kind == kReverseResetMessage) { + TORCH_CHECK(!saw_message, "reverse_reset_rows has duplicate message reset row"); + saw_message = true; + } else { + TORCH_CHECK(!saw_transition, "reverse_reset_rows has duplicate transition reset row"); + saw_transition = true; + } + const at::Tensor& tensor = reverse_reset_tensors[static_cast(tensor_index)]; + check_cuda_bool_rank1(tensor, "reverse reset tensor"); + TORCH_CHECK(tensor.size(0) == batch_size, "reverse reset tensor batch size mismatch"); + } +} + +inline at::Tensor reverse_reset_tensor_for_kind( + const std::vector& reverse_reset_tensors, + const at::Tensor& reverse_reset_rows, + int64_t reset_kind) { + check_cpu_long_rank2(reverse_reset_rows, "reverse_reset_rows", 4); + const int64_t* rows = reverse_reset_rows.data_ptr(); + int64_t tensor_index = -1; + for (int64_t row_index = 0; row_index < reverse_reset_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 4; + if (row[0] != reset_kind) { + continue; + } + TORCH_CHECK(tensor_index < 0, "reverse_reset_rows has duplicate reset kind ", reset_kind); + tensor_index = row[1]; + } + if (tensor_index < 0) { + return at::Tensor(); + } + TORCH_CHECK( + tensor_index < static_cast(reverse_reset_tensors.size()), + "reverse reset tensor index is out of range"); + return reverse_reset_tensors[static_cast(tensor_index)]; +} + +inline void check_forward_reset_rows( + const std::vector& forward_reset_tensors, + const at::Tensor& forward_reset_rows, + int64_t batch_size, + int64_t outer_steps) { + check_cpu_long_rank2(forward_reset_rows, "forward_reset_rows", 4); + const int64_t* rows = forward_reset_rows.data_ptr(); + bool saw_message = false; + bool saw_transition = false; + for (int64_t row_index = 0; row_index < forward_reset_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 4; + const int64_t reset_kind = row[0]; + const int64_t tensor_index = row[1]; + TORCH_CHECK( + reset_kind == kForwardResetMessage || reset_kind == kForwardResetTransition, + "forward_reset_rows row ", + row_index, + " has unsupported reset kind"); + TORCH_CHECK( + tensor_index >= 0 && tensor_index < static_cast(forward_reset_tensors.size()), + "forward_reset_rows row ", + row_index, + " has reset tensor index out of range"); + TORCH_CHECK( + row[2] == kForwardResetPolicyZeroSourceRows, + "forward_reset_rows row ", + row_index, + " has unsupported reset policy"); + TORCH_CHECK( + row[3] == kForwardResetScopeBatchOuterStep, + "forward_reset_rows row ", + row_index, + " has unsupported reset scope"); + if (reset_kind == kForwardResetMessage) { + TORCH_CHECK(!saw_message, "forward_reset_rows has duplicate message reset row"); + saw_message = true; + } else { + TORCH_CHECK(!saw_transition, "forward_reset_rows has duplicate transition reset row"); + saw_transition = true; + } + const at::Tensor& tensor = forward_reset_tensors[static_cast(tensor_index)]; + check_cuda_bool_rank2(tensor, "forward reset tensor"); + TORCH_CHECK(tensor.size(0) == batch_size, "forward reset tensor batch size mismatch"); + TORCH_CHECK(tensor.size(1) == outer_steps, "forward reset tensor time dimension mismatch"); + } +} + +inline at::Tensor forward_reset_tensor_for_kind( + const std::vector& forward_reset_tensors, + const at::Tensor& forward_reset_rows, + int64_t reset_kind) { + check_cpu_long_rank2(forward_reset_rows, "forward_reset_rows", 4); + const int64_t* rows = forward_reset_rows.data_ptr(); + int64_t tensor_index = -1; + for (int64_t row_index = 0; row_index < forward_reset_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 4; + if (row[0] != reset_kind) { + continue; + } + TORCH_CHECK(tensor_index < 0, "forward_reset_rows has duplicate reset kind ", reset_kind); + tensor_index = row[1]; + } + if (tensor_index < 0) { + return at::Tensor(); + } + TORCH_CHECK( + tensor_index < static_cast(forward_reset_tensors.size()), + "forward reset tensor index is out of range"); + return forward_reset_tensors[static_cast(tensor_index)]; +} + +inline at::Tensor forward_reset_step_tensor(const at::Tensor& reset_seq, int64_t outer_step) { + if (!reset_seq.defined() || reset_seq.numel() == 0) { + return at::Tensor(); + } + check_cuda_bool_rank2(reset_seq, "forward reset tensor"); + TORCH_CHECK(outer_step >= 0 && outer_step < reset_seq.size(1), "forward reset outer step is out of range"); + return reset_seq.select(1, outer_step).contiguous(); +} + +inline at::Tensor zero_batch_rows_for_reset( + const at::Tensor& tensor, + const at::Tensor& reset, + const char* subject) { + if (!tensor.defined() || tensor.numel() == 0 || !reset.defined() || reset.numel() == 0) { + return tensor; + } + check_cuda_float_bank(tensor, subject); + check_cuda_bool_rank1(reset, "reverse reset tensor"); + TORCH_CHECK(tensor.size(0) == reset.size(0), subject, " batch size must match reset tensor"); + at::Tensor mask = reset.view({reset.size(0), 1, 1}); + return at::where(mask, at::zeros_like(tensor), tensor); +} + +inline void apply_transition_state_reset_outputs( + std::vector>& transition_outputs, + const at::Tensor& transition_state_reset_rows, + const at::Tensor& transition_reset) { + check_cpu_long_rank2(transition_state_reset_rows, "transition_state_reset_rows", 2); + if (transition_state_reset_rows.size(0) == 0) { + return; + } + if (!transition_reset.defined() || transition_reset.numel() == 0) { + return; + } + const int64_t* rows = transition_state_reset_rows.data_ptr(); + for (int64_t row_index = 0; row_index < transition_state_reset_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 2; + const int64_t group_index = row[0]; + const int64_t output_slot = row[1]; + TORCH_CHECK( + group_index >= 0 && group_index < static_cast(transition_outputs.size()), + "transition_state_reset_rows row ", + row_index, + " has invalid transition group"); + std::vector& group = transition_outputs[static_cast(group_index)]; + TORCH_CHECK( + output_slot >= 0 && output_slot < static_cast(group.size()), + "transition_state_reset_rows row ", + row_index, + " has invalid output slot"); + at::Tensor& value = group[static_cast(output_slot)]; + if (value.defined() && value.numel() > 0) { + value = zero_batch_rows_for_reset(value, transition_reset, "transition state grad reset output"); + } + } +} + +inline at::Tensor optional_reverse_artifact_tensor_for_transition_state_binding( + const std::vector& reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_access_rows, + int64_t local_step, + int64_t bucket_ordinal, + int64_t binding_index, + const char* subject) { + const int64_t expected_flags = bucket_ordinal * kTransitionStateArtifactFlagStride + binding_index; + const int64_t role_id = reverse_artifact_role_for_access( + reverse_artifact_access_rows, + kReverseArtifactAccessTransitionStateBefore, + subject); + const int64_t* rows = reverse_artifact_binding_rows.data_ptr(); + int64_t tensor_index = -1; + for (int64_t row_index = 0; row_index < reverse_artifact_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kReverseArtifactBindingRowColumns; + if (row[0] != role_id || row[2] != local_step || row[3] != expected_flags) { + continue; + } + TORCH_CHECK(tensor_index < 0, subject, " has duplicate transition state-before reverse artifact binding"); + tensor_index = row[1]; + } + if (tensor_index < 0) { + return at::Tensor(); + } + TORCH_CHECK( + tensor_index < static_cast(reverse_artifact_tensors.size()), + subject, + " transition state-before reverse artifact tensor index is out of range"); + return reverse_artifact_tensors[static_cast(tensor_index)]; +} + +inline at::Tensor reverse_artifact_tensor_for_transition_state_binding( + const std::vector& reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_access_rows, + int64_t local_step, + int64_t bucket_ordinal, + int64_t binding_index, + const char* subject) { + at::Tensor tensor = optional_reverse_artifact_tensor_for_transition_state_binding( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + local_step, + bucket_ordinal, + binding_index, + subject); + TORCH_CHECK(tensor.defined(), subject, " is missing transition state-before reverse artifact binding"); + return tensor; +} + +inline void check_transition_seed_rows( + const std::vector& transition_seed_tensors, + const at::Tensor& transition_reverse_seed_role_rows, + const at::Tensor& transition_seed_rows) { + check_cpu_long_rank2( + transition_reverse_seed_role_rows, + "transition_reverse_seed_role_rows", + kTransitionReverseSeedRoleRowColumns); + const int64_t* seed_role_rows = transition_reverse_seed_role_rows.data_ptr(); + auto seed_role_is_registered = [&](int64_t role_id) { + for (int64_t role_index = 0; role_index < transition_reverse_seed_role_rows.size(0); ++role_index) { + const int64_t* role = seed_role_rows + role_index * kTransitionReverseSeedRoleRowColumns; + if (role[0] == role_id) { + return true; + } + } + return false; + }; + check_cpu_long_rank2(transition_seed_rows, "transition_seed_rows", 3); + const int64_t* rows = transition_seed_rows.data_ptr(); + for (int64_t row_index = 0; row_index < transition_seed_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + TORCH_CHECK( + seed_role_is_registered(row[0]), + "transition_seed_rows row ", + row_index, + " has invalid seed role"); + TORCH_CHECK( + row[1] >= 0 && row[1] < static_cast(transition_seed_tensors.size()), + "transition_seed_rows row ", + row_index, + " has invalid tensor slot"); + TORCH_CHECK(row[2] >= 0, "transition_seed_rows row ", row_index, " has invalid bucket ordinal"); + TORCH_CHECK( + transition_seed_tensors[static_cast(row[1])].defined(), + "transition seed tensor table contains an undefined tensor at index ", + row[1]); + } +} + +inline void check_transition_dynamic_binding_rows( + const at::Tensor& transition_dynamic_binding_rows, + const at::Tensor& transition_reverse_seed_role_rows) { + check_cpu_long_rank2( + transition_reverse_seed_role_rows, + "transition_reverse_seed_role_rows", + kTransitionReverseSeedRoleRowColumns); + const int64_t* seed_role_rows = transition_reverse_seed_role_rows.data_ptr(); + auto seed_role_is_registered = [&](int64_t role_id) { + for (int64_t role_index = 0; role_index < transition_reverse_seed_role_rows.size(0); ++role_index) { + const int64_t* role = seed_role_rows + role_index * kTransitionReverseSeedRoleRowColumns; + if (role[0] == role_id) { + return true; + } + } + return false; + }; + check_cpu_long_rank2(transition_dynamic_binding_rows, "transition_dynamic_binding_rows", 5); + const int64_t* rows = transition_dynamic_binding_rows.data_ptr(); + for (int64_t row_index = 0; row_index < transition_dynamic_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 5; + TORCH_CHECK(row[0] >= 0, "transition_dynamic_binding_rows row ", row_index, " has invalid destination binding"); + TORCH_CHECK( + row[1] == kTransitionDynamicSourceReverseArtifact || + row[1] == kTransitionDynamicSourceStateBeforeArtifact || + row[1] == kTransitionDynamicSourceSeedOrZeros, + "transition_dynamic_binding_rows row ", + row_index, + " has invalid source kind"); + TORCH_CHECK(row[2] >= 0, "transition_dynamic_binding_rows row ", row_index, " has invalid source key"); + TORCH_CHECK( + row[4] == 0 || row[4] == 1, + "transition_dynamic_binding_rows row ", + row_index, + " has invalid required flag"); + if (row[1] == kTransitionDynamicSourceSeedOrZeros) { + TORCH_CHECK( + seed_role_is_registered(row[2]), + "transition_dynamic_binding_rows row ", + row_index, + " has invalid seed role"); + TORCH_CHECK( + row[3] >= 0, + "transition_dynamic_binding_rows row ", + row_index, + " seed row has invalid template binding"); + } + } +} + +inline void check_transition_parameter_rows( + const std::vector& transition_parameter_tensors, + const at::Tensor& transition_parameter_rows) { + check_cpu_long_rank2(transition_parameter_rows, "transition_parameter_rows", 3); + const int64_t* rows = transition_parameter_rows.data_ptr(); + for (int64_t row_index = 0; row_index < transition_parameter_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + TORCH_CHECK(row[0] >= 0, "transition_parameter_rows row ", row_index, " has invalid binding index"); + TORCH_CHECK( + row[1] >= 0 && row[1] < static_cast(transition_parameter_tensors.size()), + "transition_parameter_rows row ", + row_index, + " has invalid tensor slot"); + TORCH_CHECK(row[2] >= 0, "transition_parameter_rows row ", row_index, " has invalid bucket ordinal"); + TORCH_CHECK( + transition_parameter_tensors[static_cast(row[1])].defined(), + "transition parameter tensor table contains an undefined tensor at index ", + row[1]); + } +} + +inline bool transition_seed_zero_cache_matches(const at::Tensor& cached, const at::Tensor& reference) { + return cached.defined() && cached.sizes() == reference.sizes() && cached.scalar_type() == reference.scalar_type() && + cached.device() == reference.device() && cached.layout() == reference.layout(); +} + +inline at::Tensor transition_seed_tensor_or_cached_zeros( + const std::vector& transition_seed_tensors, + const at::Tensor& transition_seed_rows, + int64_t role_id, + int64_t bucket_ordinal, + const at::Tensor& reference, + std::vector& cached_zero_seed_tensors, + const char* subject) { + const int64_t* rows = transition_seed_rows.data_ptr(); + int64_t tensor_index = -1; + for (int64_t row_index = 0; row_index < transition_seed_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[0] != role_id || row[2] != bucket_ordinal) { + continue; + } + TORCH_CHECK(tensor_index < 0, subject, " has duplicate transition seed binding"); + tensor_index = row[1]; + } + if (tensor_index >= 0) { + TORCH_CHECK(tensor_index < static_cast(transition_seed_tensors.size()), subject, " seed slot out of range"); + at::Tensor tensor = transition_seed_tensors[static_cast(tensor_index)]; + TORCH_CHECK(tensor.sizes() == reference.sizes(), subject, " seed tensor shape mismatch"); + return tensor; + } + for (const at::Tensor& cached : cached_zero_seed_tensors) { + if (transition_seed_zero_cache_matches(cached, reference)) { + return cached; + } + } + at::Tensor zero_seed = at::zeros_like(reference); + cached_zero_seed_tensors.push_back(zero_seed); + return zero_seed; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_device_kernels.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_device_kernels.cuh new file mode 100644 index 00000000..242c5cef --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_device_kernels.cuh @@ -0,0 +1,2744 @@ +#pragma once + +template +__device__ inline float read_partitioned_bank( + const float* __restrict__ input_bank, + const float* __restrict__ recurrent_bank, + int b, + int sender, + int d, + int input_senders, + int recurrent_senders, + int dim) { + if (sender < input_senders) { + return input_bank[(static_cast(b) * input_senders + sender) * dim + d]; + } + const int recurrent_sender = sender - input_senders; + if (recurrent_sender < 0 || recurrent_sender >= recurrent_senders) { + return 0.0f; + } + return recurrent_bank[(static_cast(b) * recurrent_senders + recurrent_sender) * dim + d]; +} + +__device__ inline float read_partitioned_projected_recurrent_value( + const float* __restrict__ input_v, + const float* __restrict__ recurrent_hidden, + const float* __restrict__ recurrent_value_weight, + int b, + int sender, + int v, + int input_senders, + int recurrent_senders, + int hidden_dim, + int value_dim) { + if (sender < input_senders) { + return input_v[(static_cast(b) * input_senders + sender) * value_dim + v]; + } + const int recurrent_sender = sender - input_senders; + if (recurrent_sender < 0 || recurrent_sender >= recurrent_senders) { + return 0.0f; + } + const int64_t hidden_base = + (static_cast(b) * recurrent_senders + recurrent_sender) * hidden_dim; + const int64_t weight_base = static_cast(recurrent_sender) * hidden_dim * value_dim + v; + float acc = 0.0f; + for (int h = 0; h < hidden_dim; ++h) { + acc += recurrent_hidden[hidden_base + h] * recurrent_value_weight[weight_base + static_cast(h) * value_dim]; + } + return acc; +} + +template +__device__ inline float readout_value( + int b, + int output_idx, + int hidden, + const float* __restrict__ input_k, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_k, + const float* __restrict__ recurrent_v, + const float* __restrict__ output_q, + const int32_t* __restrict__ output_local_sender_idx, + const float* __restrict__ local_distance, + const float* __restrict__ value_to_output_weight, + const float* __restrict__ output_cell_bias, + int input_senders, + int recurrent_senders, + int output_count, + int degree, + int head_dim, + int key_dim, + int value_dim, + int hidden_dim, + float distance_scale) { + static_cast(output_count); + const float inv_sqrt_dk = rsqrtf(static_cast(head_dim > 0 ? head_dim : 1)); + float max_logit = -std::numeric_limits::infinity(); + for (int edge = 0; edge < degree; ++edge) { + const int sender = output_local_sender_idx[static_cast(output_idx) * degree + edge]; + if (sender < 0 || sender >= input_senders + recurrent_senders) { + continue; + } + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += output_q[static_cast(output_idx) * head_dim + d] * + read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + key_dim); + } + const float penalty = distance_scale * local_distance[edge]; + max_logit = fmaxf(max_logit, dot * inv_sqrt_dk - penalty); + } + float norm = 0.0f; + float projected_acc = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = output_local_sender_idx[static_cast(output_idx) * degree + edge]; + if (sender < 0 || sender >= input_senders + recurrent_senders) { + continue; + } + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += output_q[static_cast(output_idx) * head_dim + d] * + read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + key_dim); + } + const float penalty = distance_scale * local_distance[edge]; + const float weight = expf(dot * inv_sqrt_dk - penalty - max_logit); + norm += weight; + for (int v = 0; v < value_dim; ++v) { + const float value = read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + projected_acc += weight * + value * + value_to_output_weight[(static_cast(output_idx) * value_dim + v) * hidden_dim + hidden]; + } + } + const float projected = norm > 0.0f ? projected_acc / norm : 0.0f; + return output_cell_bias[static_cast(output_idx) * hidden_dim + hidden] + projected; +} + +__global__ void registered_forward_partitioned_attention_kernel( + const float* __restrict__ q, + const float* __restrict__ input_k, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_k, + const float* __restrict__ recurrent_v, + const int32_t* __restrict__ receiver_sender_idx, + const float* __restrict__ offset_distance, + const int32_t* __restrict__ offset_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ out, + int64_t total_elements, + int B, + int receiver_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int key_dim, + int value_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int v = static_cast(linear % value_dim); + const int receiver = static_cast((linear / value_dim) % receiver_count); + const int b = static_cast(linear / (static_cast(value_dim) * receiver_count)); + const int64_t step_value = use_delay ? step_flat[b] : 0; + float logits[kMaxRegisteredAttentionOffsets]; + int senders[kMaxRegisteredAttentionOffsets]; + float max_logit = -std::numeric_limits::infinity(); + for (int edge = 0; edge < degree; ++edge) { + int sender = -1; + float logit = -std::numeric_limits::infinity(); + if (!use_delay || static_cast(offset_delay[edge]) <= step_value) { + sender = receiver_sender_idx[receiver * degree + edge]; + if (sender >= 0 && sender < input_senders + recurrent_senders) { + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += q[receiver * head_dim + d] * + read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + key_dim); + } + logit = dot * inv_sqrt_dk - distance_scale * offset_distance[edge]; + max_logit = fmaxf(max_logit, logit); + } + } + senders[edge] = sender; + logits[edge] = logit; + } + float norm = 0.0f; + float acc = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const float weight = expf(logits[edge] - max_logit); + norm += weight; + acc += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + out[linear] = norm > 0.0f ? acc / norm : 0.0f; + } +} + +__global__ void registered_backward_partitioned_attention_kernel( + const float* __restrict__ grad_msg, + const float* __restrict__ q, + const float* __restrict__ input_k, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_k, + const float* __restrict__ recurrent_v, + const int32_t* __restrict__ receiver_sender_idx, + const float* __restrict__ offset_distance, + const int32_t* __restrict__ offset_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ grad_q, + float* __restrict__ grad_input_k, + float* __restrict__ grad_input_v, + float* __restrict__ grad_recurrent_k, + float* __restrict__ grad_recurrent_v, + int64_t receiver_total, + int receiver_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int key_dim, + int value_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < receiver_total; + linear += static_cast(blockDim.x) * gridDim.x) { + const int receiver = static_cast(linear % receiver_count); + const int b = static_cast(linear / receiver_count); + const int64_t step_value = use_delay ? step_flat[b] : 0; + float logits[kMaxRegisteredAttentionOffsets]; + float weights[kMaxRegisteredAttentionOffsets]; + int senders[kMaxRegisteredAttentionOffsets]; + float max_logit = -std::numeric_limits::infinity(); + for (int edge = 0; edge < degree; ++edge) { + int sender = -1; + float logit = -std::numeric_limits::infinity(); + if (!use_delay || static_cast(offset_delay[edge]) <= step_value) { + sender = receiver_sender_idx[receiver * degree + edge]; + if (sender >= 0 && sender < input_senders + recurrent_senders) { + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += q[receiver * head_dim + d] * + read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + key_dim); + } + logit = dot * inv_sqrt_dk - distance_scale * offset_distance[edge]; + max_logit = fmaxf(max_logit, logit); + } + } + senders[edge] = sender; + logits[edge] = logit; + } + float norm = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + const float raw_weight = sender >= 0 ? expf(logits[edge] - max_logit) : 0.0f; + weights[edge] = raw_weight; + norm += raw_weight; + } + if (norm <= 0.0f) { + continue; + } + float dweights[kMaxRegisteredAttentionOffsets]; + float expected_dweight = 0.0f; + const int64_t grad_offset = (static_cast(b) * receiver_count + receiver) * value_dim; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + weights[edge] = weights[edge] / norm; + float dweight = 0.0f; + if (sender >= 0) { + for (int v = 0; v < value_dim; ++v) { + dweight += grad_msg[grad_offset + v] * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + } + dweights[edge] = dweight; + expected_dweight += weights[edge] * dweight; + } + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const bool is_input = sender < input_senders; + const int bank_sender = is_input ? sender : sender - input_senders; + const int bank_sender_count = is_input ? input_senders : recurrent_senders; + float* grad_k_bank = is_input ? grad_input_k : grad_recurrent_k; + float* grad_v_bank = is_input ? grad_input_v : grad_recurrent_v; + const float* k_bank = is_input ? input_k : recurrent_k; + const float dlogit = weights[edge] * (dweights[edge] - expected_dweight); + for (int v = 0; v < value_dim; ++v) { + atomicAdd( + &grad_v_bank[(static_cast(b) * bank_sender_count + bank_sender) * value_dim + v], + weights[edge] * grad_msg[grad_offset + v]); + } + for (int d = 0; d < head_dim; ++d) { + const float q_value = q[receiver * head_dim + d]; + const float k_value = k_bank[(static_cast(b) * bank_sender_count + bank_sender) * key_dim + d]; + atomicAdd( + &grad_k_bank[(static_cast(b) * bank_sender_count + bank_sender) * key_dim + d], + dlogit * q_value * inv_sqrt_dk); + atomicAdd( + &grad_q[receiver * head_dim + d], + dlogit * k_value * inv_sqrt_dk); + } + } + } +} + +__global__ void registered_forward_sparse_attention_kernel( + const float* __restrict__ q, + const float* __restrict__ input_k, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_k, + const float* __restrict__ recurrent_v, + const int64_t* __restrict__ neighbor_idx, + const bool* __restrict__ neighbor_valid, + const float* __restrict__ edge_distance, + const int64_t* __restrict__ edge_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ out, + int64_t total_elements, + int receiver_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int key_dim, + int value_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int v = static_cast(linear % value_dim); + const int receiver = static_cast((linear / value_dim) % receiver_count); + const int b = static_cast(linear / (static_cast(value_dim) * receiver_count)); + const int64_t step_value = use_delay ? step_flat[b] : 0; + float max_logit = -std::numeric_limits::infinity(); + for (int edge = 0; edge < degree; ++edge) { + const int64_t edge_offset = static_cast(receiver) * degree + edge; + if (!neighbor_valid[edge_offset] || (use_delay && edge_delay[edge_offset] > step_value)) { + continue; + } + const int sender = static_cast(neighbor_idx[edge_offset]); + if (sender < 0 || sender >= input_senders + recurrent_senders) { + continue; + } + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += q[receiver * head_dim + d] * + read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + key_dim); + } + max_logit = fmaxf(max_logit, dot * inv_sqrt_dk - distance_scale * edge_distance[edge_offset]); + } + float norm = 0.0f; + float acc = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int64_t edge_offset = static_cast(receiver) * degree + edge; + if (!neighbor_valid[edge_offset] || (use_delay && edge_delay[edge_offset] > step_value)) { + continue; + } + const int sender = static_cast(neighbor_idx[edge_offset]); + if (sender < 0 || sender >= input_senders + recurrent_senders) { + continue; + } + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += q[receiver * head_dim + d] * + read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + key_dim); + } + const float weight = expf(dot * inv_sqrt_dk - distance_scale * edge_distance[edge_offset] - max_logit); + norm += weight; + acc += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + out[linear] = norm > 0.0f ? acc / norm : 0.0f; + } +} + +__global__ void registered_backward_sparse_attention_kernel( + const float* __restrict__ grad_msg, + const float* __restrict__ q, + const float* __restrict__ input_k, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_k, + const float* __restrict__ recurrent_v, + const int64_t* __restrict__ neighbor_idx, + const bool* __restrict__ neighbor_valid, + const float* __restrict__ edge_distance, + const int64_t* __restrict__ edge_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ grad_q, + float* __restrict__ grad_input_k, + float* __restrict__ grad_input_v, + float* __restrict__ grad_recurrent_k, + float* __restrict__ grad_recurrent_v, + int64_t receiver_total, + int receiver_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int key_dim, + int value_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < receiver_total; + linear += static_cast(blockDim.x) * gridDim.x) { + const int receiver = static_cast(linear % receiver_count); + const int b = static_cast(linear / receiver_count); + const int64_t step_value = use_delay ? step_flat[b] : 0; + float max_logit = -std::numeric_limits::infinity(); + for (int edge = 0; edge < degree; ++edge) { + const int64_t edge_offset = static_cast(receiver) * degree + edge; + if (!neighbor_valid[edge_offset] || (use_delay && edge_delay[edge_offset] > step_value)) { + continue; + } + const int sender = static_cast(neighbor_idx[edge_offset]); + if (sender < 0 || sender >= input_senders + recurrent_senders) { + continue; + } + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += q[receiver * head_dim + d] * + read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + key_dim); + } + max_logit = fmaxf(max_logit, dot * inv_sqrt_dk - distance_scale * edge_distance[edge_offset]); + } + float norm = 0.0f; + float expected_dweight = 0.0f; + const int64_t grad_offset = (static_cast(b) * receiver_count + receiver) * value_dim; + for (int edge = 0; edge < degree; ++edge) { + const int64_t edge_offset = static_cast(receiver) * degree + edge; + if (!neighbor_valid[edge_offset] || (use_delay && edge_delay[edge_offset] > step_value)) { + continue; + } + const int sender = static_cast(neighbor_idx[edge_offset]); + if (sender < 0 || sender >= input_senders + recurrent_senders) { + continue; + } + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += q[receiver * head_dim + d] * + read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + key_dim); + } + const float raw_weight = expf(dot * inv_sqrt_dk - distance_scale * edge_distance[edge_offset] - max_logit); + float dweight = 0.0f; + for (int v = 0; v < value_dim; ++v) { + dweight += grad_msg[grad_offset + v] * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + norm += raw_weight; + expected_dweight += raw_weight * dweight; + } + if (norm <= 0.0f) { + continue; + } + expected_dweight /= norm; + for (int edge = 0; edge < degree; ++edge) { + const int64_t edge_offset = static_cast(receiver) * degree + edge; + if (!neighbor_valid[edge_offset] || (use_delay && edge_delay[edge_offset] > step_value)) { + continue; + } + const int sender = static_cast(neighbor_idx[edge_offset]); + if (sender < 0 || sender >= input_senders + recurrent_senders) { + continue; + } + const bool is_input = sender < input_senders; + const int bank_sender = is_input ? sender : sender - input_senders; + const int bank_sender_count = is_input ? input_senders : recurrent_senders; + float* grad_k_bank = is_input ? grad_input_k : grad_recurrent_k; + float* grad_v_bank = is_input ? grad_input_v : grad_recurrent_v; + const float* k_bank = is_input ? input_k : recurrent_k; + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += q[receiver * head_dim + d] * k_bank[(static_cast(b) * bank_sender_count + bank_sender) * key_dim + d]; + } + const float weight = expf(dot * inv_sqrt_dk - distance_scale * edge_distance[edge_offset] - max_logit) / norm; + float dweight = 0.0f; + for (int v = 0; v < value_dim; ++v) { + dweight += grad_msg[grad_offset + v] * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + const float dlogit = weight * (dweight - expected_dweight); + for (int v = 0; v < value_dim; ++v) { + atomicAdd( + &grad_v_bank[(static_cast(b) * bank_sender_count + bank_sender) * value_dim + v], + weight * grad_msg[grad_offset + v]); + } + for (int d = 0; d < head_dim; ++d) { + const float q_value = q[receiver * head_dim + d]; + const float k_value = k_bank[(static_cast(b) * bank_sender_count + bank_sender) * key_dim + d]; + atomicAdd( + &grad_k_bank[(static_cast(b) * bank_sender_count + bank_sender) * key_dim + d], + dlogit * q_value * inv_sqrt_dk); + atomicAdd( + &grad_q[receiver * head_dim + d], + dlogit * k_value * inv_sqrt_dk); + } + } + } +} + +template +__global__ void registered_forward_readout_layout_epilogue_kernel( + const float* __restrict__ boundary, + const float* __restrict__ recurrent_hidden_backend_order, + const float* __restrict__ input_k, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_k, + const float* __restrict__ recurrent_v, + const float* __restrict__ output_q, + const int32_t* __restrict__ output_local_sender_idx, + const float* __restrict__ local_distance, + const float* __restrict__ value_to_output_weight, + const float* __restrict__ output_cell_bias, + const index_t* __restrict__ backend_to_graph_inverse_order, + float* __restrict__ output_cells, + float* __restrict__ recurrent_hidden_graph_order, + float* __restrict__ cells_out, + int64_t total_elements, + int input_count, + int recurrent_count, + int output_count, + int degree, + int head_dim, + int key_dim, + int value_dim, + int hidden_dim, + float distance_scale) { + const int graph_cells = input_count + recurrent_count + output_count; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden_dim); + const int graph_cell = static_cast((linear / hidden_dim) % graph_cells); + const int b = static_cast(linear / (static_cast(hidden_dim) * graph_cells)); + float value = 0.0f; + if (graph_cell < input_count) { + value = boundary[(static_cast(b) * input_count + graph_cell) * hidden_dim + h]; + } else if (graph_cell < input_count + recurrent_count) { + const int recurrent_graph = graph_cell - input_count; + const int recurrent_backend = static_cast(backend_to_graph_inverse_order[recurrent_graph]); + value = + recurrent_hidden_backend_order[(static_cast(b) * recurrent_count + recurrent_backend) * hidden_dim + h]; + recurrent_hidden_graph_order[(static_cast(b) * recurrent_count + recurrent_graph) * hidden_dim + h] = + value; + } else { + const int output_idx = graph_cell - input_count - recurrent_count; + value = readout_value( + b, + output_idx, + h, + input_k, + input_v, + recurrent_k, + recurrent_v, + output_q, + output_local_sender_idx, + local_distance, + value_to_output_weight, + output_cell_bias, + input_count, + recurrent_count, + output_count, + degree, + head_dim, + key_dim, + value_dim, + hidden_dim, + distance_scale); + output_cells[(static_cast(b) * output_count + output_idx) * hidden_dim + h] = value; + } + cells_out[linear] = value; + } +} + +template +__global__ void registered_forward_cells_layout_kernel( + const float* __restrict__ boundary, + const float* __restrict__ recurrent_hidden_backend_order, + const float* __restrict__ output_cells, + const index_t* __restrict__ backend_to_graph_inverse_order, + float* __restrict__ recurrent_hidden_graph_order, + float* __restrict__ cells_out, + int64_t total_elements, + int input_count, + int recurrent_count, + int output_count, + int hidden_dim) { + const int graph_cells = input_count + recurrent_count + output_count; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden_dim); + const int graph_cell = static_cast((linear / hidden_dim) % graph_cells); + const int b = static_cast(linear / (static_cast(hidden_dim) * graph_cells)); + float value = 0.0f; + if (graph_cell < input_count) { + value = boundary[(static_cast(b) * input_count + graph_cell) * hidden_dim + h]; + } else if (graph_cell < input_count + recurrent_count) { + const int recurrent_graph = graph_cell - input_count; + const int recurrent_backend = static_cast(backend_to_graph_inverse_order[recurrent_graph]); + value = + recurrent_hidden_backend_order[(static_cast(b) * recurrent_count + recurrent_backend) * hidden_dim + h]; + recurrent_hidden_graph_order[(static_cast(b) * recurrent_count + recurrent_graph) * hidden_dim + h] = + value; + } else { + const int output_idx = graph_cell - input_count - recurrent_count; + value = output_cells[(static_cast(b) * output_count + output_idx) * hidden_dim + h]; + } + cells_out[linear] = value; + } +} + +template +__global__ void registered_backward_layout_split_kernel( + const float* __restrict__ grad_cells_out, + const index_t* __restrict__ graph_to_backend_order, + float* __restrict__ grad_boundary, + float* __restrict__ grad_recurrent_hidden_backend, + int64_t total_elements, + int input_count, + int recurrent_count, + int output_count, + int hidden_dim) { + static_cast(output_count); + const int direct_cells = input_count + recurrent_count; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden_dim); + const int graph_cell = static_cast((linear / hidden_dim) % direct_cells); + const int b = static_cast(linear / (static_cast(hidden_dim) * direct_cells)); + const float value = grad_cells_out[(static_cast(b) * (direct_cells + output_count) + graph_cell) * + hidden_dim + + h]; + if (graph_cell < input_count) { + grad_boundary[(static_cast(b) * input_count + graph_cell) * hidden_dim + h] = value; + } else { + const int backend_cell = graph_cell - input_count; + const int graph_index = static_cast(graph_to_backend_order[backend_cell]); + grad_recurrent_hidden_backend[(static_cast(b) * recurrent_count + backend_cell) * hidden_dim + h] = + grad_cells_out[(static_cast(b) * (direct_cells + output_count) + input_count + graph_index) * + hidden_dim + + h]; + } + } +} + +__global__ void registered_backward_readout_projection_kernel( + const float* __restrict__ grad_cells_out, + const float* __restrict__ output_msg, + const float* __restrict__ value_to_output_weight, + float* __restrict__ grad_output_msg, + float* __restrict__ grad_value_to_output_weight, + float* __restrict__ grad_output_cell_bias, + int64_t max_elements, + int B, + int input_count, + int recurrent_count, + int output_count, + int value_dim, + int hidden_dim) { + const int output_start = input_count + recurrent_count; + const int total_cells = input_count + recurrent_count + output_count; + const int64_t grad_msg_elements = static_cast(B) * output_count * value_dim; + const int64_t grad_weight_elements = static_cast(output_count) * value_dim * hidden_dim; + const int64_t grad_bias_elements = static_cast(output_count) * hidden_dim; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < max_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + if (linear < grad_msg_elements) { + const int v = static_cast(linear % value_dim); + const int output_idx = static_cast((linear / value_dim) % output_count); + const int b = static_cast(linear / (static_cast(value_dim) * output_count)); + float acc = 0.0f; + for (int h = 0; h < hidden_dim; ++h) { + const float grad_out = + grad_cells_out[(static_cast(b) * total_cells + output_start + output_idx) * hidden_dim + h]; + acc += grad_out * value_to_output_weight[(static_cast(output_idx) * value_dim + v) * hidden_dim + h]; + } + grad_output_msg[linear] = acc; + } + if (linear < grad_weight_elements) { + const int h = static_cast(linear % hidden_dim); + const int v = static_cast((linear / hidden_dim) % value_dim); + const int output_idx = static_cast(linear / (static_cast(hidden_dim) * value_dim)); + float acc = 0.0f; + for (int b = 0; b < B; ++b) { + const float msg = output_msg[(static_cast(b) * output_count + output_idx) * value_dim + v]; + const float grad_out = + grad_cells_out[(static_cast(b) * total_cells + output_start + output_idx) * hidden_dim + h]; + acc += msg * grad_out; + } + grad_value_to_output_weight[linear] = acc; + } + if (linear < grad_bias_elements) { + const int h = static_cast(linear % hidden_dim); + const int output_idx = static_cast(linear / hidden_dim); + float acc = 0.0f; + for (int b = 0; b < B; ++b) { + acc += grad_cells_out[(static_cast(b) * total_cells + output_start + output_idx) * hidden_dim + h]; + } + grad_output_cell_bias[linear] = acc; + } + } +} + +__global__ void registered_forward_readout_projection_kernel( + const float* __restrict__ output_msg, + const float* __restrict__ value_to_output_weight, + const float* __restrict__ output_cell_bias, + float* __restrict__ output_cells, + int64_t total_elements, + int output_count, + int value_dim, + int hidden_dim) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden_dim); + const int output_idx = static_cast((linear / hidden_dim) % output_count); + const int b = static_cast(linear / (static_cast(hidden_dim) * output_count)); + float acc = output_cell_bias[static_cast(output_idx) * hidden_dim + h]; + for (int v = 0; v < value_dim; ++v) { + acc += output_msg[(static_cast(b) * output_count + output_idx) * value_dim + v] * + value_to_output_weight[(static_cast(output_idx) * value_dim + v) * hidden_dim + h]; + } + output_cells[linear] = acc; + } +} + +__global__ void registered_forward_readout_projection_strided_kernel( + const float* __restrict__ output_msg, + const float* __restrict__ value_to_output_weight, + const float* __restrict__ output_cell_bias, + float* __restrict__ output_cells, + int64_t total_elements, + int output_count, + int value_dim, + int hidden_dim, + int64_t output_stride_b, + int64_t output_stride_o, + int64_t output_stride_h) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden_dim); + const int output_idx = static_cast((linear / hidden_dim) % output_count); + const int b = static_cast(linear / (static_cast(hidden_dim) * output_count)); + float acc = output_cell_bias[static_cast(output_idx) * hidden_dim + h]; + for (int v = 0; v < value_dim; ++v) { + acc += output_msg[(static_cast(b) * output_count + output_idx) * value_dim + v] * + value_to_output_weight[(static_cast(output_idx) * value_dim + v) * hidden_dim + h]; + } + output_cells[ + static_cast(b) * output_stride_b + + static_cast(output_idx) * output_stride_o + + static_cast(h) * output_stride_h] = acc; + } +} + +__global__ void registered_forward_sender_kv_sequence_kernel( + const float* __restrict__ sender_cells, + const float* __restrict__ direct_weight, + const float* __restrict__ grouped_weight, + float* __restrict__ sender_k, + float* __restrict__ sender_v, + int64_t total_elements, + int batch_size, + int time_steps, + int sender_count, + int hidden_dim, + int head_dim, + int value_dim, + int group_size, + bool use_grouped_weight) { + const int kv_dim = head_dim + value_dim; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int m = static_cast(linear % kv_dim); + const int sender = static_cast((linear / kv_dim) % sender_count); + const int b = static_cast((linear / (static_cast(kv_dim) * sender_count)) % batch_size); + const int t = static_cast(linear / (static_cast(kv_dim) * sender_count * batch_size)); + const float* weight = nullptr; + if (use_grouped_weight) { + const int group = sender / group_size; + weight = grouped_weight + (static_cast(group) * hidden_dim * kv_dim); + } else { + weight = direct_weight + (static_cast(sender) * hidden_dim * kv_dim); + } + float acc = 0.0f; + for (int h = 0; h < hidden_dim; ++h) { + acc += sender_cells[((static_cast(b) * time_steps + t) * sender_count + sender) * hidden_dim + h] * + weight[static_cast(h) * kv_dim + m]; + } + if (m < head_dim) { + sender_k[((static_cast(t) * batch_size + b) * sender_count + sender) * head_dim + m] = acc; + } else { + const int v = m - head_dim; + sender_v[((static_cast(t) * batch_size + b) * sender_count + sender) * value_dim + v] = acc; + } + } +} + +__global__ void registered_forward_sender_kv_step_kernel( + const float* __restrict__ sender_cells, + const float* __restrict__ direct_weight, + const float* __restrict__ grouped_weight, + float* __restrict__ sender_k, + float* __restrict__ sender_v, + int64_t total_elements, + int sender_count, + int hidden_dim, + int head_dim, + int value_dim, + int group_size, + bool use_grouped_weight) { + const int kv_dim = head_dim + value_dim; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int m = static_cast(linear % kv_dim); + const int sender = static_cast((linear / kv_dim) % sender_count); + const int b = static_cast(linear / (static_cast(kv_dim) * sender_count)); + const float* weight = nullptr; + if (use_grouped_weight) { + const int group = sender / group_size; + weight = grouped_weight + (static_cast(group) * hidden_dim * kv_dim); + } else { + weight = direct_weight + (static_cast(sender) * hidden_dim * kv_dim); + } + float acc = 0.0f; + for (int h = 0; h < hidden_dim; ++h) { + acc += sender_cells[(static_cast(b) * sender_count + sender) * hidden_dim + h] * + weight[static_cast(h) * kv_dim + m]; + } + if (m < head_dim) { + sender_k[(static_cast(b) * sender_count + sender) * head_dim + m] = acc; + } else { + const int v = m - head_dim; + sender_v[(static_cast(b) * sender_count + sender) * value_dim + v] = acc; + } + } +} + +__global__ void registered_forward_sender_value_sequence_kernel( + const float* __restrict__ sender_cells, + const float* __restrict__ direct_weight, + const float* __restrict__ grouped_weight, + float* __restrict__ sender_v, + int64_t total_elements, + int batch_size, + int time_steps, + int sender_count, + int hidden_dim, + int value_dim, + int group_size, + bool use_grouped_weight) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int v = static_cast(linear % value_dim); + const int sender = static_cast((linear / value_dim) % sender_count); + const int b = static_cast((linear / (static_cast(value_dim) * sender_count)) % batch_size); + const int t = static_cast(linear / (static_cast(value_dim) * sender_count * batch_size)); + const float* weight = nullptr; + if (use_grouped_weight) { + const int group = sender / group_size; + weight = grouped_weight + (static_cast(group) * hidden_dim * value_dim); + } else { + weight = direct_weight + (static_cast(sender) * hidden_dim * value_dim); + } + float acc = 0.0f; + for (int h = 0; h < hidden_dim; ++h) { + acc += sender_cells[((static_cast(b) * time_steps + t) * sender_count + sender) * hidden_dim + h] * + weight[static_cast(h) * value_dim + v]; + } + sender_v[((static_cast(t) * batch_size + b) * sender_count + sender) * value_dim + v] = acc; + } +} + +__global__ void registered_forward_sender_value_step_kernel( + const float* __restrict__ sender_cells, + const float* __restrict__ direct_weight, + const float* __restrict__ grouped_weight, + float* __restrict__ sender_v, + int64_t total_elements, + int sender_count, + int hidden_dim, + int value_dim, + int group_size, + bool use_grouped_weight) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int v = static_cast(linear % value_dim); + const int sender = static_cast((linear / value_dim) % sender_count); + const int b = static_cast(linear / (static_cast(value_dim) * sender_count)); + const float* weight = nullptr; + if (use_grouped_weight) { + const int group = sender / group_size; + weight = grouped_weight + (static_cast(group) * hidden_dim * value_dim); + } else { + weight = direct_weight + (static_cast(sender) * hidden_dim * value_dim); + } + float acc = 0.0f; + for (int h = 0; h < hidden_dim; ++h) { + acc += sender_cells[(static_cast(b) * sender_count + sender) * hidden_dim + h] * + weight[static_cast(h) * value_dim + v]; + } + sender_v[(static_cast(b) * sender_count + sender) * value_dim + v] = acc; + } +} + +__global__ void registered_forward_fixed_slot_context_key_sequence_kernel( + const float* __restrict__ sender_slot_key, + const float* __restrict__ sender_context_key, + float* __restrict__ sender_k, + int64_t total_elements, + int batch_size, + int time_steps, + int sender_count, + int key_part_dim) { + const int key_dim = key_part_dim * 2; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int d = static_cast(linear % key_dim); + const int sender = static_cast((linear / key_dim) % sender_count); + const int source_d = d < key_part_dim ? d : d - key_part_dim; + const float* source = d < key_part_dim ? sender_slot_key : sender_context_key; + sender_k[linear] = source[static_cast(sender) * key_part_dim + source_d]; + } +} + +__global__ void registered_forward_fixed_slot_context_weighted_value_kernel( + const float* __restrict__ query_slot, + const float* __restrict__ query_context_scalar, + const float* __restrict__ sender_slot_key, + const float* __restrict__ sender_context_key, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_v, + const int32_t* __restrict__ receiver_sender_idx, + const float* __restrict__ offset_distance, + const int32_t* __restrict__ offset_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ weighted_value, + int64_t row_count, + int B, + int receiver_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int value_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + __shared__ float logits[kMaxRegisteredAttentionOffsets]; + __shared__ int senders[kMaxRegisteredAttentionOffsets]; + __shared__ float row_stats[2]; + + for (int64_t row = static_cast(blockIdx.x); row < row_count; row += gridDim.x) { + const int receiver = static_cast(row % receiver_count); + const int b = static_cast(row / receiver_count); + const int64_t step_value = use_delay ? step_flat[b] : 0; + const float scale = query_context_scalar[0]; + + for (int edge = threadIdx.x; edge < degree; edge += blockDim.x) { + int sender = -1; + float logit = -std::numeric_limits::infinity(); + if (!use_delay || static_cast(offset_delay[edge]) <= step_value) { + sender = receiver_sender_idx[receiver * degree + edge]; + if (sender >= 0 && sender < input_senders + recurrent_senders) { + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + const float q0 = query_slot[receiver * head_dim + d]; + const float q1 = scale * recurrent_v[(static_cast(b) * receiver_count + receiver) * value_dim + d]; + dot += q0 * sender_slot_key[static_cast(sender) * head_dim + d] + + q1 * sender_context_key[static_cast(sender) * head_dim + d]; + } + logit = dot * inv_sqrt_dk - distance_scale * offset_distance[edge]; + } + } + senders[edge] = sender; + logits[edge] = logit; + } + __syncthreads(); + + if (threadIdx.x == 0) { + float max_logit = -std::numeric_limits::infinity(); + for (int edge = 0; edge < degree; ++edge) { + if (senders[edge] >= 0) { + max_logit = fmaxf(max_logit, logits[edge]); + } + } + float norm = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + if (senders[edge] >= 0) { + norm += expf(logits[edge] - max_logit); + } + } + row_stats[0] = max_logit; + row_stats[1] = norm; + } + __syncthreads(); + + const int64_t out_base = row * value_dim; + if (row_stats[1] <= 0.0f) { + for (int v = threadIdx.x; v < value_dim; v += blockDim.x) { + weighted_value[out_base + v] = 0.0f; + } + __syncthreads(); + continue; + } + + for (int v = threadIdx.x; v < value_dim; v += blockDim.x) { + float weighted_value_v = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const float weight = expf(logits[edge] - row_stats[0]) / row_stats[1]; + weighted_value_v += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + weighted_value[out_base + v] = weighted_value_v; + } + __syncthreads(); + } +} + +__global__ void registered_forward_fixed_slot_context_weighted_value_warp_kernel( + const float* __restrict__ query_slot, + const float* __restrict__ query_context_scalar, + const float* __restrict__ sender_slot_key, + const float* __restrict__ sender_context_key, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_v, + const int32_t* __restrict__ receiver_sender_idx, + const float* __restrict__ offset_distance, + const int32_t* __restrict__ offset_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ weighted_value, + int64_t row_count, + int B, + int receiver_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int value_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + static_cast(B); + constexpr int kWarpSize = 32; + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_in_block = threadIdx.x / kWarpSize; + const int warps_per_block = blockDim.x / kWarpSize; + const int64_t first_row = static_cast(blockIdx.x) * warps_per_block + warp_in_block; + const int64_t row_stride = static_cast(gridDim.x) * warps_per_block; + const unsigned mask = 0xffffffffu; + + for (int64_t row = first_row; row < row_count; row += row_stride) { + const int receiver = static_cast(row % receiver_count); + const int b = static_cast(row / receiver_count); + const int64_t step_value = use_delay ? step_flat[b] : 0; + const float scale = query_context_scalar[0]; + int sender = -1; + float logit = -std::numeric_limits::infinity(); + if (lane < degree && (!use_delay || static_cast(offset_delay[lane]) <= step_value)) { + sender = receiver_sender_idx[receiver * degree + lane]; + if (sender >= 0 && sender < input_senders + recurrent_senders) { + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + const float q0 = query_slot[receiver * head_dim + d]; + const float q1 = scale * recurrent_v[(static_cast(b) * receiver_count + receiver) * value_dim + d]; + dot += q0 * sender_slot_key[static_cast(sender) * head_dim + d] + + q1 * sender_context_key[static_cast(sender) * head_dim + d]; + } + logit = dot * inv_sqrt_dk - distance_scale * offset_distance[lane]; + } + } + + float max_logit = logit; + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + max_logit = fmaxf(max_logit, __shfl_down_sync(mask, max_logit, offset)); + } + max_logit = __shfl_sync(mask, max_logit, 0); + float norm = sender >= 0 ? expf(logit - max_logit) : 0.0f; + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + norm += __shfl_down_sync(mask, norm, offset); + } + norm = __shfl_sync(mask, norm, 0); + + const int64_t out_base = row * value_dim; + for (int v = lane; v < value_dim; v += kWarpSize) { + float weighted_value_v = 0.0f; + if (norm > 0.0f) { + for (int edge = 0; edge < degree; ++edge) { + const int edge_sender = __shfl_sync(mask, sender, edge); + if (edge_sender < 0) { + continue; + } + const float edge_logit = __shfl_sync(mask, logit, edge); + const float weight = expf(edge_logit - max_logit) / norm; + weighted_value_v += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + edge_sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + } + weighted_value[out_base + v] = weighted_value_v; + } + } +} + +__global__ void registered_forward_fixed_slot_context_keyless_readout_warp_kernel( + const float* __restrict__ output_q, + const float* __restrict__ sender_slot_key, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_v, + const int32_t* __restrict__ receiver_sender_idx, + const float* __restrict__ offset_distance, + const int32_t* __restrict__ offset_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ output_msg, + int64_t row_count, + int output_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int value_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + constexpr int kWarpSize = 32; + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_in_block = threadIdx.x / kWarpSize; + const int warps_per_block = blockDim.x / kWarpSize; + const int64_t first_row = static_cast(blockIdx.x) * warps_per_block + warp_in_block; + const int64_t row_stride = static_cast(gridDim.x) * warps_per_block; + const unsigned mask = 0xffffffffu; + + for (int64_t row = first_row; row < row_count; row += row_stride) { + const int receiver = static_cast(row % output_count); + const int b = static_cast(row / output_count); + const int64_t step_value = use_delay ? step_flat[b] : 0; + int sender = -1; + float logit = -std::numeric_limits::infinity(); + if (lane < degree && (!use_delay || static_cast(offset_delay[lane]) <= step_value)) { + sender = receiver_sender_idx[receiver * degree + lane]; + if (sender >= 0 && sender < input_senders + recurrent_senders) { + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += output_q[receiver * head_dim + d] * + sender_slot_key[static_cast(sender) * head_dim + d]; + } + logit = dot * inv_sqrt_dk - distance_scale * offset_distance[lane]; + } + } + + float max_logit = logit; + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + max_logit = fmaxf(max_logit, __shfl_down_sync(mask, max_logit, offset)); + } + max_logit = __shfl_sync(mask, max_logit, 0); + float norm = sender >= 0 ? expf(logit - max_logit) : 0.0f; + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + norm += __shfl_down_sync(mask, norm, offset); + } + norm = __shfl_sync(mask, norm, 0); + + const int64_t out_base = row * value_dim; + for (int v = lane; v < value_dim; v += kWarpSize) { + float value = 0.0f; + if (norm > 0.0f) { + for (int edge = 0; edge < degree; ++edge) { + const int edge_sender = __shfl_sync(mask, sender, edge); + if (edge_sender < 0) { + continue; + } + const float edge_logit = __shfl_sync(mask, logit, edge); + const float weight = expf(edge_logit - max_logit) / norm; + value += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + edge_sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + } + output_msg[out_base + v] = value; + } + } +} + +__global__ void registered_forward_fixed_slot_context_direct_keyless_readout_warp_kernel( + const float* __restrict__ output_q, + const float* __restrict__ sender_slot_key, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_hidden, + const float* __restrict__ recurrent_value_weight, + const int32_t* __restrict__ receiver_sender_idx, + const float* __restrict__ offset_distance, + const int32_t* __restrict__ offset_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ output_msg, + int64_t row_count, + int output_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int hidden_dim, + int value_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + constexpr int kWarpSize = 32; + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_in_block = threadIdx.x / kWarpSize; + const int warps_per_block = blockDim.x / kWarpSize; + const int64_t first_row = static_cast(blockIdx.x) * warps_per_block + warp_in_block; + const int64_t row_stride = static_cast(gridDim.x) * warps_per_block; + const unsigned mask = 0xffffffffu; + + for (int64_t row = first_row; row < row_count; row += row_stride) { + const int receiver = static_cast(row % output_count); + const int b = static_cast(row / output_count); + const int64_t step_value = use_delay ? step_flat[b] : 0; + int sender = -1; + float logit = -std::numeric_limits::infinity(); + if (lane < degree && (!use_delay || static_cast(offset_delay[lane]) <= step_value)) { + sender = receiver_sender_idx[receiver * degree + lane]; + if (sender >= 0 && sender < input_senders + recurrent_senders) { + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + dot += output_q[receiver * head_dim + d] * + sender_slot_key[static_cast(sender) * head_dim + d]; + } + logit = dot * inv_sqrt_dk - distance_scale * offset_distance[lane]; + } + } + + float max_logit = logit; + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + max_logit = fmaxf(max_logit, __shfl_down_sync(mask, max_logit, offset)); + } + max_logit = __shfl_sync(mask, max_logit, 0); + float norm = sender >= 0 ? expf(logit - max_logit) : 0.0f; + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + norm += __shfl_down_sync(mask, norm, offset); + } + norm = __shfl_sync(mask, norm, 0); + + const int64_t out_base = row * value_dim; + for (int v = lane; v < value_dim; v += kWarpSize) { + float value = 0.0f; + if (norm > 0.0f) { + for (int edge = 0; edge < degree; ++edge) { + const int edge_sender = __shfl_sync(mask, sender, edge); + if (edge_sender < 0) { + continue; + } + const float edge_logit = __shfl_sync(mask, logit, edge); + const float weight = expf(edge_logit - max_logit) / norm; + value += weight * + read_partitioned_projected_recurrent_value( + input_v, + recurrent_hidden, + recurrent_value_weight, + b, + edge_sender, + v, + input_senders, + recurrent_senders, + hidden_dim, + value_dim); + } + } + output_msg[out_base + v] = value; + } + } +} + +__global__ void registered_forward_fixed_slot_context_normalize_rows_kernel( + const float* __restrict__ projected, + float* __restrict__ out, + int64_t row_count, + int message_dim, + float eps) { + __shared__ float partial_sum[kThreadsPerBlock]; + __shared__ float partial_square[kThreadsPerBlock]; + __shared__ float row_stats[2]; + + for (int64_t row = static_cast(blockIdx.x); row < row_count; row += gridDim.x) { + const int64_t row_base = row * message_dim; + float thread_sum = 0.0f; + float thread_square = 0.0f; + for (int mm = threadIdx.x; mm < message_dim; mm += blockDim.x) { + const float value = projected[row_base + mm]; + thread_sum += value; + thread_square += value * value; + } + partial_sum[threadIdx.x] = thread_sum; + partial_square[threadIdx.x] = thread_square; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + partial_sum[threadIdx.x] += partial_sum[threadIdx.x + stride]; + partial_square[threadIdx.x] += partial_square[threadIdx.x + stride]; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + const float mean = partial_sum[0] / static_cast(message_dim); + const float variance = fmaxf(partial_square[0] / static_cast(message_dim) - mean * mean, 0.0f); + row_stats[0] = mean; + row_stats[1] = rsqrtf(variance + eps); + } + __syncthreads(); + + for (int mm = threadIdx.x; mm < message_dim; mm += blockDim.x) { + out[row_base + mm] = (projected[row_base + mm] - row_stats[0]) * row_stats[1]; + } + __syncthreads(); + } +} + +__global__ void registered_forward_fixed_slot_context_normalize_rows_inplace_kernel( + float* projected_and_out, + int64_t row_count, + int message_dim, + float eps) { + __shared__ float partial_sum[kThreadsPerBlock]; + __shared__ float partial_square[kThreadsPerBlock]; + __shared__ float row_stats[2]; + + for (int64_t row = static_cast(blockIdx.x); row < row_count; row += gridDim.x) { + const int64_t row_base = row * message_dim; + float thread_sum = 0.0f; + float thread_square = 0.0f; + for (int mm = threadIdx.x; mm < message_dim; mm += blockDim.x) { + const float value = projected_and_out[row_base + mm]; + thread_sum += value; + thread_square += value * value; + } + partial_sum[threadIdx.x] = thread_sum; + partial_square[threadIdx.x] = thread_square; + __syncthreads(); + + for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + partial_sum[threadIdx.x] += partial_sum[threadIdx.x + stride]; + partial_square[threadIdx.x] += partial_square[threadIdx.x + stride]; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + const float mean = partial_sum[0] / static_cast(message_dim); + const float variance = fmaxf(partial_square[0] / static_cast(message_dim) - mean * mean, 0.0f); + row_stats[0] = mean; + row_stats[1] = rsqrtf(variance + eps); + } + __syncthreads(); + + for (int mm = threadIdx.x; mm < message_dim; mm += blockDim.x) { + projected_and_out[row_base + mm] = (projected_and_out[row_base + mm] - row_stats[0]) * row_stats[1]; + } + __syncthreads(); + } +} + +__global__ void registered_backward_fixed_slot_context_message_kernel( + const float* __restrict__ grad_msg, + const float* __restrict__ query_slot, + const float* __restrict__ query_context_scalar, + const float* __restrict__ input_k, + const float* __restrict__ input_v, + const float* __restrict__ recurrent_k, + const float* __restrict__ recurrent_v, + const float* __restrict__ output_weight, + const int32_t* __restrict__ receiver_sender_idx, + const float* __restrict__ offset_distance, + const int32_t* __restrict__ offset_delay, + const int64_t* __restrict__ step_flat, + float* __restrict__ grad_query_slot, + float* __restrict__ grad_input_k, + float* __restrict__ grad_input_v, + float* __restrict__ grad_recurrent_k, + float* __restrict__ grad_recurrent_v, + float* __restrict__ grad_query_context_scalar, + float* __restrict__ grad_output_weight, + int64_t receiver_total, + int receiver_count, + int input_senders, + int recurrent_senders, + int degree, + int head_dim, + int value_dim, + int message_dim, + float inv_sqrt_dk, + float distance_scale, + bool use_delay) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < receiver_total; + linear += static_cast(blockDim.x) * gridDim.x) { + const int receiver = static_cast(linear % receiver_count); + const int b = static_cast(linear / receiver_count); + const int64_t step_value = use_delay ? step_flat[b] : 0; + const float scale = query_context_scalar[0]; + float logits[kMaxRegisteredAttentionOffsets]; + int senders[kMaxRegisteredAttentionOffsets]; + float max_logit = -std::numeric_limits::infinity(); + for (int edge = 0; edge < degree; ++edge) { + int sender = -1; + float logit = -std::numeric_limits::infinity(); + if (!use_delay || static_cast(offset_delay[edge]) <= step_value) { + sender = receiver_sender_idx[receiver * degree + edge]; + if (sender >= 0 && sender < input_senders + recurrent_senders) { + float dot = 0.0f; + for (int d = 0; d < head_dim; ++d) { + const float q0 = query_slot[receiver * head_dim + d]; + const float q1 = + scale * recurrent_v[(static_cast(b) * recurrent_senders + receiver) * value_dim + d]; + const float k0 = read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + d, + input_senders, + recurrent_senders, + 2 * head_dim); + const float k1 = read_partitioned_bank( + input_k, + recurrent_k, + b, + sender, + head_dim + d, + input_senders, + recurrent_senders, + 2 * head_dim); + dot += q0 * k0 + q1 * k1; + } + logit = dot * inv_sqrt_dk - distance_scale * offset_distance[edge]; + max_logit = fmaxf(max_logit, logit); + } + } + senders[edge] = sender; + logits[edge] = logit; + } + float norm = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender >= 0) { + norm += expf(logits[edge] - max_logit); + } + } + if (norm <= 0.0f) { + continue; + } + float mean = 0.0f; + float second = 0.0f; + for (int m = 0; m < message_dim; ++m) { + float projected = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const float weight = expf(logits[edge] - max_logit) / norm; + for (int v = 0; v < value_dim; ++v) { + projected += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim) * + output_weight[static_cast(m) * value_dim + v]; + } + } + mean += projected; + second += projected * projected; + } + mean /= static_cast(message_dim); + const float variance = fmaxf(second / static_cast(message_dim) - mean * mean, 0.0f); + const float inv_std = rsqrtf(variance + 1.0e-5f); + float grad_mean = 0.0f; + float grad_xhat_mean = 0.0f; + const int64_t msg_base = (static_cast(b) * receiver_count + receiver) * message_dim; + for (int m = 0; m < message_dim; ++m) { + float projected = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const float weight = expf(logits[edge] - max_logit) / norm; + for (int v = 0; v < value_dim; ++v) { + projected += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim) * + output_weight[static_cast(m) * value_dim + v]; + } + } + const float xhat = (projected - mean) * inv_std; + const float grad = grad_msg[msg_base + m]; + grad_mean += grad; + grad_xhat_mean += grad * xhat; + } + grad_mean /= static_cast(message_dim); + grad_xhat_mean /= static_cast(message_dim); + for (int m = 0; m < message_dim; ++m) { + float projected = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const float weight = expf(logits[edge] - max_logit) / norm; + for (int v = 0; v < value_dim; ++v) { + projected += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim) * + output_weight[static_cast(m) * value_dim + v]; + } + } + const float xhat = (projected - mean) * inv_std; + const float grad_projected = + inv_std * (grad_msg[msg_base + m] - grad_mean - xhat * grad_xhat_mean); + for (int v = 0; v < value_dim; ++v) { + float weighted_value = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const float weight = expf(logits[edge] - max_logit) / norm; + weighted_value += weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + atomicAdd(&grad_output_weight[static_cast(m) * value_dim + v], grad_projected * weighted_value); + } + } + float expected_dweight = 0.0f; + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const float weight = expf(logits[edge] - max_logit) / norm; + float dweight = 0.0f; + for (int v = 0; v < value_dim; ++v) { + float grad_weighted_value = 0.0f; + for (int m = 0; m < message_dim; ++m) { + float projected = 0.0f; + for (int inner_edge = 0; inner_edge < degree; ++inner_edge) { + const int inner_sender = senders[inner_edge]; + if (inner_sender < 0) { + continue; + } + const float inner_weight = expf(logits[inner_edge] - max_logit) / norm; + for (int vv = 0; vv < value_dim; ++vv) { + projected += inner_weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + inner_sender, + vv, + input_senders, + recurrent_senders, + value_dim) * + output_weight[static_cast(m) * value_dim + vv]; + } + } + const float xhat = (projected - mean) * inv_std; + const float grad_projected = + inv_std * (grad_msg[msg_base + m] - grad_mean - xhat * grad_xhat_mean); + grad_weighted_value += grad_projected * output_weight[static_cast(m) * value_dim + v]; + } + dweight += grad_weighted_value * + read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + } + expected_dweight += weight * dweight; + } + for (int edge = 0; edge < degree; ++edge) { + const int sender = senders[edge]; + if (sender < 0) { + continue; + } + const bool is_input = sender < input_senders; + const int bank_sender = is_input ? sender : sender - input_senders; + const int bank_sender_count = is_input ? input_senders : recurrent_senders; + float* grad_k_bank = is_input ? grad_input_k : grad_recurrent_k; + float* grad_v_bank = is_input ? grad_input_v : grad_recurrent_v; + const float* k_bank = is_input ? input_k : recurrent_k; + const float weight = expf(logits[edge] - max_logit) / norm; + float dweight = 0.0f; + for (int v = 0; v < value_dim; ++v) { + float grad_weighted_value = 0.0f; + for (int m = 0; m < message_dim; ++m) { + float projected = 0.0f; + for (int inner_edge = 0; inner_edge < degree; ++inner_edge) { + const int inner_sender = senders[inner_edge]; + if (inner_sender < 0) { + continue; + } + const float inner_weight = expf(logits[inner_edge] - max_logit) / norm; + for (int vv = 0; vv < value_dim; ++vv) { + projected += inner_weight * + read_partitioned_bank( + input_v, + recurrent_v, + b, + inner_sender, + vv, + input_senders, + recurrent_senders, + value_dim) * + output_weight[static_cast(m) * value_dim + vv]; + } + } + const float xhat = (projected - mean) * inv_std; + const float grad_projected = + inv_std * (grad_msg[msg_base + m] - grad_mean - xhat * grad_xhat_mean); + grad_weighted_value += grad_projected * output_weight[static_cast(m) * value_dim + v]; + } + const float sender_value = read_partitioned_bank( + input_v, + recurrent_v, + b, + sender, + v, + input_senders, + recurrent_senders, + value_dim); + dweight += grad_weighted_value * sender_value; + atomicAdd( + &grad_v_bank[(static_cast(b) * bank_sender_count + bank_sender) * value_dim + v], + weight * grad_weighted_value); + } + const float dlogit = weight * (dweight - expected_dweight); + for (int d = 0; d < head_dim; ++d) { + const float q0 = query_slot[receiver * head_dim + d]; + const float q1 = + scale * recurrent_v[(static_cast(b) * recurrent_senders + receiver) * value_dim + d]; + const float k0 = k_bank[(static_cast(b) * bank_sender_count + bank_sender) * (2 * head_dim) + d]; + const float k1 = + k_bank[(static_cast(b) * bank_sender_count + bank_sender) * (2 * head_dim) + head_dim + d]; + const float common = dlogit * inv_sqrt_dk; + atomicAdd(&grad_query_slot[receiver * head_dim + d], common * k0); + atomicAdd( + grad_query_context_scalar, + common * recurrent_v[(static_cast(b) * recurrent_senders + receiver) * value_dim + d] * k1); + atomicAdd( + &grad_recurrent_v[(static_cast(b) * recurrent_senders + receiver) * value_dim + d], + common * k1 * scale); + atomicAdd( + &grad_k_bank[(static_cast(b) * bank_sender_count + bank_sender) * (2 * head_dim) + d], + common * q0); + atomicAdd( + &grad_k_bank[ + (static_cast(b) * bank_sender_count + bank_sender) * (2 * head_dim) + head_dim + d], + common * q1); + } + } + } +} + +__global__ void program_transition_linear_forward_kernel( + const float* __restrict__ input, + const float* __restrict__ weight, + const float* __restrict__ bias, + float* __restrict__ output, + int64_t total_elements, + int receivers, + int input_dim, + int output_dim, + int weight_count, + int bias_count, + int group_size, + bool has_bias, + bool bias_is_shared) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int n = static_cast(linear % output_dim); + const int receiver = static_cast((linear / output_dim) % receivers); + const int b = static_cast(linear / (static_cast(output_dim) * receivers)); + const int weight_receiver = + weight_count == 1 ? 0 : weight_count == receivers ? receiver : receiver / group_size; + float value = 0.0f; + for (int k = 0; k < input_dim; ++k) { + value += input[(static_cast(b) * receivers + receiver) * input_dim + k] * + weight[(static_cast(weight_receiver) * input_dim + k) * output_dim + n]; + } + if (has_bias) { + const int bias_receiver = + bias_is_shared || bias_count == 1 ? 0 : bias_count == receivers ? receiver : receiver / group_size; + value += bias_is_shared ? bias[n] : bias[static_cast(bias_receiver) * output_dim + n]; + } + output[linear] = value; + } +} + +__global__ void program_transition_gate_affine_forward_kernel( + const float* __restrict__ input, + const float* __restrict__ weight, + const float* __restrict__ bias, + float* __restrict__ output, + int64_t total_elements, + int batch_size, + int receivers, + int heads, + int head_dim, + bool has_bias) { + const int hidden = heads * head_dim; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int d = static_cast(linear % hidden); + const int gate = static_cast((linear / hidden) % 4); + const int receiver = static_cast((linear / (static_cast(hidden) * 4)) % receivers); + const int b = static_cast(linear / (static_cast(hidden) * 4 * receivers)); + const int head = d / head_dim; + const int head_offset = d - head * head_dim; + const int output_dim_per_head = 4 * head_dim; + const int n = gate * head_dim + head_offset; + float value = 0.0f; + for (int k = 0; k < head_dim; ++k) { + value += input[(static_cast(b) * receivers + receiver) * hidden + head * head_dim + k] * + weight[((static_cast(receiver) * heads + head) * head_dim + k) * output_dim_per_head + n]; + } + if (has_bias) { + value += bias[((static_cast(receiver) * 4 + gate) * heads + head) * head_dim + head_offset]; + } + output[linear] = value; + } +} + +__global__ void program_transition_linear_input_backward_kernel( + const float* __restrict__ grad_output, + const float* __restrict__ weight, + float* __restrict__ grad_input, + int64_t total_elements, + int receivers, + int input_dim, + int output_dim, + int weight_count, + int group_size) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int k = static_cast(linear % input_dim); + const int receiver = static_cast((linear / input_dim) % receivers); + const int b = static_cast(linear / (static_cast(input_dim) * receivers)); + const int weight_receiver = + weight_count == 1 ? 0 : weight_count == receivers ? receiver : receiver / group_size; + float value = 0.0f; + for (int n = 0; n < output_dim; ++n) { + value += grad_output[(static_cast(b) * receivers + receiver) * output_dim + n] * + weight[(static_cast(weight_receiver) * input_dim + k) * output_dim + n]; + } + grad_input[linear] = value; + } +} + +__global__ void program_transition_linear_weight_backward_kernel( + const float* __restrict__ input, + const float* __restrict__ grad_output, + float* __restrict__ grad_weight, + int64_t total_elements, + int batch_size, + int receivers, + int input_dim, + int output_dim, + int weight_count, + int group_size) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int n = static_cast(linear % output_dim); + const int k = static_cast((linear / output_dim) % input_dim); + const int owner = static_cast(linear / (static_cast(output_dim) * input_dim)); + float value = 0.0f; + const int receiver_begin = weight_count == 1 ? 0 : weight_count == receivers ? owner : owner * group_size; + const int receiver_group_stop = (owner + 1) * group_size; + const int receiver_end = + weight_count == 1 ? receivers : weight_count == receivers ? owner + 1 : + (receiver_group_stop < receivers ? receiver_group_stop : receivers); + for (int b = 0; b < batch_size; ++b) { + for (int receiver = receiver_begin; receiver < receiver_end; ++receiver) { + value += input[(static_cast(b) * receivers + receiver) * input_dim + k] * + grad_output[(static_cast(b) * receivers + receiver) * output_dim + n]; + } + } + grad_weight[linear] = value; + } +} + +__global__ void program_transition_linear_bias_backward_kernel( + const float* __restrict__ grad_output, + float* __restrict__ grad_bias, + int64_t total_elements, + int batch_size, + int receivers, + int output_dim, + int bias_count, + int group_size, + bool bias_is_shared) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int n = static_cast(linear % output_dim); + const int owner = bias_is_shared ? 0 : static_cast(linear / output_dim); + const int receiver_begin = bias_is_shared || bias_count == 1 ? 0 : bias_count == receivers ? owner : owner * group_size; + const int receiver_group_stop = (owner + 1) * group_size; + const int receiver_end = bias_is_shared || bias_count == 1 + ? receivers + : bias_count == receivers ? owner + 1 : (receiver_group_stop < receivers ? receiver_group_stop : receivers); + float value = 0.0f; + for (int b = 0; b < batch_size; ++b) { + for (int receiver = receiver_begin; receiver < receiver_end; ++receiver) { + value += grad_output[(static_cast(b) * receivers + receiver) * output_dim + n]; + } + } + grad_bias[linear] = value; + } +} + +__global__ void program_transition_recurrent_matmul_forward_kernel( + const float* __restrict__ input, + const float* __restrict__ recurrent_kernel, + float* __restrict__ output, + int64_t total_elements, + int receivers, + int gate_count, + int head_count, + int head_dim) { + const int hidden_dim = head_count * head_dim; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int hidden_out = static_cast(linear % hidden_dim); + const int gate = static_cast((linear / hidden_dim) % gate_count); + const int receiver = static_cast((linear / (static_cast(hidden_dim) * gate_count)) % receivers); + const int b = static_cast(linear / (static_cast(hidden_dim) * gate_count * receivers)); + const int head = hidden_out / head_dim; + const int out_dim = hidden_out - head * head_dim; + float value = 0.0f; + for (int in_dim = 0; in_dim < head_dim; ++in_dim) { + value += input[(static_cast(b) * receivers + receiver) * hidden_dim + head * head_dim + in_dim] * + recurrent_kernel[((((static_cast(receiver) * gate_count + gate) * head_count + head) * + head_dim + out_dim) * + head_dim) + + in_dim]; + } + output[linear] = value; + } +} + +__global__ void program_transition_recurrent_matmul_input_backward_kernel( + const float* __restrict__ grad_output, + const float* __restrict__ recurrent_kernel, + float* __restrict__ grad_input, + int64_t total_elements, + int receivers, + int gate_count, + int head_count, + int head_dim) { + const int hidden_dim = head_count * head_dim; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int hidden_in = static_cast(linear % hidden_dim); + const int receiver = static_cast((linear / hidden_dim) % receivers); + const int b = static_cast(linear / (static_cast(hidden_dim) * receivers)); + const int head = hidden_in / head_dim; + const int in_dim = hidden_in - head * head_dim; + float value = 0.0f; + for (int gate = 0; gate < gate_count; ++gate) { + for (int out_dim = 0; out_dim < head_dim; ++out_dim) { + value += grad_output[((static_cast(b) * receivers + receiver) * gate_count + gate) * + hidden_dim + head * head_dim + out_dim] * + recurrent_kernel[((((static_cast(receiver) * gate_count + gate) * head_count + head) * + head_dim + out_dim) * + head_dim) + + in_dim]; + } + } + grad_input[linear] = value; + } +} + +__global__ void program_transition_recurrent_matmul_weight_backward_kernel( + const float* __restrict__ input, + const float* __restrict__ grad_output, + float* __restrict__ grad_kernel, + int64_t total_elements, + int batch_size, + int receivers, + int gate_count, + int head_count, + int head_dim) { + const int hidden_dim = head_count * head_dim; + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int in_dim = static_cast(linear % head_dim); + const int out_dim = static_cast((linear / head_dim) % head_dim); + const int head = static_cast((linear / (static_cast(head_dim) * head_dim)) % head_count); + const int gate = static_cast( + (linear / (static_cast(head_dim) * head_dim * head_count)) % gate_count); + const int receiver = static_cast( + linear / (static_cast(head_dim) * head_dim * head_count * gate_count)); + float value = 0.0f; + for (int b = 0; b < batch_size; ++b) { + value += input[(static_cast(b) * receivers + receiver) * hidden_dim + head * head_dim + in_dim] * + grad_output[((static_cast(b) * receivers + receiver) * gate_count + gate) * + hidden_dim + head * head_dim + out_dim]; + } + grad_kernel[linear] = value; + } +} + +__global__ void program_transition_tanh_forward_kernel( + const float* __restrict__ input, + float* __restrict__ output, + int64_t total_elements) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + output[linear] = tanhf(input[linear]); + } +} + +__global__ void program_transition_tanh_backward_kernel( + const float* __restrict__ output, + const float* __restrict__ grad_output, + float* __restrict__ grad_input, + int64_t total_elements) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const float value = output[linear]; + grad_input[linear] = grad_output[linear] * (1.0f - value * value); + } +} + +__device__ inline float sigmoidf_stable(float value) { + return 1.0f / (1.0f + expf(-value)); +} + +__device__ inline float logsigmoidf_stable(float value) { + return -logf(1.0f + expf(-value)); +} + +__device__ inline void gated_logspace_core_values( + float iraw, + float fraw, + float zraw, + float oraw, + float prev_c, + float prev_n, + float prev_m, + float* next_y, + float* next_c, + float* next_n, + float* next_m) { + const float logfplusm = prev_m + logsigmoidf_stable(fraw); + const bool is_first = prev_n == 0.0f; + const float m_out = is_first ? iraw : fmaxf(iraw, logfplusm); + const float i_exp = expf(iraw - m_out); + const float f_exp = expf(logfplusm - m_out); + const float i_gate = fminf(i_exp, 1.0f); + const float f_gate = fminf(f_exp, 1.0f); + const float tanh_z = tanhf(zraw); + const float o_gate = sigmoidf_stable(oraw); + const float c_out = f_gate * prev_c + i_gate * tanh_z; + const float n_out = f_gate * prev_n + i_gate; + *next_y = o_gate * c_out / (n_out + 1.0e-6f); + *next_c = c_out; + *next_n = n_out; + *next_m = m_out; +} + +__global__ void program_transition_gated_logspace_recurrence_forward_kernel( + const float* __restrict__ gate_logits, + const float* __restrict__ recurrent_gate_logits, + const float* __restrict__ c_prev, + const float* __restrict__ n_prev, + const float* __restrict__ m_prev, + float* __restrict__ next_y, + float* __restrict__ next_c, + float* __restrict__ next_n, + float* __restrict__ next_m, + int64_t total_elements, + int receivers, + int hidden, + bool has_recurrent_gate_logits, + bool has_prev_state) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden); + const int receiver = static_cast((linear / hidden) % receivers); + const int b = static_cast(linear / (static_cast(hidden) * receivers)); + const int64_t state_offset = (static_cast(b) * receivers + receiver) * hidden + h; + const int64_t gate_base = ((static_cast(b) * receivers + receiver) * 4) * hidden + h; + const float recurrent_i = has_recurrent_gate_logits ? recurrent_gate_logits[gate_base] : 0.0f; + const float recurrent_f = has_recurrent_gate_logits ? recurrent_gate_logits[gate_base + hidden] : 0.0f; + const float recurrent_z = has_recurrent_gate_logits ? recurrent_gate_logits[gate_base + 2 * hidden] : 0.0f; + const float recurrent_o = has_recurrent_gate_logits ? recurrent_gate_logits[gate_base + 3 * hidden] : 0.0f; + const float iraw = gate_logits[gate_base] + recurrent_i; + const float fraw = gate_logits[gate_base + hidden] + recurrent_f; + const float zraw = gate_logits[gate_base + 2 * hidden] + recurrent_z; + const float oraw = gate_logits[gate_base + 3 * hidden] + recurrent_o; + const float prev_c = has_prev_state ? c_prev[state_offset] : 0.0f; + const float prev_n = has_prev_state ? n_prev[state_offset] : 0.0f; + const float prev_m = has_prev_state ? m_prev[state_offset] : 0.0f; + float y_out = 0.0f; + float c_out = 0.0f; + float n_out = 0.0f; + float m_out = 0.0f; + gated_logspace_core_values( + iraw, + fraw, + zraw, + oraw, + prev_c, + prev_n, + prev_m, + &y_out, + &c_out, + &n_out, + &m_out); + next_y[state_offset] = y_out; + if (next_c != nullptr) { + next_c[state_offset] = c_out; + } + if (next_n != nullptr) { + next_n[state_offset] = n_out; + } + if (next_m != nullptr) { + next_m[state_offset] = m_out; + } + } +} + +__global__ void program_transition_gated_logspace_recurrence_backward_kernel( + const float* __restrict__ gate_logits, + const float* __restrict__ recurrent_gate_logits, + const float* __restrict__ c_prev, + const float* __restrict__ n_prev, + const float* __restrict__ m_prev, + const float* __restrict__ grad_next_y, + const float* __restrict__ grad_next_c, + const float* __restrict__ grad_next_n, + const float* __restrict__ grad_next_m, + float* __restrict__ grad_raw, + float* __restrict__ grad_c_prev, + float* __restrict__ grad_n_prev, + float* __restrict__ grad_m_prev, + int64_t total_elements, + int receivers, + int hidden, + bool has_recurrent_gate_logits, + bool has_prev_state, + bool has_grad_next_y, + bool has_grad_next_c, + bool has_grad_next_n, + bool has_grad_next_m, + bool return_state_grads) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden); + const int receiver = static_cast((linear / hidden) % receivers); + const int b = static_cast(linear / (static_cast(hidden) * receivers)); + const int64_t state_offset = (static_cast(b) * receivers + receiver) * hidden + h; + const int64_t gate_base = ((static_cast(b) * receivers + receiver) * 4) * hidden + h; + const float recurrent_i = has_recurrent_gate_logits ? recurrent_gate_logits[gate_base] : 0.0f; + const float recurrent_f = has_recurrent_gate_logits ? recurrent_gate_logits[gate_base + hidden] : 0.0f; + const float recurrent_z = has_recurrent_gate_logits ? recurrent_gate_logits[gate_base + 2 * hidden] : 0.0f; + const float recurrent_o = has_recurrent_gate_logits ? recurrent_gate_logits[gate_base + 3 * hidden] : 0.0f; + const float iraw = gate_logits[gate_base] + recurrent_i; + const float fraw = gate_logits[gate_base + hidden] + recurrent_f; + const float zraw = gate_logits[gate_base + 2 * hidden] + recurrent_z; + const float oraw = gate_logits[gate_base + 3 * hidden] + recurrent_o; + const float prev_c = has_prev_state ? c_prev[state_offset] : 0.0f; + const float prev_n = has_prev_state ? n_prev[state_offset] : 0.0f; + const float prev_m = has_prev_state ? m_prev[state_offset] : 0.0f; + const float logfplusm = prev_m + logsigmoidf_stable(fraw); + const bool is_first = prev_n == 0.0f; + const float m_out = is_first ? iraw : fmaxf(iraw, logfplusm); + const float i_exp = expf(iraw - m_out); + const float f_exp = expf(logfplusm - m_out); + const float i_gate = fminf(i_exp, 1.0f); + const float f_gate = fminf(f_exp, 1.0f); + const float tanh_z = tanhf(zraw); + const float o_gate = sigmoidf_stable(oraw); + const float c_out = f_gate * prev_c + i_gate * tanh_z; + const float n_out = f_gate * prev_n + i_gate; + const float denom = n_out + 1.0e-6f; + + const float gy = has_grad_next_y ? grad_next_y[state_offset] : 0.0f; + float gc = has_grad_next_c ? grad_next_c[state_offset] : 0.0f; + float gn = has_grad_next_n ? grad_next_n[state_offset] : 0.0f; + float gm = has_grad_next_m ? grad_next_m[state_offset] : 0.0f; + + const float grad_o = gy * c_out / denom; + gc += gy * o_gate / denom; + gn += -gy * o_gate * c_out / (denom * denom); + const float grad_f = gc * prev_c + gn * prev_n; + const float grad_i = gc * tanh_z + gn; + const float grad_zraw = gc * i_gate * (1.0f - tanh_z * tanh_z); + const float grad_oraw = grad_o * o_gate * (1.0f - o_gate); + const float out_grad_c_prev = gc * f_gate; + const float out_grad_n_prev = gn * f_gate; + + const float grad_i_exp = i_exp < 1.0f ? grad_i : 0.0f; + const float grad_f_exp = f_exp < 1.0f ? grad_f : 0.0f; + const float grad_i_arg = grad_i_exp * i_exp; + const float grad_f_arg = grad_f_exp * f_exp; + float grad_iraw = grad_i_arg; + float grad_logfplusm = grad_f_arg; + gm = gm - grad_i_arg - grad_f_arg; + const bool choose_iraw = iraw >= logfplusm; + grad_iraw += (is_first || choose_iraw) ? gm : 0.0f; + grad_logfplusm += ((!is_first) && (!choose_iraw)) ? gm : 0.0f; + const float out_grad_m_prev = grad_logfplusm; + const float grad_fraw = grad_logfplusm * sigmoidf_stable(-fraw); + + grad_raw[gate_base] = grad_iraw; + grad_raw[gate_base + hidden] = grad_fraw; + grad_raw[gate_base + 2 * hidden] = grad_zraw; + grad_raw[gate_base + 3 * hidden] = grad_oraw; + if (return_state_grads) { + grad_c_prev[state_offset] = has_prev_state ? out_grad_c_prev : 0.0f; + grad_n_prev[state_offset] = has_prev_state ? out_grad_n_prev : 0.0f; + grad_m_prev[state_offset] = has_prev_state ? out_grad_m_prev : 0.0f; + } + } +} + +__global__ void program_transition_norm_or_identity_forward_kernel( + const float* __restrict__ input, + const float* __restrict__ weight, + float* __restrict__ output, + int64_t total_elements, + int receivers, + int hidden, + bool has_weight, + bool weight_is_shared, + float eps) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + if (!has_weight) { + output[linear] = input[linear]; + continue; + } + const int h = static_cast(linear % hidden); + const int receiver = static_cast((linear / hidden) % receivers); + const int64_t row_base = (linear / hidden) * hidden; + float mean = 0.0f; + for (int k = 0; k < hidden; ++k) { + mean += input[row_base + k]; + } + mean /= static_cast(hidden); + float second = 0.0f; + for (int k = 0; k < hidden; ++k) { + const float value = input[row_base + k]; + second += value * value; + } + const float var = fmaxf(second / static_cast(hidden) - mean * mean, 0.0f); + const float norm = (input[linear] - mean) * rsqrtf(var + eps); + const float scale = weight[weight_is_shared ? h : static_cast(receiver) * hidden + h]; + output[linear] = norm * scale; + } +} + +__global__ void program_transition_norm_or_identity_input_backward_kernel( + const float* __restrict__ input, + const float* __restrict__ weight, + const float* __restrict__ grad_output, + float* __restrict__ grad_input, + int64_t total_elements, + int receivers, + int hidden, + bool has_weight, + bool weight_is_shared, + float eps) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + if (!has_weight) { + grad_input[linear] = grad_output[linear]; + continue; + } + const int h = static_cast(linear % hidden); + const int receiver = static_cast((linear / hidden) % receivers); + const int64_t row_base = (linear / hidden) * hidden; + float mean = 0.0f; + for (int k = 0; k < hidden; ++k) { + mean += input[row_base + k]; + } + mean /= static_cast(hidden); + float second = 0.0f; + for (int k = 0; k < hidden; ++k) { + const float value = input[row_base + k]; + second += value * value; + } + const float var = fmaxf(second / static_cast(hidden) - mean * mean, 0.0f); + const float inv_std = rsqrtf(var + eps); + float grad_sum = 0.0f; + float grad_xhat_sum = 0.0f; + for (int k = 0; k < hidden; ++k) { + const float xhat = (input[row_base + k] - mean) * inv_std; + const float scale = weight[weight_is_shared ? k : static_cast(receiver) * hidden + k]; + const float grad_norm = grad_output[row_base + k] * scale; + grad_sum += grad_norm; + grad_xhat_sum += grad_norm * xhat; + } + const float xhat_h = (input[linear] - mean) * inv_std; + const float scale_h = weight[weight_is_shared ? h : static_cast(receiver) * hidden + h]; + const float grad_norm_h = grad_output[linear] * scale_h; + grad_input[linear] = (inv_std / static_cast(hidden)) * + (static_cast(hidden) * grad_norm_h - grad_sum - xhat_h * grad_xhat_sum); + } +} + +__global__ void program_transition_norm_or_identity_weight_backward_kernel( + const float* __restrict__ input, + const float* __restrict__ grad_output, + float* __restrict__ grad_weight, + int64_t total_elements, + int batch_size, + int receivers, + int hidden, + bool weight_is_shared, + float eps) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden); + const int owner = weight_is_shared ? 0 : static_cast(linear / hidden); + const int receiver_begin = weight_is_shared ? 0 : owner; + const int receiver_end = weight_is_shared ? receivers : owner + 1; + float accum = 0.0f; + for (int b = 0; b < batch_size; ++b) { + for (int receiver = receiver_begin; receiver < receiver_end; ++receiver) { + const int64_t row_base = (static_cast(b) * receivers + receiver) * hidden; + float mean = 0.0f; + for (int k = 0; k < hidden; ++k) { + mean += input[row_base + k]; + } + mean /= static_cast(hidden); + float second = 0.0f; + for (int k = 0; k < hidden; ++k) { + const float value = input[row_base + k]; + second += value * value; + } + const float var = fmaxf(second / static_cast(hidden) - mean * mean, 0.0f); + const float xhat = (input[row_base + h] - mean) * rsqrtf(var + eps); + accum += grad_output[row_base + h] * xhat; + } + } + grad_weight[linear] = accum; + } +} + +__device__ inline float program_diagonal_activation_forward(float value, int activation_id) { + if (activation_id == 0) { + return value * sigmoidf_stable(value); + } + if (activation_id == 1) { + return fmaxf(value, 0.0f); + } + if (activation_id == 2) { + return tanhf(value); + } + return value; +} + +__device__ inline float program_diagonal_activation_grad(float value, int activation_id) { + if (activation_id == 0) { + const float sig = sigmoidf_stable(value); + return sig * (1.0f + value * (1.0f - sig)); + } + if (activation_id == 1) { + return value > 0.0f ? 1.0f : 0.0f; + } + if (activation_id == 2) { + const float tanh_value = tanhf(value); + return 1.0f - tanh_value * tanh_value; + } + return 1.0f; +} + +__device__ inline void program_diagonal_coefficients( + float nu_log, + float theta_log, + float* g, + float* phi, + float* gamma, + float* d_g_d_nu, + float* d_phi_d_nu, + float* d_gamma_d_nu, + float* d_g_d_theta, + float* d_phi_d_theta) { + const float exp_nu = expf(nu_log); + const float radius = expf(-exp_nu); + const float theta = expf(theta_log); + const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta); + *g = radius * cos_theta; + *phi = radius * sin_theta; + const float radius_sq = radius * radius; + *gamma = sqrtf(fmaxf(1.0f - radius_sq, 0.0f)); + const float gamma_safe = fmaxf(*gamma, 1.0e-20f); + *d_g_d_nu = -exp_nu * (*g); + *d_phi_d_nu = -exp_nu * (*phi); + *d_gamma_d_nu = exp_nu * radius_sq / gamma_safe; + *d_g_d_theta = -(*phi) * theta; + *d_phi_d_theta = (*g) * theta; +} + +__global__ void program_transition_diag_rtu_forward_kernel( + const float* __restrict__ cell_input, + const float* __restrict__ hc1, + const float* __restrict__ hc2, + const float* __restrict__ e_nu_c1, + const float* __restrict__ e_nu_c2, + const float* __restrict__ e_th_c1, + const float* __restrict__ e_th_c2, + const float* __restrict__ e_w1_c1, + const float* __restrict__ e_w1_c2, + const float* __restrict__ e_w2_c1, + const float* __restrict__ e_w2_c2, + const float* __restrict__ nu_log, + const float* __restrict__ theta_log, + const float* __restrict__ w1, + const float* __restrict__ w2, + float* __restrict__ preproj, + float* __restrict__ next_hc1, + float* __restrict__ next_hc2, + float* __restrict__ next_e_nu_c1, + float* __restrict__ next_e_nu_c2, + float* __restrict__ next_e_th_c1, + float* __restrict__ next_e_th_c2, + float* __restrict__ next_e_w1_c1, + float* __restrict__ next_e_w1_c2, + float* __restrict__ next_e_w2_c1, + float* __restrict__ next_e_w2_c2, + int64_t total_elements, + int receivers, + int hidden, + int activation_id, + bool has_prev_state, + bool has_trace_state, + bool write_trace_state_next) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden); + const int receiver = static_cast((linear / hidden) % receivers); + const int64_t param_offset = static_cast(receiver) * hidden + h; + const float x = cell_input[linear]; + const float prev_c1 = has_prev_state ? hc1[linear] : 0.0f; + const float prev_c2 = has_prev_state ? hc2[linear] : 0.0f; + const float w1_value = w1[param_offset]; + const float w2_value = w2[param_offset]; + float g = 0.0f; + float phi = 0.0f; + float gamma = 0.0f; + float d_g_d_nu = 0.0f; + float d_phi_d_nu = 0.0f; + float d_gamma_d_nu = 0.0f; + float d_g_d_theta = 0.0f; + float d_phi_d_theta = 0.0f; + program_diagonal_coefficients( + nu_log[param_offset], + theta_log[param_offset], + &g, + &phi, + &gamma, + &d_g_d_nu, + &d_phi_d_nu, + &d_gamma_d_nu, + &d_g_d_theta, + &d_phi_d_theta); + const float c1 = gamma * w1_value * x + g * prev_c1 - phi * prev_c2; + const float c2 = gamma * w2_value * x + g * prev_c2 + phi * prev_c1; + const int64_t preproj_base = (linear / hidden) * (2 * hidden) + h; + preproj[preproj_base] = program_diagonal_activation_forward(c1, activation_id); + preproj[preproj_base + hidden] = program_diagonal_activation_forward(c2, activation_id); + if (next_hc1 != nullptr) { + next_hc1[linear] = c1; + } + if (next_hc2 != nullptr) { + next_hc2[linear] = c2; + } + if (write_trace_state_next) { + const float prev_e_nu_c1 = has_trace_state ? e_nu_c1[linear] : 0.0f; + const float prev_e_nu_c2 = has_trace_state ? e_nu_c2[linear] : 0.0f; + const float prev_e_th_c1 = has_trace_state ? e_th_c1[linear] : 0.0f; + const float prev_e_th_c2 = has_trace_state ? e_th_c2[linear] : 0.0f; + const float prev_e_w1_c1 = has_trace_state ? e_w1_c1[linear] : 0.0f; + const float prev_e_w1_c2 = has_trace_state ? e_w1_c2[linear] : 0.0f; + const float prev_e_w2_c1 = has_trace_state ? e_w2_c1[linear] : 0.0f; + const float prev_e_w2_c2 = has_trace_state ? e_w2_c2[linear] : 0.0f; + next_e_w1_c1[linear] = gamma * x + g * prev_e_w1_c1 - phi * prev_e_w1_c2; + next_e_w1_c2[linear] = g * prev_e_w1_c2 + phi * prev_e_w1_c1; + next_e_w2_c2[linear] = gamma * x + g * prev_e_w2_c2 + phi * prev_e_w2_c1; + next_e_w2_c1[linear] = g * prev_e_w2_c1 - phi * prev_e_w2_c2; + next_e_nu_c1[linear] = d_g_d_nu * prev_c1 + g * prev_e_nu_c1 - d_phi_d_nu * prev_c2 - + phi * prev_e_nu_c2 + d_gamma_d_nu * w1_value * x; + next_e_nu_c2[linear] = d_g_d_nu * prev_c2 + g * prev_e_nu_c2 + d_phi_d_nu * prev_c1 + + phi * prev_e_nu_c1 + d_gamma_d_nu * w2_value * x; + next_e_th_c1[linear] = + d_g_d_theta * prev_c1 + g * prev_e_th_c1 - d_phi_d_theta * prev_c2 - phi * prev_e_th_c2; + next_e_th_c2[linear] = + d_g_d_theta * prev_c2 + g * prev_e_th_c2 + d_phi_d_theta * prev_c1 + phi * prev_e_th_c1; + } + } +} + +__global__ void program_transition_diag_rtu_input_backward_kernel( + const float* __restrict__ cell_input, + const float* __restrict__ hc1, + const float* __restrict__ hc2, + const float* __restrict__ nu_log, + const float* __restrict__ theta_log, + const float* __restrict__ w1, + const float* __restrict__ w2, + const float* __restrict__ grad_preproj, + const float* __restrict__ grad_hc1_next, + const float* __restrict__ grad_hc2_next, + float* __restrict__ grad_cell_input, + float* __restrict__ grad_hc1, + float* __restrict__ grad_hc2, + int64_t total_elements, + int receivers, + int hidden, + int activation_id, + bool has_grad_preproj, + bool has_grad_hc1_next, + bool has_grad_hc2_next, + bool return_state_grads) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden); + const int receiver = static_cast((linear / hidden) % receivers); + const int64_t param_offset = static_cast(receiver) * hidden + h; + const float x = cell_input[linear]; + const float prev_c1 = hc1[linear]; + const float prev_c2 = hc2[linear]; + const float w1_value = w1[param_offset]; + const float w2_value = w2[param_offset]; + float g = 0.0f; + float phi = 0.0f; + float gamma = 0.0f; + float d_g_d_nu = 0.0f; + float d_phi_d_nu = 0.0f; + float d_gamma_d_nu = 0.0f; + float d_g_d_theta = 0.0f; + float d_phi_d_theta = 0.0f; + program_diagonal_coefficients( + nu_log[param_offset], + theta_log[param_offset], + &g, + &phi, + &gamma, + &d_g_d_nu, + &d_phi_d_nu, + &d_gamma_d_nu, + &d_g_d_theta, + &d_phi_d_theta); + const float c1 = gamma * w1_value * x + g * prev_c1 - phi * prev_c2; + const float c2 = gamma * w2_value * x + g * prev_c2 + phi * prev_c1; + float grad_c1 = has_grad_hc1_next ? grad_hc1_next[linear] : 0.0f; + float grad_c2 = has_grad_hc2_next ? grad_hc2_next[linear] : 0.0f; + if (has_grad_preproj) { + const int64_t preproj_base = (linear / hidden) * (2 * hidden) + h; + grad_c1 += grad_preproj[preproj_base] * program_diagonal_activation_grad(c1, activation_id); + grad_c2 += grad_preproj[preproj_base + hidden] * program_diagonal_activation_grad(c2, activation_id); + } + grad_cell_input[linear] = grad_c1 * gamma * w1_value + grad_c2 * gamma * w2_value; + if (return_state_grads) { + grad_hc1[linear] = grad_c1 * g + grad_c2 * phi; + grad_hc2[linear] = -grad_c1 * phi + grad_c2 * g; + } + } +} + +__global__ void program_transition_diag_rtu_param_backward_kernel( + const float* __restrict__ cell_input, + const float* __restrict__ hc1, + const float* __restrict__ hc2, + const float* __restrict__ nu_log, + const float* __restrict__ theta_log, + const float* __restrict__ w1, + const float* __restrict__ w2, + const float* __restrict__ grad_preproj, + const float* __restrict__ grad_hc1_next, + const float* __restrict__ grad_hc2_next, + float* __restrict__ grad_nu_log, + float* __restrict__ grad_theta_log, + float* __restrict__ grad_w1, + float* __restrict__ grad_w2, + int64_t total_elements, + int batch_size, + int receivers, + int hidden, + int activation_id, + bool has_grad_preproj, + bool has_grad_hc1_next, + bool has_grad_hc2_next) { + for (int64_t linear = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + linear < total_elements; + linear += static_cast(blockDim.x) * gridDim.x) { + const int h = static_cast(linear % hidden); + const int receiver = static_cast(linear / hidden); + const int64_t param_offset = static_cast(receiver) * hidden + h; + float grad_nu = 0.0f; + float grad_theta = 0.0f; + float out_grad_w1 = 0.0f; + float out_grad_w2 = 0.0f; + for (int b = 0; b < batch_size; ++b) { + const int64_t state_offset = (static_cast(b) * receivers + receiver) * hidden + h; + const float x = cell_input[state_offset]; + const float prev_c1 = hc1[state_offset]; + const float prev_c2 = hc2[state_offset]; + const float w1_value = w1[param_offset]; + const float w2_value = w2[param_offset]; + float g = 0.0f; + float phi = 0.0f; + float gamma = 0.0f; + float d_g_d_nu = 0.0f; + float d_phi_d_nu = 0.0f; + float d_gamma_d_nu = 0.0f; + float d_g_d_theta = 0.0f; + float d_phi_d_theta = 0.0f; + program_diagonal_coefficients( + nu_log[param_offset], + theta_log[param_offset], + &g, + &phi, + &gamma, + &d_g_d_nu, + &d_phi_d_nu, + &d_gamma_d_nu, + &d_g_d_theta, + &d_phi_d_theta); + const float c1 = gamma * w1_value * x + g * prev_c1 - phi * prev_c2; + const float c2 = gamma * w2_value * x + g * prev_c2 + phi * prev_c1; + float grad_c1 = has_grad_hc1_next ? grad_hc1_next[state_offset] : 0.0f; + float grad_c2 = has_grad_hc2_next ? grad_hc2_next[state_offset] : 0.0f; + if (has_grad_preproj) { + const int64_t preproj_base = (state_offset / hidden) * (2 * hidden) + h; + grad_c1 += grad_preproj[preproj_base] * program_diagonal_activation_grad(c1, activation_id); + grad_c2 += grad_preproj[preproj_base + hidden] * program_diagonal_activation_grad(c2, activation_id); + } + const float grad_g = grad_c1 * prev_c1 + grad_c2 * prev_c2; + const float grad_phi = -grad_c1 * prev_c2 + grad_c2 * prev_c1; + const float grad_gamma = grad_c1 * w1_value * x + grad_c2 * w2_value * x; + grad_nu += grad_g * d_g_d_nu + grad_phi * d_phi_d_nu + grad_gamma * d_gamma_d_nu; + grad_theta += grad_g * d_g_d_theta + grad_phi * d_phi_d_theta; + out_grad_w1 += grad_c1 * gamma * x; + out_grad_w2 += grad_c2 * gamma * x; + } + grad_nu_log[param_offset] = grad_nu; + grad_theta_log[param_offset] = grad_theta; + grad_w1[param_offset] = out_grad_w1; + grad_w2[param_offset] = out_grad_w2; + } +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_forward_program.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_forward_program.cuh new file mode 100644 index 00000000..9bffbd6d --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_forward_program.cuh @@ -0,0 +1,2872 @@ +#pragma once + +std::vector flat_bucket_registered_temporal_fused_forward_program_validate_cuda( + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version) { + return validate_registered_temporal_fused_program( + primitive_rows, + forward_executor_rows, + reverse_executor_rows, + forward_handler_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_binding_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version); +} + +using RegisteredTransitionForwardPrimitiveRunFn = void (*)( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + int64_t primitive_opcode, + int64_t native_callable_hash, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs); + +struct RegisteredTransitionForwardPrimitiveExecutor { + int64_t native_callable_hash; + const char* name; + RegisteredTransitionForwardPrimitiveRunFn run; +}; + +constexpr int64_t kRegisteredTransitionForwardScratchChunkBytes = 256LL * 1024LL * 1024LL; + +inline void run_registered_transition_linear_forward_primitive( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + int64_t primitive_opcode, + int64_t native_callable_hash, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs) { + (void)forward_executor_rows; + (void)forward_executor_binding_rows; + const int64_t input_binding = native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingInput, + "input", + true, + 1, + "registered transition linear forward primitive"); + const int64_t weight_binding = native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingParameter, + "weight", + true, + 1, + "registered transition linear forward primitive"); + const int64_t bias_binding = native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingParameter, + "bias", + false, + 1, + "registered transition linear forward primitive"); + const int64_t output_binding = native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingOutput, + "output", + true, + 1, + "registered transition linear forward primitive"); + at::Tensor input = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, input_binding, "registered transition linear input"); + at::Tensor weight = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, weight_binding, "registered transition linear weight"); + at::Tensor bias = bias_binding >= 0 + ? program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, bias_binding, "registered transition linear bias") + : input.new_empty({0}); + if (weight.dim() == 4) { + check_cuda_float_bank(input, "registered transition gate affine input"); + TORCH_CHECK(weight.size(0) == input.size(1), "registered transition gate affine weight R must match input"); + TORCH_CHECK(input.size(2) == weight.size(1) * weight.size(2), "registered transition gate affine hidden mismatch"); + at::Tensor output = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 0, + output_binding, + 1, + {input.size(0), input.size(1), 4, input.size(2)}, + "registered transition gate affine linear output"); + program_transition_linear_gate_affine_forward_into(input, weight, bias, output); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding, + output, + "registered transition gate affine linear output"); + return; + } + at::Tensor weight_view = program_transition_linear_weight_view(input, weight, 1); + at::Tensor dense_weight = weight.dim() == 2 ? weight : weight_view; + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t output_dim = program_transition_linear_output_dim(weight_view); + validate_program_transition_linear_bias(bias, receivers, output_dim, 1); + at::Tensor output = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 0, + output_binding, + 1, + {B, receivers, output_dim}, + "registered transition linear output"); + fabric::cuda::ops::dense_affine_out_cuda( + input, + dense_weight, + bias, + output, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding, output, "registered transition linear output"); +} + +inline void run_registered_transition_matmul_forward_primitive( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + int64_t primitive_opcode, + int64_t native_callable_hash, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs) { + (void)forward_executor_rows; + (void)forward_executor_binding_rows; + (void)span; + const int64_t input_binding = native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingInput, + "input", + true, + 1, + "registered transition matmul forward primitive"); + const int64_t weight_binding = native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingParameter, + "weight", + true, + 1, + "registered transition matmul forward primitive"); + const int64_t output_binding = native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingOutput, + "output", + true, + 1, + "registered transition matmul forward primitive"); + at::Tensor input = program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, input_binding, "registered transition matmul input"); + at::Tensor kernel = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, weight_binding, "registered transition matmul weight"); + check_cuda_float_bank(input, "registered transition matmul input"); + TORCH_CHECK(kernel.is_cuda() && kernel.is_contiguous(), "registered transition matmul kernel must be CUDA contiguous"); + TORCH_CHECK(kernel.scalar_type() == at::kFloat && kernel.dim() == 5, "registered transition matmul kernel must be [R,G,Heads,D,D]"); + const bool fresh_zero_input = input.numel() == 0; + const int64_t B = input.size(0); + const int64_t receivers = fresh_zero_input ? span.receiver_count : input.size(1); + const int64_t hidden = input.size(2); + const int64_t gates = kernel.size(1); + const int64_t heads = kernel.size(2); + const int64_t head_dim = kernel.size(3); + TORCH_CHECK( + !fresh_zero_input || input.size(1) == 0, + "registered transition matmul fresh-zero input sentinel must carry shape [B,0,H]"); + TORCH_CHECK(receivers >= 0, "registered transition matmul receiver count must be non-negative"); + TORCH_CHECK(kernel.size(0) == receivers && kernel.size(4) == head_dim, "registered transition matmul kernel shape mismatch"); + TORCH_CHECK(hidden == heads * head_dim, "registered transition matmul hidden/head shape mismatch"); + at::Tensor output = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 0, + output_binding, + 1, + {B, receivers, gates, hidden}, + "registered transition matmul output"); + const int64_t total = output.numel(); + if (fresh_zero_input) { + output.zero_(); + } else if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_recurrent_matmul_forward_kernel<<>>( + input.data_ptr(), + kernel.data_ptr(), + output.data_ptr(), + total, + static_cast(receivers), + static_cast(gates), + static_cast(heads), + static_cast(head_dim)); + check_launch("registered_transition_recurrent_matmul_forward_primitive_kernel"); + } + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding, output, "registered transition matmul output"); +} + +inline void run_registered_transition_tanh_forward_primitive( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + int64_t primitive_opcode, + int64_t native_callable_hash, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs) { + (void)forward_executor_rows; + (void)forward_executor_binding_rows; + (void)span; + (void)params; + const int64_t input_binding = native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingInput, + "input", + true, + 1, + "registered transition tanh forward primitive"); + const int64_t output_binding = native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingOutput, + "output", + true, + 1, + "registered transition tanh forward primitive"); + at::Tensor input = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, input_binding, "registered transition tanh input"); + check_cuda_float_bank(input, "registered transition tanh input"); + at::Tensor output = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 0, + output_binding, + 1, + {input.size(0), input.size(1), input.size(2)}, + "registered transition tanh output"); + const int64_t total = output.numel(); + if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_tanh_forward_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + total); + check_launch("registered_transition_tanh_forward_primitive_kernel"); + } + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding, output, "registered transition tanh output"); +} + +inline void run_registered_transition_gated_logspace_forward_primitive( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + int64_t primitive_opcode, + int64_t native_callable_hash, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs) { + (void)forward_executor_rows; + (void)forward_executor_binding_rows; + (void)span; + (void)params; + const auto input_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingInput, + logical_name, + true, + 1, + "registered transition gated_logspace forward primitive"); + }; + const auto output_binding = [&](const char* logical_name, bool required) { + return native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingOutput, + logical_name, + required, + 1, + "registered transition gated_logspace forward primitive"); + }; + const int64_t gate_logits_binding = input_binding("gate_logits"); + const int64_t recurrent_gate_logits_binding = input_binding("recurrent_gate_logits"); + const int64_t c_prev_binding = input_binding("c_prev"); + const int64_t n_prev_binding = input_binding("n_prev"); + const int64_t m_prev_binding = input_binding("m_prev"); + const int64_t next_y_binding = output_binding("next_y", true); + const int64_t next_c_binding = output_binding("next_c", false); + const int64_t next_n_binding = output_binding("next_n", false); + const int64_t next_m_binding = output_binding("next_m", false); + at::Tensor gate_logits = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, gate_logits_binding, "registered gated gate_logits"); + at::Tensor recurrent_gate_logits = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, recurrent_gate_logits_binding, "registered gated recurrent_gate_logits"); + at::Tensor c_prev = program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, c_prev_binding, "registered gated c_prev"); + at::Tensor n_prev = program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, n_prev_binding, "registered gated n_prev"); + at::Tensor m_prev = program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, m_prev_binding, "registered gated m_prev"); + check_program_gated_logits(gate_logits, "registered gated gate_logits"); + check_program_gated_logits(recurrent_gate_logits, "registered gated recurrent_gate_logits"); + const int64_t B = gate_logits.size(0); + const int64_t receivers = gate_logits.size(1); + const int64_t hidden = gate_logits.size(3); + const std::vector state_shape = {B, receivers, hidden}; + const bool has_prev_state = c_prev.defined() && c_prev.numel() > 0; + TORCH_CHECK( + has_prev_state == (n_prev.defined() && n_prev.numel() > 0) && + has_prev_state == (m_prev.defined() && m_prev.numel() > 0), + "registered gated fresh-zero state sentinel must cover c/n/m together"); + if (has_prev_state) { + check_program_diag_state_tensor(c_prev, "registered gated c_prev", state_shape); + check_program_diag_state_tensor(n_prev, "registered gated n_prev", state_shape); + check_program_diag_state_tensor(m_prev, "registered gated m_prev", state_shape); + } + at::Tensor next_y = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 0, + next_y_binding, + 1, + state_shape, + "registered gated next_y"); + at::Tensor next_c = next_c_binding >= 0 + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 1, + next_c_binding, + 1, + state_shape, + "registered gated next_c") + : at::Tensor(); + at::Tensor next_n = next_n_binding >= 0 + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 2, + next_n_binding, + 1, + state_shape, + "registered gated next_n") + : at::Tensor(); + at::Tensor next_m = next_m_binding >= 0 + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 3, + next_m_binding, + 1, + state_shape, + "registered gated next_m") + : at::Tensor(); + const int64_t total = next_y.numel(); + if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_gated_logspace_recurrence_forward_kernel<<>>( + gate_logits.data_ptr(), + recurrent_gate_logits.data_ptr(), + has_prev_state ? c_prev.data_ptr() : nullptr, + has_prev_state ? n_prev.data_ptr() : nullptr, + has_prev_state ? m_prev.data_ptr() : nullptr, + next_y.data_ptr(), + next_c.defined() ? next_c.data_ptr() : nullptr, + next_n.defined() ? next_n.data_ptr() : nullptr, + next_m.defined() ? next_m.data_ptr() : nullptr, + total, + static_cast(receivers), + static_cast(hidden), + true, + has_prev_state); + check_launch("registered_transition_gated_logspace_forward_primitive_kernel"); + } + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_y_binding, next_y, "registered gated next_y"); + if (next_c_binding >= 0) { + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_c_binding, next_c, "registered gated next_c"); + } + if (next_n_binding >= 0) { + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_n_binding, next_n, "registered gated next_n"); + } + if (next_m_binding >= 0) { + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_m_binding, next_m, "registered gated next_m"); + } +} + +inline void run_registered_transition_norm_or_identity_forward_primitive( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + int64_t primitive_opcode, + int64_t native_callable_hash, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs) { + (void)forward_executor_rows; + (void)forward_executor_binding_rows; + (void)span; + const int64_t input_binding = native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingInput, + "input", + true, + 1, + "registered transition norm_or_identity forward primitive"); + const int64_t weight_binding = native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingParameter, + "weight", + true, + 1, + "registered transition norm_or_identity forward primitive"); + const int64_t eps_binding = native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingParameter, + "eps", + false, + 1, + "registered transition norm_or_identity forward primitive"); + const int64_t output_binding = native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingOutput, + "output", + true, + 1, + "registered transition norm_or_identity forward primitive"); + at::Tensor input = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, input_binding, "registered norm input"); + at::Tensor weight = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, weight_binding, "registered norm weight"); + const double eps = eps_binding >= 0 + ? program_scalar_double_for_binding( + program_tensors, + program_tensor_binding_rows, + eps_binding, + "registered norm eps") + : 1.0e-5; + check_cuda_float_bank(input, "registered norm input"); + const bool has_weight = check_optional_program_norm_weight(weight, input.size(1), input.size(2)); + at::Tensor output = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 0, + output_binding, + 1, + {input.size(0), input.size(1), input.size(2)}, + "registered transition norm output"); + const int64_t total = output.numel(); + if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_norm_or_identity_forward_kernel<<>>( + input.data_ptr(), + has_weight ? weight.data_ptr() : nullptr, + output.data_ptr(), + total, + static_cast(input.size(1)), + static_cast(input.size(2)), + has_weight, + has_weight && weight.dim() == 1, + static_cast(eps)); + check_launch("registered_transition_norm_or_identity_forward_primitive_kernel"); + } + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding, output, "registered norm output"); +} + +inline void run_registered_transition_diag_rtu_forward_primitive( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + int64_t primitive_opcode, + int64_t native_callable_hash, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs) { + (void)forward_executor_rows; + (void)forward_executor_binding_rows; + (void)span; + const auto input_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingInput, + logical_name, + true, + 1, + "registered transition diag_rtu forward primitive"); + }; + const auto parameter_binding = [&](const char* logical_name, bool required) { + return native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingParameter, + logical_name, + required, + 1, + "registered transition diag_rtu forward primitive"); + }; + const auto output_binding = [&](const char* logical_name, bool required) { + return native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kForwardDirectionOpcode, + primitive_opcode, + kNativeCallableBindingOutput, + logical_name, + required, + 1, + "registered transition diag_rtu forward primitive"); + }; + const int64_t cell_input_binding = input_binding("cell_input"); + const int64_t hc1_binding = input_binding("hc1"); + const int64_t hc2_binding = input_binding("hc2"); + const int64_t nu_log_binding = parameter_binding("nu_log", true); + const int64_t theta_log_binding = parameter_binding("theta_log", true); + const int64_t w1_binding = parameter_binding("w1", true); + const int64_t w2_binding = parameter_binding("w2", true); + const int64_t activation_id_binding = parameter_binding("activation_id", false); + const int64_t preproj_binding = output_binding("preproj", true); + const int64_t next_hc1_binding = output_binding("next_hc1", false); + const int64_t next_hc2_binding = output_binding("next_hc2", false); + TORCH_CHECK( + (next_hc1_binding >= 0) == (next_hc2_binding >= 0), + "registered transition diag_rtu forward primitive has partial state output bindings"); + const int64_t next_e_nu_c1_binding = output_binding("next_E_nu_c1", false); + const int64_t next_e_nu_c2_binding = output_binding("next_E_nu_c2", false); + const int64_t next_e_th_c1_binding = output_binding("next_E_th_c1", false); + const int64_t next_e_th_c2_binding = output_binding("next_E_th_c2", false); + const int64_t next_e_w1_c1_binding = output_binding("next_E_w1_c1", false); + const int64_t next_e_w1_c2_binding = output_binding("next_E_w1_c2", false); + const int64_t next_e_w2_c1_binding = output_binding("next_E_w2_c1", false); + const int64_t next_e_w2_c2_binding = output_binding("next_E_w2_c2", false); + const bool write_trace = next_e_nu_c1_binding >= 0; + TORCH_CHECK( + !write_trace || + (next_e_nu_c2_binding >= 0 && next_e_th_c1_binding >= 0 && next_e_th_c2_binding >= 0 && + next_e_w1_c1_binding >= 0 && next_e_w1_c2_binding >= 0 && next_e_w2_c1_binding >= 0 && + next_e_w2_c2_binding >= 0), + "registered transition diag_rtu forward primitive has partial trace output bindings"); + const int64_t e_nu_c1_binding = write_trace ? input_binding("E_nu_c1") : -1; + const int64_t e_nu_c2_binding = write_trace ? input_binding("E_nu_c2") : -1; + const int64_t e_th_c1_binding = write_trace ? input_binding("E_th_c1") : -1; + const int64_t e_th_c2_binding = write_trace ? input_binding("E_th_c2") : -1; + const int64_t e_w1_c1_binding = write_trace ? input_binding("E_w1_c1") : -1; + const int64_t e_w1_c2_binding = write_trace ? input_binding("E_w1_c2") : -1; + const int64_t e_w2_c1_binding = write_trace ? input_binding("E_w2_c1") : -1; + const int64_t e_w2_c2_binding = write_trace ? input_binding("E_w2_c2") : -1; + at::Tensor cell_input = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, cell_input_binding, "registered diag cell_input"); + at::Tensor hc1 = program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, hc1_binding, "registered diag hc1"); + at::Tensor hc2 = program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, hc2_binding, "registered diag hc2"); + at::Tensor e_nu_c1 = write_trace + ? program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, e_nu_c1_binding, "registered diag e_nu_c1") + : cell_input.new_empty({0}); + at::Tensor e_nu_c2 = write_trace + ? program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, e_nu_c2_binding, "registered diag e_nu_c2") + : cell_input.new_empty({0}); + at::Tensor e_th_c1 = write_trace + ? program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, e_th_c1_binding, "registered diag e_th_c1") + : cell_input.new_empty({0}); + at::Tensor e_th_c2 = write_trace + ? program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, e_th_c2_binding, "registered diag e_th_c2") + : cell_input.new_empty({0}); + at::Tensor e_w1_c1 = write_trace + ? program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, e_w1_c1_binding, "registered diag e_w1_c1") + : cell_input.new_empty({0}); + at::Tensor e_w1_c2 = write_trace + ? program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, e_w1_c2_binding, "registered diag e_w1_c2") + : cell_input.new_empty({0}); + at::Tensor e_w2_c1 = write_trace + ? program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, e_w2_c1_binding, "registered diag e_w2_c1") + : cell_input.new_empty({0}); + at::Tensor e_w2_c2 = write_trace + ? program_tensor_for_binding_allow_empty( + program_tensors, program_tensor_binding_rows, e_w2_c2_binding, "registered diag e_w2_c2") + : cell_input.new_empty({0}); + at::Tensor nu_log = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, nu_log_binding, "registered diag nu_log"); + at::Tensor theta_log = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, theta_log_binding, "registered diag theta_log"); + at::Tensor w1 = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, w1_binding, "registered diag w1"); + at::Tensor w2 = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, w2_binding, "registered diag w2"); + const int64_t activation_id = activation_id_binding >= 0 + ? program_scalar_int_for_binding( + program_tensors, + program_tensor_binding_rows, + activation_id_binding, + "registered diag activation_id") + : 0; + check_cuda_float_bank(cell_input, "registered diag cell_input"); + const int64_t B = cell_input.size(0); + const int64_t receivers = cell_input.size(1); + const int64_t hidden = cell_input.size(2); + const std::vector state_shape = {B, receivers, hidden}; + const bool has_prev_state = hc1.defined() && hc1.numel() > 0; + TORCH_CHECK( + has_prev_state == (hc2.defined() && hc2.numel() > 0), + "registered diag fresh-zero state sentinel must cover hc1/hc2 together"); + if (has_prev_state) { + check_program_diag_state_tensor(hc1, "registered diag hc1", state_shape); + check_program_diag_state_tensor(hc2, "registered diag hc2", state_shape); + } + const bool has_trace_state = write_trace && e_nu_c1.defined() && e_nu_c1.numel() > 0; + if (write_trace) { + TORCH_CHECK( + has_trace_state == (e_nu_c2.defined() && e_nu_c2.numel() > 0) && + has_trace_state == (e_th_c1.defined() && e_th_c1.numel() > 0) && + has_trace_state == (e_th_c2.defined() && e_th_c2.numel() > 0) && + has_trace_state == (e_w1_c1.defined() && e_w1_c1.numel() > 0) && + has_trace_state == (e_w1_c2.defined() && e_w1_c2.numel() > 0) && + has_trace_state == (e_w2_c1.defined() && e_w2_c1.numel() > 0) && + has_trace_state == (e_w2_c2.defined() && e_w2_c2.numel() > 0), + "registered diag fresh-zero trace sentinel must cover all trace inputs together"); + if (has_trace_state) { + check_program_diag_state_tensor(e_nu_c1, "registered diag e_nu_c1", state_shape); + check_program_diag_state_tensor(e_nu_c2, "registered diag e_nu_c2", state_shape); + check_program_diag_state_tensor(e_th_c1, "registered diag e_th_c1", state_shape); + check_program_diag_state_tensor(e_th_c2, "registered diag e_th_c2", state_shape); + check_program_diag_state_tensor(e_w1_c1, "registered diag e_w1_c1", state_shape); + check_program_diag_state_tensor(e_w1_c2, "registered diag e_w1_c2", state_shape); + check_program_diag_state_tensor(e_w2_c1, "registered diag e_w2_c1", state_shape); + check_program_diag_state_tensor(e_w2_c2, "registered diag e_w2_c2", state_shape); + } + } + check_program_diag_param_tensor(nu_log, "registered diag nu_log", receivers, hidden); + check_program_diag_param_tensor(theta_log, "registered diag theta_log", receivers, hidden); + check_program_diag_param_tensor(w1, "registered diag w1", receivers, hidden); + check_program_diag_param_tensor(w2, "registered diag w2", receivers, hidden); + at::Tensor preproj = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 0, + preproj_binding, + 1, + {B, receivers, 2 * hidden}, + "registered diag preproj"); + at::Tensor next_hc1 = next_hc1_binding >= 0 + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 1, + next_hc1_binding, + 1, + state_shape, + "registered diag next_hc1") + : at::Tensor(); + at::Tensor next_hc2 = next_hc2_binding >= 0 + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 2, + next_hc2_binding, + 1, + state_shape, + "registered diag next_hc2") + : at::Tensor(); + at::Tensor next_e_nu_c1 = write_trace + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 3, + next_e_nu_c1_binding, + 1, + state_shape, + "registered diag next_e_nu_c1") + : cell_input.new_empty({0}); + at::Tensor next_e_nu_c2 = write_trace + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 4, + next_e_nu_c2_binding, + 1, + state_shape, + "registered diag next_e_nu_c2") + : cell_input.new_empty({0}); + at::Tensor next_e_th_c1 = write_trace + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 5, + next_e_th_c1_binding, + 1, + state_shape, + "registered diag next_e_th_c1") + : cell_input.new_empty({0}); + at::Tensor next_e_th_c2 = write_trace + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 6, + next_e_th_c2_binding, + 1, + state_shape, + "registered diag next_e_th_c2") + : cell_input.new_empty({0}); + at::Tensor next_e_w1_c1 = write_trace + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 7, + next_e_w1_c1_binding, + 1, + state_shape, + "registered diag next_e_w1_c1") + : cell_input.new_empty({0}); + at::Tensor next_e_w1_c2 = write_trace + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 8, + next_e_w1_c2_binding, + 1, + state_shape, + "registered diag next_e_w1_c2") + : cell_input.new_empty({0}); + at::Tensor next_e_w2_c1 = write_trace + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 9, + next_e_w2_c1_binding, + 1, + state_shape, + "registered diag next_e_w2_c1") + : cell_input.new_empty({0}); + at::Tensor next_e_w2_c2 = write_trace + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + native_callable_hash, + primitive_opcode, + primitive_row_index, + 10, + next_e_w2_c2_binding, + 1, + state_shape, + "registered diag next_e_w2_c2") + : cell_input.new_empty({0}); + const int64_t total = cell_input.numel(); + if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_diag_rtu_forward_kernel<<>>( + cell_input.data_ptr(), + has_prev_state ? hc1.data_ptr() : nullptr, + has_prev_state ? hc2.data_ptr() : nullptr, + has_trace_state ? e_nu_c1.data_ptr() : nullptr, + has_trace_state ? e_nu_c2.data_ptr() : nullptr, + has_trace_state ? e_th_c1.data_ptr() : nullptr, + has_trace_state ? e_th_c2.data_ptr() : nullptr, + has_trace_state ? e_w1_c1.data_ptr() : nullptr, + has_trace_state ? e_w1_c2.data_ptr() : nullptr, + has_trace_state ? e_w2_c1.data_ptr() : nullptr, + has_trace_state ? e_w2_c2.data_ptr() : nullptr, + nu_log.data_ptr(), + theta_log.data_ptr(), + w1.data_ptr(), + w2.data_ptr(), + preproj.data_ptr(), + next_hc1.defined() ? next_hc1.data_ptr() : nullptr, + next_hc2.defined() ? next_hc2.data_ptr() : nullptr, + write_trace ? next_e_nu_c1.data_ptr() : nullptr, + write_trace ? next_e_nu_c2.data_ptr() : nullptr, + write_trace ? next_e_th_c1.data_ptr() : nullptr, + write_trace ? next_e_th_c2.data_ptr() : nullptr, + write_trace ? next_e_w1_c1.data_ptr() : nullptr, + write_trace ? next_e_w1_c2.data_ptr() : nullptr, + write_trace ? next_e_w2_c1.data_ptr() : nullptr, + write_trace ? next_e_w2_c2.data_ptr() : nullptr, + total, + static_cast(receivers), + static_cast(hidden), + static_cast(activation_id), + has_prev_state, + has_trace_state, + write_trace); + check_launch("registered_transition_diag_rtu_forward_primitive_kernel"); + } + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, preproj_binding, preproj, "registered diag preproj"); + if (next_hc1_binding >= 0) { + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_hc1_binding, next_hc1, "registered diag next_hc1"); + } + if (next_hc2_binding >= 0) { + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_hc2_binding, next_hc2, "registered diag next_hc2"); + } + if (write_trace) { + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_e_nu_c1_binding, next_e_nu_c1, "registered diag next_e_nu_c1"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_e_nu_c2_binding, next_e_nu_c2, "registered diag next_e_nu_c2"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_e_th_c1_binding, next_e_th_c1, "registered diag next_e_th_c1"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_e_th_c2_binding, next_e_th_c2, "registered diag next_e_th_c2"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_e_w1_c1_binding, next_e_w1_c1, "registered diag next_e_w1_c1"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_e_w1_c2_binding, next_e_w1_c2, "registered diag next_e_w1_c2"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_e_w2_c1_binding, next_e_w2_c1, "registered diag next_e_w2_c1"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, next_e_w2_c2_binding, next_e_w2_c2, "registered diag next_e_w2_c2"); + } +} + +inline bool program_tensor_bindings_share_slot( + const at::Tensor& program_tensor_binding_rows, + int64_t lhs_binding, + int64_t rhs_binding) { + return program_tensor_index_for_binding(program_tensor_binding_rows, lhs_binding) == + program_tensor_index_for_binding(program_tensor_binding_rows, rhs_binding); +} + +inline int64_t registered_transition_forward_callable_hash_for_primitive( + const at::Tensor& transition_primitive_callable_rows, + int64_t primitive_opcode); + +struct RegisteredTransitionInputProjectionTarget { + bool defined; + int64_t aggregate_binding; + int64_t output_binding; + int64_t primitive_row_index; + int64_t native_callable_hash; + at::Tensor input_weight; + at::Tensor input_bias; + at::Tensor output; + int64_t receiver_count; + int64_t message_dim; + int64_t hidden; +}; + +inline at::Tensor registered_transition_input_projection_weight_view( + const at::Tensor& weight, + int64_t receivers, + int64_t message_dim, + const char* subject) { + TORCH_CHECK(weight.is_cuda(), subject, " weight must be CUDA"); + TORCH_CHECK(weight.is_contiguous(), subject, " weight must be contiguous"); + TORCH_CHECK(weight.scalar_type() == at::kFloat, subject, " weight must be float32"); + if (weight.dim() == 2) { + TORCH_CHECK(weight.size(0) == message_dim, subject, " shared weight K must match message output dim"); + return weight.view({1, message_dim, weight.size(1)}); + } + TORCH_CHECK(weight.dim() == 3, subject, " weight must be [K,H], [R,K,H], or [1,K,H]"); + TORCH_CHECK(weight.size(1) == message_dim, subject, " weight K must match message output dim"); + TORCH_CHECK( + weight.size(0) == receivers || weight.size(0) == 1, + subject, + " weight owner count must be 1 or transition receiver count"); + return weight; +} + +inline RegisteredTransitionInputProjectionTarget registered_transition_input_projection_target_for_span( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& primitive_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t batch_size, + int64_t message_dim, + const char* subject, + const at::Tensor& output_override = at::Tensor()) { + const int64_t p0 = span.primitive_row_start; + const int64_t* primitives = primitive_rows.data_ptr(); + TORCH_CHECK( + primitives[p0 * 4] == kPrimitiveLinearOpcode, + subject, + " requires the transition row group to begin with a compiler linear input projection"); + const std::vector p0_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p0); + const std::vector p0_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p0); + const std::vector p0_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p0); + TORCH_CHECK(p0_outputs.size() == 1, subject, " requires one transition input projection output"); + const int64_t linear_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveLinearOpcode); + const auto linear_binding = [&]( + const std::vector& bindings, + int64_t binding_kind, + const char* logical_name, + bool required) { + return native_callable_program_binding_for( + bindings, + native_callable_binding_schema_rows, + linear_hash, + kForwardDirectionOpcode, + kPrimitiveLinearOpcode, + binding_kind, + logical_name, + required, + 1, + subject); + }; + const int64_t aggregate_binding = linear_binding( + p0_inputs, + kNativeCallableBindingInput, + "input", + true); + const int64_t input_weight_binding = linear_binding( + p0_params, + kNativeCallableBindingParameter, + "weight", + true); + const int64_t input_bias_binding = linear_binding( + p0_params, + kNativeCallableBindingParameter, + "bias", + false); + const int64_t output_binding = linear_binding( + p0_outputs, + kNativeCallableBindingOutput, + "output", + true); + at::Tensor input_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_weight_binding, + subject); + const int64_t receivers = span.receiver_count; + at::Tensor input_weight_view = registered_transition_input_projection_weight_view( + input_weight, + receivers, + message_dim, + subject); + const int64_t hidden = program_transition_linear_output_dim(input_weight_view); + at::Tensor input_bias = input_bias_binding >= 0 + ? program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_bias_binding, + subject) + : input_weight.new_empty({0}); + validate_program_transition_linear_bias(input_bias, receivers, hidden, 1); + at::Tensor output; + if (output_override.defined() && output_override.numel() > 0) { + check_cuda_float_bank(output_override, subject); + TORCH_CHECK(output_override.is_contiguous(), subject, " output override must be contiguous"); + TORCH_CHECK(output_override.size(0) == batch_size, subject, " output override B mismatch"); + TORCH_CHECK(output_override.size(1) == receivers, subject, " output override receiver mismatch"); + TORCH_CHECK(output_override.size(2) == hidden, subject, " output override hidden mismatch"); + output = output_override; + } else { + output = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + linear_hash, + kPrimitiveLinearOpcode, + p0, + 0, + output_binding, + 1, + {batch_size, receivers, hidden}, + subject); + } + return RegisteredTransitionInputProjectionTarget{ + true, + aggregate_binding, + output_binding, + p0, + linear_hash, + input_weight.dim() == 2 ? input_weight : input_weight_view, + input_bias, + output, + receivers, + message_dim, + hidden, + }; +} + +inline bool try_run_registered_transition_diag_rtu_forward_row_group( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& primitive_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& forward_transition_state_carry_rows, + const RegisteredFusedProgramSpan& span, + bool release_dead_input_bindings) { + if (!release_dead_input_bindings || span.primitive_row_count != 4) { + return false; + } + const int64_t p0 = span.primitive_row_start; + const int64_t p1 = span.primitive_row_start + 1; + const int64_t p2 = span.primitive_row_start + 2; + const int64_t p3 = span.primitive_row_start + 3; + const int64_t* primitives = primitive_rows.data_ptr(); + const auto primitive_opcode = [&](int64_t primitive_row_index) { + return primitives[primitive_row_index * 4]; + }; + if ( + primitive_opcode(p0) != kPrimitiveLinearOpcode || + primitive_opcode(p1) != kPrimitiveDiagRtuOpcode || + primitive_opcode(p2) != kPrimitiveLinearOpcode || + primitive_opcode(p3) != kPrimitiveNormOrIdentityOpcode) { + return false; + } + + const std::vector p0_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p0); + const std::vector p0_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p0); + const std::vector p0_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p0); + const std::vector p1_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p1); + const std::vector p1_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p1); + const std::vector p1_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p1); + const std::vector p2_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p2); + const std::vector p2_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p2); + const std::vector p2_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p2); + const std::vector p3_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p3); + const std::vector p3_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p3); + const std::vector p3_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p3); + if ( + p0_outputs.size() != 1 || + p1_outputs.size() != 1 || + p2_outputs.size() != 1 || + p3_outputs.size() != 1) { + return false; + } + + const int64_t linear_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveLinearOpcode); + const int64_t diag_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveDiagRtuOpcode); + const int64_t norm_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveNormOrIdentityOpcode); + const auto linear_binding = [&]( + const std::vector& bindings, + int64_t binding_kind, + const char* logical_name, + bool required, + const char* subject) { + return native_callable_program_binding_for( + bindings, + native_callable_binding_schema_rows, + linear_hash, + kForwardDirectionOpcode, + kPrimitiveLinearOpcode, + binding_kind, + logical_name, + required, + 1, + subject); + }; + const auto diag_binding = [&]( + const std::vector& bindings, + int64_t binding_kind, + const char* logical_name, + bool required) { + return native_callable_program_binding_for( + bindings, + native_callable_binding_schema_rows, + diag_hash, + kForwardDirectionOpcode, + kPrimitiveDiagRtuOpcode, + binding_kind, + logical_name, + required, + 1, + "registered transition diag_rtu row-group"); + }; + + const int64_t aggregate_binding = linear_binding( + p0_inputs, + kNativeCallableBindingInput, + "input", + true, + "registered transition diag_rtu row-group input projection"); + const int64_t input_weight_binding = linear_binding( + p0_params, + kNativeCallableBindingParameter, + "weight", + true, + "registered transition diag_rtu row-group input projection"); + const int64_t input_bias_binding = linear_binding( + p0_params, + kNativeCallableBindingParameter, + "bias", + false, + "registered transition diag_rtu row-group input projection"); + const int64_t cell_input_output_binding = linear_binding( + p0_outputs, + kNativeCallableBindingOutput, + "output", + true, + "registered transition diag_rtu row-group input projection"); + const int64_t diag_cell_input_binding = diag_binding( + p1_inputs, + kNativeCallableBindingInput, + "cell_input", + true); + const int64_t preproj_binding = diag_binding( + p1_outputs, + kNativeCallableBindingOutput, + "preproj", + true); + const int64_t output_projection_input_binding = linear_binding( + p2_inputs, + kNativeCallableBindingInput, + "input", + true, + "registered transition diag_rtu row-group output projection"); + const int64_t output_weight_binding = linear_binding( + p2_params, + kNativeCallableBindingParameter, + "weight", + true, + "registered transition diag_rtu row-group output projection"); + const int64_t output_bias_binding = linear_binding( + p2_params, + kNativeCallableBindingParameter, + "bias", + false, + "registered transition diag_rtu row-group output projection"); + const int64_t raw_public_y_binding = linear_binding( + p2_outputs, + kNativeCallableBindingOutput, + "output", + true, + "registered transition diag_rtu row-group output projection"); + TORCH_CHECK( + program_tensor_bindings_share_slot(program_tensor_binding_rows, cell_input_output_binding, diag_cell_input_binding), + "registered transition diag_rtu row-group requires compiler tensor edge linear.output -> diag.cell_input"); + TORCH_CHECK( + program_tensor_bindings_share_slot(program_tensor_binding_rows, preproj_binding, output_projection_input_binding), + "registered transition diag_rtu row-group requires compiler tensor edge diag.preproj -> linear.input"); + + at::Tensor aggregate_input = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + aggregate_binding, + "registered transition diag_rtu row-group aggregate input"); + at::Tensor input_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_weight_binding, + "registered transition diag_rtu row-group input weight"); + at::Tensor prefilled_cell_input = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + cell_input_output_binding, + "registered transition diag_rtu row-group prefilled cell_input"); + const bool has_prefilled_cell_input = prefilled_cell_input.defined() && prefilled_cell_input.numel() > 0; + TORCH_CHECK( + has_prefilled_cell_input || (aggregate_input.defined() && aggregate_input.numel() > 0), + "registered transition diag_rtu row-group needs either aggregate input or a compiler-produced cell_input"); + const at::Tensor shape_reference = has_prefilled_cell_input ? prefilled_cell_input : aggregate_input; + if (has_prefilled_cell_input) { + check_cuda_float_bank(prefilled_cell_input, "registered transition diag_rtu row-group prefilled cell_input"); + } else { + check_cuda_float_bank(aggregate_input, "registered transition diag_rtu row-group aggregate input"); + } + at::Tensor input_weight_view = has_prefilled_cell_input + ? registered_transition_input_projection_weight_view( + input_weight, + shape_reference.size(1), + static_cast(input_weight.dim() == 2 ? input_weight.size(0) : input_weight.size(1)), + "registered transition diag_rtu row-group prefilled input projection") + : program_transition_linear_weight_view(aggregate_input, input_weight, 1); + at::Tensor input_dense_weight = input_weight.dim() == 2 ? input_weight : input_weight_view; + const int64_t B = shape_reference.size(0); + const int64_t receivers = shape_reference.size(1); + const int64_t hidden = has_prefilled_cell_input + ? prefilled_cell_input.size(2) + : program_transition_linear_output_dim(input_weight_view); + at::Tensor input_bias = input_bias_binding >= 0 + ? program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_bias_binding, + "registered transition diag_rtu row-group input bias") + : shape_reference.new_empty({0}); + validate_program_transition_linear_bias(input_bias, receivers, hidden, 1); + TORCH_CHECK( + !has_prefilled_cell_input || hidden == program_transition_linear_output_dim(input_weight_view), + "registered transition diag_rtu row-group prefilled cell_input H must match compiler input projection output"); + at::Tensor cell_input_buffer = has_prefilled_cell_input + ? prefilled_cell_input + : registered_runtime_buffer_for_native_callable_output_allow_deferred( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + linear_hash, + kPrimitiveLinearOpcode, + p0, + 0, + cell_input_output_binding, + 1, + {B, receivers, hidden}, + "registered transition diag_rtu row-group cell_input"); + const bool cell_input_is_deferred_local = + !has_prefilled_cell_input && registered_runtime_buffer_is_deferred_local_placeholder(cell_input_buffer); + at::Tensor cell_input = cell_input_is_deferred_local ? at::Tensor() : cell_input_buffer; + + const int64_t hc1_binding = diag_binding(p1_inputs, kNativeCallableBindingInput, "hc1", true); + const int64_t hc2_binding = diag_binding(p1_inputs, kNativeCallableBindingInput, "hc2", true); + at::Tensor hc1 = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + hc1_binding, + "registered transition diag_rtu row-group hc1"); + at::Tensor hc2 = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + hc2_binding, + "registered transition diag_rtu row-group hc2"); + const std::vector state_shape = {B, receivers, hidden}; + const bool has_prev_state = hc1.defined() && hc1.numel() > 0; + TORCH_CHECK( + has_prev_state == (hc2.defined() && hc2.numel() > 0), + "registered transition diag_rtu row-group fresh-zero state sentinel must cover hc1/hc2 together"); + if (has_prev_state) { + check_program_diag_state_tensor(hc1, "registered transition diag_rtu row-group hc1", state_shape); + check_program_diag_state_tensor(hc2, "registered transition diag_rtu row-group hc2", state_shape); + } + const int64_t nu_log_binding = diag_binding(p1_params, kNativeCallableBindingParameter, "nu_log", true); + const int64_t theta_log_binding = diag_binding(p1_params, kNativeCallableBindingParameter, "theta_log", true); + const int64_t w1_binding = diag_binding(p1_params, kNativeCallableBindingParameter, "w1", true); + const int64_t w2_binding = diag_binding(p1_params, kNativeCallableBindingParameter, "w2", true); + const int64_t activation_id_binding = diag_binding( + p1_params, + kNativeCallableBindingParameter, + "activation_id", + false); + at::Tensor nu_log = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, nu_log_binding, "registered transition diag_rtu row-group nu_log"); + at::Tensor theta_log = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + theta_log_binding, + "registered transition diag_rtu row-group theta_log"); + at::Tensor w1 = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, w1_binding, "registered transition diag_rtu row-group w1"); + at::Tensor w2 = program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, w2_binding, "registered transition diag_rtu row-group w2"); + check_program_diag_param_tensor(nu_log, "registered transition diag_rtu row-group nu_log", receivers, hidden); + check_program_diag_param_tensor(theta_log, "registered transition diag_rtu row-group theta_log", receivers, hidden); + check_program_diag_param_tensor(w1, "registered transition diag_rtu row-group w1", receivers, hidden); + check_program_diag_param_tensor(w2, "registered transition diag_rtu row-group w2", receivers, hidden); + const int64_t activation_id = activation_id_binding >= 0 + ? program_scalar_int_for_binding( + program_tensors, + program_tensor_binding_rows, + activation_id_binding, + "registered transition diag_rtu row-group activation_id") + : 0; + + at::Tensor output_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_weight_binding, + "registered transition diag_rtu row-group output weight"); + TORCH_CHECK(output_weight.is_cuda(), "registered transition diag_rtu row-group output weight must be CUDA"); + TORCH_CHECK(output_weight.is_contiguous(), "registered transition diag_rtu row-group output weight must be contiguous"); + TORCH_CHECK(output_weight.scalar_type() == at::kFloat, "registered transition diag_rtu row-group output weight must be float32"); + if (output_weight.dim() == 2) { + TORCH_CHECK( + output_weight.size(0) == 2 * hidden && output_weight.size(1) == hidden, + "registered transition diag_rtu row-group shared output weight must be [2H,H]"); + } else { + TORCH_CHECK(output_weight.dim() == 3, "registered transition diag_rtu row-group output weight must be [2H,H] or [R,2H,H]"); + TORCH_CHECK( + output_weight.size(1) == 2 * hidden && output_weight.size(2) == hidden, + "registered transition diag_rtu row-group output weight shape mismatch"); + TORCH_CHECK( + output_weight.size(0) == 1 || output_weight.size(0) == receivers, + "registered transition diag_rtu row-group output weight owner count must be 1 or R"); + } + at::Tensor output_bias = output_bias_binding >= 0 + ? program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_bias_binding, + "registered transition diag_rtu row-group output bias") + : aggregate_input.new_empty({0}); + validate_program_transition_linear_bias(output_bias, receivers, hidden, 1); + at::Tensor raw_public_y_buffer = registered_runtime_buffer_for_native_callable_output_allow_deferred( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + linear_hash, + kPrimitiveLinearOpcode, + p2, + 0, + raw_public_y_binding, + 1, + {B, receivers, hidden}, + "registered transition diag_rtu row-group raw public_y"); + const bool raw_public_y_is_deferred_local = + registered_runtime_buffer_is_deferred_local_placeholder(raw_public_y_buffer); + const int64_t norm_input_binding = native_callable_program_binding_for( + p3_inputs, + native_callable_binding_schema_rows, + norm_hash, + kForwardDirectionOpcode, + kPrimitiveNormOrIdentityOpcode, + kNativeCallableBindingInput, + "input", + true, + 1, + "registered transition diag_rtu row-group norm"); + const int64_t norm_weight_binding = native_callable_program_binding_for( + p3_params, + native_callable_binding_schema_rows, + norm_hash, + kForwardDirectionOpcode, + kPrimitiveNormOrIdentityOpcode, + kNativeCallableBindingParameter, + "weight", + true, + 1, + "registered transition diag_rtu row-group norm"); + const int64_t norm_eps_binding = native_callable_program_binding_for( + p3_params, + native_callable_binding_schema_rows, + norm_hash, + kForwardDirectionOpcode, + kPrimitiveNormOrIdentityOpcode, + kNativeCallableBindingParameter, + "eps", + false, + 1, + "registered transition diag_rtu row-group norm"); + const int64_t norm_output_binding = native_callable_program_binding_for( + p3_outputs, + native_callable_binding_schema_rows, + norm_hash, + kForwardDirectionOpcode, + kPrimitiveNormOrIdentityOpcode, + kNativeCallableBindingOutput, + "output", + true, + 1, + "registered transition diag_rtu row-group norm"); + TORCH_CHECK( + program_tensor_bindings_share_slot(program_tensor_binding_rows, raw_public_y_binding, norm_input_binding), + "registered transition diag_rtu row-group requires compiler tensor edge linear.output -> norm.input"); + const int64_t preproj_bytes_per_batch = std::max(1, receivers * 2 * hidden * 4); + const int64_t batch_chunk = std::max( + 1, + std::min(B, kRegisteredTransitionForwardScratchChunkBytes / preproj_bytes_per_batch)); + if (cell_input_is_deferred_local && raw_public_y_is_deferred_local) { + at::Tensor norm_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + norm_weight_binding, + "registered transition diag_rtu row-group norm weight"); + const double norm_eps = norm_eps_binding >= 0 + ? program_scalar_double_for_binding( + program_tensors, + program_tensor_binding_rows, + norm_eps_binding, + "registered transition diag_rtu row-group norm eps") + : 1.0e-5; + const bool has_norm_weight = check_optional_program_norm_weight(norm_weight, receivers, hidden); + at::Tensor norm_output = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + norm_hash, + kPrimitiveNormOrIdentityOpcode, + p3, + 0, + norm_output_binding, + 1, + {B, receivers, hidden}, + "registered transition diag_rtu row-group norm output"); + for (int64_t batch_start = 0; batch_start < B; batch_start += batch_chunk) { + const int64_t current_batch = std::min(batch_chunk, B - batch_start); + at::Tensor aggregate_input_chunk = aggregate_input.narrow(0, batch_start, current_batch); + at::Tensor cell_input_chunk = at::empty({current_batch, receivers, hidden}, aggregate_input.options()); + fabric::cuda::ops::dense_affine_out_cuda( + aggregate_input_chunk, + input_dense_weight, + input_bias, + cell_input_chunk, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + at::Tensor hc1_chunk = has_prev_state ? hc1.narrow(0, batch_start, current_batch) : hc1; + at::Tensor hc2_chunk = has_prev_state ? hc2.narrow(0, batch_start, current_batch) : hc2; + at::Tensor preproj_chunk = at::empty({current_batch, receivers, 2 * hidden}, aggregate_input.options()); + const int64_t diag_total = cell_input_chunk.numel(); + if (diag_total > 0) { + const int blocks = static_cast(std::min( + 4096, + (diag_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_diag_rtu_forward_kernel<<>>( + cell_input_chunk.data_ptr(), + has_prev_state ? hc1_chunk.data_ptr() : nullptr, + has_prev_state ? hc2_chunk.data_ptr() : nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nu_log.data_ptr(), + theta_log.data_ptr(), + w1.data_ptr(), + w2.data_ptr(), + preproj_chunk.data_ptr(), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + diag_total, + static_cast(receivers), + static_cast(hidden), + static_cast(activation_id), + has_prev_state, + false, + false); + check_launch("registered_transition_diag_rtu_row_group_preproj_chunk_kernel"); + } + at::Tensor raw_public_y_chunk = at::empty({current_batch, receivers, hidden}, aggregate_input.options()); + fabric::cuda::ops::dense_affine_out_cuda( + preproj_chunk, + output_weight, + output_bias, + raw_public_y_chunk, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + at::Tensor norm_output_chunk = norm_output.narrow(0, batch_start, current_batch); + const int64_t norm_total = norm_output_chunk.numel(); + if (norm_total > 0) { + const int blocks = static_cast(std::min( + 4096, + (norm_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_norm_or_identity_forward_kernel<<>>( + raw_public_y_chunk.data_ptr(), + has_norm_weight ? norm_weight.data_ptr() : nullptr, + norm_output_chunk.data_ptr(), + norm_total, + static_cast(receivers), + static_cast(hidden), + has_norm_weight, + has_norm_weight && norm_weight.dim() == 1, + static_cast(norm_eps)); + check_launch("registered_transition_diag_rtu_row_group_norm_chunk_kernel"); + } + } + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + norm_output_binding, + norm_output, + "registered transition diag_rtu row-group norm output"); + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p0, + p0_inputs); + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p3, + p3_inputs); + return true; + } + if (cell_input_is_deferred_local) { + cell_input = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + linear_hash, + kPrimitiveLinearOpcode, + p0, + 0, + cell_input_output_binding, + 1, + {B, receivers, hidden}, + "registered transition diag_rtu row-group cell_input"); + } + if (!has_prefilled_cell_input) { + fabric::cuda::ops::dense_affine_out_cuda( + aggregate_input, + input_dense_weight, + input_bias, + cell_input, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + } + at::Tensor raw_public_y = raw_public_y_is_deferred_local + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + linear_hash, + kPrimitiveLinearOpcode, + p2, + 0, + raw_public_y_binding, + 1, + {B, receivers, hidden}, + "registered transition diag_rtu row-group raw public_y") + : raw_public_y_buffer; + for (int64_t batch_start = 0; batch_start < B; batch_start += batch_chunk) { + const int64_t current_batch = std::min(batch_chunk, B - batch_start); + at::Tensor cell_input_chunk = cell_input.narrow(0, batch_start, current_batch); + at::Tensor hc1_chunk = has_prev_state ? hc1.narrow(0, batch_start, current_batch) : hc1; + at::Tensor hc2_chunk = has_prev_state ? hc2.narrow(0, batch_start, current_batch) : hc2; + at::Tensor preproj_chunk = at::empty({current_batch, receivers, 2 * hidden}, cell_input.options()); + const int64_t diag_total = cell_input_chunk.numel(); + if (diag_total > 0) { + const int blocks = static_cast(std::min( + 4096, + (diag_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_diag_rtu_forward_kernel<<>>( + cell_input_chunk.data_ptr(), + has_prev_state ? hc1_chunk.data_ptr() : nullptr, + has_prev_state ? hc2_chunk.data_ptr() : nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nu_log.data_ptr(), + theta_log.data_ptr(), + w1.data_ptr(), + w2.data_ptr(), + preproj_chunk.data_ptr(), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + diag_total, + static_cast(receivers), + static_cast(hidden), + static_cast(activation_id), + has_prev_state, + false, + false); + check_launch("registered_transition_diag_rtu_row_group_preproj_chunk_kernel"); + } + at::Tensor raw_public_y_chunk = raw_public_y.narrow(0, batch_start, current_batch); + fabric::cuda::ops::dense_affine_out_cuda( + preproj_chunk, + output_weight, + output_bias, + raw_public_y_chunk, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + } + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + raw_public_y_binding, + raw_public_y, + "registered transition diag_rtu row-group raw public_y"); + run_registered_transition_norm_or_identity_forward_primitive( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_rows, + forward_executor_binding_rows, + span, + p3, + kPrimitiveNormOrIdentityOpcode, + norm_hash, + p3_inputs, + p3_params, + p3_outputs); + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p0, + p0_inputs); + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p3, + p3_inputs); + return true; +} + +inline bool try_run_registered_transition_gated_logspace_forward_row_group( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& primitive_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& forward_transition_state_carry_rows, + const RegisteredFusedProgramSpan& span, + bool release_dead_input_bindings, + bool allow_terminal_local_state_outputs) { + if (!release_dead_input_bindings || span.primitive_row_count != 5) { + return false; + } + const int64_t p0 = span.primitive_row_start; + const int64_t p1 = span.primitive_row_start + 1; + const int64_t p2 = span.primitive_row_start + 2; + const int64_t p3 = span.primitive_row_start + 3; + const int64_t p4 = span.primitive_row_start + 4; + const int64_t* primitives = primitive_rows.data_ptr(); + const auto primitive_opcode = [&](int64_t primitive_row_index) { + return primitives[primitive_row_index * 4]; + }; + if ( + primitive_opcode(p0) != kPrimitiveLinearOpcode || + primitive_opcode(p1) != kPrimitiveLinearOpcode || + primitive_opcode(p2) != kPrimitiveMatmulOpcode || + primitive_opcode(p3) != kPrimitiveGatedLogspaceRecurrenceOpcode || + primitive_opcode(p4) != kPrimitiveNormOrIdentityOpcode) { + return false; + } + + const std::vector p0_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p0); + const std::vector p0_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p0); + const std::vector p0_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p0); + const std::vector p1_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p1); + const std::vector p1_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p1); + const std::vector p1_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p1); + const std::vector p2_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p2); + const std::vector p2_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p2); + const std::vector p2_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p2); + const std::vector p3_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p3); + const std::vector p3_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p3); + const std::vector p3_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p3); + const std::vector p4_inputs = fused_input_bindings_for_primitive(forward_executor_binding_rows, p4); + const std::vector p4_params = fused_parameter_bindings_for_primitive(forward_executor_binding_rows, p4); + const std::vector p4_outputs = fused_output_bindings_for_primitive(forward_executor_binding_rows, p4); + if ( + p0_outputs.size() != 1 || + p1_outputs.size() != 1 || + p2_outputs.size() != 1 || + p3_outputs.empty() || + p4_inputs.empty() || + p4_outputs.size() != 1) { + return false; + } + (void)p3_params; + + const int64_t linear_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveLinearOpcode); + const int64_t matmul_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveMatmulOpcode); + const int64_t gated_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveGatedLogspaceRecurrenceOpcode); + const int64_t norm_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveNormOrIdentityOpcode); + const auto linear_binding = [&]( + const std::vector& bindings, + int64_t binding_kind, + const char* logical_name, + bool required, + const char* subject) { + return native_callable_program_binding_for( + bindings, + native_callable_binding_schema_rows, + linear_hash, + kForwardDirectionOpcode, + kPrimitiveLinearOpcode, + binding_kind, + logical_name, + required, + 1, + subject); + }; + const auto matmul_binding = [&]( + const std::vector& bindings, + int64_t binding_kind, + const char* logical_name, + bool required) { + return native_callable_program_binding_for( + bindings, + native_callable_binding_schema_rows, + matmul_hash, + kForwardDirectionOpcode, + kPrimitiveMatmulOpcode, + binding_kind, + logical_name, + required, + 1, + "registered transition gated row-group matmul"); + }; + const auto gated_binding = [&]( + const std::vector& bindings, + int64_t binding_kind, + const char* logical_name, + bool required) { + return native_callable_program_binding_for( + bindings, + native_callable_binding_schema_rows, + gated_hash, + kForwardDirectionOpcode, + kPrimitiveGatedLogspaceRecurrenceOpcode, + binding_kind, + logical_name, + required, + 1, + "registered transition gated row-group recurrence"); + }; + + const int64_t aggregate_binding = linear_binding( + p0_inputs, + kNativeCallableBindingInput, + "input", + true, + "registered transition gated row-group aggregate input projection"); + const int64_t input_weight_binding = linear_binding( + p0_params, + kNativeCallableBindingParameter, + "weight", + true, + "registered transition gated row-group aggregate input projection"); + const int64_t input_bias_binding = linear_binding( + p0_params, + kNativeCallableBindingParameter, + "bias", + false, + "registered transition gated row-group aggregate input projection"); + const int64_t transition_input_binding = linear_binding( + p0_outputs, + kNativeCallableBindingOutput, + "output", + true, + "registered transition gated row-group aggregate input projection"); + const int64_t gate_affine_input_binding = linear_binding( + p1_inputs, + kNativeCallableBindingInput, + "input", + true, + "registered transition gated row-group gate affine"); + const int64_t gate_weight_binding = linear_binding( + p1_params, + kNativeCallableBindingParameter, + "weight", + true, + "registered transition gated row-group gate affine"); + const int64_t gate_bias_binding = linear_binding( + p1_params, + kNativeCallableBindingParameter, + "bias", + false, + "registered transition gated row-group gate affine"); + const int64_t gate_logits_binding = linear_binding( + p1_outputs, + kNativeCallableBindingOutput, + "output", + true, + "registered transition gated row-group gate affine"); + const int64_t recurrent_input_binding = matmul_binding( + p2_inputs, + kNativeCallableBindingInput, + "input", + true); + const int64_t recurrent_kernel_binding = matmul_binding( + p2_params, + kNativeCallableBindingParameter, + "weight", + true); + const int64_t recurrent_gate_logits_binding = matmul_binding( + p2_outputs, + kNativeCallableBindingOutput, + "output", + true); + const int64_t gated_gate_logits_binding = gated_binding( + p3_inputs, + kNativeCallableBindingInput, + "gate_logits", + true); + const int64_t gated_recurrent_gate_logits_binding = gated_binding( + p3_inputs, + kNativeCallableBindingInput, + "recurrent_gate_logits", + true); + const int64_t c_prev_binding = gated_binding(p3_inputs, kNativeCallableBindingInput, "c_prev", true); + const int64_t n_prev_binding = gated_binding(p3_inputs, kNativeCallableBindingInput, "n_prev", true); + const int64_t m_prev_binding = gated_binding(p3_inputs, kNativeCallableBindingInput, "m_prev", true); + const int64_t next_y_binding = gated_binding(p3_outputs, kNativeCallableBindingOutput, "next_y", true); + const int64_t next_c_binding = gated_binding(p3_outputs, kNativeCallableBindingOutput, "next_c", false); + const int64_t next_n_binding = gated_binding(p3_outputs, kNativeCallableBindingOutput, "next_n", false); + const int64_t next_m_binding = gated_binding(p3_outputs, kNativeCallableBindingOutput, "next_m", false); + TORCH_CHECK( + program_tensor_bindings_share_slot(program_tensor_binding_rows, transition_input_binding, gate_affine_input_binding), + "registered transition gated row-group requires compiler tensor edge linear.output -> gate_affine.input"); + TORCH_CHECK( + program_tensor_bindings_share_slot(program_tensor_binding_rows, gate_logits_binding, gated_gate_logits_binding), + "registered transition gated row-group requires compiler tensor edge gate_affine.output -> gated.gate_logits"); + TORCH_CHECK( + program_tensor_bindings_share_slot( + program_tensor_binding_rows, + recurrent_gate_logits_binding, + gated_recurrent_gate_logits_binding), + "registered transition gated row-group requires compiler tensor edge matmul.output -> gated.recurrent_gate_logits"); + TORCH_CHECK( + program_tensor_bindings_share_slot(program_tensor_binding_rows, next_y_binding, p4_inputs[0]), + "registered transition gated row-group requires compiler tensor edge gated.next_y -> norm.input"); + + at::Tensor aggregate_input = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + aggregate_binding, + "registered transition gated row-group aggregate input"); + at::Tensor input_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_weight_binding, + "registered transition gated row-group input weight"); + at::Tensor prefilled_transition_input = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + transition_input_binding, + "registered transition gated row-group prefilled transition_input"); + const bool has_prefilled_transition_input = prefilled_transition_input.defined() && prefilled_transition_input.numel() > 0; + TORCH_CHECK( + has_prefilled_transition_input || (aggregate_input.defined() && aggregate_input.numel() > 0), + "registered transition gated row-group needs either aggregate input or a compiler-produced transition_input"); + const at::Tensor shape_reference = has_prefilled_transition_input ? prefilled_transition_input : aggregate_input; + if (has_prefilled_transition_input) { + check_cuda_float_bank(prefilled_transition_input, "registered transition gated row-group prefilled transition_input"); + } else { + check_cuda_float_bank(aggregate_input, "registered transition gated row-group aggregate input"); + } + at::Tensor input_weight_view = has_prefilled_transition_input + ? registered_transition_input_projection_weight_view( + input_weight, + shape_reference.size(1), + static_cast(input_weight.dim() == 2 ? input_weight.size(0) : input_weight.size(1)), + "registered transition gated row-group prefilled input projection") + : program_transition_linear_weight_view(aggregate_input, input_weight, 1); + at::Tensor input_dense_weight = input_weight.dim() == 2 ? input_weight : input_weight_view; + const int64_t B = shape_reference.size(0); + const int64_t receivers = shape_reference.size(1); + const int64_t hidden = has_prefilled_transition_input + ? prefilled_transition_input.size(2) + : program_transition_linear_output_dim(input_weight_view); + at::Tensor input_bias = input_bias_binding >= 0 + ? program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_bias_binding, + "registered transition gated row-group input bias") + : shape_reference.new_empty({0}); + validate_program_transition_linear_bias(input_bias, receivers, hidden, 1); + TORCH_CHECK( + !has_prefilled_transition_input || hidden == program_transition_linear_output_dim(input_weight_view), + "registered transition gated row-group prefilled transition_input H must match compiler input projection output"); + at::Tensor transition_input_buffer = has_prefilled_transition_input + ? prefilled_transition_input + : registered_runtime_buffer_for_native_callable_output_allow_deferred( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + linear_hash, + kPrimitiveLinearOpcode, + p0, + 0, + transition_input_binding, + 1, + {B, receivers, hidden}, + "registered transition gated row-group transition_input"); + const bool transition_input_is_deferred_local = + !has_prefilled_transition_input && registered_runtime_buffer_is_deferred_local_placeholder(transition_input_buffer); + at::Tensor transition_input = transition_input_is_deferred_local ? at::Tensor() : transition_input_buffer; + if (!has_prefilled_transition_input && !transition_input_is_deferred_local) { + fabric::cuda::ops::dense_affine_out_cuda( + aggregate_input, + input_dense_weight, + input_bias, + transition_input, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + transition_input_binding, + transition_input, + "registered transition gated row-group transition_input"); + } + + at::Tensor gate_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + gate_weight_binding, + "registered transition gated row-group gate weight"); + at::Tensor gate_bias = gate_bias_binding >= 0 + ? program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + gate_bias_binding, + "registered transition gated row-group gate bias") + : aggregate_input.new_empty({0}); + at::Tensor recurrent_input = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + recurrent_input_binding, + "registered transition gated row-group recurrent input"); + at::Tensor recurrent_kernel = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + recurrent_kernel_binding, + "registered transition gated row-group recurrent kernel"); + TORCH_CHECK(gate_weight.dim() == 4, "registered transition gated row-group gate weight must be [R,Heads,D,4D]"); + TORCH_CHECK(recurrent_kernel.dim() == 5, "registered transition gated row-group recurrent kernel must be [R,4,Heads,D,D]"); + const int64_t heads = gate_weight.size(1); + const int64_t head_dim = gate_weight.size(2); + TORCH_CHECK(gate_weight.size(0) == receivers, "registered transition gated row-group gate weight R mismatch"); + TORCH_CHECK(gate_weight.size(3) == 4 * head_dim, "registered transition gated row-group gate weight output mismatch"); + TORCH_CHECK(hidden == heads * head_dim, "registered transition gated row-group hidden/head mismatch"); + TORCH_CHECK(recurrent_kernel.size(0) == receivers, "registered transition gated row-group recurrent kernel R mismatch"); + TORCH_CHECK(recurrent_kernel.size(1) == 4, "registered transition gated row-group recurrent kernel gate mismatch"); + TORCH_CHECK(recurrent_kernel.size(2) == heads, "registered transition gated row-group recurrent kernel head mismatch"); + TORCH_CHECK(recurrent_kernel.size(3) == head_dim && recurrent_kernel.size(4) == head_dim, "registered transition gated row-group recurrent kernel D mismatch"); + const bool has_recurrent_input = recurrent_input.defined() && recurrent_input.numel() > 0; + if (has_recurrent_input) { + check_cuda_float_bank(recurrent_input, "registered transition gated row-group recurrent input"); + TORCH_CHECK(recurrent_input.size(0) == B, "registered transition gated row-group recurrent input B mismatch"); + TORCH_CHECK(recurrent_input.size(1) == receivers, "registered transition gated row-group recurrent input R mismatch"); + TORCH_CHECK(recurrent_input.size(2) == hidden, "registered transition gated row-group recurrent input H mismatch"); + } + + at::Tensor c_prev = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + c_prev_binding, + "registered transition gated row-group c_prev"); + at::Tensor n_prev = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + n_prev_binding, + "registered transition gated row-group n_prev"); + at::Tensor m_prev = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + m_prev_binding, + "registered transition gated row-group m_prev"); + const std::vector state_shape = {B, receivers, hidden}; + const bool has_prev_state = c_prev.defined() && c_prev.numel() > 0; + TORCH_CHECK( + has_prev_state == (n_prev.defined() && n_prev.numel() > 0) && + has_prev_state == (m_prev.defined() && m_prev.numel() > 0), + "registered transition gated row-group fresh-zero state sentinel must cover c/n/m together"); + if (has_prev_state) { + check_program_diag_state_tensor(c_prev, "registered transition gated row-group c_prev", state_shape); + check_program_diag_state_tensor(n_prev, "registered transition gated row-group n_prev", state_shape); + check_program_diag_state_tensor(m_prev, "registered transition gated row-group m_prev", state_shape); + } + + at::Tensor next_y = at::Tensor(); + bool next_y_is_deferred_local = false; + if (allow_terminal_local_state_outputs) { + at::Tensor next_y_buffer = registered_runtime_buffer_for_native_callable_output_allow_deferred( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + gated_hash, + kPrimitiveGatedLogspaceRecurrenceOpcode, + p3, + 0, + next_y_binding, + 1, + state_shape, + "registered transition gated row-group next_y"); + next_y_is_deferred_local = registered_runtime_buffer_is_deferred_local_placeholder(next_y_buffer); + next_y = next_y_is_deferred_local ? at::Tensor() : next_y_buffer; + } else { + next_y = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + gated_hash, + kPrimitiveGatedLogspaceRecurrenceOpcode, + p3, + 0, + next_y_binding, + 1, + state_shape, + "registered transition gated row-group next_y"); + } + at::Tensor next_c = next_c_binding >= 0 + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + gated_hash, + kPrimitiveGatedLogspaceRecurrenceOpcode, + p3, + 1, + next_c_binding, + 1, + state_shape, + "registered transition gated row-group next_c") + : at::Tensor(); + at::Tensor next_n = next_n_binding >= 0 + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + gated_hash, + kPrimitiveGatedLogspaceRecurrenceOpcode, + p3, + 2, + next_n_binding, + 1, + state_shape, + "registered transition gated row-group next_n") + : at::Tensor(); + at::Tensor next_m = next_m_binding >= 0 + ? registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + gated_hash, + kPrimitiveGatedLogspaceRecurrenceOpcode, + p3, + 3, + next_m_binding, + 1, + state_shape, + "registered transition gated row-group next_m") + : at::Tensor(); + + int64_t norm_output_binding = -1; + at::Tensor norm_weight; + double norm_eps = 1.0e-5; + bool has_norm_weight = false; + at::Tensor norm_output; + if (next_y_is_deferred_local) { + const int64_t norm_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + kPrimitiveNormOrIdentityOpcode); + const int64_t norm_input_binding = native_callable_program_binding_for( + p4_inputs, + native_callable_binding_schema_rows, + norm_hash, + kForwardDirectionOpcode, + kPrimitiveNormOrIdentityOpcode, + kNativeCallableBindingInput, + "input", + true, + 1, + "registered transition gated row-group norm"); + const int64_t norm_weight_binding = native_callable_program_binding_for( + p4_params, + native_callable_binding_schema_rows, + norm_hash, + kForwardDirectionOpcode, + kPrimitiveNormOrIdentityOpcode, + kNativeCallableBindingParameter, + "weight", + true, + 1, + "registered transition gated row-group norm"); + const int64_t norm_eps_binding = native_callable_program_binding_for( + p4_params, + native_callable_binding_schema_rows, + norm_hash, + kForwardDirectionOpcode, + kPrimitiveNormOrIdentityOpcode, + kNativeCallableBindingParameter, + "eps", + false, + 1, + "registered transition gated row-group norm"); + norm_output_binding = native_callable_program_binding_for( + p4_outputs, + native_callable_binding_schema_rows, + norm_hash, + kForwardDirectionOpcode, + kPrimitiveNormOrIdentityOpcode, + kNativeCallableBindingOutput, + "output", + true, + 1, + "registered transition gated row-group norm"); + TORCH_CHECK( + program_tensor_bindings_share_slot(program_tensor_binding_rows, next_y_binding, norm_input_binding), + "registered transition gated row-group requires compiler tensor edge gated.next_y -> norm.input"); + norm_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + norm_weight_binding, + "registered transition gated row-group norm weight"); + norm_eps = norm_eps_binding >= 0 + ? program_scalar_double_for_binding( + program_tensors, + program_tensor_binding_rows, + norm_eps_binding, + "registered transition gated row-group norm eps") + : 1.0e-5; + has_norm_weight = check_optional_program_norm_weight(norm_weight, receivers, hidden); + norm_output = registered_runtime_buffer_for_native_callable_output( + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_output_rows, + norm_hash, + kPrimitiveNormOrIdentityOpcode, + p4, + 0, + norm_output_binding, + 1, + state_shape, + "registered transition gated row-group norm output"); + } + + const int64_t gate_bytes_per_batch = std::max(1, receivers * 4 * hidden * 4); + const int64_t batch_chunk = std::max( + 1, + std::min(B, kRegisteredTransitionForwardScratchChunkBytes / gate_bytes_per_batch)); + for (int64_t batch_start = 0; batch_start < B; batch_start += batch_chunk) { + const int64_t current_batch = std::min(batch_chunk, B - batch_start); + at::Tensor transition_input_chunk = transition_input_is_deferred_local + ? at::empty({current_batch, receivers, hidden}, aggregate_input.options()) + : transition_input.narrow(0, batch_start, current_batch); + if (transition_input_is_deferred_local) { + fabric::cuda::ops::dense_affine_out_cuda( + aggregate_input.narrow(0, batch_start, current_batch), + input_dense_weight, + input_bias, + transition_input_chunk, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + 1, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + } + at::Tensor gate_logits_chunk = at::empty({current_batch, receivers, 4, hidden}, aggregate_input.options()); + program_transition_linear_gate_affine_forward_into( + transition_input_chunk, + gate_weight, + gate_bias, + gate_logits_chunk); + at::Tensor recurrent_gate_logits_chunk = at::empty({current_batch, receivers, 4, hidden}, aggregate_input.options()); + if (has_recurrent_input) { + at::Tensor recurrent_input_chunk = recurrent_input.narrow(0, batch_start, current_batch); + const int64_t recurrent_total = recurrent_gate_logits_chunk.numel(); + if (recurrent_total > 0) { + const int blocks = static_cast(std::min( + 4096, + (recurrent_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_recurrent_matmul_forward_kernel<<>>( + recurrent_input_chunk.data_ptr(), + recurrent_kernel.data_ptr(), + recurrent_gate_logits_chunk.data_ptr(), + recurrent_total, + static_cast(receivers), + 4, + static_cast(heads), + static_cast(head_dim)); + check_launch("registered_transition_gated_row_group_recurrent_matmul_chunk_kernel"); + } + } else { + recurrent_gate_logits_chunk.zero_(); + } + at::Tensor next_y_chunk = next_y_is_deferred_local + ? at::empty({current_batch, receivers, hidden}, aggregate_input.options()) + : next_y.narrow(0, batch_start, current_batch); + at::Tensor next_c_chunk = next_c.defined() ? next_c.narrow(0, batch_start, current_batch) : at::Tensor(); + at::Tensor next_n_chunk = next_n.defined() ? next_n.narrow(0, batch_start, current_batch) : at::Tensor(); + at::Tensor next_m_chunk = next_m.defined() ? next_m.narrow(0, batch_start, current_batch) : at::Tensor(); + at::Tensor c_prev_chunk = has_prev_state ? c_prev.narrow(0, batch_start, current_batch) : c_prev; + at::Tensor n_prev_chunk = has_prev_state ? n_prev.narrow(0, batch_start, current_batch) : n_prev; + at::Tensor m_prev_chunk = has_prev_state ? m_prev.narrow(0, batch_start, current_batch) : m_prev; + const int64_t state_total = next_y_chunk.numel(); + if (state_total > 0) { + const int blocks = static_cast(std::min( + 4096, + (state_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_gated_logspace_recurrence_forward_kernel<<>>( + gate_logits_chunk.data_ptr(), + recurrent_gate_logits_chunk.data_ptr(), + has_prev_state ? c_prev_chunk.data_ptr() : nullptr, + has_prev_state ? n_prev_chunk.data_ptr() : nullptr, + has_prev_state ? m_prev_chunk.data_ptr() : nullptr, + next_y_chunk.data_ptr(), + next_c_chunk.defined() ? next_c_chunk.data_ptr() : nullptr, + next_n_chunk.defined() ? next_n_chunk.data_ptr() : nullptr, + next_m_chunk.defined() ? next_m_chunk.data_ptr() : nullptr, + state_total, + static_cast(receivers), + static_cast(hidden), + true, + has_prev_state); + check_launch("registered_transition_gated_row_group_recurrence_chunk_kernel"); + } + if (next_y_is_deferred_local) { + at::Tensor norm_output_chunk = norm_output.narrow(0, batch_start, current_batch); + const int64_t norm_total = norm_output_chunk.numel(); + if (norm_total > 0) { + const int blocks = static_cast(std::min( + 4096, + (norm_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_norm_or_identity_forward_kernel<<>>( + next_y_chunk.data_ptr(), + has_norm_weight ? norm_weight.data_ptr() : nullptr, + norm_output_chunk.data_ptr(), + norm_total, + static_cast(receivers), + static_cast(hidden), + has_norm_weight, + has_norm_weight && norm_weight.dim() == 1, + static_cast(norm_eps)); + check_launch("registered_transition_gated_row_group_norm_chunk_kernel"); + } + } + } + if (!next_y_is_deferred_local) { + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + next_y_binding, + next_y, + "registered transition gated row-group next_y"); + } + if (next_c_binding >= 0) { + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + next_c_binding, + next_c, + "registered transition gated row-group next_c"); + } + if (next_n_binding >= 0) { + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + next_n_binding, + next_n, + "registered transition gated row-group next_n"); + } + if (next_m_binding >= 0) { + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + next_m_binding, + next_m, + "registered transition gated row-group next_m"); + } + if (next_y_is_deferred_local) { + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + norm_output_binding, + norm_output, + "registered transition gated row-group norm output"); + } else { + run_registered_transition_norm_or_identity_forward_primitive( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_rows, + forward_executor_binding_rows, + span, + p4, + kPrimitiveNormOrIdentityOpcode, + norm_hash, + p4_inputs, + p4_params, + p4_outputs); + } + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p0, + p0_inputs); + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p1, + p1_inputs); + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p2, + p2_inputs); + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p3, + p3_inputs); + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + span, + p4, + p4_inputs); + return true; +} + +inline int64_t registered_transition_callable_hash_for_primitive( + const at::Tensor& transition_primitive_callable_rows, + int64_t primitive_opcode, + int64_t callable_column, + int64_t status_column, + const char* direction_name) { + check_cpu_long_rank2( + transition_primitive_callable_rows, + "transition_primitive_callable_rows", + kTransitionPrimitiveCallableRowColumns); + TORCH_CHECK( + callable_column == 1 || callable_column == 2, + "transition primitive callable lookup has invalid callable column"); + TORCH_CHECK( + status_column == 4 || status_column == 5, + "transition primitive callable lookup has invalid status column"); + const int64_t* rows = transition_primitive_callable_rows.data_ptr(); + int64_t selected = 0; + for (int64_t row_index = 0; row_index < transition_primitive_callable_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * kTransitionPrimitiveCallableRowColumns; + TORCH_CHECK(row[0] > 0, "transition primitive callable row has invalid primitive opcode"); + TORCH_CHECK(row[3] == 0 || row[3] == 1, "transition primitive callable row has invalid layer status"); + TORCH_CHECK(row[4] == 0 || row[4] == 1, "transition primitive callable row has invalid forward status"); + TORCH_CHECK(row[5] == 0 || row[5] == 1, "transition primitive callable row has invalid backward status"); + if (row[0] != primitive_opcode) { + continue; + } + TORCH_CHECK( + selected == 0, + "transition primitive callable rows have duplicate primitive opcode ", + primitive_opcode, + " for ", + direction_name); + TORCH_CHECK( + row[3] == 1 && row[status_column] == 1 && row[callable_column] > 0, + "transition primitive ", + direction_name, + " callable is not registered for primitive opcode ", + primitive_opcode); + selected = row[callable_column]; + } + TORCH_CHECK( + selected > 0, + "transition primitive callable rows are missing primitive opcode ", + primitive_opcode, + " for ", + direction_name); + return selected; +} + +inline int64_t registered_transition_forward_callable_hash_for_primitive( + const at::Tensor& transition_primitive_callable_rows, + int64_t primitive_opcode) { + return registered_transition_callable_hash_for_primitive( + transition_primitive_callable_rows, + primitive_opcode, + 1, + 4, + "forward"); +} + +inline int64_t registered_transition_backward_callable_hash_for_primitive( + const at::Tensor& transition_primitive_callable_rows, + int64_t primitive_opcode) { + return registered_transition_callable_hash_for_primitive( + transition_primitive_callable_rows, + primitive_opcode, + 2, + 5, + "backward"); +} + +#define REGISTERED_TEMPORAL_NATIVE_FORWARD_TRANSITION_CATALOG +#include "../flat_bucket_registered_native_callables.cuh" + +inline const RegisteredTransitionForwardPrimitiveExecutor& registered_transition_forward_primitive_executor_for_callable( + int64_t native_callable_hash) { + for (const RegisteredTransitionForwardPrimitiveExecutor* executor = + registered_native_transition_forward_primitive_catalog_begin(); + executor != registered_native_transition_forward_primitive_catalog_end(); + ++executor) { + if (executor->native_callable_hash == native_callable_hash) { + return *executor; + } + } + TORCH_CHECK( + false, + "registered transition forward primitive executor is missing native callable hash ", + native_callable_hash); + return *registered_native_transition_forward_primitive_catalog_begin(); +} + +inline void run_registered_transition_forward_primitive_executor( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + const RegisteredFusedProgramSpan& span, + int64_t primitive_row_index, + int64_t primitive_opcode, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs) { + const int64_t native_callable_hash = registered_transition_forward_callable_hash_for_primitive( + transition_primitive_callable_rows, + primitive_opcode); + const RegisteredTransitionForwardPrimitiveExecutor& executor = + registered_transition_forward_primitive_executor_for_callable(native_callable_hash); + executor.run( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + forward_executor_rows, + forward_executor_binding_rows, + span, + primitive_row_index, + primitive_opcode, + native_callable_hash, + inputs, + params, + outputs); +} + +std::vector flat_bucket_registered_temporal_fused_forward_transition_program_cuda( + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + std::vector runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const at::Tensor& forward_transition_state_carry_rows, + bool release_dead_input_bindings, + bool allow_terminal_local_state_outputs, + int64_t schema_version) { + TORCH_CHECK(schema_version == 1, "registered fused transition program schema version mismatch"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_cpu_long_rank2(runtime_buffer_rows, "runtime_buffer_rows", 10); + validate_registered_native_callable_binding_schema_rows(native_callable_binding_schema_rows, schema_version); + validate_registered_native_callable_output_rows(native_callable_output_rows, schema_version); + check_cpu_long_rank2(primitive_rows, "primitive_rows", 4); + validate_registered_fused_program_executor_rows(primitive_rows, forward_executor_rows, "forward_executor_rows"); + validate_registered_fused_program_binding_rows( + primitive_rows, + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + "forward_executor_binding_rows"); + validate_registered_fused_program_memory_rows(primitive_rows, memory_liveness_rows); + at::Tensor forward_spans = decode_registered_fused_program_spans( + primitive_rows, + forward_executor_rows, + forward_handler_rows, + native_strategy_rows, + forward_executor_binding_rows, + memory_liveness_rows, + kForwardDirectionOpcode, + schema_version, + native_strategy_rows.size(0) > 0, + "fused_transition_program"); + validate_registered_fused_forward_span_memory(forward_spans, memory_liveness_rows); + const int64_t* spans = forward_spans.data_ptr(); + const int64_t* primitive_table = primitive_rows.data_ptr(); + bool saw_transition_span = false; + for (int64_t span_index = 0; span_index < forward_spans.size(0); ++span_index) { + const int64_t* span = spans + span_index * kFusedProgramSpanColumns; + if (span[3] != kTransitionSurfaceOpcode) { + continue; + } + saw_transition_span = true; + const int64_t primitive_start = span[5]; + const int64_t primitive_count = span[6]; + const int64_t receiver_count = span[8]; + const int64_t bucket_ordinal = span[4]; + const RegisteredFusedProgramSpan decoded_span = registered_fused_program_span_at(forward_spans, span_index); + if (try_run_registered_transition_diag_rtu_forward_row_group( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + primitive_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + transition_primitive_callable_rows, + forward_executor_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + decoded_span, + release_dead_input_bindings)) { + continue; + } + if (try_run_registered_transition_gated_logspace_forward_row_group( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + primitive_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + transition_primitive_callable_rows, + forward_executor_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + decoded_span, + release_dead_input_bindings, + allow_terminal_local_state_outputs)) { + continue; + } + for (int64_t offset = 0; offset < primitive_count; ++offset) { + const int64_t primitive_row_index = primitive_start + offset; + const int64_t* primitive = primitive_table + primitive_row_index * 4; + TORCH_CHECK(primitive[3] == bucket_ordinal, "fused transition primitive bucket mismatch"); + TORCH_CHECK(primitive[2] == receiver_count, "fused transition primitive receiver count mismatch"); + const int64_t opcode = primitive[0]; + const std::vector inputs = + fused_input_bindings_for_primitive(forward_executor_binding_rows, primitive_row_index); + const std::vector params = + fused_parameter_bindings_for_primitive(forward_executor_binding_rows, primitive_row_index); + const std::vector outputs = + fused_output_bindings_for_primitive(forward_executor_binding_rows, primitive_row_index); + run_registered_transition_forward_primitive_executor( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + transition_primitive_callable_rows, + forward_executor_rows, + forward_executor_binding_rows, + decoded_span, + primitive_row_index, + opcode, + inputs, + params, + outputs); + if (release_dead_input_bindings) { + clear_forward_transition_dead_input_binding_slots_after_primitive( + program_tensors, + program_tensor_binding_rows, + forward_executor_binding_rows, + forward_transition_state_carry_rows, + decoded_span, + primitive_row_index, + inputs); + } + } + } + TORCH_CHECK(saw_transition_span, "fused transition program found no transition executor span"); + return program_tensors; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_math_helpers.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_math_helpers.cuh new file mode 100644 index 00000000..b38450c8 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_math_helpers.cuh @@ -0,0 +1,316 @@ +#pragma once + +inline int64_t program_transition_linear_output_dim(const at::Tensor& weight) { + if (weight.dim() == 2) { + return weight.size(1); + } + TORCH_CHECK(weight.dim() == 3, "program transition linear weight must be rank-2 or rank-3"); + return weight.size(2); +} + +inline at::Tensor program_transition_linear_weight_view( + const at::Tensor& input, + const at::Tensor& weight, + int64_t group_size) { + check_cuda_float_bank(input, "input"); + TORCH_CHECK(weight.is_cuda(), "weight must be a CUDA tensor"); + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(weight.scalar_type() == at::kFloat, "weight must be float32"); + TORCH_CHECK(group_size > 0, "program transition linear group_size must be positive"); + const int64_t receivers = input.size(1); + const int64_t input_dim = input.size(2); + if (weight.dim() == 2) { + TORCH_CHECK(weight.size(0) == input_dim, "program transition linear shared weight K must match input"); + return weight.view({1, input_dim, weight.size(1)}); + } + TORCH_CHECK(weight.dim() == 3, "program transition linear weight must be [K,N], [R,K,N], or [G,K,N]"); + TORCH_CHECK(weight.size(1) == input_dim, "program transition linear receiver weight K must match input"); + TORCH_CHECK( + weight.size(0) == receivers || weight.size(0) == 1 || weight.size(0) * group_size == receivers, + "program transition linear weight owner count must be R, 1, or G with G*group_size == R"); + return weight; +} + +inline void validate_program_transition_linear_bias( + const at::Tensor& bias, + int64_t receivers, + int64_t output_dim, + int64_t group_size) { + if (!bias.defined() || bias.numel() == 0) { + return; + } + TORCH_CHECK(bias.is_cuda(), "bias must be a CUDA tensor"); + TORCH_CHECK(bias.is_contiguous(), "bias must be contiguous"); + TORCH_CHECK(bias.scalar_type() == at::kFloat, "bias must be float32"); + TORCH_CHECK(bias.dim() == 1 || bias.dim() == 2, "program transition linear bias must be [N], [R,N], or [G,N]"); + if (bias.dim() == 1) { + TORCH_CHECK(bias.size(0) == output_dim, "program transition linear shared bias N must match output"); + return; + } + TORCH_CHECK(bias.size(1) == output_dim, "program transition linear receiver bias N must match output"); + TORCH_CHECK( + bias.size(0) == receivers || bias.size(0) == 1 || bias.size(0) * group_size == receivers, + "program transition linear bias owner count must be R, 1, or G with G*group_size == R"); +} + +inline at::Tensor program_transition_linear_gate_bias_view( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias) { + if (!bias.defined() || bias.numel() == 0) { + return input.new_empty({0}); + } + TORCH_CHECK(bias.is_cuda(), "program transition gate affine bias must be CUDA"); + TORCH_CHECK(bias.is_contiguous(), "program transition gate affine bias must be contiguous"); + TORCH_CHECK(bias.scalar_type() == at::kFloat, "program transition gate affine bias must be float32"); + TORCH_CHECK(bias.dim() == 4, "program transition gate affine bias must be [R,4,Heads,D]"); + TORCH_CHECK(bias.size(0) == weight.size(0), "program transition gate affine bias R must match weight"); + TORCH_CHECK(bias.size(1) == 4, "program transition gate affine bias gate count must be 4"); + TORCH_CHECK(bias.size(2) == weight.size(1), "program transition gate affine bias head count must match weight"); + TORCH_CHECK(bias.size(3) == weight.size(2), "program transition gate affine bias head dim must match weight"); + return bias.permute({0, 2, 1, 3}).reshape({weight.size(0) * weight.size(1), 4 * weight.size(2)}).contiguous(); +} + +inline void program_transition_linear_gate_affine_forward_into( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias, + const at::Tensor& output) { + check_cuda_float_bank(input, "program transition gate affine input"); + TORCH_CHECK(weight.is_cuda(), "program transition gate affine weight must be CUDA"); + TORCH_CHECK(weight.is_contiguous(), "program transition gate affine weight must be contiguous"); + TORCH_CHECK(weight.scalar_type() == at::kFloat, "program transition gate affine weight must be float32"); + TORCH_CHECK(weight.dim() == 4, "program transition gate affine weight must be [R,Heads,D,4D]"); + TORCH_CHECK(output.is_cuda(), "program transition gate affine output must be CUDA"); + TORCH_CHECK(output.is_contiguous(), "program transition gate affine output must be contiguous"); + TORCH_CHECK(output.scalar_type() == at::kFloat, "program transition gate affine output must be float32"); + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t hidden = input.size(2); + const int64_t heads = weight.size(1); + const int64_t head_dim = weight.size(2); + TORCH_CHECK(weight.size(0) == receivers, "program transition gate affine weight R must match input"); + TORCH_CHECK(weight.size(3) == 4 * head_dim, "program transition gate affine weight output dim must be 4D"); + TORCH_CHECK(hidden == heads * head_dim, "program transition gate affine input hidden must equal Heads*D"); + TORCH_CHECK( + output.dim() == 4 && output.size(0) == B && output.size(1) == receivers && + output.size(2) == 4 && output.size(3) == hidden, + "program transition gate affine output must be [B,R,4,H]"); + const bool has_bias = bias.defined() && bias.numel() > 0; + if (has_bias) { + TORCH_CHECK(bias.is_cuda(), "program transition gate affine bias must be CUDA"); + TORCH_CHECK(bias.is_contiguous(), "program transition gate affine bias must be contiguous"); + TORCH_CHECK(bias.scalar_type() == at::kFloat, "program transition gate affine bias must be float32"); + TORCH_CHECK(bias.dim() == 4, "program transition gate affine bias must be [R,4,Heads,D]"); + TORCH_CHECK(bias.size(0) == receivers, "program transition gate affine bias R must match weight"); + TORCH_CHECK(bias.size(1) == 4, "program transition gate affine bias gate count must be 4"); + TORCH_CHECK(bias.size(2) == heads, "program transition gate affine bias head count must match weight"); + TORCH_CHECK(bias.size(3) == head_dim, "program transition gate affine bias head dim must match weight"); + } + const int64_t total = output.numel(); + if (total > 0) { + at::Tensor input_batches = input.reshape({B, receivers, heads, head_dim}) + .permute({1, 2, 0, 3}) + .reshape({receivers * heads, B, head_dim}) + .contiguous(); + at::Tensor weight_batches = weight.reshape({receivers * heads, head_dim, 4 * head_dim}); + at::Tensor projected = at::bmm(input_batches, weight_batches); + if (has_bias) { + at::Tensor bias_batches = bias.permute({0, 2, 1, 3}).reshape({receivers * heads, 4 * head_dim}); + projected.add_(bias_batches.unsqueeze(1)); + } + TORCH_CHECK( + projected.dim() == 3 && + projected.size(0) == receivers * heads && + projected.size(1) == B && + projected.size(2) == 4 * head_dim, + "program transition gate affine GEMM output shape mismatch"); + at::Tensor output_view = projected.reshape({receivers, heads, B, 4, head_dim}) + .permute({2, 0, 3, 1, 4}) + .reshape({B, receivers, 4, hidden}); + output.copy_(output_view); + } +} + +inline at::Tensor program_transition_linear_gate_affine_forward( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias) { + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t hidden = input.size(2); + at::Tensor output = at::empty({B, receivers, 4, hidden}, input.options()); + program_transition_linear_gate_affine_forward_into(input, weight, bias, output); + return output; +} + +inline std::vector program_transition_linear_gate_affine_backward( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias, + const at::Tensor& grad_output) { + check_cuda_float_bank(input, "program transition gate affine input"); + check_cuda_float_rank4(grad_output, "program transition gate affine grad_output"); + TORCH_CHECK(weight.is_cuda(), "program transition gate affine weight must be CUDA"); + TORCH_CHECK(weight.is_contiguous(), "program transition gate affine weight must be contiguous"); + TORCH_CHECK(weight.scalar_type() == at::kFloat, "program transition gate affine weight must be float32"); + TORCH_CHECK(weight.dim() == 4, "program transition gate affine weight must be [R,Heads,D,4D]"); + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t hidden = input.size(2); + const int64_t heads = weight.size(1); + const int64_t head_dim = weight.size(2); + TORCH_CHECK(weight.size(0) == receivers, "program transition gate affine weight R must match input"); + TORCH_CHECK(weight.size(3) == 4 * head_dim, "program transition gate affine weight output dim must be 4D"); + TORCH_CHECK(hidden == heads * head_dim, "program transition gate affine input hidden must equal Heads*D"); + TORCH_CHECK(grad_output.size(0) == B, "program transition gate affine grad_output B must match input"); + TORCH_CHECK(grad_output.size(1) == receivers, "program transition gate affine grad_output R must match input"); + TORCH_CHECK(grad_output.size(2) == 4, "program transition gate affine grad_output gate count must be 4"); + TORCH_CHECK(grad_output.size(3) == hidden, "program transition gate affine grad_output H must match input"); + at::Tensor input_flat = input.view({B, receivers, heads, head_dim}).reshape({B, receivers * heads, head_dim}); + at::Tensor weight_flat = weight.reshape({receivers * heads, head_dim, 4 * head_dim}); + at::Tensor grad_flat = grad_output.view({B, receivers, 4, heads, head_dim}) + .permute({0, 1, 3, 2, 4}) + .reshape({B, receivers * heads, 4 * head_dim}) + .contiguous(); + at::Tensor grad_input_flat = at::empty_like(input_flat); + at::Tensor grad_weight_flat = at::empty_like(weight_flat); + const bool has_bias = bias.defined() && bias.numel() > 0; + at::Tensor grad_bias_flat = + has_bias ? at::empty({receivers * heads, 4 * head_dim}, input.options()) : input.new_empty({0}); + const auto stream = at::cuda::getCurrentCUDAStream(); + const int64_t input_total = grad_input_flat.numel(); + if (input_total > 0) { + const int blocks = static_cast(std::min( + 4096, + (input_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_linear_input_backward_kernel<<>>( + grad_flat.data_ptr(), + weight_flat.data_ptr(), + grad_input_flat.data_ptr(), + input_total, + static_cast(receivers * heads), + static_cast(head_dim), + static_cast(4 * head_dim), + static_cast(receivers * heads), + 1); + check_launch("program_transition_linear_gate_affine_input_backward_kernel"); + } + const int64_t weight_total = grad_weight_flat.numel(); + if (weight_total > 0) { + const int blocks = static_cast(std::min( + 4096, + (weight_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_linear_weight_backward_kernel<<>>( + input_flat.data_ptr(), + grad_flat.data_ptr(), + grad_weight_flat.data_ptr(), + weight_total, + static_cast(B), + static_cast(receivers * heads), + static_cast(head_dim), + static_cast(4 * head_dim), + static_cast(receivers * heads), + 1); + check_launch("program_transition_linear_gate_affine_weight_backward_kernel"); + } + if (has_bias) { + program_transition_linear_gate_bias_view(input, weight, bias); + const int64_t bias_total = grad_bias_flat.numel(); + const int blocks = static_cast(std::min( + 4096, + (bias_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_linear_bias_backward_kernel<<>>( + grad_flat.data_ptr(), + grad_bias_flat.data_ptr(), + bias_total, + static_cast(B), + static_cast(receivers * heads), + static_cast(4 * head_dim), + static_cast(receivers * heads), + 1, + false); + check_launch("program_transition_linear_gate_affine_bias_backward_kernel"); + } + at::Tensor grad_input = grad_input_flat.view({B, receivers, heads, head_dim}) + .reshape({B, receivers, hidden}) + .contiguous(); + at::Tensor grad_weight = grad_weight_flat.reshape({receivers, heads, head_dim, 4 * head_dim}).contiguous(); + at::Tensor grad_bias = has_bias + ? grad_bias_flat.view({receivers, heads, 4, head_dim}).permute({0, 2, 1, 3}).contiguous() + : input.new_empty({0}); + return {grad_input, grad_weight, grad_bias}; +} + +inline void check_program_gated_logits(const at::Tensor& tensor, const char* name) { + check_cuda_float_rank4(tensor, name); + TORCH_CHECK(tensor.size(2) == 4, name, " must have shape [B,R,4,H]"); +} + +inline bool check_optional_program_gated_tensor( + const at::Tensor& tensor, + const char* name, + at::IntArrayRef expected_shape) { + if (!tensor.defined() || tensor.numel() == 0) { + return false; + } + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kFloat, name, " must be float32"); + TORCH_CHECK(tensor.sizes() == expected_shape, name, " shape mismatch"); + return true; +} + +inline bool check_optional_program_norm_weight( + const at::Tensor& weight, + int64_t receivers, + int64_t hidden) { + if (!weight.defined() || weight.numel() == 0) { + return false; + } + TORCH_CHECK(weight.is_cuda(), "norm_or_identity weight must be a CUDA tensor"); + TORCH_CHECK(weight.is_contiguous(), "norm_or_identity weight must be contiguous"); + TORCH_CHECK(weight.scalar_type() == at::kFloat, "norm_or_identity weight must be float32"); + TORCH_CHECK(weight.dim() == 1 || weight.dim() == 2, "norm_or_identity weight must be [H] or [R,H]"); + if (weight.dim() == 1) { + TORCH_CHECK(weight.size(0) == hidden, "norm_or_identity shared weight H must match input"); + return true; + } + TORCH_CHECK(weight.size(0) == receivers, "norm_or_identity receiver weight R must match input"); + TORCH_CHECK(weight.size(1) == hidden, "norm_or_identity receiver weight H must match input"); + return true; +} + +inline void check_program_diag_state_tensor( + const at::Tensor& tensor, + const char* name, + at::IntArrayRef expected_shape) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kFloat, name, " must be float32"); + TORCH_CHECK(tensor.sizes() == expected_shape, name, " shape mismatch"); +} + +inline bool check_optional_program_diag_tensor( + const at::Tensor& tensor, + const char* name, + at::IntArrayRef expected_shape) { + if (!tensor.defined() || tensor.numel() == 0) { + return false; + } + check_program_diag_state_tensor(tensor, name, expected_shape); + return true; +} + +inline void check_program_diag_param_tensor( + const at::Tensor& tensor, + const char* name, + int64_t receivers, + int64_t hidden) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.scalar_type() == at::kFloat, name, " must be float32"); + TORCH_CHECK(tensor.dim() == 2, name, " must be rank-2 [R,H]"); + TORCH_CHECK(tensor.size(0) == receivers, name, " receiver dimension mismatch"); + TORCH_CHECK(tensor.size(1) == hidden, name, " hidden dimension mismatch"); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_primitive_forward_ops.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_primitive_forward_ops.cuh new file mode 100644 index 00000000..6e4bb34c --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_primitive_forward_ops.cuh @@ -0,0 +1,854 @@ +#pragma once + +std::vector flat_bucket_registered_program_transition_gated_logspace_recurrence_forward_cuda( + const at::Tensor& gate_logits, + const at::Tensor& recurrent_gate_logits, + const at::Tensor& c_prev, + const at::Tensor& n_prev, + const at::Tensor& m_prev, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + TORCH_CHECK(bucket_ordinal >= 0, "program gated recurrence requires a transition bucket ordinal"); + check_program_gated_logits(gate_logits, "gate_logits"); + const int64_t B = gate_logits.size(0); + const int64_t receivers = gate_logits.size(1); + const int64_t hidden = gate_logits.size(3); + const auto gate_shape = gate_logits.sizes(); + const std::vector state_shape = {B, receivers, hidden}; + const bool has_recurrent_gate_logits = check_optional_program_gated_tensor( + recurrent_gate_logits, + "recurrent_gate_logits", + gate_shape); + const bool has_c_prev = check_optional_program_gated_tensor(c_prev, "c_prev", state_shape); + const bool has_n_prev = check_optional_program_gated_tensor(n_prev, "n_prev", state_shape); + const bool has_m_prev = check_optional_program_gated_tensor(m_prev, "m_prev", state_shape); + TORCH_CHECK( + has_c_prev == has_n_prev && has_c_prev == has_m_prev, + "program gated recurrence previous state tensors must be provided together or omitted together"); + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "program gated recurrence forward"); + at::Tensor next_y = at::empty({B, receivers, hidden}, gate_logits.options()); + at::Tensor next_c = at::empty_like(next_y); + at::Tensor next_n = at::empty_like(next_y); + at::Tensor next_m = at::empty_like(next_y); + const int64_t total = next_y.numel(); + if (total == 0) { + return {next_y, next_c, next_n, next_m}; + } + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_gated_logspace_recurrence_forward_kernel<<>>( + gate_logits.data_ptr(), + has_recurrent_gate_logits ? recurrent_gate_logits.data_ptr() : nullptr, + has_c_prev ? c_prev.data_ptr() : nullptr, + has_n_prev ? n_prev.data_ptr() : nullptr, + has_m_prev ? m_prev.data_ptr() : nullptr, + next_y.data_ptr(), + next_c.data_ptr(), + next_n.data_ptr(), + next_m.data_ptr(), + total, + static_cast(receivers), + static_cast(hidden), + has_recurrent_gate_logits, + has_c_prev); + check_launch("program_transition_gated_logspace_recurrence_forward_kernel"); + return {next_y, next_c, next_n, next_m}; +} + +std::vector flat_bucket_registered_program_transition_gated_logspace_recurrence_backward_cuda( + const at::Tensor& gate_logits, + const at::Tensor& recurrent_gate_logits, + const at::Tensor& c_prev, + const at::Tensor& n_prev, + const at::Tensor& m_prev, + const at::Tensor& grad_next_y, + const at::Tensor& grad_next_c, + const at::Tensor& grad_next_n, + const at::Tensor& grad_next_m, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + bool return_state_grads) { + TORCH_CHECK(bucket_ordinal >= 0, "program gated recurrence backward requires a transition bucket ordinal"); + check_program_gated_logits(gate_logits, "gate_logits"); + const int64_t B = gate_logits.size(0); + const int64_t receivers = gate_logits.size(1); + const int64_t hidden = gate_logits.size(3); + const auto gate_shape = gate_logits.sizes(); + const std::vector state_shape = {B, receivers, hidden}; + const bool has_recurrent_gate_logits = check_optional_program_gated_tensor( + recurrent_gate_logits, + "recurrent_gate_logits", + gate_shape); + const bool has_c_prev = check_optional_program_gated_tensor(c_prev, "c_prev", state_shape); + const bool has_n_prev = check_optional_program_gated_tensor(n_prev, "n_prev", state_shape); + const bool has_m_prev = check_optional_program_gated_tensor(m_prev, "m_prev", state_shape); + TORCH_CHECK( + has_c_prev == has_n_prev && has_c_prev == has_m_prev, + "program gated recurrence previous state tensors must be provided together or omitted together"); + const bool has_grad_next_y = check_optional_program_gated_tensor(grad_next_y, "grad_next_y", state_shape); + const bool has_grad_next_c = check_optional_program_gated_tensor(grad_next_c, "grad_next_c", state_shape); + const bool has_grad_next_n = check_optional_program_gated_tensor(grad_next_n, "grad_next_n", state_shape); + const bool has_grad_next_m = check_optional_program_gated_tensor(grad_next_m, "grad_next_m", state_shape); + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "program gated recurrence backward"); + at::Tensor grad_raw = at::empty_like(gate_logits); + at::Tensor grad_c_prev = return_state_grads ? at::empty({B, receivers, hidden}, gate_logits.options()) + : at::empty({0}, gate_logits.options()); + at::Tensor grad_n_prev = return_state_grads ? at::empty_like(grad_c_prev) : at::empty({0}, gate_logits.options()); + at::Tensor grad_m_prev = return_state_grads ? at::empty_like(grad_c_prev) : at::empty({0}, gate_logits.options()); + const int64_t total = B * receivers * hidden; + if (total == 0) { + return {grad_raw, grad_c_prev, grad_n_prev, grad_m_prev}; + } + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_gated_logspace_recurrence_backward_kernel<<>>( + gate_logits.data_ptr(), + has_recurrent_gate_logits ? recurrent_gate_logits.data_ptr() : nullptr, + has_c_prev ? c_prev.data_ptr() : nullptr, + has_n_prev ? n_prev.data_ptr() : nullptr, + has_m_prev ? m_prev.data_ptr() : nullptr, + has_grad_next_y ? grad_next_y.data_ptr() : nullptr, + has_grad_next_c ? grad_next_c.data_ptr() : nullptr, + has_grad_next_n ? grad_next_n.data_ptr() : nullptr, + has_grad_next_m ? grad_next_m.data_ptr() : nullptr, + grad_raw.data_ptr(), + return_state_grads ? grad_c_prev.data_ptr() : nullptr, + return_state_grads ? grad_n_prev.data_ptr() : nullptr, + return_state_grads ? grad_m_prev.data_ptr() : nullptr, + total, + static_cast(receivers), + static_cast(hidden), + has_recurrent_gate_logits, + has_c_prev, + has_grad_next_y, + has_grad_next_c, + has_grad_next_n, + has_grad_next_m, + return_state_grads); + check_launch("program_transition_gated_logspace_recurrence_backward_kernel"); + return {grad_raw, grad_c_prev, grad_n_prev, grad_m_prev}; +} + +std::vector flat_bucket_registered_program_transition_norm_or_identity_forward_cuda( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double eps) { + TORCH_CHECK(bucket_ordinal >= 0, "program norm_or_identity requires a transition bucket ordinal"); + check_cuda_float_bank(input, "input"); + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t hidden = input.size(2); + TORCH_CHECK(hidden > 0, "program norm_or_identity hidden size must be positive"); + const bool has_weight = check_optional_program_norm_weight(weight, receivers, hidden); + const bool weight_is_shared = has_weight && weight.dim() == 1; + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "program norm_or_identity forward"); + at::Tensor output = at::empty_like(input); + const int64_t total = input.numel(); + if (total == 0) { + return {output}; + } + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_norm_or_identity_forward_kernel<<>>( + input.data_ptr(), + has_weight ? weight.data_ptr() : nullptr, + output.data_ptr(), + total, + static_cast(receivers), + static_cast(hidden), + has_weight, + weight_is_shared, + static_cast(eps)); + check_launch("program_transition_norm_or_identity_forward_kernel"); + return {output}; +} + +std::vector flat_bucket_registered_program_transition_norm_or_identity_backward_cuda( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& grad_output, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + double eps, + bool return_input_grad) { + TORCH_CHECK(bucket_ordinal >= 0, "program norm_or_identity backward requires a transition bucket ordinal"); + check_cuda_float_bank(input, "input"); + check_cuda_float_bank(grad_output, "grad_output"); + TORCH_CHECK(grad_output.sizes() == input.sizes(), "program norm_or_identity grad_output shape mismatch"); + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t hidden = input.size(2); + TORCH_CHECK(hidden > 0, "program norm_or_identity hidden size must be positive"); + const bool has_weight = check_optional_program_norm_weight(weight, receivers, hidden); + const bool weight_is_shared = has_weight && weight.dim() == 1; + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "program norm_or_identity backward"); + at::Tensor grad_input = return_input_grad ? at::empty_like(input) : at::empty({0}, input.options()); + at::Tensor grad_weight = has_weight ? at::empty_like(weight) : at::empty({0}, input.options()); + const int64_t total = input.numel(); + const auto stream = at::cuda::getCurrentCUDAStream(); + if (total > 0 && return_input_grad) { + const int input_blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_norm_or_identity_input_backward_kernel<<>>( + input.data_ptr(), + has_weight ? weight.data_ptr() : nullptr, + grad_output.data_ptr(), + grad_input.data_ptr(), + total, + static_cast(receivers), + static_cast(hidden), + has_weight, + weight_is_shared, + static_cast(eps)); + check_launch("program_transition_norm_or_identity_input_backward_kernel"); + } + if (has_weight && grad_weight.numel() > 0) { + const int weight_blocks = static_cast(std::min( + 4096, + (grad_weight.numel() + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_norm_or_identity_weight_backward_kernel<<>>( + input.data_ptr(), + grad_output.data_ptr(), + grad_weight.data_ptr(), + grad_weight.numel(), + static_cast(B), + static_cast(receivers), + static_cast(hidden), + weight_is_shared, + static_cast(eps)); + check_launch("program_transition_norm_or_identity_weight_backward_kernel"); + } + return {grad_input, grad_weight}; +} + +std::vector flat_bucket_registered_program_transition_diag_rtu_forward_cuda( + const at::Tensor& cell_input, + const at::Tensor& hc1, + const at::Tensor& hc2, + const at::Tensor& e_nu_c1, + const at::Tensor& e_nu_c2, + const at::Tensor& e_th_c1, + const at::Tensor& e_th_c2, + const at::Tensor& e_w1_c1, + const at::Tensor& e_w1_c2, + const at::Tensor& e_w2_c1, + const at::Tensor& e_w2_c2, + const at::Tensor& nu_log, + const at::Tensor& theta_log, + const at::Tensor& w1, + const at::Tensor& w2, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t activation_id, + bool write_trace_state_next) { + TORCH_CHECK(bucket_ordinal >= 0, "program diag_rtu requires a transition bucket ordinal"); + check_cuda_float_bank(cell_input, "cell_input"); + const int64_t B = cell_input.size(0); + const int64_t receivers = cell_input.size(1); + const int64_t hidden = cell_input.size(2); + TORCH_CHECK(hidden > 0, "program diag_rtu hidden size must be positive"); + const std::vector state_shape = {B, receivers, hidden}; + check_program_diag_state_tensor(hc1, "hc1", state_shape); + check_program_diag_state_tensor(hc2, "hc2", state_shape); + if (write_trace_state_next) { + check_program_diag_state_tensor(e_nu_c1, "e_nu_c1", state_shape); + check_program_diag_state_tensor(e_nu_c2, "e_nu_c2", state_shape); + check_program_diag_state_tensor(e_th_c1, "e_th_c1", state_shape); + check_program_diag_state_tensor(e_th_c2, "e_th_c2", state_shape); + check_program_diag_state_tensor(e_w1_c1, "e_w1_c1", state_shape); + check_program_diag_state_tensor(e_w1_c2, "e_w1_c2", state_shape); + check_program_diag_state_tensor(e_w2_c1, "e_w2_c1", state_shape); + check_program_diag_state_tensor(e_w2_c2, "e_w2_c2", state_shape); + } + check_program_diag_param_tensor(nu_log, "nu_log", receivers, hidden); + check_program_diag_param_tensor(theta_log, "theta_log", receivers, hidden); + check_program_diag_param_tensor(w1, "w1", receivers, hidden); + check_program_diag_param_tensor(w2, "w2", receivers, hidden); + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "program diag_rtu forward"); + at::Tensor preproj = at::empty({B, receivers, 2 * hidden}, cell_input.options()); + at::Tensor next_hc1 = at::empty_like(cell_input); + at::Tensor next_hc2 = at::empty_like(cell_input); + at::Tensor empty = at::empty({0}, cell_input.options()); + at::Tensor next_e_nu_c1 = write_trace_state_next ? at::empty_like(cell_input) : empty; + at::Tensor next_e_nu_c2 = write_trace_state_next ? at::empty_like(cell_input) : empty; + at::Tensor next_e_th_c1 = write_trace_state_next ? at::empty_like(cell_input) : empty; + at::Tensor next_e_th_c2 = write_trace_state_next ? at::empty_like(cell_input) : empty; + at::Tensor next_e_w1_c1 = write_trace_state_next ? at::empty_like(cell_input) : empty; + at::Tensor next_e_w1_c2 = write_trace_state_next ? at::empty_like(cell_input) : empty; + at::Tensor next_e_w2_c1 = write_trace_state_next ? at::empty_like(cell_input) : empty; + at::Tensor next_e_w2_c2 = write_trace_state_next ? at::empty_like(cell_input) : empty; + const int64_t total = cell_input.numel(); + if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_diag_rtu_forward_kernel<<>>( + cell_input.data_ptr(), + hc1.data_ptr(), + hc2.data_ptr(), + write_trace_state_next ? e_nu_c1.data_ptr() : nullptr, + write_trace_state_next ? e_nu_c2.data_ptr() : nullptr, + write_trace_state_next ? e_th_c1.data_ptr() : nullptr, + write_trace_state_next ? e_th_c2.data_ptr() : nullptr, + write_trace_state_next ? e_w1_c1.data_ptr() : nullptr, + write_trace_state_next ? e_w1_c2.data_ptr() : nullptr, + write_trace_state_next ? e_w2_c1.data_ptr() : nullptr, + write_trace_state_next ? e_w2_c2.data_ptr() : nullptr, + nu_log.data_ptr(), + theta_log.data_ptr(), + w1.data_ptr(), + w2.data_ptr(), + preproj.data_ptr(), + next_hc1.data_ptr(), + next_hc2.data_ptr(), + write_trace_state_next ? next_e_nu_c1.data_ptr() : nullptr, + write_trace_state_next ? next_e_nu_c2.data_ptr() : nullptr, + write_trace_state_next ? next_e_th_c1.data_ptr() : nullptr, + write_trace_state_next ? next_e_th_c2.data_ptr() : nullptr, + write_trace_state_next ? next_e_w1_c1.data_ptr() : nullptr, + write_trace_state_next ? next_e_w1_c2.data_ptr() : nullptr, + write_trace_state_next ? next_e_w2_c1.data_ptr() : nullptr, + write_trace_state_next ? next_e_w2_c2.data_ptr() : nullptr, + total, + static_cast(receivers), + static_cast(hidden), + static_cast(activation_id), + true, + write_trace_state_next, + write_trace_state_next); + check_launch("program_transition_diag_rtu_forward_kernel"); + } + if (!write_trace_state_next) { + return {preproj, next_hc1, next_hc2}; + } + return { + preproj, + next_hc1, + next_hc2, + next_e_nu_c1, + next_e_nu_c2, + next_e_th_c1, + next_e_th_c2, + next_e_w1_c1, + next_e_w1_c2, + next_e_w2_c1, + next_e_w2_c2, + }; +} + +std::vector flat_bucket_registered_program_transition_diag_rtu_backward_cuda( + const at::Tensor& cell_input, + const at::Tensor& hc1, + const at::Tensor& hc2, + const at::Tensor& nu_log, + const at::Tensor& theta_log, + const at::Tensor& w1, + const at::Tensor& w2, + const at::Tensor& grad_preproj, + const at::Tensor& grad_hc1_next, + const at::Tensor& grad_hc2_next, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t activation_id, + bool return_state_grads) { + TORCH_CHECK(bucket_ordinal >= 0, "program diag_rtu backward requires a transition bucket ordinal"); + check_cuda_float_bank(cell_input, "cell_input"); + const int64_t B = cell_input.size(0); + const int64_t receivers = cell_input.size(1); + const int64_t hidden = cell_input.size(2); + TORCH_CHECK(hidden > 0, "program diag_rtu hidden size must be positive"); + const std::vector state_shape = {B, receivers, hidden}; + const std::vector preproj_shape = {B, receivers, 2 * hidden}; + check_program_diag_state_tensor(hc1, "hc1", state_shape); + check_program_diag_state_tensor(hc2, "hc2", state_shape); + check_program_diag_param_tensor(nu_log, "nu_log", receivers, hidden); + check_program_diag_param_tensor(theta_log, "theta_log", receivers, hidden); + check_program_diag_param_tensor(w1, "w1", receivers, hidden); + check_program_diag_param_tensor(w2, "w2", receivers, hidden); + const bool has_grad_preproj = check_optional_program_diag_tensor(grad_preproj, "grad_preproj", preproj_shape); + const bool has_grad_hc1_next = check_optional_program_diag_tensor(grad_hc1_next, "grad_hc1_next", state_shape); + const bool has_grad_hc2_next = check_optional_program_diag_tensor(grad_hc2_next, "grad_hc2_next", state_shape); + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "program diag_rtu backward"); + at::Tensor grad_cell_input = at::empty_like(cell_input); + at::Tensor grad_hc1 = return_state_grads ? at::empty_like(cell_input) : at::empty({0}, cell_input.options()); + at::Tensor grad_hc2 = return_state_grads ? at::empty_like(cell_input) : at::empty({0}, cell_input.options()); + at::Tensor empty = at::empty({0}, cell_input.options()); + at::Tensor grad_nu_log = at::empty_like(nu_log); + at::Tensor grad_theta_log = at::empty_like(theta_log); + at::Tensor grad_w1 = at::empty_like(w1); + at::Tensor grad_w2 = at::empty_like(w2); + const int64_t total = cell_input.numel(); + const auto stream = at::cuda::getCurrentCUDAStream(); + if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_diag_rtu_input_backward_kernel<<>>( + cell_input.data_ptr(), + hc1.data_ptr(), + hc2.data_ptr(), + nu_log.data_ptr(), + theta_log.data_ptr(), + w1.data_ptr(), + w2.data_ptr(), + has_grad_preproj ? grad_preproj.data_ptr() : nullptr, + has_grad_hc1_next ? grad_hc1_next.data_ptr() : nullptr, + has_grad_hc2_next ? grad_hc2_next.data_ptr() : nullptr, + grad_cell_input.data_ptr(), + return_state_grads ? grad_hc1.data_ptr() : nullptr, + return_state_grads ? grad_hc2.data_ptr() : nullptr, + total, + static_cast(receivers), + static_cast(hidden), + static_cast(activation_id), + has_grad_preproj, + has_grad_hc1_next, + has_grad_hc2_next, + return_state_grads); + check_launch("program_transition_diag_rtu_input_backward_kernel"); + } + const int64_t param_total = receivers * hidden; + if (param_total > 0) { + const int param_blocks = static_cast(std::min( + 4096, + (param_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_diag_rtu_param_backward_kernel<<>>( + cell_input.data_ptr(), + hc1.data_ptr(), + hc2.data_ptr(), + nu_log.data_ptr(), + theta_log.data_ptr(), + w1.data_ptr(), + w2.data_ptr(), + has_grad_preproj ? grad_preproj.data_ptr() : nullptr, + has_grad_hc1_next ? grad_hc1_next.data_ptr() : nullptr, + has_grad_hc2_next ? grad_hc2_next.data_ptr() : nullptr, + grad_nu_log.data_ptr(), + grad_theta_log.data_ptr(), + grad_w1.data_ptr(), + grad_w2.data_ptr(), + param_total, + static_cast(B), + static_cast(receivers), + static_cast(hidden), + static_cast(activation_id), + has_grad_preproj, + has_grad_hc1_next, + has_grad_hc2_next); + check_launch("program_transition_diag_rtu_param_backward_kernel"); + } + return { + grad_cell_input, + grad_hc1, + grad_hc2, + empty, + empty, + empty, + empty, + empty, + empty, + empty, + empty, + grad_nu_log, + grad_theta_log, + grad_w1, + grad_w2, + }; +} + +std::vector flat_bucket_registered_program_transition_linear_forward_cuda( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t group_size) { + TORCH_CHECK(bucket_ordinal >= 0, "program transition linear requires a transition bucket ordinal"); + if (weight.dim() == 4) { + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "program transition gate affine forward"); + return {program_transition_linear_gate_affine_forward(input, weight, bias)}; + } + at::Tensor weight_view = program_transition_linear_weight_view(input, weight, group_size); + at::Tensor dense_weight = weight.dim() == 2 ? weight : weight_view; + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t output_dim = program_transition_linear_output_dim(weight_view); + validate_program_transition_linear_bias(bias, receivers, output_dim, group_size); + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "program transition linear forward"); + at::Tensor output = at::empty({B, receivers, output_dim}, input.options()); + fabric::cuda::ops::dense_affine_out_cuda( + input, + dense_weight, + bias, + output, + fabric::cuda::ops::DenseAffineLayout::ReceiverMajor, + group_size, + fabric::cuda::ops::DenseAffineOutputMode::Overwrite); + return {output}; +} + +std::vector flat_bucket_registered_program_transition_linear_backward_cuda( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& bias, + const at::Tensor& grad_output, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + int64_t group_size) { + TORCH_CHECK(bucket_ordinal >= 0, "program transition linear backward requires a transition bucket ordinal"); + if (weight.dim() == 4) { + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "program transition gate affine backward"); + return program_transition_linear_gate_affine_backward(input, weight, bias, grad_output); + } + at::Tensor weight_view = program_transition_linear_weight_view(input, weight, group_size); + check_cuda_float_bank(grad_output, "grad_output"); + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t input_dim = input.size(2); + const int64_t output_dim = program_transition_linear_output_dim(weight_view); + TORCH_CHECK(grad_output.size(0) == B, "program transition linear grad_output B must match input"); + TORCH_CHECK(grad_output.size(1) == receivers, "program transition linear grad_output R must match input"); + TORCH_CHECK(grad_output.size(2) == output_dim, "program transition linear grad_output N must match weight"); + validate_program_transition_linear_bias(bias, receivers, output_dim, group_size); + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "program transition linear backward"); + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_weight = at::empty_like(weight_view); + at::Tensor grad_bias = bias.defined() && bias.numel() > 0 ? at::empty_like(bias) : at::empty({0}, input.options()); + const auto stream = at::cuda::getCurrentCUDAStream(); + const int64_t input_total = grad_input.numel(); + if (input_total > 0) { + const int input_blocks = static_cast(std::min( + 4096, + (input_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_linear_input_backward_kernel<<>>( + grad_output.data_ptr(), + weight_view.data_ptr(), + grad_input.data_ptr(), + input_total, + static_cast(receivers), + static_cast(input_dim), + static_cast(output_dim), + static_cast(weight_view.size(0)), + static_cast(group_size)); + check_launch("program_transition_linear_input_backward_kernel"); + } + const int64_t weight_total = grad_weight.numel(); + if (weight_total > 0) { + const int weight_blocks = static_cast(std::min( + 4096, + (weight_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_linear_weight_backward_kernel<<>>( + input.data_ptr(), + grad_output.data_ptr(), + grad_weight.data_ptr(), + weight_total, + static_cast(B), + static_cast(receivers), + static_cast(input_dim), + static_cast(output_dim), + static_cast(weight_view.size(0)), + static_cast(group_size)); + check_launch("program_transition_linear_weight_backward_kernel"); + } + if (bias.defined() && bias.numel() > 0) { + const int64_t bias_total = grad_bias.numel(); + const int bias_blocks = static_cast(std::min( + 4096, + (bias_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_linear_bias_backward_kernel<<>>( + grad_output.data_ptr(), + grad_bias.data_ptr(), + bias_total, + static_cast(B), + static_cast(receivers), + static_cast(output_dim), + bias.dim() == 2 ? static_cast(bias.size(0)) : 1, + static_cast(group_size), + bias.dim() == 1); + check_launch("program_transition_linear_bias_backward_kernel"); + } + if (weight.dim() == 2) { + return {grad_input, grad_weight.view_as(weight), grad_bias}; + } + return {grad_input, grad_weight, grad_bias}; +} + +std::vector flat_bucket_registered_program_transition_recurrent_matmul_forward_cuda( + const at::Tensor& input, + const at::Tensor& recurrent_kernel, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + TORCH_CHECK(bucket_ordinal >= 0, "program transition recurrent matmul requires a transition bucket ordinal"); + check_cuda_float_bank(input, "input"); + TORCH_CHECK(recurrent_kernel.is_cuda(), "recurrent_kernel must be a CUDA tensor"); + TORCH_CHECK(recurrent_kernel.is_contiguous(), "recurrent_kernel must be contiguous"); + TORCH_CHECK(recurrent_kernel.scalar_type() == at::kFloat, "recurrent_kernel must be float32"); + TORCH_CHECK(recurrent_kernel.dim() == 5, "recurrent_kernel must be [R,G,heads,head_dim,head_dim]"); + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t hidden_dim = input.size(2); + const int64_t gate_count = recurrent_kernel.size(1); + const int64_t head_count = recurrent_kernel.size(2); + const int64_t head_dim = recurrent_kernel.size(3); + TORCH_CHECK(recurrent_kernel.size(0) == receivers, "recurrent_kernel R must match input"); + TORCH_CHECK(recurrent_kernel.size(4) == head_dim, "recurrent_kernel must have square head blocks"); + TORCH_CHECK(hidden_dim == head_count * head_dim, "input hidden dim must equal heads*head_dim"); + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "program transition recurrent matmul forward"); + at::Tensor output = at::empty({B, receivers, gate_count, hidden_dim}, input.options()); + const int64_t total = output.numel(); + if (total == 0) { + return {output}; + } + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_recurrent_matmul_forward_kernel<<>>( + input.data_ptr(), + recurrent_kernel.data_ptr(), + output.data_ptr(), + total, + static_cast(receivers), + static_cast(gate_count), + static_cast(head_count), + static_cast(head_dim)); + check_launch("program_transition_recurrent_matmul_forward_kernel"); + return {output}; +} + +std::vector flat_bucket_registered_program_transition_recurrent_matmul_backward_cuda( + const at::Tensor& input, + const at::Tensor& recurrent_kernel, + const at::Tensor& grad_output, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal, + bool return_input_grad) { + TORCH_CHECK(bucket_ordinal >= 0, "program transition recurrent matmul backward requires a transition bucket ordinal"); + check_cuda_float_bank(input, "input"); + check_cuda_float_rank4(grad_output, "grad_output"); + TORCH_CHECK(recurrent_kernel.is_cuda(), "recurrent_kernel must be a CUDA tensor"); + TORCH_CHECK(recurrent_kernel.is_contiguous(), "recurrent_kernel must be contiguous"); + TORCH_CHECK(recurrent_kernel.scalar_type() == at::kFloat, "recurrent_kernel must be float32"); + TORCH_CHECK(recurrent_kernel.dim() == 5, "recurrent_kernel must be [R,G,heads,head_dim,head_dim]"); + const int64_t B = input.size(0); + const int64_t receivers = input.size(1); + const int64_t hidden_dim = input.size(2); + const int64_t gate_count = recurrent_kernel.size(1); + const int64_t head_count = recurrent_kernel.size(2); + const int64_t head_dim = recurrent_kernel.size(3); + TORCH_CHECK(recurrent_kernel.size(0) == receivers, "recurrent_kernel R must match input"); + TORCH_CHECK(recurrent_kernel.size(4) == head_dim, "recurrent_kernel must have square head blocks"); + TORCH_CHECK(hidden_dim == head_count * head_dim, "input hidden dim must equal heads*head_dim"); + TORCH_CHECK(grad_output.size(0) == B, "grad_output B must match input"); + TORCH_CHECK(grad_output.size(1) == receivers, "grad_output R must match input"); + TORCH_CHECK(grad_output.size(2) == gate_count, "grad_output gate count must match recurrent_kernel"); + TORCH_CHECK(grad_output.size(3) == hidden_dim, "grad_output hidden dim must match input"); + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "program transition recurrent matmul backward"); + at::Tensor grad_input = return_input_grad ? at::empty_like(input) : at::empty({0}, input.options()); + at::Tensor grad_kernel = at::empty_like(recurrent_kernel); + const auto stream = at::cuda::getCurrentCUDAStream(); + if (return_input_grad && grad_input.numel() > 0) { + const int64_t input_total = grad_input.numel(); + const int input_blocks = static_cast(std::min( + 4096, + (input_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_recurrent_matmul_input_backward_kernel<<>>( + grad_output.data_ptr(), + recurrent_kernel.data_ptr(), + grad_input.data_ptr(), + input_total, + static_cast(receivers), + static_cast(gate_count), + static_cast(head_count), + static_cast(head_dim)); + check_launch("program_transition_recurrent_matmul_input_backward_kernel"); + } + const int64_t weight_total = grad_kernel.numel(); + if (weight_total > 0) { + const int weight_blocks = static_cast(std::min( + 4096, + (weight_total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + program_transition_recurrent_matmul_weight_backward_kernel<<>>( + input.data_ptr(), + grad_output.data_ptr(), + grad_kernel.data_ptr(), + weight_total, + static_cast(B), + static_cast(receivers), + static_cast(gate_count), + static_cast(head_count), + static_cast(head_dim)); + check_launch("program_transition_recurrent_matmul_weight_backward_kernel"); + } + return {grad_input, grad_kernel}; +} + +std::vector flat_bucket_registered_program_transition_tanh_forward_cuda( + const at::Tensor& input, + const at::Tensor& forward_executor_rows, + const at::Tensor& forward_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + TORCH_CHECK(bucket_ordinal >= 0, "program transition tanh requires a transition bucket ordinal"); + check_cuda_float_bank(input, "input"); + validate_registered_executor_binding_rows( + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + executor_id, + bucket_ordinal, + "program transition tanh forward", + false); + at::Tensor output = at::empty_like(input); + const int64_t total = output.numel(); + if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_tanh_forward_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + total); + check_launch("program_transition_tanh_forward_kernel"); + } + return {output}; +} + +std::vector flat_bucket_registered_program_transition_tanh_backward_cuda( + const at::Tensor& output, + const at::Tensor& grad_output, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + int64_t executor_id, + int64_t bucket_ordinal) { + TORCH_CHECK(bucket_ordinal >= 0, "program transition tanh backward requires a transition bucket ordinal"); + check_cuda_float_bank(output, "output"); + check_cuda_float_bank(grad_output, "grad_output"); + TORCH_CHECK(grad_output.sizes() == output.sizes(), "program transition tanh grad_output shape mismatch"); + validate_registered_executor_binding_rows( + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + executor_id, + bucket_ordinal, + "program transition tanh backward", + false); + at::Tensor grad_input = at::empty_like(output); + const int64_t total = grad_input.numel(); + if (total > 0) { + const int blocks = static_cast(std::min( + 4096, + (total + kThreadsPerBlock - 1) / kThreadsPerBlock)); + const auto stream = at::cuda::getCurrentCUDAStream(); + program_transition_tanh_backward_kernel<<>>( + output.data_ptr(), + grad_output.data_ptr(), + grad_input.data_ptr(), + total); + check_launch("program_transition_tanh_backward_kernel"); + } + return {grad_input}; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_handlers.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_handlers.cuh new file mode 100644 index 00000000..3e21153c --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_handlers.cuh @@ -0,0 +1,937 @@ +inline void run_registered_tanh_reverse_transition_handler( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t native_callable_hash, + int64_t schema_version, + bool return_state_grads) { + (void)params; + (void)return_state_grads; + const auto input_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingInput, + logical_name, + true, + schema_version, + "registered tanh reverse handler"); + }; + const auto output_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingOutput, + logical_name, + true, + schema_version, + "registered tanh reverse handler"); + }; + at::Tensor output = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("output"), + "registered tanh reverse handler output"); + at::Tensor grad_output = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_output"), + "registered tanh reverse handler grad_output"); + const std::vector backward = + flat_bucket_registered_program_transition_tanh_backward_cuda( + output, + grad_output, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_input"), + backward[0], + "registered tanh reverse handler grad_input"); +} + +inline void run_registered_linear_reverse_transition_handler( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t native_callable_hash, + int64_t schema_version, + bool return_state_grads) { + (void)return_state_grads; + const auto input_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingInput, + logical_name, + true, + schema_version, + "registered linear reverse handler"); + }; + const auto parameter_binding = [&](const char* logical_name, bool required) { + return native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingParameter, + logical_name, + required, + schema_version, + "registered linear reverse handler"); + }; + const auto output_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingOutput, + logical_name, + true, + schema_version, + "registered linear reverse handler"); + }; + at::Tensor input = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("input"), + "registered linear reverse handler input"); + at::Tensor grad_output = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_output"), + "registered linear reverse handler grad_output"); + at::Tensor weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("weight", true), + "registered linear reverse handler weight"); + const int64_t bias_binding = parameter_binding("bias", false); + at::Tensor bias = bias_binding >= 0 + ? program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + bias_binding, + "registered linear reverse handler bias") + : input.new_empty({0}); + const std::vector backward = + flat_bucket_registered_program_transition_linear_backward_cuda( + input, + weight, + bias, + grad_output, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + 1); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_input"), + backward[0], + "registered linear reverse handler grad_input"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_weight"), + backward[1], + "registered linear reverse handler grad_weight"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_bias"), + backward[2], + "registered linear reverse handler grad_bias"); +} + +inline void run_registered_matmul_reverse_transition_handler( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t native_callable_hash, + int64_t schema_version, + bool return_state_grads) { + const auto input_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingInput, + logical_name, + true, + schema_version, + "registered matmul reverse handler"); + }; + const auto parameter_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingParameter, + logical_name, + true, + schema_version, + "registered matmul reverse handler"); + }; + const auto output_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingOutput, + logical_name, + true, + schema_version, + "registered matmul reverse handler"); + }; + at::Tensor input = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("input"), + "registered matmul reverse handler input"); + at::Tensor grad_output = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_output"), + "registered matmul reverse handler grad_output"); + at::Tensor weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("weight"), + "registered matmul reverse handler weight"); + const std::vector backward = + flat_bucket_registered_program_transition_recurrent_matmul_backward_cuda( + input, + weight, + grad_output, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + return_state_grads); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_input"), + backward[0], + "registered matmul reverse handler grad_input"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_weight"), + backward[1], + "registered matmul reverse handler grad_weight"); +} + +inline void run_registered_norm_or_identity_reverse_transition_handler( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t native_callable_hash, + int64_t schema_version, + bool return_state_grads) { + (void)return_state_grads; + const auto input_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingInput, + logical_name, + true, + schema_version, + "registered norm_or_identity reverse handler"); + }; + const auto parameter_binding = [&](const char* logical_name, bool required) { + return native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingParameter, + logical_name, + required, + schema_version, + "registered norm_or_identity reverse handler"); + }; + const auto output_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingOutput, + logical_name, + true, + schema_version, + "registered norm_or_identity reverse handler"); + }; + at::Tensor input = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("input"), + "registered norm_or_identity reverse handler input"); + at::Tensor grad_output = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_output"), + "registered norm_or_identity reverse handler grad_output"); + at::Tensor weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("weight", true), + "registered norm_or_identity reverse handler weight"); + const int64_t eps_binding = parameter_binding("eps", false); + const double eps = eps_binding >= 0 + ? program_scalar_double_for_binding( + program_tensors, + program_tensor_binding_rows, + eps_binding, + "registered norm_or_identity reverse handler eps") + : 1.0e-5; + const std::vector backward = + flat_bucket_registered_program_transition_norm_or_identity_backward_cuda( + input, + weight, + grad_output, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + eps, + true); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_input"), + backward[0], + "registered norm_or_identity reverse handler grad_input"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_weight"), + backward[1], + "registered norm_or_identity reverse handler grad_weight"); +} + +inline int64_t registered_transition_reverse_primitive_row_index_for_handler( + const RegisteredReverseExecutorHandler& handler, + const RegisteredFusedProgramSpan& span, + const at::Tensor& primitive_rows, + const char* subject) { + TORCH_CHECK(handler.runs_transition_adjoint, subject, " received non-transition reverse handler ", handler.name); + TORCH_CHECK(span.surface_opcode == kTransitionSurfaceOpcode, subject, " received non-transition executor span"); + TORCH_CHECK( + span.primitive_row_count == handler.primitive_row_count, + subject, + " handler ", + handler.name, + " primitive row count mismatch"); + const int64_t primitive_row_index = span.primitive_row_start; + const int64_t* primitive = primitive_rows.data_ptr() + primitive_row_index * 4; + TORCH_CHECK( + primitive[0] == handler.primitive_opcode, + subject, + " handler ", + handler.name, + " primitive opcode mismatch"); + TORCH_CHECK( + primitive[3] == span.bucket_ordinal, + subject, + " handler ", + handler.name, + " primitive bucket mismatch"); + TORCH_CHECK( + primitive[2] == span.receiver_count, + subject, + " handler ", + handler.name, + " primitive receiver count mismatch"); + return primitive_row_index; +} + +inline void run_registered_gated_logspace_reverse_transition_handler( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t native_callable_hash, + int64_t schema_version, + bool return_state_grads) { + const auto input_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingInput, + logical_name, + true, + schema_version, + "registered gated reverse handler"); + }; + const auto parameter_binding = [&](const char* logical_name, bool required) { + return native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingParameter, + logical_name, + required, + schema_version, + "registered gated reverse handler"); + }; + const auto output_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingOutput, + logical_name, + true, + schema_version, + "registered gated reverse handler"); + }; + at::Tensor aggregated_message = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("aggregated_message"), + "registered gated reverse handler aggregated_message"); + at::Tensor transition_input = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("transition_input"), + "registered gated reverse handler transition_input"); + at::Tensor gate_logits = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("gate_logits"), + "registered gated reverse handler gate_logits"); + at::Tensor recurrent_gate_logits = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("recurrent_gate_logits"), + "registered gated reverse handler recurrent_gate_logits"); + at::Tensor y_prev = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("y"), + "registered gated reverse handler y"); + at::Tensor c_prev = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("c"), + "registered gated reverse handler c"); + at::Tensor n_prev = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("n"), + "registered gated reverse handler n"); + at::Tensor m_prev = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("m"), + "registered gated reverse handler m"); + at::Tensor next_y = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("next_y"), + "registered gated reverse handler next_y"); + at::Tensor grad_public_y = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_public_y"), + "registered gated reverse handler grad_public_y"); + at::Tensor grad_next_y = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_next_y"), + "registered gated reverse handler grad_next_y"); + at::Tensor grad_next_c = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_next_c"), + "registered gated reverse handler grad_next_c"); + at::Tensor grad_next_n = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_next_n"), + "registered gated reverse handler grad_next_n"); + at::Tensor grad_next_m = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_next_m"), + "registered gated reverse handler grad_next_m"); + at::Tensor value_to_state_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("value_to_state_weight", true), + "registered gated reverse handler value_to_state_weight"); + at::Tensor recurrent_bias = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("recurrent_bias", true), + "registered gated reverse handler recurrent_bias"); + at::Tensor gate_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("gate_weight", true), + "registered gated reverse handler gate_weight"); + at::Tensor gate_bias = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("bias", true), + "registered gated reverse handler gate_bias"); + at::Tensor recurrent_kernel = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("recurrent_kernel", true), + "registered gated reverse handler recurrent_kernel"); + at::Tensor outnorm_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("outnorm_weight", true), + "registered gated reverse handler outnorm_weight"); + const int64_t eps_binding = parameter_binding("outnorm_eps", false); + const double eps = eps_binding >= 0 + ? program_scalar_double_for_binding( + program_tensors, + program_tensor_binding_rows, + eps_binding, + "registered gated reverse handler outnorm_eps") + : 1.0e-5; + + std::vector norm_backward = + flat_bucket_registered_program_transition_norm_or_identity_backward_cuda( + next_y, + outnorm_weight, + grad_public_y, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + eps, + true); + at::Tensor grad_next_y_total = norm_backward[0] + grad_next_y; + norm_backward[0] = at::Tensor(); + std::vector core_backward = + flat_bucket_registered_program_transition_gated_logspace_recurrence_backward_cuda( + gate_logits, + recurrent_gate_logits, + c_prev, + n_prev, + m_prev, + grad_next_y_total, + grad_next_c, + grad_next_n, + grad_next_m, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + return_state_grads); + const at::Tensor& grad_raw = core_backward[0]; + grad_next_y_total = at::Tensor(); + std::vector recurrent_backward = + flat_bucket_registered_program_transition_recurrent_matmul_backward_cuda( + y_prev, + recurrent_kernel, + grad_raw, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + return_state_grads); + std::vector gate_backward = + flat_bucket_registered_program_transition_linear_backward_cuda( + transition_input, + gate_weight, + gate_bias, + grad_raw, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + 1); + core_backward[0] = at::Tensor(); + std::vector input_backward = + flat_bucket_registered_program_transition_linear_backward_cuda( + aggregated_message, + value_to_state_weight, + recurrent_bias, + gate_backward[0], + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + 1); + gate_backward[0] = at::Tensor(); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_aggregated_message"), input_backward[0], "registered gated reverse handler grad message"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_y"), recurrent_backward[0], "registered gated reverse handler grad y"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_c"), core_backward[1], "registered gated reverse handler grad c"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_n"), core_backward[2], "registered gated reverse handler grad n"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_m"), core_backward[3], "registered gated reverse handler grad m"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_value_to_state_weight"), input_backward[1], "registered gated reverse handler grad input weight"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_recurrent_bias"), input_backward[2], "registered gated reverse handler grad input bias"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_gate_weight"), gate_backward[1], "registered gated reverse handler grad gate weight"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_bias"), gate_backward[2], "registered gated reverse handler grad gate bias"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_recurrent_kernel"), recurrent_backward[1], "registered gated reverse handler grad recurrent kernel"); + set_program_tensor_for_binding( + program_tensors, program_tensor_binding_rows, output_binding("grad_outnorm_weight"), norm_backward[1], "registered gated reverse handler grad outnorm"); +} + +inline void run_registered_diag_rtu_reverse_transition_handler( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t native_callable_hash, + int64_t schema_version, + bool return_state_grads) { + const auto input_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + inputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingInput, + logical_name, + true, + schema_version, + "registered diag reverse handler"); + }; + const auto parameter_binding = [&](const char* logical_name, bool required) { + return native_callable_program_binding_for( + params, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingParameter, + logical_name, + required, + schema_version, + "registered diag reverse handler"); + }; + const auto output_binding = [&](const char* logical_name) { + return native_callable_program_binding_for( + outputs, + native_callable_binding_schema_rows, + native_callable_hash, + kReverseDirectionOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingOutput, + logical_name, + true, + schema_version, + "registered diag reverse handler"); + }; + at::Tensor aggregated_message = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("aggregated_message"), + "registered diag reverse handler aggregated_message"); + at::Tensor cell_input = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("cell_input"), + "registered diag reverse handler cell_input"); + at::Tensor hc1 = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("hc1"), + "registered diag reverse handler hc1"); + at::Tensor hc2 = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("hc2"), + "registered diag reverse handler hc2"); + at::Tensor preproj = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("preproj"), + "registered diag reverse handler preproj"); + at::Tensor public_y_raw = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("public_y_raw"), + "registered diag reverse handler public_y_raw"); + at::Tensor grad_public_y = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_public_y"), + "registered diag reverse handler grad_public_y"); + at::Tensor grad_next_hc1 = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_next_hc1"), + "registered diag reverse handler grad_next_hc1"); + at::Tensor grad_next_hc2 = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + input_binding("grad_next_hc2"), + "registered diag reverse handler grad_next_hc2"); + at::Tensor input_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("input_proj_weight", true), + "registered diag reverse handler input weight"); + at::Tensor input_bias = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("recurrent_cell_bias", true), + "registered diag reverse handler input bias"); + at::Tensor nu_log = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("nu_log", true), + "registered diag reverse handler nu_log"); + at::Tensor theta_log = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("theta_log", true), + "registered diag reverse handler theta_log"); + at::Tensor w1 = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("w1", true), + "registered diag reverse handler w1"); + at::Tensor w2 = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("w2", true), + "registered diag reverse handler w2"); + at::Tensor output_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("out_proj_weight", true), + "registered diag reverse handler output weight"); + at::Tensor output_bias = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("out_proj_bias", true), + "registered diag reverse handler output bias"); + at::Tensor outnorm_weight = program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + parameter_binding("outnorm_weight", true), + "registered diag reverse handler outnorm_weight"); + const int64_t activation_id_binding = parameter_binding("activation_id", false); + const int64_t activation_id = activation_id_binding >= 0 + ? program_scalar_int_for_binding( + program_tensors, + program_tensor_binding_rows, + activation_id_binding, + "registered diag reverse handler activation_id") + : 0; + const int64_t outnorm_eps_binding = parameter_binding("outnorm_eps", false); + const double outnorm_eps = outnorm_eps_binding >= 0 + ? program_scalar_double_for_binding( + program_tensors, + program_tensor_binding_rows, + outnorm_eps_binding, + "registered diag reverse handler outnorm_eps") + : 1.0e-6; + std::vector norm_backward = + flat_bucket_registered_program_transition_norm_or_identity_backward_cuda( + public_y_raw, + outnorm_weight, + grad_public_y, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + outnorm_eps, + true); + std::vector output_backward = + flat_bucket_registered_program_transition_linear_backward_cuda( + preproj, + output_weight, + output_bias, + norm_backward[0], + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + 1); + norm_backward[0] = at::Tensor(); + std::vector core_backward = + flat_bucket_registered_program_transition_diag_rtu_backward_cuda( + cell_input, + hc1, + hc2, + nu_log, + theta_log, + w1, + w2, + output_backward[0], + grad_next_hc1, + grad_next_hc2, + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + activation_id, + return_state_grads); + output_backward[0] = at::Tensor(); + std::vector input_backward = + flat_bucket_registered_program_transition_linear_backward_cuda( + aggregated_message, + input_weight, + input_bias, + core_backward[0], + reverse_executor_rows, + reverse_executor_binding_rows, + span.executor_id, + span.bucket_ordinal, + 1); + core_backward[0] = at::Tensor(); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_aggregated_message"), + input_backward[0], + "registered diag reverse handler grad message"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_hc1"), + core_backward[1], + "registered diag reverse handler grad hc1"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_hc2"), + core_backward[2], + "registered diag reverse handler grad hc2"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_input_proj_weight"), + input_backward[1], + "registered diag reverse handler grad input weight"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_recurrent_cell_bias"), + input_backward[2], + "registered diag reverse handler grad input bias"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_nu_log"), + core_backward[11], + "registered diag reverse handler grad nu"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_theta_log"), + core_backward[12], + "registered diag reverse handler grad theta"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_w1"), + core_backward[13], + "registered diag reverse handler grad w1"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_w2"), + core_backward[14], + "registered diag reverse handler grad w2"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_out_proj_weight"), + output_backward[1], + "registered diag reverse handler grad output weight"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_out_proj_bias"), + output_backward[2], + "registered diag reverse handler grad output bias"); + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + output_binding("grad_outnorm_weight"), + norm_backward[1], + "registered diag reverse handler grad outnorm"); +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_program.cuh b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_program.cuh new file mode 100644 index 00000000..7c590eed --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/transition_reverse_program.cuh @@ -0,0 +1,872 @@ +#pragma once + +#include + +constexpr int64_t kRegisteredBackwardMemoryStageTransitionGroupEntry = 101; +constexpr int64_t kRegisteredBackwardMemoryStageTransitionGroupParamsBound = 102; +constexpr int64_t kRegisteredBackwardMemoryStageTransitionGroupDynamicBound = 103; +constexpr int64_t kRegisteredBackwardMemoryStageTransitionGroupAfterForwardRecompute = 104; +constexpr int64_t kRegisteredBackwardMemoryStageTransitionGroupAfterReversePrimitive = 105; + +inline void append_registered_transition_reverse_memory_stage_row( + std::vector* rows, + const at::Tensor& reference, + int64_t local_step, + int64_t stage_id) { + if (rows == nullptr || !reference.defined() || !reference.is_cuda()) { + return; + } + const auto device_index = static_cast(reference.get_device()); + const auto stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_index); + const size_t aggregate = static_cast(c10::CachingAllocator::StatType::AGGREGATE); + rows->push_back(local_step); + rows->push_back(stage_id); + rows->push_back(stats.allocated_bytes[aggregate].current); + rows->push_back(stats.reserved_bytes[aggregate].current); + rows->push_back(stats.allocated_bytes[aggregate].peak); +} + +using RegisteredTransitionReversePrimitiveRunFn = void (*)( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t native_callable_hash, + int64_t schema_version, + bool return_state_grads); + +struct RegisteredTransitionReversePrimitiveExecutor { + int64_t native_callable_hash; + int64_t primitive_backward_callable_hash; + int64_t input_count; + int64_t min_param_count; + int64_t max_param_count; + int64_t output_count; + const char* name; + RegisteredTransitionReversePrimitiveRunFn run; +}; + +inline bool registered_reverse_callable_matches_native_strategy( + const RegisteredTransitionReversePrimitiveExecutor& executor, + const RegisteredReverseExecutorHandler& handler, + const RegisteredNativeStrategyRow& native_strategy, + int64_t primitive_backward_callable_hash) { + (void)handler; + return executor.native_callable_hash == native_strategy.native_callable_hash && + executor.primitive_backward_callable_hash == primitive_backward_callable_hash; +} + +inline void require_registered_reverse_primitive_binding_contract( + const RegisteredTransitionReversePrimitiveExecutor& executor, + const at::Tensor& native_callable_binding_schema_rows, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t schema_version) { + const RegisteredNativeCallableBindingVectorContract input_contract = + require_native_callable_binding_vector_contract( + inputs, + native_callable_binding_schema_rows, + executor.native_callable_hash, + kReverseDirectionOpcode, + kTransitionSurfaceOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingInput, + false, + schema_version, + executor.name); + TORCH_CHECK( + executor.input_count == input_contract.max_count, + executor.name, + " generated reverse transition input count does not match compiler native callable schema"); + const RegisteredNativeCallableBindingVectorContract parameter_contract = + require_native_callable_binding_vector_contract( + params, + native_callable_binding_schema_rows, + executor.native_callable_hash, + kReverseDirectionOpcode, + kTransitionSurfaceOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingParameter, + true, + schema_version, + executor.name); + TORCH_CHECK( + executor.min_param_count == parameter_contract.min_count && + executor.max_param_count == parameter_contract.max_count, + executor.name, + " generated reverse transition parameter count does not match compiler native callable schema"); + const RegisteredNativeCallableBindingVectorContract output_contract = + require_native_callable_binding_vector_contract( + outputs, + native_callable_binding_schema_rows, + executor.native_callable_hash, + kReverseDirectionOpcode, + kTransitionSurfaceOpcode, + span.handler_primitive_opcode, + kNativeCallableBindingOutput, + false, + schema_version, + executor.name); + TORCH_CHECK( + executor.output_count == output_contract.max_count, + executor.name, + " generated reverse transition output count does not match compiler native callable schema"); +} + +#define REGISTERED_TEMPORAL_NATIVE_REVERSE_TRANSITION_CATALOG +#include "../flat_bucket_registered_native_callables.cuh" + +inline const RegisteredTransitionReversePrimitiveExecutor& registered_transition_reverse_primitive_executor_for_handler( + const RegisteredReverseExecutorHandler& handler, + const at::Tensor& native_strategy_rows, + const at::Tensor& transition_primitive_callable_rows, + const RegisteredFusedProgramSpan& span) { + const RegisteredNativeStrategyRow native_strategy = registered_native_strategy_row_for_span( + native_strategy_rows, + kReverseDirectionOpcode, + span, + "registered transition reverse primitive executor"); + const int64_t primitive_backward_callable_hash = registered_transition_backward_callable_hash_for_primitive( + transition_primitive_callable_rows, + native_strategy.primitive_opcode); + for (const RegisteredTransitionReversePrimitiveExecutor* executor = + registered_native_transition_reverse_primitive_catalog_begin(); + executor != registered_native_transition_reverse_primitive_catalog_end(); + ++executor) { + if (registered_reverse_callable_matches_native_strategy( + *executor, + handler, + native_strategy, + primitive_backward_callable_hash)) { + return *executor; + } + } + TORCH_CHECK( + false, + "registered transition reverse primitive executor is missing for compiler strategy contract: handler=", + handler.name, + ", handler_kind=", + handler.handler_kind, + ", executor_id=", + handler.executor_id, + ", primitive_opcode=", + native_strategy.primitive_opcode, + ", strategy_hash=", + native_strategy.strategy_id_hash, + ", access_count=", + native_strategy.program_access_count, + ", carry_count=", + native_strategy.state_carry_rule_count, + ", native_callable_hash=", + native_strategy.native_callable_hash, + ", primitive_backward_callable_hash=", + primitive_backward_callable_hash); + return *registered_native_transition_reverse_primitive_catalog_begin(); +} + +inline void run_registered_transition_reverse_handler( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& transition_primitive_callable_rows, + const RegisteredReverseExecutorHandler& handler, + const RegisteredFusedProgramSpan& span, + const std::vector& inputs, + const std::vector& params, + const std::vector& outputs, + int64_t schema_version, + bool return_state_grads) { + const RegisteredTransitionReversePrimitiveExecutor& executor = + registered_transition_reverse_primitive_executor_for_handler( + handler, + native_strategy_rows, + transition_primitive_callable_rows, + span); + require_registered_reverse_primitive_binding_contract( + executor, + native_callable_binding_schema_rows, + span, + inputs, + params, + outputs, + schema_version); + executor.run( + program_tensors, + program_tensor_binding_rows, + reverse_executor_rows, + reverse_executor_binding_rows, + native_callable_binding_schema_rows, + span, + inputs, + params, + outputs, + executor.native_callable_hash, + schema_version, + return_state_grads); +} + +std::vector flat_bucket_registered_temporal_fused_reverse_transition_program_cuda( + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& primitive_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + int64_t schema_version, + bool return_state_grads) { + TORCH_CHECK(schema_version == 1, "registered fused reverse transition program schema version mismatch"); + (void)native_callable_output_rows; + validate_registered_native_callable_binding_schema_rows(native_callable_binding_schema_rows, schema_version); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_cpu_long_rank2(primitive_rows, "primitive_rows", 4); + validate_registered_fused_program_executor_rows(primitive_rows, reverse_executor_rows, "reverse_executor_rows"); + validate_registered_fused_program_binding_rows( + primitive_rows, + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + "reverse_executor_binding_rows"); + validate_registered_fused_program_memory_rows(primitive_rows, memory_liveness_rows); + at::Tensor reverse_spans = decode_registered_fused_program_spans( + primitive_rows, + reverse_executor_rows, + reverse_handler_rows, + native_strategy_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + kReverseDirectionOpcode, + schema_version, + native_strategy_rows.size(0) > 0, + "fused_reverse_transition_program"); + validate_registered_fused_reverse_span_memory(reverse_spans, memory_liveness_rows); + bool saw_transition_span = false; + for (int64_t span_index = 0; span_index < reverse_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(reverse_spans, span_index); + const RegisteredReverseExecutorHandler& handler = registered_reverse_executor_handler_for_span(span); + if (!handler.runs_transition_adjoint) { + continue; + } + const int64_t primitive_row_index = registered_transition_reverse_primitive_row_index_for_handler( + handler, + span, + primitive_rows, + "registered fused reverse transition program"); + const std::vector inputs = + fused_reverse_input_bindings_for_primitive(reverse_executor_binding_rows, primitive_row_index); + const std::vector params = + fused_reverse_parameter_bindings_for_primitive(reverse_executor_binding_rows, primitive_row_index); + const std::vector outputs = + fused_reverse_output_bindings_for_primitive(reverse_executor_binding_rows, primitive_row_index); + run_registered_transition_reverse_handler( + program_tensors, + program_tensor_binding_rows, + reverse_executor_rows, + reverse_executor_binding_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + transition_primitive_callable_rows, + handler, + span, + inputs, + params, + outputs, + schema_version, + return_state_grads); + saw_transition_span = true; + } + TORCH_CHECK(saw_transition_span, "fused reverse transition program found no transition executor span"); + return program_tensors; +} + +inline void require_registered_transition_forward_span_for_reverse_span( + const at::Tensor& forward_spans, + const RegisteredFusedProgramSpan& reverse_span) { + check_cpu_long_rank2(forward_spans, "forward_spans", kFusedProgramSpanColumns); + bool saw_match = false; + for (int64_t span_index = 0; span_index < forward_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan forward_span = registered_fused_program_span_at(forward_spans, span_index); + const RegisteredForwardExecutorHandler& handler = registered_forward_executor_handler_for_span(forward_span); + if (!handler.runs_transition_program || forward_span.bucket_ordinal != reverse_span.bucket_ordinal) { + continue; + } + TORCH_CHECK( + forward_span.receiver_start == reverse_span.receiver_start && + forward_span.receiver_count == reverse_span.receiver_count, + "registered transition reverse handler has mismatched forward transition receiver span"); + saw_match = true; + } + TORCH_CHECK(saw_match, "registered transition reverse handler has no matching forward transition handler span"); +} + +inline std::vector registered_transition_dynamic_tensor_shape(const at::Tensor& tensor) { + std::vector shape; + shape.reserve(static_cast(tensor.dim())); + for (int64_t dim = 0; dim < tensor.dim(); ++dim) { + shape.push_back(tensor.size(dim)); + } + return shape; +} + +inline void bind_transition_dynamic_tensors_for_reverse_handler( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const RegisteredReverseExecutorHandler& handler, + const RegisteredFusedProgramSpan& span, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& transition_dynamic_binding_rows, + const at::Tensor& transition_reverse_seed_role_rows, + const std::vector& reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_access_rows, + const std::vector& transition_seed_tensors, + const at::Tensor& transition_seed_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + std::vector& cached_zero_seed_tensors, + int64_t local_step) { + check_cpu_long_rank2(forward_executor_binding_rows, "forward_executor_binding_rows", 8); + TORCH_CHECK(handler.runs_transition_adjoint, "registered transition dynamic binder received non-transition handler"); + check_transition_dynamic_binding_rows(transition_dynamic_binding_rows, transition_reverse_seed_role_rows); + const int64_t bucket_ordinal = span.bucket_ordinal; + const int64_t* rows = transition_dynamic_binding_rows.data_ptr(); + bool saw_dynamic_binding = false; + bool saw_message_binding = false; + for (int64_t row_index = 0; row_index < transition_dynamic_binding_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 5; + const int64_t binding_index = row[0]; + const int64_t source_kind = row[1]; + const int64_t source_key = row[2]; + const int64_t template_binding_index = row[3]; + const bool required = row[4] != 0; + at::Tensor value; + if (source_kind == kTransitionDynamicSourceReverseArtifact) { + value = reverse_artifact_tensor_for_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + source_key, + local_step, + "registered transition reverse dynamic artifact"); + if (source_key == kReverseArtifactAccessRecurrentMsgBackendOrder) { + at::Tensor span_view = value.slice(1, span.receiver_start, span.receiver_start + span.receiver_count); + if (span.receiver_start == 0 && span.receiver_count == value.size(1) && value.is_contiguous()) { + value = span_view; + } else { + value = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleTransitionReverseRecurrentMsgSpan, + bucket_ordinal, + registered_transition_dynamic_tensor_shape(span_view), + "registered transition reverse dynamic recurrent-message span"); + value.copy_(span_view); + } + saw_message_binding = true; + } else { + value = value.contiguous(); + } + } else if (source_kind == kTransitionDynamicSourceStateBeforeArtifact) { + value = optional_reverse_artifact_tensor_for_transition_state_binding( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + local_step, + bucket_ordinal, + source_key, + "registered transition reverse dynamic state-before artifact"); + if (!value.defined()) { + TORCH_CHECK( + !required, + "registered transition reverse dynamic state-before artifact is missing for binding ", + binding_index); + at::Tensor reference; + if (template_binding_index >= 0) { + reference = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + template_binding_index, + "registered transition reverse dynamic state-before zero template"); + } + if (!reference.defined() || reference.numel() == 0) { + reference = reverse_artifact_tensor_for_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + kReverseArtifactAccessRecurrentHiddenBackendOrder, + local_step, + "registered transition reverse dynamic state-before zero template") + .slice(1, span.receiver_start, span.receiver_start + span.receiver_count); + } + value = registered_runtime_buffer_for_role( + runtime_buffer_tensors, + runtime_buffer_rows, + kRuntimeBufferRoleTransitionReverseStateBeforeZero, + bucket_ordinal, + registered_transition_dynamic_tensor_shape(reference), + "registered transition reverse dynamic state-before zero"); + value.zero_(); + } else { + value = value.contiguous(); + } + } else if (source_kind == kTransitionDynamicSourceSeedOrZeros) { + at::Tensor reference; + if (template_binding_index >= 0) { + reference = program_tensor_for_binding_allow_empty( + program_tensors, + program_tensor_binding_rows, + template_binding_index, + "registered transition reverse dynamic seed template"); + } + if (!reference.defined() || reference.numel() == 0) { + reference = reverse_artifact_tensor_for_access_step( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + kReverseArtifactAccessRecurrentHiddenBackendOrder, + local_step, + "registered transition reverse dynamic public-state seed template") + .slice(1, span.receiver_start, span.receiver_start + span.receiver_count); + } + value = transition_seed_tensor_or_cached_zeros( + transition_seed_tensors, + transition_seed_rows, + source_key, + bucket_ordinal, + reference, + cached_zero_seed_tensors, + "registered transition reverse dynamic seed"); + } else { + TORCH_CHECK(false, "registered transition reverse dynamic binding has unsupported source kind ", source_kind); + } + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + binding_index, + value, + "registered transition reverse dynamic binding"); + saw_dynamic_binding = true; + } + TORCH_CHECK(saw_dynamic_binding, "registered transition reverse dynamic binder had no compiler binding rows"); + TORCH_CHECK(saw_message_binding, "registered transition reverse dynamic binder had no recurrent-message binding row"); +} + +inline void bind_transition_dynamic_tensors_for_handlers( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_spans, + const at::Tensor& reverse_spans, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& transition_dynamic_binding_rows, + const at::Tensor& transition_reverse_seed_role_rows, + const std::vector& reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_access_rows, + const std::vector& transition_seed_tensors, + const at::Tensor& transition_seed_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + int64_t local_step) { + (void)primitive_rows; + (void)reverse_executor_binding_rows; + bool saw_transition_handler = false; + std::vector cached_zero_seed_tensors; + for (int64_t span_index = 0; span_index < reverse_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(reverse_spans, span_index); + const RegisteredReverseExecutorHandler& handler = registered_reverse_executor_handler_for_span(span); + if (!handler.runs_transition_adjoint) { + continue; + } + require_registered_transition_forward_span_for_reverse_span(forward_spans, span); + bind_transition_dynamic_tensors_for_reverse_handler( + program_tensors, + program_tensor_binding_rows, + handler, + span, + forward_executor_binding_rows, + transition_dynamic_binding_rows, + transition_reverse_seed_role_rows, + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + transition_seed_tensors, + transition_seed_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + cached_zero_seed_tensors, + local_step); + saw_transition_handler = true; + } + TORCH_CHECK(saw_transition_handler, "registered transition dynamic binder found no reverse transition handler span"); +} + +inline void bind_transition_parameter_tensors( + std::vector& program_tensors, + const at::Tensor& program_tensor_binding_rows, + const std::vector& transition_parameter_tensors, + const at::Tensor& transition_parameter_rows, + int64_t bucket_ordinal) { + const int64_t* rows = transition_parameter_rows.data_ptr(); + for (int64_t row_index = 0; row_index < transition_parameter_rows.size(0); ++row_index) { + const int64_t* row = rows + row_index * 3; + if (row[2] != bucket_ordinal) { + continue; + } + const int64_t binding_index = row[0]; + const int64_t tensor_index = row[1]; + set_program_tensor_for_binding( + program_tensors, + program_tensor_binding_rows, + binding_index, + transition_parameter_tensors[static_cast(tensor_index)], + "registered transition parameter"); + } +} + +static std::vector registered_temporal_backward_transition_group_impl( + std::vector program_tensors, + const at::Tensor& program_tensor_binding_rows, + const at::Tensor& primitive_rows, + const at::Tensor& forward_executor_rows, + const at::Tensor& reverse_executor_rows, + const at::Tensor& forward_handler_rows, + const at::Tensor& reverse_handler_rows, + const at::Tensor& native_strategy_rows, + const at::Tensor& native_callable_binding_schema_rows, + const at::Tensor& native_callable_output_rows, + const at::Tensor& transition_primitive_callable_rows, + const at::Tensor& forward_executor_binding_rows, + const at::Tensor& reverse_executor_binding_rows, + const at::Tensor& memory_liveness_rows, + const std::vector& runtime_buffer_tensors, + const at::Tensor& runtime_buffer_rows, + std::vector reverse_artifact_tensors, + const at::Tensor& reverse_artifact_binding_rows, + const at::Tensor& reverse_artifact_role_rows, + const at::Tensor& reverse_artifact_access_rows, + std::vector transition_seed_tensors, + const at::Tensor& transition_reverse_seed_role_rows, + const at::Tensor& transition_seed_rows, + const at::Tensor& transition_dynamic_binding_rows, + std::vector transition_parameter_tensors, + const at::Tensor& transition_parameter_rows, + int64_t local_step, + std::vector* memory_stage_rows, + int64_t schema_version, + bool return_state_grads) { + TORCH_CHECK(schema_version == 1, "registered fused backward transition group schema version mismatch"); + check_program_tensor_binding_rows(program_tensor_binding_rows); + check_transition_seed_rows(transition_seed_tensors, transition_reverse_seed_role_rows, transition_seed_rows); + check_transition_dynamic_binding_rows(transition_dynamic_binding_rows, transition_reverse_seed_role_rows); + check_transition_parameter_rows(transition_parameter_tensors, transition_parameter_rows); + const int64_t window_len = reverse_artifact_binding_window_len(reverse_artifact_binding_rows); + std::vector tensor_required = validate_temporal_reverse_artifact_role_rows(reverse_artifact_role_rows); + validate_temporal_reverse_artifact_access_rows(reverse_artifact_access_rows, tensor_required); + validate_temporal_reverse_artifact_binding_rows( + reverse_artifact_tensors, + reverse_artifact_binding_rows, + tensor_required, + window_len); + TORCH_CHECK(local_step >= 0 && local_step < window_len, "transition group local_step is outside reverse artifact window"); + at::Tensor memory_reference = reverse_artifact_tensors.empty() ? at::Tensor() : reverse_artifact_tensors[0]; + append_registered_transition_reverse_memory_stage_row( + memory_stage_rows, + memory_reference, + local_step, + kRegisteredBackwardMemoryStageTransitionGroupEntry); + validate_registered_fused_program_executor_rows(primitive_rows, forward_executor_rows, "forward_executor_rows"); + validate_registered_fused_program_executor_rows(primitive_rows, reverse_executor_rows, "reverse_executor_rows"); + validate_registered_fused_program_binding_rows( + primitive_rows, + forward_executor_rows, + forward_executor_binding_rows, + kForwardDirectionOpcode, + "forward_executor_binding_rows"); + validate_registered_fused_program_binding_rows( + primitive_rows, + reverse_executor_rows, + reverse_executor_binding_rows, + kReverseDirectionOpcode, + "reverse_executor_binding_rows"); + validate_registered_fused_program_memory_rows(primitive_rows, memory_liveness_rows); + at::Tensor forward_spans = decode_registered_fused_program_spans( + primitive_rows, + forward_executor_rows, + forward_handler_rows, + native_strategy_rows, + forward_executor_binding_rows, + memory_liveness_rows, + kForwardDirectionOpcode, + schema_version, + native_strategy_rows.size(0) > 0, + "registered_transition_group_forward"); + at::Tensor reverse_spans = decode_registered_fused_program_spans( + primitive_rows, + reverse_executor_rows, + reverse_handler_rows, + native_strategy_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + kReverseDirectionOpcode, + schema_version, + native_strategy_rows.size(0) > 0, + "registered_transition_group_reverse"); + validate_registered_fused_forward_span_memory(forward_spans, memory_liveness_rows); + validate_registered_fused_reverse_span_memory(reverse_spans, memory_liveness_rows); + bool saw_transition_handler = false; + for (int64_t span_index = 0; span_index < reverse_spans.size(0); ++span_index) { + const RegisteredFusedProgramSpan span = registered_fused_program_span_at(reverse_spans, span_index); + const RegisteredReverseExecutorHandler& handler = registered_reverse_executor_handler_for_span(span); + if (!handler.runs_transition_adjoint) { + continue; + } + bind_transition_parameter_tensors( + program_tensors, + program_tensor_binding_rows, + transition_parameter_tensors, + transition_parameter_rows, + span.bucket_ordinal); + saw_transition_handler = true; + } + TORCH_CHECK(saw_transition_handler, "registered transition group found no reverse transition handler span"); + append_registered_transition_reverse_memory_stage_row( + memory_stage_rows, + memory_reference, + local_step, + kRegisteredBackwardMemoryStageTransitionGroupParamsBound); + bind_transition_dynamic_tensors_for_handlers( + program_tensors, + program_tensor_binding_rows, + primitive_rows, + forward_spans, + reverse_spans, + forward_executor_binding_rows, + reverse_executor_binding_rows, + transition_dynamic_binding_rows, + transition_reverse_seed_role_rows, + reverse_artifact_tensors, + reverse_artifact_binding_rows, + reverse_artifact_access_rows, + transition_seed_tensors, + transition_seed_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + local_step); + append_registered_transition_reverse_memory_stage_row( + memory_stage_rows, + memory_reference, + local_step, + kRegisteredBackwardMemoryStageTransitionGroupDynamicBound); + program_tensors = flat_bucket_registered_temporal_fused_forward_transition_program_cuda( + program_tensors, + program_tensor_binding_rows, + runtime_buffer_tensors, + runtime_buffer_rows, + primitive_rows, + forward_executor_rows, + forward_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + transition_primitive_callable_rows, + forward_executor_binding_rows, + memory_liveness_rows, + at::empty({0, 3}, forward_executor_binding_rows.options()), + false, + false, + schema_version); + append_registered_transition_reverse_memory_stage_row( + memory_stage_rows, + memory_reference, + local_step, + kRegisteredBackwardMemoryStageTransitionGroupAfterForwardRecompute); + program_tensors = flat_bucket_registered_temporal_fused_reverse_transition_program_cuda( + program_tensors, + program_tensor_binding_rows, + primitive_rows, + reverse_executor_rows, + reverse_handler_rows, + native_strategy_rows, + native_callable_binding_schema_rows, + native_callable_output_rows, + transition_primitive_callable_rows, + reverse_executor_binding_rows, + memory_liveness_rows, + schema_version, + return_state_grads); + append_registered_transition_reverse_memory_stage_row( + memory_stage_rows, + memory_reference, + local_step, + kRegisteredBackwardMemoryStageTransitionGroupAfterReversePrimitive); + return program_tensors; +} + +inline std::vector tensor_shape_vector(const at::Tensor& tensor) { + std::vector shape; + shape.reserve(static_cast(tensor.dim())); + for (int64_t dim = 0; dim < tensor.dim(); ++dim) { + shape.push_back(tensor.size(dim)); + } + return shape; +} + +inline bool tensors_have_same_shape(const at::Tensor& lhs, const at::Tensor& rhs) { + if (lhs.dim() != rhs.dim()) { + return false; + } + for (int64_t dim = 0; dim < lhs.dim(); ++dim) { + if (lhs.size(dim) != rhs.size(dim)) { + return false; + } + } + return true; +} + +inline bool can_transpose_last_two_dims_to_target(const at::Tensor& grad, const at::Tensor& target) { + if (grad.dim() != target.dim() || grad.dim() < 2) { + return false; + } + const int64_t dims = grad.dim(); + for (int64_t dim = 0; dim < dims - 2; ++dim) { + if (grad.size(dim) != target.size(dim)) { + return false; + } + } + return grad.size(dims - 2) == target.size(dims - 1) && + grad.size(dims - 1) == target.size(dims - 2); +} + +at::Tensor align_registered_transition_grad_to_target( + at::Tensor grad, + const at::Tensor& target, + bool reduce_leading_dims) { + TORCH_CHECK( + grad.is_cuda() && grad.is_contiguous() && grad.scalar_type() == at::kFloat, + "transition reducer source grad must be contiguous CUDA float32"); + TORCH_CHECK( + target.is_cuda() && target.is_contiguous() && target.scalar_type() == at::kFloat, + "transition reducer target parameter must be contiguous CUDA float32"); + if (tensors_have_same_shape(grad, target)) { + return grad.contiguous(); + } + if (can_transpose_last_two_dims_to_target(grad, target)) { + return grad.transpose(grad.dim() - 2, grad.dim() - 1).contiguous(); + } + if (grad.numel() == target.numel()) { + return grad.reshape(tensor_shape_vector(target)).contiguous(); + } + at::Tensor reduced = grad; + if (reduce_leading_dims) { + while (reduced.dim() > target.dim()) { + reduced = at::sum(reduced, {0}); + } + } + if (tensors_have_same_shape(reduced, target)) { + return reduced.contiguous(); + } + if (can_transpose_last_two_dims_to_target(reduced, target)) { + return reduced.transpose(reduced.dim() - 2, reduced.dim() - 1).contiguous(); + } + if (reduced.dim() == target.dim()) { + for (int64_t dim = 0; dim < reduced.dim(); ++dim) { + if (reduced.size(dim) == target.size(dim)) { + continue; + } + TORCH_CHECK( + target.size(dim) == 1, + "transition reducer cannot align source grad dimension ", + dim, + ": source=", + reduced.size(dim), + "; target=", + target.size(dim)); + reduced = at::sum(reduced, {dim}, true); + } + } + if (tensors_have_same_shape(reduced, target)) { + return reduced.contiguous(); + } + if (can_transpose_last_two_dims_to_target(reduced, target)) { + return reduced.transpose(reduced.dim() - 2, reduced.dim() - 1).contiguous(); + } + if (reduced.numel() == target.numel()) { + return reduced.reshape(tensor_shape_vector(target)).contiguous(); + } + TORCH_CHECK( + false, + "transition reducer could not align source gradient to trainable parameter shape"); +} + +at::Tensor reduce_registered_transition_source_tensors( + const std::vector& tensors, + int64_t tensor_start, + int64_t tensor_count) { + TORCH_CHECK(tensor_count > 0, "transition reducer source rows must reference at least one tensor"); + TORCH_CHECK(tensor_start >= 0, "transition reducer source tensor_start must be non-negative"); + TORCH_CHECK( + tensor_start + tensor_count <= static_cast(tensors.size()), + "transition reducer source row references tensors outside the source tensor table"); + const at::Tensor& first = tensors[static_cast(tensor_start)]; + TORCH_CHECK( + first.is_cuda() && first.is_contiguous() && first.scalar_type() == at::kFloat, + "transition reducer source tensors must be contiguous CUDA float32"); + if (tensor_count == 1) { + return first.contiguous(); + } + at::Tensor reduced = at::zeros_like(first); + for (int64_t offset = 0; offset < tensor_count; ++offset) { + const at::Tensor& tensor = tensors[static_cast(tensor_start + offset)]; + TORCH_CHECK( + tensor.is_cuda() && tensor.is_contiguous() && tensor.scalar_type() == at::kFloat, + "transition reducer source tensors must be contiguous CUDA float32"); + TORCH_CHECK(tensor.sizes() == first.sizes(), "transition reducer source tensor shapes must match"); + reduced.add_(tensor); + } + return reduced; +} + +int64_t find_registered_transition_source_row( + const at::Tensor& transition_source_rows, + int64_t request_index, + int64_t source_name_index) { + const int64_t* source_rows = transition_source_rows.data_ptr(); + for (int64_t row = 0; row < transition_source_rows.size(0); ++row) { + const int64_t* item = source_rows + row * 8; + if (item[1] == request_index && item[2] == source_name_index) { + return row; + } + } + return -1; +} + +at::Tensor transition_recurrent_bias_full_grad( + const at::Tensor& source_grad, + const at::Tensor& recurrent_cell_idx, + int64_t coord_count) { + at::Tensor source_rows = source_grad.dim() == 3 ? source_grad.squeeze(0) : source_grad; + check_cuda_float_rank2(source_rows, "transition recurrent bias source grad"); + check_cuda_long_rank1(recurrent_cell_idx, "transition recurrent bias recurrent_cell_idx"); + TORCH_CHECK( + recurrent_cell_idx.size(0) == source_rows.size(0), + "transition recurrent bias index length must match source rows"); + at::Tensor full = at::zeros({coord_count, source_rows.size(1)}, source_rows.options()); + full.index_add_(0, recurrent_cell_idx, source_rows); + return full; +} diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/flat_buckets.py b/src/cortical/fabric/backend/cuda/sequence_surface/flat_buckets.py deleted file mode 100644 index c90110b0..00000000 --- a/src/cortical/fabric/backend/cuda/sequence_surface/flat_buckets.py +++ /dev/null @@ -1,1474 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from typing import Any, cast - -import torch -from tensordict import TensorDict, TensorDictBase - -from cortical.fabric.backend.cuda import transition_execution -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - BACKEND_STATE_CACHE_SPEC_KEY as _BACKEND_STATE_CACHE_SPEC_KEY, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - PHYSICAL_TRANSITION_BACKWARD_EXECUTOR as _PHYSICAL_TRANSITION_BACKWARD_EXECUTOR, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - POPULATION_STATIC_TENSORS_KEY as _POPULATION_STATIC_TENSORS_KEY, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - BackendStateCacheSpec as _BackendStateCacheSpec, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - PopulationStateSpec as _PopulationStateSpec, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - active_population_names as _active_population_names, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - backend_order_population_buckets as _backend_order_population_buckets, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - flat_bucket_trainable_items as _flat_bucket_trainable_items, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - population_static_tensors_for as _population_static_tensors_for, -) -from cortical.fabric.runtime.state import ( - flatten_backend_packed_state as _flatten_backend_packed_state, -) -from cortical.fabric.runtime.state import ( - unflatten_backend_packed_state as _unflatten_backend_packed_state, -) - - -@dataclass(frozen=True) -class BackendOrderTransitionPopulationParamGrads: - materialized_param_grads: dict[str, torch.Tensor] - static_source_grads: dict[str, torch.Tensor] - - -@dataclass(frozen=True) -class BackendOrderTransitionParamGrads: - by_population: dict[str, BackendOrderTransitionPopulationParamGrads] - - -@dataclass(frozen=True) -class BackendOrderTransitionStepResult: - recurrent_hidden: torch.Tensor - next_state: TensorDict - backward_tape_by_population: dict[str, object] - - -def _backend_grad_tree_has_tensor(value: object) -> bool: - if torch.is_tensor(value): - return True - if value is None: - return False - if isinstance(value, TensorDictBase): - return any(_backend_grad_tree_has_tensor(item) for item in value.values()) - if isinstance(value, Mapping): - return any(_backend_grad_tree_has_tensor(item) for item in value.values()) - return False - - -def run_transition_buckets_step( - runtime: Any, - recurrent_msg: torch.Tensor, - state: TensorDict, - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = True, -) -> tuple[torch.Tensor, TensorDict]: - num_recurrent = int(runtime.recurrent_cell_idx.numel()) - if num_recurrent == 0: - return recurrent_msg.new_empty(batch_size, 0, runtime.hidden_size), TensorDict({}, batch_size=[]) - recurrent_next = recurrent_msg.new_empty(batch_size, num_recurrent, runtime.hidden_size) - next_state = TensorDict({}, batch_size=[]) - active_populations = _active_population_names(runtime) - for name in active_populations: - recurrent_idx = runtime._population_recurrent_indices(name) - population_y, population_state = run_transition_bucket_step( - runtime, - name, - recurrent_msg, - state.get(name), - resets=resets, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=materialize_next_state, - ) - recurrent_next[:, recurrent_idx, :] = population_y.to(dtype=recurrent_next.dtype) - next_state[name] = population_state - return recurrent_next, next_state - - -def run_backend_order_transition_buckets_step( - runtime: Any, - recurrent_msg: torch.Tensor, - population_state: TensorDict, - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = True, -) -> tuple[torch.Tensor, TensorDict]: - if ( - step_population_state_cache is not None - and materialize_next_state - and torch.is_grad_enabled() - and recurrent_msg.requires_grad - ): - return _run_backend_order_transition_buckets_step_cached_physical_autograd( - runtime, - recurrent_msg, - step_population_state_cache, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - ) - if ( - step_population_state_cache is None - and materialize_next_state - and torch.is_grad_enabled() - and recurrent_msg.requires_grad - ): - return _run_backend_order_transition_buckets_step_physical_autograd( - runtime, - recurrent_msg, - population_state, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - ) - if step_population_state_cache is not None: - return _run_backend_order_transition_buckets_step_cached_eager( - runtime, - recurrent_msg, - step_population_state_cache, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - materialize_next_state=materialize_next_state, - ) - return _run_backend_order_transition_buckets_step_eager( - runtime, - recurrent_msg, - population_state, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=materialize_next_state, - ) - - -def _run_backend_order_transition_buckets_step_eager( - runtime: Any, - recurrent_msg: torch.Tensor, - population_state: TensorDict, - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = True, -) -> tuple[torch.Tensor, TensorDict]: - num_recurrent = int(runtime.recurrent_cell_idx.numel()) - if num_recurrent == 0: - return recurrent_msg.new_empty(batch_size, 0, runtime.hidden_size), TensorDict({}, batch_size=[]) - recurrent_next = recurrent_msg.new_empty(batch_size, num_recurrent, runtime.hidden_size) - next_state = TensorDict({}, batch_size=[]) - for name in _active_population_names(runtime): - start, stop = runtime._population_backend_recurrent_slice(name) - population_msg = recurrent_msg[:, start:stop, :] - population_y, population_next = run_transition_bucket_step( - runtime, - name, - population_msg, - population_state.get(name), - resets=resets, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=materialize_next_state, - message_already_population_ordered=True, - ) - recurrent_next[:, start:stop, :] = population_y.to(dtype=recurrent_next.dtype) - next_state[name] = population_next - return recurrent_next, next_state - - -def _run_backend_order_transition_buckets_step_cached_eager( - runtime: Any, - recurrent_msg: torch.Tensor, - step_population_state_cache: dict[str, object], - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - materialize_next_state: bool = True, -) -> tuple[torch.Tensor, TensorDict]: - result = run_backend_order_transition_buckets_step_cached_eager_result( - runtime, - recurrent_msg, - step_population_state_cache, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - materialize_next_state=materialize_next_state, - transition_tape_mode="disabled", - ) - return result.recurrent_hidden, result.next_state - - -def run_backend_order_transition_buckets_step_cached_eager_result( - runtime: Any, - recurrent_msg: torch.Tensor, - step_population_state_cache: dict[str, object], - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - materialize_next_state: bool = True, - transition_tape_mode: str = "disabled", -) -> BackendOrderTransitionStepResult: - if transition_tape_mode not in {"disabled", "input_projection", "full"}: - raise RuntimeError(f"Unsupported flat-bucket transition tape mode {transition_tape_mode!r}") - num_recurrent = int(runtime.recurrent_cell_idx.numel()) - if num_recurrent == 0: - return BackendOrderTransitionStepResult( - recurrent_hidden=recurrent_msg.new_empty(batch_size, 0, runtime.hidden_size), - next_state=TensorDict({}, batch_size=[]), - backward_tape_by_population={}, - ) - recurrent_next = recurrent_msg.new_empty(batch_size, num_recurrent, runtime.hidden_size) - next_state = TensorDict({}, batch_size=[]) - backward_tape_by_population: dict[str, object] = {} - for bucket in _backend_order_population_buckets(runtime, static_tensors): - result = transition_execution.lower_backend_population_transition_forward_result_shared( - runtime, - population_name=bucket.name, - recurrent_msg=recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :], - packed_state_before=step_population_state_cache.get(bucket.name), - population_reset_step=resets, - static_tensors=bucket.static_tensors, - materialize_recurrent_kv=False, - materialize_backward_tape=transition_tape_mode != "disabled", - materialize_next_state=materialize_next_state, - materialize_diagonal_preproj_tape=transition_tape_mode == "full", - materialize_recurrence_backward_tape=transition_tape_mode == "full", - materialize_trace_state_next=materialize_next_state, - ) - recurrent_next[:, bucket.backend_start : bucket.backend_stop, :] = result.recurrent_hidden.to( - dtype=recurrent_next.dtype - ) - if result.next_packed_state is not None: - step_population_state_cache[bucket.name] = result.next_packed_state - next_state[bucket.name] = TensorDict({}, batch_size=[]) - if result.backward_tape is not None: - backward_tape_by_population[bucket.name] = result.backward_tape - return BackendOrderTransitionStepResult( - recurrent_hidden=recurrent_next, - next_state=next_state, - backward_tape_by_population=backward_tape_by_population, - ) - - -def _run_backend_order_transition_buckets_step_physical_autograd( - runtime: Any, - recurrent_msg: torch.Tensor, - population_state: TensorDict, - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], -) -> tuple[torch.Tensor, TensorDict]: - del batch_size - state_specs, state_tensors = _flatten_population_state_inputs(runtime, population_state) - trainable_items = _flat_bucket_trainable_items(runtime, static_tensors) - outputs = _BackendOrderTransitionBucketsStepFunction.apply( - runtime, - static_tensors, - state_specs, - tuple(name for name, _param in trainable_items), - recurrent_msg, - resets, - *state_tensors, - *(param for _name, param in trainable_items), - ) - recurrent_next = cast(torch.Tensor, outputs[0]) - next_state = _unflatten_population_state_outputs( - state_specs, - cast(tuple[torch.Tensor, ...], tuple(outputs[1:])), - ) - runtime._last_flat_bucket_transition_backward_executor = _PHYSICAL_TRANSITION_BACKWARD_EXECUTOR - return recurrent_next, next_state - - -def _run_backend_order_transition_buckets_step_cached_physical_autograd( - runtime: Any, - recurrent_msg: torch.Tensor, - step_population_state_cache: dict[str, object], - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], -) -> tuple[torch.Tensor, TensorDict]: - del batch_size - cache_specs, cache_tensors = _flatten_backend_state_cache_inputs( - runtime, - step_population_state_cache, - static_tensors, - ) - trainable_items = _flat_bucket_trainable_items(runtime, static_tensors) - outputs = _BackendOrderCachedTransitionBucketsStepFunction.apply( - runtime, - static_tensors, - cache_specs, - tuple(name for name, _param in trainable_items), - recurrent_msg, - resets, - *cache_tensors, - *(param for _name, param in trainable_items), - ) - recurrent_next = cast(torch.Tensor, outputs[0]) - next_cache = _unflatten_backend_state_cache_outputs( - cache_specs, - cast(tuple[torch.Tensor, ...], tuple(outputs[1:])), - ) - step_population_state_cache.update(next_cache) - runtime._last_flat_bucket_transition_backward_executor = _PHYSICAL_TRANSITION_BACKWARD_EXECUTOR - return recurrent_next, TensorDict( - {name: TensorDict({}, batch_size=[]) for name, _keys in cache_specs}, - batch_size=[], - ) - - -def _flatten_population_state_inputs( - runtime: Any, - population_state: TensorDict, -) -> tuple[_PopulationStateSpec, tuple[torch.Tensor, ...]]: - specs: list[tuple[str, tuple[str, ...]]] = [] - tensors: list[torch.Tensor] = [] - for population_name in _active_population_names(runtime): - state_value = population_state.get(population_name) - if not isinstance(state_value, TensorDictBase): - raise RuntimeError(f"CUDA flat-bucket physical transition requires TensorDict state for {population_name}") - keys = tuple(runtime._cell_spec_for_population(population_name).state_schema.keys) - specs.append((population_name, keys)) - for key in keys: - tensor = state_value[key] - if not torch.is_tensor(tensor): - raise RuntimeError( - f"CUDA flat-bucket physical transition state {population_name}.{key} is not a tensor" - ) - tensors.append(tensor) - return tuple(specs), tuple(tensors) - - -def _unflatten_population_state_outputs( - specs: _PopulationStateSpec, - tensors: tuple[torch.Tensor, ...], -) -> TensorDict: - state = TensorDict({}, batch_size=[]) - offset = 0 - for population_name, keys in specs: - leaves: dict[str, torch.Tensor] = {} - first: torch.Tensor | None = None - for key in keys: - tensor = tensors[offset] - offset += 1 - leaves[key] = tensor - if first is None: - first = tensor - if first is None: - state[population_name] = TensorDict({}, batch_size=[]) - else: - state[population_name] = TensorDict( - leaves, - batch_size=list(first.shape[:2]), - device=first.device, - ) - return state - - -def _unflatten_population_state_grads( - specs: _PopulationStateSpec, - tensors: tuple[torch.Tensor | None, ...], -) -> dict[str, dict[str, torch.Tensor | None]]: - state: dict[str, dict[str, torch.Tensor | None]] = {} - offset = 0 - for population_name, keys in specs: - leaves: dict[str, torch.Tensor | None] = {} - for key in keys: - leaves[key] = tensors[offset] - offset += 1 - state[population_name] = leaves - return state - - -def _population_grad_state_to_backend_grad_state( - runtime: Any, - population_name: str, - population_grad_state: Mapping[str, torch.Tensor | None] | None, -) -> dict[str, torch.Tensor | None] | None: - if population_grad_state is None: - return None - backend_grads: dict[str, torch.Tensor | None] = {} - for state_name in runtime._cell_spec_for_population(population_name).state_schema.keys: - grad = population_grad_state.get(state_name) - backend_grads[state_name] = grad.permute(1, 0, 2).contiguous() if torch.is_tensor(grad) else None - return backend_grads - - -def _flatten_population_state_grad_outputs( - specs: _PopulationStateSpec, - grad_state: TensorDict, -) -> tuple[torch.Tensor | None, ...]: - grads: list[torch.Tensor | None] = [] - for population_name, keys in specs: - population_grad = grad_state.get(population_name) - for key in keys: - if isinstance(population_grad, TensorDictBase): - grad = population_grad.get(key) - grads.append(grad if torch.is_tensor(grad) else None) - else: - grads.append(None) - return tuple(grads) - - -def _flatten_backend_state_grad_outputs( - packed_state_keys: tuple[str, ...] | None, - grad_state: Mapping[str, torch.Tensor | None] | torch.Tensor | None, -) -> tuple[torch.Tensor | None, ...]: - if packed_state_keys is None: - return (cast(torch.Tensor | None, grad_state) if torch.is_tensor(grad_state) else None,) - if not isinstance(grad_state, Mapping): - return tuple(None for _key in packed_state_keys) - return tuple(cast(torch.Tensor | None, grad_state.get(key)) for key in packed_state_keys) - - -def _backend_state_cache_specs( - runtime: Any, - step_population_state_cache: Mapping[str, object], - static_tensors: dict[str, object], -) -> _BackendStateCacheSpec: - cached = static_tensors.get(_BACKEND_STATE_CACHE_SPEC_KEY) - if isinstance(cached, tuple): - return cast(_BackendStateCacheSpec, cached) - specs: list[tuple[str, tuple[str, ...] | None]] = [] - for population_name in _active_population_names(runtime): - packed_state = step_population_state_cache.get(population_name) - if packed_state is None: - raise RuntimeError("CUDA physical transition backward requires materialized backend state cache") - keys, _state_tensors = _flatten_backend_packed_state(packed_state) - specs.append((population_name, keys)) - cache_specs = tuple(specs) - static_tensors[_BACKEND_STATE_CACHE_SPEC_KEY] = cache_specs - return cache_specs - - -def _backend_state_cache_tensors( - specs: _BackendStateCacheSpec, - step_population_state_cache: Mapping[str, object], -) -> tuple[torch.Tensor, ...]: - tensors: list[torch.Tensor] = [] - for population_name, expected_keys in specs: - packed_state = step_population_state_cache.get(population_name) - if packed_state is None: - raise RuntimeError("CUDA physical transition backward requires materialized backend state cache") - keys, state_tensors = _flatten_backend_packed_state(packed_state) - if keys != expected_keys: - raise RuntimeError("CUDA physical transition backward requires stable backend cache keys") - tensors.extend(state_tensors) - return tuple(tensors) - - -def _flatten_backend_state_cache_inputs( - runtime: Any, - step_population_state_cache: Mapping[str, object], - static_tensors: dict[str, object], -) -> tuple[_BackendStateCacheSpec, tuple[torch.Tensor, ...]]: - specs = _backend_state_cache_specs(runtime, step_population_state_cache, static_tensors) - return specs, _backend_state_cache_tensors(specs, step_population_state_cache) - - -def _unflatten_backend_state_cache_outputs( - specs: _BackendStateCacheSpec, - tensors: tuple[torch.Tensor, ...], -) -> dict[str, object]: - cache: dict[str, object] = {} - offset = 0 - for population_name, keys in specs: - state_tensor_count = 1 if keys is None else len(keys) - state_tensors = tuple(tensors[offset : offset + state_tensor_count]) - offset += state_tensor_count - cache[population_name] = _unflatten_backend_packed_state(keys, state_tensors) - return cache - - -def _unflatten_backend_state_cache_grads( - specs: _BackendStateCacheSpec, - tensors: tuple[torch.Tensor | None, ...], -) -> dict[str, object]: - cache: dict[str, object] = {} - offset = 0 - for population_name, keys in specs: - state_tensor_count = 1 if keys is None else len(keys) - state_tensors = tuple(tensors[offset : offset + state_tensor_count]) - offset += state_tensor_count - cache[population_name] = _unflatten_backend_packed_state(keys, state_tensors) - return cache - - -class _BackendOrderTransitionBucketsStepFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - runtime: Any, - static_tensors: dict[str, object], - state_specs: _PopulationStateSpec, - trainable_param_names: tuple[str, ...], - recurrent_msg: torch.Tensor, - resets: torch.Tensor | None, - *state_tensors_and_params: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - state_tensor_count = sum(len(keys) for _population_name, keys in state_specs) - state_tensors = tuple(state_tensors_and_params[:state_tensor_count]) - trainable_params = tuple(state_tensors_and_params[state_tensor_count:]) - ctx.runtime = runtime - ctx.static_tensors = static_tensors - ctx.state_specs = state_specs - ctx.trainable_param_names = trainable_param_names - ctx.state_tensor_count = state_tensor_count - ctx.trainable_param_count = len(trainable_params) - ctx.has_resets = torch.is_tensor(resets) - tensors_to_save = ( - recurrent_msg, - *((resets,) if torch.is_tensor(resets) else ()), - *state_tensors, - *trainable_params, - ) - ctx.save_for_backward(*tensors_to_save) - population_state = _unflatten_population_state_outputs(state_specs, state_tensors) - with torch.no_grad(): - recurrent_next, next_state = _run_backend_order_transition_buckets_step_eager( - runtime, - recurrent_msg, - population_state, - resets=resets, - batch_size=int(recurrent_msg.shape[0]), - static_tensors=static_tensors, - step_population_state_cache=None, - materialize_next_state=True, - ) - state_outputs = _flatten_population_state_grad_outputs(state_specs, next_state) - if any(output is None for output in state_outputs): - raise RuntimeError("CUDA flat-bucket physical transition forward produced incomplete population state") - return (recurrent_next, *cast(tuple[torch.Tensor, ...], state_outputs)) - - @staticmethod - def backward( - ctx: Any, - *grad_outputs: torch.Tensor | None, - ) -> tuple[object, ...]: - saved = ctx.saved_tensors - offset = 0 - recurrent_msg = saved[offset] - offset += 1 - if ctx.has_resets: - resets = saved[offset] - offset += 1 - else: - resets = None - state_tensor_count = int(ctx.state_tensor_count) - state_tensors = tuple(saved[offset : offset + state_tensor_count]) - offset += state_tensor_count - trainable_params = tuple(saved[offset:]) - grad_recurrent_hidden = grad_outputs[0] - grad_next_state = _unflatten_population_state_grads( - ctx.state_specs, - cast(tuple[torch.Tensor | None, ...], tuple(grad_outputs[1:])), - ) - population_state_before = _unflatten_population_state_outputs(ctx.state_specs, state_tensors) - grad_recurrent_msg, grad_state_before, param_grads = run_backend_order_transition_buckets_backward_step( - ctx.runtime, - recurrent_msg, - population_state_before, - grad_recurrent_hidden=grad_recurrent_hidden, - grad_next_population_state=grad_next_state, - resets=resets, - static_tensors=ctx.static_tensors, - trainable_params=cast(tuple[torch.Tensor, ...], trainable_params), - trainable_param_names=ctx.trainable_param_names, - need_grad_state_before=True, - ) - state_grads = _flatten_population_state_grad_outputs(ctx.state_specs, grad_state_before) - return ( - None, - None, - None, - None, - grad_recurrent_msg, - None, - *state_grads, - *param_grads, - ) - - -class _BackendOrderCachedTransitionBucketsStepFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - runtime: Any, - static_tensors: dict[str, object], - cache_specs: _BackendStateCacheSpec, - trainable_param_names: tuple[str, ...], - recurrent_msg: torch.Tensor, - resets: torch.Tensor | None, - *state_tensors_and_params: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - state_tensor_count = sum(1 if keys is None else len(keys) for _population_name, keys in cache_specs) - state_tensors = tuple(state_tensors_and_params[:state_tensor_count]) - trainable_params = tuple(state_tensors_and_params[state_tensor_count:]) - ctx.runtime = runtime - ctx.static_tensors = static_tensors - ctx.cache_specs = cache_specs - ctx.trainable_param_names = trainable_param_names - ctx.trainable_param_shapes = tuple(tuple(param.shape) for param in trainable_params) - ctx.state_tensor_count = state_tensor_count - ctx.has_resets = torch.is_tensor(resets) - ctx.save_for_backward( - recurrent_msg, - *((resets,) if torch.is_tensor(resets) else ()), - *state_tensors, - ) - step_population_state_cache = _unflatten_backend_state_cache_outputs(cache_specs, state_tensors) - with torch.no_grad(): - recurrent_next, _next_state = _run_backend_order_transition_buckets_step_cached_eager( - runtime, - recurrent_msg, - step_population_state_cache, - resets=resets, - batch_size=int(recurrent_msg.shape[0]), - static_tensors=static_tensors, - materialize_next_state=True, - ) - next_state_tensors: list[torch.Tensor] = [] - for population_name, keys in cache_specs: - next_keys, tensors = _flatten_backend_packed_state(step_population_state_cache[population_name]) - if next_keys != keys: - raise RuntimeError("CUDA cached transition backward requires stable backend cache keys") - next_state_tensors.extend(tensors) - return (recurrent_next, *next_state_tensors) - - @staticmethod - def backward( - ctx: Any, - *grad_outputs: torch.Tensor | None, - ) -> tuple[object, ...]: - saved = ctx.saved_tensors - offset = 0 - recurrent_msg = saved[offset] - offset += 1 - if ctx.has_resets: - resets = saved[offset] - offset += 1 - else: - resets = None - state_tensor_count = int(ctx.state_tensor_count) - state_tensors = tuple(saved[offset : offset + state_tensor_count]) - offset += state_tensor_count - grad_recurrent_hidden = grad_outputs[0] - state_cache_before = _unflatten_backend_state_cache_outputs(ctx.cache_specs, state_tensors) - grad_next_state_cache = _unflatten_backend_state_cache_grads( - ctx.cache_specs, - cast(tuple[torch.Tensor | None, ...], tuple(grad_outputs[1:])), - ) - grad_recurrent_msg = torch.empty_like(recurrent_msg) - param_grad_accum: list[torch.Tensor | None] = [None] * len(ctx.trainable_param_names) - grad_state_tensors: list[torch.Tensor | None] = [] - buckets = _backend_order_population_buckets(ctx.runtime, ctx.static_tensors) - if len(buckets) != len(ctx.cache_specs): - raise RuntimeError("CUDA cached transition backward requires stable backend-order population buckets") - for (population_name, keys), bucket in zip( - ctx.cache_specs, - buckets, - strict=True, - ): - if bucket.name != population_name: - raise RuntimeError("CUDA cached transition backward bucket order does not match cache spec order") - population_grad_msg, population_grad_state_before, population_param_grads = ( - ctx.runtime._run_backend_state_public_backward_phase( - population_name=population_name, - recurrent_msg=recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :], - recurrent_hidden_tape=None, - packed_state_before=state_cache_before[population_name], - population_reset_step=resets, - grad_next_packed_state=grad_next_state_cache[population_name], - grad_recurrent_hidden=None - if grad_recurrent_hidden is None - else grad_recurrent_hidden[:, bucket.backend_start : bucket.backend_stop, :], - grad_recurrent_k=None, - grad_recurrent_v=None, - trainable_params=(), - trainable_param_names=ctx.trainable_param_names, - trainable_param_shapes=ctx.trainable_param_shapes, - sequence_static_tensors=bucket.static_tensors, - param_static_tensors=ctx.static_tensors, - transition_backward_tape=None, - need_grad_packed_state_before=True, - device=recurrent_msg.device, - dtype=recurrent_msg.dtype, - active_receiver_window=None, - ) - ) - if population_grad_msg is not None: - grad_recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :] = population_grad_msg.to( - dtype=grad_recurrent_msg.dtype - ) - else: - grad_recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :].zero_() - grad_state_tensors.extend(_flatten_backend_state_grad_outputs(keys, population_grad_state_before)) - for index, grad_param in enumerate(population_param_grads): - if grad_param is None: - continue - param_grad_accum[index] = ( - grad_param if param_grad_accum[index] is None else param_grad_accum[index] + grad_param - ) - return ( - None, - None, - None, - None, - grad_recurrent_msg, - None, - *grad_state_tensors, - *param_grad_accum, - ) - - -class _CachedBackendTransitionBucketStepFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - runtime: Any, - population_name: str, - static_tensors: dict[str, object], - packed_state_keys: tuple[str, ...] | None, - trainable_param_names: tuple[str, ...], - recurrent_msg: torch.Tensor, - resets: torch.Tensor | None, - *state_tensors_and_params: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - state_tensor_count = 1 if packed_state_keys is None else len(packed_state_keys) - state_tensors = tuple(state_tensors_and_params[:state_tensor_count]) - trainable_params = tuple(state_tensors_and_params[state_tensor_count:]) - ctx.runtime = runtime - ctx.population_name = population_name - ctx.static_tensors = static_tensors - ctx.packed_state_keys = packed_state_keys - ctx.trainable_param_names = trainable_param_names - ctx.state_tensor_count = state_tensor_count - ctx.has_resets = torch.is_tensor(resets) - ctx.save_for_backward( - recurrent_msg, - *((resets,) if torch.is_tensor(resets) else ()), - *state_tensors, - *trainable_params, - ) - packed_state_before = _unflatten_backend_packed_state(packed_state_keys, state_tensors) - with torch.no_grad(): - result = transition_execution.lower_backend_population_transition_forward_result_shared( - runtime, - population_name=population_name, - recurrent_msg=recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=resets, - static_tensors=static_tensors, - materialize_recurrent_kv=False, - materialize_backward_tape=False, - materialize_next_state=True, - materialize_trace_state_next=True, - ) - next_state_keys, next_state_tensors = _flatten_backend_packed_state(result.next_packed_state) - if next_state_keys != packed_state_keys: - raise RuntimeError("CUDA cached transition backward requires stable packed-state keys") - return (result.recurrent_hidden.to(dtype=recurrent_msg.dtype), *next_state_tensors) - - @staticmethod - def backward( - ctx: Any, - *grad_outputs: torch.Tensor | None, - ) -> tuple[object, ...]: - saved = ctx.saved_tensors - offset = 0 - recurrent_msg = saved[offset] - offset += 1 - if ctx.has_resets: - resets = saved[offset] - offset += 1 - else: - resets = None - state_tensor_count = int(ctx.state_tensor_count) - state_tensors = tuple(saved[offset : offset + state_tensor_count]) - offset += state_tensor_count - trainable_params = tuple(saved[offset:]) - grad_recurrent_hidden = grad_outputs[0] - grad_next_packed_state = _unflatten_backend_packed_state( - ctx.packed_state_keys, - cast(tuple[torch.Tensor | None, ...], tuple(grad_outputs[1:])), - ) - packed_state_before = _unflatten_backend_packed_state(ctx.packed_state_keys, state_tensors) - population_grad_msg, population_grad_state_before, population_param_grads = ( - ctx.runtime._run_backend_state_public_backward_phase( - population_name=ctx.population_name, - recurrent_msg=recurrent_msg, - recurrent_hidden_tape=None, - packed_state_before=packed_state_before, - population_reset_step=resets, - grad_next_packed_state=grad_next_packed_state, - grad_recurrent_hidden=grad_recurrent_hidden, - grad_recurrent_k=None, - grad_recurrent_v=None, - trainable_params=cast(tuple[torch.Tensor, ...], trainable_params), - trainable_param_names=ctx.trainable_param_names, - sequence_static_tensors=ctx.static_tensors, - param_static_tensors=ctx.static_tensors, - transition_backward_tape=None, - need_grad_packed_state_before=True, - device=recurrent_msg.device, - dtype=recurrent_msg.dtype, - active_receiver_window=None, - ) - ) - grad_state_tensors = _flatten_backend_state_grad_outputs(ctx.packed_state_keys, population_grad_state_before) - return ( - None, - None, - None, - None, - None, - population_grad_msg, - None, - *grad_state_tensors, - *population_param_grads, - ) - - -def run_backend_order_transition_buckets_backward_step( - runtime: Any, - recurrent_msg: torch.Tensor, - population_state_before: TensorDict, - *, - grad_recurrent_hidden: torch.Tensor | None, - grad_next_population_state: Mapping[str, Mapping[str, torch.Tensor | None]] | None = None, - resets: torch.Tensor | None, - static_tensors: dict[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - need_grad_state_before: bool = True, -) -> tuple[torch.Tensor | None, TensorDict, tuple[torch.Tensor | None, ...]]: - """Run physical transition backward for backend-order flat population buckets.""" - has_next_population_grad = False - if grad_next_population_state is not None: - for population_grad in grad_next_population_state.values(): - if any(torch.is_tensor(grad) for grad in population_grad.values()): - has_next_population_grad = True - break - if grad_recurrent_hidden is None and not has_next_population_grad: - return None, TensorDict({}, batch_size=[]), tuple(None for _ in trainable_params) - num_recurrent = int(runtime.recurrent_cell_idx.numel()) - if num_recurrent == 0: - return None, TensorDict({}, batch_size=[]), tuple(None for _ in trainable_params) - grad_recurrent_msg = torch.empty_like(recurrent_msg) - grad_population_state = TensorDict({}, batch_size=[]) - param_grad_accum: list[torch.Tensor | None] = [None] * len(trainable_params) - buckets = _backend_order_population_buckets(runtime, static_tensors) - for bucket in buckets: - population_state = population_state_before.get(bucket.name) - if not isinstance(population_state, TensorDictBase): - raise RuntimeError(f"Flat-bucket transition backward requires TensorDict state for {bucket.name}") - packed_state_before = runtime._population_state_to_backend_state(bucket.name, population_state) - population_grad_next_state = ( - None - if grad_next_population_state is None - else _population_grad_state_to_backend_grad_state( - runtime, - bucket.name, - grad_next_population_state.get(bucket.name), - ) - ) - ( - population_grad_msg, - population_grad_state_before, - population_param_grads, - ) = runtime._run_backend_state_public_backward_phase( - population_name=bucket.name, - recurrent_msg=recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :], - recurrent_hidden_tape=None, - packed_state_before=packed_state_before, - population_reset_step=resets, - grad_next_packed_state=population_grad_next_state, - grad_recurrent_hidden=None - if grad_recurrent_hidden is None - else grad_recurrent_hidden[:, bucket.backend_start : bucket.backend_stop, :], - grad_recurrent_k=None, - grad_recurrent_v=None, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - sequence_static_tensors=bucket.static_tensors, - param_static_tensors=static_tensors, - transition_backward_tape=None, - need_grad_packed_state_before=need_grad_state_before, - device=recurrent_msg.device, - dtype=recurrent_msg.dtype, - active_receiver_window=None, - ) - if population_grad_msg is not None: - grad_recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :] = population_grad_msg.to( - dtype=grad_recurrent_msg.dtype - ) - else: - grad_recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :].zero_() - if isinstance(population_grad_state_before, Mapping): - grad_population_state[bucket.name] = _partial_backend_grad_state_to_population_state( - population_grad_state_before - ) - else: - grad_population_state[bucket.name] = TensorDict({}, batch_size=[]) - for index, grad_param in enumerate(population_param_grads): - if grad_param is None: - continue - param_grad_accum[index] = ( - grad_param if param_grad_accum[index] is None else param_grad_accum[index] + grad_param - ) - return grad_recurrent_msg, grad_population_state, tuple(param_grad_accum) - - -def run_backend_order_transition_buckets_backward_step_cached( - runtime: Any, - recurrent_msg: torch.Tensor, - backend_state_cache_before: Mapping[str, object], - *, - grad_recurrent_hidden: torch.Tensor | None, - grad_next_backend_state_cache: Mapping[str, object] | None = None, - resets: torch.Tensor | None, - static_tensors: dict[str, object], - trainable_param_names: tuple[str, ...], - trainable_param_shapes: tuple[tuple[int, ...], ...], - need_grad_state_before: bool = True, - forward_tape_by_population: Mapping[str, object] | None = None, -) -> tuple[torch.Tensor | None, dict[str, object], tuple[torch.Tensor | None, ...]]: - """Run physical transition backward from backend-packed temporal tape.""" - has_next_state_grad = False - if grad_next_backend_state_cache is not None: - for grad_state in grad_next_backend_state_cache.values(): - if _backend_grad_tree_has_tensor(grad_state): - has_next_state_grad = True - break - if grad_recurrent_hidden is None and not has_next_state_grad: - return None, {}, tuple(None for _name in trainable_param_names) - num_recurrent = int(runtime.recurrent_cell_idx.numel()) - if num_recurrent == 0: - return None, {}, tuple(None for _name in trainable_param_names) - grad_recurrent_msg = torch.empty_like(recurrent_msg) - grad_state_cache_before: dict[str, object] = {} - param_grad_accum: list[torch.Tensor | None] = [None] * len(trainable_param_names) - buckets = _backend_order_population_buckets(runtime, static_tensors) - for bucket in buckets: - packed_state_before = backend_state_cache_before.get(bucket.name) - if packed_state_before is None: - raise RuntimeError(f"Flat-bucket temporal backward is missing backend state tape for {bucket.name}") - ( - population_grad_msg, - population_grad_state_before, - population_param_grads, - ) = runtime._run_backend_state_public_backward_phase( - population_name=bucket.name, - recurrent_msg=recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :], - recurrent_hidden_tape=None, - packed_state_before=packed_state_before, - population_reset_step=resets, - grad_next_packed_state=None - if grad_next_backend_state_cache is None - else grad_next_backend_state_cache.get(bucket.name), - grad_recurrent_hidden=None - if grad_recurrent_hidden is None - else grad_recurrent_hidden[:, bucket.backend_start : bucket.backend_stop, :], - grad_recurrent_k=None, - grad_recurrent_v=None, - trainable_params=(), - trainable_param_names=trainable_param_names, - trainable_param_shapes=trainable_param_shapes, - sequence_static_tensors=bucket.static_tensors, - param_static_tensors=static_tensors, - transition_backward_tape=None - if forward_tape_by_population is None - else forward_tape_by_population.get(bucket.name), - need_grad_packed_state_before=need_grad_state_before, - device=recurrent_msg.device, - dtype=recurrent_msg.dtype, - active_receiver_window=None, - ) - if population_grad_msg is not None: - grad_recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :] = population_grad_msg.to( - dtype=grad_recurrent_msg.dtype - ) - else: - grad_recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :].zero_() - if need_grad_state_before: - grad_state_cache_before[bucket.name] = population_grad_state_before - for index, grad_param in enumerate(population_param_grads): - if grad_param is None: - continue - param_grad_accum[index] = ( - grad_param if param_grad_accum[index] is None else param_grad_accum[index] + grad_param - ) - return grad_recurrent_msg, grad_state_cache_before, tuple(param_grad_accum) - - -def run_backend_order_transition_buckets_backward_step_cached_unbound( - runtime: Any, - recurrent_msg: torch.Tensor, - backend_state_cache_before: Mapping[str, object], - *, - grad_recurrent_hidden: torch.Tensor | None, - grad_next_backend_state_cache: Mapping[str, object] | None = None, - resets: torch.Tensor | None, - static_tensors: dict[str, object], - need_grad_state_before: bool = True, - forward_tape_by_population: Mapping[str, object] | None = None, -) -> tuple[torch.Tensor | None, dict[str, object], BackendOrderTransitionParamGrads]: - """Run transition backward and return family-local parameter grads before public binding. - - The reverse temporal executor can call this per timestep, accumulate the returned - family-local gradients across time, and bind them to public trainable parameters once. - That keeps recurrent state dependencies in the reverse scan while removing - per-step parameter-binding glue from the hot path. - """ - has_next_state_grad = False - if grad_next_backend_state_cache is not None: - for grad_state in grad_next_backend_state_cache.values(): - if _backend_grad_tree_has_tensor(grad_state): - has_next_state_grad = True - break - if grad_recurrent_hidden is None and not has_next_state_grad: - return None, {}, BackendOrderTransitionParamGrads(by_population={}) - num_recurrent = int(runtime.recurrent_cell_idx.numel()) - if num_recurrent == 0: - return None, {}, BackendOrderTransitionParamGrads(by_population={}) - grad_recurrent_msg = torch.empty_like(recurrent_msg) - grad_state_cache_before: dict[str, object] = {} - param_grads_by_population: dict[str, BackendOrderTransitionPopulationParamGrads] = {} - buckets = _backend_order_population_buckets(runtime, static_tensors) - for bucket in buckets: - packed_state_before = backend_state_cache_before.get(bucket.name) - if packed_state_before is None: - raise RuntimeError(f"Flat-bucket temporal backward is missing backend state tape for {bucket.name}") - state_public_profile_name = runtime._state_public_backward_profile_name_for_population(bucket.name) - transition_static_tensors = runtime._cached_receiver_window_static_tensors(bucket.static_tensors, None) - if not transition_static_tensors: - transition_static_tensors = runtime._materialize_inference_static_tensors( - device=recurrent_msg.device, - dtype=recurrent_msg.dtype, - include_backend_cell_tensors=False, - ) - with ( - torch.no_grad(), - torch.profiler.record_function(state_public_profile_name), - ): - transition_backward = runtime._lower_backend_population_transition_backward_shared( - population_name=bucket.name, - recurrent_msg=recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :], - packed_state_before=packed_state_before, - population_reset_step=resets, - static_tensors=transition_static_tensors, - grad_next_packed_state=None - if grad_next_backend_state_cache is None - else grad_next_backend_state_cache.get(bucket.name), - grad_recurrent_hidden=None - if grad_recurrent_hidden is None - else grad_recurrent_hidden[:, bucket.backend_start : bucket.backend_stop, :], - need_grad_packed_state_before=need_grad_state_before, - forward_tape=None - if forward_tape_by_population is None - else forward_tape_by_population.get(bucket.name), - ) - population_grad_msg = cast(torch.Tensor | None, transition_backward.grad_recurrent_msg) - if population_grad_msg is not None: - grad_recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :] = population_grad_msg.to( - dtype=grad_recurrent_msg.dtype - ) - else: - grad_recurrent_msg[:, bucket.backend_start : bucket.backend_stop, :].zero_() - if need_grad_state_before: - grad_state_cache_before[bucket.name] = transition_backward.grad_packed_state_before - param_grads_by_population[bucket.name] = BackendOrderTransitionPopulationParamGrads( - materialized_param_grads=dict(transition_backward.materialized_param_grads), - static_source_grads=dict(transition_backward.static_source_grads), - ) - return ( - grad_recurrent_msg, - grad_state_cache_before, - BackendOrderTransitionParamGrads(by_population=param_grads_by_population), - ) - - -def _partial_backend_grad_state_to_population_state(backend_state: Mapping[str, object]) -> TensorDict: - leaves: dict[str, torch.Tensor] = {} - batch_size: list[int] | None = None - for state_name, grad in backend_state.items(): - if not torch.is_tensor(grad): - continue - if grad.dim() < 3: - continue - leaves[state_name] = grad.permute(1, 0, 2) - if batch_size is None: - batch_size = [int(grad.shape[1]), int(grad.shape[0])] - return TensorDict(leaves, batch_size=batch_size or []) - - -def run_active_window_transition_buckets_step( - runtime: Any, - recurrent_msg: torch.Tensor, - *, - active_recurrent_idx: torch.Tensor, - active_window_buckets: dict[str, dict[str, torch.Tensor]] | None = None, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - active_count = int(active_recurrent_idx.numel()) - if active_count == 0: - empty_hidden = recurrent_msg.new_empty(batch_size, 0, runtime.hidden_size) - empty_k = recurrent_msg.new_empty(batch_size, 0, runtime.head_dim) - empty_v = recurrent_msg.new_empty(batch_size, 0, runtime.value_dim) - return empty_hidden, empty_k, empty_v - recurrent_hidden = recurrent_msg.new_empty(batch_size, active_count, runtime.hidden_size) - recurrent_k = recurrent_msg.new_empty(batch_size, active_count, runtime.head_dim) - recurrent_v = recurrent_msg.new_empty(batch_size, active_count, runtime.value_dim) - for name in _active_population_names(runtime): - bucket = active_window_buckets.get(name) if active_window_buckets is not None else None - if bucket is None: - population_recurrent_idx = runtime._population_recurrent_indices(name) - population_mask = torch.isin(population_recurrent_idx, active_recurrent_idx) - if not bool(population_mask.any().item()): - continue - population_positions = torch.nonzero(population_mask, as_tuple=False).reshape(-1) - population_active_recurrent_idx = population_recurrent_idx.index_select(0, population_positions) - active_offsets = torch.searchsorted(active_recurrent_idx.contiguous(), population_active_recurrent_idx) - else: - population_positions = bucket["population_positions"] - if int(population_positions.numel()) == 0: - continue - population_active_recurrent_idx = bucket["active_recurrent_idx"] - active_offsets = bucket["active_offsets"] - if int(active_offsets.numel()) == 0: - continue - population_msg = recurrent_msg.index_select(1, active_offsets) - population_static_tensors = _cached_population_active_window_static_tensors( - runtime, - static_tensors, - population_name=name, - population_positions=population_positions, - active_recurrent_idx=population_active_recurrent_idx, - ) - _require_materialized_population(population_static_tensors, name) - packed_state_before = ( - step_population_state_cache[name] - if step_population_state_cache is not None and name in step_population_state_cache - else _fresh_zero_backend_packed_state( - runtime, - population_name=name, - batch_size=batch_size, - receiver_count=int(population_msg.shape[1]), - device=recurrent_msg.device, - dtype=recurrent_msg.dtype, - ) - ) - result = transition_execution.lower_backend_population_transition_forward_result_shared( - runtime, - population_name=name, - recurrent_msg=population_msg, - packed_state_before=packed_state_before, - population_reset_step=resets, - static_tensors=population_static_tensors, - materialize_recurrent_kv=True, - materialize_backward_tape=False, - materialize_next_state=materialize_next_state, - materialize_trace_state_next=materialize_next_state, - ) - if step_population_state_cache is not None and result.next_packed_state is not None: - step_population_state_cache[name] = result.next_packed_state - if result.recurrent_k is None or result.recurrent_v is None: - raise RuntimeError("Active-window flat-bucket sequence surface requires recurrent K/V materialization") - recurrent_hidden[:, active_offsets, :] = result.recurrent_hidden.to(dtype=recurrent_hidden.dtype) - recurrent_k[:, active_offsets, :] = result.recurrent_k.to(dtype=recurrent_k.dtype) - recurrent_v[:, active_offsets, :] = result.recurrent_v.to(dtype=recurrent_v.dtype) - return recurrent_hidden, recurrent_k, recurrent_v - - -def _fresh_zero_backend_packed_state( - runtime: Any, - *, - population_name: str, - batch_size: int, - receiver_count: int, - device: torch.device, - dtype: torch.dtype, -) -> TensorDict: - population_spec = runtime._backend_population_specs[population_name] - leaves = { - state_name: torch.zeros(batch_size, receiver_count, runtime.hidden_size, device=device, dtype=dtype) - for state_name in population_spec.transition_ir.state_inputs - } - return TensorDict(leaves, batch_size=[batch_size, receiver_count], device=device) - - -def run_transition_bucket_step( - runtime: Any, - population_name: str, - recurrent_msg: torch.Tensor, - population_state: TensorDict | None, - *, - resets: torch.Tensor | None, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = True, - message_already_population_ordered: bool = False, -) -> tuple[torch.Tensor, TensorDict]: - recurrent_idx = runtime._population_recurrent_indices(population_name) - population_recurrent_msg = ( - recurrent_msg if message_already_population_ordered else recurrent_msg.index_select(1, recurrent_idx) - ) - population_static_tensors = _population_static_tensors_for( - static_tensors, - population_name, - recurrent_idx=recurrent_idx, - recurrent_count=int(runtime.recurrent_cell_idx.numel()), - ) - _require_materialized_population(population_static_tensors, population_name) - if step_population_state_cache is not None and population_name in step_population_state_cache: - if population_state is None: - population_state = TensorDict({}, batch_size=[]) - if torch.is_grad_enabled() and population_recurrent_msg.requires_grad: - packed_state = step_population_state_cache[population_name] - packed_state_keys, packed_state_tensors = _flatten_backend_packed_state(packed_state) - trainable_items = _flat_bucket_trainable_items(runtime, static_tensors) - outputs = _CachedBackendTransitionBucketStepFunction.apply( - runtime, - population_name, - population_static_tensors, - packed_state_keys, - tuple(name for name, _param in trainable_items), - population_recurrent_msg, - resets, - *packed_state_tensors, - *(param for _name, param in trainable_items), - ) - recurrent_hidden = cast(torch.Tensor, outputs[0]).to(dtype=population_recurrent_msg.dtype) - next_packed_state = _unflatten_backend_packed_state( - packed_state_keys, - cast(tuple[torch.Tensor, ...], tuple(outputs[1:])), - ) - step_population_state_cache[population_name] = next_packed_state - runtime._last_flat_bucket_transition_backward_executor = _PHYSICAL_TRANSITION_BACKWARD_EXECUTOR - return recurrent_hidden, population_state - result = transition_execution.lower_backend_population_transition_forward_result_shared( - runtime, - population_name=population_name, - recurrent_msg=population_recurrent_msg, - packed_state_before=step_population_state_cache[population_name], - population_reset_step=resets, - static_tensors=population_static_tensors, - materialize_recurrent_kv=False, - materialize_backward_tape=False, - materialize_next_state=materialize_next_state, - materialize_trace_state_next=materialize_next_state, - ) - if result.next_packed_state is not None: - step_population_state_cache[population_name] = result.next_packed_state - return result.recurrent_hidden.to(dtype=population_recurrent_msg.dtype), population_state - if population_state is None: - raise RuntimeError(f"CUDA flat-bucket sequence surface requires TensorDict state for {population_name}") - packed_state_before = runtime._population_state_to_backend_state( - population_name, - cast(TensorDictBase, population_state), - ) - result = transition_execution.lower_backend_population_transition_forward_result_shared( - runtime, - population_name=population_name, - recurrent_msg=population_recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=resets, - static_tensors=population_static_tensors, - materialize_recurrent_kv=False, - materialize_backward_tape=False, - ) - next_state = runtime._backend_state_to_population_state( - population_name, - cast(dict[str, torch.Tensor], result.next_packed_state), - ) - return result.recurrent_hidden.to(dtype=population_recurrent_msg.dtype), cast(TensorDict, next_state) - - -def _require_materialized_population(static_tensors: dict[str, object], population_name: str) -> None: - population_materialized = static_tensors.get("population_materialized") - if not isinstance(population_materialized, dict): - raise RuntimeError("CUDA flat-bucket sequence surface requires materialized population parameters") - if not isinstance(population_materialized.get(population_name), dict): - raise RuntimeError(f"CUDA flat-bucket sequence surface has no materialized parameters for {population_name}") - - -def _population_active_window_static_tensors( - runtime: Any, - static_tensors: dict[str, object], - *, - population_name: str, - population_positions: torch.Tensor, - active_recurrent_idx: torch.Tensor, -) -> dict[str, object]: - out = dict(static_tensors) - out.pop(_POPULATION_STATIC_TENSORS_KEY, None) - recurrent_count = int(runtime.recurrent_cell_idx.numel()) - for key in ( - "recurrent_cell_bias", - "fused_recurrent_cell_bias", - ): - value = out.get(key) - if torch.is_tensor(value): - if value.dim() >= 2 and int(value.shape[1]) == recurrent_count: - out[key] = value.index_select(1, active_recurrent_idx) - elif value.dim() >= 1 and int(value.shape[0]) == recurrent_count: - out[key] = value.index_select(0, active_recurrent_idx) - for key in ( - "fused_recurrent_value_to_cell_weight", - "recurrent_sender_input_to_kv_weight", - ): - value = out.get(key) - if torch.is_tensor(value) and value.dim() >= 1 and int(value.shape[0]) == recurrent_count: - out[key] = value.index_select(0, active_recurrent_idx) - if torch.is_tensor(out.get("recurrent_sender_input_to_kv_weight")): - out["recurrent_group_input_to_kv_weight"] = None - grouped_recurrent_kv = out.get("recurrent_group_input_to_kv_weight") - if torch.is_tensor(grouped_recurrent_kv) and not torch.is_tensor(out.get("recurrent_sender_input_to_kv_weight")): - expanded = grouped_recurrent_kv.repeat_interleave( - max(1, int(runtime._recurrent_sender_kv_group_size)), - dim=0, - ) - if int(expanded.shape[0]) >= recurrent_count: - out["recurrent_sender_input_to_kv_weight"] = expanded[:recurrent_count].index_select( - 0, - active_recurrent_idx, - ) - out["recurrent_group_input_to_kv_weight"] = None - - population_materialized = out.get("population_materialized") - if isinstance(population_materialized, dict): - full_population_params = population_materialized.get(population_name) - sliced_population_materialized: dict[str, object | None] = {key: None for key in runtime._population_names} - if isinstance(full_population_params, dict): - population_count = int(runtime._population_recurrent_indices(population_name).numel()) - sliced_population_materialized[population_name] = { - key: _slice_population_receiver_tensor(value, population_positions, population_count) - if torch.is_tensor(value) - else value - for key, value in full_population_params.items() - } - input_proj_weight_t = full_population_params.get("input_proj_weight_t") - if torch.is_tensor(input_proj_weight_t): - out["input_proj_weight_t"] = _slice_population_receiver_tensor( - input_proj_weight_t, - population_positions, - population_count, - ) - out["population_materialized"] = sliced_population_materialized - out["fused_recurrent_population_input"] = False - out["unfused_recurrent_input_projection"] = True - return out - - -def _cached_population_active_window_static_tensors( - runtime: Any, - static_tensors: dict[str, object], - *, - population_name: str, - population_positions: torch.Tensor, - active_recurrent_idx: torch.Tensor, -) -> dict[str, object]: - cache = static_tensors.get("_flat_bucket_active_window_static_cache") - if not isinstance(cache, dict): - cache = {} - static_tensors["_flat_bucket_active_window_static_cache"] = cache - key = ( - population_name, - str(population_positions.device), - int(population_positions.data_ptr()), - tuple(population_positions.shape), - str(active_recurrent_idx.device), - int(active_recurrent_idx.data_ptr()), - tuple(active_recurrent_idx.shape), - ) - cached = cache.get(key) - if isinstance(cached, dict): - return cached - cached = _population_active_window_static_tensors( - runtime, - static_tensors, - population_name=population_name, - population_positions=population_positions, - active_recurrent_idx=active_recurrent_idx, - ) - cache[key] = cached - return cached - - -def _slice_population_receiver_tensor( - tensor: torch.Tensor, - positions: torch.Tensor, - population_count: int, -) -> torch.Tensor: - if population_count <= 0 or int(positions.numel()) == population_count: - return tensor - if tensor.dim() >= 1 and int(tensor.shape[0]) == population_count: - return tensor.index_select(0, positions).contiguous() - if tensor.dim() >= 1 and int(tensor.shape[0]) % population_count == 0: - factor = int(tensor.shape[0]) // population_count - reshaped = tensor.reshape(population_count, factor, *tuple(tensor.shape[1:])) - selected = reshaped.index_select(0, positions) - return selected.reshape(int(positions.numel()) * factor, *tuple(tensor.shape[1:])).contiguous() - if tensor.dim() >= 2 and int(tensor.shape[1]) == population_count: - return tensor.index_select(1, positions).contiguous() - if tensor.dim() >= 2 and int(tensor.shape[1]) % population_count == 0: - factor = int(tensor.shape[1]) // population_count - reshaped = tensor.reshape(tensor.shape[0], population_count, factor, *tuple(tensor.shape[2:])) - selected = reshaped.index_select(1, positions) - return selected.reshape(tensor.shape[0], int(positions.numel()) * factor, *tuple(tensor.shape[2:])).contiguous() - if tensor.dim() == 1 and int(tensor.numel()) % population_count == 0: - factor = int(tensor.numel()) // population_count - reshaped = tensor.reshape(population_count, factor) - return reshaped.index_select(0, positions).reshape(-1).contiguous() - return tensor - - -__all__ = [ - "_partial_backend_grad_state_to_population_state", - "_population_grad_state_to_backend_grad_state", - "BackendOrderTransitionStepResult", - "run_active_window_transition_buckets_step", - "run_backend_order_transition_buckets_backward_step", - "run_backend_order_transition_buckets_backward_step_cached", - "run_backend_order_transition_buckets_backward_step_cached_unbound", - "run_backend_order_transition_buckets_step_cached_eager_result", - "run_backend_order_transition_buckets_step", - "run_transition_bucket_step", - "run_transition_buckets_step", -] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/replay.py b/src/cortical/fabric/backend/cuda/sequence_surface/replay.py deleted file mode 100644 index 328b8ce2..00000000 --- a/src/cortical/fabric/backend/cuda/sequence_surface/replay.py +++ /dev/null @@ -1,720 +0,0 @@ -from __future__ import annotations - -import os -from typing import Any, Literal, cast - -import torch -from tensordict import TensorDictBase - -from cortical.fabric.backend.cuda.sequence_surface.support import _BACKWARD_ATTRIBUTION_MODE_ENV -from cortical.fabric.backend.planner import ( - PlannedFabricBackwardExecution, - PlannedFabricExecution, -) -from cortical.fabric.backend.reuse import ExecutionFamily -from cortical.fabric.runtime.state import ( - flatten_backend_packed_state as _flatten_backend_packed_state, -) -from cortical.fabric.runtime.state import ( - unflatten_backend_packed_state as _unflatten_backend_packed_state, -) - - -class CudaSequenceReplayMixin: - def _run_backend_sequence_surface_backward_full_replay_once( - self, - *, - boundary_seq: torch.Tensor, - projected_boundary_source_seq: torch.Tensor | None = None, - projected_boundary_weight: torch.Tensor | None = None, - projected_boundary_bias: torch.Tensor | None = None, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - initial_state_is_fresh: bool, - population_resets: torch.Tensor | None, - planned_backend_execution: PlannedFabricExecution, - planned_backend_backward_execution: PlannedFabricBackwardExecution, - grad_output_seq: torch.Tensor | None, - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - grad_recurrent_k: torch.Tensor | None, - grad_recurrent_v: torch.Tensor | None, - grad_input_k_last: torch.Tensor | None, - grad_input_v_last: torch.Tensor | None, - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - replay_static_tensors: dict[str, object], - output_boundary: Literal["sequence", "terminal"] = "sequence", - forward_carry_checkpoints: Any | None = None, - ) -> tuple[dict[str, torch.Tensor | None], tuple[torch.Tensor | None, ...]]: - del trainable_param_names, forward_carry_checkpoints - projected_boundary_active = ( - projected_boundary_source_seq is not None - or projected_boundary_weight is not None - or projected_boundary_bias is not None - ) - if ( - not planned_backend_backward_execution.receiver_bucket_plans - or not planned_backend_backward_execution.sender_bucket_plans - ): - raise RuntimeError("Supported Fabric training surface requires a planned backward execution") - if any( - bucket_plan.execution_family != ExecutionFamily.RECEIVER_MAJOR - for bucket_plan in planned_backend_backward_execution.receiver_bucket_plans - ): - raise RuntimeError("Supported Fabric backward surface requires receiver-major receiver-adjoint execution") - if any( - bucket_plan.execution_family != ExecutionFamily.EDGE_MAJOR - for bucket_plan in planned_backend_backward_execution.sender_bucket_plans - ): - raise RuntimeError("Supported Fabric backward surface requires edge-major sender/public accumulation") - packed_state_keys, packed_state_tensors = _flatten_backend_packed_state(packed_state) - with torch.profiler.record_function("fabric.glue.backward_replay_leaf_boundary"): - detached_boundary_steps = tuple( - boundary_seq[:, step_index] - .detach() - .clone() - .requires_grad_(boundary_seq.requires_grad or projected_boundary_active) - for step_index in range(int(boundary_seq.shape[1])) - ) - with torch.profiler.record_function("fabric.glue.backward_replay_leaf_packed_state"): - detached_packed_state_tensors = tuple( - tensor.detach().clone().requires_grad_(True) if tensor.is_floating_point() else tensor.detach().clone() - for tensor in packed_state_tensors - ) - detached_packed_state = _unflatten_backend_packed_state(packed_state_keys, detached_packed_state_tensors) - with torch.profiler.record_function("fabric.glue.backward_replay_leaf_recurrent_carry"): - detached_initial_hidden = initial_hidden.detach().clone().requires_grad_(True) - detached_initial_recurrent_k = ( - None if initial_recurrent_k is None else initial_recurrent_k.detach().clone().requires_grad_(True) - ) - detached_initial_recurrent_v = ( - None if initial_recurrent_v is None else initial_recurrent_v.detach().clone().requires_grad_(True) - ) - - previous_backend_name = self._active_backend_name - self._active_backend_name = "cuda" - try: - with torch.enable_grad(): - static_tensors = replay_static_tensors - if getattr(self, "_last_backward_projection_mode", None) == "factorized_recurrent_input": - static_tensors = dict(replay_static_tensors) - static_tensors["replay_unfused_recurrent_input_projection"] = True - backend_population_name = self._select_output_cells_stream_backend_population( - k=1, - ) - if backend_population_name is None: - raise RuntimeError( - f"Supported Fabric {planned_backend_execution.surface_key} training surface " - "requires a callable backend-owned CUDA sequence surface" - ) - running_packed_state = detached_packed_state - running_hidden = detached_initial_hidden - running_recurrent_k = detached_initial_recurrent_k - running_recurrent_v = detached_initial_recurrent_v - replay_output_steps: list[torch.Tensor] = [] - replay_input_k_last: torch.Tensor | None = None - replay_input_v_last: torch.Tensor | None = None - for step_index in range(int(boundary_seq.shape[1])): - ( - replay_output_step, - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - replay_input_k_last, - replay_input_v_last, - ) = self._run_backend_sequence_surface_step_primitives( - population_name=backend_population_name, - boundary_step=detached_boundary_steps[step_index], - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - population_resets=population_resets[:, step_index] if population_resets is not None else None, - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] - ), - static_tensors=static_tensors, - ) - replay_output_steps.append(replay_output_step) - replay_next_packed_state = running_packed_state - replay_recurrent_hidden = running_hidden - replay_recurrent_k = running_recurrent_k - replay_recurrent_v = running_recurrent_v - finally: - self._active_backend_name = previous_backend_name - - output_spec_groups: list[tuple[str, list[tuple[torch.Tensor, torch.Tensor]]]] = [] - sequence_output_specs: list[tuple[torch.Tensor, torch.Tensor]] = [] - if grad_output_seq is not None: - if output_boundary == "terminal": - replay_output_step = replay_output_steps[-1] - if replay_output_step.requires_grad: - sequence_output_specs.append((replay_output_step, grad_output_seq[:, -1])) - else: - sequence_output_specs.extend( - (replay_output_step, grad_output_seq[:, step_index]) - for step_index, replay_output_step in enumerate(replay_output_steps) - if replay_output_step.requires_grad - ) - if sequence_output_specs: - output_spec_groups.append(("fabric.backward.replay.output_sequence", sequence_output_specs)) - next_packed_state_keys, next_packed_state_tensors = _flatten_backend_packed_state(replay_next_packed_state) - if next_packed_state_keys != packed_state_keys: - raise RuntimeError("Backend replay backward must preserve packed-state structure") - next_packed_state_specs: list[tuple[torch.Tensor, torch.Tensor]] = [] - if packed_state_keys is None: - grad_next_tensor = cast(torch.Tensor | None, grad_next_packed_state) - next_tensor = next_packed_state_tensors[0] - if grad_next_tensor is not None and next_tensor.requires_grad: - next_packed_state_specs.append((next_tensor, grad_next_tensor)) - else: - assert isinstance(grad_next_packed_state, (dict, TensorDictBase)) - for key, tensor in zip(packed_state_keys, next_packed_state_tensors, strict=True): - grad = cast(torch.Tensor | None, grad_next_packed_state.get(key)) - if grad is not None and tensor.requires_grad: - next_packed_state_specs.append((tensor, grad)) - if next_packed_state_specs: - output_spec_groups.append(("fabric.backward.replay.next_packed_state", next_packed_state_specs)) - recurrent_carry_specs: list[tuple[torch.Tensor, torch.Tensor]] = [] - if grad_recurrent_hidden is not None and replay_recurrent_hidden.requires_grad: - recurrent_carry_specs.append((replay_recurrent_hidden, grad_recurrent_hidden)) - if grad_recurrent_k is not None and replay_recurrent_k is not None and replay_recurrent_k.requires_grad: - recurrent_carry_specs.append((replay_recurrent_k, grad_recurrent_k)) - if grad_recurrent_v is not None and replay_recurrent_v is not None and replay_recurrent_v.requires_grad: - recurrent_carry_specs.append((replay_recurrent_v, grad_recurrent_v)) - if recurrent_carry_specs: - output_spec_groups.append(("fabric.backward.replay.recurrent_carry", recurrent_carry_specs)) - input_kv_specs: list[tuple[torch.Tensor, torch.Tensor]] = [] - if grad_input_k_last is not None and replay_input_k_last is not None and replay_input_k_last.requires_grad: - input_kv_specs.append((replay_input_k_last, grad_input_k_last)) - if grad_input_v_last is not None and replay_input_v_last is not None and replay_input_v_last.requires_grad: - input_kv_specs.append((replay_input_v_last, grad_input_v_last)) - if input_kv_specs: - output_spec_groups.append(("fabric.backward.replay.input_kv_last", input_kv_specs)) - - input_specs: list[tuple[str, torch.Tensor]] = [] - boundary_input_specs: list[tuple[str, torch.Tensor]] = [] - if boundary_seq.requires_grad: - boundary_input_specs.extend( - (f"boundary_step_{step_index}", boundary_step) - for step_index, boundary_step in enumerate(detached_boundary_steps) - ) - input_specs.extend(boundary_input_specs) - packed_state_input_specs: list[tuple[str, torch.Tensor]] = [] - for index, tensor in enumerate(detached_packed_state_tensors): - if tensor.requires_grad: - packed_state_input_specs.append((f"packed_state_{index}", tensor)) - input_specs.extend(packed_state_input_specs) - recurrent_carry_input_specs: list[tuple[str, torch.Tensor]] = [] - if detached_initial_hidden.requires_grad: - recurrent_carry_input_specs.append(("initial_hidden", detached_initial_hidden)) - if detached_initial_recurrent_k is not None and detached_initial_recurrent_k.requires_grad: - recurrent_carry_input_specs.append(("initial_recurrent_k", detached_initial_recurrent_k)) - if detached_initial_recurrent_v is not None and detached_initial_recurrent_v.requires_grad: - recurrent_carry_input_specs.append(("initial_recurrent_v", detached_initial_recurrent_v)) - input_specs.extend(recurrent_carry_input_specs) - param_input_specs: list[tuple[str, torch.Tensor]] = [] - for parameter_index, parameter in enumerate(trainable_params): - if parameter.requires_grad: - param_input_specs.append((f"param_{parameter_index}", parameter)) - input_specs.extend(param_input_specs) - - backward_attribution_mode = os.environ.get(_BACKWARD_ATTRIBUTION_MODE_ENV) - if backward_attribution_mode == "full_replay_boundary_probe": - - def accumulate_probe_grad( - existing_grad: torch.Tensor | None, - grad: torch.Tensor | None, - ) -> torch.Tensor | None: - if grad is None: - return existing_grad - if existing_grad is None: - return grad - return existing_grad + grad - - grad_result_map: dict[str, torch.Tensor | None] = {name: None for name, _tensor in input_specs} - input_spec_groups = ( - ("boundary_inputs", boundary_input_specs), - ("packed_state_inputs", packed_state_input_specs), - ("recurrent_carry_inputs", recurrent_carry_input_specs), - ("parameter_inputs", param_input_specs), - ) - active_calls: list[tuple[str, list[tuple[torch.Tensor, torch.Tensor]], list[tuple[str, torch.Tensor]]]] = [] - for profile_name, specs in output_spec_groups: - if not specs: - continue - if profile_name == "fabric.backward.replay.output_sequence": - active_calls.extend( - (f"{profile_name}.{input_group_name}", specs, group_input_specs) - for input_group_name, group_input_specs in input_spec_groups - if group_input_specs - ) - else: - active_calls.append((profile_name, specs, input_specs)) - for group_index, (profile_name, specs, group_input_specs) in enumerate(active_calls): - group_grad_map = self._run_named_autograd_phase( - outputs=specs, - inputs=group_input_specs, - profile_name=profile_name, - retain_graph=group_index < len(active_calls) - 1, - ) - for name, grad in group_grad_map.items(): - grad_result_map[name] = accumulate_probe_grad(grad_result_map.get(name), grad) - elif backward_attribution_mode == "full_replay_no_parameter_probe": - output_specs = [spec for _profile_name, specs in output_spec_groups for spec in specs] - non_parameter_input_specs = boundary_input_specs + packed_state_input_specs + recurrent_carry_input_specs - grad_result_map = {name: None for name, _tensor in input_specs} - grad_result_map.update( - self._run_named_autograd_phase( - outputs=output_specs, - inputs=non_parameter_input_specs, - profile_name="fabric.backward.replay.no_parameter_inputs", - ) - ) - elif backward_attribution_mode == "full_replay_without_boundary_inputs_probe": - output_specs = [spec for _profile_name, specs in output_spec_groups for spec in specs] - non_boundary_input_specs = packed_state_input_specs + recurrent_carry_input_specs + param_input_specs - grad_result_map = {name: None for name, _tensor in input_specs} - grad_result_map.update( - self._run_named_autograd_phase( - outputs=output_specs, - inputs=non_boundary_input_specs, - profile_name="fabric.backward.replay.without_boundary_inputs", - ) - ) - elif backward_attribution_mode == "full_replay_no_parameter_boundary_probe": - - def accumulate_probe_grad( - existing_grad: torch.Tensor | None, - grad: torch.Tensor | None, - ) -> torch.Tensor | None: - if grad is None: - return existing_grad - if existing_grad is None: - return grad - return existing_grad + grad - - grad_result_map = {name: None for name, _tensor in input_specs} - output_sequence_specs = next( - ( - specs - for profile_name, specs in output_spec_groups - if profile_name == "fabric.backward.replay.output_sequence" - ), - [], - ) - input_spec_groups = ( - ("boundary_inputs", boundary_input_specs), - ("packed_state_inputs", packed_state_input_specs), - ("recurrent_carry_inputs", recurrent_carry_input_specs), - ) - active_calls = [ - ( - f"fabric.backward.replay.no_parameter_output_sequence.{input_group_name}", - output_sequence_specs, - group_input_specs, - ) - for input_group_name, group_input_specs in input_spec_groups - if output_sequence_specs and group_input_specs - ] - for group_index, (profile_name, specs, group_input_specs) in enumerate(active_calls): - group_grad_map = self._run_named_autograd_phase( - outputs=specs, - inputs=group_input_specs, - profile_name=profile_name, - retain_graph=group_index < len(active_calls) - 1, - ) - for name, grad in group_grad_map.items(): - grad_result_map[name] = accumulate_probe_grad(grad_result_map.get(name), grad) - else: - output_specs = [spec for _profile_name, specs in output_spec_groups for spec in specs] - grad_result_map = self._run_named_autograd_phase( - outputs=output_specs, - inputs=input_specs, - profile_name="fabric.backward.full_replay_autograd", - ) - grad_boundary_steps = ( - tuple( - cast(torch.Tensor | None, grad_result_map.get(f"boundary_step_{step_index}")) - for step_index in range(len(detached_boundary_steps)) - ) - if boundary_seq.requires_grad - else () - ) - grad_boundary_seq: torch.Tensor | None = None - if grad_boundary_steps and any(grad is not None for grad in grad_boundary_steps): - grad_boundary_seq = torch.stack( - tuple( - grad if grad is not None else torch.zeros_like(detached_boundary_steps[step_index]) - for step_index, grad in enumerate(grad_boundary_steps) - ), - dim=1, - ) - grad_sequence_inputs: dict[str, torch.Tensor | None] = { - "initial_hidden": cast(torch.Tensor | None, grad_result_map.get("initial_hidden")), - "population_resets": None, - "initial_recurrent_k": cast(torch.Tensor | None, grad_result_map.get("initial_recurrent_k")), - "initial_recurrent_v": cast(torch.Tensor | None, grad_result_map.get("initial_recurrent_v")), - } - if projected_boundary_active: - if ( - grad_boundary_seq is not None - and projected_boundary_source_seq is not None - and projected_boundary_weight is not None - ): - grad_boundary_flat = grad_boundary_seq.reshape( - int(grad_boundary_seq.shape[0]), - int(grad_boundary_seq.shape[1]), - -1, - ) - grad_sequence_inputs["projected_boundary_source_seq"] = ( - grad_boundary_flat.matmul(projected_boundary_weight) - if projected_boundary_source_seq.requires_grad - else None - ) - grad_sequence_inputs["projected_boundary_weight"] = ( - torch.einsum("btf,bth->fh", grad_boundary_flat, projected_boundary_source_seq) - if projected_boundary_weight.requires_grad - else None - ) - grad_sequence_inputs["projected_boundary_bias"] = ( - grad_boundary_flat.sum(dim=(0, 1)) - if projected_boundary_bias is not None and projected_boundary_bias.requires_grad - else None - ) - else: - grad_sequence_inputs["projected_boundary_source_seq"] = None - grad_sequence_inputs["projected_boundary_weight"] = None - if projected_boundary_bias is not None: - grad_sequence_inputs["projected_boundary_bias"] = None - else: - grad_sequence_inputs["boundary_seq"] = grad_boundary_seq - for index in range(len(detached_packed_state_tensors)): - grad_sequence_inputs[f"packed_state_{index}"] = cast( - torch.Tensor | None, - grad_result_map.get(f"packed_state_{index}"), - ) - return grad_sequence_inputs, tuple( - cast(torch.Tensor | None, grad_result_map.get(f"param_{parameter_index}")) - for parameter_index in range(len(trainable_params)) - ) - - def _run_backend_sequence_surface_backward_reference_replay_once( - self, - *, - boundary_seq: torch.Tensor, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - initial_state_is_fresh: bool, - population_resets: torch.Tensor | None, - planned_backend_execution: PlannedFabricExecution, - planned_backend_backward_execution: PlannedFabricBackwardExecution, - grad_output_seq: torch.Tensor | None, - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - grad_recurrent_k: torch.Tensor | None, - grad_recurrent_v: torch.Tensor | None, - grad_input_k_last: torch.Tensor | None, - grad_input_v_last: torch.Tensor | None, - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - output_boundary: Literal["sequence", "terminal"] = "sequence", - ) -> tuple[dict[str, torch.Tensor | None], tuple[torch.Tensor | None, ...]]: - del trainable_param_names - if ( - not planned_backend_backward_execution.receiver_bucket_plans - or not planned_backend_backward_execution.sender_bucket_plans - ): - raise RuntimeError("Supported Fabric training surface requires a planned backward execution") - if any( - bucket_plan.execution_family != ExecutionFamily.RECEIVER_MAJOR - for bucket_plan in planned_backend_backward_execution.receiver_bucket_plans - ): - raise RuntimeError("Supported Fabric backward surface requires receiver-major receiver-adjoint execution") - if any( - bucket_plan.execution_family != ExecutionFamily.EDGE_MAJOR - for bucket_plan in planned_backend_backward_execution.sender_bucket_plans - ): - raise RuntimeError("Supported Fabric backward surface requires edge-major sender/public accumulation") - - def _packed_state_grad_map( - keys: tuple[str, ...] | None, - grad_state: Any, - ) -> dict[str, torch.Tensor | None]: - if keys is None: - return { - "packed_state_0": cast(torch.Tensor | None, grad_state if torch.is_tensor(grad_state) else None) - } - assert grad_state is None or isinstance(grad_state, dict) - return { - f"packed_state_{index}": cast( - torch.Tensor | None, - None if grad_state is None else grad_state.get(key), - ) - for index, key in enumerate(keys) - } - - def _packed_state_from_grad_map( - keys: tuple[str, ...] | None, - grad_map: dict[str, torch.Tensor | None], - ) -> Any: - if keys is None: - return grad_map.get("packed_state_0") - return {key: grad_map.get(f"packed_state_{index}") for index, key in enumerate(keys)} - - def _accumulate_grad(current: torch.Tensor | None, new_grad: torch.Tensor | None) -> torch.Tensor | None: - if new_grad is None: - return current - return new_grad if current is None else current + new_grad - - static_tensors = self._get_inference_static_tensors(device=boundary_seq.device, dtype=boundary_seq.dtype) - backend_population_name = self._select_output_cells_stream_backend_population( - k=1, - ) - if backend_population_name is None: - raise RuntimeError( - f"Supported Fabric {planned_backend_execution.surface_key} training surface " - "requires the backend-owned CUDA sequence surface" - ) - - packed_state_keys, _ = _flatten_backend_packed_state(packed_state) - step_inputs: list[tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]] = [] - running_packed_state = packed_state - running_hidden = initial_hidden - running_recurrent_k = initial_recurrent_k - running_recurrent_v = initial_recurrent_v - with torch.no_grad(): - for step_index in range(int(boundary_seq.shape[1])): - step_inputs.append((running_packed_state, running_hidden, running_recurrent_k, running_recurrent_v)) - ( - _output_cells, - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - _input_k, - _input_v, - ) = self._run_backend_sequence_surface_step_primitives( - population_name=backend_population_name, - boundary_step=boundary_seq[:, step_index], - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - population_resets=population_resets[:, step_index] if population_resets is not None else None, - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] - ), - static_tensors=static_tensors, - ) - - running_grad_packed_state = grad_next_packed_state - running_grad_recurrent_hidden = grad_recurrent_hidden - running_grad_recurrent_k = grad_recurrent_k - running_grad_recurrent_v = grad_recurrent_v - grad_boundary_steps: list[torch.Tensor | None] = [None] * int(boundary_seq.shape[1]) - grad_initial_hidden_total: torch.Tensor | None = None - grad_initial_recurrent_k_total: torch.Tensor | None = None - grad_initial_recurrent_v_total: torch.Tensor | None = None - grad_param_accum: list[torch.Tensor | None] = [None] * len(trainable_params) - - for step_index in reversed(range(int(boundary_seq.shape[1]))): - packed_state_before, hidden_before, recurrent_k_before, recurrent_v_before = step_inputs[step_index] - detached_boundary = boundary_seq[:, step_index].detach().requires_grad_(boundary_seq.requires_grad) - current_packed_state_keys, current_packed_state_tensors = _flatten_backend_packed_state(packed_state_before) - detached_packed_state_tensors = tuple( - tensor.detach().clone().requires_grad_(True) if tensor.is_floating_point() else tensor.detach().clone() - for tensor in current_packed_state_tensors - ) - detached_packed_state = _unflatten_backend_packed_state( - current_packed_state_keys, detached_packed_state_tensors - ) - detached_hidden = hidden_before.detach().requires_grad_(True) - detached_recurrent_k = ( - None if recurrent_k_before is None else recurrent_k_before.detach().requires_grad_(True) - ) - detached_recurrent_v = ( - None if recurrent_v_before is None else recurrent_v_before.detach().requires_grad_(True) - ) - population_reset_step = population_resets[:, step_index] if population_resets is not None else None - - with torch.enable_grad(): - step_static_tensors = self._materialize_inference_static_tensors( - device=boundary_seq.device, - dtype=boundary_seq.dtype, - include_backend_cell_tensors=False, - ) - backend_population_name = self._select_output_cells_stream_backend_population( - k=1, - ) - if backend_population_name is None: - raise RuntimeError( - f"Supported Fabric {planned_backend_execution.surface_key} backward surface " - "requires the backend-owned CUDA sequence surface" - ) - ( - replay_output_step, - replay_next_packed_state, - replay_recurrent_hidden, - replay_recurrent_k, - replay_recurrent_v, - replay_input_k, - replay_input_v, - ) = self._run_backend_sequence_surface_step_primitives( - population_name=backend_population_name, - boundary_step=detached_boundary, - packed_state=detached_packed_state, - initial_hidden=detached_hidden, - initial_recurrent_k=detached_recurrent_k, - initial_recurrent_v=detached_recurrent_v, - population_resets=population_reset_step, - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, step_static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, step_static_tensors["input_group_input_to_kv_weight"] - ), - static_tensors=step_static_tensors, - ) - - output_specs: list[tuple[torch.Tensor, torch.Tensor]] = [] - if grad_output_seq is not None: - if output_boundary != "terminal" or step_index == int(boundary_seq.shape[1]) - 1: - grad_step_index = -1 if output_boundary == "terminal" else step_index - output_specs.append((replay_output_step, grad_output_seq[:, grad_step_index])) - packed_state_grad_map = _packed_state_grad_map(current_packed_state_keys, running_grad_packed_state) - if current_packed_state_keys is None: - packed_state_grad = packed_state_grad_map["packed_state_0"] - if ( - packed_state_grad is not None - and torch.is_tensor(replay_next_packed_state) - and replay_next_packed_state.requires_grad - ): - output_specs.append((replay_next_packed_state, packed_state_grad)) - else: - assert isinstance(replay_next_packed_state, (dict, TensorDictBase)) - for index, key in enumerate(current_packed_state_keys): - packed_state_grad = packed_state_grad_map[f"packed_state_{index}"] - packed_state_output = replay_next_packed_state[key] - if packed_state_grad is not None and packed_state_output.requires_grad: - output_specs.append((packed_state_output, packed_state_grad)) - if running_grad_recurrent_hidden is not None and replay_recurrent_hidden.requires_grad: - output_specs.append((replay_recurrent_hidden, running_grad_recurrent_hidden)) - if running_grad_recurrent_k is not None and replay_recurrent_k.requires_grad: - output_specs.append((replay_recurrent_k, running_grad_recurrent_k)) - if running_grad_recurrent_v is not None and replay_recurrent_v.requires_grad: - output_specs.append((replay_recurrent_v, running_grad_recurrent_v)) - if ( - step_index == int(boundary_seq.shape[1]) - 1 - and grad_input_k_last is not None - and replay_input_k.requires_grad - ): - output_specs.append((replay_input_k, grad_input_k_last)) - if ( - step_index == int(boundary_seq.shape[1]) - 1 - and grad_input_v_last is not None - and replay_input_v.requires_grad - ): - output_specs.append((replay_input_v, grad_input_v_last)) - - input_specs: list[tuple[str, torch.Tensor]] = [] - if detached_boundary.requires_grad: - input_specs.append(("boundary_step", detached_boundary)) - if current_packed_state_keys is None: - packed_state_input = cast(torch.Tensor, detached_packed_state) - if packed_state_input.requires_grad: - input_specs.append(("packed_state_0", packed_state_input)) - else: - assert isinstance(detached_packed_state, (dict, TensorDictBase)) - for index, key in enumerate(current_packed_state_keys): - packed_state_input = detached_packed_state[key] - if packed_state_input.requires_grad: - input_specs.append((f"packed_state_{index}", packed_state_input)) - if detached_hidden.requires_grad: - input_specs.append(("initial_hidden", detached_hidden)) - if detached_recurrent_k is not None and detached_recurrent_k.requires_grad: - input_specs.append(("initial_recurrent_k", detached_recurrent_k)) - if detached_recurrent_v is not None and detached_recurrent_v.requires_grad: - input_specs.append(("initial_recurrent_v", detached_recurrent_v)) - for parameter_index, parameter in enumerate(trainable_params): - if parameter.requires_grad: - input_specs.append((f"param_{parameter_index}", parameter)) - - with torch.profiler.record_function("fabric.backward.reference_replay_autograd"): - grad_results = ( - torch.autograd.grad( - outputs=tuple(tensor for tensor, _grad in output_specs), - inputs=tuple(tensor for _name, tensor in input_specs), - grad_outputs=tuple(grad for _tensor, grad in output_specs), - allow_unused=True, - ) - if output_specs and input_specs - else () - ) - grad_result_map = {name: grad for (name, _tensor), grad in zip(input_specs, grad_results, strict=True)} - grad_boundary_steps[step_index] = cast(torch.Tensor | None, grad_result_map.get("boundary_step")) - packed_state_input_grad_map = { - key: cast(torch.Tensor | None, grad_result_map.get(key)) for key in packed_state_grad_map - } - running_grad_packed_state = _packed_state_from_grad_map( - current_packed_state_keys, packed_state_input_grad_map - ) - hidden_input_grad = cast(torch.Tensor | None, grad_result_map.get("initial_hidden")) - if step_index == 0: - grad_initial_hidden_total = _accumulate_grad(grad_initial_hidden_total, hidden_input_grad) - grad_initial_recurrent_k_total = _accumulate_grad( - grad_initial_recurrent_k_total, - cast(torch.Tensor | None, grad_result_map.get("initial_recurrent_k")), - ) - grad_initial_recurrent_v_total = _accumulate_grad( - grad_initial_recurrent_v_total, - cast(torch.Tensor | None, grad_result_map.get("initial_recurrent_v")), - ) - else: - running_grad_recurrent_hidden = hidden_input_grad - running_grad_recurrent_k = cast(torch.Tensor | None, grad_result_map.get("initial_recurrent_k")) - running_grad_recurrent_v = cast(torch.Tensor | None, grad_result_map.get("initial_recurrent_v")) - for parameter_index in range(len(trainable_params)): - grad_param_accum[parameter_index] = _accumulate_grad( - grad_param_accum[parameter_index], - cast(torch.Tensor | None, grad_result_map.get(f"param_{parameter_index}")), - ) - - grad_sequence_inputs: dict[str, torch.Tensor | None] = { - "boundary_seq": torch.stack( - [ - cast(torch.Tensor, grad_step) - if grad_step is not None - else torch.zeros_like(boundary_seq[:, step_index]) - for step_index, grad_step in enumerate(grad_boundary_steps) - ], - dim=1, - ) - if any(grad_step is not None for grad_step in grad_boundary_steps) - else None, - "initial_hidden": grad_initial_hidden_total, - "population_resets": None, - "initial_recurrent_k": grad_initial_recurrent_k_total, - "initial_recurrent_v": grad_initial_recurrent_v_total, - } - if packed_state_keys is None: - grad_sequence_inputs["packed_state_0"] = cast(torch.Tensor | None, running_grad_packed_state) - else: - assert isinstance(running_grad_packed_state, (dict, TensorDictBase)) - for index, key in enumerate(packed_state_keys): - grad_sequence_inputs[f"packed_state_{index}"] = cast( - torch.Tensor | None, running_grad_packed_state.get(key) - ) - return grad_sequence_inputs, tuple(grad_param_accum) diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/runtime/__init__.py b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/__init__.py new file mode 100644 index 00000000..c852d376 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/__init__.py @@ -0,0 +1 @@ +"""Runtime entry points for CUDA sequence-surface execution.""" diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py new file mode 100644 index 00000000..8c372ce4 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py @@ -0,0 +1,1028 @@ +from __future__ import annotations + +from typing import Any, Literal, cast + +import torch +from tensordict import TensorDict + +from cortical.fabric.backend.cuda.sequence_surface.runtime.support import ( + _transition_supports_receiver_local_dependency_window, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.memory_stages import ( + record_frontend_tensor_bytes, + record_registered_memory_stage, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.forward_scan import ( + run_shared_temporal_bucket_forward_scan as _run_shared_temporal_bucket_forward_scan, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.physical_autograd import ( + run_temporal_bucket_sequence_physical_autograd as _run_temporal_bucket_sequence_physical_autograd, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + TemporalOutputContract as _TemporalOutputContract, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.buckets import ( + PHYSICAL_TEMPORAL_BACKWARD_EXECUTOR as _PHYSICAL_TEMPORAL_BACKWARD_EXECUTOR, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.buckets import ( + PHYSICAL_TRANSITION_BACKWARD_EXECUTOR as _PHYSICAL_TRANSITION_BACKWARD_EXECUTOR, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.buckets import ( + temporal_backward_owner_plan as _temporal_backward_owner_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.buckets import ( + with_cached_population_static_tensors as _with_cached_population_static_tensors, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.tables import ( + build_temporal_primitive_table_plan as _build_temporal_primitive_table_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.tables import ( + temporal_tensor_binding_summaries as _temporal_tensor_binding_summaries, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.tables import ( + temporal_reverse_executor_summaries as _temporal_reverse_executor_summaries, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.primitive_dispatch import ( + build_temporal_primitive_executor_plan as _build_temporal_primitive_executor_plan, +) +from cortical.fabric.backend.graph_regions import ClosedRecurrentRegion +from cortical.fabric.backend.surfaces import BackendExecutionRecord +from cortical.fabric.backend.temporal_plan import temporal_execution_record_metadata + + +def execute_temporal_bucket_sequence( + runtime: Any, + *, + hidden_seq: torch.Tensor | None, + boundary_seq: torch.Tensor | None, + state: TensorDict, + population_resets: torch.Tensor | None, + step_reset_flags: list[bool] | None, + k: int | torch.Tensor | None, + constant_k: int | None, + batch_size: int, + time_steps: int, + step_mode: bool, + capture_active: bool, + static_tensors: dict[str, object], + grad_path: bool, + materialize_final_state: bool, + backend_population_state_is_fresh: bool, + use_fresh_backend_population_cache: bool, + tape_policy: Any | None = None, + output_contract: _TemporalOutputContract = "full_cells", + output_boundary: Literal["sequence", "terminal"] = "sequence", +) -> tuple[torch.Tensor, TensorDict]: + stage_reference = boundary_seq if boundary_seq is not None else cast(torch.Tensor, hidden_seq) + record_registered_memory_stage(runtime, stage_reference, "frontend_execute_entry") + record_frontend_tensor_bytes( + runtime, + stage="frontend_execute_entry", + tensors={ + "hidden_seq": hidden_seq, + "boundary_seq": boundary_seq, + "state": state, + "population_resets": population_resets, + }, + ) + runtime._last_flat_bucket_transition_backward_executor = None + runtime._last_flat_bucket_temporal_recurrent_kv_carry_reuse = False + runtime._last_flat_bucket_temporal_artifact_mode = None + runtime._last_flat_bucket_temporal_artifact_reason = None + runtime._last_flat_bucket_temporal_artifact_checkpoint_stride = None + runtime._last_flat_bucket_temporal_artifact_recompute_window_len = None + runtime._last_flat_bucket_temporal_artifact_checkpoint_count = None + runtime._last_flat_bucket_forward_transition_executor = None + runtime._last_flat_bucket_state_cache_mode = None + runtime._last_flat_bucket_state_cache_materialized_steps = None + runtime._last_flat_bucket_state_cache_elided_steps = None + runtime._last_flat_bucket_recurrent_graph_layout_backend = None + runtime._last_flat_bucket_graph_order_layout_backend = None + runtime._last_flat_bucket_public_carry_order = None + runtime._last_flat_bucket_public_projection_backend = None + runtime._last_flat_bucket_readout_backend = None + runtime._last_flat_bucket_temporal_scan_owner = None + runtime._last_flat_bucket_scan_implementation = None + runtime._last_flat_bucket_temporal_table_review = () + runtime._last_flat_bucket_temporal_primitive_names = () + runtime._last_flat_bucket_temporal_primitive_families = () + runtime._last_flat_bucket_temporal_primitive_row_count = None + runtime._last_flat_bucket_temporal_tensor_binding_row_count = None + runtime._last_flat_bucket_temporal_tensor_binding_summaries = () + runtime._last_flat_bucket_temporal_scan_binding_projection = () + runtime._last_flat_bucket_temporal_reverse_executor_summaries = () + runtime._last_flat_bucket_temporal_primitive_executor_contracts = () + runtime._last_flat_bucket_temporal_primitive_executor_blockers = () + runtime._last_flat_bucket_temporal_scan_binding_abi = None + runtime._last_flat_bucket_temporal_scan_primitive_row_source = None + runtime._last_flat_bucket_temporal_backward_binding_abi = None + runtime._last_flat_bucket_temporal_transition_backward_binding_abi = None + runtime._last_flat_bucket_temporal_scheduler_plan = () + device = boundary_seq.device if boundary_seq is not None else cast(torch.Tensor, hidden_seq).device + static_tensors = _with_cached_population_static_tensors(runtime, static_tensors) + record_registered_memory_stage(runtime, stage_reference, "frontend_execute_after_static_cache") + record_frontend_tensor_bytes( + runtime, + stage="frontend_execute_after_static_cache", + tensors={ + "static_tensors": static_tensors, + "state": state, + "boundary_seq": boundary_seq, + }, + ) + + running_state = state + if grad_path and hidden_seq is None and boundary_seq is not None and constant_k is not None: + inner_steps = int(constant_k) + if inner_steps <= 0: + raise RuntimeError("Temporal bucket sequence physical training requires positive constant K") + planned_backward_execution = runtime.plan_backend_backward_execution( + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + training=True, + tape_policy=tape_policy, + device=device, + temporal_plan=getattr(runtime, "_last_temporal_execution_plan", None), + ) + output_seq, next_state = _run_temporal_bucket_sequence_physical_autograd( + runtime, + boundary_seq=boundary_seq, + state=running_state, + population_resets=population_resets, + static_tensors=static_tensors, + planned_backward_execution=planned_backward_execution, + materialize_final_state=materialize_final_state, + output_contract=output_contract, + output_boundary=output_boundary, + inner_steps=inner_steps, + ) + runtime._last_flat_bucket_transition_backward_executor = _PHYSICAL_TEMPORAL_BACKWARD_EXECUTOR + record_temporal_bucket_step_loop_execution( + runtime, + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + training=grad_path, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + output_contract=output_contract, + static_tensors=static_tensors, + ) + return (output_seq.squeeze(1) if step_mode else output_seq), next_state + if ( + not grad_path + and hidden_seq is None + and boundary_seq is not None + and constant_k is not None + and int(constant_k) >= 1 + ): + scan_result = _run_shared_temporal_bucket_forward_scan( + runtime, + boundary_seq=boundary_seq, + state=running_state, + population_resets=population_resets, + static_tensors=static_tensors, + inner_steps=int(constant_k), + materialize_final_state=materialize_final_state, + output_contract=output_contract, + output_boundary=output_boundary, + temporal_plan=getattr(runtime, "_last_temporal_execution_plan", None), + collect_artifacts=False, + backend_population_state_is_fresh=backend_population_state_is_fresh, + ) + record_temporal_bucket_step_loop_execution( + runtime, + batch_size=batch_size, + time_steps=time_steps, + inner_steps=int(constant_k), + training=False, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + output_contract=output_contract, + static_tensors=static_tensors, + ) + output_seq = scan_result.output_seq + return (output_seq.squeeze(1) if step_mode else output_seq), scan_result.final_state + if hidden_seq is None and boundary_seq is not None and constant_k is None: + raise RuntimeError( + "Shared temporal engine requires planner-lowered per-timestep K before CUDA execution. " + "Variable K must lower into compiler-owned temporal schedule rows before CUDA execution." + ) + raise RuntimeError( + "Shared temporal engine requires planner-lowered boundary sequences before CUDA execution. " + "Direct hidden-input temporal execution must lower the public input projection into boundary_seq " + "before entering the CUDA temporal scheduler." + ) + + +def _apply_temporal_output_contract( + runtime: Any, + y_step: torch.Tensor, + output_contract: _TemporalOutputContract, +) -> torch.Tensor: + if output_contract == "full_cells": + return y_step + if output_contract == "output_cells": + return runtime._select_output_cells(y_step.unsqueeze(1)).squeeze(1) + if output_contract == "pooled_output_cells": + output_cells = runtime._select_output_cells(y_step.unsqueeze(1)) + return runtime._pool_output_ports(output_cells).squeeze(1) + raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") + + +def supports_temporal_bucket_active_output_window(runtime: Any, *, time_steps: int) -> bool: + if time_steps < 1: + return False + active_populations = tuple( + name for name in runtime._population_names if int(runtime._population_recurrent_indices(name).numel()) > 0 + ) + if len(active_populations) <= 1: + return False + if not bool(runtime._partitioned_layout) or not bool(runtime._local_message_step_enabled): + return False + if runtime._has_edge_delay or bool(getattr(runtime, "_uses_sparse_message_backend", False)): + return False + active_recurrent_idx = getattr(runtime, "flat_bucket_active_output_recurrent_idx", None) + if not torch.is_tensor(active_recurrent_idx) or int(active_recurrent_idx.numel()) == 0: + return False + if not bool(getattr(runtime, "_flat_bucket_active_output_region_compact_contiguous", False)): + return False + if not bool(getattr(runtime, "_flat_bucket_active_output_region_is_full", False)): + return False + for population_name in runtime._population_names: + if int(runtime._population_recurrent_indices(population_name).numel()) == 0: + continue + population_spec = runtime._backend_population_specs.get(population_name) + if population_spec is None or not _transition_supports_receiver_local_dependency_window( + population_spec.transition_ir + ): + return False + return True + + +def _flat_bucket_active_output_region_mode(runtime: Any) -> str: + return str( + getattr( + runtime, + "_flat_bucket_active_output_region_mode", + getattr(runtime, "_flat_bucket_output_recurrent_closure_mode", "unknown"), + ) + ) + + +def _flat_bucket_active_output_region(runtime: Any) -> ClosedRecurrentRegion: + indices = getattr( + runtime, + "_flat_bucket_active_output_region_indices", + getattr(runtime, "_flat_bucket_output_recurrent_closure_indices", ()), + ) + return ClosedRecurrentRegion( + indices=tuple(indices), + full_count=int(runtime._num_recurrent_cells), + ) + + +def _flat_bucket_output_recurrent_closure(runtime: Any) -> ClosedRecurrentRegion: + return ClosedRecurrentRegion( + indices=tuple(runtime._flat_bucket_output_recurrent_closure_indices), + full_count=int(runtime._num_recurrent_cells), + ) + + +def _runtime_sparse_message_bucket_kind(runtime: Any) -> str: + return ( + "ragged_grouped_sparse" + if int(getattr(runtime, "_recurrent_sparse_positive_degree_buckets", 0)) > 1 + else "degree_bucketed_sparse" + ) + + +def _runtime_sparse_degree_summary(runtime: Any) -> str: + degree_ptr = getattr(runtime, "recurrent_sparse_degree_ptr", None) + if torch.is_tensor(degree_ptr) and int(degree_ptr.numel()) > 1: + degrees: list[str] = [] + degree_ptr_cpu = degree_ptr.detach().cpu() + for degree in range(int(degree_ptr_cpu.numel()) - 1): + count = int(degree_ptr_cpu[degree + 1].item()) - int(degree_ptr_cpu[degree].item()) + if count > 0: + degrees.append(f"{degree}:{count}") + if degrees: + return "degrees=" + ",".join(degrees) + neighbor_valid = getattr(runtime, "recurrent_neighbor_valid", None) + if torch.is_tensor(neighbor_valid) and int(neighbor_valid.numel()) > 0: + degree_counts = torch.bincount(neighbor_valid.to(dtype=torch.long).sum(dim=1).detach().cpu()) + degrees = [ + f"{degree}:{int(count.item())}" for degree, count in enumerate(degree_counts) if int(count.item()) > 0 + ] + if degrees: + return "degrees=" + ",".join(degrees) + return "degrees=unknown" + + +def _runtime_message_record_metadata(runtime: Any) -> dict[str, tuple[str, ...]]: + uses_sparse = bool(getattr(runtime, "_uses_sparse_message_backend", False)) + reset_policy = "zero_source_rows" + reset_scope = "batch_row" + use_delay = bool(getattr(runtime, "_has_edge_delay", False)) + use_delay_str = "true" if use_delay else "false" + if not uses_sparse and bool(getattr(runtime, "_local_message_step_enabled", False)): + degree = int(getattr(runtime, "recurrent_local_sender_idx", torch.empty(0, 0)).shape[1]) + bucket_kind = "regular_local_receiver_owned" + return { + "message_projection_bucket_kinds": ("regular_local_projected_message_boundary",), + "message_bucket_count": ("1",), + "message_regular_local_bucket_count": ("1",), + "message_sparse_bucket_count": ("0",), + "message_batched_backend_count": ("1",), + "message_grouped_backend_count": ("0",), + "message_reset_aware_bucket_count": ("1",), + "message_degree_uniform_bucket_count": ("1",), + "message_ragged_grouped_bucket_count": ("0",), + "message_demoted_bucket_count": ("0",), + "message_bucket_signatures": ( + f"bucket_kind={bucket_kind}|topology_kind=regular_local|degree_or_block={degree}|" + f"K={int(runtime.head_dim)}|V={int(runtime.value_dim)}|reset_policy={reset_policy}|" + f"use_delay={use_delay_str}", + ), + "message_bucket_kinds": (bucket_kind,), + "message_topology_kinds": ("regular_local",), + "message_spatial_ownership": ("receiver_owned",), + "message_degree_bucket_lists": (f"degree={degree}",), + "message_logit_backends": ("direct_fixed_degree",), + "message_softmax_backends": ("custom_fixed_degree_softmax",), + "message_weighted_value_backends": ("direct_fixed_degree",), + "message_physical_mode": ("regular_local_direct_projected",), + "message_execution_mode": ("direct_fixed_degree",), + "message_output_boundary": ("projected_message",), + "message_degree": (str(degree),), + "message_k": (str(int(runtime.head_dim)),), + "message_v": (str(int(runtime.value_dim)),), + "message_projected_n": (str(int(runtime.d_msg)),), + "message_reset_policies": (reset_policy,), + "message_reset_scopes": (reset_scope,), + "message_use_delay": (use_delay_str,), + "message_distance_penalty_kinds": ("offset_distance",), + "message_epilogue_kinds": ("softmax_weighted_sum",), + "message_packed_source_reuse_count": ("1",), + "message_demotions": ("none",), + "message_workspace_mode": ("fixed_degree_direct",), + } + bucket_kind = _runtime_sparse_message_bucket_kind(runtime) + degree = int(getattr(runtime, "recurrent_neighbor_idx", torch.empty(0, 0)).shape[1]) + topology_kind = ( + "edge_owned_sparse" if int(getattr(runtime.config, "patch_edges_per_cell", 0)) > 0 else "receiver_owned_sparse" + ) + spatial_ownership = "edge_owned" if topology_kind == "edge_owned_sparse" else "receiver_owned" + degree_uniform = bucket_kind == "degree_bucketed_sparse" + execution_mode = "degree_bucketed_batched" if degree_uniform else "ragged_grouped" + gemm_backend = "batched_gemm" if degree_uniform else "grouped_gemm" + return { + "message_projection_bucket_kinds": ("sparse_projected_message_boundary",), + "message_bucket_count": ("1",), + "message_regular_local_bucket_count": ("0",), + "message_sparse_bucket_count": ("1",), + "message_batched_backend_count": ("1" if degree_uniform else "0",), + "message_grouped_backend_count": ("0" if degree_uniform else "1",), + "message_reset_aware_bucket_count": ("1",), + "message_degree_uniform_bucket_count": ("1" if degree_uniform else "0",), + "message_ragged_grouped_bucket_count": ("0" if degree_uniform else "1",), + "message_demoted_bucket_count": ("0",), + "message_bucket_signatures": ( + f"bucket_kind={bucket_kind}|topology_kind={topology_kind}|degree_or_block={degree}|" + f"K={int(runtime.head_dim)}|V={int(runtime.value_dim)}|reset_policy={reset_policy}|" + f"use_delay={use_delay_str}|distance_penalty_kind=edge_distance", + ), + "message_bucket_kinds": (bucket_kind,), + "message_topology_kinds": (topology_kind,), + "message_spatial_ownership": (spatial_ownership,), + "message_degree_bucket_lists": (_runtime_sparse_degree_summary(runtime),), + "message_logit_backends": (gemm_backend,), + "message_softmax_backends": ("custom_segment_softmax",), + "message_weighted_value_backends": (gemm_backend,), + "message_physical_mode": ( + "sparse_degree_bucketed_projected" if degree_uniform else "sparse_ragged_grouped_projected", + ), + "message_execution_mode": (execution_mode,), + "message_output_boundary": ("projected_message",), + "message_degree": (str(degree),), + "message_k": (str(int(runtime.head_dim)),), + "message_v": (str(int(runtime.value_dim)),), + "message_projected_n": (str(int(runtime.d_msg)),), + "message_reset_policies": (reset_policy,), + "message_reset_scopes": (reset_scope,), + "message_use_delay": (use_delay_str,), + "message_distance_penalty_kinds": ("edge_distance",), + "message_epilogue_kinds": ("segment_softmax_weighted_sum",), + "message_packed_source_reuse_count": ("1",), + "message_demotions": ("none",), + "message_workspace_mode": ("degree_bucketed_sparse" if degree_uniform else "degree_grouped_sparse_ragged",), + } + + +def _runtime_message_backward_metadata(runtime: Any) -> tuple[str, str, str, str]: + if bool(getattr(runtime, "_uses_sparse_message_backend", False)): + return ( + "sparse_message_superop_backward", + "physical_sparse_message_backward_executor", + "sparse_message_superop_backward:partitioned_cuda", + "sparse_message_superop_backward:active_sparse_cuda_owner", + ) + return ( + "tiny_message_superop_backward", + "physical_tiny_message_backward_executor", + "tiny_message_superop_backward:fused_receiver_sender_cuda", + "tiny_message_superop_backward:active_fused_receiver_sender_cuda_owner", + ) + + +def _runtime_transition_backward_record_metadata( + runtime: Any, + static_tensors: dict[str, object], +) -> dict[str, tuple[str, ...]]: + kinds: list[str] = [] + executors: list[str] = [] + boundaries: list[str] = [] + launch_counts: list[str] = [] + saved_launch_counts: list[str] = [] + + def add( + *, + kind: str, + executor: str, + boundary: str, + launch_count: str, + saved_launch_count: str | None = None, + ) -> None: + if kind not in kinds: + kinds.append(kind) + if executor not in executors: + executors.append(executor) + if boundary not in boundaries: + boundaries.append(boundary) + if launch_count and launch_count not in launch_counts: + launch_counts.append(launch_count) + if saved_launch_count and saved_launch_count not in saved_launch_counts: + saved_launch_counts.append(saved_launch_count) + + temporal_table = _build_temporal_primitive_table_plan(runtime, static_tensors) + primitive_names = frozenset(temporal_table.primitive_names) + if "matmul" in primitive_names: + add( + kind="receiver_affine_superop_backward", + executor="physical_receiver_affine_backward_executor", + boundary="state_affine_output", + launch_count="receiver_affine_superop_backward:physical_cuda_tiled", + saved_launch_count="receiver_affine_superop_backward:active_cuda_owner", + ) + if "diag_rtu" in primitive_names: + add( + kind="diagonal_recurrence_superop_backward", + executor="physical_diagonal_recurrence_backward_executor", + boundary="raw_public", + launch_count="diagonal_recurrence_superop_backward:triton_cuda", + saved_launch_count="diagonal_recurrence_superop_backward:active_cuda_owner", + ) + if "gated_logspace_recurrence" in primitive_names or "norm_or_identity" in primitive_names: + add( + kind="lowered_state_epilogue_backward", + executor="physical_state_epilogue_backward_executor", + boundary="raw_public", + launch_count="state_epilogue_backward:gated_logspace_cuda_tiled", + saved_launch_count="state_epilogue_backward:active_cuda_owner", + ) + + add( + kind="lowered_public_projection_backward", + executor="registered_sender_kv_projection_backward_executor", + boundary="public_projection", + launch_count="sender_kv_projection_backward:registered_cuda", + ) + add( + kind="lowered_readout_projection_backward", + executor="projection_reduction_boundary_backward", + boundary="readout_boundary", + launch_count="readout_projection_backward:registered_reverse_executor", + ) + return { + "kinds": tuple(kinds), + "executors": tuple(executors), + "boundaries": tuple(boundaries), + "launch_counts": tuple(launch_counts), + "saved_launch_counts": tuple(saved_launch_counts), + "residual_demotions": (), + } + + +def _runtime_state_affine_record_metadata( + runtime: Any, + static_tensors: dict[str, object], +) -> dict[str, tuple[str, ...]]: + temporal_table = _build_temporal_primitive_table_plan(runtime, static_tensors) + has_recurrent_affine = "matmul" in temporal_table.primitive_names + if not has_recurrent_affine: + return { + "backends": (), + "sources": (), + "reset_policies": (), + "reset_mode": (), + "reset_scope": (), + } + return { + "backends": ("receiver_affine_superop",), + "sources": ("projected_message", "previous_state"), + "reset_policies": ("none", "zero_source_rows"), + "reset_mode": ("row_mask_pack",), + "reset_scope": ("batch_row",), + } + + +def record_temporal_bucket_step_loop_execution( + runtime: Any, + *, + batch_size: int, + time_steps: int, + inner_steps: int, + training: bool, + materialize_final_state: bool = True, + output_boundary: Literal["sequence", "terminal"] = "sequence", + output_contract: _TemporalOutputContract = "full_cells", + static_tensors: dict[str, object], +) -> None: + if runtime._last_backend_execution is not None: + return + active_populations = tuple( + name for name in runtime._population_names if runtime._population_recurrent_indices(name).numel() > 0 + ) + recurrent_closure = _flat_bucket_output_recurrent_closure(runtime) + output_only_streaming = bool( + not training + and not materialize_final_state + and output_contract in {"output_cells", "pooled_output_cells"} + and supports_temporal_bucket_active_output_window(runtime, time_steps=time_steps) + ) + active_region = _flat_bucket_active_output_region(runtime) if output_only_streaming else recurrent_closure + active_region_mode = ( + _flat_bucket_active_output_region_mode(runtime) if output_only_streaming else recurrent_closure.mode + ) + active_region_demotions = ( + ("active_region_closure_full_surface",) + if time_steps > 1 and recurrent_closure.is_full + else ("active_region_closure_ragged",) + if time_steps > 1 and not recurrent_closure.compact_contiguous + else () + ) + transition_backward_executor = getattr(runtime, "_last_flat_bucket_transition_backward_executor", None) + physical_transition_backward = transition_backward_executor == _PHYSICAL_TRANSITION_BACKWARD_EXECUTOR + physical_temporal_backward = transition_backward_executor == _PHYSICAL_TEMPORAL_BACKWARD_EXECUTOR + transition_tape_mode = getattr(runtime, "_last_flat_bucket_transition_tape_mode", None) + transition_tape_reason = getattr(runtime, "_last_flat_bucket_transition_tape_reason", None) + temporal_artifact_mode = getattr(runtime, "_last_flat_bucket_temporal_artifact_mode", None) + temporal_artifact_reason = getattr(runtime, "_last_flat_bucket_temporal_artifact_reason", None) + temporal_artifact_checkpoint_stride = getattr( + runtime, + "_last_flat_bucket_temporal_artifact_checkpoint_stride", + None, + ) + temporal_artifact_recompute_window_len = getattr( + runtime, + "_last_flat_bucket_temporal_artifact_recompute_window_len", + None, + ) + temporal_artifact_checkpoint_count = getattr( + runtime, + "_last_flat_bucket_temporal_artifact_checkpoint_count", + None, + ) + state_cache_mode = getattr(runtime, "_last_flat_bucket_state_cache_mode", None) + state_cache_materialized_steps = getattr(runtime, "_last_flat_bucket_state_cache_materialized_steps", None) + state_cache_elided_steps = getattr(runtime, "_last_flat_bucket_state_cache_elided_steps", None) + recurrent_graph_layout_backend = getattr(runtime, "_last_flat_bucket_recurrent_graph_layout_backend", None) + graph_order_layout_backend = getattr(runtime, "_last_flat_bucket_graph_order_layout_backend", None) + public_carry_order = getattr(runtime, "_last_flat_bucket_public_carry_order", None) + public_projection_backend = getattr(runtime, "_last_flat_bucket_public_projection_backend", None) + readout_backend = getattr(runtime, "_last_flat_bucket_readout_backend", None) + temporal_table = _build_temporal_primitive_table_plan(runtime, static_tensors) + temporal_primitive_executor_plan = _build_temporal_primitive_executor_plan(temporal_table) + temporal_table_review = tuple( + getattr(runtime, "_last_flat_bucket_temporal_table_review", ()) or temporal_table.review_summary + ) + temporal_scheduler_plan = tuple(getattr(runtime, "_last_flat_bucket_temporal_scheduler_plan", ()) or ()) + temporal_primitive_names = tuple( + getattr(runtime, "_last_flat_bucket_temporal_primitive_names", ()) or temporal_table.primitive_names + ) + temporal_primitive_families = tuple( + getattr(runtime, "_last_flat_bucket_temporal_primitive_families", ()) or temporal_table.primitive_families + ) + temporal_primitive_row_count = ( + getattr(runtime, "_last_flat_bucket_temporal_primitive_row_count", None) + if getattr(runtime, "_last_flat_bucket_temporal_primitive_row_count", None) is not None + else len(temporal_table.primitive_rows) + ) + temporal_tensor_binding_row_count = ( + getattr(runtime, "_last_flat_bucket_temporal_tensor_binding_row_count", None) + if getattr(runtime, "_last_flat_bucket_temporal_tensor_binding_row_count", None) is not None + else len(temporal_table.tensor_bindings) + ) + temporal_tensor_binding_summaries = tuple( + getattr(runtime, "_last_flat_bucket_temporal_tensor_binding_summaries", ()) + or _temporal_tensor_binding_summaries(temporal_table) + ) + temporal_scan_binding_projection = tuple( + getattr(runtime, "_last_flat_bucket_temporal_scan_binding_projection", ()) or () + ) + temporal_reverse_executor_summaries = tuple( + getattr(runtime, "_last_flat_bucket_temporal_reverse_executor_summaries", ()) + or _temporal_reverse_executor_summaries(temporal_table) + ) + temporal_primitive_executor_contracts = tuple( + getattr(runtime, "_last_flat_bucket_temporal_primitive_executor_contracts", ()) + or temporal_primitive_executor_plan.summaries + ) + temporal_primitive_executor_blockers = tuple( + getattr(runtime, "_last_flat_bucket_temporal_primitive_executor_blockers", ()) + or temporal_primitive_executor_plan.blockers + ) + temporal_executor_kernel_registry = tuple( + getattr(runtime, "_last_flat_bucket_temporal_executor_kernel_registry", ()) or () + ) + fused_cuda_program_plan = tuple(getattr(runtime, "_last_flat_bucket_temporal_fused_cuda_program_plan", ()) or ()) + fused_cuda_launch_contract = tuple( + getattr(runtime, "_last_flat_bucket_temporal_fused_cuda_launch_contract", ()) or () + ) + fused_cuda_program_blocker = tuple( + getattr(runtime, "_last_flat_bucket_temporal_fused_cuda_program_blocker", ()) or () + ) + registered_program_executor_plan = tuple( + getattr(runtime, "_last_flat_bucket_temporal_registered_program_executor_plan", ()) or () + ) + memory_runtime_buffer_plan = tuple( + getattr(runtime, "_last_flat_bucket_temporal_memory_runtime_buffer_plan", ()) or () + ) + memory_runtime_artifact_plan = tuple( + getattr(runtime, "_last_flat_bucket_temporal_memory_runtime_artifact_plan", ()) or () + ) + physical_strategy_plan = tuple(getattr(runtime, "_last_flat_bucket_temporal_physical_strategy_plan", ()) or ()) + memory_runtime_schedule_rows = getattr( + runtime, + "_last_flat_bucket_temporal_memory_runtime_schedule_rows", + None, + ) + physical_strategy_rows = getattr( + runtime, + "_last_flat_bucket_temporal_physical_strategy_rows", + None, + ) + optional_transition_outputs = getattr( + runtime, + "_last_flat_bucket_temporal_optional_transition_outputs", + None, + ) + reverse_artifact_tensor_store = tuple( + getattr(runtime, "_last_flat_bucket_temporal_reverse_artifact_tensor_store", ()) or () + ) + registered_backward_memory_stages = tuple( + getattr(runtime, "_last_flat_bucket_temporal_registered_backward_memory_stages", ()) or () + ) + frontend_tensor_bytes = tuple(getattr(runtime, "_last_flat_bucket_temporal_frontend_tensor_bytes", ()) or ()) + runtime._last_flat_bucket_temporal_primitive_executor_contracts = temporal_primitive_executor_contracts + runtime._last_flat_bucket_temporal_primitive_executor_blockers = temporal_primitive_executor_blockers + temporal_scan_binding_abi = getattr(runtime, "_last_flat_bucket_temporal_scan_binding_abi", None) + temporal_scan_primitive_row_source = getattr( + runtime, + "_last_flat_bucket_temporal_scan_primitive_row_source", + None, + ) + temporal_backward_binding_abi = getattr(runtime, "_last_flat_bucket_temporal_backward_binding_abi", None) + forward_transition_executor = getattr(runtime, "_last_flat_bucket_forward_transition_executor", None) + message_backward_kind, message_backward_executor, message_backward_launch, message_backward_saved_launch = ( + _runtime_message_backward_metadata(runtime) + ) + transition_backward_metadata = _runtime_transition_backward_record_metadata(runtime, static_tensors) + static_saved_launch_counts: tuple[str, ...] = () + if training: + static_saved_items: list[str] = [] + static_tape_mode = getattr(runtime, "_last_training_static_tape_mode", None) + if static_tape_mode: + static_saved_items.append(f"training_static_tape:{static_tape_mode}") + if getattr(runtime, "_last_training_static_prepack_mode", None) == "views": + static_saved_items.append("training_static_prepack:receiver_major_views") + if getattr(runtime, "_last_backward_projection_mode", None) == "factorized_recurrent_input": + static_saved_items.append("training_static_projection:factorized_receiver_input") + static_saved_launch_counts = tuple(static_saved_items) + backward_owner_plan = _temporal_backward_owner_plan( + training=training, + transition_backward_executor=transition_backward_executor, + active_region_demotions=active_region_demotions, + transition_tape_mode=transition_tape_mode, + message_backward_kind=message_backward_kind, + message_backward_executor=message_backward_executor, + ) + temporal_tape_policy_bin = ( + f"physical_temporal_bucket_{transition_tape_mode}_transition_tape" + if transition_tape_mode + else "physical_temporal_bucket_saved_tape" + ) + scan_implementation = ( + "windowed_temporal_physical_scan" + if physical_temporal_backward and temporal_artifact_mode == "recompute_step_artifacts" + else "stored_temporal_physical_scan" + if physical_temporal_backward and temporal_artifact_mode == "store_step_artifacts" + else "flat_bucket_temporal_scan" + if physical_temporal_backward + else "flat_bucket_temporal_scan" + ) + scan_implementation = getattr(runtime, "_last_flat_bucket_scan_implementation", None) or scan_implementation + temporal_scan_owner = ( + getattr(runtime, "_last_flat_bucket_temporal_scan_owner", None) or "registered_fused_forward_program_cuda" + ) + message_metadata = _runtime_message_record_metadata(runtime) + state_affine_metadata = _runtime_state_affine_record_metadata(runtime, static_tensors) + workspace_aliases = ( + f"sequence_output_boundary:{'terminal_step' if output_boundary == 'terminal' else 'all_steps'}", + f"sequence_output_materialization:{'terminal_step_only' if output_boundary == 'terminal' else 'all_steps'}", + f"sequence_output_contract:{output_contract}", + "final_state=materialized" if materialize_final_state else "final_state=not_materialized", + *( + ("flat_bucket_temporal_scan:recurrent_kv_carry_reuse",) + if bool(getattr(runtime, "_last_flat_bucket_temporal_recurrent_kv_carry_reuse", False)) + else () + ), + *( + ( + f"flat_bucket_state_cache:{state_cache_mode}", + f"flat_bucket_state_cache_materialized_steps:{state_cache_materialized_steps}", + f"flat_bucket_state_cache_elided_steps:{state_cache_elided_steps}", + ) + if state_cache_mode is not None + else () + ), + *( + (f"flat_bucket_recurrent_graph_layout:{recurrent_graph_layout_backend}",) + if recurrent_graph_layout_backend is not None + else () + ), + *((f"flat_bucket_public_carry:{public_carry_order}",) if public_carry_order is not None else ()), + *( + (f"flat_bucket_public_projection:{public_projection_backend}",) + if public_projection_backend is not None + else () + ), + *((f"flat_bucket_readout:{readout_backend}",) if readout_backend is not None else ()), + *( + (f"flat_bucket_graph_order_layout:{graph_order_layout_backend}",) + if graph_order_layout_backend is not None + else () + ), + *( + ( + f"temporal_artifacts:{temporal_artifact_mode}", + f"temporal_artifact_checkpoint_stride:{temporal_artifact_checkpoint_stride}", + f"temporal_artifact_recompute_window_len:{temporal_artifact_recompute_window_len}", + f"temporal_artifact_checkpoint_count:{temporal_artifact_checkpoint_count}", + ) + if temporal_artifact_mode is not None + else () + ), + *(f"flat_bucket_temporal_table:{item}" for item in temporal_table_review), + *(f"flat_bucket_temporal_scheduler:{item}" for item in temporal_scheduler_plan), + *( + (f"flat_bucket_temporal_table_primitive_rows:{temporal_primitive_row_count}",) + if temporal_primitive_row_count is not None + else () + ), + *( + (f"flat_bucket_temporal_table_tensor_binding_rows:{temporal_tensor_binding_row_count}",) + if temporal_tensor_binding_row_count is not None + else () + ), + *(f"flat_bucket_temporal_tensor_binding:{item}" for item in temporal_tensor_binding_summaries), + *(f"flat_bucket_temporal_scan_binding_projection:{item}" for item in temporal_scan_binding_projection), + *(f"flat_bucket_temporal_reverse_executor:{item}" for item in temporal_reverse_executor_summaries), + *( + ("flat_bucket_temporal_table_primitives:" + ",".join(temporal_primitive_names),) + if temporal_primitive_names + else () + ), + *( + ("flat_bucket_temporal_table_primitive_families:" + ",".join(temporal_primitive_families),) + if temporal_primitive_families + else () + ), + *(f"flat_bucket_temporal_primitive_executor:{item}" for item in temporal_primitive_executor_contracts), + *(f"flat_bucket_temporal_primitive_executor_blocker:{item}" for item in temporal_primitive_executor_blockers), + *(f"flat_bucket_temporal_executor_kernel_registry:{item}" for item in temporal_executor_kernel_registry), + *(f"flat_bucket_temporal_fused_cuda_program:{item}" for item in fused_cuda_program_plan), + *(f"flat_bucket_temporal_fused_cuda_launch_contract:{item}" for item in fused_cuda_launch_contract), + *( + ("flat_bucket_temporal_fused_cuda_program_blocker:" + "|".join(fused_cuda_program_blocker),) + if fused_cuda_program_blocker + else () + ), + *(f"flat_bucket_temporal_registered_program_executor:{item}" for item in registered_program_executor_plan), + *(f"flat_bucket_temporal_memory_runtime_buffer:{item}" for item in memory_runtime_buffer_plan), + *( + (f"flat_bucket_temporal_transition_output_policy:{optional_transition_outputs}",) + if optional_transition_outputs is not None + else () + ), + *(f"flat_bucket_temporal_memory_runtime_artifact:{item}" for item in memory_runtime_artifact_plan), + *(f"flat_bucket_temporal_physical_strategy:{item}" for item in physical_strategy_plan), + *(f"flat_bucket_temporal_reverse_artifact_tensor_store:{item}" for item in reverse_artifact_tensor_store), + *( + f"flat_bucket_temporal_registered_backward_memory_stage:{item}" + for item in registered_backward_memory_stages + ), + *(f"flat_bucket_temporal_frontend_tensor_bytes:{item}" for item in frontend_tensor_bytes), + *( + ( + "flat_bucket_temporal_memory_runtime_schedule_rows:" + + "x".join(str(int(dim)) for dim in memory_runtime_schedule_rows.shape), + ) + if hasattr(memory_runtime_schedule_rows, "shape") + else () + ), + *( + ( + "flat_bucket_temporal_physical_strategy_rows:" + + "x".join(str(int(dim)) for dim in physical_strategy_rows.shape), + ) + if hasattr(physical_strategy_rows, "shape") + else () + ), + *( + (f"flat_bucket_temporal_scan_binding_abi:{temporal_scan_binding_abi}",) + if temporal_scan_binding_abi is not None + else () + ), + *( + (f"flat_bucket_temporal_scan_primitive_rows:{temporal_scan_primitive_row_source}",) + if temporal_scan_primitive_row_source is not None + else () + ), + *( + (f"flat_bucket_temporal_backward_binding_abi:{temporal_backward_binding_abi}",) + if temporal_backward_binding_abi is not None + else () + ), + ) + runtime._last_backend_execution = BackendExecutionRecord( + backend_name="cuda", + surface_key="registered_temporal_sequence_surface", + cell_type="bucketed", + regime="stream", + training=training, + batch_size=batch_size, + time_steps=time_steps, + inner_steps=inner_steps, + bucket_ids=tuple(bucket.bucket_id for bucket in runtime.backend_ir.buckets), + execution_families=("message_program", "transition_program", "readout_program"), + math_backends=("cuda_tensor_ops",), + tape_policy_bin=temporal_tape_policy_bin + if training and physical_temporal_backward + else "hybrid_physical_transition" + if training and physical_transition_backward + else "autograd" + if training + else "none", + graph_capture_enabled=False, + capability_variants=("registered_temporal_sequence_surface", "flat_bucket_temporal_scan", scan_implementation), + launch_temporal_executions=("temporal_bucket_sequence",), + launch_scan_implementations=(scan_implementation,), + launch_temporal_scan_owners=(temporal_scan_owner,), + launch_temporal_scan_outer_steps=(str(int(time_steps)),), + launch_temporal_scan_inner_steps=(str(int(inner_steps)),), + launch_temporal_scan_physical_steps=(str(int(time_steps) * int(inner_steps)),), + launch_temporal_scan_emission_counts=("1" if output_boundary == "terminal" else str(int(time_steps)),), + launch_temporal_scan_first_emission_steps=(str(max(0, int(inner_steps) - 1)),), + launch_temporal_scan_emission_strides=(str(max(1, int(inner_steps))),), + launch_temporal_scan_output_boundaries=(output_boundary,), + temporal_primitive_executor_contracts=temporal_primitive_executor_contracts, + temporal_primitive_executor_blockers=temporal_primitive_executor_blockers, + physical_op_kinds=( + "message", + "receiver_affine", + "state_epilogue", + "diagonal_recurrence", + "readout", + "glue/layout", + ), + physical_op_executors=( + "registered_temporal_sequence_surface", + "shared_graph_message", + f"transition_program={','.join(active_populations)}", + scan_implementation, + "readout_projection", + f"transition_tape={transition_tape_mode or 'unknown'}", + *( + (f"forward_transition={forward_transition_executor}",) + if forward_transition_executor is not None + else () + ), + ), + physical_boundary_contracts=( + "shared_public_message_substrate", + "population_local_state_banks", + "projected_message", + "fixed_active_spatial_region", + "readout_boundary", + ), + state_affine_backends=state_affine_metadata["backends"], + state_affine_sources=state_affine_metadata["sources"], + state_affine_reset_policies=state_affine_metadata["reset_policies"], + state_affine_reset_mode=state_affine_metadata["reset_mode"], + state_affine_reset_scope=state_affine_metadata["reset_scope"], + physical_op_demotions=active_region_demotions, + active_receiver_window_modes=(active_region_mode,), + active_receiver_window_offsets=(str(active_region.start),), + active_receiver_window_counts=(str(active_region.count),), + workspace_aliases=workspace_aliases, + message_projection_boundaries=("projected_message",), + message_projection_bucket_kinds=message_metadata["message_projection_bucket_kinds"], + message_bucket_count=message_metadata["message_bucket_count"], + message_regular_local_bucket_count=message_metadata["message_regular_local_bucket_count"], + message_sparse_bucket_count=message_metadata["message_sparse_bucket_count"], + message_batched_backend_count=message_metadata["message_batched_backend_count"], + message_grouped_backend_count=message_metadata["message_grouped_backend_count"], + message_reset_aware_bucket_count=message_metadata["message_reset_aware_bucket_count"], + message_degree_uniform_bucket_count=message_metadata["message_degree_uniform_bucket_count"], + message_ragged_grouped_bucket_count=message_metadata["message_ragged_grouped_bucket_count"], + message_demoted_bucket_count=message_metadata["message_demoted_bucket_count"], + message_bucket_signatures=message_metadata["message_bucket_signatures"], + message_bucket_kinds=message_metadata["message_bucket_kinds"], + message_topology_kinds=message_metadata["message_topology_kinds"], + message_spatial_ownership=message_metadata["message_spatial_ownership"], + message_degree_bucket_lists=message_metadata["message_degree_bucket_lists"], + message_logit_backends=message_metadata["message_logit_backends"], + message_softmax_backends=message_metadata["message_softmax_backends"], + message_weighted_value_backends=message_metadata["message_weighted_value_backends"], + message_physical_mode=message_metadata["message_physical_mode"], + message_execution_mode=message_metadata["message_execution_mode"], + message_output_boundary=message_metadata["message_output_boundary"], + message_degree=message_metadata["message_degree"], + message_k=message_metadata["message_k"], + message_v=message_metadata["message_v"], + message_projected_n=message_metadata["message_projected_n"], + message_reset_policies=message_metadata["message_reset_policies"], + message_reset_scopes=message_metadata["message_reset_scopes"], + message_use_delay=message_metadata["message_use_delay"], + message_distance_penalty_kinds=message_metadata["message_distance_penalty_kinds"], + message_epilogue_kinds=message_metadata["message_epilogue_kinds"], + message_packed_source_reuse_count=message_metadata["message_packed_source_reuse_count"], + message_demotions=message_metadata["message_demotions"], + message_workspace_mode=message_metadata["message_workspace_mode"], + backward_physical_op_kinds=( + (message_backward_kind, *transition_backward_metadata["kinds"], "glue/layout") if training else () + ), + backward_physical_op_executors=( + ( + message_backward_executor, + *transition_backward_metadata["executors"], + "physical_temporal_bucket_sequence_backward", + ) + if training + else () + ), + backward_physical_op_demotions=backward_owner_plan.demotions, + backward_boundary_contracts=( + ("projected_message", *transition_backward_metadata["boundaries"], "fixed_active_spatial_region") + if training + else () + ), + backward_tape_mode=backward_owner_plan.tape_modes, + backward_launch_counts=( + (message_backward_launch, *transition_backward_metadata["launch_counts"]) if training else () + ), + backward_saved_launch_counts=( + ( + message_backward_saved_launch, + *transition_backward_metadata["saved_launch_counts"], + *static_saved_launch_counts, + ) + if training + else () + ), + backward_residual_glue_demotions=transition_backward_metadata["residual_demotions"] if training else (), + backward_recompute_mode=( + tuple( + item + for item in ( + f"transition_tape:{transition_tape_mode}" if transition_tape_mode is not None else None, + transition_tape_reason, + f"temporal_artifacts:{temporal_artifact_mode}" if temporal_artifact_mode is not None else None, + temporal_artifact_reason, + ) + if item is not None + ) + if training + else () + ), + **temporal_execution_record_metadata(getattr(runtime, "_last_temporal_execution_plan", None)), + ) + + +__all__ = [ + "execute_temporal_bucket_sequence", + "record_temporal_bucket_step_loop_execution", + "supports_temporal_bucket_active_output_window", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/runtime/memory_stages.py b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/memory_stages.py new file mode 100644 index 00000000..573ac0a6 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/memory_stages.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import replace +from typing import Any + +import torch +from tensordict import TensorDictBase + + +def _append_workspace_summary(runtime: Any, prefix: str, summary: str) -> None: + record = getattr(runtime, "_last_backend_execution", None) + if record is None: + return + try: + runtime._last_backend_execution = replace( + record, + workspace_aliases=( + *tuple(getattr(record, "workspace_aliases", ()) or ()), + f"{prefix}:{summary}", + ), + ) + except TypeError: + return + + +def append_registered_memory_stage_summary(runtime: Any, summary: str) -> None: + previous = tuple(getattr(runtime, "_last_flat_bucket_temporal_registered_backward_memory_stages", ()) or ()) + runtime._last_flat_bucket_temporal_registered_backward_memory_stages = (*previous, summary) + _append_workspace_summary(runtime, "flat_bucket_temporal_registered_backward_memory_stage", summary) + + +def record_registered_memory_stage(runtime: Any, reference: torch.Tensor, stage: str) -> None: + if not torch.is_tensor(reference) or not reference.is_cuda: + return + enabled_fn = getattr(runtime, "_backend_owner_timing_enabled", None) + if not callable(enabled_fn) or not bool(enabled_fn(reference.device)): + return + try: + torch.cuda.synchronize(reference.device) + allocated = int(torch.cuda.memory_allocated(reference.device)) + reserved = int(torch.cuda.memory_reserved(reference.device)) + max_allocated = int(torch.cuda.max_memory_allocated(reference.device)) + except RuntimeError: + return + append_registered_memory_stage_summary( + runtime, + f"stage={stage};allocated={allocated};reserved={reserved};max_allocated={max_allocated}", + ) + + +def _logical_tensor_bytes(value: Any) -> int: + if torch.is_tensor(value): + return int(value.numel()) * int(value.element_size()) + if isinstance(value, TensorDictBase): + return sum(_logical_tensor_bytes(item) for item in value.values()) + if isinstance(value, Mapping): + return sum(_logical_tensor_bytes(item) for item in value.values()) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return sum(_logical_tensor_bytes(item) for item in value) + return 0 + + +def _first_tensor(value: Any) -> torch.Tensor | None: + if torch.is_tensor(value): + return value + if isinstance(value, TensorDictBase): + for item in value.values(): + tensor = _first_tensor(item) + if tensor is not None: + return tensor + if isinstance(value, Mapping): + for item in value.values(): + tensor = _first_tensor(item) + if tensor is not None: + return tensor + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + for item in value: + tensor = _first_tensor(item) + if tensor is not None: + return tensor + return None + + +def record_frontend_tensor_bytes( + runtime: Any, + *, + stage: str, + tensors: Mapping[str, Any], +) -> None: + reference = _first_tensor(tensors) + if reference is None or not reference.is_cuda: + return + enabled_fn = getattr(runtime, "_backend_owner_timing_enabled", None) + if not callable(enabled_fn) or not bool(enabled_fn(reference.device)): + return + byte_items: list[tuple[str, int]] = [] + for role, value in tensors.items(): + byte_count = _logical_tensor_bytes(value) + if byte_count > 0: + byte_items.append((str(role), int(byte_count))) + if not byte_items: + return + bytes_by_role = ",".join(f"{role}:{byte_count}" for role, byte_count in byte_items) + total_bytes = sum(byte_count for _role, byte_count in byte_items) + summary = f"stage={stage};total_bytes={int(total_bytes)};bytes_by_role={bytes_by_role}" + previous = tuple(getattr(runtime, "_last_flat_bucket_temporal_frontend_tensor_bytes", ()) or ()) + runtime._last_flat_bucket_temporal_frontend_tensor_bytes = (*previous, summary) + _append_workspace_summary(runtime, "flat_bucket_temporal_frontend_tensor_bytes", summary) diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/policy.py b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/policy.py similarity index 93% rename from src/cortical/fabric/backend/cuda/sequence_surface/policy.py rename to src/cortical/fabric/backend/cuda/sequence_surface/runtime/policy.py index bcdacee8..0af5aa6d 100644 --- a/src/cortical/fabric/backend/cuda/sequence_surface/policy.py +++ b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/policy.py @@ -378,58 +378,6 @@ def forward_layout_batch_tile_policy( ) -def projected_boundary_time_chunk_policy( - *, - time_steps: int, - batch_size: int, - dtype_bytes: int, - projected_features: int, - state_bytes: int, - output_ports: int, - hidden_size: int, - output_boundary: OutputBoundary, - memory: CudaMemoryBudget | None, - graph_chunk_target_bytes: int, - graph_chunk_min_steps: int, -) -> PolicyDecision: - if time_steps <= 1 or memory is None: - return PolicyDecision(value=int(time_steps), reason=f"time_steps={time_steps};chunk_len={time_steps}") - boundary_step_bytes = int(batch_size) * int(projected_features) * int(dtype_bytes) - output_step_bytes = int(batch_size) * int(output_ports) * int(hidden_size) * int(dtype_bytes) - sequence_output_multiplier = 3.0 if output_boundary == "sequence" else 1.0 - estimated_per_step_bytes = max( - 1, - int( - math.ceil(float(boundary_step_bytes + state_bytes) + float(output_step_bytes) * sequence_output_multiplier) - ), - ) - runtime_reserve_bytes = max(3 << 30, int(memory.total_bytes * 0.03)) - target_fraction = 0.18 if output_boundary == "sequence" else 0.30 - target_bytes = min( - int(memory.total_bytes * target_fraction), - max(1 << 30, int(memory.usable_bytes) - int(runtime_reserve_bytes)), - ) - target_bytes = max(1 << 30, target_bytes) - raw_chunk_len = max(1, int(target_bytes // estimated_per_step_bytes)) - chunk_len = max(1, min(int(time_steps), raw_chunk_len)) - chunk_len = round_down_power_of_two(chunk_len, time_steps) if chunk_len > 1 else chunk_len - graph_chunk_cap = int(time_steps) - if output_boundary == "sequence": - raw_graph_chunk_cap = max(int(graph_chunk_min_steps), int(memory.total_bytes // graph_chunk_target_bytes)) - graph_chunk_cap = round_down_power_of_two(raw_graph_chunk_cap, time_steps) - graph_chunk_cap = max(1, min(int(time_steps), graph_chunk_cap)) - chunk_len = min(chunk_len, graph_chunk_cap) - return PolicyDecision( - value=int(chunk_len), - reason=( - f"time_steps={time_steps};estimated_per_step_bytes={estimated_per_step_bytes};" - f"free_bytes={int(memory.usable_bytes)};target_bytes={int(target_bytes)};" - f"raw_chunk_len={int(raw_chunk_len)};graph_chunk_cap={int(graph_chunk_cap)};" - f"chunk_len={int(chunk_len)}" - ), - ) - - def active_output_backward_batch_tile_policy( *, inputs: ActiveOutputBackwardTileInputs, diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/support.py b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/support.py similarity index 92% rename from src/cortical/fabric/backend/cuda/sequence_surface/support.py rename to src/cortical/fabric/backend/cuda/sequence_surface/runtime/support.py index 94ff1285..368ab654 100644 --- a/src/cortical/fabric/backend/cuda/sequence_surface/support.py +++ b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/support.py @@ -10,10 +10,7 @@ import torch from tensordict import TensorDict, TensorDictBase -from cortical.fabric.backend.planner import PlannedFabricBackwardExecution - _BACKWARD_ATTRIBUTION_MODE_ENV = "CORTICAL_FABRIC_BACKWARD_ATTRIBUTION_MODE" -_BACKWARD_MODE_ENV = "CORTICAL_FABRIC_BACKWARD_MODE" _BACKWARD_OWNER_TIMING_ENV = "CORTICAL_FABRIC_BACKWARD_OWNER_TIMING" _RECOMPUTE_PAYLOAD_TOTAL_KEYS = ( "glue_boundary", @@ -208,7 +205,7 @@ def _population_display_name(population_name: str) -> str: @dataclass(frozen=True) -class _GraphCaptureFallback: +class _GraphCaptureCacheEntry: key: object shape_signature: tuple[tuple[int, ...], ...] @@ -364,25 +361,3 @@ def wall_summary(self) -> tuple[str, ...]: f"{name}:ms={totals[name]:.3f};count={counts[name]}" for name in sorted(totals, key=totals.__getitem__, reverse=True) ) - - -@dataclass(frozen=True) -class _PhysicalBackwardSequenceExecutor: - runtime: Any - plan: PlannedFabricBackwardExecution - - def run(self, **kwargs: Any) -> tuple[dict[str, torch.Tensor | None], tuple[torch.Tensor | None, ...]]: - unsupported = tuple( - behavior.family - for behavior in self.plan.physical_plan.family_behaviors - if behavior.behavior == "unsupported" - ) - if unsupported: - raise RuntimeError( - "Fabric physical backward executor cannot run unsupported families: " + ", ".join(sorted(unsupported)) - ) - self.runtime._begin_backend_owner_timing(kwargs["boundary_seq"].device) - with torch.profiler.record_function("fabric.backward.physical_sequence_executor"): - result = self.runtime._run_backend_sequence_surface_backward_once(**kwargs) - self.runtime._finish_backend_owner_timing() - return result diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/runtime/surface.py b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/surface.py new file mode 100644 index 00000000..62873d53 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/runtime/surface.py @@ -0,0 +1,1527 @@ +from __future__ import annotations + +import os +from collections.abc import Mapping +from contextlib import contextmanager +from dataclasses import replace +from typing import Any, Iterator, Literal, cast + +import torch +from tensordict import TensorDict, TensorDictBase + +from cortical.fabric.backend.cuda.sequence_surface.runtime.executor import ( + execute_temporal_bucket_sequence, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.support import ( + _BACKWARD_OWNER_TIMING_ENV, + _BackendOwnerTimingCollector, + _cat_batch_tree, + _GraphCaptureCacheEntry, + _ReceiverWindowSpec, + _slice_batch_tensor, + _slice_batch_tree, +) +from cortical.fabric.backend.planner import ( + PlannedFabricBackwardExecution, + PlannedFabricExecution, +) +from cortical.fabric.backend.reuse import ExecutionFamily +from cortical.fabric.backend.surfaces import SupportedSurface +from cortical.fabric.contracts.cells import collect_cell_tensors + + +class CudaSequenceSurfaceMixin: + @staticmethod + def _clone_backend_carry_value(value: Any) -> Any: + if value is None: + return None + if torch.is_tensor(value): + return value.clone() + if isinstance(value, TensorDictBase): + return TensorDict( + {key: CudaSequenceSurfaceMixin._clone_backend_carry_value(item) for key, item in value.items()}, + batch_size=value.batch_size, + device=value.device, + ) + if isinstance(value, dict): + return {key: CudaSequenceSurfaceMixin._clone_backend_carry_value(item) for key, item in value.items()} + return value + + @staticmethod + def _cuda_usable_memory_info(device: torch.device) -> tuple[int, int, int, int]: + free_bytes, total_bytes = torch.cuda.mem_get_info(device) + try: + reserved_bytes = int(torch.cuda.memory_reserved(device)) + allocated_bytes = int(torch.cuda.memory_allocated(device)) + except RuntimeError: + reserved_bytes = 0 + allocated_bytes = 0 + reusable_reserved_bytes = max(0, reserved_bytes - allocated_bytes) + usable_bytes = int(free_bytes) + int(reusable_reserved_bytes) + return int(usable_bytes), int(total_bytes), int(free_bytes), int(reusable_reserved_bytes) + + def _backend_owner_timing_enabled(self, device: torch.device) -> bool: + if device.type != "cuda": + return False + return os.environ.get(_BACKWARD_OWNER_TIMING_ENV, "").lower() in {"1", "true", "yes", "on"} + + def _begin_backend_owner_timing(self, device: torch.device) -> None: + if not self._backend_owner_timing_enabled(device): + self._active_backend_owner_timing = None + return + self._active_backend_owner_timing = _BackendOwnerTimingCollector(device=device, events=[]) + + def _finish_backend_owner_timing(self) -> None: + collector = getattr(self, "_active_backend_owner_timing", None) + self._active_backend_owner_timing = None + if collector is None: + return + timing_summary = collector.summary() + wall_summary = collector.wall_summary() + if not timing_summary and not wall_summary: + return + record = getattr(self, "_last_backend_execution", None) + if record is None: + self._last_backend_owner_timing_ms = timing_summary + self._last_backend_owner_wall_ms = wall_summary + return + self._last_backend_execution = replace( + record, + backward_owner_timing_ms=timing_summary, + backward_owner_wall_ms=wall_summary, + ) + self._last_backend_owner_timing_ms = timing_summary + self._last_backend_owner_wall_ms = wall_summary + + @contextmanager + def _backend_owner_timing(self, name: str) -> Iterator[None]: + collector = getattr(self, "_active_backend_owner_timing", None) + if collector is None: + yield + return + with collector.record(name): + yield + + def _backend_sequence_graph_capture_safe(self) -> bool: + # The active CUDA recurrence path runs through the generic dispatcher-backed + # backend executor and is the intended graph-capture surface for supported + # plans. Do not silently disable capture here and fall back to uncaptured + # replay for surfaces the planner has already marked capture-safe. + return True + + def _backend_execution_semantics(self, execution_family: ExecutionFamily) -> tuple[str, str]: + if execution_family in {ExecutionFamily.SEQUENCE_MAJOR, ExecutionFamily.RECEIVER_MAJOR}: + return "receiver_owned", "persistent_scan" + if execution_family == ExecutionFamily.EDGE_MAJOR: + return "edge_owned", "persistent_scan" + raise RuntimeError(f"Unsupported Fabric execution family {execution_family.value}") + + @staticmethod + def _slice_receiver_window_static_tensor( + tensor: torch.Tensor, + window: _ReceiverWindowSpec, + ) -> torch.Tensor: + if tensor.dim() == 1 and int(tensor.numel()) >= window.full_count: + if int(tensor.numel()) % window.full_count == 0: + receiver_view = tensor.reshape(window.full_count, -1) + return receiver_view.narrow(0, window.start, window.count).contiguous().reshape(-1) + if tensor.dim() >= 1 and int(tensor.shape[0]) == window.full_count: + return tensor.narrow(0, window.start, window.count).contiguous() + if tensor.dim() >= 2 and int(tensor.shape[0]) == 1 and int(tensor.shape[1]) == window.full_count: + return tensor.narrow(1, window.start, window.count).contiguous() + return tensor + + def _slice_receiver_window_static_tensors( + self, + static_tensors: Mapping[str, object], + window: _ReceiverWindowSpec | None, + ) -> dict[str, object]: + if window is None or not window.active: + return dict(static_tensors) + sliced: dict[str, object] = dict(static_tensors) + for key in ( + "recurrent_cell_bias", + "fused_recurrent_value_to_cell_weight", + "fused_recurrent_cell_bias", + "recurrent_sender_input_to_kv_weight", + ): + value = sliced.get(key) + if torch.is_tensor(value): + sliced[key] = self._slice_receiver_window_static_tensor(value, window) + + direct_recurrent_kv = sliced.get("recurrent_sender_input_to_kv_weight") + grouped_recurrent_kv = sliced.get("recurrent_group_input_to_kv_weight") + if torch.is_tensor(direct_recurrent_kv): + sliced["recurrent_group_input_to_kv_weight"] = None + elif torch.is_tensor(grouped_recurrent_kv): + expanded = grouped_recurrent_kv.repeat_interleave( + max(1, int(self._recurrent_sender_kv_group_size)), + dim=0, + ) + if int(expanded.shape[0]) >= window.full_count: + sliced["recurrent_sender_input_to_kv_weight"] = ( + expanded[: window.full_count].narrow(0, window.start, window.count).contiguous() + ) + sliced["recurrent_group_input_to_kv_weight"] = None + + population_materialized = sliced.get("population_materialized") + if isinstance(population_materialized, dict): + sliced_population_materialized: dict[str, object | None] = {} + for population_name, params in population_materialized.items(): + if not isinstance(params, dict): + sliced_population_materialized[population_name] = params + continue + sliced_params: dict[str, object] = {} + for name, value in params.items(): + sliced_params[name] = ( + self._slice_receiver_window_static_tensor(value, window) if torch.is_tensor(value) else value + ) + sliced_population_materialized[population_name] = sliced_params + sliced["population_materialized"] = sliced_population_materialized + return sliced + + def _record_cuda_launch_metadata( + self, + request: Any, + *, + spatial_ownership: str, + temporal_execution: str, + actual_launch_metadata: Mapping[str, tuple[Any, ...]] | None = None, + ) -> None: + self._last_backend_launch_metadata = { + "receiver_tiles": (int(request.receiver_tile),), + "batch_tiles": (int(request.batch_tile),), + "edge_tiles": (int(request.edge_tile),), + "hidden_chunks": (int(request.hidden_chunk),), + "state_receiver_tiles": (int(request.state_receiver_tile),), + "state_batch_tiles": (int(request.state_batch_tile),), + "state_hidden_chunks": (int(request.state_hidden_chunk),), + "state_static_stage_modes": (str(request.state_static_stage_mode),), + "emit_receiver_tiles": (int(request.emit_receiver_tile),), + "emit_batch_tiles": (int(request.emit_batch_tile),), + "emit_hidden_chunks": (int(request.emit_hidden_chunk),), + "emit_static_stage_modes": (str(request.emit_static_stage_mode),), + "public_receiver_tiles": (int(request.public_receiver_tile),), + "public_batch_tiles": (int(request.public_batch_tile),), + "replication_factors": (int(request.replication_factor),), + "cell_static_stage_modes": (str(request.cell_static_stage_mode),), + "readout_modes": (str(request.readout_mode),), + "workspace_aliases": ("none",), + "temporal_executions": (temporal_execution,), + "scan_implementations": ( + "disabled_sequence_surface" if temporal_execution == "persistent_scan" else "single_step", + ), + "temporal_scan_owners": ( + "disabled_sequence_surface" if temporal_execution == "persistent_scan" else "single_step", + ), + "temporal_scan_outer_steps": (str(int(request.input_k_seq.shape[1])),), + "temporal_scan_inner_steps": ("1",), + "temporal_scan_physical_steps": (str(int(request.input_k_seq.shape[1])),), + "temporal_scan_emission_counts": ( + "1" + if str(request.static_config.get("output_boundary", "sequence")) == "terminal" + else str(int(request.input_k_seq.shape[1])), + ), + "temporal_scan_first_emission_steps": ("0",), + "temporal_scan_emission_strides": ("1",), + "temporal_scan_output_boundaries": (str(request.static_config.get("output_boundary", "sequence")),), + "active_receiver_window_modes": ("full_surface",), + "active_receiver_window_offsets": ("0",), + "active_receiver_window_counts": ("0",), + "phases": ( + "receiver_message_aggregate", + "dense_input_projection", + "dense_state_affines", + "receiver_state_update", + "receiver_reduce_stats", + "receiver_emit_raw_public", + "dense_public_projection", + "readout_message_aggregate", + "dense_readout_projection", + ) + if spatial_ownership == "receiver_owned" + else ( + "edge_owned_accumulate", + "receiver_message_normalize", + "dense_input_projection", + "dense_state_affines", + "receiver_state_update", + "receiver_reduce_stats", + "receiver_emit_raw_public", + "dense_public_projection", + "readout_message_aggregate", + "dense_readout_projection", + ), + "input_projection_backends": ("unrun",), + "input_projection_notes": ("unrun",), + "message_projection_boundaries": ("unrun",), + "message_projection_bucket_kinds": ("unrun",), + "message_bucket_count": ("unrun",), + "message_regular_local_bucket_count": ("unrun",), + "message_sparse_bucket_count": ("unrun",), + "message_batched_backend_count": ("unrun",), + "message_grouped_backend_count": ("unrun",), + "message_reset_aware_bucket_count": ("unrun",), + "message_degree_uniform_bucket_count": ("unrun",), + "message_ragged_grouped_bucket_count": ("unrun",), + "message_demoted_bucket_count": ("unrun",), + "message_bucket_signatures": ("unrun",), + "message_bucket_kinds": ("unrun",), + "message_topology_kinds": ("unrun",), + "message_spatial_ownership": ("unrun",), + "message_degree_bucket_lists": ("unrun",), + "message_logit_backends": ("unrun",), + "message_softmax_backends": ("unrun",), + "message_weighted_value_backends": ("unrun",), + "message_physical_mode": ("unrun",), + "message_execution_mode": ("unrun",), + "message_output_boundary": ("unrun",), + "message_reset_policies": ("unrun",), + "message_reset_scopes": ("unrun",), + "message_use_delay": ("unrun",), + "message_distance_penalty_kinds": ("unrun",), + "message_epilogue_kinds": ("unrun",), + "message_packed_source_reuse_count": ("unrun",), + "message_demotions": ("unrun",), + "message_workspace_buffers": ("unrun",), + "message_workspace_buffer_bytes": ("unrun",), + "message_workspace_peak_bytes": ("unrun",), + "message_workspace_mode": ("unrun",), + "message_workspace_aliases": ("unrun",), + "message_per_bucket_workspace_bytes": ("unrun",), + "phase_launch_counts": ("unrun",), + "small_cublas_launch_counts": ("unrun",), + "copy_glue_launch_counts": ("unrun",), + "copy_glue_saved_launch_counts": ("unrun",), + "bias_glue_launch_counts": ("unrun",), + "bias_glue_saved_launch_counts": ("unrun",), + "state_epilogue_modes": ("unrun",), + "state_epilogue_saved_launch_counts": ("unrun",), + "launch_coalescing_modes": ("unrun",), + "generic_glue_fusion_modes": ("unrun",), + "launch_granularity_modes": ("unrun",), + "physical_op_kinds": ("unrun",), + "physical_layout_contracts": ("unrun",), + "layout_mode": ("unrun",), + "copy_elision_mode": ("unrun",), + "bias_fusion_mode": ("unrun",), + "physical_op_executors": ("unrun",), + "physical_op_demotions": ("unrun",), + "physical_boundary_contracts": ("unrun",), + "physical_applicability_predicates": ("unrun",), + "physical_workspace_aliases": ("unrun",), + "physical_workspace_peak_bytes": ("unrun",), + "physical_op_launch_counts": ("unrun",), + "physical_op_saved_launch_counts": ("unrun",), + "standalone_copy_kernel_count": ("unrun",), + "standalone_bias_kernel_count": ("unrun",), + "receiver_affine_superop_surface_count": ("unrun",), + "receiver_affine_superop_receivers": ("unrun",), + "receiver_affine_superop_k": ("unrun",), + "receiver_affine_superop_n": ("unrun",), + "receiver_affine_superop_source_layout": ("unrun",), + "receiver_affine_superop_reset_policy": ("unrun",), + "receiver_affine_superop_executor": ("unrun",), + "receiver_affine_superop_physical_mode": ("unrun",), + "receiver_affine_superop_demotion_reason": ("unrun",), + "diagonal_recurrence_superop_surface_count": ("unrun",), + "diagonal_recurrence_kind": ("unrun",), + "diagonal_recurrence_executor": ("unrun",), + "diagonal_recurrence_physical_mode": ("unrun",), + "diagonal_recurrence_coeff_cache_mode": ("unrun",), + "diagonal_recurrence_coeff_cache_hit": ("unrun",), + "diagonal_recurrence_coeff_cache_bytes": ("unrun",), + "diagonal_recurrence_coeff_cache_version_source": ("unrun",), + "diagonal_recurrence_reset_policy": ("unrun",), + "diagonal_recurrence_reset_scope": ("unrun",), + "diagonal_recurrence_output_boundary": ("unrun",), + "diagonal_recurrence_workspace_mode": ("unrun",), + "diagonal_recurrence_workspace_peak_bytes": ("unrun",), + "diagonal_recurrence_demotion_reason": ("unrun",), + "diagonal_recurrence_launch_count": ("unrun",), + "state_affine_backends": ("unrun",), + "state_affine_sources": ("unrun",), + "state_affine_bucket_signatures": ("unrun",), + "state_affine_output_modes": ("unrun",), + "state_affine_reset_policies": ("unrun",), + "state_affine_reset_mode": ("unrun",), + "state_affine_reset_scope": ("unrun",), + "state_affine_workspace_mode": ("unrun",), + "state_affine_receiver_chunk_size": ("unrun",), + "state_affine_receiver_chunks": ("unrun",), + "state_affine_workspace_buffers": ("unrun",), + "state_affine_workspace_buffer_bytes": ("unrun",), + "state_affine_workspace_bytes": ("unrun",), + "state_affine_reset_rows_present": ("unrun",), + "state_affine_packed_source_reused": ("unrun",), + "public_projection_hidden_backends": ("unrun",), + "public_projection_kv_backends": ("unrun",), + "readout_projection_backends": ("unrun",), + "workspace_buffers": ("unrun",), + "workspace_buffer_bytes": ("unrun",), + "workspace_peak_bytes": ("unrun",), + } + if actual_launch_metadata is not None: + self._last_backend_launch_metadata.update(actual_launch_metadata) + + def _resolve_backend_initial_recurrent_kv( + self, + *, + population_name: str, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + static_tensors: dict[str, object], + active_receiver_window: _ReceiverWindowSpec | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if initial_recurrent_k is None and initial_recurrent_v is None: + public_kind = self._cell_spec_for_population(population_name).public_schema.kind + if public_kind == "hidden": + projection_static_tensors = static_tensors + if ( + active_receiver_window is not None + and active_receiver_window.active + and initial_hidden.dim() >= 2 + and int(initial_hidden.shape[1]) == int(active_receiver_window.count) + ): + projection_static_tensors = self._slice_receiver_window_static_tensors( + static_tensors, + active_receiver_window, + ) + return self._project_sender_kv_from_cells_step( + initial_hidden, + sender_input_to_kv_weight=cast( + torch.Tensor | None, projection_static_tensors["recurrent_sender_input_to_kv_weight"] + ), + grouped_sender_input_to_kv_weight=cast( + torch.Tensor | None, projection_static_tensors["recurrent_group_input_to_kv_weight"] + ), + sender_group_size=1 + if projection_static_tensors.get("recurrent_group_input_to_kv_weight") is None + else self._recurrent_sender_kv_group_size, + ) + if public_kind == "preproj": + batch_size, recurrent_cells, _hidden = initial_hidden.shape + return ( + initial_hidden.new_zeros(batch_size, recurrent_cells, self.head_dim), + initial_hidden.new_zeros(batch_size, recurrent_cells, self.value_dim), + ) + raise RuntimeError(f"Unsupported public schema kind {public_kind} for cell population {population_name}") + if initial_recurrent_k is None or initial_recurrent_v is None: + raise ValueError("initial_recurrent_k and initial_recurrent_v must both be provided or both be None") + return initial_recurrent_k, initial_recurrent_v + + def _backend_sequence_surface_projection_dims( + self, + *, + population_name: str, + static_tensors: dict[str, object], + ) -> tuple[int, int, tuple[torch.Tensor, ...]]: + population_materialized = cast(dict[str, object | None], static_tensors["population_materialized"]) + population_params = population_materialized[population_name] + if not isinstance(population_params, dict): + raise RuntimeError(f"Fabric cell population {population_name} is missing materialized parameters") + tensor_population_params = {key: value for key, value in population_params.items() if torch.is_tensor(value)} + backend_cell_tensors = cast(dict[str, dict[str, torch.Tensor]], static_tensors["backend_cell_tensors"]) + extra_cell_tensors = backend_cell_tensors.get(population_name) + if extra_cell_tensors is None: + raise RuntimeError( + f"Fabric cell population {population_name} is missing backend cell tensor materialization" + ) + cell_spec = self._cell_spec_for_population(population_name) + cell_tensors = collect_cell_tensors(cell_spec, extra_cell_tensors, tensor_population_params) + cell_params = tuple(cell_spec.parameter_schema.flatten(cell_tensors)) + input_projection_params = tuple(cell_spec.input_projection_schema.flatten(cell_tensors)) + public_projection_params = tuple(cell_spec.public_projection_schema.flatten(cell_tensors)) + input_weight = input_projection_params[0] + projected_message_dim = int(input_weight.shape[2] if input_weight.dim() == 3 else input_weight.shape[0]) + if cell_spec.public_schema.kind == "hidden": + raw_public_dim = int(self.hidden_size) + else: + raw_public_dim = int(public_projection_params[0].shape[1]) + return projected_message_dim, raw_public_dim, cell_params + + def _can_virtualize_fresh_backend_state( + self, + *, + population_name: str, + static_tensors: dict[str, object], + projected_message_dim: int, + raw_public_dim: int, + cell_params: tuple[torch.Tensor, ...], + ) -> bool: + del population_name, static_tensors, projected_message_dim, raw_public_dim, cell_params + return False + + def _supports_cuda_backend_sequence_surface( + self, + *, + k: int | torch.Tensor | None, + device: torch.device, + dtype: torch.dtype, + ) -> bool: + route = self._plan_sequence_surface_route( + k=k, + device=device, + dtype=dtype, + ) + return route.supported + + def _execute_backend_sequence_surface( + self, + *, + state: TensorDict, + boundary_seq: torch.Tensor, + projected_boundary_source_seq: torch.Tensor | None = None, + projected_boundary_weight: torch.Tensor | None = None, + projected_boundary_bias: torch.Tensor | None = None, + static_tensors: dict[str, object], + population_resets: torch.Tensor | None, + input_sender_input_to_kv_weight: torch.Tensor | None, + input_group_input_to_kv_weight: torch.Tensor | None, + backend_population_name: str | None, + backend_population_state_is_fresh: bool, + materialize_final_state: bool, + grad_path: bool, + selected_backend_surface: SupportedSurface, + planned_backend_execution: PlannedFabricExecution, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + ) -> tuple[torch.Tensor, TensorDict]: + if backend_population_name is not None: + return self._execute_backend_supported_sequence_surface( + state=state, + boundary_seq=boundary_seq, + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + static_tensors=static_tensors, + population_resets=population_resets, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + backend_population_name=backend_population_name, + backend_population_state_is_fresh=backend_population_state_is_fresh, + materialize_final_state=materialize_final_state, + grad_path=grad_path, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=planned_backend_execution, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + raise ValueError(f"Unsupported backend surface {selected_backend_surface.key}") + + def _compiler_temporal_initial_state( + self, + *, + population_name: str, + boundary_seq: torch.Tensor, + packed_state: Any, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + initial_state_is_fresh: bool, + ) -> TensorDict: + batch_size = int(boundary_seq.shape[0]) + cells = boundary_seq.new_zeros(batch_size, int(self.coords.shape[0]), int(self.hidden_size)) + cells[:, self._recurrent_slice, :] = initial_hidden.to(dtype=cells.dtype) + state_payload: dict[str, object] = {"cells": cells} + if packed_state is not None: + state_payload[population_name] = self._backend_state_to_population_state( + population_name, + cast(Mapping[str, torch.Tensor], packed_state), + ) + elif not initial_state_is_fresh: + state_payload[population_name] = self._backend_state_to_population_state( + population_name, + cast( + Mapping[str, torch.Tensor], + self._init_backend_population_state( + population_name, + batch=batch_size, + device=boundary_seq.device, + dtype=boundary_seq.dtype, + ), + ), + ) + if torch.is_tensor(initial_recurrent_k) and torch.is_tensor(initial_recurrent_v): + input_k = boundary_seq.new_zeros(batch_size, int(self._num_input_cells), int(self.head_dim)) + input_v = boundary_seq.new_zeros(batch_size, int(self._num_input_cells), int(self.value_dim)) + state_payload["sender_k"] = torch.cat((input_k, initial_recurrent_k.to(dtype=input_k.dtype)), dim=1) + state_payload["sender_v"] = torch.cat((input_v, initial_recurrent_v.to(dtype=input_v.dtype)), dim=1) + return TensorDict(state_payload, batch_size=[]) + + def _execute_compiler_temporal_sequence_surface( + self, + *, + population_name: str, + boundary_seq: torch.Tensor, + packed_state: Any, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + initial_state_is_fresh: bool, + materialize_final_state: bool, + compact_input_carry: bool = False, + preserve_internal_carry: bool = False, + population_resets: torch.Tensor, + transition_population_resets: torch.Tensor | None = None, + input_sender_input_to_kv_weight: torch.Tensor | None, + input_group_input_to_kv_weight: torch.Tensor | None, + planned_backend_execution: PlannedFabricExecution, + population_materialized: dict[str, object | None], + static_tensors: dict[str, object], + grad_path: bool, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + temporal_outer_time_steps: int | None = None, + temporal_inner_steps: int = 1, + ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + del ( + input_sender_input_to_kv_weight, + input_group_input_to_kv_weight, + population_materialized, + compact_input_carry, + ) + inner_steps = max(1, int(temporal_inner_steps)) + if temporal_outer_time_steps is not None and int(temporal_outer_time_steps) != int(boundary_seq.shape[1]): + raise RuntimeError("Compiler temporal sequence surface received mismatched outer timestep count") + compiler_state = self._compiler_temporal_initial_state( + population_name=population_name, + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + initial_state_is_fresh=initial_state_is_fresh, + ) + materialize_carry = bool(materialize_final_state or preserve_internal_carry) + output_contract: Literal["output_cells", "pooled_output_cells"] = ( + "pooled_output_cells" if readout_output_boundary == "pooled" else "output_cells" + ) + output_seq, next_state = execute_temporal_bucket_sequence( + self, + hidden_seq=None, + boundary_seq=boundary_seq, + state=compiler_state, + population_resets=population_resets, + step_reset_flags=None, + k=inner_steps, + constant_k=inner_steps, + batch_size=int(boundary_seq.shape[0]), + time_steps=int(boundary_seq.shape[1]), + step_mode=False, + capture_active=bool(boundary_seq.device.type == "cuda" and torch.cuda.is_current_stream_capturing()), + static_tensors=static_tensors, + grad_path=grad_path, + materialize_final_state=materialize_carry, + backend_population_state_is_fresh=initial_state_is_fresh, + use_fresh_backend_population_cache=False, + tape_policy=self._tape_policy_from_bin(planned_backend_execution.tape_policy_bin) if grad_path else None, + output_contract=output_contract, + output_boundary=output_boundary, + ) + batch_size = int(boundary_seq.shape[0]) + if not materialize_carry: + recurrent_hidden = boundary_seq.new_empty(batch_size, 0, int(self.hidden_size)) + recurrent_k = boundary_seq.new_empty(batch_size, 0, int(self.head_dim)) + recurrent_v = boundary_seq.new_empty(batch_size, 0, int(self.value_dim)) + input_k_seq = boundary_seq.new_empty(batch_size, 1, int(self._num_input_cells), int(self.head_dim)) + input_v_seq = boundary_seq.new_empty(batch_size, 1, int(self._num_input_cells), int(self.value_dim)) + return output_seq, {}, recurrent_hidden, recurrent_k, recurrent_v, input_k_seq, input_v_seq + final_cells = next_state.get("cells") + if not torch.is_tensor(final_cells): + raise RuntimeError("Compiler temporal sequence surface did not materialize final cells") + sender_k = next_state.get("sender_k") + sender_v = next_state.get("sender_v") + if not torch.is_tensor(sender_k) or not torch.is_tensor(sender_v): + raise RuntimeError("Compiler temporal sequence surface did not materialize final sender K/V") + recurrent_hidden = final_cells[:, self._recurrent_slice, :].contiguous() + input_count = int(self._num_input_cells) + input_k_last = sender_k[:, :input_count, :].contiguous() + input_v_last = sender_v[:, :input_count, :].contiguous() + recurrent_k = sender_k[:, input_count:, :].contiguous() + recurrent_v = sender_v[:, input_count:, :].contiguous() + population_state = next_state.get(population_name) + if not isinstance(population_state, TensorDictBase): + raise RuntimeError("Compiler temporal sequence surface did not materialize population state") + next_packed_state = self._population_state_to_backend_state(population_name, population_state) + return ( + output_seq, + next_packed_state, + recurrent_hidden, + recurrent_k, + recurrent_v, + input_k_last.unsqueeze(1), + input_v_last.unsqueeze(1), + ) + + def _execute_or_capture_backend_sequence_surface( + self, + *, + backend_population_name: str, + boundary_seq: torch.Tensor, + packed_state: Any, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + initial_state_is_fresh: bool, + population_resets: torch.Tensor, + input_sender_input_to_kv_weight: torch.Tensor | None, + input_group_input_to_kv_weight: torch.Tensor | None, + static_tensors: dict[str, object], + selected_backend_surface: SupportedSurface, + planned_backend_execution: PlannedFabricExecution, + materialize_final_state: bool, + compact_input_carry: bool = False, + preserve_internal_carry: bool = False, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool]: + if output_boundary not in {"sequence", "terminal"}: + raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") + if readout_output_boundary not in {"cells", "pooled"}: + raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") + input_layout, graph_inputs = self._build_backend_graph_inputs( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + population_resets=population_resets, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + packed_state_is_fresh=initial_state_is_fresh, + ) + shape_signature = self._graph_shape_signature( + graph_inputs=graph_inputs, + ) + graph_key = self._backend_graph_capture_key( + surface=selected_backend_surface, + plan=planned_backend_execution, + shape_signature=shape_signature, + ) + cached = self._backend_graph_capture_cache.get(graph_key) + cache_hit = cached is not None + if not cache_hit: + self._backend_graph_capture_cache.put( + graph_key, + _GraphCaptureCacheEntry(key=graph_key, shape_signature=shape_signature), + ) + output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v, input_k_seq, input_v_seq = ( + self._execute_compiler_temporal_sequence_surface( + population_name=backend_population_name, + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + initial_state_is_fresh=initial_state_is_fresh, + materialize_final_state=materialize_final_state, + compact_input_carry=compact_input_carry, + preserve_internal_carry=preserve_internal_carry, + population_resets=population_resets, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + planned_backend_execution=planned_backend_execution, + population_materialized=cast(dict[str, object | None], static_tensors["population_materialized"]), + static_tensors=static_tensors, + grad_path=False, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + ) + return ( + output_seq, + next_packed_state, + recurrent_hidden, + recurrent_k, + recurrent_v, + input_k_seq, + input_v_seq, + cache_hit, + True, + ) + + def _execute_backend_sequence_with_tape_policy( + self, + *, + backend_population_name: str, + boundary_seq: torch.Tensor, + projected_boundary_source_seq: torch.Tensor | None, + projected_boundary_weight: torch.Tensor | None, + projected_boundary_bias: torch.Tensor | None, + packed_state: Any, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + population_resets: torch.Tensor, + population_resets_active: bool, + initial_state_is_fresh: bool, + input_sender_input_to_kv_weight: torch.Tensor | None, + input_group_input_to_kv_weight: torch.Tensor | None, + static_tensors: dict[str, object], + selected_backend_surface: SupportedSurface, + planned_backend_execution: PlannedFabricExecution, + planned_backend_backward_execution: PlannedFabricBackwardExecution, + materialize_final_state: bool, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool]: + chunk_len = self._backend_tape_checkpoint_chunk_len( + plan=planned_backend_execution, + time_steps=int(boundary_seq.shape[1]), + output_boundary=output_boundary, + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + ) + self._backend_backward_batch_tile_len( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + output_boundary=output_boundary, + ) + use_training_graph_capture = ( + self._should_use_backend_graph_capture( + plan=planned_backend_execution, + device=boundary_seq.device, + grad_path=True, + time_steps=int(boundary_seq.shape[1]), + ) + and self._backend_sequence_graph_capture_safe() + ) + del chunk_len + ( + output_seq, + running_packed_state, + running_hidden, + running_recurrent_k, + running_recurrent_v, + last_input_k, + last_input_v, + graph_capture_cache_hit, + graph_capture_replayed, + ) = self._execute_or_capture_backend_training_sequence_surface( + boundary_seq=boundary_seq, + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + packed_state=packed_state, + initial_hidden=initial_hidden, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + initial_state_is_fresh=initial_state_is_fresh, + population_resets=population_resets, + population_resets_active=population_resets_active, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=planned_backend_execution, + planned_backend_backward_execution=planned_backend_backward_execution, + static_tensors=static_tensors, + enable_graph_capture=use_training_graph_capture, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + assert last_input_k is not None and last_input_v is not None + assert running_recurrent_k is not None and running_recurrent_v is not None + return ( + output_seq, + running_packed_state, + running_hidden, + running_recurrent_k, + running_recurrent_v, + last_input_k, + last_input_v, + graph_capture_cache_hit, + graph_capture_replayed, + ) + + def _execute_or_capture_backend_training_sequence_surface( + self, + *, + boundary_seq: torch.Tensor, + projected_boundary_source_seq: torch.Tensor | None = None, + projected_boundary_weight: torch.Tensor | None = None, + projected_boundary_bias: torch.Tensor | None = None, + packed_state: Any, + initial_hidden: torch.Tensor, + initial_recurrent_k: torch.Tensor | None, + initial_recurrent_v: torch.Tensor | None, + population_resets: torch.Tensor, + population_resets_active: bool = True, + selected_backend_surface: SupportedSurface, + planned_backend_execution: PlannedFabricExecution, + planned_backend_backward_execution: PlannedFabricBackwardExecution, + static_tensors: dict[str, object] | None, + enable_graph_capture: bool, + initial_state_is_fresh: bool = False, + materialize_final_state: bool = True, + compact_input_carry: bool = False, + preserve_internal_carry: bool = False, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool]: + if output_boundary not in {"sequence", "terminal"}: + raise ValueError(f"Unsupported Fabric training output boundary {output_boundary!r}") + if readout_output_boundary not in {"cells", "pooled"}: + raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") + projected_boundary_active = ( + projected_boundary_source_seq is not None + or projected_boundary_weight is not None + or projected_boundary_bias is not None + ) + if projected_boundary_active and (projected_boundary_source_seq is None or projected_boundary_weight is None): + raise RuntimeError("Projected Fabric boundary surface requires both source sequence and projection weight") + input_layout, graph_inputs = self._build_backend_graph_inputs( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=initial_hidden, + population_resets=population_resets, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + packed_state_is_fresh=initial_state_is_fresh and packed_state is None, + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + ) + shape_signature = self._graph_shape_signature(graph_inputs=graph_inputs) + graph_key = self._backend_graph_capture_key( + surface=selected_backend_surface, + plan=planned_backend_execution, + shape_signature=shape_signature, + ) + if static_tensors is None: + static_tensors = self._materialize_inference_static_tensors( + device=boundary_seq.device, + dtype=boundary_seq.dtype, + ) + backend_population_name = self._select_output_cells_stream_backend_population( + k=1, + ) + if backend_population_name is None: + raise RuntimeError( + f"Supported Fabric {selected_backend_surface.cell_type} " + "training surface requires a callable backend sequence engine" + ) + + capture_state = {"cache_hit": False, "replayed": False} + if enable_graph_capture: + cached = self._backend_graph_capture_cache.get(graph_key) + capture_state["cache_hit"] = cached is not None + capture_state["replayed"] = True + if cached is None: + self._backend_graph_capture_cache.put( + graph_key, + _GraphCaptureCacheEntry(key=graph_key, shape_signature=shape_signature), + ) + current_packed_state, current_recurrent_k, current_recurrent_v = self._unpack_backend_graph_inputs( + input_layout=input_layout, + graph_inputs=graph_inputs, + ) + current_boundary_seq = graph_inputs.get("boundary_seq") + if current_boundary_seq is None: + current_source_hidden_seq = graph_inputs["projected_boundary_source_seq"] + current_projection_weight = graph_inputs["projected_boundary_weight"] + current_projection_bias = graph_inputs.get("projected_boundary_bias") + current_boundary_seq = self._project_boundary_source_sequence( + current_source_hidden_seq, + input_projection_weight=current_projection_weight, + input_projection_bias=current_projection_bias, + ) + output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v, input_k_seq, input_v_seq = ( + self._execute_compiler_temporal_sequence_surface( + population_name=backend_population_name, + boundary_seq=current_boundary_seq, + packed_state=current_packed_state, + initial_hidden=graph_inputs["initial_hidden"], + initial_recurrent_k=current_recurrent_k, + initial_recurrent_v=current_recurrent_v, + initial_state_is_fresh=initial_state_is_fresh, + materialize_final_state=materialize_final_state, + compact_input_carry=compact_input_carry, + preserve_internal_carry=preserve_internal_carry, + population_resets=graph_inputs["population_resets"] if population_resets_active else None, + input_sender_input_to_kv_weight=cast( + torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] + ), + input_group_input_to_kv_weight=cast( + torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] + ), + planned_backend_execution=planned_backend_execution, + population_materialized=cast(dict[str, object | None], static_tensors["population_materialized"]), + static_tensors=static_tensors, + grad_path=True, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + ) + if output_boundary == "terminal": + output_seq = output_seq[:, -1:] + return ( + output_seq, + next_packed_state, + recurrent_hidden, + recurrent_k, + recurrent_v, + input_k_seq[:, -1], + input_v_seq[:, -1], + capture_state["cache_hit"], + capture_state["replayed"], + ) + + def _execute_backend_supported_sequence_surface_batch_tiled( + self, + *, + batch_tile_len: int, + batch_tile_reason: str | None, + state: TensorDict, + boundary_seq: torch.Tensor, + projected_boundary_source_seq: torch.Tensor | None, + projected_boundary_weight: torch.Tensor | None, + projected_boundary_bias: torch.Tensor | None, + static_tensors: dict[str, object], + population_resets: torch.Tensor | None, + input_sender_input_to_kv_weight: torch.Tensor | None, + input_group_input_to_kv_weight: torch.Tensor | None, + backend_population_name: str, + backend_population_state_is_fresh: bool, + materialize_final_state: bool, + grad_path: bool, + selected_backend_surface: SupportedSurface, + planned_backend_execution: PlannedFabricExecution, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + ) -> tuple[torch.Tensor, TensorDict]: + if readout_output_boundary not in {"cells", "pooled"}: + raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") + output_chunks: list[torch.Tensor] = [] + output_seq: torch.Tensor | None = None + preallocate_output = not grad_path + state_chunks: list[TensorDict] = [] + for start in range(0, int(boundary_seq.shape[0]), int(batch_tile_len)): + end = min(start + int(batch_tile_len), int(boundary_seq.shape[0])) + tile_plan = self.plan_backend_execution( + batch_size=end - start, + time_steps=int(boundary_seq.shape[1]), + inner_steps=1, + training=grad_path, + tape_policy=self._tape_policy_from_bin(planned_backend_execution.tape_policy_bin), + device=boundary_seq.device, + surface_key=selected_backend_surface.key, + ) + output_chunk, state_chunk = self._execute_backend_supported_sequence_surface( + state=state + if backend_population_state_is_fresh + else cast(TensorDict, _slice_batch_tree(state, start, end)), + boundary_seq=boundary_seq[start:end], + projected_boundary_source_seq=_slice_batch_tensor(projected_boundary_source_seq, start, end), + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + static_tensors=static_tensors, + population_resets=_slice_batch_tensor(population_resets, start, end), + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + backend_population_name=backend_population_name, + backend_population_state_is_fresh=backend_population_state_is_fresh, + materialize_final_state=materialize_final_state, + grad_path=grad_path, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=tile_plan, + allow_batch_tiling=False, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + if preallocate_output: + if output_seq is None: + output_seq = output_chunk.new_empty((int(boundary_seq.shape[0]), *output_chunk.shape[1:])) + output_seq[start:end].copy_(output_chunk) + else: + output_chunks.append(output_chunk) + if materialize_final_state: + state_chunks.append(state_chunk) + if output_seq is None: + output_seq = torch.cat(output_chunks, dim=0) + next_state = ( + cast(TensorDict, _cat_batch_tree(cast(list[Any], state_chunks))) + if materialize_final_state + else TensorDict({}, batch_size=[]) + ) + if self._last_backend_execution is not None: + record = self._last_backend_execution + self._last_backend_execution = replace( + record, + batch_size=int(boundary_seq.shape[0]), + workspace_aliases=record.workspace_aliases + + ( + f"forward_batch_tile:b={int(batch_tile_len)}", + f"forward_batch_tile_reason:{batch_tile_reason or self._last_backend_forward_batch_tile_reason}", + f"forward_batch_tile_output:{'preallocated' if preallocate_output else 'cat'}", + ), + ) + return output_seq, next_state + + def _execute_backend_supported_sequence_surface( + self, + *, + state: TensorDict, + boundary_seq: torch.Tensor, + projected_boundary_source_seq: torch.Tensor | None = None, + projected_boundary_weight: torch.Tensor | None = None, + projected_boundary_bias: torch.Tensor | None = None, + static_tensors: dict[str, object], + population_resets: torch.Tensor | None, + input_sender_input_to_kv_weight: torch.Tensor | None, + input_group_input_to_kv_weight: torch.Tensor | None, + backend_population_name: str | None, + backend_population_state_is_fresh: bool, + materialize_final_state: bool, + grad_path: bool, + selected_backend_surface: SupportedSurface, + planned_backend_execution: PlannedFabricExecution, + allow_batch_tiling: bool = True, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + ) -> tuple[torch.Tensor, TensorDict]: + if output_boundary not in {"sequence", "terminal"}: + raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") + if readout_output_boundary not in {"cells", "pooled"}: + raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") + batch_size = boundary_seq.shape[0] + if backend_population_name is None: + raise RuntimeError("Supported Fabric sequence surface requires a recurrent backend cell population") + planned_backend_backward_execution = ( + self.plan_backend_backward_execution( + batch_size=batch_size, + time_steps=int(boundary_seq.shape[1]), + inner_steps=1, + training=True, + tape_policy=self._tape_policy_from_bin(planned_backend_execution.tape_policy_bin), + device=boundary_seq.device, + surface_key=selected_backend_surface.key, + ) + if grad_path + else None + ) + population_name = backend_population_name + use_backend_tape_policy = self._should_use_backend_tape_policy( + plan=planned_backend_execution, + grad_path=grad_path, + time_steps=int(boundary_seq.shape[1]), + ) + use_backend_graph_capture = ( + self._should_use_backend_graph_capture( + plan=planned_backend_execution, + device=boundary_seq.device, + grad_path=grad_path, + time_steps=int(boundary_seq.shape[1]), + ) + and self._backend_sequence_graph_capture_safe() + ) + time_steps = int(boundary_seq.shape[1]) + projected_message_dim, raw_public_dim, cell_params = self._backend_sequence_surface_projection_dims( + population_name=backend_population_name, + static_tensors=static_tensors, + ) + terminal_dependency_receiver_count = None + if ( + backend_population_state_is_fresh + and not materialize_final_state + and self._local_message_step_enabled + and not self._has_edge_delay + and not bool(getattr(self, "_uses_sparse_message_backend", False)) + ): + terminal_dependency_receiver_count = self._fresh_output_dependency_receiver_count( + population_name=backend_population_name, + time_steps=time_steps, + fresh_state_virtualized=True, + ) + fixed_output_active_region = terminal_dependency_receiver_count is not None + state_buffers_needed_by_sequence = materialize_final_state or ( + time_steps > 1 and not fixed_output_active_region + ) + fresh_state_virtualized = backend_population_state_is_fresh and self._can_virtualize_fresh_backend_state( + population_name=backend_population_name, + static_tensors=static_tensors, + projected_message_dim=projected_message_dim, + raw_public_dim=raw_public_dim, + cell_params=cell_params, + ) + defer_fresh_backend_state = ( + backend_population_state_is_fresh + and (fresh_state_virtualized or fixed_output_active_region) + and (fixed_output_active_region or (not use_backend_tape_policy and not use_backend_graph_capture)) + ) + fresh_zero_sentinel_prev = ( + backend_population_state_is_fresh and not state_buffers_needed_by_sequence and time_steps <= 1 + ) + if allow_batch_tiling: + forward_batch_tile_len = self._backend_forward_batch_tile_len_for_layout( + population_name=backend_population_name, + batch_size=int(batch_size), + time_steps=time_steps, + boundary_seq=boundary_seq, + materialize_final_state=materialize_final_state, + training=grad_path, + fresh_state_virtualized=fresh_state_virtualized, + fresh_output_dependency_receiver_count=terminal_dependency_receiver_count, + projected_message_dim=projected_message_dim, + raw_public_dim=raw_public_dim, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + early_batch_tile_reason = self._last_backend_forward_batch_tile_reason + if 0 < forward_batch_tile_len < int(batch_size): + return self._execute_backend_supported_sequence_surface_batch_tiled( + batch_tile_len=forward_batch_tile_len, + batch_tile_reason=early_batch_tile_reason, + state=state, + boundary_seq=boundary_seq, + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + static_tensors=static_tensors, + population_resets=population_resets, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + backend_population_name=backend_population_name, + backend_population_state_is_fresh=backend_population_state_is_fresh, + materialize_final_state=materialize_final_state, + grad_path=grad_path, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=planned_backend_execution, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + virtual_fresh_public_prev = ( + backend_population_state_is_fresh + and (defer_fresh_backend_state or fresh_zero_sentinel_prev) + and not state_buffers_needed_by_sequence + and time_steps <= 1 + ) + recurrent_prev = ( + boundary_seq.new_empty(batch_size, 0, self.hidden_size) + if virtual_fresh_public_prev + else boundary_seq.new_zeros(batch_size, int(terminal_dependency_receiver_count), self.hidden_size) + if backend_population_state_is_fresh + and fixed_output_active_region + and terminal_dependency_receiver_count is not None + else boundary_seq.new_zeros(batch_size, self._num_recurrent_cells, self.hidden_size) + if backend_population_state_is_fresh + else state["cells"][:, self._recurrent_slice, :] + ) + if defer_fresh_backend_state or fresh_zero_sentinel_prev: + packed_state = None + elif backend_population_state_is_fresh: + packed_state = self._init_backend_population_state( + population_name, + batch=batch_size, + device=boundary_seq.device, + dtype=boundary_seq.dtype, + ) + else: + packed_state = self._population_state_to_backend_state( + population_name, + cast(TensorDictBase, state[population_name]), + ) + initial_recurrent_k = None + initial_recurrent_v = None + if backend_population_state_is_fresh and not virtual_fresh_public_prev: + initial_receiver_count = int(recurrent_prev.shape[1]) + initial_recurrent_k = recurrent_prev.new_zeros(batch_size, initial_receiver_count, self.head_dim) + initial_recurrent_v = recurrent_prev.new_zeros(batch_size, initial_receiver_count, self.value_dim) + state_sender_k = None if backend_population_state_is_fresh else state.get("sender_k") + state_sender_v = None if backend_population_state_is_fresh else state.get("sender_v") + if torch.is_tensor(state_sender_k) and tuple(state_sender_k.shape) == ( + batch_size, + int(self.sender_cell_idx.numel()), + self.head_dim, + ): + initial_recurrent_k = state_sender_k[:, self._recurrent_slice, :] + if torch.is_tensor(state_sender_v) and tuple(state_sender_v.shape) == ( + batch_size, + int(self.sender_cell_idx.numel()), + self.value_dim, + ): + initial_recurrent_v = state_sender_v[:, self._recurrent_slice, :] + if allow_batch_tiling: + forward_batch_tile_len = self._backend_forward_batch_tile_len( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=recurrent_prev, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + ) + if 0 < forward_batch_tile_len < int(batch_size): + return self._execute_backend_supported_sequence_surface_batch_tiled( + batch_tile_len=forward_batch_tile_len, + batch_tile_reason=self._last_backend_forward_batch_tile_reason, + state=state, + boundary_seq=boundary_seq, + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + static_tensors=static_tensors, + population_resets=population_resets, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + backend_population_name=backend_population_name, + backend_population_state_is_fresh=backend_population_state_is_fresh, + materialize_final_state=materialize_final_state, + grad_path=grad_path, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=planned_backend_execution, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + population_resets_active = population_resets is not None + with torch.profiler.record_function("fabric.glue.normalized_population_resets"): + normalized_population_resets = ( + torch.zeros( + batch_size, + boundary_seq.shape[1], + device=boundary_seq.device, + dtype=torch.bool, + ) + if population_resets is None + else population_resets.to(device=boundary_seq.device, dtype=torch.bool) + ).contiguous() + graph_capture_cache_hit = False + graph_capture_replayed = False + input_k_last: torch.Tensor | None = None + input_v_last: torch.Tensor | None = None + self._last_backend_tape_chunk_len = None + self._last_backend_tape_chunk_reason = None + self._last_backend_tape_artifact_mode = None + self._last_backend_recompute_artifact_window_len = None + self._last_backend_recompute_artifact_window_reason = None + self._last_backend_recompute_checkpoint_stride = None + self._last_backend_recompute_checkpoint_count = None + self._last_backend_recompute_checkpoint_reason = None + self._last_backend_recompute_checkpoint_artifact_cache_mode = None + self._last_backend_recompute_predecessor_cache_mode = None + self._last_backend_recompute_transition_tape_mode = None + self._last_backend_recompute_transition_tape_reason = None + self._last_backend_recompute_payload_max_bytes = None + self._last_backend_recompute_payload_max_window_len = None + self._last_backend_recompute_payload_max_mode = None + self._last_backend_recompute_payload_sample_count = None + self._last_backend_recompute_public_kv_materialization_mode = None + self._last_backend_recompute_target_state_materialization_mode = None + self._last_backend_recompute_checkpoint_source = None + self._last_backend_recompute_checkpoint_hidden_carry_mode = None + self._last_backend_forward_carry_checkpoints = None + self._last_backend_backward_batch_tile_len = None + self._last_backend_backward_batch_tile_reason = None + self._last_backend_backward_active_receiver_window = None + self._last_backend_backward_active_receiver_window_reason = None + if use_backend_tape_policy: + ( + output_seq, + next_packed_state, + recurrent_hidden, + recurrent_k, + recurrent_v, + input_k_last, + input_v_last, + graph_capture_cache_hit, + graph_capture_replayed, + ) = self._execute_backend_sequence_with_tape_policy( + boundary_seq=boundary_seq, + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + packed_state=packed_state, + initial_hidden=recurrent_prev, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + initial_state_is_fresh=backend_population_state_is_fresh, + population_resets=normalized_population_resets, + population_resets_active=population_resets_active, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + backend_population_name=backend_population_name, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=planned_backend_execution, + planned_backend_backward_execution=cast( + PlannedFabricBackwardExecution, planned_backend_backward_execution + ), + static_tensors=static_tensors, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + elif use_backend_graph_capture: + if grad_path: + ( + output_seq, + next_packed_state, + recurrent_hidden, + recurrent_k, + recurrent_v, + input_k_last, + input_v_last, + graph_capture_cache_hit, + graph_capture_replayed, + ) = self._execute_or_capture_backend_training_sequence_surface( + boundary_seq=boundary_seq, + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + packed_state=packed_state, + initial_hidden=recurrent_prev, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + initial_state_is_fresh=backend_population_state_is_fresh, + population_resets=normalized_population_resets, + population_resets_active=population_resets_active, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=planned_backend_execution, + planned_backend_backward_execution=cast( + PlannedFabricBackwardExecution, planned_backend_backward_execution + ), + enable_graph_capture=True, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + else: + ( + output_seq, + next_packed_state, + recurrent_hidden, + recurrent_k, + recurrent_v, + input_k_seq, + input_v_seq, + graph_capture_cache_hit, + graph_capture_replayed, + ) = self._execute_or_capture_backend_sequence_surface( + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=recurrent_prev, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + initial_state_is_fresh=backend_population_state_is_fresh, + population_resets=normalized_population_resets, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + backend_population_name=backend_population_name, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=planned_backend_execution, + static_tensors=static_tensors, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + if output_boundary == "terminal": + output_seq = output_seq[:, -1:] + input_k_last = input_k_seq[:, -1] + input_v_last = input_v_seq[:, -1] + elif grad_path: + ( + output_seq, + next_packed_state, + recurrent_hidden, + recurrent_k, + recurrent_v, + input_k_last, + input_v_last, + graph_capture_cache_hit, + graph_capture_replayed, + ) = self._execute_or_capture_backend_training_sequence_surface( + boundary_seq=boundary_seq, + projected_boundary_source_seq=projected_boundary_source_seq, + projected_boundary_weight=projected_boundary_weight, + projected_boundary_bias=projected_boundary_bias, + packed_state=packed_state, + initial_hidden=recurrent_prev, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + initial_state_is_fresh=backend_population_state_is_fresh, + population_resets=normalized_population_resets, + population_resets_active=population_resets_active, + selected_backend_surface=selected_backend_surface, + planned_backend_execution=planned_backend_execution, + planned_backend_backward_execution=cast( + PlannedFabricBackwardExecution, planned_backend_backward_execution + ), + static_tensors=static_tensors, + enable_graph_capture=False, + materialize_final_state=materialize_final_state, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + else: + output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v, input_k_seq, input_v_seq = ( + self._execute_compiler_temporal_sequence_surface( + population_name=backend_population_name, + boundary_seq=boundary_seq, + packed_state=packed_state, + initial_hidden=recurrent_prev, + initial_recurrent_k=initial_recurrent_k, + initial_recurrent_v=initial_recurrent_v, + initial_state_is_fresh=backend_population_state_is_fresh, + materialize_final_state=materialize_final_state, + population_resets=normalized_population_resets, + input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, + input_group_input_to_kv_weight=input_group_input_to_kv_weight, + planned_backend_execution=planned_backend_execution, + population_materialized=cast(dict[str, object | None], static_tensors["population_materialized"]), + static_tensors=static_tensors, + grad_path=grad_path, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + ) + ) + if output_boundary == "terminal": + output_seq = output_seq[:, -1:] + input_k_last = input_k_seq[:, -1] + input_v_last = input_v_seq[:, -1] + if not materialize_final_state and self._last_backend_launch_metadata is not None: + metadata = dict(self._last_backend_launch_metadata) + metadata["generic_glue_fusion_modes"] = tuple(metadata.get("generic_glue_fusion_modes", ())) + ( + "final_state_materialization:skipped_by_request", + ) + metadata["workspace_aliases"] = tuple(metadata.get("workspace_aliases", ())) + ( + "final_state=not_materialized", + ) + self._last_backend_launch_metadata = metadata + self._record_backend_execution( + surface=selected_backend_surface, + plan=planned_backend_execution, + backward_plan=planned_backend_backward_execution, + batch_size=batch_size, + time_steps=boundary_seq.shape[1], + inner_steps=1, + training=grad_path, + graph_capture_replayed=graph_capture_replayed, + graph_capture_cache_hit=graph_capture_cache_hit, + ) + if output_boundary == "terminal" and self._last_backend_execution is not None: + record = self._last_backend_execution + self._last_backend_execution = replace( + record, + workspace_aliases=record.workspace_aliases + + ( + "sequence_output_boundary:terminal_step", + "sequence_output_materialization:terminal_step_only", + ), + ) + next_state = TensorDict({}, batch_size=[]) + if materialize_final_state: + last_boundary_step = boundary_seq[:, -1] + last_output_cells = output_seq[:, -1] + with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): + next_state["cells"] = torch.cat((last_boundary_step, recurrent_hidden, last_output_cells), dim=1) + assert input_k_last is not None and input_v_last is not None + with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): + next_state["sender_k"] = torch.cat((input_k_last, recurrent_k), dim=1) + with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): + next_state["sender_v"] = torch.cat((input_v_last, recurrent_v), dim=1) + next_state[population_name] = self._backend_state_to_population_state( + population_name, + cast(Mapping[str, torch.Tensor], next_packed_state), + ) + return output_seq, next_state diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/surface.py b/src/cortical/fabric/backend/cuda/sequence_surface/surface.py deleted file mode 100644 index 55d858e3..00000000 --- a/src/cortical/fabric/backend/cuda/sequence_surface/surface.py +++ /dev/null @@ -1,2608 +0,0 @@ -from __future__ import annotations - -import importlib -import os -from collections.abc import Callable, Mapping -from contextlib import contextmanager -from dataclasses import replace -from typing import Any, Iterator, Literal, cast - -import torch -from tensordict import TensorDict, TensorDictBase - -from cortical.fabric.backend.cuda.ops import reset_backend_state_rows_cuda, reset_backend_tensors_rows_cuda -from cortical.fabric.backend.cuda.sequence_surface.backward import CudaSequenceBackwardMixin -from cortical.fabric.backend.cuda.sequence_surface.policy import projected_boundary_time_chunk_policy -from cortical.fabric.backend.cuda.sequence_surface.support import ( - _BACKWARD_ATTRIBUTION_MODE_ENV, - _BACKWARD_MODE_ENV, - _BACKWARD_OWNER_TIMING_ENV, - _BackendOwnerTimingCollector, - _BackendSequenceStepArtifacts, - _cat_batch_tree, - _GraphCaptureFallback, - _PhysicalBackwardSequenceExecutor, - _ReceiverWindowSpec, - _slice_batch_tensor, - _slice_batch_tree, - _transition_supports_receiver_local_dependency_window, -) -from cortical.fabric.backend.planner import ( - PlannedFabricBackwardExecution, - PlannedFabricExecution, -) -from cortical.fabric.backend.reuse import ExecutionFamily -from cortical.fabric.backend.surfaces import SupportedSurface -from cortical.fabric.contracts.cells import collect_cell_tensors, reset_backend_tensor_rows -from cortical.fabric.runtime.state import ( - flatten_backend_packed_state as _flatten_backend_packed_state, -) -from cortical.fabric.runtime.state import ( - unflatten_backend_packed_state as _unflatten_backend_packed_state, -) - -_STREAMING_OUTPUT_GRAPH_CHUNK_TARGET_BYTES = 128 << 20 -_STREAMING_OUTPUT_GRAPH_CHUNK_MIN_STEPS = 128 - - -class CudaSequenceSurfaceMixin(CudaSequenceBackwardMixin): - @staticmethod - def _clone_backend_carry_value(value: Any) -> Any: - if value is None: - return None - if torch.is_tensor(value): - return value.clone() - if isinstance(value, TensorDictBase): - return TensorDict( - {key: CudaSequenceSurfaceMixin._clone_backend_carry_value(item) for key, item in value.items()}, - batch_size=value.batch_size, - device=value.device, - ) - if isinstance(value, dict): - return {key: CudaSequenceSurfaceMixin._clone_backend_carry_value(item) for key, item in value.items()} - return value - - @staticmethod - def _cuda_usable_memory_info(device: torch.device) -> tuple[int, int, int, int]: - free_bytes, total_bytes = torch.cuda.mem_get_info(device) - try: - reserved_bytes = int(torch.cuda.memory_reserved(device)) - allocated_bytes = int(torch.cuda.memory_allocated(device)) - except RuntimeError: - reserved_bytes = 0 - allocated_bytes = 0 - reusable_reserved_bytes = max(0, reserved_bytes - allocated_bytes) - usable_bytes = int(free_bytes) + int(reusable_reserved_bytes) - return int(usable_bytes), int(total_bytes), int(free_bytes), int(reusable_reserved_bytes) - - def _backend_owner_timing_enabled(self, device: torch.device) -> bool: - if device.type != "cuda": - return False - return os.environ.get(_BACKWARD_OWNER_TIMING_ENV, "").lower() in {"1", "true", "yes", "on"} - - def _begin_backend_owner_timing(self, device: torch.device) -> None: - if not self._backend_owner_timing_enabled(device): - self._active_backend_owner_timing = None - return - self._active_backend_owner_timing = _BackendOwnerTimingCollector(device=device, events=[]) - - def _finish_backend_owner_timing(self) -> None: - collector = getattr(self, "_active_backend_owner_timing", None) - self._active_backend_owner_timing = None - if collector is None: - return - timing_summary = collector.summary() - wall_summary = collector.wall_summary() - if not timing_summary and not wall_summary: - return - record = getattr(self, "_last_backend_execution", None) - if record is None: - self._last_backend_owner_timing_ms = timing_summary - self._last_backend_owner_wall_ms = wall_summary - return - self._last_backend_execution = replace( - record, - backward_owner_timing_ms=timing_summary, - backward_owner_wall_ms=wall_summary, - ) - self._last_backend_owner_timing_ms = timing_summary - self._last_backend_owner_wall_ms = wall_summary - - @contextmanager - def _backend_owner_timing(self, name: str) -> Iterator[None]: - collector = getattr(self, "_active_backend_owner_timing", None) - if collector is None: - yield - return - with collector.record(name): - yield - - def _backend_sequence_graph_capture_safe(self) -> bool: - # The active CUDA recurrence path runs through the generic dispatcher-backed - # backend executor and is the intended graph-capture surface for supported - # plans. Do not silently disable capture here and fall back to uncaptured - # replay for surfaces the planner has already marked capture-safe. - return True - - def _backend_execution_semantics(self, execution_family: ExecutionFamily) -> tuple[str, str]: - if execution_family in {ExecutionFamily.SEQUENCE_MAJOR, ExecutionFamily.RECEIVER_MAJOR}: - return "receiver_owned", "persistent_scan" - if execution_family == ExecutionFamily.EDGE_MAJOR: - return "edge_owned", "persistent_scan" - raise RuntimeError(f"Unsupported Fabric execution family {execution_family.value}") - - def _record_cuda_launch_metadata( - self, - request: Any, - *, - spatial_ownership: str, - temporal_execution: str, - actual_launch_metadata: Mapping[str, tuple[Any, ...]] | None = None, - ) -> None: - self._last_backend_launch_metadata = { - "receiver_tiles": (int(request.receiver_tile),), - "batch_tiles": (int(request.batch_tile),), - "edge_tiles": (int(request.edge_tile),), - "hidden_chunks": (int(request.hidden_chunk),), - "state_receiver_tiles": (int(request.state_receiver_tile),), - "state_batch_tiles": (int(request.state_batch_tile),), - "state_hidden_chunks": (int(request.state_hidden_chunk),), - "state_static_stage_modes": (str(request.state_static_stage_mode),), - "emit_receiver_tiles": (int(request.emit_receiver_tile),), - "emit_batch_tiles": (int(request.emit_batch_tile),), - "emit_hidden_chunks": (int(request.emit_hidden_chunk),), - "emit_static_stage_modes": (str(request.emit_static_stage_mode),), - "public_receiver_tiles": (int(request.public_receiver_tile),), - "public_batch_tiles": (int(request.public_batch_tile),), - "replication_factors": (int(request.replication_factor),), - "cell_static_stage_modes": (str(request.cell_static_stage_mode),), - "readout_modes": (str(request.readout_mode),), - "workspace_aliases": ("none",), - "temporal_executions": (temporal_execution,), - "scan_implementations": ( - "backend_host_loop" if temporal_execution == "persistent_scan" else "single_step", - ), - "active_receiver_window_modes": ("full_surface",), - "active_receiver_window_offsets": ("0",), - "active_receiver_window_counts": ("0",), - "phases": ( - "receiver_message_aggregate", - "dense_input_projection", - "dense_state_affines", - "receiver_state_update", - "receiver_reduce_stats", - "receiver_emit_raw_public", - "dense_public_projection", - "readout_message_aggregate", - "dense_readout_projection", - ) - if spatial_ownership == "receiver_owned" - else ( - "edge_owned_accumulate", - "receiver_message_normalize", - "dense_input_projection", - "dense_state_affines", - "receiver_state_update", - "receiver_reduce_stats", - "receiver_emit_raw_public", - "dense_public_projection", - "readout_message_aggregate", - "dense_readout_projection", - ), - "input_projection_backends": ("unrun",), - "input_projection_notes": ("unrun",), - "message_projection_boundaries": ("unrun",), - "message_projection_bucket_kinds": ("unrun",), - "message_bucket_count": ("unrun",), - "message_regular_local_bucket_count": ("unrun",), - "message_sparse_bucket_count": ("unrun",), - "message_batched_backend_count": ("unrun",), - "message_grouped_backend_count": ("unrun",), - "message_reset_aware_bucket_count": ("unrun",), - "message_degree_uniform_bucket_count": ("unrun",), - "message_ragged_grouped_bucket_count": ("unrun",), - "message_demoted_bucket_count": ("unrun",), - "message_bucket_signatures": ("unrun",), - "message_bucket_kinds": ("unrun",), - "message_topology_kinds": ("unrun",), - "message_spatial_ownership": ("unrun",), - "message_degree_bucket_lists": ("unrun",), - "message_logit_backends": ("unrun",), - "message_softmax_backends": ("unrun",), - "message_weighted_value_backends": ("unrun",), - "message_physical_mode": ("unrun",), - "message_execution_mode": ("unrun",), - "message_output_boundary": ("unrun",), - "message_reset_policies": ("unrun",), - "message_reset_scopes": ("unrun",), - "message_use_delay": ("unrun",), - "message_distance_penalty_kinds": ("unrun",), - "message_epilogue_kinds": ("unrun",), - "message_packed_source_reuse_count": ("unrun",), - "message_demotions": ("unrun",), - "message_workspace_buffers": ("unrun",), - "message_workspace_buffer_bytes": ("unrun",), - "message_workspace_peak_bytes": ("unrun",), - "message_workspace_mode": ("unrun",), - "message_workspace_aliases": ("unrun",), - "message_per_bucket_workspace_bytes": ("unrun",), - "phase_launch_counts": ("unrun",), - "small_cublas_launch_counts": ("unrun",), - "copy_glue_launch_counts": ("unrun",), - "copy_glue_saved_launch_counts": ("unrun",), - "bias_glue_launch_counts": ("unrun",), - "bias_glue_saved_launch_counts": ("unrun",), - "state_epilogue_modes": ("unrun",), - "state_epilogue_saved_launch_counts": ("unrun",), - "launch_coalescing_modes": ("unrun",), - "generic_glue_fusion_modes": ("unrun",), - "launch_granularity_modes": ("unrun",), - "physical_op_kinds": ("unrun",), - "physical_layout_contracts": ("unrun",), - "layout_mode": ("unrun",), - "copy_elision_mode": ("unrun",), - "bias_fusion_mode": ("unrun",), - "physical_op_executors": ("unrun",), - "physical_op_demotions": ("unrun",), - "physical_boundary_contracts": ("unrun",), - "physical_applicability_predicates": ("unrun",), - "physical_workspace_aliases": ("unrun",), - "physical_workspace_peak_bytes": ("unrun",), - "physical_op_launch_counts": ("unrun",), - "physical_op_saved_launch_counts": ("unrun",), - "standalone_copy_kernel_count": ("unrun",), - "standalone_bias_kernel_count": ("unrun",), - "receiver_affine_superop_surface_count": ("unrun",), - "receiver_affine_superop_receivers": ("unrun",), - "receiver_affine_superop_k": ("unrun",), - "receiver_affine_superop_n": ("unrun",), - "receiver_affine_superop_source_layout": ("unrun",), - "receiver_affine_superop_reset_policy": ("unrun",), - "receiver_affine_superop_executor": ("unrun",), - "receiver_affine_superop_physical_mode": ("unrun",), - "receiver_affine_superop_demotion_reason": ("unrun",), - "diagonal_recurrence_superop_surface_count": ("unrun",), - "diagonal_recurrence_kind": ("unrun",), - "diagonal_recurrence_executor": ("unrun",), - "diagonal_recurrence_physical_mode": ("unrun",), - "diagonal_recurrence_coeff_cache_mode": ("unrun",), - "diagonal_recurrence_coeff_cache_hit": ("unrun",), - "diagonal_recurrence_coeff_cache_bytes": ("unrun",), - "diagonal_recurrence_coeff_cache_version_source": ("unrun",), - "diagonal_recurrence_reset_policy": ("unrun",), - "diagonal_recurrence_reset_scope": ("unrun",), - "diagonal_recurrence_output_boundary": ("unrun",), - "diagonal_recurrence_workspace_mode": ("unrun",), - "diagonal_recurrence_workspace_peak_bytes": ("unrun",), - "diagonal_recurrence_demotion_reason": ("unrun",), - "diagonal_recurrence_launch_count": ("unrun",), - "state_affine_backends": ("unrun",), - "state_affine_sources": ("unrun",), - "state_affine_bucket_signatures": ("unrun",), - "state_affine_output_modes": ("unrun",), - "state_affine_reset_policies": ("unrun",), - "state_affine_reset_mode": ("unrun",), - "state_affine_reset_scope": ("unrun",), - "state_affine_workspace_mode": ("unrun",), - "state_affine_receiver_chunk_size": ("unrun",), - "state_affine_receiver_chunks": ("unrun",), - "state_affine_workspace_buffers": ("unrun",), - "state_affine_workspace_buffer_bytes": ("unrun",), - "state_affine_workspace_bytes": ("unrun",), - "state_affine_reset_rows_present": ("unrun",), - "state_affine_packed_source_reused": ("unrun",), - "public_projection_hidden_backends": ("unrun",), - "public_projection_kv_backends": ("unrun",), - "readout_projection_backends": ("unrun",), - "workspace_buffers": ("unrun",), - "workspace_buffer_bytes": ("unrun",), - "workspace_peak_bytes": ("unrun",), - } - if actual_launch_metadata is not None: - self._last_backend_launch_metadata.update(actual_launch_metadata) - - def _resolve_backend_initial_recurrent_kv( - self, - *, - population_name: str, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - static_tensors: dict[str, object], - active_receiver_window: _ReceiverWindowSpec | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if initial_recurrent_k is None and initial_recurrent_v is None: - public_kind = self._cell_spec_for_population(population_name).public_schema.kind - if public_kind == "hidden": - projection_static_tensors = static_tensors - if ( - active_receiver_window is not None - and active_receiver_window.active - and initial_hidden.dim() >= 2 - and int(initial_hidden.shape[1]) == int(active_receiver_window.count) - ): - projection_static_tensors = self._slice_receiver_window_static_tensors( - static_tensors, - active_receiver_window, - ) - return self._project_sender_kv_from_cells_step( - initial_hidden, - sender_input_to_kv_weight=cast( - torch.Tensor | None, projection_static_tensors["recurrent_sender_input_to_kv_weight"] - ), - grouped_sender_input_to_kv_weight=cast( - torch.Tensor | None, projection_static_tensors["recurrent_group_input_to_kv_weight"] - ), - sender_group_size=1 - if projection_static_tensors.get("recurrent_group_input_to_kv_weight") is None - else self._recurrent_sender_kv_group_size, - ) - if public_kind == "preproj": - batch_size, recurrent_cells, _hidden = initial_hidden.shape - return ( - initial_hidden.new_zeros(batch_size, recurrent_cells, self.head_dim), - initial_hidden.new_zeros(batch_size, recurrent_cells, self.value_dim), - ) - raise RuntimeError(f"Unsupported public schema kind {public_kind} for cell population {population_name}") - if initial_recurrent_k is None or initial_recurrent_v is None: - raise ValueError("initial_recurrent_k and initial_recurrent_v must both be provided or both be None") - return initial_recurrent_k, initial_recurrent_v - - def _backend_sequence_surface_projection_dims( - self, - *, - population_name: str, - static_tensors: dict[str, object], - ) -> tuple[int, int, tuple[torch.Tensor, ...]]: - population_materialized = cast(dict[str, object | None], static_tensors["population_materialized"]) - population_params = population_materialized[population_name] - if not isinstance(population_params, dict): - raise RuntimeError(f"Fabric cell population {population_name} is missing materialized parameters") - tensor_population_params = {key: value for key, value in population_params.items() if torch.is_tensor(value)} - backend_cell_tensors = cast(dict[str, dict[str, torch.Tensor]], static_tensors["backend_cell_tensors"]) - extra_cell_tensors = backend_cell_tensors.get(population_name) - if extra_cell_tensors is None: - raise RuntimeError( - f"Fabric cell population {population_name} is missing backend cell tensor materialization" - ) - cell_spec = self._cell_spec_for_population(population_name) - cell_tensors = collect_cell_tensors(cell_spec, extra_cell_tensors, tensor_population_params) - cell_params = tuple(cell_spec.parameter_schema.flatten(cell_tensors)) - input_projection_params = tuple(cell_spec.input_projection_schema.flatten(cell_tensors)) - public_projection_params = tuple(cell_spec.public_projection_schema.flatten(cell_tensors)) - input_weight = input_projection_params[0] - projected_message_dim = int(input_weight.shape[2] if input_weight.dim() == 3 else input_weight.shape[0]) - if cell_spec.public_schema.kind == "hidden": - raw_public_dim = int(self.hidden_size) - else: - raw_public_dim = int(public_projection_params[0].shape[1]) - return projected_message_dim, raw_public_dim, cell_params - - def _can_virtualize_fresh_backend_state( - self, - *, - population_name: str, - static_tensors: dict[str, object], - projected_message_dim: int, - raw_public_dim: int, - cell_params: tuple[torch.Tensor, ...], - ) -> bool: - cell_spec = self._cell_spec_for_population(population_name) - dispatcher_cuda = importlib.import_module("cortical.fabric.backend.cuda.execution.dispatcher_cuda") - return bool( - dispatcher_cuda._load_ext().can_virtualize_fresh_state( - int(cell_spec.cell_kind), - cell_params, - int(self._population_num_cells(population_name)), - int(projected_message_dim), - int(raw_public_dim), - ) - ) - - def _build_backend_sequence_request( - self, - *, - population_name: str, - input_k_seq: torch.Tensor, - input_v_seq: torch.Tensor, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - initial_state_is_fresh: bool, - materialize_final_state: bool, - compact_input_carry: bool = False, - preserve_internal_carry: bool = False, - resets_u8: torch.Tensor, - reset_rows_present: bool, - stage_receiver_static: bool, - replication_factor: int, - receiver_tile: int, - batch_tile: int, - edge_tile: int, - hidden_chunk: int, - state_receiver_tile: int, - state_batch_tile: int, - state_hidden_chunk: int, - state_static_stage_mode: str, - emit_receiver_tile: int, - emit_batch_tile: int, - emit_hidden_chunk: int, - emit_static_stage_mode: str, - public_receiver_tile: int, - public_batch_tile: int, - readout_mode: str, - readout_port_tile: int, - readout_output_chunk: int, - cell_static_stage_mode: str, - message_rule_name: str, - message_rule_lowering_kind: str, - message_rule_expression_signature: str, - message_rule_source_signature: str, - message_rule_parameter_sharing_signature: str, - message_rule_output_boundary: str, - population_materialized: dict[str, object | None], - static_tensors: dict[str, object], - grad_path: bool, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - ) -> Any: - if output_boundary not in {"sequence", "terminal"}: - raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") - if readout_output_boundary not in {"cells", "pooled"}: - raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") - fabric_execution_request_cls = importlib.import_module( - "cortical.fabric.backend.cuda.execution.registry" - ).FabricExecutionRequest - cell_spec = self._cell_spec_for_population(population_name) - population_params = population_materialized[population_name] - if not isinstance(population_params, dict): - raise RuntimeError(f"Fabric cell population {population_name} is missing materialized parameters") - tensor_population_params = {key: value for key, value in population_params.items() if torch.is_tensor(value)} - backend_cell_tensors = cast(dict[str, dict[str, torch.Tensor]], static_tensors["backend_cell_tensors"]) - extra_cell_tensors = backend_cell_tensors.get(population_name) - if extra_cell_tensors is None: - raise RuntimeError( - f"Fabric cell population {population_name} is missing backend cell tensor materialization" - ) - cell_tensors = collect_cell_tensors(cell_spec, extra_cell_tensors, tensor_population_params) - uses_sparse_messages = bool(self._uses_sparse_message_backend) - time_steps = int(input_k_seq.shape[1]) - # Streaming Fabric runs a fixed graph/boundary-selected spatial plan for each recurrent step. The terminal - # boundary controls output materialization only; it must not expand receiver ownership as a function of T. - active_receiver_window = None - population_spec = self._backend_population_specs.get(population_name) - if ( - (initial_state_is_fresh or compact_input_carry) - and not materialize_final_state - and not uses_sparse_messages - and self._local_message_step_enabled - and not self._has_edge_delay - and population_spec is not None - and _transition_supports_receiver_local_dependency_window(population_spec.transition_ir) - ): - active_receiver_window = self._fixed_output_dependency_receiver_window( - reason=f"streaming_output_active_region:forward_fixed_window;output_boundary={output_boundary};" - f"time_steps={time_steps}" - ) - output_local_window_start = 0 - output_local_window_count = 0 - output_local_window_contiguous = False - active_receiver_window_mode = "full_surface" - if active_receiver_window is not None and active_receiver_window.active: - output_local_window_start = int(active_receiver_window.start) - output_local_window_count = int(active_receiver_window.count) - output_local_window_contiguous = True - active_receiver_window_mode = active_receiver_window.mode - forward_carry_checkpoint_stride = -1 if bool(grad_path) and time_steps > 1 else 0 - forward_carry_checkpoint_reason = ( - "planner_budgeted:streaming_fixed_spatial_plan" - if forward_carry_checkpoint_stride < 0 - else "disabled:streaming_fixed_spatial_plan" - ) - forward_carry_checkpoint_state_names: tuple[str, ...] = () - core_checkpoint_state_names = self._transition_core_state_names_for_population(population_name) - trace_checkpoint_elision_allowed = bool( - bool(grad_path) - and active_receiver_window is not None - and active_receiver_window.active - and core_checkpoint_state_names is not None - and not compact_input_carry - and not preserve_internal_carry - ) - if trace_checkpoint_elision_allowed: - forward_carry_checkpoint_state_names = tuple(core_checkpoint_state_names) - forward_carry_checkpoint_reason = f"{forward_carry_checkpoint_reason};checkpoint_state=core_without_trace" - elif ( - bool(grad_path) - and active_receiver_window is not None - and active_receiver_window.active - and core_checkpoint_state_names is not None - ): - forward_carry_checkpoint_reason = f"{forward_carry_checkpoint_reason};checkpoint_state=full_with_trace" - return fabric_execution_request_cls( - population_name=population_name, - cell_core_spec=cell_spec, - message_backend_name="sparse" if uses_sparse_messages else "local", - message_rule_name=message_rule_name, - message_rule_lowering_kind=message_rule_lowering_kind, - message_rule_expression_signature=message_rule_expression_signature, - message_rule_source_signature=message_rule_source_signature, - message_rule_parameter_sharing_signature=message_rule_parameter_sharing_signature, - message_rule_output_boundary=message_rule_output_boundary, - readout_backend_name="output_sequence_from_sparse_banks" - if uses_sparse_messages - else "output_sequence_from_banks", - gradient_enabled=bool(grad_path), - input_k_seq=input_k_seq, - input_v_seq=input_v_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=initial_state_is_fresh, - materialize_final_state=materialize_final_state, - resets_u8=resets_u8, - reset_rows_present=reset_rows_present, - stage_receiver_static=stage_receiver_static, - replication_factor=replication_factor, - receiver_tile=receiver_tile, - batch_tile=batch_tile, - edge_tile=edge_tile, - hidden_chunk=hidden_chunk, - state_receiver_tile=state_receiver_tile, - state_batch_tile=state_batch_tile, - state_hidden_chunk=state_hidden_chunk, - state_static_stage_mode=state_static_stage_mode, - emit_receiver_tile=emit_receiver_tile, - emit_batch_tile=emit_batch_tile, - emit_hidden_chunk=emit_hidden_chunk, - emit_static_stage_mode=emit_static_stage_mode, - public_receiver_tile=public_receiver_tile, - public_batch_tile=public_batch_tile, - readout_mode=readout_mode, - readout_port_tile=readout_port_tile, - readout_output_chunk=readout_output_chunk, - cell_static_stage_mode=cell_static_stage_mode, - routing_tensors={ - "recurrent_q": cast(torch.Tensor, static_tensors["recurrent_q"]), - "output_q": cast(torch.Tensor, static_tensors["output_q"]), - "recurrent_local_sender_idx": self.recurrent_local_sender_idx, - "recurrent_local_valid": self.recurrent_local_valid, - "recurrent_local_receiver_idx_by_sender": self.recurrent_local_receiver_idx_by_sender, - "output_local_sender_idx": self.output_local_sender_idx, - "output_local_valid": self.output_local_valid, - "output_local_receiver_idx_by_sender": self.output_local_receiver_idx_by_sender, - "local_distance": self.local_distance, - "local_delay": self.local_delay, - "recurrent_neighbor_idx": self.recurrent_neighbor_idx, - "recurrent_neighbor_valid": self.recurrent_neighbor_valid, - "recurrent_edge_distance": self.recurrent_edge_distance, - "recurrent_edge_delay": self.recurrent_edge_delay, - "recurrent_sparse_receiver_order": self.recurrent_sparse_receiver_order, - "recurrent_sparse_degree_ptr": self.recurrent_sparse_degree_ptr, - "output_neighbor_idx": self.output_neighbor_idx, - "output_neighbor_valid": self.output_neighbor_valid, - "output_edge_distance": self.output_edge_distance, - "output_edge_delay": self.output_edge_delay, - }, - cell_tensors=cell_tensors, - readout_tensors={ - "output_projection_weight": cast(torch.Tensor, static_tensors["value_to_output_weight"]), - "output_projection_bias": self.output_cell_bias, - }, - static_config={ - "distance_scale": float(self.config.distance_logit_scale), - "use_delay": bool(self._has_edge_delay), - "sender_group_size": int(self._recurrent_sender_kv_group_size), - "recurrent_sparse_positive_degree_buckets": int(self._recurrent_sparse_positive_degree_buckets), - "output_local_recurrent_window_start": output_local_window_start, - "output_local_recurrent_window_count": output_local_window_count, - "output_local_recurrent_window_contiguous": output_local_window_contiguous, - "output_sparse_recurrent_window_start": int(self._output_sparse_recurrent_window_start), - "output_sparse_recurrent_window_count": int(self._output_sparse_recurrent_window_count), - "output_sparse_recurrent_window_contiguous": bool(self._output_sparse_recurrent_window_contiguous), - "output_boundary": output_boundary, - "readout_pool": str(self.config.readout_pool) if readout_output_boundary == "pooled" else "cells", - "readout_slots": int(self.readout_slots) if readout_output_boundary == "pooled" else 0, - "readout_output_boundary": readout_output_boundary, - "terminal_active_receiver_window_mode": active_receiver_window_mode, - "forward_carry_checkpoint_stride": int(forward_carry_checkpoint_stride), - "forward_carry_checkpoint_reason": forward_carry_checkpoint_reason, - "forward_carry_checkpoint_state_names": forward_carry_checkpoint_state_names, - }, - compact_input_carry=compact_input_carry, - preserve_internal_carry=preserve_internal_carry, - ) - - def _supports_cuda_backend_sequence_surface( - self, - *, - k: int | torch.Tensor | None, - device: torch.device, - dtype: torch.dtype, - ) -> bool: - route = self._plan_sequence_surface_route( - k=k, - device=device, - dtype=dtype, - ) - return route.supported - - def _execute_backend_sequence_surface( - self, - *, - state: TensorDict, - boundary_seq: torch.Tensor, - projected_boundary_source_seq: torch.Tensor | None = None, - projected_boundary_weight: torch.Tensor | None = None, - projected_boundary_bias: torch.Tensor | None = None, - static_tensors: dict[str, object], - population_resets: torch.Tensor | None, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - backend_population_name: str | None, - backend_population_state_is_fresh: bool, - materialize_final_state: bool, - grad_path: bool, - selected_backend_surface: SupportedSurface, - planned_backend_execution: PlannedFabricExecution, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - ) -> tuple[torch.Tensor, TensorDict]: - if backend_population_name is not None: - return self._execute_backend_supported_sequence_surface( - state=state, - boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - static_tensors=static_tensors, - population_resets=population_resets, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - backend_population_name=backend_population_name, - backend_population_state_is_fresh=backend_population_state_is_fresh, - materialize_final_state=materialize_final_state, - grad_path=grad_path, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - raise ValueError(f"Unsupported backend surface {selected_backend_surface.key}") - - def _projected_boundary_time_chunk_len( - self, - *, - projected_boundary_source_seq: torch.Tensor, - projected_boundary_weight: torch.Tensor, - projected_boundary_bias: torch.Tensor | None = None, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - output_boundary: Literal["sequence", "terminal"], - readout_output_boundary: Literal["cells", "pooled"], - ) -> int: - del projected_boundary_bias - time_steps = int(projected_boundary_source_seq.shape[1]) - if time_steps <= 1 or projected_boundary_source_seq.device.type != "cuda": - self._last_projected_boundary_time_chunk_reason = ( - f"time_steps={time_steps};device={projected_boundary_source_seq.device.type};chunk_len={time_steps}" - ) - return time_steps - dtype_bytes = int(projected_boundary_source_seq.element_size()) - projected_features = int(projected_boundary_weight.shape[0]) - state_bytes = 0 - if packed_state is not None: - _state_keys, state_tensors = _flatten_backend_packed_state(packed_state) - state_bytes += sum(self._tensor_storage_bytes(tensor) for tensor in state_tensors) - state_bytes += self._tensor_storage_bytes(initial_hidden) - state_bytes += self._tensor_storage_bytes(initial_recurrent_k) - state_bytes += self._tensor_storage_bytes(initial_recurrent_v) - output_ports = int(self.readout_slots) if readout_output_boundary == "pooled" else int(self._num_output_cells) - decision = projected_boundary_time_chunk_policy( - time_steps=time_steps, - batch_size=int(projected_boundary_source_seq.shape[0]), - dtype_bytes=dtype_bytes, - projected_features=projected_features, - state_bytes=state_bytes, - output_ports=output_ports, - hidden_size=int(self.hidden_size), - output_boundary=output_boundary, - memory=self._cuda_memory_budget(projected_boundary_source_seq.device), - graph_chunk_target_bytes=_STREAMING_OUTPUT_GRAPH_CHUNK_TARGET_BYTES, - graph_chunk_min_steps=_STREAMING_OUTPUT_GRAPH_CHUNK_MIN_STEPS, - ) - self._last_projected_boundary_time_chunk_reason = decision.reason - return int(decision.value) - - def _execute_backend_projected_source_sequence_surface( - self, - *, - state: TensorDict, - projected_boundary_source_seq: torch.Tensor, - projected_boundary_weight: torch.Tensor, - projected_boundary_bias: torch.Tensor | None = None, - static_tensors: dict[str, object], - population_resets: torch.Tensor | None, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - backend_population_name: str | None, - backend_population_state_is_fresh: bool, - materialize_final_state: bool, - grad_path: bool, - selected_backend_surface: SupportedSurface, - planned_backend_execution: PlannedFabricExecution, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - output_chunk_consumer: Callable[[torch.Tensor, int, int], None] | None = None, - detach_internal_carry_after_output_chunk: bool = False, - ) -> tuple[torch.Tensor, TensorDict]: - if output_boundary not in {"sequence", "terminal"}: - raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") - if readout_output_boundary not in {"cells", "pooled"}: - raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") - if backend_population_name is None: - raise RuntimeError("Supported Fabric sequence surface requires a recurrent backend cell population") - if projected_boundary_source_seq.dim() != 3: - raise ValueError("Projected Fabric boundary source must be shaped [B,T,H]") - batch_size = int(projected_boundary_source_seq.shape[0]) - time_steps = int(projected_boundary_source_seq.shape[1]) - projected_message_dim, raw_public_dim, cell_params = self._backend_sequence_surface_projection_dims( - population_name=backend_population_name, - static_tensors=static_tensors, - ) - planned_backend_backward_execution = ( - self.plan_backend_backward_execution( - batch_size=batch_size, - time_steps=time_steps, - inner_steps=1, - training=True, - tape_policy=self._tape_policy_from_bin(planned_backend_execution.tape_policy_bin), - device=projected_boundary_source_seq.device, - surface_key=selected_backend_surface.key, - ) - if grad_path - else None - ) - terminal_dependency_receiver_count = None - if ( - backend_population_state_is_fresh - and not materialize_final_state - and self._local_message_step_enabled - and not self._has_edge_delay - and not bool(getattr(self, "_uses_sparse_message_backend", False)) - ): - terminal_dependency_receiver_count = self._fresh_output_dependency_receiver_count( - population_name=backend_population_name, - time_steps=time_steps, - fresh_state_virtualized=True, - ) - fixed_output_active_region = terminal_dependency_receiver_count is not None - state_buffers_needed_by_sequence = materialize_final_state or ( - time_steps > 1 and not fixed_output_active_region - ) - fresh_state_virtualized = backend_population_state_is_fresh and self._can_virtualize_fresh_backend_state( - population_name=backend_population_name, - static_tensors=static_tensors, - projected_message_dim=projected_message_dim, - raw_public_dim=raw_public_dim, - cell_params=cell_params, - ) - defer_fresh_backend_state = backend_population_state_is_fresh and ( - fresh_state_virtualized or fixed_output_active_region - ) - fresh_zero_sentinel_prev = ( - backend_population_state_is_fresh and not state_buffers_needed_by_sequence and time_steps <= 1 - ) - recurrent_prev = ( - projected_boundary_source_seq.new_empty(batch_size, 0, self.hidden_size) - if fresh_zero_sentinel_prev - else projected_boundary_source_seq.new_zeros( - batch_size, - int(terminal_dependency_receiver_count), - self.hidden_size, - ) - if backend_population_state_is_fresh - and fixed_output_active_region - and terminal_dependency_receiver_count is not None - else projected_boundary_source_seq.new_zeros(batch_size, self._num_recurrent_cells, self.hidden_size) - if backend_population_state_is_fresh - else state["cells"][:, self._recurrent_slice, :] - ) - if defer_fresh_backend_state or fresh_zero_sentinel_prev: - packed_state = None - elif backend_population_state_is_fresh: - packed_state = self._init_backend_population_state( - backend_population_name, - batch=batch_size, - device=projected_boundary_source_seq.device, - dtype=projected_boundary_source_seq.dtype, - ) - else: - packed_state = self._population_state_to_backend_state( - backend_population_name, - cast(TensorDictBase, state[backend_population_name]), - ) - initial_recurrent_k = None - initial_recurrent_v = None - if backend_population_state_is_fresh and recurrent_prev.numel() > 0: - initial_receiver_count = int(recurrent_prev.shape[1]) - initial_recurrent_k = recurrent_prev.new_zeros(batch_size, initial_receiver_count, self.head_dim) - initial_recurrent_v = recurrent_prev.new_zeros(batch_size, initial_receiver_count, self.value_dim) - state_sender_k = None if backend_population_state_is_fresh else state.get("sender_k") - state_sender_v = None if backend_population_state_is_fresh else state.get("sender_v") - if torch.is_tensor(state_sender_k) and tuple(state_sender_k.shape) == ( - batch_size, - int(self.sender_cell_idx.numel()), - self.head_dim, - ): - initial_recurrent_k = state_sender_k[:, self._recurrent_slice, :] - if torch.is_tensor(state_sender_v) and tuple(state_sender_v.shape) == ( - batch_size, - int(self.sender_cell_idx.numel()), - self.value_dim, - ): - initial_recurrent_v = state_sender_v[:, self._recurrent_slice, :] - normalized_population_resets = ( - torch.zeros( - batch_size, - time_steps, - device=projected_boundary_source_seq.device, - dtype=torch.bool, - ) - if population_resets is None - else population_resets.to(device=projected_boundary_source_seq.device, dtype=torch.bool) - ).contiguous() - chunk_len = self._projected_boundary_time_chunk_len( - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - packed_state=packed_state, - initial_hidden=recurrent_prev, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - tape_policy = self._tape_policy_from_bin(planned_backend_execution.tape_policy_bin) - outputs: list[torch.Tensor] = [] - running_packed_state = packed_state - running_hidden = recurrent_prev - running_recurrent_k = initial_recurrent_k - running_recurrent_v = initial_recurrent_v - running_state_is_fresh = backend_population_state_is_fresh - graph_capture_cache_hit = False - graph_capture_replayed = False - last_input_k: torch.Tensor | None = None - last_input_v: torch.Tensor | None = None - last_boundary_step: torch.Tensor | None = None - last_output_chunk: torch.Tensor | None = None - for start in range(0, time_steps, chunk_len): - end = min(time_steps, start + chunk_len) - source_chunk = projected_boundary_source_seq[:, start:end] - boundary_chunk = self._project_boundary_source_sequence( - source_chunk, - input_projection_weight=projected_boundary_weight, - input_projection_bias=projected_boundary_bias, - ) - chunk_plan = self.plan_backend_execution( - batch_size=batch_size, - time_steps=end - start, - inner_steps=1, - training=grad_path, - tape_policy=tape_policy, - device=projected_boundary_source_seq.device, - surface_key=selected_backend_surface.key, - ) - chunk_backward_plan = ( - self.plan_backend_backward_execution( - batch_size=batch_size, - time_steps=end - start, - inner_steps=1, - training=True, - tape_policy=tape_policy, - device=projected_boundary_source_seq.device, - surface_key=selected_backend_surface.key, - ) - if grad_path - else None - ) - chunk_materialize_final_state = bool(materialize_final_state) - chunk_output_boundary: Literal["sequence", "terminal"] = ( - output_boundary if end == time_steps else "sequence" - ) - if grad_path: - chunk_compact_input_carry = bool(fixed_output_active_region and not running_state_is_fresh) - ( - output_chunk, - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - last_input_k, - last_input_v, - chunk_cache_hit, - chunk_replayed, - ) = self._execute_or_capture_backend_training_sequence_surface( - boundary_seq=boundary_chunk, - projected_boundary_source_seq=source_chunk, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - initial_state_is_fresh=running_state_is_fresh, - population_resets=normalized_population_resets[:, start:end], - population_resets_active=population_resets is not None, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=chunk_plan, - planned_backend_backward_execution=cast(PlannedFabricBackwardExecution, chunk_backward_plan), - static_tensors=static_tensors, - enable_graph_capture=False, - materialize_final_state=chunk_materialize_final_state, - compact_input_carry=chunk_compact_input_carry, - preserve_internal_carry=end < time_steps, - output_boundary=chunk_output_boundary, - readout_output_boundary=readout_output_boundary, - ) - else: - chunk_compact_input_carry = bool(fixed_output_active_region and not running_state_is_fresh) - ( - output_chunk, - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - input_k_seq, - input_v_seq, - chunk_cache_hit, - chunk_replayed, - ) = self._execute_or_capture_backend_sequence_surface( - backend_population_name=backend_population_name, - boundary_seq=boundary_chunk, - packed_state=running_packed_state, - initial_hidden=running_hidden, - initial_recurrent_k=running_recurrent_k, - initial_recurrent_v=running_recurrent_v, - initial_state_is_fresh=running_state_is_fresh, - population_resets=normalized_population_resets[:, start:end], - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - static_tensors=static_tensors, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=chunk_plan, - materialize_final_state=chunk_materialize_final_state, - compact_input_carry=chunk_compact_input_carry, - preserve_internal_carry=end < time_steps, - output_boundary=chunk_output_boundary, - readout_output_boundary=readout_output_boundary, - ) - last_input_k = input_k_seq[:, -1] - last_input_v = input_v_seq[:, -1] - running_state_is_fresh = False - graph_capture_cache_hit = graph_capture_cache_hit or chunk_cache_hit - graph_capture_replayed = graph_capture_replayed or chunk_replayed - last_boundary_step = boundary_chunk[:, -1] - if output_boundary == "sequence" or end == time_steps: - last_output_chunk = output_chunk - if output_chunk_consumer is None: - outputs.append(output_chunk) - else: - output_chunk_consumer(output_chunk, start, end) - if detach_internal_carry_after_output_chunk and end < time_steps: - running_packed_state = self._detach_backend_static_tensors(running_packed_state) - running_hidden = running_hidden.detach() - running_recurrent_k = None if running_recurrent_k is None else running_recurrent_k.detach() - running_recurrent_v = None if running_recurrent_v is None else running_recurrent_v.detach() - if outputs: - output_seq = torch.cat(outputs, dim=1) - elif last_output_chunk is not None: - if output_boundary == "terminal": - output_seq = last_output_chunk[:, -1:] - else: - output_seq = last_output_chunk.new_empty((batch_size, 0, *tuple(last_output_chunk.shape[2:]))) - else: - output_ports = ( - int(self.readout_slots) if readout_output_boundary == "pooled" else int(self._num_output_cells) - ) - output_seq = projected_boundary_source_seq.new_empty((batch_size, 0, output_ports, self.hidden_size)) - self._record_backend_execution( - surface=selected_backend_surface, - plan=planned_backend_execution, - backward_plan=planned_backend_backward_execution, - batch_size=batch_size, - time_steps=time_steps, - inner_steps=1, - training=grad_path, - graph_capture_replayed=graph_capture_replayed, - graph_capture_cache_hit=graph_capture_cache_hit, - ) - record = self._last_backend_execution - if record is not None: - chunk_reason = getattr(self, "_last_projected_boundary_time_chunk_reason", None) - self._last_backend_execution = replace( - record, - workspace_aliases=record.workspace_aliases - + ( - f"projected_boundary_time_chunk_len:t={int(chunk_len)}", - "projected_boundary_sequence_executor:streaming_chunked", - *((f"projected_boundary_time_chunk_reason:{chunk_reason}",) if chunk_reason else ()), - ), - ) - if output_boundary == "terminal" and self._last_backend_execution is not None: - record = self._last_backend_execution - self._last_backend_execution = replace( - record, - workspace_aliases=record.workspace_aliases - + ( - "sequence_output_boundary:terminal_step", - "sequence_output_materialization:terminal_step_only", - ), - ) - if not materialize_final_state: - return output_seq, TensorDict({}, batch_size=[]) - if last_boundary_step is None or last_input_k is None or last_input_v is None: - raise RuntimeError("Projected Fabric sequence executor is missing final carry tensors") - next_state = TensorDict({}, batch_size=[]) - last_output_cells = output_seq[:, -1] - with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): - next_state["cells"] = torch.cat((last_boundary_step, running_hidden, last_output_cells), dim=1) - with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): - next_state["sender_k"] = torch.cat((last_input_k, running_recurrent_k), dim=1) - with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): - next_state["sender_v"] = torch.cat((last_input_v, running_recurrent_v), dim=1) - next_state[backend_population_name] = self._backend_state_to_population_state( - backend_population_name, - cast(Mapping[str, torch.Tensor], running_packed_state), - ) - return output_seq, next_state - - def _run_backend_sequence_surface_once( - self, - *, - population_name: str, - boundary_seq: torch.Tensor, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - initial_state_is_fresh: bool, - materialize_final_state: bool, - compact_input_carry: bool = False, - preserve_internal_carry: bool = False, - population_resets: torch.Tensor, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - planned_backend_execution: PlannedFabricExecution, - population_materialized: dict[str, object | None], - static_tensors: dict[str, object], - grad_path: bool, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - if output_boundary not in {"sequence", "terminal"}: - raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") - cuda_execution = importlib.import_module("cortical.fabric.backend.cuda.execution") - cuda_recurrence_executor = importlib.import_module("cortical.fabric.backend.cuda.recurrence_executor") - input_k_seq, input_v_seq = self._project_sender_kv_from_cells_sequence( - boundary_seq, - sender_input_to_kv_weight=input_sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, - sender_group_size=self._input_sender_kv_group_size, - ) - execution_families = {bucket_plan.execution_family for bucket_plan in planned_backend_execution.bucket_plans} - math_backends = {bucket_plan.math_backend for bucket_plan in planned_backend_execution.bucket_plans} - if len(execution_families) != 1 or len(math_backends) != 1: - raise RuntimeError( - "Supported Fabric backend received a mixed execution-family/math-backend plan, " - f"execution_families={sorted(family.value for family in execution_families)} " - f"math={sorted(backend.value for backend in math_backends)}" - ) - execution_family = next(iter(execution_families)) - math_backend = next(iter(math_backends)) - message_rule_names = {bucket_plan.message_rule_name for bucket_plan in planned_backend_execution.bucket_plans} - message_rule_lowerings = { - bucket_plan.message_rule_lowering_kind for bucket_plan in planned_backend_execution.bucket_plans - } - message_rule_expression_signatures = { - bucket_plan.message_rule_expression_signature for bucket_plan in planned_backend_execution.bucket_plans - } - message_rule_source_signatures = { - bucket_plan.message_rule_source_signature for bucket_plan in planned_backend_execution.bucket_plans - } - message_rule_parameter_sharing_signatures = { - bucket_plan.message_rule_parameter_sharing_signature - for bucket_plan in planned_backend_execution.bucket_plans - } - message_rule_output_boundaries = { - bucket_plan.message_rule_output_boundary for bucket_plan in planned_backend_execution.bucket_plans - } - if ( - len(message_rule_names) != 1 - or len(message_rule_lowerings) != 1 - or len(message_rule_expression_signatures) != 1 - or len(message_rule_source_signatures) != 1 - or len(message_rule_parameter_sharing_signatures) != 1 - or len(message_rule_output_boundaries) != 1 - ): - raise RuntimeError("Supported Fabric CUDA launch requires one message-rule contract per launch") - launch_plan = cuda_recurrence_executor._single_launch_plan(planned_backend_execution) - stage_receiver_static, replication_factor = cuda_recurrence_executor.backend_surface_launch_policy( - self, - population_name=population_name, - planned_backend_execution=planned_backend_execution, - ) - batch_size, time_steps = input_k_seq.shape[:2] - batch_rows_per_block = launch_plan.state_batch_tile - effective_replication = cuda_recurrence_executor.effective_staged_replication_factor( - stage_receiver_static=stage_receiver_static, - replication_factor=replication_factor, - batch_size=batch_size, - batch_rows_per_block=batch_rows_per_block, - ) - reset_rows_present = bool(population_resets.any().item()) - request_packed_state = packed_state - request_initial_hidden = initial_hidden - request_initial_recurrent_k = initial_recurrent_k - request_initial_recurrent_v = initial_recurrent_v - if int(time_steps) > 1: - with torch.profiler.record_function("fabric.glue.preserve_sequence_carry_inputs"): - request_packed_state = self._clone_backend_carry_value(packed_state) - request_initial_hidden = cast(torch.Tensor, self._clone_backend_carry_value(initial_hidden)) - request_initial_recurrent_k = cast( - torch.Tensor | None, - self._clone_backend_carry_value(initial_recurrent_k), - ) - request_initial_recurrent_v = cast( - torch.Tensor | None, - self._clone_backend_carry_value(initial_recurrent_v), - ) - request = self._build_backend_sequence_request( - population_name=population_name, - input_k_seq=input_k_seq, - input_v_seq=input_v_seq, - packed_state=request_packed_state, - initial_hidden=request_initial_hidden, - initial_recurrent_k=request_initial_recurrent_k, - initial_recurrent_v=request_initial_recurrent_v, - initial_state_is_fresh=initial_state_is_fresh, - materialize_final_state=materialize_final_state, - compact_input_carry=compact_input_carry, - preserve_internal_carry=preserve_internal_carry, - resets_u8=population_resets.to(dtype=torch.uint8), - reset_rows_present=reset_rows_present, - stage_receiver_static=stage_receiver_static, - replication_factor=effective_replication, - receiver_tile=launch_plan.receiver_tile, - batch_tile=launch_plan.batch_tile, - edge_tile=launch_plan.edge_tile, - hidden_chunk=launch_plan.hidden_chunk, - state_receiver_tile=launch_plan.state_receiver_tile, - state_batch_tile=launch_plan.state_batch_tile, - state_hidden_chunk=launch_plan.state_hidden_chunk, - state_static_stage_mode=launch_plan.state_static_stage_mode if stage_receiver_static else "disabled", - emit_receiver_tile=launch_plan.emit_receiver_tile, - emit_batch_tile=launch_plan.emit_batch_tile, - emit_hidden_chunk=launch_plan.emit_hidden_chunk, - emit_static_stage_mode=launch_plan.emit_static_stage_mode if stage_receiver_static else "disabled", - public_receiver_tile=launch_plan.public_receiver_tile, - public_batch_tile=launch_plan.public_batch_tile, - readout_mode=launch_plan.readout_mode, - readout_port_tile=launch_plan.readout_port_tile, - readout_output_chunk=launch_plan.readout_output_chunk, - cell_static_stage_mode=launch_plan.cell_static_stage_mode if stage_receiver_static else "disabled", - message_rule_name=next(iter(message_rule_names)), - message_rule_lowering_kind=next(iter(message_rule_lowerings)), - message_rule_expression_signature=next(iter(message_rule_expression_signatures)), - message_rule_source_signature=next(iter(message_rule_source_signatures)), - message_rule_parameter_sharing_signature=next(iter(message_rule_parameter_sharing_signatures)), - message_rule_output_boundary=next(iter(message_rule_output_boundaries)), - population_materialized=population_materialized, - static_tensors=static_tensors, - grad_path=grad_path, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - spatial_ownership, temporal_execution = self._backend_execution_semantics(execution_family) - request = cuda_execution.normalize_launch_request(request) - output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v = ( - cuda_execution.run_registered_execution( - spatial_ownership=cast(Any, spatial_ownership), - temporal_execution=cast(Any, temporal_execution), - math_backend=math_backend, - request=request, - ) - ) - self._record_cuda_launch_metadata( - request, - spatial_ownership=spatial_ownership, - temporal_execution=temporal_execution, - actual_launch_metadata=cuda_execution.last_launch_metadata(), - ) - self._last_backend_forward_carry_checkpoints = cuda_execution.last_forward_carry_checkpoints() - return output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v, input_k_seq, input_v_seq - - def _run_backend_sequence_surface_step_primitives( - self, - *, - population_name: str, - boundary_step: torch.Tensor, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - population_resets: torch.Tensor | None, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - static_tensors: dict[str, object], - ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - artifacts = self._compute_backend_sequence_surface_step_artifacts( - population_name=population_name, - boundary_step=boundary_step, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - population_resets=population_resets, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - static_tensors=static_tensors, - ) - return ( - artifacts.output_cells, - artifacts.next_packed_state, - artifacts.recurrent_hidden, - artifacts.recurrent_k, - artifacts.recurrent_v, - artifacts.input_k, - artifacts.input_v, - ) - - def _compute_backend_sequence_surface_step_artifacts( - self, - *, - population_name: str, - boundary_step: torch.Tensor, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - population_resets: torch.Tensor | None, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - static_tensors: dict[str, object], - materialize_output_artifacts: bool = True, - materialize_transition_backward_tape: bool = False, - materialize_diagonal_preproj_tape: bool = True, - materialize_recurrence_backward_tape: bool | None = None, - materialize_recurrent_kv: bool = True, - materialize_next_state: bool = True, - materialize_trace_state_next: bool = True, - active_receiver_window: _ReceiverWindowSpec | None = None, - artifact_owner_scope: str = "artifact", - ) -> _BackendSequenceStepArtifacts: - if not self._partitioned_layout: - raise RuntimeError("Supported Fabric backend sequence surfaces require the partitioned sender layout") - artifact_owner_scope = artifact_owner_scope.strip(".") or "artifact" - - def artifact_owner_name(name: str) -> str: - return f"{artifact_owner_scope}.{name}" - - recurrent_q = cast(torch.Tensor, static_tensors["recurrent_q"]) - output_q = cast(torch.Tensor, static_tensors["output_q"]) - value_to_output_weight = cast(torch.Tensor, static_tensors["value_to_output_weight"]) - recurrent_packed_state = packed_state - recurrent_hidden_prev = initial_hidden - if active_receiver_window is None: - inferred_window = self._fixed_output_dependency_receiver_window( - reason="streaming_output_active_region:checkpoint_compact_carry" - ) - compact_receiver_count: int | None = None - if torch.is_tensor(recurrent_hidden_prev) and recurrent_hidden_prev.dim() >= 2: - hidden_receivers = int(recurrent_hidden_prev.shape[1]) - if hidden_receivers > 0: - compact_receiver_count = hidden_receivers - if compact_receiver_count is None and recurrent_packed_state is not None: - _packed_state_keys, packed_state_tensors = _flatten_backend_packed_state(recurrent_packed_state) - for tensor in packed_state_tensors: - if torch.is_tensor(tensor) and tensor.dim() >= 2 and int(tensor.shape[1]) > 0: - compact_receiver_count = int(tensor.shape[1]) - break - if ( - inferred_window is not None - and inferred_window.active - and compact_receiver_count == int(inferred_window.count) - ): - active_receiver_window = inferred_window - reset_bank_mask: torch.Tensor | None = None - if population_resets is not None: - reset_bank_mask = torch.as_tensor( - population_resets, - device=boundary_step.device, - dtype=torch.bool, - ).view(boundary_step.shape[0], 1, 1) - recurrent_hidden_prev = reset_backend_tensor_rows( - recurrent_hidden_prev, - reset_bank_mask.view(boundary_step.shape[0]), - ) - recurrent_packed_state = reset_backend_state_rows_cuda( - recurrent_packed_state, - torch.as_tensor(population_resets, device=boundary_step.device, dtype=torch.bool).view( - boundary_step.shape[0] - ), - ) - with self._backend_owner_timing(artifact_owner_name("input_projection")): - input_k, input_v = self._project_sender_kv_from_cells_step( - boundary_step, - sender_input_to_kv_weight=input_sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, - sender_group_size=self._input_sender_kv_group_size, - ) - recurrent_k_prev, recurrent_v_prev = self._resolve_backend_initial_recurrent_kv( - population_name=population_name, - initial_hidden=recurrent_hidden_prev, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - static_tensors=static_tensors, - active_receiver_window=active_receiver_window, - ) - if reset_bank_mask is not None: - reset_rows = reset_bank_mask.view(boundary_step.shape[0]) - recurrent_k_prev, recurrent_v_prev = reset_backend_tensors_rows_cuda( - (recurrent_k_prev, recurrent_v_prev), - reset_rows, - ) - compact_recurrent_senders = self._receiver_window_compacts_recurrent_senders(active_receiver_window) - recurrent_q_for_message = self._slice_receiver_window_rows(recurrent_q, active_receiver_window) - recurrent_k_for_message = self._slice_receiver_window_recurrent_bank( - recurrent_k_prev, - active_receiver_window, - ) - recurrent_v_for_message = self._slice_receiver_window_recurrent_bank( - recurrent_v_prev, - active_receiver_window, - ) - recurrent_local_sender_idx = self._cached_receiver_window_sender_table( - name="recurrent", - table=self.recurrent_local_sender_idx, - window=active_receiver_window, - num_input_senders=int(self._num_input_cells), - slice_receivers=True, - compact_recurrent_senders=compact_recurrent_senders, - ) - recurrent_sender_count = ( - int(self._num_input_cells) + int(active_receiver_window.count) - if compact_recurrent_senders and active_receiver_window is not None - else int(self.sender_cell_idx.numel()) - ) - recurrent_local_receiver_idx_by_sender = ( - self.recurrent_local_receiver_idx_by_sender - if active_receiver_window is None or not active_receiver_window.active - else self._cached_sender_reverse_table( - name="recurrent", - receiver_sender_idx=recurrent_local_sender_idx, - num_senders=recurrent_sender_count, - ) - ) - recurrent_neighbor_idx = ( - self.recurrent_neighbor_idx - if active_receiver_window is None or not active_receiver_window.active - else self._cached_receiver_window_static_rows( - name="recurrent_neighbor_idx", - tensor=self.recurrent_neighbor_idx, - window=active_receiver_window, - ) - ) - recurrent_neighbor_valid = ( - self.recurrent_neighbor_valid - if active_receiver_window is None or not active_receiver_window.active - else self._cached_receiver_window_static_rows( - name="recurrent_neighbor_valid", - tensor=self.recurrent_neighbor_valid, - window=active_receiver_window, - ) - ) - recurrent_edge_distance = ( - self.recurrent_edge_distance - if active_receiver_window is None or not active_receiver_window.active - else self._cached_receiver_window_static_rows( - name="recurrent_edge_distance", - tensor=self.recurrent_edge_distance, - window=active_receiver_window, - ) - ) - recurrent_edge_delay = ( - self.recurrent_edge_delay - if active_receiver_window is None or not active_receiver_window.active - else self._cached_receiver_window_static_rows( - name="recurrent_edge_delay", - tensor=self.recurrent_edge_delay, - window=active_receiver_window, - ) - ) - with self._backend_owner_timing(artifact_owner_name("recurrent_message")): - recurrent_msg = self._compute_messages_step_subset_partitioned_raw( - input_k, - input_v, - recurrent_k_for_message, - recurrent_v_for_message, - q_subset=recurrent_q_for_message, - neighbor_idx=recurrent_neighbor_idx, - neighbor_valid=recurrent_neighbor_valid, - edge_distance=recurrent_edge_distance, - edge_delay=recurrent_edge_delay, - use_delay=self._has_edge_delay, - step_idx=1, - local_sender_idx=recurrent_local_sender_idx, - local_receiver_idx_by_sender=recurrent_local_receiver_idx_by_sender, - ) - transition_packed_state = self._slice_receiver_window_batch_rows( - recurrent_packed_state, - active_receiver_window, - ) - hidden_before_for_artifact = ( - self._slice_receiver_window_batch_rows( - initial_hidden, - active_receiver_window, - ) - if initial_recurrent_k is None or initial_recurrent_v is None - else None - ) - recurrent_k_before_for_artifact = ( - self._slice_receiver_window_recurrent_bank( - initial_recurrent_k, - active_receiver_window, - ) - if initial_recurrent_k is not None - else None - ) - recurrent_v_before_for_artifact = ( - self._slice_receiver_window_recurrent_bank( - initial_recurrent_v, - active_receiver_window, - ) - if initial_recurrent_v is not None - else None - ) - transition_static_tensors = self._cached_receiver_window_static_tensors( - static_tensors, - active_receiver_window, - ) - with self._backend_owner_timing(artifact_owner_name("transition")): - transition_result = self._lower_backend_population_transition_forward_result_shared( - recurrent_msg=recurrent_msg, - packed_state_before=transition_packed_state, - population_reset_step=None, - static_tensors=transition_static_tensors, - materialize_recurrent_kv=bool(materialize_recurrent_kv or materialize_output_artifacts), - materialize_backward_tape=materialize_transition_backward_tape, - materialize_diagonal_preproj_tape=materialize_diagonal_preproj_tape, - materialize_recurrence_backward_tape=materialize_recurrence_backward_tape, - materialize_next_state=materialize_next_state, - materialize_trace_state_next=materialize_trace_state_next, - ) - next_packed_state = transition_result.next_packed_state - recurrent_hidden = transition_result.recurrent_hidden - recurrent_k = transition_result.recurrent_k - recurrent_v = transition_result.recurrent_v - output_msg = None - output_cells = None - if materialize_output_artifacts: - output_local_sender_idx = self._cached_receiver_window_sender_table( - name="output", - table=self.output_local_sender_idx, - window=active_receiver_window, - num_input_senders=int(self._num_input_cells), - slice_receivers=False, - compact_recurrent_senders=True, - ) - output_local_receiver_idx_by_sender = ( - self.output_local_receiver_idx_by_sender - if active_receiver_window is None or not active_receiver_window.active - else self._cached_sender_reverse_table( - name="output", - receiver_sender_idx=output_local_sender_idx, - num_senders=int(self._num_input_cells) + int(active_receiver_window.count), - ) - ) - with self._backend_owner_timing(artifact_owner_name("output_message")): - if recurrent_k is None or recurrent_v is None: - raise RuntimeError("Fabric output-message artifact requires materialized recurrent K/V") - output_msg = self._compute_messages_step_subset_partitioned_raw( - input_k, - input_v, - recurrent_k, - recurrent_v, - q_subset=output_q, - neighbor_idx=self.output_neighbor_idx, - neighbor_valid=self.output_neighbor_valid, - edge_distance=self.output_edge_distance, - edge_delay=self.output_edge_delay, - use_delay=self._has_edge_delay, - step_idx=1, - local_sender_idx=output_local_sender_idx, - local_receiver_idx_by_sender=output_local_receiver_idx_by_sender, - ) - with self._backend_owner_timing(artifact_owner_name("output_projection")): - output_cells = self._project_output_cells_step_raw( - output_msg, - value_to_output_weight=value_to_output_weight, - ).to(dtype=boundary_step.dtype) - return _BackendSequenceStepArtifacts( - boundary_step=boundary_step, - population_reset_step=None - if population_resets is None - else torch.as_tensor(population_resets, device=boundary_step.device), - packed_state_before=transition_packed_state, - hidden_before=hidden_before_for_artifact, - recurrent_k_before=recurrent_k_before_for_artifact, - recurrent_v_before=recurrent_v_before_for_artifact, - input_k=input_k, - input_v=input_v, - recurrent_k_bank=recurrent_k_for_message, - recurrent_v_bank=recurrent_v_for_message, - recurrent_msg=recurrent_msg, - next_packed_state=next_packed_state, - recurrent_hidden=recurrent_hidden, - recurrent_k=recurrent_k, - recurrent_v=recurrent_v, - transition_backward_tape=transition_result.backward_tape, - output_msg=output_msg, - output_cells=output_cells, - active_receiver_window=active_receiver_window, - ) - - def _execute_or_capture_backend_sequence_surface( - self, - *, - backend_population_name: str, - boundary_seq: torch.Tensor, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - initial_state_is_fresh: bool, - population_resets: torch.Tensor, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - static_tensors: dict[str, object], - selected_backend_surface: SupportedSurface, - planned_backend_execution: PlannedFabricExecution, - materialize_final_state: bool, - compact_input_carry: bool = False, - preserve_internal_carry: bool = False, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool]: - if output_boundary not in {"sequence", "terminal"}: - raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") - if readout_output_boundary not in {"cells", "pooled"}: - raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") - input_layout, graph_inputs = self._build_backend_graph_inputs( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - population_resets=population_resets, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - packed_state_is_fresh=initial_state_is_fresh, - ) - shape_signature = self._graph_shape_signature( - graph_inputs=graph_inputs, - ) - graph_key = self._backend_graph_capture_key( - surface=selected_backend_surface, - plan=planned_backend_execution, - shape_signature=shape_signature, - ) - cached = self._backend_graph_capture_cache.get(graph_key) - cache_hit = cached is not None - if not cache_hit: - self._backend_graph_capture_cache.put( - graph_key, - _GraphCaptureFallback(key=graph_key, shape_signature=shape_signature), - ) - output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v, input_k_seq, input_v_seq = ( - self._run_backend_sequence_surface_once( - population_name=backend_population_name, - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=initial_state_is_fresh, - materialize_final_state=materialize_final_state, - compact_input_carry=compact_input_carry, - preserve_internal_carry=preserve_internal_carry, - population_resets=population_resets, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - planned_backend_execution=planned_backend_execution, - population_materialized=cast(dict[str, object | None], static_tensors["population_materialized"]), - static_tensors=static_tensors, - grad_path=False, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - ) - return ( - output_seq, - next_packed_state, - recurrent_hidden, - recurrent_k, - recurrent_v, - input_k_seq, - input_v_seq, - cache_hit, - True, - ) - - def _execute_backend_sequence_with_tape_policy( - self, - *, - backend_population_name: str, - boundary_seq: torch.Tensor, - projected_boundary_source_seq: torch.Tensor | None, - projected_boundary_weight: torch.Tensor | None, - projected_boundary_bias: torch.Tensor | None, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - population_resets: torch.Tensor, - population_resets_active: bool, - initial_state_is_fresh: bool, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - static_tensors: dict[str, object], - selected_backend_surface: SupportedSurface, - planned_backend_execution: PlannedFabricExecution, - planned_backend_backward_execution: PlannedFabricBackwardExecution, - materialize_final_state: bool, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool]: - chunk_len = self._backend_tape_checkpoint_chunk_len( - plan=planned_backend_execution, - time_steps=int(boundary_seq.shape[1]), - output_boundary=output_boundary, - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - ) - self._backend_backward_batch_tile_len( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - output_boundary=output_boundary, - ) - use_training_graph_capture = ( - self._should_use_backend_graph_capture( - plan=planned_backend_execution, - device=boundary_seq.device, - grad_path=True, - time_steps=int(boundary_seq.shape[1]), - ) - and self._backend_sequence_graph_capture_safe() - ) - del chunk_len - ( - output_seq, - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - last_input_k, - last_input_v, - graph_capture_cache_hit, - graph_capture_replayed, - ) = self._execute_or_capture_backend_training_sequence_surface( - boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - packed_state=packed_state, - initial_hidden=initial_hidden, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=initial_state_is_fresh, - population_resets=population_resets, - population_resets_active=population_resets_active, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - planned_backend_backward_execution=planned_backend_backward_execution, - static_tensors=static_tensors, - enable_graph_capture=use_training_graph_capture, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - assert last_input_k is not None and last_input_v is not None - assert running_recurrent_k is not None and running_recurrent_v is not None - return ( - output_seq, - running_packed_state, - running_hidden, - running_recurrent_k, - running_recurrent_v, - last_input_k, - last_input_v, - graph_capture_cache_hit, - graph_capture_replayed, - ) - - def _execute_or_capture_backend_training_sequence_surface( - self, - *, - boundary_seq: torch.Tensor, - projected_boundary_source_seq: torch.Tensor | None = None, - projected_boundary_weight: torch.Tensor | None = None, - projected_boundary_bias: torch.Tensor | None = None, - packed_state: Any, - initial_hidden: torch.Tensor, - initial_recurrent_k: torch.Tensor | None, - initial_recurrent_v: torch.Tensor | None, - population_resets: torch.Tensor, - population_resets_active: bool = True, - selected_backend_surface: SupportedSurface, - planned_backend_execution: PlannedFabricExecution, - planned_backend_backward_execution: PlannedFabricBackwardExecution, - static_tensors: dict[str, object] | None, - enable_graph_capture: bool, - initial_state_is_fresh: bool = False, - materialize_final_state: bool = True, - compact_input_carry: bool = False, - preserve_internal_carry: bool = False, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - ) -> tuple[torch.Tensor, Any, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, bool, bool]: - if output_boundary not in {"sequence", "terminal"}: - raise ValueError(f"Unsupported Fabric training output boundary {output_boundary!r}") - if readout_output_boundary not in {"cells", "pooled"}: - raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") - projected_boundary_active = ( - projected_boundary_source_seq is not None - or projected_boundary_weight is not None - or projected_boundary_bias is not None - ) - if projected_boundary_active and (projected_boundary_source_seq is None or projected_boundary_weight is None): - raise RuntimeError("Projected Fabric boundary surface requires both source sequence and projection weight") - input_layout, graph_inputs = self._build_backend_graph_inputs( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=initial_hidden, - population_resets=population_resets, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - packed_state_is_fresh=initial_state_is_fresh and packed_state is None, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - ) - shape_signature = self._graph_shape_signature(graph_inputs=graph_inputs) - graph_key = self._backend_graph_capture_key( - surface=selected_backend_surface, - plan=planned_backend_execution, - shape_signature=shape_signature, - ) - if static_tensors is None: - static_tensors = self._materialize_inference_static_tensors( - device=boundary_seq.device, - dtype=boundary_seq.dtype, - ) - trainable_param_items = tuple( - (name, parameter) for name, parameter in self.named_parameters() if parameter.requires_grad - ) - trainable_param_names = tuple(name for name, _parameter in trainable_param_items) - trainable_params = tuple(parameter for _name, parameter in trainable_param_items) - backend_population_name = self._select_output_cells_stream_backend_population( - k=1, - ) - if backend_population_name is None: - raise RuntimeError( - f"Supported Fabric {selected_backend_surface.cell_type} " - "training surface requires a callable backend sequence engine" - ) - - def packed_state_output_keys() -> tuple[str, ...] | None: - if input_layout.packed_state_is_fresh and not input_layout.packed_state_shapes: - if not materialize_final_state and not preserve_internal_carry: - return () - return self._cell_spec_for_population(backend_population_name).state_schema.keys - return input_layout.packed_state_keys - - def packed_state_output_count() -> int: - keys = packed_state_output_keys() - return 0 if keys is None else len(keys) - - def unpack_outputs( - graph_outputs: tuple[torch.Tensor, ...], - ) -> tuple[ - torch.Tensor, - Any, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: - packed_state_keys = packed_state_output_keys() - packed_state_count = packed_state_output_count() - output_seq = graph_outputs[0] - next_packed_state = ( - None - if packed_state_keys is not None and packed_state_count == 0 - else _unflatten_backend_packed_state( - packed_state_keys, - tuple(graph_outputs[1 : 1 + packed_state_count]), - ) - ) - recurrent_hidden = graph_outputs[1 + packed_state_count] - recurrent_k = graph_outputs[2 + packed_state_count] - recurrent_v = graph_outputs[3 + packed_state_count] - input_k_last = graph_outputs[4 + packed_state_count] - input_v_last = graph_outputs[5 + packed_state_count] - return ( - output_seq, - next_packed_state, - recurrent_hidden, - recurrent_k, - recurrent_v, - input_k_last, - input_v_last, - ) - - def run_training_sequence_surface( - current_graph_inputs: dict[str, torch.Tensor], - ) -> tuple[torch.Tensor, ...]: - current_packed_state, current_recurrent_k, current_recurrent_v = self._unpack_backend_graph_inputs( - input_layout=input_layout, - graph_inputs=current_graph_inputs, - ) - current_boundary_seq = current_graph_inputs.get("boundary_seq") - if current_boundary_seq is None: - current_source_hidden_seq = current_graph_inputs["projected_boundary_source_seq"] - current_projection_weight = current_graph_inputs["projected_boundary_weight"] - current_projection_bias = current_graph_inputs.get("projected_boundary_bias") - current_boundary_seq = self._project_boundary_source_sequence( - current_source_hidden_seq, - input_projection_weight=current_projection_weight, - input_projection_bias=current_projection_bias, - ) - output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v, input_k_seq, input_v_seq = ( - self._run_backend_sequence_surface_once( - population_name=backend_population_name, - boundary_seq=current_boundary_seq, - packed_state=current_packed_state, - initial_hidden=current_graph_inputs["initial_hidden"], - initial_recurrent_k=current_recurrent_k, - initial_recurrent_v=current_recurrent_v, - initial_state_is_fresh=initial_state_is_fresh, - materialize_final_state=materialize_final_state, - compact_input_carry=compact_input_carry, - preserve_internal_carry=preserve_internal_carry, - population_resets=current_graph_inputs["population_resets"], - input_sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"] - ), - input_group_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"] - ), - planned_backend_execution=planned_backend_execution, - population_materialized=cast(dict[str, object | None], static_tensors["population_materialized"]), - static_tensors=static_tensors, - grad_path=True, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - ) - if output_boundary == "terminal": - output_seq = output_seq[:, -1:] - if not materialize_final_state and not preserve_internal_carry: - dummy = current_boundary_seq.new_empty(0) - return ( - output_seq, - *(dummy for _shape in input_layout.packed_state_shapes), - dummy, - dummy, - dummy, - dummy, - dummy, - ) - next_packed_state_keys, next_packed_state_inputs = _flatten_backend_packed_state(next_packed_state) - if next_packed_state_keys != packed_state_output_keys(): - raise RuntimeError("Backend training graph capture must preserve packed-state structure") - return ( - output_seq, - *next_packed_state_inputs, - recurrent_hidden, - recurrent_k, - recurrent_v, - input_k_seq[:, -1], - input_v_seq[:, -1], - ) - - capture_state = {"cache_hit": False, "replayed": False} - if enable_graph_capture: - cached = self._backend_graph_capture_cache.get(graph_key) - capture_state["cache_hit"] = cached is not None - capture_state["replayed"] = True - if cached is None: - self._backend_graph_capture_cache.put( - graph_key, - _GraphCaptureFallback(key=graph_key, shape_signature=shape_signature), - ) - - class _CapturedTrainingSequenceSurface(torch.autograd.Function): - @staticmethod - def forward(ctx, *forward_inputs: torch.Tensor): # type: ignore[override] - ctx.set_materialize_grads(False) - sequence_inputs = forward_inputs[: len(input_layout.input_names)] - current_graph_inputs = { - name: tensor for name, tensor in zip(input_layout.input_names, sequence_inputs, strict=True) - } - graph_outputs = run_training_sequence_surface(current_graph_inputs) - ctx.sequence_input_requires_grad = tuple( - tensor.requires_grad if tensor.is_floating_point() else False for tensor in sequence_inputs - ) - ctx.forward_carry_checkpoints = getattr(self, "_last_backend_forward_carry_checkpoints", None) - ctx.save_for_backward(*sequence_inputs) - return graph_outputs - - @staticmethod - def backward(ctx, *grad_outputs: torch.Tensor | None): # type: ignore[override] - grad_outputs = tuple( - grad.contiguous() if torch.is_tensor(grad) and not grad.is_contiguous() else grad - for grad in grad_outputs - ) - saved_sequence_inputs = ctx.saved_tensors - graph_inputs = { - name: tensor for name, tensor in zip(input_layout.input_names, saved_sequence_inputs, strict=True) - } - boundary_seq_for_backward = graph_inputs.get("boundary_seq") - if boundary_seq_for_backward is None: - with torch.no_grad(): - source_hidden_seq = graph_inputs["projected_boundary_source_seq"] - projection_weight = graph_inputs["projected_boundary_weight"] - projection_bias = graph_inputs.get("projected_boundary_bias") - boundary_seq_for_backward = self._project_boundary_source_sequence( - source_hidden_seq, - input_projection_weight=projection_weight, - input_projection_bias=projection_bias, - ) - packed_state, _unused_recurrent_k, _unused_recurrent_v = self._unpack_backend_graph_inputs( - input_layout=input_layout, - graph_inputs=graph_inputs, - ) - initial_recurrent_k = graph_inputs["initial_recurrent_k"] - initial_recurrent_v = graph_inputs["initial_recurrent_v"] - packed_state_count = packed_state_output_count() - structured_grad_next_packed_state = ( - None - if packed_state_count == 0 - else _unflatten_backend_packed_state( - packed_state_output_keys(), - tuple(grad_outputs[1 : 1 + packed_state_count]), - ) - ) - with torch.profiler.record_function("fabric.backward.total"): - backward_kwargs = { - "boundary_seq": boundary_seq_for_backward, - "projected_boundary_source_seq": graph_inputs.get("projected_boundary_source_seq"), - "projected_boundary_weight": graph_inputs.get("projected_boundary_weight"), - "projected_boundary_bias": graph_inputs.get("projected_boundary_bias"), - "packed_state": packed_state, - "initial_hidden": graph_inputs["initial_hidden"], - "initial_recurrent_k": None if initial_recurrent_k.numel() == 0 else initial_recurrent_k, - "initial_recurrent_v": None if initial_recurrent_v.numel() == 0 else initial_recurrent_v, - "initial_state_is_fresh": initial_state_is_fresh, - "population_resets": graph_inputs["population_resets"] if population_resets_active else None, - "planned_backend_execution": planned_backend_execution, - "planned_backend_backward_execution": planned_backend_backward_execution, - "grad_output_seq": grad_outputs[0], - "grad_next_packed_state": structured_grad_next_packed_state, - "grad_recurrent_hidden": grad_outputs[1 + packed_state_count], - "grad_recurrent_k": grad_outputs[2 + packed_state_count], - "grad_recurrent_v": grad_outputs[3 + packed_state_count], - "grad_input_k_last": grad_outputs[4 + packed_state_count], - "grad_input_v_last": grad_outputs[5 + packed_state_count], - "trainable_params": trainable_params, - "trainable_param_names": trainable_param_names, - "replay_static_tensors": static_tensors, - "output_boundary": output_boundary, - "forward_carry_checkpoints": getattr(ctx, "forward_carry_checkpoints", None), - } - backward_mode = os.environ.get(_BACKWARD_MODE_ENV) - if backward_mode in {None, "", "physical_plan"}: - grad_sequence_map, grad_params = _PhysicalBackwardSequenceExecutor( - runtime=self, - plan=planned_backend_backward_execution, - ).run(**backward_kwargs) - elif os.environ.get(_BACKWARD_ATTRIBUTION_MODE_ENV) == "phase_decomposed_probe": - with torch.profiler.record_function("fabric.backward.phase_decomposed_probe"): - grad_sequence_map, grad_params = self._run_backend_sequence_surface_backward_once( - **backward_kwargs - ) - elif backward_mode in {"full_replay", "reference_replay"}: - grad_sequence_map, grad_params = self._run_backend_sequence_surface_backward_full_replay_once( - **backward_kwargs - ) - else: - raise RuntimeError(f"Unsupported Fabric backward mode {backward_mode!r}") - if backward_mode in {None, "", "physical_plan"}: - self._append_backend_backward_runtime_metadata() - backward_grads: list[torch.Tensor | None] = [] - for name in input_layout.input_names: - backward_grads.append(grad_sequence_map.get(name)) - backward_grads.extend(grad_params) - return tuple(backward_grads) - - outputs = _CapturedTrainingSequenceSurface.apply( - *(graph_inputs[name] for name in input_layout.input_names), - *trainable_params, - ) - unpacked = unpack_outputs(cast(tuple[torch.Tensor, ...], outputs)) - return (*unpacked, capture_state["cache_hit"], capture_state["replayed"]) - - def _execute_backend_supported_sequence_surface_batch_tiled( - self, - *, - batch_tile_len: int, - batch_tile_reason: str | None, - state: TensorDict, - boundary_seq: torch.Tensor, - projected_boundary_source_seq: torch.Tensor | None, - projected_boundary_weight: torch.Tensor | None, - projected_boundary_bias: torch.Tensor | None, - static_tensors: dict[str, object], - population_resets: torch.Tensor | None, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - backend_population_name: str, - backend_population_state_is_fresh: bool, - materialize_final_state: bool, - grad_path: bool, - selected_backend_surface: SupportedSurface, - planned_backend_execution: PlannedFabricExecution, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - ) -> tuple[torch.Tensor, TensorDict]: - if readout_output_boundary not in {"cells", "pooled"}: - raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") - output_chunks: list[torch.Tensor] = [] - output_seq: torch.Tensor | None = None - preallocate_output = not grad_path - state_chunks: list[TensorDict] = [] - for start in range(0, int(boundary_seq.shape[0]), int(batch_tile_len)): - end = min(start + int(batch_tile_len), int(boundary_seq.shape[0])) - tile_plan = self.plan_backend_execution( - batch_size=end - start, - time_steps=int(boundary_seq.shape[1]), - inner_steps=1, - training=grad_path, - tape_policy=self._tape_policy_from_bin(planned_backend_execution.tape_policy_bin), - device=boundary_seq.device, - surface_key=selected_backend_surface.key, - ) - output_chunk, state_chunk = self._execute_backend_supported_sequence_surface( - state=state - if backend_population_state_is_fresh - else cast(TensorDict, _slice_batch_tree(state, start, end)), - boundary_seq=boundary_seq[start:end], - projected_boundary_source_seq=_slice_batch_tensor(projected_boundary_source_seq, start, end), - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - static_tensors=static_tensors, - population_resets=_slice_batch_tensor(population_resets, start, end), - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - backend_population_name=backend_population_name, - backend_population_state_is_fresh=backend_population_state_is_fresh, - materialize_final_state=materialize_final_state, - grad_path=grad_path, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=tile_plan, - allow_batch_tiling=False, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - if preallocate_output: - if output_seq is None: - output_seq = output_chunk.new_empty((int(boundary_seq.shape[0]), *output_chunk.shape[1:])) - output_seq[start:end].copy_(output_chunk) - else: - output_chunks.append(output_chunk) - if materialize_final_state: - state_chunks.append(state_chunk) - if output_seq is None: - output_seq = torch.cat(output_chunks, dim=0) - next_state = ( - cast(TensorDict, _cat_batch_tree(cast(list[Any], state_chunks))) - if materialize_final_state - else TensorDict({}, batch_size=[]) - ) - if self._last_backend_execution is not None: - record = self._last_backend_execution - self._last_backend_execution = replace( - record, - batch_size=int(boundary_seq.shape[0]), - workspace_aliases=record.workspace_aliases - + ( - f"forward_batch_tile:b={int(batch_tile_len)}", - f"forward_batch_tile_reason:{batch_tile_reason or self._last_backend_forward_batch_tile_reason}", - f"forward_batch_tile_output:{'preallocated' if preallocate_output else 'cat'}", - ), - ) - return output_seq, next_state - - def _execute_backend_supported_sequence_surface( - self, - *, - state: TensorDict, - boundary_seq: torch.Tensor, - projected_boundary_source_seq: torch.Tensor | None = None, - projected_boundary_weight: torch.Tensor | None = None, - projected_boundary_bias: torch.Tensor | None = None, - static_tensors: dict[str, object], - population_resets: torch.Tensor | None, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - backend_population_name: str | None, - backend_population_state_is_fresh: bool, - materialize_final_state: bool, - grad_path: bool, - selected_backend_surface: SupportedSurface, - planned_backend_execution: PlannedFabricExecution, - allow_batch_tiling: bool = True, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - ) -> tuple[torch.Tensor, TensorDict]: - if output_boundary not in {"sequence", "terminal"}: - raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") - if readout_output_boundary not in {"cells", "pooled"}: - raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") - batch_size = boundary_seq.shape[0] - if backend_population_name is None: - raise RuntimeError("Supported Fabric sequence surface requires a recurrent backend cell population") - planned_backend_backward_execution = ( - self.plan_backend_backward_execution( - batch_size=batch_size, - time_steps=int(boundary_seq.shape[1]), - inner_steps=1, - training=True, - tape_policy=self._tape_policy_from_bin(planned_backend_execution.tape_policy_bin), - device=boundary_seq.device, - surface_key=selected_backend_surface.key, - ) - if grad_path - else None - ) - population_name = backend_population_name - use_backend_tape_policy = self._should_use_backend_tape_policy( - plan=planned_backend_execution, - grad_path=grad_path, - time_steps=int(boundary_seq.shape[1]), - ) - use_backend_graph_capture = ( - self._should_use_backend_graph_capture( - plan=planned_backend_execution, - device=boundary_seq.device, - grad_path=grad_path, - time_steps=int(boundary_seq.shape[1]), - ) - and self._backend_sequence_graph_capture_safe() - ) - time_steps = int(boundary_seq.shape[1]) - projected_message_dim, raw_public_dim, cell_params = self._backend_sequence_surface_projection_dims( - population_name=backend_population_name, - static_tensors=static_tensors, - ) - terminal_dependency_receiver_count = None - if ( - backend_population_state_is_fresh - and not materialize_final_state - and self._local_message_step_enabled - and not self._has_edge_delay - and not bool(getattr(self, "_uses_sparse_message_backend", False)) - ): - terminal_dependency_receiver_count = self._fresh_output_dependency_receiver_count( - population_name=backend_population_name, - time_steps=time_steps, - fresh_state_virtualized=True, - ) - fixed_output_active_region = terminal_dependency_receiver_count is not None - state_buffers_needed_by_sequence = materialize_final_state or ( - time_steps > 1 and not fixed_output_active_region - ) - fresh_state_virtualized = backend_population_state_is_fresh and self._can_virtualize_fresh_backend_state( - population_name=backend_population_name, - static_tensors=static_tensors, - projected_message_dim=projected_message_dim, - raw_public_dim=raw_public_dim, - cell_params=cell_params, - ) - defer_fresh_backend_state = ( - backend_population_state_is_fresh - and (fresh_state_virtualized or fixed_output_active_region) - and (fixed_output_active_region or (not use_backend_tape_policy and not use_backend_graph_capture)) - ) - fresh_zero_sentinel_prev = ( - backend_population_state_is_fresh and not state_buffers_needed_by_sequence and time_steps <= 1 - ) - if allow_batch_tiling: - forward_batch_tile_len = self._backend_forward_batch_tile_len_for_layout( - population_name=backend_population_name, - batch_size=int(batch_size), - time_steps=time_steps, - boundary_seq=boundary_seq, - materialize_final_state=materialize_final_state, - training=grad_path, - fresh_state_virtualized=fresh_state_virtualized, - fresh_output_dependency_receiver_count=terminal_dependency_receiver_count, - projected_message_dim=projected_message_dim, - raw_public_dim=raw_public_dim, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - early_batch_tile_reason = self._last_backend_forward_batch_tile_reason - if 0 < forward_batch_tile_len < int(batch_size): - return self._execute_backend_supported_sequence_surface_batch_tiled( - batch_tile_len=forward_batch_tile_len, - batch_tile_reason=early_batch_tile_reason, - state=state, - boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - static_tensors=static_tensors, - population_resets=population_resets, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - backend_population_name=backend_population_name, - backend_population_state_is_fresh=backend_population_state_is_fresh, - materialize_final_state=materialize_final_state, - grad_path=grad_path, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - virtual_fresh_public_prev = ( - backend_population_state_is_fresh - and (defer_fresh_backend_state or fresh_zero_sentinel_prev) - and not state_buffers_needed_by_sequence - and time_steps <= 1 - ) - recurrent_prev = ( - boundary_seq.new_empty(batch_size, 0, self.hidden_size) - if virtual_fresh_public_prev - else boundary_seq.new_zeros(batch_size, int(terminal_dependency_receiver_count), self.hidden_size) - if backend_population_state_is_fresh - and fixed_output_active_region - and terminal_dependency_receiver_count is not None - else boundary_seq.new_zeros(batch_size, self._num_recurrent_cells, self.hidden_size) - if backend_population_state_is_fresh - else state["cells"][:, self._recurrent_slice, :] - ) - if defer_fresh_backend_state or fresh_zero_sentinel_prev: - packed_state = None - elif backend_population_state_is_fresh: - packed_state = self._init_backend_population_state( - population_name, - batch=batch_size, - device=boundary_seq.device, - dtype=boundary_seq.dtype, - ) - else: - packed_state = self._population_state_to_backend_state( - population_name, - cast(TensorDictBase, state[population_name]), - ) - initial_recurrent_k = None - initial_recurrent_v = None - if backend_population_state_is_fresh and not virtual_fresh_public_prev: - initial_receiver_count = int(recurrent_prev.shape[1]) - initial_recurrent_k = recurrent_prev.new_zeros(batch_size, initial_receiver_count, self.head_dim) - initial_recurrent_v = recurrent_prev.new_zeros(batch_size, initial_receiver_count, self.value_dim) - state_sender_k = None if backend_population_state_is_fresh else state.get("sender_k") - state_sender_v = None if backend_population_state_is_fresh else state.get("sender_v") - if torch.is_tensor(state_sender_k) and tuple(state_sender_k.shape) == ( - batch_size, - int(self.sender_cell_idx.numel()), - self.head_dim, - ): - initial_recurrent_k = state_sender_k[:, self._recurrent_slice, :] - if torch.is_tensor(state_sender_v) and tuple(state_sender_v.shape) == ( - batch_size, - int(self.sender_cell_idx.numel()), - self.value_dim, - ): - initial_recurrent_v = state_sender_v[:, self._recurrent_slice, :] - if allow_batch_tiling: - forward_batch_tile_len = self._backend_forward_batch_tile_len( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=recurrent_prev, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - ) - if 0 < forward_batch_tile_len < int(batch_size): - return self._execute_backend_supported_sequence_surface_batch_tiled( - batch_tile_len=forward_batch_tile_len, - batch_tile_reason=self._last_backend_forward_batch_tile_reason, - state=state, - boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - static_tensors=static_tensors, - population_resets=population_resets, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - backend_population_name=backend_population_name, - backend_population_state_is_fresh=backend_population_state_is_fresh, - materialize_final_state=materialize_final_state, - grad_path=grad_path, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - population_resets_active = population_resets is not None - with torch.profiler.record_function("fabric.glue.normalized_population_resets"): - normalized_population_resets = ( - torch.zeros( - batch_size, - boundary_seq.shape[1], - device=boundary_seq.device, - dtype=torch.bool, - ) - if population_resets is None - else population_resets.to(device=boundary_seq.device, dtype=torch.bool) - ).contiguous() - graph_capture_cache_hit = False - graph_capture_replayed = False - input_k_last: torch.Tensor | None = None - input_v_last: torch.Tensor | None = None - self._last_backend_tape_chunk_len = None - self._last_backend_tape_chunk_reason = None - self._last_backend_tape_artifact_mode = None - self._last_backend_recompute_artifact_window_len = None - self._last_backend_recompute_artifact_window_reason = None - self._last_backend_recompute_checkpoint_stride = None - self._last_backend_recompute_checkpoint_count = None - self._last_backend_recompute_checkpoint_reason = None - self._last_backend_recompute_checkpoint_artifact_cache_mode = None - self._last_backend_recompute_predecessor_cache_mode = None - self._last_backend_recompute_transition_tape_mode = None - self._last_backend_recompute_transition_tape_reason = None - self._last_backend_recompute_payload_max_bytes = None - self._last_backend_recompute_payload_max_window_len = None - self._last_backend_recompute_payload_max_mode = None - self._last_backend_recompute_payload_sample_count = None - self._last_backend_recompute_public_kv_materialization_mode = None - self._last_backend_recompute_target_state_materialization_mode = None - self._last_backend_recompute_checkpoint_source = None - self._last_backend_recompute_checkpoint_hidden_carry_mode = None - self._last_backend_forward_carry_checkpoints = None - self._last_backend_backward_batch_tile_len = None - self._last_backend_backward_batch_tile_reason = None - self._last_backend_backward_active_receiver_window = None - self._last_backend_backward_active_receiver_window_reason = None - if use_backend_tape_policy: - ( - output_seq, - next_packed_state, - recurrent_hidden, - recurrent_k, - recurrent_v, - input_k_last, - input_v_last, - graph_capture_cache_hit, - graph_capture_replayed, - ) = self._execute_backend_sequence_with_tape_policy( - boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - packed_state=packed_state, - initial_hidden=recurrent_prev, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=backend_population_state_is_fresh, - population_resets=normalized_population_resets, - population_resets_active=population_resets_active, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - backend_population_name=backend_population_name, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - planned_backend_backward_execution=cast( - PlannedFabricBackwardExecution, planned_backend_backward_execution - ), - static_tensors=static_tensors, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - elif use_backend_graph_capture: - if grad_path: - ( - output_seq, - next_packed_state, - recurrent_hidden, - recurrent_k, - recurrent_v, - input_k_last, - input_v_last, - graph_capture_cache_hit, - graph_capture_replayed, - ) = self._execute_or_capture_backend_training_sequence_surface( - boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - packed_state=packed_state, - initial_hidden=recurrent_prev, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=backend_population_state_is_fresh, - population_resets=normalized_population_resets, - population_resets_active=population_resets_active, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - planned_backend_backward_execution=cast( - PlannedFabricBackwardExecution, planned_backend_backward_execution - ), - enable_graph_capture=True, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - else: - ( - output_seq, - next_packed_state, - recurrent_hidden, - recurrent_k, - recurrent_v, - input_k_seq, - input_v_seq, - graph_capture_cache_hit, - graph_capture_replayed, - ) = self._execute_or_capture_backend_sequence_surface( - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=recurrent_prev, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=backend_population_state_is_fresh, - population_resets=normalized_population_resets, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - backend_population_name=backend_population_name, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - static_tensors=static_tensors, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - if output_boundary == "terminal": - output_seq = output_seq[:, -1:] - input_k_last = input_k_seq[:, -1] - input_v_last = input_v_seq[:, -1] - elif grad_path: - ( - output_seq, - next_packed_state, - recurrent_hidden, - recurrent_k, - recurrent_v, - input_k_last, - input_v_last, - graph_capture_cache_hit, - graph_capture_replayed, - ) = self._execute_or_capture_backend_training_sequence_surface( - boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - packed_state=packed_state, - initial_hidden=recurrent_prev, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=backend_population_state_is_fresh, - population_resets=normalized_population_resets, - population_resets_active=population_resets_active, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - planned_backend_backward_execution=cast( - PlannedFabricBackwardExecution, planned_backend_backward_execution - ), - static_tensors=static_tensors, - enable_graph_capture=False, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - else: - output_seq, next_packed_state, recurrent_hidden, recurrent_k, recurrent_v, input_k_seq, input_v_seq = ( - self._run_backend_sequence_surface_once( - population_name=backend_population_name, - boundary_seq=boundary_seq, - packed_state=packed_state, - initial_hidden=recurrent_prev, - initial_recurrent_k=initial_recurrent_k, - initial_recurrent_v=initial_recurrent_v, - initial_state_is_fresh=backend_population_state_is_fresh, - materialize_final_state=materialize_final_state, - population_resets=normalized_population_resets, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - planned_backend_execution=planned_backend_execution, - population_materialized=cast(dict[str, object | None], static_tensors["population_materialized"]), - static_tensors=static_tensors, - grad_path=grad_path, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - ) - ) - if output_boundary == "terminal": - output_seq = output_seq[:, -1:] - input_k_last = input_k_seq[:, -1] - input_v_last = input_v_seq[:, -1] - if not materialize_final_state and self._last_backend_launch_metadata is not None: - metadata = dict(self._last_backend_launch_metadata) - metadata["generic_glue_fusion_modes"] = tuple(metadata.get("generic_glue_fusion_modes", ())) + ( - "final_state_materialization:skipped_by_request", - ) - metadata["workspace_aliases"] = tuple(metadata.get("workspace_aliases", ())) + ( - "final_state=not_materialized", - ) - self._last_backend_launch_metadata = metadata - self._record_backend_execution( - surface=selected_backend_surface, - plan=planned_backend_execution, - backward_plan=planned_backend_backward_execution, - batch_size=batch_size, - time_steps=boundary_seq.shape[1], - inner_steps=1, - training=grad_path, - graph_capture_replayed=graph_capture_replayed, - graph_capture_cache_hit=graph_capture_cache_hit, - ) - if output_boundary == "terminal" and self._last_backend_execution is not None: - record = self._last_backend_execution - self._last_backend_execution = replace( - record, - workspace_aliases=record.workspace_aliases - + ( - "sequence_output_boundary:terminal_step", - "sequence_output_materialization:terminal_step_only", - ), - ) - next_state = TensorDict({}, batch_size=[]) - if materialize_final_state: - last_boundary_step = boundary_seq[:, -1] - last_output_cells = output_seq[:, -1] - with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): - next_state["cells"] = torch.cat((last_boundary_step, recurrent_hidden, last_output_cells), dim=1) - assert input_k_last is not None and input_v_last is not None - with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): - next_state["sender_k"] = torch.cat((input_k_last, recurrent_k), dim=1) - with torch.profiler.record_function("fabric.glue.materialize_next_state_cat"): - next_state["sender_v"] = torch.cat((input_v_last, recurrent_v), dim=1) - next_state[population_name] = self._backend_state_to_population_state( - population_name, - cast(Mapping[str, torch.Tensor], next_packed_state), - ) - return output_seq, next_state diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/__init__.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/__init__.py new file mode 100644 index 00000000..15ebda6d --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/__init__.py @@ -0,0 +1 @@ +"""Compiler-owned temporal sequence-surface execution modules.""" diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/common.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/common.py new file mode 100644 index 00000000..b0a55dc2 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/common.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import Any + +import torch +from tensordict import TensorDict, TensorDictBase + +from cortical.fabric.backend.cuda.sequence_surface.compiler.tables import ( + TemporalPrimitiveTablePlan, + build_temporal_primitive_table_plan, + temporal_table_transition_recurrent_bucket_kinds, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + TemporalBucketStepArtifacts, + TemporalOutputContract, + TemporalPublicCarryOrder, +) +from cortical.fabric.backend.reuse import ExecutionFamily + + +def _validate_temporal_physical_backward_plan(planned_backward_execution: Any | None) -> None: + if planned_backward_execution is None: + raise RuntimeError("Temporal bucket physical backward requires a planned backward execution") + if not planned_backward_execution.receiver_bucket_plans or not planned_backward_execution.sender_bucket_plans: + raise RuntimeError("Temporal bucket physical backward requires receiver and sender bucket plans") + if any( + bucket_plan.execution_family != ExecutionFamily.RECEIVER_MAJOR + for bucket_plan in planned_backward_execution.receiver_bucket_plans + ): + raise RuntimeError("Temporal bucket physical backward requires receiver-major receiver-adjoint execution") + if any( + bucket_plan.execution_family != ExecutionFamily.EDGE_MAJOR + for bucket_plan in planned_backward_execution.sender_bucket_plans + ): + raise RuntimeError("Temporal bucket physical backward requires edge-major sender/public accumulation") + try: + family_behaviors = planned_backward_execution.physical_plan.family_behaviors + except KeyError as error: + raise RuntimeError(f"Temporal bucket physical backward has unregistered family {error.args[0]!r}") from error + unsupported = tuple(behavior.family for behavior in family_behaviors if behavior.behavior == "unsupported") + if unsupported: + raise RuntimeError( + "Temporal bucket physical backward cannot run unsupported families: " + ", ".join(sorted(unsupported)) + ) + + +def _pool_output_ports_backward( + runtime: Any, + port_y: torch.Tensor, + grad_pooled: torch.Tensor, +) -> torch.Tensor: + readout_pool = str(runtime.config.readout.pool) + if readout_pool == "mean": + return grad_pooled.expand_as(port_y) / max(1, int(port_y.shape[1])) + if readout_pool == "flatten": + return grad_pooled.reshape_as(port_y) + scores = torch.einsum("bph,qh->bpq", port_y, runtime.readout_query) + weights = torch.softmax(scores.to(dtype=torch.float32), dim=1).to(dtype=port_y.dtype) + direct_grad = torch.einsum("bpq,bqh->bph", weights, grad_pooled) + weighted_dot = torch.einsum("bqh,bph->bpq", grad_pooled, port_y) + expected_dot = (weights * weighted_dot).sum(dim=1, keepdim=True) + grad_scores = weights * (weighted_dot - expected_dot) + score_grad = torch.einsum("bpq,qh->bph", grad_scores, runtime.readout_query) + return direct_grad + score_grad + + +def _grad_output_cells_for_contract( + runtime: Any, + output_cells: torch.Tensor, + grad_output: torch.Tensor | None, + output_contract: TemporalOutputContract, +) -> torch.Tensor | None: + if grad_output is None: + return None + if output_contract == "output_cells": + return grad_output + if output_contract == "pooled_output_cells": + return _pool_output_ports_backward(runtime, output_cells, grad_output) + raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") + + +def _temporal_forward_output_for_contract( + runtime: Any, + artifacts: TemporalBucketStepArtifacts, + output_contract: TemporalOutputContract, +) -> torch.Tensor: + if output_contract == "full_cells": + return artifacts.cells_out + if output_contract == "output_cells": + return artifacts.output_cells + if output_contract == "pooled_output_cells": + return runtime._pool_output_ports(artifacts.output_cells.unsqueeze(1)).squeeze(1) + raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") + + +def _flat_bucket_name(bucket: Any) -> str: + return str(getattr(bucket, "name", getattr(bucket, "binding_name", ""))) + + +def temporal_message_output_dim(runtime: Any) -> int: + message_program = getattr(getattr(runtime, "backend_ir", None), "message_program", None) + output_dim_role = str(getattr(message_program, "output_dim_role", "value_dim")) + if output_dim_role == "d_msg": + return int(runtime.d_msg) + if output_dim_role != "value_dim": + raise RuntimeError(f"Unsupported compiled message output dimension role {output_dim_role!r}") + return int(runtime.value_dim) + + +def _flat_bucket_temporal_table_plan( + runtime: Any, + static_tensors: dict[str, object], +) -> TemporalPrimitiveTablePlan: + table_plan = build_temporal_primitive_table_plan(runtime, static_tensors) + bucket_kinds = temporal_table_transition_recurrent_bucket_kinds(table_plan) + runtime._last_flat_bucket_temporal_registered_transition_bucket_kinds = tuple( + f"bucket={bucket_ordinal},kind={kind}" for bucket_ordinal, kind in sorted(bucket_kinds.items()) + ) + return table_plan + + +def _receiver_hidden_view(tensor: torch.Tensor, *, receivers: int, hidden: int) -> torch.Tensor: + if tensor.dim() == 3 and int(tensor.shape[0]) == 1: + tensor = tensor.squeeze(0) + return tensor.reshape(int(receivers), int(hidden)) + + +def _initial_backend_state_tensors( + runtime: Any, + bucket: Any, + state: TensorDict, + keys: tuple[str, ...], +) -> tuple[torch.Tensor, ...] | None: + population_state = state.get(_flat_bucket_name(bucket)) + if not isinstance(population_state, TensorDictBase): + return None + backend_state = runtime._population_state_to_backend_state(_flat_bucket_name(bucket), population_state) + tensors: list[torch.Tensor] = [] + for key in keys: + tensor = backend_state.get(key) + if not torch.is_tensor(tensor): + return None + tensors.append(tensor.contiguous()) + return tuple(tensors) + + +def _backend_state_cache_tensors( + backend_state_cache: dict[str, object] | None, + bucket: Any, + keys: tuple[str, ...], +) -> tuple[torch.Tensor, ...] | None: + if backend_state_cache is None: + return None + backend_state = backend_state_cache.get(_flat_bucket_name(bucket)) + if backend_state is None or not hasattr(backend_state, "get"): + return None + tensors: list[torch.Tensor] = [] + for key in keys: + tensor = backend_state.get(key) # type: ignore[attr-defined] + if not torch.is_tensor(tensor): + return None + tensors.append(tensor.contiguous()) + return tuple(tensors) + + +def _gated_public_from_raw_state_window( + y_window: torch.Tensor, + *, + outnorm_weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + if int(y_window.numel()) == 0: + return y_window.contiguous() + if y_window.dim() == 3: + weight = outnorm_weight.view(1, int(outnorm_weight.shape[0]), int(outnorm_weight.shape[1])) + elif y_window.dim() == 4: + weight = outnorm_weight.view(1, 1, int(outnorm_weight.shape[0]), int(outnorm_weight.shape[1])) + else: + raise RuntimeError("Gated public-state projection expects [B,R,H] or [T,B,R,H]") + mean = y_window.mean(dim=-1, keepdim=True) + var = torch.clamp((y_window * y_window).mean(dim=-1, keepdim=True) - mean * mean, min=0.0) + return ((y_window - mean) * torch.rsqrt(var + float(eps)) * weight).contiguous() + + +def _reorder_backend_recurrent_bank_to_graph_order(runtime: Any, tensor: torch.Tensor) -> torch.Tensor: + return tensor.index_select(1, runtime.population_backend_recurrent_inverse_order).contiguous() + + +def _reorder_backend_recurrent_grad_to_graph_order(runtime: Any, tensor: torch.Tensor | None) -> torch.Tensor | None: + if tensor is None: + return None + return tensor.index_select(1, runtime.population_backend_recurrent_inverse_order).contiguous() + + +def _reorder_graph_recurrent_bank_to_backend_order(runtime: Any, tensor: torch.Tensor) -> torch.Tensor: + return tensor.index_select(1, runtime.population_backend_recurrent_order).contiguous() + + +def _materialize_backend_recurrent_hidden_grad_to_cells( + runtime: Any, + grad_hidden_backend_order: torch.Tensor, + *, + template_cells: torch.Tensor, + record_tag: str, +) -> torch.Tensor: + del record_tag + grad_hidden_graph = _reorder_backend_recurrent_bank_to_graph_order(runtime, grad_hidden_backend_order) + grad_cells = template_cells.new_zeros( + int(grad_hidden_backend_order.shape[0]), + int(template_cells.shape[1]), + int(grad_hidden_backend_order.shape[2]), + ) + grad_cells[:, runtime._recurrent_slice, :] = grad_hidden_graph.to(dtype=grad_cells.dtype) + return grad_cells + + +def _temporal_recurrent_message_tables( + runtime: Any, + public_carry_order: TemporalPublicCarryOrder, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if public_carry_order == "backend_order": + return ( + runtime.recurrent_neighbor_idx_flat_bucket_carry_order, + runtime.recurrent_local_sender_idx_flat_bucket_carry_order, + runtime.recurrent_local_receiver_idx_by_sender_flat_bucket_carry_order, + ) + return ( + runtime.recurrent_neighbor_idx_backend_order, + runtime.recurrent_local_sender_idx_backend_order, + runtime.recurrent_local_receiver_idx_by_sender_backend_order, + ) + + +def _temporal_recurrent_message_reverse_tables( + runtime: Any, + public_carry_order: TemporalPublicCarryOrder, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, bool]: + neighbor_idx, receiver_sender_idx, sender_receiver_idx = _temporal_recurrent_message_tables( + runtime, + public_carry_order, + ) + if public_carry_order == "backend_order" and hasattr( + runtime, + "recurrent_local_receiver_idx_by_sender_compact_flat_bucket_carry_order", + ): + return ( + neighbor_idx, + receiver_sender_idx, + runtime.recurrent_local_receiver_idx_by_sender_compact_flat_bucket_carry_order, + runtime.recurrent_local_receiver_slot_idx_by_sender_compact_flat_bucket_carry_order, + True, + ) + if hasattr(runtime, "recurrent_local_receiver_idx_by_sender_compact_backend_order"): + return ( + neighbor_idx, + receiver_sender_idx, + runtime.recurrent_local_receiver_idx_by_sender_compact_backend_order, + runtime.recurrent_local_receiver_slot_idx_by_sender_compact_backend_order, + True, + ) + return neighbor_idx, receiver_sender_idx, sender_receiver_idx, None, False + + +def _temporal_output_message_tables( + runtime: Any, + public_carry_order: TemporalPublicCarryOrder, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if public_carry_order == "backend_order": + return ( + runtime.output_neighbor_idx_flat_bucket_carry_order, + runtime.output_local_sender_idx_flat_bucket_carry_order, + runtime.output_local_receiver_idx_by_sender_flat_bucket_carry_order, + ) + return ( + runtime.output_neighbor_idx, + runtime.output_local_sender_idx, + runtime.output_local_receiver_idx_by_sender, + ) + + +def _bucket_population_params(bucket: Any) -> dict[str, object] | None: + population_materialized = bucket.static_tensors.get("population_materialized") + if not isinstance(population_materialized, dict): + return None + params = population_materialized.get(_flat_bucket_name(bucket)) + return params if isinstance(params, dict) else None + + +def _empty_temporal_population_views(runtime: Any) -> TensorDict: + return TensorDict( + {name: TensorDict({}, batch_size=[]) for name in runtime._population_names}, + batch_size=[], + ) + + +def _temporal_plan_output_request(temporal_plan: Any | None) -> Any | None: + return None if temporal_plan is None else getattr(temporal_plan, "output_request", None) + + +def _temporal_plan_reverse_artifact_kind(temporal_plan: Any | None) -> str: + materialization = None if temporal_plan is None else getattr(temporal_plan, "materialization", None) + return "none" if materialization is None else str(getattr(materialization, "reverse_artifact_kind", "none")) + + +def _initial_cells_for_forward_reverse_tables( + runtime: Any, + state: TensorDict, + *, + batch_size: int, + hidden: int, + reference: torch.Tensor, +) -> torch.Tensor: + cells = state.get("cells") + if torch.is_tensor(cells): + return cells + return reference.new_zeros(int(batch_size), int(runtime.coords.shape[0]), int(hidden)) + + +def _zero_initial_population_tuple( + reference: torch.Tensor, + *, + batch_size: int, + count: int, + hidden: int, + entries: int, +) -> tuple[torch.Tensor, ...]: + return tuple(reference.new_zeros(int(batch_size), int(count), int(hidden)) for _ in range(int(entries))) + + +def _gated_affine_param_grads_from_reverse_windows( + *, + gated_raw_window: torch.Tensor, + gated_input_grad_window: torch.Tensor, + gated_input_window: torch.Tensor, + recurrent_msg_window: torch.Tensor, + gated_start: int, + gated_count: int, + value_dim: int, + gated_recurrent_affine_head_dim: int, + gated_input_projection_bias: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + time_steps = int(gated_raw_window.shape[0]) + batch_size = int(gated_raw_window.shape[1]) + hidden = int(gated_raw_window.shape[4]) + time_batch = int(time_steps) * int(batch_size) + num_heads = hidden // int(gated_recurrent_affine_head_dim) + gated_raw_batch = gated_raw_window.reshape(time_batch, int(gated_count), 4, hidden) + gated_gate_proj_grad = ( + gated_raw_batch.reshape( + time_batch, + int(gated_count), + 4, + num_heads, + int(gated_recurrent_affine_head_dim), + ) + .permute(0, 1, 3, 2, 4) + .reshape(time_batch, int(gated_count) * num_heads, 4 * int(gated_recurrent_affine_head_dim)) + .contiguous() + ) + gated_input_heads = ( + gated_input_window.reshape( + time_batch, + int(gated_count), + num_heads, + int(gated_recurrent_affine_head_dim), + ) + .reshape(time_batch, int(gated_count) * num_heads, int(gated_recurrent_affine_head_dim)) + .contiguous() + ) + gated_gate_weight_grad = ( + torch.bmm( + gated_input_heads.permute(1, 2, 0).contiguous(), + gated_gate_proj_grad.permute(1, 0, 2).contiguous(), + ) + .reshape( + int(gated_count), + num_heads, + int(gated_recurrent_affine_head_dim), + 4 * int(gated_recurrent_affine_head_dim), + ) + .contiguous() + ) + gated_gate_bias_grad = ( + gated_raw_batch.sum(dim=0) + .reshape(int(gated_count), 4 * hidden) + .reshape(int(gated_count), num_heads, 4, int(gated_recurrent_affine_head_dim)) + .permute(0, 2, 1, 3) + .contiguous() + ) + gated_recurrent_msg = recurrent_msg_window[:, :, int(gated_start) : int(gated_start) + int(gated_count), :] + gated_input_weight_grad = torch.mm( + gated_input_grad_window.reshape(time_batch * int(gated_count), hidden).transpose(0, 1).contiguous(), + gated_recurrent_msg.reshape(time_batch * int(gated_count), int(value_dim)).contiguous(), + ).contiguous() + gated_input_bias_grad = gated_input_grad_window.sum(dim=(0, 1)).contiguous() + if gated_input_projection_bias.dim() == 3: + gated_input_bias_grad = gated_input_bias_grad.unsqueeze(0).contiguous() + return gated_gate_weight_grad, gated_gate_bias_grad, gated_input_weight_grad, gated_input_bias_grad + + +def _record_temporal_backward_glue_cuda(runtime: Any, tag: str) -> None: + record = getattr(runtime, "_last_backend_execution", None) + if record is None: + return + executor_tag = "cuda_temporal_backward_glue" + launch_tag = f"temporal_backward_glue:{tag}" + workspace_tag = "temporal_backward_glue:cuda_scan_index" + backward_binding_abi = getattr(runtime, "_last_flat_bucket_temporal_backward_binding_abi", None) + backward_binding_tag = ( + None if backward_binding_abi is None else f"flat_bucket_temporal_backward_binding_abi:{backward_binding_abi}" + ) + runtime._last_backend_execution = replace( + record, + backward_physical_op_executors=record.backward_physical_op_executors + + (() if executor_tag in record.backward_physical_op_executors else (executor_tag,)), + backward_launch_counts=record.backward_launch_counts + + (() if launch_tag in record.backward_launch_counts else (launch_tag,)), + workspace_aliases=record.workspace_aliases + + (() if workspace_tag in record.workspace_aliases else (workspace_tag,)) + + ( + () + if backward_binding_tag is None or backward_binding_tag in record.workspace_aliases + else (backward_binding_tag,) + ), + ) + + +_TEMPORAL_REVERSE_SCAN_OWNER_PREFIX = "flat_bucket_temporal_reverse_scan_owner:" +_TEMPORAL_REVERSE_SCAN_BINDING_ABI_PREFIX = "flat_bucket_temporal_reverse_scan_binding_abi:" +_TEMPORAL_REVERSE_SCAN_OWNER_BLOCKED_CUDA_SUPEROP = "cuda_temporal_superop" +_TEMPORAL_REVERSE_SCAN_TABLE_ABI_BLOCKED = "flat_bucket_temporal_reverse_table_extension" +_TEMPORAL_REVERSE_SCAN_OWNER_REGISTERED = "registered_reverse_executor_bindings" +_TEMPORAL_REVERSE_SCAN_BINDING_ABI_REGISTERED = "registered_executor_binding_rows" + + +def _append_unique(items: tuple[str, ...], *candidates: str | None) -> tuple[str, ...]: + out = items + for candidate in candidates: + if candidate is not None and candidate not in out: + out = out + (candidate,) + return out + + +def _record_temporal_reverse_scan_owner( + runtime: Any, + owner: str, + *, + binding_abi: str | None = None, +) -> None: + record = getattr(runtime, "_last_backend_execution", None) + if record is None: + return + if ( + owner == _TEMPORAL_REVERSE_SCAN_OWNER_BLOCKED_CUDA_SUPEROP + or binding_abi == _TEMPORAL_REVERSE_SCAN_TABLE_ABI_BLOCKED + ): + raise RuntimeError( + "Temporal backward cannot claim the deleted CUDA reverse-table path; " + "use registered reverse executor bindings" + ) + if ( + owner == _TEMPORAL_REVERSE_SCAN_OWNER_REGISTERED + and binding_abi != _TEMPORAL_REVERSE_SCAN_BINDING_ABI_REGISTERED + ): + raise RuntimeError( + "Temporal backward cannot claim registered reverse-scan ownership without the " + f"{_TEMPORAL_REVERSE_SCAN_BINDING_ABI_REGISTERED} binding ABI" + ) + owner_tag = f"{_TEMPORAL_REVERSE_SCAN_OWNER_PREFIX}{owner}" + binding_tag = None if binding_abi is None else f"{_TEMPORAL_REVERSE_SCAN_BINDING_ABI_PREFIX}{binding_abi}" + backward_binding_tag = None if binding_abi is None else f"flat_bucket_temporal_backward_binding_abi:{binding_abi}" + runtime._last_backend_execution = replace( + record, + workspace_aliases=_append_unique(record.workspace_aliases, binding_tag, backward_binding_tag), + backward_workspace_aliases=_append_unique( + record.backward_workspace_aliases, + owner_tag, + binding_tag, + backward_binding_tag, + ), + ) + + +def _validate_temporal_reverse_scan_claim(runtime: Any) -> None: + record = getattr(runtime, "_last_backend_execution", None) + if record is None: + return + aliases = set(record.workspace_aliases) | set(record.backward_workspace_aliases) + blocked_owner_tag = f"{_TEMPORAL_REVERSE_SCAN_OWNER_PREFIX}{_TEMPORAL_REVERSE_SCAN_OWNER_BLOCKED_CUDA_SUPEROP}" + blocked_binding_tag = f"{_TEMPORAL_REVERSE_SCAN_BINDING_ABI_PREFIX}{_TEMPORAL_REVERSE_SCAN_TABLE_ABI_BLOCKED}" + blocked_claims = ( + _TEMPORAL_REVERSE_SCAN_OWNER_BLOCKED_CUDA_SUPEROP in record.temporal_plan_backward_owners + or any( + executor in {"cuda_temporal_superop", "cuda_temporal_reverse_superop"} + for executor in record.backward_physical_op_executors + ) + or blocked_owner_tag in aliases + or blocked_binding_tag in aliases + ) + if blocked_claims: + raise RuntimeError("Temporal backward record claims the deleted CUDA reverse-table path") + registered_owner_claims = _TEMPORAL_REVERSE_SCAN_OWNER_REGISTERED in record.temporal_plan_backward_owners + registered_owner_tag = f"{_TEMPORAL_REVERSE_SCAN_OWNER_PREFIX}{_TEMPORAL_REVERSE_SCAN_OWNER_REGISTERED}" + registered_binding_tag = ( + f"{_TEMPORAL_REVERSE_SCAN_BINDING_ABI_PREFIX}{_TEMPORAL_REVERSE_SCAN_BINDING_ABI_REGISTERED}" + ) + if not (registered_owner_claims or registered_owner_tag in aliases): + return + if registered_binding_tag not in aliases: + raise RuntimeError( + "Temporal backward record claims registered reverse-scan ownership without recording " + f"{registered_binding_tag}" + ) + + +__all__ = [ + name + for name in globals() + if not name.startswith("__") + and name + not in { + "annotations", + } +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py new file mode 100644 index 00000000..f3b8b697 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_patterns import ( + temporal_executor_strategy_registry, +) + + +@dataclass(frozen=True) +class RegisteredTemporalExecutorKernelRegistry: + forward_message_executors: frozenset[str] + reverse_message_executors: frozenset[str] + forward_readout_executors: frozenset[str] + reverse_readout_executors: frozenset[str] + forward_transition_executors: frozenset[str] + reverse_transition_executors: frozenset[str] + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + "registered_temporal_executor_kernel_registry=compiler_owned", + "forward_surfaces=message,transition,readout,layout", + "reverse_surfaces=message,transition,readout,boundary", + "dispatch_key=selected_executor_rows_and_binding_rows", + "executor_validation=registry_owned", + "forward_message_executors=" + ",".join(sorted(self.forward_message_executors)), + "reverse_message_executors=" + ",".join(sorted(self.reverse_message_executors)), + "forward_readout_executors=" + ",".join(sorted(self.forward_readout_executors)), + "reverse_readout_executors=" + ",".join(sorted(self.reverse_readout_executors)), + "forward_transition_executors=" + ",".join(sorted(self.forward_transition_executors)), + "reverse_transition_executors=" + ",".join(sorted(self.reverse_transition_executors)), + ) + + def require_forward_message(self, executor_handle: Any) -> None: + self._require_handle( + executor_handle, + direction="forward", + surface="message", + executor_names=self.forward_message_executors, + ) + + def require_forward_messages(self, executor_handles: tuple[Any, ...]) -> None: + self._require_surface_handles( + executor_handles, + direction="forward", + surface="message", + ) + for executor_handle in executor_handles: + self.require_forward_message(executor_handle) + + def require_reverse_message(self, executor_handle: Any) -> None: + self._require_handle( + executor_handle, + direction="reverse", + surface="message", + executor_names=self.reverse_message_executors, + ) + + def require_reverse_messages(self, executor_handles: tuple[Any, ...]) -> None: + self._require_surface_handles( + executor_handles, + direction="reverse", + surface="message", + ) + for executor_handle in executor_handles: + self.require_reverse_message(executor_handle) + + def require_forward_readout(self, executor_handle: Any) -> None: + self._require_handle( + executor_handle, + direction="forward", + surface="readout", + executor_names=self.forward_readout_executors, + ) + + def require_forward_readouts(self, executor_handles: tuple[Any, ...]) -> None: + self._require_surface_handles( + executor_handles, + direction="forward", + surface="readout", + ) + for executor_handle in executor_handles: + self.require_forward_readout(executor_handle) + + def require_reverse_readout(self, executor_handle: Any) -> None: + self._require_handle( + executor_handle, + direction="reverse", + surface="readout", + executor_names=self.reverse_readout_executors, + ) + + def require_reverse_readouts(self, executor_handles: tuple[Any, ...]) -> None: + self._require_surface_handles( + executor_handles, + direction="reverse", + surface="readout", + ) + for executor_handle in executor_handles: + self.require_reverse_readout(executor_handle) + + def require_forward_transition(self, executor_handle: Any) -> None: + self._require_handle( + executor_handle, + direction="forward", + surface="transition", + executor_names=self.forward_transition_executors, + ) + + def require_reverse_transition(self, executor_handle: Any) -> None: + self._require_handle( + executor_handle, + direction="reverse", + surface="transition", + executor_names=self.reverse_transition_executors, + ) + + def require_forward_transitions(self, executor_handles: tuple[Any, ...]) -> None: + self._require_transition_bucket_coverage(executor_handles) + for executor_handle in executor_handles: + self.require_forward_transition(executor_handle) + + def require_reverse_transitions(self, executor_handles: tuple[Any, ...]) -> None: + self._require_transition_bucket_coverage(executor_handles) + for executor_handle in executor_handles: + self.require_reverse_transition(executor_handle) + + def _require_handle( + self, + executor_handle: Any, + *, + direction: str, + surface: str, + executor_names: frozenset[str], + ) -> None: + actual_direction = str(getattr(executor_handle, "direction", "")) + actual_surface = str(getattr(executor_handle, "surface", "")) + actual_executor = str(getattr(executor_handle, "executor_name", "")) + actual_bucket = int(getattr(executor_handle, "bucket_ordinal", -999999)) + if actual_direction != direction or actual_surface != surface or actual_executor not in executor_names: + raise RuntimeError( + "Registered temporal executor kernel registry rejected executor row: " + f"direction={actual_direction}; surface={actual_surface}; bucket={actual_bucket}; " + f"executor={actual_executor!r}; expected_direction={direction}; " + f"expected_surface={surface}; expected_executors={tuple(sorted(executor_names))!r}" + ) + try: + executor_handle.require_parameter_bindings() + except AttributeError as error: + raise RuntimeError( + "Registered temporal executor kernel registry requires compiler-owned parameter bindings" + ) from error + + @staticmethod + def _require_surface_handles( + executor_handles: tuple[Any, ...], + *, + direction: str, + surface: str, + ) -> None: + if not executor_handles: + raise RuntimeError( + "Registered temporal executor kernel registry requires at least one executor row: " + f"direction={direction}; surface={surface}" + ) + + @staticmethod + def _require_transition_bucket_coverage(executor_handles: tuple[Any, ...]) -> None: + actual_buckets = { + int(getattr(executor_handle, "bucket_ordinal", -999999)) for executor_handle in executor_handles + } + expected_buckets = set(range(1 + max(actual_buckets, default=-1))) + if actual_buckets != expected_buckets: + raise RuntimeError( + "Registered temporal transition executor registry rows must cover every transition bucket: " + f"expected={tuple(sorted(expected_buckets))}; actual={tuple(sorted(actual_buckets))}" + ) + + +def _forward_executor_names(*, surface: str) -> frozenset[str]: + return frozenset( + pattern.executor_name + for pattern in temporal_executor_strategy_registry().forward_patterns() + if pattern.surface == str(surface) + ) + + +def _reverse_executor_names(*, surface: str) -> frozenset[str]: + return frozenset( + pattern.executor_name + for pattern in temporal_executor_strategy_registry().reverse_patterns() + if pattern.surface == str(surface) + ) + + +_REGISTERED_TEMPORAL_EXECUTOR_KERNEL_REGISTRY = RegisteredTemporalExecutorKernelRegistry( + forward_message_executors=_forward_executor_names(surface="message"), + reverse_message_executors=_reverse_executor_names(surface="message"), + forward_readout_executors=_forward_executor_names(surface="readout"), + reverse_readout_executors=_reverse_executor_names(surface="readout"), + forward_transition_executors=_forward_executor_names(surface="transition"), + reverse_transition_executors=_reverse_executor_names(surface="transition"), +) + + +def registered_temporal_executor_kernel_registry() -> RegisteredTemporalExecutorKernelRegistry: + return _REGISTERED_TEMPORAL_EXECUTOR_KERNEL_REGISTRY + + +__all__ = [ + "RegisteredTemporalExecutorKernelRegistry", + "registered_temporal_executor_kernel_registry", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py new file mode 100644 index 00000000..57db6b28 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py @@ -0,0 +1,517 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import Any, Literal + +import torch +from tensordict import TensorDict, TensorDictBase + +from cortical.fabric.backend.cuda.sequence_surface.compiler.buckets import ( + backend_order_flat_buckets, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.backward_plan import ( + build_temporal_backward_executable_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.forward_plan import ( + build_temporal_forward_executable_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.memory_plan import ( + build_temporal_memory_liveness_plan, + build_temporal_memory_runtime_artifact_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.runtime_metadata import ( + record_temporal_primitive_table_runtime_metadata, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.tables import ( + build_temporal_primitive_table_plan, + temporal_table_full_tape_extra_state_factors, + temporal_table_transition_kind_labels, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.scan_schedule import ( + build_scalar_temporal_scan_schedule, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + SharedTemporalForwardScanResult, + TemporalArtifactCheckpoint, + TemporalArtifactStore, + TemporalOutputContract, + TemporalPublicCarryOrder, + TemporalTransitionTapeMode, + TemporalTransitionTapePolicy, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.registered_executors import ( + build_registered_temporal_executor_program, + run_registered_temporal_forward_executor_scan, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.program_parameters import ( + surface_parameter_binding_allows_empty_static_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.scheduler import ( + TemporalRuntimeSchedulerPlan, + build_temporal_runtime_scheduler_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.common import ( + _flat_bucket_temporal_table_plan, +) + + +def _compiler_bound_surface_parameters_available( + runtime: Any, + table_plan: Any, + static_tensors: dict[str, object], + *, + surface: str, +) -> bool: + for binding in tuple(getattr(table_plan, "tensor_bindings", ()) or ()): + if getattr(binding, "surface", None) != surface or getattr(binding, "binding_kind", None) != "parameter": + continue + resolved = False + for source in tuple(getattr(binding, "source_bindings", ()) or ()): + source_text = str(source) + if source_text.startswith("static_tensor:"): + source_key = source_text.removeprefix("static_tensor:") + if torch.is_tensor(static_tensors.get(source_key)) or ( + source_key in static_tensors + and surface_parameter_binding_allows_empty_static_tensor(binding, source_key) + ): + resolved = True + break + if source_text.startswith("message_parameter:"): + logical_name = str(getattr(binding, "logical_name", "")) + message_parameters = getattr(runtime, "message_rule_parameters", None) + if hasattr(message_parameters, "__contains__") and logical_name in message_parameters: + resolved = True + break + if source_text.startswith("runtime_attr:") and torch.is_tensor( + getattr(runtime, source_text.removeprefix("runtime_attr:"), None) + ): + resolved = True + break + if not resolved: + return False + return True + + +def _attach_temporal_compiler_plans_to_artifact_store( + runtime: Any, + *, + static_tensors: dict[str, object], + artifact_store: TemporalArtifactStore, +) -> TemporalArtifactStore: + if artifact_store.primitive_table_plan is not None: + return artifact_store + temporal_table_plan = _flat_bucket_temporal_table_plan(runtime, static_tensors) + memory_plan = build_temporal_memory_liveness_plan(temporal_table_plan) + return replace( + artifact_store, + memory_plan_fingerprint=artifact_store.memory_plan_fingerprint or memory_plan.fingerprint, + primitive_table_fingerprint=temporal_table_plan.fingerprint, + primitive_table_plan=temporal_table_plan, + ) + + +def _try_registered_temporal_forward_scan( + runtime: Any, + *, + boundary_seq: torch.Tensor, + state: TensorDict, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None, + static_tensors: dict[str, object], + inner_steps: int, + output_contract: TemporalOutputContract, + output_boundary: Literal["sequence", "terminal"], + collect_artifacts: bool, + materialize_final_state: bool, + backend_population_state_is_fresh: bool, + transition_tape_mode: TemporalTransitionTapeMode = "disabled", + artifact_checkpoints: dict[int, TemporalArtifactCheckpoint] | None = None, + initial_population_state_cache: dict[str, object] | None = None, + temporal_plan: Any | None = None, + scheduler_plan: TemporalRuntimeSchedulerPlan | None = None, +) -> SharedTemporalForwardScanResult | None: + if ( + output_contract == "full_cells" + or boundary_seq.device.type != "cuda" + or boundary_seq.dtype != torch.float32 + or bool(getattr(runtime, "_has_edge_delay", False)) + or bool(getattr(runtime, "_uses_sparse_message_backend", False)) + or not bool(getattr(runtime, "_partitioned_layout", False)) + or not bool(getattr(runtime, "_local_message_step_enabled", False)) + ): + return None + readout_pool = str(runtime.config.readout.pool) + if output_contract == "pooled_output_cells" and readout_pool not in {"flatten", "mean"}: + return None + if not ( + hasattr(runtime, "recurrent_local_sender_idx_flat_bucket_carry_order") + and hasattr(runtime, "output_local_sender_idx_flat_bucket_carry_order") + ): + return None + temporal_table_plan = _flat_bucket_temporal_table_plan(runtime, static_tensors) + if not _compiler_bound_surface_parameters_available( + runtime, + temporal_table_plan, + static_tensors, + surface="message", + ): + return None + record_temporal_primitive_table_runtime_metadata( + runtime, + temporal_table_plan, + scheduler_plan=scheduler_plan, + ) + forward_executable_plan = build_temporal_forward_executable_plan(temporal_table_plan) + backward_executable_plan = build_temporal_backward_executable_plan(temporal_table_plan) + if ( + forward_executable_plan.strategy_legality_status != "legal" + or backward_executable_plan.strategy_legality_status != "legal" + ): + runtime._last_flat_bucket_temporal_scan_owner = "registered_executor_bindings_unavailable" + runtime._last_flat_bucket_temporal_scan_reject = ( + "registered_executor_bindings_required;" + f"forward_status={forward_executable_plan.strategy_legality_status};" + f"backward_status={backward_executable_plan.strategy_legality_status};" + "registered_program_required=1" + ) + return None + streaming_readout_body_available = ( + not bool(collect_artifacts) + and not bool(materialize_final_state) + and not torch.is_tensor(population_resets) + and not torch.is_tensor(transition_resets) + ) + executor_program = build_registered_temporal_executor_program( + runtime, + static_tensors, + table_plan=temporal_table_plan, + forward_plan=forward_executable_plan, + backward_plan=backward_executable_plan, + streaming_readout_body_available=streaming_readout_body_available, + streaming_readout_body_profitable=True, + ) + scan_schedule = build_scalar_temporal_scan_schedule( + outer_time_steps=int(boundary_seq.shape[1]), + inner_steps=int(inner_steps), + ) + return run_registered_temporal_forward_executor_scan( + runtime, + executor_program=executor_program, + boundary_seq=boundary_seq, + state=state, + population_resets=population_resets, + transition_resets=transition_resets, + static_tensors=static_tensors, + inner_steps=int(inner_steps), + output_contract=output_contract, + output_boundary=output_boundary, + collect_artifacts=collect_artifacts, + materialize_final_state=materialize_final_state, + transition_tape_mode=transition_tape_mode, + memory_artifact_plan=build_temporal_memory_runtime_artifact_plan( + executor_program.memory_plan, + physical_time_steps=int(scan_schedule.physical_time_steps), + collect_artifacts=collect_artifacts, + scheduler_plan=scheduler_plan, + ), + artifact_checkpoints=artifact_checkpoints, + scan_schedule=scan_schedule, + initial_population_state_cache=initial_population_state_cache, + ) + + +def _estimate_temporal_transition_tape_step_bytes( + runtime: Any, + static_tensors: dict[str, object], + *, + batch_size: int, + dtype_bytes: int, + mode: TemporalTransitionTapeMode, +) -> int: + if mode == "disabled": + return 0 + hidden_size = int(runtime.hidden_size) + total = 0 + buckets = tuple(backend_order_flat_buckets(runtime, static_tensors)) + table_plan = build_temporal_primitive_table_plan(runtime, static_tensors) + full_tape_extra_factors = temporal_table_full_tape_extra_state_factors(table_plan) + for bucket_ordinal, bucket in enumerate(buckets): + base = int(batch_size) * int(bucket.count) * hidden_size * int(dtype_bytes) + total += base + if mode != "full": + continue + total += int(full_tape_extra_factors.get(int(bucket_ordinal), 0)) * base + return int(total) + + +def _temporal_transition_kinds(runtime: Any, static_tensors: dict[str, object]) -> frozenset[str]: + return temporal_table_transition_kind_labels(build_temporal_primitive_table_plan(runtime, static_tensors)) + + +def temporal_transition_tape_policy( + runtime: Any, + static_tensors: dict[str, object], + *, + batch_size: int, + time_steps: int, + device: torch.device, + dtype_bytes: int, + tape_policy_bin: str | None = None, +) -> TemporalTransitionTapePolicy: + if time_steps <= 1: + return TemporalTransitionTapePolicy(mode="full", reason="transition_tape=full;time_steps<=1") + memory = runtime._cuda_memory_budget(device) if hasattr(runtime, "_cuda_memory_budget") else None + if memory is None: + return TemporalTransitionTapePolicy(mode="disabled", reason="transition_tape=disabled;memory=unknown") + input_step_bytes = _estimate_temporal_transition_tape_step_bytes( + runtime, + static_tensors, + batch_size=batch_size, + dtype_bytes=dtype_bytes, + mode="input_projection", + ) + full_step_bytes = _estimate_temporal_transition_tape_step_bytes( + runtime, + static_tensors, + batch_size=batch_size, + dtype_bytes=dtype_bytes, + mode="full", + ) + input_bytes = int(input_step_bytes) * int(time_steps) + full_bytes = int(full_step_bytes) * int(time_steps) + transition_kinds = _temporal_transition_kinds(runtime, static_tensors) + reserve_bytes = max(4 << 30, int(memory.total_bytes * 0.04)) + budget_usable_bytes = ( + int(memory.usable_bytes) if int(memory.reusable_reserved_bytes) > int(reserve_bytes) else int(memory.free_bytes) + ) + if tape_policy_bin in {"checkpoint", "tbptt"}: + bounded_budget_bytes = max( + 0, min(int(memory.total_bytes * 0.01), int(budget_usable_bytes) - int(reserve_bytes)) + ) + if full_bytes > 0 and full_bytes <= bounded_budget_bytes: + bounded_mode: TemporalTransitionTapeMode = "full" + elif input_bytes > 0 and input_bytes <= bounded_budget_bytes: + bounded_mode = "input_projection" + else: + bounded_mode = "disabled" + return TemporalTransitionTapePolicy( + mode=bounded_mode, + reason=( + f"transition_tape={bounded_mode};planner_tape_policy={tape_policy_bin};" + f"bounded_temporal_recompute=1;input_step_bytes={int(input_step_bytes)};" + f"input_sequence_bytes={int(input_bytes)};full_step_bytes={int(full_step_bytes)};" + f"full_sequence_bytes={int(full_bytes)};bounded_budget_bytes={int(bounded_budget_bytes)};" + f"transition_kinds={','.join(sorted(transition_kinds))};" + f"free_bytes={int(memory.free_bytes)};allocator_reusable_bytes={int(memory.reusable_reserved_bytes)};" + f"reserve_bytes={int(reserve_bytes)}" + ), + ) + budget_bytes = max(0, min(int(memory.total_bytes * 0.08), int(budget_usable_bytes) - int(reserve_bytes))) + if full_bytes > 0 and full_bytes <= budget_bytes: + mode: TemporalTransitionTapeMode = "full" + elif input_bytes > 0 and input_bytes <= budget_bytes: + mode = "input_projection" + else: + mode = "disabled" + return TemporalTransitionTapePolicy( + mode=mode, + reason=( + f"transition_tape={mode};input_step_bytes={int(input_step_bytes)};" + f"input_sequence_bytes={int(input_bytes)};full_step_bytes={int(full_step_bytes)};" + f"full_sequence_bytes={int(full_bytes)};budget_bytes={int(budget_bytes)};" + f"transition_kinds={','.join(sorted(transition_kinds))};" + f"free_bytes={int(memory.free_bytes)};allocator_reusable_bytes={int(memory.reusable_reserved_bytes)};" + f"reserve_bytes={int(reserve_bytes)}" + ), + ) + + +def _make_temporal_artifact_checkpoint( + *, + step_index: int, + state: TensorDict, + population_state_cache: dict[str, object] | None, + recurrent_k: torch.Tensor | None, + recurrent_v: torch.Tensor | None, + recurrent_kv_layout: TemporalPublicCarryOrder | None = None, +) -> TemporalArtifactCheckpoint: + resolved_layout = ( + None + if recurrent_k is None or recurrent_v is None + else recurrent_kv_layout + if recurrent_kv_layout is not None + else "graph_order" + ) + return TemporalArtifactCheckpoint( + step_index=int(step_index), + state=TensorDict(state.to_dict(), batch_size=[]), + population_state_cache=dict(population_state_cache) if population_state_cache is not None else None, + recurrent_k=recurrent_k, + recurrent_v=recurrent_v, + recurrent_kv_layout=resolved_layout, + ) + + +def _apply_temporal_recurrent_kv_reset( + *, + reset_step: torch.Tensor | None, + recurrent_k: torch.Tensor | None, + recurrent_v: torch.Tensor | None, +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if reset_step is None or recurrent_k is None or recurrent_v is None: + return recurrent_k, recurrent_v + reset_mask = reset_step.to(device=recurrent_k.device, dtype=torch.bool).view(int(reset_step.shape[0]), 1, 1) + return ( + torch.where(reset_mask, torch.zeros_like(recurrent_k), recurrent_k), + torch.where(reset_mask, torch.zeros_like(recurrent_v), recurrent_v), + ) + + +def run_shared_temporal_bucket_forward_scan( + runtime: Any, + *, + boundary_seq: torch.Tensor, + state: TensorDict, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None = None, + static_tensors: dict[str, object], + inner_steps: int, + materialize_final_state: bool, + output_contract: TemporalOutputContract, + output_boundary: Literal["sequence", "terminal"], + planned_backward_execution: Any | None = None, + temporal_plan: Any | None = None, + collect_artifacts: bool = False, + backend_population_state_is_fresh: bool = False, +) -> SharedTemporalForwardScanResult: + inner_steps = max(1, int(inner_steps)) + running_state = TensorDict(state.to_dict(), batch_size=[]) + artifact_store: TemporalArtifactStore | None = None + artifact_checkpoints: dict[int, TemporalArtifactCheckpoint] = {} + state_has_population_views = all( + isinstance(running_state.get(name), TensorDictBase) + for name in runtime._population_names + if int(runtime._population_recurrent_indices(name).numel()) > 0 + ) + if backend_population_state_is_fresh: + step_population_state_cache = runtime._prepare_fresh_stream_step_population_cache( + batch=int(boundary_seq.shape[0]), + device=boundary_seq.device, + dtype=boundary_seq.dtype, + ) + elif state_has_population_views: + step_population_state_cache = runtime._prepare_stream_step_population_cache( + running_state, + batch=int(boundary_seq.shape[0]), + device=boundary_seq.device, + dtype=boundary_seq.dtype, + ) + else: + step_population_state_cache = { + name: runtime._init_backend_population_state( + name, + batch=int(boundary_seq.shape[0]), + device=boundary_seq.device, + dtype=boundary_seq.dtype, + ) + for name in runtime._population_names + if int(runtime._population_recurrent_indices(name).numel()) > 0 + } + if not step_population_state_cache: + step_population_state_cache = None + if collect_artifacts: + artifact_checkpoints[0] = _make_temporal_artifact_checkpoint( + step_index=0, + state=running_state, + population_state_cache=step_population_state_cache, + recurrent_k=None, + recurrent_v=None, + ) + transition_tape_policy = temporal_transition_tape_policy( + runtime, + static_tensors, + batch_size=int(boundary_seq.shape[0]), + time_steps=int(boundary_seq.shape[1]) * int(inner_steps), + device=boundary_seq.device, + dtype_bytes=int(boundary_seq.element_size()), + tape_policy_bin=getattr(planned_backward_execution, "tape_policy_bin", None), + ) + else: + transition_tape_policy = TemporalTransitionTapePolicy( + mode="disabled", + reason="transition_tape=disabled;shared_temporal_forward_inference", + ) + scan_schedule = build_scalar_temporal_scan_schedule( + outer_time_steps=int(boundary_seq.shape[1]), + inner_steps=int(inner_steps), + ) + physical_time_steps = scan_schedule.physical_time_steps + runtime._last_flat_bucket_temporal_recurrent_kv_carry_reuse = physical_time_steps > 1 + runtime._last_flat_bucket_forward_transition_executor = "registered_fused_forward_program_cuda" + runtime._last_flat_bucket_transition_tape_mode = transition_tape_policy.mode + runtime._last_flat_bucket_transition_tape_reason = transition_tape_policy.reason + runtime._last_flat_bucket_temporal_artifact_mode = None + runtime._last_flat_bucket_temporal_artifact_reason = None + runtime._last_flat_bucket_temporal_artifact_checkpoint_stride = None + runtime._last_flat_bucket_temporal_artifact_recompute_window_len = None + runtime._last_flat_bucket_temporal_artifact_checkpoint_count = None + runtime._last_flat_bucket_state_cache_mode = "registered_fused_program_pending" + runtime._last_flat_bucket_state_cache_materialized_steps = 0 + runtime._last_flat_bucket_state_cache_elided_steps = 0 + temporal_plan = ( + temporal_plan if temporal_plan is not None else getattr(planned_backward_execution, "temporal_plan", None) + ) + scheduler_plan = build_temporal_runtime_scheduler_plan( + temporal_plan=temporal_plan, + outer_time_steps=int(boundary_seq.shape[1]), + inner_steps=inner_steps, + output_boundary=output_boundary, + output_contract=output_contract, + materialize_final_state=materialize_final_state, + collect_artifacts=collect_artifacts, + ) + runtime._last_flat_bucket_temporal_scheduler_plan = scheduler_plan.review_summary + registered_scan_result = _try_registered_temporal_forward_scan( + runtime, + boundary_seq=boundary_seq, + state=running_state, + population_resets=population_resets, + transition_resets=transition_resets, + static_tensors=static_tensors, + inner_steps=inner_steps, + output_contract=output_contract, + output_boundary=output_boundary, + collect_artifacts=collect_artifacts, + materialize_final_state=materialize_final_state, + backend_population_state_is_fresh=backend_population_state_is_fresh, + transition_tape_mode=transition_tape_policy.mode, + artifact_checkpoints=artifact_checkpoints, + initial_population_state_cache=step_population_state_cache, + temporal_plan=temporal_plan, + scheduler_plan=scheduler_plan, + ) + if registered_scan_result is None: + raise RuntimeError( + "Shared temporal engine requires the compiler-owned CUDA temporal table scan. " + "Add the missing fabric.cuda.nn primitive executor or fail the unsupported declaration at lowering." + ) + if collect_artifacts: + artifact_store = registered_scan_result.artifact_store + if artifact_store is None: + raise RuntimeError("Registered temporal forward executor produced no planned artifact store") + artifact_store = _attach_temporal_compiler_plans_to_artifact_store( + runtime, + static_tensors=static_tensors, + artifact_store=artifact_store, + ) + runtime._last_flat_bucket_temporal_artifact_mode = artifact_store.mode + runtime._last_flat_bucket_temporal_artifact_reason = artifact_store.reason + runtime._last_flat_bucket_temporal_artifact_checkpoint_stride = artifact_store.checkpoint_stride + runtime._last_flat_bucket_temporal_artifact_recompute_window_len = artifact_store.recompute_window_len + runtime._last_flat_bucket_temporal_artifact_checkpoint_count = len(artifact_store.checkpoints) + return SharedTemporalForwardScanResult( + output_seq=registered_scan_result.output_seq, + final_state=registered_scan_result.final_state, + artifact_store=artifact_store, + ) + return registered_scan_result diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/output_backward.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/output_backward.py new file mode 100644 index 00000000..c48dd06a --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/output_backward.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from cortical.fabric.backend.cuda.sequence_surface.compiler.scan_schedule import ( + emitted_output_index_for_scan_step, + scalar_temporal_scan_step, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + TemporalBucketStepArtifacts, + TemporalOutputContract, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.scheduler import ( + TemporalOutputEmissionRuntimePlan, +) + + +def _zero_temporal_output_grad_for_contract( + runtime: Any, + artifacts: TemporalBucketStepArtifacts, + output_contract: TemporalOutputContract, +) -> torch.Tensor: + if output_contract == "full_cells": + return artifacts.cells_out.new_zeros(artifacts.cells_out.shape) + if output_contract == "output_cells": + return artifacts.output_cells.new_zeros(artifacts.output_cells.shape) + if output_contract == "pooled_output_cells": + pooled = runtime._pool_output_ports(artifacts.output_cells.unsqueeze(1)).squeeze(1) + return pooled.new_zeros(pooled.shape) + raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") + + +def _temporal_output_grad_for_physical_step( + grad_output_seq: torch.Tensor | None, + *, + global_step_index: int, + outer_time_steps: int, + inner_steps: int, + output_emissions: TemporalOutputEmissionRuntimePlan | None = None, +) -> torch.Tensor | None: + if grad_output_seq is None: + return None + if output_emissions is not None: + return output_emissions.output_grad_for_physical_step( + grad_output_seq, + physical_step=int(global_step_index), + ) + scan_step = scalar_temporal_scan_step(physical_step=global_step_index, inner_steps=inner_steps) + grad_time_steps = int(grad_output_seq.shape[1]) + output_index = emitted_output_index_for_scan_step( + scan_step, + outer_time_steps=outer_time_steps, + emitted_time_steps=grad_time_steps, + ) + return None if output_index is None else grad_output_seq[:, output_index] + + +def _temporal_output_grad_physical_window( + runtime: Any, + artifacts_window: list[TemporalBucketStepArtifacts], + *, + grad_output_seq: torch.Tensor | None, + window_start: int, + outer_time_steps: int, + inner_steps: int, + output_contract: TemporalOutputContract, + output_emissions: TemporalOutputEmissionRuntimePlan | None = None, +) -> torch.Tensor | None: + grad_steps = [] + for local_step_index, artifacts in enumerate(artifacts_window): + grad_step = _temporal_output_grad_for_physical_step( + grad_output_seq, + global_step_index=int(window_start) + int(local_step_index), + outer_time_steps=outer_time_steps, + inner_steps=inner_steps, + output_emissions=output_emissions, + ) + if grad_step is None: + grad_step = _zero_temporal_output_grad_for_contract(runtime, artifacts, output_contract) + grad_steps.append(grad_step) + return torch.stack(grad_steps, dim=1) + + +def _temporal_output_active_local_steps( + grad_output_seq: torch.Tensor | None, + *, + window_start: int, + window_len: int, + outer_time_steps: int, + inner_steps: int, + output_emissions: TemporalOutputEmissionRuntimePlan | None = None, +) -> tuple[int, ...]: + if grad_output_seq is None: + return () + if output_emissions is not None: + return output_emissions.active_local_steps( + grad_output_seq, + window_start=int(window_start), + window_len=int(window_len), + ) + active_steps: list[int] = [] + for local_step_index in range(int(window_len)): + grad_step = _temporal_output_grad_for_physical_step( + grad_output_seq, + global_step_index=int(window_start) + int(local_step_index), + outer_time_steps=int(outer_time_steps), + inner_steps=int(inner_steps), + ) + if grad_step is not None: + active_steps.append(int(local_step_index)) + return tuple(active_steps) diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py new file mode 100644 index 00000000..b075ce54 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py @@ -0,0 +1,1016 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal, TypeAlias, cast + +import torch + +from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_parameter_reducer_program_cuda, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.native_callables import ( + parameter_reducer_native_callable_id, + temporal_strategy_id_hash, + transition_trainable_reducer_native_callable_id, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.reducer_patterns import ( + TemporalParameterReducerPattern, + temporal_message_strategy_grad_output_role_opcode, + temporal_parameter_reducer_count_mode_opcode, + temporal_parameter_reducer_count_target_opcode, + temporal_parameter_reducer_kind_opcode, + temporal_parameter_reducer_pattern, + temporal_parameter_reducer_pattern_for_opcode, + temporal_parameter_reducer_strategy_opcode, + temporal_parameter_runtime_role_opcode, + temporal_parameter_trainable_role_name, + temporal_parameter_trainable_role_opcode, + temporal_parameter_trainable_roles, + temporal_transition_trainable_reducer_kind_opcode, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_bindings import ( + TemporalTransitionParamGradBinding, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.support import ( + _accumulate_owned_tensor_grad, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + TemporalSenderKVProjectionRawParamGrad, + _TransitionNamedGradSequence, +) + +_PARAMETER_REDUCER_STAGE_OPCODE = 5 +_TRANSITION_SOURCE_KIND_OPCODE = { + "materialized": 1, + "static_source": 2, +} + + +@dataclass(frozen=True) +class TemporalReadoutOutputParamReducerRequest: + kind: Literal["readout_output"] + readout_executor: Any + grad_value_to_output_weight: torch.Tensor | None + grad_output_cell_bias: torch.Tensor | None + + +@dataclass(frozen=True) +class TemporalSenderKVProjectionParamReducerRequest: + kind: Literal["sender_kv_projection"] + message_executor: Any + raw_grads: tuple[TemporalSenderKVProjectionRawParamGrad | None, ...] + + +@dataclass(frozen=True) +class TemporalRecurrentQueryParamReducerRequest: + kind: Literal["recurrent_query"] + message_executor: Any + readout_executor: Any + grad_recurrent_q_backend: torch.Tensor | None + grad_output_q: torch.Tensor | None + device: torch.device + dtype: torch.dtype + + +@dataclass(frozen=True) +class TemporalTransitionParamReducerRequest: + kind: Literal["transition"] + population_name: str + materialized_grad_accum: _TransitionNamedGradSequence + static_source_accum: _TransitionNamedGradSequence + transition_param_grad_bindings: tuple[TemporalTransitionParamGradBinding, ...] + + +@dataclass(frozen=True) +class TemporalMessageStrategyParamReducerRequest: + kind: Literal["message_strategy"] + reducer_kind: str + message_executor: Any + grad_outputs: tuple[tuple[str, torch.Tensor | None], ...] + + +TemporalParameterReducerRequest: TypeAlias = ( + TemporalReadoutOutputParamReducerRequest + | TemporalSenderKVProjectionParamReducerRequest + | TemporalRecurrentQueryParamReducerRequest + | TemporalTransitionParamReducerRequest + | TemporalMessageStrategyParamReducerRequest +) + + +@dataclass(frozen=True) +class TemporalParameterReducerProgram: + requests: tuple[TemporalParameterReducerRequest, ...] + rows: torch.Tensor + strategy_rows: torch.Tensor + trainable_role_rows: torch.Tensor + transition_source_rows: torch.Tensor + transition_trainable_rows: torch.Tensor + transition_source_names: tuple[str, ...] + summaries: tuple[str, ...] + strategy_summaries: tuple[str, ...] + trainable_role_summaries: tuple[str, ...] + transition_source_summaries: tuple[str, ...] + transition_trainable_summaries: tuple[str, ...] + + +def _require_parameter_reducer_stage( + reverse_program_stage_rows: torch.Tensor, + *, + reducer_name: str, + executor_row_index: int, + executor_id: int, +) -> None: + if ( + reverse_program_stage_rows.device.type != "cpu" + or reverse_program_stage_rows.dtype != torch.long + or reverse_program_stage_rows.dim() != 2 + or int(reverse_program_stage_rows.shape[1]) != 10 + ): + raise RuntimeError( + "Registered temporal parameter reducer requires compiler reverse stage rows with shape [N,10]" + ) + for row in reverse_program_stage_rows.tolist(): + if ( + int(row[1]) == _PARAMETER_REDUCER_STAGE_OPCODE + and int(row[3]) == int(executor_row_index) + and int(row[4]) == int(executor_id) + ): + return + raise RuntimeError( + "Registered temporal parameter reducer has no compiler parameter-reducer stage row: " + f"reducer={reducer_name}; executor_row={int(executor_row_index)}; executor_id={int(executor_id)}" + ) + + +def _transition_named_grad_sequence_has_tensor(accumulator: _TransitionNamedGradSequence) -> bool: + return any(torch.is_tensor(grad) for grads in accumulator.values() for grad in grads) + + +def _parameter_reducer_strategy_rows( + reducer_kinds: tuple[str, ...], +) -> tuple[torch.Tensor, tuple[str, ...]]: + rows: list[list[int]] = [] + summaries: list[str] = [] + for reducer_kind in reducer_kinds: + pattern = temporal_parameter_reducer_pattern(reducer_kind) + count_target = pattern.count_target + count_mode = pattern.count_mode + native_callable = parameter_reducer_native_callable_id(reducer_kind) + row_index = len(rows) + rows.append( + [ + int(row_index), + int(temporal_parameter_reducer_kind_opcode(reducer_kind)), + int(temporal_parameter_reducer_strategy_opcode(reducer_kind)), + int(temporal_parameter_reducer_count_target_opcode(count_target)), + int(temporal_parameter_reducer_count_mode_opcode(count_mode)), + 0, + 0, + 0, + int(temporal_strategy_id_hash(native_callable)), + ] + ) + summaries.append( + f"row={int(row_index)},reducer={reducer_kind},strategy={reducer_kind}," + f"count_target={count_target},count_mode={count_mode},native_callable={native_callable}" + ) + return ( + torch.tensor(rows, dtype=torch.long) if rows else torch.empty((0, 9), dtype=torch.long), + tuple(summaries), + ) + + +def _parameter_reducer_trainable_role_rows( + trainable_param_names: tuple[str, ...], +) -> tuple[torch.Tensor, tuple[str, ...]]: + trainable_name_to_index = {name: index for index, name in enumerate(trainable_param_names)} + rows: list[list[int]] = [] + summaries: list[str] = [] + for role in temporal_parameter_trainable_roles(): + trainable_name = temporal_parameter_trainable_role_name(role) + parameter_index = trainable_name_to_index.get(trainable_name) + if parameter_index is None: + continue + row_index = len(rows) + rows.append( + [ + int(row_index), + int(temporal_parameter_trainable_role_opcode(role)), + int(parameter_index), + 0, + 0, + 0, + ] + ) + summaries.append( + f"row={int(row_index)},role={role},parameter_index={int(parameter_index)},parameter={trainable_name}" + ) + return ( + torch.tensor(rows, dtype=torch.long) if rows else torch.empty((0, 6), dtype=torch.long), + tuple(summaries), + ) + + +def _parameter_reducer_runtime_metadata_rows_and_tensors( + runtime: Any, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, tuple[str, ...]]: + role_to_tensor = { + "population_backend_recurrent_inverse_order": runtime.population_backend_recurrent_inverse_order, + "recurrent_cell_idx": runtime.recurrent_cell_idx, + "output_cell_idx": runtime.output_cell_idx, + "input_cell_idx": runtime.input_cell_idx, + } + rows: list[list[int]] = [] + tensors: list[torch.Tensor] = [] + summaries: list[str] = [] + for role, tensor in role_to_tensor.items(): + row_index = len(rows) + rows.append( + [ + int(row_index), + int(temporal_parameter_runtime_role_opcode(role)), + int(row_index), + 0, + ] + ) + tensors.append(tensor) + summaries.append(f"row={int(row_index)},role={role},tensor_index={int(row_index)}") + return ( + tuple(tensors), + torch.tensor(rows, dtype=torch.long) if rows else torch.empty((0, 4), dtype=torch.long), + tuple(summaries), + ) + + +def _parameter_reducer_active_parameter_indices( + program: TemporalParameterReducerProgram, +) -> frozenset[int]: + role_to_parameter_index = {int(row[1]): int(row[2]) for row in program.trainable_role_rows.tolist()} + active_roles: set[str] = set() + for row in program.rows.tolist(): + active_roles.update(temporal_parameter_reducer_pattern_for_opcode(int(row[1])).active_trainable_roles) + active_indices = { + int(role_to_parameter_index[int(temporal_parameter_trainable_role_opcode(role))]) + for role in active_roles + if int(temporal_parameter_trainable_role_opcode(role)) in role_to_parameter_index + } + for row in program.transition_trainable_rows.tolist(): + active_indices.add(int(row[4])) + return frozenset(active_indices) + + +def _parameter_reducer_output_tensors( + program: TemporalParameterReducerProgram, + trainable_params: tuple[torch.Tensor, ...], +) -> tuple[torch.Tensor, ...]: + active_indices = _parameter_reducer_active_parameter_indices(program) + return tuple( + torch.zeros_like(param.detach()) if parameter_index in active_indices else param.detach().new_empty(0) + for parameter_index, param in enumerate(trainable_params) + ) + + +def build_temporal_parameter_reducer_program( + *, + requests: tuple[TemporalParameterReducerRequest, ...], + reverse_program_stage_rows: torch.Tensor, + trainable_param_names: tuple[str, ...], +) -> TemporalParameterReducerProgram: + rows: list[list[int]] = [] + transition_source_rows: list[list[int]] = [] + transition_trainable_rows: list[list[int]] = [] + summaries: list[str] = [] + transition_source_summaries: list[str] = [] + transition_trainable_summaries: list[str] = [] + source_name_to_index: dict[str, int] = {} + trainable_name_to_index = {name: index for index, name in enumerate(trainable_param_names)} + transition_source_tensor_cursor = 0 + reducer_kind_names: list[str] = [] + + def source_name_index(name: str) -> int: + index = source_name_to_index.get(name) + if index is None: + index = len(source_name_to_index) + source_name_to_index[name] = int(index) + return int(index) + + def add_row( + reducer_kind: str, + *, + request_index: int, + executor_row_index: int, + executor_id: int, + bucket_ordinal: int, + tensor_count: int, + flags: int = 0, + ) -> None: + nonlocal transition_source_tensor_cursor + _require_parameter_reducer_stage( + reverse_program_stage_rows, + reducer_name=reducer_kind, + executor_row_index=int(executor_row_index), + executor_id=int(executor_id), + ) + row_index = len(rows) + if reducer_kind not in reducer_kind_names: + reducer_kind_names.append(reducer_kind) + rows.append( + [ + int(row_index), + int(temporal_parameter_reducer_kind_opcode(reducer_kind)), + int(executor_row_index), + int(executor_id), + int(bucket_ordinal), + int(request_index), + int(tensor_count), + int(flags), + ] + ) + summaries.append( + f"row={int(row_index)},kind={reducer_kind},executor_row={int(executor_row_index)}," + f"executor_id={int(executor_id)},bucket={int(bucket_ordinal)},request={int(request_index)}," + f"tensor_count={int(tensor_count)},flags={int(flags)}" + ) + if reducer_kind == "transition": + transition_source_tensor_cursor = _add_transition_source_rows( + request=cast(TemporalTransitionParamReducerRequest, requests[int(request_index)]), + request_index=int(request_index), + source_name_index=source_name_index, + tensor_start=int(transition_source_tensor_cursor), + rows=transition_source_rows, + summaries=transition_source_summaries, + ) + _add_transition_trainable_rows( + request=cast(TemporalTransitionParamReducerRequest, requests[int(request_index)]), + request_index=int(request_index), + trainable_name_to_index=trainable_name_to_index, + source_name_index=source_name_index, + rows=transition_trainable_rows, + summaries=transition_trainable_summaries, + ) + + for request_index, request in enumerate(requests): + if isinstance(request, TemporalReadoutOutputParamReducerRequest): + tensor_count = int(torch.is_tensor(request.grad_value_to_output_weight)) + int( + torch.is_tensor(request.grad_output_cell_bias) + ) + if tensor_count > 0: + add_row( + "readout_output", + request_index=int(request_index), + executor_row_index=int(request.readout_executor.row_index), + executor_id=int(request.readout_executor.row.executor_id), + bucket_ordinal=int(request.readout_executor.bucket_ordinal), + tensor_count=tensor_count, + ) + elif isinstance(request, TemporalSenderKVProjectionParamReducerRequest): + raw_count = sum(raw_grad is not None for raw_grad in request.raw_grads) + if raw_count > 0: + add_row( + "sender_kv_projection", + request_index=int(request_index), + executor_row_index=int(request.message_executor.row_index), + executor_id=int(request.message_executor.row.executor_id), + bucket_ordinal=int(request.message_executor.bucket_ordinal), + tensor_count=int(raw_count), + ) + elif isinstance(request, TemporalRecurrentQueryParamReducerRequest): + if request.grad_recurrent_q_backend is not None: + add_row( + "recurrent_query", + request_index=int(request_index), + executor_row_index=int(request.message_executor.row_index), + executor_id=int(request.message_executor.row.executor_id), + bucket_ordinal=int(request.message_executor.bucket_ordinal), + tensor_count=1, + ) + if request.grad_output_q is not None: + add_row( + "output_query", + request_index=int(request_index), + executor_row_index=int(request.readout_executor.row_index), + executor_id=int(request.readout_executor.row.executor_id), + bucket_ordinal=int(request.readout_executor.bucket_ordinal), + tensor_count=1, + ) + elif isinstance(request, TemporalTransitionParamReducerRequest): + if not ( + _transition_named_grad_sequence_has_tensor(request.materialized_grad_accum) + or _transition_named_grad_sequence_has_tensor(request.static_source_accum) + ): + continue + if not request.transition_param_grad_bindings: + raise RuntimeError("Registered transition parameter reducer requires compiler reducer binding rows") + first_binding = request.transition_param_grad_bindings[0] + add_row( + "transition", + request_index=int(request_index), + executor_row_index=int(first_binding.executor_row_index), + executor_id=int(first_binding.executor_id), + bucket_ordinal=int(first_binding.bucket_ordinal), + tensor_count=sum(len(grads) for grads in request.materialized_grad_accum.values()) + + sum(len(grads) for grads in request.static_source_accum.values()), + ) + elif isinstance(request, TemporalMessageStrategyParamReducerRequest): + tensor_count = sum(int(torch.is_tensor(tensor)) for _name, tensor in request.grad_outputs) + if tensor_count > 0: + add_row( + request.reducer_kind, + request_index=int(request_index), + executor_row_index=int(request.message_executor.row_index), + executor_id=int(request.message_executor.row.executor_id), + bucket_ordinal=int(request.message_executor.bucket_ordinal), + tensor_count=tensor_count, + ) + else: + raise RuntimeError(f"Unsupported temporal parameter reducer request {request!r}") + + rows_tensor = torch.tensor(rows, dtype=torch.long) if rows else torch.empty((0, 8), dtype=torch.long) + transition_source_rows_tensor = ( + torch.tensor(transition_source_rows, dtype=torch.long) + if transition_source_rows + else torch.empty((0, 8), dtype=torch.long) + ) + transition_rows_tensor = ( + torch.tensor(transition_trainable_rows, dtype=torch.long) + if transition_trainable_rows + else torch.empty((0, 9), dtype=torch.long) + ) + source_names = tuple(name for name, _index in sorted(source_name_to_index.items(), key=lambda item: int(item[1]))) + strategy_rows_tensor, strategy_summaries = _parameter_reducer_strategy_rows(tuple(reducer_kind_names)) + trainable_role_rows_tensor, trainable_role_summaries = _parameter_reducer_trainable_role_rows( + trainable_param_names, + ) + return TemporalParameterReducerProgram( + requests=requests, + rows=rows_tensor, + strategy_rows=strategy_rows_tensor, + trainable_role_rows=trainable_role_rows_tensor, + transition_source_rows=transition_source_rows_tensor, + transition_trainable_rows=transition_rows_tensor, + transition_source_names=source_names, + summaries=tuple(summaries), + strategy_summaries=strategy_summaries, + trainable_role_summaries=trainable_role_summaries, + transition_source_summaries=tuple(transition_source_summaries), + transition_trainable_summaries=tuple(transition_trainable_summaries), + ) + + +def _add_transition_source_rows( + *, + request: TemporalTransitionParamReducerRequest, + request_index: int, + source_name_index: Any, + tensor_start: int, + rows: list[list[int]], + summaries: list[str], +) -> int: + cursor = int(tensor_start) + + def add_source(source_name: str, source_kind: str, tensor_count: int) -> None: + nonlocal cursor + row_index = len(rows) + rows.append( + [ + int(row_index), + int(request_index), + int(source_name_index(source_name)), + int(_TRANSITION_SOURCE_KIND_OPCODE[source_kind]), + int(cursor), + int(tensor_count), + 0, + 0, + ] + ) + summaries.append( + f"row={int(row_index)},request={int(request_index)},source={source_name}," + f"kind={source_kind},tensor_start={int(cursor)},tensor_count={int(tensor_count)}" + ) + cursor += int(tensor_count) + + for source_name, grads in request.materialized_grad_accum.items(): + if len(grads) <= 0: + continue + add_source(source_name, "materialized", len(grads)) + for source_name, grads in request.static_source_accum.items(): + if len(grads) <= 0: + continue + add_source(source_name, "static_source", len(grads)) + return int(cursor) + + +def _add_transition_trainable_rows( + *, + request: TemporalTransitionParamReducerRequest, + request_index: int, + trainable_name_to_index: dict[str, int], + source_name_index: Any, + rows: list[list[int]], + summaries: list[str], +) -> None: + population_prefix = f"population_modules.{request.population_name}." + + def add_row( + *, + source_name: str, + reducer_kind: str, + target_name: str, + aux_name: str = "", + ) -> None: + target_index = trainable_name_to_index.get(target_name) + if target_index is None: + return + aux_index = -1 if not aux_name else trainable_name_to_index.get(aux_name, -1) + if aux_name and aux_index < 0: + return + row_index = len(rows) + native_callable = transition_trainable_reducer_native_callable_id(reducer_kind) + rows.append( + [ + int(row_index), + int(request_index), + int(source_name_index(source_name)), + int(temporal_transition_trainable_reducer_kind_opcode(reducer_kind)), + int(target_index), + int(aux_index), + 0, + 0, + int(temporal_strategy_id_hash(native_callable)), + ] + ) + summaries.append( + f"row={int(row_index)},request={int(request_index)},source={source_name}," + f"kind={reducer_kind},target_index={int(target_index)},target={target_name}," + f"aux_index={int(aux_index)},aux={aux_name or '-'},native_callable={native_callable}" + ) + + for materialized_name, grads in request.materialized_grad_accum.items(): + if len(grads) <= 0: + continue + add_row( + source_name=materialized_name, + reducer_kind="materialized_base", + target_name=f"{population_prefix}{materialized_name}_base", + ) + add_row( + source_name=materialized_name, + reducer_kind="materialized_delta", + target_name=f"{population_prefix}{materialized_name}_delta", + ) + if request.static_source_accum.get("value_to_cell_weight"): + add_row( + source_name="value_to_cell_weight", + reducer_kind="value_to_cell_msg_to_cell", + target_name="msg_to_cell.weight", + aux_name="msg_out.weight", + ) + add_row( + source_name="value_to_cell_weight", + reducer_kind="value_to_cell_msg_out", + target_name="msg_out.weight", + aux_name="msg_to_cell.weight", + ) + if request.static_source_accum.get("message_to_cell_weight"): + add_row( + source_name="message_to_cell_weight", + reducer_kind="materialized_base", + target_name="msg_to_cell.weight", + ) + if request.static_source_accum.get("recurrent_cell_bias"): + add_row( + source_name="recurrent_cell_bias", + reducer_kind="recurrent_bias_slot_embed", + target_name="slot_embed", + aux_name="cell_bias_proj.weight", + ) + add_row( + source_name="recurrent_cell_bias", + reducer_kind="recurrent_bias_cell_bias_proj", + target_name="cell_bias_proj.weight", + aux_name="slot_embed", + ) + + +def _registered_named_param_grad_tuple( + *, + named_grads: dict[str, torch.Tensor | None], + trainable_params: tuple[torch.Tensor, ...], + trainable_param_names: tuple[str, ...], +) -> tuple[torch.Tensor | None, ...]: + return tuple(named_grads.get(name) for name in trainable_param_names) + + +def _require_sender_kv_projection_param_binding( + message_executor: Any, + static_tensors: dict[str, object], + raw_grad: TemporalSenderKVProjectionRawParamGrad, +) -> None: + role = str(raw_grad.role) + if role == "recurrent": + if message_executor.optional_static_tensor(static_tensors, "recurrent_sender_kv_weight") is not None: + return + if message_executor.optional_static_tensor(static_tensors, "recurrent_sender_value_weight") is not None: + return + raise RuntimeError( + "Registered sender projection parameter reducer requires a compiler-bound recurrent sender weight: " + "logical='recurrent_sender_kv_weight' or 'recurrent_sender_value_weight'" + ) + if role == "input": + logical_name = "input_group_kv_weight" if raw_grad.grouped else "input_sender_kv_weight" + if message_executor.optional_static_tensor(static_tensors, logical_name) is not None: + return + context_nudge_logical = "input_group_value_weight" if raw_grad.grouped else "input_sender_value_weight" + if message_executor.optional_static_tensor(static_tensors, context_nudge_logical) is not None: + return + raise RuntimeError( + "Registered sender projection parameter reducer requires a compiler-bound input sender weight: " + f"logical={logical_name!r} or {context_nudge_logical!r}" + ) + raise RuntimeError(f"Registered sender K/V parameter reducer received unsupported role {role!r}") + + +def _transition_recurrent_index_for_source( + runtime: Any, + request: TemporalTransitionParamReducerRequest, + *, + source_name: str, + source_tensors: tuple[torch.Tensor, ...], +) -> torch.Tensor: + reference = next((tensor for tensor in source_tensors if torch.is_tensor(tensor)), None) + if reference is None: + return runtime.recurrent_cell_idx.new_empty(0) + if source_name in {"message_to_cell_weight", "value_to_cell_weight"}: + return torch.empty(0, device=reference.device, dtype=torch.long) + recurrent_row_idx = torch.arange( + int(runtime.recurrent_cell_idx.numel()), + device=reference.device, + dtype=torch.long, + ) + population_recurrent_row_idx = runtime._population_recurrent_indices(request.population_name) + population_full_recurrent_idx = runtime.recurrent_cell_idx.index_select( + 0, + population_recurrent_row_idx.to(device=runtime.recurrent_cell_idx.device, dtype=torch.long), + ) + row_axis_candidates = [0] + if reference.dim() >= 4: + row_axis_candidates.append(2) + if reference.dim() >= 3: + row_axis_candidates.append(1) + for axis in row_axis_candidates: + if axis >= reference.dim(): + continue + row_count = int(reference.shape[axis]) + if row_count == int(population_recurrent_row_idx.numel()): + if source_name == "recurrent_cell_bias": + return population_full_recurrent_idx.to(device=reference.device, dtype=torch.long).contiguous() + return population_recurrent_row_idx.to(device=reference.device, dtype=torch.long).contiguous() + if row_count == int(recurrent_row_idx.numel()): + if source_name == "recurrent_cell_bias": + return runtime.recurrent_cell_idx.to(device=reference.device, dtype=torch.long).contiguous() + return recurrent_row_idx.contiguous() + if source_name == "recurrent_cell_bias": + return runtime.recurrent_cell_idx.to(device=reference.device, dtype=torch.long).contiguous() + return torch.empty(0, device=reference.device, dtype=torch.long) + + +def _validate_transition_param_reducer_program( + *, + transition_param_grad_bindings: tuple[TemporalTransitionParamGradBinding, ...], + materialized_names: set[str], + static_source_names: set[str], +) -> None: + if not transition_param_grad_bindings: + raise RuntimeError("Registered transition parameter reducer requires compiler reducer binding rows") + direct_materialized = { + str(binding.parameter_name) + for binding in transition_param_grad_bindings + if binding.reducer_kind == "materialized" + } + input_projection_parameters = { + str(binding.parameter_name) + for binding in transition_param_grad_bindings + if str(binding.reducer_kind).startswith("input_projection_") + } + has_input_projection_weight = any( + binding.reducer_kind == "input_projection_weight" for binding in transition_param_grad_bindings + ) + has_input_projection_bias = any( + binding.reducer_kind == "input_projection_bias" for binding in transition_param_grad_bindings + ) + allowed_input_projection_materialized = {"input_proj_weight"} | input_projection_parameters + for name in materialized_names: + if name in direct_materialized: + continue + if has_input_projection_weight and name in allowed_input_projection_materialized: + continue + raise RuntimeError( + "Registered transition parameter reducer received a materialized gradient with no compiler row: " + f"name={name!r}; allowed={tuple(sorted(direct_materialized | allowed_input_projection_materialized))}" + ) + for name in static_source_names: + if name == "value_to_cell_weight" and has_input_projection_weight: + continue + if name == "message_to_cell_weight" and has_input_projection_weight: + continue + if name == "recurrent_cell_bias" and has_input_projection_bias: + continue + raise RuntimeError( + "Registered transition parameter reducer received a static-source gradient with no compiler row: " + f"name={name!r}" + ) + + +def _require_message_strategy_reducer_bindings( + *, + pattern: TemporalParameterReducerPattern, + message_executor: Any, + static_tensors: dict[str, object], +) -> None: + for logical_group in pattern.required_static_logical_groups: + if any( + message_executor.optional_static_tensor(static_tensors, logical_name) is not None + for logical_name in logical_group + ): + continue + raise RuntimeError( + "Registered message strategy parameter reducer is missing compiler-bound static tensor group: " + f"reducer={pattern.reducer_kind!r}; executor={message_executor.executor_name}; " + f"logical_options={tuple(logical_group)!r}" + ) + + +def _run_registered_parameter_reducer_cuda_program( + runtime: Any, + *, + program: TemporalParameterReducerProgram, + static_tensors: dict[str, object], + trainable_params: tuple[torch.Tensor, ...], + trainable_param_names: tuple[str, ...], +) -> tuple[torch.Tensor | None, ...]: + sender_grad_weight_tensors: list[torch.Tensor] = [] + sender_group_id_tensors: list[torch.Tensor] = [] + sender_grouped_flags: list[int] = [] + readout_grad_value_to_output_weight_tensors: list[torch.Tensor] = [] + readout_grad_output_cell_bias_tensors: list[torch.Tensor] = [] + recurrent_query_grad_tensors: list[torch.Tensor] = [] + output_query_grad_tensors: list[torch.Tensor] = [] + message_strategy_grad_tensors: list[torch.Tensor] = [] + message_strategy_grad_rows: list[list[int]] = [] + transition_source_tensors: list[torch.Tensor] = [] + transition_source_recurrent_cell_idx_tensors: list[torch.Tensor] = [] + summaries: list[str] = [] + executed_requests: set[int] = set() + transition_opcode = int(temporal_parameter_reducer_kind_opcode("transition")) + + for row in program.rows.tolist(): + reducer_kind = int(row[1]) + if reducer_kind == transition_opcode: + continue + request_index = int(row[5]) + if request_index in executed_requests: + continue + executed_requests.add(request_index) + request = program.requests[request_index] + if isinstance(request, TemporalReadoutOutputParamReducerRequest): + if request.grad_value_to_output_weight is not None: + request.readout_executor.require_static_tensor(static_tensors, "value_to_output_weight") + readout_grad_value_to_output_weight_tensors.append(request.grad_value_to_output_weight.contiguous()) + if request.grad_output_cell_bias is not None: + request.readout_executor.require_runtime_tensor_attr(runtime, "output_cell_bias") + readout_grad_output_cell_bias_tensors.append(request.grad_output_cell_bias.contiguous()) + summaries.append( + "readout_output:" + f"value_weight={int(torch.is_tensor(request.grad_value_to_output_weight))};" + f"bias={int(torch.is_tensor(request.grad_output_cell_bias))}" + ) + elif isinstance(request, TemporalSenderKVProjectionParamReducerRequest): + raw_count = 0 + for raw_grad in request.raw_grads: + if raw_grad is None: + continue + _require_sender_kv_projection_param_binding(request.message_executor, static_tensors, raw_grad) + sender_grad_weight_tensors.append(raw_grad.grad_weight.contiguous()) + sender_group_id_tensors.append( + raw_grad.group_ids.to(device=raw_grad.grad_weight.device, dtype=torch.long).contiguous() + ) + sender_grouped_flags.append(int(bool(raw_grad.grouped))) + raw_count += 1 + summaries.append(f"sender_kv_projection:raw_grads={int(raw_count)}") + elif isinstance(request, TemporalRecurrentQueryParamReducerRequest): + if request.grad_recurrent_q_backend is not None: + request.message_executor.require_static_tensor(static_tensors, "recurrent_q_weight") + recurrent_query_grad_tensors.append(request.grad_recurrent_q_backend.contiguous()) + if request.grad_output_q is not None: + request.readout_executor.require_static_tensor(static_tensors, "output_q") + output_query_grad_tensors.append(request.grad_output_q.contiguous()) + summaries.append( + "query:" + f"recurrent={int(torch.is_tensor(request.grad_recurrent_q_backend))};" + f"output={int(torch.is_tensor(request.grad_output_q))}" + ) + elif isinstance(request, TemporalMessageStrategyParamReducerRequest): + pattern = temporal_parameter_reducer_pattern(request.reducer_kind) + _require_message_strategy_reducer_bindings( + pattern=pattern, + message_executor=request.message_executor, + static_tensors=static_tensors, + ) + grad_output_by_name = dict(request.grad_outputs) + tensor_count = 0 + for grad_name, grad_tensor in grad_output_by_name.items(): + if grad_tensor is None: + continue + if pattern.grad_output_roles and str(grad_name) not in pattern.grad_output_roles: + raise RuntimeError( + "Registered message strategy parameter reducer received an output role outside its " + "compiler reducer contract: " + f"reducer={request.reducer_kind!r}; output={grad_name!r}; " + f"declared={pattern.grad_output_roles!r}" + ) + try: + role_opcode = temporal_message_strategy_grad_output_role_opcode(str(grad_name)) + except RuntimeError as exc: + raise RuntimeError( + "Registered message strategy parameter reducer received an undeclared output role: " + f"reducer={request.reducer_kind!r}; output={grad_name!r}" + ) from exc + tensor_index = len(message_strategy_grad_tensors) + message_strategy_grad_tensors.append(grad_tensor.contiguous()) + message_strategy_grad_rows.append( + [ + len(message_strategy_grad_rows), + int(temporal_parameter_reducer_kind_opcode(request.reducer_kind)), + int(role_opcode), + int(tensor_index), + 0, + ] + ) + tensor_count += 1 + summaries.append(f"message_strategy:{request.reducer_kind}:tensors={int(tensor_count)}") + else: + raise RuntimeError(f"Unsupported common temporal parameter reducer request {request!r}") + + covered_transition_sources = {(int(row[1]), int(row[2])) for row in program.transition_trainable_rows.tolist()} + for row in program.transition_source_rows.tolist(): + request_index = int(row[1]) + source_index = int(row[2]) + source_kind = int(row[3]) + tensor_count = int(row[5]) + if source_index < 0 or source_index >= len(program.transition_source_names): + raise RuntimeError(f"Transition source reducer row has invalid source index {source_index}") + if (request_index, source_index) not in covered_transition_sources: + source_name = program.transition_source_names[source_index] + raise RuntimeError( + "Registered transition parameter reducer has gradient source with no compiler trainable row: " + f"request={int(request_index)}; source={source_name!r}" + ) + request = program.requests[request_index] + if not isinstance(request, TemporalTransitionParamReducerRequest): + raise RuntimeError(f"Transition source reducer row references non-transition request {request!r}") + _validate_transition_param_reducer_program( + transition_param_grad_bindings=request.transition_param_grad_bindings, + materialized_names=set(request.materialized_grad_accum), + static_source_names=set(request.static_source_accum), + ) + source_name = program.transition_source_names[source_index] + source_grads = ( + request.materialized_grad_accum.get(source_name, ()) + if source_kind == _TRANSITION_SOURCE_KIND_OPCODE["materialized"] + else request.static_source_accum.get(source_name, ()) + if source_kind == _TRANSITION_SOURCE_KIND_OPCODE["static_source"] + else () + ) + if len(source_grads) != tensor_count: + raise RuntimeError( + "Transition source reducer row tensor count does not match request gradients: " + f"request={int(request_index)}; source={source_name!r}; " + f"row_count={int(tensor_count)}; actual={len(source_grads)}" + ) + source_tuple = tuple(grad.contiguous() for grad in source_grads) + transition_source_tensors.extend(source_tuple) + transition_source_recurrent_cell_idx_tensors.append( + _transition_recurrent_index_for_source( + runtime, + request, + source_name=source_name, + source_tensors=source_tuple, + ) + ) + + if not ( + sender_grad_weight_tensors + or readout_grad_value_to_output_weight_tensors + or readout_grad_output_cell_bias_tensors + or recurrent_query_grad_tensors + or output_query_grad_tensors + or message_strategy_grad_tensors + or transition_source_tensors + ): + return tuple(None for _param in trainable_params) + + runtime_metadata_tensors, runtime_metadata_rows, runtime_metadata_summaries = ( + _parameter_reducer_runtime_metadata_rows_and_tensors(runtime) + ) + parameter_output_tensors = _parameter_reducer_output_tensors(program, trainable_params) + output_tensors = registered_temporal_parameter_reducer_program_cuda( + parameter_reducer_rows=program.rows, + parameter_reducer_strategy_rows=program.strategy_rows, + parameter_reducer_trainable_role_rows=program.trainable_role_rows, + parameter_reducer_runtime_metadata_rows=runtime_metadata_rows, + transition_source_rows=program.transition_source_rows, + transition_trainable_rows=program.transition_trainable_rows, + sender_grad_weight_tensors=tuple(sender_grad_weight_tensors), + sender_group_id_tensors=tuple(sender_group_id_tensors), + sender_grouped_flags=torch.tensor(sender_grouped_flags, dtype=torch.long), + readout_grad_value_to_output_weight_tensors=tuple(readout_grad_value_to_output_weight_tensors), + readout_grad_output_cell_bias_tensors=tuple(readout_grad_output_cell_bias_tensors), + recurrent_query_grad_tensors=tuple(recurrent_query_grad_tensors), + output_query_grad_tensors=tuple(output_query_grad_tensors), + message_strategy_grad_tensors=tuple(message_strategy_grad_tensors), + message_strategy_grad_rows=( + torch.tensor(message_strategy_grad_rows, dtype=torch.long) + if message_strategy_grad_rows + else torch.empty((0, 5), dtype=torch.long) + ), + transition_source_tensors=tuple(transition_source_tensors), + transition_source_recurrent_cell_idx_tensors=tuple(transition_source_recurrent_cell_idx_tensors), + parameter_output_tensors=parameter_output_tensors, + trainable_param_tensors=tuple(param.detach() for param in trainable_params), + runtime_metadata_tensors=runtime_metadata_tensors, + coord_count=int(runtime.coords.shape[0]), + head_dim=int(runtime.head_dim), + value_dim=int(runtime.value_dim), + ) + if len(output_tensors) != len(trainable_params): + raise RuntimeError( + "Registered parameter reducer returned mismatched trainable output count: " + f"actual={len(output_tensors)}; expected={len(trainable_params)}" + ) + grad_tuple: list[torch.Tensor | None] = [None for _param in trainable_params] + for parameter_index, grad in enumerate(output_tensors): + if int(grad.numel()) == 0: + continue + grad_tuple[parameter_index] = _accumulate_owned_tensor_grad(grad_tuple[parameter_index], grad) + runtime._last_flat_bucket_common_param_reducer_program = "registered_temporal_parameter_reducer_cuda_row_program" + runtime._last_flat_bucket_common_param_reducer_rows = program.rows + runtime._last_flat_bucket_common_param_reducer_strategy_rows = program.strategy_rows + runtime._last_flat_bucket_common_param_reducer_summaries = tuple(summaries) + runtime._last_flat_bucket_common_param_reducer_strategy_summaries = program.strategy_summaries + runtime._last_flat_bucket_common_param_reducer_trainable_role_rows = program.trainable_role_rows + runtime._last_flat_bucket_common_param_reducer_trainable_role_summaries = program.trainable_role_summaries + runtime._last_flat_bucket_common_param_reducer_runtime_metadata_rows = runtime_metadata_rows + runtime._last_flat_bucket_common_param_reducer_runtime_metadata_summaries = runtime_metadata_summaries + runtime._last_flat_bucket_common_param_reducer_output_plan = "compiler_owned_parameter_output_tensor_table" + runtime._last_flat_bucket_common_param_reducer_output_active_indices = tuple( + sorted(_parameter_reducer_active_parameter_indices(program)) + ) + runtime._last_flat_bucket_transition_param_reducer_program = ( + "registered_transition_cuda_trainable_parameter_row_program" + ) + runtime._last_flat_bucket_transition_param_reducer_rows = program.transition_trainable_rows + runtime._last_flat_bucket_transition_param_reducer_summaries = program.transition_trainable_summaries + runtime._last_flat_bucket_transition_param_reducer_source_rows = program.transition_source_rows + runtime._last_flat_bucket_transition_param_reducer_source_summaries = program.transition_source_summaries + return tuple(grad_tuple) + + +def run_temporal_parameter_reducer_program( + runtime: Any, + *, + program: TemporalParameterReducerProgram, + static_tensors: dict[str, object], + trainable_params: tuple[torch.Tensor, ...], + trainable_param_names: tuple[str, ...], +) -> tuple[torch.Tensor | None, ...]: + grad_accum: list[torch.Tensor | None] = [None] * len(trainable_params) + + def accumulate(grads: tuple[torch.Tensor | None, ...]) -> None: + for parameter_index, grad in enumerate(grads): + grad_accum[parameter_index] = _accumulate_owned_tensor_grad(grad_accum[parameter_index], grad) + + accumulate( + _run_registered_parameter_reducer_cuda_program( + runtime, + program=program, + static_tensors=static_tensors, + trainable_params=trainable_params, + trainable_param_names=trainable_param_names, + ) + ) + + runtime._last_flat_bucket_temporal_parameter_reducer_program = ( + "registered_temporal_parameter_reducer_cuda_row_program" + ) + runtime._last_flat_bucket_temporal_parameter_reducer_rows = program.rows + runtime._last_flat_bucket_temporal_parameter_reducer_strategy_rows = program.strategy_rows + runtime._last_flat_bucket_temporal_parameter_reducer_summaries = program.summaries + runtime._last_flat_bucket_temporal_parameter_reducer_strategy_summaries = program.strategy_summaries + runtime._last_flat_bucket_temporal_parameter_reducer_trainable_role_rows = program.trainable_role_rows + runtime._last_flat_bucket_temporal_parameter_reducer_trainable_role_summaries = program.trainable_role_summaries + runtime._last_flat_bucket_temporal_transition_trainable_parameter_rows = program.transition_trainable_rows + runtime._last_flat_bucket_temporal_transition_trainable_parameter_summaries = program.transition_trainable_summaries + runtime._last_flat_bucket_temporal_transition_source_parameter_rows = program.transition_source_rows + runtime._last_flat_bucket_temporal_transition_source_parameter_summaries = program.transition_source_summaries + return tuple(grad_accum) diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py new file mode 100644 index 00000000..d7f256f4 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +from typing import Any, Literal, cast + +import torch +from tensordict import TensorDict, TensorDictBase + +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + TemporalArtifactStore, + TemporalOutputContract, +) + +from cortical.fabric.backend.cuda.sequence_surface.temporal.common import ( + _validate_temporal_physical_backward_plan, +) + +from cortical.fabric.backend.cuda.sequence_surface.temporal.forward_scan import ( + run_shared_temporal_bucket_forward_scan, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.reverse_executor import ( + TemporalPhysicalBackwardScanExecutor, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.registered_executors import ( + _record_registered_backward_memory_stage, +) + +_TemporalPopulationStateSpec = tuple[tuple[str, tuple[str, ...]], ...] +_TemporalStateSpec = tuple[tuple[str, ...], _TemporalPopulationStateSpec] + + +def _flatten_temporal_state_inputs( + runtime: Any, + state: TensorDict, +) -> tuple[_TemporalStateSpec, tuple[torch.Tensor, ...]]: + top_level_keys = ("cells",) + tensors: list[torch.Tensor] = [state["cells"]] + population_specs: list[tuple[str, tuple[str, ...]]] = [] + for population_name in runtime._population_names: + population_state = state[population_name] + if not isinstance(population_state, TensorDictBase): + raise RuntimeError(f"Temporal bucket sequence requires TensorDict state for {population_name}") + keys = tuple(runtime._compiled_transition_state_names_for_population(population_name)) + population_specs.append((population_name, keys)) + for key in keys: + tensor = population_state[key] + if not torch.is_tensor(tensor): + raise RuntimeError(f"Temporal bucket sequence state {population_name}.{key} is not a tensor") + tensors.append(tensor) + return (top_level_keys, tuple(population_specs)), tuple(tensors) + + +def _unflatten_temporal_state( + specs: _TemporalStateSpec, + tensors: tuple[torch.Tensor, ...], +) -> TensorDict: + top_level_keys, population_specs = specs + state = TensorDict({}, batch_size=[]) + offset = 0 + for key in top_level_keys: + state[key] = tensors[offset] + offset += 1 + for population_name, keys in population_specs: + leaves: dict[str, torch.Tensor] = {} + first: torch.Tensor | None = None + for key in keys: + tensor = tensors[offset] + offset += 1 + leaves[key] = tensor + if first is None: + first = tensor + state[population_name] = TensorDict( + leaves, + batch_size=[] if first is None else list(first.shape[:2]), + device=None if first is None else first.device, + ) + return state + + +def _flatten_temporal_state_grad_outputs( + specs: _TemporalStateSpec, + grad_state: TensorDict, +) -> tuple[torch.Tensor | None, ...]: + top_level_keys, population_specs = specs + grads: list[torch.Tensor | None] = [] + for key in top_level_keys: + grad = grad_state.get(key) + grads.append(grad if torch.is_tensor(grad) else None) + for population_name, keys in population_specs: + population_grad = grad_state.get(population_name) + for key in keys: + if isinstance(population_grad, TensorDictBase): + grad = population_grad.get(key) + grads.append(grad if torch.is_tensor(grad) else None) + else: + grads.append(None) + return tuple(grads) + + +def run_temporal_bucket_sequence_physical_autograd( + runtime: Any, + *, + boundary_seq: torch.Tensor, + state: TensorDict, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None = None, + static_tensors: dict[str, object], + planned_backward_execution: Any | None, + materialize_final_state: bool, + output_contract: TemporalOutputContract = "full_cells", + output_boundary: Literal["sequence", "terminal"] = "sequence", + inner_steps: int = 1, +) -> tuple[torch.Tensor, TensorDict]: + inner_steps = max(1, int(inner_steps)) + state_specs, state_tensors = _flatten_temporal_state_inputs(runtime, state) + trainable_items = tuple((name, param) for name, param in runtime.named_parameters() if param.requires_grad) + outputs = _TemporalBucketSequenceFunction.apply( + runtime, + static_tensors, + planned_backward_execution, + state_specs, + tuple(name for name, _param in trainable_items), + materialize_final_state, + output_contract, + output_boundary, + inner_steps, + boundary_seq, + population_resets, + transition_resets, + *state_tensors, + *(param for _name, param in trainable_items), + ) + output_seq = cast(torch.Tensor, outputs[0]) + if not materialize_final_state: + return output_seq, TensorDict({}, batch_size=[]) + state_tensor_count = len(state_tensors) + final_state = _unflatten_temporal_state( + state_specs, + cast(tuple[torch.Tensor, ...], tuple(outputs[1 : 1 + state_tensor_count])), + ) + return output_seq, final_state + + +class _TemporalBucketSequenceFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + runtime: Any, + static_tensors: dict[str, object], + planned_backward_execution: Any | None, + state_specs: _TemporalStateSpec, + trainable_param_names: tuple[str, ...], + materialize_final_state: bool, + output_contract: TemporalOutputContract, + output_boundary: Literal["sequence", "terminal"], + inner_steps: int, + boundary_seq: torch.Tensor, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None, + *state_tensors_and_params: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + state_tensor_count = 1 + sum(len(keys) for _population_name, keys in state_specs[1]) + state_tensors = tuple(state_tensors_and_params[:state_tensor_count]) + trainable_params = tuple(state_tensors_and_params[state_tensor_count:]) + ctx.runtime = runtime + ctx.static_tensors = static_tensors + ctx.planned_backward_execution = planned_backward_execution + ctx.state_specs = state_specs + ctx.trainable_param_names = trainable_param_names + ctx.materialize_final_state = bool(materialize_final_state) + ctx.output_contract = output_contract + ctx.output_boundary = output_boundary + ctx.inner_steps = max(1, int(inner_steps)) + ctx.state_tensor_count = state_tensor_count + ctx.trainable_param_count = len(trainable_params) + ctx.has_resets = torch.is_tensor(population_resets) + ctx.has_transition_resets = torch.is_tensor(transition_resets) + ctx.save_for_backward( + boundary_seq, + *((population_resets,) if torch.is_tensor(population_resets) else ()), + *((transition_resets,) if torch.is_tensor(transition_resets) else ()), + *state_tensors, + *trainable_params, + ) + scan_result = run_shared_temporal_bucket_forward_scan( + runtime, + boundary_seq=boundary_seq, + state=_unflatten_temporal_state(state_specs, state_tensors), + population_resets=population_resets, + transition_resets=transition_resets, + static_tensors=static_tensors, + inner_steps=int(ctx.inner_steps), + materialize_final_state=materialize_final_state, + output_contract=cast(TemporalOutputContract, output_contract), + output_boundary=output_boundary, + planned_backward_execution=planned_backward_execution, + collect_artifacts=True, + ) + if scan_result.artifact_store is None: + raise RuntimeError("Temporal bucket sequence produced no artifacts") + ctx.artifact_store = scan_result.artifact_store + if not materialize_final_state: + return (scan_result.output_seq,) + final_state_tensors = _flatten_temporal_state_grad_outputs(state_specs, scan_result.final_state) + if any(tensor is None for tensor in final_state_tensors): + raise RuntimeError("Temporal bucket sequence produced incomplete final state") + return (scan_result.output_seq, *cast(tuple[torch.Tensor, ...], final_state_tensors)) + + @staticmethod + def backward( + ctx: Any, + *grad_outputs: torch.Tensor | None, + ) -> tuple[object, ...]: + saved = ctx.saved_tensors + offset = 0 + boundary_seq = saved[offset] + offset += 1 + if not hasattr(ctx.runtime, "_last_flat_bucket_temporal_registered_backward_memory_stages"): + ctx.runtime._last_flat_bucket_temporal_registered_backward_memory_stages = () + _record_registered_backward_memory_stage(ctx.runtime, boundary_seq, "autograd_backward_entry") + population_resets = None + if ctx.has_resets: + population_resets = saved[offset] + offset += 1 + transition_resets = None + if ctx.has_transition_resets: + transition_resets = saved[offset] + offset += 1 + state_tensor_count = int(ctx.state_tensor_count) + state_tensors = tuple(saved[offset : offset + state_tensor_count]) + offset += state_tensor_count + trainable_params = tuple(saved[offset:]) + + artifact_store = cast(TemporalArtifactStore, ctx.artifact_store) + _validate_temporal_physical_backward_plan(ctx.planned_backward_execution) + _record_registered_backward_memory_stage(ctx.runtime, boundary_seq, "autograd_plan_validated") + + grad_output_seq = grad_outputs[0] + grad_final_state = ( + _unflatten_temporal_state( + ctx.state_specs, + cast(tuple[torch.Tensor, ...], tuple(grad_outputs[1 : 1 + state_tensor_count])), + ) + if ctx.materialize_final_state + else TensorDict({}, batch_size=[]) + ) + _record_registered_backward_memory_stage(ctx.runtime, boundary_seq, "autograd_grad_outputs_prepared") + scan_result = TemporalPhysicalBackwardScanExecutor( + ctx.runtime, + static_tensors=ctx.static_tensors, + trainable_params=cast(tuple[torch.Tensor, ...], trainable_params), + trainable_param_names=ctx.trainable_param_names, + output_contract=cast(TemporalOutputContract, ctx.output_contract), + output_boundary=cast(Literal["sequence", "terminal"], ctx.output_boundary), + materialize_final_state=ctx.materialize_final_state, + boundary_requires_grad=boundary_seq.requires_grad, + state_requires_grad=any(tensor.requires_grad for tensor in state_tensors), + inner_steps=int(ctx.inner_steps), + temporal_plan=getattr(ctx.planned_backward_execution, "temporal_plan", None), + ).run( + boundary_seq=boundary_seq, + artifact_store=artifact_store, + population_resets=cast(torch.Tensor | None, population_resets), + transition_resets=cast(torch.Tensor | None, transition_resets), + grad_output_seq=grad_output_seq, + grad_final_state=grad_final_state, + ) + + grad_state_payload: dict[str, object] = { + name: TensorDict( + {key: value for key, value in population_grad.items() if torch.is_tensor(value)}, + batch_size=[], + ) + for name, population_grad in scan_result.grad_next_population_state.items() + } + if state_tensors[0].requires_grad: + grad_state_payload["cells"] = ( + scan_result.grad_carry_cells + if torch.is_tensor(scan_result.grad_carry_cells) + else torch.zeros_like(state_tensors[0]) + ) + grad_state_tensors = _flatten_temporal_state_grad_outputs( + ctx.state_specs, + TensorDict(grad_state_payload, batch_size=[]), + ) + return ( + None, + None, + None, + None, + None, + None, + None, + None, + None, + scan_result.grad_boundary_seq, + None, + None, + *grad_state_tensors, + *scan_result.param_grads, + ) diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py new file mode 100644 index 00000000..79b3cfcb --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py @@ -0,0 +1,364 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from cortical.fabric.backend.cuda.sequence_surface.compiler.buckets import ( + backend_order_flat_buckets, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.common import ( + _flat_bucket_name, +) + + +def transition_buckets_for_executors( + runtime: Any, + transition_executors: tuple[Any, ...], + static_tensors: dict[str, object], +) -> tuple[tuple[Any, Any], ...]: + buckets = tuple(backend_order_flat_buckets(runtime, static_tensors)) + by_ordinal = {int(index): bucket for index, bucket in enumerate(buckets)} + ordered: list[tuple[Any, Any]] = [] + for executor in sorted(transition_executors, key=lambda item: int(item.bucket_ordinal)): + bucket = by_ordinal.get(int(executor.bucket_ordinal)) + if bucket is None: + raise RuntimeError( + "Registered transition executor row points at a missing compiler bucket: " + f"bucket={int(executor.bucket_ordinal)}; bucket_count={len(buckets)}" + ) + if int(bucket.backend_start) != int(executor.row.receiver_start) or ( + int(bucket.backend_stop) - int(bucket.backend_start) != int(executor.row.receiver_count) + ): + raise RuntimeError( + "Registered transition executor row receiver range does not match compiler bucket: " + f"bucket={int(executor.bucket_ordinal)}; " + f"row_start={int(executor.row.receiver_start)}; row_count={int(executor.row.receiver_count)}; " + f"bucket_start={int(bucket.backend_start)}; " + f"bucket_count={int(bucket.backend_stop) - int(bucket.backend_start)}" + ) + ordered.append((executor, bucket)) + return tuple(ordered) + + +def transition_program_slot_table( + *, + executors: tuple[Any, ...], + reference: torch.Tensor, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, dict[str, int]]: + slot_by_bucket_logical: dict[tuple[int, str], int] = {} + logical_to_slot: dict[str, int] = {} + tensors: list[torch.Tensor] = [] + rows: list[list[int]] = [] + + def slot_for_binding(binding: Any) -> int: + logical_name = str(binding.logical_name) + key = (int(binding.bucket_ordinal), logical_name) + slot = slot_by_bucket_logical.get(key) + if slot is None: + slot = len(tensors) + slot_by_bucket_logical[key] = int(slot) + logical_to_slot.setdefault(logical_name, int(slot)) + tensors.append(reference.new_empty(0)) + return int(slot) + + for executor in executors: + for binding in executor.bindings: + rows.append( + [ + int(binding.binding_index), + slot_for_binding(binding), + int(binding.primitive_row_index), + 0 if binding.binding_kind == "input" else 1 if binding.binding_kind == "parameter" else 2, + ] + ) + return tuple(tensors), torch.tensor(rows, dtype=torch.long), logical_to_slot + + +def transition_parameter_tensor_table( + runtime: Any, + *, + transition_executors: tuple[Any, ...], + static_tensors: dict[str, object], + reference: torch.Tensor, + recurrent_msg_dim: int, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + message_program = getattr(getattr(runtime, "backend_ir", None), "message_program", None) + prefer_projected_message_input = str(getattr(message_program, "output_dim_role", "value_dim")) != "value_dim" + buckets = tuple(backend_order_flat_buckets(runtime, static_tensors)) + by_ordinal = {int(index): bucket for index, bucket in enumerate(buckets)} + tensors: list[torch.Tensor] = [] + rows: list[list[int]] = [] + slot_by_bucket_logical: dict[tuple[int, str], int] = {} + + for executor in transition_executors: + bucket_ordinal = int(executor.bucket_ordinal) + bucket = by_ordinal.get(bucket_ordinal) + if bucket is None: + raise RuntimeError( + "Registered transition parameter table references a missing compiler bucket: " + f"bucket={bucket_ordinal}; bucket_count={len(buckets)}" + ) + population_params = _transition_population_params(bucket) + for binding in executor.bindings: + if binding.binding_kind != "parameter": + continue + logical_name = str(binding.logical_name) + key = (bucket_ordinal, logical_name) + slot = slot_by_bucket_logical.get(key) + if slot is None: + slot = len(tensors) + slot_by_bucket_logical[key] = int(slot) + tensors.append( + _resolve_transition_parameter_tensor( + binding, + bucket=bucket, + static_tensors=static_tensors, + population_params=population_params, + recurrent_msg_dim=int(recurrent_msg_dim), + receiver_count=int(executor.row.receiver_count), + prefer_projected_message_input=bool(prefer_projected_message_input), + ) + ) + rows.append([int(binding.binding_index), int(slot), bucket_ordinal]) + if not rows: + return (), torch.empty((0, 3), dtype=torch.long) + return tuple(tensors), torch.tensor(rows, dtype=torch.long) + + +def surface_parameter_tensor_table( + runtime: Any, + *, + handles: tuple[Any, ...], + static_tensors: dict[str, object], + reference: torch.Tensor, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + tensors: list[torch.Tensor] = [] + rows: list[list[int]] = [] + slot_by_binding_index: dict[int, int] = {} + + for handle in handles: + for binding in handle.bindings: + if binding.binding_kind != "parameter": + continue + binding_index = int(binding.binding_index) + slot = slot_by_binding_index.get(binding_index) + if slot is None: + slot = len(tensors) + slot_by_binding_index[binding_index] = int(slot) + tensors.append( + _resolve_surface_parameter_tensor( + runtime, + binding, + static_tensors=static_tensors, + reference=reference, + ) + ) + rows.append( + [ + binding_index, + int(slot), + int(binding.primitive_row_index), + 1, + ] + ) + if not rows: + raise RuntimeError("Registered program surface parameter tensor table has no compiler parameter bindings") + return tuple(tensors), torch.tensor(rows, dtype=torch.long) + + +def _transition_population_params(bucket: Any) -> dict[str, object]: + population_materialized = bucket.static_tensors.get("population_materialized") + if not isinstance(population_materialized, dict): + raise RuntimeError(f"Registered transition bucket {_flat_bucket_name(bucket)!r} has no materialized parameters") + population_name = _flat_bucket_name(bucket) + population_params = population_materialized.get(population_name) + if not isinstance(population_params, dict): + raise RuntimeError(f"Registered transition bucket {population_name!r} has no population parameter table") + return population_params + + +def _normalize_transition_parameter_tensor( + tensor: torch.Tensor, + *, + logical_name: str, + recurrent_msg_dim: int, + receiver_count: int, +) -> torch.Tensor: + if tensor.dim() == 1 and int(receiver_count) > 0 and int(tensor.numel()) % int(receiver_count) == 0: + tensor = tensor.reshape(int(receiver_count), -1) + if tensor.dim() == 3 and int(tensor.shape[0]) == 1: + tensor = tensor.squeeze(0) + if tensor.dim() == 3 and int(tensor.shape[1]) == 1: + tensor = tensor.squeeze(1) + if logical_name in {"value_to_state_weight", "input_proj_weight"} and tensor.dim() == 2: + if int(tensor.shape[1]) == int(recurrent_msg_dim): + return tensor.transpose(0, 1).contiguous() + return tensor.contiguous() + + +def _transition_parameter_tensor_is_compatible( + tensor: torch.Tensor, + *, + logical_name: str, + recurrent_msg_dim: int, +) -> bool: + if logical_name not in {"value_to_state_weight", "input_proj_weight"}: + return True + if tensor.dim() == 2: + return int(tensor.shape[0]) == int(recurrent_msg_dim) + if tensor.dim() == 3: + return int(tensor.shape[1]) == int(recurrent_msg_dim) + return False + + +def _resolve_surface_parameter_tensor( + runtime: Any, + binding: Any, + *, + static_tensors: dict[str, object], + reference: torch.Tensor, +) -> torch.Tensor: + for source in binding.source_bindings: + source_text = str(source) + if source_text.startswith("message_parameter:"): + value = _resolve_message_parameter_tensor( + runtime, + binding, + static_tensors=static_tensors, + ) + if value is not None: + return value.contiguous() + if source_text.startswith("static_tensor:"): + key = source_text.removeprefix("static_tensor:") + value = static_tensors.get(key) + if torch.is_tensor(value): + return value.contiguous() + if key in static_tensors and surface_parameter_binding_allows_empty_static_tensor(binding, key): + return reference.new_empty((0,)) + if source_text.startswith("runtime_attr:"): + key = source_text.removeprefix("runtime_attr:") + value = getattr(runtime, key, None) + if torch.is_tensor(value): + return value.contiguous() + raise RuntimeError( + "Registered program surface parameter table could not resolve compiler parameter binding: " + f"direction={binding.direction}; surface={binding.surface}; bucket={int(binding.bucket_ordinal)}; " + f"executor={binding.executor_name}; primitive={binding.primitive}; logical={binding.logical_name!r}; " + f"sources={binding.source_bindings!r}" + ) + + +def surface_parameter_binding_allows_empty_static_tensor(binding: Any, source_key: str) -> bool: + """Return true for compiler bindings whose selected layout can intentionally omit one tensor side.""" + + if str(getattr(binding, "surface", "")) != "message" or str(getattr(binding, "binding_kind", "")) != "parameter": + return False + logical_name = str(getattr(binding, "logical_name", "")) + optional_message_projection_sources = { + ("input_sender_kv_weight", "input_sender_input_to_kv_weight"), + ("input_group_kv_weight", "input_group_input_to_kv_weight"), + } + return (logical_name, str(source_key)) in optional_message_projection_sources + + +def _resolve_message_parameter_tensor( + runtime: Any, + binding: Any, + *, + static_tensors: dict[str, object], +) -> torch.Tensor | None: + logical_name = str(binding.logical_name) + value = static_tensors.get(logical_name) + if torch.is_tensor(value): + return value + message_parameters = getattr(runtime, "message_rule_parameters", None) + if hasattr(message_parameters, "__contains__") and logical_name in message_parameters: + value = message_parameters[logical_name] + if torch.is_tensor(value): + return value + return None + + +def _resolve_transition_parameter_tensor( + binding: Any, + *, + bucket: Any, + static_tensors: dict[str, object], + population_params: dict[str, object], + recurrent_msg_dim: int, + receiver_count: int, + prefer_projected_message_input: bool, +) -> torch.Tensor: + first_incompatible: torch.Tensor | None = None + + def normalized_if_compatible(value: torch.Tensor, *, logical_name: str) -> torch.Tensor | None: + nonlocal first_incompatible + tensor = _normalize_transition_parameter_tensor( + value, + logical_name=logical_name, + recurrent_msg_dim=recurrent_msg_dim, + receiver_count=receiver_count, + ) + if _transition_parameter_tensor_is_compatible( + tensor, + logical_name=logical_name, + recurrent_msg_dim=recurrent_msg_dim, + ): + return tensor + if first_incompatible is None: + first_incompatible = tensor + return None + + if prefer_projected_message_input and str(binding.logical_name) in {"value_to_state_weight", "input_proj_weight"}: + for key in ("fused_recurrent_value_to_cell_weight", "message_to_cell_weight"): + value = static_tensors.get(key) + if torch.is_tensor(value): + tensor = normalized_if_compatible(value, logical_name=str(binding.logical_name)) + if tensor is not None: + return tensor + for source in binding.source_bindings: + source_text = str(source) + if source_text.startswith("static_tensor:"): + key = source_text.removeprefix("static_tensor:") + value = bucket.static_tensors.get(key, static_tensors.get(key)) + if torch.is_tensor(value): + tensor = normalized_if_compatible(value, logical_name=str(binding.logical_name)) + if tensor is not None: + return tensor + if source_text.startswith("expanded_transposed_static_tensor:"): + key = source_text.removeprefix("expanded_transposed_static_tensor:") + value = bucket.static_tensors.get(key, static_tensors.get(key)) + if torch.is_tensor(value): + tensor = value.transpose(0, 1).contiguous() if value.dim() == 2 else value.contiguous() + normalized = normalized_if_compatible(tensor, logical_name=str(binding.logical_name)) + if normalized is not None: + return normalized + if source_text.startswith("cell_param:"): + key = source_text.removeprefix("cell_param:") + value = population_params.get(key) + if torch.is_tensor(value): + tensor = normalized_if_compatible(value, logical_name=str(binding.logical_name)) + if tensor is not None: + return tensor + if first_incompatible is not None: + raise RuntimeError( + "Registered transition parameter table found only incompatible compiler parameter sources: " + f"bucket={int(binding.bucket_ordinal)}; primitive={binding.primitive}; " + f"logical={binding.logical_name!r}; recurrent_msg_dim={int(recurrent_msg_dim)}; " + f"first_shape={tuple(int(dim) for dim in first_incompatible.shape)}; " + f"sources={binding.source_bindings!r}" + ) + raise RuntimeError( + "Registered transition parameter table could not resolve compiler parameter binding: " + f"bucket={int(binding.bucket_ordinal)}; primitive={binding.primitive}; " + f"logical={binding.logical_name!r}; sources={binding.source_bindings!r}" + ) + + +__all__ = [ + "surface_parameter_tensor_table", + "transition_buckets_for_executors", + "transition_parameter_tensor_table", + "transition_program_slot_table", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py new file mode 100644 index 00000000..2f745143 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch +from tensordict import TensorDictBase + +from cortical.fabric.backend.cuda.sequence_surface.compiler.forward_program import ( + temporal_forward_program_access_rows_tensor, + temporal_forward_transition_state_carry_rows_tensor, + temporal_reverse_program_access_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.common import ( + _flat_bucket_name, + temporal_message_output_dim, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.program_parameters import ( + surface_parameter_tensor_table, + transition_buckets_for_executors, + transition_parameter_tensor_table, + transition_program_slot_table, +) + + +@dataclass(frozen=True) +class TemporalExecutableProgramTensorTable: + program_tensors: tuple[torch.Tensor, ...] + program_tensor_binding_rows: torch.Tensor + forward_program_access_rows: torch.Tensor + reverse_program_access_rows: torch.Tensor + forward_transition_state_carry_rows: torch.Tensor + transition_parameter_tensors: tuple[torch.Tensor, ...] + transition_parameter_rows: torch.Tensor + review_summary: tuple[str, ...] + + +def build_forward_executable_program_tensor_table( + runtime: Any, + *, + executor_program: Any, + boundary_seq: torch.Tensor, + state: Any, + static_tensors: dict[str, object], + step_population_state_cache: dict[str, object] | None, +) -> TemporalExecutableProgramTensorTable: + tensors: list[torch.Tensor] = [] + rows: list[list[int]] = [] + slot_by_binding_index: dict[int, int] = {} + + surface_parameter_tensors, surface_parameter_rows = surface_parameter_tensor_table( + runtime, + handles=( + *executor_program.forward_surface_handles(surface="message"), + *executor_program.forward_surface_handles(surface="readout"), + ), + static_tensors=static_tensors, + reference=boundary_seq, + ) + _append_program_rows( + tensors, + rows, + slot_by_binding_index, + source_tensors=surface_parameter_tensors, + source_rows=surface_parameter_rows, + ) + + transition_executors = tuple( + handle + for bucket_index in executor_program.transition_bucket_ordinals() + for handle in executor_program.forward_handles_for(surface="transition", bucket_ordinal=int(bucket_index)) + ) + forward_program_access_rows = temporal_forward_program_access_rows_tensor( + message_handles=executor_program.forward_surface_handles(surface="message"), + readout_handles=executor_program.forward_surface_handles(surface="readout"), + transition_handles=transition_executors, + ) + forward_transition_state_carry_rows = temporal_forward_transition_state_carry_rows_tensor( + transition_handles=transition_executors, + ) + transition_parameter_tensors, transition_parameter_rows = transition_parameter_tensor_table( + runtime, + transition_executors=transition_executors, + static_tensors=static_tensors, + reference=boundary_seq, + recurrent_msg_dim=temporal_message_output_dim(runtime), + ) + seen_transition_buckets: set[int] = set() + for transition_executor, bucket in transition_buckets_for_executors(runtime, transition_executors, static_tensors): + if int(transition_executor.bucket_ordinal) in seen_transition_buckets: + continue + seen_transition_buckets.add(int(transition_executor.bucket_ordinal)) + transition_tensors, transition_rows = _forward_transition_program_tensors_for_bucket( + runtime, + transition_executors=tuple( + handle + for handle in transition_executors + if int(handle.bucket_ordinal) == int(transition_executor.bucket_ordinal) + ), + bucket=bucket, + boundary_seq=boundary_seq, + state=state, + step_population_state_cache=step_population_state_cache, + transition_parameter_tensors=transition_parameter_tensors, + transition_parameter_rows=transition_parameter_rows, + ) + _append_program_rows( + tensors, + rows, + slot_by_binding_index, + source_tensors=transition_tensors, + source_rows=transition_rows, + ) + if not rows: + raise RuntimeError("Registered fused forward program tensor table has no compiler bindings") + return TemporalExecutableProgramTensorTable( + program_tensors=tuple(tensors), + program_tensor_binding_rows=torch.tensor(rows, dtype=torch.long), + forward_program_access_rows=forward_program_access_rows, + reverse_program_access_rows=torch.empty((0, 6), dtype=torch.long), + forward_transition_state_carry_rows=forward_transition_state_carry_rows, + transition_parameter_tensors=transition_parameter_tensors, + transition_parameter_rows=transition_parameter_rows, + review_summary=( + "owner=temporal_executable_program_tensor_table", + "direction=forward", + f"program_tensors={len(tensors)}", + f"program_rows={len(rows)}", + f"forward_program_access_rows={int(forward_program_access_rows.shape[0])}", + f"forward_transition_state_carry_rows={int(forward_transition_state_carry_rows.shape[0])}", + f"transition_parameter_tensors={len(transition_parameter_tensors)}", + f"transition_parameter_rows={int(transition_parameter_rows.shape[0])}", + ), + ) + + +def build_reverse_executable_program_tensor_table( + runtime: Any, + *, + executor_program: Any, + static_tensors: dict[str, object], + reference: torch.Tensor, +) -> TemporalExecutableProgramTensorTable: + program_tensors, program_tensor_binding_rows = surface_parameter_tensor_table( + runtime, + handles=( + *executor_program.reverse_surface_handles(surface="message"), + *executor_program.reverse_surface_handles(surface="readout"), + ), + static_tensors=static_tensors, + reference=reference, + ) + reverse_program_access_rows = temporal_reverse_program_access_rows_tensor( + message_handles=executor_program.reverse_surface_handles(surface="message"), + readout_handles=executor_program.reverse_surface_handles(surface="readout"), + ) + transition_forward_executors = tuple( + handle + for bucket_index in executor_program.transition_bucket_ordinals() + for handle in executor_program.forward_handles_for(surface="transition", bucket_ordinal=int(bucket_index)) + ) + transition_reverse_executors = tuple( + handle + for bucket_index in executor_program.transition_bucket_ordinals() + for handle in executor_program.reverse_handles_for(surface="transition", bucket_ordinal=int(bucket_index)) + ) + transition_parameter_tensors, transition_parameter_rows = transition_parameter_tensor_table( + runtime, + transition_executors=(*transition_forward_executors, *transition_reverse_executors), + static_tensors=static_tensors, + reference=reference, + recurrent_msg_dim=temporal_message_output_dim(runtime), + ) + return TemporalExecutableProgramTensorTable( + program_tensors=program_tensors, + program_tensor_binding_rows=program_tensor_binding_rows, + forward_program_access_rows=torch.empty((0, 6), dtype=torch.long), + reverse_program_access_rows=reverse_program_access_rows, + forward_transition_state_carry_rows=torch.empty((0, 3), dtype=torch.long), + transition_parameter_tensors=transition_parameter_tensors, + transition_parameter_rows=transition_parameter_rows, + review_summary=( + "owner=temporal_executable_program_tensor_table", + "direction=reverse", + f"program_tensors={len(program_tensors)}", + f"program_rows={int(program_tensor_binding_rows.shape[0])}", + f"reverse_program_access_rows={int(reverse_program_access_rows.shape[0])}", + f"transition_parameter_tensors={len(transition_parameter_tensors)}", + f"transition_parameter_rows={int(transition_parameter_rows.shape[0])}", + ), + ) + + +def _forward_transition_program_tensors_for_bucket( + runtime: Any, + *, + transition_executors: tuple[Any, ...], + bucket: Any, + boundary_seq: torch.Tensor, + state: Any, + step_population_state_cache: dict[str, object] | None, + transition_parameter_tensors: tuple[torch.Tensor, ...], + transition_parameter_rows: torch.Tensor, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + if not transition_executors: + raise RuntimeError("Registered fused forward program has no transition executor rows for bucket") + transition_executor = transition_executors[0] + population_name = _flat_bucket_name(bucket) + packed_state_before: object | None = None + if step_population_state_cache is not None: + packed_state_before = step_population_state_cache.get(population_name) + if packed_state_before is None: + population_state = state.get(population_name) + if isinstance(population_state, TensorDictBase): + packed_state_before = runtime._population_state_to_backend_state(population_name, population_state) + transition_tensors_raw, transition_rows, _logical_to_slot = transition_program_slot_table( + executors=transition_executors, + reference=boundary_seq, + ) + transition_tensors = list(transition_tensors_raw) + local_slot_by_binding_index = {int(row[0]): int(row[1]) for row in transition_rows.cpu().tolist()} + for row in transition_parameter_rows.cpu().tolist(): + if int(row[2]) != int(transition_executor.bucket_ordinal): + continue + local_slot = local_slot_by_binding_index.get(int(row[0])) + if local_slot is not None: + transition_tensors[local_slot] = transition_parameter_tensors[int(row[1])].contiguous() + produced_logical_names = { + str(binding.logical_name) + for executor in transition_executors + for binding in executor.bindings + if binding.binding_kind == "output" + } + for executor in transition_executors: + for binding in executor.bindings: + if binding.binding_kind != "input": + continue + if str(binding.logical_name) == "aggregated_message": + continue + if str(binding.logical_name) in produced_logical_names: + continue + tensor = None + if hasattr(packed_state_before, "get"): + value = packed_state_before.get(str(binding.logical_name)) # type: ignore[attr-defined] + tensor = value if torch.is_tensor(value) else None + if tensor is None and packed_state_before is None: + tensor = boundary_seq.new_empty( + ( + int(boundary_seq.shape[0]), + 0, + int(boundary_seq.shape[-1]), + ) + ) + if tensor is None: + raise RuntimeError( + "Registered fused forward program could not resolve compiler state input binding: " + f"bucket={int(binding.bucket_ordinal)}; executor={binding.executor_name}; " + f"primitive={binding.primitive}; logical={binding.logical_name!r}; " + f"sources={binding.source_bindings!r}" + ) + local_slot = local_slot_by_binding_index.get(int(binding.binding_index)) + if local_slot is not None: + transition_tensors[local_slot] = tensor.contiguous() + return tuple(tensor.contiguous() for tensor in transition_tensors), transition_rows + + +def _append_program_rows( + tensors: list[torch.Tensor], + rows: list[list[int]], + slot_by_binding_index: dict[int, int], + *, + source_tensors: tuple[torch.Tensor, ...], + source_rows: torch.Tensor, +) -> None: + slot_offset = len(tensors) + tensors.extend(tensor.contiguous() for tensor in source_tensors) + for row in source_rows.cpu().tolist(): + binding_index = int(row[0]) + local_slot = int(row[1]) + slot_by_binding_index[binding_index] = slot_offset + local_slot + rows.append( + [ + binding_index, + slot_offset + local_slot, + int(row[2]), + int(row[3]), + ] + ) + + +__all__ = [ + "TemporalExecutableProgramTensorTable", + "build_forward_executable_program_tensor_table", + "build_reverse_executable_program_tensor_table", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py new file mode 100644 index 00000000..b845685f --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py @@ -0,0 +1,3742 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any, Callable, Literal, cast + +import torch +from tensordict import TensorDict, TensorDictBase + +from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_buckets import ( + _partial_backend_grad_state_to_population_state, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.backward_plan import ( + TemporalBackwardExecutablePlan, + build_temporal_backward_executable_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_bindings import ( + TemporalExecutorBindingPlan, + TemporalExecutorTensorBinding, + TemporalTransitionParamGradBinding, + build_temporal_forward_executor_binding_plan, + build_temporal_reverse_executor_binding_plan, + build_temporal_transition_param_grad_binding_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_patterns import ( + temporal_executor_strategy_registry, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.forward_plan import ( + TemporalForwardExecutablePlan, + build_temporal_forward_executable_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.forward_program import ( + temporal_program_access_opcode, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.memory_plan import ( + TemporalMemoryLivenessPlan, + TemporalMemoryRuntimeArtifactPlan, + TemporalMemoryRuntimeSchedulePlan, + TemporalTransitionForwardRuntimeBufferRequest, + TemporalTransitionReverseRuntimeBufferRequest, + allocate_temporal_runtime_buffers, + build_temporal_memory_liveness_plan, + build_temporal_memory_runtime_schedule_plan, + build_temporal_physical_strategy_plan, + build_temporal_runtime_buffer_plan, + temporal_memory_liveness_rows_tensor, + temporal_memory_runtime_schedule_rows_tensor, + temporal_physical_strategy_rows_tensor, + temporal_runtime_buffer_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.program_execution import ( + TemporalFusedCudaProgramPlan, + TemporalRegisteredProgramExecutorPlan, + TemporalReverseProgramStagePlan, + build_temporal_fused_cuda_program_plan, + build_temporal_message_transition_producer_consumer_plan, + build_temporal_readout_message_producer_consumer_plan, + build_temporal_registered_program_executor_plan, + build_temporal_reverse_program_stage_plan, + temporal_forward_artifact_merge_rows_tensor, + temporal_forward_artifact_merge_summaries, + temporal_forward_artifact_route_rows_tensor, + temporal_forward_artifact_route_summaries, + temporal_forward_executor_handler_rows_tensor, + temporal_forward_output_route_kind_opcode, + temporal_forward_output_route_rows_tensor, + temporal_forward_output_route_summaries, + temporal_native_callable_binding_schema_rows_tensor, + temporal_native_callable_catalog_rows_tensor, + temporal_native_callable_output_rows_tensor, + temporal_native_executor_strategy_rows_tensor, + temporal_message_transition_producer_consumer_rows_tensor, + temporal_message_transition_producer_consumer_summaries, + temporal_readout_message_producer_consumer_rows_tensor, + temporal_readout_message_producer_consumer_summaries, + temporal_reverse_artifact_consumer_route_rows_tensor, + temporal_reverse_artifact_consumer_route_summaries, + temporal_reverse_executor_handler_rows_tensor, + temporal_reverse_output_route_kind_opcode, + temporal_reverse_output_route_rows_tensor, + temporal_reverse_output_route_target_id, + temporal_reverse_parameter_reducer_route_rows_tensor, + temporal_reverse_parameter_reducer_route_summaries, + temporal_reverse_span_output_group_opcode, + temporal_reverse_span_output_rows_tensor, + temporal_transition_primitive_native_callable_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.program_runtime import ( + build_temporal_forward_program_runtime_plan, + build_temporal_forward_program_runtime_support_plan, + build_temporal_reverse_program_runtime_plan, + build_temporal_reverse_program_runtime_support_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.primitive_registry import ( + temporal_surface_opcode, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.native_callables import ( + temporal_native_callable_transition_forward_output_definition, + temporal_transition_reverse_seed_role_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.reverse_artifacts import ( + temporal_reverse_artifact_access_id, + temporal_reverse_artifact_access_rows_tensor, + temporal_reverse_artifact_role_id, + temporal_reverse_artifact_role_is_tensor, + temporal_reverse_artifact_role_names, + temporal_reverse_artifact_role_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.reset_plan import ( + temporal_forward_reset_tensor_table, + temporal_reverse_reset_tensor_table, + temporal_reverse_transition_state_reset_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.scan_schedule import ( + scalar_temporal_scan_step, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.tables import ( + TemporalForwardExecutorRow, + TemporalPrimitiveTablePlan, + TemporalReverseExecutorRow, + build_temporal_primitive_table_plan, + temporal_forward_executor_rows, + temporal_primitive_rows_tensor, + temporal_reverse_executor_rows, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.support import ( + _accumulate_tensor_grad, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.surface_executor_runtime import ( + _transition_param_grad_accumulator_from_binding_rows, + _transition_reverse_seed_role_id, + _transition_reverse_state_grad_names, + _transition_reverse_seed_tensor_table, + _transition_tensor_by_logical_optional, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.program_parameters import ( + transition_buckets_for_executors, + transition_program_slot_table, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.program_tensors import ( + build_forward_executable_program_tensor_table, + build_reverse_executable_program_tensor_table, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.executor_registry import ( + RegisteredTemporalExecutorKernelRegistry, + registered_temporal_executor_kernel_registry, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.param_binding import ( + TemporalMessageStrategyParamReducerRequest, + TemporalParameterReducerRequest, + TemporalReadoutOutputParamReducerRequest, + TemporalRecurrentQueryParamReducerRequest, + TemporalSenderKVProjectionParamReducerRequest, + TemporalTransitionParamReducerRequest, + build_temporal_parameter_reducer_program, + run_temporal_parameter_reducer_program, +) +from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_fused_backward_program_cuda, + registered_temporal_fused_backward_program_stage_memory_rows, + registered_temporal_fused_forward_program_cuda, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.common import ( + _flat_bucket_name, + _grad_output_cells_for_contract, + _record_temporal_backward_glue_cuda, + _record_temporal_reverse_scan_owner, + temporal_message_output_dim, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + SharedTemporalForwardScanResult, + TemporalArtifactCheckpoint, + TemporalArtifactStore, + TemporalBackwardWindowResult, + TemporalBucketStepArtifacts, + TemporalOutputContract, + TemporalReverseArtifactTensorStore, + TemporalSenderKVProjectionRawParamGrad, + TemporalTransitionTapeMode, + _TransitionParamGradAccumulator, +) +from cortical.fabric.backend.cuda.transition_execution.registry import ( + transition_primitive_executor_record_for_lowered_primitive, +) + + +@dataclass(frozen=True) +class RegisteredTemporalExecutorHandle: + direction: Literal["forward", "reverse"] + row_index: int + row: TemporalForwardExecutorRow | TemporalReverseExecutorRow + bindings: tuple[TemporalExecutorTensorBinding, ...] + + @property + def executor_name(self) -> str: + return str(self.row.executor_name) + + @property + def surface(self) -> str: + return str(self.row.surface) + + @property + def bucket_ordinal(self) -> int: + return int(self.row.bucket_ordinal) + + @property + def parameter_names(self) -> tuple[str, ...]: + return tuple(binding.logical_name for binding in self.bindings if binding.binding_kind == "parameter") + + def require_parameter_bindings(self) -> None: + missing = tuple( + parameter for parameter in self.row.parameter_bindings if parameter not in set(self.parameter_names) + ) + if missing: + raise RuntimeError( + "Registered temporal executor row is missing compiler-owned parameter bindings: " + f"direction={self.direction}; surface={self.surface}; bucket={self.bucket_ordinal}; " + f"executor={self.executor_name}; missing={missing!r}" + ) + + def parameter_binding(self, logical_name: str) -> TemporalExecutorTensorBinding: + matches = tuple( + binding + for binding in self.bindings + if binding.binding_kind == "parameter" and binding.logical_name == str(logical_name) + ) + if len(matches) != 1: + raise RuntimeError( + "Registered temporal executor row has no unique compiler parameter binding: " + f"direction={self.direction}; surface={self.surface}; bucket={self.bucket_ordinal}; " + f"executor={self.executor_name}; logical={logical_name!r}; count={len(matches)}" + ) + return matches[0] + + def source_keys(self, logical_name: str, *, prefix: str) -> tuple[str, ...]: + matches = tuple( + binding + for binding in self.bindings + if binding.binding_kind == "parameter" and binding.logical_name == str(logical_name) + ) + if not matches: + raise RuntimeError( + "Registered temporal executor row has no compiler parameter binding: " + f"direction={self.direction}; surface={self.surface}; bucket={self.bucket_ordinal}; " + f"executor={self.executor_name}; logical={logical_name!r}" + ) + token_prefix = f"{prefix}:" + return tuple( + dict.fromkeys( + str(source).removeprefix(token_prefix) + for binding in matches + for source in binding.source_bindings + if str(source).startswith(token_prefix) + ) + ) + + def require_static_tensor( + self, + static_tensors: dict[str, object], + logical_name: str, + ) -> torch.Tensor: + compiled_keys = self.source_keys(logical_name, prefix="static_tensor") + for key in compiled_keys: + tensor = static_tensors.get(key) + if torch.is_tensor(tensor): + return tensor + raise RuntimeError( + "Registered temporal executor could not resolve compiler-bound static tensor: " + f"direction={self.direction}; surface={self.surface}; bucket={self.bucket_ordinal}; " + f"executor={self.executor_name}; logical={logical_name!r}; compiled={compiled_keys!r}" + ) + + def optional_static_tensor( + self, + static_tensors: dict[str, object], + logical_name: str, + ) -> torch.Tensor | None: + try: + compiled_keys = self.source_keys(logical_name, prefix="static_tensor") + except RuntimeError: + return None + for key in compiled_keys: + tensor = static_tensors.get(key) + if torch.is_tensor(tensor): + return tensor + return None + + def require_runtime_tensor_attr( + self, + runtime: Any, + logical_name: str, + ) -> torch.Tensor: + compiled_names = self.source_keys(logical_name, prefix="runtime_attr") + for name in compiled_names: + tensor = getattr(runtime, name, None) + if torch.is_tensor(tensor): + return tensor + raise RuntimeError( + "Registered temporal executor could not resolve compiler-bound runtime tensor: " + f"direction={self.direction}; surface={self.surface}; bucket={self.bucket_ordinal}; " + f"executor={self.executor_name}; logical={logical_name!r}; compiled={compiled_names!r}" + ) + + +def _message_strategy_parameter_reducer_request( + message_executor: RegisteredTemporalExecutorHandle, + *, + grad_recurrent_query_backend: torch.Tensor | None, + boundary_outputs_by_logical_name: dict[str, torch.Tensor | None], +) -> tuple[TemporalMessageStrategyParamReducerRequest | None, bool]: + strategy = temporal_executor_strategy_registry().reverse_pattern_for_executor( + surface="message", + executor_name=message_executor.executor_name, + ) + if not strategy.parameter_reducer_kind: + return None, False + grad_outputs: list[tuple[str, torch.Tensor | None]] = [] + consumes_recurrent_query_grad = False + for output in strategy.message_param_grad_outputs: + if output.source == "recurrent_query_grad": + tensor = grad_recurrent_query_backend + consumes_recurrent_query_grad = True + elif output.source == "boundary_extra_output": + if output.logical_name not in boundary_outputs_by_logical_name: + raise RuntimeError( + "Registered message reverse strategy did not return a declared parameter-gradient output: " + f"strategy={strategy.stable_strategy_id}; output={output.logical_name!r}; " + "the fused reverse boundary outputs were not declared by compiler role rows" + ) + tensor = boundary_outputs_by_logical_name[output.logical_name] + else: + raise RuntimeError( + "Registered message reverse strategy declares an unsupported parameter-gradient source: " + f"strategy={strategy.stable_strategy_id}; source={output.source!r}" + ) + grad_outputs.append((output.logical_name, tensor if torch.is_tensor(tensor) else None)) + return ( + TemporalMessageStrategyParamReducerRequest( + kind="message_strategy", + reducer_kind=strategy.parameter_reducer_kind, + message_executor=message_executor, + grad_outputs=tuple(grad_outputs), + ), + consumes_recurrent_query_grad, + ) + + +def _reverse_span_output_tensor_by_role_id( + outputs: tuple[torch.Tensor, ...], + output_rows: torch.Tensor, + *, + group_opcode: int, + role_id: int, + route_label: str, +) -> torch.Tensor: + if ( + output_rows.device.type != "cpu" + or output_rows.dtype != torch.long + or output_rows.dim() != 2 + or int(output_rows.shape[1]) != 6 + ): + raise RuntimeError("Registered fused reverse span output rows must be CPU int64 [N,6]") + matches = [ + int(row[3]) for row in output_rows.tolist() if int(row[1]) == int(group_opcode) and int(row[2]) == int(role_id) + ] + if len(matches) != 1: + raise RuntimeError( + "Registered fused reverse span outputs are missing a compiler-declared role: " + f"{route_label}; matches={len(matches)}" + ) + slot = matches[0] + if slot < 0 or slot >= len(outputs): + raise RuntimeError( + "Registered fused reverse span output role points outside returned group: " + f"{route_label}; slot={slot}; group_size={len(outputs)}" + ) + return outputs[slot] + + +def _reverse_routed_span_output_tensor( + *, + front_outputs: tuple[torch.Tensor, ...], + boundary_outputs: tuple[torch.Tensor, ...], + output_rows: torch.Tensor, + route_rows: torch.Tensor, + route_kind: str, + target_role: str, +) -> torch.Tensor: + if ( + route_rows.device.type != "cpu" + or route_rows.dtype != torch.long + or route_rows.dim() != 2 + or int(route_rows.shape[1]) != 8 + ): + raise RuntimeError("Registered fused reverse output route rows must be CPU int64 [N,8]") + route_kind_opcode = temporal_reverse_output_route_kind_opcode(route_kind) + target_role_id = temporal_reverse_output_route_target_id(target_role) + matches = [ + (int(row[3]), int(row[4])) + for row in route_rows.tolist() + if int(row[1]) == int(route_kind_opcode) and int(row[2]) == int(target_role_id) + ] + if len(matches) != 1: + raise RuntimeError( + "Registered fused reverse reducer output route is missing or ambiguous: " + f"route_kind={route_kind!r}; target_role={target_role!r}; matches={len(matches)}" + ) + group_opcode, source_role_id = matches[0] + front_group_opcode = temporal_reverse_span_output_group_opcode("front") + boundary_group_opcode = temporal_reverse_span_output_group_opcode("boundary") + if group_opcode == front_group_opcode: + outputs = front_outputs + elif group_opcode == boundary_group_opcode: + outputs = boundary_outputs + else: + raise RuntimeError( + "Registered fused reverse reducer output route points at an unknown span group: " + f"route_kind={route_kind!r}; target_role={target_role!r}; group_opcode={int(group_opcode)}" + ) + return _reverse_span_output_tensor_by_role_id( + outputs, + output_rows, + group_opcode=int(group_opcode), + role_id=int(source_role_id), + route_label=f"route_kind={route_kind!r}; target_role={target_role!r}", + ) + + +def _reverse_parameter_reducer_routed_span_output_tensor( + *, + front_outputs: tuple[torch.Tensor, ...], + boundary_outputs: tuple[torch.Tensor, ...], + output_rows: torch.Tensor, + reducer_route_rows: torch.Tensor, + executor: RegisteredTemporalExecutorHandle, + route_kind: str, + target_role: str, +) -> torch.Tensor: + if ( + reducer_route_rows.device.type != "cpu" + or reducer_route_rows.dtype != torch.long + or reducer_route_rows.dim() != 2 + or int(reducer_route_rows.shape[1]) != 12 + ): + raise RuntimeError("Registered fused reverse parameter reducer route rows must be CPU int64 [N,12]") + route_kind_opcode = temporal_reverse_output_route_kind_opcode(route_kind) + target_role_id = temporal_reverse_output_route_target_id(target_role) + surface_opcode = temporal_surface_opcode(executor.surface) + matches = [ + (int(row[3]), int(row[4])) + for row in reducer_route_rows.tolist() + if int(row[1]) == int(route_kind_opcode) + and int(row[2]) == int(target_role_id) + and int(row[5]) == int(surface_opcode) + and int(row[6]) == int(executor.row_index) + and int(row[7]) == int(executor.row.executor_id) + and int(row[8]) == int(executor.bucket_ordinal) + ] + if len(matches) != 1: + raise RuntimeError( + "Registered fused reverse parameter reducer route is missing or ambiguous: " + f"route_kind={route_kind!r}; target_role={target_role!r}; surface={executor.surface!r}; " + f"executor_row={int(executor.row_index)}; bucket={int(executor.bucket_ordinal)}; matches={len(matches)}" + ) + group_opcode, source_role_id = matches[0] + front_group_opcode = temporal_reverse_span_output_group_opcode("front") + boundary_group_opcode = temporal_reverse_span_output_group_opcode("boundary") + if group_opcode == front_group_opcode: + outputs = front_outputs + elif group_opcode == boundary_group_opcode: + outputs = boundary_outputs + else: + raise RuntimeError( + "Registered fused reverse parameter reducer route points at an unknown span group: " + f"route_kind={route_kind!r}; target_role={target_role!r}; group_opcode={int(group_opcode)}" + ) + return _reverse_span_output_tensor_by_role_id( + outputs, + output_rows, + group_opcode=int(group_opcode), + role_id=int(source_role_id), + route_label=(f"route_kind={route_kind!r}; target_role={target_role!r}; executor_row={int(executor.row_index)}"), + ) + + +@dataclass(frozen=True) +class _TransitionBoundaryReverseStepProgramTables: + program_tensor_groups: tuple[tuple[torch.Tensor, ...], ...] + program_tensor_binding_row_groups: tuple[torch.Tensor, ...] + forward_executor_row_groups: tuple[torch.Tensor, ...] + reverse_executor_row_groups: tuple[torch.Tensor, ...] + forward_executor_binding_row_groups: tuple[torch.Tensor, ...] + reverse_executor_binding_row_groups: tuple[torch.Tensor, ...] + memory_liveness_row_groups: tuple[torch.Tensor, ...] + seed_tensor_groups: tuple[tuple[torch.Tensor, ...], ...] + seed_row_groups: tuple[torch.Tensor, ...] + dynamic_binding_row_groups: tuple[torch.Tensor, ...] + output_keep_slot_row_groups: tuple[torch.Tensor, ...] + recurrent_msg_output_rows: torch.Tensor + public_y_seed_rows: torch.Tensor + transition_state_reset_rows: torch.Tensor + next_seed_output_rows: torch.Tensor + group_metadata: tuple[ + tuple[ + tuple[RegisteredTemporalExecutorHandle, ...], + Any, + dict[str, int], + tuple[TemporalTransitionParamGradBinding, ...], + ], + ..., + ] + + +_TRANSITION_DYNAMIC_SOURCE_REVERSE_ARTIFACT = 1 +_TRANSITION_DYNAMIC_SOURCE_STATE_BEFORE_ARTIFACT = 2 +_TRANSITION_DYNAMIC_SOURCE_SEED_OR_ZEROS = 3 + + +def _transition_input_binding_by_logical( + executor: RegisteredTemporalExecutorHandle, +) -> dict[str, TemporalExecutorTensorBinding]: + return {str(binding.logical_name): binding for binding in executor.bindings if binding.binding_kind == "input"} + + +def _transition_input_binding_by_logical_for_group( + executors: tuple[RegisteredTemporalExecutorHandle, ...], +) -> dict[str, TemporalExecutorTensorBinding]: + bindings: dict[str, TemporalExecutorTensorBinding] = {} + for executor in executors: + for logical_name, binding in _transition_input_binding_by_logical(executor).items(): + bindings.setdefault(str(logical_name), binding) + return bindings + + +def _transition_output_binding_logicals_for_group( + executors: tuple[RegisteredTemporalExecutorHandle, ...], +) -> set[str]: + return { + str(binding.logical_name) + for executor in executors + for binding in executor.bindings + if binding.binding_kind == "output" + } + + +def _transition_reverse_seed_source(binding: TemporalExecutorTensorBinding) -> str | None: + for source in binding.source_bindings: + source = str(source) + if source.startswith("reverse_seed:"): + return source.removeprefix("reverse_seed:") + logical_name = str(binding.logical_name) + if logical_name == "grad_public_y" or logical_name.startswith("grad_next_"): + return logical_name + return None + + +def _transition_reverse_seed_role_names_from_binding_plan( + reverse_binding_plan: TemporalExecutorBindingPlan, +) -> tuple[str, ...]: + role_names: list[str] = [] + for binding in reverse_binding_plan.bindings: + if binding.surface != "transition" or binding.binding_kind != "input": + continue + seed_role = _transition_reverse_seed_source(binding) + if seed_role is not None: + role_names.append(seed_role) + return tuple(dict.fromkeys(role_names)) + + +def _transition_reverse_seed_template_binding( + *, + logical_name: str, + state_template_by_logical: dict[str, int], + first_state_template: int | None, +) -> int: + if logical_name == "grad_public_y": + for state_name in ("y", "hc1"): + binding_index = state_template_by_logical.get(state_name) + if binding_index is not None: + return int(binding_index) + elif logical_name.startswith("grad_next_"): + state_name = logical_name.removeprefix("grad_next_") + binding_index = state_template_by_logical.get(state_name) + if binding_index is not None: + return int(binding_index) + if first_state_template is not None: + return int(first_state_template) + raise RuntimeError( + f"Registered transition reverse seed has no compiler-owned state template binding: logical={logical_name!r}" + ) + + +def _transition_dynamic_binding_rows_tensor( + *, + forward_executor: RegisteredTemporalExecutorHandle | tuple[RegisteredTemporalExecutorHandle, ...], + reverse_executor: RegisteredTemporalExecutorHandle | tuple[RegisteredTemporalExecutorHandle, ...], +) -> torch.Tensor: + forward_executors = forward_executor if isinstance(forward_executor, tuple) else (forward_executor,) + reverse_executors = reverse_executor if isinstance(reverse_executor, tuple) else (reverse_executor,) + forward_inputs = _transition_input_binding_by_logical_for_group(forward_executors) + reverse_inputs = _transition_input_binding_by_logical_for_group(reverse_executors) + forward_outputs = _transition_output_binding_logicals_for_group(forward_executors) + reverse_input_logicals = set(reverse_inputs) + rows: list[list[int]] = [] + + aggregated_binding = reverse_inputs.get("aggregated_message") or forward_inputs.get("aggregated_message") + if aggregated_binding is None: + raise RuntimeError( + "Registered transition reverse executor has no aggregated_message binding: " + f"bucket={int(reverse_executors[0].bucket_ordinal)}; " + f"executors={tuple(executor.executor_name for executor in reverse_executors)!r}" + ) + rows.append( + [ + int(aggregated_binding.binding_index), + _TRANSITION_DYNAMIC_SOURCE_REVERSE_ARTIFACT, + temporal_reverse_artifact_access_id("recurrent_msg_backend_order"), + -1, + 1, + ] + ) + + state_template_by_logical: dict[str, int] = {} + first_state_template: int | None = None + for logical_name, forward_binding in forward_inputs.items(): + if logical_name == "aggregated_message": + continue + if logical_name in forward_outputs: + if first_state_template is None: + first_state_template = int(forward_binding.binding_index) + continue + if logical_name not in reverse_input_logicals and f"next_{logical_name}" not in forward_outputs: + continue + binding_index = int(forward_binding.binding_index) + rows.append( + [ + binding_index, + _TRANSITION_DYNAMIC_SOURCE_STATE_BEFORE_ARTIFACT, + binding_index, + -1, + 0, + ] + ) + state_template_by_logical[str(logical_name)] = binding_index + if first_state_template is None: + first_state_template = binding_index + + if first_state_template is None: + first_state_template = int(aggregated_binding.binding_index) + + for _logical_name, reverse_binding in reverse_inputs.items(): + seed_role = _transition_reverse_seed_source(reverse_binding) + if seed_role is None: + continue + rows.append( + [ + int(reverse_binding.binding_index), + _TRANSITION_DYNAMIC_SOURCE_SEED_OR_ZEROS, + _transition_reverse_seed_role_id(seed_role), + _transition_reverse_seed_template_binding( + logical_name=seed_role, + state_template_by_logical=state_template_by_logical, + first_state_template=first_state_template, + ), + 0, + ] + ) + + return torch.tensor(rows, dtype=torch.long) if rows else torch.empty((0, 5), dtype=torch.long) + + +def _transition_executor_rows_for_group( + executors: tuple[RegisteredTemporalExecutorHandle, ...], +) -> torch.Tensor: + rows = [ + [ + int(executor.row.executor_id), + int(executor.row.primitive_row_start), + int(executor.row.primitive_row_count), + int(executor.bucket_ordinal), + int(executor.row.receiver_start), + int(executor.row.receiver_count), + ] + for executor in executors + ] + if not rows: + return torch.empty((0, 6), dtype=torch.long) + return torch.tensor(rows, dtype=torch.long) + + +def _transition_executor_binding_rows_for_group( + executors: tuple[RegisteredTemporalExecutorHandle, ...], + *, + direction_opcode: int, + materialize_optional_outputs: bool = True, +) -> torch.Tensor: + rows: list[list[int]] = [] + for local_executor_row, executor in enumerate(executors): + for binding in executor.bindings: + if ( + int(direction_opcode) == 1 + and not materialize_optional_outputs + and binding.binding_kind == "output" + and _is_optional_transition_forward_output(binding.primitive, binding.logical_name) + ): + continue + rows.append( + [ + int(direction_opcode), + int(local_executor_row), + int(binding.executor_id), + int(binding.primitive_row_index), + int(binding.binding_index), + int(binding.bucket_ordinal), + 0 if binding.binding_kind == "input" else 1 if binding.binding_kind == "parameter" else 2, + int(binding.local_binding_index), + ] + ) + if not rows: + raise RuntimeError("Registered fused transition executor group requires compiler tensor bindings") + return torch.tensor(rows, dtype=torch.long) + + +def _transition_memory_liveness_rows_for_group( + executors: tuple[RegisteredTemporalExecutorHandle, ...], + memory_liveness_rows: torch.Tensor, +) -> torch.Tensor: + primitive_ranges = tuple( + ( + int(executor.row.primitive_row_start), + int(executor.row.primitive_row_start) + int(executor.row.primitive_row_count), + int(executor.bucket_ordinal), + ) + for executor in executors + ) + rows: list[list[int]] = [] + for row in memory_liveness_rows.cpu().tolist(): + primitive_row_index = int(row[1]) + bucket_ordinal = int(row[2]) + if primitive_row_index == -1: + include = any(bucket_ordinal == group_bucket for _start, _stop, group_bucket in primitive_ranges) + else: + include = any( + bucket_ordinal == group_bucket and start <= primitive_row_index < stop + for start, stop, group_bucket in primitive_ranges + ) + if not include: + continue + selected = [int(value) for value in row] + selected[0] = len(rows) + rows.append(selected) + if not rows: + raise RuntimeError("Registered fused transition executor group requires compiler memory-liveness rows") + return torch.tensor(rows, dtype=torch.long) + + +@dataclass(frozen=True) +class RegisteredTemporalExecutorProgram: + primitive_table_fingerprint: tuple[str, ...] + bucket_count: int + primitive_rows: torch.Tensor + forward_handler_rows: torch.Tensor + reverse_handler_rows: torch.Tensor + native_strategy_rows: torch.Tensor + native_callable_catalog_rows: torch.Tensor + native_callable_binding_schema_rows: torch.Tensor + native_callable_output_rows: torch.Tensor + reverse_span_output_rows: torch.Tensor + reverse_output_route_rows: torch.Tensor + forward_artifact_route_rows: torch.Tensor + forward_artifact_route_summaries: tuple[str, ...] + forward_artifact_merge_rows: torch.Tensor + forward_artifact_merge_summaries: tuple[str, ...] + forward_output_route_rows: torch.Tensor + forward_output_route_summaries: tuple[str, ...] + readout_message_producer_consumer_rows: torch.Tensor + readout_message_producer_consumer_summaries: tuple[str, ...] + readout_message_producer_consumer_template_rows: torch.Tensor + readout_message_producer_consumer_template_summaries: tuple[str, ...] + message_transition_producer_consumer_rows: torch.Tensor + message_transition_producer_consumer_summaries: tuple[str, ...] + message_transition_producer_consumer_template_rows: torch.Tensor + message_transition_producer_consumer_template_summaries: tuple[str, ...] + reverse_artifact_consumer_route_rows: torch.Tensor + reverse_artifact_consumer_route_summaries: tuple[str, ...] + reverse_parameter_reducer_route_rows: torch.Tensor + reverse_parameter_reducer_route_summaries: tuple[str, ...] + transition_reverse_seed_role_rows: torch.Tensor + transition_primitive_callable_rows: torch.Tensor + forward_plan: TemporalForwardExecutablePlan + backward_plan: TemporalBackwardExecutablePlan + memory_plan: TemporalMemoryLivenessPlan + memory_liveness_rows: torch.Tensor + physical_strategy_template_rows: torch.Tensor + physical_strategy_template_summaries: tuple[str, ...] + reverse_program_stage_plan: TemporalReverseProgramStagePlan + reverse_program_stage_rows: torch.Tensor + reverse_program_stage_summaries: tuple[str, ...] + transition_param_grad_bindings: tuple[TemporalTransitionParamGradBinding, ...] + transition_param_grad_binding_rows: torch.Tensor + transition_param_grad_binding_summaries: tuple[str, ...] + fused_cuda_program_plan: TemporalFusedCudaProgramPlan + program_executor_plan: TemporalRegisteredProgramExecutorPlan + kernel_registry: RegisteredTemporalExecutorKernelRegistry + forward_handles: tuple[RegisteredTemporalExecutorHandle, ...] + reverse_handles: tuple[RegisteredTemporalExecutorHandle, ...] + + def forward_handle(self, *, surface: str, bucket_ordinal: int) -> RegisteredTemporalExecutorHandle: + return _select_registered_executor_handle( + self.forward_handles, + direction="forward", + surface=surface, + bucket_ordinal=bucket_ordinal, + ) + + def reverse_handle(self, *, surface: str, bucket_ordinal: int) -> RegisteredTemporalExecutorHandle: + return _select_registered_executor_handle( + self.reverse_handles, + direction="reverse", + surface=surface, + bucket_ordinal=bucket_ordinal, + ) + + def forward_handles_for(self, *, surface: str, bucket_ordinal: int) -> tuple[RegisteredTemporalExecutorHandle, ...]: + return _select_registered_executor_handles( + self.forward_handles, + direction="forward", + surface=surface, + bucket_ordinal=bucket_ordinal, + ) + + def reverse_handles_for(self, *, surface: str, bucket_ordinal: int) -> tuple[RegisteredTemporalExecutorHandle, ...]: + return _select_registered_executor_handles( + self.reverse_handles, + direction="reverse", + surface=surface, + bucket_ordinal=bucket_ordinal, + ) + + def forward_surface_handles(self, *, surface: str) -> tuple[RegisteredTemporalExecutorHandle, ...]: + return _select_registered_surface_executor_handles( + self.forward_handles, + direction="forward", + surface=surface, + ) + + def reverse_surface_handles(self, *, surface: str) -> tuple[RegisteredTemporalExecutorHandle, ...]: + return _select_registered_surface_executor_handles( + self.reverse_handles, + direction="reverse", + surface=surface, + ) + + def transition_bucket_ordinals(self) -> tuple[int, ...]: + return tuple(range(int(self.bucket_count))) + + def require_transition_forward_coverage(self) -> None: + transition_buckets = { + int(handle.bucket_ordinal) for handle in self.forward_handles if handle.surface == "transition" + } + expected = set(self.transition_bucket_ordinals()) + if transition_buckets != expected: + raise RuntimeError( + "Registered temporal executor transition coverage is inconsistent: " + f"expected={tuple(sorted(expected))}; actual={tuple(sorted(transition_buckets))}" + ) + + def require_transition_reverse_coverage(self) -> None: + transition_buckets = { + int(handle.bucket_ordinal) for handle in self.reverse_handles if handle.surface == "transition" + } + expected = set(self.transition_bucket_ordinals()) + if transition_buckets != expected: + raise RuntimeError( + "Registered temporal reverse executor transition coverage is inconsistent: " + f"expected={tuple(sorted(expected))}; actual={tuple(sorted(transition_buckets))}" + ) + + +def _producer_consumer_rows_select_streaming_strategy(rows: torch.Tensor) -> bool: + if rows.device.type != "cpu" or rows.dtype != torch.long or rows.dim() != 2 or int(rows.shape[1]) != 16: + return False + if int(rows.shape[0]) == 0: + return False + for row in rows.tolist(): + strategy_opcode = int(row[2]) + status_opcode = int(row[3]) + executable = int(row[4]) + if strategy_opcode == 2 and status_opcode == 1 and executable == 1: + return True + return False + + +def _streaming_step_producer_consumer_body_available( + executor_program: RegisteredTemporalExecutorProgram, +) -> bool: + return _producer_consumer_rows_select_streaming_strategy( + executor_program.readout_message_producer_consumer_rows + ) or _producer_consumer_rows_select_streaming_strategy(executor_program.message_transition_producer_consumer_rows) + + +def build_registered_temporal_executor_program( + runtime: Any, + static_tensors: dict[str, object], + *, + table_plan: TemporalPrimitiveTablePlan | None = None, + forward_plan: TemporalForwardExecutablePlan | None = None, + backward_plan: TemporalBackwardExecutablePlan | None = None, + streaming_readout_body_available: bool = False, + streaming_readout_body_profitable: bool = True, +) -> RegisteredTemporalExecutorProgram: + table_plan = build_temporal_primitive_table_plan(runtime, static_tensors) if table_plan is None else table_plan + forward_binding_plan = build_temporal_forward_executor_binding_plan(table_plan) + reverse_binding_plan = build_temporal_reverse_executor_binding_plan(table_plan) + transition_param_grad_binding_plan = build_temporal_transition_param_grad_binding_plan( + table_plan, + reverse_binding_plan=reverse_binding_plan, + ) + forward_plan = ( + build_temporal_forward_executable_plan(table_plan, forward_binding_plan=forward_binding_plan) + if forward_plan is None + else forward_plan + ) + backward_plan = ( + build_temporal_backward_executable_plan(table_plan, reverse_binding_plan=reverse_binding_plan) + if backward_plan is None + else backward_plan + ) + memory_plan = build_temporal_memory_liveness_plan(table_plan) + physical_strategy_template_plan = build_temporal_physical_strategy_plan( + build_temporal_memory_runtime_schedule_plan( + memory_plan, + physical_time_steps=1, + collect_artifacts=True, + scheduler_plan=None, + ), + inner_steps=1, + output_boundary="terminal", + reset_policy="absent", + ) + physical_strategy_template_rows = temporal_physical_strategy_rows_tensor(physical_strategy_template_plan) + primitive_rows = temporal_primitive_rows_tensor(table_plan) + memory_liveness_rows = temporal_memory_liveness_rows_tensor(memory_plan) + forward_handler_rows = temporal_forward_executor_handler_rows_tensor(table_plan) + reverse_handler_rows = temporal_reverse_executor_handler_rows_tensor(table_plan) + native_strategy_rows = temporal_native_executor_strategy_rows_tensor() + native_callable_catalog_rows = temporal_native_callable_catalog_rows_tensor() + native_callable_binding_schema_rows = temporal_native_callable_binding_schema_rows_tensor() + native_callable_output_rows = temporal_native_callable_output_rows_tensor() + reverse_span_output_rows = temporal_reverse_span_output_rows_tensor() + reverse_output_route_rows = temporal_reverse_output_route_rows_tensor() + forward_artifact_route_rows = temporal_forward_artifact_route_rows_tensor(table_plan) + forward_artifact_merge_rows = temporal_forward_artifact_merge_rows_tensor(table_plan) + forward_output_route_rows = temporal_forward_output_route_rows_tensor(table_plan) + readout_message_producer_consumer_plan = build_temporal_readout_message_producer_consumer_plan( + table_plan, + streaming_readout_body_available=bool(streaming_readout_body_available), + streaming_readout_body_profitable=bool(streaming_readout_body_profitable), + ) + readout_message_producer_consumer_rows = temporal_readout_message_producer_consumer_rows_tensor( + readout_message_producer_consumer_plan + ) + readout_message_producer_consumer_template_plan = build_temporal_readout_message_producer_consumer_plan( + table_plan, + streaming_readout_body_available=True, + streaming_readout_body_profitable=True, + ) + readout_message_producer_consumer_template_rows = temporal_readout_message_producer_consumer_rows_tensor( + readout_message_producer_consumer_template_plan + ) + message_transition_producer_consumer_plan = build_temporal_message_transition_producer_consumer_plan( + table_plan, + streaming_transition_body_available=bool(streaming_readout_body_available), + streaming_transition_body_profitable=True, + ) + message_transition_producer_consumer_rows = temporal_message_transition_producer_consumer_rows_tensor( + message_transition_producer_consumer_plan + ) + message_transition_producer_consumer_template_plan = build_temporal_message_transition_producer_consumer_plan( + table_plan, + streaming_transition_body_available=True, + streaming_transition_body_profitable=True, + ) + message_transition_producer_consumer_template_rows = temporal_message_transition_producer_consumer_rows_tensor( + message_transition_producer_consumer_template_plan + ) + reverse_artifact_consumer_route_rows = temporal_reverse_artifact_consumer_route_rows_tensor(table_plan) + reverse_parameter_reducer_route_rows = temporal_reverse_parameter_reducer_route_rows_tensor(table_plan) + transition_reverse_seed_role_rows = temporal_transition_reverse_seed_role_rows_tensor( + _transition_reverse_seed_role_names_from_binding_plan(reverse_binding_plan) + ) + transition_primitive_callable_rows = temporal_transition_primitive_native_callable_rows_tensor() + reverse_program_stage_plan = build_temporal_reverse_program_stage_plan(table_plan, backward_plan) + fused_cuda_program_plan = build_temporal_fused_cuda_program_plan( + primitive_rows=primitive_rows, + forward_plan=forward_plan, + backward_plan=backward_plan, + memory_plan=memory_plan, + memory_liveness_rows=memory_liveness_rows, + forward_handler_rows=forward_handler_rows, + reverse_handler_rows=reverse_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_binding_schema_rows=native_callable_binding_schema_rows, + native_callable_output_rows=native_callable_output_rows, + transition_reverse_seed_role_rows=transition_reverse_seed_role_rows, + transition_primitive_callable_rows=transition_primitive_callable_rows, + reverse_output_route_rows=reverse_output_route_rows, + forward_artifact_route_rows=forward_artifact_route_rows, + forward_artifact_merge_rows=forward_artifact_merge_rows, + forward_output_route_rows=forward_output_route_rows, + readout_message_producer_consumer_rows=readout_message_producer_consumer_rows, + message_transition_producer_consumer_rows=message_transition_producer_consumer_rows, + reverse_artifact_consumer_route_rows=reverse_artifact_consumer_route_rows, + reverse_parameter_reducer_route_rows=reverse_parameter_reducer_route_rows, + ) + program_executor_plan = build_temporal_registered_program_executor_plan(fused_cuda_program_plan) + _validate_registered_executable_plan( + table_plan=table_plan, + forward_plan=forward_plan, + backward_plan=backward_plan, + forward_binding_plan=forward_binding_plan, + reverse_binding_plan=reverse_binding_plan, + ) + forward_handles = _executor_handles( + direction="forward", + rows=temporal_forward_executor_rows(table_plan), + binding_plan=forward_binding_plan, + ) + reverse_handles = _executor_handles( + direction="reverse", + rows=temporal_reverse_executor_rows(table_plan), + binding_plan=reverse_binding_plan, + ) + kernel_registry = registered_temporal_executor_kernel_registry() + program = RegisteredTemporalExecutorProgram( + primitive_table_fingerprint=tuple(table_plan.fingerprint), + bucket_count=int(table_plan.bucket_count), + primitive_rows=primitive_rows, + forward_handler_rows=forward_handler_rows, + reverse_handler_rows=reverse_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_catalog_rows=native_callable_catalog_rows, + native_callable_binding_schema_rows=native_callable_binding_schema_rows, + native_callable_output_rows=native_callable_output_rows, + reverse_span_output_rows=reverse_span_output_rows, + reverse_output_route_rows=reverse_output_route_rows, + forward_artifact_route_rows=forward_artifact_route_rows, + forward_artifact_route_summaries=temporal_forward_artifact_route_summaries(table_plan), + forward_artifact_merge_rows=forward_artifact_merge_rows, + forward_artifact_merge_summaries=temporal_forward_artifact_merge_summaries(table_plan), + forward_output_route_rows=forward_output_route_rows, + forward_output_route_summaries=temporal_forward_output_route_summaries(table_plan), + readout_message_producer_consumer_rows=readout_message_producer_consumer_rows, + readout_message_producer_consumer_summaries=temporal_readout_message_producer_consumer_summaries( + readout_message_producer_consumer_plan + ), + readout_message_producer_consumer_template_rows=readout_message_producer_consumer_template_rows, + readout_message_producer_consumer_template_summaries=( + readout_message_producer_consumer_template_plan.review_summary + ), + message_transition_producer_consumer_rows=message_transition_producer_consumer_rows, + message_transition_producer_consumer_summaries=temporal_message_transition_producer_consumer_summaries( + message_transition_producer_consumer_plan + ), + message_transition_producer_consumer_template_rows=message_transition_producer_consumer_template_rows, + message_transition_producer_consumer_template_summaries=( + message_transition_producer_consumer_template_plan.review_summary + ), + reverse_artifact_consumer_route_rows=reverse_artifact_consumer_route_rows, + reverse_artifact_consumer_route_summaries=temporal_reverse_artifact_consumer_route_summaries(table_plan), + reverse_parameter_reducer_route_rows=reverse_parameter_reducer_route_rows, + reverse_parameter_reducer_route_summaries=temporal_reverse_parameter_reducer_route_summaries(table_plan), + transition_reverse_seed_role_rows=transition_reverse_seed_role_rows, + transition_primitive_callable_rows=transition_primitive_callable_rows, + forward_plan=forward_plan, + backward_plan=backward_plan, + memory_plan=memory_plan, + memory_liveness_rows=memory_liveness_rows, + physical_strategy_template_rows=physical_strategy_template_rows, + physical_strategy_template_summaries=physical_strategy_template_plan.review_summary, + reverse_program_stage_plan=reverse_program_stage_plan, + reverse_program_stage_rows=reverse_program_stage_plan.rows, + reverse_program_stage_summaries=reverse_program_stage_plan.summaries, + transition_param_grad_bindings=transition_param_grad_binding_plan.bindings, + transition_param_grad_binding_rows=transition_param_grad_binding_plan.rows, + transition_param_grad_binding_summaries=transition_param_grad_binding_plan.summaries, + fused_cuda_program_plan=fused_cuda_program_plan, + program_executor_plan=program_executor_plan, + kernel_registry=kernel_registry, + forward_handles=forward_handles, + reverse_handles=reverse_handles, + ) + program.kernel_registry.require_forward_messages(program.forward_surface_handles(surface="message")) + program.kernel_registry.require_reverse_messages(program.reverse_surface_handles(surface="message")) + program.kernel_registry.require_forward_readouts(program.forward_surface_handles(surface="readout")) + program.kernel_registry.require_reverse_readouts(program.reverse_surface_handles(surface="readout")) + program.require_transition_forward_coverage() + program.require_transition_reverse_coverage() + program.kernel_registry.require_forward_transitions( + tuple( + handle + for bucket_index in program.transition_bucket_ordinals() + for handle in program.forward_handles_for(surface="transition", bucket_ordinal=int(bucket_index)) + ) + ) + program.kernel_registry.require_reverse_transitions( + tuple( + handle + for bucket_index in program.transition_bucket_ordinals() + for handle in program.reverse_handles_for(surface="transition", bucket_ordinal=int(bucket_index)) + ) + ) + return program + + +def _transition_forward_runtime_buffer_requests( + runtime: Any, + executor_program: RegisteredTemporalExecutorProgram, + *, + batch_size: int, + materialize_optional_outputs: bool = True, + forward_program_access_rows: torch.Tensor | None = None, +) -> tuple[TemporalTransitionForwardRuntimeBufferRequest, ...]: + hidden = int(runtime.hidden_size) + public_state_output_binding_indices = _transition_public_state_output_binding_indices(forward_program_access_rows) + requests: list[TemporalTransitionForwardRuntimeBufferRequest] = [] + seen: set[tuple[str, int]] = set() + for bucket_ordinal in executor_program.transition_bucket_ordinals(): + for handle in executor_program.forward_handles_for(surface="transition", bucket_ordinal=int(bucket_ordinal)): + output_index_by_binding = _transition_forward_output_index_by_binding(handle.bindings) + for binding in handle.bindings: + if binding.binding_kind != "output": + continue + if not materialize_optional_outputs and _is_optional_transition_forward_output( + binding.primitive, binding.logical_name + ): + continue + output_contract = temporal_native_callable_transition_forward_output_definition( + primitive=binding.primitive, + output_name=binding.logical_name, + output_index=output_index_by_binding.get(int(binding.binding_index)), + ) + logical_name = str(binding.logical_name) + logical_index = output_contract.logical_index( + primitive_row_index=int(binding.primitive_row_index), + binding_index=int(binding.binding_index), + ) + runtime_role = output_contract.runtime_role + seen_key = (runtime_role, logical_index) + if seen_key in seen: + continue + seen.add(seen_key) + receiver_count = int(binding.receiver_count) + shape = output_contract.shape( + batch_size=int(batch_size), + receiver_count=receiver_count, + hidden_size=hidden, + ) + requests.append( + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=int(binding.primitive_row_index), + bucket_ordinal=int(binding.bucket_ordinal), + logical_name=logical_name, + shape=shape, + runtime_role=runtime_role, + logical_index=logical_index, + alias_runtime_role=( + "forward_recurrent_hidden_after" + if int(binding.binding_index) in public_state_output_binding_indices + else "" + ), + ) + ) + return tuple(requests) + + +def _transition_reverse_dynamic_runtime_buffer_requests( + runtime: Any, + executor_program: RegisteredTemporalExecutorProgram, + *, + batch_size: int, + local_time_steps: int, + transition_step_tables: _TransitionBoundaryReverseStepProgramTables, + reverse_artifact_binding_rows: torch.Tensor, +) -> tuple[TemporalTransitionReverseRuntimeBufferRequest, ...]: + hidden = int(runtime.hidden_size) + message_dim = int(temporal_message_output_dim(runtime)) + recurrent_count = int(runtime.recurrent_cell_idx.numel()) + requests: list[TemporalTransitionReverseRuntimeBufferRequest] = [] + seen: set[tuple[str, int]] = set() + for bucket_ordinal in executor_program.transition_bucket_ordinals(): + handles = executor_program.reverse_handles_for(surface="transition", bucket_ordinal=int(bucket_ordinal)) + receiver_start = min((int(handle.row.receiver_start) for handle in handles), default=0) + receiver_stop = max( + (int(handle.row.receiver_start) + int(handle.row.receiver_count) for handle in handles), + default=0, + ) + receiver_count = int(receiver_stop) - int(receiver_start) + if receiver_count <= 0: + continue + if int(receiver_start) != 0 or int(receiver_count) != recurrent_count: + runtime_role = "transition_reverse_recurrent_msg_span" + seen_key = (runtime_role, int(bucket_ordinal)) + if seen_key not in seen: + seen.add(seen_key) + requests.append( + TemporalTransitionReverseRuntimeBufferRequest( + bucket_ordinal=int(bucket_ordinal), + logical_name="aggregated_message", + shape=(int(batch_size), int(receiver_count), message_dim), + runtime_role=runtime_role, + effect="message_read", + logical_index=int(bucket_ordinal), + ) + ) + if _transition_reverse_state_before_zero_buffer_required( + transition_step_tables=transition_step_tables, + reverse_artifact_binding_rows=reverse_artifact_binding_rows, + bucket_ordinal=int(bucket_ordinal), + local_time_steps=int(local_time_steps), + ): + runtime_role = "transition_reverse_state_before_zero" + seen_key = (runtime_role, int(bucket_ordinal)) + if seen_key in seen: + continue + seen.add(seen_key) + requests.append( + TemporalTransitionReverseRuntimeBufferRequest( + bucket_ordinal=int(bucket_ordinal), + logical_name="state_before_zero", + shape=(int(batch_size), int(receiver_count), hidden), + runtime_role=runtime_role, + effect="state_read", + logical_index=int(bucket_ordinal), + ) + ) + return tuple(requests) + + +def _transition_reverse_state_before_zero_buffer_required( + *, + transition_step_tables: _TransitionBoundaryReverseStepProgramTables, + reverse_artifact_binding_rows: torch.Tensor, + bucket_ordinal: int, + local_time_steps: int, +) -> bool: + transition_state_before_role = int(temporal_reverse_artifact_role_id("transition_state_before")) + existing_transition_state_flags = { + (int(row[2]), int(row[3])) + for row in reverse_artifact_binding_rows.cpu().tolist() + if int(row[0]) == transition_state_before_role + } + for group_index, (reverse_executor_group, _bucket, _reverse_logical_to_slot, _param_bindings) in enumerate( + transition_step_tables.group_metadata + ): + if not reverse_executor_group or int(reverse_executor_group[0].bucket_ordinal) != int(bucket_ordinal): + continue + dynamic_rows = transition_step_tables.dynamic_binding_row_groups[int(group_index)].cpu().tolist() + for row in dynamic_rows: + if int(row[1]) != _TRANSITION_DYNAMIC_SOURCE_STATE_BEFORE_ARTIFACT: + continue + if int(row[4]) != 0: + continue + expected_flags = int(bucket_ordinal) * 1_000_000 + int(row[2]) + for local_step in range(max(1, int(local_time_steps))): + if (int(local_step), expected_flags) not in existing_transition_state_flags: + return True + return False + + +def _transition_public_state_output_binding_indices( + forward_program_access_rows: torch.Tensor | None, +) -> set[int]: + if not torch.is_tensor(forward_program_access_rows) or int(forward_program_access_rows.numel()) == 0: + return set() + if forward_program_access_rows.dim() != 2 or int(forward_program_access_rows.shape[1]) != 6: + raise RuntimeError("Registered temporal forward program access rows must be a 6-column tensor") + public_state_opcode = temporal_program_access_opcode("transition_public_state_output") + return { + int(row[3]) + for row in forward_program_access_rows.to(device="cpu", dtype=torch.long).tolist() + if int(row[5]) == int(public_state_opcode) + } + + +def _transition_forward_optional_outputs_required( + *, + physical_time_steps: int, + materialize_final_state: bool, + grad_carry_cells: torch.Tensor | None = None, +) -> bool: + if int(physical_time_steps) > 1: + return True + if bool(materialize_final_state): + return True + return torch.is_tensor(grad_carry_cells) + + +def _forward_executor_binding_rows_for_runtime( + executor_program: RegisteredTemporalExecutorProgram, + *, + materialize_optional_transition_outputs: bool, +) -> torch.Tensor: + rows = executor_program.forward_plan.executor_binding_rows + if materialize_optional_transition_outputs: + return rows + omitted_keys = _optional_transition_forward_output_binding_row_keys(executor_program) + if not omitted_keys: + return rows + filtered_rows = [ + [int(value) for value in row] + for row in rows.cpu().tolist() + if ( + int(row[1]), + int(row[2]), + int(row[3]), + int(row[4]), + int(row[5]), + int(row[6]), + int(row[7]), + ) + not in omitted_keys + ] + if len(filtered_rows) == int(rows.shape[0]): + return rows + return torch.tensor(filtered_rows, dtype=torch.long) + + +def _optional_transition_forward_output_binding_indices( + executor_program: RegisteredTemporalExecutorProgram, +) -> set[int]: + return { + int(binding.binding_index) + for handle in executor_program.forward_surface_handles(surface="transition") + for binding in handle.bindings + if binding.binding_kind == "output" + and _is_optional_transition_forward_output(binding.primitive, binding.logical_name) + } + + +def _optional_transition_forward_output_binding_row_keys( + executor_program: RegisteredTemporalExecutorProgram, +) -> set[tuple[int, int, int, int, int, int, int]]: + return { + ( + int(binding.executor_row_index), + int(binding.executor_id), + int(binding.primitive_row_index), + int(binding.binding_index), + int(binding.bucket_ordinal), + 2, + int(binding.local_binding_index), + ) + for handle in executor_program.forward_surface_handles(surface="transition") + for binding in handle.bindings + if binding.binding_kind == "output" + and _is_optional_transition_forward_output(binding.primitive, binding.logical_name) + } + + +def _is_optional_transition_forward_output(primitive: str, logical_name: str) -> bool: + record = transition_primitive_executor_record_for_lowered_primitive(str(primitive)) + if record is None: + return False + return any( + str(output_name) == str(logical_name) and not bool(required) + for output_name, required in record.program_forward_output_bindings + ) + + +def _transition_forward_output_index_by_binding( + bindings: tuple[TemporalExecutorTensorBinding, ...], +) -> dict[int, int]: + output_index_by_binding: dict[int, int] = {} + primitive_rows = tuple( + dict.fromkeys(int(binding.primitive_row_index) for binding in bindings if binding.binding_kind == "output") + ) + for primitive_row_index in primitive_rows: + output_bindings = sorted( + ( + binding + for binding in bindings + if binding.binding_kind == "output" and int(binding.primitive_row_index) == int(primitive_row_index) + ), + key=lambda binding: int(binding.local_binding_index), + ) + for output_index, binding in enumerate(output_bindings): + output_index_by_binding[int(binding.binding_index)] = int(output_index) + return output_index_by_binding + + +def _validate_registered_executable_plan( + *, + table_plan: TemporalPrimitiveTablePlan, + forward_plan: TemporalForwardExecutablePlan, + backward_plan: TemporalBackwardExecutablePlan, + forward_binding_plan: TemporalExecutorBindingPlan, + reverse_binding_plan: TemporalExecutorBindingPlan, +) -> None: + if forward_plan.strategy_legality_status != "legal" or backward_plan.strategy_legality_status != "legal": + raise RuntimeError( + "Registered temporal executor requires legal forward and backward executable plans: " + f"forward={forward_plan.strategy_legality_status}; backward={backward_plan.strategy_legality_status}" + ) + if not torch.equal(forward_plan.executor_binding_rows.cpu(), forward_binding_plan.rows.cpu()): + raise RuntimeError("Registered temporal forward executable plan has stale executor binding rows") + if not torch.equal(backward_plan.executor_binding_rows.cpu(), reverse_binding_plan.rows.cpu()): + raise RuntimeError("Registered temporal backward executable plan has stale executor binding rows") + if tuple(table_plan.fingerprint) == (): + raise RuntimeError("Registered temporal executor requires a primitive table fingerprint") + + +def _executor_handles( + *, + direction: Literal["forward", "reverse"], + rows: tuple[TemporalForwardExecutorRow | TemporalReverseExecutorRow, ...], + binding_plan: TemporalExecutorBindingPlan, +) -> tuple[RegisteredTemporalExecutorHandle, ...]: + handles: list[RegisteredTemporalExecutorHandle] = [] + for row_index, row in enumerate(rows): + bindings = tuple( + binding for binding in binding_plan.bindings if int(binding.executor_row_index) == int(row_index) + ) + handle = RegisteredTemporalExecutorHandle( + direction=direction, + row_index=int(row_index), + row=row, + bindings=bindings, + ) + handle.require_parameter_bindings() + handles.append(handle) + return tuple(handles) + + +def _select_registered_executor_handle( + handles: tuple[RegisteredTemporalExecutorHandle, ...], + *, + direction: Literal["forward", "reverse"], + surface: str, + bucket_ordinal: int, +) -> RegisteredTemporalExecutorHandle: + matches = tuple( + handle for handle in handles if handle.surface == surface and int(handle.bucket_ordinal) == int(bucket_ordinal) + ) + if len(matches) != 1: + raise RuntimeError( + "Registered temporal executor program has no unique executor row: " + f"direction={direction}; surface={surface}; bucket={int(bucket_ordinal)}; count={len(matches)}" + ) + return matches[0] + + +def _select_registered_executor_handles( + handles: tuple[RegisteredTemporalExecutorHandle, ...], + *, + direction: Literal["forward", "reverse"], + surface: str, + bucket_ordinal: int, +) -> tuple[RegisteredTemporalExecutorHandle, ...]: + matches = tuple( + sorted( + ( + handle + for handle in handles + if handle.surface == surface and int(handle.bucket_ordinal) == int(bucket_ordinal) + ), + key=lambda handle: int(handle.row_index), + ) + ) + if not matches: + raise RuntimeError( + "Registered temporal executor program has no executor rows: " + f"direction={direction}; surface={surface}; bucket={int(bucket_ordinal)}" + ) + return matches + + +def _select_registered_surface_executor_handles( + handles: tuple[RegisteredTemporalExecutorHandle, ...], + *, + direction: Literal["forward", "reverse"], + surface: str, +) -> tuple[RegisteredTemporalExecutorHandle, ...]: + matches = tuple( + sorted( + (handle for handle in handles if handle.surface == surface), + key=lambda handle: int(handle.row_index), + ) + ) + if not matches: + raise RuntimeError( + "Registered temporal executor program has no surface executor rows: " + f"direction={direction}; surface={surface}; count={len(matches)}" + ) + return matches + + +def _population_grad_dict(grad_state: TensorDict) -> dict[str, dict[str, torch.Tensor | None]]: + population_grads: dict[str, dict[str, torch.Tensor | None]] = {} + for population_name, population_grad in grad_state.items(): + if not isinstance(population_grad, TensorDictBase): + continue + population_grads[population_name] = { + key: cast(torch.Tensor | None, value) if torch.is_tensor(value) else None + for key, value in population_grad.items() + } + return population_grads + + +def _bucket_population_params(bucket: Any) -> dict[str, object] | None: + population_materialized = bucket.static_tensors.get("population_materialized") + if not isinstance(population_materialized, dict): + return None + params = population_materialized.get(_flat_bucket_name(bucket)) + return params if isinstance(params, dict) else None + + +def _make_temporal_artifact_checkpoint( + *, + step_index: int, + state: TensorDict, + population_state_cache: dict[str, object] | None, + recurrent_k: torch.Tensor | None, + recurrent_v: torch.Tensor | None, + recurrent_kv_layout: Literal["graph_order", "backend_order"] | None = None, +) -> TemporalArtifactCheckpoint: + resolved_layout = ( + None + if recurrent_k is None or recurrent_v is None + else recurrent_kv_layout + if recurrent_kv_layout is not None + else "graph_order" + ) + return TemporalArtifactCheckpoint( + step_index=int(step_index), + state=TensorDict(state.to_dict(), batch_size=[]), + population_state_cache=dict(population_state_cache) if population_state_cache is not None else None, + recurrent_k=recurrent_k, + recurrent_v=recurrent_v, + recurrent_kv_layout=resolved_layout, + ) + + +def _temporal_forward_output_for_contract( + runtime: Any, + artifacts: TemporalBucketStepArtifacts, + output_contract: TemporalOutputContract, +) -> torch.Tensor: + if output_contract == "full_cells": + return artifacts.cells_out + if output_contract == "output_cells": + return artifacts.output_cells + if output_contract == "pooled_output_cells": + return runtime._pool_output_ports(artifacts.output_cells.unsqueeze(1)).squeeze(1) + raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") + + +def _temporal_forward_output_step_shape_for_contract( + runtime: Any, + boundary_seq: torch.Tensor, + output_contract: TemporalOutputContract, +) -> tuple[int, ...]: + batch_size = int(boundary_seq.shape[0]) + hidden = int(runtime.hidden_size) + if output_contract == "full_cells": + return (batch_size, int(runtime.coords.shape[0]), hidden) + output_count = int(runtime.output_cell_idx.numel()) + if output_contract == "output_cells": + return (batch_size, output_count, hidden) + if output_contract == "pooled_output_cells": + readout_pool = str(runtime.config.readout.pool) + if readout_pool == "mean": + return (batch_size, 1, hidden) + if readout_pool == "flatten": + return (batch_size, output_count, hidden) + return (batch_size, int(runtime.readout_query.shape[0]), hidden) + raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") + + +def _initial_recurrent_hidden_backend_order_for_fused_program( + runtime: Any, + boundary_seq: torch.Tensor, + state: TensorDict, +) -> torch.Tensor: + cells = state.get("cells") + recurrent_count = int(runtime.recurrent_cell_idx.numel()) + hidden = int(boundary_seq.shape[-1]) + if not torch.is_tensor(cells): + return boundary_seq.new_zeros(int(boundary_seq.shape[0]), recurrent_count, hidden) + return ( + cells[:, runtime._recurrent_slice, :].index_select(1, runtime.population_backend_recurrent_order).contiguous() + ) + + +def _program_tensor_slot_by_binding_index(program_tensor_binding_rows: torch.Tensor) -> dict[int, int]: + return {int(row[0]): int(row[1]) for row in program_tensor_binding_rows.cpu().tolist()} + + +def _tensor_logical_bytes(tensor: torch.Tensor) -> int: + return int(tensor.numel()) * int(tensor.element_size()) + + +def _reverse_artifact_tensor_store_review_summary( + *, + tensors: tuple[torch.Tensor, ...], + binding_rows: torch.Tensor, +) -> tuple[str, ...]: + role_name_by_id = { + int(temporal_reverse_artifact_role_id(role_name)): str(role_name) + for role_name in temporal_reverse_artifact_role_names() + } + counts_by_role: dict[str, int] = {} + logical_bytes_by_role: dict[str, int] = {} + unique_storage_keys: set[tuple[int, int, str]] = set() + unique_storage_bytes = 0 + for role_id, tensor_index, _physical_step, _flags, _route_row in binding_rows.cpu().tolist(): + role_name = role_name_by_id.get(int(role_id), f"role_{int(role_id)}") + if int(tensor_index) < 0 or int(tensor_index) >= len(tensors): + continue + tensor = tensors[int(tensor_index)] + if not torch.is_tensor(tensor): + continue + counts_by_role[role_name] = counts_by_role.get(role_name, 0) + 1 + logical_bytes_by_role[role_name] = logical_bytes_by_role.get(role_name, 0) + _tensor_logical_bytes(tensor) + try: + storage_key = ( + int(tensor.untyped_storage().data_ptr()), + int(tensor.untyped_storage().nbytes()), + str(tensor.device), + ) + except RuntimeError: + storage_key = (int(tensor.data_ptr()), _tensor_logical_bytes(tensor), str(tensor.device)) + if storage_key not in unique_storage_keys: + unique_storage_keys.add(storage_key) + unique_storage_bytes += int(storage_key[1]) + logical_total = sum(int(value) for value in logical_bytes_by_role.values()) + return ( + "reverse_artifact_tensor_store=compiler_executable", + f"tensor_count={len(tensors)}", + f"binding_rows={int(binding_rows.shape[0])}", + f"logical_artifact_bytes={int(logical_total)}", + f"unique_storage_bytes={int(unique_storage_bytes)}", + "tensor_count_by_role=" + _artifact_summary(counts_by_role), + "logical_bytes_by_role=" + _artifact_summary(logical_bytes_by_role), + "storage_policy=role_owned;transition_state_before_binding_pruned;" + "message_recurrent_kv_before_recompute;transition_state_before_requires_copy;" + "immutable_artifacts_reuse_storage", + ) + + +def _artifact_summary(values: dict[str, int]) -> str: + return "none" if not values else ",".join(f"{key}:{int(values[key])}" for key in sorted(values)) + + +def _forward_reverse_artifact_roles_for_runtime( + roles: tuple[str, ...], + *, + output_contract: TemporalOutputContract, + materialize_final_state: bool, +) -> tuple[str, ...]: + if output_contract == "full_cells" or bool(materialize_final_state): + return tuple(roles) + return tuple(role for role in roles if role != "cells_prev") + + +def _materialize_fused_forward_population_state( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + final_program_tensors: tuple[torch.Tensor, ...], + program_tensor_binding_rows: torch.Tensor, + forward_transition_state_carry_rows: torch.Tensor, + static_tensors: dict[str, object], +) -> dict[str, TensorDict]: + slot_by_binding_index = _program_tensor_slot_by_binding_index(program_tensor_binding_rows) + carry_input_bindings_by_bucket: dict[int, set[int]] = {} + for row in forward_transition_state_carry_rows.cpu().tolist(): + carry_input_bindings_by_bucket.setdefault(int(row[0]), set()).add(int(row[1])) + transition_executors = tuple( + handle + for bucket_index in executor_program.transition_bucket_ordinals() + for handle in executor_program.forward_handles_for(surface="transition", bucket_ordinal=int(bucket_index)) + ) + population_state_by_name: dict[str, TensorDict] = {} + seen_transition_buckets: set[int] = set() + for transition_executor, bucket in transition_buckets_for_executors( + runtime, + transition_executors, + static_tensors, + ): + bucket_ordinal = int(transition_executor.bucket_ordinal) + if bucket_ordinal in seen_transition_buckets: + continue + seen_transition_buckets.add(bucket_ordinal) + transition_executor_group = executor_program.forward_handles_for( + surface="transition", + bucket_ordinal=bucket_ordinal, + ) + state_input_bindings = carry_input_bindings_by_bucket.get(bucket_ordinal, set()) + if not state_input_bindings: + continue + backend_state_payload: dict[str, torch.Tensor] = {} + for transition_executor_in_group in transition_executor_group: + for binding in transition_executor_in_group.bindings: + if binding.binding_kind != "input" or int(binding.binding_index) not in state_input_bindings: + continue + tensor_slot = slot_by_binding_index.get(int(binding.binding_index)) + if tensor_slot is None or tensor_slot >= len(final_program_tensors): + raise RuntimeError( + "Registered fused forward program final tensor table is missing a state binding: " + f"bucket={bucket_ordinal}; binding={int(binding.binding_index)}; " + f"logical={binding.logical_name!r}" + ) + tensor = final_program_tensors[int(tensor_slot)] + if not torch.is_tensor(tensor) or int(tensor.numel()) == 0: + raise RuntimeError( + "Registered fused forward program final tensor table has an empty state tensor: " + f"bucket={bucket_ordinal}; binding={int(binding.binding_index)}; " + f"logical={binding.logical_name!r}" + ) + backend_state_payload[str(binding.logical_name)] = tensor.contiguous() + if not backend_state_payload: + continue + first_tensor = next(iter(backend_state_payload.values())) + population_name = _flat_bucket_name(bucket) + backend_state = TensorDict( + backend_state_payload, + batch_size=[int(first_tensor.shape[0]), int(first_tensor.shape[1])], + device=first_tensor.device, + ) + population_state_by_name[population_name] = runtime._backend_state_to_population_state( + population_name, + backend_state, + ) + return population_state_by_name + + +def _materialize_fused_forward_final_state( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + boundary_seq: torch.Tensor, + output_seq: torch.Tensor, + final_recurrent_hidden_backend_order: torch.Tensor, + final_program_tensors: tuple[torch.Tensor, ...], + program_tensor_table: Any, + static_tensors: dict[str, object], +) -> TensorDict: + recurrent_hidden_graph_order = final_recurrent_hidden_backend_order.index_select( + 1, + runtime.population_backend_recurrent_inverse_order, + ).contiguous() + cells_out = torch.cat( + ( + boundary_seq[:, -1].contiguous(), + recurrent_hidden_graph_order, + output_seq[:, -1].contiguous(), + ), + dim=1, + ) + population_state_by_name = _materialize_fused_forward_population_state( + runtime, + executor_program=executor_program, + final_program_tensors=final_program_tensors, + program_tensor_binding_rows=program_tensor_table.program_tensor_binding_rows, + forward_transition_state_carry_rows=program_tensor_table.forward_transition_state_carry_rows, + static_tensors=static_tensors, + ) + runtime._last_flat_bucket_recurrent_graph_layout_backend = "registered_executor_backend_to_graph" + runtime._last_flat_bucket_graph_order_layout_backend = "registered_fused_program_tensor_table_layout" + return TensorDict( + { + "cells": cells_out, + **{ + name: population_state_by_name[name] + for name in runtime._population_names + if name in population_state_by_name + }, + }, + batch_size=[], + ) + + +def _try_run_registered_temporal_fused_forward_program_scan( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + boundary_seq: torch.Tensor, + state: TensorDict, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None, + static_tensors: dict[str, object], + inner_steps: int, + output_contract: TemporalOutputContract, + output_boundary: Literal["sequence", "terminal"], + collect_artifacts: bool, + materialize_final_state: bool, + transition_tape_mode: TemporalTransitionTapeMode, + memory_artifact_plan: TemporalMemoryRuntimeArtifactPlan, + artifact_checkpoints: dict[int, TemporalArtifactCheckpoint] | None, + initial_population_state_cache: dict[str, object] | None = None, +) -> SharedTemporalForwardScanResult | None: + if not hasattr(runtime, "_last_flat_bucket_temporal_registered_backward_memory_stages"): + runtime._last_flat_bucket_temporal_registered_backward_memory_stages = () + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_entry") + readout_pool = str(runtime.config.readout.pool) + runtime_support_plan = build_temporal_forward_program_runtime_support_plan( + runtime, + boundary_seq=boundary_seq, + output_contract=output_contract, + readout_pool=readout_pool, + materialize_final_state=bool(materialize_final_state), + collect_artifacts=bool(collect_artifacts), + memory_artifact_plan=memory_artifact_plan, + ) + runtime._last_flat_bucket_temporal_forward_program_runtime_support = runtime_support_plan.review_summary + runtime._last_flat_bucket_temporal_forward_program_runtime_support_rows = runtime_support_plan.rows + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_runtime_support_built") + if runtime_support_plan.rejection_reason is not None: + runtime._last_flat_bucket_temporal_scan_reject = ( + f"registered_fused_forward_program_reject:{runtime_support_plan.rejection_reason}" + ) + return None + step_population_state_cache = initial_population_state_cache + if step_population_state_cache is None: + state_has_population_views = all( + isinstance(state.get(name), TensorDictBase) + for name in runtime._population_names + if int(runtime._population_recurrent_indices(name).numel()) > 0 + ) + if state_has_population_views: + step_population_state_cache = runtime._prepare_stream_step_population_cache( + state, + batch=int(boundary_seq.shape[0]), + device=boundary_seq.device, + dtype=boundary_seq.dtype, + ) + else: + step_population_state_cache = { + name: runtime._init_backend_population_state( + name, + batch=int(boundary_seq.shape[0]), + device=boundary_seq.device, + dtype=boundary_seq.dtype, + ) + for name in runtime._population_names + if int(runtime._population_recurrent_indices(name).numel()) > 0 + } + if not step_population_state_cache: + step_population_state_cache = None + program_tensor_table = build_forward_executable_program_tensor_table( + runtime, + executor_program=executor_program, + boundary_seq=boundary_seq, + state=state, + static_tensors=static_tensors, + step_population_state_cache=step_population_state_cache, + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_program_tensor_table_built") + initial_recurrent_hidden = _initial_recurrent_hidden_backend_order_for_fused_program( + runtime, + boundary_seq, + state, + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_initial_recurrent_hidden_built") + forward_reset_tensors, forward_reset_rows = temporal_forward_reset_tensor_table( + population_resets=population_resets, + transition_resets=transition_resets, + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_reset_table_built") + output_step_shape = _temporal_forward_output_step_shape_for_contract( + runtime, + boundary_seq, + output_contract, + ) + physical_time_steps = int(boundary_seq.shape[1]) * int(inner_steps) + materialize_optional_transition_outputs = _transition_forward_optional_outputs_required( + physical_time_steps=physical_time_steps, + materialize_final_state=bool(materialize_final_state), + ) + streaming_readout_body_available = ( + not bool(collect_artifacts) + and not bool(materialize_final_state) + and not torch.is_tensor(population_resets) + and not torch.is_tensor(transition_resets) + ) + forward_executor_binding_rows = _forward_executor_binding_rows_for_runtime( + executor_program, + materialize_optional_transition_outputs=materialize_optional_transition_outputs, + ) + omitted_optional_transition_output_bindings = ( + set() + if materialize_optional_transition_outputs + else _optional_transition_forward_output_binding_indices(executor_program) + ) + reverse_artifact_roles = _forward_reverse_artifact_roles_for_runtime( + memory_artifact_plan.reverse_artifact_roles, + output_contract=output_contract, + materialize_final_state=materialize_final_state, + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_reverse_artifact_roles_built") + runtime_buffer_plan = build_temporal_runtime_buffer_plan( + executor_program.memory_plan, + output_seq_shape=( + int(boundary_seq.shape[0]), + 1 if output_boundary == "terminal" else int(boundary_seq.shape[1]), + *tuple(int(dim) for dim in output_step_shape[1:]), + ), + forward_message_step_flat_shape=(int(boundary_seq.shape[0]),), + physical_time_steps=physical_time_steps, + runtime_schedule_plan=memory_artifact_plan.runtime_schedule_plan, + cells_prev_shape=( + int(boundary_seq.shape[0]), + int(boundary_seq.shape[2]) + int(runtime.recurrent_cell_idx.numel()) + int(runtime.output_cell_idx.numel()), + int(runtime.hidden_size), + ) + if collect_artifacts and "cells_prev" in reverse_artifact_roles + else None, + recurrent_hidden_shape=( + int(boundary_seq.shape[0]), + int(runtime.recurrent_cell_idx.numel()), + int(runtime.hidden_size), + ) + if int(runtime.recurrent_cell_idx.numel()) > 0 + else None, + forward_recurrent_msg_shape=( + int(boundary_seq.shape[0]), + int(runtime.recurrent_cell_idx.numel()), + temporal_message_output_dim(runtime), + ), + forward_output_msg_shape=( + int(boundary_seq.shape[0]), + int(runtime.output_cell_idx.numel()), + int(runtime.value_dim), + ), + forward_output_cells_shape=( + int(boundary_seq.shape[0]), + int(runtime.output_cell_idx.numel()), + int(runtime.hidden_size), + ), + transition_forward_outputs=_transition_forward_runtime_buffer_requests( + runtime, + executor_program, + batch_size=int(boundary_seq.shape[0]), + materialize_optional_outputs=materialize_optional_transition_outputs, + forward_program_access_rows=program_tensor_table.forward_program_access_rows, + ), + dtype=str(boundary_seq.dtype), + device=str(boundary_seq.device), + include_workspace_rows=True, + enable_public_state_runtime_alias=not bool(collect_artifacts), + defer_forward_step_buffers=not bool(collect_artifacts) and not bool(materialize_final_state), + defer_local_transition_outputs=not bool(collect_artifacts) and not bool(materialize_final_state), + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_runtime_buffer_plan_built") + runtime_buffer_tensors = allocate_temporal_runtime_buffers(boundary_seq, runtime_buffer_plan) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_runtime_buffers_allocated") + runtime_buffer_rows = temporal_runtime_buffer_rows_tensor(runtime_buffer_plan) + forward_runtime_plan = build_temporal_forward_program_runtime_plan( + runtime, + boundary_seq=boundary_seq, + inner_steps=int(inner_steps), + output_boundary_terminal=output_boundary == "terminal", + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_runtime_plan_built") + runtime._last_flat_bucket_temporal_memory_runtime_buffer_plan = runtime_buffer_plan.review_summary + runtime._last_flat_bucket_temporal_memory_runtime_buffer_rows = runtime_buffer_rows + memory_runtime_schedule_rows = temporal_memory_runtime_schedule_rows_tensor( + memory_artifact_plan.runtime_schedule_plan + ) + streaming_step_body_available = bool(streaming_readout_body_available) and ( + _streaming_step_producer_consumer_body_available(executor_program) + ) + physical_strategy_plan = build_temporal_physical_strategy_plan( + memory_artifact_plan.runtime_schedule_plan, + inner_steps=int(inner_steps), + output_boundary="terminal" if output_boundary == "terminal" else "sequence", + reset_policy=( + "present" if torch.is_tensor(population_resets) or torch.is_tensor(transition_resets) else "absent" + ), + streaming_step_body_available=streaming_step_body_available, + ) + physical_strategy_rows = temporal_physical_strategy_rows_tensor(physical_strategy_plan) + runtime._last_flat_bucket_temporal_memory_runtime_schedule_rows = memory_runtime_schedule_rows + runtime._last_flat_bucket_temporal_physical_strategy_plan = physical_strategy_plan.review_summary + runtime._last_flat_bucket_temporal_physical_strategy_rows = physical_strategy_rows + runtime._last_flat_bucket_temporal_readout_message_producer_consumer_rows = ( + executor_program.readout_message_producer_consumer_rows + ) + runtime._last_flat_bucket_temporal_readout_message_producer_consumer_plan = ( + executor_program.readout_message_producer_consumer_summaries + ) + runtime._last_flat_bucket_temporal_readout_message_producer_consumer_template_rows = ( + executor_program.readout_message_producer_consumer_template_rows + ) + runtime._last_flat_bucket_temporal_readout_message_producer_consumer_template_plan = ( + executor_program.readout_message_producer_consumer_template_summaries + ) + runtime._last_flat_bucket_temporal_message_transition_producer_consumer_rows = ( + executor_program.message_transition_producer_consumer_rows + ) + runtime._last_flat_bucket_temporal_message_transition_producer_consumer_plan = ( + executor_program.message_transition_producer_consumer_summaries + ) + runtime._last_flat_bucket_temporal_message_transition_producer_consumer_template_rows = ( + executor_program.message_transition_producer_consumer_template_rows + ) + runtime._last_flat_bucket_temporal_message_transition_producer_consumer_template_plan = ( + executor_program.message_transition_producer_consumer_template_summaries + ) + runtime._last_flat_bucket_temporal_forward_program_runtime_plan = forward_runtime_plan.review_summary + runtime._last_flat_bucket_temporal_forward_program_runtime_rows = forward_runtime_plan.rows + runtime._last_flat_bucket_temporal_optional_transition_outputs = ( + "transition_optional_outputs=materialized" + if materialize_optional_transition_outputs + else ( + "transition_optional_outputs=elided;" + f"omitted_bindings={len(omitted_optional_transition_output_bindings)};" + f"carry_rows={int(program_tensor_table.forward_transition_state_carry_rows.shape[0])}" + ) + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "before_fused_forward_program") + fused_outputs = registered_temporal_fused_forward_program_cuda( + boundary_seq=boundary_seq, + recurrent_hidden_initial_backend_order=initial_recurrent_hidden, + program_tensors=program_tensor_table.program_tensors, + program_tensor_binding_rows=program_tensor_table.program_tensor_binding_rows, + forward_program_access_rows=program_tensor_table.forward_program_access_rows, + forward_transition_state_carry_rows=program_tensor_table.forward_transition_state_carry_rows, + forward_artifact_route_rows=executor_program.forward_artifact_route_rows, + forward_artifact_merge_rows=executor_program.forward_artifact_merge_rows, + forward_output_route_rows=executor_program.forward_output_route_rows, + readout_message_producer_consumer_rows=executor_program.readout_message_producer_consumer_rows, + message_transition_producer_consumer_rows=executor_program.message_transition_producer_consumer_rows, + forward_reset_tensors=forward_reset_tensors, + forward_reset_rows=forward_reset_rows, + primitive_rows=executor_program.primitive_rows, + forward_executor_rows=executor_program.forward_plan.forward_executor_rows, + reverse_executor_rows=executor_program.backward_plan.reverse_executor_rows, + forward_handler_rows=executor_program.forward_handler_rows, + reverse_handler_rows=executor_program.reverse_handler_rows, + native_strategy_rows=executor_program.native_strategy_rows, + native_callable_binding_schema_rows=executor_program.native_callable_binding_schema_rows, + native_callable_output_rows=executor_program.native_callable_output_rows, + transition_primitive_callable_rows=executor_program.transition_primitive_callable_rows, + forward_executor_binding_rows=forward_executor_binding_rows, + reverse_executor_binding_rows=executor_program.backward_plan.executor_binding_rows, + memory_liveness_rows=executor_program.memory_liveness_rows, + memory_runtime_schedule_rows=memory_runtime_schedule_rows, + physical_strategy_rows=physical_strategy_rows, + runtime_buffer_tensors=runtime_buffer_tensors, + runtime_buffer_rows=runtime_buffer_rows, + forward_program_runtime_tensors=forward_runtime_plan.tensors, + forward_program_runtime_rows=forward_runtime_plan.rows, + return_final_program_tensors=bool(materialize_final_state), + return_reverse_artifacts=bool(collect_artifacts), + ) + if len(fused_outputs) < 2: + raise RuntimeError("Registered fused forward program returned no output/final recurrent state") + output_seq = fused_outputs[0] + final_recurrent_hidden = fused_outputs[1] + program_tensor_count = len(program_tensor_table.program_tensors) + forward_native_memory_stage_offset = 2 + program_tensor_count + if len(fused_outputs) <= forward_native_memory_stage_offset: + raise RuntimeError("Registered fused forward program returned no native memory stage rows") + forward_native_memory_stage_rows = fused_outputs[forward_native_memory_stage_offset] + _record_registered_backward_native_memory_stage_rows(runtime, forward_native_memory_stage_rows) + _record_registered_backward_memory_stage(runtime, boundary_seq, "after_fused_forward_program") + final_program_tensors = tuple(fused_outputs[2 : 2 + program_tensor_count]) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_final_tensor_table_sliced") + if len(final_program_tensors) != len(program_tensor_table.program_tensors): + raise RuntimeError( + "Registered fused forward program returned a mismatched final tensor table: " + f"expected={len(program_tensor_table.program_tensors)}; actual={len(final_program_tensors)}" + ) + artifact_store = None + runtime._last_flat_bucket_temporal_reverse_artifact_tensor_store = () + if collect_artifacts: + artifact_offset = forward_native_memory_stage_offset + 1 + if len(fused_outputs) <= artifact_offset: + raise RuntimeError("Registered fused forward program returned no compiler reverse artifact tensor table") + reverse_artifact_binding_rows = fused_outputs[artifact_offset] + reverse_artifact_tensors = tuple(fused_outputs[artifact_offset + 1 :]) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_reverse_artifacts_sliced") + if ( + not torch.is_tensor(reverse_artifact_binding_rows) + or reverse_artifact_binding_rows.dim() != 2 + or int(reverse_artifact_binding_rows.shape[1]) != 5 + ): + raise RuntimeError("Registered fused forward program returned invalid reverse artifact binding rows") + reverse_artifact_role_rows = temporal_reverse_artifact_role_rows_tensor(reverse_artifact_roles) + reverse_artifact_access_rows = temporal_reverse_artifact_access_rows_tensor(reverse_artifact_roles) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_reverse_artifact_rows_built") + reverse_artifact_tensor_store = TemporalReverseArtifactTensorStore( + tensors=reverse_artifact_tensors, + binding_rows=reverse_artifact_binding_rows.to(device="cpu", dtype=torch.long).contiguous(), + role_rows=reverse_artifact_role_rows, + access_rows=reverse_artifact_access_rows, + window_start=0, + window_end=int(boundary_seq.shape[1]) * int(inner_steps), + source="registered_fused_forward_program_cuda", + ) + runtime._last_flat_bucket_temporal_reverse_artifact_tensor_store = ( + _reverse_artifact_tensor_store_review_summary( + tensors=reverse_artifact_tensors, + binding_rows=reverse_artifact_tensor_store.binding_rows, + ) + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_reverse_artifact_store_built") + artifact_store = TemporalArtifactStore( + mode="store_step_artifacts", + artifacts_by_step=None, + checkpoints={} if artifact_checkpoints is None else dict(artifact_checkpoints), + checkpoint_stride=max(1, int(memory_artifact_plan.checkpoint_stride)), + recompute_window_len=max(1, int(memory_artifact_plan.recompute_window_len)), + transition_tape_mode=transition_tape_mode, + reason=( + "artifact_mode=store_step_artifacts;" + "source=registered_fused_forward_program_cuda;" + f"time_steps={int(boundary_seq.shape[1]) * int(inner_steps)};" + f"{memory_artifact_plan.reason};" + f"checkpoint_owner={memory_artifact_plan.checkpoint_owner};" + "reverse_owner=registered_fused_reverse_program_tensor_table" + ), + stored_artifact_step_bytes=0, + checkpoint_steps=memory_artifact_plan.checkpoint_steps, + backward_windows=memory_artifact_plan.backward_windows, + memory_plan_fingerprint=executor_program.memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=memory_artifact_plan.fingerprint, + memory_runtime_policy_fingerprint=memory_artifact_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=memory_artifact_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=memory_runtime_schedule_rows, + physical_strategy_fingerprint=physical_strategy_plan.fingerprint, + physical_strategy_rows=physical_strategy_rows, + reverse_artifact_roles=reverse_artifact_roles, + reverse_artifact_tensor_store=reverse_artifact_tensor_store, + ) + final_state = ( + _materialize_fused_forward_final_state( + runtime, + executor_program=executor_program, + boundary_seq=boundary_seq, + output_seq=output_seq, + final_recurrent_hidden_backend_order=final_recurrent_hidden, + final_program_tensors=final_program_tensors, + program_tensor_table=program_tensor_table, + static_tensors=static_tensors, + ) + if materialize_final_state + else TensorDict({}, batch_size=[]) + ) + _record_registered_backward_memory_stage(runtime, boundary_seq, "forward_final_state_materialized") + runtime._last_flat_bucket_temporal_memory_runtime_artifact_plan = memory_artifact_plan.review_summary + runtime._last_flat_bucket_temporal_executable_program_tensor_table = program_tensor_table.review_summary + runtime._last_flat_bucket_temporal_memory_workspace_aliases = memory_artifact_plan.workspace_aliases + runtime._last_flat_bucket_temporal_memory_checkpoint_steps = memory_artifact_plan.checkpoint_steps + runtime._last_flat_bucket_temporal_memory_backward_windows = memory_artifact_plan.backward_windows + runtime._last_flat_bucket_temporal_memory_liveness_rows = executor_program.memory_liveness_rows + runtime._last_flat_bucket_temporal_reverse_program_stage_rows = executor_program.reverse_program_stage_rows + runtime._last_flat_bucket_temporal_reverse_program_stage_summaries = ( + executor_program.reverse_program_stage_summaries + ) + runtime._last_flat_bucket_temporal_reverse_output_route_rows = executor_program.reverse_output_route_rows + runtime._last_flat_bucket_temporal_forward_artifact_route_rows = executor_program.forward_artifact_route_rows + runtime._last_flat_bucket_temporal_forward_artifact_merge_rows = executor_program.forward_artifact_merge_rows + runtime._last_flat_bucket_temporal_forward_output_route_rows = executor_program.forward_output_route_rows + runtime._last_flat_bucket_temporal_reverse_artifact_consumer_route_rows = ( + executor_program.reverse_artifact_consumer_route_rows + ) + runtime._last_flat_bucket_temporal_reverse_artifact_consumer_route_summaries = ( + executor_program.reverse_artifact_consumer_route_summaries + ) + runtime._last_flat_bucket_temporal_forward_artifact_route_summaries = ( + executor_program.forward_artifact_route_summaries + ) + runtime._last_flat_bucket_temporal_reverse_parameter_reducer_route_rows = ( + executor_program.reverse_parameter_reducer_route_rows + ) + runtime._last_flat_bucket_temporal_reverse_parameter_reducer_route_summaries = ( + executor_program.reverse_parameter_reducer_route_summaries + ) + runtime._last_flat_bucket_temporal_transition_param_grad_binding_rows = ( + executor_program.transition_param_grad_binding_rows + ) + runtime._last_flat_bucket_temporal_transition_param_grad_binding_summaries = ( + executor_program.transition_param_grad_binding_summaries + ) + runtime._last_flat_bucket_temporal_fused_cuda_program_plan = executor_program.fused_cuda_program_plan.review_summary + runtime._last_flat_bucket_temporal_fused_cuda_launch_contract = ( + executor_program.fused_cuda_program_plan.launch_contract.review_summary + ) + runtime._last_flat_bucket_temporal_fused_cuda_program_status = executor_program.fused_cuda_program_plan.status + runtime._last_flat_bucket_temporal_fused_cuda_program_blocker = ( + executor_program.fused_cuda_program_plan.blocker_code, + executor_program.fused_cuda_program_plan.blocker_reason, + ) + runtime._last_flat_bucket_temporal_registered_program_executor_plan = ( + executor_program.program_executor_plan.review_summary + ) + runtime._last_flat_bucket_temporal_registered_program_executor_status = ( + executor_program.program_executor_plan.status + ) + runtime._last_flat_bucket_temporal_registered_program_executor_demotion_policy = ( + executor_program.program_executor_plan.demotion_policy + ) + runtime._last_flat_bucket_temporal_executor_kernel_registry = executor_program.kernel_registry.review_summary + runtime._last_flat_bucket_temporal_scan_owner = "registered_fused_forward_program_cuda" + runtime._last_flat_bucket_scan_implementation = executor_program.fused_cuda_program_plan.forward_entrypoint + runtime._last_flat_bucket_temporal_scan_binding_abi = "registered_executor_binding_rows" + runtime._last_flat_bucket_temporal_scan_primitive_row_source = "compiler_primitive_rows" + runtime._last_flat_bucket_forward_transition_executor = "registered_fused_forward_program_cuda" + runtime._last_flat_bucket_public_projection_backend = "registered_fused_forward_program_cuda" + runtime._last_flat_bucket_readout_backend = "registered_fused_forward_program_cuda" + runtime._last_flat_bucket_state_cache_mode = ( + "registered_fused_program_final_tensor_table" + if materialize_final_state + else "registered_fused_program_internal_state" + ) + return SharedTemporalForwardScanResult( + output_seq=output_seq, + final_state=final_state, + artifact_store=artifact_store, + ) + + +def run_registered_temporal_forward_executor_scan( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + boundary_seq: torch.Tensor, + state: TensorDict, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None, + static_tensors: dict[str, object], + inner_steps: int, + output_contract: TemporalOutputContract, + output_boundary: Literal["sequence", "terminal"], + collect_artifacts: bool, + materialize_final_state: bool, + transition_tape_mode: TemporalTransitionTapeMode, + memory_artifact_plan: TemporalMemoryRuntimeArtifactPlan, + artifact_checkpoints: dict[int, TemporalArtifactCheckpoint] | None, + scan_schedule: Any, + initial_population_state_cache: dict[str, object] | None = None, +) -> SharedTemporalForwardScanResult: + fused_result = _try_run_registered_temporal_fused_forward_program_scan( + runtime, + executor_program=executor_program, + boundary_seq=boundary_seq, + state=state, + population_resets=population_resets, + transition_resets=transition_resets, + static_tensors=static_tensors, + inner_steps=int(inner_steps), + output_contract=output_contract, + output_boundary=output_boundary, + collect_artifacts=collect_artifacts, + materialize_final_state=materialize_final_state, + transition_tape_mode=transition_tape_mode, + memory_artifact_plan=memory_artifact_plan, + artifact_checkpoints=artifact_checkpoints, + initial_population_state_cache=initial_population_state_cache, + ) + if fused_result is not None: + return fused_result + del scan_schedule + runtime._last_flat_bucket_temporal_scan_owner = "registered_fused_forward_program_unavailable" + runtime._last_flat_bucket_scan_implementation = executor_program.fused_cuda_program_plan.forward_entrypoint + runtime._last_flat_bucket_temporal_scan_binding_abi = "registered_executor_binding_rows" + runtime._last_flat_bucket_temporal_scan_primitive_row_source = "compiler_primitive_rows" + runtime._last_flat_bucket_temporal_scan_reject = ( + "registered_fused_forward_program_required;" + f"output_contract={output_contract};" + f"output_boundary={output_boundary};" + f"collect_artifacts={int(bool(collect_artifacts))};" + f"artifact_mode={memory_artifact_plan.mode};" + f"store_step_artifacts={int(bool(memory_artifact_plan.store_step_artifacts))};" + "compiled_fused_forward_program_only=1" + ) + raise RuntimeError( + "Registered temporal forward execution must run through the compiler-owned fused CUDA program. " + f"{runtime._last_flat_bucket_temporal_scan_reject}" + ) + + +def _reverse_artifact_tensor_store_window_table( + tensor_store: TemporalReverseArtifactTensorStore, + *, + window_start: int, + window_end: int, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + if int(window_start) < int(tensor_store.window_start) or int(window_end) > int(tensor_store.window_end): + raise RuntimeError( + "Registered fused forward artifact tensor store cannot serve requested backward window: " + f"store=({int(tensor_store.window_start)}, {int(tensor_store.window_end)}); " + f"request=({int(window_start)}, {int(window_end)})" + ) + rows: list[list[int]] = [] + if ( + tensor_store.binding_rows.device.type != "cpu" + or tensor_store.binding_rows.dtype != torch.long + or tensor_store.binding_rows.dim() != 2 + or int(tensor_store.binding_rows.shape[1]) != 5 + ): + raise RuntimeError("Registered fused forward artifact tensor store requires CPU int64 binding rows [N,5]") + for role_id, tensor_index, physical_step, flags, route_row in tensor_store.binding_rows.cpu().tolist(): + if int(window_start) <= int(physical_step) < int(window_end): + rows.append( + [ + int(role_id), + int(tensor_index), + int(physical_step) - int(window_start), + int(flags), + int(route_row), + ] + ) + if not rows: + raise RuntimeError( + "Registered fused forward artifact tensor store produced an empty backward window table: " + f"window=({int(window_start)}, {int(window_end)})" + ) + return tensor_store.tensors, torch.tensor(rows, dtype=torch.long) + + +def _reverse_artifact_tensor_for_step( + tensor_store: TemporalReverseArtifactTensorStore, + *, + role: str, + physical_step: int, +) -> torch.Tensor: + if role in {"output_msg", "output_cells"}: + raise RuntimeError( + "Registered fused forward output artifacts are route-owned; " + f"use compiler output/artifact route rows for role={role!r}" + ) + role_id = temporal_reverse_artifact_role_id(role) + matches = [ + int(row[1]) + for row in tensor_store.binding_rows.cpu().tolist() + if int(row[0]) == int(role_id) and int(row[2]) == int(physical_step) and int(row[3]) == 0 + ] + if len(matches) != 1: + raise RuntimeError( + "Registered fused forward artifact tensor store has no unique role tensor: " + f"role={role!r}; physical_step={int(physical_step)}; count={len(matches)}" + ) + tensor_index = matches[0] + if tensor_index < 0 or tensor_index >= len(tensor_store.tensors): + raise RuntimeError( + "Registered fused forward artifact tensor store row points outside tensor table: " + f"role={role!r}; tensor_index={int(tensor_index)}" + ) + return tensor_store.tensors[int(tensor_index)] + + +def _reverse_artifact_tensor_for_route_step( + tensor_store: TemporalReverseArtifactTensorStore, + *, + role: str, + physical_step: int, + route_row: int, +) -> torch.Tensor: + role_id = temporal_reverse_artifact_role_id(role) + matches = [ + int(row[1]) + for row in tensor_store.binding_rows.cpu().tolist() + if ( + int(row[0]) == int(role_id) + and int(row[2]) == int(physical_step) + and int(row[3]) == 0 + and int(row[4]) == int(route_row) + ) + ] + if len(matches) != 1: + raise RuntimeError( + "Registered fused forward artifact tensor store has no unique routed tensor: " + f"role={role!r}; physical_step={int(physical_step)}; route_row={int(route_row)}; count={len(matches)}" + ) + tensor_index = matches[0] + if tensor_index < 0 or tensor_index >= len(tensor_store.tensors): + raise RuntimeError( + "Registered fused forward artifact tensor store routed row points outside tensor table: " + f"role={role!r}; route_row={int(route_row)}; tensor_index={int(tensor_index)}" + ) + return tensor_store.tensors[int(tensor_index)] + + +def _readout_forward_artifact_route_row_for_output_route( + executor_program: RegisteredTemporalExecutorProgram, + *, + output_route_row: list[int], + artifact_role: str, +) -> int: + role_id = temporal_reverse_artifact_role_id(artifact_role) + readout_surface = temporal_surface_opcode("readout") + matches = [ + int(row[0]) + for row in executor_program.forward_artifact_route_rows.cpu().tolist() + if ( + int(row[1]) == int(readout_surface) + and int(row[2]) == int(output_route_row[3]) + and int(row[3]) == int(output_route_row[4]) + and int(row[4]) == int(output_route_row[5]) + and int(row[5]) == int(role_id) + ) + ] + if len(matches) != 1: + raise RuntimeError( + "Registered output route has no unique readout artifact producer route: " + f"artifact_role={artifact_role!r}; output_route_row={int(output_route_row[0])}; count={len(matches)}" + ) + return matches[0] + + +def reverse_artifact_tensor_store_output_cells_for_step( + tensor_store: TemporalReverseArtifactTensorStore, + executor_program: RegisteredTemporalExecutorProgram, + *, + physical_step: int, +) -> torch.Tensor: + rows = executor_program.forward_output_route_rows.cpu().tolist() + if not rows: + raise RuntimeError("Registered output-cell artifact lookup requires forward output route rows") + route_kind = int(rows[0][1]) + output_cells_by_route: list[torch.Tensor] = [] + for row in rows: + if int(row[1]) != route_kind: + raise RuntimeError("Registered output-cell artifact lookup received mixed output route kinds") + artifact_route_row = _readout_forward_artifact_route_row_for_output_route( + executor_program, + output_route_row=[int(value) for value in row], + artifact_role="output_cells", + ) + output_cells_by_route.append( + _reverse_artifact_tensor_for_route_step( + tensor_store, + role="output_cells", + physical_step=int(physical_step), + route_row=int(artifact_route_row), + ) + ) + if route_kind == temporal_forward_output_route_kind_opcode("readout_output_concat"): + return torch.cat(tuple(output_cells_by_route), dim=1).contiguous() + if route_kind == temporal_forward_output_route_kind_opcode("readout_output_sum"): + output_cells = output_cells_by_route[0].clone() + for candidate in output_cells_by_route[1:]: + if tuple(candidate.shape) != tuple(output_cells.shape): + raise RuntimeError("Registered output-cell sum route has incompatible artifact shapes") + output_cells.add_(candidate) + return output_cells.contiguous() + if len(output_cells_by_route) != 1: + raise RuntimeError("Registered singleton output route resolved multiple output-cell artifacts") + return output_cells_by_route[0] + + +def reverse_artifact_tensor_store_tensor_for_step( + tensor_store: TemporalReverseArtifactTensorStore, + *, + role: str, + physical_step: int, +) -> torch.Tensor: + return _reverse_artifact_tensor_for_step( + tensor_store, + role=role, + physical_step=int(physical_step), + ) + + +def _registered_reverse_program_output_cell_grad_window_from_tensor_store( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + tensor_store: TemporalReverseArtifactTensorStore, + window_start: int, + window_end: int, + grad_output_window: torch.Tensor, + output_contract: TemporalOutputContract, +) -> torch.Tensor: + if output_contract == "output_cells": + return grad_output_window + grad_output_cell_steps: list[torch.Tensor] = [] + for local_step, physical_step in enumerate(range(int(window_start), int(window_end))): + output_cells = reverse_artifact_tensor_store_output_cells_for_step( + tensor_store, + executor_program, + physical_step=int(physical_step), + ) + grad_output_cells = _grad_output_cells_for_contract( + runtime, + output_cells, + grad_output_window[:, int(local_step)], + output_contract, + ) + if not torch.is_tensor(grad_output_cells): + raise RuntimeError("Registered reverse program could not materialize output-cell gradient window") + grad_output_cell_steps.append(grad_output_cells.contiguous()) + return torch.stack(tuple(grad_output_cell_steps), dim=1).contiguous() + + +def _reject_registered_reverse_program_window(runtime: Any, reason: str) -> None: + reject = f"registered_reverse_program_window_reject:{reason}" + runtime._last_flat_bucket_temporal_reverse_program_window_reject = reject + runtime._last_flat_bucket_temporal_reverse_engine_reject = reject + + +def _build_transition_boundary_reverse_step_program_tables( + runtime: Any, + executor_program: RegisteredTemporalExecutorProgram, + transition_executors: tuple[RegisteredTemporalExecutorHandle, ...], + *, + forward_transition_executors: tuple[RegisteredTemporalExecutorHandle, ...], + materialize_optional_transition_outputs: bool, + grad_next_backend_state_cache: dict[str, object] | None, + static_tensors: dict[str, object], + reverse_artifact_tensors: tuple[torch.Tensor, ...], +) -> _TransitionBoundaryReverseStepProgramTables: + executor_program.kernel_registry.require_reverse_transitions(transition_executors) + executor_program.kernel_registry.require_forward_transitions(forward_transition_executors) + reference = next( + (tensor for tensor in reverse_artifact_tensors if torch.is_tensor(tensor) and tensor.dim() >= 3), + None, + ) + if reference is None: + raise RuntimeError("Registered transition reverse program requires at least one tensor reverse artifact") + forward_executors_by_bucket: dict[int, tuple[RegisteredTemporalExecutorHandle, ...]] = {} + for executor, _bucket in transition_buckets_for_executors(runtime, forward_transition_executors, static_tensors): + forward_executors_by_bucket.setdefault(int(executor.bucket_ordinal), ()) + forward_executors_by_bucket[int(executor.bucket_ordinal)] = ( + *forward_executors_by_bucket[int(executor.bucket_ordinal)], + executor, + ) + program_tensor_groups: list[tuple[torch.Tensor, ...]] = [] + program_tensor_binding_row_groups: list[torch.Tensor] = [] + forward_executor_row_groups: list[torch.Tensor] = [] + reverse_executor_row_groups: list[torch.Tensor] = [] + forward_executor_binding_row_groups: list[torch.Tensor] = [] + reverse_executor_binding_row_groups: list[torch.Tensor] = [] + memory_liveness_row_groups: list[torch.Tensor] = [] + transition_seed_tensor_groups: list[tuple[torch.Tensor, ...]] = [] + transition_seed_row_groups: list[torch.Tensor] = [] + transition_dynamic_binding_row_groups: list[torch.Tensor] = [] + transition_output_keep_slot_row_groups: list[torch.Tensor] = [] + transition_recurrent_msg_output_rows: list[list[int]] = [] + transition_public_y_seed_rows: list[list[int]] = [] + transition_next_seed_output_rows: list[list[int]] = [] + transition_state_reset_slot_maps: list[dict[str, int]] = [] + transition_group_metadata: list[ + tuple[ + tuple[RegisteredTemporalExecutorHandle, ...], + Any, + dict[str, int], + tuple[TemporalTransitionParamGradBinding, ...], + ] + ] = [] + seen_transition_buckets: set[int] = set() + for reverse_executor, bucket in transition_buckets_for_executors(runtime, transition_executors, static_tensors): + bucket_ordinal = int(reverse_executor.bucket_ordinal) + if bucket_ordinal in seen_transition_buckets: + continue + seen_transition_buckets.add(bucket_ordinal) + forward_executor_group = forward_executors_by_bucket.get(bucket_ordinal) + if not forward_executor_group: + raise RuntimeError( + "Registered reverse transition executor has no matching forward executor for recompute: " + f"bucket={int(reverse_executor.bucket_ordinal)}" + ) + reverse_executor_group = tuple( + executor for executor in transition_executors if int(executor.bucket_ordinal) == bucket_ordinal + ) + population_name = _flat_bucket_name(bucket) + transition_program_tensors, transition_program_tensor_binding_rows, reverse_logical_to_slot = ( + transition_program_slot_table( + executors=(*forward_executor_group, *reverse_executor_group), + reference=reference, + ) + ) + seed_tensors_list: list[torch.Tensor] = [] + seed_rows_list: list[list[int]] = [] + seen_seed_roles: set[int] = set() + for grouped_reverse_executor in reverse_executor_group: + grouped_seed_tensors, grouped_seed_rows = _transition_reverse_seed_tensor_table( + reverse_executor=grouped_reverse_executor, + bucket=bucket, + bucket_ordinal=bucket_ordinal, + population_name=population_name, + grad_recurrent_hidden_backend=None, + grad_next_backend_state_cache=grad_next_backend_state_cache, + ) + for row in grouped_seed_rows.cpu().tolist(): + role_id = int(row[0]) + if role_id in seen_seed_roles: + continue + seen_seed_roles.add(role_id) + seed_rows_list.append([role_id, len(seed_tensors_list), bucket_ordinal]) + seed_tensors_list.append(grouped_seed_tensors[int(row[1])]) + seed_tensors = tuple(seed_tensors_list) + seed_rows = ( + torch.tensor(seed_rows_list, dtype=torch.long) if seed_rows_list else torch.empty((0, 3), dtype=torch.long) + ) + bucket_param_grad_bindings = tuple( + binding + for binding in executor_program.transition_param_grad_bindings + if int(binding.executor_row_index) in {int(executor.row_index) for executor in reverse_executor_group} + ) + grad_message_slot = reverse_logical_to_slot.get("grad_aggregated_message") + if grad_message_slot is None: + raise RuntimeError( + "Registered reverse transition executor did not bind grad_aggregated_message output: " + f"bucket={int(reverse_executor.bucket_ordinal)}" + ) + output_keep_slots = {int(grad_message_slot)} + group_index = len(transition_group_metadata) + for grouped_reverse_executor in reverse_executor_group: + for state_name in _transition_reverse_state_grad_names(grouped_reverse_executor): + state_slot = reverse_logical_to_slot.get(f"grad_{state_name}") + if state_slot is None: + continue + output_keep_slots.add(int(state_slot)) + transition_next_seed_output_rows.append( + [ + int(group_index), + _transition_reverse_seed_role_id(f"grad_next_{state_name}"), + int(state_slot), + bucket_ordinal, + ] + ) + for binding in bucket_param_grad_bindings: + grad_slot = reverse_logical_to_slot.get(str(binding.grad_logical_name)) + if grad_slot is None: + raise RuntimeError( + "Registered reverse transition parameter reducer binding has no compiler output slot: " + f"bucket={int(bucket_ordinal)}; grad={binding.grad_logical_name!r}" + ) + output_keep_slots.add(int(grad_slot)) + program_tensor_groups.append(transition_program_tensors) + program_tensor_binding_row_groups.append(transition_program_tensor_binding_rows) + forward_executor_row_groups.append(_transition_executor_rows_for_group(forward_executor_group)) + reverse_executor_row_groups.append(_transition_executor_rows_for_group(reverse_executor_group)) + forward_executor_binding_row_groups.append( + _transition_executor_binding_rows_for_group( + forward_executor_group, + direction_opcode=1, + materialize_optional_outputs=materialize_optional_transition_outputs, + ) + ) + reverse_executor_binding_row_groups.append( + _transition_executor_binding_rows_for_group(reverse_executor_group, direction_opcode=2) + ) + memory_liveness_row_groups.append( + _transition_memory_liveness_rows_for_group( + (*forward_executor_group, *reverse_executor_group), + executor_program.memory_liveness_rows, + ) + ) + transition_seed_tensor_groups.append(seed_tensors) + transition_seed_row_groups.append(seed_rows) + transition_dynamic_binding_row_groups.append( + _transition_dynamic_binding_rows_tensor( + forward_executor=forward_executor_group, + reverse_executor=reverse_executor_group, + ) + ) + transition_output_keep_slot_row_groups.append( + torch.tensor([[int(slot)] for slot in sorted(output_keep_slots)], dtype=torch.long) + ) + transition_recurrent_msg_output_rows.append( + [int(group_index), int(grad_message_slot), int(bucket.backend_start), int(bucket.backend_stop)] + ) + transition_public_y_seed_rows.append( + [ + bucket_ordinal, + int(bucket.backend_start), + int(bucket.backend_stop), + _transition_reverse_seed_role_id("grad_public_y"), + ] + ) + transition_state_reset_slot_maps.append(reverse_logical_to_slot) + transition_group_metadata.append( + (reverse_executor_group, bucket, reverse_logical_to_slot, bucket_param_grad_bindings) + ) + return _TransitionBoundaryReverseStepProgramTables( + program_tensor_groups=tuple(program_tensor_groups), + program_tensor_binding_row_groups=tuple(program_tensor_binding_row_groups), + forward_executor_row_groups=tuple(forward_executor_row_groups), + reverse_executor_row_groups=tuple(reverse_executor_row_groups), + forward_executor_binding_row_groups=tuple(forward_executor_binding_row_groups), + reverse_executor_binding_row_groups=tuple(reverse_executor_binding_row_groups), + memory_liveness_row_groups=tuple(memory_liveness_row_groups), + seed_tensor_groups=tuple(transition_seed_tensor_groups), + seed_row_groups=tuple(transition_seed_row_groups), + dynamic_binding_row_groups=tuple(transition_dynamic_binding_row_groups), + output_keep_slot_row_groups=tuple(transition_output_keep_slot_row_groups), + recurrent_msg_output_rows=torch.tensor(transition_recurrent_msg_output_rows, dtype=torch.long), + public_y_seed_rows=torch.tensor(transition_public_y_seed_rows, dtype=torch.long), + transition_state_reset_rows=temporal_reverse_transition_state_reset_rows_tensor( + group_logical_slots=tuple(transition_state_reset_slot_maps), + ), + next_seed_output_rows=( + torch.tensor(transition_next_seed_output_rows, dtype=torch.long) + if transition_next_seed_output_rows + else torch.empty((0, 4), dtype=torch.long) + ), + group_metadata=tuple(transition_group_metadata), + ) + + +def _consume_transition_boundary_reverse_step_outputs( + runtime: Any, + *, + transition_step_tables: _TransitionBoundaryReverseStepProgramTables, + reverse_output_groups: tuple[tuple[torch.Tensor, ...], ...], +) -> tuple[ + TensorDict, + tuple[TemporalParameterReducerRequest, ...], + dict[str, object] | None, +]: + grad_backend_state_cache: dict[str, object] = {} + transition_param_reducer_requests: list[TemporalParameterReducerRequest] = [] + if reverse_output_groups and len(reverse_output_groups) != len(transition_step_tables.group_metadata): + raise RuntimeError( + "Registered fused reverse full step returned mismatched transition output groups: " + f"groups={len(reverse_output_groups)} metadata={len(transition_step_tables.group_metadata)}" + ) + for (_reverse_executors, bucket, reverse_logical_to_slot, bucket_param_grad_bindings), reverse_outputs in zip( + transition_step_tables.group_metadata, + reverse_output_groups, + strict=False, + ): + population_name = _flat_bucket_name(bucket) + state_grad_by_name: dict[str, torch.Tensor | None] = {} + for _reverse_executor in _reverse_executors: + for state_name in _transition_reverse_state_grad_names(_reverse_executor): + grad_state = _transition_tensor_by_logical_optional( + reverse_outputs, + reverse_logical_to_slot, + f"grad_{state_name}", + ) + if grad_state is not None: + state_grad_by_name[state_name] = grad_state.contiguous() + grad_backend_state_cache[population_name] = state_grad_by_name + materialized_grad_accum, static_source_accum = _transition_param_grad_accumulator_from_binding_rows( + bucket_static_tensors=bucket.static_tensors, + reverse_outputs=reverse_outputs, + reverse_logical_to_slot=reverse_logical_to_slot, + transition_param_grad_bindings=bucket_param_grad_bindings, + ) + transition_param_reducer_requests.append( + TemporalTransitionParamReducerRequest( + kind="transition", + population_name=population_name, + materialized_grad_accum=materialized_grad_accum, + static_source_accum=static_source_accum, + transition_param_grad_bindings=bucket_param_grad_bindings, + ) + ) + _record_temporal_backward_glue_cuda(runtime, "registered_fused_backward_program_span_transition_boundary") + runtime._last_flat_bucket_transition_backward_executor = ( + "registered_fused_backward_program_span_transition_boundary" + ) + return ( + TensorDict( + { + name: _partial_backend_grad_state_to_population_state(backend_state) + for name, backend_state in grad_backend_state_cache.items() + if isinstance(backend_state, dict) + }, + batch_size=[], + ), + tuple(transition_param_reducer_requests), + grad_backend_state_cache, + ) + + +def _record_registered_backward_memory_stage( + runtime: Any, + reference: torch.Tensor, + stage: str, +) -> None: + if not torch.is_tensor(reference) or not reference.is_cuda: + return + enabled_fn = getattr(runtime, "_backend_owner_timing_enabled", None) + if not callable(enabled_fn) or not bool(enabled_fn(reference.device)): + return + try: + torch.cuda.synchronize(reference.device) + allocated = int(torch.cuda.memory_allocated(reference.device)) + reserved = int(torch.cuda.memory_reserved(reference.device)) + max_allocated = int(torch.cuda.max_memory_allocated(reference.device)) + except RuntimeError: + return + summary = f"stage={stage};allocated={allocated};reserved={reserved};max_allocated={max_allocated}" + _append_registered_backward_memory_stage_summary(runtime, summary) + + +def _append_registered_backward_memory_stage_summary( + runtime: Any, + summary: str, +) -> None: + previous = tuple(getattr(runtime, "_last_flat_bucket_temporal_registered_backward_memory_stages", ()) or ()) + runtime._last_flat_bucket_temporal_registered_backward_memory_stages = (*previous, summary) + record = getattr(runtime, "_last_backend_execution", None) + if record is None: + return + try: + runtime._last_backend_execution = replace( + record, + workspace_aliases=( + *tuple(getattr(record, "workspace_aliases", ()) or ()), + f"flat_bucket_temporal_registered_backward_memory_stage:{summary}", + ), + ) + except TypeError: + return + + +_REGISTERED_BACKWARD_NATIVE_MEMORY_STAGE_NAMES = { + 1: "native_entry", + 2: "native_after_grad_cells_seed", + 3: "native_after_readout", + 4: "native_after_output_message", + 5: "native_after_recurrent_kv", + 6: "native_after_front_outputs", + 7: "native_after_transition", + 8: "native_after_recurrent_msg_buffer", + 9: "native_after_recurrent_message", + 10: "native_after_boundary_kv", + 11: "native_after_initial_recurrent_kv", + 12: "native_after_boundary_outputs", + 13: "native_after_step_return", + 14: "native_after_seed_update", + 15: "native_after_carry_update", + 16: "native_after_stable_append", + 17: "native_return", + 18: "native_after_transition_keep_slots", + 101: "native_transition_group_entry", + 102: "native_transition_group_params_bound", + 103: "native_transition_group_dynamic_bound", + 104: "native_transition_group_after_forward_recompute", + 105: "native_transition_group_after_reverse_primitive", + 201: "native_forward_entry", + 202: "native_forward_after_input_kv", + 203: "native_forward_after_recurrent_kv_before", + 204: "native_forward_after_recurrent_message", + 205: "native_forward_after_transition", + 206: "native_forward_after_recurrent_kv_after", + 207: "native_forward_after_readout_message", + 208: "native_forward_after_readout_projection", + 209: "native_forward_after_output_route", + 210: "native_forward_after_tensor_compaction", + 211: "native_forward_return", + 212: "native_forward_message_after_output_weight", + 213: "native_forward_message_after_weighted_value", + 214: "native_forward_message_after_projected", + 215: "native_forward_message_after_normalize", + 216: "native_forward_message_before_weighted_value_alloc", + 217: "native_forward_message_after_weighted_value_alloc", + 218: "native_forward_message_before_output_weight", + 219: "native_forward_message_before_projected_gemm", + 220: "native_forward_message_after_projected_gemm", + 221: "native_forward_message_after_projected_contiguous", + 222: "native_forward_message_before_normalize", + 223: "native_forward_after_streaming_message_release", +} + + +def _record_registered_backward_native_memory_stage_rows( + runtime: Any, + rows: torch.Tensor | None, +) -> None: + if not torch.is_tensor(rows) or rows.device.type != "cpu" or rows.dtype != torch.long or rows.dim() != 2: + return + if int(rows.shape[1]) != 5: + return + for local_step, stage_id, allocated, reserved, max_allocated in rows.tolist(): + stage_name = _REGISTERED_BACKWARD_NATIVE_MEMORY_STAGE_NAMES.get( + int(stage_id), + f"native_stage_{int(stage_id)}", + ) + if int(local_step) >= 0: + stage_name = f"{stage_name}_local{int(local_step)}" + _append_registered_backward_memory_stage_summary( + runtime, + ( + f"stage={stage_name};" + f"allocated={int(allocated)};" + f"reserved={int(reserved)};" + f"max_allocated={int(max_allocated)}" + ), + ) + + +def _requires_reverse_grad_recurrent_msg_runtime_buffer( + *, + transition_step_tables: _TransitionBoundaryReverseStepProgramTables, + recurrent_count: int, +) -> bool: + rows = transition_step_tables.recurrent_msg_output_rows + if not torch.is_tensor(rows) or rows.device.type != "cpu" or rows.dtype != torch.long or rows.dim() != 2: + return True + if int(rows.shape[1]) != 4 or int(rows.shape[0]) != 1: + return True + if len(transition_step_tables.group_metadata) != 1: + return True + row = [int(item) for item in rows[0].tolist()] + return not (row[0] == 0 and row[2] == 0 and row[3] == int(recurrent_count)) + + +def _run_registered_temporal_reverse_program_tensor_table_window( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + window_start: int, + local_time_steps: int, + grad_output_cell_window: torch.Tensor, + grad_carry_cells: torch.Tensor | None, + materialize_grad_carry_cells: bool, + grad_next_backend_state_cache: dict[str, object] | None, + static_tensors: dict[str, object], + trainable_params: tuple[torch.Tensor, ...], + trainable_param_names: tuple[str, ...], + output_contract: TemporalOutputContract, + boundary_requires_grad: bool, + return_window_start_transition_state_grads: bool, + reverse_artifact_roles: tuple[str, ...], + reverse_artifact_tensors: tuple[torch.Tensor, ...], + reverse_artifact_binding_rows: torch.Tensor, + reverse_artifact_role_rows: torch.Tensor, + reverse_artifact_access_rows: torch.Tensor, + transition_executors: tuple[RegisteredTemporalExecutorHandle, ...], + transition_forward_executors: tuple[RegisteredTemporalExecutorHandle, ...], + boundary_step_for_local_index: Callable[[int], torch.Tensor], + cells_prev_for_local_index: Callable[[int], torch.Tensor], + reset_step_for_local_index: Callable[[int], torch.Tensor | None], + transition_reset_step_for_local_index: Callable[[int], torch.Tensor | None], + message_step_index_for_local_index: Callable[[int], int], + runtime_schedule_plan: TemporalMemoryRuntimeSchedulePlan | None, +) -> TemporalBackwardWindowResult: + if not hasattr(runtime, "_last_flat_bucket_temporal_registered_backward_memory_stages"): + runtime._last_flat_bucket_temporal_registered_backward_memory_stages = () + message_executors = executor_program.reverse_surface_handles(surface="message") + readout_executors = executor_program.reverse_surface_handles(surface="readout") + if not message_executors or not readout_executors: + _reject_registered_reverse_program_window( + runtime, + "reverse_surface_reducer_outputs_missing_surface_coverage:" + f"message_spans={len(message_executors)};readout_spans={len(readout_executors)}", + ) + return None + reverse_program_tensor_table = build_reverse_executable_program_tensor_table( + runtime, + executor_program=executor_program, + static_tensors=static_tensors, + reference=boundary_step_for_local_index(0), + ) + runtime._last_flat_bucket_temporal_reverse_executable_program_tensor_table = ( + reverse_program_tensor_table.review_summary + ) + runtime._last_flat_bucket_temporal_fused_backward_program_output_grad = ( + "registered_temporal_fused_backward_program_cuda;" + f"window_start={int(window_start)};window_len={int(local_time_steps)};" + f"artifact_roles={','.join(reverse_artifact_roles)}" + ) + _record_registered_backward_memory_stage( + runtime, + boundary_step_for_local_index(0), + "reverse_tensor_table_built", + ) + _record_temporal_backward_glue_cuda(runtime, "registered_fused_backward_program_span") + + parameter_reducer_requests: list[TemporalParameterReducerRequest] = [] + grad_boundary_steps: list[torch.Tensor | None] = [None for _ in range(int(local_time_steps))] + current_grad_carry_cells: torch.Tensor | None = ( + grad_carry_cells.contiguous() if torch.is_tensor(grad_carry_cells) else None + ) + current_grad_next_backend_state_cache: dict[str, object] | None = grad_next_backend_state_cache + materialize_optional_transition_outputs = _transition_forward_optional_outputs_required( + physical_time_steps=int(local_time_steps), + materialize_final_state=False, + grad_carry_cells=grad_carry_cells, + ) or bool(current_grad_next_backend_state_cache) + transition_step_tables = _build_transition_boundary_reverse_step_program_tables( + runtime, + executor_program, + transition_executors, + forward_transition_executors=transition_forward_executors, + materialize_optional_transition_outputs=materialize_optional_transition_outputs, + grad_next_backend_state_cache=current_grad_next_backend_state_cache, + static_tensors=static_tensors, + reverse_artifact_tensors=reverse_artifact_tensors, + ) + reverse_reset_tensor_groups: list[tuple[torch.Tensor, ...]] = [] + reverse_reset_row_groups: list[torch.Tensor] = [] + for local_index in range(int(local_time_steps)): + reverse_reset_tensors, reverse_reset_rows = temporal_reverse_reset_tensor_table( + message_reset_step=reset_step_for_local_index(int(local_index)), + transition_reset_step=transition_reset_step_for_local_index(int(local_index)), + ) + reverse_reset_tensor_groups.append(reverse_reset_tensors) + reverse_reset_row_groups.append(reverse_reset_rows) + message_step_indices = torch.tensor( + [int(message_step_index_for_local_index(int(local_index))) for local_index in range(int(local_time_steps))], + dtype=torch.long, + ) + reference_boundary = boundary_step_for_local_index(0) + forward_executor_binding_rows = _forward_executor_binding_rows_for_runtime( + executor_program, + materialize_optional_transition_outputs=materialize_optional_transition_outputs, + ) + grad_cells_shape = ( + int(reference_boundary.shape[0]), + int(reference_boundary.shape[1]) + + int(runtime.recurrent_cell_idx.numel()) + + int(runtime.output_cell_idx.numel()), + int(reference_boundary.shape[2]), + ) + reverse_grad_recurrent_msg_shape = ( + None + if not _requires_reverse_grad_recurrent_msg_runtime_buffer( + transition_step_tables=transition_step_tables, + recurrent_count=int(runtime.recurrent_cell_idx.numel()), + ) + else ( + int(reference_boundary.shape[0]), + int(runtime.recurrent_cell_idx.numel()), + temporal_message_output_dim(runtime), + ) + ) + runtime_buffer_plan = build_temporal_runtime_buffer_plan( + executor_program.memory_plan, + grad_carry_cells_shape=grad_cells_shape if bool(materialize_grad_carry_cells) else None, + reverse_grad_cells_work_shape=grad_cells_shape, + reverse_message_step_flat_shape=(int(reference_boundary.shape[0]),), + reverse_grad_recurrent_msg_shape=reverse_grad_recurrent_msg_shape, + transition_forward_outputs=_transition_forward_runtime_buffer_requests( + runtime, + executor_program, + batch_size=int(reference_boundary.shape[0]), + materialize_optional_outputs=materialize_optional_transition_outputs, + ), + transition_reverse_dynamic_buffers=_transition_reverse_dynamic_runtime_buffer_requests( + runtime, + executor_program, + batch_size=int(reference_boundary.shape[0]), + local_time_steps=int(local_time_steps), + transition_step_tables=transition_step_tables, + reverse_artifact_binding_rows=reverse_artifact_binding_rows, + ), + runtime_schedule_plan=runtime_schedule_plan, + dtype=str(reference_boundary.dtype), + device=str(reference_boundary.device), + include_workspace_rows=True, + enable_public_state_runtime_alias=False, + ) + _record_registered_backward_memory_stage(runtime, reference_boundary, "runtime_buffer_plan_built") + runtime_buffer_tensors = allocate_temporal_runtime_buffers(reference_boundary, runtime_buffer_plan) + _record_registered_backward_memory_stage(runtime, reference_boundary, "runtime_buffers_allocated") + runtime_buffer_rows = temporal_runtime_buffer_rows_tensor(runtime_buffer_plan) + reverse_grad_carry_cells_buffer: torch.Tensor | None = None + if bool(materialize_grad_carry_cells): + reverse_grad_carry_buffer_index = next( + index for index, spec in enumerate(runtime_buffer_plan.specs) if spec.name == "reverse_grad_carry_cells" + ) + reverse_grad_carry_cells_buffer = runtime_buffer_tensors[int(reverse_grad_carry_buffer_index)] + runtime._last_flat_bucket_temporal_memory_runtime_buffer_plan = runtime_buffer_plan.review_summary + runtime._last_flat_bucket_temporal_memory_runtime_buffer_rows = runtime_buffer_rows + if runtime_schedule_plan is None: + raise RuntimeError("Registered fused backward program requires compiler memory runtime schedule rows") + physical_strategy_plan = build_temporal_physical_strategy_plan( + runtime_schedule_plan, + inner_steps=1, + output_boundary="terminal" if int(grad_output_cell_window.shape[1]) == 1 else "sequence", + reset_policy="present" if any(reverse_reset_tensor_groups) else "absent", + ) + physical_strategy_rows = temporal_physical_strategy_rows_tensor(physical_strategy_plan) + runtime._last_flat_bucket_temporal_physical_strategy_plan = physical_strategy_plan.review_summary + runtime._last_flat_bucket_temporal_physical_strategy_rows = physical_strategy_rows + reverse_runtime_plan = build_temporal_reverse_program_runtime_plan( + runtime, + reference_boundary=reference_boundary, + message_step_indices=message_step_indices, + return_boundary_grad=bool(boundary_requires_grad), + use_sparse_messages=bool(getattr(runtime, "_uses_sparse_message_backend", False)), + ) + runtime._last_flat_bucket_temporal_reverse_program_runtime_plan = reverse_runtime_plan.review_summary + runtime._last_flat_bucket_temporal_reverse_program_runtime_rows = reverse_runtime_plan.rows + _record_registered_backward_memory_stage(runtime, reference_boundary, "before_fused_backward_program") + fused_backward_kwargs = dict( + reverse_program_stage_rows=executor_program.reverse_program_stage_rows, + grad_output_window=grad_output_cell_window, + grad_carry_cells=current_grad_carry_cells, + reverse_artifact_tensors=reverse_artifact_tensors, + reverse_artifact_binding_rows=reverse_artifact_binding_rows, + reverse_artifact_role_rows=reverse_artifact_role_rows, + reverse_artifact_access_rows=reverse_artifact_access_rows, + forward_artifact_route_rows=executor_program.forward_artifact_route_rows, + forward_artifact_merge_rows=executor_program.forward_artifact_merge_rows, + forward_output_route_rows=executor_program.forward_output_route_rows, + reverse_artifact_consumer_route_rows=executor_program.reverse_artifact_consumer_route_rows, + reverse_reset_tensor_groups=tuple(reverse_reset_tensor_groups), + reverse_reset_row_groups=tuple(reverse_reset_row_groups), + program_tensors=reverse_program_tensor_table.program_tensors, + program_tensor_binding_rows=reverse_program_tensor_table.program_tensor_binding_rows, + reverse_program_access_rows=reverse_program_tensor_table.reverse_program_access_rows, + primitive_rows=executor_program.primitive_rows, + forward_executor_rows=executor_program.forward_plan.forward_executor_rows, + reverse_executor_rows=executor_program.backward_plan.reverse_executor_rows, + forward_handler_rows=executor_program.forward_handler_rows, + reverse_handler_rows=executor_program.reverse_handler_rows, + native_strategy_rows=executor_program.native_strategy_rows, + native_callable_binding_schema_rows=executor_program.native_callable_binding_schema_rows, + native_callable_output_rows=executor_program.native_callable_output_rows, + reverse_span_output_rows=executor_program.reverse_span_output_rows, + transition_reverse_seed_role_rows=executor_program.transition_reverse_seed_role_rows, + transition_primitive_callable_rows=executor_program.transition_primitive_callable_rows, + forward_executor_binding_rows=forward_executor_binding_rows, + reverse_executor_binding_rows=executor_program.backward_plan.executor_binding_rows, + memory_liveness_rows=executor_program.memory_liveness_rows, + memory_runtime_schedule_rows=cast(torch.Tensor, runtime_buffer_plan.runtime_schedule_rows), + physical_strategy_rows=physical_strategy_rows, + runtime_buffer_tensors=runtime_buffer_tensors, + runtime_buffer_rows=runtime_buffer_rows, + reverse_program_runtime_tensors=reverse_runtime_plan.tensors, + reverse_program_runtime_rows=reverse_runtime_plan.rows, + transition_program_tensor_groups=transition_step_tables.program_tensor_groups, + transition_program_tensor_binding_row_groups=transition_step_tables.program_tensor_binding_row_groups, + transition_forward_executor_row_groups=transition_step_tables.forward_executor_row_groups, + transition_reverse_executor_row_groups=transition_step_tables.reverse_executor_row_groups, + transition_forward_executor_binding_row_groups=transition_step_tables.forward_executor_binding_row_groups, + transition_reverse_executor_binding_row_groups=transition_step_tables.reverse_executor_binding_row_groups, + transition_memory_liveness_row_groups=transition_step_tables.memory_liveness_row_groups, + transition_seed_tensor_groups=transition_step_tables.seed_tensor_groups, + transition_seed_row_groups=transition_step_tables.seed_row_groups, + transition_dynamic_binding_row_groups=transition_step_tables.dynamic_binding_row_groups, + transition_output_keep_slot_row_groups=transition_step_tables.output_keep_slot_row_groups, + transition_parameter_tensors=reverse_program_tensor_table.transition_parameter_tensors, + transition_parameter_rows=reverse_program_tensor_table.transition_parameter_rows, + transition_recurrent_msg_output_rows=transition_step_tables.recurrent_msg_output_rows, + transition_public_y_seed_rows=transition_step_tables.public_y_seed_rows, + transition_state_reset_rows=transition_step_tables.transition_state_reset_rows, + transition_next_seed_output_rows=transition_step_tables.next_seed_output_rows, + return_window_start_transition_state_grads=bool(return_window_start_transition_state_grads), + ) + try: + span_output_groups = registered_temporal_fused_backward_program_cuda(**fused_backward_kwargs) + except BaseException: + _record_registered_backward_memory_stage(runtime, reference_boundary, "fused_backward_program_error") + raise + _record_registered_backward_native_memory_stage_rows( + runtime, + registered_temporal_fused_backward_program_stage_memory_rows(), + ) + _record_registered_backward_memory_stage(runtime, reference_boundary, "after_fused_backward_program") + _record_temporal_backward_glue_cuda( + runtime, + "registered_fused_backward_program_local_only_span_output_elision", + ) + transition_group_count = len(transition_step_tables.group_metadata) + readout_span_count = len(readout_executors) + message_span_count = len(message_executors) + returned_readout_span_count = 0 if int(readout_span_count) == 1 else int(readout_span_count) + returned_message_span_count = 0 if int(message_span_count) == 1 else int(message_span_count) + if int(readout_span_count) == 1 or int(message_span_count) == 1: + _record_temporal_backward_glue_cuda( + runtime, + "registered_fused_backward_program_single_executor_span_output_elision", + ) + groups_per_step = ( + int(transition_group_count) + 2 + int(returned_readout_span_count) + 2 * int(returned_message_span_count) + ) + expected_output_group_count = int(local_time_steps) * int(groups_per_step) + if len(span_output_groups) != expected_output_group_count: + raise RuntimeError( + "Registered fused backward program returned mismatched output groups: " + f"expected={expected_output_group_count}; actual={len(span_output_groups)}" + ) + + for span_step_index, local_index in enumerate(range(int(local_time_steps) - 1, -1, -1)): + group_offset = int(span_step_index) * int(groups_per_step) + boundary_step = boundary_step_for_local_index(int(local_index)) + front_outputs = span_output_groups[group_offset] + reverse_output_groups = span_output_groups[group_offset + 1 : group_offset + 1 + transition_group_count] + boundary_group = span_output_groups[group_offset + 1 + transition_group_count] + per_span_offset = group_offset + 2 + transition_group_count + readout_front_groups = span_output_groups[per_span_offset : per_span_offset + returned_readout_span_count] + message_front_groups = span_output_groups[ + per_span_offset + returned_readout_span_count : per_span_offset + + returned_readout_span_count + + returned_message_span_count + ] + message_boundary_groups = span_output_groups[ + per_span_offset + returned_readout_span_count + returned_message_span_count : per_span_offset + + returned_readout_span_count + + 2 * returned_message_span_count + ] + if int(readout_span_count) == 1: + readout_front_by_executor = {int(readout_executors[0].row_index): front_outputs} + else: + readout_front_by_executor = { + int(executor.row_index): group + for executor, group in zip(readout_executors, readout_front_groups, strict=True) + } + if int(message_span_count) == 1: + message_front_by_executor = {int(message_executors[0].row_index): front_outputs} + message_boundary_by_executor = {int(message_executors[0].row_index): boundary_group} + else: + message_front_by_executor = { + int(executor.row_index): group + for executor, group in zip(message_executors, message_front_groups, strict=True) + } + message_boundary_by_executor = { + int(executor.row_index): group + for executor, group in zip(message_executors, message_boundary_groups, strict=True) + } + + def routed_output(route_kind: str, target_role: str) -> torch.Tensor: + return _reverse_routed_span_output_tensor( + front_outputs=front_outputs, + boundary_outputs=boundary_group, + output_rows=executor_program.reverse_span_output_rows, + route_rows=executor_program.reverse_output_route_rows, + route_kind=route_kind, + target_role=target_role, + ) + + def reducer_output( + executor: RegisteredTemporalExecutorHandle, + route_kind: str, + target_role: str, + ) -> torch.Tensor: + if executor.surface == "readout": + executor_front_outputs = readout_front_by_executor[int(executor.row_index)] + executor_boundary_outputs = boundary_group + elif executor.surface == "message": + executor_front_outputs = message_front_by_executor[int(executor.row_index)] + executor_boundary_outputs = message_boundary_by_executor[int(executor.row_index)] + else: + executor_front_outputs = front_outputs + executor_boundary_outputs = boundary_group + return _reverse_parameter_reducer_routed_span_output_tensor( + front_outputs=executor_front_outputs, + boundary_outputs=executor_boundary_outputs, + output_rows=executor_program.reverse_span_output_rows, + reducer_route_rows=executor_program.reverse_parameter_reducer_route_rows, + executor=executor, + route_kind=route_kind, + target_role=target_role, + ) + + grad_boundary_from_projection_raw = routed_output("transition_boundary", "boundary_projection") + grad_hidden_before = routed_output("carry_grad", "hidden_graph_order_before") + grad_boundary_direct = routed_output("boundary_grad", "direct_boundary") + grad_boundary_from_projection = ( + None if int(grad_boundary_from_projection_raw.numel()) == 0 else grad_boundary_from_projection_raw + ) + for readout_executor in readout_executors: + parameter_reducer_requests.append( + TemporalReadoutOutputParamReducerRequest( + kind="readout_output", + readout_executor=readout_executor, + grad_value_to_output_weight=reducer_output( + readout_executor, + "readout_parameter_grad", + "value_to_output_weight", + ), + grad_output_cell_bias=reducer_output( + readout_executor, + "readout_parameter_grad", + "output_cell_bias", + ), + ) + ) + grad_output_q = reducer_output(readout_executor, "query_parameter_grad", "output_query") + parameter_reducer_requests.append( + TemporalRecurrentQueryParamReducerRequest( + kind="recurrent_query", + message_executor=message_executors[0], + readout_executor=readout_executor, + grad_recurrent_q_backend=None, + grad_output_q=grad_output_q, + device=boundary_step.device, + dtype=boundary_step.dtype, + ) + ) + _record_temporal_backward_glue_cuda( + runtime, + "registered_fused_backward_program_span_readout_message_kv", + ) + ( + grad_population_state, + transition_param_reducer_requests, + grad_backend_state_cache, + ) = _consume_transition_boundary_reverse_step_outputs( + runtime, + transition_step_tables=transition_step_tables, + reverse_output_groups=reverse_output_groups, + ) + parameter_reducer_requests.extend(transition_param_reducer_requests) + _record_temporal_backward_glue_cuda( + runtime, + "registered_fused_backward_program_span_recurrent_message_boundary_initial_kv", + ) + for message_executor in message_executors: + grad_recurrent_q_backend = reducer_output(message_executor, "transition_boundary", "recurrent_query") + grad_input_kv_weight = reducer_output( + message_executor, + "sender_kv_parameter_grad", + "boundary_input_kv_weight", + ) + grouped_flag = reducer_output( + message_executor, + "sender_kv_parameter_grad", + "boundary_input_kv_grouped_flag", + ) + input_kv_grouped = bool(int(grouped_flag.reshape(-1)[0].item())) + grad_initial_recurrent_kv_weight_graph_order = reducer_output( + message_executor, + "sender_kv_parameter_grad", + "initial_recurrent_kv_weight", + ) + grad_recurrent_kv_weight_graph_order = reducer_output( + message_executor, + "sender_kv_parameter_grad", + "recurrent_output_kv_weight", + ) + message_strategy = temporal_executor_strategy_registry().reverse_pattern_for_executor( + surface="message", + executor_name=message_executor.executor_name, + ) + message_boundary_output_patterns = tuple( + output + for output in message_strategy.message_param_grad_outputs + if output.source == "boundary_extra_output" + ) + message_boundary_outputs = { + output.logical_name: reducer_output( + message_executor, + "message_strategy_parameter_grad", + output.logical_name, + ) + for output in message_boundary_output_patterns + } + runtime._last_flat_bucket_temporal_message_strategy_extra_param_grad_roles = tuple( + sorted(message_boundary_outputs) + ) + runtime._last_flat_bucket_temporal_message_strategy_extra_param_grad_slots = ( + 0 + if not message_boundary_output_patterns + else 1 + max(int(output.source_index) for output in message_boundary_output_patterns) + ) + message_param_request, message_query_grad_consumed = _message_strategy_parameter_reducer_request( + message_executor, + grad_recurrent_query_backend=grad_recurrent_q_backend, + boundary_outputs_by_logical_name=message_boundary_outputs, + ) + if message_param_request is not None: + parameter_reducer_requests.append(message_param_request) + if not message_query_grad_consumed: + parameter_reducer_requests.append( + TemporalRecurrentQueryParamReducerRequest( + kind="recurrent_query", + message_executor=message_executor, + readout_executor=readout_executors[0], + grad_recurrent_q_backend=grad_recurrent_q_backend, + grad_output_q=None, + device=boundary_step.device, + dtype=boundary_step.dtype, + ) + ) + backend_order_output_raw = TemporalSenderKVProjectionRawParamGrad( + role="recurrent", + grad_weight=grad_recurrent_kv_weight_graph_order, + group_ids=runtime.kv_group_id.index_select( + 0, + runtime.recurrent_cell_idx.to(device=runtime.kv_group_id.device), + ).to(device=grad_recurrent_kv_weight_graph_order.device), + grouped=False, + ) + parameter_reducer_requests.append( + TemporalSenderKVProjectionParamReducerRequest( + kind="sender_kv_projection", + message_executor=message_executor, + raw_grads=(backend_order_output_raw,), + ) + ) + input_group_ids = ( + runtime.input_sender_kv_group_ids.to(device=grad_input_kv_weight.device, dtype=torch.long) + if input_kv_grouped + else runtime.kv_group_id.index_select( + 0, + runtime.input_cell_idx.to(device=runtime.kv_group_id.device), + ).to(device=grad_input_kv_weight.device) + ) + boundary_raw_grad = TemporalSenderKVProjectionRawParamGrad( + role=cast(Any, "input"), + grad_weight=grad_input_kv_weight, + group_ids=input_group_ids, + grouped=bool(input_kv_grouped), + ) + parameter_reducer_requests.append( + TemporalSenderKVProjectionParamReducerRequest( + kind="sender_kv_projection", + message_executor=message_executor, + raw_grads=(boundary_raw_grad,), + ) + ) + backend_order_initial_raw = TemporalSenderKVProjectionRawParamGrad( + role="recurrent", + grad_weight=grad_initial_recurrent_kv_weight_graph_order, + group_ids=runtime.kv_group_id.index_select( + 0, + runtime.recurrent_cell_idx.to(device=runtime.kv_group_id.device), + ).to(device=grad_initial_recurrent_kv_weight_graph_order.device), + grouped=False, + ) + parameter_reducer_requests.append( + TemporalSenderKVProjectionParamReducerRequest( + kind="sender_kv_projection", + message_executor=message_executor, + raw_grads=(backend_order_initial_raw,), + ) + ) + grad_boundary_steps[local_index] = _accumulate_tensor_grad(grad_boundary_direct, grad_boundary_from_projection) + if reverse_grad_carry_cells_buffer is not None: + grad_cells = reverse_grad_carry_cells_buffer + grad_cells.zero_() + if grad_hidden_before is not None and int(grad_cells.shape[1]) > 0: + grad_cells[:, runtime._recurrent_slice, :] = grad_hidden_before.to(dtype=grad_cells.dtype) + current_grad_carry_cells = grad_cells + else: + if int(local_index) != 0: + raise RuntimeError( + "Registered reverse program omitted grad-carry cells buffer before an earlier local step" + ) + current_grad_carry_cells = None + current_grad_next_backend_state_cache = grad_backend_state_cache + del grad_population_state + + _record_registered_backward_memory_stage(runtime, reference_boundary, "after_span_output_consumption") + runtime._last_flat_bucket_temporal_reverse_scan_owner = "registered_reverse_program_window" + runtime._last_flat_bucket_temporal_reverse_scan_binding_abi = "registered_executor_binding_rows" + runtime._last_flat_bucket_temporal_backward_binding_abi = "registered_executor_binding_rows" + runtime._last_flat_bucket_temporal_reverse_program_window = ( + f"registered_reverse_program_window;window_start={int(window_start)};" + f"window_len={int(local_time_steps)};output_contract={output_contract}" + ) + _record_temporal_reverse_scan_owner( + runtime, + "registered_reverse_program_window", + binding_abi="registered_executor_binding_rows", + ) + runtime._last_flat_bucket_temporal_reverse_program_window_reject = "" + runtime._last_flat_bucket_temporal_reverse_engine_reject = "" + parameter_reducer_program = build_temporal_parameter_reducer_program( + requests=tuple(parameter_reducer_requests), + reverse_program_stage_rows=executor_program.reverse_program_stage_rows, + trainable_param_names=trainable_param_names, + ) + _record_registered_backward_memory_stage(runtime, reference_boundary, "before_parameter_reducer") + parameter_reducer_param_grads = run_temporal_parameter_reducer_program( + runtime, + program=parameter_reducer_program, + static_tensors=static_tensors, + trainable_params=trainable_params, + trainable_param_names=trainable_param_names, + ) + _record_registered_backward_memory_stage(runtime, reference_boundary, "after_parameter_reducer") + return TemporalBackwardWindowResult( + grad_boundary_steps=tuple(grad_boundary_steps), + grad_carry_cells=cast(torch.Tensor | None, current_grad_carry_cells), + grad_carry_recurrent_hidden_backend=None, + grad_next_population_state=_population_grad_dict( + TensorDict({}, batch_size=[]) + if current_grad_next_backend_state_cache is None + else TensorDict( + { + name: _partial_backend_grad_state_to_population_state(cache) + for name, cache in current_grad_next_backend_state_cache.items() + if isinstance(cache, dict) + }, + batch_size=[], + ) + ), + grad_next_backend_state_cache=current_grad_next_backend_state_cache, + param_grads=parameter_reducer_param_grads, + deferred_grad_recurrent_q_backend=None, + deferred_grad_recurrent_kv_weight_backend=None, + deferred_transition_param_accum=cast(_TransitionParamGradAccumulator, {}), + ) + + +def _try_run_registered_temporal_reverse_program_tensor_store_window( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + tensor_store: TemporalReverseArtifactTensorStore, + window_start: int, + window_end: int, + grad_output_window: torch.Tensor | None, + grad_carry_cells: torch.Tensor | None, + materialize_grad_carry_cells: bool, + grad_next_backend_state_cache: dict[str, object] | None, + static_tensors: dict[str, object], + trainable_params: tuple[torch.Tensor, ...], + trainable_param_names: tuple[str, ...], + output_contract: TemporalOutputContract, + boundary_requires_grad: bool, + return_window_start_transition_state_grads: bool, + reverse_artifact_roles: tuple[str, ...], + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None, + inner_steps: int, + runtime_schedule_plan: TemporalMemoryRuntimeSchedulePlan | None, +) -> TemporalBackwardWindowResult | None: + local_time_steps = int(window_end) - int(window_start) + readout_pool = str(runtime.config.readout.pool) + reference_boundary = ( + _reverse_artifact_tensor_for_step( + tensor_store, + role="boundary_step", + physical_step=int(window_start), + ) + if int(local_time_steps) > 0 and bool(reverse_artifact_roles) + else grad_output_window + if torch.is_tensor(grad_output_window) + else torch.empty(0) + ) + runtime_support_plan = build_temporal_reverse_program_runtime_support_plan( + reference_boundary=reference_boundary, + grad_output_window=grad_output_window, + grad_carry_cells=grad_carry_cells, + materialize_grad_carry_cells=bool(materialize_grad_carry_cells), + local_time_steps=int(local_time_steps), + output_contract=output_contract, + readout_pool=readout_pool, + reverse_artifact_roles=reverse_artifact_roles, + ) + runtime._last_flat_bucket_temporal_reverse_program_runtime_support = runtime_support_plan.review_summary + runtime._last_flat_bucket_temporal_reverse_program_runtime_support_rows = runtime_support_plan.rows + _record_registered_backward_memory_stage(runtime, reference_boundary, "reverse_runtime_support_built") + if runtime_support_plan.rejection_reason is not None: + _reject_registered_reverse_program_window(runtime, runtime_support_plan.rejection_reason) + return None + grad_output_window = cast(torch.Tensor, grad_output_window) + + transition_executors = tuple( + handle + for bucket_index in executor_program.transition_bucket_ordinals() + for handle in executor_program.reverse_handles_for(surface="transition", bucket_ordinal=int(bucket_index)) + ) + transition_forward_executors = tuple( + handle + for bucket_index in executor_program.transition_bucket_ordinals() + for handle in executor_program.forward_handles_for(surface="transition", bucket_ordinal=int(bucket_index)) + ) + for physical_step in range(int(window_start), int(window_end)): + for role in reverse_artifact_roles: + if not temporal_reverse_artifact_role_is_tensor(role): + continue + if role not in {"boundary_step", "cells_prev"}: + continue + _reverse_artifact_tensor_for_step(tensor_store, role=role, physical_step=int(physical_step)) + boundary_step = _reverse_artifact_tensor_for_step( + tensor_store, + role="boundary_step", + physical_step=int(physical_step), + ) + if boundary_step.device.type != "cuda" or boundary_step.dtype != torch.float32: + _reject_registered_reverse_program_window(runtime, "unsupported_boundary_device_or_dtype") + return None + for output_route_row in executor_program.forward_output_route_rows.cpu().tolist(): + output_msg_route_row = _readout_forward_artifact_route_row_for_output_route( + executor_program, + output_route_row=[int(value) for value in output_route_row], + artifact_role="output_msg", + ) + output_msg = _reverse_artifact_tensor_for_route_step( + tensor_store, + role="output_msg", + physical_step=int(physical_step), + route_row=int(output_msg_route_row), + ) + if int(output_msg.numel()) == 0: + _reject_registered_reverse_program_window( + runtime, + f"missing_routed_output_message:physical_step={int(physical_step)};" + f"route_row={int(output_msg_route_row)}", + ) + return None + _record_registered_backward_memory_stage(runtime, reference_boundary, "reverse_artifact_preflight_done") + + reverse_artifact_tensors, reverse_artifact_binding_rows = _reverse_artifact_tensor_store_window_table( + tensor_store, + window_start=int(window_start), + window_end=int(window_end), + ) + _record_registered_backward_memory_stage(runtime, reference_boundary, "reverse_artifact_window_table_built") + grad_output_cell_window = _registered_reverse_program_output_cell_grad_window_from_tensor_store( + runtime, + executor_program=executor_program, + tensor_store=tensor_store, + window_start=int(window_start), + window_end=int(window_end), + grad_output_window=grad_output_window, + output_contract=output_contract, + ) + _record_registered_backward_memory_stage(runtime, grad_output_cell_window, "output_grad_cell_window_built") + + def physical_step_for_local_index(local_index: int) -> int: + return int(window_start) + int(local_index) + + def scan_step_for_local_index(local_index: int) -> Any: + return scalar_temporal_scan_step( + physical_step=physical_step_for_local_index(int(local_index)), + inner_steps=int(inner_steps), + ) + + def reset_step_for_local_index(local_index: int) -> torch.Tensor | None: + scan_step = scan_step_for_local_index(int(local_index)) + if torch.is_tensor(population_resets) and scan_step.apply_boundary_reset: + return population_resets[:, int(scan_step.outer_step)] + return None + + def transition_reset_step_for_local_index(local_index: int) -> torch.Tensor | None: + scan_step = scan_step_for_local_index(int(local_index)) + if torch.is_tensor(transition_resets) and scan_step.apply_transition_reset: + return transition_resets[:, int(scan_step.outer_step)] + if torch.is_tensor(population_resets) and scan_step.apply_transition_reset: + return population_resets[:, int(scan_step.outer_step)] + return None + + return _run_registered_temporal_reverse_program_tensor_table_window( + runtime, + executor_program=executor_program, + window_start=int(window_start), + local_time_steps=int(local_time_steps), + grad_output_cell_window=grad_output_cell_window, + grad_carry_cells=grad_carry_cells, + materialize_grad_carry_cells=bool(materialize_grad_carry_cells), + grad_next_backend_state_cache=grad_next_backend_state_cache, + static_tensors=static_tensors, + trainable_params=trainable_params, + trainable_param_names=trainable_param_names, + output_contract=output_contract, + boundary_requires_grad=boundary_requires_grad, + return_window_start_transition_state_grads=bool(return_window_start_transition_state_grads), + reverse_artifact_roles=reverse_artifact_roles, + reverse_artifact_tensors=reverse_artifact_tensors, + reverse_artifact_binding_rows=reverse_artifact_binding_rows, + reverse_artifact_role_rows=tensor_store.role_rows, + reverse_artifact_access_rows=tensor_store.access_rows, + transition_executors=transition_executors, + transition_forward_executors=transition_forward_executors, + boundary_step_for_local_index=lambda local_index: _reverse_artifact_tensor_for_step( + tensor_store, + role="boundary_step", + physical_step=physical_step_for_local_index(int(local_index)), + ), + cells_prev_for_local_index=lambda local_index: _reverse_artifact_tensor_for_step( + tensor_store, + role="cells_prev", + physical_step=physical_step_for_local_index(int(local_index)), + ), + reset_step_for_local_index=reset_step_for_local_index, + transition_reset_step_for_local_index=transition_reset_step_for_local_index, + message_step_index_for_local_index=lambda local_index: ( + int(scan_step_for_local_index(int(local_index)).inner_step) + 1 + ), + runtime_schedule_plan=runtime_schedule_plan, + ) + + +def run_registered_temporal_reverse_executor_tensor_store_window( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + tensor_store: TemporalReverseArtifactTensorStore, + window_start: int, + window_end: int, + grad_output_window: torch.Tensor | None, + grad_carry_cells: torch.Tensor | None, + materialize_grad_carry_cells: bool, + grad_next_backend_state_cache: dict[str, object] | None, + static_tensors: dict[str, object], + trainable_params: tuple[torch.Tensor, ...], + trainable_param_names: tuple[str, ...], + output_contract: TemporalOutputContract, + boundary_requires_grad: bool, + return_window_start_transition_state_grads: bool, + reverse_artifact_roles: tuple[str, ...], + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None, + inner_steps: int, + runtime_schedule_plan: TemporalMemoryRuntimeSchedulePlan | None = None, +) -> TemporalBackwardWindowResult: + program_window_result = _try_run_registered_temporal_reverse_program_tensor_store_window( + runtime, + executor_program=executor_program, + tensor_store=tensor_store, + window_start=int(window_start), + window_end=int(window_end), + grad_output_window=grad_output_window, + grad_carry_cells=grad_carry_cells, + materialize_grad_carry_cells=bool(materialize_grad_carry_cells), + grad_next_backend_state_cache=grad_next_backend_state_cache, + static_tensors=static_tensors, + trainable_params=trainable_params, + trainable_param_names=trainable_param_names, + output_contract=output_contract, + boundary_requires_grad=boundary_requires_grad, + return_window_start_transition_state_grads=bool(return_window_start_transition_state_grads), + reverse_artifact_roles=reverse_artifact_roles, + population_resets=population_resets, + transition_resets=transition_resets, + inner_steps=int(inner_steps), + runtime_schedule_plan=runtime_schedule_plan, + ) + if program_window_result is not None: + return program_window_result + reject = str( + getattr( + runtime, + "_last_flat_bucket_temporal_reverse_program_window_reject", + "registered_reverse_program_window_reject:unknown", + ) + ) + runtime._last_flat_bucket_temporal_reverse_scan_owner = "registered_reverse_program_window_rejected" + runtime._last_flat_bucket_temporal_reverse_scan_binding_abi = "registered_executor_binding_rows" + runtime._last_flat_bucket_temporal_backward_binding_abi = "registered_executor_binding_rows" + raise RuntimeError( + "Registered temporal reverse tensor-store windows must run through the compiler-owned fused reverse program. " + f"{reject}" + ) + + +__all__ = [ + "RegisteredTemporalExecutorHandle", + "RegisteredTemporalExecutorProgram", + "build_registered_temporal_executor_program", + "reverse_artifact_tensor_store_output_cells_for_step", + "reverse_artifact_tensor_store_tensor_for_step", + "run_registered_temporal_forward_executor_scan", + "run_registered_temporal_reverse_executor_tensor_store_window", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py new file mode 100644 index 00000000..a55822a6 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py @@ -0,0 +1,689 @@ +from __future__ import annotations + +from typing import Any, Literal, cast + +import torch +from tensordict import TensorDict, TensorDictBase + +from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_buckets import ( + _partial_backend_grad_state_to_population_state, + _population_grad_state_to_backend_grad_state, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.backward_plan import ( + build_temporal_backward_executable_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.memory_plan import ( + TemporalMemoryRuntimeSchedulePlan, + allocate_temporal_runtime_buffer, + build_temporal_memory_liveness_plan, + build_temporal_memory_runtime_artifact_plan, + build_temporal_physical_strategy_plan, + build_temporal_runtime_buffer_plan, + temporal_memory_runtime_schedule_rows_tensor, + temporal_physical_strategy_rows_tensor, + temporal_runtime_buffer_spec, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.support import ( + _accumulate_owned_tensor_grad, + _accumulate_tensor_grad, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.scan_schedule import ( + scalar_temporal_scan_step, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + TemporalArtifactStore, + TemporalBackwardWindowResult, + TemporalOutputContract, + TemporalPhysicalBackwardScanResult, + TemporalReverseArtifactTensorStore, +) + +from cortical.fabric.backend.cuda.sequence_surface.temporal.common import ( + _flat_bucket_temporal_table_plan, + _materialize_backend_recurrent_hidden_grad_to_cells, + _validate_temporal_reverse_scan_claim, +) + +from cortical.fabric.backend.cuda.sequence_surface.temporal.output_backward import ( + _temporal_output_grad_for_physical_step, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.scheduler import ( + TemporalRuntimeSchedulerPlan, + build_temporal_runtime_scheduler_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.registered_executors import ( + RegisteredTemporalExecutorProgram, + _record_registered_backward_memory_stage, + build_registered_temporal_executor_program, + reverse_artifact_tensor_store_output_cells_for_step, + reverse_artifact_tensor_store_tensor_for_step, + run_registered_temporal_reverse_executor_tensor_store_window, +) + + +def _population_grad_dict(grad_state: TensorDict) -> dict[str, dict[str, torch.Tensor | None]]: + population_grads: dict[str, dict[str, torch.Tensor | None]] = {} + for population_name, population_grad in grad_state.items(): + if not isinstance(population_grad, TensorDictBase): + continue + population_grads[population_name] = { + key: cast(torch.Tensor | None, value) if torch.is_tensor(value) else None + for key, value in population_grad.items() + } + return population_grads + + +def _population_grad_dict_from_backend_cache( + grad_backend_state_cache: dict[str, object] | None, +) -> dict[str, dict[str, torch.Tensor | None]]: + if grad_backend_state_cache is None: + return {} + population_grads: dict[str, dict[str, torch.Tensor | None]] = {} + for population_name, backend_state in grad_backend_state_cache.items(): + if not isinstance(backend_state, dict): + continue + population_grad = _partial_backend_grad_state_to_population_state(backend_state) + if not isinstance(population_grad, TensorDictBase): + continue + population_grads[population_name] = { + key: cast(torch.Tensor | None, value) if torch.is_tensor(value) else None + for key, value in population_grad.items() + } + return population_grads + + +def _zero_temporal_output_grad_from_tensor_store( + runtime: Any, + *, + executor_program: RegisteredTemporalExecutorProgram, + tensor_store: TemporalReverseArtifactTensorStore, + physical_step: int, + output_contract: TemporalOutputContract, +) -> torch.Tensor: + if output_contract == "full_cells": + cells_prev = reverse_artifact_tensor_store_tensor_for_step( + tensor_store, + role="cells_prev", + physical_step=int(physical_step), + ) + return cells_prev.new_zeros(cells_prev.shape) + output_cells = reverse_artifact_tensor_store_output_cells_for_step( + tensor_store, + executor_program, + physical_step=int(physical_step), + ) + if output_contract == "output_cells": + return output_cells.new_zeros(output_cells.shape) + if output_contract == "pooled_output_cells": + pooled = runtime._pool_output_ports(output_cells.unsqueeze(1)).squeeze(1) + return pooled.new_zeros(pooled.shape) + raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") + + +def _temporal_output_grad_tensor_store_window( + runtime: Any, + tensor_store: TemporalReverseArtifactTensorStore, + *, + executor_program: RegisteredTemporalExecutorProgram, + grad_output_seq: torch.Tensor | None, + window_start: int, + window_len: int, + outer_time_steps: int, + inner_steps: int, + output_contract: TemporalOutputContract, + output_emissions: Any | None = None, +) -> torch.Tensor | None: + grad_steps = [] + for local_step_index in range(int(window_len)): + physical_step = int(window_start) + int(local_step_index) + grad_step = _temporal_output_grad_for_physical_step( + grad_output_seq, + global_step_index=int(physical_step), + outer_time_steps=int(outer_time_steps), + inner_steps=int(inner_steps), + output_emissions=output_emissions, + ) + if grad_step is None: + grad_step = _zero_temporal_output_grad_from_tensor_store( + runtime, + executor_program=executor_program, + tensor_store=tensor_store, + physical_step=int(physical_step), + output_contract=output_contract, + ) + grad_steps.append(grad_step) + return torch.stack(grad_steps, dim=1) + + +class TemporalPhysicalBackwardScanExecutor: + def __init__( + self, + runtime: Any, + *, + static_tensors: dict[str, object], + trainable_params: tuple[torch.Tensor, ...], + trainable_param_names: tuple[str, ...], + output_contract: TemporalOutputContract, + output_boundary: Literal["sequence", "terminal"], + materialize_final_state: bool, + boundary_requires_grad: bool, + state_requires_grad: bool, + inner_steps: int, + temporal_plan: Any | None, + ) -> None: + self.runtime = runtime + self.static_tensors = static_tensors + self.trainable_params = trainable_params + self.trainable_param_names = trainable_param_names + self.output_contract = output_contract + self.output_boundary = output_boundary + self.materialize_final_state = bool(materialize_final_state) + self.boundary_requires_grad = bool(boundary_requires_grad) + self.state_requires_grad = bool(state_requires_grad) + self.inner_steps = max(1, int(inner_steps)) + self.temporal_plan = temporal_plan + self._artifact_primitive_table_fingerprint: tuple[str, ...] = () + + def _materialize_grad_carry_cells_for_window( + self, + *, + window_start: int, + window_end: int, + grad_carry_cells: torch.Tensor | None, + ) -> bool: + local_time_steps = int(window_end) - int(window_start) + return ( + bool(self.state_requires_grad) + or torch.is_tensor(grad_carry_cells) + or int(local_time_steps) > 1 + or int(window_start) > 0 + ) + + def _require_registered_reverse_executor_for_table(self, table_plan: Any) -> RegisteredTemporalExecutorProgram: + if self._artifact_primitive_table_fingerprint and ( + tuple(self._artifact_primitive_table_fingerprint) != tuple(table_plan.fingerprint) + ): + self.runtime._last_flat_bucket_temporal_reverse_engine_reject = ( + "registered_reverse_executor_bindings_required;" + "backward_status=artifact_table_fingerprint_mismatch;" + "registered_program_required=1" + ) + raise RuntimeError(self.runtime._last_flat_bucket_temporal_reverse_engine_reject) + backward_executable_plan = build_temporal_backward_executable_plan(table_plan) + if backward_executable_plan.strategy_legality_status != "legal": + self.runtime._last_flat_bucket_temporal_reverse_engine_reject = ( + "registered_reverse_executor_bindings_required;" + f"backward_status={backward_executable_plan.strategy_legality_status};" + "registered_program_required=1" + ) + raise RuntimeError(self.runtime._last_flat_bucket_temporal_reverse_engine_reject) + self.runtime._last_flat_bucket_temporal_reverse_engine_reject = "" + return build_registered_temporal_executor_program( + self.runtime, + self.static_tensors, + table_plan=table_plan, + backward_plan=backward_executable_plan, + ) + + def _run_backward_tensor_store_window( + self, + *, + tensor_store: TemporalReverseArtifactTensorStore, + window_start: int, + window_end: int, + outer_time_steps: int, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None, + grad_output_seq: torch.Tensor | None, + grad_carry_cells: torch.Tensor | None, + grad_carry_recurrent_hidden_backend: torch.Tensor | None, + grad_next_population_state: dict[str, dict[str, torch.Tensor | None]], + grad_next_backend_state_cache: dict[str, object] | None, + scheduler_plan: TemporalRuntimeSchedulerPlan, + reverse_artifact_roles: tuple[str, ...], + runtime_schedule_plan: TemporalMemoryRuntimeSchedulePlan, + ) -> TemporalBackwardWindowResult: + del grad_carry_recurrent_hidden_backend + del grad_next_population_state + local_time_steps = int(window_end) - int(window_start) + if local_time_steps <= 0: + raise RuntimeError("Temporal backward tensor-store window is empty") + stage_reference = ( + tensor_store.tensors[0] + if tensor_store.tensors + else grad_output_seq + if torch.is_tensor(grad_output_seq) + else torch.empty(0) + ) + _record_registered_backward_memory_stage(self.runtime, stage_reference, "tensor_store_window_entry") + table_plan = _flat_bucket_temporal_table_plan(self.runtime, self.static_tensors) + executor_program = self._require_registered_reverse_executor_for_table(table_plan) + _record_registered_backward_memory_stage(self.runtime, stage_reference, "executor_program_resolved") + grad_output_window = _temporal_output_grad_tensor_store_window( + self.runtime, + tensor_store, + executor_program=executor_program, + grad_output_seq=grad_output_seq, + window_start=int(window_start), + window_len=local_time_steps, + outer_time_steps=outer_time_steps, + inner_steps=self.inner_steps, + output_contract=self.output_contract, + output_emissions=scheduler_plan.output_emissions, + ) + _record_registered_backward_memory_stage( + self.runtime, + grad_output_window if torch.is_tensor(grad_output_window) else stage_reference, + "output_grad_window_materialized", + ) + window_result = run_registered_temporal_reverse_executor_tensor_store_window( + self.runtime, + executor_program=executor_program, + tensor_store=tensor_store, + window_start=int(window_start), + window_end=int(window_end), + grad_output_window=grad_output_window, + grad_carry_cells=grad_carry_cells, + materialize_grad_carry_cells=self._materialize_grad_carry_cells_for_window( + window_start=int(window_start), + window_end=int(window_end), + grad_carry_cells=grad_carry_cells, + ), + grad_next_backend_state_cache=grad_next_backend_state_cache, + static_tensors=self.static_tensors, + trainable_params=self.trainable_params, + trainable_param_names=self.trainable_param_names, + output_contract=self.output_contract, + boundary_requires_grad=self.boundary_requires_grad, + return_window_start_transition_state_grads=bool(self.state_requires_grad) or int(window_start) > 0, + reverse_artifact_roles=reverse_artifact_roles, + population_resets=population_resets, + transition_resets=transition_resets, + inner_steps=self.inner_steps, + runtime_schedule_plan=runtime_schedule_plan, + ) + _validate_temporal_reverse_scan_claim(self.runtime) + return window_result + + def run( + self, + *, + boundary_seq: torch.Tensor, + artifact_store: TemporalArtifactStore, + population_resets: torch.Tensor | None, + transition_resets: torch.Tensor | None = None, + grad_output_seq: torch.Tensor | None, + grad_final_state: TensorDict, + ) -> TemporalPhysicalBackwardScanResult: + outer_time_steps = int(boundary_seq.shape[1]) + time_steps = outer_time_steps * self.inner_steps + if not hasattr(self.runtime, "_last_flat_bucket_temporal_registered_backward_memory_stages"): + self.runtime._last_flat_bucket_temporal_registered_backward_memory_stages = () + _record_registered_backward_memory_stage(self.runtime, boundary_seq, "physical_backward_run_entry") + self._artifact_primitive_table_fingerprint = tuple(artifact_store.primitive_table_fingerprint) + scheduler_plan = build_temporal_runtime_scheduler_plan( + temporal_plan=self.temporal_plan, + outer_time_steps=outer_time_steps, + inner_steps=self.inner_steps, + output_boundary=self.output_boundary, + output_contract=self.output_contract, + materialize_final_state=self.materialize_final_state, + collect_artifacts=True, + ) + _record_registered_backward_memory_stage(self.runtime, boundary_seq, "scheduler_plan_built") + table_plan = ( + artifact_store.primitive_table_plan + if artifact_store.primitive_table_plan is not None + else _flat_bucket_temporal_table_plan(self.runtime, self.static_tensors) + ) + memory_plan = build_temporal_memory_liveness_plan(table_plan) + _record_registered_backward_memory_stage(self.runtime, boundary_seq, "memory_liveness_plan_built") + memory_artifact_plan = build_temporal_memory_runtime_artifact_plan( + memory_plan, + physical_time_steps=time_steps, + collect_artifacts=True, + scheduler_plan=scheduler_plan, + ) + physical_strategy_plan = build_temporal_physical_strategy_plan( + memory_artifact_plan.runtime_schedule_plan, + inner_steps=int(self.inner_steps), + output_boundary="terminal" if self.output_boundary == "terminal" else "sequence", + reset_policy=( + "present" if torch.is_tensor(population_resets) or torch.is_tensor(transition_resets) else "absent" + ), + ) + physical_strategy_rows = temporal_physical_strategy_rows_tensor(physical_strategy_plan) + _record_registered_backward_memory_stage(self.runtime, boundary_seq, "memory_artifact_plan_built") + _require_compiler_memory_artifact_plan( + artifact_store, + memory_plan_fingerprint=memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=memory_artifact_plan.fingerprint, + memory_runtime_policy_fingerprint=memory_artifact_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=memory_artifact_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor( + memory_artifact_plan.runtime_schedule_plan + ), + expected_mode=memory_artifact_plan.mode, + expected_checkpoint_stride=memory_artifact_plan.checkpoint_stride, + expected_recompute_window_len=memory_artifact_plan.recompute_window_len, + expected_checkpoint_steps=memory_artifact_plan.checkpoint_steps, + expected_backward_windows=memory_artifact_plan.backward_windows, + ) + _require_compiler_physical_strategy_plan( + artifact_store, + physical_strategy_fingerprint=physical_strategy_plan.fingerprint, + physical_strategy_rows=physical_strategy_rows, + ) + self.runtime._last_flat_bucket_temporal_scheduler_plan = scheduler_plan.review_summary + self.runtime._last_flat_bucket_temporal_memory_reverse_validation = ( + "memory_artifact_plan=validated_against_compiler_liveness_fingerprint", + *memory_artifact_plan.review_summary, + *physical_strategy_plan.review_summary, + ) + if artifact_store.reverse_artifact_tensor_store is None: + raise RuntimeError( + "Registered temporal backward requires compiler-owned reverse artifact tensor-store rows; " + "step-object artifact windows are not an active CUDA training path" + ) + grad_carry_cells = grad_final_state.get("cells") + grad_carry_recurrent_hidden_backend: torch.Tensor | None = None + grad_carry_template_cells: torch.Tensor | None = None + grad_next_population_state = _population_grad_dict(grad_final_state) + grad_next_backend_state_cache: dict[str, object] | None = None + if grad_next_population_state: + grad_next_backend_state_cache = { + name: _population_grad_state_to_backend_grad_state(self.runtime, name, population_grad) + for name, population_grad in grad_next_population_state.items() + } + runtime_buffer_plan = build_temporal_runtime_buffer_plan( + memory_plan, + grad_boundary_seq_shape=tuple(int(item) for item in boundary_seq.shape), + runtime_schedule_plan=memory_artifact_plan.runtime_schedule_plan, + dtype=str(boundary_seq.dtype), + device=str(boundary_seq.device), + ) + _record_registered_backward_memory_stage(self.runtime, boundary_seq, "grad_boundary_buffer_plan_built") + self.runtime._last_flat_bucket_temporal_memory_runtime_buffer_plan = runtime_buffer_plan.review_summary + grad_boundary_seq = allocate_temporal_runtime_buffer( + boundary_seq, + temporal_runtime_buffer_spec(runtime_buffer_plan, name="grad_boundary_seq"), + ) + _record_registered_backward_memory_stage(self.runtime, boundary_seq, "grad_boundary_buffer_allocated") + grad_param_accum: list[torch.Tensor | None] = [None] * len(self.trainable_params) + artifact_windows = _require_compiler_planned_artifact_windows( + artifact_store, + time_steps=time_steps, + ) + + def consume_window_result(window_start: int, window_result: TemporalBackwardWindowResult) -> None: + nonlocal grad_carry_cells + nonlocal grad_carry_recurrent_hidden_backend + nonlocal grad_next_population_state + nonlocal grad_next_backend_state_cache + if ( + torch.is_tensor(window_result.deferred_grad_recurrent_q_backend) + or torch.is_tensor(window_result.deferred_grad_recurrent_kv_weight_backend) + or window_result.deferred_transition_param_accum + ): + raise RuntimeError( + "Registered temporal reverse executor returned deferred parameter reductions outside " + "the compiler-owned reducer program" + ) + for local_step_index, grad_boundary_step in enumerate(window_result.grad_boundary_steps): + if grad_boundary_step is not None: + scan_step = scalar_temporal_scan_step( + physical_step=int(window_start) + int(local_step_index), + inner_steps=self.inner_steps, + ) + grad_boundary_seq[:, scan_step.outer_step] = _accumulate_tensor_grad( + grad_boundary_seq[:, scan_step.outer_step], + grad_boundary_step.to(dtype=grad_boundary_seq.dtype), + ) + grad_carry_cells = window_result.grad_carry_cells + grad_carry_recurrent_hidden_backend = window_result.grad_carry_recurrent_hidden_backend + grad_next_population_state = window_result.grad_next_population_state + grad_next_backend_state_cache = window_result.grad_next_backend_state_cache + for parameter_index, grad_param in enumerate(window_result.param_grads): + grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( + grad_param_accum[parameter_index], + grad_param, + ) + + self.runtime._begin_backend_owner_timing(boundary_seq.device) + artifact_windows_reversed = tuple(reversed(artifact_windows)) + with torch.profiler.record_function("fabric.backward.physical_temporal_bucket_sequence"): + window_cursor = 0 + while window_cursor < len(artifact_windows_reversed): + window_start, window_end = artifact_windows_reversed[window_cursor] + _require_compiler_scheduled_recompute_window( + artifact_store, + window_start=int(window_start), + window_end=int(window_end), + ) + self.runtime._last_flat_bucket_temporal_artifact_recompute_owner = ( + "registered_fused_forward_program_tensor_store_direct" + ) + _record_registered_backward_memory_stage(self.runtime, boundary_seq, "before_artifact_window") + if int(window_start) == 0: + if "cells_prev" in artifact_store.reverse_artifact_roles: + grad_carry_template_cells = reverse_artifact_tensor_store_tensor_for_step( + artifact_store.reverse_artifact_tensor_store, + role="cells_prev", + physical_step=int(window_start), + ) + else: + grad_carry_template_cells = boundary_seq.new_empty( + int(boundary_seq.shape[0]), + int(boundary_seq.shape[2]) + + int(self.runtime.recurrent_cell_idx.numel()) + + int(self.runtime.output_cell_idx.numel()), + int(self.runtime.hidden_size), + ) + _record_registered_backward_memory_stage(self.runtime, boundary_seq, "artifact_window_template_ready") + window_result = self._run_backward_tensor_store_window( + tensor_store=artifact_store.reverse_artifact_tensor_store, + window_start=window_start, + window_end=window_end, + outer_time_steps=int(boundary_seq.shape[1]), + population_resets=population_resets, + transition_resets=transition_resets, + grad_output_seq=grad_output_seq, + grad_carry_cells=cast(torch.Tensor | None, grad_carry_cells), + grad_carry_recurrent_hidden_backend=grad_carry_recurrent_hidden_backend, + grad_next_population_state=grad_next_population_state, + grad_next_backend_state_cache=grad_next_backend_state_cache, + scheduler_plan=scheduler_plan, + reverse_artifact_roles=artifact_store.reverse_artifact_roles, + runtime_schedule_plan=memory_artifact_plan.runtime_schedule_plan, + ) + consume_window_result(int(window_start), window_result) + window_cursor += 1 + self.runtime._finish_backend_owner_timing() + if grad_carry_cells is None and torch.is_tensor(grad_carry_recurrent_hidden_backend): + if grad_carry_template_cells is None: + raise RuntimeError( + "Temporal backward cannot materialize backend recurrent carry without a state template" + ) + grad_carry_cells = _materialize_backend_recurrent_hidden_grad_to_cells( + self.runtime, + grad_carry_recurrent_hidden_backend, + template_cells=grad_carry_template_cells, + record_tag="registered_recurrent_state_grad_materialize_final_boundary", + ) + return TemporalPhysicalBackwardScanResult( + grad_boundary_seq=grad_boundary_seq, + grad_carry_cells=cast(torch.Tensor | None, grad_carry_cells), + grad_next_population_state=grad_next_population_state, + param_grads=tuple(grad_param_accum), + ) + + +def _require_compiler_planned_artifact_windows( + artifact_store: TemporalArtifactStore, + *, + time_steps: int, +) -> tuple[tuple[int, int], ...]: + windows = tuple((int(start), int(end)) for start, end in artifact_store.backward_windows) + if not windows: + raise RuntimeError( + "Registered temporal backward requires compiler memory-plan backward_windows; " + "checkpoint and recompute window derivation must come from the compiler memory plan" + ) + cursor = 0 + for start, end in windows: + if start != cursor or end <= start or end > int(time_steps): + raise RuntimeError( + "Registered temporal backward memory-plan windows must exactly cover physical time: " + f"time_steps={int(time_steps)}; windows={windows!r}" + ) + cursor = end + if cursor != int(time_steps): + raise RuntimeError( + "Registered temporal backward memory-plan windows do not cover the full physical sequence: " + f"time_steps={int(time_steps)}; covered={int(cursor)}; windows={windows!r}" + ) + return windows + + +def _require_compiler_scheduled_recompute_window( + artifact_store: TemporalArtifactStore, + *, + window_start: int, + window_end: int, +) -> None: + candidate = (int(window_start), int(window_end)) + planned = tuple((int(start), int(end)) for start, end in artifact_store.backward_windows) + if candidate not in planned: + raise RuntimeError( + "Registered temporal backward tried to materialize an artifact window outside the compiler " + f"memory schedule: window={candidate!r}; planned={planned!r}" + ) + if int(window_end) - int(window_start) > int(artifact_store.recompute_window_len): + raise RuntimeError( + "Registered temporal backward artifact window exceeds compiler recompute window length: " + f"window={candidate!r}; recompute_window_len={int(artifact_store.recompute_window_len)}" + ) + + +def _require_compiler_memory_artifact_plan( + artifact_store: TemporalArtifactStore, + *, + memory_plan_fingerprint: tuple[str, ...], + memory_runtime_artifact_fingerprint: tuple[str, ...], + memory_runtime_policy_fingerprint: tuple[str, ...], + memory_runtime_schedule_fingerprint: tuple[str, ...], + memory_runtime_schedule_rows: torch.Tensor, + expected_mode: str, + expected_checkpoint_stride: int, + expected_recompute_window_len: int, + expected_checkpoint_steps: tuple[int, ...], + expected_backward_windows: tuple[tuple[int, int], ...], +) -> None: + if not artifact_store.memory_plan_fingerprint: + raise RuntimeError( + "Registered temporal backward requires artifact_store.memory_plan_fingerprint; " + "memory/liveness policy must be compiler-owned before reverse execution" + ) + if tuple(artifact_store.memory_plan_fingerprint) != tuple(memory_plan_fingerprint): + raise RuntimeError( + "Registered temporal backward memory-plan fingerprint mismatch: " + "artifact store was not produced by the current compiler liveness plan" + ) + if not artifact_store.memory_runtime_artifact_fingerprint: + raise RuntimeError( + "Registered temporal backward requires artifact_store.memory_runtime_artifact_fingerprint; " + "checkpoint/recompute windows must be compiler memory-plan products" + ) + if tuple(artifact_store.memory_runtime_artifact_fingerprint) != tuple(memory_runtime_artifact_fingerprint): + raise RuntimeError( + "Registered temporal backward runtime artifact-plan fingerprint mismatch: " + "checkpoint/recompute policy does not match the current compiler memory plan" + ) + if not artifact_store.memory_runtime_policy_fingerprint: + raise RuntimeError( + "Registered temporal backward requires artifact_store.memory_runtime_policy_fingerprint; " + "runtime policy rows must be compiler memory-plan products" + ) + if tuple(artifact_store.memory_runtime_policy_fingerprint) != tuple(memory_runtime_policy_fingerprint): + raise RuntimeError( + "Registered temporal backward runtime policy fingerprint mismatch: " + "runtime memory policy does not match the current compiler memory plan" + ) + if not artifact_store.memory_runtime_schedule_fingerprint: + raise RuntimeError( + "Registered temporal backward requires artifact_store.memory_runtime_schedule_fingerprint; " + "scheduler checkpoint/recompute policy must be a compiler memory-plan product" + ) + if tuple(artifact_store.memory_runtime_schedule_fingerprint) != tuple(memory_runtime_schedule_fingerprint): + raise RuntimeError( + "Registered temporal backward runtime schedule fingerprint mismatch: " + "scheduler checkpoint/recompute policy does not match the current compiler memory plan" + ) + if not torch.is_tensor(artifact_store.memory_runtime_schedule_rows): + raise RuntimeError( + "Registered temporal backward requires artifact_store.memory_runtime_schedule_rows; " + "runtime schedule policy must have concrete compiler table rows" + ) + actual_schedule_rows = artifact_store.memory_runtime_schedule_rows.to(device="cpu", dtype=torch.long).contiguous() + expected_schedule_rows_tensor = memory_runtime_schedule_rows.to(device="cpu", dtype=torch.long).contiguous() + if tuple(actual_schedule_rows.shape) != tuple(expected_schedule_rows_tensor.shape) or not torch.equal( + actual_schedule_rows, + expected_schedule_rows_tensor, + ): + raise RuntimeError( + "Registered temporal backward runtime schedule rows mismatch: " + "scheduler policy rows do not match the current compiler memory plan" + ) + if artifact_store.mode != expected_mode: + raise RuntimeError( + "Registered temporal backward artifact mode does not match compiler memory plan: " + f"actual={artifact_store.mode}; expected={expected_mode}" + ) + if int(artifact_store.checkpoint_stride) != int(expected_checkpoint_stride): + raise RuntimeError( + "Registered temporal backward checkpoint stride does not match compiler memory plan: " + f"actual={int(artifact_store.checkpoint_stride)}; expected={int(expected_checkpoint_stride)}" + ) + if int(artifact_store.recompute_window_len) != int(expected_recompute_window_len): + raise RuntimeError( + "Registered temporal backward recompute window does not match compiler memory plan: " + f"actual={int(artifact_store.recompute_window_len)}; expected={int(expected_recompute_window_len)}" + ) + if tuple(int(step) for step in artifact_store.checkpoint_steps) != tuple( + int(step) for step in expected_checkpoint_steps + ): + raise RuntimeError("Registered temporal backward checkpoint steps do not match compiler memory plan") + if tuple((int(start), int(end)) for start, end in artifact_store.backward_windows) != tuple( + (int(start), int(end)) for start, end in expected_backward_windows + ): + raise RuntimeError("Registered temporal backward windows do not match compiler memory plan") + + +def _require_compiler_physical_strategy_plan( + artifact_store: TemporalArtifactStore, + *, + physical_strategy_fingerprint: tuple[str, ...], + physical_strategy_rows: torch.Tensor, +) -> None: + if not artifact_store.physical_strategy_fingerprint: + raise RuntimeError( + "Registered temporal backward requires artifact_store.physical_strategy_fingerprint; " + "physical execution strategy must be a compiler-owned launch product" + ) + if tuple(artifact_store.physical_strategy_fingerprint) != tuple(physical_strategy_fingerprint): + raise RuntimeError( + "Registered temporal backward physical-strategy fingerprint mismatch: " + "artifact store was not produced by the current compiler physical strategy plan" + ) + if not torch.is_tensor(artifact_store.physical_strategy_rows): + raise RuntimeError( + "Registered temporal backward requires artifact_store.physical_strategy_rows; " + "physical execution strategy must have concrete compiler table rows" + ) + actual_rows = artifact_store.physical_strategy_rows.to(device="cpu", dtype=torch.long).contiguous() + expected_rows = physical_strategy_rows.to(device="cpu", dtype=torch.long).contiguous() + if tuple(actual_rows.shape) != tuple(expected_rows.shape) or not torch.equal(actual_rows, expected_rows): + raise RuntimeError( + "Registered temporal backward physical-strategy rows mismatch: " + "launch strategy rows do not match the current compiler physical strategy plan" + ) diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/scheduler.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/scheduler.py new file mode 100644 index 00000000..e24b85c1 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/scheduler.py @@ -0,0 +1,361 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import torch + + +@dataclass(frozen=True) +class TemporalOutputEmissionRuntimePlan: + selector_kind: str + outer_time_steps: int + inner_steps: int + emitted_output_count: int + physical_to_output_index: tuple[tuple[int, int], ...] + output_surface: str + materialization: str + autograd_seed_kind: str + required_backward_surfaces: tuple[str, ...] + checkpoint_policy_basis: str + owner: Literal["planner", "inferred"] + + @property + def physical_steps(self) -> tuple[int, ...]: + return tuple(physical_step for physical_step, _output_index in self.physical_to_output_index) + + def output_index_for_physical_step(self, physical_step: int) -> int | None: + physical_step = int(physical_step) + for candidate_step, output_index in self.physical_to_output_index: + if int(candidate_step) == physical_step: + return int(output_index) + return None + + def output_grad_for_physical_step( + self, + grad_output_seq: torch.Tensor | None, + *, + physical_step: int, + ) -> torch.Tensor | None: + if grad_output_seq is None: + return None + output_index = self.output_index_for_physical_step(int(physical_step)) + if output_index is None or int(output_index) >= int(grad_output_seq.shape[1]): + return None + return grad_output_seq[:, int(output_index)] + + def active_local_steps( + self, + grad_output_seq: torch.Tensor | None, + *, + window_start: int, + window_len: int, + ) -> tuple[int, ...]: + if grad_output_seq is None: + return () + return tuple( + local_step + for local_step in range(int(window_len)) + if self.output_grad_for_physical_step( + grad_output_seq, + physical_step=int(window_start) + int(local_step), + ) + is not None + ) + + def cuda_materializer_compatible(self, grad_output_seq: torch.Tensor | None) -> bool: + if grad_output_seq is None: + return False + emitted_time_steps = int(grad_output_seq.shape[1]) + if emitted_time_steps == int(self.outer_time_steps): + expected = tuple( + (outer_step * int(self.inner_steps) + int(self.inner_steps) - 1, outer_step) + for outer_step in range(int(self.outer_time_steps)) + ) + return self.physical_to_output_index == expected + if emitted_time_steps == 1: + terminal_step = (int(self.outer_time_steps) - 1) * int(self.inner_steps) + int(self.inner_steps) - 1 + return self.physical_to_output_index == ((terminal_step, 0),) + return False + + +@dataclass(frozen=True) +class TemporalCheckpointRuntimePlan: + checkpoint_steps: int | None + backward_window_steps: int | None + checkpoint_kind: str + backward_window_kind: str + owner: Literal["planner", "inferred"] + + +@dataclass(frozen=True) +class TemporalMaterializationRuntimePlan: + reverse_artifact_kind: str + checkpoint_steps: int | None + recompute_window_steps: int | None + output_materialization: str + reason: str + owner: Literal["planner", "inferred"] + + +@dataclass(frozen=True) +class TemporalReplayArtifactRequest: + output_message_physical_steps: tuple[int, ...] | None + final_state_physical_step: int | None + reason: str + + @property + def materialize_output_messages(self) -> bool: + return self.output_message_physical_steps is None or bool(self.output_message_physical_steps) + + +@dataclass(frozen=True) +class TemporalRuntimeSchedulerPlan: + outer_time_steps: int + inner_steps: int + physical_time_steps: int + output_emissions: TemporalOutputEmissionRuntimePlan + checkpoint: TemporalCheckpointRuntimePlan + materialization: TemporalMaterializationRuntimePlan + owner: Literal["planner", "inferred"] + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + f"scheduler_owner={self.owner}", + f"physical_steps={int(self.physical_time_steps)}", + f"output_selector={self.output_emissions.selector_kind}", + "output_physical_steps=" + _int_tuple_summary(self.output_emissions.physical_steps), + f"autograd_seed={self.output_emissions.autograd_seed_kind}", + f"checkpoint_steps={_optional_int_summary(self.checkpoint.checkpoint_steps)}", + f"recompute_window_steps={_optional_int_summary(self.checkpoint.backward_window_steps)}", + f"reverse_artifacts={self.materialization.reverse_artifact_kind}", + ) + + def replay_request_for_window( + self, + *, + grad_output_seq: torch.Tensor | None, + window_start: int, + window_end: int, + include_final_state_output: bool, + ) -> TemporalReplayArtifactRequest: + window_start = int(window_start) + window_end = int(window_end) + active_steps = { + window_start + local_step + for local_step in self.output_emissions.active_local_steps( + grad_output_seq, + window_start=window_start, + window_len=max(0, window_end - window_start), + ) + } + final_state_step = window_end - 1 if include_final_state_output and window_end > window_start else None + if final_state_step is not None: + active_steps.add(int(final_state_step)) + return TemporalReplayArtifactRequest( + output_message_physical_steps=tuple(sorted(active_steps)), + final_state_physical_step=final_state_step, + reason=( + "replay_artifacts=planned_output_emissions;" + f"autograd_seed={self.output_emissions.autograd_seed_kind};" + f"final_state_step={_optional_int_summary(final_state_step)}" + ), + ) + + +def build_temporal_runtime_scheduler_plan( + *, + temporal_plan: Any | None, + outer_time_steps: int, + inner_steps: int, + output_boundary: str, + output_contract: str, + materialize_final_state: bool, + collect_artifacts: bool, +) -> TemporalRuntimeSchedulerPlan: + outer_time_steps = max(1, int(outer_time_steps)) + inner_steps = max(1, int(inner_steps)) + physical_time_steps = outer_time_steps * inner_steps + owner: Literal["planner", "inferred"] = "planner" if temporal_plan is not None else "inferred" + output_emissions = _output_emission_plan( + temporal_plan=temporal_plan, + outer_time_steps=outer_time_steps, + inner_steps=inner_steps, + output_boundary=output_boundary, + output_contract=output_contract, + materialize_final_state=materialize_final_state, + collect_artifacts=collect_artifacts, + owner=owner, + ) + checkpoint = _checkpoint_plan( + temporal_plan=temporal_plan, + owner=owner, + ) + materialization = _materialization_plan( + temporal_plan=temporal_plan, + collect_artifacts=collect_artifacts, + materialize_final_state=materialize_final_state, + checkpoint=checkpoint, + owner=owner, + ) + return TemporalRuntimeSchedulerPlan( + outer_time_steps=outer_time_steps, + inner_steps=inner_steps, + physical_time_steps=physical_time_steps, + output_emissions=output_emissions, + checkpoint=checkpoint, + materialization=materialization, + owner=owner, + ) + + +def _output_emission_plan( + *, + temporal_plan: Any | None, + outer_time_steps: int, + inner_steps: int, + output_boundary: str, + output_contract: str, + materialize_final_state: bool, + collect_artifacts: bool, + owner: Literal["planner", "inferred"], +) -> TemporalOutputEmissionRuntimePlan: + request = getattr(temporal_plan, "output_request", None) + selector_kind = str( + getattr( + request, + "selector_kind", + "terminal_outer_step" if output_boundary == "terminal" else "all_outer_steps", + ) + ) + explicit_outer_steps = tuple(int(step) for step in getattr(request, "explicit_outer_steps", ()) or ()) + outer_steps = _selected_outer_steps( + selector_kind=selector_kind, + explicit_outer_steps=explicit_outer_steps, + outer_time_steps=outer_time_steps, + ) + physical_to_output_index = tuple( + (int(outer_step) * int(inner_steps) + int(inner_steps) - 1, output_index) + for output_index, outer_step in enumerate(outer_steps) + ) + emitted_output_count = int(getattr(request, "emitted_output_count", len(physical_to_output_index))) + materialization = str( + getattr( + request, + "materialization", + "outputs_and_final_state" if materialize_final_state else "outputs_only", + ) + ) + autograd_seed_kind = str( + getattr( + request, + "autograd_seed_kind", + "emitted_output_grad" if collect_artifacts else "none", + ) + ) + return TemporalOutputEmissionRuntimePlan( + selector_kind=selector_kind, + outer_time_steps=outer_time_steps, + inner_steps=inner_steps, + emitted_output_count=emitted_output_count, + physical_to_output_index=physical_to_output_index, + output_surface=str(getattr(request, "output_surface", output_contract)), + materialization=materialization, + autograd_seed_kind=autograd_seed_kind, + required_backward_surfaces=tuple( + str(item) for item in getattr(request, "required_backward_surfaces", ()) or () + ), + checkpoint_policy_basis=str( + getattr( + request, + "checkpoint_policy_basis", + "emitted_output_schedule" if collect_artifacts else "inference", + ) + ), + owner=owner, + ) + + +def _selected_outer_steps( + *, + selector_kind: str, + explicit_outer_steps: tuple[int, ...], + outer_time_steps: int, +) -> tuple[int, ...]: + if selector_kind == "terminal_outer_step": + return (max(0, int(outer_time_steps) - 1),) + if selector_kind == "explicit_outer_steps": + return tuple(step for step in explicit_outer_steps if 0 <= int(step) < int(outer_time_steps)) + return tuple(range(int(outer_time_steps))) + + +def _checkpoint_plan( + *, + temporal_plan: Any | None, + owner: Literal["planner", "inferred"], +) -> TemporalCheckpointRuntimePlan: + checkpoint = getattr(temporal_plan, "checkpoint", None) + backward_window = getattr(temporal_plan, "backward_window", None) + checkpoint_steps = getattr(checkpoint, "checkpoint_steps", None) + backward_window_steps = getattr(backward_window, "max_window_steps", None) + return TemporalCheckpointRuntimePlan( + checkpoint_steps=None if checkpoint_steps is None else max(1, int(checkpoint_steps)), + backward_window_steps=None if backward_window_steps is None else max(1, int(backward_window_steps)), + checkpoint_kind=str(getattr(checkpoint, "checkpoint_kind", "none")), + backward_window_kind=str(getattr(backward_window, "window_kind", "none")), + owner=owner, + ) + + +def _materialization_plan( + *, + temporal_plan: Any | None, + collect_artifacts: bool, + materialize_final_state: bool, + checkpoint: TemporalCheckpointRuntimePlan, + owner: Literal["planner", "inferred"], +) -> TemporalMaterializationRuntimePlan: + materialization = getattr(temporal_plan, "materialization", None) + reverse_artifact_kind = str( + getattr( + materialization, + "reverse_artifact_kind", + "store_step_artifacts" if collect_artifacts else "none", + ) + ) + output_materialization = str( + getattr( + materialization, + "output_materialization", + "outputs_and_final_state" if materialize_final_state else "outputs_only", + ) + ) + checkpoint_steps = getattr(materialization, "checkpoint_steps", checkpoint.checkpoint_steps) + recompute_window_steps = getattr(materialization, "recompute_window_steps", checkpoint.backward_window_steps) + return TemporalMaterializationRuntimePlan( + reverse_artifact_kind=reverse_artifact_kind, + checkpoint_steps=None if checkpoint_steps is None else max(1, int(checkpoint_steps)), + recompute_window_steps=None if recompute_window_steps is None else max(1, int(recompute_window_steps)), + output_materialization=output_materialization, + reason=str(getattr(materialization, "reason", "materialization=inferred_runtime_scheduler")), + owner=owner, + ) + + +def _optional_int_summary(value: int | None) -> str: + return "none" if value is None else str(int(value)) + + +def _int_tuple_summary(values: tuple[int, ...]) -> str: + return "none" if not values else ",".join(str(int(value)) for value in values) + + +__all__ = [ + "TemporalCheckpointRuntimePlan", + "TemporalMaterializationRuntimePlan", + "TemporalOutputEmissionRuntimePlan", + "TemporalReplayArtifactRequest", + "TemporalRuntimeSchedulerPlan", + "build_temporal_runtime_scheduler_plan", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py new file mode 100644 index 00000000..2080a39c --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +from typing import Any + +import torch +from tensordict import TensorDictBase + +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_bindings import ( + TemporalTransitionParamGradBinding, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.native_callables import ( + temporal_transition_reverse_seed_role_id, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.support import ( + _accumulate_tensor_grad, +) +from cortical.fabric.backend.cuda.transition_execution.projection import ( + _unfuse_recurrent_input_projection_grads, +) + + +def _transition_reverse_seed_role_id(role_name: str) -> int: + return temporal_transition_reverse_seed_role_id(role_name) + + +def _transition_binding_logical_names(executor: Any, *, binding_kind: str) -> tuple[str, ...]: + return tuple( + str(binding.logical_name) + for binding in getattr(executor, "bindings", ()) + if str(binding.binding_kind) == binding_kind + ) + + +def _transition_reverse_seed_state_names(reverse_executor: Any) -> tuple[str, ...]: + state_names: list[str] = [] + for logical_name in _transition_binding_logical_names(reverse_executor, binding_kind="input"): + if not logical_name.startswith("grad_next_"): + continue + _transition_reverse_seed_role_id(logical_name) + state_name = logical_name.removeprefix("grad_next_") + if state_name not in state_names: + state_names.append(state_name) + return tuple(state_names) + + +def _transition_reverse_has_public_seed(reverse_executor: Any) -> bool: + return "grad_public_y" in _transition_binding_logical_names(reverse_executor, binding_kind="input") + + +def _transition_reverse_state_grad_names(reverse_executor: Any) -> tuple[str, ...]: + output_names = set(_transition_binding_logical_names(reverse_executor, binding_kind="output")) + return tuple( + state_name + for state_name in _transition_reverse_seed_state_names(reverse_executor) + if f"grad_{state_name}" in output_names + ) + + +def _transition_tensor_by_logical( + tensors: tuple[torch.Tensor, ...], + logical_to_slot: dict[str, int], + logical_name: str, +) -> torch.Tensor: + slot = logical_to_slot.get(logical_name) + if slot is None: + raise RuntimeError(f"Registered fused transition program did not bind logical tensor {logical_name!r}") + tensor = tensors[int(slot)] + if not torch.is_tensor(tensor) or int(tensor.numel()) == 0: + raise RuntimeError(f"Registered fused transition program produced no tensor for {logical_name!r}") + return tensor + + +def _transition_tensor_by_logical_optional( + tensors: tuple[torch.Tensor, ...], + logical_to_slot: dict[str, int], + logical_name: str, +) -> torch.Tensor | None: + slot = logical_to_slot.get(logical_name) + if slot is None: + return None + tensor = tensors[int(slot)] + return tensor if torch.is_tensor(tensor) and int(tensor.numel()) > 0 else None + + +def _transition_grad_state_value( + grad_next_backend_state_cache: dict[str, object] | None, + population_name: str, + state_name: str, +) -> torch.Tensor | None: + if grad_next_backend_state_cache is None: + return None + state = grad_next_backend_state_cache.get(population_name) + if isinstance(state, TensorDictBase): + value = state.get(state_name) + return value if torch.is_tensor(value) else None + if isinstance(state, dict): + value = state.get(state_name) + return value if torch.is_tensor(value) else None + return None + + +def _transition_reverse_seed_tensor_table( + *, + reverse_executor: Any, + bucket: Any, + bucket_ordinal: int, + population_name: str, + grad_recurrent_hidden_backend: torch.Tensor | None, + grad_next_backend_state_cache: dict[str, object] | None, +) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + tensors: list[torch.Tensor] = [] + rows: list[list[int]] = [] + + def add_seed(role_name: str, tensor: torch.Tensor | None) -> None: + if tensor is None: + return + role_id = _transition_reverse_seed_role_id(role_name) + rows.append([int(role_id), len(tensors), int(bucket_ordinal)]) + tensors.append(tensor.contiguous()) + + if _transition_reverse_has_public_seed(reverse_executor): + add_seed( + "grad_public_y", + None + if grad_recurrent_hidden_backend is None + else grad_recurrent_hidden_backend[:, int(bucket.backend_start) : int(bucket.backend_stop), :], + ) + for state_name in _transition_reverse_seed_state_names(reverse_executor): + add_seed( + f"grad_next_{state_name}", + _transition_grad_state_value(grad_next_backend_state_cache, population_name, state_name), + ) + return ( + tuple(tensors), + torch.tensor(rows, dtype=torch.long) if rows else torch.empty((0, 3), dtype=torch.long), + ) + + +def _has_factorized_transition_input_projection_tape(static_tensors: dict[str, object]) -> bool: + if torch.is_tensor(static_tensors.get("input_proj_weight_t")): + return True + population_materialized = static_tensors.get("population_materialized") + if isinstance(population_materialized, dict) and len(population_materialized) == 1: + only_params = next(iter(population_materialized.values())) + if isinstance(only_params, dict) and torch.is_tensor(only_params.get("input_proj_weight_t")): + return True + return False + + +def _transition_input_projection_grad_maps( + *, + static_tensors: dict[str, object], + grad_fused_weight: torch.Tensor, + grad_fused_bias: torch.Tensor | None, + selected_static_source: str = "", +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + message_to_cell_weight = static_tensors.get("message_to_cell_weight") + value_to_cell_weight = static_tensors.get("value_to_cell_weight") + grad_matches_message_to_cell = ( + torch.is_tensor(message_to_cell_weight) + and grad_fused_weight.dim() == 2 + and tuple(int(dim) for dim in grad_fused_weight.shape) + == tuple(int(dim) for dim in message_to_cell_weight.shape) + and ( + not torch.is_tensor(value_to_cell_weight) + or tuple(int(dim) for dim in grad_fused_weight.shape) + != tuple(int(dim) for dim in value_to_cell_weight.shape) + ) + ) + if grad_fused_weight.dim() == 2 and ( + selected_static_source == "message_to_cell_weight" or grad_matches_message_to_cell + ): + static_grads = {"message_to_cell_weight": grad_fused_weight} + if grad_fused_bias is not None: + static_grads["recurrent_cell_bias"] = grad_fused_bias + return static_grads, {} + if _has_factorized_transition_input_projection_tape(static_tensors): + return _unfuse_recurrent_input_projection_grads( + static_tensors=static_tensors, + grad_fused_weight=grad_fused_weight, + grad_fused_bias=grad_fused_bias, + selected_static_source=selected_static_source, + ) + grad_value_to_cell_weight = grad_fused_weight + if grad_value_to_cell_weight.dim() == 2: + grad_value_to_cell_weight = grad_value_to_cell_weight.transpose(0, 1).contiguous() + static_grads = {"value_to_cell_weight": grad_value_to_cell_weight} + if grad_fused_bias is not None: + static_grads["recurrent_cell_bias"] = grad_fused_bias + return static_grads, {} + + +def _transition_param_grad_accumulator_from_binding_rows( + *, + bucket_static_tensors: dict[str, object], + reverse_outputs: tuple[torch.Tensor, ...], + reverse_logical_to_slot: dict[str, int], + transition_param_grad_bindings: tuple[TemporalTransitionParamGradBinding, ...], +) -> tuple[dict[str, list[torch.Tensor]], dict[str, list[torch.Tensor]]]: + if not transition_param_grad_bindings: + return {}, {} + materialized_param_grads: dict[str, list[torch.Tensor]] = {} + static_source_grads: dict[str, list[torch.Tensor]] = {} + input_projection_weight_grad: torch.Tensor | None = None + input_projection_bias_grad: torch.Tensor | None = None + input_projection_static_source = "" + for binding in transition_param_grad_bindings: + grad = _transition_tensor_by_logical( + reverse_outputs, + reverse_logical_to_slot, + binding.grad_logical_name, + ) + if binding.reducer_kind == "materialized": + materialized_param_grads.setdefault(binding.parameter_name, []).append(grad) + elif binding.reducer_kind == "input_projection_weight": + input_projection_weight_grad = _accumulate_tensor_grad(input_projection_weight_grad, grad) + if binding.selected_static_source: + input_projection_static_source = binding.selected_static_source + elif binding.reducer_kind == "input_projection_bias": + input_projection_bias_grad = _accumulate_tensor_grad(input_projection_bias_grad, grad) + else: + raise RuntimeError( + "Registered reverse transition executor received unknown compiler parameter-gradient reducer: " + f"{binding.reducer_kind!r}" + ) + if input_projection_weight_grad is not None: + input_static_grads, input_materialized_grads = _transition_input_projection_grad_maps( + static_tensors=bucket_static_tensors, + grad_fused_weight=input_projection_weight_grad, + grad_fused_bias=input_projection_bias_grad, + selected_static_source=input_projection_static_source, + ) + for name, grad in input_materialized_grads.items(): + materialized_param_grads.setdefault(name, []).append(grad) + for name, grad in input_static_grads.items(): + static_source_grads.setdefault(name, []).append(grad) + return materialized_param_grads, static_source_grads + + +__all__ = [ + "_transition_param_grad_accumulator_from_binding_rows", + "_transition_reverse_seed_role_id", + "_transition_reverse_seed_tensor_table", + "_transition_reverse_state_grad_names", + "_transition_tensor_by_logical_optional", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py new file mode 100644 index 00000000..f545e87a --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import torch +from tensordict import TensorDict + +from cortical.fabric.backend.cuda.transition_execution.types import TransitionInputProjectionParamGradStep +from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_buckets import BackendOrderTransitionParamGrads + + +TemporalPublicCarryOrder = Literal["graph_order", "backend_order"] +TemporalOutputContract = Literal["full_cells", "output_cells", "pooled_output_cells"] +TemporalTransitionTapeMode = Literal["disabled", "input_projection", "full"] + +_TransitionNamedGradSequence = dict[str, list[torch.Tensor]] +_TransitionInputProjectionParamGradSequence = list[TransitionInputProjectionParamGradStep] +_TransitionParamGradAccumulator = dict[ + str, + tuple[ + _TransitionNamedGradSequence, + _TransitionNamedGradSequence, + _TransitionInputProjectionParamGradSequence, + ], +] + + +@dataclass(frozen=True) +class TemporalBucketStepArtifacts: + physical_step_index: int + message_step_index: int + boundary_step: torch.Tensor + cells_prev: torch.Tensor + population_state_before: TensorDict + reset_step: torch.Tensor | None + transition_reset_step: torch.Tensor | None + input_k: torch.Tensor + input_v: torch.Tensor + recurrent_k_before: torch.Tensor + recurrent_v_before: torch.Tensor + recurrent_msg_backend_order: torch.Tensor + recurrent_hidden_before_backend_order: torch.Tensor | None + recurrent_hidden_backend_order: torch.Tensor + recurrent_hidden_graph_order: torch.Tensor + population_state_after: TensorDict + recurrent_k: torch.Tensor + recurrent_v: torch.Tensor + output_msg: torch.Tensor + output_cells: torch.Tensor + cells_out: torch.Tensor + public_carry_order: TemporalPublicCarryOrder + + +@dataclass(frozen=True) +class TemporalOutputBackwardStep: + grad_input_k_from_output: torch.Tensor | None + grad_input_v_from_output: torch.Tensor | None + grad_recurrent_hidden_from_output: torch.Tensor | None + + +@dataclass(frozen=True) +class TemporalOutputBackwardSequence: + steps: tuple[TemporalOutputBackwardStep, ...] + param_grads: tuple[torch.Tensor | None, ...] + + +@dataclass(frozen=True) +class TemporalBoundaryBackwardStep: + grad_input_k: torch.Tensor | None + grad_input_v: torch.Tensor | None + + +@dataclass(frozen=True) +class TemporalBoundaryBackwardSequence: + grad_boundary_steps: tuple[torch.Tensor | None, ...] + param_grads: tuple[torch.Tensor | None, ...] + + +@dataclass(frozen=True) +class TemporalRecurrentQueryBackwardStep: + grad_recurrent_q: torch.Tensor | None + grad_output_q: torch.Tensor | None + + +@dataclass(frozen=True) +class TemporalInitialRecurrentBackwardStep: + raw_param_grad: object | None + + +@dataclass(frozen=True) +class TemporalSenderKVProjectionWindowParamGrad: + role: Literal["recurrent"] + recurrent_hidden_backend_order: torch.Tensor + grad_recurrent_k_backend_order: torch.Tensor + grad_recurrent_v_backend_order: torch.Tensor + + +@dataclass(frozen=True) +class TemporalSenderKVProjectionRawParamGrad: + role: Literal["input", "recurrent"] + grad_weight: torch.Tensor + group_ids: torch.Tensor + grouped: bool + + +@dataclass(frozen=True) +class TemporalBucketStepBackwardResult: + grad_boundary: torch.Tensor | None + grad_state: TensorDict + grad_carry_recurrent_hidden_backend: torch.Tensor | None + param_grads: tuple[torch.Tensor | None, ...] + grad_backend_state_cache: dict[str, object] | None + boundary_backward_step: TemporalBoundaryBackwardStep | None + recurrent_query_backward_step: TemporalRecurrentQueryBackwardStep | None + initial_recurrent_backward_step: TemporalInitialRecurrentBackwardStep | None + transition_param_grads: BackendOrderTransitionParamGrads | None + + +@dataclass(frozen=True) +class TemporalPhysicalBackwardScanResult: + grad_boundary_seq: torch.Tensor + grad_carry_cells: torch.Tensor | None + grad_next_population_state: dict[str, dict[str, torch.Tensor | None]] + param_grads: tuple[torch.Tensor | None, ...] + + +@dataclass(frozen=True) +class TemporalBackwardWindowResult: + grad_boundary_steps: tuple[torch.Tensor | None, ...] + grad_carry_cells: torch.Tensor | None + grad_carry_recurrent_hidden_backend: torch.Tensor | None + grad_next_population_state: dict[str, dict[str, torch.Tensor | None]] + grad_next_backend_state_cache: dict[str, object] | None + param_grads: tuple[torch.Tensor | None, ...] + deferred_grad_recurrent_q_backend: torch.Tensor | None = None + deferred_grad_recurrent_kv_weight_backend: torch.Tensor | None = None + deferred_transition_param_accum: _TransitionParamGradAccumulator | None = None + + +@dataclass(frozen=True) +class TemporalArtifactCheckpoint: + step_index: int + state: TensorDict + population_state_cache: dict[str, object] | None + recurrent_k: torch.Tensor | None + recurrent_v: torch.Tensor | None + recurrent_kv_layout: TemporalPublicCarryOrder | None + + +@dataclass(frozen=True) +class TemporalReverseArtifactTensorStore: + tensors: tuple[torch.Tensor, ...] + binding_rows: torch.Tensor + role_rows: torch.Tensor + access_rows: torch.Tensor + window_start: int + window_end: int + source: str + + +@dataclass(frozen=True) +class TemporalArtifactStore: + mode: Literal["store_step_artifacts", "recompute_step_artifacts"] + artifacts_by_step: list[TemporalBucketStepArtifacts] | None + checkpoints: dict[int, TemporalArtifactCheckpoint] + checkpoint_stride: int + recompute_window_len: int + transition_tape_mode: TemporalTransitionTapeMode + reason: str + stored_artifact_step_bytes: int + checkpoint_steps: tuple[int, ...] = () + backward_windows: tuple[tuple[int, int], ...] = () + memory_plan_fingerprint: tuple[str, ...] = () + memory_runtime_artifact_fingerprint: tuple[str, ...] = () + memory_runtime_policy_fingerprint: tuple[str, ...] = () + memory_runtime_schedule_fingerprint: tuple[str, ...] = () + memory_runtime_schedule_rows: torch.Tensor | None = None + physical_strategy_fingerprint: tuple[str, ...] = () + physical_strategy_rows: torch.Tensor | None = None + primitive_table_fingerprint: tuple[str, ...] = () + primitive_table_plan: Any | None = None + reverse_artifact_roles: tuple[str, ...] = () + reverse_artifact_tensor_store: TemporalReverseArtifactTensorStore | None = None + + +@dataclass(frozen=True) +class SharedTemporalForwardScanResult: + output_seq: torch.Tensor + final_state: TensorDict + artifact_store: TemporalArtifactStore | None + + +@dataclass(frozen=True) +class TemporalTransitionTapePolicy: + mode: TemporalTransitionTapeMode + reason: str + + +__all__ = [ + "SharedTemporalForwardScanResult", + "TemporalArtifactCheckpoint", + "TemporalArtifactStore", + "TemporalBackwardWindowResult", + "TemporalBoundaryBackwardSequence", + "TemporalBoundaryBackwardStep", + "TemporalBucketStepArtifacts", + "TemporalBucketStepBackwardResult", + "TemporalInitialRecurrentBackwardStep", + "TemporalOutputBackwardSequence", + "TemporalOutputBackwardStep", + "TemporalOutputContract", + "TemporalPhysicalBackwardScanResult", + "TemporalPublicCarryOrder", + "TemporalRecurrentQueryBackwardStep", + "TemporalReverseArtifactTensorStore", + "TemporalSenderKVProjectionRawParamGrad", + "TemporalSenderKVProjectionWindowParamGrad", + "TemporalTransitionTapeMode", + "TemporalTransitionTapePolicy", + "_TransitionInputProjectionParamGradSequence", + "_TransitionNamedGradSequence", + "_TransitionParamGradAccumulator", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal/windows.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/windows.py new file mode 100644 index 00000000..d0b4fa7a --- /dev/null +++ b/src/cortical/fabric/backend/cuda/sequence_surface/temporal/windows.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import torch + +from cortical.fabric.backend.cuda.sequence_surface.temporal.types import ( + TemporalArtifactCheckpoint, +) + + +def nearest_temporal_artifact_checkpoint( + checkpoints: dict[int, TemporalArtifactCheckpoint], + step_index: int, +) -> TemporalArtifactCheckpoint: + eligible_steps = [checkpoint_step for checkpoint_step in checkpoints if int(checkpoint_step) <= int(step_index)] + if not eligible_steps: + raise RuntimeError("Temporal artifact recompute is missing an initial checkpoint") + return checkpoints[max(eligible_steps)] + + +def flatten_time_batch(tensor: torch.Tensor) -> torch.Tensor: + return tensor.contiguous().reshape(int(tensor.shape[0]) * int(tensor.shape[1]), *tensor.shape[2:]) + + +def unflatten_time_batch( + tensor: torch.Tensor | None, *, time_steps: int, batch_size: int +) -> tuple[torch.Tensor | None, ...]: + if tensor is None: + return tuple(None for _ in range(time_steps)) + return tuple(tensor.reshape(time_steps, batch_size, *tensor.shape[1:]).unbind(dim=0)) + + +def flatten_optional_time_batch(tensors: tuple[torch.Tensor | None, ...]) -> torch.Tensor | None: + template = next((tensor for tensor in tensors if tensor is not None), None) + if template is None: + return None + return flatten_time_batch( + torch.stack( + [tensor if tensor is not None else torch.zeros_like(template) for tensor in tensors], + dim=0, + ) + ) + + +__all__ = [ + "flatten_optional_time_batch", + "flatten_time_batch", + "nearest_temporal_artifact_checkpoint", + "unflatten_time_batch", +] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py deleted file mode 100644 index c01d89b7..00000000 --- a/src/cortical/fabric/backend/cuda/sequence_surface/temporal_backward.py +++ /dev/null @@ -1,2147 +0,0 @@ -from __future__ import annotations - -import logging -import os -from dataclasses import dataclass -from typing import Any, Literal, cast - -import torch -from tensordict import TensorDict, TensorDictBase - -from cortical.fabric.backend.cuda.sequence_surface.flat_buckets import ( - BackendOrderTransitionParamGrads, - _partial_backend_grad_state_to_population_state, - _population_grad_state_to_backend_grad_state, - run_backend_order_transition_buckets_backward_step_cached, - run_backend_order_transition_buckets_backward_step_cached_unbound, - run_backend_order_transition_buckets_step_cached_eager_result, -) -from cortical.fabric.backend.cuda.sequence_surface.policy import ( - artifact_storage_policy, - recompute_artifact_window_policy, - recompute_checkpoint_stride_policy, -) -from cortical.fabric.backend.cuda.sequence_surface.support import ( - _accumulate_owned_tensor_grad, - _accumulate_tensor_grad, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - backend_order_population_buckets, -) -from cortical.fabric.backend.reuse import ExecutionFamily - - -def _temporal_backward_debug_enabled() -> bool: - value = os.getenv("CORTICAL_FABRIC_TEMPORAL_BACKWARD_DEBUG", "") - return value.lower() in {"1", "true", "yes"} - - -def _debug_check_finite(label: str, tensor: torch.Tensor | None) -> None: - if not _temporal_backward_debug_enabled() or not torch.is_tensor(tensor): - return - values = tensor.detach().float() - finite = torch.isfinite(values) - if bool(finite.all()): - return - nonfinite = int((~finite).sum().item()) - finite_count = int(finite.sum().item()) - if finite_count > 0: - finite_values = values[finite] - min_value = float(finite_values.min().item()) - max_value = float(finite_values.max().item()) - max_abs = float(finite_values.abs().max().item()) - else: - min_value = float("nan") - max_value = float("nan") - max_abs = float("nan") - logging.error( - "[temporal-backward-debug] label=%s shape=%s nonfinite=%d finite_frac=%.6f " - "finite_min=%.3e finite_max=%.3e finite_max_abs=%.3e", - label, - tuple(tensor.shape), - nonfinite, - finite_count / max(1, int(values.numel())), - min_value, - max_value, - max_abs, - ) - raise RuntimeError(f"nonfinite temporal backward tensor: {label}") - - -def _debug_check_grad_tuple( - label: str, - grads: tuple[torch.Tensor | None, ...], - names: tuple[str, ...] | None = None, -) -> None: - if not _temporal_backward_debug_enabled(): - return - for index, grad in enumerate(grads): - name = names[index] if names is not None and index < len(names) else str(index) - _debug_check_finite(f"{label}.{name}", grad) - - -@dataclass(frozen=True) -class TemporalBucketStepArtifacts: - boundary_step: torch.Tensor - cells_prev: torch.Tensor - population_state_before: TensorDict - reset_step: torch.Tensor | None - transition_reset_step: torch.Tensor | None - input_k: torch.Tensor - input_v: torch.Tensor - recurrent_k_before: torch.Tensor - recurrent_v_before: torch.Tensor - recurrent_msg_backend_order: torch.Tensor - recurrent_hidden_backend_order: torch.Tensor - recurrent_hidden_graph_order: torch.Tensor - backend_state_cache_before: dict[str, object] | None - transition_backward_tape_by_population: dict[str, object] | None - population_state_after: TensorDict - recurrent_k: torch.Tensor - recurrent_v: torch.Tensor - output_msg: torch.Tensor - output_cells: torch.Tensor - cells_out: torch.Tensor - - -@dataclass(frozen=True) -class TemporalOutputBackwardStep: - grad_input_k_from_output: torch.Tensor | None - grad_input_v_from_output: torch.Tensor | None - grad_recurrent_hidden_from_output: torch.Tensor | None - - -@dataclass(frozen=True) -class TemporalOutputBackwardSequence: - steps: tuple[TemporalOutputBackwardStep, ...] - param_grads: tuple[torch.Tensor | None, ...] - - -@dataclass(frozen=True) -class TemporalBoundaryBackwardStep: - grad_input_k: torch.Tensor | None - grad_input_v: torch.Tensor | None - - -@dataclass(frozen=True) -class TemporalBoundaryBackwardSequence: - grad_boundary_steps: tuple[torch.Tensor | None, ...] - param_grads: tuple[torch.Tensor | None, ...] - - -@dataclass(frozen=True) -class TemporalRecurrentQueryBackwardStep: - grad_recurrent_q: torch.Tensor | None - - -@dataclass(frozen=True) -class TemporalInitialRecurrentBackwardStep: - raw_param_grad: object | None - - -@dataclass(frozen=True) -class TemporalBucketStepBackwardResult: - grad_boundary: torch.Tensor | None - grad_state: TensorDict - param_grads: tuple[torch.Tensor | None, ...] - grad_backend_state_cache: dict[str, object] | None - boundary_backward_step: TemporalBoundaryBackwardStep | None - recurrent_query_backward_step: TemporalRecurrentQueryBackwardStep | None - initial_recurrent_backward_step: TemporalInitialRecurrentBackwardStep | None - transition_param_grads: BackendOrderTransitionParamGrads | None - - -@dataclass(frozen=True) -class TemporalPhysicalBackwardScanResult: - grad_boundary_seq: torch.Tensor - grad_carry_cells: torch.Tensor | None - grad_next_population_state: dict[str, dict[str, torch.Tensor | None]] - param_grads: tuple[torch.Tensor | None, ...] - - -@dataclass(frozen=True) -class TemporalBackwardWindowResult: - grad_boundary_steps: tuple[torch.Tensor | None, ...] - grad_carry_cells: torch.Tensor | None - grad_next_population_state: dict[str, dict[str, torch.Tensor | None]] - grad_next_backend_state_cache: dict[str, object] | None - param_grads: tuple[torch.Tensor | None, ...] - - -@dataclass(frozen=True) -class TemporalArtifactCheckpoint: - step_index: int - state: TensorDict - population_state_cache: dict[str, object] | None - recurrent_k: torch.Tensor | None - recurrent_v: torch.Tensor | None - - -@dataclass(frozen=True) -class TemporalArtifactStore: - mode: Literal["store_step_artifacts", "recompute_step_artifacts"] - artifacts_by_step: list[TemporalBucketStepArtifacts] | None - checkpoints: dict[int, TemporalArtifactCheckpoint] - checkpoint_stride: int - recompute_window_len: int - transition_tape_mode: TemporalTransitionTapeMode - reason: str - stored_artifact_step_bytes: int - - -TemporalOutputContract = Literal["full_cells", "output_cells", "pooled_output_cells"] - - -def _validate_temporal_physical_backward_plan(planned_backward_execution: Any | None) -> None: - if planned_backward_execution is None: - raise RuntimeError("Temporal bucket physical backward requires a planned backward execution") - if not planned_backward_execution.receiver_bucket_plans or not planned_backward_execution.sender_bucket_plans: - raise RuntimeError("Temporal bucket physical backward requires receiver and sender bucket plans") - if any( - bucket_plan.execution_family != ExecutionFamily.RECEIVER_MAJOR - for bucket_plan in planned_backward_execution.receiver_bucket_plans - ): - raise RuntimeError("Temporal bucket physical backward requires receiver-major receiver-adjoint execution") - if any( - bucket_plan.execution_family != ExecutionFamily.EDGE_MAJOR - for bucket_plan in planned_backward_execution.sender_bucket_plans - ): - raise RuntimeError("Temporal bucket physical backward requires edge-major sender/public accumulation") - try: - family_behaviors = planned_backward_execution.physical_plan.family_behaviors - except KeyError as error: - raise RuntimeError(f"Temporal bucket physical backward has unregistered family {error.args[0]!r}") from error - unsupported = tuple(behavior.family for behavior in family_behaviors if behavior.behavior == "unsupported") - if unsupported: - raise RuntimeError( - "Temporal bucket physical backward cannot run unsupported families: " + ", ".join(sorted(unsupported)) - ) - - -def _pool_output_ports_backward( - runtime: Any, - port_y: torch.Tensor, - grad_pooled: torch.Tensor, -) -> torch.Tensor: - readout_pool = str(runtime.config.readout_pool) - if readout_pool == "mean": - return grad_pooled.expand_as(port_y) / max(1, int(port_y.shape[1])) - if readout_pool == "flatten": - return grad_pooled.reshape_as(port_y) - scores = torch.einsum("bph,qh->bpq", port_y, runtime.readout_query) - weights = torch.softmax(scores.to(dtype=torch.float32), dim=1).to(dtype=port_y.dtype) - direct_grad = torch.einsum("bpq,bqh->bph", weights, grad_pooled) - weighted_dot = torch.einsum("bqh,bph->bpq", grad_pooled, port_y) - expected_dot = (weights * weighted_dot).sum(dim=1, keepdim=True) - grad_scores = weights * (weighted_dot - expected_dot) - score_grad = torch.einsum("bpq,qh->bph", grad_scores, runtime.readout_query) - return direct_grad + score_grad - - -def _grad_output_cells_for_contract( - runtime: Any, - output_cells: torch.Tensor, - grad_output: torch.Tensor | None, - output_contract: TemporalOutputContract, -) -> torch.Tensor | None: - if grad_output is None: - return None - if output_contract == "output_cells": - return grad_output - if output_contract == "pooled_output_cells": - return _pool_output_ports_backward(runtime, output_cells, grad_output) - raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") - - -TemporalTransitionTapeMode = Literal["disabled", "input_projection", "full"] - - -@dataclass(frozen=True) -class TemporalTransitionTapePolicy: - mode: TemporalTransitionTapeMode - reason: str - - -def _transition_kind_for_population(runtime: Any, population_name: str) -> str: - population_spec = runtime._backend_population_specs.get(population_name) - if population_spec is None: - return "generic" - op_names = {str(getattr(op, "name", "")) for op in getattr(population_spec.transition_ir, "ops", ())} - if "gated_logspace_recurrence" in op_names: - return "gated_logspace" - if "diag_rtu" in op_names or "diagonal_recurrence" in op_names: - return "diagonal_recurrence" - return "generic" - - -def _estimate_temporal_transition_tape_step_bytes( - runtime: Any, - static_tensors: dict[str, object], - *, - batch_size: int, - dtype_bytes: int, - mode: TemporalTransitionTapeMode, -) -> int: - if mode == "disabled": - return 0 - hidden_size = int(runtime.hidden_size) - total = 0 - for bucket in backend_order_population_buckets(runtime, static_tensors): - base = int(batch_size) * int(bucket.count) * hidden_size * int(dtype_bytes) - total += base - if mode != "full": - continue - transition_kind = _transition_kind_for_population(runtime, bucket.name) - if transition_kind == "gated_logspace": - total += 8 * base - elif transition_kind == "diagonal_recurrence": - total += base - return int(total) - - -def _temporal_transition_kinds(runtime: Any, static_tensors: dict[str, object]) -> frozenset[str]: - return frozenset( - _transition_kind_for_population(runtime, bucket.name) - for bucket in backend_order_population_buckets(runtime, static_tensors) - ) - - -def temporal_transition_tape_policy( - runtime: Any, - static_tensors: dict[str, object], - *, - batch_size: int, - time_steps: int, - device: torch.device, - dtype_bytes: int, - tape_policy_bin: str | None = None, -) -> TemporalTransitionTapePolicy: - if time_steps <= 1: - return TemporalTransitionTapePolicy(mode="full", reason="transition_tape=full;time_steps<=1") - memory = runtime._cuda_memory_budget(device) if hasattr(runtime, "_cuda_memory_budget") else None - if memory is None: - return TemporalTransitionTapePolicy(mode="disabled", reason="transition_tape=disabled;memory=unknown") - input_step_bytes = _estimate_temporal_transition_tape_step_bytes( - runtime, - static_tensors, - batch_size=batch_size, - dtype_bytes=dtype_bytes, - mode="input_projection", - ) - full_step_bytes = _estimate_temporal_transition_tape_step_bytes( - runtime, - static_tensors, - batch_size=batch_size, - dtype_bytes=dtype_bytes, - mode="full", - ) - input_bytes = int(input_step_bytes) * int(time_steps) - full_bytes = int(full_step_bytes) * int(time_steps) - transition_kinds = _temporal_transition_kinds(runtime, static_tensors) - reserve_bytes = max(4 << 30, int(memory.total_bytes * 0.04)) - budget_usable_bytes = ( - int(memory.usable_bytes) if int(memory.reusable_reserved_bytes) > int(reserve_bytes) else int(memory.free_bytes) - ) - if tape_policy_bin in {"checkpoint", "tbptt"}: - bounded_budget_bytes = max( - 0, min(int(memory.total_bytes * 0.01), int(budget_usable_bytes) - int(reserve_bytes)) - ) - if full_bytes > 0 and full_bytes <= bounded_budget_bytes: - bounded_mode: TemporalTransitionTapeMode = "full" - elif input_bytes > 0 and input_bytes <= bounded_budget_bytes: - bounded_mode = "input_projection" - else: - bounded_mode = "disabled" - return TemporalTransitionTapePolicy( - mode=bounded_mode, - reason=( - f"transition_tape={bounded_mode};planner_tape_policy={tape_policy_bin};" - f"bounded_temporal_recompute=1;input_step_bytes={int(input_step_bytes)};" - f"input_sequence_bytes={int(input_bytes)};full_step_bytes={int(full_step_bytes)};" - f"full_sequence_bytes={int(full_bytes)};bounded_budget_bytes={int(bounded_budget_bytes)};" - f"transition_kinds={','.join(sorted(transition_kinds))};" - f"free_bytes={int(memory.free_bytes)};allocator_reusable_bytes={int(memory.reusable_reserved_bytes)};" - f"reserve_bytes={int(reserve_bytes)}" - ), - ) - budget_bytes = max(0, min(int(memory.total_bytes * 0.08), int(budget_usable_bytes) - int(reserve_bytes))) - if full_bytes > 0 and full_bytes <= budget_bytes: - mode: TemporalTransitionTapeMode = "full" - elif input_bytes > 0 and input_bytes <= budget_bytes: - mode = "input_projection" - else: - mode = "disabled" - return TemporalTransitionTapePolicy( - mode=mode, - reason=( - f"transition_tape={mode};input_step_bytes={int(input_step_bytes)};" - f"input_sequence_bytes={int(input_bytes)};full_step_bytes={int(full_step_bytes)};" - f"full_sequence_bytes={int(full_bytes)};budget_bytes={int(budget_bytes)};" - f"transition_kinds={','.join(sorted(transition_kinds))};" - f"free_bytes={int(memory.free_bytes)};allocator_reusable_bytes={int(memory.reusable_reserved_bytes)};" - f"reserve_bytes={int(reserve_bytes)}" - ), - ) - - -def _tree_payload_bytes(value: object, seen: set[int] | None = None) -> int: - if seen is None: - seen = set() - if torch.is_tensor(value): - object_id = id(value) - if object_id in seen: - return 0 - seen.add(object_id) - return int(value.numel()) * int(value.element_size()) - if isinstance(value, TensorDictBase): - return sum(_tree_payload_bytes(item, seen) for item in value.values()) - if isinstance(value, dict): - return sum(_tree_payload_bytes(item, seen) for item in value.values()) - if isinstance(value, (tuple, list)): - return sum(_tree_payload_bytes(item, seen) for item in value) - if hasattr(value, "__dataclass_fields__"): - return sum( - _tree_payload_bytes(getattr(value, field_name), seen) - for field_name in getattr(value, "__dataclass_fields__", {}) - ) - return 0 - - -def _make_temporal_artifact_checkpoint( - *, - step_index: int, - state: TensorDict, - population_state_cache: dict[str, object] | None, - recurrent_k: torch.Tensor | None, - recurrent_v: torch.Tensor | None, -) -> TemporalArtifactCheckpoint: - return TemporalArtifactCheckpoint( - step_index=int(step_index), - state=TensorDict(state.to_dict(), batch_size=[]), - population_state_cache=dict(population_state_cache) if population_state_cache is not None else None, - recurrent_k=recurrent_k, - recurrent_v=recurrent_v, - ) - - -def _temporal_artifact_store_policy( - runtime: Any, - *, - first_artifact: TemporalBucketStepArtifacts, - time_steps: int, - device: torch.device, -) -> TemporalArtifactStore: - if int(time_steps) <= 1: - return TemporalArtifactStore( - mode="store_step_artifacts", - artifacts_by_step=[], - checkpoints={}, - checkpoint_stride=int(time_steps), - recompute_window_len=int(time_steps), - transition_tape_mode="disabled", - reason="artifact_mode=store_step_artifacts;artifact_storage_guard=disabled;time_steps<=1", - stored_artifact_step_bytes=0, - ) - memory = runtime._cuda_memory_budget(device) if hasattr(runtime, "_cuda_memory_budget") else None - stored_artifact_step_bytes = max(1, int(_tree_payload_bytes(first_artifact))) - storage = artifact_storage_policy( - artifact_mode="store_step_artifacts", - time_steps=int(time_steps), - stored_artifact_step_bytes=int(stored_artifact_step_bytes), - memory=memory, - ) - if storage.artifact_mode != "recompute_step_artifacts": - return TemporalArtifactStore( - mode="store_step_artifacts", - artifacts_by_step=[], - checkpoints={}, - checkpoint_stride=int(time_steps), - recompute_window_len=int(time_steps), - transition_tape_mode="disabled", - reason=( - f"artifact_mode=store_step_artifacts;stored_artifact_step_bytes={stored_artifact_step_bytes};" - f"time_steps={int(time_steps)}" + (f";{storage.reason_suffix}" if storage.reason_suffix else "") - ), - stored_artifact_step_bytes=int(stored_artifact_step_bytes), - ) - checkpoint = recompute_checkpoint_stride_policy( - time_steps=int(time_steps), - estimated_step_bytes=int(stored_artifact_step_bytes), - effective_batch_split_active=False, - memory=memory, - ) - window = recompute_artifact_window_policy( - time_steps=int(time_steps), - stride=int(checkpoint.stride), - estimated_step_bytes=int(stored_artifact_step_bytes), - artifact_window_step_bytes=int(stored_artifact_step_bytes), - effective_batch_split_active=False, - memory=memory, - ) - return TemporalArtifactStore( - mode="recompute_step_artifacts", - artifacts_by_step=None, - checkpoints={}, - checkpoint_stride=max(1, int(checkpoint.stride)), - recompute_window_len=max(1, int(window.window_len)), - transition_tape_mode="disabled", - reason=( - f"artifact_mode=recompute_step_artifacts;stored_artifact_step_bytes={stored_artifact_step_bytes};" - f"time_steps={int(time_steps)};checkpoint_stride={max(1, int(checkpoint.stride))};" - f"recompute_window_len={max(1, int(window.window_len))};" - f"{storage.reason_suffix or 'artifact_storage_guard=active'};" - f"checkpoint_policy={checkpoint.reason};window_policy={window.reason}" - ), - stored_artifact_step_bytes=int(stored_artifact_step_bytes), - ) - - -def _apply_temporal_recurrent_kv_reset( - *, - reset_step: torch.Tensor | None, - recurrent_k: torch.Tensor | None, - recurrent_v: torch.Tensor | None, -) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if reset_step is None or recurrent_k is None or recurrent_v is None: - return recurrent_k, recurrent_v - reset_mask = reset_step.to(device=recurrent_k.device, dtype=torch.bool).view(int(reset_step.shape[0]), 1, 1) - return ( - torch.where(reset_mask, torch.zeros_like(recurrent_k), recurrent_k), - torch.where(reset_mask, torch.zeros_like(recurrent_v), recurrent_v), - ) - - -def compute_temporal_bucket_step_artifacts( - runtime: Any, - *, - boundary_step: torch.Tensor, - state: TensorDict, - reset_step: torch.Tensor | None, - transition_reset_step: torch.Tensor | None = None, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - input_k_step: torch.Tensor | None = None, - input_v_step: torch.Tensor | None = None, - recurrent_k_before_step: torch.Tensor | None = None, - recurrent_v_before_step: torch.Tensor | None = None, - transition_tape_mode: TemporalTransitionTapeMode = "disabled", -) -> TemporalBucketStepArtifacts: - if not bool(runtime._partitioned_layout): - raise RuntimeError("Temporal bucket physical backward requires partitioned flat graph layout") - transition_reset = reset_step if transition_reset_step is None else transition_reset_step - current_state = state - if reset_step is not None: - if step_population_state_cache is not None: - runtime._reset_stream_step_population_cache(step_population_state_cache, reset_step) - reset_state = runtime.reset_state(state, reset_step) - if not isinstance(reset_state, TensorDictBase): - raise RuntimeError("Temporal bucket physical backward requires TensorDict reset state") - current_state = TensorDict(reset_state.to_dict(), batch_size=[]) - cells_prev = current_state.get("cells") - if not torch.is_tensor(cells_prev): - raise RuntimeError("Temporal bucket physical backward requires materialized cell state") - population_state_before = TensorDict( - { - name: current_state[name] - for name in runtime._population_names - if isinstance(current_state.get(name), TensorDictBase) - }, - batch_size=[], - ) - input_sender_weight = cast(torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"]) - input_group_weight = cast(torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"]) - recurrent_sender_weight = cast(torch.Tensor | None, static_tensors["recurrent_sender_input_to_kv_weight"]) - recurrent_group_weight = cast(torch.Tensor | None, static_tensors["recurrent_group_input_to_kv_weight"]) - output_q = cast(torch.Tensor, static_tensors["output_q"]) - recurrent_q_backend_order = cast(torch.Tensor, static_tensors["recurrent_q_backend_order"]) - value_to_output_weight = cast(torch.Tensor, static_tensors["value_to_output_weight"]) - recurrent_prev = cells_prev[:, runtime._recurrent_slice, :] - if input_k_step is None or input_v_step is None: - input_k, input_v = runtime._project_sender_kv_from_cells_step( - boundary_step, - sender_input_to_kv_weight=input_sender_weight, - grouped_sender_input_to_kv_weight=input_group_weight, - sender_group_size=runtime._input_sender_kv_group_size, - contiguous_kv=True, - ) - else: - input_k, input_v = input_k_step, input_v_step - if recurrent_k_before_step is not None and recurrent_v_before_step is not None: - recurrent_k_before, recurrent_v_before = recurrent_k_before_step, recurrent_v_before_step - else: - recurrent_k_before, recurrent_v_before = runtime._project_sender_kv_from_cells_step( - recurrent_prev, - sender_input_to_kv_weight=recurrent_sender_weight, - grouped_sender_input_to_kv_weight=recurrent_group_weight, - sender_group_size=runtime._recurrent_sender_kv_group_size, - contiguous_kv=True, - ) - recurrent_msg = runtime._compute_messages_step_subset_partitioned_raw( - input_k, - input_v, - recurrent_k_before, - recurrent_v_before, - q_subset=recurrent_q_backend_order, - neighbor_idx=runtime.recurrent_neighbor_idx_backend_order, - neighbor_valid=runtime.recurrent_neighbor_valid_backend_order, - edge_distance=runtime.recurrent_edge_distance_backend_order, - edge_delay=runtime.recurrent_edge_delay_backend_order, - use_delay=runtime._has_edge_delay, - step_idx=1, - local_sender_idx=runtime.recurrent_local_sender_idx_backend_order, - local_receiver_idx_by_sender=runtime.recurrent_local_receiver_idx_by_sender_backend_order, - owner_tag="temporal_bucket_recurrent", - ) - backend_state_cache_before = dict(step_population_state_cache) if step_population_state_cache is not None else None - transition_backward_tape_by_population: dict[str, object] | None = None - if step_population_state_cache is not None and transition_tape_mode != "disabled": - transition_result = run_backend_order_transition_buckets_step_cached_eager_result( - runtime, - recurrent_msg, - step_population_state_cache, - resets=transition_reset, - batch_size=int(boundary_step.shape[0]), - static_tensors=static_tensors, - materialize_next_state=True, - transition_tape_mode=transition_tape_mode, - ) - recurrent_hidden_backend_order = transition_result.recurrent_hidden - population_state_after = transition_result.next_state - transition_backward_tape_by_population = transition_result.backward_tape_by_population - else: - recurrent_hidden_backend_order, population_state_after = runtime._run_backend_order_transition_buckets_step( - recurrent_msg, - population_state_before, - resets=transition_reset, - batch_size=int(boundary_step.shape[0]), - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=True, - ) - if step_population_state_cache is not None: - population_state_after = TensorDict( - { - name: runtime._backend_state_to_population_state(name, step_population_state_cache[name]) - for name in runtime._population_names - }, - batch_size=[], - ) - recurrent_hidden_graph_order = recurrent_hidden_backend_order.index_select( - 1, - runtime.population_backend_recurrent_inverse_order, - ) - recurrent_k, recurrent_v = runtime._project_sender_kv_from_cells_step( - recurrent_hidden_graph_order, - sender_input_to_kv_weight=recurrent_sender_weight, - grouped_sender_input_to_kv_weight=recurrent_group_weight, - sender_group_size=runtime._recurrent_sender_kv_group_size, - contiguous_kv=True, - ) - output_msg = runtime._compute_messages_step_subset_partitioned_raw( - input_k, - input_v, - recurrent_k, - recurrent_v, - q_subset=output_q, - neighbor_idx=runtime.output_neighbor_idx, - neighbor_valid=runtime.output_neighbor_valid, - edge_distance=runtime.output_edge_distance, - edge_delay=runtime.output_edge_delay, - use_delay=runtime._has_edge_delay, - step_idx=1, - local_sender_idx=runtime.output_local_sender_idx, - local_receiver_idx_by_sender=runtime.output_local_receiver_idx_by_sender, - owner_tag="temporal_bucket_readout", - ) - output_cells = runtime._project_output_cells_step_raw( - output_msg, - value_to_output_weight=value_to_output_weight, - ).to(dtype=boundary_step.dtype) - cells_out = torch.cat((boundary_step, recurrent_hidden_graph_order, output_cells), dim=1) - return TemporalBucketStepArtifacts( - boundary_step=boundary_step, - cells_prev=cells_prev, - population_state_before=population_state_before, - reset_step=reset_step, - transition_reset_step=transition_reset, - input_k=input_k, - input_v=input_v, - recurrent_k_before=recurrent_k_before, - recurrent_v_before=recurrent_v_before, - recurrent_msg_backend_order=recurrent_msg, - recurrent_hidden_backend_order=recurrent_hidden_backend_order, - recurrent_hidden_graph_order=recurrent_hidden_graph_order, - backend_state_cache_before=backend_state_cache_before, - transition_backward_tape_by_population=transition_backward_tape_by_population, - population_state_after=population_state_after, - recurrent_k=recurrent_k, - recurrent_v=recurrent_v, - output_msg=output_msg, - output_cells=output_cells, - cells_out=cells_out, - ) - - -def _recompute_temporal_bucket_artifact_window( - runtime: Any, - *, - boundary_seq: torch.Tensor, - population_resets: torch.Tensor | None, - transition_resets: torch.Tensor | None = None, - static_tensors: dict[str, object], - checkpoint: TemporalArtifactCheckpoint, - start_step: int, - end_step: int, - transition_tape_mode: TemporalTransitionTapeMode, -) -> list[TemporalBucketStepArtifacts]: - if int(start_step) < int(checkpoint.step_index): - raise RuntimeError("Temporal artifact recompute window starts before its checkpoint") - if int(end_step) <= int(start_step): - return [] - running_state = TensorDict(checkpoint.state.to_dict(), batch_size=[]) - step_population_state_cache = ( - dict(checkpoint.population_state_cache) if checkpoint.population_state_cache is not None else None - ) - running_recurrent_k = checkpoint.recurrent_k - running_recurrent_v = checkpoint.recurrent_v - artifacts_window: list[TemporalBucketStepArtifacts] = [] - with torch.no_grad(): - for step_index in range(int(checkpoint.step_index), int(end_step)): - reset_step = population_resets[:, step_index] if torch.is_tensor(population_resets) else None - transition_reset_step = ( - transition_resets[:, step_index] if torch.is_tensor(transition_resets) else reset_step - ) - running_recurrent_k, running_recurrent_v = _apply_temporal_recurrent_kv_reset( - reset_step=reset_step, - recurrent_k=running_recurrent_k, - recurrent_v=running_recurrent_v, - ) - artifacts = compute_temporal_bucket_step_artifacts( - runtime, - boundary_step=boundary_seq[:, step_index], - state=running_state, - reset_step=reset_step, - transition_reset_step=transition_reset_step, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - recurrent_k_before_step=running_recurrent_k, - recurrent_v_before_step=running_recurrent_v, - transition_tape_mode=transition_tape_mode if step_index >= int(start_step) else "disabled", - ) - running_recurrent_k = artifacts.recurrent_k - running_recurrent_v = artifacts.recurrent_v - running_state = TensorDict( - { - "cells": artifacts.cells_out, - **{name: artifacts.population_state_after[name] for name in runtime._population_names}, - }, - batch_size=[], - ) - if step_index >= int(start_step): - artifacts_window.append(artifacts) - else: - del artifacts - return artifacts_window - - -def _nearest_temporal_artifact_checkpoint( - checkpoints: dict[int, TemporalArtifactCheckpoint], - step_index: int, -) -> TemporalArtifactCheckpoint: - eligible_steps = [checkpoint_step for checkpoint_step in checkpoints if int(checkpoint_step) <= int(step_index)] - if not eligible_steps: - raise RuntimeError("Temporal artifact recompute is missing an initial checkpoint") - return checkpoints[max(eligible_steps)] - - -def _temporal_artifact_windows( - *, - time_steps: int, - checkpoint_stride: int, - window_len: int, -) -> tuple[tuple[int, int], ...]: - windows: list[tuple[int, int]] = [] - segment_start = 0 - stride = max(1, int(checkpoint_stride)) - tile = max(1, int(window_len)) - while segment_start < int(time_steps): - segment_end = min(int(time_steps), segment_start + stride) - window_start = segment_start - while window_start < segment_end: - window_end = min(segment_end, window_start + tile) - windows.append((window_start, window_end)) - window_start = window_end - segment_start = segment_end - return tuple(windows) - - -def _flatten_time_batch(tensor: torch.Tensor) -> torch.Tensor: - return tensor.contiguous().reshape(int(tensor.shape[0]) * int(tensor.shape[1]), *tensor.shape[2:]) - - -def _unflatten_time_batch( - tensor: torch.Tensor | None, *, time_steps: int, batch_size: int -) -> tuple[torch.Tensor | None, ...]: - if tensor is None: - return tuple(None for _ in range(time_steps)) - return tuple(tensor.reshape(time_steps, batch_size, *tensor.shape[1:]).unbind(dim=0)) - - -def _flatten_optional_time_batch(tensors: tuple[torch.Tensor | None, ...]) -> torch.Tensor | None: - template = next((tensor for tensor in tensors if tensor is not None), None) - if template is None: - return None - return _flatten_time_batch( - torch.stack( - [tensor if tensor is not None else torch.zeros_like(template) for tensor in tensors], - dim=0, - ) - ) - - -def run_temporal_output_backward_sequence( - runtime: Any, - artifacts_by_step: list[TemporalBucketStepArtifacts], - *, - grad_output_seq: torch.Tensor | None, - static_tensors: dict[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - output_contract: TemporalOutputContract, -) -> TemporalOutputBackwardSequence: - if not artifacts_by_step: - return TemporalOutputBackwardSequence((), tuple(None for _ in trainable_params)) - if grad_output_seq is None: - return TemporalOutputBackwardSequence( - tuple( - TemporalOutputBackwardStep( - grad_input_k_from_output=None, - grad_input_v_from_output=None, - grad_recurrent_hidden_from_output=None, - ) - for _ in artifacts_by_step - ), - tuple(None for _ in trainable_params), - ) - time_steps = len(artifacts_by_step) - batch_size = int(grad_output_seq.shape[0]) - if output_contract == "full_cells": - grad_output_cells = grad_output_seq[:, :, runtime._output_slice, :].transpose(0, 1).contiguous() - elif output_contract == "output_cells": - grad_output_cells = grad_output_seq.transpose(0, 1).contiguous() - elif output_contract == "pooled_output_cells": - grad_output_cells = torch.stack( - [ - cast( - torch.Tensor, - _grad_output_cells_for_contract( - runtime, - artifacts.output_cells, - grad_output_seq[:, step_index], - output_contract, - ), - ) - for step_index, artifacts in enumerate(artifacts_by_step) - ], - dim=0, - ).contiguous() - else: - raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") - output_msg = _flatten_time_batch(torch.stack([artifacts.output_msg for artifacts in artifacts_by_step], dim=0)) - grad_output_cells_flat = _flatten_time_batch(grad_output_cells) - grad_param_accum: list[torch.Tensor | None] = [None] * len(trainable_params) - - def accumulate_params(grads: tuple[torch.Tensor | None, ...]) -> None: - for index, grad_param in enumerate(grads): - grad_param_accum[index] = _accumulate_owned_tensor_grad(grad_param_accum[index], grad_param) - - grad_output_msg, output_projection_param_grads = runtime._run_backend_output_projection_backward_phase( - output_msg=output_msg, - grad_output_cells=grad_output_cells_flat, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - accumulate_params(output_projection_param_grads) - - input_k = _flatten_time_batch(torch.stack([artifacts.input_k for artifacts in artifacts_by_step], dim=0)) - input_v = _flatten_time_batch(torch.stack([artifacts.input_v for artifacts in artifacts_by_step], dim=0)) - recurrent_k = _flatten_time_batch(torch.stack([artifacts.recurrent_k for artifacts in artifacts_by_step], dim=0)) - recurrent_v = _flatten_time_batch(torch.stack([artifacts.recurrent_v for artifacts in artifacts_by_step], dim=0)) - ( - grad_output_q, - grad_input_k_from_output, - grad_input_v_from_output, - grad_recurrent_k_from_output, - grad_recurrent_v_from_output, - ) = runtime._run_backend_message_backward_phase( - grad_msg=grad_output_msg, - q_subset=cast(torch.Tensor, static_tensors["output_q"]), - input_k=input_k, - input_v=input_v, - recurrent_k=recurrent_k, - recurrent_v=recurrent_v, - neighbor_idx=runtime.output_neighbor_idx, - neighbor_valid=runtime.output_neighbor_valid, - edge_distance=runtime.output_edge_distance, - edge_delay=runtime.output_edge_delay, - local_sender_idx=runtime.output_local_sender_idx, - local_receiver_idx_by_sender=runtime.output_local_receiver_idx_by_sender, - use_sparse_messages=bool(getattr(runtime, "_uses_sparse_message_backend", False)), - ) - - recurrent_hidden_graph_order = _flatten_time_batch( - torch.stack([artifacts.recurrent_hidden_graph_order for artifacts in artifacts_by_step], dim=0) - ) - grad_recurrent_hidden_from_output, recurrent_projection_param_grads = ( - runtime._run_backend_sender_kv_projection_backward_phase( - role="recurrent", - sender_cells=recurrent_hidden_graph_order, - grad_k=grad_recurrent_k_from_output, - grad_v=grad_recurrent_v_from_output, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - active_receiver_window=None, - owner="grouped_projection", - ) - ) - accumulate_params(recurrent_projection_param_grads) - query_param_grads = runtime._run_backend_query_param_backward_phase( - grad_recurrent_q=None, - grad_output_q=grad_output_q, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - device=grad_output_seq.device, - dtype=grad_output_seq.dtype, - active_receiver_window=None, - ) - accumulate_params(query_param_grads) - input_k_steps = _unflatten_time_batch(grad_input_k_from_output, time_steps=time_steps, batch_size=batch_size) - input_v_steps = _unflatten_time_batch(grad_input_v_from_output, time_steps=time_steps, batch_size=batch_size) - recurrent_hidden_steps = _unflatten_time_batch( - grad_recurrent_hidden_from_output, - time_steps=time_steps, - batch_size=batch_size, - ) - steps = tuple( - TemporalOutputBackwardStep( - grad_input_k_from_output=input_k_steps[index], - grad_input_v_from_output=input_v_steps[index], - grad_recurrent_hidden_from_output=recurrent_hidden_steps[index], - ) - for index in range(time_steps) - ) - return TemporalOutputBackwardSequence(steps=steps, param_grads=tuple(grad_param_accum)) - - -def run_temporal_boundary_backward_sequence( - runtime: Any, - artifacts_by_step: list[TemporalBucketStepArtifacts], - boundary_backward_steps: tuple[TemporalBoundaryBackwardStep | None, ...], - *, - static_tensors: dict[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - boundary_requires_grad: bool, -) -> TemporalBoundaryBackwardSequence: - if not artifacts_by_step: - return TemporalBoundaryBackwardSequence((), tuple(None for _ in trainable_params)) - if len(boundary_backward_steps) != len(artifacts_by_step): - raise RuntimeError("Temporal boundary backward received mismatched step count") - time_steps = len(artifacts_by_step) - batch_size = int(artifacts_by_step[0].boundary_step.shape[0]) - grad_input_k = _flatten_optional_time_batch( - tuple(None if step is None else step.grad_input_k for step in boundary_backward_steps) - ) - grad_input_v = _flatten_optional_time_batch( - tuple(None if step is None else step.grad_input_v for step in boundary_backward_steps) - ) - if grad_input_k is None and grad_input_v is None: - return TemporalBoundaryBackwardSequence( - tuple(None for _ in range(time_steps)), - tuple(None for _ in trainable_params), - ) - boundary_seq = _flatten_time_batch(torch.stack([artifacts.boundary_step for artifacts in artifacts_by_step], dim=0)) - grad_boundary, param_grads = runtime._run_backend_boundary_public_backward_phase( - boundary_step=boundary_seq, - grad_input_k=grad_input_k, - grad_input_v=grad_input_v, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - device=boundary_seq.device, - dtype=boundary_seq.dtype, - boundary_requires_grad=boundary_requires_grad, - ) - return TemporalBoundaryBackwardSequence( - grad_boundary_steps=_unflatten_time_batch(grad_boundary, time_steps=time_steps, batch_size=batch_size), - param_grads=param_grads, - ) - - -def run_temporal_recurrent_query_backward_sequence( - runtime: Any, - query_backward_steps: tuple[TemporalRecurrentQueryBackwardStep | None, ...], - *, - static_tensors: dict[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - device: torch.device, - dtype: torch.dtype, -) -> tuple[torch.Tensor | None, ...]: - grad_recurrent_q = None - for step in query_backward_steps: - if step is not None: - grad_recurrent_q = _accumulate_tensor_grad(grad_recurrent_q, step.grad_recurrent_q) - return runtime._run_backend_query_param_backward_phase( - grad_recurrent_q=grad_recurrent_q, - grad_output_q=None, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - device=device, - dtype=dtype, - active_receiver_window=None, - ) - - -def run_temporal_initial_recurrent_param_binding_sequence( - runtime: Any, - initial_recurrent_backward_steps: tuple[TemporalInitialRecurrentBackwardStep | None, ...], - *, - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], -) -> tuple[torch.Tensor | None, ...]: - raw_grads = tuple(None if step is None else step.raw_param_grad for step in initial_recurrent_backward_steps) - with ( - torch.profiler.record_function("fabric.backward.glue.initial_recurrent.param_binding"), - runtime._backend_owner_timing("glue.initial_recurrent.param_binding"), - ): - return runtime._sender_kv_projection_param_grad_tuple_from_raw_grads( - raw_grads, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - - -def _accumulate_named_grad_map( - accumulator: dict[str, torch.Tensor], - update: dict[str, torch.Tensor], -) -> None: - for name, grad in update.items(): - accumulator[name] = grad if name not in accumulator else accumulator[name] + grad - - -def _accumulate_transition_param_grads( - accumulator: dict[str, tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]], - update: BackendOrderTransitionParamGrads | None, -) -> None: - if update is None: - return - for population_name, population_grads in update.by_population.items(): - materialized_accum, static_accum = accumulator.setdefault(population_name, ({}, {})) - _accumulate_named_grad_map(materialized_accum, population_grads.materialized_param_grads) - _accumulate_named_grad_map(static_accum, population_grads.static_source_grads) - - -def bind_temporal_transition_param_grads( - runtime: Any, - transition_param_accum: dict[str, tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]], - *, - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], -) -> tuple[torch.Tensor | None, ...]: - grad_param_accum: list[torch.Tensor | None] = [None] * len(trainable_params) - trainable_param_shapes = tuple(tuple(param.shape) for param in trainable_params) - for population_name, (materialized_param_grads, static_source_grads) in transition_param_accum.items(): - population_param_grads = runtime._state_public_explicit_param_grad_tuple( - population_name=population_name, - materialized_param_grads=materialized_param_grads, - static_source_grads=static_source_grads, - projection_param_grads={}, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - trainable_param_shapes=trainable_param_shapes, - active_receiver_window=None, - ) - for index, grad_param in enumerate(population_param_grads): - grad_param_accum[index] = _accumulate_owned_tensor_grad(grad_param_accum[index], grad_param) - return tuple(grad_param_accum) - - -def _run_temporal_bucket_step_backward_result( - runtime: Any, - artifacts: TemporalBucketStepArtifacts, - *, - grad_cells_out: torch.Tensor | None, - static_tensors: dict[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - need_grad_state_before: bool, - grad_next_population_state: dict[str, dict[str, torch.Tensor | None]] | None = None, - grad_next_backend_state_cache: dict[str, object] | None = None, - output_contract: TemporalOutputContract = "full_cells", - output_backward_step: TemporalOutputBackwardStep | None = None, - defer_boundary_backward: bool = False, - defer_recurrent_query_backward: bool = False, - defer_initial_recurrent_param_binding: bool = False, - defer_transition_param_binding: bool = False, - boundary_requires_grad: bool = True, - debug_label: str = "step", -) -> TemporalBucketStepBackwardResult: - if grad_cells_out is None and grad_next_population_state is None and grad_next_backend_state_cache is None: - return TemporalBucketStepBackwardResult( - grad_boundary=None, - grad_state=TensorDict({}, batch_size=[]), - param_grads=tuple(None for _ in trainable_params), - grad_backend_state_cache=None, - boundary_backward_step=None, - recurrent_query_backward_step=None, - initial_recurrent_backward_step=None, - transition_param_grads=None, - ) - grad_param_accum: list[torch.Tensor | None] = [None] * len(trainable_params) - - def accumulate_params(grads: tuple[torch.Tensor | None, ...]) -> None: - for index, grad_param in enumerate(grads): - grad_param_accum[index] = _accumulate_owned_tensor_grad(grad_param_accum[index], grad_param) - - if output_contract == "full_cells": - grad_boundary_direct = None if grad_cells_out is None else grad_cells_out[:, runtime._input_slice, :] - grad_recurrent_direct_graph = None if grad_cells_out is None else grad_cells_out[:, runtime._recurrent_slice, :] - grad_output_cells = None if grad_cells_out is None else grad_cells_out[:, runtime._output_slice, :] - elif output_contract == "output_cells": - grad_boundary_direct = None - grad_recurrent_direct_graph = None - grad_output_cells = grad_cells_out - elif output_contract == "pooled_output_cells": - grad_boundary_direct = None - grad_recurrent_direct_graph = None - grad_output_cells = _grad_output_cells_for_contract( - runtime, - artifacts.output_cells, - grad_cells_out, - output_contract, - ) - else: - raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") - - if output_backward_step is None: - grad_output_msg, output_projection_param_grads = runtime._run_backend_output_projection_backward_phase( - output_msg=artifacts.output_msg, - grad_output_cells=grad_output_cells, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - accumulate_params(output_projection_param_grads) - - ( - grad_output_q, - grad_input_k_from_output, - grad_input_v_from_output, - grad_recurrent_k_from_output, - grad_recurrent_v_from_output, - ) = runtime._run_backend_message_backward_phase( - grad_msg=grad_output_msg, - q_subset=cast(torch.Tensor, static_tensors["output_q"]), - input_k=artifacts.input_k, - input_v=artifacts.input_v, - recurrent_k=artifacts.recurrent_k, - recurrent_v=artifacts.recurrent_v, - neighbor_idx=runtime.output_neighbor_idx, - neighbor_valid=runtime.output_neighbor_valid, - edge_distance=runtime.output_edge_distance, - edge_delay=runtime.output_edge_delay, - local_sender_idx=runtime.output_local_sender_idx, - local_receiver_idx_by_sender=runtime.output_local_receiver_idx_by_sender, - use_sparse_messages=bool(getattr(runtime, "_uses_sparse_message_backend", False)), - ) - - grad_recurrent_hidden_from_kv, recurrent_projection_param_grads = ( - runtime._run_backend_sender_kv_projection_backward_phase( - role="recurrent", - sender_cells=artifacts.recurrent_hidden_graph_order, - grad_k=grad_recurrent_k_from_output, - grad_v=grad_recurrent_v_from_output, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - active_receiver_window=None, - owner="grouped_projection", - ) - ) - accumulate_params(recurrent_projection_param_grads) - else: - grad_output_q = None - grad_input_k_from_output = output_backward_step.grad_input_k_from_output - grad_input_v_from_output = output_backward_step.grad_input_v_from_output - grad_recurrent_hidden_from_kv = output_backward_step.grad_recurrent_hidden_from_output - grad_recurrent_hidden_graph = _accumulate_tensor_grad( - grad_recurrent_direct_graph, - grad_recurrent_hidden_from_kv, - ) - grad_recurrent_hidden_backend = None - if grad_recurrent_hidden_graph is not None: - grad_recurrent_hidden_backend = grad_recurrent_hidden_graph.index_select( - 1, - runtime.population_backend_recurrent_order, - ) - _debug_check_finite(f"{debug_label}.grad_recurrent_hidden_graph", grad_recurrent_hidden_graph) - _debug_check_finite(f"{debug_label}.grad_recurrent_hidden_backend", grad_recurrent_hidden_backend) - grad_backend_state_cache: dict[str, object] | None = None - transition_param_grads: BackendOrderTransitionParamGrads | None = None - if artifacts.backend_state_cache_before is not None: - if defer_transition_param_binding: - grad_recurrent_msg, grad_backend_state_cache, transition_param_grads = ( - run_backend_order_transition_buckets_backward_step_cached_unbound( - runtime, - artifacts.recurrent_msg_backend_order, - artifacts.backend_state_cache_before, - grad_recurrent_hidden=grad_recurrent_hidden_backend, - grad_next_backend_state_cache=grad_next_backend_state_cache, - resets=artifacts.transition_reset_step, - static_tensors=static_tensors, - need_grad_state_before=need_grad_state_before, - forward_tape_by_population=artifacts.transition_backward_tape_by_population, - ) - ) - transition_param_tuple = tuple(None for _ in trainable_params) - else: - grad_recurrent_msg, grad_backend_state_cache, transition_param_tuple = ( - run_backend_order_transition_buckets_backward_step_cached( - runtime, - artifacts.recurrent_msg_backend_order, - artifacts.backend_state_cache_before, - grad_recurrent_hidden=grad_recurrent_hidden_backend, - grad_next_backend_state_cache=grad_next_backend_state_cache, - resets=artifacts.transition_reset_step, - static_tensors=static_tensors, - trainable_param_names=trainable_param_names, - trainable_param_shapes=tuple(tuple(param.shape) for param in trainable_params), - need_grad_state_before=need_grad_state_before, - forward_tape_by_population=artifacts.transition_backward_tape_by_population, - ) - ) - grad_population_state = TensorDict( - { - name: _partial_backend_grad_state_to_population_state(backend_state) - for name, backend_state in grad_backend_state_cache.items() - if isinstance(backend_state, dict) - }, - batch_size=[], - ) - else: - grad_recurrent_msg, grad_population_state, transition_param_tuple = ( - runtime._run_backend_order_transition_buckets_backward_step( - artifacts.recurrent_msg_backend_order, - artifacts.population_state_before, - grad_recurrent_hidden=grad_recurrent_hidden_backend, - grad_next_population_state=grad_next_population_state, - resets=artifacts.transition_reset_step, - static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - need_grad_state_before=need_grad_state_before, - ) - ) - _debug_check_finite(f"{debug_label}.grad_recurrent_msg", grad_recurrent_msg) - _debug_check_grad_tuple(f"{debug_label}.transition_param_tuple", transition_param_tuple, trainable_param_names) - accumulate_params(transition_param_tuple) - - recurrent_q_backend_order = cast(torch.Tensor, static_tensors["recurrent_q_backend_order"]) - ( - grad_recurrent_q_backend, - grad_input_k_from_recurrent, - grad_input_v_from_recurrent, - grad_recurrent_k_before, - grad_recurrent_v_before, - ) = runtime._run_backend_message_backward_phase( - grad_msg=grad_recurrent_msg, - q_subset=recurrent_q_backend_order, - input_k=artifacts.input_k, - input_v=artifacts.input_v, - recurrent_k=artifacts.recurrent_k_before, - recurrent_v=artifacts.recurrent_v_before, - neighbor_idx=runtime.recurrent_neighbor_idx_backend_order, - neighbor_valid=runtime.recurrent_neighbor_valid_backend_order, - edge_distance=runtime.recurrent_edge_distance_backend_order, - edge_delay=runtime.recurrent_edge_delay_backend_order, - local_sender_idx=runtime.recurrent_local_sender_idx_backend_order, - local_receiver_idx_by_sender=runtime.recurrent_local_receiver_idx_by_sender_backend_order, - use_sparse_messages=bool(getattr(runtime, "_uses_sparse_message_backend", False)), - ) - grad_recurrent_q = ( - None - if grad_recurrent_q_backend is None - else grad_recurrent_q_backend.index_select(0, runtime.population_backend_recurrent_inverse_order) - ) - _debug_check_finite(f"{debug_label}.grad_recurrent_q", grad_recurrent_q) - _debug_check_finite(f"{debug_label}.grad_input_k_from_recurrent", grad_input_k_from_recurrent) - _debug_check_finite(f"{debug_label}.grad_input_v_from_recurrent", grad_input_v_from_recurrent) - _debug_check_finite(f"{debug_label}.grad_recurrent_k_before", grad_recurrent_k_before) - _debug_check_finite(f"{debug_label}.grad_recurrent_v_before", grad_recurrent_v_before) - recurrent_query_backward_step = None - if defer_recurrent_query_backward: - if grad_output_q is not None: - raise RuntimeError("Temporal recurrent query deferral requires output query gradients to be pre-batched") - recurrent_query_backward_step = TemporalRecurrentQueryBackwardStep(grad_recurrent_q=grad_recurrent_q) - else: - query_param_grads = runtime._run_backend_query_param_backward_phase( - grad_recurrent_q=grad_recurrent_q, - grad_output_q=grad_output_q, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - device=artifacts.boundary_step.device, - dtype=artifacts.boundary_step.dtype, - active_receiver_window=None, - ) - accumulate_params(query_param_grads) - - total_grad_input_k = _accumulate_tensor_grad(grad_input_k_from_output, grad_input_k_from_recurrent) - total_grad_input_v = _accumulate_tensor_grad(grad_input_v_from_output, grad_input_v_from_recurrent) - boundary_backward_step = None - if defer_boundary_backward: - grad_boundary_from_projection = None - boundary_backward_step = TemporalBoundaryBackwardStep( - grad_input_k=total_grad_input_k, - grad_input_v=total_grad_input_v, - ) - else: - grad_boundary_from_projection, boundary_param_grads = runtime._run_backend_boundary_public_backward_phase( - boundary_step=artifacts.boundary_step, - grad_input_k=total_grad_input_k, - grad_input_v=total_grad_input_v, - sequence_static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - device=artifacts.boundary_step.device, - dtype=artifacts.boundary_step.dtype, - boundary_requires_grad=boundary_requires_grad, - ) - accumulate_params(boundary_param_grads) - - grad_hidden_before, _grad_initial_k, _grad_initial_v, initial_recurrent_raw_param_grad = ( - runtime._run_backend_initial_recurrent_backward_raw_phase( - hidden_before=artifacts.cells_prev[:, runtime._recurrent_slice, :], - initial_recurrent_k_before=None, - initial_recurrent_v_before=None, - population_reset_step=artifacts.reset_step, - grad_resolved_recurrent_k=grad_recurrent_k_before, - grad_resolved_recurrent_v=grad_recurrent_v_before, - sequence_static_tensors=static_tensors, - active_receiver_window=None, - device=artifacts.boundary_step.device, - dtype=artifacts.boundary_step.dtype, - ) - ) - _debug_check_finite(f"{debug_label}.initial_recurrent.grad_hidden_before", grad_hidden_before) - if initial_recurrent_raw_param_grad is not None: - _debug_check_finite( - f"{debug_label}.initial_recurrent.raw_grad_weight", - initial_recurrent_raw_param_grad.grad_weight, - ) - initial_recurrent_backward_step = None - if defer_initial_recurrent_param_binding: - initial_recurrent_backward_step = TemporalInitialRecurrentBackwardStep( - raw_param_grad=initial_recurrent_raw_param_grad, - ) - else: - initial_recurrent_param_grads = runtime._sender_kv_projection_param_grad_tuple_from_raw_grads( - (initial_recurrent_raw_param_grad,), - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - ) - accumulate_params(initial_recurrent_param_grads) - - grad_boundary = _accumulate_tensor_grad(grad_boundary_direct, grad_boundary_from_projection) - grad_state = TensorDict({}, batch_size=[]) - if need_grad_state_before: - grad_cells = torch.zeros_like(artifacts.cells_prev) - if grad_hidden_before is not None: - grad_cells[:, runtime._recurrent_slice, :] = grad_hidden_before.to(dtype=grad_cells.dtype) - grad_state["cells"] = grad_cells - for name in runtime._population_names: - population_grad = grad_population_state.get(name) - if isinstance(population_grad, TensorDictBase): - grad_state[name] = population_grad - return TemporalBucketStepBackwardResult( - grad_boundary=grad_boundary, - grad_state=grad_state, - param_grads=tuple(grad_param_accum), - grad_backend_state_cache=grad_backend_state_cache, - boundary_backward_step=boundary_backward_step, - recurrent_query_backward_step=recurrent_query_backward_step, - initial_recurrent_backward_step=initial_recurrent_backward_step, - transition_param_grads=transition_param_grads, - ) - - -def run_temporal_bucket_step_backward( - runtime: Any, - artifacts: TemporalBucketStepArtifacts, - *, - grad_cells_out: torch.Tensor | None, - static_tensors: dict[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - need_grad_state_before: bool, - grad_next_population_state: dict[str, dict[str, torch.Tensor | None]] | None = None, - grad_next_backend_state_cache: dict[str, object] | None = None, - output_contract: TemporalOutputContract = "full_cells", - output_backward_step: TemporalOutputBackwardStep | None = None, -) -> tuple[torch.Tensor | None, TensorDict, tuple[torch.Tensor | None, ...], dict[str, object] | None]: - result = _run_temporal_bucket_step_backward_result( - runtime, - artifacts, - grad_cells_out=grad_cells_out, - static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - need_grad_state_before=need_grad_state_before, - grad_next_population_state=grad_next_population_state, - grad_next_backend_state_cache=grad_next_backend_state_cache, - output_contract=output_contract, - output_backward_step=output_backward_step, - debug_label="step", - ) - return result.grad_boundary, result.grad_state, result.param_grads, result.grad_backend_state_cache - - -_TemporalPopulationStateSpec = tuple[tuple[str, tuple[str, ...]], ...] -_TemporalStateSpec = tuple[tuple[str, ...], _TemporalPopulationStateSpec] - - -def _flatten_temporal_state_inputs( - runtime: Any, - state: TensorDict, -) -> tuple[_TemporalStateSpec, tuple[torch.Tensor, ...]]: - top_level_keys = ("cells",) - tensors: list[torch.Tensor] = [state["cells"]] - population_specs: list[tuple[str, tuple[str, ...]]] = [] - for population_name in runtime._population_names: - population_state = state[population_name] - if not isinstance(population_state, TensorDictBase): - raise RuntimeError(f"Temporal bucket sequence requires TensorDict state for {population_name}") - keys = tuple(runtime._cell_spec_for_population(population_name).state_schema.keys) - population_specs.append((population_name, keys)) - for key in keys: - tensor = population_state[key] - if not torch.is_tensor(tensor): - raise RuntimeError(f"Temporal bucket sequence state {population_name}.{key} is not a tensor") - tensors.append(tensor) - return (top_level_keys, tuple(population_specs)), tuple(tensors) - - -def _unflatten_temporal_state( - specs: _TemporalStateSpec, - tensors: tuple[torch.Tensor, ...], -) -> TensorDict: - top_level_keys, population_specs = specs - state = TensorDict({}, batch_size=[]) - offset = 0 - for key in top_level_keys: - state[key] = tensors[offset] - offset += 1 - for population_name, keys in population_specs: - leaves: dict[str, torch.Tensor] = {} - first: torch.Tensor | None = None - for key in keys: - tensor = tensors[offset] - offset += 1 - leaves[key] = tensor - if first is None: - first = tensor - state[population_name] = TensorDict( - leaves, - batch_size=[] if first is None else list(first.shape[:2]), - device=None if first is None else first.device, - ) - return state - - -def _flatten_temporal_state_grad_outputs( - specs: _TemporalStateSpec, - grad_state: TensorDict, -) -> tuple[torch.Tensor | None, ...]: - top_level_keys, population_specs = specs - grads: list[torch.Tensor | None] = [] - for key in top_level_keys: - grad = grad_state.get(key) - grads.append(grad if torch.is_tensor(grad) else None) - for population_name, keys in population_specs: - population_grad = grad_state.get(population_name) - for key in keys: - if isinstance(population_grad, TensorDictBase): - grad = population_grad.get(key) - grads.append(grad if torch.is_tensor(grad) else None) - else: - grads.append(None) - return tuple(grads) - - -def _population_grad_dict(grad_state: TensorDict) -> dict[str, dict[str, torch.Tensor | None]]: - population_grads: dict[str, dict[str, torch.Tensor | None]] = {} - for population_name, population_grad in grad_state.items(): - if not isinstance(population_grad, TensorDictBase): - continue - population_grads[population_name] = { - key: cast(torch.Tensor | None, value) if torch.is_tensor(value) else None - for key, value in population_grad.items() - } - return population_grads - - -class TemporalPhysicalBackwardScanExecutor: - def __init__( - self, - runtime: Any, - *, - static_tensors: dict[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - output_contract: TemporalOutputContract, - materialize_final_state: bool, - boundary_requires_grad: bool, - ) -> None: - self.runtime = runtime - self.static_tensors = static_tensors - self.trainable_params = trainable_params - self.trainable_param_names = trainable_param_names - self.output_contract = output_contract - self.materialize_final_state = bool(materialize_final_state) - self.boundary_requires_grad = bool(boundary_requires_grad) - - def _run_backward_window( - self, - *, - artifacts_window: list[TemporalBucketStepArtifacts], - window_start: int, - window_end: int, - grad_output_seq: torch.Tensor | None, - grad_carry_cells: torch.Tensor | None, - grad_next_population_state: dict[str, dict[str, torch.Tensor | None]], - grad_next_backend_state_cache: dict[str, object] | None, - ) -> TemporalBackwardWindowResult: - local_time_steps = len(artifacts_window) - if local_time_steps != int(window_end) - int(window_start): - raise RuntimeError("Temporal backward window has mismatched artifact count") - grad_param_accum: list[torch.Tensor | None] = [None] * len(self.trainable_params) - - def accumulate_params(grads: tuple[torch.Tensor | None, ...]) -> None: - for parameter_index, grad_param in enumerate(grads): - grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( - grad_param_accum[parameter_index], - grad_param, - ) - - grad_output_window = None if grad_output_seq is None else grad_output_seq[:, window_start:window_end] - output_backward_sequence = ( - run_temporal_output_backward_sequence( - self.runtime, - artifacts_window, - grad_output_seq=grad_output_window, - static_tensors=self.static_tensors, - trainable_params=self.trainable_params, - trainable_param_names=self.trainable_param_names, - output_contract=self.output_contract, - ) - if self.output_contract in {"output_cells", "pooled_output_cells"} and not self.materialize_final_state - else None - ) - if output_backward_sequence is not None: - accumulate_params(output_backward_sequence.param_grads) - defer_boundary_backward = True - defer_recurrent_query_backward = output_backward_sequence is not None - defer_initial_recurrent_param_binding = True - defer_transition_param_binding = True - boundary_backward_steps: list[TemporalBoundaryBackwardStep | None] = [None for _ in range(local_time_steps)] - recurrent_query_backward_steps: list[TemporalRecurrentQueryBackwardStep | None] = [ - None for _ in range(local_time_steps) - ] - initial_recurrent_backward_steps: list[TemporalInitialRecurrentBackwardStep | None] = [ - None for _ in range(local_time_steps) - ] - transition_param_accum: dict[str, tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]] = {} - grad_boundary_steps: list[torch.Tensor | None] = [None for _ in range(local_time_steps)] - current_grad_carry_cells = grad_carry_cells - current_grad_next_population_state = grad_next_population_state - current_grad_next_backend_state_cache = grad_next_backend_state_cache - for local_step_index in reversed(range(local_time_steps)): - global_step_index = int(window_start) + int(local_step_index) - artifacts = artifacts_window[local_step_index] - grad_cells_out = None if grad_output_seq is None else grad_output_seq[:, global_step_index] - step_output_contract = self.output_contract - if self.output_contract in {"output_cells", "pooled_output_cells"} and torch.is_tensor( - current_grad_carry_cells - ): - full_grad_cells_out = current_grad_carry_cells.clone() - grad_output_cells = _grad_output_cells_for_contract( - self.runtime, - artifacts.output_cells, - grad_cells_out, - self.output_contract, - ) - if grad_output_cells is not None: - full_grad_cells_out[:, self.runtime._output_slice, :] = _accumulate_tensor_grad( - full_grad_cells_out[:, self.runtime._output_slice, :], - grad_output_cells, - ) - grad_cells_out = full_grad_cells_out - step_output_contract = "full_cells" - elif self.output_contract == "full_cells": - grad_cells_out = _accumulate_tensor_grad( - cast(torch.Tensor | None, grad_cells_out), - cast(torch.Tensor | None, current_grad_carry_cells), - ) - step_backward = _run_temporal_bucket_step_backward_result( - self.runtime, - artifacts, - grad_cells_out=grad_cells_out, - static_tensors=self.static_tensors, - trainable_params=self.trainable_params, - trainable_param_names=self.trainable_param_names, - need_grad_state_before=True, - grad_next_population_state=current_grad_next_population_state, - grad_next_backend_state_cache=current_grad_next_backend_state_cache, - output_contract=step_output_contract, - output_backward_step=None - if output_backward_sequence is None - else output_backward_sequence.steps[local_step_index], - defer_boundary_backward=defer_boundary_backward, - defer_recurrent_query_backward=defer_recurrent_query_backward, - defer_initial_recurrent_param_binding=defer_initial_recurrent_param_binding, - defer_transition_param_binding=defer_transition_param_binding, - boundary_requires_grad=self.boundary_requires_grad, - debug_label=f"step={global_step_index}", - ) - _debug_check_finite(f"step={global_step_index}.grad_boundary", step_backward.grad_boundary) - _debug_check_finite(f"step={global_step_index}.grad_state.cells", step_backward.grad_state.get("cells")) - _debug_check_grad_tuple( - f"step={global_step_index}.param_grads", - step_backward.param_grads, - self.trainable_param_names, - ) - grad_boundary_steps[local_step_index] = step_backward.grad_boundary - boundary_backward_steps[local_step_index] = step_backward.boundary_backward_step - recurrent_query_backward_steps[local_step_index] = step_backward.recurrent_query_backward_step - initial_recurrent_backward_steps[local_step_index] = step_backward.initial_recurrent_backward_step - _accumulate_transition_param_grads(transition_param_accum, step_backward.transition_param_grads) - current_grad_carry_cells = step_backward.grad_state.get("cells") - current_grad_next_population_state = _population_grad_dict(step_backward.grad_state) - current_grad_next_backend_state_cache = step_backward.grad_backend_state_cache - accumulate_params(step_backward.param_grads) - boundary_backward_sequence = run_temporal_boundary_backward_sequence( - self.runtime, - artifacts_window, - tuple(boundary_backward_steps), - static_tensors=self.static_tensors, - trainable_params=self.trainable_params, - trainable_param_names=self.trainable_param_names, - boundary_requires_grad=self.boundary_requires_grad, - ) - for local_step_index, grad_boundary_step in enumerate(boundary_backward_sequence.grad_boundary_steps): - _debug_check_finite( - f"boundary_backward_sequence.step={int(window_start) + local_step_index}.grad_boundary", - grad_boundary_step, - ) - if grad_boundary_step is not None: - grad_boundary_steps[local_step_index] = _accumulate_tensor_grad( - grad_boundary_steps[local_step_index], - grad_boundary_step, - ) - _debug_check_finite( - f"window.step={int(window_start) + local_step_index}.accumulated_grad_boundary", - grad_boundary_steps[local_step_index], - ) - _debug_check_grad_tuple( - "boundary_backward_sequence.param_grads", - boundary_backward_sequence.param_grads, - self.trainable_param_names, - ) - accumulate_params(boundary_backward_sequence.param_grads) - if defer_recurrent_query_backward: - recurrent_query_param_grads = run_temporal_recurrent_query_backward_sequence( - self.runtime, - tuple(recurrent_query_backward_steps), - static_tensors=self.static_tensors, - trainable_params=self.trainable_params, - trainable_param_names=self.trainable_param_names, - device=artifacts_window[0].boundary_step.device, - dtype=artifacts_window[0].boundary_step.dtype, - ) - _debug_check_grad_tuple( - "recurrent_query_param_grads", - recurrent_query_param_grads, - self.trainable_param_names, - ) - accumulate_params(recurrent_query_param_grads) - initial_recurrent_param_grads = run_temporal_initial_recurrent_param_binding_sequence( - self.runtime, - tuple(initial_recurrent_backward_steps), - trainable_params=self.trainable_params, - trainable_param_names=self.trainable_param_names, - ) - _debug_check_grad_tuple( - "initial_recurrent_param_grads", - initial_recurrent_param_grads, - self.trainable_param_names, - ) - accumulate_params(initial_recurrent_param_grads) - transition_param_grads = bind_temporal_transition_param_grads( - self.runtime, - transition_param_accum, - trainable_params=self.trainable_params, - trainable_param_names=self.trainable_param_names, - ) - _debug_check_grad_tuple( - "transition_param_grads", - transition_param_grads, - self.trainable_param_names, - ) - accumulate_params(transition_param_grads) - for local_step_index, grad_boundary_step in enumerate(grad_boundary_steps): - _debug_check_finite( - f"window.step={int(window_start) + local_step_index}.final_grad_boundary", - grad_boundary_step, - ) - _debug_check_finite("window.grad_carry_cells", cast(torch.Tensor | None, current_grad_carry_cells)) - _debug_check_grad_tuple("window.param_grads", tuple(grad_param_accum), self.trainable_param_names) - return TemporalBackwardWindowResult( - grad_boundary_steps=tuple(grad_boundary_steps), - grad_carry_cells=cast(torch.Tensor | None, current_grad_carry_cells), - grad_next_population_state=current_grad_next_population_state, - grad_next_backend_state_cache=current_grad_next_backend_state_cache, - param_grads=tuple(grad_param_accum), - ) - - def run( - self, - *, - boundary_seq: torch.Tensor, - artifact_store: TemporalArtifactStore, - population_resets: torch.Tensor | None, - transition_resets: torch.Tensor | None = None, - grad_output_seq: torch.Tensor | None, - grad_final_state: TensorDict, - ) -> TemporalPhysicalBackwardScanResult: - time_steps = int(boundary_seq.shape[1]) - artifacts_by_step = artifact_store.artifacts_by_step - stored_artifacts = artifacts_by_step is not None - if stored_artifacts and len(cast(list[TemporalBucketStepArtifacts], artifacts_by_step)) != time_steps: - raise RuntimeError("Temporal artifact store has mismatched step count") - grad_carry_cells = grad_final_state.get("cells") - grad_next_population_state = _population_grad_dict(grad_final_state) - grad_next_backend_state_cache: dict[str, object] | None = None - if grad_next_population_state: - grad_next_backend_state_cache = { - name: _population_grad_state_to_backend_grad_state(self.runtime, name, population_grad) - for name, population_grad in grad_next_population_state.items() - } - grad_boundary_seq = torch.zeros_like(boundary_seq) - grad_param_accum: list[torch.Tensor | None] = [None] * len(self.trainable_params) - if stored_artifacts: - artifact_windows = ((0, time_steps),) - else: - artifact_windows = _temporal_artifact_windows( - time_steps=time_steps, - checkpoint_stride=artifact_store.checkpoint_stride, - window_len=artifact_store.recompute_window_len, - ) - - self.runtime._begin_backend_owner_timing(boundary_seq.device) - with torch.profiler.record_function("fabric.backward.physical_temporal_bucket_sequence"): - for window_start, window_end in reversed(artifact_windows): - if artifacts_by_step is not None: - artifacts_window = artifacts_by_step[window_start:window_end] - else: - checkpoint = _nearest_temporal_artifact_checkpoint(artifact_store.checkpoints, window_start) - with self.runtime._backend_owner_timing("temporal_artifact_recompute"): - artifacts_window = _recompute_temporal_bucket_artifact_window( - self.runtime, - boundary_seq=boundary_seq, - population_resets=population_resets, - transition_resets=transition_resets, - static_tensors=self.static_tensors, - checkpoint=checkpoint, - start_step=window_start, - end_step=window_end, - transition_tape_mode=artifact_store.transition_tape_mode, - ) - window_result = self._run_backward_window( - artifacts_window=artifacts_window, - window_start=window_start, - window_end=window_end, - grad_output_seq=grad_output_seq, - grad_carry_cells=cast(torch.Tensor | None, grad_carry_cells), - grad_next_population_state=grad_next_population_state, - grad_next_backend_state_cache=grad_next_backend_state_cache, - ) - for local_step_index, grad_boundary_step in enumerate(window_result.grad_boundary_steps): - _debug_check_finite( - f"scan.window={window_start}:{window_end}.step={window_start + local_step_index}.grad_boundary", - grad_boundary_step, - ) - if grad_boundary_step is not None: - grad_boundary_seq[:, window_start + local_step_index] = grad_boundary_step.to( - dtype=grad_boundary_seq.dtype - ) - _debug_check_finite( - f"scan.grad_boundary_seq.step={window_start + local_step_index}", - grad_boundary_seq[:, window_start + local_step_index], - ) - grad_carry_cells = window_result.grad_carry_cells - _debug_check_finite("scan.grad_carry_cells", cast(torch.Tensor | None, grad_carry_cells)) - grad_next_population_state = window_result.grad_next_population_state - grad_next_backend_state_cache = window_result.grad_next_backend_state_cache - for parameter_index, grad_param in enumerate(window_result.param_grads): - grad_param_accum[parameter_index] = _accumulate_owned_tensor_grad( - grad_param_accum[parameter_index], - grad_param, - ) - _debug_check_grad_tuple("scan.param_grads", tuple(grad_param_accum), self.trainable_param_names) - self.runtime._finish_backend_owner_timing() - _debug_check_finite("scan.grad_boundary_seq", grad_boundary_seq) - return TemporalPhysicalBackwardScanResult( - grad_boundary_seq=grad_boundary_seq, - grad_carry_cells=cast(torch.Tensor | None, grad_carry_cells), - grad_next_population_state=grad_next_population_state, - param_grads=tuple(grad_param_accum), - ) - - -def run_temporal_bucket_sequence_physical_autograd( - runtime: Any, - *, - boundary_seq: torch.Tensor, - state: TensorDict, - population_resets: torch.Tensor | None, - transition_resets: torch.Tensor | None = None, - static_tensors: dict[str, object], - planned_backward_execution: Any | None, - materialize_final_state: bool, - output_contract: TemporalOutputContract = "full_cells", -) -> tuple[torch.Tensor, TensorDict]: - state_specs, state_tensors = _flatten_temporal_state_inputs(runtime, state) - trainable_items = tuple((name, param) for name, param in runtime.named_parameters() if param.requires_grad) - outputs = _TemporalBucketSequenceFunction.apply( - runtime, - static_tensors, - planned_backward_execution, - state_specs, - tuple(name for name, _param in trainable_items), - materialize_final_state, - output_contract, - boundary_seq, - population_resets, - transition_resets, - *state_tensors, - *(param for _name, param in trainable_items), - ) - output_seq = cast(torch.Tensor, outputs[0]) - if not materialize_final_state: - return output_seq, TensorDict({}, batch_size=[]) - state_tensor_count = len(state_tensors) - final_state = _unflatten_temporal_state( - state_specs, - cast(tuple[torch.Tensor, ...], tuple(outputs[1 : 1 + state_tensor_count])), - ) - return output_seq, final_state - - -class _TemporalBucketSequenceFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - runtime: Any, - static_tensors: dict[str, object], - planned_backward_execution: Any | None, - state_specs: _TemporalStateSpec, - trainable_param_names: tuple[str, ...], - materialize_final_state: bool, - output_contract: TemporalOutputContract, - boundary_seq: torch.Tensor, - population_resets: torch.Tensor | None, - transition_resets: torch.Tensor | None, - *state_tensors_and_params: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - state_tensor_count = 1 + sum(len(keys) for _population_name, keys in state_specs[1]) - state_tensors = tuple(state_tensors_and_params[:state_tensor_count]) - trainable_params = tuple(state_tensors_and_params[state_tensor_count:]) - ctx.runtime = runtime - ctx.static_tensors = static_tensors - ctx.planned_backward_execution = planned_backward_execution - ctx.state_specs = state_specs - ctx.trainable_param_names = trainable_param_names - ctx.materialize_final_state = bool(materialize_final_state) - ctx.output_contract = output_contract - ctx.state_tensor_count = state_tensor_count - ctx.trainable_param_count = len(trainable_params) - ctx.has_resets = torch.is_tensor(population_resets) - ctx.has_transition_resets = torch.is_tensor(transition_resets) - ctx.save_for_backward( - boundary_seq, - *((population_resets,) if torch.is_tensor(population_resets) else ()), - *((transition_resets,) if torch.is_tensor(transition_resets) else ()), - *state_tensors, - *trainable_params, - ) - running_state = _unflatten_temporal_state(state_specs, state_tensors) - output_steps: list[torch.Tensor] = [] - artifact_store: TemporalArtifactStore | None = None - artifacts_by_step: list[TemporalBucketStepArtifacts] | None = None - artifact_checkpoints: dict[int, TemporalArtifactCheckpoint] = { - 0: _make_temporal_artifact_checkpoint( - step_index=0, - state=running_state, - population_state_cache=None, - recurrent_k=None, - recurrent_v=None, - ) - } - step_population_state_cache = runtime._prepare_stream_step_population_cache( - running_state, - batch=int(boundary_seq.shape[0]), - device=boundary_seq.device, - dtype=boundary_seq.dtype, - ) - artifact_checkpoints[0] = _make_temporal_artifact_checkpoint( - step_index=0, - state=running_state, - population_state_cache=step_population_state_cache, - recurrent_k=None, - recurrent_v=None, - ) - transition_tape_policy = temporal_transition_tape_policy( - runtime, - static_tensors, - batch_size=int(boundary_seq.shape[0]), - time_steps=int(boundary_seq.shape[1]), - device=boundary_seq.device, - dtype_bytes=int(boundary_seq.element_size()), - tape_policy_bin=getattr(planned_backward_execution, "tape_policy_bin", None), - ) - runtime._last_flat_bucket_temporal_recurrent_kv_carry_reuse = int(boundary_seq.shape[1]) > 1 - runtime._last_flat_bucket_transition_tape_mode = transition_tape_policy.mode - runtime._last_flat_bucket_transition_tape_reason = transition_tape_policy.reason - runtime._last_flat_bucket_temporal_artifact_mode = None - runtime._last_flat_bucket_temporal_artifact_reason = None - runtime._last_flat_bucket_temporal_artifact_checkpoint_stride = None - runtime._last_flat_bucket_temporal_artifact_recompute_window_len = None - runtime._last_flat_bucket_temporal_artifact_checkpoint_count = None - input_sender_weight = cast(torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"]) - input_group_weight = cast(torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"]) - with torch.profiler.record_function("fabric.physical.message.input_projection_sequence"): - input_k_seq, input_v_seq = runtime._project_sender_kv_from_cells_sequence( - boundary_seq, - sender_input_to_kv_weight=input_sender_weight, - grouped_sender_input_to_kv_weight=input_group_weight, - sender_group_size=runtime._input_sender_kv_group_size, - ) - input_k_seq = input_k_seq.transpose(0, 1).contiguous() - input_v_seq = input_v_seq.transpose(0, 1).contiguous() - with torch.no_grad(): - running_recurrent_k: torch.Tensor | None = None - running_recurrent_v: torch.Tensor | None = None - for step_index in range(int(boundary_seq.shape[1])): - reset_step = population_resets[:, step_index] if torch.is_tensor(population_resets) else None - transition_reset_step = ( - transition_resets[:, step_index] if torch.is_tensor(transition_resets) else reset_step - ) - if ( - artifact_store is not None - and artifact_store.mode == "recompute_step_artifacts" - and step_index > 0 - and step_index % max(1, int(artifact_store.checkpoint_stride)) == 0 - ): - artifact_checkpoints[step_index] = _make_temporal_artifact_checkpoint( - step_index=step_index, - state=running_state, - population_state_cache=step_population_state_cache, - recurrent_k=running_recurrent_k, - recurrent_v=running_recurrent_v, - ) - running_recurrent_k, running_recurrent_v = _apply_temporal_recurrent_kv_reset( - reset_step=reset_step, - recurrent_k=running_recurrent_k, - recurrent_v=running_recurrent_v, - ) - artifacts = compute_temporal_bucket_step_artifacts( - runtime, - boundary_step=boundary_seq[:, step_index], - state=running_state, - reset_step=reset_step, - transition_reset_step=transition_reset_step, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - input_k_step=input_k_seq[step_index], - input_v_step=input_v_seq[step_index], - recurrent_k_before_step=running_recurrent_k, - recurrent_v_before_step=running_recurrent_v, - transition_tape_mode=transition_tape_policy.mode, - ) - running_recurrent_k = artifacts.recurrent_k - running_recurrent_v = artifacts.recurrent_v - if artifact_store is None: - artifact_store = _temporal_artifact_store_policy( - runtime, - first_artifact=artifacts, - time_steps=int(boundary_seq.shape[1]), - device=boundary_seq.device, - ) - artifacts_by_step = artifact_store.artifacts_by_step - runtime._last_flat_bucket_temporal_artifact_mode = artifact_store.mode - runtime._last_flat_bucket_temporal_artifact_reason = artifact_store.reason - runtime._last_flat_bucket_temporal_artifact_checkpoint_stride = artifact_store.checkpoint_stride - runtime._last_flat_bucket_temporal_artifact_recompute_window_len = ( - artifact_store.recompute_window_len - ) - runtime._last_flat_bucket_temporal_artifact_checkpoint_count = len(artifact_store.checkpoints) - if artifacts_by_step is not None: - artifacts_by_step.append(artifacts) - if output_contract == "full_cells": - output_steps.append(artifacts.cells_out) - elif output_contract == "output_cells": - output_steps.append(artifacts.output_cells) - elif output_contract == "pooled_output_cells": - output_steps.append(runtime._pool_output_ports(artifacts.output_cells.unsqueeze(1)).squeeze(1)) - else: - raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") - running_state = TensorDict( - { - "cells": artifacts.cells_out, - **{name: artifacts.population_state_after[name] for name in runtime._population_names}, - }, - batch_size=[], - ) - if artifact_store is None: - raise RuntimeError("Temporal bucket sequence produced no artifacts") - if artifact_store.mode == "recompute_step_artifacts": - artifact_store = TemporalArtifactStore( - mode=artifact_store.mode, - artifacts_by_step=None, - checkpoints=artifact_checkpoints, - checkpoint_stride=artifact_store.checkpoint_stride, - recompute_window_len=artifact_store.recompute_window_len, - transition_tape_mode=artifact_store.transition_tape_mode, - reason=artifact_store.reason, - stored_artifact_step_bytes=artifact_store.stored_artifact_step_bytes, - ) - else: - artifact_store = TemporalArtifactStore( - mode=artifact_store.mode, - artifacts_by_step=artifacts_by_step or [], - checkpoints={}, - checkpoint_stride=artifact_store.checkpoint_stride, - recompute_window_len=artifact_store.recompute_window_len, - transition_tape_mode=transition_tape_policy.mode, - reason=artifact_store.reason, - stored_artifact_step_bytes=artifact_store.stored_artifact_step_bytes, - ) - runtime._last_flat_bucket_temporal_artifact_checkpoint_count = len(artifact_store.checkpoints) - ctx.artifact_store = artifact_store - output_seq = torch.stack(output_steps, dim=1) - if not materialize_final_state: - return (output_seq,) - final_state_tensors = _flatten_temporal_state_grad_outputs(state_specs, running_state) - if any(tensor is None for tensor in final_state_tensors): - raise RuntimeError("Temporal bucket sequence produced incomplete final state") - return (output_seq, *cast(tuple[torch.Tensor, ...], final_state_tensors)) - - @staticmethod - def backward( - ctx: Any, - *grad_outputs: torch.Tensor | None, - ) -> tuple[object, ...]: - saved = ctx.saved_tensors - offset = 0 - boundary_seq = saved[offset] - offset += 1 - population_resets = None - if ctx.has_resets: - population_resets = saved[offset] - offset += 1 - transition_resets = None - if ctx.has_transition_resets: - transition_resets = saved[offset] - offset += 1 - state_tensor_count = int(ctx.state_tensor_count) - state_tensors = tuple(saved[offset : offset + state_tensor_count]) - offset += state_tensor_count - trainable_params = tuple(saved[offset:]) - - artifact_store = cast(TemporalArtifactStore, ctx.artifact_store) - _validate_temporal_physical_backward_plan(ctx.planned_backward_execution) - - grad_output_seq = grad_outputs[0] - grad_final_state = ( - _unflatten_temporal_state( - ctx.state_specs, - cast(tuple[torch.Tensor, ...], tuple(grad_outputs[1 : 1 + state_tensor_count])), - ) - if ctx.materialize_final_state - else TensorDict({}, batch_size=[]) - ) - scan_result = TemporalPhysicalBackwardScanExecutor( - ctx.runtime, - static_tensors=ctx.static_tensors, - trainable_params=cast(tuple[torch.Tensor, ...], trainable_params), - trainable_param_names=ctx.trainable_param_names, - output_contract=cast(TemporalOutputContract, ctx.output_contract), - materialize_final_state=ctx.materialize_final_state, - boundary_requires_grad=boundary_seq.requires_grad, - ).run( - boundary_seq=boundary_seq, - artifact_store=artifact_store, - population_resets=cast(torch.Tensor | None, population_resets), - transition_resets=cast(torch.Tensor | None, transition_resets), - grad_output_seq=grad_output_seq, - grad_final_state=grad_final_state, - ) - _debug_check_finite("autograd.grad_boundary_seq", scan_result.grad_boundary_seq) - _debug_check_grad_tuple("autograd.param_grads", scan_result.param_grads, ctx.trainable_param_names) - - grad_state_tensors = _flatten_temporal_state_grad_outputs( - ctx.state_specs, - TensorDict( - { - "cells": scan_result.grad_carry_cells - if torch.is_tensor(scan_result.grad_carry_cells) - else torch.zeros_like(state_tensors[0]), - **{ - name: TensorDict( - {key: value for key, value in population_grad.items() if torch.is_tensor(value)}, - batch_size=[], - ) - for name, population_grad in scan_result.grad_next_population_state.items() - }, - }, - batch_size=[], - ), - ) - return ( - None, - None, - None, - None, - None, - None, - None, - scan_result.grad_boundary_seq, - None, - None, - *grad_state_tensors, - *scan_result.param_grads, - ) - - -__all__ = [ - "TemporalBucketStepArtifacts", - "compute_temporal_bucket_step_artifacts", - "run_temporal_bucket_sequence_physical_autograd", - "run_temporal_bucket_step_backward", - "temporal_transition_tape_policy", -] diff --git a/src/cortical/fabric/backend/cuda/sequence_surface/temporal_executor.py b/src/cortical/fabric/backend/cuda/sequence_surface/temporal_executor.py deleted file mode 100644 index 5dd067dd..00000000 --- a/src/cortical/fabric/backend/cuda/sequence_surface/temporal_executor.py +++ /dev/null @@ -1,1322 +0,0 @@ -from __future__ import annotations - -from typing import Any, Literal, cast - -import torch -from tensordict import TensorDict - -from cortical.fabric.backend.cuda.sequence_surface.support import ( - _ReceiverWindowSpec, - _transition_supports_receiver_local_dependency_window, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_backward import ( - TemporalOutputContract as _TemporalOutputContract, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_backward import ( - run_temporal_bucket_sequence_physical_autograd as _run_temporal_bucket_sequence_physical_autograd, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - PHYSICAL_TEMPORAL_BACKWARD_EXECUTOR as _PHYSICAL_TEMPORAL_BACKWARD_EXECUTOR, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - PHYSICAL_TRANSITION_BACKWARD_EXECUTOR as _PHYSICAL_TRANSITION_BACKWARD_EXECUTOR, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - active_population_names as _active_population_names, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - temporal_backward_owner_plan as _temporal_backward_owner_plan, -) -from cortical.fabric.backend.cuda.sequence_surface.temporal_buckets import ( - with_cached_population_static_tensors as _with_cached_population_static_tensors, -) -from cortical.fabric.backend.graph_regions import ClosedRecurrentRegion -from cortical.fabric.backend.surfaces import BackendExecutionRecord - - -def _expand_temporal_boundary_for_constant_inner_steps( - *, - boundary_seq: torch.Tensor, - population_resets: torch.Tensor | None, - inner_steps: int, -) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - if int(inner_steps) <= 1: - return boundary_seq, population_resets, population_resets, None - expanded_boundary_seq = torch.repeat_interleave(boundary_seq, repeats=int(inner_steps), dim=1) - expanded_resets = None - expanded_transition_resets = None - if torch.is_tensor(population_resets): - batch_size, time_steps = int(population_resets.shape[0]), int(population_resets.shape[1]) - expanded_resets = population_resets.new_zeros((batch_size, time_steps * int(inner_steps))) - expanded_resets[:, :: int(inner_steps)] = population_resets - expanded_transition_resets = ( - population_resets.unsqueeze(2) - .expand(batch_size, time_steps, int(inner_steps)) - .reshape(batch_size, time_steps * int(inner_steps)) - ) - output_step_indices = torch.arange( - int(inner_steps) - 1, - int(boundary_seq.shape[1]) * int(inner_steps), - int(inner_steps), - device=boundary_seq.device, - dtype=torch.long, - ) - return expanded_boundary_seq, expanded_resets, expanded_transition_resets, output_step_indices - - -def execute_temporal_bucket_sequence( - runtime: Any, - *, - hidden_seq: torch.Tensor | None, - boundary_seq: torch.Tensor | None, - state: TensorDict, - population_resets: torch.Tensor | None, - step_reset_flags: list[bool] | None, - k: int | torch.Tensor | None, - constant_k: int | None, - batch_size: int, - time_steps: int, - step_mode: bool, - capture_active: bool, - static_tensors: dict[str, object], - grad_path: bool, - materialize_final_state: bool, - backend_population_state_is_fresh: bool, - use_fresh_backend_population_cache: bool, - tape_policy: Any | None = None, - output_contract: _TemporalOutputContract = "full_cells", - output_boundary: Literal["sequence", "terminal"] = "sequence", -) -> tuple[torch.Tensor, TensorDict]: - runtime._last_flat_bucket_transition_backward_executor = None - runtime._last_flat_bucket_temporal_recurrent_kv_carry_reuse = False - runtime._last_flat_bucket_temporal_artifact_mode = None - runtime._last_flat_bucket_temporal_artifact_reason = None - runtime._last_flat_bucket_temporal_artifact_checkpoint_stride = None - runtime._last_flat_bucket_temporal_artifact_recompute_window_len = None - runtime._last_flat_bucket_temporal_artifact_checkpoint_count = None - q = cast(torch.Tensor, static_tensors["q"]) - recurrent_q = cast(torch.Tensor, static_tensors["recurrent_q"]) - output_q = cast(torch.Tensor, static_tensors["output_q"]) - gathered_kv_weight = cast(torch.Tensor, static_tensors["gathered_kv_weight"]) - sender_kv_weight = cast(torch.Tensor | None, static_tensors["sender_kv_weight"]) - sender_input_to_kv_weight = cast(torch.Tensor | None, static_tensors["sender_input_to_kv_weight"]) - input_sender_input_to_kv_weight = cast(torch.Tensor | None, static_tensors["input_sender_input_to_kv_weight"]) - sender_group_input_to_kv_weight = cast(torch.Tensor | None, static_tensors["sender_group_input_to_kv_weight"]) - recurrent_sender_input_to_kv_weight = cast( - torch.Tensor | None, - static_tensors["recurrent_sender_input_to_kv_weight"], - ) - recurrent_group_input_to_kv_weight = cast( - torch.Tensor | None, - static_tensors["recurrent_group_input_to_kv_weight"], - ) - input_group_input_to_kv_weight = cast(torch.Tensor | None, static_tensors["input_group_input_to_kv_weight"]) - value_to_cell_weight = cast(torch.Tensor, static_tensors["value_to_cell_weight"]) - fused_recurrent_value_to_cell_weight = cast( - torch.Tensor | None, - static_tensors["fused_recurrent_value_to_cell_weight"], - ) - value_to_output_weight = cast(torch.Tensor, static_tensors["value_to_output_weight"]) - cell_bias = cast(torch.Tensor, static_tensors["cell_bias"]) - recurrent_cell_bias = cast(torch.Tensor, static_tensors["recurrent_cell_bias"]) - fused_recurrent_cell_bias = cast(torch.Tensor, static_tensors["fused_recurrent_cell_bias"]) - fused_recurrent_population_input = bool(static_tensors["fused_recurrent_population_input"]) - population_materialized = cast(dict[str, object | None], static_tensors["population_materialized"]) - device = boundary_seq.device if boundary_seq is not None else cast(torch.Tensor, hidden_seq).device - dtype = boundary_seq.dtype if boundary_seq is not None else cast(torch.Tensor, hidden_seq).dtype - static_tensors = _with_cached_population_static_tensors(runtime, static_tensors) - - running_state = state - last_k_rows: torch.Tensor | None = None - if ( - not step_mode - and hidden_seq is None - and boundary_seq is not None - and constant_k == 1 - and not grad_path - and not materialize_final_state - and backend_population_state_is_fresh - and output_contract in {"output_cells", "pooled_output_cells"} - ): - active_output_result = execute_temporal_bucket_active_output_window( - runtime, - boundary_seq=boundary_seq, - resets=population_resets, - static_tensors=static_tensors, - output_boundary=output_boundary, - ) - if active_output_result is not None: - output_cells, next_state = active_output_result - if output_contract == "pooled_output_cells": - output_cells = runtime._pool_output_ports(output_cells) - record_temporal_bucket_sequence_surface_execution( - runtime, - batch_size=batch_size, - time_steps=time_steps, - inner_steps=1, - training=grad_path, - output_boundary=output_boundary, - active_receiver_window_mode=_flat_bucket_active_output_region_mode(runtime), - active_receiver_window_offset=str(int(_flat_bucket_active_output_region(runtime).start)), - active_receiver_window_count=str(int(_flat_bucket_active_output_region(runtime).count)), - ) - return output_cells, next_state - if grad_path and not step_mode and hidden_seq is None and boundary_seq is not None and constant_k is not None: - inner_steps = int(constant_k) - if inner_steps <= 0: - raise RuntimeError("Temporal bucket sequence physical training requires positive constant K") - physical_boundary_seq, physical_resets, physical_transition_resets, output_step_indices = ( - _expand_temporal_boundary_for_constant_inner_steps( - boundary_seq=boundary_seq, - population_resets=population_resets, - inner_steps=inner_steps, - ) - ) - planned_backward_execution = runtime.plan_backend_backward_execution( - batch_size=batch_size, - time_steps=time_steps, - inner_steps=inner_steps, - training=True, - tape_policy=tape_policy, - device=device, - ) - output_seq, next_state = _run_temporal_bucket_sequence_physical_autograd( - runtime, - boundary_seq=physical_boundary_seq, - state=running_state, - population_resets=physical_resets, - transition_resets=physical_transition_resets, - static_tensors=static_tensors, - planned_backward_execution=planned_backward_execution, - materialize_final_state=materialize_final_state, - output_contract=output_contract, - ) - if output_step_indices is not None: - output_seq = output_seq.index_select(1, output_step_indices) - runtime._last_flat_bucket_transition_backward_executor = _PHYSICAL_TEMPORAL_BACKWARD_EXECUTOR - record_temporal_bucket_step_loop_execution( - runtime, - batch_size=batch_size, - time_steps=time_steps, - inner_steps=inner_steps, - training=grad_path, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - output_contract=output_contract, - ) - return output_seq, next_state - if step_mode: - input_k_seq = None - input_v_seq = None - if boundary_seq is not None: - with torch.profiler.record_function("fabric.physical.message.input_projection_sequence"): - input_k_seq, input_v_seq = runtime._project_sender_kv_from_cells_sequence( - boundary_seq, - sender_input_to_kv_weight=input_sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, - sender_group_size=runtime._input_sender_kv_group_size, - ) - input_k_seq = input_k_seq.transpose(0, 1).contiguous() - input_v_seq = input_v_seq.transpose(0, 1).contiguous() - if constant_k is None: - k_rows, max_steps = runtime._resolve_step_k( - k, - batch_size=batch_size, - time_steps=time_steps, - step_index=0, - device=device, - ) - all_active = None - else: - k_rows = torch.full((batch_size,), constant_k, device=device, dtype=torch.long) - max_steps = constant_k - all_active = constant_k > 0 - last_k_rows = k_rows - step_resets = population_resets[:, 0] if population_resets is not None else None - step_population_state_cache = None - if use_fresh_backend_population_cache: - step_population_state_cache = runtime._prepare_fresh_stream_step_population_cache( - batch=batch_size, - device=device, - dtype=dtype, - ) - elif constant_k == 1: - step_population_state_cache = runtime._prepare_stream_step_population_cache( - running_state, - batch=batch_size, - device=device, - dtype=dtype, - ) - y_step, next_state = runtime._forward_stream_step( - hidden_step=hidden_seq[:, 0] if hidden_seq is not None else None, - state=running_state, - resets=step_resets, - has_resets=step_reset_flags[0] if step_reset_flags is not None else None, - capture_active=capture_active, - k_rows=k_rows, - max_steps=max_steps, - all_active=all_active if max_steps <= 1 else None, - q=q, - recurrent_q=recurrent_q, - output_q=output_q, - gathered_kv_weight=gathered_kv_weight, - sender_kv_weight=sender_kv_weight, - sender_input_to_kv_weight=sender_input_to_kv_weight, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - sender_group_input_to_kv_weight=sender_group_input_to_kv_weight, - sender_group_size=runtime._sender_kv_group_size, - recurrent_sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, - recurrent_group_input_to_kv_weight=recurrent_group_input_to_kv_weight, - recurrent_group_size=runtime._recurrent_sender_kv_group_size, - value_to_cell_weight=value_to_cell_weight, - fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, - value_to_output_weight=value_to_output_weight, - cell_bias=cell_bias, - recurrent_cell_bias=recurrent_cell_bias, - fused_recurrent_cell_bias=fused_recurrent_cell_bias, - fused_recurrent_population_input=fused_recurrent_population_input, - boundary_step=boundary_seq[:, 0] if boundary_seq is not None else None, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - population_materialized=population_materialized, - step_population_state_cache=step_population_state_cache, - grad_path=grad_path, - input_k_step=input_k_seq[0] if input_k_seq is not None else None, - input_v_step=input_v_seq[0] if input_v_seq is not None else None, - backend_static_tensors=static_tensors, - backend_population_name=None, - selected_backend_surface=None, - planned_backend_execution=None, - backend_population_state_is_fresh=backend_population_state_is_fresh, - materialize_population_next_state=materialize_final_state or grad_path, - materialize_cells_state=materialize_final_state or output_contract == "full_cells", - ) - if step_population_state_cache is not None and materialize_final_state: - runtime._apply_stream_step_population_cache(next_state, step_population_state_cache) - record_temporal_bucket_step_loop_execution( - runtime, - batch_size=batch_size, - time_steps=1, - inner_steps=int(max_steps), - training=grad_path, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - output_contract=output_contract, - ) - return ( - _apply_temporal_output_contract(runtime, y_step, output_contract), - next_state if materialize_final_state else TensorDict({}, batch_size=[]), - ) - - outputs: list[torch.Tensor] | None = [] if grad_path else None - outputs_buffer: torch.Tensor | None = None - constant_k_rows = None - constant_max_steps = None - constant_all_active = None - step_population_state_cache = None - step_sender_cache = None - running_state_is_fresh = backend_population_state_is_fresh - if constant_k is not None: - constant_k_rows = torch.full((batch_size,), constant_k, device=device, dtype=torch.long) - constant_max_steps = constant_k - constant_all_active = constant_k > 0 - if constant_k > 0: - if use_fresh_backend_population_cache: - step_population_state_cache = runtime._prepare_fresh_stream_step_population_cache( - batch=batch_size, - device=device, - dtype=dtype, - ) - else: - step_population_state_cache = runtime._prepare_stream_step_population_cache( - running_state, - batch=batch_size, - device=device, - dtype=dtype, - ) - if constant_k == 1 and (materialize_final_state or grad_path or time_steps > 1): - step_sender_cache = {} - input_k_seq = None - input_v_seq = None - if boundary_seq is not None: - with torch.profiler.record_function("fabric.physical.message.input_projection_sequence"): - input_k_seq, input_v_seq = runtime._project_sender_kv_from_cells_sequence( - boundary_seq, - sender_input_to_kv_weight=input_sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, - sender_group_size=runtime._input_sender_kv_group_size, - ) - input_k_seq = input_k_seq.transpose(0, 1).contiguous() - input_v_seq = input_v_seq.transpose(0, 1).contiguous() - for step_index in range(time_steps): - if constant_k_rows is None or constant_max_steps is None: - k_rows, max_steps = runtime._resolve_step_k( - k, - batch_size=batch_size, - time_steps=time_steps, - step_index=step_index, - device=device, - ) - all_active = None - else: - k_rows, max_steps = constant_k_rows, constant_max_steps - all_active = constant_all_active if max_steps <= 1 else None - last_k_rows = k_rows - step_resets = population_resets[:, step_index] if population_resets is not None else None - y_step, running_state = runtime._forward_stream_step( - hidden_step=hidden_seq[:, step_index] if hidden_seq is not None else None, - state=running_state, - resets=step_resets, - has_resets=step_reset_flags[step_index] if step_reset_flags is not None else None, - capture_active=capture_active, - k_rows=k_rows, - max_steps=max_steps, - all_active=all_active, - q=q, - recurrent_q=recurrent_q, - output_q=output_q, - gathered_kv_weight=gathered_kv_weight, - sender_kv_weight=sender_kv_weight, - sender_input_to_kv_weight=sender_input_to_kv_weight, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - sender_group_input_to_kv_weight=sender_group_input_to_kv_weight, - sender_group_size=runtime._sender_kv_group_size, - recurrent_sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, - recurrent_group_input_to_kv_weight=recurrent_group_input_to_kv_weight, - recurrent_group_size=runtime._recurrent_sender_kv_group_size, - value_to_cell_weight=value_to_cell_weight, - fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, - value_to_output_weight=value_to_output_weight, - cell_bias=cell_bias, - recurrent_cell_bias=recurrent_cell_bias, - fused_recurrent_cell_bias=fused_recurrent_cell_bias, - fused_recurrent_population_input=fused_recurrent_population_input, - boundary_step=boundary_seq[:, step_index] if boundary_seq is not None else None, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - population_materialized=population_materialized, - step_population_state_cache=step_population_state_cache, - step_sender_cache=step_sender_cache, - grad_path=grad_path, - input_k_step=input_k_seq[step_index] if input_k_seq is not None else None, - input_v_step=input_v_seq[step_index] if input_v_seq is not None else None, - backend_static_tensors=static_tensors, - backend_population_name=None, - selected_backend_surface=None, - planned_backend_execution=None, - backend_population_state_is_fresh=running_state_is_fresh, - materialize_population_next_state=materialize_final_state or grad_path or step_index + 1 < time_steps, - materialize_cells_state=( - materialize_final_state or step_index + 1 < time_steps or output_contract == "full_cells" - ), - ) - running_state_is_fresh = False - if grad_path: - assert outputs is not None - outputs.append(_apply_temporal_output_contract(runtime, y_step, output_contract)) - else: - y_step = _apply_temporal_output_contract(runtime, y_step, output_contract) - if outputs_buffer is None: - outputs_buffer = y_step.new_empty(batch_size, time_steps, *y_step.shape[1:]) - assert outputs_buffer is not None - outputs_buffer[:, step_index].copy_(y_step) - if step_population_state_cache is not None and materialize_final_state: - runtime._apply_stream_step_population_cache(running_state, step_population_state_cache) - inner_steps = int( - constant_k if constant_k is not None else max(1, int(cast(torch.Tensor, last_k_rows).max().item())) - ) - record_temporal_bucket_step_loop_execution( - runtime, - batch_size=batch_size, - time_steps=time_steps, - inner_steps=inner_steps, - training=grad_path, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - output_contract=output_contract, - ) - if grad_path: - assert outputs is not None - return torch.stack(outputs, dim=1), running_state if materialize_final_state else TensorDict({}, batch_size=[]) - assert outputs_buffer is not None - return outputs_buffer, running_state if materialize_final_state else TensorDict({}, batch_size=[]) - - -def _apply_temporal_output_contract( - runtime: Any, - y_step: torch.Tensor, - output_contract: _TemporalOutputContract, -) -> torch.Tensor: - if output_contract == "full_cells": - return y_step - if output_contract == "output_cells": - return runtime._select_output_cells(y_step.unsqueeze(1)).squeeze(1) - if output_contract == "pooled_output_cells": - output_cells = runtime._select_output_cells(y_step.unsqueeze(1)) - return runtime._pool_output_ports(output_cells).squeeze(1) - raise RuntimeError(f"Unsupported temporal output contract {output_contract!r}") - - -def supports_temporal_bucket_active_output_window(runtime: Any, *, time_steps: int) -> bool: - if time_steps < 1: - return False - active_populations = tuple( - name for name in runtime._population_names if int(runtime._population_recurrent_indices(name).numel()) > 0 - ) - if len(active_populations) <= 1: - return False - if not bool(runtime._partitioned_layout) or not bool(runtime._local_message_step_enabled): - return False - if runtime._has_edge_delay or bool(getattr(runtime, "_uses_sparse_message_backend", False)): - return False - active_recurrent_idx = getattr(runtime, "flat_bucket_active_output_recurrent_idx", None) - if not torch.is_tensor(active_recurrent_idx) or int(active_recurrent_idx.numel()) == 0: - return False - if not bool(getattr(runtime, "_flat_bucket_active_output_region_compact_contiguous", False)): - return False - if not bool(getattr(runtime, "_flat_bucket_active_output_region_is_full", False)): - return False - for population_name in runtime._population_names: - if int(runtime._population_recurrent_indices(population_name).numel()) == 0: - continue - population_spec = runtime._backend_population_specs.get(population_name) - if population_spec is None or not _transition_supports_receiver_local_dependency_window( - population_spec.transition_ir - ): - return False - return True - - -def _active_output_receiver_window(runtime: Any) -> _ReceiverWindowSpec | None: - full_count = int(runtime._num_recurrent_cells) - start = int(getattr(runtime, "_flat_bucket_active_output_region_start", 0)) - count = int(getattr(runtime, "_flat_bucket_active_output_region_count", 0)) - if count <= 0 or count >= full_count: - return None - if not bool(getattr(runtime, "_flat_bucket_active_output_region_compact_contiguous", False)): - return None - return _ReceiverWindowSpec( - mode=_flat_bucket_active_output_region_mode(runtime), - start=start, - count=count, - full_count=full_count, - ) - - -def _flat_bucket_active_output_region_mode(runtime: Any) -> str: - return str( - getattr( - runtime, - "_flat_bucket_active_output_region_mode", - getattr(runtime, "_flat_bucket_output_recurrent_closure_mode", "unknown"), - ) - ) - - -def _flat_bucket_active_output_region(runtime: Any) -> ClosedRecurrentRegion: - indices = getattr( - runtime, - "_flat_bucket_active_output_region_indices", - getattr(runtime, "_flat_bucket_output_recurrent_closure_indices", ()), - ) - return ClosedRecurrentRegion( - indices=tuple(indices), - full_count=int(runtime._num_recurrent_cells), - ) - - -def _flat_bucket_output_recurrent_closure(runtime: Any) -> ClosedRecurrentRegion: - return ClosedRecurrentRegion( - indices=tuple(runtime._flat_bucket_output_recurrent_closure_indices), - full_count=int(runtime._num_recurrent_cells), - ) - - -def execute_temporal_bucket_active_output_window( - runtime: Any, - *, - boundary_seq: torch.Tensor | None = None, - projected_boundary_source_seq: torch.Tensor | None = None, - projected_boundary_weight: torch.Tensor | None = None, - projected_boundary_bias: torch.Tensor | None = None, - resets: torch.Tensor | None, - static_tensors: dict[str, object], - output_boundary: Literal["sequence", "terminal"], -) -> tuple[torch.Tensor, TensorDict] | None: - if boundary_seq is None: - if projected_boundary_source_seq is None or projected_boundary_weight is None: - raise RuntimeError( - "Active-output flat-bucket sequence requires boundary_seq or projected boundary source inputs" - ) - batch_size = int(projected_boundary_source_seq.shape[0]) - time_steps = int(projected_boundary_source_seq.shape[1]) - output_template = projected_boundary_source_seq - else: - batch_size = int(boundary_seq.shape[0]) - time_steps = int(boundary_seq.shape[1]) - output_template = boundary_seq - if not supports_temporal_bucket_active_output_window(runtime, time_steps=time_steps): - return None - static_tensors = _with_cached_population_static_tensors(runtime, static_tensors) - active_recurrent_idx = runtime.flat_bucket_active_output_recurrent_idx - active_count = int(active_recurrent_idx.numel()) - if active_count == 0: - return None - window = _active_output_receiver_window(runtime) - input_sender_input_to_kv_weight = cast( - torch.Tensor | None, - static_tensors.get("input_sender_input_to_kv_weight"), - ) - input_group_input_to_kv_weight = cast( - torch.Tensor | None, - static_tensors.get("input_group_input_to_kv_weight"), - ) - recurrent_sender_input_to_kv_weight = cast( - torch.Tensor | None, - static_tensors.get("recurrent_sender_input_to_kv_weight"), - ) - input_k_seq = None - input_v_seq = None - if boundary_seq is not None: - with torch.profiler.record_function("fabric.physical.message.input_projection_sequence"): - input_k_seq, input_v_seq = runtime._project_sender_kv_from_cells_sequence( - boundary_seq, - sender_input_to_kv_weight=input_sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, - sender_group_size=runtime._input_sender_kv_group_size, - ) - input_k_seq = input_k_seq.transpose(0, 1).contiguous() - input_v_seq = input_v_seq.transpose(0, 1).contiguous() - recurrent_q = cast(torch.Tensor, static_tensors["recurrent_q"]) - active_recurrent_q = recurrent_q.index_select(0, active_recurrent_idx) - full_recurrent_count = int(runtime._num_recurrent_cells) - recurrent_k = output_template.new_zeros(batch_size, full_recurrent_count, runtime.head_dim) - recurrent_v = output_template.new_zeros(batch_size, full_recurrent_count, runtime.value_dim) - recurrent_local_sender_idx = runtime._cached_receiver_window_sender_table( - name="flat_bucket_recurrent", - table=runtime.recurrent_local_sender_idx, - window=window, - num_input_senders=int(runtime._num_input_cells), - slice_receivers=True, - compact_recurrent_senders=False, - ) - recurrent_sender_count = int(runtime._num_input_cells) + full_recurrent_count - recurrent_local_receiver_idx_by_sender = runtime._cached_sender_reverse_table( - name="flat_bucket_recurrent", - receiver_sender_idx=recurrent_local_sender_idx, - num_senders=recurrent_sender_count, - ) - output_q = cast(torch.Tensor, static_tensors["output_q"]) - output_local_sender_idx = runtime._cached_receiver_window_sender_table( - name="flat_bucket_output", - table=runtime.output_local_sender_idx, - window=window, - num_input_senders=int(runtime._num_input_cells), - slice_receivers=False, - compact_recurrent_senders=False, - ) - output_local_receiver_idx_by_sender = runtime._cached_sender_reverse_table( - name="flat_bucket_output", - receiver_sender_idx=output_local_sender_idx, - num_senders=recurrent_sender_count, - ) - value_to_output_weight = cast(torch.Tensor, static_tensors["value_to_output_weight"]) - recurrent_neighbor_idx = runtime.recurrent_neighbor_idx.index_select(0, active_recurrent_idx) - recurrent_neighbor_valid = runtime.recurrent_neighbor_valid.index_select(0, active_recurrent_idx) - recurrent_edge_distance = runtime.recurrent_edge_distance.index_select(0, active_recurrent_idx) - recurrent_edge_delay = runtime.recurrent_edge_delay.index_select(0, active_recurrent_idx) - active_window_buckets = runtime._flat_bucket_active_output_window_buckets() - step_population_state_cache: dict[str, object] = {} - output_steps: list[torch.Tensor] = [] - last_output_cells: torch.Tensor | None = None - - def reset_recurrent_bank(tensor: torch.Tensor, reset_step: torch.Tensor | None) -> torch.Tensor: - if reset_step is None: - return tensor - mask = reset_step.to(device=tensor.device, dtype=torch.bool).view(batch_size, 1, 1) - return torch.where(mask, torch.zeros_like(tensor), tensor) - - for step_index in range(time_steps): - reset_step = resets[:, step_index] if resets is not None and resets.dim() == 2 else resets - recurrent_k = reset_recurrent_bank(recurrent_k, reset_step) - recurrent_v = reset_recurrent_bank(recurrent_v, reset_step) - if input_k_seq is not None and input_v_seq is not None: - input_k = input_k_seq[step_index] - input_v = input_v_seq[step_index] - else: - assert projected_boundary_source_seq is not None and projected_boundary_weight is not None - with torch.profiler.record_function("fabric.physical.input_projection.active_window_step"): - boundary_step = torch.nn.functional.linear( - projected_boundary_source_seq[:, step_index], - projected_boundary_weight, - projected_boundary_bias, - ).view(batch_size, int(runtime._num_input_cells), int(runtime.hidden_size)) - with torch.profiler.record_function("fabric.physical.message.input_projection_step"): - input_k, input_v = runtime._project_sender_kv_from_cells_step( - boundary_step, - sender_input_to_kv_weight=input_sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=input_group_input_to_kv_weight, - sender_group_size=runtime._input_sender_kv_group_size, - ) - with torch.profiler.record_function("fabric.physical.message.active_window_recurrent"): - recurrent_msg = runtime._compute_messages_step_subset_partitioned_raw( - input_k, - input_v, - recurrent_k, - recurrent_v, - q_subset=active_recurrent_q, - neighbor_idx=recurrent_neighbor_idx, - neighbor_valid=recurrent_neighbor_valid, - edge_distance=recurrent_edge_distance, - edge_delay=recurrent_edge_delay, - use_delay=runtime._has_edge_delay, - step_idx=1, - local_sender_idx=recurrent_local_sender_idx, - local_receiver_idx_by_sender=recurrent_local_receiver_idx_by_sender, - owner_tag="active_window_recurrent", - ) - with torch.profiler.record_function("fabric.physical.transition_buckets.active_window"): - active_recurrent_hidden, _transition_recurrent_k, _transition_recurrent_v = ( - runtime._run_active_window_transition_buckets_step( - recurrent_msg, - active_recurrent_idx=active_recurrent_idx, - active_window_buckets=active_window_buckets, - resets=reset_step, - batch_size=batch_size, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=step_index + 1 < time_steps, - ) - ) - active_recurrent_sender_input_to_kv_weight = ( - recurrent_sender_input_to_kv_weight.index_select(0, active_recurrent_idx) - if torch.is_tensor(recurrent_sender_input_to_kv_weight) - else None - ) - active_recurrent_k, active_recurrent_v = runtime._project_sender_kv_from_cells_step( - active_recurrent_hidden, - sender_input_to_kv_weight=active_recurrent_sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=None, - sender_group_size=runtime._recurrent_sender_kv_group_size, - ) - recurrent_k = recurrent_k.clone() - recurrent_v = recurrent_v.clone() - recurrent_k.index_copy_(1, active_recurrent_idx, active_recurrent_k) - recurrent_v.index_copy_(1, active_recurrent_idx, active_recurrent_v) - with torch.profiler.record_function("fabric.physical.message.active_window_readout"): - output_msg = runtime._compute_messages_step_subset_partitioned_raw( - input_k, - input_v, - recurrent_k, - recurrent_v, - q_subset=output_q, - neighbor_idx=runtime.output_neighbor_idx, - neighbor_valid=runtime.output_neighbor_valid, - edge_distance=runtime.output_edge_distance, - edge_delay=runtime.output_edge_delay, - use_delay=runtime._has_edge_delay, - step_idx=1, - local_sender_idx=output_local_sender_idx, - local_receiver_idx_by_sender=output_local_receiver_idx_by_sender, - owner_tag="active_window_readout", - ) - with torch.profiler.record_function("fabric.physical.readout.active_window_projection"): - output_cells = runtime._project_output_cells_step_raw( - output_msg, - value_to_output_weight=value_to_output_weight, - ).to(dtype=output_template.dtype) - last_output_cells = output_cells - if output_boundary == "sequence": - output_steps.append(output_cells) - if output_boundary == "terminal": - if last_output_cells is None: - return None - return last_output_cells.unsqueeze(1), TensorDict({}, batch_size=[]) - return torch.stack(output_steps, dim=1), TensorDict({}, batch_size=[]) - - -def record_temporal_bucket_sequence_surface_execution( - runtime: Any, - *, - batch_size: int, - time_steps: int, - inner_steps: int, - training: bool, - output_boundary: Literal["sequence", "terminal"], - active_receiver_window_mode: str, - active_receiver_window_offset: str, - active_receiver_window_count: str, -) -> None: - if runtime._last_backend_execution is not None: - return - active_populations = tuple( - name for name in runtime._population_names if runtime._population_recurrent_indices(name).numel() > 0 - ) - runtime._last_backend_execution = BackendExecutionRecord( - backend_name="cuda", - surface_key="flat_bucket_sequence_surface", - cell_type="bucketed", - regime="stream", - training=training, - batch_size=batch_size, - time_steps=time_steps, - inner_steps=inner_steps, - bucket_ids=tuple(bucket.bucket_id for bucket in runtime.backend_ir.buckets), - execution_families=("message", "transition_buckets", "readout"), - math_backends=("cuda_tensor_ops",), - tape_policy_bin="autograd" if training else "none", - graph_capture_enabled=False, - capability_variants=("flat_bucket_sequence_surface", "active_output_window"), - launch_temporal_executions=("temporal_bucket_sequence",), - launch_scan_implementations=("active_output_window",), - physical_op_kinds=( - "message", - "receiver_affine", - "state_epilogue", - "diagonal_recurrence", - "readout", - "glue/layout", - ), - physical_op_executors=( - "flat_bucket_sequence_surface", - "shared_graph_message", - f"transition_buckets={','.join(active_populations)}", - "active_window_static_buckets", - "readout_projection", - ), - physical_boundary_contracts=( - "shared_public_message_substrate", - "population_local_state_banks", - "projected_message", - "active_output_dependency_window", - "readout_boundary", - ), - physical_op_demotions=(), - active_receiver_window_modes=(active_receiver_window_mode,), - active_receiver_window_offsets=(active_receiver_window_offset,), - active_receiver_window_counts=(active_receiver_window_count,), - workspace_aliases=( - f"sequence_output_boundary:{'terminal_step' if output_boundary == 'terminal' else 'all_steps'}", - f"sequence_output_materialization:{'terminal_step_only' if output_boundary == 'terminal' else 'all_steps'}", - "sequence_output_contract:output_cells", - "final_state=not_materialized", - ), - ) - - -def _runtime_sparse_message_bucket_kind(runtime: Any) -> str: - return ( - "ragged_grouped_sparse" - if int(getattr(runtime, "_recurrent_sparse_positive_degree_buckets", 0)) > 1 - else "degree_bucketed_sparse" - ) - - -def _runtime_sparse_degree_summary(runtime: Any) -> str: - degree_ptr = getattr(runtime, "recurrent_sparse_degree_ptr", None) - if torch.is_tensor(degree_ptr) and int(degree_ptr.numel()) > 1: - degrees: list[str] = [] - degree_ptr_cpu = degree_ptr.detach().cpu() - for degree in range(int(degree_ptr_cpu.numel()) - 1): - count = int(degree_ptr_cpu[degree + 1].item()) - int(degree_ptr_cpu[degree].item()) - if count > 0: - degrees.append(f"{degree}:{count}") - if degrees: - return "degrees=" + ",".join(degrees) - neighbor_valid = getattr(runtime, "recurrent_neighbor_valid", None) - if torch.is_tensor(neighbor_valid) and int(neighbor_valid.numel()) > 0: - degree_counts = torch.bincount(neighbor_valid.to(dtype=torch.long).sum(dim=1).detach().cpu()) - degrees = [ - f"{degree}:{int(count.item())}" for degree, count in enumerate(degree_counts) if int(count.item()) > 0 - ] - if degrees: - return "degrees=" + ",".join(degrees) - return "degrees=unknown" - - -def _runtime_message_record_metadata(runtime: Any) -> dict[str, tuple[str, ...]]: - uses_sparse = bool(getattr(runtime, "_uses_sparse_message_backend", False)) - reset_policy = "zero_source_rows" - reset_scope = "batch_row" - use_delay = bool(getattr(runtime, "_has_edge_delay", False)) - use_delay_str = "true" if use_delay else "false" - if not uses_sparse and bool(getattr(runtime, "_local_message_step_enabled", False)): - degree = int(getattr(runtime, "recurrent_local_sender_idx", torch.empty(0, 0)).shape[1]) - bucket_kind = "regular_local_receiver_owned" - return { - "message_projection_bucket_kinds": ("regular_local_projected_message_boundary",), - "message_bucket_count": ("1",), - "message_regular_local_bucket_count": ("1",), - "message_sparse_bucket_count": ("0",), - "message_batched_backend_count": ("1",), - "message_grouped_backend_count": ("0",), - "message_reset_aware_bucket_count": ("1",), - "message_degree_uniform_bucket_count": ("1",), - "message_ragged_grouped_bucket_count": ("0",), - "message_demoted_bucket_count": ("0",), - "message_bucket_signatures": ( - f"bucket_kind={bucket_kind}|topology_kind=regular_local|degree_or_block={degree}|" - f"K={int(runtime.head_dim)}|V={int(runtime.value_dim)}|reset_policy={reset_policy}|" - f"use_delay={use_delay_str}", - ), - "message_bucket_kinds": (bucket_kind,), - "message_topology_kinds": ("regular_local",), - "message_spatial_ownership": ("receiver_owned",), - "message_degree_bucket_lists": (f"degree={degree}",), - "message_logit_backends": ("direct_fixed_degree",), - "message_softmax_backends": ("custom_fixed_degree_softmax",), - "message_weighted_value_backends": ("direct_fixed_degree",), - "message_physical_mode": ("regular_local_direct_projected",), - "message_execution_mode": ("direct_fixed_degree",), - "message_output_boundary": ("projected_message",), - "message_degree": (str(degree),), - "message_k": (str(int(runtime.head_dim)),), - "message_v": (str(int(runtime.value_dim)),), - "message_projected_n": (str(int(runtime.config.d_msg)),), - "message_reset_policies": (reset_policy,), - "message_reset_scopes": (reset_scope,), - "message_use_delay": (use_delay_str,), - "message_distance_penalty_kinds": ("offset_distance",), - "message_epilogue_kinds": ("softmax_weighted_sum",), - "message_packed_source_reuse_count": ("1",), - "message_demotions": ("none",), - "message_workspace_mode": ("fixed_degree_direct",), - } - bucket_kind = _runtime_sparse_message_bucket_kind(runtime) - degree = int(getattr(runtime, "recurrent_neighbor_idx", torch.empty(0, 0)).shape[1]) - topology_kind = ( - "edge_owned_sparse" if int(getattr(runtime.config, "patch_edges_per_cell", 0)) > 0 else "receiver_owned_sparse" - ) - spatial_ownership = "edge_owned" if topology_kind == "edge_owned_sparse" else "receiver_owned" - degree_uniform = bucket_kind == "degree_bucketed_sparse" - execution_mode = "degree_bucketed_batched" if degree_uniform else "ragged_grouped" - gemm_backend = "batched_gemm" if degree_uniform else "grouped_gemm" - return { - "message_projection_bucket_kinds": ("sparse_projected_message_boundary",), - "message_bucket_count": ("1",), - "message_regular_local_bucket_count": ("0",), - "message_sparse_bucket_count": ("1",), - "message_batched_backend_count": ("1" if degree_uniform else "0",), - "message_grouped_backend_count": ("0" if degree_uniform else "1",), - "message_reset_aware_bucket_count": ("1",), - "message_degree_uniform_bucket_count": ("1" if degree_uniform else "0",), - "message_ragged_grouped_bucket_count": ("0" if degree_uniform else "1",), - "message_demoted_bucket_count": ("0",), - "message_bucket_signatures": ( - f"bucket_kind={bucket_kind}|topology_kind={topology_kind}|degree_or_block={degree}|" - f"K={int(runtime.head_dim)}|V={int(runtime.value_dim)}|reset_policy={reset_policy}|" - f"use_delay={use_delay_str}|distance_penalty_kind=edge_distance", - ), - "message_bucket_kinds": (bucket_kind,), - "message_topology_kinds": (topology_kind,), - "message_spatial_ownership": (spatial_ownership,), - "message_degree_bucket_lists": (_runtime_sparse_degree_summary(runtime),), - "message_logit_backends": (gemm_backend,), - "message_softmax_backends": ("custom_segment_softmax",), - "message_weighted_value_backends": (gemm_backend,), - "message_physical_mode": ( - "sparse_degree_bucketed_projected" if degree_uniform else "sparse_ragged_grouped_projected", - ), - "message_execution_mode": (execution_mode,), - "message_output_boundary": ("projected_message",), - "message_degree": (str(degree),), - "message_k": (str(int(runtime.head_dim)),), - "message_v": (str(int(runtime.value_dim)),), - "message_projected_n": (str(int(runtime.config.d_msg)),), - "message_reset_policies": (reset_policy,), - "message_reset_scopes": (reset_scope,), - "message_use_delay": (use_delay_str,), - "message_distance_penalty_kinds": ("edge_distance",), - "message_epilogue_kinds": ("segment_softmax_weighted_sum",), - "message_packed_source_reuse_count": ("1",), - "message_demotions": ("none",), - "message_workspace_mode": ("degree_bucketed_sparse" if degree_uniform else "degree_grouped_sparse_ragged",), - } - - -def _runtime_message_backward_metadata(runtime: Any) -> tuple[str, str, str, str]: - if bool(getattr(runtime, "_uses_sparse_message_backend", False)): - return ( - "sparse_message_superop_backward", - "physical_sparse_message_backward_executor", - "sparse_message_superop_backward:partitioned_cuda", - "sparse_message_superop_backward:active_sparse_cuda_owner", - ) - return ( - "tiny_message_superop_backward", - "physical_tiny_message_backward_executor", - "tiny_message_superop_backward:fused_receiver_sender_cuda", - "tiny_message_superop_backward:active_fused_receiver_sender_cuda_owner", - ) - - -def _runtime_transition_backward_record_metadata(runtime: Any) -> dict[str, tuple[str, ...]]: - kinds: list[str] = [] - executors: list[str] = [] - boundaries: list[str] = [] - launch_counts: list[str] = [] - saved_launch_counts: list[str] = [] - residual_demotions: list[str] = [] - - def add( - *, - kind: str, - executor: str, - boundary: str, - launch_count: str, - saved_launch_count: str | None = None, - residual_demotion: str | None = None, - ) -> None: - if kind not in kinds: - kinds.append(kind) - if executor not in executors: - executors.append(executor) - if boundary not in boundaries: - boundaries.append(boundary) - if launch_count and launch_count not in launch_counts: - launch_counts.append(launch_count) - if saved_launch_count and saved_launch_count not in saved_launch_counts: - saved_launch_counts.append(saved_launch_count) - if residual_demotion and residual_demotion not in residual_demotions: - residual_demotions.append(residual_demotion) - - for population_name in _active_population_names(runtime): - population_spec = runtime._backend_population_specs[population_name] - op_names = frozenset(op.name for op in population_spec.transition_ir.ops) - if "matmul" in op_names: - add( - kind="receiver_affine_superop_backward", - executor="physical_receiver_affine_backward_executor", - boundary="state_affine_output", - launch_count="receiver_affine_superop_backward:physical_cuda_tiled", - saved_launch_count="receiver_affine_superop_backward:active_cuda_owner", - ) - if "diag_rtu" in op_names: - add( - kind="diagonal_recurrence_superop_backward", - executor="physical_diagonal_recurrence_backward_executor", - boundary="raw_public", - launch_count="diagonal_recurrence_superop_backward:triton_cuda", - saved_launch_count="diagonal_recurrence_superop_backward:active_cuda_owner", - ) - if "gated_logspace_recurrence" in op_names or "norm_or_identity" in op_names: - add( - kind="lowered_state_epilogue_backward", - executor="physical_state_epilogue_backward_executor", - boundary="raw_public", - launch_count="state_epilogue_backward:gated_logspace_cuda_tiled", - saved_launch_count="state_epilogue_backward:active_cuda_owner", - residual_demotion="lowered_state_epilogue_backward:explicit_cuda_executor", - ) - - add( - kind="lowered_public_projection_backward", - executor="explicit_public_projection_thin_reverse", - boundary="public_projection", - launch_count="public_projection_backward:explicit_thin_reverse", - residual_demotion="lowered_public_projection_backward:thin_reverse_path:explicit_executor", - ) - add( - kind="lowered_readout_projection_backward", - executor="explicit_readout_projection_thin_reverse", - boundary="readout_boundary", - launch_count="readout_projection_backward:explicit_thin_reverse", - residual_demotion="lowered_readout_projection_backward:thin_reverse_path:explicit_executor", - ) - return { - "kinds": tuple(kinds), - "executors": tuple(executors), - "boundaries": tuple(boundaries), - "launch_counts": tuple(launch_counts), - "saved_launch_counts": tuple(saved_launch_counts), - "residual_demotions": tuple(residual_demotions), - } - - -def _runtime_state_affine_record_metadata(runtime: Any) -> dict[str, tuple[str, ...]]: - has_recurrent_affine = False - for population_name in runtime._population_names: - population_spec = runtime._backend_population_specs.get(population_name) - transition_ir = getattr(population_spec, "transition_ir", None) - if transition_ir is None: - continue - if any(getattr(op, "name", None) == "matmul" for op in getattr(transition_ir, "ops", ())): - has_recurrent_affine = True - break - if not has_recurrent_affine: - return { - "backends": (), - "sources": (), - "reset_policies": (), - "reset_mode": (), - "reset_scope": (), - } - return { - "backends": ("receiver_affine_superop",), - "sources": ("projected_message", "previous_state"), - "reset_policies": ("none", "zero_source_rows"), - "reset_mode": ("row_mask_pack",), - "reset_scope": ("batch_row",), - } - - -def record_temporal_bucket_step_loop_execution( - runtime: Any, - *, - batch_size: int, - time_steps: int, - inner_steps: int, - training: bool, - materialize_final_state: bool = True, - output_boundary: Literal["sequence", "terminal"] = "sequence", - output_contract: _TemporalOutputContract = "full_cells", -) -> None: - if runtime._last_backend_execution is not None: - return - active_populations = tuple( - name for name in runtime._population_names if runtime._population_recurrent_indices(name).numel() > 0 - ) - recurrent_closure = _flat_bucket_output_recurrent_closure(runtime) - output_only_streaming = bool( - not training - and not materialize_final_state - and output_contract in {"output_cells", "pooled_output_cells"} - and supports_temporal_bucket_active_output_window(runtime, time_steps=time_steps) - ) - active_region = _flat_bucket_active_output_region(runtime) if output_only_streaming else recurrent_closure - active_region_mode = ( - _flat_bucket_active_output_region_mode(runtime) if output_only_streaming else recurrent_closure.mode - ) - active_region_demotions = ( - ("active_region_closure_full_surface",) - if time_steps > 1 and recurrent_closure.is_full - else ("active_region_closure_ragged",) - if time_steps > 1 and not recurrent_closure.compact_contiguous - else () - ) - transition_backward_executor = getattr(runtime, "_last_flat_bucket_transition_backward_executor", None) - physical_transition_backward = transition_backward_executor == _PHYSICAL_TRANSITION_BACKWARD_EXECUTOR - physical_temporal_backward = transition_backward_executor == _PHYSICAL_TEMPORAL_BACKWARD_EXECUTOR - transition_tape_mode = getattr(runtime, "_last_flat_bucket_transition_tape_mode", None) - transition_tape_reason = getattr(runtime, "_last_flat_bucket_transition_tape_reason", None) - temporal_artifact_mode = getattr(runtime, "_last_flat_bucket_temporal_artifact_mode", None) - temporal_artifact_reason = getattr(runtime, "_last_flat_bucket_temporal_artifact_reason", None) - temporal_artifact_checkpoint_stride = getattr( - runtime, - "_last_flat_bucket_temporal_artifact_checkpoint_stride", - None, - ) - temporal_artifact_recompute_window_len = getattr( - runtime, - "_last_flat_bucket_temporal_artifact_recompute_window_len", - None, - ) - temporal_artifact_checkpoint_count = getattr( - runtime, - "_last_flat_bucket_temporal_artifact_checkpoint_count", - None, - ) - message_backward_kind, message_backward_executor, message_backward_launch, message_backward_saved_launch = ( - _runtime_message_backward_metadata(runtime) - ) - transition_backward_metadata = _runtime_transition_backward_record_metadata(runtime) - static_saved_launch_counts: tuple[str, ...] = () - if training: - static_saved_items: list[str] = [] - static_tape_mode = getattr(runtime, "_last_training_static_tape_mode", None) - if static_tape_mode: - static_saved_items.append(f"training_static_tape:{static_tape_mode}") - if getattr(runtime, "_last_training_static_prepack_mode", None) == "views": - static_saved_items.append("training_static_prepack:receiver_major_views") - if getattr(runtime, "_last_backward_projection_mode", None) == "factorized_recurrent_input": - static_saved_items.append("training_static_projection:factorized_receiver_input") - static_saved_launch_counts = tuple(static_saved_items) - backward_owner_plan = _temporal_backward_owner_plan( - training=training, - transition_backward_executor=transition_backward_executor, - active_region_demotions=active_region_demotions, - transition_tape_mode=transition_tape_mode, - message_backward_kind=message_backward_kind, - message_backward_executor=message_backward_executor, - ) - temporal_tape_policy_bin = ( - f"physical_temporal_bucket_{transition_tape_mode}_transition_tape" - if transition_tape_mode - else "physical_temporal_bucket_saved_tape" - ) - scan_implementation = ( - "windowed_temporal_physical_scan" - if physical_temporal_backward and temporal_artifact_mode == "recompute_step_artifacts" - else "stored_temporal_physical_scan" - if physical_temporal_backward and temporal_artifact_mode == "store_step_artifacts" - else "flat_bucket_temporal_scan" - if physical_temporal_backward - else "flat_bucket_temporal_scan" - ) - message_metadata = _runtime_message_record_metadata(runtime) - state_affine_metadata = _runtime_state_affine_record_metadata(runtime) - workspace_aliases = ( - f"sequence_output_boundary:{'terminal_step' if output_boundary == 'terminal' else 'all_steps'}", - f"sequence_output_materialization:{'terminal_step_only' if output_boundary == 'terminal' else 'all_steps'}", - f"sequence_output_contract:{output_contract}", - "final_state=materialized" if materialize_final_state else "final_state=not_materialized", - *( - ("flat_bucket_temporal_scan:recurrent_kv_carry_reuse",) - if bool(getattr(runtime, "_last_flat_bucket_temporal_recurrent_kv_carry_reuse", False)) - else () - ), - *( - ( - f"temporal_artifacts:{temporal_artifact_mode}", - f"temporal_artifact_checkpoint_stride:{temporal_artifact_checkpoint_stride}", - f"temporal_artifact_recompute_window_len:{temporal_artifact_recompute_window_len}", - f"temporal_artifact_checkpoint_count:{temporal_artifact_checkpoint_count}", - ) - if temporal_artifact_mode is not None - else () - ), - ) - runtime._last_backend_execution = BackendExecutionRecord( - backend_name="cuda", - surface_key="flat_bucket_sequence_surface", - cell_type="bucketed", - regime="stream", - training=training, - batch_size=batch_size, - time_steps=time_steps, - inner_steps=inner_steps, - bucket_ids=tuple(bucket.bucket_id for bucket in runtime.backend_ir.buckets), - execution_families=("message", "transition_buckets", "readout"), - math_backends=("cuda_tensor_ops",), - tape_policy_bin=temporal_tape_policy_bin - if training and physical_temporal_backward - else "hybrid_physical_transition" - if training and physical_transition_backward - else "autograd" - if training - else "none", - graph_capture_enabled=False, - capability_variants=("flat_bucket_sequence_surface", "flat_bucket_temporal_scan", scan_implementation), - launch_temporal_executions=("temporal_bucket_sequence",), - launch_scan_implementations=(scan_implementation,), - physical_op_kinds=( - "message", - "receiver_affine", - "state_epilogue", - "diagonal_recurrence", - "readout", - "glue/layout", - ), - physical_op_executors=( - "flat_bucket_sequence_surface", - "shared_graph_message", - f"transition_buckets={','.join(active_populations)}", - scan_implementation, - "readout_projection", - f"transition_tape={transition_tape_mode or 'unknown'}", - ), - physical_boundary_contracts=( - "shared_public_message_substrate", - "population_local_state_banks", - "projected_message", - "fixed_active_spatial_region", - "readout_boundary", - ), - state_affine_backends=state_affine_metadata["backends"], - state_affine_sources=state_affine_metadata["sources"], - state_affine_reset_policies=state_affine_metadata["reset_policies"], - state_affine_reset_mode=state_affine_metadata["reset_mode"], - state_affine_reset_scope=state_affine_metadata["reset_scope"], - physical_op_demotions=active_region_demotions, - active_receiver_window_modes=(active_region_mode,), - active_receiver_window_offsets=(str(active_region.start),), - active_receiver_window_counts=(str(active_region.count),), - workspace_aliases=workspace_aliases, - message_projection_boundaries=("projected_message",), - message_projection_bucket_kinds=message_metadata["message_projection_bucket_kinds"], - message_bucket_count=message_metadata["message_bucket_count"], - message_regular_local_bucket_count=message_metadata["message_regular_local_bucket_count"], - message_sparse_bucket_count=message_metadata["message_sparse_bucket_count"], - message_batched_backend_count=message_metadata["message_batched_backend_count"], - message_grouped_backend_count=message_metadata["message_grouped_backend_count"], - message_reset_aware_bucket_count=message_metadata["message_reset_aware_bucket_count"], - message_degree_uniform_bucket_count=message_metadata["message_degree_uniform_bucket_count"], - message_ragged_grouped_bucket_count=message_metadata["message_ragged_grouped_bucket_count"], - message_demoted_bucket_count=message_metadata["message_demoted_bucket_count"], - message_bucket_signatures=message_metadata["message_bucket_signatures"], - message_bucket_kinds=message_metadata["message_bucket_kinds"], - message_topology_kinds=message_metadata["message_topology_kinds"], - message_spatial_ownership=message_metadata["message_spatial_ownership"], - message_degree_bucket_lists=message_metadata["message_degree_bucket_lists"], - message_logit_backends=message_metadata["message_logit_backends"], - message_softmax_backends=message_metadata["message_softmax_backends"], - message_weighted_value_backends=message_metadata["message_weighted_value_backends"], - message_physical_mode=message_metadata["message_physical_mode"], - message_execution_mode=message_metadata["message_execution_mode"], - message_output_boundary=message_metadata["message_output_boundary"], - message_degree=message_metadata["message_degree"], - message_k=message_metadata["message_k"], - message_v=message_metadata["message_v"], - message_projected_n=message_metadata["message_projected_n"], - message_reset_policies=message_metadata["message_reset_policies"], - message_reset_scopes=message_metadata["message_reset_scopes"], - message_use_delay=message_metadata["message_use_delay"], - message_distance_penalty_kinds=message_metadata["message_distance_penalty_kinds"], - message_epilogue_kinds=message_metadata["message_epilogue_kinds"], - message_packed_source_reuse_count=message_metadata["message_packed_source_reuse_count"], - message_demotions=message_metadata["message_demotions"], - message_workspace_mode=message_metadata["message_workspace_mode"], - backward_physical_op_kinds=( - (message_backward_kind, *transition_backward_metadata["kinds"], "glue/layout") if training else () - ), - backward_physical_op_executors=( - ( - message_backward_executor, - *transition_backward_metadata["executors"], - "physical_temporal_bucket_sequence_backward", - ) - if training - else () - ), - backward_physical_op_demotions=backward_owner_plan.demotions, - backward_boundary_contracts=( - ("projected_message", *transition_backward_metadata["boundaries"], "fixed_active_spatial_region") - if training - else () - ), - backward_tape_mode=backward_owner_plan.tape_modes, - backward_launch_counts=( - (message_backward_launch, *transition_backward_metadata["launch_counts"]) if training else () - ), - backward_saved_launch_counts=( - ( - message_backward_saved_launch, - *transition_backward_metadata["saved_launch_counts"], - *static_saved_launch_counts, - ) - if training - else () - ), - backward_residual_glue_demotions=transition_backward_metadata["residual_demotions"] if training else (), - backward_recompute_mode=( - tuple( - item - for item in ( - f"transition_tape:{transition_tape_mode}" if transition_tape_mode is not None else None, - transition_tape_reason, - f"temporal_artifacts:{temporal_artifact_mode}" if temporal_artifact_mode is not None else None, - temporal_artifact_reason, - ) - if item is not None - ) - if training - else () - ), - ) - - -__all__ = [ - "execute_temporal_bucket_active_output_window", - "execute_temporal_bucket_sequence", - "record_temporal_bucket_sequence_surface_execution", - "record_temporal_bucket_step_loop_execution", - "supports_temporal_bucket_active_output_window", -] diff --git a/src/cortical/fabric/backend/cuda/temporal_param_binding.py b/src/cortical/fabric/backend/cuda/temporal_param_binding.py new file mode 100644 index 00000000..2652464d --- /dev/null +++ b/src/cortical/fabric/backend/cuda/temporal_param_binding.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from cortical.fabric.backend.cell_backend import TransitionParameterBinding + + +def compiled_transition_program_for_bucket(runtime: Any, bucket: Any) -> Any: + backend_ir = getattr(runtime, "backend_ir", None) + if backend_ir is None: + raise RuntimeError("Fabric CUDA temporal execution requires compiled backend_ir transition programs") + binding_slot = int(getattr(bucket, "binding_slot", -1)) + transition_program = backend_ir.transition_program_for_binding_slot(binding_slot) + if int(getattr(transition_program, "binding_slot", -2)) != binding_slot: + raise RuntimeError( + "Fabric CUDA temporal transition binding slot mismatch: " + f"bucket={binding_slot}, program={getattr(transition_program, 'binding_slot', None)}" + ) + return transition_program + + +def compiled_transition_parameter_bindings( + transition_program: Any, +) -> dict[str, tuple[TransitionParameterBinding, ...]]: + compiled_bindings = getattr(transition_program, "parameter_bindings", ()) + binding_map: dict[str, tuple[TransitionParameterBinding, ...]] = {} + for item in compiled_bindings: + parameter = str(getattr(item, "parameter", "")) + if not parameter: + continue + binding_map[parameter] = tuple(getattr(item, "bindings", ()) or ()) + return binding_map + + +def resolve_transition_parameter( + transition_program: Any, + population_params: dict[str, object], + static_tensors: dict[str, object], + name: str, + *, + num_receivers: int, +) -> torch.Tensor: + binding_map = compiled_transition_parameter_bindings(transition_program) + try: + bindings = binding_map[str(name)] + except KeyError as exc: + raise RuntimeError( + f"Fabric CUDA temporal primitive parameter {name!r} has no compiled binding; " + "update the fabric.cuda.nn declaration/lowering instead of reading a runtime tensor by convention" + ) from exc + for binding in bindings: + tensor = resolve_bound_transition_parameter( + population_params, + static_tensors, + binding, + num_receivers=num_receivers, + ) + if tensor is not None: + return tensor + raise RuntimeError(f"Fabric CUDA temporal primitive could not resolve compiled parameter {name!r}") + + +def resolve_bound_transition_parameter( + population_params: dict[str, object], + static_tensors: dict[str, object], + binding: TransitionParameterBinding, + *, + num_receivers: int, +) -> torch.Tensor | None: + if binding.kind == "cell_param": + tensor = population_params.get(binding.source) + return tensor if torch.is_tensor(tensor) else None + if binding.kind == "static_tensor": + tensor = static_tensors.get(binding.source) + return tensor if torch.is_tensor(tensor) else None + if binding.kind == "expanded_transposed_static_tensor": + tensor = static_tensors.get(binding.source) + if torch.is_tensor(tensor): + return tensor.transpose(0, 1).unsqueeze(0).expand(num_receivers, -1, -1) + return None + raise RuntimeError(f"Unsupported transition parameter binding kind {binding.kind}") + + +__all__ = [ + "compiled_transition_parameter_bindings", + "compiled_transition_program_for_bucket", + "resolve_bound_transition_parameter", + "resolve_transition_parameter", +] diff --git a/src/cortical/fabric/backend/cuda/transition_execution/__init__.py b/src/cortical/fabric/backend/cuda/transition_execution/__init__.py new file mode 100644 index 00000000..f1635783 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/transition_execution/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +# Transition execution is intentionally split into semantic submodules. +# Active code imports those submodules directly; this package root is not a barrel. + +__all__: list[str] = [] diff --git a/src/cortical/fabric/backend/cuda/transition_execution/program.py b/src/cortical/fabric/backend/cuda/transition_execution/program.py new file mode 100644 index 00000000..75da4147 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/transition_execution/program.py @@ -0,0 +1,623 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal, cast + +import torch +from tensordict import TensorDictBase + +from cortical.fabric.backend.cuda.ops import reset_backend_tensors_rows_cuda +from cortical.fabric.backend.cuda.transition_execution.registry import ( + TransitionProgramExecutorKind, + TransitionProgramExecutorRecord, + registered_transition_primitive_executor_records, + registered_transition_executor_records, + transition_primitive_executor_record_for_lowered_primitive, + transition_program_layer_blocker_codes, + transition_program_layer_missing_symbols, +) + + +class TransitionProgramExecutorSelectionError(RuntimeError): + def __init__(self, code: str, primitive_names: tuple[str, ...], detail: str) -> None: + self.code = str(code) + self.primitive_names = primitive_names + super().__init__( + "Fabric CUDA transition execution has no registered physical executor " + f"for compiled transition ops: {', '.join(primitive_names) or ''}; " + f"code={self.code}; {detail}" + ) + + +@dataclass(frozen=True) +class TransitionPrimitiveDagOp: + op_index: int + primitive: str + inputs: tuple[str, ...] + outputs: tuple[str, ...] + parameter_inputs: tuple[str, ...] + tensor_role_contract: tuple[str, ...] + tape_policy: str + forward_symbol: str + backward_symbol: str + param_grad_outputs: tuple[tuple[str, str, str], ...] + + @property + def summary(self) -> str: + return ( + f"op={int(self.op_index)}:{self.primitive}" + f",inputs={','.join(self.inputs) or '-'}" + f",outputs={','.join(self.outputs) or '-'}" + f",params={','.join(self.parameter_inputs) or '-'}" + f",roles={','.join(self.tensor_role_contract) or '-'}" + f",tape={self.tape_policy}" + ) + + +@dataclass(frozen=True) +class TransitionPrimitiveDagTensorEdge: + producer_op_index: int + tensor_name: str + consumer_op_index: int + consumer_input_name: str + + @property + def summary(self) -> str: + return ( + f"{int(self.producer_op_index)}:{self.tensor_name}" + f"->{int(self.consumer_op_index)}:{self.consumer_input_name}" + ) + + +@dataclass(frozen=True) +class TransitionPrimitiveDagTapeContract: + op_index: int + primitive: str + tape_policy: str + saved_inputs: tuple[str, ...] + saved_outputs: tuple[str, ...] + recompute_inputs: tuple[str, ...] + recompute_outputs: tuple[str, ...] + reverse_inputs: tuple[str, ...] + + @property + def summary(self) -> str: + return ( + f"op={int(self.op_index)}:{self.primitive}" + f",policy={self.tape_policy}" + f",save_inputs={','.join(self.saved_inputs) or '-'}" + f",save_outputs={','.join(self.saved_outputs) or '-'}" + f",recompute_inputs={','.join(self.recompute_inputs) or '-'}" + f",recompute_outputs={','.join(self.recompute_outputs) or '-'}" + f",reverse_inputs={','.join(self.reverse_inputs) or '-'}" + ) + + +@dataclass(frozen=True) +class TransitionPrimitiveDagExecutorPlan: + registry_id: str + primitive_names: tuple[str, ...] + ops: tuple[TransitionPrimitiveDagOp, ...] + tensor_edges: tuple[TransitionPrimitiveDagTensorEdge, ...] + tape_contracts: tuple[TransitionPrimitiveDagTapeContract, ...] + external_inputs: tuple[str, ...] + state_inputs: tuple[str, ...] + state_outputs: tuple[str, ...] + public_outputs: tuple[str, ...] + message_inputs: tuple[str, ...] + parameter_inputs: tuple[str, ...] + + @property + def forward_symbols(self) -> tuple[str, ...]: + return tuple(dict.fromkeys(op.forward_symbol for op in self.ops if op.forward_symbol)) + + @property + def backward_symbols(self) -> tuple[str, ...]: + return tuple(dict.fromkeys(op.backward_symbol for op in self.ops if op.backward_symbol)) + + @property + def tape_policies(self) -> tuple[str, ...]: + return tuple(dict.fromkeys(op.tape_policy for op in self.ops if op.tape_policy)) + + @property + def param_grad_outputs(self) -> tuple[tuple[str, str, str], ...]: + return tuple(item for op in self.ops for item in op.param_grad_outputs) + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + f"primitive_dag_registry_id={self.registry_id}", + "primitive_dag_primitives=" + ",".join(self.primitive_names), + "primitive_dag_external_inputs=" + (",".join(self.external_inputs) or "-"), + "primitive_dag_tensor_edges=" + (",".join(edge.summary for edge in self.tensor_edges) or "-"), + "primitive_dag_tape_policies=" + (",".join(self.tape_policies) or "-"), + "primitive_dag_tape_contracts=" + (",".join(contract.summary for contract in self.tape_contracts) or "-"), + f"primitive_dag_param_grad_outputs={len(self.param_grad_outputs)}", + ) + + +def _transition_primitive_executor_records_by_name() -> dict[str, object]: + return {record.primitive: record for record in registered_transition_primitive_executor_records()} + + +@dataclass(frozen=True) +class TransitionProgramExecutorPlan: + executor: TransitionProgramExecutorKind + primitive_names: tuple[str, ...] + registry_id: str + forward_strategy_id: str + backward_strategy_id: str + selection_kind: Literal["fused_program_record", "primitive_dag"] = "fused_program_record" + runtime_execution_status: Literal[ + "registered_fused_program", + "registered_primitive_dag_program", + ] = "registered_fused_program" + primitive_dag: TransitionPrimitiveDagExecutorPlan | None = None + rejection_code: str = "" + program_layer_status: str = "blocked" + program_layer_blocker_codes: tuple[str, ...] = () + program_forward_symbols: tuple[str, ...] = () + program_backward_symbols: tuple[str, ...] = () + program_missing_symbols: tuple[str, ...] = () + + @property + def review_summary(self) -> tuple[str, ...]: + return ( + f"transition_executor={self.executor}", + f"registry_id={self.registry_id}", + f"selection_kind={self.selection_kind}", + "primitives=" + ",".join(self.primitive_names), + f"forward_strategy_id={self.forward_strategy_id}", + f"backward_strategy_id={self.backward_strategy_id}", + f"runtime_execution_status={self.runtime_execution_status}", + f"rejection_code={self.rejection_code or '-'}", + f"program_layer_status={self.program_layer_status}", + "program_layer_blocker_codes=" + + ("-" if not self.program_layer_blocker_codes else ",".join(self.program_layer_blocker_codes)), + "program_forward_symbols=" + + ("-" if not self.program_forward_symbols else ",".join(self.program_forward_symbols)), + "program_backward_symbols=" + + ("-" if not self.program_backward_symbols else ",".join(self.program_backward_symbols)), + "program_missing_symbols=" + + ("-" if not self.program_missing_symbols else ",".join(self.program_missing_symbols)), + "primitive_dag=" + ("-" if self.primitive_dag is None else ",".join(self.primitive_dag.review_summary)), + ) + + +def _eligibility_trace_state_names(transition_program: Any) -> frozenset[str]: + return frozenset( + schema.name + for schema in getattr(transition_program, "private_state_schema", ()) + if schema.semantic_kind == "eligibility_trace" + ) + + +def _grad_state_value(grad_state: Any, state_name: str) -> torch.Tensor | None: + if grad_state is None: + return None + if isinstance(grad_state, (dict, TensorDictBase)): + return cast(torch.Tensor | None, grad_state.get(state_name)) + if torch.is_tensor(grad_state): + return grad_state + raise RuntimeError(f"Unsupported Fabric CUDA transition grad-state container {type(grad_state).__name__}") + + +def _reset_state_inputs_for_transition( + state_tensors: tuple[torch.Tensor, ...], + *, + population_reset_step: torch.Tensor | None, + batch_size: int, + device: torch.device, +) -> tuple[torch.Tensor, ...]: + if population_reset_step is None: + return state_tensors + reset_rows = torch.as_tensor(population_reset_step, device=device, dtype=torch.bool).view(batch_size) + resettable_indices = tuple( + index + for index, tensor in enumerate(state_tensors) + if tensor.dim() >= 2 and int(tensor.shape[0]) == int(batch_size) + ) + if len(resettable_indices) == len(state_tensors): + return cast(tuple[torch.Tensor, ...], reset_backend_tensors_rows_cuda(state_tensors, reset_rows)) + resettable = tuple(state_tensors[index] for index in resettable_indices) + reset_values = reset_backend_tensors_rows_cuda(resettable, reset_rows) if resettable else () + reset_by_index = dict(zip(resettable_indices, reset_values, strict=True)) + return tuple(cast(torch.Tensor, reset_by_index.get(index, tensor)) for index, tensor in enumerate(state_tensors)) + + +def _zero_reset_row_grad( + grad: torch.Tensor | None, + population_reset_step: torch.Tensor | None, + *, + batch_size: int, +) -> torch.Tensor | None: + if grad is None or population_reset_step is None: + return grad + reset_rows = torch.as_tensor(population_reset_step, device=grad.device, dtype=torch.bool).view(batch_size) + return cast(torch.Tensor, reset_backend_tensors_rows_cuda((grad,), reset_rows)[0]) + + +def _compiled_transition_program_for_population(runtime: Any, population_name: str) -> Any: + backend_ir = getattr(runtime, "backend_ir", None) + if backend_ir is None: + raise RuntimeError("Fabric CUDA transition execution requires compiled backend_ir transition programs") + population_names = tuple(str(name) for name in getattr(backend_ir, "population_names", ())) + try: + binding_slot = population_names.index(str(population_name)) + except ValueError as exc: + raise RuntimeError(f"Fabric backend IR has no transition program for population {population_name}") from exc + return backend_ir.transition_program_for_binding_slot(binding_slot) + + +def _transition_program_ops(transition_program: Any) -> tuple[Any, ...]: + ops = tuple(getattr(transition_program, "primitive_ops", ()) or ()) + if not ops: + raise RuntimeError("Fabric CUDA transition execution requires compiled transition primitive rows") + return ops + + +def _record_compiled_transition_executor_selection(runtime: Any, transition_program: Any) -> None: + ops = _transition_program_ops(transition_program) + runtime._last_transition_executor_program_source = "compiled_transition_program" + runtime._last_transition_executor_binding_slot = int(getattr(transition_program, "binding_slot", -1)) + runtime._last_transition_executor_primitives = tuple(str(op.name) for op in ops) + + +def _record_transition_executor_plan(runtime: Any, executor_plan: TransitionProgramExecutorPlan) -> None: + runtime._last_transition_executor_registry_id = executor_plan.registry_id + runtime._last_transition_executor_name = executor_plan.executor + runtime._last_transition_executor_selection_kind = executor_plan.selection_kind + runtime._last_transition_executor_forward_strategy_id = executor_plan.forward_strategy_id + runtime._last_transition_executor_backward_strategy_id = executor_plan.backward_strategy_id + runtime._last_transition_executor_runtime_execution_status = executor_plan.runtime_execution_status + runtime._last_transition_executor_rejection_code = executor_plan.rejection_code + runtime._last_transition_executor_program_layer_status = executor_plan.program_layer_status + runtime._last_transition_executor_program_layer_blocker_codes = executor_plan.program_layer_blocker_codes + runtime._last_transition_executor_program_forward_symbols = executor_plan.program_forward_symbols + runtime._last_transition_executor_program_backward_symbols = executor_plan.program_backward_symbols + runtime._last_transition_executor_program_missing_symbols = executor_plan.program_missing_symbols + runtime._last_transition_executor_primitive_dag = ( + () if executor_plan.primitive_dag is None else executor_plan.primitive_dag.review_summary + ) + + +def _matches_transition_executor_record( + transition_program: Any, + record: TransitionProgramExecutorRecord, +) -> bool: + ops = _transition_program_ops(transition_program) + if tuple(str(op.name) for op in ops) != record.primitive_names: + return False + for arity in record.arities: + op = ops[int(arity.op_index)] + if len(getattr(op, "inputs", ())) < int(arity.min_inputs): + return False + if arity.output_count is not None and len(getattr(op, "outputs", ())) != int(arity.output_count): + return False + message_inputs = set(str(item) for item in getattr(transition_program, "message_inputs", ())) + for message_input in record.message_inputs: + op_inputs = tuple(str(item) for item in getattr(ops[int(message_input.op_index)], "inputs", ())) + if ( + int(message_input.input_index) >= len(op_inputs) + or op_inputs[int(message_input.input_index)] not in message_inputs + ): + return False + for state_slice in record.state_slices: + op_items = tuple(str(item) for item in getattr(ops[int(state_slice.op_index)], state_slice.op_field, ())) + state_items = tuple(str(item) for item in getattr(transition_program, state_slice.state_field, ())) + if ( + op_items[int(state_slice.op_start) : state_slice.op_stop] + != state_items[int(state_slice.state_start) : state_slice.state_stop] + ): + return False + for edge in record.tensor_edges: + producer_outputs = tuple(str(item) for item in getattr(ops[int(edge.producer_op_index)], "outputs", ())) + consumer_inputs = tuple(str(item) for item in getattr(ops[int(edge.consumer_op_index)], "inputs", ())) + if int(edge.producer_output_index) >= len(producer_outputs) or int(edge.consumer_input_index) >= len( + consumer_inputs + ): + return False + if producer_outputs[int(edge.producer_output_index)] != consumer_inputs[int(edge.consumer_input_index)]: + return False + return True + + +def _op_parameter_inputs(op: Any, program_parameter_inputs: tuple[str, ...]) -> tuple[str, ...]: + parameter_inputs = getattr(op, "parameter_inputs", None) + if parameter_inputs is not None: + return tuple(str(item) for item in parameter_inputs) + parameter_set = set(program_parameter_inputs) + return tuple(str(item) for item in getattr(op, "inputs", ()) if str(item) in parameter_set) + + +def _build_transition_primitive_dag_executor_plan( + transition_program: Any, +) -> TransitionPrimitiveDagExecutorPlan: + ops = _transition_program_ops(transition_program) + primitive_names = tuple(str(op.name) for op in ops) + state_inputs = tuple(str(item) for item in getattr(transition_program, "state_inputs", ()) or ()) + state_outputs = tuple(str(item) for item in getattr(transition_program, "state_outputs", ()) or ()) + public_outputs = tuple(str(item) for item in getattr(transition_program, "public_outputs", ()) or ()) + message_inputs = tuple(str(item) for item in getattr(transition_program, "message_inputs", ()) or ()) + parameter_inputs = tuple(str(item) for item in getattr(transition_program, "parameter_inputs", ()) or ()) + external_inputs = tuple(dict.fromkeys((*state_inputs, *message_inputs, *parameter_inputs))) + external_input_set = set(external_inputs) + all_outputs = { + str(output) for op in ops for output in tuple(str(item) for item in getattr(op, "outputs", ()) or ()) + } + produced_by: dict[str, int] = {} + tensor_edges: list[TransitionPrimitiveDagTensorEdge] = [] + dag_ops: list[TransitionPrimitiveDagOp] = [] + tape_contracts: list[TransitionPrimitiveDagTapeContract] = [] + for op_index, op in enumerate(ops): + primitive = str(op.name) + record = transition_primitive_executor_record_for_lowered_primitive(primitive) + if record is None: + raise TransitionProgramExecutorSelectionError( + "UNREGISTERED_TRANSITION_PRIMITIVE", + primitive_names, + f"missing primitive executor record: {primitive}", + ) + inputs = tuple(str(item) for item in getattr(op, "inputs", ()) or ()) + outputs = tuple(str(item) for item in getattr(op, "outputs", ()) or ()) + parameter_inputs_for_op = _op_parameter_inputs(op, parameter_inputs) + non_parameter_inputs = _op_non_parameter_inputs(inputs, parameter_inputs_for_op) + for input_name in inputs: + producer_op_index = produced_by.get(input_name) + if producer_op_index is not None: + tensor_edges.append( + TransitionPrimitiveDagTensorEdge( + producer_op_index=int(producer_op_index), + tensor_name=input_name, + consumer_op_index=int(op_index), + consumer_input_name=input_name, + ) + ) + continue + if input_name in all_outputs: + raise TransitionProgramExecutorSelectionError( + "ILLEGAL_TRANSITION_PRIMITIVE_DAG", + primitive_names, + f"primitive op {op_index} consumes future tensor {input_name!r}", + ) + if input_name not in external_input_set: + raise TransitionProgramExecutorSelectionError( + "UNBOUND_TRANSITION_PRIMITIVE_INPUT", + primitive_names, + f"primitive op {op_index} input {input_name!r} is neither an external input nor a prior output", + ) + for output_name in outputs: + if output_name in produced_by: + raise TransitionProgramExecutorSelectionError( + "ILLEGAL_TRANSITION_PRIMITIVE_DAG", + primitive_names, + f"transition primitive tensor {output_name!r} is produced by multiple ops", + ) + produced_by[output_name] = int(op_index) + dag_ops.append( + TransitionPrimitiveDagOp( + op_index=int(op_index), + primitive=primitive, + inputs=inputs, + outputs=outputs, + parameter_inputs=parameter_inputs_for_op, + tensor_role_contract=record.tensor_role_contract, + tape_policy=record.tape_policy, + forward_symbol=record.program_forward_symbol, + backward_symbol=record.program_backward_symbol, + param_grad_outputs=record.param_grad_outputs, + ) + ) + tape_contracts.append( + _transition_primitive_dag_tape_contract( + op_index=int(op_index), + primitive=primitive, + record=record, + non_parameter_inputs=non_parameter_inputs, + outputs=outputs, + ) + ) + produced_outputs = set(produced_by) + missing_state_outputs = tuple(output for output in state_outputs if output not in produced_outputs) + missing_public_outputs = tuple(output for output in public_outputs if output not in produced_outputs) + if missing_state_outputs or missing_public_outputs: + raise TransitionProgramExecutorSelectionError( + "UNBOUND_TRANSITION_PROGRAM_OUTPUT", + primitive_names, + "compiled transition program outputs are not produced by primitive rows: " + f"state={missing_state_outputs}; public={missing_public_outputs}", + ) + return TransitionPrimitiveDagExecutorPlan( + registry_id="transition_executor:primitive_dag:v1", + primitive_names=primitive_names, + ops=tuple(dag_ops), + tensor_edges=tuple(tensor_edges), + tape_contracts=tuple(tape_contracts), + external_inputs=external_inputs, + state_inputs=state_inputs, + state_outputs=state_outputs, + public_outputs=public_outputs, + message_inputs=message_inputs, + parameter_inputs=parameter_inputs, + ) + + +def _op_non_parameter_inputs( + inputs: tuple[str, ...], + parameter_inputs: tuple[str, ...], +) -> tuple[str, ...]: + parameter_input_set = set(parameter_inputs) + return tuple(input_name for input_name in inputs if input_name not in parameter_input_set) + + +def _transition_primitive_dag_tape_contract( + *, + op_index: int, + primitive: str, + record: Any, + non_parameter_inputs: tuple[str, ...], + outputs: tuple[str, ...], +) -> TransitionPrimitiveDagTapeContract: + reverse_inputs = tuple(str(item) for item in getattr(record, "reverse_input_bindings", ()) or ()) + forward_input_bindings = tuple(str(item) for item in getattr(record, "program_forward_input_bindings", ()) or ()) + forward_output_bindings = tuple( + str(logical_name) + for logical_name, _required in tuple(getattr(record, "program_forward_output_bindings", ()) or ()) + ) + input_by_logical = dict(zip(forward_input_bindings, non_parameter_inputs, strict=False)) + output_by_logical = dict(zip(forward_output_bindings, outputs, strict=False)) + saved_inputs = _resolve_transition_tape_bindings( + getattr(record, "tape_saved_input_bindings", ()), + input_by_logical, + primitive=primitive, + binding_kind="saved input", + ) + saved_outputs = _resolve_transition_tape_bindings( + getattr(record, "tape_saved_output_bindings", ()), + output_by_logical, + primitive=primitive, + binding_kind="saved output", + ) + recompute_inputs = _resolve_transition_tape_bindings( + getattr(record, "tape_recompute_input_bindings", ()), + input_by_logical, + primitive=primitive, + binding_kind="recompute input", + ) + recompute_outputs = _resolve_transition_tape_bindings( + getattr(record, "tape_recompute_output_bindings", ()), + output_by_logical, + primitive=primitive, + binding_kind="recompute output", + ) + tape_policy = str(getattr(record, "tape_policy", "")) + return TransitionPrimitiveDagTapeContract( + op_index=int(op_index), + primitive=primitive, + tape_policy=tape_policy, + saved_inputs=saved_inputs, + saved_outputs=saved_outputs, + recompute_inputs=recompute_inputs, + recompute_outputs=recompute_outputs, + reverse_inputs=reverse_inputs, + ) + + +def _resolve_transition_tape_bindings( + logical_names: tuple[str, ...], + actual_by_logical: dict[str, str], + *, + primitive: str, + binding_kind: str, +) -> tuple[str, ...]: + actual_names: list[str] = [] + for logical_name in tuple(str(item) for item in logical_names): + actual_name = actual_by_logical.get(logical_name) + if actual_name is None: + raise TransitionProgramExecutorSelectionError( + "INVALID_TRANSITION_PRIMITIVE_TAPE_CONTRACT", + (primitive,), + f"primitive {primitive!r} declares unknown tape {binding_kind} binding {logical_name!r}", + ) + actual_names.append(actual_name) + return tuple(actual_names) + + +def select_transition_program_executor(transition_program: Any) -> TransitionProgramExecutorPlan: + ops = _transition_program_ops(transition_program) + primitive_names = tuple(str(op.name) for op in ops) + missing_primitive_records = tuple( + primitive + for primitive in primitive_names + if transition_primitive_executor_record_for_lowered_primitive(primitive) is None + ) + if missing_primitive_records: + raise TransitionProgramExecutorSelectionError( + "UNREGISTERED_TRANSITION_PRIMITIVE", + primitive_names, + "missing primitive executor records: " + ", ".join(missing_primitive_records), + ) + primitive_blockers = transition_program_layer_blocker_codes(primitive_names) + if primitive_blockers: + raise TransitionProgramExecutorSelectionError( + primitive_blockers[0], + primitive_names, + "primitive executor blockers: " + ", ".join(primitive_blockers), + ) + primitive_dag = _build_transition_primitive_dag_executor_plan(transition_program) + for record in registered_transition_executor_records(): + if _matches_transition_executor_record(transition_program, record): + return TransitionProgramExecutorPlan( + executor=record.executor, + primitive_names=primitive_names, + registry_id=record.registry_id, + forward_strategy_id=record.forward_strategy_id, + backward_strategy_id=record.backward_strategy_id, + selection_kind="fused_program_record", + runtime_execution_status="registered_fused_program", + primitive_dag=primitive_dag, + program_layer_status="callable", + program_layer_blocker_codes=(), + program_forward_symbols=primitive_dag.forward_symbols, + program_backward_symbols=primitive_dag.backward_symbols, + program_missing_symbols=transition_program_layer_missing_symbols(primitive_names), + ) + return TransitionProgramExecutorPlan( + executor="primitive_dag", + primitive_names=primitive_names, + registry_id=primitive_dag.registry_id, + forward_strategy_id="forward.transition.primitive_dag.v1", + backward_strategy_id="reverse.transition.primitive_dag.v1", + selection_kind="primitive_dag", + runtime_execution_status="registered_primitive_dag_program", + primitive_dag=primitive_dag, + program_layer_status="callable", + program_layer_blocker_codes=(), + program_forward_symbols=primitive_dag.forward_symbols, + program_backward_symbols=primitive_dag.backward_symbols, + program_missing_symbols=transition_program_layer_missing_symbols(primitive_names), + ) + + +def _receiver_hidden_parameter(param: torch.Tensor, *, num_receivers: int, hidden_size: int) -> torch.Tensor: + return param.reshape(num_receivers, hidden_size) + + +def _scalar_parameter(value: object) -> float: + if torch.is_tensor(value): + return float(value.reshape(-1)[0].item()) + return float(value) + + +def _activation_id(value: object) -> int: + if torch.is_tensor(value) and value.numel() > 0: + return int(value.reshape(-1)[0].item()) + return 3 + + +__all__ = [ + "TransitionProgramExecutorKind", + "TransitionProgramExecutorPlan", + "TransitionProgramExecutorRecord", + "TransitionProgramExecutorSelectionError", + "TransitionPrimitiveDagExecutorPlan", + "TransitionPrimitiveDagOp", + "TransitionPrimitiveDagTapeContract", + "TransitionPrimitiveDagTensorEdge", + "_activation_id", + "_build_transition_primitive_dag_executor_plan", + "_compiled_transition_program_for_population", + "_eligibility_trace_state_names", + "_grad_state_value", + "_receiver_hidden_parameter", + "_record_compiled_transition_executor_selection", + "_record_transition_executor_plan", + "_reset_state_inputs_for_transition", + "_scalar_parameter", + "_transition_program_ops", + "_transition_primitive_executor_records_by_name", + "_zero_reset_row_grad", + "registered_transition_primitive_executor_records", + "registered_transition_executor_records", + "select_transition_program_executor", +] diff --git a/src/cortical/fabric/backend/cuda/transition_execution.py b/src/cortical/fabric/backend/cuda/transition_execution/projection.py similarity index 50% rename from src/cortical/fabric/backend/cuda/transition_execution.py rename to src/cortical/fabric/backend/cuda/transition_execution/projection.py index f0ef7903..bf2cd180 100644 --- a/src/cortical/fabric/backend/cuda/transition_execution.py +++ b/src/cortical/fabric/backend/cuda/transition_execution/projection.py @@ -1,32 +1,24 @@ from __future__ import annotations -from collections.abc import Callable, Mapping -from dataclasses import dataclass +from collections.abc import Sequence from typing import Any, cast import torch -from tensordict import TensorDict, TensorDictBase from torch.autograd import Function -from cortical.fabric.backend.cell_backend import CellBackendSpec, TransitionParameterBinding +from cortical.fabric.backend.cuda.temporal_param_binding import ( + resolve_transition_parameter as _resolve_transition_parameter, +) from cortical.fabric.backend.cuda.ops import ( - diagonal_recurrence_backward_cuda, - diagonal_recurrence_forward_autograd_cuda, - diagonal_recurrence_forward_cuda, - diagonal_recurrence_output_projection_backward_cuda, - diagonal_recurrence_preproj_cuda, - gated_logspace_recurrence_outnorm_backward_cuda, - gated_logspace_recurrence_outnorm_cuda, - gated_logspace_recurrence_outnorm_forward_cuda, receiver_major_affine_backward_cuda, receiver_major_affine_bias_backward_cuda, receiver_major_affine_bias_cuda, receiver_major_affine_bias_out_cuda, receiver_major_affine_bias_small_batch_cuda, receiver_major_affine_cuda, + receiver_major_affine_input_backward_cuda, receiver_major_affine_out_cuda, receiver_major_affine_small_batch_cuda, - reset_backend_tensors_rows_cuda, ) from cortical.fabric.backend.cuda.ops.factorized_projection_grads_triton import ( factorized_recurrent_input_base_cuda, @@ -41,40 +33,10 @@ receiver_major_affine_backward_block_b, receiver_major_projection_backward_gate, ) - - -@dataclass(frozen=True) -class TransitionBackwardResult: - grad_recurrent_msg: torch.Tensor | None - grad_packed_state_before: dict[str, torch.Tensor | None] - materialized_param_grads: dict[str, torch.Tensor] - static_source_grads: dict[str, torch.Tensor] - - -@dataclass(frozen=True) -class _ProjectionBackwardTape: - output: torch.Tensor - backward: Callable[ - [torch.Tensor | None], - tuple[torch.Tensor | None, dict[str, torch.Tensor], dict[str, torch.Tensor]], - ] - - -@dataclass(frozen=True) -class TransitionBackwardTape: - input_projection: _ProjectionBackwardTape | None = None - diagonal_preproj: torch.Tensor | None = None - gated_gate_logits: torch.Tensor | None = None - gated_recurrent_gate_logits: torch.Tensor | None = None - - -@dataclass(frozen=True) -class TransitionForwardResult: - next_packed_state: Any - recurrent_hidden: torch.Tensor - recurrent_k: torch.Tensor | None - recurrent_v: torch.Tensor | None - backward_tape: TransitionBackwardTape | None = None +from cortical.fabric.backend.cuda.transition_execution.types import ( + TransitionInputProjectionParamGradStep, + _ProjectionBackwardTape, +) class _SharedReceiverBiasLinearFunction(Function): @@ -318,736 +280,10 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.T return grad_y_heads.reshape(batch, num_receivers, hidden_size), grad_kernel -def lower_backend_population_transition_shared( - runtime: Any, - *, - population_name: str | None = None, - recurrent_msg: torch.Tensor, - packed_state_before: Any, - population_reset_step: torch.Tensor | None, - static_tensors: dict[str, object], - materialize_recurrent_kv: bool = True, -) -> tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - result = lower_backend_population_transition_forward_result_shared( - runtime, - population_name=population_name, - recurrent_msg=recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=population_reset_step, - static_tensors=static_tensors, - materialize_recurrent_kv=materialize_recurrent_kv, - materialize_backward_tape=False, - ) - return result.next_packed_state, result.recurrent_hidden, result.recurrent_k, result.recurrent_v - - -def lower_backend_population_transition_forward_result_shared( - runtime: Any, - *, - population_name: str | None = None, - recurrent_msg: torch.Tensor, - packed_state_before: Any, - population_reset_step: torch.Tensor | None, - static_tensors: dict[str, object], - materialize_recurrent_kv: bool = True, - materialize_backward_tape: bool = False, - materialize_diagonal_preproj_tape: bool = True, - materialize_recurrence_backward_tape: bool | None = None, - materialize_next_state: bool = True, - materialize_trace_state_next: bool = True, -) -> TransitionForwardResult: - if materialize_recurrence_backward_tape is None: - materialize_recurrence_backward_tape = bool(materialize_diagonal_preproj_tape) - population_name = runtime._full_recurrent_population_name if population_name is None else population_name - if population_name is None: - raise RuntimeError("Fabric CUDA transition execution requires a resolved recurrent population") - population_spec = runtime._backend_population_specs.get(population_name) - if population_spec is None: - raise RuntimeError(f"Fabric CUDA transition execution has no backend spec for population {population_name}") - population_materialized = static_tensors.get("population_materialized") - if not isinstance(population_materialized, dict): - raise RuntimeError("Fabric CUDA transition execution requires materialized population parameters") - population_params = population_materialized.get(population_name) - if not isinstance(population_params, dict): - raise RuntimeError(f"Fabric CUDA transition execution has no materialized parameters for {population_name}") - virtual_fresh_zero_state = packed_state_before is None - if virtual_fresh_zero_state and materialize_next_state: - raise RuntimeError("Virtual fresh backend state requires unmaterialized next state on CUDA transition forward") - if not virtual_fresh_zero_state and not isinstance(packed_state_before, Mapping): - raise RuntimeError("Fabric CUDA transition execution requires mapping-packed recurrent state") - if _is_gated_logspace_recurrence_transition(population_spec): - return _lower_gated_logspace_recurrence_transition( - runtime, - population_spec=population_spec, - population_params=population_params, - packed_state_before=None - if virtual_fresh_zero_state - else cast(Mapping[str, torch.Tensor], packed_state_before), - population_reset_step=population_reset_step, - recurrent_msg=recurrent_msg, - static_tensors=static_tensors, - materialize_recurrent_kv=materialize_recurrent_kv, - materialize_backward_tape=materialize_backward_tape, - materialize_recurrence_backward_tape=bool(materialize_recurrence_backward_tape), - materialize_next_state=materialize_next_state, - ) - if not _is_diagonal_recurrence_transition(population_spec): - op_names = ", ".join(op.name for op in population_spec.transition_ir.ops) or "" - raise RuntimeError( - "Fabric CUDA transition execution has no registered backward contract for transition IR ops: " + op_names - ) - - ops = population_spec.transition_ir.ops - input_op = ops[0] - recurrence_op = ops[1] - output_op = ops[2] - batch_size = int(recurrent_msg.shape[0]) - num_receivers = int(recurrent_msg.shape[1]) - hidden_size = int(runtime.hidden_size) - - input_tape: _ProjectionBackwardTape | None = None - with ( - torch.profiler.record_function("fabric.projection.population_input"), - runtime._backend_owner_timing("artifact.transition.input_projection"), - ): - if materialize_backward_tape: - input_tape = _diagonal_recurrence_input_projection_tape( - recurrent_msg, - population_spec=population_spec, - population_params=population_params, - static_tensors=static_tensors, - input_op=input_op, - num_receivers=num_receivers, - retain_factorized_base=bool(materialize_recurrence_backward_tape), - ) - cell_input = input_tape.output - else: - cell_input = _diagonal_recurrence_input_projection( - recurrent_msg, - population_spec=population_spec, - population_params=population_params, - static_tensors=static_tensors, - input_op=input_op, - num_receivers=num_receivers, - ) - - trace_state_names = _eligibility_trace_state_names(population_spec) - if virtual_fresh_zero_state: - state_tensors = tuple( - cell_input.new_zeros((batch_size, num_receivers, hidden_size)) - for _state_name in population_spec.transition_ir.state_inputs - ) - else: - packed_state_mapping = cast(Mapping[str, torch.Tensor], packed_state_before) - state_tensors = tuple( - packed_state_mapping[state_name] for state_name in population_spec.transition_ir.state_inputs - ) - state_tensors = _reset_state_inputs_for_transition( - state_tensors, - population_reset_step=population_reset_step, - batch_size=batch_size, - device=recurrent_msg.device, - ) - - nu_log = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrence_op.inputs[11], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - theta_log = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrence_op.inputs[12], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - w1 = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrence_op.inputs[13], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - w2 = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrence_op.inputs[14], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - activation_id = _activation_id(population_params.get("activation_id")) - - recurrence_inputs = ( - cell_input.contiguous(), - *(state.contiguous() for state in state_tensors), - nu_log.contiguous(), - theta_log.contiguous(), - w1.contiguous(), - w2.contiguous(), - ) - with runtime._backend_owner_timing("artifact.transition.diagonal_recurrence"): - recurrence_outputs = ( - diagonal_recurrence_forward_autograd_cuda( - *recurrence_inputs, - activation_id=activation_id, - write_trace_state_next=materialize_trace_state_next, - ) - if torch.is_grad_enabled() - else diagonal_recurrence_forward_cuda( - *recurrence_inputs, - activation_id=activation_id, - write_trace_state_next=materialize_trace_state_next, - ) - ) - preproj = cast(torch.Tensor, recurrence_outputs[0]) - if materialize_next_state: - next_state_values = cast(tuple[torch.Tensor, ...], recurrence_outputs[1:]) - next_state_by_name: dict[str, torch.Tensor] | None = { - state_name: value - for state_name, value in zip( - ( - population_spec.transition_ir.state_inputs - if materialize_trace_state_next - else tuple( - state_name - for state_name in population_spec.transition_ir.state_inputs - if state_name not in trace_state_names - ) - ), - next_state_values, - strict=True, - ) - } - else: - next_state_by_name = None - - output_weight = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - output_op.inputs[1], - num_receivers=num_receivers, - ) - output_bias = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - output_op.inputs[2], - num_receivers=num_receivers, - ) - with ( - torch.profiler.record_function("fabric.projection.population_output"), - runtime._backend_owner_timing("artifact.transition.output_projection"), - ): - public_y = _receiver_major_linear(preproj, output_weight, output_bias) - next_packed_state = ( - None - if next_state_by_name is None - else TensorDict( - next_state_by_name, - batch_size=[batch_size, num_receivers], - device=recurrent_msg.device, - ) - ) - recurrent_k: torch.Tensor | None = None - recurrent_v: torch.Tensor | None = None - if materialize_recurrent_kv: - with ( - torch.profiler.record_function("fabric.projection.population_public_kv"), - runtime._backend_owner_timing("artifact.transition.public_kv_projection"), - ): - recurrent_k, recurrent_v = runtime._project_sender_kv_from_cells_step( - public_y, - sender_input_to_kv_weight=cast( - torch.Tensor | None, - static_tensors["recurrent_sender_input_to_kv_weight"], - ), - grouped_sender_input_to_kv_weight=cast( - torch.Tensor | None, - static_tensors["recurrent_group_input_to_kv_weight"], - ), - sender_group_size=runtime._recurrent_sender_kv_group_size, - ) - backward_tape = None - if materialize_backward_tape: - backward_tape = TransitionBackwardTape( - input_projection=input_tape, - diagonal_preproj=preproj if materialize_recurrence_backward_tape else None, - ) - return TransitionForwardResult(next_packed_state, public_y, recurrent_k, recurrent_v, backward_tape) - - -def lower_backend_population_transition_backward_shared( - runtime: Any, - *, - population_name: str | None = None, - recurrent_msg: torch.Tensor, - packed_state_before: Any, - population_reset_step: torch.Tensor | None, - static_tensors: dict[str, object], - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - need_grad_packed_state_before: bool = True, - forward_tape: TransitionBackwardTape | None = None, -) -> TransitionBackwardResult: - population_name = runtime._full_recurrent_population_name if population_name is None else population_name - if population_name is None: - raise RuntimeError("Fabric CUDA transition backward requires a resolved recurrent population") - population_spec = runtime._backend_population_specs.get(population_name) - if population_spec is None: - raise RuntimeError(f"Fabric CUDA transition backward has no backend spec for population {population_name}") - population_materialized = static_tensors.get("population_materialized") - if not isinstance(population_materialized, dict): - raise RuntimeError("Fabric CUDA transition backward requires materialized population parameters") - population_params = population_materialized.get(population_name) - if not isinstance(population_params, dict): - raise RuntimeError(f"Fabric CUDA transition backward has no materialized parameters for {population_name}") - if not isinstance(packed_state_before, Mapping): - raise RuntimeError("Fabric CUDA transition backward requires mapping-packed recurrent state") - if not recurrent_msg.is_cuda: - raise RuntimeError("Fabric CUDA transition backward requires CUDA tensors") - if _is_gated_logspace_recurrence_transition(population_spec): - return _lower_gated_logspace_recurrence_backward( - runtime, - population_spec=population_spec, - population_params=population_params, - packed_state_before=cast(Mapping[str, torch.Tensor], packed_state_before), - population_reset_step=population_reset_step, - recurrent_msg=recurrent_msg, - static_tensors=static_tensors, - grad_next_packed_state=grad_next_packed_state, - grad_recurrent_hidden=grad_recurrent_hidden, - need_grad_packed_state_before=need_grad_packed_state_before, - forward_tape=forward_tape, - ) - if _is_diagonal_recurrence_transition(population_spec): - return _lower_diagonal_recurrence_backward( - runtime, - population_spec=population_spec, - population_params=population_params, - packed_state_before=cast(Mapping[str, torch.Tensor], packed_state_before), - population_reset_step=population_reset_step, - recurrent_msg=recurrent_msg, - static_tensors=static_tensors, - grad_next_packed_state=grad_next_packed_state, - grad_recurrent_hidden=grad_recurrent_hidden, - need_grad_packed_state_before=need_grad_packed_state_before, - forward_tape=forward_tape, - ) - op_names = ", ".join(op.name for op in population_spec.transition_ir.ops) or "" - raise RuntimeError( - "Fabric CUDA transition backward has no registered physical executor for transition IR ops: " + op_names - ) - - -def _lower_gated_logspace_recurrence_backward( - runtime: Any, - *, - population_spec: CellBackendSpec, - population_params: dict[str, object], - packed_state_before: Mapping[str, torch.Tensor], - population_reset_step: torch.Tensor | None, - recurrent_msg: torch.Tensor, - static_tensors: dict[str, object], - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - need_grad_packed_state_before: bool, - forward_tape: TransitionBackwardTape | None, -) -> TransitionBackwardResult: - ops = population_spec.transition_ir.ops - input_op = ops[0] - recurrent_op = ops[1] - core_op = ops[2] - public_op = ops[3] - batch_size = int(recurrent_msg.shape[0]) - num_receivers = int(recurrent_msg.shape[1]) - hidden_size = int(runtime.hidden_size) - - gate_weight = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - input_op.inputs[1], - num_receivers=num_receivers, - ) - gate_bias = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - input_op.inputs[2], - num_receivers=num_receivers, - ) - recurrent_kernel = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrent_op.inputs[1], - num_receivers=num_receivers, - ) - - if forward_tape is not None and forward_tape.input_projection is not None: - input_tape = forward_tape.input_projection - population_input = input_tape.output - else: - with ( - torch.profiler.record_function("fabric.backward.receiver_affine.input_projection.recompute"), - runtime._backend_owner_timing("receiver_affine.input_projection_recompute"), - ): - input_tape = _gated_logspace_input_projection_tape(recurrent_msg, static_tensors=static_tensors) - population_input = input_tape.output - - if forward_tape is not None and forward_tape.gated_gate_logits is not None: - gate_logits = forward_tape.gated_gate_logits - else: - with runtime._backend_owner_timing("receiver_affine.gate_affine_recompute"): - gate_logits = _transition_linear_no_grad(population_input, gate_weight, gate_bias, hidden_size=hidden_size) - - y_prev_original = packed_state_before[recurrent_op.inputs[0]] - c_prev_original = packed_state_before[core_op.inputs[2]] - n_prev_original = packed_state_before[core_op.inputs[3]] - m_prev_original = packed_state_before[core_op.inputs[4]] - y_prev, c_prev, n_prev, m_prev = _reset_state_inputs_for_transition( - (y_prev_original, c_prev_original, n_prev_original, m_prev_original), - population_reset_step=population_reset_step, - batch_size=batch_size, - device=recurrent_msg.device, - ) - - if forward_tape is not None and forward_tape.gated_recurrent_gate_logits is not None: - recurrent_gate_logits = forward_tape.gated_recurrent_gate_logits - else: - with runtime._backend_owner_timing("receiver_affine.recurrent_affine_recompute"): - recurrent_gate_logits = _gated_recurrent_matmul_no_grad(y_prev, recurrent_kernel) - outnorm_weight = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - public_op.inputs[1], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - eps = _scalar_parameter(population_params["outnorm_eps"]) - grad_outputs = ( - grad_recurrent_hidden, - _grad_state_value(grad_next_packed_state, core_op.outputs[0].removeprefix("next_")), - _grad_state_value(grad_next_packed_state, core_op.outputs[1].removeprefix("next_")), - _grad_state_value(grad_next_packed_state, core_op.outputs[2].removeprefix("next_")), - _grad_state_value(grad_next_packed_state, core_op.outputs[3].removeprefix("next_")), - ) - with ( - torch.profiler.record_function("fabric.backward.state_epilogue"), - runtime._backend_owner_timing("state_epilogue.core"), - ): - grad_raw, grad_c_prev, grad_n_prev, grad_m_prev, grad_outnorm_weight = ( - gated_logspace_recurrence_outnorm_backward_cuda( - gate_logits, - recurrent_gate_logits, - c_prev, - n_prev, - m_prev, - outnorm_weight, - grad_outputs, - eps=eps, - return_state_grads=need_grad_packed_state_before, - ) - ) - - with runtime._backend_owner_timing("receiver_affine.gate_affine_backward"): - grad_population_input, grad_gate_weight, grad_gate_bias = _transition_linear_backward_no_grad( - population_input, - gate_weight, - gate_bias, - grad_raw, - hidden_size=hidden_size, - ) - with runtime._backend_owner_timing("receiver_affine.recurrent_affine_backward"): - grad_y_prev, grad_recurrent_kernel = _gated_recurrent_matmul_backward_no_grad_impl( - y_prev, - recurrent_kernel, - grad_raw, - return_input_grad=need_grad_packed_state_before, - ) - with runtime._backend_owner_timing("receiver_affine.input_projection_backward"): - grad_recurrent_msg, static_param_grads, materialized_param_grads_from_input = input_tape.backward( - grad_population_input, - ) - materialized_param_grads = { - "gate_weight": grad_gate_weight, - "recurrent_kernel": grad_recurrent_kernel, - "bias": grad_gate_bias, - "outnorm_weight": grad_outnorm_weight, - **materialized_param_grads_from_input, - } - grad_state_before = { - recurrent_op.inputs[0]: _zero_reset_row_grad(grad_y_prev, population_reset_step, batch_size=batch_size), - core_op.inputs[2]: _zero_reset_row_grad(grad_c_prev, population_reset_step, batch_size=batch_size), - core_op.inputs[3]: _zero_reset_row_grad(grad_n_prev, population_reset_step, batch_size=batch_size), - core_op.inputs[4]: _zero_reset_row_grad(grad_m_prev, population_reset_step, batch_size=batch_size), - } - return TransitionBackwardResult( - grad_recurrent_msg=grad_recurrent_msg, - grad_packed_state_before=grad_state_before, - materialized_param_grads=materialized_param_grads, - static_source_grads=static_param_grads, - ) - - -def _lower_diagonal_recurrence_backward( - runtime: Any, - *, - population_spec: CellBackendSpec, - population_params: dict[str, object], - packed_state_before: Mapping[str, torch.Tensor], - population_reset_step: torch.Tensor | None, - recurrent_msg: torch.Tensor, - static_tensors: dict[str, object], - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - need_grad_packed_state_before: bool, - forward_tape: TransitionBackwardTape | None, -) -> TransitionBackwardResult: - ops = population_spec.transition_ir.ops - input_op = ops[0] - recurrence_op = ops[1] - output_op = ops[2] - batch_size = int(recurrent_msg.shape[0]) - num_receivers = int(recurrent_msg.shape[1]) - hidden_size = int(runtime.hidden_size) - - if forward_tape is not None and forward_tape.input_projection is not None: - input_tape = forward_tape.input_projection - cell_input = input_tape.output - else: - with ( - torch.profiler.record_function("fabric.backward.receiver_affine.input_projection.recompute"), - runtime._backend_owner_timing("receiver_affine.input_projection_recompute"), - ): - input_tape = _diagonal_recurrence_input_projection_tape( - recurrent_msg, - population_spec=population_spec, - population_params=population_params, - static_tensors=static_tensors, - input_op=input_op, - num_receivers=num_receivers, - ) - cell_input = input_tape.output - - state_names = tuple(population_spec.transition_ir.state_inputs) - eligibility_trace_states = _eligibility_trace_state_names(population_spec) - propagate_trace_state_grads = not eligibility_trace_states - dummy_trace_state = recurrent_msg.new_empty((1,)) - state_tensors = tuple( - packed_state_before[state_name] - if state_name in packed_state_before - else dummy_trace_state - if state_name in eligibility_trace_states and not propagate_trace_state_grads - else packed_state_before[state_name] - for state_name in state_names - ) - reset_state_tensors = _reset_state_inputs_for_transition( - state_tensors, - population_reset_step=population_reset_step, - batch_size=batch_size, - device=recurrent_msg.device, - ) - nu_log = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrence_op.inputs[11], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - theta_log = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrence_op.inputs[12], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - w1 = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrence_op.inputs[13], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - w2 = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrence_op.inputs[14], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - activation_id = _activation_id(population_params.get("activation_id")) - recurrence_inputs = ( - cell_input.contiguous(), - *(state.contiguous() for state in reset_state_tensors), - nu_log.contiguous(), - theta_log.contiguous(), - w1.contiguous(), - w2.contiguous(), - ) - output_weight = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - output_op.inputs[1], - num_receivers=num_receivers, - ) - output_bias = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - output_op.inputs[2], - num_receivers=num_receivers, - ) - preproj = None if forward_tape is None else forward_tape.diagonal_preproj - grad_output_weight_is_parameter_layout = False - state_grad_outputs = tuple( - None if state_name in eligibility_trace_states else _grad_state_value(grad_next_packed_state, state_name) - for state_name in population_spec.transition_ir.state_inputs - ) - recurrence_grads: tuple[torch.Tensor | None, ...] | None = None - if grad_recurrent_hidden is None: - grad_preproj = None - grad_output_weight_t = torch.zeros_like(output_weight) - grad_output_bias = None if output_bias is None else torch.zeros_like(output_bias) - elif preproj is None and output_weight.dim() == 3: - with runtime._backend_owner_timing("diagonal_recurrence.output_projection_backward"): - grad_preproj, grad_output_weight_t, grad_output_bias = diagonal_recurrence_output_projection_backward_cuda( - recurrence_inputs[0], - recurrence_inputs[1], - recurrence_inputs[2], - recurrence_inputs[11], - recurrence_inputs[12], - recurrence_inputs[13], - recurrence_inputs[14], - output_weight, - output_bias, - grad_recurrent_hidden, - activation_id=activation_id, - block_b=receiver_major_affine_backward_block_b( - batch_size=batch_size, - output_dim=int(output_weight.shape[2]), - ), - return_grad_weight_transposed=False, - ) - grad_output_weight_is_parameter_layout = True - else: - if preproj is None: - with runtime._backend_owner_timing("diagonal_recurrence.output_projection_recompute"): - preproj = diagonal_recurrence_preproj_cuda( - recurrence_inputs[0], - recurrence_inputs[1], - recurrence_inputs[2], - recurrence_inputs[11], - recurrence_inputs[12], - recurrence_inputs[13], - recurrence_inputs[14], - activation_id=activation_id, - ) - with runtime._backend_owner_timing("diagonal_recurrence.output_projection_backward"): - grad_preproj, grad_output_weight_t, grad_output_bias = _receiver_major_linear_backward_no_grad( - preproj, - output_weight, - output_bias, - grad_recurrent_hidden, - ) - grad_outputs = ( - grad_preproj, - *state_grad_outputs, - ) - if recurrence_grads is None: - with ( - torch.profiler.record_function("fabric.backward.diagonal_recurrence"), - runtime._backend_owner_timing("diagonal_recurrence.core"), - ): - recurrence_grads = diagonal_recurrence_backward_cuda( - *recurrence_inputs, - grad_outputs, - activation_id=activation_id, - return_state_grads=need_grad_packed_state_before, - return_trace_state_grads=propagate_trace_state_grads, - ) - grad_cell_input = recurrence_grads[0] - grad_state_tensors = cast(tuple[torch.Tensor, ...], recurrence_grads[1:11]) - grad_nu_log, grad_theta_log, grad_w1, grad_w2 = cast(tuple[torch.Tensor, ...], recurrence_grads[11:15]) - with runtime._backend_owner_timing("receiver_affine.input_projection_backward"): - grad_recurrent_msg, static_param_grads, materialized_param_grads_from_input = input_tape.backward( - grad_cell_input - ) - materialized_param_grads = { - "nu_log": grad_nu_log, - "theta_log": grad_theta_log, - "w1": grad_w1, - "w2": grad_w2, - "out_proj_weight": grad_output_weight_t - if grad_output_weight_is_parameter_layout - else grad_output_weight_t.transpose(1, 2).contiguous(), - "out_proj_bias": grad_output_bias, - **materialized_param_grads_from_input, - } - grad_state_before = { - state_name: None - if state_name in eligibility_trace_states - else _zero_reset_row_grad(grad, population_reset_step, batch_size=batch_size) - for state_name, grad in zip(state_names, grad_state_tensors, strict=True) - } - return TransitionBackwardResult( - grad_recurrent_msg=grad_recurrent_msg, - grad_packed_state_before=grad_state_before, - materialized_param_grads=materialized_param_grads, - static_source_grads=static_param_grads, - ) - - -def _eligibility_trace_state_names(population_spec: CellBackendSpec) -> frozenset[str]: - return frozenset( - schema.name for schema in population_spec.private_state_schema if schema.semantic_kind == "eligibility_trace" - ) - - def _diagonal_recurrence_input_projection( recurrent_msg: torch.Tensor, *, - population_spec: CellBackendSpec, + transition_program: Any, population_params: dict[str, object], static_tensors: dict[str, object], input_op: Any, @@ -1080,77 +316,249 @@ def _diagonal_recurrence_input_projection( recurrent_cell_bias, input_proj_weight, ) - input_weight = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - input_op.inputs[1], - num_receivers=num_receivers, - ) - input_bias = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - input_op.inputs[2], - num_receivers=num_receivers, - ) - return _receiver_major_linear(recurrent_msg, input_weight, input_bias) - - -def _grad_state_value(grad_state: Any, state_name: str) -> torch.Tensor | None: - if grad_state is None: - return None - if isinstance(grad_state, (dict, TensorDictBase)): - return cast(torch.Tensor | None, grad_state.get(state_name)) - if torch.is_tensor(grad_state): - return grad_state - raise RuntimeError(f"Unsupported Fabric CUDA transition grad-state container {type(grad_state).__name__}") - - -def _reset_state_inputs_for_transition( - state_tensors: tuple[torch.Tensor, ...], - *, - population_reset_step: torch.Tensor | None, - batch_size: int, - device: torch.device, -) -> tuple[torch.Tensor, ...]: - if population_reset_step is None: - return state_tensors - reset_rows = torch.as_tensor(population_reset_step, device=device, dtype=torch.bool).view(batch_size) - resettable_indices = tuple( - index - for index, tensor in enumerate(state_tensors) - if tensor.dim() >= 2 and int(tensor.shape[0]) == int(batch_size) - ) - if len(resettable_indices) == len(state_tensors): - return cast(tuple[torch.Tensor, ...], reset_backend_tensors_rows_cuda(state_tensors, reset_rows)) - resettable = tuple(state_tensors[index] for index in resettable_indices) - reset_values = reset_backend_tensors_rows_cuda(resettable, reset_rows) if resettable else () - reset_by_index = dict(zip(resettable_indices, reset_values, strict=True)) - return tuple(cast(torch.Tensor, reset_by_index.get(index, tensor)) for index, tensor in enumerate(state_tensors)) - - -def _zero_reset_row_grad( - grad: torch.Tensor | None, - population_reset_step: torch.Tensor | None, - *, - batch_size: int, -) -> torch.Tensor | None: - if grad is None or population_reset_step is None: - return grad - reset_rows = torch.as_tensor(population_reset_step, device=grad.device, dtype=torch.bool).view(batch_size) - return cast(torch.Tensor, reset_backend_tensors_rows_cuda((grad,), reset_rows)[0]) + input_weight = _resolve_transition_parameter( + transition_program, + population_params, + static_tensors, + input_op.inputs[1], + num_receivers=num_receivers, + ) + input_bias = _resolve_transition_parameter( + transition_program, + population_params, + static_tensors, + input_op.inputs[2], + num_receivers=num_receivers, + ) + return _receiver_major_linear(recurrent_msg, input_weight, input_bias) + + +def _accumulate_named_grad( + grads: dict[str, torch.Tensor], + name: str, + grad: torch.Tensor | None, +) -> None: + if grad is None: + return + existing = grads.get(name) + grads[name] = grad if existing is None else existing + grad + + +def _tensor_group_identity(tensor: torch.Tensor | None) -> tuple[object, ...]: + if tensor is None: + return ("none",) + return ( + "tensor", + int(tensor.data_ptr()), + tuple(int(dim) for dim in tensor.shape), + tuple(int(stride) for stride in tensor.stride()), + ) + + +def _head_grouped_gate_affine_flat_tensors( + x: torch.Tensor, + weight: torch.Tensor, + grad_output: torch.Tensor, +) -> tuple[int, int, int, int, torch.Tensor, torch.Tensor, torch.Tensor]: + batch, num_receivers, gates, hidden_size = grad_output.shape + receivers, num_heads, head_dim, gate_dim = weight.shape + if receivers != num_receivers or gates != 4 or hidden_size != num_heads * head_dim or gate_dim != 4 * head_dim: + raise RuntimeError("Fabric CUDA head-grouped gate affine backward received incompatible tensor shapes") + x_heads = x.view(batch, num_receivers, num_heads, head_dim) + grad_gate_proj = ( + grad_output.view(batch, num_receivers, 4, num_heads, head_dim) + .permute(0, 1, 3, 2, 4) + .reshape(batch, num_receivers * num_heads, 4 * head_dim) + .contiguous() + ) + return ( + batch, + num_receivers, + num_heads, + head_dim, + x_heads.reshape(batch, num_receivers * num_heads, head_dim), + weight.reshape(num_receivers * num_heads, head_dim, 4 * head_dim), + grad_gate_proj, + ) + + +def _receiver_major_linear_input_backward_no_grad( + x: torch.Tensor, + weight: torch.Tensor, + grad_output: torch.Tensor | None, +) -> torch.Tensor | None: + if grad_output is None: + return None + with torch.profiler.record_function("fabric.backward.receiver_affine.input_only"): + if weight.dim() == 3: + if ( + x.is_cuda + and weight.is_cuda + and grad_output.is_cuda + and x.dtype == torch.float32 + and weight.dtype == torch.float32 + and grad_output.dtype == torch.float32 + ): + return receiver_major_affine_input_backward_cuda( + x, + weight, + grad_output, + block_b=receiver_major_affine_backward_block_b( + batch_size=int(x.shape[0]), + output_dim=int(weight.shape[2]), + ), + ) + return torch.bmm(grad_output.transpose(0, 1), weight.transpose(1, 2)).transpose(0, 1) + if weight.dim() != 2: + raise RuntimeError(f"Unsupported affine weight rank {weight.dim()} for Fabric transition backward") + return grad_output.matmul(weight) + + +def _transition_linear_input_backward_no_grad( + x: torch.Tensor, + weight: torch.Tensor, + grad_output: torch.Tensor | None, +) -> torch.Tensor | None: + if grad_output is None: + return None + if weight.dim() != 4: + return _receiver_major_linear_input_backward_no_grad(x, weight, grad_output) + with torch.profiler.record_function("fabric.backward.receiver_affine.input_only"): + with torch.profiler.record_function("fabric.backward.gate_affine.receiver_affine"): + if not ( + x.is_cuda + and weight.is_cuda + and grad_output.is_cuda + and x.dtype == torch.float32 + and weight.dtype == torch.float32 + and grad_output.dtype == torch.float32 + ): + raise RuntimeError("Fabric CUDA transition gate affine input backward requires CUDA float32 tensors") + batch, num_receivers, num_heads, head_dim, flat_x, flat_weight, flat_grad_output = ( + _head_grouped_gate_affine_flat_tensors(x, weight, grad_output) + ) + grad_x_flat = receiver_major_affine_input_backward_cuda( + flat_x, + flat_weight, + flat_grad_output, + block_b=receiver_major_affine_backward_block_b( + batch_size=batch, + output_dim=4 * head_dim, + ), + ) + return grad_x_flat.reshape(batch, num_receivers, num_heads, head_dim).reshape( + batch, + num_receivers, + num_heads * head_dim, + ) + + +def _receiver_major_linear_param_backward_no_grad( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor | None]: + with torch.profiler.record_function("fabric.backward.receiver_affine.temporal_param"): + if weight.dim() == 3: + if ( + x.is_cuda + and weight.is_cuda + and grad_output.is_cuda + and x.dtype == torch.float32 + and weight.dtype == torch.float32 + and grad_output.dtype == torch.float32 + ): + block_b = receiver_major_affine_backward_block_b( + batch_size=int(x.shape[0]), + output_dim=int(weight.shape[2]), + ) + if bias is not None and bias.is_cuda and bias.dtype == torch.float32: + _grad_input, grad_weight, grad_bias = receiver_major_affine_bias_backward_cuda( + x, + weight, + bias, + grad_output, + block_b=block_b, + return_input_grad=False, + ) + return grad_weight, grad_bias + _grad_input, grad_weight = receiver_major_affine_backward_cuda( + x, + weight, + grad_output, + block_b=block_b, + return_input_grad=False, + ) + else: + grad_weight = torch.bmm(x.transpose(0, 1).transpose(1, 2), grad_output.transpose(0, 1)) + grad_bias = None if bias is None else _receiver_bias_grad(grad_output, bias) + return grad_weight, grad_bias + if weight.dim() == 4: + if not ( + x.is_cuda + and weight.is_cuda + and grad_output.is_cuda + and x.dtype == torch.float32 + and weight.dtype == torch.float32 + and grad_output.dtype == torch.float32 + ): + raise RuntimeError("Fabric CUDA transition gate affine param backward requires CUDA float32 tensors") + batch, num_receivers, num_heads, head_dim, flat_x, flat_weight, flat_grad_output = ( + _head_grouped_gate_affine_flat_tensors(x, weight, grad_output) + ) + _grad_input, grad_weight_flat = receiver_major_affine_backward_cuda( + flat_x, + flat_weight, + flat_grad_output, + block_b=receiver_major_affine_backward_block_b( + batch_size=batch, + output_dim=4 * head_dim, + ), + return_input_grad=False, + ) + grad_weight = grad_weight_flat.reshape(num_receivers, num_heads, head_dim, 4 * head_dim) + grad_bias = None + if bias is not None: + grad_bias = ( + grad_output.sum(dim=0) + .reshape(num_receivers, 4 * num_heads * head_dim) + .view(num_receivers, num_heads, 4, head_dim) + .permute(0, 2, 1, 3) + .contiguous() + ) + return grad_weight, grad_bias + if weight.dim() != 2: + raise RuntimeError(f"Unsupported affine weight rank {weight.dim()} for Fabric transition backward") + grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) + grad_bias = None if bias is None else _receiver_bias_grad(grad_output, bias) + return grad_weight, grad_bias -def _accumulate_named_grad( - grads: dict[str, torch.Tensor], - name: str, - grad: torch.Tensor | None, -) -> None: - if grad is None: - return - existing = grads.get(name) - grads[name] = grad if existing is None else existing + grad +def reduce_transition_input_projection_param_grad_steps( + steps: Sequence[TransitionInputProjectionParamGradStep], +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + materialized_param_grads: dict[str, torch.Tensor] = {} + static_source_grads: dict[str, torch.Tensor] = {} + grouped_steps: dict[tuple[object, ...], list[TransitionInputProjectionParamGradStep]] = {} + for step in steps: + grouped_steps.setdefault(step.group_key, []).append(step) + for group in grouped_steps.values(): + first = group[0] + input_tensor = torch.cat(tuple(step.input_tensor for step in group), dim=0) + grad_output = torch.cat(tuple(step.grad_output for step in group), dim=0) + grad_weight, grad_bias = _receiver_major_linear_param_backward_no_grad( + input_tensor, + first.weight, + first.bias, + grad_output, + ) + static_update, materialized_update = first.map_param_grads(grad_weight, grad_bias) + for name, grad in static_update.items(): + _accumulate_named_grad(static_source_grads, name, grad) + for name, grad in materialized_update.items(): + _accumulate_named_grad(materialized_param_grads, name, grad) + return materialized_param_grads, static_source_grads def _receiver_major_linear_backward_no_grad( @@ -1349,6 +757,129 @@ def _transition_linear_backward_no_grad( return grad_x_heads.reshape(batch, num_receivers, hidden_size), grad_weight, grad_bias +def _transition_gate_affine_param_grad_step( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + grad_output: torch.Tensor, +) -> TransitionInputProjectionParamGradStep: + def map_param_grads( + grad_weight: torch.Tensor, + grad_bias: torch.Tensor | None, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + materialized_grads: dict[str, torch.Tensor] = {"gate_weight": grad_weight} + _accumulate_named_grad(materialized_grads, "bias", grad_bias) + return {}, materialized_grads + + return TransitionInputProjectionParamGradStep( + input_tensor=x, + weight=weight, + bias=bias, + grad_output=grad_output, + group_key=( + "gated_transition_gate_affine", + _tensor_group_identity(weight), + _tensor_group_identity(bias), + ), + map_param_grads=map_param_grads, + ) + + +def _diagonal_output_projection_param_grad_step( + preproj: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + grad_output: torch.Tensor, +) -> TransitionInputProjectionParamGradStep: + def map_param_grads( + grad_weight: torch.Tensor, + grad_bias: torch.Tensor | None, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + materialized_grads: dict[str, torch.Tensor] = {"out_proj_weight": grad_weight} + _accumulate_named_grad(materialized_grads, "out_proj_bias", grad_bias) + return {}, materialized_grads + + return TransitionInputProjectionParamGradStep( + input_tensor=preproj, + weight=weight, + bias=bias, + grad_output=grad_output, + group_key=( + "diagonal_output_projection", + _tensor_group_identity(weight), + _tensor_group_identity(bias), + ), + map_param_grads=map_param_grads, + ) + + +def _static_recurrent_input_projection_param_grad_step( + recurrent_msg: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + grad_output: torch.Tensor, +) -> TransitionInputProjectionParamGradStep: + def map_param_grads( + grad_weight: torch.Tensor, + grad_bias: torch.Tensor | None, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + static_grads: dict[str, torch.Tensor] = {} + _accumulate_named_grad(static_grads, "value_to_cell_weight", grad_weight) + _accumulate_named_grad(static_grads, "recurrent_cell_bias", grad_bias) + return static_grads, {} + + return TransitionInputProjectionParamGradStep( + input_tensor=recurrent_msg, + weight=weight, + bias=bias, + grad_output=grad_output, + group_key=( + "static_recurrent_input_projection", + _tensor_group_identity(weight), + _tensor_group_identity(bias), + ), + map_param_grads=map_param_grads, + ) + + +def _diagonal_input_projection_param_grad_step( + recurrent_msg: torch.Tensor, + input_weight: torch.Tensor, + input_bias: torch.Tensor | None, + grad_output: torch.Tensor, + *, + static_tensors: dict[str, object], +) -> TransitionInputProjectionParamGradStep: + def map_param_grads( + grad_weight: torch.Tensor, + grad_bias: torch.Tensor | None, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + fused_recurrent_weight = static_tensors.get("fused_recurrent_value_to_cell_weight") + if torch.is_tensor(fused_recurrent_weight) and input_weight.data_ptr() == fused_recurrent_weight.data_ptr(): + return _unfuse_recurrent_input_projection_grads( + static_tensors=static_tensors, + grad_fused_weight=grad_weight, + grad_fused_bias=grad_bias, + ) + static_grads: dict[str, torch.Tensor] = {} + _accumulate_named_grad(static_grads, "value_to_cell_weight", grad_weight.sum(dim=0).transpose(0, 1)) + _accumulate_named_grad(static_grads, "recurrent_cell_bias", grad_bias) + return static_grads, {} + + return TransitionInputProjectionParamGradStep( + input_tensor=recurrent_msg, + weight=input_weight, + bias=input_bias, + grad_output=grad_output, + group_key=( + "diagonal_recurrent_input_projection", + _tensor_group_identity(input_weight), + _tensor_group_identity(input_bias), + ), + map_param_grads=map_param_grads, + ) + + def _gated_recurrent_matmul_no_grad(y_prev: torch.Tensor, recurrent_kernel: torch.Tensor) -> torch.Tensor: _, num_receivers, _ = y_prev.shape _, gates, num_heads, head_dim, _ = recurrent_kernel.shape @@ -1455,11 +986,16 @@ def _gated_logspace_input_projection_tape( recurrent_msg: torch.Tensor, *, static_tensors: dict[str, object], + output_override: torch.Tensor | None = None, ) -> _ProjectionBackwardTape: fused_weight = static_tensors["fused_recurrent_value_to_cell_weight"] if bool(static_tensors["fused_recurrent_population_input"]) and torch.is_tensor(fused_weight): fused_bias = cast(torch.Tensor, static_tensors["fused_recurrent_cell_bias"]) - output = _receiver_major_linear(recurrent_msg, fused_weight, fused_bias) + output = ( + _receiver_major_linear(recurrent_msg, fused_weight, fused_bias) + if output_override is None + else output_override + ) def backward( grad_output: torch.Tensor | None, @@ -1477,11 +1013,48 @@ def backward( ) return grad_input, static_grads, materialized_grads - return _ProjectionBackwardTape(output=output, backward=backward) + def map_param_grads( + grad_fused_weight: torch.Tensor, + grad_fused_bias: torch.Tensor | None, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + if grad_fused_bias is None: + raise RuntimeError("Fused recurrent input projection requires a fused bias gradient") + return _unfuse_recurrent_input_projection_grads( + static_tensors=static_tensors, + grad_fused_weight=grad_fused_weight, + grad_fused_bias=grad_fused_bias, + ) + + def backward_deferred( + grad_output: torch.Tensor | None, + ) -> tuple[ + torch.Tensor | None, + dict[str, torch.Tensor], + dict[str, torch.Tensor], + TransitionInputProjectionParamGradStep | None, + ]: + grad_input = _receiver_major_linear_input_backward_no_grad(recurrent_msg, fused_weight, grad_output) + step = None + if grad_output is not None: + step = TransitionInputProjectionParamGradStep( + input_tensor=recurrent_msg, + weight=fused_weight, + bias=fused_bias, + grad_output=grad_output, + group_key=( + "fused_recurrent_input_projection", + _tensor_group_identity(fused_weight), + _tensor_group_identity(fused_bias), + ), + map_param_grads=map_param_grads, + ) + return grad_input, {}, {}, step + + return _ProjectionBackwardTape(output=output, backward=backward, backward_deferred=backward_deferred) weight = cast(torch.Tensor, static_tensors["value_to_cell_weight"]) bias = cast(torch.Tensor, static_tensors["recurrent_cell_bias"]) - output = torch.nn.functional.linear(recurrent_msg, weight) + bias + output = torch.nn.functional.linear(recurrent_msg, weight) + bias if output_override is None else output_override def backward( grad_output: torch.Tensor | None, @@ -1497,30 +1070,65 @@ def backward( _accumulate_named_grad(static_grads, "recurrent_cell_bias", grad_bias) return grad_input, static_grads, {} - return _ProjectionBackwardTape(output=output, backward=backward) + def map_param_grads( + grad_weight: torch.Tensor, + grad_bias: torch.Tensor | None, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + static_grads: dict[str, torch.Tensor] = {} + _accumulate_named_grad(static_grads, "value_to_cell_weight", grad_weight) + _accumulate_named_grad(static_grads, "recurrent_cell_bias", grad_bias) + return static_grads, {} + + def backward_deferred( + grad_output: torch.Tensor | None, + ) -> tuple[ + torch.Tensor | None, + dict[str, torch.Tensor], + dict[str, torch.Tensor], + TransitionInputProjectionParamGradStep | None, + ]: + grad_input = _receiver_major_linear_input_backward_no_grad(recurrent_msg, weight, grad_output) + step = None + if grad_output is not None: + step = TransitionInputProjectionParamGradStep( + input_tensor=recurrent_msg, + weight=weight, + bias=bias, + grad_output=grad_output, + group_key=( + "static_recurrent_input_projection", + _tensor_group_identity(weight), + _tensor_group_identity(bias), + ), + map_param_grads=map_param_grads, + ) + return grad_input, {}, {}, step + + return _ProjectionBackwardTape(output=output, backward=backward, backward_deferred=backward_deferred) def _diagonal_recurrence_input_projection_tape( recurrent_msg: torch.Tensor, *, - population_spec: CellBackendSpec, + transition_program: Any, population_params: dict[str, object], static_tensors: dict[str, object], input_op: Any, num_receivers: int, retain_factorized_base: bool = True, + output_override: torch.Tensor | None = None, ) -> _ProjectionBackwardTape: if bool(static_tensors.get("replay_unfused_recurrent_input_projection", False)): raise RuntimeError("Fabric CUDA physical transition backward does not use replay input projection modes") input_weight = _resolve_transition_parameter( - population_spec, + transition_program, population_params, static_tensors, input_op.inputs[1], num_receivers=num_receivers, ) input_bias = _resolve_transition_parameter( - population_spec, + transition_program, population_params, static_tensors, input_op.inputs[2], @@ -1538,14 +1146,18 @@ def _diagonal_recurrence_input_projection_tape( ): value_to_cell_weight = cast(torch.Tensor, static_tensors["value_to_cell_weight"]) recurrent_cell_bias = cast(torch.Tensor, static_tensors["recurrent_cell_bias"]) - with torch.profiler.record_function("fabric.backward.projection.factorized_recurrent_input.base_recompute"): - base_input = factorized_recurrent_input_base_cuda( - recurrent_msg=recurrent_msg, - value_to_cell_weight=value_to_cell_weight, - recurrent_cell_bias=recurrent_cell_bias, - ) - output = _receiver_major_linear(base_input, input_proj_weight_t, None) - cached_base_input = base_input if retain_factorized_base else None + if output_override is None: + with torch.profiler.record_function("fabric.backward.projection.factorized_recurrent_input.base_recompute"): + base_input = factorized_recurrent_input_base_cuda( + recurrent_msg=recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + ) + output = _receiver_major_linear(base_input, input_proj_weight_t, None) + cached_base_input = base_input if retain_factorized_base else None + else: + output = output_override + cached_base_input = None def backward( grad_output: torch.Tensor | None, @@ -1584,9 +1196,68 @@ def backward( } return grad_recurrent_msg, static_grads, {"input_proj_weight": grad_input_proj_weight} - return _ProjectionBackwardTape(output=output, backward=backward) + def map_param_grads( + grad_input_proj_weight_t: torch.Tensor, + _grad_unused_bias: torch.Tensor | None, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + return {}, {"input_proj_weight": grad_input_proj_weight_t.transpose(1, 2).contiguous()} + + def backward_deferred( + grad_output: torch.Tensor | None, + ) -> tuple[ + torch.Tensor | None, + dict[str, torch.Tensor], + dict[str, torch.Tensor], + TransitionInputProjectionParamGradStep | None, + ]: + local_base_input = cached_base_input + if local_base_input is None: + with torch.profiler.record_function( + "fabric.backward.projection.factorized_recurrent_input.base_recompute" + ): + local_base_input = factorized_recurrent_input_base_cuda( + recurrent_msg=recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + ) + grad_base = _receiver_major_linear_input_backward_no_grad( + local_base_input, + input_proj_weight_t, + grad_output, + ) + if grad_base is None or grad_output is None: + return None, {}, {}, None + with torch.profiler.record_function("fabric.backward.projection.factorized_recurrent_input.direct"): + grad_recurrent_msg, grad_value_to_cell_weight, grad_recurrent_cell_bias = ( + factorized_recurrent_input_base_grads_cuda( + recurrent_msg=recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + grad_base=grad_base, + ) + ) + static_grads: dict[str, torch.Tensor] = { + "value_to_cell_weight": grad_value_to_cell_weight, + "recurrent_cell_bias": grad_recurrent_cell_bias, + } + step = TransitionInputProjectionParamGradStep( + input_tensor=local_base_input, + weight=input_proj_weight_t, + bias=None, + grad_output=grad_output, + group_key=( + "factorized_input_projection_weight", + _tensor_group_identity(input_proj_weight_t), + ), + map_param_grads=map_param_grads, + ) + return grad_recurrent_msg, static_grads, {}, step + + return _ProjectionBackwardTape(output=output, backward=backward, backward_deferred=backward_deferred) - output = _receiver_major_linear(recurrent_msg, input_weight, input_bias) + output = ( + _receiver_major_linear(recurrent_msg, input_weight, input_bias) if output_override is None else output_override + ) def backward( grad_output: torch.Tensor | None, @@ -1611,7 +1282,49 @@ def backward( _accumulate_named_grad(static_grads, "recurrent_cell_bias", grad_bias) return grad_input, static_grads, materialized_grads - return _ProjectionBackwardTape(output=output, backward=backward) + def map_param_grads( + grad_weight: torch.Tensor, + grad_bias: torch.Tensor | None, + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + static_grads: dict[str, torch.Tensor] = {} + materialized_grads: dict[str, torch.Tensor] = {} + fused_recurrent_weight = static_tensors.get("fused_recurrent_value_to_cell_weight") + if torch.is_tensor(fused_recurrent_weight) and input_weight.data_ptr() == fused_recurrent_weight.data_ptr(): + return _unfuse_recurrent_input_projection_grads( + static_tensors=static_tensors, + grad_fused_weight=grad_weight, + grad_fused_bias=grad_bias, + ) + _accumulate_named_grad(static_grads, "value_to_cell_weight", grad_weight.sum(dim=0).transpose(0, 1)) + _accumulate_named_grad(static_grads, "recurrent_cell_bias", grad_bias) + return static_grads, materialized_grads + + def backward_deferred( + grad_output: torch.Tensor | None, + ) -> tuple[ + torch.Tensor | None, + dict[str, torch.Tensor], + dict[str, torch.Tensor], + TransitionInputProjectionParamGradStep | None, + ]: + grad_input = _receiver_major_linear_input_backward_no_grad(recurrent_msg, input_weight, grad_output) + step = None + if grad_output is not None: + step = TransitionInputProjectionParamGradStep( + input_tensor=recurrent_msg, + weight=input_weight, + bias=input_bias, + grad_output=grad_output, + group_key=( + "diagonal_recurrent_input_projection", + _tensor_group_identity(input_weight), + _tensor_group_identity(input_bias), + ), + map_param_grads=map_param_grads, + ) + return grad_input, {}, {}, step + + return _ProjectionBackwardTape(output=output, backward=backward, backward_deferred=backward_deferred) def _unfuse_recurrent_input_projection_grads( @@ -1619,6 +1332,7 @@ def _unfuse_recurrent_input_projection_grads( static_tensors: dict[str, object], grad_fused_weight: torch.Tensor, grad_fused_bias: torch.Tensor | None, + selected_static_source: str = "", ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: input_proj_weight_t = static_tensors.get("input_proj_weight_t") if not torch.is_tensor(input_proj_weight_t): @@ -1629,7 +1343,48 @@ def _unfuse_recurrent_input_projection_grads( input_proj_weight_t = only_params.get("input_proj_weight_t") if not torch.is_tensor(input_proj_weight_t): raise RuntimeError("Fabric CUDA fused recurrent input backward requires input projection weight tape") - value_to_cell_weight = cast(torch.Tensor, static_tensors["value_to_cell_weight"]) + _, hidden, _ = (int(dim) for dim in input_proj_weight_t.shape[:3]) + value_dim = int(grad_fused_weight.shape[1]) if grad_fused_weight.dim() >= 2 else -1 + + candidate_sources: list[str] = [] + if not selected_static_source: + selected_static_source = str(static_tensors.get("recurrent_message_to_cell_weight_source", "")) + if selected_static_source in {"message_to_cell_weight", "value_to_cell_weight"}: + candidate_sources.append(selected_static_source) + candidate_sources.extend(("value_to_cell_weight", "message_to_cell_weight")) + + if grad_fused_weight.dim() == 2 and selected_static_source == "message_to_cell_weight": + message_to_cell_weight = static_tensors.get("message_to_cell_weight") + if torch.is_tensor(message_to_cell_weight) and tuple(int(dim) for dim in grad_fused_weight.shape) == tuple( + int(dim) for dim in message_to_cell_weight.shape + ): + static_grads: dict[str, torch.Tensor] = {"message_to_cell_weight": grad_fused_weight} + if grad_fused_bias is not None: + _accumulate_named_grad(static_grads, "recurrent_cell_bias", grad_fused_bias) + return static_grads, {} + + base_source_name = "" + base_weight: torch.Tensor | None = None + for source_name in candidate_sources: + if source_name == base_source_name: + continue + value = static_tensors.get(source_name) + if ( + torch.is_tensor(value) + and value.dim() == 2 + and tuple(int(dim) for dim in value.shape) == (hidden, value_dim) + ): + base_source_name = source_name + base_weight = value + break + if base_weight is None: + value = static_tensors.get("value_to_cell_weight") + if torch.is_tensor(value): + base_source_name = "value_to_cell_weight" + base_weight = value + if base_weight is None: + raise RuntimeError("Fabric CUDA fused recurrent input backward requires a compiler-bound base weight") + recurrent_cell_bias = cast(torch.Tensor, static_tensors["recurrent_cell_bias"]) recurrent_cell_bias_2d = recurrent_cell_bias.squeeze(0) if recurrent_cell_bias.dim() == 3 else recurrent_cell_bias grad_fused_bias_2d = ( @@ -1641,248 +1396,20 @@ def _unfuse_recurrent_input_projection_grads( grad_value_to_cell_weight, grad_recurrent_cell_bias, grad_input_proj_weight = ( factorized_recurrent_input_projection_grads_cuda( input_proj_weight_t=input_proj_weight_t, - value_to_cell_weight=value_to_cell_weight, + value_to_cell_weight=base_weight, recurrent_cell_bias=recurrent_cell_bias_2d, grad_fused_weight=grad_fused_weight, grad_fused_bias=grad_fused_bias_2d, ) ) - static_grads: dict[str, torch.Tensor] = {"value_to_cell_weight": grad_value_to_cell_weight} + static_grads: dict[str, torch.Tensor] = {base_source_name: grad_value_to_cell_weight} if grad_recurrent_cell_bias is not None: _accumulate_named_grad( static_grads, "recurrent_cell_bias", grad_recurrent_cell_bias.unsqueeze(0) if grad_recurrent_cell_bias.dim() == 2 else grad_recurrent_cell_bias, ) - return static_grads, {"input_proj_weight": grad_input_proj_weight} - - -def _lower_gated_logspace_recurrence_transition( - runtime: Any, - *, - population_spec: CellBackendSpec, - population_params: dict[str, object], - packed_state_before: Mapping[str, torch.Tensor] | None, - population_reset_step: torch.Tensor | None, - recurrent_msg: torch.Tensor, - static_tensors: dict[str, object], - materialize_recurrent_kv: bool, - materialize_backward_tape: bool, - materialize_recurrence_backward_tape: bool, - materialize_next_state: bool, -) -> TransitionForwardResult: - ops = population_spec.transition_ir.ops - input_op = ops[0] - recurrent_op = ops[1] - core_op = ops[2] - public_op = ops[3] - batch_size = int(recurrent_msg.shape[0]) - num_receivers = int(recurrent_msg.shape[1]) - hidden_size = int(runtime.hidden_size) - - input_tape: _ProjectionBackwardTape | None = None - with torch.profiler.record_function("fabric.projection.population_input"): - if materialize_backward_tape: - input_tape = _gated_logspace_input_projection_tape(recurrent_msg, static_tensors=static_tensors) - population_input = input_tape.output - else: - population_input = _prepare_gated_logspace_population_input( - recurrent_msg, - static_tensors=static_tensors, - ) - gate_weight = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - input_op.inputs[1], - num_receivers=num_receivers, - ) - gate_bias = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - input_op.inputs[2], - num_receivers=num_receivers, - ) - gate_logits = _transition_linear(population_input, gate_weight, gate_bias, hidden_size=hidden_size) - - if packed_state_before is None: - y_prev = None - c_prev = None - n_prev = None - m_prev = None - else: - y_prev = packed_state_before[recurrent_op.inputs[0]] - c_prev = packed_state_before[core_op.inputs[2]] - n_prev = packed_state_before[core_op.inputs[3]] - m_prev = packed_state_before[core_op.inputs[4]] - if population_reset_step is not None and y_prev is not None: - reset_mask = torch.as_tensor(population_reset_step, device=recurrent_msg.device, dtype=torch.bool).view( - batch_size, - 1, - 1, - ) - reset_rows = reset_mask.view(batch_size) - y_prev, c_prev, n_prev, m_prev = reset_backend_tensors_rows_cuda( - (y_prev, c_prev, n_prev, m_prev), - reset_rows, - ) - - recurrent_kernel = _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - recurrent_op.inputs[1], - num_receivers=num_receivers, - ) - recurrent_gate_logits = None if y_prev is None else _gated_recurrent_matmul(y_prev, recurrent_kernel) - outnorm_weight = _receiver_hidden_parameter( - _resolve_transition_parameter( - population_spec, - population_params, - static_tensors, - public_op.inputs[1], - num_receivers=num_receivers, - ), - num_receivers=num_receivers, - hidden_size=hidden_size, - ) - eps = _scalar_parameter(population_params["outnorm_eps"]) - runtime._last_gated_logspace_epilogue_backward_mode = "triton_cuda" - if materialize_next_state: - public_y, next_y, next_c, next_n, next_m = gated_logspace_recurrence_outnorm_cuda( - gate_logits, - recurrent_gate_logits, - c_prev, - n_prev, - m_prev, - outnorm_weight, - eps=eps, - ) - else: - public_y, next_y, next_c, next_n, next_m = gated_logspace_recurrence_outnorm_forward_cuda( - gate_logits, - recurrent_gate_logits, - c_prev, - n_prev, - m_prev, - outnorm_weight, - eps=eps, - materialize_next_state=False, - ) - next_packed_state = None - if materialize_next_state: - if next_y is None or next_c is None or next_n is None or next_m is None: - raise RuntimeError("Fabric gated recurrence requested next state but CUDA kernel returned none.") - next_state_by_name = { - core_op.outputs[0].removeprefix("next_"): next_y, - core_op.outputs[1].removeprefix("next_"): next_c, - core_op.outputs[2].removeprefix("next_"): next_n, - core_op.outputs[3].removeprefix("next_"): next_m, - } - next_packed_state = TensorDict( - next_state_by_name, - batch_size=[batch_size, num_receivers], - device=recurrent_msg.device, - ) - recurrent_k: torch.Tensor | None = None - recurrent_v: torch.Tensor | None = None - if materialize_recurrent_kv: - with torch.profiler.record_function("fabric.projection.population_public_kv"): - recurrent_k, recurrent_v = runtime._project_sender_kv_from_cells_step( - public_y, - sender_input_to_kv_weight=cast( - torch.Tensor | None, static_tensors["recurrent_sender_input_to_kv_weight"] - ), - grouped_sender_input_to_kv_weight=cast( - torch.Tensor | None, - static_tensors["recurrent_group_input_to_kv_weight"], - ), - sender_group_size=runtime._recurrent_sender_kv_group_size, - ) - backward_tape = None - if materialize_backward_tape: - backward_tape = TransitionBackwardTape( - input_projection=input_tape, - gated_gate_logits=gate_logits if materialize_recurrence_backward_tape else None, - gated_recurrent_gate_logits=recurrent_gate_logits if materialize_recurrence_backward_tape else None, - ) - return TransitionForwardResult(next_packed_state, public_y, recurrent_k, recurrent_v, backward_tape) - - -def _is_diagonal_recurrence_transition(population_spec: CellBackendSpec) -> bool: - ops = population_spec.transition_ir.ops - if len(ops) != 3: - return False - input_op, recurrence_op, output_op = ops - return ( - input_op.name == "linear" - and recurrence_op.name == "diag_rtu" - and output_op.name == "linear" - and len(recurrence_op.inputs) >= 15 - and len(recurrence_op.outputs) == 11 - and tuple(recurrence_op.inputs[1:11]) == tuple(population_spec.transition_ir.state_inputs) - and tuple(recurrence_op.outputs[1:]) == tuple(population_spec.transition_ir.state_outputs) - and output_op.inputs[0] == recurrence_op.outputs[0] - ) - - -def _is_gated_logspace_recurrence_transition(population_spec: CellBackendSpec) -> bool: - ops = population_spec.transition_ir.ops - if len(ops) != 4: - return False - input_op, recurrent_op, core_op, public_op = ops - return ( - input_op.name == "linear" - and recurrent_op.name == "matmul" - and core_op.name == "gated_logspace_recurrence" - and public_op.name == "norm_or_identity" - and tuple(core_op.inputs[2:]) == tuple(population_spec.transition_ir.state_inputs[1:]) - and tuple(core_op.outputs) == tuple(population_spec.transition_ir.state_outputs) - and public_op.inputs[0] == core_op.outputs[0] - ) - - -def _resolve_transition_parameter( - population_spec: CellBackendSpec, - population_params: dict[str, object], - static_tensors: dict[str, object], - name: str, - *, - num_receivers: int, -) -> torch.Tensor: - bindings = population_spec.transition_parameter_bindings.get(name, (TransitionParameterBinding(source=name),)) - for binding in bindings: - tensor = _resolve_bound_transition_parameter( - population_params, - static_tensors, - binding, - num_receivers=num_receivers, - ) - if tensor is not None: - return tensor - raise RuntimeError(f"Fabric CUDA diagonal recurrence could not resolve parameter {name}") - - -def _resolve_bound_transition_parameter( - population_params: dict[str, object], - static_tensors: dict[str, object], - binding: TransitionParameterBinding, - *, - num_receivers: int, -) -> torch.Tensor | None: - if binding.kind == "cell_param": - tensor = population_params.get(binding.source) - return tensor if torch.is_tensor(tensor) else None - if binding.kind == "static_tensor": - tensor = static_tensors.get(binding.source) - return tensor if torch.is_tensor(tensor) else None - if binding.kind == "expanded_transposed_static_tensor": - tensor = static_tensors.get(binding.source) - if torch.is_tensor(tensor): - return tensor.transpose(0, 1).unsqueeze(0).expand(num_receivers, -1, -1) - return None - raise RuntimeError(f"Unsupported transition parameter binding kind {binding.kind}") + return static_grads, {"input_proj_weight": grad_input_proj_weight.transpose(1, 2).contiguous()} def _receiver_major_linear(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor: @@ -2067,27 +1594,17 @@ def _gated_recurrent_matmul(y_prev: torch.Tensor, recurrent_kernel: torch.Tensor ) -def _receiver_hidden_parameter(param: torch.Tensor, *, num_receivers: int, hidden_size: int) -> torch.Tensor: - return param.reshape(num_receivers, hidden_size) - - -def _scalar_parameter(value: object) -> float: - if torch.is_tensor(value): - return float(value.reshape(-1)[0].item()) - return float(value) - - -def _activation_id(value: object) -> int: - if torch.is_tensor(value) and value.numel() > 0: - return int(value.reshape(-1)[0].item()) - return 3 - - __all__ = [ - "TransitionBackwardTape", - "TransitionBackwardResult", - "TransitionForwardResult", - "lower_backend_population_transition_backward_shared", - "lower_backend_population_transition_forward_result_shared", - "lower_backend_population_transition_shared", + "_HeadGroupedGateLinearFunction", + "_RecurrentMatmulFunction", + "_SharedReceiverBiasLinearFunction", + "_diagonal_recurrence_input_projection", + "_diagonal_recurrence_input_projection_tape", + "_gated_logspace_input_projection_tape", + "_gated_recurrent_matmul_no_grad", + "_receiver_major_linear", + "_transition_linear_no_grad", + "_unfuse_recurrent_input_projection_grads", + "factorized_recurrent_input_prepack", + "reduce_transition_input_projection_param_grad_steps", ] diff --git a/src/cortical/fabric/backend/cuda/transition_execution/registry.py b/src/cortical/fabric/backend/cuda/transition_execution/registry.py new file mode 100644 index 00000000..d2b771d8 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/transition_execution/registry.py @@ -0,0 +1,956 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + + +TransitionProgramExecutorKind = str +TransitionPrimitiveExecutorStatus = Literal["cuda", "reference_only", "blocked"] +TransitionPrimitiveProgramLayerStatus = Literal["callable", "blocked"] + + +@dataclass(frozen=True) +class TransitionProgramTensorEdge: + producer_op_index: int + producer_output_index: int + consumer_op_index: int + consumer_input_index: int + + +@dataclass(frozen=True) +class TransitionProgramStateSlice: + op_index: int + op_field: Literal["inputs", "outputs"] + op_start: int + op_stop: int | None + state_field: Literal["state_inputs", "state_outputs"] + state_start: int + state_stop: int | None + + +@dataclass(frozen=True) +class TransitionProgramOpArity: + op_index: int + min_inputs: int = 0 + output_count: int | None = None + + +@dataclass(frozen=True) +class TransitionProgramMessageInput: + op_index: int + input_index: int + + +@dataclass(frozen=True) +class TransitionProgramExecutorRecord: + registry_id: str + executor: TransitionProgramExecutorKind + primitive_names: tuple[str, ...] + forward_strategy_id: str + backward_strategy_id: str + state_slices: tuple[TransitionProgramStateSlice, ...] + tensor_edges: tuple[TransitionProgramTensorEdge, ...] + arities: tuple[TransitionProgramOpArity, ...] = () + message_inputs: tuple[TransitionProgramMessageInput, ...] = () + + @property + def review_summary(self) -> str: + return ( + f"registry_id={self.registry_id}" + f",executor={self.executor}" + f",primitives={'+'.join(self.primitive_names)}" + f",forward_strategy_id={self.forward_strategy_id}" + f",backward_strategy_id={self.backward_strategy_id}" + f",state_slices={len(self.state_slices)}" + f",tensor_edges={len(self.tensor_edges)}" + f",arities={len(self.arities)}" + f",message_inputs={len(self.message_inputs)}" + ) + + +@dataclass(frozen=True) +class TransitionPrimitiveExecutorRecord: + primitive: str + status: TransitionPrimitiveExecutorStatus + forward_executor: str + backward_executor: str + tensor_role_contract: tuple[str, ...] + tape_policy: str + reference_executor: str + aliases: tuple[str, ...] = () + cuda_executor: str | None = None + blocker_code: str = "" + program_layer_status: TransitionPrimitiveProgramLayerStatus = "blocked" + program_forward_status: TransitionPrimitiveProgramLayerStatus = "blocked" + program_backward_status: TransitionPrimitiveProgramLayerStatus = "blocked" + program_forward_symbol: str = "" + program_backward_symbol: str = "" + program_forward_cxx_entrypoint: str = "" + program_forward_input_bindings: tuple[str, ...] = () + program_forward_parameter_bindings: tuple[tuple[str, bool], ...] = () + program_forward_output_bindings: tuple[tuple[str, bool], ...] = () + program_forward_output_contracts: tuple[tuple[str, str, str, str], ...] = () + program_reverse_native_callable: str = "" + program_layer_blocker_code: str = "MISSING_FUSED_TRANSITION_PRIMITIVE_EXECUTOR_BINDING" + param_grad_outputs: tuple[tuple[str, str, str], ...] = () + reverse_input_bindings: tuple[str, ...] = () + parameter_bindings: tuple[str, ...] = () + reverse_output_bindings: tuple[str, ...] = () + tape_saved_input_bindings: tuple[str, ...] = () + tape_saved_output_bindings: tuple[str, ...] = () + tape_recompute_input_bindings: tuple[str, ...] = () + tape_recompute_output_bindings: tuple[str, ...] = () + + @property + def review_summary(self) -> str: + cuda_executor = "-" if self.cuda_executor is None else self.cuda_executor + blocker = "-" if not self.blocker_code else self.blocker_code + program_forward = "-" if not self.program_forward_symbol else self.program_forward_symbol + program_backward = "-" if not self.program_backward_symbol else self.program_backward_symbol + program_blocker = "-" if not self.program_layer_blocker_code else self.program_layer_blocker_code + return ( + f"primitive={self.primitive}" + f",aliases={'+'.join(self.aliases) or '-'}" + f",status={self.status}" + f",forward_executor={self.forward_executor}" + f",backward_executor={self.backward_executor}" + f",roles={'+'.join(self.tensor_role_contract)}" + f",tape_policy={self.tape_policy}" + f",reference_executor={self.reference_executor}" + f",cuda_executor={cuda_executor}" + f",blocker_code={blocker}" + f",program_layer_status={self.program_layer_status}" + f",program_forward_status={self.program_forward_status}" + f",program_backward_status={self.program_backward_status}" + f",program_forward_symbol={program_forward}" + f",program_backward_symbol={program_backward}" + f",program_forward_cxx_entrypoint={self.program_forward_cxx_entrypoint or '-'}" + f",program_forward_input_bindings={len(self.program_forward_input_bindings)}" + f",program_forward_parameter_bindings={len(self.program_forward_parameter_bindings)}" + f",program_forward_output_bindings={len(self.program_forward_output_bindings)}" + f",program_forward_output_contracts={len(self.program_forward_output_contracts)}" + f",program_reverse_native_callable={self.program_reverse_native_callable or '-'}" + f",program_layer_blocker_code={program_blocker}" + f",param_grad_outputs={len(self.param_grad_outputs)}" + f",reverse_input_bindings={len(self.reverse_input_bindings)}" + f",parameter_bindings={len(self.parameter_bindings)}" + f",reverse_output_bindings={len(self.reverse_output_bindings)}" + f",tape_saved_inputs={len(self.tape_saved_input_bindings)}" + f",tape_saved_outputs={len(self.tape_saved_output_bindings)}" + f",tape_recompute_inputs={len(self.tape_recompute_input_bindings)}" + f",tape_recompute_outputs={len(self.tape_recompute_output_bindings)}" + ) + + +@dataclass(frozen=True) +class TransitionExecutorPrimitivePatternSpec: + primitive: str + parameter_inputs: tuple[str, ...] = () + + +@dataclass(frozen=True) +class TransitionExecutorProgramAccessSpec: + access_name: str + logical_name: str + binding_kind: str = "parameter" + required: bool = True + access_opcode: int = 0 + + +@dataclass(frozen=True) +class TransitionExecutorStrategySpec: + direction: Literal["forward", "reverse"] + executor_id: int + executor_name: str + row_pattern: tuple[TransitionExecutorPrimitivePatternSpec, ...] + implementation_contract: str + strategy_id: str + native_callable: str + cxx_entrypoints: tuple[str, ...] = () + strategy_version: int = 1 + required_effects: tuple[str, ...] = () + match_effects: tuple[str, ...] = () + program_accesses: tuple[TransitionExecutorProgramAccessSpec, ...] = () + state_carry_rules: tuple[tuple[str, str], ...] = () + + +def registered_transition_executor_records() -> tuple[TransitionProgramExecutorRecord, ...]: + return _REGISTERED_TRANSITION_EXECUTOR_RECORDS + + +def registered_transition_primitive_executor_records() -> tuple[TransitionPrimitiveExecutorRecord, ...]: + return _REGISTERED_TRANSITION_PRIMITIVE_EXECUTOR_RECORDS + + +def registered_transition_forward_strategy_specs() -> tuple[TransitionExecutorStrategySpec, ...]: + return _REGISTERED_TRANSITION_FORWARD_STRATEGY_SPECS + + +def registered_transition_reverse_strategy_specs() -> tuple[TransitionExecutorStrategySpec, ...]: + return _REGISTERED_TRANSITION_REVERSE_STRATEGY_SPECS + + +def transition_primitive_executor_record(primitive: str) -> TransitionPrimitiveExecutorRecord | None: + primitive = str(primitive) + for record in registered_transition_primitive_executor_records(): + if record.primitive == primitive: + return record + return None + + +def transition_primitive_executor_record_for_lowered_primitive( + primitive: str, +) -> TransitionPrimitiveExecutorRecord | None: + lowered_primitive = str(primitive) + for record in registered_transition_primitive_executor_records(): + if lowered_primitive == record.primitive or lowered_primitive in record.aliases: + return record + return None + + +def transition_primitive_program_contract_blocker_code(record: TransitionPrimitiveExecutorRecord) -> str: + if record.status != "cuda": + return record.blocker_code or "MISSING_CUDA_TRANSITION_PRIMITIVE_EXECUTOR" + if record.program_layer_status != "callable": + return record.program_layer_blocker_code or "MISSING_FUSED_TRANSITION_PRIMITIVE_EXECUTOR_BINDING" + if record.program_forward_status != "callable": + return record.program_layer_blocker_code or "MISSING_FUSED_TRANSITION_PRIMITIVE_FORWARD_EXECUTOR" + if record.program_backward_status != "callable": + return record.program_layer_blocker_code or "MISSING_FUSED_TRANSITION_PRIMITIVE_REVERSE_EXECUTOR" + if ( + not record.program_forward_symbol + or not record.program_forward_cxx_entrypoint + or not record.program_forward_input_bindings + or not record.program_forward_output_bindings + or not record.program_forward_output_contracts + ): + return "INCOMPLETE_FUSED_TRANSITION_PRIMITIVE_FORWARD_CONTRACT" + if ( + not record.program_backward_symbol + or not record.program_reverse_native_callable + or not record.reverse_input_bindings + or not record.reverse_output_bindings + ): + return "INCOMPLETE_FUSED_TRANSITION_PRIMITIVE_REVERSE_CONTRACT" + if not ( + record.tape_saved_input_bindings + or record.tape_saved_output_bindings + or record.tape_recompute_input_bindings + or record.tape_recompute_output_bindings + ): + return "INCOMPLETE_FUSED_TRANSITION_PRIMITIVE_TAPE_CONTRACT" + forward_inputs = tuple(str(item) for item in record.program_forward_input_bindings) + forward_parameters = tuple(str(name) for name, _required in record.program_forward_parameter_bindings) + forward_outputs = tuple(str(name) for name, _required in record.program_forward_output_bindings) + if _has_duplicates(forward_inputs) or _has_duplicates(forward_parameters) or _has_duplicates(forward_outputs): + return "INVALID_FUSED_TRANSITION_PRIMITIVE_FORWARD_BINDING_CONTRACT" + output_contract_names = tuple( + str(output_name) for output_name, _role, _shape, _index_source in record.program_forward_output_contracts + ) + if _has_duplicates(output_contract_names) or any( + str(shape_kind) not in {"hidden", "gate_logits", "diagonal_preproj"} + or str(index_source) not in {"primitive_row", "binding_index"} + for _output_name, _role, shape_kind, index_source in record.program_forward_output_contracts + ): + return "INVALID_FUSED_TRANSITION_PRIMITIVE_OUTPUT_CONTRACT" + if _first_missing( + (*record.tape_saved_input_bindings, *record.tape_recompute_input_bindings), + forward_inputs, + ): + return "INVALID_FUSED_TRANSITION_PRIMITIVE_TAPE_INPUT_CONTRACT" + if _first_missing( + (*record.tape_saved_output_bindings, *record.tape_recompute_output_bindings), + forward_outputs, + ): + return "INVALID_FUSED_TRANSITION_PRIMITIVE_TAPE_OUTPUT_CONTRACT" + reverse_inputs = tuple(str(item) for item in record.reverse_input_bindings) + reverse_parameters = tuple(str(item) for item in record.parameter_bindings) + reverse_outputs = tuple(str(item) for item in record.reverse_output_bindings) + if _has_duplicates(reverse_inputs) or _has_duplicates(reverse_parameters) or _has_duplicates(reverse_outputs): + return "INVALID_FUSED_TRANSITION_PRIMITIVE_REVERSE_BINDING_CONTRACT" + if _first_missing( + tuple(str(parameter_name) for _grad_name, parameter_name, _kind in record.param_grad_outputs), + reverse_parameters, + ): + return "INVALID_FUSED_TRANSITION_PRIMITIVE_PARAM_GRAD_CONTRACT" + return "" + + +def _has_duplicates(names: tuple[str, ...]) -> bool: + return len(set(names)) != len(names) + + +def _first_missing(names: tuple[str, ...], allowed_names: tuple[str, ...]) -> str: + allowed = set(allowed_names) + for name in names: + if str(name) not in allowed: + return str(name) + return "" + + +def transition_program_layer_blocker_codes(primitive_names: tuple[str, ...]) -> tuple[str, ...]: + blocker_codes: list[str] = [] + for primitive in dict.fromkeys(str(name) for name in primitive_names): + record = transition_primitive_executor_record_for_lowered_primitive(primitive) + if record is None: + blocker_codes.append("UNREGISTERED_TRANSITION_PRIMITIVE") + continue + blocker_code = transition_primitive_program_contract_blocker_code(record) + if blocker_code: + blocker_codes.append(blocker_code) + return tuple(dict.fromkeys(blocker_codes)) + + +def transition_program_layer_missing_symbols(primitive_names: tuple[str, ...]) -> tuple[str, ...]: + missing_symbols: list[str] = [] + for primitive in dict.fromkeys(str(name) for name in primitive_names): + record = transition_primitive_executor_record_for_lowered_primitive(primitive) + if record is None or record.program_layer_status == "callable": + continue + if record.program_forward_status != "callable" and record.program_forward_symbol: + missing_symbols.append(record.program_forward_symbol) + if record.program_backward_status != "callable" and record.program_backward_symbol: + missing_symbols.append(record.program_backward_symbol) + return tuple(dict.fromkeys(missing_symbols)) + + +_REGISTERED_TRANSITION_EXECUTOR_RECORDS = ( + TransitionProgramExecutorRecord( + registry_id="transition_executor:gated_logspace_recurrence:v1", + executor="gated_logspace_recurrence", + primitive_names=("linear", "linear", "matmul", "gated_logspace_recurrence", "norm_or_identity"), + forward_strategy_id="forward.transition.gated_logspace.v1", + backward_strategy_id="reverse.transition.gated_logspace.v1", + state_slices=( + TransitionProgramStateSlice( + op_index=3, + op_field="inputs", + op_start=2, + op_stop=None, + state_field="state_inputs", + state_start=1, + state_stop=None, + ), + TransitionProgramStateSlice( + op_index=3, + op_field="outputs", + op_start=0, + op_stop=None, + state_field="state_outputs", + state_start=0, + state_stop=None, + ), + ), + tensor_edges=( + TransitionProgramTensorEdge(0, 0, 1, 0), + TransitionProgramTensorEdge(3, 0, 4, 0), + ), + message_inputs=(TransitionProgramMessageInput(0, 0),), + ), + TransitionProgramExecutorRecord( + registry_id="transition_executor:diagonal_rtu:v1", + executor="diagonal_rtu", + primitive_names=("linear", "diag_rtu", "linear", "norm_or_identity"), + forward_strategy_id="forward.transition.diag_rtu.v1", + backward_strategy_id="reverse.transition.diag_rtu.v1", + state_slices=( + TransitionProgramStateSlice( + op_index=1, + op_field="inputs", + op_start=1, + op_stop=11, + state_field="state_inputs", + state_start=0, + state_stop=None, + ), + TransitionProgramStateSlice( + op_index=1, + op_field="outputs", + op_start=1, + op_stop=None, + state_field="state_outputs", + state_start=0, + state_stop=None, + ), + ), + tensor_edges=( + TransitionProgramTensorEdge(0, 0, 1, 0), + TransitionProgramTensorEdge(1, 0, 2, 0), + TransitionProgramTensorEdge(2, 0, 3, 0), + ), + arities=(TransitionProgramOpArity(1, min_inputs=15, output_count=11),), + message_inputs=(TransitionProgramMessageInput(0, 0),), + ), +) + + +_REGISTERED_TRANSITION_PRIMITIVE_EXECUTOR_RECORDS = ( + TransitionPrimitiveExecutorRecord( + primitive="linear", + status="cuda", + forward_executor="transition_linear_forward", + backward_executor="transition_linear_backward", + tensor_role_contract=("input", "weight", "bias", "output"), + tape_policy="input_projection_tape_or_recompute", + reference_executor="torch_linear_reference", + cuda_executor="receiver_major_affine_cuda", + program_layer_status="callable", + program_forward_status="callable", + program_backward_status="callable", + program_forward_symbol="program_transition_linear_forward", + program_backward_symbol="program_transition_linear_backward", + program_forward_cxx_entrypoint="run_registered_transition_linear_forward_primitive", + program_forward_input_bindings=("input",), + program_forward_parameter_bindings=(("weight", True), ("bias", False)), + program_forward_output_bindings=(("output", True),), + program_forward_output_contracts=( + ("transition_input", "transition_forward_linear_output", "hidden", "primitive_row"), + ("gate_logits", "transition_forward_linear_output", "gate_logits", "primitive_row"), + ("cell_input", "transition_forward_linear_output", "hidden", "primitive_row"), + ("public_y", "transition_forward_linear_output", "hidden", "primitive_row"), + ), + program_reverse_native_callable="native.reverse.transition_linear_primitive.v1", + program_layer_blocker_code="", + param_grad_outputs=( + ("grad_weight", "weight", "materialized"), + ("grad_bias", "bias", "materialized"), + ), + reverse_input_bindings=("input", "grad_output"), + parameter_bindings=("weight", "bias"), + reverse_output_bindings=("grad_input", "grad_weight", "grad_bias"), + tape_saved_input_bindings=("input",), + tape_recompute_input_bindings=("input",), + tape_recompute_output_bindings=("output",), + ), + TransitionPrimitiveExecutorRecord( + primitive="matmul", + status="cuda", + forward_executor="transition_recurrent_matmul_forward", + backward_executor="transition_recurrent_matmul_backward", + tensor_role_contract=("input", "weight", "output"), + tape_policy="recompute_or_full_tape", + reference_executor="torch_matmul_reference", + cuda_executor="gated_recurrent_matmul_cuda", + program_layer_status="callable", + program_forward_status="callable", + program_backward_status="callable", + program_forward_symbol="program_transition_recurrent_matmul_forward", + program_backward_symbol="program_transition_recurrent_matmul_backward", + program_forward_cxx_entrypoint="run_registered_transition_matmul_forward_primitive", + program_forward_input_bindings=("input",), + program_forward_parameter_bindings=(("weight", True),), + program_forward_output_bindings=(("output", True),), + program_forward_output_contracts=( + ("recurrent_gate_logits", "transition_forward_matmul_output", "gate_logits", "primitive_row"), + ), + program_reverse_native_callable="native.reverse.transition_matmul_primitive.v1", + program_layer_blocker_code="", + param_grad_outputs=(("grad_weight", "weight", "materialized"),), + reverse_input_bindings=("input", "grad_output"), + parameter_bindings=("weight",), + reverse_output_bindings=("grad_input", "grad_weight"), + tape_saved_input_bindings=("input",), + tape_recompute_input_bindings=("input",), + tape_recompute_output_bindings=("output",), + ), + TransitionPrimitiveExecutorRecord( + primitive="gated_logspace_recurrence", + status="cuda", + forward_executor="transition_gated_logspace_recurrence_forward", + backward_executor="transition_gated_logspace_recurrence_backward", + tensor_role_contract=("gate_logits", "recurrent_gate_logits", "state", "next_state"), + tape_policy="full_gate_logits_or_recompute", + reference_executor="torch_gated_logspace_recurrence_reference", + cuda_executor="gated_logspace_recurrence_outnorm_cuda", + program_layer_status="callable", + program_forward_status="callable", + program_backward_status="callable", + program_forward_symbol="program_transition_gated_logspace_recurrence_forward", + program_backward_symbol="program_transition_gated_logspace_recurrence_backward", + program_forward_cxx_entrypoint="run_registered_transition_gated_logspace_forward_primitive", + program_forward_input_bindings=( + "gate_logits", + "recurrent_gate_logits", + "c_prev", + "n_prev", + "m_prev", + ), + program_forward_output_bindings=( + ("next_y", True), + ("next_c", False), + ("next_n", False), + ("next_m", False), + ), + program_forward_output_contracts=( + ("next_y", "transition_forward_state_output", "hidden", "binding_index"), + ("next_c", "transition_forward_state_output", "hidden", "binding_index"), + ("next_n", "transition_forward_state_output", "hidden", "binding_index"), + ("next_m", "transition_forward_state_output", "hidden", "binding_index"), + ), + program_reverse_native_callable="native.reverse.transition_gated_logspace.v1", + program_layer_blocker_code="", + param_grad_outputs=( + ("grad_value_to_state_weight", "value_to_state_weight", "input_projection_weight"), + ("grad_recurrent_bias", "recurrent_bias", "input_projection_bias"), + ("grad_gate_weight", "gate_weight", "materialized"), + ("grad_bias", "bias", "materialized"), + ("grad_recurrent_kernel", "recurrent_kernel", "materialized"), + ("grad_outnorm_weight", "outnorm_weight", "materialized"), + ), + reverse_input_bindings=( + "aggregated_message", + "transition_input", + "gate_logits", + "recurrent_gate_logits", + "y", + "c", + "n", + "m", + "next_y", + "grad_public_y", + "grad_next_y", + "grad_next_c", + "grad_next_n", + "grad_next_m", + ), + parameter_bindings=( + "value_to_state_weight", + "recurrent_bias", + "gate_weight", + "bias", + "recurrent_kernel", + "outnorm_weight", + "outnorm_eps", + ), + reverse_output_bindings=( + "grad_aggregated_message", + "grad_y", + "grad_c", + "grad_n", + "grad_m", + "grad_value_to_state_weight", + "grad_recurrent_bias", + "grad_gate_weight", + "grad_bias", + "grad_recurrent_kernel", + "grad_outnorm_weight", + ), + tape_saved_input_bindings=("gate_logits", "recurrent_gate_logits"), + tape_saved_output_bindings=("next_y",), + tape_recompute_input_bindings=("gate_logits", "recurrent_gate_logits", "c_prev", "n_prev", "m_prev"), + tape_recompute_output_bindings=("next_y", "next_c", "next_n", "next_m"), + ), + TransitionPrimitiveExecutorRecord( + primitive="norm_or_identity", + status="cuda", + forward_executor="transition_public_norm_or_identity_forward", + backward_executor="transition_public_norm_or_identity_backward", + tensor_role_contract=("input", "weight", "output"), + tape_policy="recompute_or_full_tape", + reference_executor="torch_norm_or_identity_reference", + cuda_executor="gated_logspace_recurrence_outnorm_cuda", + program_layer_status="callable", + program_forward_status="callable", + program_backward_status="callable", + program_forward_symbol="program_transition_norm_or_identity_forward", + program_backward_symbol="program_transition_norm_or_identity_backward", + program_forward_cxx_entrypoint="run_registered_transition_norm_or_identity_forward_primitive", + program_forward_input_bindings=("input",), + program_forward_parameter_bindings=(("weight", True), ("eps", False)), + program_forward_output_bindings=(("output", True),), + program_forward_output_contracts=(("public_y", "transition_forward_norm_output", "hidden", "primitive_row"),), + program_reverse_native_callable="native.reverse.transition_norm_or_identity_primitive.v1", + program_layer_blocker_code="", + param_grad_outputs=(("grad_weight", "weight", "materialized"),), + reverse_input_bindings=("input", "grad_output"), + parameter_bindings=("weight", "eps"), + reverse_output_bindings=("grad_input", "grad_weight"), + tape_saved_input_bindings=("input",), + tape_recompute_input_bindings=("input",), + tape_recompute_output_bindings=("output",), + ), + TransitionPrimitiveExecutorRecord( + primitive="diag_rtu", + aliases=("diagonal_recurrence",), + status="cuda", + forward_executor="transition_diag_rtu_forward", + backward_executor="transition_diag_rtu_backward", + tensor_role_contract=("cell_input", "state", "params", "preproj", "next_state"), + tape_policy="diagonal_preproj_tape_or_recompute", + reference_executor="torch_diag_rtu_reference", + cuda_executor="diagonal_recurrence_forward_cuda", + program_layer_status="callable", + program_forward_status="callable", + program_backward_status="callable", + program_forward_symbol="program_transition_diag_rtu_forward", + program_backward_symbol="program_transition_diag_rtu_backward", + program_forward_cxx_entrypoint="run_registered_transition_diag_rtu_forward_primitive", + program_forward_input_bindings=( + "cell_input", + "hc1", + "hc2", + "E_nu_c1", + "E_nu_c2", + "E_th_c1", + "E_th_c2", + "E_w1_c1", + "E_w1_c2", + "E_w2_c1", + "E_w2_c2", + ), + program_forward_parameter_bindings=( + ("nu_log", True), + ("theta_log", True), + ("w1", True), + ("w2", True), + ("activation_id", False), + ), + program_forward_output_bindings=( + ("preproj", True), + ("next_hc1", False), + ("next_hc2", False), + ("next_E_nu_c1", False), + ("next_E_nu_c2", False), + ("next_E_th_c1", False), + ("next_E_th_c2", False), + ("next_E_w1_c1", False), + ("next_E_w1_c2", False), + ("next_E_w2_c1", False), + ("next_E_w2_c2", False), + ), + program_forward_output_contracts=( + ("preproj", "transition_forward_diag_output", "diagonal_preproj", "binding_index"), + ("next_hc1", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_hc2", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_E_nu_c1", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_E_nu_c2", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_E_th_c1", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_E_th_c2", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_E_w1_c1", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_E_w1_c2", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_E_w2_c1", "transition_forward_diag_output", "hidden", "binding_index"), + ("next_E_w2_c2", "transition_forward_diag_output", "hidden", "binding_index"), + ), + program_reverse_native_callable="native.reverse.transition_diag_rtu.v1", + program_layer_blocker_code="", + param_grad_outputs=( + ("grad_input_proj_weight", "input_proj_weight", "input_projection_weight"), + ("grad_recurrent_cell_bias", "recurrent_cell_bias", "input_projection_bias"), + ("grad_nu_log", "nu_log", "materialized"), + ("grad_theta_log", "theta_log", "materialized"), + ("grad_w1", "w1", "materialized"), + ("grad_w2", "w2", "materialized"), + ("grad_out_proj_weight", "out_proj_weight", "materialized"), + ("grad_out_proj_bias", "out_proj_bias", "materialized"), + ("grad_outnorm_weight", "outnorm_weight", "materialized"), + ), + reverse_input_bindings=( + "aggregated_message", + "cell_input", + "hc1", + "hc2", + "preproj", + "public_y_raw", + "grad_public_y", + "grad_next_hc1", + "grad_next_hc2", + ), + parameter_bindings=( + "input_proj_weight", + "recurrent_cell_bias", + "nu_log", + "theta_log", + "w1", + "w2", + "out_proj_weight", + "out_proj_bias", + "activation_id", + "outnorm_weight", + "outnorm_eps", + ), + reverse_output_bindings=( + "grad_aggregated_message", + "grad_hc1", + "grad_hc2", + "grad_input_proj_weight", + "grad_recurrent_cell_bias", + "grad_nu_log", + "grad_theta_log", + "grad_w1", + "grad_w2", + "grad_out_proj_weight", + "grad_out_proj_bias", + "grad_outnorm_weight", + ), + tape_saved_input_bindings=("cell_input", "hc1", "hc2"), + tape_saved_output_bindings=("preproj",), + tape_recompute_input_bindings=( + "cell_input", + "hc1", + "hc2", + "E_nu_c1", + "E_nu_c2", + "E_th_c1", + "E_th_c2", + "E_w1_c1", + "E_w1_c2", + "E_w2_c1", + "E_w2_c2", + ), + tape_recompute_output_bindings=( + "preproj", + "next_hc1", + "next_hc2", + "next_E_nu_c1", + "next_E_nu_c2", + "next_E_th_c1", + "next_E_th_c2", + "next_E_w1_c1", + "next_E_w1_c2", + "next_E_w2_c1", + "next_E_w2_c2", + ), + ), + TransitionPrimitiveExecutorRecord( + primitive="tanh", + status="cuda", + forward_executor="transition_tanh_forward", + backward_executor="transition_tanh_backward", + tensor_role_contract=("input", "output"), + tape_policy="input_tape_or_recompute", + reference_executor="torch_tanh_reference", + cuda_executor="elementwise_tanh_cuda", + program_layer_status="callable", + program_forward_status="callable", + program_backward_status="callable", + program_forward_symbol="program_transition_tanh_forward", + program_backward_symbol="program_transition_tanh_backward", + program_forward_cxx_entrypoint="run_registered_transition_tanh_forward_primitive", + program_forward_input_bindings=("input",), + program_forward_output_bindings=(("output", True),), + program_forward_output_contracts=(("output", "transition_forward_unary_output", "hidden", "primitive_row"),), + program_reverse_native_callable="native.reverse.transition_tanh.v1", + program_layer_blocker_code="", + reverse_input_bindings=("output", "grad_output"), + reverse_output_bindings=("grad_input",), + tape_saved_output_bindings=("output",), + tape_recompute_input_bindings=("input",), + tape_recompute_output_bindings=("output",), + ), +) + + +_TRANSITION_AGGREGATE_PROGRAM_ACCESSES = ( + TransitionExecutorProgramAccessSpec( + "transition_aggregated_message_input", + "aggregated_message", + "input", + access_opcode=8, + ), + TransitionExecutorProgramAccessSpec( + "transition_public_state_output", + "public_y", + "output", + access_opcode=9, + ), +) + +_TRANSITION_FORWARD_EFFECTS = ("state_read", "message_read", "state_write", "tape_policy") +_TRANSITION_REVERSE_EFFECTS = ("tape_read", "state_grad_emit", "parameter_grad_emit") + +_REGISTERED_TRANSITION_FORWARD_STRATEGY_SPECS = ( + TransitionExecutorStrategySpec( + direction="forward", + executor_id=3, + executor_name="gated_logspace_transition", + row_pattern=( + TransitionExecutorPrimitivePatternSpec("linear", ("value_to_state_weight", "recurrent_bias")), + TransitionExecutorPrimitivePatternSpec("linear", ("gate_weight", "bias")), + TransitionExecutorPrimitivePatternSpec("matmul", ("recurrent_kernel",)), + TransitionExecutorPrimitivePatternSpec("gated_logspace_recurrence"), + TransitionExecutorPrimitivePatternSpec("norm_or_identity", ("outnorm_weight", "outnorm_eps")), + ), + implementation_contract="registered_gated_logspace_transition_executor_binding_rows", + strategy_id="forward.transition.gated_logspace.v1", + native_callable="native.forward.transition_gated_logspace.v1", + required_effects=_TRANSITION_FORWARD_EFFECTS, + match_effects=_TRANSITION_FORWARD_EFFECTS, + program_accesses=_TRANSITION_AGGREGATE_PROGRAM_ACCESSES, + state_carry_rules=(("y", "next_y"), ("c", "next_c"), ("n", "next_n"), ("m", "next_m")), + ), + TransitionExecutorStrategySpec( + direction="forward", + executor_id=4, + executor_name="diag_rtu_transition", + row_pattern=( + TransitionExecutorPrimitivePatternSpec("linear", ("input_proj_weight", "recurrent_cell_bias")), + TransitionExecutorPrimitivePatternSpec("diag_rtu", ("nu_log", "theta_log", "w1", "w2", "activation_id")), + TransitionExecutorPrimitivePatternSpec("linear", ("out_proj_weight", "out_proj_bias")), + TransitionExecutorPrimitivePatternSpec("norm_or_identity", ("outnorm_weight", "outnorm_eps")), + ), + implementation_contract="registered_diag_rtu_transition_executor_binding_rows", + strategy_id="forward.transition.diag_rtu.v1", + native_callable="native.forward.transition_diag_rtu.v1", + required_effects=_TRANSITION_FORWARD_EFFECTS, + match_effects=_TRANSITION_FORWARD_EFFECTS, + program_accesses=_TRANSITION_AGGREGATE_PROGRAM_ACCESSES, + state_carry_rules=( + ("hc1", "next_hc1"), + ("hc2", "next_hc2"), + ("E_nu_c1", "next_E_nu_c1"), + ("E_nu_c2", "next_E_nu_c2"), + ("E_th_c1", "next_E_th_c1"), + ("E_th_c2", "next_E_th_c2"), + ("E_w1_c1", "next_E_w1_c1"), + ("E_w1_c2", "next_E_w1_c2"), + ("E_w2_c1", "next_E_w2_c1"), + ("E_w2_c2", "next_E_w2_c2"), + ), + ), + TransitionExecutorStrategySpec( + direction="forward", + executor_id=7, + executor_name="transition_linear_primitive", + row_pattern=(TransitionExecutorPrimitivePatternSpec("linear", ("*",)),), + implementation_contract="registered_transition_linear_primitive_executor_binding_rows", + strategy_id="forward.transition.linear_primitive.v1", + native_callable="native.forward.transition_linear_primitive.v1", + required_effects=_TRANSITION_FORWARD_EFFECTS, + match_effects=_TRANSITION_FORWARD_EFFECTS, + ), + TransitionExecutorStrategySpec( + direction="forward", + executor_id=8, + executor_name="transition_matmul_primitive", + row_pattern=(TransitionExecutorPrimitivePatternSpec("matmul", ("*",)),), + implementation_contract="registered_transition_matmul_primitive_executor_binding_rows", + strategy_id="forward.transition.matmul_primitive.v1", + native_callable="native.forward.transition_matmul_primitive.v1", + required_effects=_TRANSITION_FORWARD_EFFECTS, + match_effects=_TRANSITION_FORWARD_EFFECTS, + ), + TransitionExecutorStrategySpec( + direction="forward", + executor_id=9, + executor_name="transition_norm_or_identity_primitive", + row_pattern=(TransitionExecutorPrimitivePatternSpec("norm_or_identity", ("*",)),), + implementation_contract="registered_transition_norm_or_identity_primitive_executor_binding_rows", + strategy_id="forward.transition.norm_or_identity_primitive.v1", + native_callable="native.forward.transition_norm_or_identity_primitive.v1", + required_effects=_TRANSITION_FORWARD_EFFECTS, + match_effects=_TRANSITION_FORWARD_EFFECTS, + ), + TransitionExecutorStrategySpec( + direction="forward", + executor_id=6, + executor_name="tanh_transition", + row_pattern=(TransitionExecutorPrimitivePatternSpec("tanh"),), + implementation_contract="registered_tanh_transition_executor_binding_rows", + strategy_id="forward.transition.tanh.v1", + native_callable="native.forward.transition_tanh.v1", + required_effects=_TRANSITION_FORWARD_EFFECTS, + match_effects=_TRANSITION_FORWARD_EFFECTS, + ), +) + +_REGISTERED_TRANSITION_REVERSE_STRATEGY_SPECS = ( + TransitionExecutorStrategySpec( + direction="reverse", + executor_id=2, + executor_name="gated_logspace_transition_backward", + row_pattern=(TransitionExecutorPrimitivePatternSpec("gated_logspace_recurrence"),), + implementation_contract="registered_gated_logspace_reverse_executor_binding_rows", + strategy_id="reverse.transition.gated_logspace.v1", + native_callable="native.reverse.transition_gated_logspace.v1", + cxx_entrypoints=("run_registered_gated_logspace_reverse_transition_handler",), + required_effects=_TRANSITION_REVERSE_EFFECTS, + ), + *( + TransitionExecutorStrategySpec( + direction="reverse", + executor_id=3, + executor_name="diag_rtu_transition_backward", + row_pattern=( + TransitionExecutorPrimitivePatternSpec( + primitive, + ("nu_log", "theta_log", "w1", "w2", "activation_id"), + ), + ), + implementation_contract=f"registered_{primitive}_reverse_executor_binding_rows", + strategy_id=f"reverse.transition.{primitive}.v1", + native_callable="native.reverse.transition_diag_rtu.v1", + cxx_entrypoints=("run_registered_diag_rtu_reverse_transition_handler",), + required_effects=_TRANSITION_REVERSE_EFFECTS, + ) + for primitive in ("diag_rtu", "diagonal_recurrence") + ), + TransitionExecutorStrategySpec( + direction="reverse", + executor_id=7, + executor_name="transition_linear_primitive_backward", + row_pattern=(TransitionExecutorPrimitivePatternSpec("linear", ("*",)),), + implementation_contract="registered_transition_linear_primitive_reverse_executor_binding_rows", + strategy_id="reverse.transition.linear_primitive.v1", + native_callable="native.reverse.transition_linear_primitive.v1", + cxx_entrypoints=("run_registered_linear_reverse_transition_handler",), + required_effects=("tape_read", "message_grad_emit", "parameter_grad_emit"), + ), + TransitionExecutorStrategySpec( + direction="reverse", + executor_id=8, + executor_name="transition_matmul_primitive_backward", + row_pattern=(TransitionExecutorPrimitivePatternSpec("matmul", ("*",)),), + implementation_contract="registered_transition_matmul_primitive_reverse_executor_binding_rows", + strategy_id="reverse.transition.matmul_primitive.v1", + native_callable="native.reverse.transition_matmul_primitive.v1", + cxx_entrypoints=("run_registered_matmul_reverse_transition_handler",), + required_effects=_TRANSITION_REVERSE_EFFECTS, + ), + TransitionExecutorStrategySpec( + direction="reverse", + executor_id=9, + executor_name="transition_norm_or_identity_primitive_backward", + row_pattern=(TransitionExecutorPrimitivePatternSpec("norm_or_identity", ("*",)),), + implementation_contract="registered_transition_norm_or_identity_primitive_reverse_executor_binding_rows", + strategy_id="reverse.transition.norm_or_identity_primitive.v1", + native_callable="native.reverse.transition_norm_or_identity_primitive.v1", + cxx_entrypoints=("run_registered_norm_or_identity_reverse_transition_handler",), + required_effects=_TRANSITION_REVERSE_EFFECTS, + ), + TransitionExecutorStrategySpec( + direction="reverse", + executor_id=6, + executor_name="tanh_transition_backward", + row_pattern=(TransitionExecutorPrimitivePatternSpec("tanh"),), + implementation_contract="registered_tanh_reverse_executor_binding_rows", + strategy_id="reverse.transition.tanh.v1", + native_callable="native.reverse.transition_tanh.v1", + cxx_entrypoints=("run_registered_tanh_reverse_transition_handler",), + required_effects=("tape_read", "state_grad_emit"), + ), +) + + +__all__ = [ + "TransitionProgramExecutorKind", + "TransitionExecutorPrimitivePatternSpec", + "TransitionExecutorProgramAccessSpec", + "TransitionExecutorStrategySpec", + "TransitionPrimitiveExecutorRecord", + "TransitionPrimitiveProgramLayerStatus", + "TransitionPrimitiveExecutorStatus", + "TransitionProgramExecutorRecord", + "TransitionProgramMessageInput", + "TransitionProgramOpArity", + "TransitionProgramStateSlice", + "TransitionProgramTensorEdge", + "registered_transition_forward_strategy_specs", + "registered_transition_executor_records", + "registered_transition_primitive_executor_records", + "registered_transition_reverse_strategy_specs", + "transition_primitive_executor_record", + "transition_primitive_executor_record_for_lowered_primitive", + "transition_primitive_program_contract_blocker_code", + "transition_program_layer_blocker_codes", + "transition_program_layer_missing_symbols", +] diff --git a/src/cortical/fabric/backend/cuda/transition_execution/types.py b/src/cortical/fabric/backend/cuda/transition_execution/types.py new file mode 100644 index 00000000..5681fe51 --- /dev/null +++ b/src/cortical/fabric/backend/cuda/transition_execution/types.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import torch + + +@dataclass(frozen=True) +class TransitionBackwardResult: + grad_recurrent_msg: torch.Tensor | None + grad_packed_state_before: dict[str, torch.Tensor | None] + materialized_param_grads: dict[str, torch.Tensor] + static_source_grads: dict[str, torch.Tensor] + input_projection_param_grad_step: TransitionInputProjectionParamGradStep | None = None + deferred_param_grad_steps: tuple[TransitionInputProjectionParamGradStep, ...] = () + + +@dataclass(frozen=True) +class TransitionInputProjectionParamGradStep: + input_tensor: torch.Tensor + weight: torch.Tensor + bias: torch.Tensor | None + grad_output: torch.Tensor + group_key: tuple[object, ...] + map_param_grads: Callable[ + [torch.Tensor, torch.Tensor | None], + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], + ] + + +@dataclass(frozen=True) +class _ProjectionBackwardTape: + output: torch.Tensor + backward: Callable[ + [torch.Tensor | None], + tuple[torch.Tensor | None, dict[str, torch.Tensor], dict[str, torch.Tensor]], + ] + backward_deferred: ( + Callable[ + [torch.Tensor | None], + tuple[ + torch.Tensor | None, + dict[str, torch.Tensor], + dict[str, torch.Tensor], + TransitionInputProjectionParamGradStep | None, + ], + ] + | None + ) = None + + +@dataclass(frozen=True) +class TransitionBackwardTape: + input_projection: _ProjectionBackwardTape | None = None + diagonal_preproj: torch.Tensor | None = None + gated_gate_logits: torch.Tensor | None = None + gated_recurrent_gate_logits: torch.Tensor | None = None + + +@dataclass(frozen=True) +class TransitionForwardResult: + next_packed_state: Any + recurrent_hidden: torch.Tensor + recurrent_k: torch.Tensor | None + recurrent_v: torch.Tensor | None + backward_tape: TransitionBackwardTape | None = None + + +__all__ = [ + "TransitionBackwardResult", + "TransitionBackwardTape", + "TransitionForwardResult", + "TransitionInputProjectionParamGradStep", +] diff --git a/src/cortical/fabric/backend/flat_bucket_identity.py b/src/cortical/fabric/backend/flat_bucket_identity.py new file mode 100644 index 00000000..9da4c171 --- /dev/null +++ b/src/cortical/fabric/backend/flat_bucket_identity.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping, Sequence +from typing import Any + + +def transition_flat_bucket_identity( + *, + state_schema_keys: Iterable[str], + public_schema_kind: str, + parameter_schema_keys: Iterable[str] = (), + input_projection_schema_keys: Iterable[str] = (), + public_projection_schema_keys: Iterable[str] = (), + transition_ir: Any | None = None, + transition_parameter_bindings: Mapping[str, Sequence[Any]] | None = None, +) -> tuple[str, ...]: + return ( + "flat_bucket_identity", + f"state_schema={_join(state_schema_keys)}", + f"public_schema={public_schema_kind}", + f"parameter_schema={_join(parameter_schema_keys)}", + f"input_projection_schema={_join(input_projection_schema_keys)}", + f"public_projection_schema={_join(public_projection_schema_keys)}", + f"transition_state_inputs={_join(_transition_field(transition_ir, 'state_inputs'))}", + f"transition_message_inputs={_join(_transition_field(transition_ir, 'message_inputs'))}", + f"transition_parameter_inputs={_join(_transition_field(transition_ir, 'parameter_inputs'))}", + f"transition_state_outputs={_join(_transition_field(transition_ir, 'state_outputs'))}", + f"transition_public_outputs={_join(_transition_field(transition_ir, 'public_outputs'))}", + f"transition_recompute_outputs={_join(_transition_field(transition_ir, 'recompute_outputs'))}", + f"transition_backward_decomposition={_join(_transition_field(transition_ir, 'backward_decomposition'))}", + f"transition_ops={_transition_ops_signature(transition_ir)}", + f"transition_parameter_bindings={_transition_parameter_binding_signature(transition_parameter_bindings)}", + ) + + +def _join(items: Iterable[Any]) -> str: + return ",".join(str(item) for item in items) + + +def _transition_field(transition_ir: Any | None, field: str) -> tuple[Any, ...]: + if transition_ir is None: + return () + return tuple(getattr(transition_ir, field, ())) + + +def _transition_ops_signature(transition_ir: Any | None) -> str: + if transition_ir is None: + return "" + parts: list[str] = [] + for op in getattr(transition_ir, "ops", ()): + name = str(getattr(op, "name", "")) + inputs = _join(getattr(op, "inputs", ())) + outputs = _join(getattr(op, "outputs", ())) + attributes = tuple(getattr(op, "attributes", ())) + if attributes: + attr_text = ",".join(f"{key}={value}" for key, value in attributes) + parts.append(f"{name}({inputs})->({outputs})[{attr_text}]") + else: + parts.append(f"{name}({inputs})->({outputs})") + return "|".join(parts) + + +def _transition_parameter_binding_signature(bindings: Mapping[str, Sequence[Any]] | None) -> str: + if not bindings: + return "" + entries: list[str] = [] + for logical_name in sorted(bindings): + binding_parts: list[str] = [] + for binding in bindings[logical_name]: + source = str(getattr(binding, "source", "")) + kind = str(getattr(binding, "kind", "")) + binding_parts.append(f"{source}:{kind}") + entries.append(f"{logical_name}=({','.join(binding_parts)})") + return "|".join(entries) diff --git a/src/cortical/fabric/backend/ir.py b/src/cortical/fabric/backend/ir.py index fb1d610f..619876be 100644 --- a/src/cortical/fabric/backend/ir.py +++ b/src/cortical/fabric/backend/ir.py @@ -6,8 +6,27 @@ from cortical.fabric.anatomy import Spec from cortical.fabric.backend.buckets import FabricBucket, ReceiverKind -from cortical.fabric.backend.message_rules import MessageRuleSummary, default_dot_product_message_rule_summary +from cortical.fabric.backend.cell_backend import ( + CellBackendSpec, + CompiledTransitionProgram, + build_cell_backend_spec, + compile_transition_program, +) +from cortical.fabric.backend.flat_bucket_identity import transition_flat_bucket_identity +from cortical.fabric.backend.message_rules import ( + CompiledMessageRule, + MessageRuleIR, + compile_message_rule, + default_dot_product_message_rule_ir, +) +from cortical.fabric.backend.readout_rules import ( + CompiledReadoutRule, + ReadoutRuleIR, + compile_readout_rule, + default_readout_rule_ir, +) from cortical.fabric.graph import GraphTopologySummary +from cortical.fabric.registry.cells import get_cell_spec @dataclass(frozen=True) @@ -26,11 +45,14 @@ class FabricIR: num_input_ports: int num_recurrent_cells: int num_output_ports: int - wrap: bool delay_depth: int kv_group_count: int graph_summary: GraphTopologySummary - message_rule: MessageRuleSummary + message_rule: MessageRuleIR + message_program: CompiledMessageRule + readout_rule: ReadoutRuleIR + readout_program: CompiledReadoutRule + transition_programs: tuple[CompiledTransitionProgram, ...] receiver_sets: tuple[ReceiverSetSummary, ...] buckets: tuple[FabricBucket, ...] @@ -38,6 +60,12 @@ class FabricIR: def bucket_count(self) -> int: return len(self.buckets) + def transition_program_for_binding_slot(self, binding_slot: int) -> CompiledTransitionProgram: + try: + return self.transition_programs[int(binding_slot)] + except IndexError as exc: + raise RuntimeError(f"Fabric backend IR has no transition program for binding_slot={binding_slot}") from exc + def compile_fabric_ir( spec: Spec, @@ -48,6 +76,30 @@ def compile_fabric_ir( head_dim: int, value_dim: int, ) -> FabricIR: + message_rule = spec.message_rule + if message_rule is None: + message_rule = default_dot_product_message_rule_ir( + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + ) + message_program = compile_message_rule(message_rule) + readout_rule = default_readout_rule_ir( + readout_pool=str(spec.config.readout.pool), + readout_slots=int(spec.config.readout.slots), + ) + readout_program = compile_readout_rule(readout_rule) + population_backend_specs = _compile_population_backend_specs( + spec, + hidden_size=hidden_size, + d_public=d_public, + d_msg=d_msg, + head_dim=head_dim, + value_dim=value_dim, + ) + transition_programs = tuple( + compile_transition_program(backend_spec, binding_slot=binding_slot) + for binding_slot, backend_spec in enumerate(population_backend_specs) + ) delay_depth = 1 if spec.anatomy.edge_delay is not None and spec.anatomy.edge_delay.numel() > 0: delay_depth = int(spec.anatomy.edge_delay.max().item()) @@ -83,6 +135,7 @@ def compile_fabric_ir( value_dim=value_dim, delay_depth=delay_depth, bucket_id_start=bucket_id, + population_backend_specs=population_backend_specs, ) ) bucket_id += len(buckets) @@ -98,6 +151,7 @@ def compile_fabric_ir( value_dim=value_dim, delay_depth=delay_depth, bucket_id_start=bucket_id, + population_backend_specs=population_backend_specs, ) ) return FabricIR( @@ -106,19 +160,41 @@ def compile_fabric_ir( num_input_ports=int(spec.input_cell_idx.numel()), num_recurrent_cells=int(spec.recurrent_cell_idx.numel()), num_output_ports=int(spec.output_cell_idx.numel()), - wrap=bool(spec.config.wrap), delay_depth=delay_depth, kv_group_count=int(spec.num_kv_groups), graph_summary=graph_summary, - message_rule=default_dot_product_message_rule_summary( - kv_group_count=int(spec.num_kv_groups), - cell_count=int(spec.anatomy.num_cells), - ), + message_rule=message_rule, + message_program=message_program, + readout_rule=readout_rule, + readout_program=readout_program, + transition_programs=transition_programs, receiver_sets=receiver_sets, buckets=tuple(buckets), ) +def _compile_population_backend_specs( + spec: Spec, + *, + hidden_size: int, + d_public: int, + d_msg: int, + head_dim: int, + value_dim: int, +) -> tuple[CellBackendSpec, ...]: + return tuple( + build_cell_backend_spec( + cell_type=spec.config.populations.cell_populations[population_name].cell_type, + hidden_size=hidden_size, + d_public=d_public, + d_msg=d_msg, + head_dim=head_dim, + value_dim=value_dim, + ) + for population_name in spec.population_names + ) + + def _make_receiver_summary( spec: Spec, *, @@ -152,6 +228,7 @@ def _compile_buckets_for_receivers( value_dim: int, delay_depth: int, bucket_id_start: int, + population_backend_specs: tuple[CellBackendSpec, ...], ) -> list[FabricBucket]: if receiver_indices.numel() == 0: return [] @@ -161,19 +238,29 @@ def _compile_buckets_for_receivers( degree = neighbor_valid.sum(dim=1) template_ids = _template_ids(local_valid) dim_signature = (hidden_size, d_public, d_msg, head_dim, value_dim) - raw_populations = ( + raw_population_indices = ( spec.anatomy.cell_layout.index_select(0, receiver_indices) if receiver_kind == ReceiverKind.RECURRENT_CELL else torch.full((receiver_indices.numel(),), -1, dtype=torch.long) ) - population_names = [] - for population_idx in raw_populations.tolist(): - population_names.append("readout" if population_idx < 0 else spec.population_names[population_idx]) + transition_signatures: list[tuple[str, ...]] = [] + parameter_bindings: list[str] = [] + for population_idx in raw_population_indices.tolist(): + transition_signatures.append( + _transition_signature_for_population( + spec, + int(population_idx), + population_backend_specs=population_backend_specs, + ) + ) + parameter_bindings.append(_parameter_binding_for_population(int(population_idx))) grouped = {} sharing_pattern = "grouped_kv" if spec.num_kv_groups < spec.anatomy.num_cells else "per_cell_kv" for local_idx, receiver_idx in enumerate(receiver_indices.tolist()): key = ( - population_names[local_idx], + int(raw_population_indices[local_idx].item()), + transition_signatures[local_idx], + parameter_bindings[local_idx], _degree_bin(int(degree[local_idx].item())), int(template_ids[local_idx].item()), bool((edge_type[local_idx] == 1).any().item()), @@ -181,14 +268,25 @@ def _compile_buckets_for_receivers( grouped.setdefault(key, []).append((local_idx, receiver_idx)) buckets: list[FabricBucket] = [] next_bucket_id = bucket_id_start - for (population_name, degree_bin, template_id, has_sparse_overlay), entries in grouped.items(): + for ( + population_index, + transition_signature, + parameter_binding, + degree_bin, + template_id, + has_sparse_overlay, + ), entries in grouped.items(): degree_values = [int(degree[local_idx].item()) for local_idx, _receiver_idx in entries] entry_indices = torch.tensor([receiver_idx for _local_idx, receiver_idx in entries], dtype=torch.long) + population_name = "readout" if population_index < 0 else spec.population_names[population_index] buckets.append( FabricBucket( bucket_id=next_bucket_id, receiver_kind=receiver_kind, population_name=population_name, + population_index=None if population_index < 0 else population_index, + transition_signature=transition_signature, + parameter_binding=parameter_binding, dim_signature=dim_signature, receiver_count=len(entries), degree_bin=degree_bin, @@ -206,6 +304,35 @@ def _compile_buckets_for_receivers( return buckets +def _transition_signature_for_population( + spec: Spec, + population_idx: int, + *, + population_backend_specs: tuple[CellBackendSpec, ...], +) -> tuple[str, ...]: + if population_idx < 0: + return ("receiver=readout", "transition=readout_projection") + population_name = spec.population_names[population_idx] + cell_type = spec.config.populations.cell_populations[population_name].cell_type + cell_spec = get_cell_spec(cell_type) + backend_spec = population_backend_specs[population_idx] + return transition_flat_bucket_identity( + state_schema_keys=cell_spec.state_schema.keys, + public_schema_kind=cell_spec.public_schema.kind, + parameter_schema_keys=cell_spec.parameter_schema.keys, + input_projection_schema_keys=cell_spec.input_projection_schema.keys, + public_projection_schema_keys=cell_spec.public_projection_schema.keys, + transition_ir=backend_spec.transition_ir, + transition_parameter_bindings=backend_spec.transition_parameter_bindings, + ) + + +def _parameter_binding_for_population(population_idx: int) -> str: + if population_idx < 0: + return "readout_binding" + return f"population_slot:{int(population_idx)}" + + def _template_ids(local_valid: torch.Tensor) -> torch.Tensor: templates: dict[tuple[int, ...], int] = {} out = torch.empty(local_valid.shape[0], dtype=torch.long) diff --git a/src/cortical/fabric/backend/message_rule_specs.py b/src/cortical/fabric/backend/message_rule_specs.py new file mode 100644 index 00000000..3c2a743d --- /dev/null +++ b/src/cortical/fabric/backend/message_rule_specs.py @@ -0,0 +1,465 @@ +from __future__ import annotations + +from cortical.fabric.backend.message_rules import ( + DOT_PRODUCT_FIXED_SLOT_CONTEXT_GATE, + DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE, + DOT_PRODUCT_SEGMENT_SOFTMAX_WEIGHTED_SUM, + PROJECTED_MESSAGE_BOUNDARY, + MessageOpPrimitiveBinding, + MessageRuleBackendSpec, + MessageRuleNativeExecutorSpec, + MessageRuleNode, + MessageRuleParamGradOutputSpec, + MessageRuleParameter, + MessageRuleParameterReducerSpec, + MessageRuleRuntimeModule, + MessageRuleRuntimeParameterSpec, + MessageRuleSource, + MessageRuleStaticTensorSpec, + MessageSharingScope, + register_message_rule_backend_spec_builder, +) + + +def _dot_product_native_executors() -> tuple[MessageRuleNativeExecutorSpec, ...]: + return ( + MessageRuleNativeExecutorSpec( + direction="forward", + executor_id=1, + executor_name="neighborhood_attention_project", + strategy_id="forward.message.neighborhood_attention_project.v1", + native_callable="native.forward.msg_attention_project.v1", + implementation_contract="registered_message_executor_binding_rows", + cxx_entrypoints=( + "bind_neighborhood_attention_project_message_handler", + "run_neighborhood_attention_project_recurrent_kv", + "run_neighborhood_attention_project_message", + ), + cxx_entrypoint_phases=("bind", "recurrent_kv", "message"), + ), + MessageRuleNativeExecutorSpec( + direction="reverse", + executor_id=1, + executor_name="neighborhood_attention_project_backward", + strategy_id="reverse.message.neighborhood_attention_project.v1", + native_callable="native.reverse.msg_attention_project.v1", + implementation_contract="registered_message_reverse_executor_binding_rows", + cxx_entrypoints=( + "run_neighborhood_attention_project_recurrent_kv_backward", + "run_neighborhood_attention_project_recurrent_message_backward", + "run_neighborhood_attention_project_initial_recurrent_kv_backward", + "run_neighborhood_attention_project_boundary_kv_backward", + "run_neighborhood_attention_project_recurrent_kv_forward_recompute", + ), + cxx_entrypoint_phases=( + "recurrent_kv_backward", + "recurrent_message_backward", + "initial_recurrent_kv_backward", + "boundary_kv_backward", + "recurrent_kv_forward_recompute", + ), + ), + ) + + +def _fixed_slot_context_native_executors( + *, + variant: str, + executor_id: int, +) -> tuple[MessageRuleNativeExecutorSpec, ...]: + return ( + MessageRuleNativeExecutorSpec( + direction="forward", + executor_id=int(executor_id), + executor_name=f"fixed_slot_context_{variant}_message", + strategy_id=f"forward.message.fixed_slot_context_{variant}.v1", + native_callable=f"native.forward.msg_fixed_slot_context_{variant}.v1", + implementation_contract=f"registered_fixed_slot_context_{variant}_message_native_callable", + cxx_entrypoints=( + "bind_fixed_slot_context_message_handler", + "run_fixed_slot_context_recurrent_kv", + "run_fixed_slot_context_message", + "run_fixed_slot_context_keyless_readout_message", + "run_fixed_slot_context_direct_keyless_readout_message", + "run_fixed_slot_context_stream_readout_message", + "run_fixed_slot_context_stream_transition_input", + ), + cxx_entrypoint_phases=( + "bind", + "recurrent_kv", + "message", + "keyless_readout_message", + "direct_keyless_readout_message", + "stream_readout_message", + "stream_transition_input", + ), + ), + MessageRuleNativeExecutorSpec( + direction="reverse", + executor_id=int(executor_id), + executor_name=f"fixed_slot_context_{variant}_message_backward", + strategy_id=f"reverse.message.fixed_slot_context_{variant}.v1", + native_callable=f"native.reverse.msg_fixed_slot_context_{variant}.v1", + implementation_contract=f"registered_fixed_slot_context_{variant}_message_backward_native_callable", + cxx_entrypoints=( + "run_fixed_slot_context_recurrent_kv_backward", + "run_fixed_slot_context_recurrent_message_backward", + "run_fixed_slot_context_initial_recurrent_kv_backward", + "run_fixed_slot_context_boundary_kv_backward", + "run_fixed_slot_context_recurrent_kv_forward_recompute", + ), + cxx_entrypoint_phases=( + "recurrent_kv_backward", + "recurrent_message_backward", + "initial_recurrent_kv_backward", + "boundary_kv_backward", + "recurrent_kv_forward_recompute", + ), + ), + ) + + +def _fixed_slot_context_parameter_reducer() -> MessageRuleParameterReducerSpec: + return MessageRuleParameterReducerSpec( + reducer_kind="fixed_slot_context_message", + reducer_kind_opcode=6, + native_callable="native.reverse.parameter_reduction.fixed_slot_context_message.v1", + implementation_symbol="run_registered_fixed_slot_context_message_parameter_reducer_strategy", + count_target="message_strategy", + count_mode="tensor_count", + active_trainable_roles=( + "slot_embed", + "message_query_slot_proj_weight", + "message_sender_slot_key_proj_weight", + "message_query_nudge_scale", + "message_query_context_gate", + "message_sender_context_key", + "msg_out_weight", + ), + required_static_logical_groups=( + ("message_query_slot_weight",), + ("message_sender_slot_key_weight",), + ("message_sender_context_key",), + ("message_output_weight",), + ("message_query_context_gate", "message_query_nudge_scale"), + ), + grad_output_roles=( + "grad_query_slot_backend", + "grad_query_context_scalar", + "grad_output_weight", + "grad_input_key_bank", + "grad_recurrent_key_bank", + ), + ) + + +def build_dot_product_message_rule_backend_spec( + *, + kv_group_count: int, + cell_count: int, +) -> MessageRuleBackendSpec: + sharing_mode: MessageSharingScope = ( + "sender_group_shared" if int(kv_group_count) < int(cell_count) else "sender_local" + ) + return MessageRuleBackendSpec( + rule_type="dot_product", + default_name="dot_product", + lowering_kind=DOT_PRODUCT_SEGMENT_SOFTMAX_WEIGHTED_SUM, + sources=( + MessageRuleSource("receiver_slot", "receiver_slot"), + MessageRuleSource("sender_public_prev", "sender_public_prev", "zero_source_rows", "batch_row"), + MessageRuleSource("edge_distance", "edge_distance"), + ), + parameters=( + MessageRuleParameter("recurrent_q_weight", "projection", "rule_global"), + MessageRuleParameter("input_sender_kv_weight", "projection", sharing_mode, int(kv_group_count)), + MessageRuleParameter("input_group_kv_weight", "projection", sharing_mode, int(kv_group_count)), + MessageRuleParameter("recurrent_sender_kv_weight", "projection", sharing_mode, int(kv_group_count)), + ), + nodes=( + MessageRuleNode("receiver_source", "source", (), "receiver_slot"), + MessageRuleNode("sender_source", "source", (), "sender_public_prev"), + MessageRuleNode("edge_distance_source", "source", (), "edge_distance"), + MessageRuleNode("q_projection", "linear", ("receiver_slot", "recurrent_q_weight"), "q"), + MessageRuleNode( + "k_projection", + "linear", + ("sender_public_prev", "input_sender_kv_weight"), + "k", + ), + MessageRuleNode( + "v_projection", + "linear", + ("sender_public_prev", "input_group_kv_weight"), + "v", + ), + MessageRuleNode("logits", "dot", ("q", "k"), "logits"), + MessageRuleNode("biased_logits", "add", ("logits", "edge_distance"), "biased_logits"), + MessageRuleNode("weights", "segment_softmax", ("biased_logits",), "weights"), + MessageRuleNode("weighted_value", "segment_weighted_sum", ("weights", "v"), "weighted_value"), + MessageRuleNode( + "message_projection", + "linear", + ("weighted_value", "recurrent_sender_kv_weight"), + PROJECTED_MESSAGE_BOUNDARY, + ), + ), + output_boundary=PROJECTED_MESSAGE_BOUNDARY, + primitive_bindings=( + MessageOpPrimitiveBinding("linear", "linear"), + MessageOpPrimitiveBinding("add", "add"), + MessageOpPrimitiveBinding("dot", "attention_logits"), + MessageOpPrimitiveBinding("segment_softmax", "segment_softmax"), + MessageOpPrimitiveBinding("segment_weighted_sum", "weighted_sum"), + ), + static_tensors=( + MessageRuleStaticTensorSpec( + "recurrent_q_weight", + "existing_static_tensor", + program_access_name="message_recurrent_query", + program_access_opcode=1, + ), + MessageRuleStaticTensorSpec( + "input_sender_kv_weight", + "existing_static_tensor", + program_access_name="message_input_direct_kv_weight", + program_access_opcode=2, + ), + MessageRuleStaticTensorSpec( + "input_group_kv_weight", + "existing_static_tensor", + program_access_name="message_input_group_kv_weight", + program_access_opcode=3, + ), + MessageRuleStaticTensorSpec( + "recurrent_sender_kv_weight", + "existing_static_tensor", + program_access_name="message_recurrent_kv_weight", + program_access_opcode=4, + ), + ), + native_executors=_dot_product_native_executors(), + ) + + +def _build_fixed_slot_context_dot_product_message_rule_backend_spec( + *, + kv_group_count: int, + cell_count: int, + rule_type: str, + lowering_kind: str, + context_scalar_name: str, + context_scalar_node_name: str, + variant: str, + executor_id: int, +) -> MessageRuleBackendSpec: + sharing_mode: MessageSharingScope = ( + "sender_group_shared" if int(kv_group_count) < int(cell_count) else "sender_local" + ) + return MessageRuleBackendSpec( + rule_type=rule_type, + default_name="dot_product", + lowering_kind=lowering_kind, + sources=( + MessageRuleSource("receiver_slot", "receiver_slot"), + MessageRuleSource("receiver_public_prev", "receiver_public_prev", "zero_source_rows", "batch_row"), + MessageRuleSource("sender_slot", "sender_slot"), + MessageRuleSource("sender_public_prev", "sender_public_prev", "zero_source_rows", "batch_row"), + MessageRuleSource("edge_distance", "edge_distance"), + ), + parameters=( + MessageRuleParameter("message_query_slot_weight", "projection", "rule_global"), + MessageRuleParameter(context_scalar_name, "rule_scalar", "fabric_global"), + MessageRuleParameter("message_sender_slot_key_weight", "projection", "rule_global"), + MessageRuleParameter("message_sender_context_key", "rule_table", "sender_local", int(cell_count)), + MessageRuleParameter("input_sender_value_weight", "projection", sharing_mode, int(kv_group_count)), + MessageRuleParameter("input_group_value_weight", "projection", sharing_mode, int(kv_group_count)), + MessageRuleParameter("recurrent_sender_value_weight", "projection", sharing_mode, int(kv_group_count)), + MessageRuleParameter("message_output_weight", "projection", "rule_global"), + ), + nodes=( + MessageRuleNode("receiver_source", "source", (), "receiver_slot"), + MessageRuleNode("receiver_public_source", "source", (), "receiver_public_prev"), + MessageRuleNode("sender_slot_source", "source", (), "sender_slot"), + MessageRuleNode("sender_public_source", "source", (), "sender_public_prev"), + MessageRuleNode("edge_distance_source", "source", (), "edge_distance"), + MessageRuleNode("query_slot_prefix", "linear", ("receiver_slot", "message_query_slot_weight"), "q_slot"), + MessageRuleNode( + "query_context", + "linear", + ("receiver_public_prev", "recurrent_sender_value_weight"), + "q_context", + ), + MessageRuleNode( + context_scalar_node_name, + "mul", + ("q_context", context_scalar_name), + context_scalar_node_name, + ), + MessageRuleNode("query", "concat", ("q_slot", context_scalar_node_name), "q"), + MessageRuleNode("key_slot_prefix", "linear", ("sender_slot", "message_sender_slot_key_weight"), "k_slot"), + MessageRuleNode("key", "concat", ("k_slot", "message_sender_context_key"), "k"), + MessageRuleNode( + "value_projection", + "linear", + ( + "sender_public_prev", + "input_sender_value_weight", + "input_group_value_weight", + "recurrent_sender_value_weight", + ), + "v", + ), + MessageRuleNode("logits", "dot", ("q", "k"), "logits"), + MessageRuleNode("biased_logits", "add", ("logits", "edge_distance"), "biased_logits"), + MessageRuleNode("weights", "segment_softmax", ("biased_logits",), "weights"), + MessageRuleNode("weighted_value", "segment_weighted_sum", ("weights", "v"), "weighted_value"), + MessageRuleNode( + "message_projection", + "linear", + ("weighted_value", "message_output_weight"), + "message_raw", + ), + MessageRuleNode("message_norm", "normalize", ("message_raw",), PROJECTED_MESSAGE_BOUNDARY), + ), + output_boundary=PROJECTED_MESSAGE_BOUNDARY, + primitive_bindings=( + MessageOpPrimitiveBinding("linear", "linear"), + MessageOpPrimitiveBinding("mul", "mul"), + MessageOpPrimitiveBinding("concat", "concat"), + MessageOpPrimitiveBinding("add", "add"), + MessageOpPrimitiveBinding("dot", "attention_logits"), + MessageOpPrimitiveBinding("segment_softmax", "segment_softmax"), + MessageOpPrimitiveBinding("segment_weighted_sum", "weighted_sum"), + MessageOpPrimitiveBinding("normalize", "normalize"), + ), + output_dim_role="d_msg", + runtime_modules=( + MessageRuleRuntimeModule("message_query_slot_proj", "linear", "d_slot", "head_dim"), + MessageRuleRuntimeModule("message_sender_slot_key_proj", "linear", "d_slot", "head_dim"), + ), + runtime_parameters=( + MessageRuleRuntimeParameterSpec(context_scalar_name, ("1",), "ones"), + MessageRuleRuntimeParameterSpec("message_sender_context_key", ("cell_count", "head_dim"), "normal_head"), + ), + static_tensors=( + MessageRuleStaticTensorSpec( + "message_query_slot_weight", + "slot_linear_backend_recurrent", + "message_query_slot_proj", + "message_query_slot_weight", + 10, + ), + MessageRuleStaticTensorSpec( + context_scalar_name, + "runtime_parameter", + context_scalar_name, + "message_query_context_scalar", + 11, + ), + MessageRuleStaticTensorSpec( + "message_sender_slot_key_weight", + "slot_linear_sender", + "message_sender_slot_key_proj", + "message_sender_slot_key_weight", + 12, + ), + MessageRuleStaticTensorSpec( + "message_sender_context_key", + "runtime_parameter_sender", + "message_sender_context_key", + "message_sender_context_key", + 13, + ), + MessageRuleStaticTensorSpec( + "input_sender_value_weight", + "input_sender_value_weight", + program_access_name="message_input_value_weight", + program_access_opcode=14, + ), + MessageRuleStaticTensorSpec( + "input_group_value_weight", + "input_group_value_weight", + program_access_name="message_input_group_value_weight", + program_access_opcode=15, + ), + MessageRuleStaticTensorSpec( + "recurrent_sender_value_weight", + "recurrent_sender_value_weight", + program_access_name="message_recurrent_value_weight", + program_access_opcode=16, + ), + MessageRuleStaticTensorSpec( + "message_output_weight", + "module_weight", + "msg_out", + "message_output_weight", + 17, + ), + ), + parameter_reducer_kind="fixed_slot_context_message", + param_grad_outputs=( + MessageRuleParamGradOutputSpec("grad_query_slot_backend", "recurrent_query_grad"), + MessageRuleParamGradOutputSpec("grad_query_context_scalar", "boundary_extra_output", 0), + MessageRuleParamGradOutputSpec("grad_output_weight", "boundary_extra_output", 1), + MessageRuleParamGradOutputSpec("grad_input_key_bank", "boundary_extra_output", 2), + MessageRuleParamGradOutputSpec("grad_recurrent_key_bank", "boundary_extra_output", 3), + ), + native_executors=_fixed_slot_context_native_executors( + variant=variant, + executor_id=int(executor_id), + ), + parameter_reducer=_fixed_slot_context_parameter_reducer(), + ) + + +def build_fixed_slot_context_nudge_dot_product_message_rule_backend_spec( + *, + kv_group_count: int, + cell_count: int, +) -> MessageRuleBackendSpec: + return _build_fixed_slot_context_dot_product_message_rule_backend_spec( + kv_group_count=kv_group_count, + cell_count=cell_count, + rule_type="dot_product_fixed_slot_context_nudge", + lowering_kind=DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE, + context_scalar_name="message_query_nudge_scale", + context_scalar_node_name="q_context_nudge", + variant="nudge", + executor_id=5, + ) + + +def build_fixed_slot_context_gate_dot_product_message_rule_backend_spec( + *, + kv_group_count: int, + cell_count: int, +) -> MessageRuleBackendSpec: + return _build_fixed_slot_context_dot_product_message_rule_backend_spec( + kv_group_count=kv_group_count, + cell_count=cell_count, + rule_type="dot_product_fixed_slot_context_gate", + lowering_kind=DOT_PRODUCT_FIXED_SLOT_CONTEXT_GATE, + context_scalar_name="message_query_context_gate", + context_scalar_node_name="q_context_gate", + variant="gate", + executor_id=10, + ) + + +register_message_rule_backend_spec_builder("dot_product", build_dot_product_message_rule_backend_spec) +register_message_rule_backend_spec_builder( + "dot_product_fixed_slot_context_nudge", + build_fixed_slot_context_nudge_dot_product_message_rule_backend_spec, +) +register_message_rule_backend_spec_builder( + "dot_product_fixed_slot_context_gate", + build_fixed_slot_context_gate_dot_product_message_rule_backend_spec, +) + + +__all__ = [ + "build_dot_product_message_rule_backend_spec", + "build_fixed_slot_context_gate_dot_product_message_rule_backend_spec", + "build_fixed_slot_context_nudge_dot_product_message_rule_backend_spec", +] diff --git a/src/cortical/fabric/backend/message_rules.py b/src/cortical/fabric/backend/message_rules.py index 140c6b79..19685d3c 100644 --- a/src/cortical/fabric/backend/message_rules.py +++ b/src/cortical/fabric/backend/message_rules.py @@ -1,24 +1,51 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass from typing import Literal -MessageSourceKind = Literal["receiver_slot", "sender_public_prev", "edge_distance"] -MessageParameterRole = Literal["projection", "bias"] -MessageSharingScope = Literal["rule_global", "sender_group_shared", "sender_local"] +from cortical.fabric.backend.primitives import is_callable_cuda_nn_primitive + +MessageSourceKind = Literal[ + "receiver_slot", + "receiver_public_prev", + "sender_slot", + "sender_public_prev", + "edge_distance", +] +MessageParameterRole = Literal["projection", "bias", "rule_scalar", "rule_table"] +MessageSharingScope = Literal["fabric_global", "rule_global", "sender_group_shared", "sender_local"] +MessageParamGradSource = Literal["recurrent_query_grad", "boundary_extra_output"] +MessageNativeExecutorDirection = Literal["forward", "reverse"] +MessageNativeExecutorPhase = Literal[ + "bind", + "recurrent_kv", + "message", + "keyless_readout_message", + "direct_keyless_readout_message", + "recurrent_kv_forward_recompute", + "recurrent_kv_backward", + "recurrent_message_backward", + "initial_recurrent_kv_backward", + "boundary_kv_backward", +] MessageOpKind = Literal[ "source", "parameter", "linear", "bias", "add", + "mul", + "concat", "dot", + "normalize", "segment_softmax", "segment_weighted_sum", ] DOT_PRODUCT_SEGMENT_SOFTMAX_WEIGHTED_SUM = "dot_product_segment_softmax_weighted_sum" +DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE = "dot_product_fixed_slot_context_nudge" +DOT_PRODUCT_FIXED_SLOT_CONTEXT_GATE = "dot_product_fixed_slot_context_gate" PROJECTED_MESSAGE_BOUNDARY = "projected_message" @@ -58,6 +85,126 @@ class MessageRuleNode: output: str +@dataclass(frozen=True) +class MessageOpPrimitiveBinding: + source_op: str + primitive: str + + +@dataclass(frozen=True) +class MessageRuleRuntimeModule: + name: str + module_kind: str + input_dim_role: str + output_dim_role: str + bias: bool = False + + +@dataclass(frozen=True) +class MessageRuleRuntimeParameterSpec: + name: str + shape_roles: tuple[str, ...] + init: str + + +@dataclass(frozen=True) +class MessageRuleStaticTensorSpec: + name: str + source_kind: str + source_name: str = "" + program_access_name: str = "" + program_access_opcode: int = 0 + + +@dataclass(frozen=True) +class MessageRuleParamGradOutputSpec: + logical_name: str + source: MessageParamGradSource + source_index: int = 0 + + +@dataclass(frozen=True) +class MessageRuleNativeExecutorEntrypointSpec: + phase: MessageNativeExecutorPhase + symbol: str + + +@dataclass(frozen=True) +class MessageRuleNativeExecutorSpec: + direction: MessageNativeExecutorDirection + executor_id: int + executor_name: str + strategy_id: str + native_callable: str + implementation_contract: str + cxx_entrypoints: tuple[str, ...] + cxx_entrypoint_phases: tuple[MessageNativeExecutorPhase, ...] = () + strategy_version: int = 1 + + @property + def cxx_entrypoint_contract(self) -> tuple[MessageRuleNativeExecutorEntrypointSpec, ...]: + return tuple( + MessageRuleNativeExecutorEntrypointSpec(phase=phase, symbol=symbol) + for phase, symbol in zip(self.cxx_entrypoint_phases, self.cxx_entrypoints, strict=True) + ) + + +@dataclass(frozen=True) +class MessageRuleParameterReducerSpec: + reducer_kind: str + reducer_kind_opcode: int + native_callable: str + implementation_symbol: str + count_target: str + count_mode: str + active_trainable_roles: tuple[str, ...] + required_static_logical_groups: tuple[tuple[str, ...], ...] + grad_output_roles: tuple[str, ...] + strategy_opcode: int = 0 + strategy_version: int = 1 + cxx_entrypoints: tuple[str, ...] = () + + +@dataclass(frozen=True) +class MessageRuleBackendSpec: + rule_type: str + default_name: str + lowering_kind: str + sources: tuple[MessageRuleSource, ...] + parameters: tuple[MessageRuleParameter, ...] + nodes: tuple[MessageRuleNode, ...] + output_boundary: str + primitive_bindings: tuple[MessageOpPrimitiveBinding, ...] + output_dim_role: str = "value_dim" + runtime_modules: tuple[MessageRuleRuntimeModule, ...] = () + runtime_parameters: tuple[MessageRuleRuntimeParameterSpec, ...] = () + static_tensors: tuple[MessageRuleStaticTensorSpec, ...] = () + parameter_reducer_kind: str = "" + param_grad_outputs: tuple[MessageRuleParamGradOutputSpec, ...] = () + native_executors: tuple[MessageRuleNativeExecutorSpec, ...] = () + parameter_reducer: MessageRuleParameterReducerSpec | None = None + + def to_ir(self, *, name: str | None = None) -> MessageRuleIR: + return MessageRuleIR( + name=self.default_name if name is None else str(name), + sources=self.sources, + parameters=self.parameters, + nodes=self.nodes, + output_boundary=self.output_boundary, + rule_type=self.rule_type, + lowering_kind_id=self.lowering_kind, + primitive_bindings=self.primitive_bindings, + output_dim_role=self.output_dim_role, + runtime_modules=self.runtime_modules, + runtime_parameters=self.runtime_parameters, + static_tensors=self.static_tensors, + parameter_reducer_kind=self.parameter_reducer_kind, + param_grad_outputs=self.param_grad_outputs, + native_executors=self.native_executors, + parameter_reducer=self.parameter_reducer, + ) + + @dataclass(frozen=True) class MessageRuleIR: name: str @@ -65,29 +212,66 @@ class MessageRuleIR: parameters: tuple[MessageRuleParameter, ...] nodes: tuple[MessageRuleNode, ...] output_boundary: str + rule_type: str = "" + lowering_kind_id: str = "" + primitive_bindings: tuple[MessageOpPrimitiveBinding, ...] = () + output_dim_role: str = "value_dim" + runtime_modules: tuple[MessageRuleRuntimeModule, ...] = () + runtime_parameters: tuple[MessageRuleRuntimeParameterSpec, ...] = () + static_tensors: tuple[MessageRuleStaticTensorSpec, ...] = () + parameter_reducer_kind: str = "" + param_grad_outputs: tuple[MessageRuleParamGradOutputSpec, ...] = () + native_executors: tuple[MessageRuleNativeExecutorSpec, ...] = () + parameter_reducer: MessageRuleParameterReducerSpec | None = None @property def expression_signature(self) -> str: - if self.name == "dot_product": + if self.lowering_kind == DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE: return ( - "source(receiver_slot)->linear(q);" - "source(sender_public_prev,reset=zero_source_rows,batch_row)->linear(k),linear(v);" + "fixed_slot_query(receiver_slot)+context_nudge(receiver_public_prev)->q;" + "fixed_slot_key(sender_slot)->k;" + "dynamic_value(sender_public_prev)->v;" "dot(q,k)+edge_distance->segment_softmax(receiver_neighborhood);" - "segment_weighted_sum(receiver_neighborhood,v)->linear(out)->projected_message" + "segment_weighted_sum(receiver_neighborhood,v)->linear->normalize->projected_message" + ) + if self.lowering_kind == DOT_PRODUCT_FIXED_SLOT_CONTEXT_GATE: + return ( + "fixed_slot_query(receiver_slot)+context_gate(receiver_public_prev)->q;" + "fixed_slot_key(sender_slot)->k;" + "dynamic_value(sender_public_prev)->v;" + "dot(q,k)+edge_distance->segment_softmax(receiver_neighborhood);" + "segment_weighted_sum(receiver_neighborhood,v)->linear->normalize->projected_message" + ) + if self.rule_type == "dot_product" or self.name == "dot_product": + return ( + "source(receiver_slot)->linear(recurrent_q);" + "source(sender_public_prev,reset=zero_source_rows,batch_row)->linear(sender_kv);" + "dot(q,k)+edge_distance->segment_softmax(receiver_neighborhood);" + "segment_weighted_sum(receiver_neighborhood,v)->projected_message" ) return ";".join(f"{node.op}({','.join(node.inputs)})->{node.output}" for node in self.nodes) @property def lowering_kind(self) -> str: + if self.lowering_kind_id: + return self.lowering_kind_id return classify_message_rule(self) + @property + def source_signature(self) -> tuple[str, ...]: + return tuple(source.signature for source in self.sources) + + @property + def parameter_sharing_signature(self) -> tuple[str, ...]: + return tuple(parameter.signature for parameter in self.parameters) + def summary(self) -> MessageRuleSummary: return MessageRuleSummary( name=self.name, lowering_kind=self.lowering_kind, expression_signature=self.expression_signature, - source_signature=tuple(source.signature for source in self.sources), - parameter_sharing_signature=tuple(parameter.signature for parameter in self.parameters), + source_signature=self.source_signature, + parameter_sharing_signature=self.parameter_sharing_signature, output_boundary=self.output_boundary, ) @@ -102,44 +286,535 @@ class MessageRuleSummary: output_boundary: str +@dataclass(frozen=True) +class CompiledMessagePrimitiveOp: + node_index: int + node_name: str + primitive: str + source_op: str + inputs: tuple[str, ...] + outputs: tuple[str, ...] + attributes: tuple[tuple[str, str], ...] = () + parameter_bindings: tuple[str, ...] = () + + +@dataclass(frozen=True) +class CompiledMessageRule: + rule_name: str + lowering_kind: str + output_boundary: str + output_dim_role: str + primitive_ops: tuple[CompiledMessagePrimitiveOp, ...] + runtime_modules: tuple[MessageRuleRuntimeModule, ...] = () + runtime_parameters: tuple[MessageRuleRuntimeParameterSpec, ...] = () + static_tensors: tuple[MessageRuleStaticTensorSpec, ...] = () + parameter_reducer_kind: str = "" + param_grad_outputs: tuple[MessageRuleParamGradOutputSpec, ...] = () + native_executors: tuple[MessageRuleNativeExecutorSpec, ...] = () + parameter_reducer: MessageRuleParameterReducerSpec | None = None + + @property + def primitive_names(self) -> tuple[str, ...]: + return tuple(dict.fromkeys(op.primitive for op in self.primitive_ops)) + + +MessageRuleBackendSpecBuilder = Callable[..., MessageRuleBackendSpec] + + +_MESSAGE_RULE_BACKEND_SPEC_BUILDERS: dict[str, MessageRuleBackendSpecBuilder] = {} + + +def register_message_rule_backend_spec_builder( + rule_type: str, + builder: MessageRuleBackendSpecBuilder, +) -> None: + _MESSAGE_RULE_BACKEND_SPEC_BUILDERS[str(rule_type)] = builder + + +def _ensure_builtin_message_rule_backend_specs_registered() -> None: + if "dot_product" in _MESSAGE_RULE_BACKEND_SPEC_BUILDERS: + return + import cortical.fabric.backend.message_rule_specs # noqa: F401 + + +def build_message_rule_backend_spec( + *, + rule_type: str, + kv_group_count: int, + cell_count: int, +) -> MessageRuleBackendSpec: + _ensure_builtin_message_rule_backend_specs_registered() + try: + builder = _MESSAGE_RULE_BACKEND_SPEC_BUILDERS[str(rule_type)] + except KeyError as exc: + raise ValueError(f"Unsupported Fabric message rule backend {rule_type}") from exc + return builder(kv_group_count=int(kv_group_count), cell_count=int(cell_count)) + + +def build_message_rule_ir( + *, + rule_type: str, + kv_group_count: int, + cell_count: int, + name: str | None = None, +) -> MessageRuleIR: + return build_message_rule_backend_spec( + rule_type=rule_type, + kv_group_count=int(kv_group_count), + cell_count=int(cell_count), + ).to_ir(name=name) + + +def registered_message_rule_backend_spec_types() -> tuple[str, ...]: + _ensure_builtin_message_rule_backend_specs_registered() + return tuple(sorted(_MESSAGE_RULE_BACKEND_SPEC_BUILDERS)) + + +def ordered_message_rule_backend_spec_types() -> tuple[str, ...]: + rule_types = registered_message_rule_backend_spec_types() + return tuple( + rule_type for rule_type in _BUILTIN_MESSAGE_RULE_LOWERING_CATALOG_ORDER if rule_type in rule_types + ) + tuple(rule_type for rule_type in rule_types if rule_type not in _BUILTIN_MESSAGE_RULE_LOWERING_CATALOG_ORDER) + + +_BUILTIN_MESSAGE_RULE_LOWERING_CATALOG_ORDER = ( + "dot_product", + "dot_product_fixed_slot_context_nudge", + "dot_product_fixed_slot_context_gate", +) + +_CPP_MESSAGE_SOURCE_KIND = { + "receiver_slot": "MessageSourceKind::ReceiverSlot", + "receiver_public_prev": "MessageSourceKind::ReceiverPublicPrev", + "sender_slot": "MessageSourceKind::SenderSlot", + "sender_public_prev": "MessageSourceKind::SenderPublicPrev", + "edge_distance": "MessageSourceKind::EdgeDistance", +} + +_CPP_MESSAGE_PARAMETER_ROLE = { + "projection": "MessageParameterRole::Projection", + "bias": "MessageParameterRole::EdgeBias", + "rule_scalar": "MessageParameterRole::RuleScalar", + "rule_table": "MessageParameterRole::RuleTable", +} + +_CPP_MESSAGE_SHARING_SCOPE = { + "fabric_global": "MessageSharingScope::FabricGlobal", + "rule_global": "MessageSharingScope::RuleGlobal", + "sender_group_shared": "MessageSharingScope::SenderGroupShared", + "sender_local": "MessageSharingScope::SenderLocal", +} + +_CPP_MESSAGE_OP_KIND = { + "source": "MessageOpKind::Source", + "parameter": "MessageOpKind::Parameter", + "linear": "MessageOpKind::Linear", + "bias": "MessageOpKind::Bias", + "add": "MessageOpKind::Add", + "mul": "MessageOpKind::Mul", + "concat": "MessageOpKind::Concat", + "dot": "MessageOpKind::Dot", + "normalize": "MessageOpKind::Normalize", + "segment_softmax": "MessageOpKind::SegmentSoftmax", + "segment_weighted_sum": "MessageOpKind::SegmentWeightedSum", + "emit_projected_message": "MessageOpKind::EmitProjectedMessage", +} + +_CPP_MESSAGE_RESET_POLICY = { + "none": "ResetPolicy::None", + "zero_source_rows": "ResetPolicy::ZeroSourceRows", +} + +_CPP_MESSAGE_RESET_SCOPE = { + "none": "ResetScope::None", + "batch_row": "ResetScope::BatchRow", +} + +_BINARY_MESSAGE_OPS = frozenset(("add", "mul", "concat", "dot", "segment_weighted_sum")) +_UNARY_MESSAGE_OPS = frozenset(("normalize", "segment_softmax")) +_PARAMETER_BINDING_MESSAGE_OPS = frozenset(("bias", "linear")) + + +def _cpp_string_literal(value: str) -> str: + return '"' + str(value).replace("\\", "\\\\").replace('"', '\\"') + '"' + + +def _cpp_pascal_identifier(value: str) -> str: + return "".join(part[:1].upper() + part[1:] for part in str(value).split("_") if part) + + +def _cpp_catalog_symbol(rule_type: str, suffix: str) -> str: + return "k" + _cpp_pascal_identifier(rule_type) + suffix + + +def _cpp_enum(mapping: dict[str, str], value: str, *, label: str) -> str: + try: + return mapping[str(value)] + except KeyError as exc: + raise ValueError(f"Unsupported message-rule {label} for C++ lowering catalog: {value!r}") from exc + + +def _ordered_message_rule_lowering_catalog_specs() -> tuple[MessageRuleBackendSpec, ...]: + return tuple( + build_message_rule_backend_spec( + rule_type=rule_type, + kv_group_count=1, + cell_count=2, + ) + for rule_type in ordered_message_rule_backend_spec_types() + ) + + +def _message_rule_lowering_id_by_kind( + specs: Sequence[MessageRuleBackendSpec], +) -> dict[str, int]: + lowering_id_by_kind: dict[str, int] = {} + for spec in specs: + lowering_kind = str(spec.lowering_kind) + if lowering_kind not in lowering_id_by_kind: + lowering_id_by_kind[lowering_kind] = len(lowering_id_by_kind) + return lowering_id_by_kind + + +def _message_rule_catalog_requires_parameter_names(spec: MessageRuleBackendSpec) -> bool: + return spec.rule_type != "dot_product" + + +def _message_rule_lowering_catalog_nodes( + spec: MessageRuleBackendSpec, +) -> tuple[tuple[str, int, int, int, int, tuple[int, ...]], ...]: + source_index_by_name = {source.name: index for index, source in enumerate(spec.sources)} + parameter_index_by_name = {parameter.name: index for index, parameter in enumerate(spec.parameters)} + node_index_by_output: dict[str, int] = {} + node_rows: list[tuple[str, int, int, int, int, tuple[int, ...]]] = [] + + def add_node( + kind: str, + lhs: int = -1, + rhs: int = -1, + parameter_index: int = -1, + source_index: int = -1, + parameter_indices: tuple[int, ...] = (), + ) -> int: + if parameter_indices and parameter_index < 0: + parameter_index = int(parameter_indices[0]) + node_rows.append( + ( + kind, + int(lhs), + int(rhs), + int(parameter_index), + int(source_index), + tuple(int(index) for index in parameter_indices), + ) + ) + return len(node_rows) - 1 + + def ensure_input_node(input_name: str) -> int: + name = str(input_name) + if name in node_index_by_output: + return node_index_by_output[name] + parameter_index = parameter_index_by_name.get(name) + if parameter_index is not None: + node_index = add_node( + "parameter", + parameter_index=parameter_index, + parameter_indices=(parameter_index,), + ) + node_index_by_output[name] = node_index + return node_index + raise ValueError(f"Message rule {spec.rule_type!r} references undeclared C++ catalog input {input_name!r}") + + for node in spec.nodes: + if node.op == "source": + source_index = source_index_by_name.get(node.output) + if source_index is None: + raise ValueError( + f"Message rule {spec.rule_type!r} source node {node.name!r} does not match a declared source" + ) + node_index = add_node("source", source_index=source_index) + elif node.op == "parameter": + if not node.inputs: + raise ValueError(f"Message rule {spec.rule_type!r} parameter node {node.name!r} has no input") + parameter_name = str(node.inputs[0]) + parameter_index = parameter_index_by_name.get(parameter_name) + if parameter_index is None: + raise ValueError( + f"Message rule {spec.rule_type!r} parameter node {node.name!r} " + f"references undeclared parameter {parameter_name!r}" + ) + node_index = add_node( + "parameter", + parameter_index=parameter_index, + parameter_indices=(parameter_index,), + ) + elif node.op in _PARAMETER_BINDING_MESSAGE_OPS: + if len(node.inputs) < 2: + raise ValueError( + f"Message rule {spec.rule_type!r} {node.op} node {node.name!r} " + "must have one tensor input and at least one parameter input" + ) + lhs = ensure_input_node(str(node.inputs[0])) + parameter_inputs = tuple(str(input_name) for input_name in node.inputs[1:]) + parameter_indices = tuple( + parameter_index_by_name[input_name] + for input_name in parameter_inputs + if input_name in parameter_index_by_name + ) + if len(parameter_indices) != len(parameter_inputs): + missing = tuple( + input_name for input_name in parameter_inputs if input_name not in parameter_index_by_name + ) + raise ValueError( + f"Message rule {spec.rule_type!r} {node.op} node {node.name!r} " + f"references undeclared parameters {missing!r}" + ) + node_index = add_node( + node.op, + lhs=lhs, + parameter_indices=parameter_indices, + ) + elif node.op in _UNARY_MESSAGE_OPS: + if len(node.inputs) != 1: + raise ValueError(f"Message rule {spec.rule_type!r} {node.op} node {node.name!r} must be unary") + node_index = add_node(node.op, lhs=ensure_input_node(str(node.inputs[0]))) + elif node.op in _BINARY_MESSAGE_OPS: + if len(node.inputs) != 2: + raise ValueError(f"Message rule {spec.rule_type!r} {node.op} node {node.name!r} must be binary") + node_index = add_node( + node.op, + lhs=ensure_input_node(str(node.inputs[0])), + rhs=ensure_input_node(str(node.inputs[1])), + ) + else: + raise ValueError(f"Message rule {spec.rule_type!r} uses unsupported C++ catalog op {node.op!r}") + node_index_by_output[str(node.output)] = node_index + + if spec.output_boundary not in node_index_by_output: + raise ValueError(f"Message rule {spec.rule_type!r} does not produce output boundary {spec.output_boundary!r}") + add_node("emit_projected_message", lhs=node_index_by_output[spec.output_boundary]) + return tuple(node_rows) + + +def message_rule_lowering_catalog_header_text() -> str: + specs = _ordered_message_rule_lowering_catalog_specs() + lowering_id_by_kind = _message_rule_lowering_id_by_kind(specs) + lines: list[str] = [ + "// Generated by cortical.fabric.backend.message_rules.message_rule_lowering_catalog_header_text.", + "// Do not edit by hand; update MessageRuleBackendSpec registrations instead.", + "", + "#pragma once", + "", + ] + pattern_rows: list[tuple[MessageRuleBackendSpec, str, str, str, int]] = [] + for spec in specs: + lowering_id_symbol = _cpp_catalog_symbol(spec.rule_type, "LoweringId") + sources_symbol = _cpp_catalog_symbol(spec.rule_type, "Sources") + parameters_symbol = _cpp_catalog_symbol(spec.rule_type, "Parameters") + nodes_symbol = _cpp_catalog_symbol(spec.rule_type, "Nodes") + lines.append(f"static constexpr int {lowering_id_symbol} = {lowering_id_by_kind[str(spec.lowering_kind)]};") + lines.append("") + lines.append(f"static const MessageRuleSourcePattern {sources_symbol}[] = {{") + for source in spec.sources: + lines.append( + " {" + + ", ".join( + ( + _cpp_enum(_CPP_MESSAGE_SOURCE_KIND, source.kind, label="source kind"), + _cpp_enum(_CPP_MESSAGE_RESET_POLICY, source.reset_policy, label="source reset policy"), + _cpp_enum(_CPP_MESSAGE_RESET_SCOPE, source.reset_scope, label="source reset scope"), + ) + ) + + "}," + ) + lines.extend(("};", "")) + + require_parameter_names = _message_rule_catalog_requires_parameter_names(spec) + lines.append(f"static const MessageRuleParameterPattern {parameters_symbol}[] = {{") + for parameter in spec.parameters: + parameter_name = _cpp_string_literal(parameter.name) if require_parameter_names else "nullptr" + lines.append( + " {" + + ", ".join( + ( + parameter_name, + _cpp_enum(_CPP_MESSAGE_PARAMETER_ROLE, parameter.role, label="parameter role"), + _cpp_enum( + _CPP_MESSAGE_SHARING_SCOPE, + parameter.sharing_scope, + label="parameter sharing scope", + ), + "true" if require_parameter_names else "false", + ) + ) + + "}," + ) + lines.extend(("};", "")) + + node_rows = _message_rule_lowering_catalog_nodes(spec) + for node_index, (_kind, _lhs, _rhs, _parameter_index, _source_index, parameter_indices) in enumerate(node_rows): + if not parameter_indices: + continue + indices_symbol = _cpp_catalog_symbol(spec.rule_type, f"Node{node_index}ParameterIndices") + values = ", ".join(str(index) for index in parameter_indices) + lines.append(f"static const int {indices_symbol}[] = {{{values}}};") + if any(parameter_indices for *_prefix, parameter_indices in node_rows): + lines.append("") + + lines.append(f"static const MessageRuleNodePattern {nodes_symbol}[] = {{") + for node_index, (kind, lhs, rhs, parameter_index, source_index, parameter_indices) in enumerate(node_rows): + indices_symbol = ( + _cpp_catalog_symbol(spec.rule_type, f"Node{node_index}ParameterIndices") + if parameter_indices + else "nullptr" + ) + parameter_count = len(parameter_indices) + lines.append( + " {" + + ", ".join( + ( + _cpp_enum(_CPP_MESSAGE_OP_KIND, kind, label="op kind"), + str(lhs), + str(rhs), + str(parameter_index), + str(source_index), + "MessageSegmentKind::ReceiverNeighborhood", + indices_symbol, + str(parameter_count), + ) + ) + + "}," + ) + lines.extend(("};", "")) + pattern_rows.append((spec, sources_symbol, parameters_symbol, nodes_symbol, len(node_rows) - 1)) + + lines.extend( + ( + "inline const MessageRuleLoweringPattern* registered_message_rule_lowering_patterns_begin() {", + " static const MessageRuleLoweringPattern kRegisteredMessageRuleLoweringPatterns[] = {", + ) + ) + for spec, sources_symbol, parameters_symbol, nodes_symbol, projected_message_node in pattern_rows: + lowering_id_symbol = _cpp_catalog_symbol(spec.rule_type, "LoweringId") + lines.extend( + ( + " {", + f" {lowering_id_symbol},", + f" {_cpp_string_literal(spec.rule_type)},", + f" {sources_symbol},", + f" static_cast(sizeof({sources_symbol}) / sizeof({sources_symbol}[0])),", + f" {parameters_symbol},", + f" static_cast(sizeof({parameters_symbol}) / sizeof({parameters_symbol}[0])),", + f" {nodes_symbol},", + f" static_cast(sizeof({nodes_symbol}) / sizeof({nodes_symbol}[0])),", + f" {projected_message_node},", + " },", + ) + ) + lines.extend( + ( + " };", + " return kRegisteredMessageRuleLoweringPatterns;", + "}", + "", + "inline const MessageRuleLoweringPattern* registered_message_rule_lowering_patterns_end() {", + f" return registered_message_rule_lowering_patterns_begin() + {len(pattern_rows)};", + "}", + "", + ) + ) + return "\n".join(lines) + + +def validate_message_rule_lowering_catalog_header(header_text: str) -> None: + expected = message_rule_lowering_catalog_header_text() + actual = str(header_text) + if actual != expected: + raise AssertionError( + "message_rule_lowering_catalog.cuh is out of sync with registered MessageRuleBackendSpec rows; " + "regenerate it from message_rule_lowering_catalog_header_text()" + ) + + def classify_message_rule(rule: MessageRuleIR) -> str: - source_kinds = {source.kind for source in rule.sources} - op_kinds = {node.op for node in rule.nodes} - parameter_roles = {parameter.role for parameter in rule.parameters} - has_dot_product_attention = {"receiver_slot", "sender_public_prev", "edge_distance"}.issubset(source_kinds) and { - "linear", - "dot", - "segment_softmax", - "segment_weighted_sum", - }.issubset(op_kinds) - if ( - has_dot_product_attention - and "projection" in parameter_roles - and rule.output_boundary == PROJECTED_MESSAGE_BOUNDARY - ): - return DOT_PRODUCT_SEGMENT_SOFTMAX_WEIGHTED_SUM + if rule.lowering_kind_id: + return rule.lowering_kind_id return "unsupported" -def default_dot_product_message_rule_ir(*, kv_group_count: int, cell_count: int) -> MessageRuleIR: - sharing_mode: MessageSharingScope = ( - "sender_group_shared" if int(kv_group_count) < int(cell_count) else "sender_local" +def _primitive_by_message_op(rule: MessageRuleIR) -> dict[str, str]: + return {str(binding.source_op): str(binding.primitive) for binding in rule.primitive_bindings} + + +def compile_message_rule(rule: MessageRuleIR) -> CompiledMessageRule: + lowering_kind = rule.lowering_kind + if lowering_kind == "unsupported": + raise ValueError( + f"Unsupported Fabric message rule {rule.name!r}; message rules must lower into supported primitive op rows" + ) + if rule.output_boundary != PROJECTED_MESSAGE_BOUNDARY: + raise ValueError( + f"Unsupported Fabric message rule output_boundary={rule.output_boundary!r}; " + f"expected {PROJECTED_MESSAGE_BOUNDARY!r}" + ) + parameter_names = frozenset(str(parameter.name) for parameter in rule.parameters) + primitive_by_op = _primitive_by_message_op(rule) + primitive_ops: list[CompiledMessagePrimitiveOp] = [] + for node_index, node in enumerate(rule.nodes): + if node.op in {"source", "parameter"}: + continue + primitive = primitive_by_op.get(node.op) + if primitive is None: + raise ValueError( + f"Unsupported Fabric message rule op {node.op!r} in {rule.name!r}; " + "add the op to fabric.cuda.nn lowering before using it" + ) + if not is_callable_cuda_nn_primitive(primitive): + raise ValueError( + f"Unsupported Fabric message primitive {primitive!r} in {rule.name!r}; " + "add the primitive to fabric.cuda.nn lowering before using it" + ) + primitive_ops.append( + CompiledMessagePrimitiveOp( + node_index=int(node_index), + node_name=str(node.name), + primitive=str(primitive), + source_op=str(node.op), + inputs=tuple(str(item) for item in node.inputs), + outputs=(str(node.output),), + attributes=( + ("node_name", str(node.name)), + ("source_op", str(node.op)), + ("output_boundary", str(rule.output_boundary)), + ("lowering_kind", str(lowering_kind)), + ("output_dim_role", str(rule.output_dim_role)), + ), + parameter_bindings=tuple(str(item) for item in node.inputs if str(item) in parameter_names), + ) + ) + if not primitive_ops: + raise ValueError(f"Unsupported Fabric message rule {rule.name!r}; no executable primitive rows were lowered") + return CompiledMessageRule( + rule_name=str(rule.name), + lowering_kind=str(lowering_kind), + output_boundary=str(rule.output_boundary), + output_dim_role=str(rule.output_dim_role), + primitive_ops=tuple(primitive_ops), + runtime_modules=tuple(rule.runtime_modules), + runtime_parameters=tuple(rule.runtime_parameters), + static_tensors=tuple(rule.static_tensors), + parameter_reducer_kind=str(rule.parameter_reducer_kind), + param_grad_outputs=tuple(rule.param_grad_outputs), + native_executors=tuple(rule.native_executors), + parameter_reducer=rule.parameter_reducer, ) - return MessageRuleIR( - name="dot_product", - sources=( - MessageRuleSource("receiver_slot", "receiver_slot"), - MessageRuleSource("sender_public_prev", "sender_public_prev", "zero_source_rows", "batch_row"), - MessageRuleSource("edge_distance", "edge_distance"), - ), - parameters=( - MessageRuleParameter("q_weight", "projection", "rule_global"), - MessageRuleParameter("k_weight", "projection", sharing_mode, int(kv_group_count)), - MessageRuleParameter("v_weight", "projection", sharing_mode, int(kv_group_count)), - MessageRuleParameter("out_weight", "projection", "rule_global"), - ), - nodes=_dot_product_nodes(), - output_boundary=PROJECTED_MESSAGE_BOUNDARY, + + +def default_dot_product_message_rule_ir(*, kv_group_count: int, cell_count: int) -> MessageRuleIR: + return build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=kv_group_count, + cell_count=cell_count, ) @@ -147,40 +822,47 @@ def default_dot_product_message_rule_summary(*, kv_group_count: int, cell_count: return default_dot_product_message_rule_ir(kv_group_count=kv_group_count, cell_count=cell_count).summary() -def _dot_product_nodes() -> tuple[MessageRuleNode, ...]: - return ( - MessageRuleNode("receiver_source", "source", (), "receiver_slot"), - MessageRuleNode("sender_source", "source", (), "sender_public_prev"), - MessageRuleNode("edge_distance_source", "source", (), "edge_distance"), - MessageRuleNode("q_projection", "linear", ("receiver_slot", "q_weight"), "q"), - MessageRuleNode("k_projection", "linear", ("sender_public_prev", "k_weight"), "k"), - MessageRuleNode("v_projection", "linear", ("sender_public_prev", "v_weight"), "v"), - MessageRuleNode("logits", "dot", ("q", "k"), "logits"), - MessageRuleNode("biased_logits", "add", ("logits", "edge_distance"), "biased_logits"), - MessageRuleNode("weights", "segment_softmax", ("biased_logits",), "weights"), - MessageRuleNode("weighted_value", "segment_weighted_sum", ("weights", "v"), "weighted_value"), - MessageRuleNode("projected_message", "linear", ("weighted_value", "out_weight"), PROJECTED_MESSAGE_BOUNDARY), - ) - - def join_signature(parts: Sequence[str]) -> str: return ";".join(parts) __all__ = [ + "CompiledMessagePrimitiveOp", + "CompiledMessageRule", "DOT_PRODUCT_SEGMENT_SOFTMAX_WEIGHTED_SUM", + "DOT_PRODUCT_FIXED_SLOT_CONTEXT_GATE", + "DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE", + "MessageOpPrimitiveBinding", "PROJECTED_MESSAGE_BOUNDARY", "MessageOpKind", + "MessageNativeExecutorDirection", + "MessageNativeExecutorPhase", + "MessageRuleNativeExecutorEntrypointSpec", "MessageParameterRole", + "MessageRuleBackendSpec", "MessageRuleIR", + "MessageRuleNativeExecutorSpec", "MessageRuleNode", + "MessageRuleParamGradOutputSpec", "MessageRuleParameter", + "MessageRuleParameterReducerSpec", + "MessageRuleRuntimeModule", + "MessageRuleRuntimeParameterSpec", "MessageRuleSource", + "MessageRuleStaticTensorSpec", "MessageRuleSummary", "MessageSharingScope", "MessageSourceKind", + "build_message_rule_backend_spec", + "build_message_rule_ir", "classify_message_rule", + "compile_message_rule", "default_dot_product_message_rule_ir", "default_dot_product_message_rule_summary", "join_signature", + "message_rule_lowering_catalog_header_text", + "ordered_message_rule_backend_spec_types", + "register_message_rule_backend_spec_builder", + "registered_message_rule_backend_spec_types", + "validate_message_rule_lowering_catalog_header", ] diff --git a/src/cortical/fabric/backend/planner.py b/src/cortical/fabric/backend/planner.py index 7385c720..d49c9e3b 100644 --- a/src/cortical/fabric/backend/planner.py +++ b/src/cortical/fabric/backend/planner.py @@ -2,7 +2,7 @@ import math from collections.abc import Mapping -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Literal from cortical.fabric.backend.buckets import FabricBucket, ReceiverKind @@ -10,8 +10,25 @@ from cortical.fabric.backend.cell_backend import CellBackendSpec, SurfaceBackendVariant from cortical.fabric.backend.ir import FabricIR from cortical.fabric.backend.plan_cache import FabricPlanCache, PlanCacheKey +from cortical.fabric.backend.primitives import CALLABLE_CUDA_NN_PRIMITIVES from cortical.fabric.backend.reuse import ExecutionFamily, MathBackend from cortical.fabric.backend.tape import TapePolicy +from cortical.fabric.backend.temporal_plan import ( + SequenceSurfaceRoute, + TemporalBackwardWindowPlan, + TemporalBoundaryPlan, + TemporalCarryPlan, + TemporalCheckpointPlan, + TemporalEnginePlan, + TemporalExecutionPlan, + TemporalExecutorPlan, + TemporalGradientBoundaryPlan, + TemporalMaterializationPlan, + TemporalOutputRequestPlan, + TemporalSchedulePlan, + TemporalStaticValuePlan, + TemporalSubstratePlan, +) from cortical.fabric.backend.workspace import WorkspacePlan, WorkspacePlanner @@ -99,31 +116,6 @@ class PlannedFabricExecution: tape_policy_bin: str -SequenceSurfaceRouteKind = Literal["none", "sequence_surface"] -SequenceSurfaceExecutorKind = Literal["none", "temporal_bucket_sequence"] -SequenceSurfaceImplementationKind = Literal["none", "cell_recurrence_surface", "flat_transition_buckets"] - - -@dataclass(frozen=True) -class SequenceSurfaceRoute: - kind: SequenceSurfaceRouteKind - executor: SequenceSurfaceExecutorKind - supported: bool - reason: str - active_populations: tuple[str, ...] - surface_key: str | None = None - implementation_executor: SequenceSurfaceImplementationKind = "none" - bucket_count: int = 0 - - @property - def uses_cell_recurrence_surface(self) -> bool: - return self.implementation_executor == "cell_recurrence_surface" - - @property - def uses_flat_transition_buckets(self) -> bool: - return self.implementation_executor == "flat_transition_buckets" - - @dataclass(frozen=True) class GradBoundarySpec: name: str @@ -221,8 +213,7 @@ def saved_launch_counts(self) -> tuple[str, ...]: @property def residual_glue_demotions(self) -> tuple[str, ...]: - residual = tuple(op.residual_glue_demotion for op in self.ops if op.residual_glue_demotion) - return residual or ("none",) + return tuple(op.residual_glue_demotion for op in self.ops if op.residual_glue_demotion) @property def family_behaviors(self) -> tuple[BackwardFamilyBehavior, ...]: @@ -248,6 +239,7 @@ class PlannedFabricBackwardExecution: training_mode: bool device_caps_key: tuple[object, ...] tape_policy_bin: str + temporal_plan: TemporalExecutionPlan | None = None class FabricExecutionPlanner: @@ -290,7 +282,7 @@ def plan_execution( caps_key = _caps_key(device_caps) for bucket in self.ir.buckets: key = PlanCacheKey( - bucket_signature=bucket.signature, + bucket_signature=bucket.planner_signature, shape_bin=shape_bin, dtype="float32", training_mode=training, @@ -300,7 +292,7 @@ def plan_execution( ) cached = self.plan_cache.get(key) if isinstance(cached, FabricExecutionPlan): - plans.append(cached) + plans.append(replace(cached, bucket_id=bucket.bucket_id)) continue plan = self._heuristic_plan_bucket( bucket, @@ -359,13 +351,212 @@ def plan_sequence_surface_route( "sequence_surface", "temporal_bucket_sequence", True, - "planner_selected_sequence_surface_flat_transition_bucket_executor", + "planner_selected_sequence_surface_registered_temporal_program", active_populations, - surface_key="flat_bucket_sequence_surface", - implementation_executor="flat_transition_buckets", + surface_key="registered_temporal_sequence_surface", + implementation_executor="registered_temporal_program", bucket_count=len(active_populations), ) + def plan_temporal_execution( + self, + *, + device_type: str, + dtype: str, + partitioned_layout: bool, + configured_backend: str, + constant_k: int | None, + time_steps: int, + training: bool, + input_boundary: str, + output_boundary: str = "sequence", + readout_output_boundary: str = "cells", + output_contract: str = "full_cells", + materialize_final_state: bool = True, + state_is_fresh: bool = True, + has_resets: bool = False, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, + ) -> TemporalExecutionPlan: + route = self.plan_sequence_surface_route( + device_type=device_type, + dtype=dtype, + partitioned_layout=partitioned_layout, + constant_k=constant_k, + ) + outer_time_steps = max(1, int(time_steps)) + inner_steps = None if constant_k is None else max(0, int(constant_k)) + total_scan_steps = None if inner_steps is None else outer_time_steps * inner_steps + schedule = TemporalSchedulePlan( + schedule_kind="runtime_variable_k" if inner_steps is None else "scalar_constant_k", + outer_time_steps=outer_time_steps, + inner_steps=inner_steps, + total_scan_steps=total_scan_steps, + per_timestep_k_semantic="represented_not_lowered" if inner_steps is None else "scalar_schedule", + ) + if route.uses_registered_temporal_program: + substrate_kind = "flat_bucket_temporal_substrate" + else: + substrate_kind = "unsupported" + population_cardinality = ( + "none" if not route.active_populations else "single" if len(route.active_populations) == 1 else "multi" + ) + substrate = TemporalSubstratePlan( + substrate_kind=substrate_kind, + active_populations=route.active_populations, + bucket_count=self.ir.bucket_count, + population_cardinality=population_cardinality, + partitioned_layout=bool(partitioned_layout), + bucket_identity=_temporal_bucket_identity(substrate_kind), + ) + boundary = TemporalBoundaryPlan( + input_boundary=str(input_boundary), + readout_output_boundary="pooled" if readout_output_boundary == "pooled" else "cells", + output_contract=str(output_contract), + resets="present" if has_resets else "absent", + ) + backend_name = _select_temporal_backend_name( + configured_backend=configured_backend, + device_type=device_type, + supports_cuda_backend=route.supported, + ) + fresh_state_population_cache, fresh_state_population_cache_reason = _plan_fresh_state_population_cache( + backend_name=backend_name, + route=route, + state_is_fresh=state_is_fresh, + training=training, + materialize_final_state=materialize_final_state, + inner_steps=inner_steps, + ) + carry = TemporalCarryPlan( + initial_state="fresh" if state_is_fresh else "provided", + materialize_final_state=bool(materialize_final_state), + carry_policy="planner_recorded_materialization", + fresh_state_population_cache=fresh_state_population_cache, + fresh_state_population_cache_reason=fresh_state_population_cache_reason, + ) + output_request = _plan_temporal_output_request( + output_boundary=str(output_boundary), + output_contract=boundary.output_contract, + readout_output_boundary=boundary.readout_output_boundary, + outer_time_steps=outer_time_steps, + inner_steps=inner_steps, + training=training, + materialize_final_state=bool(materialize_final_state), + ) + if not training: + gradient_boundary = TemporalGradientBoundaryPlan(mode="inference", horizon_steps=None, owner="planner") + backward_window = TemporalBackwardWindowPlan(window_kind="none", max_window_steps=None, owner="planner") + planned_checkpoint_steps = None + elif gradient_horizon_steps is None: + gradient_boundary = TemporalGradientBoundaryPlan(mode="full_horizon", horizon_steps=None, owner="planner") + backward_window = TemporalBackwardWindowPlan( + window_kind="full_horizon", + max_window_steps=total_scan_steps, + owner="planner", + ) + planned_checkpoint_steps = _default_temporal_checkpoint_steps( + total_scan_steps=total_scan_steps, + horizon_steps=None, + output_request=output_request, + ) + else: + requested_horizon_steps = max(1, int(gradient_horizon_steps)) + horizon_steps = _effective_temporal_horizon_steps( + total_scan_steps=total_scan_steps, + requested_horizon_steps=requested_horizon_steps, + ) + horizon_covers_full_stream = total_scan_steps is not None and int(total_scan_steps) <= int(horizon_steps) + gradient_boundary = TemporalGradientBoundaryPlan( + mode="full_horizon" if horizon_covers_full_stream else "rolling_horizon", + horizon_steps=horizon_steps, + owner="planner", + ) + backward_window = TemporalBackwardWindowPlan( + window_kind="full_horizon" if horizon_covers_full_stream else "rolling_horizon", + max_window_steps=horizon_steps, + owner="planner", + ) + planned_checkpoint_steps = _default_temporal_checkpoint_steps( + total_scan_steps=total_scan_steps, + horizon_steps=horizon_steps, + output_request=output_request, + ) + explicit_checkpoint_steps = None if checkpoint_steps is None else max(1, int(checkpoint_steps)) + checkpoint = TemporalCheckpointPlan( + checkpoint_kind="none" + if not training + else "explicit" + if explicit_checkpoint_steps is not None + else "planner_default", + checkpoint_steps=explicit_checkpoint_steps + if explicit_checkpoint_steps is not None + else planned_checkpoint_steps, + owner="planner", + ) + materialization = _plan_temporal_materialization( + training=training, + route=route, + output_request=output_request, + output_contract=boundary.output_contract, + materialize_final_state=bool(materialize_final_state), + has_resets=has_resets, + total_scan_steps=total_scan_steps, + backward_window=backward_window, + checkpoint=checkpoint, + ) + native_static_materialization = bool( + backend_name == "cuda" and route.supported and route.uses_registered_temporal_program + ) + autograd_static_values = bool( + backend_name == "cuda" + and route.implementation_executor == "registered_temporal_program" + and not native_static_materialization + ) + if not training: + static_value_mode = "inference_cache" + elif configured_backend == "pytorch": + static_value_mode = "pytorch_autograd_static_values" + elif autograd_static_values: + static_value_mode = "flat_bucket_autograd_static_values" + else: + static_value_mode = "detached_shared_values" + static_values = TemporalStaticValuePlan( + static_value_mode=static_value_mode, + native_static_materialization=native_static_materialization, + include_full_cell_kv_weight=not native_static_materialization, + detach_training_static_tensors=not autograd_static_values, + owner="planner", + ) + executor = TemporalExecutorPlan( + backend_name=backend_name, + executor=route.executor, + selected_implementation=route.implementation_executor, + surface_key=route.surface_key, + supported=route.supported, + reason=route.reason, + ) + engine = _plan_temporal_engine_owner( + backend_name=backend_name, + route=route, + training=training, + ) + return TemporalExecutionPlan( + sequence_surface_route=route, + schedule=schedule, + substrate=substrate, + boundary=boundary, + carry=carry, + output_request=output_request, + checkpoint=checkpoint, + materialization=materialization, + gradient_boundary=gradient_boundary, + backward_window=backward_window, + static_values=static_values, + executor=executor, + engine=engine, + ) + def plan_backward_execution( self, *, @@ -376,6 +567,7 @@ def plan_backward_execution( device_caps: DeviceCaps, tape_policy: TapePolicy | None, supported_variants: tuple[SurfaceBackendVariant, ...] | None = None, + temporal_plan: TemporalExecutionPlan | None = None, ) -> PlannedFabricBackwardExecution: shape_bin = _shape_bin( batch_size=batch_size, @@ -392,7 +584,7 @@ def plan_backward_execution( caps_key = _caps_key(device_caps) for bucket in self.ir.buckets: receiver_key = PlanCacheKey( - bucket_signature=("bwd_receiver", *bucket.signature), + bucket_signature=("bwd_receiver", *bucket.planner_signature), shape_bin=shape_bin, dtype="float32", training_mode=training, @@ -402,7 +594,7 @@ def plan_backward_execution( ) cached_receiver = self.plan_cache.get(receiver_key) if isinstance(cached_receiver, FabricExecutionPlan): - receiver_plans.append(cached_receiver) + receiver_plans.append(replace(cached_receiver, bucket_id=bucket.bucket_id)) else: receiver_plan = self._heuristic_plan_backward_bucket( bucket, @@ -417,7 +609,7 @@ def plan_backward_execution( receiver_plans.append(receiver_plan) sender_key = PlanCacheKey( - bucket_signature=("bwd_sender", *bucket.signature), + bucket_signature=("bwd_sender", *bucket.planner_signature), shape_bin=shape_bin, dtype="float32", training_mode=training, @@ -427,7 +619,7 @@ def plan_backward_execution( ) cached_sender = self.plan_cache.get(sender_key) if isinstance(cached_sender, FabricExecutionPlan): - sender_plans.append(cached_sender) + sender_plans.append(replace(cached_sender, bucket_id=bucket.bucket_id)) continue sender_plan = self._heuristic_plan_backward_bucket( bucket, @@ -458,6 +650,7 @@ def plan_backward_execution( training_mode=training, device_caps_key=caps_key, tape_policy_bin=tape_bin, + temporal_plan=temporal_plan, ) def _build_physical_backward_plan( @@ -518,11 +711,11 @@ def _build_physical_backward_plan( ) ) _validate_backward_family_registry(tuple(ops)) - workspace_aliases = ("thin_reverse:uses_forward_tape_and_autograd_workspaces",) + workspace_aliases = ("registered_reverse_program:uses_compiler_owned_workspaces",) return PhysicalBackwardPlan( ops=tuple(ops), workspace_aliases=workspace_aliases, - workspace_peak_bytes=("thin_reverse:profile_required",), + workspace_peak_bytes=("registered_reverse_program:profile_required",), tape_mode=tape_policy_bin, recompute_mode=_backward_recompute_mode_from_tape(tape_policy_bin), ) @@ -911,6 +1104,300 @@ def _active_recurrent_populations(ir: FabricIR) -> tuple[str, ...]: return tuple(names) +def _select_temporal_backend_name(*, configured_backend: str, device_type: str, supports_cuda_backend: bool) -> str: + if configured_backend == "pytorch": + return "pytorch" + if configured_backend == "cuda": + if device_type != "cuda": + raise RuntimeError("Fabric backend='cuda' requires CUDA tensors") + if not supports_cuda_backend: + raise RuntimeError("Fabric backend='cuda' requested an unsupported CUDA backend surface") + return "cuda" + if device_type == "cuda" and supports_cuda_backend: + return "cuda" + return "pytorch" + + +def _default_temporal_checkpoint_steps( + *, + total_scan_steps: int | None, + horizon_steps: int | None, + output_request: TemporalOutputRequestPlan, +) -> int: + if horizon_steps is not None: + if total_scan_steps is None: + return max(1, int(horizon_steps)) + return max(1, min(int(total_scan_steps), int(horizon_steps))) + if total_scan_steps is None: + return 32 + required_scan_steps = _output_required_physical_span_steps( + total_scan_steps=int(total_scan_steps), + output_request=output_request, + ) + return max(1, min(required_scan_steps, 32)) + + +def _plan_temporal_materialization( + *, + training: bool, + route: SequenceSurfaceRoute, + output_request: TemporalOutputRequestPlan, + output_contract: str, + materialize_final_state: bool, + has_resets: bool, + total_scan_steps: int | None, + backward_window: TemporalBackwardWindowPlan, + checkpoint: TemporalCheckpointPlan, +) -> TemporalMaterializationPlan: + if not training: + return TemporalMaterializationPlan( + reverse_artifact_kind="none", + checkpoint_steps=None, + recompute_window_steps=None, + output_materialization="none", + owner="planner", + reason="materialization=inference", + ) + output_materialization = "outputs_and_final_state" if materialize_final_state else "outputs_only" + checkpoint_steps = checkpoint.checkpoint_steps + recompute_window_steps = backward_window.max_window_steps + if not route.uses_registered_temporal_program: + return TemporalMaterializationPlan( + reverse_artifact_kind="checkpoint_recompute", + checkpoint_steps=checkpoint_steps, + recompute_window_steps=recompute_window_steps, + output_materialization=output_materialization, + owner="planner", + reason="materialization=checkpoint_recompute;reason=unsupported_flat_bucket_route", + ) + if output_request.autograd_seed_kind == "none": + return TemporalMaterializationPlan( + reverse_artifact_kind="none", + checkpoint_steps=checkpoint_steps, + recompute_window_steps=recompute_window_steps, + output_materialization=output_materialization, + owner="planner", + reason="materialization=none;reason=no_autograd_seed", + ) + required_surfaces = set(output_request.required_backward_surfaces) + if not {"message_primitive_adjoint", "transition_primitive_adjoint"}.issubset(required_surfaces): + return TemporalMaterializationPlan( + reverse_artifact_kind="checkpoint_recompute", + checkpoint_steps=checkpoint_steps, + recompute_window_steps=recompute_window_steps, + output_materialization=output_materialization, + owner="planner", + reason="materialization=checkpoint_recompute;reason=missing_required_reverse_surfaces", + ) + if output_contract not in {"output_cells", "pooled_output_cells"}: + return TemporalMaterializationPlan( + reverse_artifact_kind="checkpoint_recompute", + checkpoint_steps=checkpoint_steps, + recompute_window_steps=recompute_window_steps, + output_materialization=output_materialization, + owner="planner", + reason="materialization=checkpoint_recompute;reason=full_cell_output_contract", + ) + physical_steps = None if total_scan_steps is None else max(1, int(total_scan_steps)) + return TemporalMaterializationPlan( + reverse_artifact_kind="store_step_artifacts", + checkpoint_steps=checkpoint_steps, + recompute_window_steps=recompute_window_steps, + output_materialization=output_materialization, + owner="planner", + reason=( + "materialization=store_step_artifacts;" + "reason=compiler_owned_fused_forward_artifact_tensor_table;" + f"physical_steps={_optional_int_for_reason(physical_steps)};" + f"window_steps={_optional_int_for_reason(recompute_window_steps)};" + f"checkpoint_steps={_optional_int_for_reason(checkpoint_steps)};" + f"final_state={int(bool(materialize_final_state))};" + f"resets={int(bool(has_resets))}" + ), + ) + + +def _optional_int_for_reason(value: int | None) -> str: + return "none" if value is None else str(int(value)) + + +def _output_required_physical_span_steps( + *, + total_scan_steps: int, + output_request: TemporalOutputRequestPlan, +) -> int: + if output_request.materialize_final_state: + return max(1, int(total_scan_steps)) + if output_request.first_physical_step is None: + return max(1, int(total_scan_steps)) + if output_request.selector_kind == "terminal_outer_step": + return max(1, min(int(total_scan_steps), int(output_request.first_physical_step) + 1)) + if output_request.selector_kind == "explicit_outer_steps" and output_request.emitted_output_count == 0: + return 1 + return max(1, int(total_scan_steps)) + + +def _effective_temporal_horizon_steps(*, total_scan_steps: int | None, requested_horizon_steps: int) -> int: + if total_scan_steps is None: + return max(1, int(requested_horizon_steps)) + return max(1, min(int(total_scan_steps), int(requested_horizon_steps))) + + +def _plan_temporal_output_request( + *, + output_boundary: str, + output_contract: str, + readout_output_boundary: Literal["cells", "pooled"], + outer_time_steps: int, + inner_steps: int | None, + training: bool, + materialize_final_state: bool, +) -> TemporalOutputRequestPlan: + outer_steps = max(1, int(outer_time_steps)) + if output_boundary == "terminal": + selector_kind = "terminal_outer_step" + first_outer_step = outer_steps - 1 + outer_stride = None + emitted_output_count = 1 + else: + selector_kind = "all_outer_steps" + first_outer_step = 0 + outer_stride = 1 + emitted_output_count = outer_steps + + if inner_steps is None: + first_physical_step = None + physical_stride = None + else: + physical_stride_value = max(1, int(inner_steps)) + first_physical_step = int(first_outer_step) * physical_stride_value + physical_stride_value - 1 + physical_stride = physical_stride_value if emitted_output_count > 1 else None + + if not training: + autograd_seed_kind = "none" + required_backward_surfaces: tuple[str, ...] = () + checkpoint_policy_basis = "inference" + else: + autograd_seed_kind = ( + "emitted_output_grad_plus_final_state_grad" if materialize_final_state else "emitted_output_grad" + ) + required_backward_surfaces = _required_backward_surfaces_for_output_request( + output_contract=output_contract, + materialize_final_state=materialize_final_state, + ) + checkpoint_policy_basis = ( + "emitted_output_and_final_state_schedule" if materialize_final_state else "emitted_output_schedule" + ) + + return TemporalOutputRequestPlan( + selector_kind=selector_kind, + explicit_outer_steps=(), + first_outer_step=first_outer_step, + outer_stride=outer_stride, + emitted_output_count=emitted_output_count, + first_physical_step=first_physical_step, + physical_stride=physical_stride, + output_surface=str(output_contract), + readout_surface=readout_output_boundary, + materialize_final_state=bool(materialize_final_state), + materialization="outputs_and_final_state" if materialize_final_state else "outputs_only", + autograd_seed_kind=autograd_seed_kind, + required_backward_surfaces=required_backward_surfaces, + checkpoint_policy_basis=checkpoint_policy_basis, + ) + + +def _required_backward_surfaces_for_output_request( + *, + output_contract: str, + materialize_final_state: bool, +) -> tuple[str, ...]: + surfaces = ["emitted_output_adjoint"] + if output_contract == "pooled_output_cells": + surfaces.extend(("readout_pool_adjoint", "output_cell_adjoint")) + elif output_contract == "output_cells": + surfaces.append("output_cell_adjoint") + elif output_contract == "full_cells": + surfaces.append("full_cell_adjoint") + else: + surfaces.append("output_contract_adjoint") + surfaces.extend( + ( + "message_primitive_adjoint", + "transition_primitive_adjoint", + "input_boundary_adjoint", + "parameter_adjoint", + ) + ) + if materialize_final_state: + surfaces.append("final_state_adjoint") + return tuple(surfaces) + + +def _temporal_bucket_identity(substrate_kind: str) -> str: + if substrate_kind == "flat_bucket_temporal_substrate": + return "flat_bucket_identity" + return "unsupported" + + +def _plan_temporal_engine_owner( + *, + backend_name: str, + route: SequenceSurfaceRoute, + training: bool, +) -> TemporalEnginePlan: + if backend_name == "cuda" and route.uses_registered_temporal_program: + return TemporalEnginePlan( + forward_owner="registered_fused_forward_program_cuda", + backward_owner="registered_reverse_executor_bindings" if training else "none", + checkpoint_owner="planner", + target_owner="registered_temporal_executor_bindings", + status="registered_executor_bindings", + reason="flat_bucket_temporal_scan_owned_by_registered_fused_forward_program_cuda", + ) + if backend_name == "pytorch": + return TemporalEnginePlan( + forward_owner="pytorch_reference", + backward_owner="pytorch_reference" if training else "none", + checkpoint_owner="planner", + target_owner="registered_temporal_executor_bindings", + status="pytorch_reference", + reason="pytorch_reference_backend", + ) + return TemporalEnginePlan( + forward_owner="unsupported", + backward_owner="unsupported" if training else "none", + checkpoint_owner="planner", + target_owner="registered_temporal_executor_bindings", + status="unsupported", + reason=route.reason, + ) + + +def _plan_fresh_state_population_cache( + *, + backend_name: str, + route: SequenceSurfaceRoute, + state_is_fresh: bool, + training: bool, + materialize_final_state: bool, + inner_steps: int | None, +) -> tuple[bool, str]: + if backend_name != "cuda": + return False, "backend_not_cuda" + if not route.uses_registered_temporal_program: + return False, "requires_registered_temporal_program" + if not state_is_fresh: + return False, "provided_state" + if training: + return False, "training_requires_materialized_state" + if materialize_final_state: + return False, "final_state_materialized" + if inner_steps != 1: + return False, "requires_k1" + return True, "fresh_registered_inference_without_final_state" + + def _pow2_bin(value: int) -> int: value = max(1, int(value)) return 1 << (value - 1).bit_length() @@ -1023,13 +1510,13 @@ def _candidate_capability_variant( ), "public_projection": BackwardFamilyBehavior( family="public_projection", - behavior="thin_reverse", + behavior="explicit_backward", owner="physical_backward_plan", demotion_policy="unsupported_public_projection_demotes", ), "readout": BackwardFamilyBehavior( family="readout", - behavior="thin_reverse", + behavior="explicit_backward", owner="physical_backward_plan", demotion_policy="unsupported_readout_demotes", ), @@ -1042,50 +1529,6 @@ def _candidate_capability_variant( } -_CALLABLE_CUDA_NN_PRIMITIVES: tuple[str, ...] = ( - "linear", - "bias", - "add", - "split", - "concat", - "view", - "transpose", - "sigmoid", - "tanh", - "activation", - "mul", - "fma", - "exp", - "rsqrt", - "clamp", - "where", - "reduce_sum", - "reduce_sumsq", - "reduce_max", - "segment_sum", - "segment_softmax", - "normalize", - "gather_input_ports", - "gather_recurrent_public", - "attention_logits", - "weighted_sum", - "emit_state", - "emit_public", - "readout_project", - "state_affine", - "reduction_boundary", - "state_epilogue_policy", - "diagonal_recurrence", - "message_bucket_regular_local_receiver_owned", - "message_bucket_degree_bucketed_sparse", - "message_bucket_ragged_grouped_sparse", - "matmul", - "diag_rtu", - "gated_logspace_recurrence", - "norm_or_identity", -) - - _PRIMITIVE_BACKWARD_REGISTRY: dict[str, PrimitiveBackwardBehavior] = { "linear": PrimitiveBackwardBehavior( primitive="linear", @@ -1341,7 +1784,7 @@ def _candidate_capability_variant( ), "readout_project": PrimitiveBackwardBehavior( primitive="readout_project", - behavior="thin_reverse", + behavior="explicit_backward", family="readout", rule="readout_projection_adjoint", tape="save_inputs", @@ -1451,12 +1894,12 @@ def _candidate_capability_variant( def cuda_nn_callable_primitives() -> tuple[str, ...]: - return _CALLABLE_CUDA_NN_PRIMITIVES + return CALLABLE_CUDA_NN_PRIMITIVES def cuda_nn_primitive_backward_behaviors() -> tuple[PrimitiveBackwardBehavior, ...]: _validate_primitive_backward_registry() - return tuple(_PRIMITIVE_BACKWARD_REGISTRY[name] for name in _CALLABLE_CUDA_NN_PRIMITIVES) + return tuple(_PRIMITIVE_BACKWARD_REGISTRY[name] for name in CALLABLE_CUDA_NN_PRIMITIVES) def cuda_nn_primitive_backward_behavior(primitive: str) -> PrimitiveBackwardBehavior: @@ -1468,8 +1911,8 @@ def cuda_nn_primitive_backward_behavior(primitive: str) -> PrimitiveBackwardBeha def _validate_primitive_backward_registry() -> None: - missing = sorted(set(_CALLABLE_CUDA_NN_PRIMITIVES) - set(_PRIMITIVE_BACKWARD_REGISTRY)) - extra = sorted(set(_PRIMITIVE_BACKWARD_REGISTRY) - set(_CALLABLE_CUDA_NN_PRIMITIVES)) + missing = sorted(set(CALLABLE_CUDA_NN_PRIMITIVES) - set(_PRIMITIVE_BACKWARD_REGISTRY)) + extra = sorted(set(_PRIMITIVE_BACKWARD_REGISTRY) - set(CALLABLE_CUDA_NN_PRIMITIVES)) if missing or extra: fragments: list[str] = [] if missing: @@ -1567,7 +2010,7 @@ def _op_plan( tape_spec="forward_tape_boundary_descriptors", save_recompute_policy=_backward_recompute_mode_from_tape(tape_policy_bin), workspace_aliases=("forward_tape:borrowed",), - workspace_peak_bytes="thin_reverse:profile_required", + workspace_peak_bytes="registered_reverse_program:profile_required", demotion_reason=demotion_reason, launch_count=launch_count, saved_launch_count=saved_launch_count, @@ -1721,7 +2164,6 @@ def _transition_backward_ops( demotion_reason="none", launch_count="state_epilogue_backward:gated_logspace_cuda_tiled", saved_launch_count="state_epilogue_backward:active_cuda_owner", - residual_glue_demotion="lowered_state_epilogue_backward:explicit_cuda_executor", ) ) return tuple(ops) @@ -1744,14 +2186,13 @@ def _public_projection_backward_op( ParamGradBinding("public_proj.weight", "public_proj.weight.grad", "grouped_receiver_accumulate"), ParamGradBinding("public_proj.bias", "public_proj.bias.grad", "dense_reduction"), ), - executor="explicit_public_projection_thin_reverse", - execution_mode="thin_reverse", + executor="registered_sender_kv_projection_backward_executor", + execution_mode="registered_reverse_program", bucket_plan=bucket_plan, tape_policy_bin=tape_policy_bin, - demotion_reason="thin_reverse_path:explicit_executor", - launch_count="public_projection_backward:explicit_thin_reverse", + demotion_reason="none", + launch_count="sender_kv_projection_backward:registered_cuda", saved_launch_count="", - residual_glue_demotion="lowered_public_projection_backward:thin_reverse_path:explicit_executor", ) @@ -1778,14 +2219,13 @@ def _readout_message_backward_op( _grad_boundary(name="grad_sender_value", boundary="message_value", gradient="d_value"), ), param_bindings=(), - executor="explicit_readout_message_thin_reverse", - execution_mode="thin_reverse", + executor="projection_reduction_boundary_backward", + execution_mode="registered_reverse_program", bucket_plan=bucket_plan, tape_policy_bin=tape_policy_bin, - demotion_reason="thin_reverse_path:explicit_executor", - launch_count="readout_message_backward:explicit_thin_reverse", + demotion_reason="none", + launch_count="readout_message_backward:registered_reverse_executor", saved_launch_count="", - residual_glue_demotion="lowered_readout_message_backward:thin_reverse_path:explicit_executor", ) @@ -1806,14 +2246,13 @@ def _readout_projection_backward_op( ParamGradBinding("readout_out.weight", "readout_out.weight.grad", "dense_reduction"), ParamGradBinding("readout_out.bias", "readout_out.bias.grad", "dense_reduction"), ), - executor="explicit_readout_projection_thin_reverse", - execution_mode="thin_reverse", + executor="projection_reduction_boundary_backward", + execution_mode="registered_reverse_program", bucket_plan=bucket_plan, tape_policy_bin=tape_policy_bin, - demotion_reason="thin_reverse_path:explicit_executor", - launch_count="readout_projection_backward:explicit_thin_reverse", + demotion_reason="none", + launch_count="readout_projection_backward:registered_reverse_executor", saved_launch_count="", - residual_glue_demotion="lowered_readout_projection_backward:thin_reverse_path:explicit_executor", ) diff --git a/src/cortical/fabric/backend/primitives.py b/src/cortical/fabric/backend/primitives.py new file mode 100644 index 00000000..c716ddf8 --- /dev/null +++ b/src/cortical/fabric/backend/primitives.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +CALLABLE_CUDA_NN_PRIMITIVES: tuple[str, ...] = ( + "linear", + "bias", + "add", + "split", + "concat", + "view", + "transpose", + "sigmoid", + "tanh", + "activation", + "mul", + "fma", + "exp", + "rsqrt", + "clamp", + "where", + "reduce_sum", + "reduce_sumsq", + "reduce_max", + "segment_sum", + "segment_softmax", + "normalize", + "gather_input_ports", + "gather_recurrent_public", + "attention_logits", + "weighted_sum", + "emit_state", + "emit_public", + "readout_project", + "state_affine", + "reduction_boundary", + "state_epilogue_policy", + "diagonal_recurrence", + "message_bucket_regular_local_receiver_owned", + "message_bucket_degree_bucketed_sparse", + "message_bucket_ragged_grouped_sparse", + "matmul", + "diag_rtu", + "gated_logspace_recurrence", + "norm_or_identity", +) + +CALLABLE_CUDA_NN_PRIMITIVE_SET = frozenset(CALLABLE_CUDA_NN_PRIMITIVES) + + +def is_callable_cuda_nn_primitive(name: str) -> bool: + return str(name) in CALLABLE_CUDA_NN_PRIMITIVE_SET + + +__all__ = [ + "CALLABLE_CUDA_NN_PRIMITIVES", + "CALLABLE_CUDA_NN_PRIMITIVE_SET", + "is_callable_cuda_nn_primitive", +] diff --git a/src/cortical/fabric/backend/pytorch/message_passing.py b/src/cortical/fabric/backend/pytorch/message_passing.py index bb04d4f1..683adebb 100644 --- a/src/cortical/fabric/backend/pytorch/message_passing.py +++ b/src/cortical/fabric/backend/pytorch/message_passing.py @@ -3,6 +3,7 @@ import math import torch +import torch.nn.functional as F def compute_messages_dense_raw( @@ -55,6 +56,62 @@ def compute_messages_dense_raw( ).view(batch_size, time_steps, num_cells, value_dim) +def compute_fixed_slot_context_messages_dense( + v_all: torch.Tensor, + *, + batch_size: int, + time_steps: int, + q_slot: torch.Tensor, + query_context_scalar: torch.Tensor, + sender_slot_key: torch.Tensor, + sender_context_key: torch.Tensor, + message_output_weight: torch.Tensor, + neighbor_idx: torch.Tensor, + neighbor_valid: torch.Tensor, + edge_distance: torch.Tensor, + edge_delay: torch.Tensor | None, + distance_logit_scale: float, + step_idx: int | torch.Tensor, + head_dim: int, +) -> torch.Tensor: + flat_batch, num_cells, value_dim = v_all.shape + num_neighbors = int(neighbor_idx.shape[1]) + flat_neighbor_idx = neighbor_idx.reshape(-1) + sender_key = torch.cat((sender_slot_key, sender_context_key), dim=-1) + k_neighbors = sender_key.index_select(0, flat_neighbor_idx).view( + num_cells, + num_neighbors, + 2 * head_dim, + ) + v_neighbors = v_all.index_select(1, flat_neighbor_idx).view(flat_batch, num_cells, num_neighbors, value_dim) + q_context = v_all[:, :, :head_dim] * query_context_scalar.reshape(1, 1, 1).to(dtype=v_all.dtype) + q = torch.cat((q_slot.view(1, num_cells, head_dim).expand(flat_batch, -1, -1), q_context), dim=-1) + logits = ( + q.view(flat_batch, num_cells, 1, 2 * head_dim) * k_neighbors.view(1, num_cells, num_neighbors, 2 * head_dim) + ).sum(dim=-1) / math.sqrt(float(2 * head_dim)) + valid_windows = neighbor_valid.view(1, num_cells, num_neighbors) + if distance_logit_scale > 0.0: + logits = logits - distance_logit_scale * edge_distance.view(1, num_cells, num_neighbors) + if edge_delay is not None: + valid_windows = valid_windows & _broadcast_delay_mask( + edge_delay=edge_delay, + step_idx=step_idx, + batch_size=batch_size, + time_steps=time_steps, + num_receivers=num_cells, + num_neighbors=num_neighbors, + device=v_all.device, + ) + return _project_fixed_slot_context_messages( + logits, + v_neighbors, + valid_windows=valid_windows, + message_output_weight=message_output_weight, + batch_size=batch_size, + time_steps=time_steps, + ) + + def compute_messages_step_subset_raw( k_all: torch.Tensor, v_all: torch.Tensor, @@ -101,6 +158,76 @@ def compute_messages_step_subset_raw( ).view(batch_size, num_receivers, value_dim) +def compute_fixed_slot_context_messages_step_subset_raw( + sender_v: torch.Tensor, + *, + receiver_context_v: torch.Tensor, + q_slot: torch.Tensor, + query_context_scalar: torch.Tensor, + sender_slot_key: torch.Tensor, + sender_context_key: torch.Tensor, + message_output_weight: torch.Tensor, + neighbor_idx: torch.Tensor, + neighbor_valid: torch.Tensor, + edge_distance: torch.Tensor, + edge_delay: torch.Tensor, + use_delay: bool, + step_idx: int | torch.Tensor, + head_dim: int, + distance_logit_scale: float, +) -> torch.Tensor: + batch_size, _, value_dim = sender_v.shape + num_receivers = int(neighbor_idx.shape[0]) + num_neighbors = int(neighbor_idx.shape[1]) + flat_neighbor_idx = neighbor_idx.reshape(-1) + sender_key = torch.cat((sender_slot_key, sender_context_key), dim=-1) + k_neighbors = sender_key.index_select(0, flat_neighbor_idx).view( + num_receivers, + num_neighbors, + 2 * head_dim, + ) + v_neighbors = sender_v.index_select(1, flat_neighbor_idx).view( + batch_size, + num_receivers, + num_neighbors, + value_dim, + ) + q_context = receiver_context_v[:, :, :head_dim] * query_context_scalar.reshape(1, 1, 1).to( + dtype=receiver_context_v.dtype + ) + q = torch.cat((q_slot.view(1, num_receivers, head_dim).expand(batch_size, -1, -1), q_context), dim=-1) + logits = ( + q.view(batch_size, num_receivers, 1, 2 * head_dim) + * k_neighbors.view( + 1, + num_receivers, + num_neighbors, + 2 * head_dim, + ) + ).sum(dim=-1) / math.sqrt(float(2 * head_dim)) + valid_windows = neighbor_valid.view(1, num_receivers, num_neighbors) + if distance_logit_scale > 0.0: + logits = logits - distance_logit_scale * edge_distance.view(1, num_receivers, num_neighbors) + if use_delay: + valid_windows = valid_windows & _broadcast_delay_mask( + edge_delay=edge_delay, + step_idx=step_idx, + batch_size=batch_size, + time_steps=1, + num_receivers=num_receivers, + num_neighbors=num_neighbors, + device=sender_v.device, + ) + return _project_fixed_slot_context_messages( + logits, + v_neighbors, + valid_windows=valid_windows, + message_output_weight=message_output_weight, + batch_size=batch_size, + time_steps=1, + ).squeeze(1) + + def compute_messages_step_subset_partitioned_raw( input_k: torch.Tensor, input_v: torch.Tensor, @@ -134,6 +261,43 @@ def compute_messages_step_subset_partitioned_raw( ) +def compute_fixed_slot_context_messages_step_subset_partitioned_raw( + input_v: torch.Tensor, + recurrent_v: torch.Tensor, + *, + q_slot: torch.Tensor, + query_context_scalar: torch.Tensor, + sender_slot_key: torch.Tensor, + sender_context_key: torch.Tensor, + message_output_weight: torch.Tensor, + neighbor_idx: torch.Tensor, + neighbor_valid: torch.Tensor, + edge_distance: torch.Tensor, + edge_delay: torch.Tensor, + use_delay: bool, + step_idx: int | torch.Tensor, + head_dim: int, + distance_logit_scale: float, +) -> torch.Tensor: + return compute_fixed_slot_context_messages_step_subset_raw( + torch.cat((input_v, recurrent_v), dim=1), + receiver_context_v=recurrent_v, + q_slot=q_slot, + query_context_scalar=query_context_scalar, + sender_slot_key=sender_slot_key, + sender_context_key=sender_context_key, + message_output_weight=message_output_weight, + neighbor_idx=neighbor_idx, + neighbor_valid=neighbor_valid, + edge_distance=edge_distance, + edge_delay=edge_delay, + use_delay=use_delay, + step_idx=step_idx, + head_dim=head_dim, + distance_logit_scale=distance_logit_scale, + ) + + def compute_messages_sequence_subset_partitioned_raw( input_k_seq: torch.Tensor, input_v_seq: torch.Tensor, @@ -211,8 +375,32 @@ def _broadcast_delay_mask( return edge_delay.view(1, num_receivers, num_neighbors) <= step_flat.view(batch_size * time_steps, 1, 1) +def _project_fixed_slot_context_messages( + logits: torch.Tensor, + v_neighbors: torch.Tensor, + *, + valid_windows: torch.Tensor, + message_output_weight: torch.Tensor, + batch_size: int, + time_steps: int, +) -> torch.Tensor: + weights = torch.softmax(logits.masked_fill(~valid_windows, float("-inf")).to(dtype=torch.float32), dim=-1).to( + dtype=v_neighbors.dtype + ) + weights = torch.where(valid_windows, weights, torch.zeros_like(weights)) + has_valid = valid_windows.any(dim=-1, keepdim=True) + weights = torch.where(has_valid, weights, torch.zeros_like(weights)) + weighted_value = torch.matmul(weights.unsqueeze(-2), v_neighbors).squeeze(-2) + projected = F.linear(weighted_value, message_output_weight) + projected = F.layer_norm(projected, (int(projected.shape[-1]),), eps=1.0e-5) + return projected.view(batch_size, time_steps, int(projected.shape[-2]), int(projected.shape[-1])) + + __all__ = [ "compute_messages_dense_raw", + "compute_fixed_slot_context_messages_dense", + "compute_fixed_slot_context_messages_step_subset_partitioned_raw", + "compute_fixed_slot_context_messages_step_subset_raw", "compute_messages_sequence_subset_partitioned_raw", "compute_messages_step_subset_partitioned_raw", "compute_messages_step_subset_raw", diff --git a/src/cortical/fabric/backend/pytorch/runtime_ops.py b/src/cortical/fabric/backend/pytorch/runtime_ops.py index aa0fd910..ea17e429 100644 --- a/src/cortical/fabric/backend/pytorch/runtime_ops.py +++ b/src/cortical/fabric/backend/pytorch/runtime_ops.py @@ -5,6 +5,9 @@ import torch from cortical.fabric.backend.pytorch.message_passing import ( + compute_fixed_slot_context_messages_dense, + compute_fixed_slot_context_messages_step_subset_partitioned_raw, + compute_fixed_slot_context_messages_step_subset_raw, compute_messages_dense_raw, compute_messages_step_subset_partitioned_raw, compute_messages_step_subset_raw, @@ -29,6 +32,25 @@ def compute_messages( step_idx: int | torch.Tensor, ) -> torch.Tensor: batch_size, time_steps, _, _ = z_prev.shape + if _uses_fixed_slot_context_message_rule(runtime): + tensors = _fixed_slot_context_message_tensors(runtime, receiver_scope="all", sender_scope="all") + return compute_fixed_slot_context_messages_dense( + v_all, + batch_size=batch_size, + time_steps=time_steps, + q_slot=tensors["q_slot"], + query_context_scalar=tensors["query_context_scalar"], + sender_slot_key=tensors["sender_slot_key"], + sender_context_key=tensors["sender_context_key"], + message_output_weight=tensors["message_output_weight"], + neighbor_idx=runtime.neighbor_idx, + neighbor_valid=runtime.neighbor_valid, + edge_distance=runtime.edge_distance, + edge_delay=runtime.edge_delay if runtime.spec.anatomy.edge_delay is not None else None, + distance_logit_scale=float(runtime.config.message.distance_logit_scale), + step_idx=step_idx, + head_dim=runtime.head_dim, + ) return runtime.msg_out( compute_messages_dense_raw( k_all, @@ -40,7 +62,7 @@ def compute_messages( neighbor_valid=runtime.neighbor_valid, edge_distance=runtime.edge_distance, edge_delay=runtime.edge_delay if runtime.spec.anatomy.edge_delay is not None else None, - distance_logit_scale=float(runtime.config.distance_logit_scale), + distance_logit_scale=float(runtime.config.message.distance_logit_scale), step_idx=step_idx, head_dim=runtime.head_dim, value_dim=runtime.value_dim, @@ -185,7 +207,42 @@ def compute_messages_step_subset_raw_backend( step_idx: int | torch.Tensor, owner_tag: str = "generic", ) -> torch.Tensor: - del owner_tag + if str(owner_tag) == "recurrent" and _uses_fixed_slot_context_message_rule(runtime): + tensors = _fixed_slot_context_message_tensors(runtime, receiver_scope="recurrent", sender_scope="sender") + return compute_fixed_slot_context_messages_step_subset_raw( + v_all, + receiver_context_v=v_all.index_select(1, runtime.recurrent_sender_idx), + q_slot=tensors["q_slot"], + query_context_scalar=tensors["query_context_scalar"], + sender_slot_key=tensors["sender_slot_key"], + sender_context_key=tensors["sender_context_key"], + message_output_weight=tensors["message_output_weight"], + neighbor_idx=neighbor_idx, + neighbor_valid=neighbor_valid, + edge_distance=edge_distance, + edge_delay=edge_delay, + use_delay=use_delay, + step_idx=step_idx, + head_dim=runtime.head_dim, + distance_logit_scale=float(runtime.config.message.distance_logit_scale), + ) + if str(owner_tag) == "readout" and _uses_fixed_slot_context_message_rule(runtime): + tensors = _fixed_slot_context_message_tensors(runtime, receiver_scope="recurrent", sender_scope="sender") + fixed_sender_k = tensors["sender_slot_key"].unsqueeze(0).expand(int(v_all.shape[0]), -1, -1) + return compute_messages_step_subset_raw( + fixed_sender_k, + v_all, + q_subset=q_subset, + neighbor_idx=neighbor_idx, + neighbor_valid=neighbor_valid, + edge_distance=edge_distance, + edge_delay=edge_delay, + use_delay=use_delay, + step_idx=step_idx, + head_dim=runtime.head_dim, + value_dim=runtime.value_dim, + distance_logit_scale=float(runtime.config.message.distance_logit_scale), + ) return compute_messages_step_subset_raw( k_all, v_all, @@ -198,7 +255,7 @@ def compute_messages_step_subset_raw_backend( step_idx=step_idx, head_dim=runtime.head_dim, value_dim=runtime.value_dim, - distance_logit_scale=float(runtime.config.distance_logit_scale), + distance_logit_scale=float(runtime.config.message.distance_logit_scale), ) @@ -218,7 +275,56 @@ def compute_messages_step_subset_partitioned_raw_backend( step_idx: int | torch.Tensor, owner_tag: str = "generic", ) -> torch.Tensor: - del owner_tag + if str(owner_tag) == "recurrent" and _uses_fixed_slot_context_message_rule(runtime): + tensors = _fixed_slot_context_message_tensors(runtime, receiver_scope="recurrent", sender_scope="sender") + return compute_fixed_slot_context_messages_step_subset_partitioned_raw( + input_v, + recurrent_v, + q_slot=tensors["q_slot"], + query_context_scalar=tensors["query_context_scalar"], + sender_slot_key=tensors["sender_slot_key"], + sender_context_key=tensors["sender_context_key"], + message_output_weight=tensors["message_output_weight"], + neighbor_idx=neighbor_idx, + neighbor_valid=neighbor_valid, + edge_distance=edge_distance, + edge_delay=edge_delay, + use_delay=use_delay, + step_idx=step_idx, + head_dim=runtime.head_dim, + distance_logit_scale=float(runtime.config.message.distance_logit_scale), + ) + if str(owner_tag) == "readout" and _uses_fixed_slot_context_message_rule(runtime): + tensors = _fixed_slot_context_message_tensors(runtime, receiver_scope="recurrent", sender_scope="sender") + input_count = int(input_v.shape[1]) + recurrent_count = int(recurrent_v.shape[1]) + fixed_sender_k = tensors["sender_slot_key"] + fixed_input_k = fixed_sender_k[:input_count].unsqueeze(0).expand(int(input_v.shape[0]), -1, -1) + fixed_recurrent_k = ( + fixed_sender_k[input_count : input_count + recurrent_count] + .unsqueeze(0) + .expand( + int(input_v.shape[0]), + -1, + -1, + ) + ) + return compute_messages_step_subset_partitioned_raw( + fixed_input_k, + input_v, + fixed_recurrent_k, + recurrent_v, + q_subset=q_subset, + neighbor_idx=neighbor_idx, + neighbor_valid=neighbor_valid, + edge_distance=edge_distance, + edge_delay=edge_delay, + use_delay=use_delay, + step_idx=step_idx, + head_dim=runtime.head_dim, + value_dim=runtime.value_dim, + distance_logit_scale=float(runtime.config.message.distance_logit_scale), + ) return compute_messages_step_subset_partitioned_raw( input_k, input_v, @@ -233,7 +339,7 @@ def compute_messages_step_subset_partitioned_raw_backend( step_idx=step_idx, head_dim=runtime.head_dim, value_dim=runtime.value_dim, - distance_logit_scale=float(runtime.config.distance_logit_scale), + distance_logit_scale=float(runtime.config.message.distance_logit_scale), ) @@ -251,6 +357,55 @@ def project_output_cells_step_raw_backend( ) +def _uses_fixed_slot_context_message_rule(runtime: Any) -> bool: + message_rule = getattr(getattr(runtime, "backend_ir", None), "message_rule", None) + lowering_kind = str(getattr(message_rule, "lowering_kind", "")) + return lowering_kind in { + "dot_product_fixed_slot_context_nudge", + "dot_product_fixed_slot_context_gate", + } + + +def _fixed_slot_context_message_tensors( + runtime: Any, + *, + receiver_scope: str, + sender_scope: str, +) -> dict[str, torch.Tensor]: + query_module = runtime.message_rule_modules.get("message_query_slot_proj") + sender_module = runtime.message_rule_modules.get("message_sender_slot_key_proj") + if query_module is None or sender_module is None: + raise RuntimeError("Fixed-slot context message rule requires installed query and sender slot modules") + q_slot = query_module(runtime.slot_embed).view(int(runtime.coords.shape[0]), int(runtime.head_dim)) + sender_slot_key = sender_module(runtime.slot_embed).view(int(runtime.coords.shape[0]), int(runtime.head_dim)) + if receiver_scope == "recurrent": + q_slot = q_slot.index_select(0, runtime.recurrent_cell_idx) + elif receiver_scope != "all": + raise RuntimeError(f"Unsupported fixed-slot context receiver scope {receiver_scope!r}") + if sender_scope == "sender": + sender_slot_key = sender_slot_key.index_select(0, runtime.sender_cell_idx) + sender_context_key = runtime.message_rule_parameters["message_sender_context_key"].index_select( + 0, + runtime.sender_cell_idx, + ) + elif sender_scope == "all": + sender_context_key = runtime.message_rule_parameters["message_sender_context_key"] + else: + raise RuntimeError(f"Unsupported fixed-slot context sender scope {sender_scope!r}") + query_context_scalar = runtime.message_rule_parameters.get("message_query_nudge_scale") + if query_context_scalar is None: + query_context_scalar = runtime.message_rule_parameters.get("message_query_context_gate") + if query_context_scalar is None: + raise RuntimeError("Fixed-slot context message rule requires a query context scalar parameter") + return { + "q_slot": q_slot, + "query_context_scalar": query_context_scalar, + "sender_slot_key": sender_slot_key, + "sender_context_key": sender_context_key, + "message_output_weight": runtime.msg_out.weight, + } + + __all__ = [ "compute_messages", "compute_messages_step_subset_partitioned_raw_backend", diff --git a/src/cortical/fabric/backend/readout_rule_specs.py b/src/cortical/fabric/backend/readout_rule_specs.py new file mode 100644 index 00000000..e5fde2a1 --- /dev/null +++ b/src/cortical/fabric/backend/readout_rule_specs.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from cortical.fabric.backend.readout_rules import ( + ReadoutRuleBackendSpec, + ReadoutRuleNativeExecutorSpec, + ReadoutRuleStaticTensorSpec, + register_readout_rule_backend_spec_builder, +) + + +_SUPPORTED_PROJECTION_READOUT_LOWERING_KINDS = ( + "mean_readout_project", + "flatten_readout_project", + "attn_readout_project", + "attention_readout_project", +) + + +def _projection_readout_static_tensors() -> tuple[ReadoutRuleStaticTensorSpec, ...]: + return ( + ReadoutRuleStaticTensorSpec("output_q", "readout_output_query", 5), + ReadoutRuleStaticTensorSpec("value_to_output_weight", "readout_value_to_output_weight", 6), + ReadoutRuleStaticTensorSpec("output_cell_bias", "readout_output_cell_bias", 7), + ) + + +def _projection_readout_native_executors(lowering_kind: str) -> tuple[ReadoutRuleNativeExecutorSpec, ...]: + strategy_suffix = str(lowering_kind).removesuffix("_readout_project") + return ( + ReadoutRuleNativeExecutorSpec( + direction="forward", + executor_id=2, + executor_name=f"{strategy_suffix}_projection_reduction_boundary", + strategy_id=f"forward.readout.{strategy_suffix}_projection_reduction_boundary.v1", + native_callable="native.forward.output_projection_reduction_boundary.v1", + implementation_contract="registered_readout_executor_binding_rows", + cxx_entrypoints=( + "bind_projection_reduction_boundary_readout_handler", + "run_projection_reduction_boundary_readout_message", + "run_projection_reduction_boundary_readout_projection", + "run_projection_reduction_boundary_readout_projection_into", + ), + cxx_entrypoint_phases=("bind", "message", "projection", "projection_into"), + ), + ReadoutRuleNativeExecutorSpec( + direction="reverse", + executor_id=4, + executor_name=f"{strategy_suffix}_projection_reduction_boundary_backward", + strategy_id=f"reverse.readout.{strategy_suffix}_projection_reduction_boundary.v1", + native_callable="native.reverse.output_projection_reduction_boundary.v1", + implementation_contract="registered_readout_reverse_executor_binding_rows", + cxx_entrypoints=( + "run_projection_reduction_boundary_readout_backward", + "run_projection_reduction_boundary_output_message_backward", + ), + cxx_entrypoint_phases=("readout_backward", "output_message_backward"), + ), + ) + + +def build_projection_readout_backend_spec(*, lowering_kind: str) -> ReadoutRuleBackendSpec: + return ReadoutRuleBackendSpec( + lowering_kind=str(lowering_kind), + static_tensors=_projection_readout_static_tensors(), + native_executors=_projection_readout_native_executors(lowering_kind), + ) + + +for _lowering_kind in _SUPPORTED_PROJECTION_READOUT_LOWERING_KINDS: + register_readout_rule_backend_spec_builder(_lowering_kind, build_projection_readout_backend_spec) + + +__all__ = ["build_projection_readout_backend_spec"] diff --git a/src/cortical/fabric/backend/readout_rules.py b/src/cortical/fabric/backend/readout_rules.py new file mode 100644 index 00000000..863be9fa --- /dev/null +++ b/src/cortical/fabric/backend/readout_rules.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal + +READOUT_OUTPUT_BOUNDARY = "model_output" +ReadoutNativeExecutorDirection = Literal["forward", "reverse"] +ReadoutNativeExecutorPhase = Literal[ + "bind", + "message", + "projection", + "projection_into", + "readout_backward", + "output_message_backward", +] + + +@dataclass(frozen=True) +class ReadoutRuleIR: + name: str + pool: str + readout_slots: int + output_boundary: str = READOUT_OUTPUT_BOUNDARY + + @property + def lowering_kind(self) -> str: + return classify_readout_rule(self) + + +@dataclass(frozen=True) +class CompiledReadoutPrimitiveOp: + node_index: int + primitive: str + source_op: str + inputs: tuple[str, ...] + outputs: tuple[str, ...] + attributes: tuple[tuple[str, str], ...] = () + parameter_inputs: tuple[str, ...] = () + + +@dataclass(frozen=True) +class CompiledReadoutRule: + rule_name: str + lowering_kind: str + pool: str + readout_slots: int + output_boundary: str + primitive_ops: tuple[CompiledReadoutPrimitiveOp, ...] + static_tensors: tuple[ReadoutRuleStaticTensorSpec, ...] = () + native_executors: tuple[ReadoutRuleNativeExecutorSpec, ...] = () + + @property + def primitive_names(self) -> tuple[str, ...]: + return tuple(dict.fromkeys(op.primitive for op in self.primitive_ops)) + + +@dataclass(frozen=True) +class ReadoutRuleStaticTensorSpec: + name: str + program_access_name: str + program_access_opcode: int + + +@dataclass(frozen=True) +class ReadoutRuleNativeExecutorSpec: + direction: ReadoutNativeExecutorDirection + executor_id: int + executor_name: str + strategy_id: str + native_callable: str + implementation_contract: str + cxx_entrypoints: tuple[str, ...] + cxx_entrypoint_phases: tuple[ReadoutNativeExecutorPhase, ...] = () + strategy_version: int = 1 + + @property + def cxx_entrypoint_contract(self) -> tuple[tuple[ReadoutNativeExecutorPhase, str], ...]: + return tuple(zip(self.cxx_entrypoint_phases, self.cxx_entrypoints, strict=True)) + + +@dataclass(frozen=True) +class ReadoutRuleBackendSpec: + lowering_kind: str + static_tensors: tuple[ReadoutRuleStaticTensorSpec, ...] + native_executors: tuple[ReadoutRuleNativeExecutorSpec, ...] + + +_SUPPORTED_READOUT_POOLS = frozenset({"mean", "flatten", "attn", "attention"}) +ReadoutRuleBackendSpecBuilder = Callable[..., ReadoutRuleBackendSpec] + +_READOUT_RULE_BACKEND_SPEC_BUILDERS: dict[str, ReadoutRuleBackendSpecBuilder] = {} + + +def register_readout_rule_backend_spec_builder( + lowering_kind: str, + builder: ReadoutRuleBackendSpecBuilder, +) -> None: + _READOUT_RULE_BACKEND_SPEC_BUILDERS[str(lowering_kind)] = builder + + +def _ensure_builtin_readout_rule_backend_specs_registered() -> None: + if _READOUT_RULE_BACKEND_SPEC_BUILDERS: + return + import cortical.fabric.backend.readout_rule_specs # noqa: F401 + + +def build_readout_rule_backend_spec(*, lowering_kind: str) -> ReadoutRuleBackendSpec: + _ensure_builtin_readout_rule_backend_specs_registered() + try: + builder = _READOUT_RULE_BACKEND_SPEC_BUILDERS[str(lowering_kind)] + except KeyError as exc: + raise ValueError(f"Unsupported Fabric readout rule backend {lowering_kind}") from exc + return builder(lowering_kind=str(lowering_kind)) + + +def registered_readout_rule_backend_spec_lowering_kinds() -> tuple[str, ...]: + _ensure_builtin_readout_rule_backend_specs_registered() + return tuple(sorted(_READOUT_RULE_BACKEND_SPEC_BUILDERS)) + + +def readout_rule_native_executor( + *, + lowering_kind: str, + direction: ReadoutNativeExecutorDirection, +) -> ReadoutRuleNativeExecutorSpec: + spec = build_readout_rule_backend_spec(lowering_kind=str(lowering_kind)) + matches = tuple(executor for executor in spec.native_executors if executor.direction == direction) + if len(matches) != 1: + raise RuntimeError( + "Readout rule backend spec must declare exactly one native executor for direction " + f"{direction!r}: lowering_kind={lowering_kind!r}; count={len(matches)}" + ) + return matches[0] + + +def classify_readout_rule(rule: ReadoutRuleIR) -> str: + pool = str(rule.pool) + if pool not in _SUPPORTED_READOUT_POOLS: + return "unsupported" + if rule.output_boundary != READOUT_OUTPUT_BOUNDARY: + return "unsupported" + return f"{pool}_readout_project" + + +def default_readout_rule_ir(*, readout_pool: str, readout_slots: int) -> ReadoutRuleIR: + return ReadoutRuleIR( + name="readout", + pool=str(readout_pool), + readout_slots=int(readout_slots), + ) + + +def compile_readout_rule(rule: ReadoutRuleIR) -> CompiledReadoutRule: + lowering_kind = rule.lowering_kind + if lowering_kind == "unsupported": + raise ValueError( + f"Unsupported Fabric readout rule pool={rule.pool!r}, output_boundary={rule.output_boundary!r}; " + "add the readout op to fabric.cuda.nn lowering before using it" + ) + backend_spec = build_readout_rule_backend_spec(lowering_kind=lowering_kind) + if int(rule.readout_slots) <= 0: + raise ValueError(f"Unsupported Fabric readout rule readout_slots={rule.readout_slots}; expected positive") + primitive_ops = ( + CompiledReadoutPrimitiveOp( + node_index=0, + primitive="readout_project", + source_op="readout_project", + inputs=("public_state", "output_q", "value_to_output_weight", "output_cell_bias"), + outputs=("output",), + attributes=( + ("pool", str(rule.pool)), + ("readout_slots", str(int(rule.readout_slots))), + ("output_boundary", str(rule.output_boundary)), + ("lowering_kind", str(lowering_kind)), + ), + parameter_inputs=("output_q", "value_to_output_weight", "output_cell_bias"), + ), + CompiledReadoutPrimitiveOp( + node_index=1, + primitive="reduction_boundary", + source_op="readout_boundary", + inputs=("output_grad",), + outputs=("public_state_grad",), + attributes=( + ("boundary", "readout"), + ("output_boundary", str(rule.output_boundary)), + ("lowering_kind", str(lowering_kind)), + ), + ), + ) + return CompiledReadoutRule( + rule_name=str(rule.name), + lowering_kind=str(lowering_kind), + pool=str(rule.pool), + readout_slots=int(rule.readout_slots), + output_boundary=str(rule.output_boundary), + primitive_ops=primitive_ops, + static_tensors=backend_spec.static_tensors, + native_executors=backend_spec.native_executors, + ) + + +__all__ = [ + "CompiledReadoutPrimitiveOp", + "CompiledReadoutRule", + "READOUT_OUTPUT_BOUNDARY", + "ReadoutNativeExecutorDirection", + "ReadoutNativeExecutorPhase", + "ReadoutRuleBackendSpec", + "ReadoutRuleIR", + "ReadoutRuleNativeExecutorSpec", + "ReadoutRuleStaticTensorSpec", + "build_readout_rule_backend_spec", + "classify_readout_rule", + "compile_readout_rule", + "default_readout_rule_ir", + "readout_rule_native_executor", + "register_readout_rule_backend_spec_builder", + "registered_readout_rule_backend_spec_lowering_kinds", +] diff --git a/src/cortical/fabric/backend/runtime_dispatch.py b/src/cortical/fabric/backend/runtime_dispatch.py index 4f3ccdd0..8b0b59bf 100644 --- a/src/cortical/fabric/backend/runtime_dispatch.py +++ b/src/cortical/fabric/backend/runtime_dispatch.py @@ -1,25 +1,11 @@ from __future__ import annotations import importlib -from collections.abc import Mapping -from typing import Any import torch from tensordict import TensorDict -def _cuda_runtime_ops(): - return importlib.import_module("cortical.fabric.backend.cuda.runtime_ops") - - -def _cuda_transition_execution(): - return importlib.import_module("cortical.fabric.backend.cuda.transition_execution") - - -def _cuda_flat_bucket_sequence_surface(): - return importlib.import_module("cortical.fabric.backend.cuda.sequence_surface.flat_buckets") - - def _pytorch_population_execution(): return importlib.import_module("cortical.fabric.backend.pytorch.population_execution") @@ -29,7 +15,7 @@ def _pytorch_runtime_ops(): class BackendRuntimeDispatchMixin: - def _supports_cuda_flat_bucket_sequence_backend( + def _supports_cuda_registered_temporal_program_backend( self, *, device: torch.device, @@ -41,7 +27,7 @@ def _supports_cuda_flat_bucket_sequence_backend( device=device, dtype=dtype, ) - return route.uses_flat_transition_buckets + return route.uses_registered_temporal_program def _compute_messages( self, @@ -58,15 +44,6 @@ def _compute_messages( self.head_dim + self.value_dim, ) k_all, v_all = kv_all.split((self.head_dim, self.value_dim), dim=-1) - if z_prev.is_cuda and self._active_backend_name != "pytorch": - return _cuda_runtime_ops().compute_messages( - self, - z_prev, - k_all=k_all, - v_all=v_all, - q=q, - step_idx=step_idx, - ) return _pytorch_runtime_ops().compute_messages( self, z_prev, @@ -105,16 +82,7 @@ def _project_sender_kv_from_cells_step( sender_group_size: int = 1, contiguous_kv: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - if self._active_backend_name == "pytorch": - return _pytorch_runtime_ops().project_sender_kv_from_cells_step_backend( - self, - sender_cells_step, - sender_input_to_kv_weight=sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=grouped_sender_input_to_kv_weight, - sender_group_size=sender_group_size, - contiguous_kv=contiguous_kv, - ) - return _cuda_runtime_ops().project_sender_kv_from_cells_step_backend( + return _pytorch_runtime_ops().project_sender_kv_from_cells_step_backend( self, sender_cells_step, sender_input_to_kv_weight=sender_input_to_kv_weight, @@ -131,15 +99,7 @@ def _project_sender_kv_from_cells_sequence( grouped_sender_input_to_kv_weight: torch.Tensor | None = None, sender_group_size: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: - if self._active_backend_name == "pytorch": - return _pytorch_runtime_ops().project_sender_kv_from_cells_sequence_backend( - self, - sender_cells_seq, - sender_input_to_kv_weight=sender_input_to_kv_weight, - grouped_sender_input_to_kv_weight=grouped_sender_input_to_kv_weight, - sender_group_size=sender_group_size, - ) - return _cuda_runtime_ops().project_sender_kv_from_cells_sequence_backend( + return _pytorch_runtime_ops().project_sender_kv_from_cells_sequence_backend( self, sender_cells_seq, sender_input_to_kv_weight=sender_input_to_kv_weight, @@ -154,14 +114,7 @@ def _project_boundary_source_sequence( input_projection_weight: torch.Tensor, input_projection_bias: torch.Tensor | None, ) -> torch.Tensor: - if self._active_backend_name == "pytorch": - return _pytorch_runtime_ops().project_boundary_source_sequence_backend( - self, - source_hidden_seq, - input_projection_weight=input_projection_weight, - input_projection_bias=input_projection_bias, - ) - return _cuda_runtime_ops().project_boundary_source_sequence_backend( + return _pytorch_runtime_ops().project_boundary_source_sequence_backend( self, source_hidden_seq, input_projection_weight=input_projection_weight, @@ -175,14 +128,7 @@ def _project_recurrent_kv_from_preproj_step( recurrent_preproj_to_kv_weight: torch.Tensor, recurrent_preproj_to_kv_bias: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - if self._active_backend_name == "pytorch": - return _pytorch_runtime_ops().project_recurrent_kv_from_preproj_step_backend( - self, - recurrent_preproj_step, - recurrent_preproj_to_kv_weight=recurrent_preproj_to_kv_weight, - recurrent_preproj_to_kv_bias=recurrent_preproj_to_kv_bias, - ) - return _cuda_runtime_ops().project_recurrent_kv_from_preproj_step_backend( + return _pytorch_runtime_ops().project_recurrent_kv_from_preproj_step_backend( self, recurrent_preproj_step, recurrent_preproj_to_kv_weight=recurrent_preproj_to_kv_weight, @@ -196,14 +142,7 @@ def _project_recurrent_hidden_from_preproj_step( out_proj_weight_t: torch.Tensor, out_proj_bias: torch.Tensor, ) -> torch.Tensor: - if self._active_backend_name == "pytorch": - return _pytorch_runtime_ops().project_recurrent_hidden_from_preproj_step_backend( - self, - recurrent_preproj_step, - out_proj_weight_t=out_proj_weight_t, - out_proj_bias=out_proj_bias, - ) - return _cuda_runtime_ops().project_recurrent_hidden_from_preproj_step_backend( + return _pytorch_runtime_ops().project_recurrent_hidden_from_preproj_step_backend( self, recurrent_preproj_step, out_proj_weight_t=out_proj_weight_t, @@ -220,17 +159,7 @@ def _project_recurrent_message_to_cell_step( fused_recurrent_cell_bias: torch.Tensor | None = None, fused_recurrent_population_input: bool = False, ) -> torch.Tensor: - if self._active_backend_name == "pytorch": - return _pytorch_runtime_ops().project_recurrent_message_to_cell_step_backend( - self, - recurrent_msg, - value_to_cell_weight=value_to_cell_weight, - recurrent_cell_bias=recurrent_cell_bias, - fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, - fused_recurrent_cell_bias=fused_recurrent_cell_bias, - fused_recurrent_population_input=fused_recurrent_population_input, - ) - return _cuda_runtime_ops().project_recurrent_message_to_cell_step_backend( + return _pytorch_runtime_ops().project_recurrent_message_to_cell_step_backend( self, recurrent_msg, value_to_cell_weight=value_to_cell_weight, @@ -240,111 +169,36 @@ def _project_recurrent_message_to_cell_step( fused_recurrent_population_input=fused_recurrent_population_input, ) - def _lower_backend_population_transition_shared( + def _project_recurrent_message_to_cell_step_for_message_rule( self, - *, - population_name: str | None = None, recurrent_msg: torch.Tensor, - packed_state_before: Any, - population_reset_step: torch.Tensor | None, - static_tensors: dict[str, object], - materialize_recurrent_kv: bool = True, - ) -> tuple[Any, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - if recurrent_msg.is_cuda and self._active_backend_name != "pytorch": - return _cuda_transition_execution().lower_backend_population_transition_shared( - self, - population_name=population_name, - recurrent_msg=recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=population_reset_step, - static_tensors=static_tensors, - materialize_recurrent_kv=materialize_recurrent_kv, - ) - return _pytorch_population_execution().lower_backend_population_transition_shared( - self, - population_name=population_name, - recurrent_msg=recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=population_reset_step, - static_tensors=static_tensors, - ) - - def _lower_backend_population_transition_forward_result_shared( - self, *, - population_name: str | None = None, - recurrent_msg: torch.Tensor, - packed_state_before: Any, - population_reset_step: torch.Tensor | None, - static_tensors: dict[str, object], - materialize_recurrent_kv: bool = True, - materialize_backward_tape: bool = False, - materialize_diagonal_preproj_tape: bool = True, - materialize_recurrence_backward_tape: bool | None = None, - materialize_next_state: bool = True, - materialize_trace_state_next: bool = True, - ) -> Any: - if recurrent_msg.is_cuda and self._active_backend_name != "pytorch": - return _cuda_transition_execution().lower_backend_population_transition_forward_result_shared( - self, - population_name=population_name, - recurrent_msg=recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=population_reset_step, - static_tensors=static_tensors, - materialize_recurrent_kv=materialize_recurrent_kv, - materialize_backward_tape=materialize_backward_tape, - materialize_diagonal_preproj_tape=materialize_diagonal_preproj_tape, - materialize_recurrence_backward_tape=materialize_recurrence_backward_tape, - materialize_next_state=materialize_next_state, - materialize_trace_state_next=materialize_trace_state_next, - ) - next_packed_state, recurrent_hidden, recurrent_k, recurrent_v = ( - _pytorch_population_execution().lower_backend_population_transition_shared( - self, - population_name=population_name, - recurrent_msg=recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=population_reset_step, - static_tensors=static_tensors, + value_to_cell_weight: torch.Tensor, + recurrent_cell_bias: torch.Tensor, + fused_recurrent_value_to_cell_weight: torch.Tensor | None = None, + fused_recurrent_cell_bias: torch.Tensor | None = None, + fused_recurrent_population_input: bool = False, + ) -> tuple[torch.Tensor, bool]: + if self._message_rule_output_dim_role() == "d_msg": + return ( + torch.nn.functional.linear(recurrent_msg, self.msg_to_cell.weight) + recurrent_cell_bias, + False, ) - ) - transition_execution = _cuda_transition_execution() - return transition_execution.TransitionForwardResult( - next_packed_state, - recurrent_hidden, - recurrent_k, - recurrent_v, - None, + return ( + self._project_recurrent_message_to_cell_step( + recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, + fused_recurrent_cell_bias=fused_recurrent_cell_bias, + fused_recurrent_population_input=fused_recurrent_population_input, + ), + bool(fused_recurrent_population_input), ) - def _lower_backend_population_transition_backward_shared( - self, - *, - population_name: str | None = None, - recurrent_msg: torch.Tensor, - packed_state_before: Any, - population_reset_step: torch.Tensor | None, - static_tensors: dict[str, object], - grad_next_packed_state: Any, - grad_recurrent_hidden: torch.Tensor | None, - need_grad_packed_state_before: bool = True, - forward_tape: Any | None = None, - ) -> Any: - if recurrent_msg.is_cuda and self._active_backend_name != "pytorch": - return _cuda_transition_execution().lower_backend_population_transition_backward_shared( - self, - population_name=population_name, - recurrent_msg=recurrent_msg, - packed_state_before=packed_state_before, - population_reset_step=population_reset_step, - static_tensors=static_tensors, - grad_next_packed_state=grad_next_packed_state, - grad_recurrent_hidden=grad_recurrent_hidden, - need_grad_packed_state_before=need_grad_packed_state_before, - forward_tape=forward_tape, - ) - raise RuntimeError("Fabric physical transition backward is only implemented for CUDA supported rows") + def _message_rule_output_dim_role(self) -> str: + message_program = getattr(getattr(self, "backend_ir", None), "message_program", None) + return str(getattr(message_program, "output_dim_role", "value_dim")) def _grouped_kv_weight(self, group_ids: torch.Tensor) -> torch.Tensor | None: if group_ids.numel() == 0: @@ -365,14 +219,7 @@ def _project_grouped_sender_cells( *, group_size: int, ) -> torch.Tensor: - if self._active_backend_name == "pytorch": - return _pytorch_runtime_ops().project_grouped_sender_cells_backend( - self, - sender_cells_step, - grouped_weight, - group_size=group_size, - ) - return _cuda_runtime_ops().project_grouped_sender_cells_backend( + return _pytorch_runtime_ops().project_grouped_sender_cells_backend( self, sender_cells_step, grouped_weight, @@ -421,22 +268,6 @@ def _compute_messages_step_subset_raw( local_receiver_idx_by_sender: torch.Tensor | None = None, owner_tag: str = "generic", ) -> torch.Tensor: - if k_all.is_cuda and self._active_backend_name != "pytorch": - return _cuda_runtime_ops().compute_messages_step_subset_raw_backend( - self, - k_all, - v_all, - q_subset=q_subset, - neighbor_idx=neighbor_idx, - neighbor_valid=neighbor_valid, - edge_distance=edge_distance, - edge_delay=edge_delay, - use_delay=use_delay, - step_idx=step_idx, - local_sender_idx=local_sender_idx, - local_receiver_idx_by_sender=local_receiver_idx_by_sender, - owner_tag=owner_tag, - ) return _pytorch_runtime_ops().compute_messages_step_subset_raw_backend( self, k_all, @@ -469,24 +300,6 @@ def _compute_messages_step_subset_partitioned_raw( local_receiver_idx_by_sender: torch.Tensor, owner_tag: str = "generic", ) -> torch.Tensor: - if input_k.is_cuda and self._active_backend_name != "pytorch" and input_k.dtype == torch.float32: - return _cuda_runtime_ops().compute_messages_step_subset_partitioned_raw_backend( - self, - input_k, - input_v, - recurrent_k, - recurrent_v, - q_subset=q_subset, - neighbor_idx=neighbor_idx, - neighbor_valid=neighbor_valid, - edge_distance=edge_distance, - edge_delay=edge_delay, - use_delay=use_delay, - step_idx=step_idx, - local_sender_idx=local_sender_idx, - local_receiver_idx_by_sender=local_receiver_idx_by_sender, - owner_tag=owner_tag, - ) return _pytorch_runtime_ops().compute_messages_step_subset_partitioned_raw_backend( self, input_k, @@ -525,132 +338,6 @@ def _run_recurrent_population_step( population_input_already_projected=population_input_already_projected, ) - def _run_transition_bucket_step( - self, - population_name: str, - recurrent_msg: torch.Tensor, - population_state: TensorDict | None, - *, - resets: torch.Tensor | None, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = True, - ) -> tuple[torch.Tensor, TensorDict]: - if not recurrent_msg.is_cuda or self._active_backend_name == "pytorch": - raise RuntimeError("CUDA flat-bucket transition execution requires active CUDA backend tensors") - return _cuda_flat_bucket_sequence_surface().run_transition_bucket_step( - self, - population_name, - recurrent_msg, - population_state, - resets=resets, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=materialize_next_state, - ) - - def _run_transition_buckets_step( - self, - recurrent_msg: torch.Tensor, - population_state: TensorDict, - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = True, - ) -> tuple[torch.Tensor, TensorDict]: - if not recurrent_msg.is_cuda or self._active_backend_name == "pytorch": - raise RuntimeError("CUDA flat-bucket transition execution requires active CUDA backend tensors") - return _cuda_flat_bucket_sequence_surface().run_transition_buckets_step( - self, - recurrent_msg, - population_state, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=materialize_next_state, - ) - - def _run_backend_order_transition_buckets_step( - self, - recurrent_msg: torch.Tensor, - population_state: TensorDict, - *, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = True, - ) -> tuple[torch.Tensor, TensorDict]: - if not recurrent_msg.is_cuda or self._active_backend_name == "pytorch": - raise RuntimeError("CUDA flat-bucket transition execution requires active CUDA backend tensors") - return _cuda_flat_bucket_sequence_surface().run_backend_order_transition_buckets_step( - self, - recurrent_msg, - population_state, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=materialize_next_state, - ) - - def _run_backend_order_transition_buckets_backward_step( - self, - recurrent_msg: torch.Tensor, - population_state_before: TensorDict, - *, - grad_recurrent_hidden: torch.Tensor | None, - grad_next_population_state: Mapping[str, Mapping[str, torch.Tensor | None]] | None = None, - resets: torch.Tensor | None, - static_tensors: dict[str, object], - trainable_params: tuple[torch.Tensor, ...], - trainable_param_names: tuple[str, ...], - need_grad_state_before: bool = True, - ) -> tuple[torch.Tensor | None, TensorDict, tuple[torch.Tensor | None, ...]]: - if not recurrent_msg.is_cuda or self._active_backend_name == "pytorch": - raise RuntimeError("CUDA flat-bucket transition backward requires active CUDA backend tensors") - return _cuda_flat_bucket_sequence_surface().run_backend_order_transition_buckets_backward_step( - self, - recurrent_msg, - population_state_before, - grad_recurrent_hidden=grad_recurrent_hidden, - grad_next_population_state=grad_next_population_state, - resets=resets, - static_tensors=static_tensors, - trainable_params=trainable_params, - trainable_param_names=trainable_param_names, - need_grad_state_before=need_grad_state_before, - ) - - def _run_active_window_transition_buckets_step( - self, - recurrent_msg: torch.Tensor, - *, - active_recurrent_idx: torch.Tensor, - active_window_buckets: Mapping[str, Mapping[str, torch.Tensor]] | None = None, - resets: torch.Tensor | None, - batch_size: int, - static_tensors: dict[str, object], - step_population_state_cache: dict[str, object] | None = None, - materialize_next_state: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if not recurrent_msg.is_cuda or self._active_backend_name == "pytorch": - raise RuntimeError("CUDA flat-bucket active-window execution requires active CUDA backend tensors") - return _cuda_flat_bucket_sequence_surface().run_active_window_transition_buckets_step( - self, - recurrent_msg, - active_recurrent_idx=active_recurrent_idx, - active_window_buckets=active_window_buckets, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - step_population_state_cache=step_population_state_cache, - materialize_next_state=materialize_next_state, - ) - def _run_population_updates_step_cached( self, population_input: torch.Tensor, @@ -694,13 +381,7 @@ def _project_output_cells_step_raw( *, value_to_output_weight: torch.Tensor, ) -> torch.Tensor: - if self._active_backend_name == "pytorch": - return _pytorch_runtime_ops().project_output_cells_step_raw_backend( - self, - output_msg, - value_to_output_weight=value_to_output_weight, - ) - return _cuda_runtime_ops().project_output_cells_step_raw_backend( + return _pytorch_runtime_ops().project_output_cells_step_raw_backend( self, output_msg, value_to_output_weight=value_to_output_weight, diff --git a/src/cortical/fabric/backend/surfaces.py b/src/cortical/fabric/backend/surfaces.py index 6918181b..6ebe33be 100644 --- a/src/cortical/fabric/backend/surfaces.py +++ b/src/cortical/fabric/backend/surfaces.py @@ -10,7 +10,7 @@ class SupportedSurface: regime: str training: bool | None eligibility: tuple[str, ...] - disallowed_fallbacks: tuple[str, ...] + forbidden_routes: tuple[str, ...] expected_profile_signature: tuple[str, ...] @@ -86,7 +86,62 @@ class BackendExecutionRecord: launch_readout_modes: tuple[str, ...] = () launch_temporal_executions: tuple[str, ...] = () launch_scan_implementations: tuple[str, ...] = () + launch_temporal_scan_owners: tuple[str, ...] = () + launch_temporal_scan_outer_steps: tuple[str, ...] = () + launch_temporal_scan_inner_steps: tuple[str, ...] = () + launch_temporal_scan_physical_steps: tuple[str, ...] = () + launch_temporal_scan_emission_counts: tuple[str, ...] = () + launch_temporal_scan_first_emission_steps: tuple[str, ...] = () + launch_temporal_scan_emission_strides: tuple[str, ...] = () + launch_temporal_scan_output_boundaries: tuple[str, ...] = () + temporal_primitive_executor_contracts: tuple[str, ...] = () + temporal_primitive_executor_blockers: tuple[str, ...] = () launch_phases: tuple[str, ...] = () + temporal_plan_schedule_kinds: tuple[str, ...] = () + temporal_plan_outer_time_steps: tuple[str, ...] = () + temporal_plan_inner_steps: tuple[str, ...] = () + temporal_plan_total_scan_steps: tuple[str, ...] = () + temporal_plan_per_timestep_k: tuple[str, ...] = () + temporal_plan_substrate_kinds: tuple[str, ...] = () + temporal_plan_bucket_identity: tuple[str, ...] = () + temporal_plan_resets: tuple[str, ...] = () + temporal_plan_output_selectors: tuple[str, ...] = () + temporal_plan_output_explicit_outer_steps: tuple[str, ...] = () + temporal_plan_output_first_outer_steps: tuple[str, ...] = () + temporal_plan_output_outer_strides: tuple[str, ...] = () + temporal_plan_output_counts: tuple[str, ...] = () + temporal_plan_output_first_physical_steps: tuple[str, ...] = () + temporal_plan_output_physical_strides: tuple[str, ...] = () + temporal_plan_output_surfaces: tuple[str, ...] = () + temporal_plan_readout_surfaces: tuple[str, ...] = () + temporal_plan_output_materializations: tuple[str, ...] = () + temporal_plan_autograd_seed_kinds: tuple[str, ...] = () + temporal_plan_required_backward_surfaces: tuple[str, ...] = () + temporal_plan_checkpoint_policy_basis: tuple[str, ...] = () + temporal_plan_fresh_state_population_cache: tuple[str, ...] = () + temporal_plan_fresh_state_population_cache_reasons: tuple[str, ...] = () + temporal_plan_gradient_boundaries: tuple[str, ...] = () + temporal_plan_horizon_steps: tuple[str, ...] = () + temporal_plan_checkpoint_kinds: tuple[str, ...] = () + temporal_plan_checkpoint_steps: tuple[str, ...] = () + temporal_plan_reverse_artifact_kinds: tuple[str, ...] = () + temporal_plan_recompute_window_steps: tuple[str, ...] = () + temporal_plan_materialization_reasons: tuple[str, ...] = () + temporal_plan_backward_windows: tuple[str, ...] = () + temporal_plan_static_value_modes: tuple[str, ...] = () + temporal_plan_native_static_materialization: tuple[str, ...] = () + temporal_plan_static_include_full_cell_kv: tuple[str, ...] = () + temporal_plan_static_detach_training: tuple[str, ...] = () + temporal_plan_backend_names: tuple[str, ...] = () + temporal_plan_executors: tuple[str, ...] = () + temporal_plan_selected_implementations: tuple[str, ...] = () + temporal_plan_reasons: tuple[str, ...] = () + temporal_plan_forward_owners: tuple[str, ...] = () + temporal_plan_backward_owners: tuple[str, ...] = () + temporal_plan_checkpoint_owners: tuple[str, ...] = () + temporal_plan_target_owners: tuple[str, ...] = () + temporal_plan_engine_statuses: tuple[str, ...] = () + temporal_plan_engine_reasons: tuple[str, ...] = () active_receiver_window_modes: tuple[str, ...] = () active_receiver_window_offsets: tuple[str, ...] = () active_receiver_window_counts: tuple[str, ...] = () @@ -250,8 +305,8 @@ class BackendExecutionRecord: "cell_backend_supported_topology", "cell_backend_supported_math_path", ), - disallowed_fallbacks=( - "runtime/reference fallback", + forbidden_routes=( + "runtime_reference_route", "wrapper-stitched backward", "host_side_O(T)_cell_population_loop", "mandatory output reread pass", @@ -260,7 +315,7 @@ class BackendExecutionRecord: expected_profile_signature=( "cell_owned_recurrence_engine", "policy_only_step_rollout_training_split", - "no_runtime_reference_fallback", + "no_runtime_reference_route", ), ), SupportedSurface( @@ -274,8 +329,8 @@ class BackendExecutionRecord: "local_edges_only", "cell_backend_supported_math_path", ), - disallowed_fallbacks=( - "runtime/reference fallback", + forbidden_routes=( + "runtime_reference_route", "wrapper-stitched backward", "host_side_O(T)_rtu_loop", ), diff --git a/src/cortical/fabric/backend/temporal_plan.py b/src/cortical/fabric/backend/temporal_plan.py new file mode 100644 index 00000000..cbb31aae --- /dev/null +++ b/src/cortical/fabric/backend/temporal_plan.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +SequenceSurfaceRouteKind = Literal["none", "sequence_surface"] +SequenceSurfaceExecutorKind = Literal["none", "temporal_bucket_sequence"] +SequenceSurfaceImplementationKind = Literal["none", "registered_temporal_program"] +TemporalOutputSelectorKind = Literal["all_outer_steps", "terminal_outer_step", "explicit_outer_steps"] +TemporalAutogradSeedKind = Literal["none", "emitted_output_grad", "emitted_output_grad_plus_final_state_grad"] +TemporalCheckpointBasisKind = Literal[ + "inference", + "emitted_output_schedule", + "emitted_output_and_final_state_schedule", +] +TemporalEngineOwnerKind = Literal[ + "none", + "registered_fused_forward_program_cuda", + "registered_reverse_executor_bindings", + "registered_temporal_executor_bindings", + "pytorch_reference", + "unsupported", +] +TemporalEngineStatus = Literal[ + "registered_executor_bindings", + "pytorch_reference", + "unsupported", +] + + +@dataclass(frozen=True) +class SequenceSurfaceRoute: + kind: SequenceSurfaceRouteKind + executor: SequenceSurfaceExecutorKind + supported: bool + reason: str + active_populations: tuple[str, ...] + surface_key: str | None = None + implementation_executor: SequenceSurfaceImplementationKind = "none" + bucket_count: int = 0 + + @property + def uses_registered_temporal_program(self) -> bool: + return self.implementation_executor == "registered_temporal_program" + + +@dataclass(frozen=True) +class TemporalSchedulePlan: + schedule_kind: Literal["scalar_constant_k", "runtime_variable_k"] + outer_time_steps: int + inner_steps: int | None + total_scan_steps: int | None + per_timestep_k_semantic: str + + +@dataclass(frozen=True) +class TemporalSubstratePlan: + substrate_kind: str + active_populations: tuple[str, ...] + bucket_count: int + population_cardinality: Literal["none", "single", "multi"] + partitioned_layout: bool + bucket_identity: str + + +@dataclass(frozen=True) +class TemporalBoundaryPlan: + input_boundary: str + readout_output_boundary: Literal["cells", "pooled"] + output_contract: str + resets: Literal["present", "absent"] + + +@dataclass(frozen=True) +class TemporalCarryPlan: + initial_state: Literal["fresh", "provided"] + materialize_final_state: bool + carry_policy: str + fresh_state_population_cache: bool + fresh_state_population_cache_reason: str + + +@dataclass(frozen=True) +class TemporalOutputRequestPlan: + selector_kind: TemporalOutputSelectorKind + explicit_outer_steps: tuple[int, ...] + first_outer_step: int | None + outer_stride: int | None + emitted_output_count: int + first_physical_step: int | None + physical_stride: int | None + output_surface: str + readout_surface: Literal["cells", "pooled"] + materialize_final_state: bool + materialization: Literal["outputs_only", "outputs_and_final_state"] + autograd_seed_kind: TemporalAutogradSeedKind + required_backward_surfaces: tuple[str, ...] + checkpoint_policy_basis: TemporalCheckpointBasisKind + + +@dataclass(frozen=True) +class TemporalCheckpointPlan: + checkpoint_kind: Literal["none", "explicit", "planner_default"] + checkpoint_steps: int | None + owner: Literal["planner"] + + +@dataclass(frozen=True) +class TemporalMaterializationPlan: + reverse_artifact_kind: Literal["none", "store_step_artifacts", "forward_reverse_tables", "checkpoint_recompute"] + checkpoint_steps: int | None + recompute_window_steps: int | None + output_materialization: Literal["none", "outputs_only", "outputs_and_final_state"] + owner: Literal["planner"] + reason: str + + +@dataclass(frozen=True) +class TemporalGradientBoundaryPlan: + mode: Literal["inference", "full_horizon", "rolling_horizon"] + horizon_steps: int | None + owner: Literal["planner"] + + +@dataclass(frozen=True) +class TemporalBackwardWindowPlan: + window_kind: Literal["none", "full_horizon", "rolling_horizon"] + max_window_steps: int | None + owner: Literal["planner"] + + +@dataclass(frozen=True) +class TemporalStaticValuePlan: + static_value_mode: Literal[ + "inference_cache", + "pytorch_autograd_static_values", + "detached_shared_values", + "flat_bucket_autograd_static_values", + ] + native_static_materialization: bool + include_full_cell_kv_weight: bool + detach_training_static_tensors: bool + owner: Literal["planner"] + + +@dataclass(frozen=True) +class TemporalExecutorPlan: + backend_name: str + executor: SequenceSurfaceExecutorKind + selected_implementation: SequenceSurfaceImplementationKind + surface_key: str | None + supported: bool + reason: str + + +@dataclass(frozen=True) +class TemporalEnginePlan: + forward_owner: TemporalEngineOwnerKind + backward_owner: TemporalEngineOwnerKind + checkpoint_owner: Literal["planner"] + target_owner: TemporalEngineOwnerKind + status: TemporalEngineStatus + reason: str + + +@dataclass(frozen=True) +class TemporalExecutionPlan: + sequence_surface_route: SequenceSurfaceRoute + schedule: TemporalSchedulePlan + substrate: TemporalSubstratePlan + boundary: TemporalBoundaryPlan + carry: TemporalCarryPlan + output_request: TemporalOutputRequestPlan + checkpoint: TemporalCheckpointPlan + materialization: TemporalMaterializationPlan + gradient_boundary: TemporalGradientBoundaryPlan + backward_window: TemporalBackwardWindowPlan + static_values: TemporalStaticValuePlan + executor: TemporalExecutorPlan + engine: TemporalEnginePlan + + @property + def supported(self) -> bool: + return self.executor.supported + + @property + def reason(self) -> str: + return self.executor.reason + + +def temporal_execution_record_metadata(plan: TemporalExecutionPlan | None) -> dict[str, tuple[str, ...]]: + if plan is None: + return {} + return { + "temporal_plan_schedule_kinds": (plan.schedule.schedule_kind,), + "temporal_plan_outer_time_steps": (str(plan.schedule.outer_time_steps),), + "temporal_plan_inner_steps": (_optional_int_metadata(plan.schedule.inner_steps),), + "temporal_plan_total_scan_steps": (_optional_int_metadata(plan.schedule.total_scan_steps),), + "temporal_plan_per_timestep_k": (plan.schedule.per_timestep_k_semantic,), + "temporal_plan_substrate_kinds": (plan.substrate.substrate_kind,), + "temporal_plan_bucket_identity": (plan.substrate.bucket_identity,), + "temporal_plan_resets": (plan.boundary.resets,), + "temporal_plan_output_selectors": (plan.output_request.selector_kind,), + "temporal_plan_output_explicit_outer_steps": (_int_tuple_metadata(plan.output_request.explicit_outer_steps),), + "temporal_plan_output_first_outer_steps": (_optional_int_metadata(plan.output_request.first_outer_step),), + "temporal_plan_output_outer_strides": (_optional_int_metadata(plan.output_request.outer_stride),), + "temporal_plan_output_counts": (str(plan.output_request.emitted_output_count),), + "temporal_plan_output_first_physical_steps": (_optional_int_metadata(plan.output_request.first_physical_step),), + "temporal_plan_output_physical_strides": (_optional_int_metadata(plan.output_request.physical_stride),), + "temporal_plan_output_surfaces": (plan.output_request.output_surface,), + "temporal_plan_readout_surfaces": (plan.output_request.readout_surface,), + "temporal_plan_output_materializations": (plan.output_request.materialization,), + "temporal_plan_autograd_seed_kinds": (plan.output_request.autograd_seed_kind,), + "temporal_plan_required_backward_surfaces": ( + _str_tuple_metadata(plan.output_request.required_backward_surfaces), + ), + "temporal_plan_checkpoint_policy_basis": (plan.output_request.checkpoint_policy_basis,), + "temporal_plan_fresh_state_population_cache": (str(plan.carry.fresh_state_population_cache),), + "temporal_plan_fresh_state_population_cache_reasons": (plan.carry.fresh_state_population_cache_reason,), + "temporal_plan_gradient_boundaries": (plan.gradient_boundary.mode,), + "temporal_plan_horizon_steps": (_optional_int_metadata(plan.gradient_boundary.horizon_steps),), + "temporal_plan_checkpoint_kinds": (plan.checkpoint.checkpoint_kind,), + "temporal_plan_checkpoint_steps": (_optional_int_metadata(plan.checkpoint.checkpoint_steps),), + "temporal_plan_reverse_artifact_kinds": (plan.materialization.reverse_artifact_kind,), + "temporal_plan_recompute_window_steps": (_optional_int_metadata(plan.materialization.recompute_window_steps),), + "temporal_plan_materialization_reasons": (plan.materialization.reason,), + "temporal_plan_backward_windows": (plan.backward_window.window_kind,), + "temporal_plan_static_value_modes": (plan.static_values.static_value_mode,), + "temporal_plan_native_static_materialization": (str(plan.static_values.native_static_materialization),), + "temporal_plan_static_include_full_cell_kv": (str(plan.static_values.include_full_cell_kv_weight),), + "temporal_plan_static_detach_training": (str(plan.static_values.detach_training_static_tensors),), + "temporal_plan_backend_names": (plan.executor.backend_name,), + "temporal_plan_executors": (plan.executor.executor,), + "temporal_plan_selected_implementations": (plan.executor.selected_implementation,), + "temporal_plan_reasons": (plan.executor.reason,), + "temporal_plan_forward_owners": (plan.engine.forward_owner,), + "temporal_plan_backward_owners": (plan.engine.backward_owner,), + "temporal_plan_checkpoint_owners": (plan.engine.checkpoint_owner,), + "temporal_plan_target_owners": (plan.engine.target_owner,), + "temporal_plan_engine_statuses": (plan.engine.status,), + "temporal_plan_engine_reasons": (plan.engine.reason,), + } + + +def _optional_int_metadata(value: int | None) -> str: + return "none" if value is None else str(int(value)) + + +def _optional_str_metadata(value: str | None) -> str: + return "none" if value is None else value + + +def _int_tuple_metadata(values: tuple[int, ...]) -> str: + if not values: + return "none" + return ",".join(str(int(value)) for value in values) + + +def _str_tuple_metadata(values: tuple[str, ...]) -> str: + if not values: + return "none" + return "|".join(values) diff --git a/src/cortical/fabric/blueprint.py b/src/cortical/fabric/blueprint.py index 26d75b91..2160fbc1 100644 --- a/src/cortical/fabric/blueprint.py +++ b/src/cortical/fabric/blueprint.py @@ -1,14 +1,23 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence -from typing import Literal +from collections.abc import Mapping +from dataclasses import dataclass, replace +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field from cortical.fabric.anatomy import Spec, init from cortical.fabric.cells.declarations import SLSTM, CellDeclaration -from cortical.fabric.config import CellPopulationConfig, Config -from cortical.fabric.graphs import lattice2d +from cortical.fabric.config import ( + CellPopulationConfig, + Config, + FabricInterfaceConfig, + MessageConfig, + PopulationLayoutConfig, + ReadoutConfig, + RuntimeExecutionConfig, +) +from cortical.fabric.graphs import flat, lattice2d from cortical.fabric.message_rules.declarations import DotProduct from cortical.fabric.population import Population from cortical.fabric.runtime import Model @@ -44,6 +53,8 @@ class ExecutionSpec(_FabricDeclaration): backend: Literal["auto", "cuda", "pytorch"] = "auto" inner_steps: int = Field(default=1, ge=0) max_inner_steps: int | None = Field(default=None, ge=0) + gradient_horizon_steps: int | None = Field(default=None, ge=1) + checkpoint_steps: int | None = Field(default=None, ge=1) @property def resolved_max_inner_steps(self) -> int: @@ -53,7 +64,7 @@ def resolved_max_inner_steps(self) -> int: class Blueprint(_FabricDeclaration): interface: Interface - graph: lattice2d.Graph + graph: lattice2d.Graph | flat.Graph | Any inputs: Mapping[str, Input] outputs: Mapping[str, Output] message_passing: DotProduct @@ -78,9 +89,26 @@ def preset(cls, name: str) -> Blueprint: ) +@dataclass(frozen=True) +class _BlueprintLowering: + graph: Any + interface: FabricInterfaceConfig + message: MessageConfig + populations: PopulationLayoutConfig + readout: ReadoutConfig + execution: RuntimeExecutionConfig + message_passing: DotProduct + + def normalize(source: Blueprint) -> Spec: if isinstance(source, Blueprint): - return init(_blueprint_to_config(source)) + lowering = _lower_blueprint_declaration(source) + spec = init(_runtime_section_container(lowering)) + message_rule = lowering.message_passing.to_ir( + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + ) + return replace(spec, message_rule=message_rule) raise TypeError(f"Fabric normalize expects Blueprint, got {type(source).__name__}") @@ -90,11 +118,9 @@ def compile(blueprint: Blueprint) -> Model: return Model(spec, input_dim=input_dim, output_dim=output_dim) -def _blueprint_to_config(blueprint: Blueprint) -> Config: - if not isinstance(blueprint.graph, lattice2d.Graph): - raise ValueError("current Blueprint normalization supports only lattice2d.Graph") +def _lower_blueprint_declaration(blueprint: Blueprint) -> _BlueprintLowering: graph = blueprint.graph - population_specs = dict(graph.populations or {}) + population_specs = _graph_populations(graph) if not population_specs: raise ValueError("Blueprint.graph.populations must not be empty") hidden_size = _resolve_hidden_size(blueprint) @@ -105,77 +131,60 @@ def _blueprint_to_config(blueprint: Blueprint) -> Config: input_nodes = graph.input_nodes() output_nodes = graph.output_nodes() recurrent_nodes = _resolve_recurrent_nodes(graph, input_nodes=input_nodes, output_nodes=output_nodes) - population_node_indices = _resolve_population_node_indices( - graph, - population_specs=population_specs, - recurrent_nodes=recurrent_nodes, - ) + population_node_indices = graph.population_node_indices(recurrent_nodes=recurrent_nodes) population_mix = ( {name: float(len(nodes)) for name, nodes in population_node_indices.items()} if population_node_indices is not None else {next(iter(population_specs)): 1.0} ) - patch_edges = graph.patch_edges() - explicit_edges = graph.explicit_edges() - projection_region_shape = blueprint.message_passing.projection_region_shape(coord_dim=2) - blueprint.message_passing.to_ir( - kv_group_count=_estimated_kv_group_count(graph, projection_region_shape=projection_region_shape), - cell_count=graph.node_count, - ) + projection_region_shape = blueprint.message_passing.projection_region_shape(coord_dim=_graph_coord_dim(graph)) output_groups = graph.output_groups() readout_pool = next(iter(output_groups.values())).aggregate if any(output.aggregate != readout_pool for output in output_groups.values()): raise ValueError("current Blueprint normalization requires one output aggregate kind") - return Config( - width=int(graph.width), - height=int(graph.height), - depth=1, - hidden_size=hidden_size, - d_public=int(blueprint.interface.public_dim), - d_msg=int(blueprint.interface.resolved_message_dim), - d_slot=int(blueprint.interface.resolved_slot_dim), - num_heads=1, - head_dim=int(blueprint.message_passing.head_dim or hidden_size), - local_radius=float(graph.local_radius()), - patch_edges_per_cell=int(patch_edges.per_cell), - patch_min_dist=float(patch_edges.min_distance), - patch_max_dist=float(patch_edges.max_distance), - wrap=bool(graph.wrap), - graph_edges=None if explicit_edges is None else tuple(explicit_edges.edges), - kv_group_ids=None - if explicit_edges is None or explicit_edges.kv_group_ids is None - else tuple(int(group_id) for group_id in explicit_edges.kv_group_ids), - projection_region_shape=projection_region_shape, - cell_populations=cell_populations, - population_mix=population_mix, - population_node_indices=population_node_indices, - input_cell_indices=input_nodes, - output_cell_indices=output_nodes, - readout_pool=readout_pool, - backend=blueprint.execution.backend, - default_k=int(blueprint.execution.inner_steps), - k_max=int(blueprint.execution.resolved_max_inner_steps), + return _BlueprintLowering( + graph=graph, + interface=FabricInterfaceConfig( + hidden_size=hidden_size, + public_dim=int(blueprint.interface.public_dim), + message_dim=int(blueprint.interface.resolved_message_dim), + slot_dim=int(blueprint.interface.resolved_slot_dim), + ), + message=MessageConfig( + num_heads=1, + head_dim=int(blueprint.message_passing.head_dim or hidden_size), + projection_region_shape=projection_region_shape, + ), + populations=PopulationLayoutConfig( + cell_populations=cell_populations, + population_mix=population_mix, + population_node_indices=population_node_indices, + ), + readout=ReadoutConfig(pool=readout_pool), + execution=RuntimeExecutionConfig( + backend=blueprint.execution.backend, + gradient_horizon_steps=blueprint.execution.gradient_horizon_steps, + checkpoint_steps=blueprint.execution.checkpoint_steps, + default_k=int(blueprint.execution.inner_steps), + k_max=int(blueprint.execution.resolved_max_inner_steps), + ), + message_passing=blueprint.message_passing, ) -def _estimated_kv_group_count( - graph: lattice2d.Graph, - *, - projection_region_shape: tuple[int, ...] | None, -) -> int: - if projection_region_shape is None: - projection_region_shape = ( - max(1, int(graph.width) // 4), - max(1, int(graph.height) // 4), - ) - x_tile, y_tile = projection_region_shape - x_groups = (int(graph.width) + int(x_tile) - 1) // int(x_tile) - y_groups = (int(graph.height) + int(y_tile) - 1) // int(y_tile) - return int(x_groups * y_groups) +def _runtime_section_container(lowering: _BlueprintLowering) -> Config: + return Config( + graph=lowering.graph, + interface=lowering.interface, + message=lowering.message, + populations=lowering.populations, + readout=lowering.readout, + execution=lowering.execution, + ) def _resolve_hidden_size(blueprint: Blueprint) -> int: - population_specs = blueprint.graph.populations or {} + population_specs = _graph_populations(blueprint.graph) hidden_dims = {int(spec.cell.hidden_dim) for spec in population_specs.values()} if len(hidden_dims) != 1: raise ValueError( @@ -200,8 +209,21 @@ def _population_config_from_declaration(name: str, cell: CellDeclaration) -> Cel return cell.to_population_config() +def _graph_populations(graph: Any) -> dict[str, Population]: + populations = getattr(graph, "populations", None) + if populations is None: + return {} + return {str(name): population for name, population in dict(populations).items()} + + +def _graph_coord_dim(graph: Any) -> int: + if hasattr(graph, "width") and hasattr(graph, "height"): + return 2 + return 1 + + def _resolve_recurrent_nodes( - graph: lattice2d.Graph, + graph: Any, *, input_nodes: tuple[int, ...], output_nodes: tuple[int, ...], @@ -212,60 +234,6 @@ def _resolve_recurrent_nodes( return tuple(node for node in range(graph.node_count) if node not in boundary) -def _resolve_population_node_indices( - graph: lattice2d.Graph, - *, - population_specs: Mapping[str, Population], - recurrent_nodes: tuple[int, ...], -) -> dict[str, tuple[int, ...]] | None: - if len(population_specs) == 1: - _name, spec = next(iter(population_specs.items())) - if spec.nodes is None: - return None - node_sets = graph.named_node_sets() - out: dict[str, tuple[int, ...]] = {} - recurrent_set = set(recurrent_nodes) - for name, spec in population_specs.items(): - selector = spec.nodes - if selector is None: - raise ValueError("multi-population Blueprint requires explicit Population.nodes for every population") - if isinstance(selector, str): - try: - resolved = node_sets[selector] - except KeyError as exc: - raise ValueError(f"Unknown Fabric graph node_set {selector!r} for population {name!r}") from exc - out[name] = tuple(node for node in resolved if node in recurrent_set) - elif isinstance(selector, Sequence) and not isinstance(selector, Mapping): - out[name] = graph._resolve_selector(selector) - else: - out[name] = tuple(node for node in graph._resolve_selector(selector) if node in recurrent_set) - _validate_population_coverage(out, recurrent_nodes=recurrent_nodes) - return out - - -def _validate_population_coverage( - population_node_indices: Mapping[str, tuple[int, ...]], - *, - recurrent_nodes: tuple[int, ...], -) -> None: - recurrent_set = set(recurrent_nodes) - seen: set[int] = set() - for population_name, nodes in population_node_indices.items(): - if not nodes: - raise ValueError(f"population {population_name!r} has no recurrent nodes") - node_set = set(nodes) - extra = node_set - recurrent_set - if extra: - raise ValueError(f"population {population_name!r} targets non-recurrent nodes {sorted(extra)[:8]}") - overlap = seen & node_set - if overlap: - raise ValueError(f"population nodes overlap across populations: {sorted(overlap)[:8]}") - seen.update(node_set) - missing = recurrent_set - seen - if missing: - raise ValueError(f"population nodes must cover every recurrent node; missing={sorted(missing)[:8]}") - - def _resolve_external_adapter_dims(blueprint: Blueprint) -> tuple[int, int]: if len(blueprint.inputs) != 1: raise ValueError("current Fabric compile supports exactly one external input adapter") diff --git a/src/cortical/fabric/cells/axon.py b/src/cortical/fabric/cells/axon.py index 96ba12d4..c9e809c9 100644 --- a/src/cortical/fabric/cells/axon.py +++ b/src/cortical/fabric/cells/axon.py @@ -60,6 +60,8 @@ def _axon_activation_id(name: str) -> int: "dynamics_theta_log", "dynamics_w1", "dynamics_w2", + "state_norm_weight", + "outnorm_eps", "activation_id", ) ), @@ -79,6 +81,8 @@ def _axon_activation_id(name: str) -> int: "dynamics_theta_log": SharingScope("receiver_local"), "dynamics_w1": SharingScope("receiver_local"), "dynamics_w2": SharingScope("receiver_local"), + "state_norm_weight": SharingScope("receiver_local"), + "outnorm_eps": SharingScope("fabric_global"), "activation_id": SharingScope("fabric_global"), "recurrent_kv_projection_weight": SharingScope("receiver_local"), "recurrent_kv_projection_bias": SharingScope("receiver_local"), @@ -113,6 +117,7 @@ def _materialize_params( "w2": base_plus_delta(module.w2_base, module.w2_delta), "out_proj_weight": base_plus_delta(module.out_proj_weight_base, module.out_proj_weight_delta), "out_proj_bias": base_plus_delta(module.out_proj_bias_base, module.out_proj_bias_delta), + "outnorm_weight": base_plus_delta(module.outnorm_weight_base, module.outnorm_weight_delta), "input_proj_weight": None, } params["nu_log_flat"] = params["nu_log"].reshape(-1) @@ -125,6 +130,9 @@ def _materialize_params( params["dynamics_theta_log"] = params["theta_log"] params["dynamics_w1"] = params["w1"] params["dynamics_w2"] = params["w2"] + params["state_norm_weight"] = params["outnorm_weight"] + params["outnorm_weight_flat"] = params["outnorm_weight"].reshape(-1) + params["outnorm_eps"] = params["nu_log"].new_tensor([module.outnorm_eps], dtype=torch.float32) params["recurrent_hidden_projection_weight"] = params["out_proj_weight_t"] params["recurrent_hidden_projection_bias"] = params["out_proj_bias"] params["activation_id"] = params["nu_log"].new_tensor( @@ -180,6 +188,7 @@ def build_axon_cell_population( module.activation_name = str(config.activation) module.use_input_proj = True + module.outnorm_eps = 1.0e-6 nu_log = torch.zeros(module.hidden_size) theta_log = torch.zeros(module.hidden_size) w1 = torch.empty(module.hidden_size) @@ -191,6 +200,7 @@ def build_axon_cell_population( torch.nn.init.xavier_uniform_(out_proj_weight) torch.nn.init.xavier_uniform_(input_proj_weight) out_proj_bias = torch.zeros(module.hidden_size) + outnorm_weight = torch.ones(module.hidden_size) module.nu_log_base, module.nu_log_delta = init_base_plus_delta(nu_log, module.num_cells, module.init_noise_std) module.theta_log_base, module.theta_log_delta = init_base_plus_delta( @@ -210,6 +220,11 @@ def build_axon_cell_population( module.num_cells, module.init_noise_std, ) + module.outnorm_weight_base, module.outnorm_weight_delta = init_base_plus_delta( + outnorm_weight, + module.num_cells, + module.init_noise_std, + ) module.input_proj_weight_base, module.input_proj_weight_delta = init_base_plus_delta( input_proj_weight, module.num_cells, diff --git a/src/cortical/fabric/config.py b/src/cortical/fabric/config.py index 47b8ec2c..164677d3 100644 --- a/src/cortical/fabric/config.py +++ b/src/cortical/fabric/config.py @@ -1,8 +1,12 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class _ConfigSection(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) class CellPopulationConfig(BaseModel): @@ -11,162 +15,104 @@ class CellPopulationConfig(BaseModel): activation: Literal["silu", "relu", "tanh", "linear"] = "silu" -class Config(BaseModel): - width: int = Field(ge=1) - height: int = Field(ge=1) - depth: int = Field(default=1, ge=1) - +class FabricInterfaceConfig(_ConfigSection): hidden_size: int = Field(default=8, ge=1) - d_public: int | None = Field(default=None, ge=1) - d_msg: int | None = Field(default=None, ge=1) - d_slot: int | None = Field(default=None, ge=1) + public_dim: int | None = Field(default=None, ge=1) + message_dim: int | None = Field(default=None, ge=1) + slot_dim: int | None = Field(default=None, ge=1) + @property + def resolved_public_dim(self) -> int: + return int(self.hidden_size if self.public_dim is None else self.public_dim) + + @property + def resolved_message_dim(self) -> int: + return int(self.hidden_size if self.message_dim is None else self.message_dim) + + @property + def resolved_slot_dim(self) -> int: + return int(2 * self.hidden_size if self.slot_dim is None else self.slot_dim) + + +class MessageConfig(_ConfigSection): num_heads: int | None = Field(default=None, ge=1) head_dim: int = Field(default=4, ge=1) - - local_radius: float = Field(default=1.5, gt=0.0) - patch_edges_per_cell: int = Field(default=0, ge=0) - patch_min_dist: float = Field(default=4.0, ge=0.0) - patch_max_dist: float = Field(default=12.0, ge=0.0) distance_logit_scale: float = Field(default=0.5, ge=0.0) - wrap: bool = Field(default=True) + projection_region_shape: tuple[int, ...] | None = None - conduction_speed: float | None = Field(default=None, gt=0.0) - max_delay: int | None = Field(default=None, ge=1) + @property + def resolved_num_heads(self) -> int: + return 1 if self.num_heads is None else int(self.num_heads) - graph_edges: tuple[tuple[int, int], ...] | None = None - kv_group_ids: tuple[int, ...] | None = None - projection_region_shape: tuple[int, ...] | None = None - cell_arrangement: Literal["random", "x_bands"] = Field(default="random") +class PopulationLayoutConfig(_ConfigSection): cell_populations: dict[str, CellPopulationConfig] = Field( default_factory=lambda: {"slstm": CellPopulationConfig(cell_type="slstm")} ) population_mix: dict[str, float] = Field(default_factory=lambda: {"slstm": 1.0}) population_node_indices: dict[str, tuple[int, ...]] | None = None + cell_arrangement: Literal["random", "x_bands"] = "random" - input_band_width: int = Field(default=1, ge=1) - output_band_width: int = Field(default=1, ge=1) - input_cell_indices: tuple[int, ...] | None = None - output_cell_indices: tuple[int, ...] | None = None - readout_pool: Literal["mean", "attn", "flatten"] = Field(default="mean") - readout_slots: int = Field(default=4, ge=1) - backend: Literal["auto", "cuda", "pytorch"] = Field(default="auto") +class ReadoutConfig(_ConfigSection): + pool: Literal["mean", "attn", "flatten"] = Field(default="mean") + slots: int = Field(default=4, ge=1) + + +class RuntimeExecutionConfig(_ConfigSection): + backend: Literal["auto", "cuda", "pytorch"] = Field(default="auto") + gradient_horizon_steps: int | None = Field(default=None, ge=1) + checkpoint_steps: int | None = Field(default=None, ge=1) k_max: int = Field(default=8, ge=0) default_k: int = Field(default=4, ge=0) inject_every_step: bool = Field(default=True) + + @model_validator(mode="after") + def _validate_steps(self) -> RuntimeExecutionConfig: + if self.default_k > self.k_max: + raise ValueError(f"default_k={self.default_k} must be <= k_max={self.k_max}") + return self + + +class InitializationConfig(_ConfigSection): population_init_noise_std: float = Field(default=0.0, ge=0.0) seed: int = Field(default=0) - @property - def coord_shape(self) -> tuple[int, ...]: - if self.depth == 1: - return (self.width, self.height) - return (self.width, self.height, self.depth) - @property - def coord_dim(self) -> int: - return 2 if self.depth == 1 else 3 +class Config(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + graph: Any + interface: FabricInterfaceConfig = Field(default_factory=FabricInterfaceConfig) + message: MessageConfig = Field(default_factory=MessageConfig) + populations: PopulationLayoutConfig = Field(default_factory=PopulationLayoutConfig) + readout: ReadoutConfig = Field(default_factory=ReadoutConfig) + execution: RuntimeExecutionConfig = Field(default_factory=RuntimeExecutionConfig) + initialization: InitializationConfig = Field(default_factory=InitializationConfig) @model_validator(mode="after") - def _apply_defaults_and_validate(self) -> Config: - if self.d_public is None: - self.d_public = self.hidden_size - if self.d_msg is None: - self.d_msg = self.hidden_size - if self.d_slot is None: - self.d_slot = 2 * self.hidden_size - if self.num_heads is None: - self.num_heads = 1 - - if self.d_msg <= 0 or self.d_public <= 0 or self.d_slot <= 0: - raise ValueError("derived dimensions must be positive") - if self.default_k > self.k_max: - raise ValueError(f"default_k={self.default_k} must be <= k_max={self.k_max}") - if not self.cell_populations: - raise ValueError("cell_populations must not be empty") - if len(self.cell_populations) > 2: - raise ValueError("v1 supports at most two fabric cell populations") - population_keys = set(self.cell_populations.keys()) - mix_keys = set(self.population_mix.keys()) + def _validate_population_sections(self) -> Config: + if not self.populations.cell_populations: + raise ValueError("populations.cell_populations must not be empty") + population_keys = set(self.populations.cell_populations.keys()) + mix_keys = set(self.populations.population_mix.keys()) if population_keys != mix_keys: raise ValueError( - f"cell_populations keys {sorted(population_keys)} must match population_mix keys {sorted(mix_keys)}" - ) - mix_total = float(sum(self.population_mix.values())) - if mix_total <= 0.0: - raise ValueError("population_mix weights must sum to a positive value") - if self.projection_region_shape is not None and len(self.projection_region_shape) != self.coord_dim: - raise ValueError( - "projection_region_shape length " - f"{len(self.projection_region_shape)} must match coord_dim={self.coord_dim}" + "populations.cell_populations keys " + f"{sorted(population_keys)} must match populations.population_mix keys {sorted(mix_keys)}" ) - if self.population_node_indices is not None: - node_keys = set(self.population_node_indices.keys()) - if node_keys != population_keys: - raise ValueError( - f"population_node_indices keys {sorted(node_keys)} must match " - f"cell_populations keys {sorted(population_keys)}" - ) - if self.patch_edges_per_cell > 0 and self.patch_max_dist < self.patch_min_dist: - raise ValueError(f"patch_max_dist={self.patch_max_dist} must be >= patch_min_dist={self.patch_min_dist}") - total_cells = self.width * self.height * self.depth - if self.population_node_indices is not None: - seen_population_nodes: set[int] = set() - for population_name, indices in self.population_node_indices.items(): - if not indices: - raise ValueError(f"population_node_indices[{population_name!r}] must not be empty") - if len(set(indices)) != len(indices): - raise ValueError(f"population_node_indices[{population_name!r}] must not contain duplicates") - bad = [idx for idx in indices if idx < 0 or idx >= total_cells] - if bad: - raise ValueError( - f"population_node_indices[{population_name!r}] contains node ids outside " - f"[0, {total_cells}): {bad[:8]}" - ) - overlap = seen_population_nodes & set(indices) - if overlap: - raise ValueError( - f"population_node_indices must not overlap across populations; overlap={sorted(overlap)[:8]}" - ) - seen_population_nodes.update(indices) - for name, indices in ( - ("input_cell_indices", self.input_cell_indices), - ("output_cell_indices", self.output_cell_indices), - ): - if indices is None: - continue - if not indices: - raise ValueError(f"{name} must not be empty") - if len(set(indices)) != len(indices): - raise ValueError(f"{name} must not contain duplicate node ids") - bad = [idx for idx in indices if idx < 0 or idx >= total_cells] - if bad: - raise ValueError(f"{name} contains node ids outside [0, {total_cells}): {bad[:8]}") - if self.input_cell_indices is not None and self.output_cell_indices is not None: - overlap = set(self.input_cell_indices) & set(self.output_cell_indices) - if overlap: - overlap_sample = sorted(overlap)[:8] - raise ValueError( - f"input_cell_indices and output_cell_indices must be disjoint; overlap={overlap_sample}" - ) - if self.graph_edges is not None: - if not self.graph_edges: - raise ValueError("graph_edges must not be empty") - if len(set(self.graph_edges)) != len(self.graph_edges): - raise ValueError("graph_edges must not contain duplicate receiver/sender pairs") - for receiver, sender in self.graph_edges: - if receiver < 0 or receiver >= total_cells or sender < 0 or sender >= total_cells: - raise ValueError(f"graph_edges contains node ids outside [0, {total_cells}): {(receiver, sender)}") - if receiver == sender: - raise ValueError(f"graph_edges must not contain self edges: {(receiver, sender)}") - if self.kv_group_ids is not None: - if len(self.kv_group_ids) != total_cells: - raise ValueError(f"kv_group_ids length must be {total_cells}, got {len(self.kv_group_ids)}") - if min(self.kv_group_ids) < 0: - raise ValueError("kv_group_ids must be non-negative") + if float(sum(self.populations.population_mix.values())) <= 0.0: + raise ValueError("populations.population_mix weights must sum to a positive value") return self -__all__ = ["CellPopulationConfig", "Config"] +__all__ = [ + "CellPopulationConfig", + "Config", + "FabricInterfaceConfig", + "InitializationConfig", + "MessageConfig", + "PopulationLayoutConfig", + "ReadoutConfig", + "RuntimeExecutionConfig", +] diff --git a/src/cortical/fabric/graphs/__init__.py b/src/cortical/fabric/graphs/__init__.py index 0512ec40..9f76c1cd 100644 --- a/src/cortical/fabric/graphs/__init__.py +++ b/src/cortical/fabric/graphs/__init__.py @@ -1,3 +1,3 @@ -from cortical.fabric.graphs import lattice2d +from cortical.fabric.graphs import flat, lattice2d -__all__ = ["lattice2d"] +__all__ = ["flat", "lattice2d"] diff --git a/src/cortical/fabric/graphs/flat.py b/src/cortical/fabric/graphs/flat.py new file mode 100644 index 00000000..c8ad775d --- /dev/null +++ b/src/cortical/fabric/graphs/flat.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass +from typing import Literal + +from cortical.fabric.population import Population + + +@dataclass(frozen=True) +class Output: + nodes: Sequence[int] + aggregate: Literal["mean"] = "mean" + + +@dataclass(frozen=True) +class Graph: + """User-owned flat graph declaration. + + Nodes are integer ids in [0, node_count). Edges are stored as + (receiver, sender) pairs, matching the flat graph topology consumed by + Fabric planning. + """ + + node_count: int + edges: Sequence[tuple[int, int]] + populations: Mapping[str, Population] + inputs: Mapping[str, Sequence[int]] | Sequence[int] + outputs: Mapping[str, Output | Sequence[int]] | Output | Sequence[int] + recurrent: Sequence[int] | None = None + kv_group_ids: Sequence[int] | None = None + + def input_groups(self) -> dict[str, tuple[int, ...]]: + if isinstance(self.inputs, Mapping): + return { + str(name): _normalize_nodes(nodes, node_count=self.node_count) for name, nodes in self.inputs.items() + } + return {"default": _normalize_nodes(self.inputs, node_count=self.node_count)} + + def output_groups(self) -> dict[str, Output]: + raw = self.outputs + if isinstance(raw, Output): + return {"default": raw} + if isinstance(raw, Mapping): + return {str(name): _normalize_output(value) for name, value in raw.items()} + return {"default": Output(raw)} + + def output_node_groups(self) -> dict[str, tuple[int, ...]]: + return { + name: _normalize_nodes(output.nodes, node_count=self.node_count) + for name, output in self.output_groups().items() + } + + def input_nodes(self) -> tuple[int, ...]: + return _unique_sorted(node for nodes in self.input_groups().values() for node in nodes) + + def output_nodes(self) -> tuple[int, ...]: + return _unique_sorted(node for nodes in self.output_node_groups().values() for node in nodes) + + def recurrent_nodes(self) -> tuple[int, ...] | None: + if self.recurrent is None: + return None + return _normalize_nodes(self.recurrent, node_count=self.node_count) + + def explicit_edges(self) -> tuple[tuple[int, int], ...]: + edges = tuple((int(receiver), int(sender)) for receiver, sender in self.edges) + for receiver, sender in edges: + if receiver < 0 or receiver >= int(self.node_count) or sender < 0 or sender >= int(self.node_count): + raise ValueError( + f"flat.Graph edge ({receiver}, {sender}) contains node outside [0, {int(self.node_count)})" + ) + return edges + + def population_node_indices(self, *, recurrent_nodes: tuple[int, ...]) -> dict[str, tuple[int, ...]] | None: + if len(self.populations) == 1: + _name, population = next(iter(self.populations.items())) + if population.nodes is None: + return None + recurrent_set = set(int(node) for node in recurrent_nodes) + out: dict[str, tuple[int, ...]] = {} + for name, population in self.populations.items(): + if population.nodes is None: + raise ValueError("multi-population flat.Graph requires explicit Population.nodes for every population") + nodes = _normalize_nodes(population.nodes, node_count=self.node_count) + out[str(name)] = tuple(node for node in nodes if node in recurrent_set) + _validate_population_coverage(out, recurrent_nodes=recurrent_nodes) + return out + + +def _normalize_output(output: Output | Sequence[int]) -> Output: + return output if isinstance(output, Output) else Output(output) + + +def _normalize_nodes(nodes: Sequence[int], *, node_count: int) -> tuple[int, ...]: + values = tuple(int(node) for node in nodes) + if not values: + raise ValueError("flat.Graph node set must not be empty") + if len(set(values)) != len(values): + raise ValueError("flat.Graph node set must not contain duplicate nodes") + bad = [node for node in values if node < 0 or node >= int(node_count)] + if bad: + raise ValueError(f"flat.Graph node set contains nodes outside [0, {int(node_count)}): {bad[:8]}") + return values + + +def _unique_sorted(nodes: Iterable[int]) -> tuple[int, ...]: + return tuple(sorted(set(int(node) for node in nodes))) + + +def _validate_population_coverage( + population_node_indices: Mapping[str, tuple[int, ...]], + *, + recurrent_nodes: tuple[int, ...], +) -> None: + recurrent_set = set(recurrent_nodes) + seen: set[int] = set() + for population_name, nodes in population_node_indices.items(): + if not nodes: + raise ValueError(f"population {population_name!r} has no recurrent nodes") + node_set = set(nodes) + extra = node_set - recurrent_set + if extra: + raise ValueError(f"population {population_name!r} targets non-recurrent nodes {sorted(extra)[:8]}") + overlap = seen & node_set + if overlap: + raise ValueError(f"population nodes overlap across populations: {sorted(overlap)[:8]}") + seen.update(node_set) + missing = recurrent_set - seen + if missing: + raise ValueError(f"population nodes must cover every recurrent node; missing={sorted(missing)[:8]}") + + +__all__ = ["Graph", "Output"] diff --git a/src/cortical/fabric/graphs/lattice2d.py b/src/cortical/fabric/graphs/lattice2d.py index d1a8e720..13ff7641 100644 --- a/src/cortical/fabric/graphs/lattice2d.py +++ b/src/cortical/fabric/graphs/lattice2d.py @@ -60,6 +60,8 @@ class Graph: width: int = 8 height: int = 8 wrap: bool = True + conduction_speed: float | None = None + max_delay: int | None = None populations: Mapping[str, Population] | None = None inputs: InputSpec = None outputs: OutputSpec = None @@ -124,6 +126,32 @@ def explicit_edges(self) -> ExplicitEdges | None: return item return None + def population_node_indices(self, *, recurrent_nodes: tuple[int, ...]) -> dict[str, tuple[int, ...]] | None: + population_specs = dict(self.populations or {}) + if len(population_specs) == 1: + _name, spec = next(iter(population_specs.items())) + if spec.nodes is None: + return None + node_sets = self.named_node_sets() + out: dict[str, tuple[int, ...]] = {} + recurrent_set = set(recurrent_nodes) + for name, spec in population_specs.items(): + selector = spec.nodes + if selector is None: + raise ValueError( + "multi-population lattice2d.Graph requires explicit Population.nodes for every population" + ) + if isinstance(selector, str): + try: + resolved = node_sets[selector] + except KeyError as exc: + raise ValueError(f"Unknown Fabric graph node_set {selector!r} for population {name!r}") from exc + out[name] = tuple(node for node in resolved if node in recurrent_set) + else: + out[name] = tuple(node for node in self._resolve_selector(selector) if node in recurrent_set) + _validate_population_coverage(out, recurrent_nodes=recurrent_nodes) + return out + def _resolve_selector(self, selector: NodeSelector) -> tuple[int, ...]: if isinstance(selector, Mapping): selector = _selector_from_mapping(selector) @@ -184,6 +212,29 @@ def _normalize_node_sequence(nodes: Sequence[int], *, node_count: int) -> tuple[ return values +def _validate_population_coverage( + population_node_indices: Mapping[str, tuple[int, ...]], + *, + recurrent_nodes: tuple[int, ...], +) -> None: + recurrent_set = set(recurrent_nodes) + seen: set[int] = set() + for population_name, nodes in population_node_indices.items(): + if not nodes: + raise ValueError(f"population {population_name!r} has no recurrent nodes") + node_set = set(nodes) + extra = node_set - recurrent_set + if extra: + raise ValueError(f"population {population_name!r} targets non-recurrent nodes {sorted(extra)[:8]}") + overlap = seen & node_set + if overlap: + raise ValueError(f"population nodes overlap across populations: {sorted(overlap)[:8]}") + seen.update(node_set) + missing = recurrent_set - seen + if missing: + raise ValueError(f"population nodes must cover every recurrent node; missing={sorted(missing)[:8]}") + + def _selector_from_mapping(selector: Mapping[str, object]) -> NodeSelector: keys = set(selector) if keys <= {"x", "y"}: diff --git a/src/cortical/fabric/graphs/lattice_anatomy.py b/src/cortical/fabric/graphs/lattice_anatomy.py new file mode 100644 index 00000000..36862d1b --- /dev/null +++ b/src/cortical/fabric/graphs/lattice_anatomy.py @@ -0,0 +1,634 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from functools import lru_cache +from itertools import product +from typing import Any + +import torch + +from cortical.fabric.config import Config +from cortical.fabric.graph import normalize_node_indices + + +@dataclass(frozen=True) +class LatticeAnatomyConfig: + coord_shape: tuple[int, ...] + coord_dim: int + local_radius: float + patch_edges_per_cell: int + patch_min_dist: float + patch_max_dist: float + wrap: bool + conduction_speed: float | None + max_delay: int | None + graph_edges: tuple[tuple[int, int], ...] | None + kv_group_ids: tuple[int, ...] | None + projection_region_shape: tuple[int, ...] | None + cell_arrangement: str + population_mix: dict[str, float] + population_node_indices: dict[str, tuple[int, ...]] | None + input_band_width: int + output_band_width: int + input_cell_indices: tuple[int, ...] | None + output_cell_indices: tuple[int, ...] | None + hidden_size: int + d_slot: int + population_count: int + seed: int + + +def lattice_anatomy_config_from_graph(graph: Any, cfg: Config) -> LatticeAnatomyConfig: + input_nodes = tuple(int(node) for node in graph.input_nodes()) + output_nodes = tuple(int(node) for node in graph.output_nodes()) + recurrent_nodes = _resolve_graph_recurrent_nodes(graph, input_nodes=input_nodes, output_nodes=output_nodes) + population_node_indices = cfg.populations.population_node_indices + graph_populations = getattr(graph, "populations", None) + if population_node_indices is None and graph_populations and hasattr(graph, "population_node_indices"): + population_node_indices = graph.population_node_indices(recurrent_nodes=recurrent_nodes) + explicit_edges = graph.explicit_edges() if hasattr(graph, "explicit_edges") else None + graph_edges = None + kv_group_ids = None + if explicit_edges is not None: + graph_edges = ( + tuple((int(receiver), int(sender)) for receiver, sender in explicit_edges) + if isinstance(explicit_edges, tuple) + else tuple((int(receiver), int(sender)) for receiver, sender in explicit_edges.edges) + ) + kv_group_ids = getattr(explicit_edges, "kv_group_ids", None) + if getattr(graph, "kv_group_ids", None) is not None: + kv_group_ids = tuple(int(group_id) for group_id in graph.kv_group_ids) + coord_shape = _graph_coord_shape(graph) + coord_dim = len(coord_shape) + patch_edges = graph.patch_edges() if hasattr(graph, "patch_edges") else None + return LatticeAnatomyConfig( + coord_shape=coord_shape, + coord_dim=coord_dim, + local_radius=float(graph.local_radius() if hasattr(graph, "local_radius") else 0.0), + patch_edges_per_cell=0 if patch_edges is None else int(patch_edges.per_cell), + patch_min_dist=4.0 if patch_edges is None else float(patch_edges.min_distance), + patch_max_dist=12.0 if patch_edges is None else float(patch_edges.max_distance), + wrap=bool(getattr(graph, "wrap", False)), + conduction_speed=None if getattr(graph, "conduction_speed", None) is None else float(graph.conduction_speed), + max_delay=None if getattr(graph, "max_delay", None) is None else int(graph.max_delay), + graph_edges=graph_edges, + kv_group_ids=None if kv_group_ids is None else tuple(int(group_id) for group_id in kv_group_ids), + projection_region_shape=cfg.message.projection_region_shape, + cell_arrangement=str(cfg.populations.cell_arrangement), + population_mix=dict(cfg.populations.population_mix), + population_node_indices=population_node_indices, + input_band_width=1, + output_band_width=1, + input_cell_indices=input_nodes, + output_cell_indices=output_nodes, + hidden_size=int(cfg.interface.hidden_size), + d_slot=int(cfg.interface.resolved_slot_dim), + population_count=len(cfg.populations.cell_populations), + seed=int(cfg.initialization.seed), + ) + + +def _graph_coord_shape(graph: Any) -> tuple[int, ...]: + if hasattr(graph, "width") and hasattr(graph, "height"): + return (int(graph.width), int(graph.height)) + return (int(graph.node_count),) + + +def _resolve_graph_recurrent_nodes( + graph: Any, + *, + input_nodes: tuple[int, ...], + output_nodes: tuple[int, ...], +) -> tuple[int, ...]: + recurrent_nodes = graph.recurrent_nodes() if hasattr(graph, "recurrent_nodes") else None + if recurrent_nodes is not None: + return tuple(int(node) for node in recurrent_nodes) + boundary = set(input_nodes) | set(output_nodes) + return tuple(node for node in range(int(graph.node_count)) if node not in boundary) + + +def build_lattice_coords(cfg: LatticeAnatomyConfig) -> torch.Tensor: + axes = [torch.arange(size, dtype=torch.float32) for size in cfg.coord_shape] + if len(axes) == 1: + return axes[0].view(-1, 1) + return torch.cartesian_prod(*axes) + + +def assign_lattice_cells( + cfg: LatticeAnatomyConfig, + *, + recurrent_cell_idx: torch.Tensor, + recurrent_coords: torch.Tensor, + population_names: tuple[str, ...], +) -> torch.Tensor: + num_cells = recurrent_coords.shape[0] + if cfg.population_node_indices is not None: + return _assign_explicit_lattice_population_nodes( + cfg, + recurrent_cell_idx=recurrent_cell_idx, + population_names=population_names, + ) + weights = torch.tensor([cfg.population_mix[name] for name in population_names], dtype=torch.float64) + weights = weights / weights.sum() + expected = weights * float(num_cells) + counts = expected.floor().to(torch.long) + remainder = int(num_cells - counts.sum().item()) + if remainder > 0: + frac = expected - counts.to(expected.dtype) + order = torch.argsort(frac, descending=True) + counts[order[:remainder]] += 1 + labels = [] + for population_idx, count in enumerate(counts.tolist()): + labels.extend([population_idx] * count) + layout = torch.tensor(labels, dtype=torch.long) + if cfg.cell_arrangement == "x_bands": + order = torch.argsort(_lexsort_key(recurrent_coords)) + arranged = torch.empty(num_cells, dtype=torch.long) + arranged[order] = layout + return arranged + gen = torch.Generator(device="cpu") + gen.manual_seed(cfg.seed) + perm = torch.randperm(num_cells, generator=gen) + return layout.index_select(0, perm) + + +def build_lattice_sparse_graph( + cfg: LatticeAnatomyConfig, + coords: torch.Tensor, + *, + input_cell_idx: torch.Tensor, + output_cell_idx: torch.Tensor, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, +]: + if cfg.graph_edges is not None: + return _build_explicit_sparse_graph( + cfg, + coords, + input_cell_idx=input_cell_idx, + output_cell_idx=output_cell_idx, + ) + num_cells = coords.shape[0] + coord_dim = coords.shape[1] + coords_long = coords.to(torch.long) + shape = cfg.coord_shape + shape_tensor = torch.tensor(shape, dtype=torch.long) + stride_tensor = torch.tensor(_flat_index_strides(shape), dtype=torch.long) + input_mask = torch.zeros(num_cells, dtype=torch.bool) + input_mask[input_cell_idx] = True + output_mask = torch.zeros(num_cells, dtype=torch.bool) + output_mask[output_cell_idx] = True + recv_idx = torch.arange(num_cells, dtype=torch.long) + local_offsets = _neighbor_offsets(coord_dim=coord_dim, min_distance=0.0, max_distance=cfg.local_radius) + patch_offsets = ( + _neighbor_offsets(coord_dim=coord_dim, min_distance=cfg.patch_min_dist, max_distance=cfg.patch_max_dist) + if cfg.patch_edges_per_cell > 0 + else () + ) + local_offset_tensor = torch.tensor([delta for delta, _distance in local_offsets], dtype=torch.int32) + local_valid = torch.zeros(num_cells, len(local_offsets), dtype=torch.bool) + local_distance = torch.tensor([distance for _delta, distance in local_offsets], dtype=coords.dtype) + local_delay = ( + torch.tensor( + [_edge_delay(distance=distance, cfg=cfg) for _delta, distance in local_offsets], + dtype=torch.int32, + ) + if cfg.max_delay is not None + else None + ) + max_slots = len(local_offsets) + cfg.patch_edges_per_cell + if max_slots == 0: + raise ValueError("fabric graph has no edges; increase local_radius or change anatomy size") + + neighbor_idx = torch.zeros(num_cells, max_slots, dtype=torch.long) + neighbor_valid = torch.zeros(num_cells, max_slots, dtype=torch.bool) + edge_type = torch.zeros(num_cells, max_slots, dtype=torch.long) + edge_distance = torch.zeros(num_cells, max_slots, dtype=coords.dtype) + edge_delay = torch.ones(num_cells, max_slots, dtype=torch.long) if cfg.max_delay is not None else None + neighbor_counts = torch.zeros(num_cells, dtype=torch.long) + patch_counts = torch.zeros(num_cells, dtype=torch.long) if cfg.patch_edges_per_cell > 0 else None + max_neighbors = 0 + for offset_idx, (delta, distance) in enumerate(local_offsets): + max_neighbors = _append_offset_neighbors( + coords=coords_long, + shape_tensor=shape_tensor, + stride_tensor=stride_tensor, + wrap=cfg.wrap, + recv_idx=recv_idx, + input_mask=input_mask, + output_mask=output_mask, + delta=delta, + distance=distance, + edge_kind=0, + patch_limit=None, + neighbor_idx=neighbor_idx, + neighbor_valid=neighbor_valid, + edge_type=edge_type, + edge_distance=edge_distance, + edge_delay=edge_delay, + neighbor_counts=neighbor_counts, + patch_counts=patch_counts, + delay_value=_edge_delay(distance=distance, cfg=cfg), + current_max_neighbors=max_neighbors, + offset_valid=local_valid, + offset_slot=offset_idx, + ) + for delta, distance in patch_offsets: + max_neighbors = _append_offset_neighbors( + coords=coords_long, + shape_tensor=shape_tensor, + stride_tensor=stride_tensor, + wrap=cfg.wrap, + recv_idx=recv_idx, + input_mask=input_mask, + output_mask=output_mask, + delta=delta, + distance=distance, + edge_kind=1, + patch_limit=cfg.patch_edges_per_cell, + neighbor_idx=neighbor_idx, + neighbor_valid=neighbor_valid, + edge_type=edge_type, + edge_distance=edge_distance, + edge_delay=edge_delay, + neighbor_counts=neighbor_counts, + patch_counts=patch_counts, + delay_value=_edge_delay(distance=distance, cfg=cfg), + current_max_neighbors=max_neighbors, + ) + + if max_neighbors == 0: + raise ValueError("fabric graph has no edges; increase local_radius or change anatomy size") + edge_delay_out = edge_delay[:, :max_neighbors] if edge_delay is not None else None + return ( + local_offset_tensor, + local_valid, + local_distance, + local_delay, + neighbor_idx[:, :max_neighbors], + neighbor_valid[:, :max_neighbors], + edge_type[:, :max_neighbors], + edge_distance[:, :max_neighbors], + edge_delay_out, + ) + + +def build_lattice_kv_groups(cfg: LatticeAnatomyConfig, coords: torch.Tensor) -> tuple[torch.Tensor, int]: + if cfg.kv_group_ids is not None: + kv_group_id = torch.tensor(cfg.kv_group_ids, dtype=torch.long) + num_groups = int(kv_group_id.max().item()) + 1 if kv_group_id.numel() > 0 else 0 + return kv_group_id, num_groups + if cfg.projection_region_shape is None: + region_shape = tuple(max(1, size // 4) for size in cfg.coord_shape) + else: + region_shape = cfg.projection_region_shape + region = torch.div(coords.to(torch.long), torch.tensor(region_shape, dtype=torch.long), rounding_mode="floor") + grid_dims = [(size + tile - 1) // tile for size, tile in zip(cfg.coord_shape, region_shape, strict=True)] + strides = [] + acc = 1 + for size in reversed(grid_dims[1:]): + acc *= size + strides.append(acc) + strides = list(reversed(strides)) + [1] + kv_group_id = torch.zeros(region.shape[0], dtype=torch.long) + for axis in range(region.shape[1]): + kv_group_id = kv_group_id + region[:, axis] * strides[axis] + num_groups = int(kv_group_id.max().item()) + 1 if kv_group_id.numel() > 0 else 0 + return kv_group_id.to(torch.long), num_groups + + +def build_lattice_ports(cfg: LatticeAnatomyConfig, coords: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + num_cells = coords.shape[0] + explicit_input = normalize_node_indices( + cfg.input_cell_indices, + node_count=num_cells, + name="input_cell_indices", + ) + explicit_output = normalize_node_indices( + cfg.output_cell_indices, + node_count=num_cells, + name="output_cell_indices", + ) + x_coord = coords[:, 0] + if explicit_input is None: + input_mask = x_coord < float(cfg.input_band_width) + input_idx = torch.nonzero(input_mask, as_tuple=False).reshape(-1).to(torch.long) + else: + input_idx = explicit_input + if explicit_output is None: + output_mask = x_coord >= float(int(cfg.coord_shape[0]) - cfg.output_band_width) + output_idx = torch.nonzero(output_mask, as_tuple=False).reshape(-1).to(torch.long) + else: + output_idx = explicit_output + if input_idx.numel() == 0 or output_idx.numel() == 0: + raise ValueError("port construction produced an empty input or output port set") + if bool(torch.isin(input_idx, output_idx).any()): + raise ValueError("input and output port cells must be disjoint") + return input_idx, output_idx + + +def build_lattice_slot_init( + cfg: LatticeAnatomyConfig, + coords: torch.Tensor, + cell_layout: torch.Tensor, + recurrent_idx: torch.Tensor, + input_idx: torch.Tensor, + output_idx: torch.Tensor, +) -> torch.Tensor: + shape = torch.tensor(cfg.coord_shape, dtype=coords.dtype) + coords_norm = coords / shape.view(1, -1).clamp_min(1.0) + sin_feat = torch.sin(2.0 * math.pi * coords_norm) + cos_feat = torch.cos(2.0 * math.pi * coords_norm) + num_populations = int(cfg.population_count) + population_one_hot = torch.zeros(coords.shape[0], num_populations, dtype=coords.dtype) + population_one_hot[recurrent_idx] = torch.nn.functional.one_hot( + cell_layout[recurrent_idx], num_classes=num_populations + ).to(coords.dtype) + input_mask = torch.zeros(coords.shape[0], 1, dtype=coords.dtype) + input_mask[input_idx] = 1.0 + output_mask = torch.zeros(coords.shape[0], 1, dtype=coords.dtype) + output_mask[output_idx] = 1.0 + base = torch.cat([coords_norm, sin_feat, cos_feat, population_one_hot, input_mask, output_mask], dim=-1) + d_slot = int(cfg.d_slot) + repeats = math.ceil(d_slot / base.shape[1]) + slot = base.repeat(1, repeats)[:, :d_slot] + gen = torch.Generator(device="cpu") + gen.manual_seed(cfg.seed + 17) + noise = 0.01 * torch.randn(coords.shape[0], d_slot, generator=gen, dtype=coords.dtype) + return slot + noise + + +def build_lattice_local_sender_table( + *, + receiver_coords: torch.Tensor, + sender_lookup: torch.Tensor, + local_offsets: torch.Tensor, + local_valid: torch.Tensor, + coord_shape: tuple[int, ...], + wrap: bool, +) -> torch.Tensor: + receiver_coords_long = receiver_coords.to(torch.long) + local_offsets_long = local_offsets.to(torch.long) + sender_table = torch.full(local_valid.shape, -1, dtype=torch.long) + target_coords = receiver_coords_long[:, None, :] + local_offsets_long[None, :, :] + for dim, size in enumerate(coord_shape): + if wrap: + target_coords[..., dim] = torch.remainder(target_coords[..., dim], size) + else: + target_coords[..., dim].clamp_(0, size - 1) + target_flat = target_coords[..., 0] + for dim, size in enumerate(coord_shape[1:], start=1): + target_flat = target_flat * size + target_coords[..., dim] + sender_table[local_valid] = sender_lookup.index_select(0, target_flat[local_valid]) + if bool((sender_table[local_valid] < 0).any()): + raise ValueError("Local receiver subset contains a sender outside the compact sender set") + return sender_table + + +def _assign_explicit_lattice_population_nodes( + cfg: LatticeAnatomyConfig, + *, + recurrent_cell_idx: torch.Tensor, + population_names: tuple[str, ...], +) -> torch.Tensor: + assert cfg.population_node_indices is not None + recurrent_nodes = [int(idx) for idx in recurrent_cell_idx.tolist()] + global_to_recurrent = {node: local_idx for local_idx, node in enumerate(recurrent_nodes)} + layout = -torch.ones(len(recurrent_nodes), dtype=torch.long) + for population_idx, population_name in enumerate(population_names): + for node in cfg.population_node_indices[population_name]: + recurrent_idx = global_to_recurrent.get(int(node)) + if recurrent_idx is None: + raise ValueError( + f"population {population_name!r} targets non-recurrent node {int(node)}; " + "population nodes must exclude input and output boundary nodes" + ) + layout[recurrent_idx] = population_idx + missing = torch.nonzero(layout < 0, as_tuple=False).reshape(-1) + if missing.numel() > 0: + missing_nodes = recurrent_cell_idx.index_select(0, missing[:8]).tolist() + raise ValueError( + "population_node_indices must cover every recurrent node exactly once; " + f"missing recurrent nodes={missing_nodes}" + ) + return layout + + +def _build_explicit_sparse_graph( + cfg: LatticeAnatomyConfig, + coords: torch.Tensor, + *, + input_cell_idx: torch.Tensor, + output_cell_idx: torch.Tensor, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, +]: + assert cfg.graph_edges is not None + num_cells = coords.shape[0] + coord_dim = coords.shape[1] + input_mask = torch.zeros(num_cells, dtype=torch.bool) + input_mask[input_cell_idx] = True + output_mask = torch.zeros(num_cells, dtype=torch.bool) + output_mask[output_cell_idx] = True + receivers = torch.tensor([receiver for receiver, _sender in cfg.graph_edges], dtype=torch.long) + senders = torch.tensor([sender for _receiver, sender in cfg.graph_edges], dtype=torch.long) + if bool(input_mask.index_select(0, receivers).any()): + raise ValueError("graph_edges must not target input boundary nodes") + if bool(output_mask.index_select(0, senders).any()): + raise ValueError("graph_edges must not use output boundary nodes as senders") + direct_input_to_output = output_mask.index_select(0, receivers) & input_mask.index_select(0, senders) + if bool(direct_input_to_output.any()): + raise ValueError("graph_edges must not connect input boundary nodes directly into output boundary nodes") + degree = torch.bincount(receivers, minlength=num_cells) + max_neighbors = int(degree.max().item()) if degree.numel() > 0 else 0 + if max_neighbors == 0: + raise ValueError("fabric graph has no edges") + neighbor_idx = torch.zeros(num_cells, max_neighbors, dtype=torch.long) + neighbor_valid = torch.zeros(num_cells, max_neighbors, dtype=torch.bool) + edge_type = torch.ones(num_cells, max_neighbors, dtype=torch.long) + edge_distance = torch.ones(num_cells, max_neighbors, dtype=coords.dtype) + edge_delay = torch.ones(num_cells, max_neighbors, dtype=torch.long) if cfg.max_delay is not None else None + write_pos = torch.zeros(num_cells, dtype=torch.long) + for receiver, sender in cfg.graph_edges: + col = int(write_pos[receiver].item()) + neighbor_idx[receiver, col] = int(sender) + neighbor_valid[receiver, col] = True + write_pos[receiver] += 1 + local_offsets = torch.empty(0, coord_dim, dtype=torch.int32) + local_valid = torch.zeros(num_cells, 0, dtype=torch.bool) + local_distance = torch.empty(0, dtype=coords.dtype) + local_delay = torch.empty(0, dtype=torch.int32) if cfg.max_delay is not None else None + return ( + local_offsets, + local_valid, + local_distance, + local_delay, + neighbor_idx, + neighbor_valid, + edge_type, + edge_distance, + edge_delay, + ) + + +def _append_offset_neighbors( + *, + coords: torch.Tensor, + shape_tensor: torch.Tensor, + stride_tensor: torch.Tensor, + wrap: bool, + recv_idx: torch.Tensor, + input_mask: torch.Tensor, + output_mask: torch.Tensor, + delta: tuple[int, ...], + distance: float, + edge_kind: int, + patch_limit: int | None, + neighbor_idx: torch.Tensor, + neighbor_valid: torch.Tensor, + edge_type: torch.Tensor, + edge_distance: torch.Tensor, + edge_delay: torch.Tensor | None, + neighbor_counts: torch.Tensor, + patch_counts: torch.Tensor | None, + delay_value: int, + current_max_neighbors: int, + offset_valid: torch.Tensor | None = None, + offset_slot: int | None = None, +) -> int: + send_idx, valid = _resolve_offset_indices( + coords=coords, + shape_tensor=shape_tensor, + stride_tensor=stride_tensor, + wrap=wrap, + delta=delta, + ) + valid = valid & ~input_mask + valid = valid & (send_idx != recv_idx) + valid = valid & ~output_mask.index_select(0, send_idx) + valid = valid & ~(output_mask & input_mask.index_select(0, send_idx)) + if patch_limit is not None and patch_counts is not None: + valid = valid & (patch_counts < patch_limit) + if current_max_neighbors > 0: + selected_idx = neighbor_idx[:, :current_max_neighbors] + selected_valid = neighbor_valid[:, :current_max_neighbors] + duplicate = ((selected_idx == send_idx.unsqueeze(1)) & selected_valid).any(dim=1) + valid = valid & ~duplicate + rows = torch.nonzero(valid, as_tuple=False).reshape(-1) + if rows.numel() == 0: + return current_max_neighbors + cols = neighbor_counts.index_select(0, rows) + send_rows = send_idx.index_select(0, rows) + neighbor_idx[rows, cols] = send_rows + neighbor_valid[rows, cols] = True + edge_type[rows, cols] = edge_kind + edge_distance[rows, cols] = distance + if offset_valid is not None and offset_slot is not None: + offset_valid[rows, offset_slot] = True + if edge_delay is not None: + edge_delay[rows, cols] = delay_value + neighbor_counts[rows] += 1 + if patch_limit is not None and patch_counts is not None: + patch_counts[rows] += 1 + return max(current_max_neighbors, int(neighbor_counts[rows].max().item())) + + +def _resolve_offset_indices( + *, + coords: torch.Tensor, + shape_tensor: torch.Tensor, + stride_tensor: torch.Tensor, + wrap: bool, + delta: tuple[int, ...], +) -> tuple[torch.Tensor, torch.Tensor]: + shifted = coords + torch.tensor(delta, dtype=torch.long) + if wrap: + shifted = torch.remainder(shifted, shape_tensor) + valid = torch.ones(coords.shape[0], dtype=torch.bool) + else: + valid = ((shifted >= 0) & (shifted < shape_tensor)).all(dim=1) + shifted = torch.where(valid.unsqueeze(1), shifted, torch.zeros_like(shifted)) + return (shifted * stride_tensor).sum(dim=1), valid + + +@lru_cache(maxsize=None) +def _flat_index_strides(shape: tuple[int, ...]) -> tuple[int, ...]: + strides: list[int] = [] + acc = 1 + for size in reversed(shape[1:]): + acc *= int(size) + strides.append(acc) + return tuple(list(reversed(strides)) + [1]) + + +@lru_cache(maxsize=None) +def _neighbor_offsets( + *, + coord_dim: int, + min_distance: float, + max_distance: float, +) -> tuple[tuple[tuple[int, ...], float], ...]: + radius = math.ceil(max_distance) + tol = 1e-6 + offsets: list[tuple[tuple[int, ...], float]] = [] + for delta in product(range(-radius, radius + 1), repeat=coord_dim): + if all(step == 0 for step in delta): + continue + distance = math.sqrt(sum(step * step for step in delta)) + if distance + tol < min_distance or distance > max_distance + tol: + continue + offsets.append((tuple(int(step) for step in delta), distance)) + offsets.sort(key=lambda item: (item[1], tuple(abs(step) for step in item[0]), item[0])) + return tuple(offsets) + + +def _edge_delay(*, distance: float, cfg: LatticeAnatomyConfig) -> int: + if cfg.conduction_speed is None or cfg.max_delay is None: + return 1 + delay = int(math.ceil(distance / cfg.conduction_speed)) + return max(1, min(cfg.max_delay, delay)) + + +def _lexsort_key(coords: torch.Tensor) -> torch.Tensor: + strides = [] + acc = 1 + max_vals = coords.max(dim=0).values.to(torch.long) + 1 + for size in reversed(max_vals[1:].tolist()): + acc *= int(size) + strides.append(acc) + strides = list(reversed(strides)) + [1] + key = torch.zeros(coords.shape[0], dtype=torch.long) + coords_long = coords.to(torch.long) + for axis, stride in enumerate(strides): + key = key + coords_long[:, axis] * stride + return key + + +__all__ = [ + "LatticeAnatomyConfig", + "assign_lattice_cells", + "build_lattice_coords", + "build_lattice_kv_groups", + "build_lattice_local_sender_table", + "build_lattice_ports", + "build_lattice_slot_init", + "build_lattice_sparse_graph", + "lattice_anatomy_config_from_graph", +] diff --git a/src/cortical/fabric/message_rules/__init__.py b/src/cortical/fabric/message_rules/__init__.py index 089f7d1a..ab41416f 100644 --- a/src/cortical/fabric/message_rules/__init__.py +++ b/src/cortical/fabric/message_rules/__init__.py @@ -1,3 +1,9 @@ -from cortical.fabric.message_rules.declarations import DotProduct, ReceiverSlot, SenderPublic, ShareBySenderTile +from cortical.fabric.message_rules.declarations import ( + DotProduct, + FixedSlotContextNudge, + ReceiverSlot, + SenderPublic, + ShareBySenderTile, +) -__all__ = ["DotProduct", "ReceiverSlot", "SenderPublic", "ShareBySenderTile"] +__all__ = ["DotProduct", "FixedSlotContextNudge", "ReceiverSlot", "SenderPublic", "ShareBySenderTile"] diff --git a/src/cortical/fabric/message_rules/declarations.py b/src/cortical/fabric/message_rules/declarations.py index 70d14487..a278a14e 100644 --- a/src/cortical/fabric/message_rules/declarations.py +++ b/src/cortical/fabric/message_rules/declarations.py @@ -1,8 +1,9 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Literal -from cortical.fabric.backend.message_rules import MessageRuleIR, default_dot_product_message_rule_ir +from cortical.fabric.backend.message_rules import MessageRuleIR, build_message_rule_ir @dataclass(frozen=True) @@ -15,6 +16,16 @@ class SenderPublic: """Sender public-state source for message rules.""" +@dataclass(frozen=True) +class FixedSlotContextNudge: + """Fixed sender/receiver slot routing with receiver-content query nudge.""" + + +@dataclass(frozen=True) +class FixedSlotContextGate: + """Fixed sender/receiver slot routing with receiver-content query gate.""" + + @dataclass(frozen=True) class ShareBySenderTile: """Semantic sender sharing declaration. @@ -45,10 +56,22 @@ class DotProduct: value_dim: int | None = None kv_sharing: ShareBySenderTile | None = None output_dim: str = "message" + math: Literal["dynamic_key_value", "fixed_slot_context_nudge", "fixed_slot_context_gate"] = ( + "fixed_slot_context_nudge" + ) def to_ir(self, *, kv_group_count: int, cell_count: int) -> MessageRuleIR: self._validate_supported_surface() - return default_dot_product_message_rule_ir(kv_group_count=kv_group_count, cell_count=cell_count) + rule_type = { + "dynamic_key_value": "dot_product", + "fixed_slot_context_nudge": "dot_product_fixed_slot_context_nudge", + "fixed_slot_context_gate": "dot_product_fixed_slot_context_gate", + }[self.math] + return build_message_rule_ir( + rule_type=rule_type, + kv_group_count=kv_group_count, + cell_count=cell_count, + ) def projection_region_shape(self, *, coord_dim: int) -> tuple[int, ...] | None: if self.kv_sharing is None: @@ -68,6 +91,15 @@ def _validate_supported_surface(self) -> None: raise ValueError("Fabric v1 DotProduct message rule must emit the canonical projected_message boundary") if self.value_dim is not None and self.head_dim is not None and int(self.value_dim) != int(self.head_dim): raise ValueError("Fabric v1 DotProduct requires value_dim to match head_dim when value_dim is set") - - -__all__ = ["DotProduct", "ReceiverSlot", "SenderPublic", "ShareBySenderTile"] + if self.math not in {"dynamic_key_value", "fixed_slot_context_nudge", "fixed_slot_context_gate"}: + raise ValueError(f"Unsupported DotProduct math variant {self.math!r}") + + +__all__ = [ + "DotProduct", + "FixedSlotContextGate", + "FixedSlotContextNudge", + "ReceiverSlot", + "SenderPublic", + "ShareBySenderTile", +] diff --git a/src/cortical/fabric/runtime/core.py b/src/cortical/fabric/runtime/core.py index 9082aea9..f0bbc48e 100644 --- a/src/cortical/fabric/runtime/core.py +++ b/src/cortical/fabric/runtime/core.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn from tensordict import TensorDict, TensorDictBase -from torch.utils.checkpoint import checkpoint from cortical.fabric.anatomy import Spec from cortical.fabric.backend.caps import DeviceCaps, detect_device_caps @@ -18,7 +17,7 @@ receiver_major_projection_backward_gate, ) from cortical.fabric.backend.cuda.sequence_surface import CudaSequenceSurfaceMixin -from cortical.fabric.backend.cuda.sequence_surface.policy import ( +from cortical.fabric.backend.cuda.sequence_surface.runtime.policy import ( CudaMemoryBudget, LayoutBatchTileInputs, PolicyDecision, @@ -30,11 +29,15 @@ tape_checkpoint_policy, tape_memory_chunk_len, ) -from cortical.fabric.backend.cuda.sequence_surface.support import _transition_supports_receiver_local_dependency_window -from cortical.fabric.backend.cuda.sequence_surface.temporal_executor import ( - execute_temporal_bucket_active_output_window, +from cortical.fabric.backend.cuda.sequence_surface.runtime.support import ( + _transition_supports_receiver_local_dependency_window, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.executor import ( execute_temporal_bucket_sequence, - record_temporal_bucket_sequence_surface_execution, +) +from cortical.fabric.backend.cuda.sequence_surface.runtime.memory_stages import ( + record_frontend_tensor_bytes, + record_registered_memory_stage, ) from cortical.fabric.backend.graph_regions import ( close_recurrent_region_from_sender_tables, @@ -49,6 +52,7 @@ PlannedFabricExecution, SequenceSurfaceRoute, ) +from cortical.fabric.backend.temporal_plan import TemporalExecutionPlan, temporal_execution_record_metadata from cortical.fabric.backend.pytorch.readout import ( ReadoutConfig, ) @@ -62,13 +66,11 @@ select_output_cells as backend_select_output_cells, ) from cortical.fabric.backend.runtime_dispatch import BackendRuntimeDispatchMixin -from cortical.fabric.backend.selector import select_fabric_backend from cortical.fabric.backend.surfaces import ( SUPPORTED_BACKEND_SURFACES, BackendExecutionRecord, SupportedSurface, supported_surface_by_key, - supported_surface_for_cell_type, ) from cortical.fabric.backend.tape import TapeMode, TapePolicy, default_tape_policy from cortical.fabric.backend.workspace import GraphCaptureWorkspace @@ -88,10 +90,6 @@ _ModelOutputReducer = Callable[[torch.Tensor, int, int], torch.Tensor] -def _population_display_name(population_name: str) -> str: - return population_name - - @dataclass(frozen=True) class _BackendGraphInputLayout: input_names: tuple[str, ...] @@ -106,23 +104,27 @@ def __init__(self, spec: Spec) -> None: super().__init__() self.spec = spec self.config = spec.config - self.hidden_size = int(spec.config.hidden_size) - self.num_heads = int(spec.config.num_heads) - self.head_dim = int(spec.config.head_dim) - self.value_dim = int(spec.config.head_dim) + self.hidden_size = int(spec.config.interface.hidden_size) + self.d_public = int(spec.config.interface.resolved_public_dim) + self.d_msg = int(spec.config.interface.resolved_message_dim) + self.d_slot = int(spec.config.interface.resolved_slot_dim) + self.num_heads = int(spec.config.message.resolved_num_heads) + self.head_dim = int(spec.config.message.head_dim) + self.value_dim = int(spec.config.message.head_dim) self._has_edge_delay = spec.anatomy.edge_delay is not None if self.num_heads != 1: raise ValueError("Runtime fast path currently requires single-head fabric attention") - if spec.config.readout_pool == "mean": + readout_pool = str(spec.config.readout.pool) + if readout_pool == "mean": self.readout_slots = 1 - elif spec.config.readout_pool == "flatten": + elif readout_pool == "flatten": self.readout_slots = int(spec.output_cell_idx.numel()) else: - self.readout_slots = int(spec.config.readout_slots) + self.readout_slots = int(spec.config.readout.slots) self._population_names = spec.population_names self._population_name_to_idx = {name: idx for idx, name in enumerate(self._population_names)} self._population_cell_types = { - name: self.config.cell_populations[name].cell_type for name in self._population_names + name: self.config.populations.cell_populations[name].cell_type for name in self._population_names } self.register_buffer("cell_layout", spec.anatomy.cell_layout.clone()) @@ -164,14 +166,7 @@ def __init__(self, spec: Spec) -> None: recurrent_lookup[self.recurrent_cell_idx] = torch.arange(self.recurrent_cell_idx.numel(), dtype=torch.long) self.register_buffer("recurrent_lookup", recurrent_lookup) self.register_buffer("recurrent_sender_idx", sender_lookup[self.recurrent_cell_idx].clone()) - full_local_sender_idx = _build_local_sender_table( - receiver_coords=spec.anatomy.coords, - sender_lookup=sender_lookup, - local_offsets=spec.anatomy.local_offsets, - local_valid=spec.anatomy.local_valid, - coord_shape=tuple(int(size) for size in spec.config.coord_shape), - wrap=bool(spec.config.wrap), - ) + full_local_sender_idx = spec.anatomy.full_local_sender_idx self.register_buffer("full_local_sender_idx", full_local_sender_idx.to(torch.int32)) self.register_buffer( "full_local_receiver_idx_by_sender", @@ -249,14 +244,7 @@ def __init__(self, spec: Spec) -> None: self.register_buffer("recurrent_sparse_degree_ptr", recurrent_sparse_degree_ptr) self._recurrent_sparse_positive_degree_buckets = recurrent_sparse_positive_degree_buckets self.register_buffer("recurrent_local_valid", self.local_valid.index_select(0, self.recurrent_cell_idx)) - recurrent_local_sender_idx = _build_local_sender_table( - receiver_coords=spec.anatomy.coords.index_select(0, self.recurrent_cell_idx), - sender_lookup=sender_lookup, - local_offsets=spec.anatomy.local_offsets, - local_valid=self.recurrent_local_valid, - coord_shape=tuple(int(size) for size in spec.config.coord_shape), - wrap=bool(spec.config.wrap), - ) + recurrent_local_sender_idx = spec.anatomy.recurrent_local_sender_idx self.register_buffer("recurrent_local_sender_idx", recurrent_local_sender_idx.to(torch.int32)) self.register_buffer( "recurrent_local_receiver_idx_by_sender", @@ -284,14 +272,7 @@ def __init__(self, spec: Spec) -> None: self.register_buffer("output_edge_distance", output_edge_distance) self.register_buffer("output_edge_delay", output_edge_delay) self.register_buffer("output_local_valid", self.local_valid.index_select(0, self.output_cell_idx)) - output_local_sender_idx = _build_local_sender_table( - receiver_coords=spec.anatomy.coords.index_select(0, self.output_cell_idx), - sender_lookup=sender_lookup, - local_offsets=spec.anatomy.local_offsets, - local_valid=self.output_local_valid, - coord_shape=tuple(int(size) for size in spec.config.coord_shape), - wrap=bool(spec.config.wrap), - ) + output_local_sender_idx = spec.anatomy.output_local_sender_idx self.register_buffer("output_local_sender_idx", output_local_sender_idx.to(torch.int32)) self.register_buffer( "output_local_receiver_idx_by_sender", @@ -321,33 +302,23 @@ def __init__(self, spec: Spec) -> None: receiver_sender_idx=output_neighbor_idx, receiver_valid=output_neighbor_valid, ) - self._local_message_step_enabled = bool( - spec.config.patch_edges_per_cell == 0 - and int(self.local_offsets.shape[0]) > 0 - and int(self.coords.shape[1]) <= 3 - ) + self._local_message_step_enabled = bool(int(self.local_offsets.shape[0]) > 0 and int(self.coords.shape[1]) <= 3) self._uses_sparse_message_backend = bool( not self._local_message_step_enabled or bool((spec.anatomy.edge_type[spec.anatomy.neighbor_valid] != 0).any().item()) ) - self._coord_shape = tuple(int(size) for size in spec.config.coord_shape) - self.slot_embed = nn.Parameter(spec.slot_init.clone()) - self.public_proj = nn.Linear(self.hidden_size, int(self.config.d_public), bias=False) - self.input_proj = nn.Linear(self.hidden_size, int(self.config.d_msg), bias=False) - self.msg_to_cell = nn.Linear(int(self.config.d_msg), self.hidden_size, bias=False) - self.cell_bias_proj = nn.Linear(int(self.config.d_slot), self.hidden_size, bias=False) - self.q_proj = nn.Linear(int(self.config.d_slot), self.num_heads * self.head_dim, bias=False) - self.k_weight = nn.Parameter( - torch.empty(spec.num_kv_groups, int(self.config.d_public), self.num_heads * self.head_dim) - ) - self.v_weight = nn.Parameter( - torch.empty(spec.num_kv_groups, int(self.config.d_public), self.num_heads * self.value_dim) - ) - self.msg_out = nn.Linear(self.num_heads * self.value_dim, int(self.config.d_msg), bias=False) + self.public_proj = nn.Linear(self.hidden_size, self.d_public, bias=False) + self.input_proj = nn.Linear(self.hidden_size, self.d_msg, bias=False) + self.msg_to_cell = nn.Linear(self.d_msg, self.hidden_size, bias=False) + self.cell_bias_proj = nn.Linear(self.d_slot, self.hidden_size, bias=False) + self.q_proj = nn.Linear(self.d_slot, self.num_heads * self.head_dim, bias=False) + self.k_weight = nn.Parameter(torch.empty(spec.num_kv_groups, self.d_public, self.num_heads * self.head_dim)) + self.v_weight = nn.Parameter(torch.empty(spec.num_kv_groups, self.d_public, self.num_heads * self.value_dim)) + self.msg_out = nn.Linear(self.num_heads * self.value_dim, self.d_msg, bias=False) self.output_cell_weight = nn.Parameter( - torch.empty(int(self.output_cell_idx.numel()), int(self.config.d_msg), self.hidden_size) + torch.empty(int(self.output_cell_idx.numel()), self.d_msg, self.hidden_size) ) self.output_cell_bias = nn.Parameter(torch.empty(int(self.output_cell_idx.numel()), self.hidden_size)) self.readout_query = nn.Parameter(torch.empty(self.readout_slots, self.hidden_size)) @@ -366,26 +337,29 @@ def __init__(self, spec: Spec) -> None: ): self._full_recurrent_population_name = name self.population_modules[name] = build_cell_population_module( - self.config.cell_populations[name], + self.config.populations.cell_populations[name], self.hidden_size, num_cells=int(indices.numel()), - init_noise_std=float(self.config.population_init_noise_std), + init_noise_std=float(self.config.initialization.population_init_noise_std), ) self._register_population_backend_order_buffers() self._backend_ir = compile_fabric_ir( spec, hidden_size=self.hidden_size, - d_public=int(self.config.d_public), - d_msg=int(self.config.d_msg), + d_public=self.d_public, + d_msg=self.d_msg, head_dim=self.head_dim, value_dim=self.value_dim, ) + self.message_rule_modules: dict[str, nn.Module] = {} + self.message_rule_parameters: dict[str, nn.Parameter] = {} + self._install_message_rule_runtime_state() self._backend_population_specs = { name: build_cell_backend_spec( cell_type=self._population_cell_types[name], hidden_size=self.hidden_size, - d_public=int(self.config.d_public), - d_msg=int(self.config.d_msg), + d_public=self.d_public, + d_msg=self.d_msg, head_dim=self.head_dim, value_dim=self.value_dim, ) @@ -399,6 +373,7 @@ def __init__(self, spec: Spec) -> None: self._backend_device_caps_cache: dict[tuple[str, int], DeviceCaps] = {} self._backend_graph_capture_cache = FabricGraphCaptureCache() self._last_backend_execution: BackendExecutionRecord | None = None + self._last_temporal_execution_plan: TemporalExecutionPlan | None = None self._last_backend_launch_metadata: dict[str, tuple[Any, ...]] | None = None self._last_backend_tape_chunk_len: int | None = None self._last_backend_tape_chunk_reason: str | None = None @@ -432,6 +407,43 @@ def _clear_training_materialization_caches(self) -> None: self._training_static_cache.clear() self._constant_step_flat_cache.clear() + def _message_rule_dimension(self, role: str) -> int: + if role == "1": + return 1 + if role == "cell_count": + return int(self.coords.shape[0]) + if role == "d_slot": + return int(self.d_slot) + if role == "head_dim": + return int(self.head_dim) + if role == "value_dim": + return int(self.value_dim) + if role == "d_msg": + return int(self.d_msg) + if role == "hidden_size": + return int(self.hidden_size) + raise RuntimeError(f"Unsupported message rule runtime dimension role {role!r}") + + def _install_message_rule_runtime_state(self) -> None: + message_program = self._backend_ir.message_program + for module_spec in message_program.runtime_modules: + if module_spec.module_kind != "linear": + raise RuntimeError( + f"Unsupported message rule runtime module kind {module_spec.module_kind!r} for {module_spec.name!r}" + ) + module = nn.Linear( + self._message_rule_dimension(module_spec.input_dim_role), + self._message_rule_dimension(module_spec.output_dim_role), + bias=bool(module_spec.bias), + ) + setattr(self, module_spec.name, module) + self.message_rule_modules[str(module_spec.name)] = module + for parameter_spec in message_program.runtime_parameters: + shape = tuple(self._message_rule_dimension(role) for role in parameter_spec.shape_roles) + parameter = nn.Parameter(torch.empty(*shape)) + setattr(self, parameter_spec.name, parameter) + self.message_rule_parameters[str(parameter_spec.name)] = parameter + def _reset_parameters(self) -> None: self._clear_execution_caches() nn.init.xavier_uniform_(self.public_proj.weight) @@ -442,6 +454,26 @@ def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.k_weight) nn.init.xavier_uniform_(self.v_weight) nn.init.xavier_uniform_(self.msg_out.weight) + for module in self.message_rule_modules.values(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + for parameter_spec in self._backend_ir.message_program.runtime_parameters: + parameter = self.message_rule_parameters.get(str(parameter_spec.name)) + if parameter is None: + raise RuntimeError(f"Message rule parameter {parameter_spec.name!r} was not installed") + if parameter_spec.init == "ones": + nn.init.ones_(parameter) + elif parameter_spec.init == "zeros": + nn.init.zeros_(parameter) + elif parameter_spec.init == "normal_head": + nn.init.normal_(parameter, mean=0.0, std=1.0 / math.sqrt(max(1, self.head_dim))) + else: + raise RuntimeError( + f"Unsupported message rule parameter initializer {parameter_spec.init!r} " + f"for {parameter_spec.name!r}" + ) nn.init.xavier_uniform_(self.output_cell_weight) nn.init.zeros_(self.output_cell_bias) nn.init.normal_(self.readout_query, mean=0.0, std=1.0 / math.sqrt(max(1, self.hidden_size))) @@ -464,6 +496,10 @@ def backend_population_specs(self): def last_backend_execution(self) -> BackendExecutionRecord | None: return self._last_backend_execution + @property + def last_temporal_execution_plan(self) -> TemporalExecutionPlan | None: + return self._last_temporal_execution_plan + @property def graph_capture_cache_stats(self) -> dict[str, int]: return self._backend_graph_capture_cache.stats() @@ -479,6 +515,59 @@ def supported_backend_surfaces(self, *, cell_type: str | None = None, population def _cell_spec_for_population(self, population_name: str): return get_cell_spec(self._population_cell_types[population_name]) + def _default_population_state_names_for_population(self, population_name: str) -> tuple[str, ...]: + return tuple(self._cell_spec_for_population(population_name).state_schema.keys) + + def _compiled_transition_state_names_for_population(self, population_name: str) -> tuple[str, ...]: + binding_slot = int(self._population_name_to_idx[population_name]) + transition_program = self._backend_ir.transition_program_for_binding_slot(binding_slot) + return tuple(str(schema.name) for schema in getattr(transition_program, "private_state_schema", ())) + + def _init_population_state_for_names( + self, + population_name: str, + *, + state_names: tuple[str, ...], + batch: int, + device: torch.device, + dtype: torch.dtype, + ) -> TensorDict: + receivers = self._population_num_cells(population_name) + return TensorDict( + { + state_name: torch.zeros( + receivers, + int(batch), + self.hidden_size, + device=device, + dtype=dtype, + ) + for state_name in state_names + }, + batch_size=[receivers, int(batch)], + device=device, + ) + + @staticmethod + def _reset_population_state_for_names( + population_state: TensorDictBase, + *, + state_names: tuple[str, ...], + batch_mask: torch.Tensor, + ) -> TensorDict: + reset_mask = batch_mask.view(1, -1, 1) + leaves: dict[str, torch.Tensor] = {} + for state_name in state_names: + tensor = population_state.get(state_name) + if not torch.is_tensor(tensor): + continue + leaves[state_name] = torch.where(reset_mask.to(device=tensor.device), torch.zeros_like(tensor), tensor) + return TensorDict( + leaves, + batch_size=list(population_state.batch_size), + device=population_state.device, + ) + def _backend_spec_for_cell_type(self, cell_type: str): for population_name, spec in self._backend_population_specs.items(): if self._population_cell_types[population_name] == cell_type: @@ -527,6 +616,51 @@ def plan_backend_execution( supported_variants=supported_variants, ) + def plan_temporal_execution( + self, + *, + batch_size: int, + time_steps: int, + k: int | torch.Tensor | None, + training: bool, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + input_boundary: str = "boundary", + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + output_contract: str = "full_cells", + materialize_final_state: bool = True, + state_is_fresh: bool = True, + has_resets: bool = False, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, + ) -> TemporalExecutionPlan: + plan_device = self.coords.device if device is None else torch.device(device) + planned_gradient_horizon_steps = ( + self.config.execution.gradient_horizon_steps if gradient_horizon_steps is None else gradient_horizon_steps + ) + planned_checkpoint_steps = ( + self.config.execution.checkpoint_steps if checkpoint_steps is None else checkpoint_steps + ) + return self._backend_planner.plan_temporal_execution( + device_type=plan_device.type, + dtype=str(dtype).removeprefix("torch."), + partitioned_layout=bool(self._partitioned_layout), + configured_backend=str(self.config.execution.backend), + constant_k=self._resolve_constant_k_host(k), + time_steps=time_steps, + training=training, + input_boundary=input_boundary, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + output_contract=output_contract, + materialize_final_state=materialize_final_state, + state_is_fresh=state_is_fresh, + has_resets=has_resets, + gradient_horizon_steps=planned_gradient_horizon_steps, + checkpoint_steps=planned_checkpoint_steps, + ) + def plan_backend_backward_execution( self, *, @@ -537,6 +671,7 @@ def plan_backend_backward_execution( tape_policy: TapePolicy | None = None, device: torch.device | None = None, surface_key: str | None = None, + temporal_plan: TemporalExecutionPlan | None = None, ) -> PlannedFabricBackwardExecution: if tape_policy is None: tape_policy = default_tape_policy(training) @@ -555,6 +690,7 @@ def plan_backend_backward_execution( device_caps=device_caps, tape_policy=tape_policy, supported_variants=supported_variants, + temporal_plan=temporal_plan, ) @staticmethod @@ -832,6 +968,24 @@ def _backend_forward_batch_tile_len_for_layout( self._last_backend_forward_batch_tile_reason = decision.reason return int(decision.value) + def _transition_core_state_names_for_population(self, population_name: str) -> tuple[str, ...] | None: + population_spec = self._backend_population_specs.get(population_name) + if population_spec is None: + return None + trace_names = { + schema.name + for schema in population_spec.private_state_schema + if schema.semantic_kind == "eligibility_trace" + } + if not trace_names: + return None + core_names = tuple( + state_name for state_name in population_spec.transition_ir.state_inputs if state_name not in trace_names + ) + if not core_names or len(core_names) == len(population_spec.transition_ir.state_inputs): + return None + return core_names + def _flat_bucket_sequence_readout_batch_tile_len( self, *, @@ -1151,7 +1305,7 @@ def actual_tuple_optional(key: str, requested: tuple[Any, ...] = ()) -> tuple[An readout_receiver_major_projection_gate = receiver_major_projection_backward_gate( batch_size=batch_size, receivers=int(self._num_output_cells), - input_dim=int(self.config.d_msg), + input_dim=int(self.d_msg), output_dim=int(self.hidden_size), biased=True, ) @@ -1415,6 +1569,14 @@ def add_backward_affine_bucket(phase: str, signature: str, backend: object) -> N launch_readout_modes=actual_readout_modes, launch_temporal_executions=actual_tuple("temporal_executions", ()), launch_scan_implementations=actual_tuple("scan_implementations", ()), + launch_temporal_scan_owners=actual_tuple_optional("temporal_scan_owners", ()), + launch_temporal_scan_outer_steps=actual_tuple_optional("temporal_scan_outer_steps", ()), + launch_temporal_scan_inner_steps=actual_tuple_optional("temporal_scan_inner_steps", ()), + launch_temporal_scan_physical_steps=actual_tuple_optional("temporal_scan_physical_steps", ()), + launch_temporal_scan_emission_counts=actual_tuple_optional("temporal_scan_emission_counts", ()), + launch_temporal_scan_first_emission_steps=actual_tuple_optional("temporal_scan_first_emission_steps", ()), + launch_temporal_scan_emission_strides=actual_tuple_optional("temporal_scan_emission_strides", ()), + launch_temporal_scan_output_boundaries=actual_tuple_optional("temporal_scan_output_boundaries", ()), launch_phases=actual_tuple("phases", ()), active_receiver_window_modes=actual_tuple("active_receiver_window_modes", ()), active_receiver_window_offsets=actual_tuple("active_receiver_window_offsets", ()), @@ -1596,8 +1758,165 @@ def add_backward_affine_bucket(phase: str, signature: str, backend: object) -> N large_r_diagnostics=tuple(bucket_plan.large_r_diagnostics for bucket_plan in plan.bucket_plans), graph_capture_replayed=graph_capture_replayed, graph_capture_cache_hit=graph_capture_cache_hit, + **temporal_execution_record_metadata(self._last_temporal_execution_plan), ) + def _message_rule_static_tensors( + self, + *, + include_full_cell_kv_weight: bool, + use_grouped_sender_weights: bool, + ) -> dict[str, object]: + static_specs = tuple(self._backend_ir.message_program.static_tensors) + if not static_specs: + return {} + empty_message_weight = self.public_proj.weight.new_empty(0) + needs_direct_recurrent_value_weight = any( + str(spec.source_kind) == "recurrent_sender_value_weight" for spec in static_specs + ) + cached_value_tensors: dict[str, torch.Tensor | None] | None = None + + def value_tensors() -> dict[str, torch.Tensor | None]: + nonlocal cached_value_tensors + if cached_value_tensors is not None: + return cached_value_tensors + gathered_value_weight = ( + self.v_weight.index_select(0, self.kv_group_id) + if include_full_cell_kv_weight or not use_grouped_sender_weights + else None + ) + sender_input_to_value_weight = ( + torch.einsum( + "dh,sdm->shm", + self.public_proj.weight, + gathered_value_weight.index_select(0, self.sender_cell_idx), + ) + if torch.is_tensor(gathered_value_weight) + else None + ) + input_sender_value_weight = ( + sender_input_to_value_weight.index_select(0, self.input_sender_idx) + if torch.is_tensor(sender_input_to_value_weight) + else None + ) + recurrent_sender_value_weight = ( + sender_input_to_value_weight.index_select(0, self.recurrent_sender_idx) + if torch.is_tensor(sender_input_to_value_weight) + else None + ) + if recurrent_sender_value_weight is None and needs_direct_recurrent_value_weight: + recurrent_value_weight = self.v_weight.index_select( + 0, + self.kv_group_id.index_select(0, self.recurrent_sender_idx), + ) + recurrent_sender_value_weight = torch.einsum( + "dh,sdm->shm", + self.public_proj.weight, + recurrent_value_weight, + ) + if ( + torch.is_tensor(recurrent_sender_value_weight) + and self.population_backend_recurrent_order.numel() == self._num_recurrent_cells + and int(recurrent_sender_value_weight.shape[0]) >= self._num_recurrent_cells + ): + recurrent_sender_value_weight = recurrent_sender_value_weight[: self._num_recurrent_cells].index_select( + 0, + self.population_backend_recurrent_order, + ) + input_group_value_weight = ( + torch.einsum( + "dh,gdm->ghm", + self.public_proj.weight, + self.v_weight.index_select(0, self.input_sender_kv_group_ids), + ) + if self.input_sender_kv_group_ids.numel() > 0 + else None + ) + cached_value_tensors = { + "input_sender_value_weight": input_sender_value_weight, + "input_group_value_weight": input_group_value_weight, + "recurrent_sender_value_weight": recurrent_sender_value_weight, + } + return cached_value_tensors + + def slot_linear(module_name: str, *, gather: str) -> torch.Tensor: + module = self.message_rule_modules.get(module_name) + if not isinstance(module, nn.Linear): + raise RuntimeError(f"Message rule static tensor source {module_name!r} is not an installed Linear") + tensor = module(self.slot_embed).view(int(self.coords.shape[0]), int(module.out_features)) + if gather == "backend_recurrent": + recurrent = tensor.index_select(0, self.recurrent_cell_idx) + if self.population_backend_recurrent_order.numel() == self.recurrent_cell_idx.numel(): + recurrent = recurrent.index_select(0, self.population_backend_recurrent_order) + return recurrent.contiguous() + if gather == "sender": + return reorder_sender_tensor_for_backend(tensor.index_select(0, self.sender_cell_idx)).contiguous() + raise RuntimeError(f"Unsupported message rule slot-linear gather {gather!r}") + + def runtime_parameter(name: str, *, gather: str = "") -> torch.Tensor: + parameter = self.message_rule_parameters.get(name) + if parameter is None: + raise RuntimeError(f"Message rule static tensor source {name!r} is not an installed parameter") + if gather == "sender": + return reorder_sender_tensor_for_backend(parameter.index_select(0, self.sender_cell_idx)).contiguous() + if gather: + raise RuntimeError(f"Unsupported message rule parameter gather {gather!r}") + return parameter + + def reorder_sender_tensor_for_backend(tensor: torch.Tensor) -> torch.Tensor: + if ( + self.population_backend_recurrent_order.numel() != self._num_recurrent_cells + or int(tensor.shape[0]) < self._num_input_cells + self._num_recurrent_cells + ): + return tensor + input_part = tensor[: self._num_input_cells] + recurrent_part = tensor[ + self._num_input_cells : self._num_input_cells + self._num_recurrent_cells + ].index_select(0, self.population_backend_recurrent_order) + if int(tensor.shape[0]) == self._num_input_cells + self._num_recurrent_cells: + return torch.cat((input_part, recurrent_part), dim=0) + return torch.cat( + (input_part, recurrent_part, tensor[self._num_input_cells + self._num_recurrent_cells :]), dim=0 + ) + + def module_weight(name: str) -> torch.Tensor: + module = getattr(self, name, None) + weight = getattr(module, "weight", None) + if not torch.is_tensor(weight): + raise RuntimeError(f"Message rule static tensor source module {name!r} has no tensor weight") + return weight + + tensors: dict[str, object] = {} + for static_spec in static_specs: + source_kind = str(static_spec.source_kind) + source_name = str(static_spec.source_name) + if source_kind == "existing_static_tensor": + continue + if source_kind == "slot_linear_backend_recurrent": + value: torch.Tensor = slot_linear(source_name, gather="backend_recurrent") + elif source_kind == "slot_linear_sender": + value = slot_linear(source_name, gather="sender") + elif source_kind == "runtime_parameter": + value = runtime_parameter(source_name) + elif source_kind == "runtime_parameter_sender": + value = runtime_parameter(source_name, gather="sender") + elif source_kind in { + "input_sender_value_weight", + "input_group_value_weight", + "recurrent_sender_value_weight", + }: + value = value_tensors().get(source_kind) + if not torch.is_tensor(value): + value = empty_message_weight + elif source_kind == "module_weight": + value = module_weight(source_name) + else: + raise RuntimeError( + f"Unsupported message rule static tensor source kind {source_kind!r} for {static_spec.name!r}" + ) + tensors[str(static_spec.name)] = value + return tensors + def _record_pytorch_backend_execution( self, *, @@ -1628,6 +1947,7 @@ def _record_pytorch_backend_execution( tape_policy_bin="none", graph_capture_enabled=False, capability_variants=(), + **temporal_execution_record_metadata(self._last_temporal_execution_plan), ) def _materialize_inference_static_tensors( @@ -1702,6 +2022,10 @@ def static_contiguous(tensor: torch.Tensor) -> torch.Tensor: if input_group_kv_weight is not None else None ) + message_rule_static_tensors = self._message_rule_static_tensors( + include_full_cell_kv_weight=include_full_cell_kv_weight, + use_grouped_sender_weights=use_grouped_sender_weights, + ) population_materialized: dict[str, object | None] = {} should_include_backend_prepack = ( include_backend_cell_tensors if include_backend_prepack is None else include_backend_prepack @@ -1718,6 +2042,15 @@ def static_contiguous(tensor: torch.Tensor) -> torch.Tensor: else: population_materialized = {name: None for name in self._population_names} value_to_cell_weight = self.msg_to_cell.weight @ self.msg_out.weight + message_rule_output_dim_role = self._message_rule_output_dim_role() + recurrent_message_to_cell_weight_source = ( + "message_to_cell_weight" if message_rule_output_dim_role == "d_msg" else "value_to_cell_weight" + ) + recurrent_message_to_cell_weight = ( + self.msg_to_cell.weight + if recurrent_message_to_cell_weight_source == "message_to_cell_weight" + else value_to_cell_weight + ) fused_recurrent_value_to_cell_weight = None fused_recurrent_cell_bias = recurrent_cell_bias fused_recurrent_population_input = False @@ -1734,7 +2067,7 @@ def static_contiguous(tensor: torch.Tensor) -> torch.Tensor: out_proj_bias = full_population_params.get("out_proj_bias") if torch.is_tensor(input_proj_weight_t): fused_recurrent_value_to_cell_weight = torch.matmul( - value_to_cell_weight.transpose(0, 1).unsqueeze(0), + recurrent_message_to_cell_weight.transpose(0, 1).unsqueeze(0), input_proj_weight_t, ) fused_recurrent_cell_bias = ( @@ -1753,6 +2086,15 @@ def static_contiguous(tensor: torch.Tensor) -> torch.Tensor: max(1, self._recurrent_sender_kv_group_size), dim=0, ) + recurrent_sender_input_to_kv_weight_backend_order = None + if ( + torch.is_tensor(recurrent_input_to_kv_weight) + and self.population_backend_recurrent_order.numel() == self._num_recurrent_cells + and int(recurrent_input_to_kv_weight.shape[0]) >= self._num_recurrent_cells + ): + recurrent_sender_input_to_kv_weight_backend_order = recurrent_input_to_kv_weight[ + : self._num_recurrent_cells + ].index_select(0, self.population_backend_recurrent_order) backend_cell_tensors: dict[str, dict[str, torch.Tensor]] = {} if include_backend_cell_tensors: for population_name in self._population_names: @@ -1794,7 +2136,9 @@ def static_contiguous(tensor: torch.Tensor) -> torch.Tensor: fused_recurrent_value_to_cell_weight if fused_recurrent_population_input and torch.is_tensor(fused_recurrent_value_to_cell_weight) else static_contiguous( - value_to_cell_weight.transpose(0, 1).unsqueeze(0).expand(population_recurrent_count, -1, -1) + recurrent_message_to_cell_weight.transpose(0, 1) + .unsqueeze(0) + .expand(population_recurrent_count, -1, -1) ), ) sequence_population_input_bias = cast( @@ -1853,12 +2197,16 @@ def static_contiguous(tensor: torch.Tensor) -> torch.Tensor: "sender_group_input_to_kv_weight": sender_group_input_to_kv_weight, "recurrent_sender_input_to_kv_weight": recurrent_sender_input_to_kv_weight, "recurrent_group_input_to_kv_weight": recurrent_group_input_to_kv_weight, + "recurrent_sender_input_to_kv_weight_backend_order": recurrent_sender_input_to_kv_weight_backend_order, "input_group_input_to_kv_weight": input_group_input_to_kv_weight, + "message_to_cell_weight": self.msg_to_cell.weight, "value_to_cell_weight": value_to_cell_weight, + "recurrent_message_to_cell_weight_source": recurrent_message_to_cell_weight_source, "fused_recurrent_value_to_cell_weight": fused_recurrent_value_to_cell_weight, "fused_recurrent_cell_bias": fused_recurrent_cell_bias, "fused_recurrent_population_input": fused_recurrent_population_input, "value_to_output_weight": torch.einsum("dv,pdh->pvh", self.msg_out.weight, self.output_cell_weight), + **message_rule_static_tensors, "population_materialized": population_materialized, "backend_cell_tensors": backend_cell_tensors, } @@ -1944,11 +2292,22 @@ def init_state(self, batch: int, *, device: torch.device | str = "cpu", dtype: t state = TensorDict({}, batch_size=[]) state["cells"] = torch.zeros(batch, self.coords.shape[0], self.hidden_size, device=device, dtype=dtype) for population_name in self._population_names: - state[population_name] = self.population_modules[population_name].init_state( - batch=batch, - device=device, - dtype=dtype, - ) + state_names = self._compiled_transition_state_names_for_population(population_name) + default_state_names = self._default_population_state_names_for_population(population_name) + if state_names == default_state_names: + state[population_name] = self.population_modules[population_name].init_state( + batch=batch, + device=device, + dtype=dtype, + ) + else: + state[population_name] = self._init_population_state_for_names( + population_name, + state_names=state_names, + batch=batch, + device=torch.device(device), + dtype=dtype, + ) return state def reset_state(self, state: MaybeState, mask: ResetMask) -> MaybeState: @@ -1972,7 +2331,17 @@ def reset_state(self, state: MaybeState, mask: ResetMask) -> MaybeState: population_state = state.get(population_name) if population_state is None: continue - out[population_name] = self.population_modules[population_name].reset_state(population_state, batch_mask) + state_names = self._compiled_transition_state_names_for_population(population_name) + if state_names == self._default_population_state_names_for_population(population_name): + out[population_name] = self.population_modules[population_name].reset_state( + population_state, batch_mask + ) + elif isinstance(population_state, TensorDictBase): + out[population_name] = self._reset_population_state_for_names( + population_state, + state_names=state_names, + batch_mask=batch_mask, + ) return out def forward( @@ -2015,6 +2384,7 @@ def forward_cells( `[B, T, P, H]`, where `P` is the number of boundary input cells. """ self._last_backend_execution = None + self._last_temporal_execution_plan = None self._last_backend_launch_metadata = None if hidden_input is None and boundary_input is None: raise ValueError("forward_cells requires either hidden_input or boundary_input") @@ -2067,42 +2437,28 @@ def forward_cells( backend_population_name = None selected_backend_recurrence_surface = None planned_backend_recurrence_execution = None - supports_cuda_flat_bucket_sequence = False - sequence_surface_route = self._plan_sequence_surface_route( + supports_cuda_registered_temporal_program = False + temporal_plan = self._plan_temporal_execution( k=k, device=device, dtype=dtype, - ) - backend_population_name = self._select_output_cells_stream_backend_population( - k=k, - ) - selected_backend_recurrence_surface = self._select_backend_sequence_surface( + batch_size=batch_size, + time_steps=time_steps, training=grad_path, - k=k, - device=device, - dtype=dtype, - backend_population_name=backend_population_name, - sequence_surface_route=sequence_surface_route, - ) - supports_cuda_flat_bucket_sequence = sequence_surface_route.uses_flat_transition_buckets - selected_backend_name = select_fabric_backend( - configured_backend=str(self.config.backend), - device=device, - supports_cuda_backend=sequence_surface_route.supported, - ) - uses_flat_bucket_physical_static_contract = bool( - selected_backend_name == "cuda" - and selected_backend_recurrence_surface is None - and supports_cuda_flat_bucket_sequence - ) - backend_native_static_materialization = selected_backend_name == "cuda" and ( - selected_backend_recurrence_surface is not None or uses_flat_bucket_physical_static_contract + input_boundary="boundary" if boundary_seq is not None else "hidden", + output_boundary="sequence", + readout_output_boundary="cells", + output_contract="full_cells", + materialize_final_state=materialize_final_state, + state_is_fresh=backend_population_state_is_fresh, + has_resets=population_resets is not None, ) - flat_bucket_autograd_static_values = bool( - selected_backend_name == "cuda" - and selected_backend_recurrence_surface is None - and supports_cuda_flat_bucket_sequence - and not uses_flat_bucket_physical_static_contract + sequence_surface_route = temporal_plan.sequence_surface_route + supports_cuda_registered_temporal_program = sequence_surface_route.uses_registered_temporal_program + selected_backend_name = temporal_plan.executor.backend_name + use_fresh_backend_population_cache = temporal_plan.carry.fresh_state_population_cache + flat_bucket_autograd_static_values = ( + temporal_plan.static_values.static_value_mode == "flat_bucket_autograd_static_values" ) if grad_path: training_static_prepack = self._training_static_prepack_enabled() @@ -2110,7 +2466,7 @@ def forward_cells( self._last_backward_projection_mode = ( "fused_static_projection" if training_static_prepack else "factorized_recurrent_input" ) - if str(self.config.backend) == "pytorch": + if temporal_plan.static_values.static_value_mode == "pytorch_autograd_static_values": static_tensors = self._materialize_inference_static_tensors( device=device, dtype=dtype, @@ -2122,8 +2478,8 @@ def forward_cells( device=device, dtype=dtype, include_backend_prepack=training_static_prepack, - include_full_cell_kv_weight=not backend_native_static_materialization, - detach_static_tensors=not flat_bucket_autograd_static_values, + include_full_cell_kv_weight=temporal_plan.static_values.include_full_cell_kv_weight, + detach_static_tensors=temporal_plan.static_values.detach_training_static_tensors, ) self._last_training_static_tape_mode = ( "flat_bucket_autograd_static_values" @@ -2136,7 +2492,7 @@ def forward_cells( static_tensors = self._get_inference_static_tensors( device=device, dtype=dtype, - include_full_cell_kv_weight=not backend_native_static_materialization, + include_full_cell_kv_weight=temporal_plan.static_values.include_full_cell_kv_weight, ) cell_bias = static_tensors["cell_bias"] recurrent_cell_bias = static_tensors["recurrent_cell_bias"] @@ -2159,28 +2515,6 @@ def forward_cells( constant_k = self._resolve_constant_k_host(k) population_materialized = static_tensors["population_materialized"] self._active_backend_name = selected_backend_name - if selected_backend_name == "cuda" and selected_backend_recurrence_surface is not None: - planned_backend_recurrence_execution = self.plan_backend_execution( - batch_size=batch_size, - time_steps=time_steps, - inner_steps=1, - training=grad_path, - device=device, - surface_key=selected_backend_recurrence_surface.key, - ) - else: - selected_backend_recurrence_surface = None - planned_backend_recurrence_execution = None - use_fresh_backend_population_cache = bool( - backend_population_state_is_fresh - and selected_backend_name == "cuda" - and selected_backend_recurrence_surface is None - and supports_cuda_flat_bucket_sequence - and len(sequence_surface_route.active_populations) != 1 - and not grad_path - and not materialize_final_state - and constant_k == 1 - ) current_state = self._ensure_state( state, batch=batch_size, @@ -2188,11 +2522,7 @@ def forward_cells( dtype=dtype, include_population_state=not use_fresh_backend_population_cache, ) - if ( - selected_backend_name == "cuda" - and selected_backend_recurrence_surface is None - and supports_cuda_flat_bucket_sequence - ): + if selected_backend_name == "cuda" and supports_cuda_registered_temporal_program: return execute_temporal_bucket_sequence( self, hidden_seq=hidden_seq, @@ -2441,6 +2771,7 @@ def forward_output_cells_for_readout( if readout_output_boundary not in {"cells", "pooled"}: raise ValueError(f"Unsupported Fabric readout output boundary {readout_output_boundary!r}") self._last_backend_execution = None + self._last_temporal_execution_plan = None self._last_backend_launch_metadata = None projected_boundary_active = source_hidden_input is not None or input_projection_weight is not None source_hidden_seq: torch.Tensor | None = None @@ -2473,19 +2804,11 @@ def forward_output_cells_for_readout( f"input_projection_bias must have shape [{projected_features}], " f"got {tuple(input_projection_bias.shape)}" ) - projected_boundary_source_seq = source_hidden_seq - projected_boundary_weight = input_projection_weight - projected_boundary_bias = input_projection_bias boundary_seq = None - boundary_input_for_fallback = None else: if boundary_input is None: raise ValueError("forward_output_cells_for_readout requires boundary_input or projected source input") - projected_boundary_source_seq = None - projected_boundary_weight = None - projected_boundary_bias = None boundary_seq = boundary_input.unsqueeze(1) if boundary_input.dim() == 3 else boundary_input - boundary_input_for_fallback = boundary_input if boundary_seq is not None and boundary_seq.dim() != 4: raise ValueError( "forward_output_cells_for_readout expects boundary_input shaped [B,P,D] or [B,T,P,D], " @@ -2503,6 +2826,17 @@ def forward_output_cells_for_readout( msg_dim = int(self.hidden_size) device = source_hidden_seq.device dtype = source_hidden_seq.dtype + stage_reference = boundary_seq if boundary_seq is not None else cast(torch.Tensor, source_hidden_seq) + record_registered_memory_stage(self, stage_reference, "frontend_output_cells_entry") + record_frontend_tensor_bytes( + self, + stage="frontend_output_cells_entry", + tensors={ + "source_hidden_seq": source_hidden_seq, + "boundary_seq": boundary_seq, + "state": state, + }, + ) if port_count != self.input_cell_idx.numel(): raise ValueError( f"Runtime boundary_input count={port_count} must match input cells={self.input_cell_idx.numel()}" @@ -2512,41 +2846,38 @@ def forward_output_cells_for_readout( backend_population_state_is_fresh = state is None or not isinstance(state, TensorDictBase) grad_path = torch.is_grad_enabled() if training_semantics is None else bool(training_semantics) population_resets = _expand_resets_for_time(resets, batch_size=batch_size, time_steps=time_steps, device=device) - sequence_surface_route = self._plan_sequence_surface_route( + record_registered_memory_stage(self, stage_reference, "frontend_after_reset_expansion") + record_frontend_tensor_bytes( + self, + stage="frontend_after_reset_expansion", + tensors={ + "population_resets": population_resets, + "source_hidden_seq": source_hidden_seq, + "boundary_seq": boundary_seq, + }, + ) + temporal_plan = self._plan_temporal_execution( k=k, device=device, dtype=dtype, - ) - backend_population_name = self._select_output_cells_stream_backend_population( - k=k, - ) - selected_backend_sequence_surface = self._select_backend_sequence_surface( + batch_size=batch_size, + time_steps=time_steps, training=grad_path, - k=k, - device=device, - dtype=dtype, - backend_population_name=backend_population_name, - sequence_surface_route=sequence_surface_route, - ) - supports_cuda_flat_bucket_sequence = sequence_surface_route.uses_flat_transition_buckets - selected_backend_name = select_fabric_backend( - configured_backend=str(self.config.backend), - device=device, - supports_cuda_backend=sequence_surface_route.supported, - ) - uses_flat_bucket_physical_static_contract = bool( - selected_backend_name == "cuda" - and selected_backend_sequence_surface is None - and supports_cuda_flat_bucket_sequence - ) - backend_native_static_materialization = selected_backend_name == "cuda" and ( - selected_backend_sequence_surface is not None or uses_flat_bucket_physical_static_contract + input_boundary="projected_boundary" if projected_boundary_active else "boundary", + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + output_contract="pooled_output_cells" if readout_output_boundary == "pooled" else "output_cells", + materialize_final_state=materialize_final_state, + state_is_fresh=backend_population_state_is_fresh, + has_resets=population_resets is not None, ) - flat_bucket_autograd_static_values = bool( - selected_backend_name == "cuda" - and selected_backend_sequence_surface is None - and supports_cuda_flat_bucket_sequence - and not uses_flat_bucket_physical_static_contract + record_registered_memory_stage(self, stage_reference, "frontend_after_temporal_plan") + sequence_surface_route = temporal_plan.sequence_surface_route + supports_cuda_registered_temporal_program = sequence_surface_route.uses_registered_temporal_program + selected_backend_name = temporal_plan.executor.backend_name + use_fresh_backend_population_cache = temporal_plan.carry.fresh_state_population_cache + flat_bucket_autograd_static_values = ( + temporal_plan.static_values.static_value_mode == "flat_bucket_autograd_static_values" ) static_tensors: dict[str, object] | None = None @@ -2560,7 +2891,7 @@ def resolve_static_tensors() -> dict[str, object]: self._last_backward_projection_mode = ( "fused_static_projection" if training_static_prepack else "factorized_recurrent_input" ) - if str(self.config.backend) == "pytorch": + if temporal_plan.static_values.static_value_mode == "pytorch_autograd_static_values": static_tensors = self._materialize_inference_static_tensors( device=device, dtype=dtype, @@ -2572,8 +2903,8 @@ def resolve_static_tensors() -> dict[str, object]: device=device, dtype=dtype, include_backend_prepack=training_static_prepack, - include_full_cell_kv_weight=not backend_native_static_materialization, - detach_static_tensors=not flat_bucket_autograd_static_values, + include_full_cell_kv_weight=temporal_plan.static_values.include_full_cell_kv_weight, + detach_static_tensors=temporal_plan.static_values.detach_training_static_tensors, ) self._last_training_static_tape_mode = ( "flat_bucket_autograd_static_values" @@ -2586,245 +2917,140 @@ def resolve_static_tensors() -> dict[str, object]: static_tensors = self._get_inference_static_tensors( device=device, dtype=dtype, - include_full_cell_kv_weight=not backend_native_static_materialization, + include_full_cell_kv_weight=temporal_plan.static_values.include_full_cell_kv_weight, ) return static_tensors self._active_backend_name = selected_backend_name - if selected_backend_name == "cuda" and selected_backend_sequence_surface is not None: - route_static_tensors = cast( - dict[str, object], - self._detach_backend_static_tensors(resolve_static_tensors()), + if projected_boundary_active and source_hidden_seq is not None and boundary_seq is None: + assert input_projection_weight is not None + record_registered_memory_stage(self, source_hidden_seq, "frontend_before_boundary_projection") + record_frontend_tensor_bytes( + self, + stage="frontend_before_boundary_projection", + tensors={ + "source_hidden_seq": source_hidden_seq, + "input_projection_weight": input_projection_weight, + "input_projection_bias": input_projection_bias, + }, ) - if grad_path: - self._last_training_static_tape_mode = "detached_shared_values" - current_state = ( - TensorDict({}, batch_size=[]) - if backend_population_state_is_fresh - else self._ensure_state(state, batch=batch_size, device=device, dtype=dtype) + boundary_seq = self._project_boundary_source_sequence( + source_hidden_seq, + input_projection_weight=input_projection_weight, + input_projection_bias=input_projection_bias, ) - planned_backend_sequence_execution = self.plan_backend_execution( - batch_size=batch_size, - time_steps=time_steps, - inner_steps=1, - training=grad_path, - tape_policy=tape_policy, + record_registered_memory_stage(self, boundary_seq, "frontend_after_boundary_projection") + record_frontend_tensor_bytes( + self, + stage="frontend_after_boundary_projection", + tensors={ + "source_hidden_seq": source_hidden_seq, + "boundary_seq": boundary_seq, + }, + ) + if selected_backend_name == "cuda" and supports_cuda_registered_temporal_program and boundary_seq is not None: + record_registered_memory_stage(self, boundary_seq, "frontend_before_ensure_state") + current_state = self._ensure_state( + state, + batch=batch_size, device=device, - surface_key=selected_backend_sequence_surface.key, + dtype=dtype, + include_population_state=not use_fresh_backend_population_cache, ) - if boundary_seq is None: - assert projected_boundary_source_seq is not None and projected_boundary_weight is not None - return self._run_backend_projected_sequence_surface( - state=current_state, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - static_tensors=route_static_tensors, - population_resets=population_resets, - input_sender_input_to_kv_weight=route_static_tensors["input_sender_input_to_kv_weight"], - input_group_input_to_kv_weight=route_static_tensors["input_group_input_to_kv_weight"], - backend_population_name=backend_population_name, - backend_population_state_is_fresh=backend_population_state_is_fresh, - materialize_final_state=materialize_final_state, - grad_path=grad_path, - selected_backend_surface=selected_backend_sequence_surface, - planned_backend_execution=planned_backend_sequence_execution, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - output_chunk_consumer=output_chunk_consumer, - detach_internal_carry_after_output_chunk=detach_internal_carry_after_output_chunk, - ) - output_cells, next_state = self._execute_backend_sequence_surface( - state=current_state, + record_registered_memory_stage(self, boundary_seq, "frontend_after_ensure_state") + record_frontend_tensor_bytes( + self, + stage="frontend_after_ensure_state", + tensors={ + "current_state": current_state, + "boundary_seq": boundary_seq, + }, + ) + record_registered_memory_stage(self, boundary_seq, "frontend_before_static_tensors") + route_static_tensors = resolve_static_tensors() + record_registered_memory_stage(self, boundary_seq, "frontend_after_static_tensors") + record_frontend_tensor_bytes( + self, + stage="frontend_after_static_tensors", + tensors={ + "route_static_tensors": route_static_tensors, + "current_state": current_state, + "boundary_seq": boundary_seq, + }, + ) + record_registered_memory_stage(self, boundary_seq, "frontend_before_registered_execute") + output_cells, next_state = execute_temporal_bucket_sequence( + self, + hidden_seq=None, boundary_seq=boundary_seq, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - static_tensors=route_static_tensors, + state=current_state, population_resets=population_resets, - input_sender_input_to_kv_weight=route_static_tensors["input_sender_input_to_kv_weight"], - input_group_input_to_kv_weight=route_static_tensors["input_group_input_to_kv_weight"], - backend_population_name=backend_population_name, - backend_population_state_is_fresh=backend_population_state_is_fresh, - materialize_final_state=materialize_final_state, + step_reset_flags=None, + k=k, + constant_k=self._resolve_constant_k_host(k), + batch_size=batch_size, + time_steps=time_steps, + step_mode=False, + capture_active=bool(device.type == "cuda" and torch.cuda.is_current_stream_capturing()), + static_tensors=route_static_tensors, grad_path=grad_path, - selected_backend_surface=selected_backend_sequence_surface, - planned_backend_execution=planned_backend_sequence_execution, + materialize_final_state=materialize_final_state, + backend_population_state_is_fresh=backend_population_state_is_fresh, + use_fresh_backend_population_cache=use_fresh_backend_population_cache, + tape_policy=tape_policy, + output_contract="pooled_output_cells" if readout_output_boundary == "pooled" else "output_cells", output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, ) + if output_boundary == "terminal" and int(output_cells.shape[1]) != 1: + raise RuntimeError( + "Flat temporal sequence executor must materialize exactly one terminal output timestep" + ) if output_chunk_consumer is None: return output_cells, next_state output_chunk_consumer(output_cells, 0, int(output_cells.shape[1])) empty_output = output_cells.new_empty((batch_size, 0, *tuple(output_cells.shape[2:]))) return empty_output, next_state - if ( - projected_boundary_active - and source_hidden_seq is not None - and boundary_seq is None - and selected_backend_name == "cuda" - and selected_backend_sequence_surface is None - and supports_cuda_flat_bucket_sequence - and not grad_path - and not materialize_final_state - and backend_population_state_is_fresh - and self._resolve_constant_k_host(k) == 1 - ): - route_static_tensors = resolve_static_tensors() - active_output_result = execute_temporal_bucket_active_output_window( - self, - projected_boundary_source_seq=source_hidden_seq, - projected_boundary_weight=cast(torch.Tensor, projected_boundary_weight), - projected_boundary_bias=projected_boundary_bias, - resets=population_resets, - static_tensors=route_static_tensors, - output_boundary=output_boundary, - ) - if active_output_result is not None: - output_cells, next_state = active_output_result - if readout_output_boundary == "pooled": - output_cells = self._pool_output_ports(output_cells) - record_temporal_bucket_sequence_surface_execution( - self, - batch_size=batch_size, - time_steps=time_steps, - inner_steps=1, - training=grad_path, - output_boundary=output_boundary, - active_receiver_window_mode=str(self._flat_bucket_active_output_region_mode), - active_receiver_window_offset=str(int(self._flat_bucket_active_output_region_start)), - active_receiver_window_count=str(int(self._flat_bucket_active_output_region_count)), - ) - if output_chunk_consumer is None: - return output_cells, next_state - output_chunk_consumer(output_cells, 0, int(output_cells.shape[1])) - empty_output = output_cells.new_empty((batch_size, 0, *tuple(output_cells.shape[2:]))) - return empty_output, next_state - if projected_boundary_active and source_hidden_seq is not None and boundary_seq is None: - assert input_projection_weight is not None - boundary_seq = self._project_boundary_source_sequence( - source_hidden_seq, - input_projection_weight=input_projection_weight, - input_projection_bias=input_projection_bias, + if selected_backend_name == "pytorch" and boundary_seq is not None: + full_cells, next_state = self.forward_cells( + state=state, + resets=resets, + k=k, + boundary_input=boundary_seq, + training_semantics=training_semantics, + materialize_final_state=materialize_final_state, ) - boundary_input_for_fallback = boundary_seq - if ( - selected_backend_name == "cuda" - and selected_backend_sequence_surface is None - and supports_cuda_flat_bucket_sequence - and not materialize_final_state - and backend_population_state_is_fresh - ): - if boundary_seq is not None: - current_state = self._ensure_state(state, batch=batch_size, device=device, dtype=dtype) - route_static_tensors = resolve_static_tensors() - output_cells, next_state = execute_temporal_bucket_sequence( - self, - hidden_seq=None, - boundary_seq=boundary_seq, - state=current_state, - population_resets=population_resets, - step_reset_flags=None, - k=k, - constant_k=self._resolve_constant_k_host(k), - batch_size=batch_size, - time_steps=time_steps, - step_mode=False, - capture_active=bool(device.type == "cuda" and torch.cuda.is_current_stream_capturing()), - static_tensors=route_static_tensors, - grad_path=grad_path, - materialize_final_state=materialize_final_state, - backend_population_state_is_fresh=backend_population_state_is_fresh, - use_fresh_backend_population_cache=False, - tape_policy=tape_policy, - output_contract="pooled_output_cells" if readout_output_boundary == "pooled" else "output_cells", - output_boundary=output_boundary, - ) - if output_boundary == "terminal" and int(output_cells.shape[1]) > 1: - output_cells = output_cells[:, -1:] - if output_chunk_consumer is None: - return output_cells, next_state - output_chunk_consumer(output_cells, 0, int(output_cells.shape[1])) - empty_output = output_cells.new_empty((batch_size, 0, *tuple(output_cells.shape[2:]))) - return empty_output, next_state - selected_backend_sequence_surface = None - internal_materialize_state_for_output = bool(not materialize_final_state and selected_backend_name == "pytorch") - y_cells, next_state = self.forward_cells( - state=state, - resets=resets, - k=k, - boundary_input=boundary_input_for_fallback, - training_semantics=training_semantics, - materialize_final_state=materialize_final_state or internal_materialize_state_for_output, - ) - y_seq = y_cells.unsqueeze(1) if y_cells.dim() == 3 else y_cells - output_cells = self._select_output_cells(y_seq) - if output_boundary == "terminal" and int(output_cells.shape[1]) > 1: - output_cells = output_cells[:, -1:] - if output_chunk_consumer is not None: + if output_boundary == "terminal": + full_cells = full_cells[:, -1:, :, :] + output_cells = self._select_output_cells(full_cells) + if readout_output_boundary == "pooled": + output_cells = self._pool_output_ports(output_cells) + if output_chunk_consumer is None: + return output_cells, cast(TensorDict, next_state) output_chunk_consumer(output_cells, 0, int(output_cells.shape[1])) empty_output = output_cells.new_empty((batch_size, 0, *tuple(output_cells.shape[2:]))) - if detach_internal_carry_after_output_chunk: - return empty_output, TensorDict({}, batch_size=[]) - return empty_output, TensorDict({}, batch_size=[]) if internal_materialize_state_for_output else next_state - return output_cells, TensorDict({}, batch_size=[]) if internal_materialize_state_for_output else next_state + return empty_output, cast(TensorDict, next_state) + del detach_internal_carry_after_output_chunk + raise RuntimeError( + "Fabric sequence execution must enter the shared flat-bucket temporal engine; " + "direct Runtime.forward_cells sequence routing is disabled. " + f"backend={selected_backend_name}; " + f"uses_registered_temporal_sequence={supports_cuda_registered_temporal_program}; " + f"reason={temporal_plan.reason}" + ) - def _run_backend_projected_sequence_surface( + def _resolve_k( self, - *, - state: TensorDict, - projected_boundary_source_seq: torch.Tensor, - projected_boundary_weight: torch.Tensor, - projected_boundary_bias: torch.Tensor | None, - static_tensors: dict[str, object], - population_resets: torch.Tensor | None, - input_sender_input_to_kv_weight: torch.Tensor | None, - input_group_input_to_kv_weight: torch.Tensor | None, - backend_population_name: str | None, - backend_population_state_is_fresh: bool, - materialize_final_state: bool, - grad_path: bool, - selected_backend_surface: SupportedSurface, - planned_backend_execution: PlannedFabricExecution, - output_boundary: Literal["sequence", "terminal"] = "sequence", - readout_output_boundary: Literal["cells", "pooled"] = "cells", - output_chunk_consumer: _RuntimeOutputChunkConsumer | None = None, - detach_internal_carry_after_output_chunk: bool = False, - ) -> tuple[torch.Tensor, TensorDict]: - return super()._execute_backend_projected_source_sequence_surface( - state=state, - projected_boundary_source_seq=projected_boundary_source_seq, - projected_boundary_weight=projected_boundary_weight, - projected_boundary_bias=projected_boundary_bias, - static_tensors=static_tensors, - population_resets=population_resets, - input_sender_input_to_kv_weight=input_sender_input_to_kv_weight, - input_group_input_to_kv_weight=input_group_input_to_kv_weight, - backend_population_name=backend_population_name, - backend_population_state_is_fresh=backend_population_state_is_fresh, - materialize_final_state=materialize_final_state, - grad_path=grad_path, - selected_backend_surface=selected_backend_surface, - planned_backend_execution=planned_backend_execution, - output_boundary=output_boundary, - readout_output_boundary=readout_output_boundary, - output_chunk_consumer=output_chunk_consumer, - detach_internal_carry_after_output_chunk=detach_internal_carry_after_output_chunk, - ) - - def _resolve_k( - self, - k: int | torch.Tensor | None, + k: int | torch.Tensor | None, *, batch_size: int, time_steps: int, device: torch.device, ) -> tuple[torch.Tensor, int]: if k is None: - max_steps = int(self.config.default_k) + max_steps = int(self.config.execution.default_k) k_rows = torch.full((batch_size,), max_steps, device=device, dtype=torch.long) elif isinstance(k, int): - max_steps = max(0, min(int(self.config.k_max), int(k))) + max_steps = max(0, min(int(self.config.execution.k_max), int(k))) k_rows = torch.full((batch_size,), max_steps, device=device, dtype=torch.long) else: k_tensor = torch.as_tensor(k, device=device, dtype=torch.long) @@ -2841,15 +3067,15 @@ def _resolve_k( else: raise ValueError(f"k must be int, [B], or [B,T], got shape {tuple(k_tensor.shape)}") - k_rows = k_rows.clamp(min=0, max=int(self.config.k_max)) + k_rows = k_rows.clamp(min=0, max=int(self.config.execution.k_max)) max_steps = int(k_rows.max().item()) if k_rows.numel() > 0 else 0 return k_rows, max_steps def _resolve_constant_k_host(self, k: int | torch.Tensor | None) -> int | None: if k is None: - return int(self.config.default_k) + return int(self.config.execution.default_k) if isinstance(k, int): - return max(0, min(int(self.config.k_max), int(k))) + return max(0, min(int(self.config.execution.k_max), int(k))) return None def _resolve_step_k( @@ -2862,7 +3088,7 @@ def _resolve_step_k( device: torch.device, ) -> tuple[torch.Tensor, int]: if k is None: - k_rows = torch.full((batch_size,), int(self.config.default_k), device=device, dtype=torch.long) + k_rows = torch.full((batch_size,), int(self.config.execution.default_k), device=device, dtype=torch.long) elif isinstance(k, int): k_rows = torch.full((batch_size,), int(k), device=device, dtype=torch.long) else: @@ -2873,7 +3099,7 @@ def _resolve_step_k( k_rows = k_tensor[:, step_index] else: raise ValueError(f"k must be int, [B], or [B,T], got shape {tuple(k_tensor.shape)}") - k_rows = k_rows.clamp(min=0, max=int(self.config.k_max)) + k_rows = k_rows.clamp(min=0, max=int(self.config.execution.k_max)) max_steps = int(k_rows.max().item()) if k_rows.numel() > 0 else 0 return k_rows, max_steps @@ -2941,12 +3167,22 @@ def _ensure_state( for population_name in self._population_names: population_state = state.get(population_name) expected = torch.Size([self._population_num_cells(population_name), batch]) + state_names = self._compiled_transition_state_names_for_population(population_name) if population_state is None or population_state.batch_size != expected: - out[population_name] = self.population_modules[population_name].init_state( - batch=batch, - device=device, - dtype=dtype, - ) + if state_names == self._default_population_state_names_for_population(population_name): + out[population_name] = self.population_modules[population_name].init_state( + batch=batch, + device=device, + dtype=dtype, + ) + else: + out[population_name] = self._init_population_state_for_names( + population_name, + state_names=state_names, + batch=batch, + device=device, + dtype=dtype, + ) elif self._population_state_matches_device_dtype( population_name, population_state, @@ -2966,8 +3202,8 @@ def _population_state_matches_device_dtype( device: torch.device, dtype: torch.dtype, ) -> bool: - for state_name in self._cell_spec_for_population(population_name).state_schema.keys: - leaf = population_state[state_name] + for state_name in self._compiled_transition_state_names_for_population(population_name): + leaf = population_state.get(state_name) if not torch.is_tensor(leaf) or leaf.device != device or leaf.dtype != dtype: return False return True @@ -3019,7 +3255,8 @@ def _forward_stream_step( materialize_cells_state: bool = True, ) -> tuple[torch.Tensor, TensorDict]: current_state = state - if resets is not None and (capture_active or has_resets is True): + should_apply_reset_mask = resets is not None and (capture_active or has_resets is not False) + if should_apply_reset_mask: materialized_population_state = [] for population_name in self._population_names: population_state_value = state.get(population_name) @@ -3124,6 +3361,7 @@ def _forward_stream_step( input_group_input_to_kv_weight=input_group_input_to_kv_weight, population_materialized=population_materialized, step_population_state_cache=step_population_state_cache, + backend_static_tensors=backend_static_tensors, ) if boundary_step is not None: @@ -3152,7 +3390,7 @@ def _forward_stream_step( gathered_kv_weight=gathered_kv_weight, step_idx=step_idx + 1, ) - if hidden_step is not None and (self.config.inject_every_step or step_idx == 0): + if hidden_step is not None and (self.config.execution.inject_every_step or step_idx == 0): msg = self._inject_hidden_inputs(msg, hidden_step.unsqueeze(1)) population_input = self.msg_to_cell(msg) + cell_bias if use_packed_loop_cache: @@ -3463,7 +3701,7 @@ def _forward_stream_step_k1( final_k = k_all final_v = v_all else: - use_backend_order_transition_buckets = bool( + use_backend_order_message_tables = bool( self._active_backend_name != "pytorch" and backend_static_tensors is not None and self.population_backend_recurrent_order.numel() == self._num_recurrent_cells @@ -3471,37 +3709,37 @@ def _forward_stream_step_k1( ) recurrent_q_for_message = ( cast(torch.Tensor, backend_static_tensors["recurrent_q_backend_order"]) - if use_backend_order_transition_buckets + if use_backend_order_message_tables else recurrent_q ) recurrent_neighbor_idx_for_message = ( self.recurrent_neighbor_idx_backend_order - if use_backend_order_transition_buckets + if use_backend_order_message_tables else self.recurrent_neighbor_idx ) recurrent_neighbor_valid_for_message = ( self.recurrent_neighbor_valid_backend_order - if use_backend_order_transition_buckets + if use_backend_order_message_tables else self.recurrent_neighbor_valid ) recurrent_edge_distance_for_message = ( self.recurrent_edge_distance_backend_order - if use_backend_order_transition_buckets + if use_backend_order_message_tables else self.recurrent_edge_distance ) recurrent_edge_delay_for_message = ( self.recurrent_edge_delay_backend_order - if use_backend_order_transition_buckets + if use_backend_order_message_tables else self.recurrent_edge_delay ) recurrent_local_sender_idx_for_message = ( self.recurrent_local_sender_idx_backend_order - if use_backend_order_transition_buckets + if use_backend_order_message_tables else self.recurrent_local_sender_idx ) recurrent_local_receiver_idx_by_sender_for_message = ( self.recurrent_local_receiver_idx_by_sender_backend_order - if use_backend_order_transition_buckets + if use_backend_order_message_tables else self.recurrent_local_receiver_idx_by_sender ) if use_partitioned_sender_banks: @@ -3542,40 +3780,19 @@ def _forward_stream_step_k1( self._active_backend_name != "pytorch" and recurrent_msg.is_cuda and backend_static_tensors is not None ) if use_cuda_flat_bucket_transition_step: - if use_backend_order_transition_buckets: - recurrent_next_backend_order, next_population_state = ( - self._run_backend_order_transition_buckets_step( - recurrent_msg, - population_state, - resets=population_resets, - batch_size=cells_prev.shape[0], - static_tensors=backend_static_tensors, - step_population_state_cache=step_population_state_cache if all_active is True else None, - materialize_next_state=materialize_population_next_state, - ) - ) - recurrent_next = recurrent_next_backend_order.index_select( - 1, - self.population_backend_recurrent_inverse_order, - ) - else: - recurrent_next, next_population_state = self._run_transition_buckets_step( + raise RuntimeError( + "Fabric CUDA stream-step transition execution requires the registered temporal program route" + ) + else: + recurrent_input, recurrent_input_already_projected = ( + self._project_recurrent_message_to_cell_step_for_message_rule( recurrent_msg, - population_state, - resets=population_resets, - batch_size=cells_prev.shape[0], - static_tensors=backend_static_tensors, - step_population_state_cache=step_population_state_cache if all_active is True else None, - materialize_next_state=materialize_population_next_state, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, + fused_recurrent_cell_bias=fused_recurrent_cell_bias, + fused_recurrent_population_input=fused_recurrent_population_input, ) - else: - recurrent_input = self._project_recurrent_message_to_cell_step( - recurrent_msg, - value_to_cell_weight=value_to_cell_weight, - recurrent_cell_bias=recurrent_cell_bias, - fused_recurrent_value_to_cell_weight=fused_recurrent_value_to_cell_weight, - fused_recurrent_cell_bias=fused_recurrent_cell_bias, - fused_recurrent_population_input=fused_recurrent_population_input, ) recurrent_next, next_population_state = self._run_population_updates_recurrent_step( recurrent_input, @@ -3584,7 +3801,7 @@ def _forward_stream_step_k1( batch_size=cells_prev.shape[0], population_materialized=population_materialized, step_population_state_cache=step_population_state_cache if all_active is True else None, - population_input_already_projected=fused_recurrent_population_input, + population_input_already_projected=recurrent_input_already_projected, ) if all_active is True: recurrent_mid = recurrent_next @@ -3773,7 +3990,7 @@ def _execute_backend_recurrence_surface_step( else population_resets.to(device=boundary_step.device, dtype=torch.bool) ).contiguous() output_seq, next_packed_state, recurrent_mid, recurrent_k, recurrent_v, input_k_seq, input_v_seq = ( - self._run_backend_sequence_surface_once( + self._execute_compiler_temporal_sequence_surface( population_name=backend_population_name, boundary_seq=boundary_step.unsqueeze(1), packed_state=packed_state, @@ -3911,6 +4128,7 @@ def _forward_stream_step_boundary_multistep( recurrent_group_size: int = 1, input_group_input_to_kv_weight: torch.Tensor | None = None, step_population_state_cache: dict[str, object] | None = None, + backend_static_tensors: dict[str, object] | None = None, ) -> tuple[torch.Tensor, TensorDict]: batch_size = cells_prev.shape[0] recurrent_mid = cells_prev[:, self.recurrent_cell_idx, :] @@ -3926,6 +4144,18 @@ def _forward_stream_step_boundary_multistep( ) use_packed_cache = step_population_state_cache is not None running_population_state = population_state + use_backend_order_message_tables = bool( + self._active_backend_name != "pytorch" + and self._partitioned_layout + and recurrent_mid.is_cuda + and backend_static_tensors is not None + and self.population_backend_recurrent_order.numel() == self._num_recurrent_cells + and torch.is_tensor(backend_static_tensors.get("recurrent_q_backend_order")) + ) + if use_backend_order_message_tables: + raise RuntimeError( + "Fabric CUDA boundary multistep transition execution requires the registered temporal program route" + ) for step_idx in range(max_steps): recurrent_k, recurrent_v = self._project_sender_kv_from_cells_step( @@ -3956,11 +4186,14 @@ def _forward_stream_step_boundary_multistep( step_idx=step_idx + 1, local_sender_idx=self.recurrent_local_sender_idx, local_receiver_idx_by_sender=self.recurrent_local_receiver_idx_by_sender, + owner_tag="recurrent", ) - recurrent_input = self._project_recurrent_message_to_cell_step( - recurrent_msg, - value_to_cell_weight=value_to_cell_weight, - recurrent_cell_bias=recurrent_cell_bias, + recurrent_input, recurrent_input_already_projected = ( + self._project_recurrent_message_to_cell_step_for_message_rule( + recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + ) ) recurrent_next, next_population_state = self._run_population_updates_recurrent_step( recurrent_input, @@ -3969,6 +4202,7 @@ def _forward_stream_step_boundary_multistep( batch_size=batch_size, population_materialized=population_materialized, step_population_state_cache=step_population_state_cache, + population_input_already_projected=recurrent_input_already_projected, ) active_rows = step_idx < k_rows recurrent_mid = torch.where(active_rows.view(-1, 1, 1), recurrent_next, recurrent_mid) @@ -4001,6 +4235,7 @@ def _forward_stream_step_boundary_multistep( step_idx=k_rows, local_sender_idx=self.output_local_sender_idx, local_receiver_idx_by_sender=self.output_local_receiver_idx_by_sender, + owner_tag="readout", ), value_to_output_weight=value_to_output_weight, ).to(dtype=cells_prev.dtype) @@ -4025,6 +4260,7 @@ def _forward_stream_step_boundary_multistep( step_idx=k_rows, local_sender_idx=self.output_local_sender_idx, local_receiver_idx_by_sender=self.output_local_receiver_idx_by_sender, + owner_tag="readout", ), value_to_output_weight=value_to_output_weight, ).to(dtype=cells_prev.dtype) @@ -4169,7 +4405,7 @@ def _select_output_cells(self, y_cells: torch.Tensor) -> torch.Tensor: def _pool_output_ports(self, port_y: torch.Tensor) -> torch.Tensor: return backend_pool_output_ports( port_y, - readout_pool=self.config.readout_pool, + readout_pool=self.config.readout.pool, readout_query=self.readout_query, ) @@ -4178,7 +4414,7 @@ def _backend_readout_config(self) -> ReadoutConfig: partitioned_layout=bool(self._partitioned_layout), output_slice=self._output_slice, output_cell_idx=self.output_cell_idx, - readout_pool=str(self.config.readout_pool), + readout_pool=str(self.config.readout.pool), ) def _select_output_cells_stream_backend_population( @@ -4186,11 +4422,10 @@ def _select_output_cells_stream_backend_population( *, k: int | torch.Tensor | None, ) -> str | None: - if self._resolve_constant_k_host(k) != 1: - return None - if not self._partitioned_layout: - return None - return self._full_recurrent_population_name + # Legacy sequence-surface callers may still ask for a population owner; + # the active shared temporal route is flat-bucket owned. + del k + return None def _plan_sequence_surface_route( self, @@ -4199,12 +4434,65 @@ def _plan_sequence_surface_route( device: torch.device, dtype: torch.dtype, ) -> SequenceSurfaceRoute: - return self._backend_planner.plan_sequence_surface_route( + return self._plan_temporal_execution( + k=k, + device=device, + dtype=dtype, + batch_size=1, + time_steps=1, + training=False, + input_boundary="unknown", + state_is_fresh=True, + has_resets=False, + configured_backend="auto", + ).sequence_surface_route + + def _plan_temporal_execution( + self, + *, + k: int | torch.Tensor | None, + device: torch.device, + dtype: torch.dtype, + batch_size: int, + time_steps: int, + training: bool, + input_boundary: str, + output_boundary: Literal["sequence", "terminal"] = "sequence", + readout_output_boundary: Literal["cells", "pooled"] = "cells", + output_contract: str = "full_cells", + materialize_final_state: bool = True, + state_is_fresh: bool = True, + has_resets: bool = False, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, + configured_backend: str | None = None, + ) -> TemporalExecutionPlan: + planned_gradient_horizon_steps = ( + self.config.execution.gradient_horizon_steps if gradient_horizon_steps is None else gradient_horizon_steps + ) + planned_checkpoint_steps = ( + self.config.execution.checkpoint_steps if checkpoint_steps is None else checkpoint_steps + ) + plan = self._backend_planner.plan_temporal_execution( device_type=device.type, dtype=str(dtype).removeprefix("torch."), partitioned_layout=bool(self._partitioned_layout), + configured_backend=str(self.config.execution.backend) if configured_backend is None else configured_backend, constant_k=self._resolve_constant_k_host(k), + time_steps=time_steps, + training=training, + input_boundary=input_boundary, + output_boundary=output_boundary, + readout_output_boundary=readout_output_boundary, + output_contract=output_contract, + materialize_final_state=materialize_final_state, + state_is_fresh=state_is_fresh, + has_resets=has_resets, + gradient_horizon_steps=planned_gradient_horizon_steps, + checkpoint_steps=planned_checkpoint_steps, ) + self._last_temporal_execution_plan = plan + return plan def _training_static_prepack_enabled(self) -> bool: population_name = self._full_recurrent_population_name @@ -4215,36 +4503,6 @@ def _training_static_prepack_enabled(self) -> bool: return True return not any(op.name == "diag_rtu" for op in population_spec.transition_ir.ops) - def _select_backend_sequence_surface( - self, - *, - training: bool, - k: int | torch.Tensor | None, - device: torch.device, - dtype: torch.dtype, - backend_population_name: str | None, - sequence_surface_route: SequenceSurfaceRoute | None = None, - ) -> SupportedSurface | None: - del training - route = sequence_surface_route or self._plan_sequence_surface_route( - k=k, - device=device, - dtype=dtype, - ) - if not route.uses_cell_recurrence_surface: - return None - if backend_population_name is None: - raise RuntimeError( - f"Supported Fabric {_population_display_name(self._full_recurrent_population_name)} recurrence surface " - "requires the backend-owned CUDA sequence surface; " - "silent fallback to non-recurrence paths is disabled" - ) - surface = supported_surface_for_cell_type(cell_type=self._population_cell_types[backend_population_name]) - population_spec = self._backend_population_specs[backend_population_name] - if surface.key not in population_spec.supported_surface_keys: - raise RuntimeError(f"Fabric cell population {backend_population_name} does not own surface {surface.key}") - return surface - def _population_num_cells(self, population_name: str) -> int: return int(self._population_indices(population_name).numel()) @@ -4262,6 +4520,7 @@ def _init_backend_population_state( receivers=self._population_num_cells(population_name), device=device, dtype=dtype, + state_names=self._compiled_transition_state_names_for_population(population_name), ) def _init_backend_population_state_for_receivers( @@ -4274,7 +4533,7 @@ def _init_backend_population_state_for_receivers( dtype: torch.dtype, state_names: tuple[str, ...] | None = None, ) -> TensorDict: - state_names = state_names or self._cell_spec_for_population(population_name).state_schema.keys + state_names = state_names or self._compiled_transition_state_names_for_population(population_name) num_receivers = int(receivers) with torch.profiler.record_function("fabric.glue.backend_population_state_zero"): return TensorDict( @@ -4291,7 +4550,15 @@ def _population_state_to_backend_state( population_name: str, population_state: TensorDictBase, ) -> TensorDict: - state_names = self._cell_spec_for_population(population_name).state_schema.keys + state_names = self._compiled_transition_state_names_for_population(population_name) + if not state_names: + if len(population_state.batch_size) == 2: + return TensorDict( + {}, + batch_size=[int(population_state.batch_size[1]), int(population_state.batch_size[0])], + device=population_state.device, + ) + return TensorDict({}, batch_size=[0, 0], device=population_state.device) first = population_state[state_names[0]] batch_size = int(first.shape[1]) num_receivers = int(first.shape[0]) @@ -4311,7 +4578,16 @@ def _backend_state_to_population_state( population_name: str, backend_state: Mapping[str, torch.Tensor], ) -> TensorDict: - state_names = self._cell_spec_for_population(population_name).state_schema.keys + state_names = self._compiled_transition_state_names_for_population(population_name) + if not state_names: + backend_batch_size = getattr(backend_state, "batch_size", ()) + if len(backend_batch_size) == 2: + return TensorDict( + {}, + batch_size=[int(backend_batch_size[1]), int(backend_batch_size[0])], + device=getattr(backend_state, "device", None), + ) + return TensorDict({}, batch_size=[0, 0], device=getattr(backend_state, "device", None)) first = backend_state[state_names[0]] batch_size = int(first.shape[0]) num_receivers = int(first.shape[1]) @@ -4387,6 +4663,90 @@ def _register_population_backend_order_buffers(self) -> None: recurrent_local_valid_backend_order, ), ) + ( + recurrent_local_receiver_idx_by_sender_compact_backend_order, + recurrent_local_receiver_slot_idx_by_sender_compact_backend_order, + ) = _build_compact_sender_reverse_tables( + int(self.sender_cell_idx.numel()), + recurrent_local_sender_idx_backend_order, + recurrent_local_valid_backend_order, + ) + self.register_buffer( + "recurrent_local_receiver_idx_by_sender_compact_backend_order", + recurrent_local_receiver_idx_by_sender_compact_backend_order, + ) + self.register_buffer( + "recurrent_local_receiver_slot_idx_by_sender_compact_backend_order", + recurrent_local_receiver_slot_idx_by_sender_compact_backend_order, + ) + recurrent_neighbor_idx_flat_bucket_carry_order = _remap_partitioned_recurrent_sender_indices( + self.recurrent_neighbor_idx.index_select(0, backend_order), + num_input_senders=int(self._num_input_cells), + backend_inverse_order=inverse_order, + ) + recurrent_local_sender_idx_flat_bucket_carry_order = _remap_partitioned_recurrent_sender_indices( + recurrent_local_sender_idx_backend_order, + num_input_senders=int(self._num_input_cells), + backend_inverse_order=inverse_order, + ) + self.register_buffer( + "recurrent_neighbor_idx_flat_bucket_carry_order", + recurrent_neighbor_idx_flat_bucket_carry_order, + ) + self.register_buffer( + "recurrent_local_sender_idx_flat_bucket_carry_order", + recurrent_local_sender_idx_flat_bucket_carry_order, + ) + self.register_buffer( + "recurrent_local_receiver_idx_by_sender_flat_bucket_carry_order", + _build_sender_reverse_table( + int(self.sender_cell_idx.numel()), + recurrent_local_sender_idx_flat_bucket_carry_order, + recurrent_local_valid_backend_order, + ), + ) + ( + recurrent_local_receiver_idx_by_sender_compact_flat_bucket_carry_order, + recurrent_local_receiver_slot_idx_by_sender_compact_flat_bucket_carry_order, + ) = _build_compact_sender_reverse_tables( + int(self.sender_cell_idx.numel()), + recurrent_local_sender_idx_flat_bucket_carry_order, + recurrent_local_valid_backend_order, + ) + self.register_buffer( + "recurrent_local_receiver_idx_by_sender_compact_flat_bucket_carry_order", + recurrent_local_receiver_idx_by_sender_compact_flat_bucket_carry_order, + ) + self.register_buffer( + "recurrent_local_receiver_slot_idx_by_sender_compact_flat_bucket_carry_order", + recurrent_local_receiver_slot_idx_by_sender_compact_flat_bucket_carry_order, + ) + output_neighbor_idx_flat_bucket_carry_order = _remap_partitioned_recurrent_sender_indices( + self.output_neighbor_idx, + num_input_senders=int(self._num_input_cells), + backend_inverse_order=inverse_order, + ) + output_local_sender_idx_flat_bucket_carry_order = _remap_partitioned_recurrent_sender_indices( + self.output_local_sender_idx, + num_input_senders=int(self._num_input_cells), + backend_inverse_order=inverse_order, + ) + self.register_buffer( + "output_neighbor_idx_flat_bucket_carry_order", + output_neighbor_idx_flat_bucket_carry_order, + ) + self.register_buffer( + "output_local_sender_idx_flat_bucket_carry_order", + output_local_sender_idx_flat_bucket_carry_order, + ) + self.register_buffer( + "output_local_receiver_idx_by_sender_flat_bucket_carry_order", + _build_sender_reverse_table( + int(self.sender_cell_idx.numel()), + output_local_sender_idx_flat_bucket_carry_order, + self.output_local_valid, + ), + ) self._register_flat_bucket_active_output_window_buffers() def _register_flat_bucket_active_output_window_buffers(self) -> None: @@ -4586,10 +4946,6 @@ def __init__( self.num_readout_slots = self.runtime.readout_slots self.in_proj = nn.Linear(self.input_dim, self.num_input_cells * self.runtime.hidden_size) self.out_proj = nn.Linear(self.num_readout_slots * self.runtime.hidden_size, self.output_dim) - # Cell-specific defaults live behind the cell plugin; these are generic overrides. - self._sequence_checkpoint_target_bytes: int | None = None - self._sequence_checkpoint_state_overhead_factor: float | None = None - self._sequence_direct_grad_target_bytes: int | None = None @property def backend_ir(self): @@ -4654,12 +5010,23 @@ def _forward_sequence_with_readout( ) -> tuple[torch.Tensor, TensorDict]: if output_boundary not in {"sequence", "terminal"}: raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") + record_registered_memory_stage(self.runtime, hidden_seq, "frontend_model_sequence_entry") + record_frontend_tensor_bytes( + self.runtime, + stage="frontend_model_sequence_entry", + tensors={ + "hidden_seq": hidden_seq, + "state": state, + "resets": resets, + }, + ) readout_batch_tile_len, readout_batch_tile_reason = self._readout_pooled_batch_tile_len( hidden_seq, k=k, materialize_final_state=materialize_final_state, output_boundary=output_boundary, ) + record_registered_memory_stage(self.runtime, hidden_seq, "frontend_model_after_tile_decision") if readout_batch_tile_len < int(hidden_seq.shape[0]): return self._forward_sequence_with_readout_batch_tiled( hidden_seq, @@ -4706,68 +5073,6 @@ def _stream_sequence_with_readout( if output_boundary not in {"sequence", "terminal"}: raise ValueError(f"Unsupported Fabric sequence output boundary {output_boundary!r}") batch_size = int(hidden_seq.shape[0]) - time_steps = int(hidden_seq.shape[1]) - checkpoint_state = state if isinstance(state, TensorDictBase) else TensorDict({}, batch_size=[]) - time_chunk_len = self._sequence_reduction_checkpoint_chunk_len( - hidden_seq, - checkpoint_state, - output_boundary=output_boundary, - ) - if 0 < time_chunk_len < time_steps: - resets_bt = _expand_resets_for_time( - resets, - batch_size=batch_size, - time_steps=time_steps, - device=hidden_seq.device, - ) - running_state: TensorDictBase | None = state if isinstance(state, TensorDictBase) else None - final_state: TensorDictBase = TensorDict({}, batch_size=[]) - for start in range(0, time_steps, time_chunk_len): - end = min(start + time_chunk_len, time_steps) - hidden_chunk = hidden_seq[:, start:end] - reset_chunk = None if resets_bt is None else resets_bt[:, start:end] - k_chunk = _slice_sequence_k(k, start=start, end=end, batch_size=batch_size, device=hidden_seq.device) - chunk_materialize_final_state = materialize_final_state or end < time_steps - chunk_output_boundary: Literal["sequence", "terminal"] = ( - output_boundary if end == time_steps else "sequence" - ) - - def consume_time_chunk( - output_chunk: torch.Tensor, - batch_start: int, - batch_end: int, - time_start: int, - time_end: int, - *, - chunk_start: int = start, - ) -> None: - output_consumer( - output_chunk, - batch_start, - batch_end, - chunk_start + time_start, - chunk_start + time_end, - ) - - next_state = self._stream_sequence_with_readout( - hidden_chunk, - running_state, - resets=reset_chunk, - k=k_chunk, - training_semantics=training_semantics, - materialize_final_state=chunk_materialize_final_state, - tape_policy=tape_policy, - output_boundary=chunk_output_boundary, - output_consumer=consume_time_chunk, - detach_internal_carry_after_output_chunk=False, - ) - if end < time_steps: - running_state = ( - _detach_tensordict(next_state) if detach_internal_carry_after_output_chunk else next_state - ) - else: - final_state = next_state - return final_state if materialize_final_state else TensorDict({}, batch_size=[]) readout_batch_tile_len, readout_batch_tile_reason = self._readout_pooled_batch_tile_len( hidden_seq, k=k, @@ -4830,7 +5135,7 @@ def _readout_pooled_batch_tile_len( device=hidden_seq.device, dtype=hidden_seq.dtype, ) - if sequence_surface_route.uses_flat_transition_buckets: + if sequence_surface_route.uses_registered_temporal_program: decision = self.runtime._flat_bucket_sequence_readout_batch_tile_len( batch_size=batch_size, time_steps=time_steps, @@ -4852,7 +5157,7 @@ def _readout_pooled_batch_tile_len( readout_slots=readout_slots, hidden_size=int(self.runtime.hidden_size), materialize_final_state=materialize_final_state, - backend_sequence_surface_supported=sequence_surface_route.uses_cell_recurrence_surface, + backend_sequence_surface_supported=False, memory=self.runtime._cuda_memory_budget(hidden_seq.device), ) return int(decision.value), decision.reason @@ -5008,328 +5313,6 @@ def _annotate_readout_pooled_batch_tile( actual_launch_readout_modes=record.actual_launch_readout_modes + ("pooled_batch_tiled",), ) - def _forward_sequence_checkpointed( - self, - hidden_seq: torch.Tensor, - state: MaybeState, - *, - resets: Optional[ResetMask], - k: int | torch.Tensor | None, - materialize_final_state: bool = True, - output_boundary: Literal["sequence", "terminal"] = "sequence", - ) -> tuple[torch.Tensor, TensorDict]: - batch_size = hidden_seq.shape[0] - current_state = self.runtime._ensure_state( - state, - batch=batch_size, - device=hidden_seq.device, - dtype=hidden_seq.dtype, - ) - chunk_len = self._sequence_checkpoint_chunk_len(hidden_seq, current_state) - resets_bt = _expand_resets_for_time( - resets, - batch_size=batch_size, - time_steps=hidden_seq.shape[1], - device=hidden_seq.device, - ) - outputs: list[torch.Tensor] = [] - running_state = current_state - - for start in range(0, hidden_seq.shape[1], chunk_len): - end = min(start + chunk_len, hidden_seq.shape[1]) - hidden_chunk = hidden_seq[:, start:end] - reset_chunk = None if resets_bt is None else resets_bt[:, start:end] - k_chunk = _slice_sequence_k(k, start=start, end=end, batch_size=batch_size, device=hidden_seq.device) - chunk_materialize_final_state = materialize_final_state or end < hidden_seq.shape[1] - chunk_output_boundary: Literal["sequence", "terminal"] = ( - output_boundary if end == hidden_seq.shape[1] else "sequence" - ) - state_paths, state_batch_sizes, state_tensors = _flatten_tensordict(running_state) - grad_marker = hidden_seq.new_zeros((), requires_grad=True) - - def run_sequence( - hidden_piece: torch.Tensor, - *flat_inputs: torch.Tensor, - chunk_paths: tuple[tuple[str, ...], ...] = state_paths, - chunk_batch_sizes: dict[tuple[str, ...], torch.Size] = state_batch_sizes, - chunk_resets: torch.Tensor | None = reset_chunk, - chunk_k: int | torch.Tensor | None = k_chunk, - chunk_materialize: bool = chunk_materialize_final_state, - chunk_output_boundary: Literal["sequence", "terminal"] = chunk_output_boundary, - ) -> tuple[torch.Tensor, ...]: - state_values = flat_inputs[:-1] - next_input_state = _unflatten_tensordict(chunk_paths, chunk_batch_sizes, state_values) - pooled, next_state = self._forward_sequence_with_readout( - hidden_piece, - next_input_state, - resets=chunk_resets, - k=chunk_k, - training_semantics=True, - materialize_final_state=chunk_materialize, - output_boundary=chunk_output_boundary, - ) - _, _, next_tensors = _flatten_tensordict(next_state) - return (pooled, *next_tensors) - - checkpoint_outputs = checkpoint( - run_sequence, - hidden_chunk, - *state_tensors, - grad_marker, - use_reentrant=False, - preserve_rng_state=False, - ) - if output_boundary == "sequence" or end == hidden_seq.shape[1]: - outputs.append(checkpoint_outputs[0]) - if chunk_materialize_final_state: - running_state = _unflatten_tensordict(state_paths, state_batch_sizes, checkpoint_outputs[1:]) - else: - running_state = TensorDict({}, batch_size=[]) - return self.out_proj(torch.cat(outputs, dim=1)), running_state - - def _sequence_checkpoint_chunk_len( - self, - hidden_seq: torch.Tensor, - state: TensorDictBase, - ) -> int: - seq_len = int(hidden_seq.shape[1]) - if seq_len <= 1: - return seq_len - state_bytes = self._estimate_sequence_state_bytes(state, hidden_seq=hidden_seq) - if state_bytes <= 0: - return seq_len - estimated_per_step_bytes = int(math.ceil(state_bytes * self._sequence_checkpoint_overhead_factor())) - target_bytes = max(1, int(self._sequence_checkpoint_target_bytes_for_state())) - chunk_len = max(1, target_bytes // max(1, estimated_per_step_bytes)) - return min(seq_len, chunk_len) - - def _estimate_sequence_state_bytes( - self, - state: TensorDictBase, - *, - hidden_seq: torch.Tensor | None = None, - ) -> int: - _, _, state_tensors = _flatten_tensordict(state) - state_bytes = sum(int(t.numel()) * int(t.element_size()) for t in state_tensors) - if state_bytes > 0 or hidden_seq is None: - return state_bytes - return self._estimate_fresh_sequence_runtime_state_bytes( - batch_size=int(hidden_seq.shape[0]), - dtype=hidden_seq.dtype, - ) - - def _estimate_fresh_sequence_runtime_state_bytes( - self, - *, - batch_size: int, - dtype: torch.dtype, - ) -> int: - dtype_bytes = int(torch.empty((), dtype=dtype).element_size()) - hidden_size = int(self.runtime.hidden_size) - cells = int(self.runtime.coords.shape[0]) - total_elements = int(batch_size) * cells * hidden_size - for population_name in self.runtime._population_names: - population_cells = int(self.runtime._population_num_cells(population_name)) - if population_cells <= 0: - continue - state_leaf_count = len(self.runtime._cell_spec_for_population(population_name).state_schema.keys) - total_elements += int(batch_size) * population_cells * hidden_size * int(state_leaf_count) - return total_elements * dtype_bytes - - def _active_cell_sequence_memory_policy(self) -> dict[str, float | int]: - default_checkpoint_target_bytes = 32 << 30 - default_checkpoint_state_overhead_factor = 4.0 - default_direct_grad_target_bytes = 96 << 30 - default_output_overhead_factor = 6.0 - checkpoint_targets: list[int] = [] - checkpoint_overheads: list[float] = [] - direct_grad_targets: list[int] = [] - output_overheads: list[float] = [] - for population_name in self.spec.population_names: - if int(self.runtime._population_num_cells(population_name)) <= 0: - continue - metadata = get_cell_spec(self.spec.config.cell_populations[population_name].cell_type).metadata or {} - sequence_memory_policy = ( - metadata["sequence_memory_policy"] if "sequence_memory_policy" in metadata else None - ) - if isinstance(sequence_memory_policy, Mapping): - checkpoint_targets.append( - int( - sequence_memory_policy["checkpoint_target_bytes"] - if "checkpoint_target_bytes" in sequence_memory_policy - else default_checkpoint_target_bytes - ) - ) - checkpoint_overheads.append( - float( - sequence_memory_policy["checkpoint_state_overhead_factor"] - if "checkpoint_state_overhead_factor" in sequence_memory_policy - else default_checkpoint_state_overhead_factor - ) - ) - direct_grad_targets.append( - int( - sequence_memory_policy["direct_grad_target_bytes"] - if "direct_grad_target_bytes" in sequence_memory_policy - else default_direct_grad_target_bytes - ) - ) - output_overheads.append( - float( - sequence_memory_policy["checkpoint_output_overhead_factor"] - if "checkpoint_output_overhead_factor" in sequence_memory_policy - else default_output_overhead_factor - ) - ) - else: - checkpoint_targets.append(default_checkpoint_target_bytes) - checkpoint_overheads.append(default_checkpoint_state_overhead_factor) - direct_grad_targets.append(default_direct_grad_target_bytes) - output_overheads.append(default_output_overhead_factor) - policy: dict[str, float | int] = { - "checkpoint_target_bytes": ( - min(checkpoint_targets) if checkpoint_targets else default_checkpoint_target_bytes - ), - "checkpoint_state_overhead_factor": ( - max(checkpoint_overheads) if checkpoint_overheads else default_checkpoint_state_overhead_factor - ), - "direct_grad_target_bytes": ( - min(direct_grad_targets) if direct_grad_targets else default_direct_grad_target_bytes - ), - "checkpoint_output_overhead_factor": ( - max(output_overheads) if output_overheads else default_output_overhead_factor - ), - } - if self._sequence_checkpoint_target_bytes is not None: - policy["checkpoint_target_bytes"] = int(self._sequence_checkpoint_target_bytes) - if self._sequence_checkpoint_state_overhead_factor is not None: - policy["checkpoint_state_overhead_factor"] = float(self._sequence_checkpoint_state_overhead_factor) - if self._sequence_direct_grad_target_bytes is not None: - policy["direct_grad_target_bytes"] = int(self._sequence_direct_grad_target_bytes) - return policy - - def _active_cell_supports_direct_grad_sequence(self) -> bool: - if len(self.spec.population_names) != 1: - return False - population_name = self.spec.population_names[0] - metadata = get_cell_spec(self.spec.config.cell_populations[population_name].cell_type).metadata or {} - return bool(metadata.get("supports_direct_grad_sequence", False)) - - def _sequence_checkpoint_overhead_factor(self) -> float: - policy = self._active_cell_sequence_memory_policy() - return float(policy["checkpoint_state_overhead_factor"]) - - def _sequence_checkpoint_target_bytes_for_state(self) -> int: - policy = self._active_cell_sequence_memory_policy() - return int(policy["checkpoint_target_bytes"]) - - def _sequence_output_overhead_factor(self) -> float: - policy = self._active_cell_sequence_memory_policy() - return float(policy.get("checkpoint_output_overhead_factor", 6.0)) - - def _estimate_sequence_output_window_bytes( - self, - hidden_seq: torch.Tensor, - *, - output_boundary: Literal["sequence", "terminal"], - ) -> int: - time_steps = int(hidden_seq.shape[1]) if output_boundary == "sequence" else 1 - if time_steps <= 0: - return 0 - dtype_bytes = int(torch.empty((), dtype=hidden_seq.dtype).element_size()) - output_elements = int(hidden_seq.shape[0]) * int(time_steps) * int(self.output_dim) - return int(math.ceil(float(output_elements * dtype_bytes) * self._sequence_output_overhead_factor())) - - def _should_use_direct_grad_sequence( - self, - hidden_seq: torch.Tensor, - state: MaybeState, - *, - materialize_final_state: bool = True, - ) -> tuple[bool, TensorDictBase | None]: - if state is None and not materialize_final_state and int(hidden_seq.shape[1]) == 1: - population_name = self.runtime._full_recurrent_population_name - if ( - population_name is not None - and self.runtime._fresh_output_dependency_receiver_count( - population_name=population_name, - time_steps=int(hidden_seq.shape[1]), - fresh_state_virtualized=True, - ) - is not None - ): - return True, None - current_state = self.runtime._ensure_state( - state, - batch=hidden_seq.shape[0], - device=hidden_seq.device, - dtype=hidden_seq.dtype, - ) - if not self._active_cell_supports_direct_grad_sequence(): - return False, current_state - state_bytes = self._estimate_sequence_state_bytes(current_state) - if state_bytes <= 0: - return True, current_state - estimated_per_step_bytes = int(math.ceil(state_bytes * self._sequence_checkpoint_overhead_factor())) - estimated_window_bytes = int(hidden_seq.shape[1]) * estimated_per_step_bytes - return estimated_window_bytes <= int(self._sequence_direct_grad_target_bytes_for_state()), current_state - - def _should_use_direct_grad_reduced_sequence( - self, - hidden_seq: torch.Tensor, - state: MaybeState, - *, - materialize_final_state: bool = True, - output_boundary: Literal["sequence", "terminal"] = "sequence", - ) -> tuple[bool, TensorDictBase | None]: - current_state = self.runtime._ensure_state( - state, - batch=hidden_seq.shape[0], - device=hidden_seq.device, - dtype=hidden_seq.dtype, - ) - if not self._active_cell_supports_direct_grad_sequence(): - return False, current_state - state_bytes = self._estimate_sequence_state_bytes(current_state) - estimated_state_step_bytes = ( - 0 if state_bytes <= 0 else int(math.ceil(state_bytes * self._sequence_checkpoint_overhead_factor())) - ) - estimated_state_window_bytes = int(hidden_seq.shape[1]) * estimated_state_step_bytes - estimated_output_window_bytes = self._estimate_sequence_output_window_bytes( - hidden_seq, - output_boundary=output_boundary, - ) - direct_target_bytes = int(self._sequence_direct_grad_target_bytes_for_state()) - estimated_window_bytes = estimated_state_window_bytes + estimated_output_window_bytes - return estimated_window_bytes <= direct_target_bytes, current_state - - def _sequence_direct_grad_target_bytes_for_state(self) -> int: - policy = self._active_cell_sequence_memory_policy() - return int(policy["direct_grad_target_bytes"]) - - def _sequence_reduction_checkpoint_chunk_len( - self, - hidden_seq: torch.Tensor, - state: TensorDictBase, - *, - output_boundary: Literal["sequence", "terminal"] = "sequence", - ) -> int: - checkpoint_chunk_len = self._sequence_checkpoint_chunk_len(hidden_seq, state) - if output_boundary == "terminal": - return checkpoint_chunk_len - seq_len = int(hidden_seq.shape[1]) - if seq_len <= 1: - return seq_len - per_step_output_bytes = self._estimate_sequence_output_window_bytes( - hidden_seq[:, :1], - output_boundary="sequence", - ) - if per_step_output_bytes <= 0: - return checkpoint_chunk_len - target_bytes = max(1, int(self._sequence_checkpoint_target_bytes_for_state())) - output_chunk_len = max(1, target_bytes // per_step_output_bytes) - return min(seq_len, checkpoint_chunk_len, output_chunk_len) - def _reduce_sequence_outputs_direct( self, hidden_seq: torch.Tensor, @@ -5357,122 +5340,6 @@ def _reduce_sequence_outputs_direct( loss = output_reducer(output_seq, 0, int(hidden_seq.shape[1])) return loss, next_state - def _reduce_sequence_outputs_checkpointed( - self, - hidden_seq: torch.Tensor, - state: TensorDictBase | None, - *, - resets: Optional[ResetMask], - k: int | torch.Tensor | None, - materialize_final_state: bool, - output_boundary: Literal["sequence", "terminal"], - output_reducer: _ModelOutputReducer, - ) -> tuple[torch.Tensor, TensorDict]: - batch_size = hidden_seq.shape[0] - initial_state = ( - self.runtime._ensure_state( - state, - batch=batch_size, - device=hidden_seq.device, - dtype=hidden_seq.dtype, - ) - if isinstance(state, TensorDictBase) - else None - ) - chunk_len = self._sequence_reduction_checkpoint_chunk_len( - hidden_seq, - initial_state if initial_state is not None else TensorDict({}, batch_size=[]), - output_boundary=output_boundary, - ) - resets_bt = _expand_resets_for_time( - resets, - batch_size=batch_size, - time_steps=hidden_seq.shape[1], - device=hidden_seq.device, - ) - running_state: TensorDictBase | None = initial_state - reduced_loss: torch.Tensor | None = None - - for start in range(0, hidden_seq.shape[1], chunk_len): - end = min(start + chunk_len, hidden_seq.shape[1]) - hidden_chunk = hidden_seq[:, start:end] - reset_chunk = None if resets_bt is None else resets_bt[:, start:end] - k_chunk = _slice_sequence_k(k, start=start, end=end, batch_size=batch_size, device=hidden_seq.device) - chunk_materialize_final_state = materialize_final_state or end < hidden_seq.shape[1] - - if running_state is None: - pooled, next_state = self._forward_sequence_with_readout( - hidden_chunk, - None, - resets=reset_chunk, - k=k_chunk, - training_semantics=True, - materialize_final_state=chunk_materialize_final_state, - output_boundary=output_boundary, - ) - output_chunk = self.out_proj(pooled) - if output_boundary == "terminal" and end < int(hidden_seq.shape[1]): - chunk_loss = output_chunk.new_zeros(()) - elif output_boundary == "terminal": - chunk_loss = output_reducer(output_chunk, int(hidden_seq.shape[1]) - 1, int(hidden_seq.shape[1])) - else: - chunk_loss = output_reducer(output_chunk, start, end) - reduced_loss = chunk_loss if reduced_loss is None else reduced_loss + chunk_loss - running_state = next_state if chunk_materialize_final_state else TensorDict({}, batch_size=[]) - continue - - state_paths, state_batch_sizes, state_tensors = _flatten_tensordict(running_state) - grad_marker = hidden_seq.new_zeros((), requires_grad=True) - - def run_sequence_chunk( - hidden_piece: torch.Tensor, - *flat_inputs: torch.Tensor, - chunk_paths: tuple[tuple[str, ...], ...] = state_paths, - chunk_batch_sizes: dict[tuple[str, ...], torch.Size] = state_batch_sizes, - chunk_resets: torch.Tensor | None = reset_chunk, - chunk_k: int | torch.Tensor | None = k_chunk, - chunk_materialize: bool = chunk_materialize_final_state, - chunk_start: int = start, - chunk_end: int = end, - ) -> tuple[torch.Tensor, ...]: - state_values = flat_inputs[:-1] - next_input_state = _unflatten_tensordict(chunk_paths, chunk_batch_sizes, state_values) - pooled, next_state = self._forward_sequence_with_readout( - hidden_piece, - next_input_state, - resets=chunk_resets, - k=chunk_k, - training_semantics=True, - materialize_final_state=chunk_materialize, - output_boundary=output_boundary, - ) - output_chunk = self.out_proj(pooled) - if output_boundary == "terminal" and chunk_end < int(hidden_seq.shape[1]): - reduced = output_chunk.new_zeros(()) - elif output_boundary == "terminal": - reduced = output_reducer(output_chunk, int(hidden_seq.shape[1]) - 1, int(hidden_seq.shape[1])) - else: - reduced = output_reducer(output_chunk, chunk_start, chunk_end) - _, _, next_tensors = _flatten_tensordict(next_state) - return (reduced, *next_tensors) - - checkpoint_outputs = checkpoint( - run_sequence_chunk, - hidden_chunk, - *state_tensors, - grad_marker, - use_reentrant=True, - preserve_rng_state=False, - ) - reduced_loss = checkpoint_outputs[0] if reduced_loss is None else reduced_loss + checkpoint_outputs[0] - if chunk_materialize_final_state: - running_state = _unflatten_tensordict(state_paths, state_batch_sizes, checkpoint_outputs[1:]) - else: - running_state = TensorDict({}, batch_size=[]) - if reduced_loss is None: - raise RuntimeError("sequence output reduction expected at least one checkpoint chunk") - return reduced_loss, running_state - def forward( self, hidden_input: Tensor, @@ -5489,92 +5356,7 @@ def forward( hidden_seq = hidden_input.unsqueeze(1) if step_mode else hidden_input if hidden_seq.dim() != 3: raise ValueError(f"Fabric expects hidden_input shaped [B,H] or [B,T,H], got {tuple(hidden_input.shape)}") - if step_mode: - sequence_surface_route = self.runtime._plan_sequence_surface_route( - k=k, - device=hidden_seq.device, - dtype=hidden_seq.dtype, - ) - use_fresh_backend_step = ( - not isinstance(state, TensorDictBase) - and not torch.is_grad_enabled() - and sequence_surface_route.supported - ) - if use_fresh_backend_step: - pooled, next_state = self._forward_sequence_with_readout( - hidden_seq, - None, - resets=resets, - k=k, - training_semantics=False, - materialize_final_state=materialize_final_state, - output_boundary="sequence", - ) - return self.out_proj(pooled.squeeze(1)), next_state - boundary_input = self.in_proj(hidden_seq).view( - hidden_seq.shape[0], hidden_seq.shape[1], self.num_input_cells, self.runtime.hidden_size - ) - y_cells, next_state = self.runtime.forward_cells( - state=state, - resets=resets, - k=k, - boundary_input=boundary_input.squeeze(1), - materialize_final_state=materialize_final_state, - ) - pooled = self.runtime._pool_output_cells(y_cells.unsqueeze(1)).squeeze(1).reshape(hidden_seq.shape[0], -1) - return self.out_proj(pooled), next_state - if torch.is_grad_enabled() and hidden_seq.shape[1] > 1: - sequence_surface_route = self.runtime._plan_sequence_surface_route( - k=k, - device=hidden_seq.device, - dtype=hidden_seq.dtype, - ) - if sequence_surface_route.supported: - pooled, next_state = self._forward_sequence_with_readout( - hidden_seq, - state if isinstance(state, TensorDictBase) else None, - resets=resets, - k=k, - training_semantics=True, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - ) - return self.out_proj(pooled), next_state - use_direct_grad, prepared_state = self._should_use_direct_grad_sequence( - hidden_seq, - state, - materialize_final_state=materialize_final_state, - ) - if not use_direct_grad: - return self._forward_sequence_checkpointed( - hidden_seq, - prepared_state, - resets=resets, - k=k, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - ) - if hidden_seq.requires_grad: - pooled, next_state = self._forward_sequence_with_readout( - hidden_seq, - prepared_state, - resets=resets, - k=k, - training_semantics=None, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - ) - return self.out_proj(pooled), next_state - pooled, next_state = self._forward_sequence_with_readout( - hidden_seq, - prepared_state, - resets=resets, - k=k, - training_semantics=True, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - ) - return self.out_proj(pooled), next_state + selected_output_boundary: Literal["sequence", "terminal"] = "sequence" if step_mode else output_boundary pooled, next_state = self._forward_sequence_with_readout( hidden_seq, state if isinstance(state, TensorDictBase) else None, @@ -5582,9 +5364,10 @@ def forward( k=k, training_semantics=None, materialize_final_state=materialize_final_state, - output_boundary=output_boundary, + output_boundary=selected_output_boundary, ) - return self.out_proj(pooled), next_state + output = self.out_proj(pooled) + return (output.squeeze(1) if step_mode else output), next_state def stream_sequence_outputs( self, @@ -5652,47 +5435,9 @@ def reduce_sequence_outputs( output_boundary="sequence", ) return output_reducer(y.unsqueeze(1), 0, 1), next_state - if not torch.is_grad_enabled() or hidden_seq.shape[1] <= 1: - return self._reduce_sequence_outputs_direct( - hidden_seq, - state if isinstance(state, TensorDictBase) else None, - resets=resets, - k=k, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - output_reducer=output_reducer, - ) - if state is None: - return self._reduce_sequence_outputs_checkpointed( - hidden_seq, - None, - resets=resets, - k=k, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - output_reducer=output_reducer, - ) - use_direct_grad, prepared_state = self._should_use_direct_grad_reduced_sequence( - hidden_seq, - state, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - ) - if use_direct_grad: - return self._reduce_sequence_outputs_direct( - hidden_seq, - cast(TensorDictBase | None, prepared_state), - resets=resets, - k=k, - materialize_final_state=materialize_final_state, - output_boundary=output_boundary, - output_reducer=output_reducer, - ) - if prepared_state is None: - raise RuntimeError("checkpointed sequence reduction requires a prepared state") - return self._reduce_sequence_outputs_checkpointed( + return self._reduce_sequence_outputs_direct( hidden_seq, - prepared_state, + state if isinstance(state, TensorDictBase) else None, resets=resets, k=k, materialize_final_state=materialize_final_state, @@ -5999,6 +5744,48 @@ def _build_sender_reverse_table( return reverse +def _build_compact_sender_reverse_tables( + num_senders: int, + receiver_sender_idx: torch.Tensor, + receiver_valid: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + degree = int(receiver_sender_idx.shape[1]) + receiver_table = torch.full((int(num_senders), degree), -1, dtype=torch.int32) + slot_table = torch.full((int(num_senders), degree), -1, dtype=torch.int32) + write_count = torch.zeros(int(num_senders), dtype=torch.long) + receiver_idx, offset_idx = torch.nonzero(receiver_valid, as_tuple=True) + sender_idx = receiver_sender_idx[receiver_idx, offset_idx] + if bool((sender_idx < 0).any()): + raise ValueError("receiver_sender_idx must be non-negative on valid local edges") + for receiver, offset, sender in zip(receiver_idx.tolist(), offset_idx.tolist(), sender_idx.tolist(), strict=True): + sender_int = int(sender) + position = int(write_count[sender_int].item()) + if position >= degree: + raise ValueError("compact local sender reverse table exceeded receiver degree") + receiver_table[sender_int, position] = int(receiver) + slot_table[sender_int, position] = int(offset) + write_count[sender_int] += 1 + return receiver_table, slot_table + + +def _remap_partitioned_recurrent_sender_indices( + sender_idx: torch.Tensor, + *, + num_input_senders: int, + backend_inverse_order: torch.Tensor, +) -> torch.Tensor: + remapped = sender_idx.clone() + recurrent_mask = remapped >= int(num_input_senders) + if not bool(recurrent_mask.any()): + return remapped + recurrent_graph_idx = remapped[recurrent_mask].to(dtype=torch.long) - int(num_input_senders) + if bool((recurrent_graph_idx < 0).any()) or bool((recurrent_graph_idx >= backend_inverse_order.numel()).any()): + raise ValueError("partitioned recurrent sender index is outside the recurrent sender bank") + recurrent_backend_idx = backend_inverse_order.index_select(0, recurrent_graph_idx) + remapped[recurrent_mask] = (recurrent_backend_idx + int(num_input_senders)).to(dtype=remapped.dtype) + return remapped + + def _contiguous_recurrent_sender_window( *, num_senders: int, @@ -6025,33 +5812,6 @@ def _contiguous_recurrent_sender_window( return start, count, contiguous -def _build_local_sender_table( - *, - receiver_coords: torch.Tensor, - sender_lookup: torch.Tensor, - local_offsets: torch.Tensor, - local_valid: torch.Tensor, - coord_shape: tuple[int, ...], - wrap: bool, -) -> torch.Tensor: - receiver_coords_long = receiver_coords.to(torch.long) - local_offsets_long = local_offsets.to(torch.long) - sender_table = torch.full(local_valid.shape, -1, dtype=torch.long) - target_coords = receiver_coords_long[:, None, :] + local_offsets_long[None, :, :] - for dim, size in enumerate(coord_shape): - if wrap: - target_coords[..., dim] = torch.remainder(target_coords[..., dim], size) - else: - target_coords[..., dim].clamp_(0, size - 1) - target_flat = target_coords[..., 0] - for dim, size in enumerate(coord_shape[1:], start=1): - target_flat = target_flat * size + target_coords[..., dim] - sender_table[local_valid] = sender_lookup.index_select(0, target_flat[local_valid]) - if bool((sender_table[local_valid] < 0).any()): - raise ValueError("Local receiver subset contains a sender outside the compact sender set") - return sender_table - - def _detect_uniform_contiguous_groups(group_ids: torch.Tensor) -> tuple[torch.Tensor | None, int]: if group_ids.numel() == 0: return None, 0 diff --git a/src/cortical/scope/scene.py b/src/cortical/scope/scene.py index affe613a..208a6366 100644 --- a/src/cortical/scope/scene.py +++ b/src/cortical/scope/scene.py @@ -107,20 +107,33 @@ def from_spec(cls, spec: Any, *, title: str | None = None) -> Scene: else: edge_delay = spec.anatomy.edge_delay[neighbor_valid].to(torch.int16) - coord_shape = tuple(int(v) for v in spec.config.coord_shape) + coord_shape = tuple( + int(v) + for v in getattr(spec.anatomy, "metadata", {}).get( + "coord_shape", + (int(spec.anatomy.num_cells),), + ) + ) + config_graph = getattr(spec.config, "graph", None) + wrap = bool( + getattr(spec.anatomy, "metadata", {}).get( + "wrap", + getattr(spec.config, "wrap", getattr(config_graph, "wrap", False)), + ) + ) coords = _topology_coordinate_layout( spec.anatomy.coords.to(torch.float32), coord_shape=coord_shape, - wrap=bool(spec.config.wrap), + wrap=wrap, ) metadata = { "source": "fabric.Spec", "coord_dim": int(spec.anatomy.coord_dim), "coord_shape": coord_shape, - "wrap": bool(spec.config.wrap), - "local_radius_axes": _config_local_radius_axes(spec.config), - "layout": {"kind": "topology", "method": _topology_layout_method(coord_shape, wrap=bool(spec.config.wrap))}, - "hidden_size": int(spec.config.hidden_size), + "wrap": wrap, + "local_radius_axes": _config_local_radius_axes(spec.config, coord_dim=int(spec.anatomy.coord_dim)), + "layout": {"kind": "topology", "method": _topology_layout_method(coord_shape, wrap=wrap)}, + "hidden_size": _config_hidden_size(spec.config), "kv_group_count": int(spec.num_kv_groups), "graph_summary": _dataclass_to_dict(spec.graph_topology.summary()), } @@ -245,12 +258,24 @@ def _edge_offsets_from_degree(degree: torch.Tensor) -> torch.Tensor: return offsets -def _config_local_radius_axes(config: Any) -> tuple[float, ...]: +def _config_local_radius_axes(config: Any, *, coord_dim: int | None = None) -> tuple[float, ...]: if hasattr(config, "local_radius_axes"): return tuple(float(value) for value in config.local_radius_axes) - radius = float(getattr(config, "local_radius", 1.5)) - coord_dim = int(getattr(config, "coord_dim", len(getattr(config, "coord_shape", (0, 1))))) - return tuple(radius for _ in range(max(1, coord_dim))) + graph = getattr(config, "graph", None) + if graph is not None and hasattr(graph, "local_radius"): + radius = float(graph.local_radius()) + else: + radius = float(getattr(config, "local_radius", 1.5)) + if coord_dim is None: + coord_dim = int(getattr(config, "coord_dim", len(getattr(config, "coord_shape", (0, 1))))) + return tuple(radius for _ in range(max(1, int(coord_dim)))) + + +def _config_hidden_size(config: Any) -> int: + interface = getattr(config, "interface", None) + if interface is not None and hasattr(interface, "hidden_size"): + return int(interface.hidden_size) + return int(getattr(config, "hidden_size", 8)) def _looks_like_spec(source: object) -> bool: diff --git a/src/cortical/scope/sculpt_lattice.py b/src/cortical/scope/sculpt_lattice.py index cdacaef9..0e66836c 100644 --- a/src/cortical/scope/sculpt_lattice.py +++ b/src/cortical/scope/sculpt_lattice.py @@ -85,9 +85,19 @@ def lattice_controls() -> list[Control]: def lattice_config(params: dict[str, Any]): - from cortical.fabric.config import CellPopulationConfig, Config + from cortical.fabric.config import ( + CellPopulationConfig, + Config, + FabricInterfaceConfig, + InitializationConfig, + MessageConfig, + PopulationLayoutConfig, + ) + from cortical.fabric.graphs import lattice2d params = normalize_lattice_params(params) + if int(params["z_size"]) > 1: + raise ValueError("Fabric Config currently supports lattice2d graphs; z_size is scope scene metadata only") mode = str(params.get("population_mode", "slstm")) if mode == "mixed": slstm_mix = float(params.get("slstm_mix", 0.65)) @@ -102,17 +112,20 @@ def lattice_config(params: dict[str, Any]): else: cell_populations = {"slstm": CellPopulationConfig(cell_type="slstm")} population_mix = {"slstm": 1.0} + connectivity: list[object] = [lattice2d.LocalRadius(radius=scalar_local_radius(params))] + if int(params["patch_edges_per_cell"]) > 0: + connectivity.append(lattice2d.PatchEdges(per_cell=int(params["patch_edges_per_cell"]))) return Config( - width=int(params["x_size"]), - height=int(params["y_size"]), - depth=int(params["z_size"]), - hidden_size=int(params["hidden_size"]), - local_radius=scalar_local_radius(params), - patch_edges_per_cell=int(params["patch_edges_per_cell"]), - wrap=bool(params["wrap"]), - cell_populations=cell_populations, - population_mix=population_mix, - seed=int(params["seed"]), + graph=lattice2d.Graph( + width=int(params["x_size"]), + height=int(params["y_size"]), + wrap=bool(params["wrap"]), + connectivity=tuple(connectivity), + ), + interface=FabricInterfaceConfig(hidden_size=int(params["hidden_size"])), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig(cell_populations=cell_populations, population_mix=population_mix), + initialization=InitializationConfig(seed=int(params["seed"])), ) @@ -184,26 +197,27 @@ def lattice_scene(params: dict[str, Any]) -> Scene: def export_lattice_python(params: dict[str, Any]) -> str: - try: - cfg = lattice_config(params) - except ModuleNotFoundError: - return _export_lattice_from_raw_params(params) - if cfg.depth > 1 or len(cfg.cell_populations) > 1: - return _export_config_python(cfg) - connectivity = [f" fabric.graphs.lattice2d.LocalRadius(radius={cfg.local_radius}),"] - if cfg.patch_edges_per_cell > 0: - connectivity.append(f" fabric.graphs.lattice2d.PatchEdges(per_cell={cfg.patch_edges_per_cell}),") - populations = ", ".join( - f"{name!r}: fabric.Population(cell=fabric.cells.{_cell_ctor(pop.cell_type)}(hidden_dim={cfg.hidden_size}))" - for name, pop in cfg.cell_populations.items() + params = normalize_lattice_params(params) + if int(params["z_size"]) > 1 or str(params.get("population_mode", "slstm")) == "mixed": + return _export_config_python_from_params(params) + radius = scalar_local_radius(params) + hidden_size = int(params["hidden_size"]) + patch_edges_per_cell = int(params["patch_edges_per_cell"]) + connectivity = [f" fabric.graphs.lattice2d.LocalRadius(radius={radius}),"] + if patch_edges_per_cell > 0: + connectivity.append(f" fabric.graphs.lattice2d.PatchEdges(per_cell={patch_edges_per_cell}),") + population_mode = str(params.get("population_mode", "slstm")) + cell_type = "axoncell" if population_mode == "axoncell" else "slstm" + populations = ( + f"{cell_type!r}: fabric.Population(cell=fabric.cells.{_cell_ctor(cell_type)}(hidden_dim={hidden_size}))" ) connectivity_body = "\n".join(connectivity) return f"""import cortical.fabric as fabric graph = fabric.graphs.lattice2d.Graph( - width={cfg.width}, - height={cfg.height}, - wrap={cfg.wrap}, + width={int(params["x_size"])}, + height={int(params["y_size"])}, + wrap={bool(params["wrap"])}, connectivity=[ {connectivity_body} ], @@ -211,11 +225,11 @@ def export_lattice_python(params: dict[str, Any]) -> str: ) blueprint = fabric.Blueprint( - interface=fabric.Interface(public_dim={cfg.hidden_size}, message_dim={cfg.hidden_size}), + interface=fabric.Interface(public_dim={hidden_size}, message_dim={hidden_size}), graph=graph, - inputs={{"tokens": fabric.Input(dim={cfg.hidden_size})}}, - outputs={{"prediction": fabric.Output(dim={cfg.hidden_size})}}, - message_passing=fabric.message_rules.DotProduct(head_dim={cfg.hidden_size}), + inputs={{"tokens": fabric.Input(dim={hidden_size})}}, + outputs={{"prediction": fabric.Output(dim={hidden_size})}}, + message_passing=fabric.message_rules.DotProduct(head_dim={hidden_size}), ) """ @@ -328,80 +342,59 @@ def _sender_for_delta( def _export_lattice_from_raw_params(params: dict[str, Any]) -> str: - z_size = int(params.get("z_size", 1)) - population_mode = str(params.get("population_mode", "slstm")) - if z_size > 1 or population_mode == "mixed": - if population_mode == "mixed": - cell_populations = "'slstm': CellPopulationConfig(cell_type='slstm'), 'axoncell': CellPopulationConfig(cell_type='axoncell')" - slstm_mix = float(params.get("slstm_mix", 0.65)) - population_mix = f"'slstm': {slstm_mix}, 'axoncell': {1.0 - slstm_mix}" - else: - cell_type = "axoncell" if population_mode == "axoncell" else "slstm" - cell_populations = f"{cell_type!r}: CellPopulationConfig(cell_type={cell_type!r})" - population_mix = f"{cell_type!r}: 1.0" - return f"""from cortical.fabric.anatomy import init -from cortical.fabric.config import CellPopulationConfig, Config - -spec = init(Config( - width={int(params.get("x_size", 32))}, - height={int(params.get("y_size", 24))}, - depth={z_size}, - hidden_size={int(params.get("hidden_size", 8))}, - local_radius={scalar_local_radius(params)}, - patch_edges_per_cell={int(params.get("patch_edges_per_cell", 0))}, - wrap={bool(params.get("wrap", True))}, - cell_populations={{{cell_populations}}}, - population_mix={{{population_mix}}}, - seed={int(params.get("seed", 0))}, -)) -""" - cell_type = "axoncell" if params.get("population_mode") == "axoncell" else "slstm" - population_ctor = _cell_ctor(cell_type) - radius = scalar_local_radius(params) - hidden_size = int(params.get("hidden_size", 8)) - return f"""import cortical.fabric as fabric + return _export_config_python_from_params(normalize_lattice_params(params)) -graph = fabric.graphs.lattice2d.Graph( - width={int(params.get("x_size", 32))}, - height={int(params.get("y_size", 24))}, - wrap={bool(params.get("wrap", True))}, - connectivity=[ - fabric.graphs.lattice2d.LocalRadius(radius={radius}), - fabric.graphs.lattice2d.PatchEdges(per_cell={int(params.get("patch_edges_per_cell", 0))}), - ], - populations={{"core": fabric.Population(cell=fabric.cells.{population_ctor}(hidden_dim={hidden_size}))}}, -) -blueprint = fabric.Blueprint( - interface=fabric.Interface(public_dim={hidden_size}, message_dim={hidden_size}), - graph=graph, - inputs={{"tokens": fabric.Input(dim={hidden_size})}}, - outputs={{"prediction": fabric.Output(dim={hidden_size})}}, - message_passing=fabric.message_rules.DotProduct(head_dim={hidden_size}), -) -""" - - -def _export_config_python(cfg: Any) -> str: - cell_populations = ", ".join( - f"{name!r}: CellPopulationConfig(cell_type={pop.cell_type!r})" for name, pop in cfg.cell_populations.items() +def _export_config_python_from_params(params: dict[str, Any]) -> str: + mode = str(params.get("population_mode", "slstm")) + if mode == "mixed": + slstm_mix = float(params.get("slstm_mix", 0.65)) + cell_populations = ( + "'slstm': CellPopulationConfig(cell_type='slstm'), 'axoncell': CellPopulationConfig(cell_type='axoncell')" + ) + population_mix = f"'slstm': {slstm_mix!r}, 'axoncell': {(1.0 - slstm_mix)!r}" + else: + cell_type = "axoncell" if mode == "axoncell" else "slstm" + cell_populations = f"{cell_type!r}: CellPopulationConfig(cell_type={cell_type!r})" + population_mix = f"{cell_type!r}: 1.0" + patch_edges = "" + if int(params.get("patch_edges_per_cell", 0)) > 0: + patch_edges = f"\n lattice2d.PatchEdges(per_cell={int(params.get('patch_edges_per_cell', 0))})," + z_comment = ( + f"# z_size={int(params.get('z_size', 1))} is represented by cortical.scope scene metadata; " + "current Fabric runtime Config consumes lattice2d graph facts.\n" + if int(params.get("z_size", 1)) > 1 + else "" ) - population_mix = ", ".join(f"{name!r}: {cfg.population_mix[name]!r}" for name in cfg.cell_populations) return f"""from cortical.fabric.anatomy import init -from cortical.fabric.config import CellPopulationConfig, Config - -spec = init(Config( - width={cfg.width}, - height={cfg.height}, - depth={cfg.depth}, - hidden_size={cfg.hidden_size}, - local_radius={cfg.local_radius}, - patch_edges_per_cell={cfg.patch_edges_per_cell}, - wrap={cfg.wrap}, - cell_populations={{{cell_populations}}}, - population_mix={{{population_mix}}}, - seed={cfg.seed}, -)) +from cortical.fabric.config import ( + CellPopulationConfig, + Config, + FabricInterfaceConfig, + InitializationConfig, + MessageConfig, + PopulationLayoutConfig, +) +from cortical.fabric.graphs import lattice2d + +{z_comment}config = Config( + graph=lattice2d.Graph( + width={int(params.get("x_size", 32))}, + height={int(params.get("y_size", 24))}, + wrap={bool(params.get("wrap", True))}, + connectivity=( + lattice2d.LocalRadius(radius={scalar_local_radius(params)}),{patch_edges} + ), + ), + interface=FabricInterfaceConfig(hidden_size={int(params.get("hidden_size", 8))}), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={{{cell_populations}}}, + population_mix={{{population_mix}}}, + ), + initialization=InitializationConfig(seed={int(params.get("seed", 0))}), +) +spec = init(config) """ diff --git a/tests/test_evaluation_run.py b/tests/test_evaluation_run.py index 11917251..2d690ff5 100644 --- a/tests/test_evaluation_run.py +++ b/tests/test_evaluation_run.py @@ -257,7 +257,8 @@ def test_fabric_stack_builder_exposes_patch_edges(): patch_max_distance=4.0, ) - assert stack.spec.config.patch_edges_per_cell == 2 - assert stack.spec.config.patch_min_dist == 2.0 - assert stack.spec.config.patch_max_dist == 4.0 + patch_edges = stack.spec.config.graph.patch_edges() + assert patch_edges.per_cell == 2 + assert patch_edges.min_distance == 2.0 + assert patch_edges.max_distance == 4.0 assert stack.spec.graph_topology.edge_count > 0 diff --git a/tests/test_fabric_anatomy.py b/tests/test_fabric_anatomy.py index 3d36f15d..37e3bf73 100644 --- a/tests/test_fabric_anatomy.py +++ b/tests/test_fabric_anatomy.py @@ -1,25 +1,63 @@ from __future__ import annotations -import cortical.fabric.anatomy as anatomy_mod +import pytest import torch + from cortical.fabric.anatomy import init -from cortical.fabric.config import CellPopulationConfig, Config +from cortical.fabric.config import ( + CellPopulationConfig, + Config, + FabricInterfaceConfig, + InitializationConfig, + MessageConfig, + PopulationLayoutConfig, +) +from cortical.fabric.graphs import flat, lattice2d +from cortical.fabric.population import Population +from cortical.fabric.cells import SLSTM + + +def _lattice_config( + *, + width: int, + height: int, + hidden_size: int = 8, + populations: dict[str, CellPopulationConfig] | None = None, + population_mix: dict[str, float] | None = None, + population_node_indices: dict[str, tuple[int, ...]] | None = None, + cell_arrangement: str = "random", + projection_region_shape: tuple[int, ...] | None = None, + graph: lattice2d.Graph | None = None, + seed: int = 0, +) -> Config: + populations = populations or {"slstm": CellPopulationConfig(cell_type="slstm")} + population_mix = population_mix or {next(iter(populations)): 1.0} + graph = graph or lattice2d.Graph(width=width, height=height) + return Config( + graph=graph, + interface=FabricInterfaceConfig(hidden_size=hidden_size), + message=MessageConfig(projection_region_shape=projection_region_shape), + populations=PopulationLayoutConfig( + cell_populations=populations, + population_mix=population_mix, + population_node_indices=population_node_indices, + cell_arrangement=cell_arrangement, # type: ignore[arg-type] + ), + initialization=InitializationConfig(seed=seed), + ) def test_init_fabric_builds_2d_anatomy_with_ports_and_neighbors(): spec = init( - Config( + _lattice_config( width=4, height=3, - hidden_size=8, - cell_populations={ + populations={ "slstm": CellPopulationConfig(cell_type="slstm"), "axoncell": CellPopulationConfig(cell_type="axoncell"), }, population_mix={"slstm": 0.5, "axoncell": 0.5}, projection_region_shape=(2, 1), - input_band_width=1, - output_band_width=1, seed=3, ) ) @@ -34,29 +72,41 @@ def test_init_fabric_builds_2d_anatomy_with_ports_and_neighbors(): assert bool((spec.anatomy.edge_distance[spec.anatomy.neighbor_valid] > 0).all()) -def test_init_fabric_builds_3d_anatomy(): +def test_init_fabric_builds_user_defined_flat_graph(): + graph = flat.Graph( + node_count=6, + inputs={"source": (0,)}, + outputs={"sink": flat.Output((5,))}, + recurrent=(1, 2, 3, 4), + edges=((1, 0), (2, 1), (3, 2), (4, 3), (5, 4)), + populations={"core": Population(cell=SLSTM(hidden_dim=4))}, + kv_group_ids=(0, 0, 1, 1, 2, 2), + ) spec = init( Config( - width=3, - height=2, - depth=2, - hidden_size=4, - projection_region_shape=(1, 1, 1), + graph=graph, + interface=FabricInterfaceConfig(hidden_size=4), + populations=PopulationLayoutConfig( + cell_populations={"core": CellPopulationConfig(cell_type="slstm")}, + population_mix={"core": 1.0}, + ), ) ) - assert spec.anatomy.coords.shape == (12, 3) - assert spec.anatomy.coord_dim == 3 - assert spec.anatomy.num_cells == 12 + assert spec.anatomy.coords.shape == (6, 1) + assert spec.anatomy.coord_dim == 1 + assert spec.input_cell_idx.tolist() == [0] + assert spec.output_cell_idx.tolist() == [5] + assert spec.recurrent_cell_idx.tolist() == [1, 2, 3, 4] + assert spec.graph_topology.summary().edge_count == 5 def test_init_fabric_supports_banded_family_arrangement(): spec = init( - Config( + _lattice_config( width=4, height=2, - hidden_size=8, - cell_populations={ + populations={ "axoncell": CellPopulationConfig(cell_type="axoncell"), "slstm": CellPopulationConfig(cell_type="slstm"), }, @@ -77,18 +127,15 @@ def test_init_fabric_supports_banded_family_arrangement(): def test_port_cells_use_source_sink_connectivity(): spec = init( - Config( + _lattice_config( width=4, height=3, - hidden_size=8, - cell_populations={ + populations={ "slstm": CellPopulationConfig(cell_type="slstm"), "axoncell": CellPopulationConfig(cell_type="axoncell"), }, population_mix={"slstm": 0.5, "axoncell": 0.5}, - input_band_width=1, - output_band_width=1, - wrap=False, + graph=lattice2d.Graph(width=4, height=3, wrap=False), seed=9, ) ) @@ -104,16 +151,10 @@ def test_port_cells_use_source_sink_connectivity(): def test_output_cells_do_not_read_directly_from_input_cells(): spec = init( - Config( + _lattice_config( width=3, height=8, - hidden_size=8, - cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, - population_mix={"slstm": 1.0}, - input_band_width=1, - output_band_width=1, - local_radius=1.5, - wrap=True, + graph=lattice2d.Graph(width=3, height=8, connectivity=(lattice2d.LocalRadius(1.5),)), seed=7, ) ) @@ -128,13 +169,10 @@ def test_output_cells_do_not_read_directly_from_input_cells(): def test_explicit_boundary_indices_are_flat_graph_ports(): spec = init( - Config( + _lattice_config( width=4, height=4, - hidden_size=8, - input_cell_indices=(0, 5), - output_cell_indices=(10, 15), - wrap=False, + graph=lattice2d.Graph(width=4, height=4, inputs=(0, 5), outputs=(10, 15), wrap=False), ) ) @@ -152,28 +190,32 @@ def test_explicit_boundary_indices_are_flat_graph_ports(): def test_explicit_boundary_indices_validate_flat_node_sets(): - for kwargs in ( - {"input_cell_indices": (0, 0), "output_cell_indices": (14, 15)}, - {"input_cell_indices": (0, 1), "output_cell_indices": (1, 15)}, - {"input_cell_indices": (0, 16), "output_cell_indices": (14, 15)}, + for graph in ( + lattice2d.Graph(width=4, height=4, inputs=(0, 0), outputs=(14, 15)), + lattice2d.Graph(width=4, height=4, inputs=(0, 1), outputs=(1, 15)), + lattice2d.Graph(width=4, height=4, inputs=(0, 16), outputs=(14, 15)), ): - try: - Config(width=4, height=4, hidden_size=8, **kwargs) - except ValueError: - continue - raise AssertionError(f"expected invalid boundary indices to fail for {kwargs}") + with pytest.raises(ValueError): + init(_lattice_config(width=4, height=4, graph=graph)) def test_explicit_graph_edges_build_flat_topology(): spec = init( - Config( + _lattice_config( width=4, height=4, - hidden_size=8, - input_cell_indices=(0, 1), - output_cell_indices=(14, 15), - graph_edges=((2, 3), (3, 2), (4, 2), (14, 13), (15, 13)), - kv_group_ids=tuple(idx // 2 for idx in range(16)), + graph=lattice2d.Graph( + width=4, + height=4, + inputs=(0, 1), + outputs=(14, 15), + connectivity=( + lattice2d.ExplicitEdges( + edges=((2, 3), (3, 2), (4, 2), (14, 13), (15, 13)), + kv_group_ids=tuple(idx // 2 for idx in range(16)), + ), + ), + ), ) ) @@ -185,21 +227,12 @@ def test_explicit_graph_edges_build_flat_topology(): assert graph_summary.degree_histogram == ((0, 11), (1, 5)) -def test_init_fabric_large_shape_avoids_dense_pairwise_distance(monkeypatch) -> None: - def fail_pairwise(*args, **kwargs): - raise AssertionError("dense pairwise distances should not be used for anatomy construction") - - monkeypatch.setattr(anatomy_mod, "_pairwise_distances", fail_pairwise) - +def test_init_fabric_large_shape_avoids_dense_pairwise_distance() -> None: spec = init( - Config( + _lattice_config( width=128, height=128, - hidden_size=8, - local_radius=1.5, - patch_edges_per_cell=0, - input_band_width=1, - output_band_width=1, + graph=lattice2d.Graph(width=128, height=128, connectivity=(lattice2d.LocalRadius(1.5),)), ) ) diff --git a/tests/test_fabric_audit_runner.py b/tests/test_fabric_audit_runner.py new file mode 100644 index 00000000..8bc64de7 --- /dev/null +++ b/tests/test_fabric_audit_runner.py @@ -0,0 +1,409 @@ +from __future__ import annotations + +import json +import sys +from pathlib import Path + +_CORTICAL_ROOT = Path(__file__).resolve().parents[1] +if str(_CORTICAL_ROOT) not in sys.path: + sys.path.insert(0, str(_CORTICAL_ROOT)) + +from benchmarks.fabric.audit import ( # noqa: E402 + _case_requires_mixed_stack_baseline, + _cuda_temporal_owner_gate, + _mixed_stack_baseline_gate, + build_case_manifest, + build_parser, + run_audit, + select_reference_key, +) + + +def test_fabric_audit_manifest_uses_april21_reference_keys() -> None: + cases = build_case_manifest( + plan="t1-single-pop", + families=("slstm",), + target_params=(100_000_000,), + modes=("forward_backward",), + batches=(1024,), + seq_lens=(1,), + inner_steps=(1,), + hidden_sizes=(32,), + ) + + assert len(cases) == 1 + assert cases[0].reference_key == "h32_t1_bxparams" + assert cases[0].high_level_api_contract == "model_forward_external_loss_backward_optimizer_step" + assert cases[0].owner_stage == "R11" + assert cases[0].reset_mode == "absent" + + +def test_fabric_audit_manifest_expands_reset_present_cases() -> None: + cases = build_case_manifest( + plan="tk-scaling", + families=("slstm",), + target_params=(1_000_000,), + modes=("forward_backward",), + batches=(2,), + seq_lens=(4,), + inner_steps=(1, 2), + hidden_sizes=(8,), + gradient_horizon_steps=(2,), + checkpoint_steps=(None,), + reset_modes=("absent", "present"), + ) + + assert {case.reset_mode for case in cases} == {"absent", "present"} + assert {case.training_output_boundary for case in cases} == {"sequence", "terminal"} + assert all("P19" in case.prompt_requirements for case in cases) + assert all(case.case_id.endswith(f"_reset{case.reset_mode}") for case in cases) + + +def test_fabric_audit_manifest_keeps_k128_sweep_target() -> None: + k_sweep = (1, 2, 4, 8, 16, 32, 64, 128) + cases = build_case_manifest( + plan="tk-scaling", + families=("slstm",), + target_params=(1_000_000,), + modes=("forward_backward",), + batches=(2,), + seq_lens=(4,), + inner_steps=k_sweep, + hidden_sizes=(8,), + gradient_horizon_steps=(64,), + checkpoint_steps=(None,), + reset_modes=("present",), + population_modes=("single", "mixed"), + ) + + assert {case.inner_steps for case in cases} == set(k_sweep) + assert {case.training_output_boundary for case in cases} == {"sequence", "terminal"} + assert any( + case.inner_steps == 128 and case.population_mode == "mixed" and case.training_output_boundary == "terminal" + for case in cases + ) + assert any( + case.inner_steps == 128 and case.population_mode == "mixed" and case.training_output_boundary == "sequence" + for case in cases + ) + assert all(case.gradient_horizon_steps == 64 for case in cases) + assert all(case.checkpoint_steps is None for case in cases) + + +def test_fabric_audit_tk_manifest_includes_terminal_and_sequence_loss_boundaries() -> None: + cases = build_case_manifest( + plan="tk-scaling", + families=("slstm",), + target_params=(1_000_000,), + modes=("forward", "forward_backward"), + batches=(2,), + seq_lens=(4,), + inner_steps=(2,), + hidden_sizes=(8,), + gradient_horizon_steps=(2,), + checkpoint_steps=(None,), + ) + + assert {case.training_output_boundary for case in cases if case.mode == "forward"} == {"sequence"} + assert {case.training_output_boundary for case in cases if case.mode == "forward_backward"} == { + "sequence", + "terminal", + } + assert len({case.case_id for case in cases}) == len(cases) + + +def test_fabric_audit_mixed_stack_baseline_is_t1_k1_only() -> None: + t1_cases = build_case_manifest( + plan="t1-single-pop", + families=("slstm",), + target_params=(1_000_000,), + modes=("forward",), + batches=(2,), + seq_lens=(1,), + inner_steps=(1,), + hidden_sizes=(32,), + population_modes=("mixed",), + ) + tk_cases = build_case_manifest( + plan="tk-scaling", + families=("slstm",), + target_params=(1_000_000,), + modes=("forward_backward",), + batches=(2,), + seq_lens=(4,), + inner_steps=(1, 128), + hidden_sizes=(8,), + gradient_horizon_steps=(64,), + checkpoint_steps=(None,), + population_modes=("mixed",), + ) + + assert all(_case_requires_mixed_stack_baseline(case) for case in t1_cases) + assert not any(_case_requires_mixed_stack_baseline(case) for case in tk_cases) + + +def test_fabric_audit_manifest_expands_shared_population_modes() -> None: + cases = build_case_manifest( + plan="t1-single-pop", + families=("slstm",), + target_params=(100_000_000,), + modes=("forward",), + batches=(1024,), + seq_lens=(1,), + inner_steps=(1,), + hidden_sizes=(32,), + population_modes=("single", "mixed"), + ) + + assert {case.population_mode for case in cases} == {"single", "mixed"} + by_mode = {case.population_mode: case for case in cases} + assert by_mode["single"].owner_stage == "R11" + assert by_mode["mixed"].owner_stage == "R12" + assert "P11" in by_mode["single"].prompt_requirements + assert "P12" in by_mode["mixed"].prompt_requirements + assert "_popsingle_" in by_mode["single"].case_id + assert "_popmixed_" in by_mode["mixed"].case_id + + +def test_fabric_audit_streaming_reference_key_prefers_exact_april21_row() -> None: + key = select_reference_key( + family="slstm", + target_params=1_000_000_000, + mode="forward_backward", + batch_size=512, + seq_len=4096, + inner_steps=1, + hidden_size=32, + training_output_boundary="sequence", + ) + + assert key == "streaming_sequence_loss:slstm:1b:b512:t4096:h32" + + +def test_fabric_audit_cuda_temporal_owner_gate_fails_forbidden_python_scan() -> None: + gate = _cuda_temporal_owner_gate( + { + "status": "ok", + "mode": "forward_backward", + "planner_signature": { + "temporal_plan_forward_owners": ["python_autograd_scan"], + "temporal_plan_backward_owners": ["python_autograd_scan"], + }, + } + ) + + assert gate is not None + assert gate["reason"] == "forward_temporal_owner_not_registered_program" + + +def test_fabric_audit_cuda_temporal_owner_gate_passes_registered_program() -> None: + gate = _cuda_temporal_owner_gate( + { + "status": "ok", + "mode": "forward_backward", + "planner_signature": { + "temporal_plan_forward_owners": ["registered_fused_forward_program_cuda"], + "temporal_plan_backward_owners": ["registered_reverse_executor_bindings"], + "launch_temporal_scan_owners": ["registered_fused_forward_program_cuda"], + "launch_scan_implementations": ["registered_temporal_fused_forward_program_cuda"], + "backward_physical_op_executors": ["physical_temporal_bucket_sequence_backward"], + "temporal_primitive_executor_blockers": [], + }, + } + ) + + assert gate is None + + +def test_fabric_audit_cuda_temporal_owner_gate_fails_runtime_forward_relabel() -> None: + gate = _cuda_temporal_owner_gate( + { + "status": "ok", + "mode": "forward_backward", + "planner_signature": { + "temporal_plan_forward_owners": ["registered_fused_forward_program_cuda"], + "temporal_plan_backward_owners": ["registered_reverse_executor_bindings"], + "launch_temporal_scan_owners": ["python_autograd_scan"], + "launch_scan_implementations": ["registered_temporal_fused_forward_program_cuda"], + "backward_physical_op_executors": ["physical_temporal_bucket_sequence_backward"], + }, + } + ) + + assert gate is not None + assert gate["reason"] == "runtime_forward_temporal_owner_not_registered_program" + + +def test_fabric_audit_cuda_temporal_owner_gate_fails_primitive_executor_blockers() -> None: + gate = _cuda_temporal_owner_gate( + { + "status": "ok", + "mode": "forward_backward", + "planner_signature": { + "temporal_plan_forward_owners": ["registered_fused_forward_program_cuda"], + "temporal_plan_backward_owners": ["registered_reverse_executor_bindings"], + "launch_temporal_scan_owners": ["registered_fused_forward_program_cuda"], + "launch_scan_implementations": ["registered_temporal_fused_forward_program_cuda"], + "backward_physical_op_executors": ["physical_temporal_bucket_sequence_backward"], + "temporal_primitive_executor_blockers": [ + "primitive=message,bucket=*,reason=message_primitive_rows_missing_from_temporal_table", + ], + }, + } + ) + + assert gate is not None + assert gate["reason"] == "temporal_primitive_executor_blockers_present" + + +def test_fabric_audit_cuda_temporal_owner_gate_fails_forbidden_backward_timing() -> None: + gate = _cuda_temporal_owner_gate( + { + "status": "ok", + "mode": "forward_backward", + "planner_signature": { + "temporal_plan_forward_owners": ["registered_fused_forward_program_cuda"], + "temporal_plan_backward_owners": ["registered_reverse_executor_bindings"], + "launch_temporal_scan_owners": ["registered_fused_forward_program_cuda"], + "launch_scan_implementations": ["registered_temporal_fused_forward_program_cuda"], + "backward_physical_op_executors": ["physical_temporal_bucket_sequence_backward"], + "backward_owner_timing_ms": [ + "transition_message_reverse_table_device_loop:ms=1.000;count=1", + ], + }, + } + ) + + assert gate is not None + assert gate["reason"] == "forbidden_backward_temporal_owner_timing_present" + + +def test_fabric_audit_cuda_temporal_owner_gate_accepts_physical_sequence_executor_but_fails_timing() -> None: + gate = _cuda_temporal_owner_gate( + { + "status": "ok", + "mode": "forward_backward", + "planner_signature": { + "temporal_plan_forward_owners": ["registered_fused_forward_program_cuda"], + "temporal_plan_backward_owners": ["registered_reverse_executor_bindings"], + "launch_temporal_scan_owners": ["registered_fused_forward_program_cuda"], + "launch_scan_implementations": ["registered_temporal_fused_forward_program_cuda"], + "backward_physical_op_executors": ["physical_temporal_bucket_sequence_backward"], + "backward_owner_timing_ms": [ + "temporal_artifact_recompute:ms=1.000;count=1", + ], + }, + } + ) + + assert gate is not None + assert gate["reason"] == "forbidden_backward_temporal_owner_timing_present" + + +def test_fabric_audit_mixed_stack_gate_fails_when_stack_is_faster() -> None: + gate = _mixed_stack_baseline_gate( + { + "status": "ok", + "mixed_stack_baseline": {"status": "ok", "tokens_per_s": 20.0}, + "mixed_stack_status": "ok", + "mixed_stack_param_error": 0.01, + "mixed_fabric_stack_ratio": 0.5, + } + ) + + assert gate is not None + assert gate["status"] == "fail" + assert gate["reason"] == "mixed_fabric_tokens_not_above_stack" + + +def test_fabric_audit_mixed_stack_gate_passes_when_fabric_is_faster() -> None: + gate = _mixed_stack_baseline_gate( + { + "status": "ok", + "mixed_stack_baseline": {"status": "ok", "tokens_per_s": 10.0}, + "mixed_stack_status": "ok", + "mixed_stack_param_error": 0.01, + "mixed_fabric_stack_ratio": 1.25, + } + ) + + assert gate is not None + assert gate["status"] == "pass" + assert gate["mixed_fabric_stack_ratio"] == 1.25 + + +def test_fabric_audit_mixed_stack_gate_requires_matched_params() -> None: + gate = _mixed_stack_baseline_gate( + { + "status": "ok", + "mixed_stack_baseline": {"status": "ok", "tokens_per_s": 10.0}, + "mixed_stack_status": "ok", + "mixed_stack_param_error": 0.20, + "mixed_fabric_stack_ratio": 1.25, + } + ) + + assert gate is not None + assert gate["status"] == "fail" + assert gate["reason"] == "mixed_stack_params_not_matched" + + +def test_fabric_audit_shared_temporal_coverage_gate_rejects_single_only_manifest(tmp_path: Path) -> None: + args = build_parser().parse_args( + [ + "--plan", + "smoke", + "--out-dir", + str(tmp_path), + "--baseline-json", + str(tmp_path / "missing-april21.json"), + "--dry-run", + "--require-shared-temporal-coverage", + ] + ) + + assert run_audit(args) == 1 + summary = (tmp_path / "summary.json").read_text() + assert "shared_temporal_owner_requires_single_and_mixed_population_coverage" in summary + + +def test_fabric_audit_shared_temporal_coverage_gate_accepts_shared_manifest(tmp_path: Path) -> None: + args = build_parser().parse_args( + [ + "--plan", + "smoke", + "--out-dir", + str(tmp_path), + "--baseline-json", + str(tmp_path / "missing-april21.json"), + "--population-modes", + "single,mixed", + "--dry-run", + "--require-shared-temporal-coverage", + ] + ) + + assert run_audit(args) == 0 + summary = json.loads((tmp_path / "summary.json").read_text()) + assert summary["shared_temporal_coverage_gate"]["status"] == "pass" + assert summary["population_modes"] == ["mixed", "single"] + + +def test_fabric_audit_dry_run_writes_manifest_and_summary(tmp_path: Path) -> None: + args = build_parser().parse_args( + [ + "--plan", + "smoke", + "--out-dir", + str(tmp_path), + "--baseline-json", + str(tmp_path / "missing-april21.json"), + "--reset-modes", + "absent,present", + "--dry-run", + ] + ) + + assert run_audit(args) == 0 + assert (tmp_path / "manifest.json").exists() + assert (tmp_path / "summary.json").exists() diff --git a/tests/test_fabric_backend_boundaries.py b/tests/test_fabric_backend_boundaries.py new file mode 100644 index 00000000..0e3e11f3 --- /dev/null +++ b/tests/test_fabric_backend_boundaries.py @@ -0,0 +1,3068 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from cortical.fabric.config import Config +from cortical.fabric.backend.cuda.sequence_surface.compiler.allocation_audit import ( + assert_registered_program_allocations_are_classified, +) +from cortical.fabric.backend.message_rules import validate_message_rule_lowering_catalog_header + + +TEMPORAL_ENGINE_SOURCES = ("src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py",) + +DISALLOWED_TEMPORAL_ENGINE_SELECTORS = ( + "axon", + "benchmark", + "cell_kind", + "hidden_size_policy", + "mixed_pop", + "native_cell_kind", + "population_name", + "pop_name", + "single_pop", + "slstm", +) + + +def _function_source(source_text: str, function_name: str) -> str: + start = source_text.index(f"def {function_name}(") + next_def = source_text.find("\ndef ", start + 1) + return source_text[start:] if next_def < 0 else source_text[start:next_def] + + +def _registered_program_kernel_source_text(repo_root: Path) -> str: + scan_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket" + source_names = ( + "flat_bucket_registered_program_kernels.cu", + "registered_program/common.cuh", + "registered_program/constants_and_checks.cuh", + "registered_program/executor_span_decode.cuh", + "registered_program/memory_runtime_buffers.cuh", + "registered_program/native_callable_bindings.cuh", + "registered_program/program_spans_and_handlers.cuh", + "registered_program/reverse_artifacts_and_resets.cuh", + "registered_program/program_tensor_access.cuh", + "registered_program/transition_device_kernels.cuh", + "registered_program/transition_math_helpers.cuh", + "registered_program/layout_kernels.cuh", + "registered_program/operator_declarations.cuh", + "registered_program/transition_forward_program.cuh", + "registered_program/forward_program.cuh", + "registered_program/native_callables/message_forward_strategies.cuh", + "registered_program/native_callables/readout_forward_strategies.cuh", + "registered_program/backward_surface_steps.cuh", + "registered_program/native_callables/message_reverse_strategies.cuh", + "registered_program/native_callables/readout_reverse_strategies.cuh", + "registered_program/transition_primitive_forward_ops.cuh", + "registered_program/transition_reverse_handlers.cuh", + "registered_program/transition_reverse_program.cuh", + "registered_program/parameter_reducer_program.cuh", + "registered_program/backward_program.cuh", + "registered_program/operator_exports.cuh", + ) + return "\n".join((scan_root / source_name).read_text(encoding="utf-8") for source_name in source_names) + + +def _cxx_inline_function_source(source_text: str, function_name: str) -> str: + start = source_text.index(f" {function_name}(") + next_inline = source_text.find("\ninline ", start + 1) + next_static = source_text.find("\nstatic ", start + 1) + candidates = tuple(index for index in (next_inline, next_static) if index >= 0) + end = min(candidates) if candidates else len(source_text) + return source_text[start:end] + + +def test_temporal_engine_table_sources_do_not_use_cell_or_benchmark_route_selectors() -> None: + repo_root = Path(__file__).resolve().parents[1] + violations: list[str] = [] + for relative_path in TEMPORAL_ENGINE_SOURCES: + source_path = repo_root / relative_path + source_text = source_path.read_text(encoding="utf-8").lower() + matches = [selector for selector in DISALLOWED_TEMPORAL_ENGINE_SELECTORS if selector in source_text] + if matches: + violations.append(f"{relative_path}: {', '.join(matches)}") + assert not violations, "temporal engine source selectors are not flat-bucket generic: " + "; ".join(violations) + table_text = (repo_root / TEMPORAL_ENGINE_SOURCES[0]).read_text(encoding="utf-8") + assert 'row.primitive not in {"gated_logspace_recurrence", "diag_rtu", "diagonal_recurrence"}' not in table_text + assert "temporal_transition_tape_kind(row.primitive) is None" in table_text + + +def test_lattice_config_cleanup_stays_out_of_backend_runtime() -> None: + repo_root = Path(__file__).resolve().parents[1] + fabric_root = repo_root / "src" / "cortical" / "fabric" + cuda_root = fabric_root / "backend" / "cuda" + runtime_core = (fabric_root / "runtime" / "core.py").read_text(encoding="utf-8") + backend_ir = (fabric_root / "backend" / "ir.py").read_text(encoding="utf-8") + anatomy = (fabric_root / "anatomy.py").read_text(encoding="utf-8") + lattice_anatomy = (fabric_root / "graphs" / "lattice_anatomy.py").read_text(encoding="utf-8") + cuda_init = (cuda_root / "__init__.py").read_text(encoding="utf-8") + projection_init = (cuda_root / "projection" / "__init__.py").read_text(encoding="utf-8") + + assert "graphs.lattice_anatomy" in anatomy + assert "def _build_local_sender_table(" not in runtime_core + assert "spec.config.coord_shape" not in runtime_core + assert "spec.config.wrap" not in runtime_core + assert "wrap=bool(spec.config.wrap)" not in backend_ir + assert "build_lattice_local_sender_table(" in lattice_anatomy + assert not (cuda_root / "reference").exists() + assert not (cuda_root / "registry").exists() + assert not (cuda_root / "message_passing").exists() + assert "fabric_local_message_cuda" not in cuda_init + assert "fabric_grouped_projection_cuda" not in cuda_init + assert "register_readout_backend" not in projection_init + with pytest.raises(ValidationError): + Config(width=4, height=4, hidden_size=8) # type: ignore[call-arg] + + +def test_blueprint_normalization_is_not_old_config_translation() -> None: + repo_root = Path(__file__).resolve().parents[1] + blueprint_text = (repo_root / "src" / "cortical" / "fabric" / "blueprint.py").read_text(encoding="utf-8") + normalize_body = _function_source(blueprint_text, "normalize") + + assert "_blueprint_to_config" not in blueprint_text + assert "_BlueprintLowering" in blueprint_text + assert "_lower_blueprint_declaration(source)" in normalize_body + assert "init(_runtime_section_container(lowering))" in normalize_body + assert "source.message_passing.to_ir(" not in normalize_body + assert "lowering.message_passing.to_ir(" in normalize_body + + +def test_cuda_message_rule_ir_distinguishes_context_gate_from_nudge() -> None: + repo_root = Path(__file__).resolve().parents[1] + nn_ir = (repo_root / "src/cortical/fabric/backend/cuda/nn/ir.cuh").read_text(encoding="utf-8") + message_rule_generator = (repo_root / "src/cortical/fabric/backend/message_rules.py").read_text(encoding="utf-8") + catalog = (repo_root / "src/cortical/fabric/backend/cuda/nn/message_rule_lowering_catalog.cuh").read_text( + encoding="utf-8" + ) + classifier = _cxx_inline_function_source(nn_ir, "classify_message_rule_lowering") + + validate_message_rule_lowering_catalog_header(catalog) + assert "enum class MessageRuleLoweringKind" not in nn_ir + assert "DotProductFixedSlotContextNudge" not in nn_ir + assert "DotProductFixedSlotContextGate" not in nn_ir + assert "Generated by cortical.fabric.backend.message_rules.message_rule_lowering_catalog_header_text" in catalog + assert "message_rule_matches_fixed_slot_context_dot_product" not in nn_ir + assert "message_rule_matches_dynamic_dot_product" not in nn_ir + assert "lowering_kind !=" not in nn_ir + assert "registered_message_rule_lowering_patterns_begin" in classifier + assert "message_rule_matches_lowering_pattern" in classifier + assert "MessageRuleLoweringKind::" not in classifier + assert '"message_query_nudge_scale"' not in classifier + assert '"message_query_context_gate"' not in classifier + assert '"message_query_nudge_scale"' in catalog + assert '"message_query_context_gate"' in catalog + assert "MessageRuleLoweringKind::" not in catalog + assert "_CPP_MESSAGE_LOWERING_KIND" not in message_rule_generator + assert "kDotProductFixedSlotContextNudgeLoweringId" in catalog + assert "kDotProductFixedSlotContextGateLoweringId" in catalog + assert "parameter_indices[-1]" not in message_rule_generator + assert "std::vector parameter_indices" in nn_ir + assert "pattern.parameter_indices[index]" in nn_ir + assert "kDotProductFixedSlotContextNudgeNode13ParameterIndices[] = {4, 5, 6}" in catalog + assert "kDotProductFixedSlotContextGateNode13ParameterIndices[] = {4, 5, 6}" in catalog + assert "has_receiver_public" not in nn_ir + assert "has_sender_slot" not in nn_ir + + +def test_fixed_temporal_scan_extension_sources_were_deleted() -> None: + repo_root = Path(__file__).resolve().parents[1] + scan_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket" + + assert not (scan_root / "flat_bucket_temporal_scan_cuda.py").exists() + assert not (scan_root / "flat_bucket_temporal_scan_binding.cpp").exists() + assert not (scan_root / "flat_bucket_temporal_scan_kernels.cu").exists() + assert not (scan_root / "flat_bucket_layout_cuda.py").exists() + assert not (scan_root / "flat_bucket_layout_binding.cpp").exists() + assert not (scan_root / "flat_bucket_layout_kernels.cu").exists() + assert not (scan_root / "flat_bucket_public_projection_cuda.py").exists() + assert not (scan_root / "flat_bucket_public_projection_binding.cpp").exists() + assert not (scan_root / "flat_bucket_public_projection_kernels.cu").exists() + assert not (scan_root / "flat_bucket_readout_cuda.py").exists() + assert not (scan_root / "flat_bucket_readout_binding.cpp").exists() + assert not (scan_root / "flat_bucket_readout_kernels.cu").exists() + assert not (scan_root / "flat_bucket_temporal_forward_primitives.cuh").exists() + + +def test_fixed_transition_message_reverse_device_loop_was_deleted() -> None: + repo_root = Path(__file__).resolve().parents[1] + temporal_backward_root = repo_root / "src/cortical/fabric/backend/cuda/ops/temporal_backward" + + assert not (temporal_backward_root / "reverse_table.py").exists() + assert not (temporal_backward_root / "flat_bucket_temporal_backward_binding.cpp").exists() + assert not (temporal_backward_root / "flat_bucket_temporal_backward_kernels.cu").exists() + assert not (temporal_backward_root / "flat_bucket_temporal_recurrent_backward_kernels.cu").exists() + assert not (temporal_backward_root / "materialization.py").exists() + assert not (temporal_backward_root / "reductions.py").exists() + assert not (temporal_backward_root / "extension.py").exists() + + +def test_reverse_transition_native_handlers_use_logical_binding_schema() -> None: + repo_root = Path(__file__).resolve().parents[1] + registered_program_kernel_text = _registered_program_kernel_source_text(repo_root) + + for function_name in ( + "run_registered_gated_logspace_reverse_transition_handler", + "run_registered_diag_rtu_reverse_transition_handler", + ): + function_text = _cxx_inline_function_source(registered_program_kernel_text, function_name) + assert "native_callable_program_binding_for(" in function_text + assert "input_binding(" in function_text + assert "parameter_binding(" in function_text + assert "output_binding(" in function_text + assert "inputs[" not in function_text + assert "params[" not in function_text + assert "outputs[" not in function_text + tanh_function_text = _cxx_inline_function_source( + registered_program_kernel_text, + "run_registered_tanh_reverse_transition_handler", + ) + assert "native_callable_program_binding_for(" in tanh_function_text + assert "input_binding(" in tanh_function_text + assert "output_binding(" in tanh_function_text + assert "flat_bucket_registered_program_transition_tanh_backward_cuda(" in tanh_function_text + assert "inputs[" not in tanh_function_text + assert "params[" not in tanh_function_text + assert "outputs[" not in tanh_function_text + + +def test_forward_message_readout_handlers_use_native_strategy_access_schema() -> None: + repo_root = Path(__file__).resolve().parents[1] + registered_program_kernel_text = _registered_program_kernel_source_text(repo_root) + native_callables_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py" + ).read_text(encoding="utf-8") + + assert "logical_name=access.access_name" in native_callables_text + assert "surface=pattern.surface" in native_callables_text + for function_name in ( + "bind_neighborhood_attention_project_message_handler", + "bind_projection_reduction_boundary_readout_handler", + ): + function_text = _cxx_inline_function_source(registered_program_kernel_text, function_name) + assert "program_tensor_for_native_strategy_access(" in function_text + assert "native_callable_binding_schema_rows" in function_text + assert "forward_program_tensor_for_access_opcode" not in function_text + assert "kProgramAccessMessage" not in function_text + assert "kProgramAccessReadout" not in function_text + + +def test_forward_transition_access_uses_compiler_program_access_rows() -> None: + repo_root = Path(__file__).resolve().parents[1] + registered_program_kernel_text = _registered_program_kernel_source_text(repo_root) + forward_program_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/forward_program.cuh" + ).read_text(encoding="utf-8") + access_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/program_tensor_access.cuh" + ).read_text(encoding="utf-8") + + aggregate_block = forward_program_text[ + forward_program_text.index("registered fused forward transition aggregate input") + - 500 : forward_program_text.index("registered fused forward transition aggregate input") + 500 + ] + public_block = forward_program_text[ + forward_program_text.index("registered fused forward transition public output") + - 500 : forward_program_text.index("registered fused forward transition public output") + 500 + ] + assert "temporal_program_access_binding_by_opcode(" in aggregate_block + assert "kProgramAccessTransitionAggregatedMessageInput" in aggregate_block + assert "program_binding_for_native_strategy_access(" not in aggregate_block + assert "program_tensor_for_native_strategy_access(" not in aggregate_block + assert "temporal_program_access_binding_by_opcode(" in public_block + assert "kProgramAccessTransitionPublicStateOutput" in public_block + assert "program_binding_for_native_strategy_access(" not in public_block + assert "program_tensor_for_native_strategy_access(" not in public_block + assert "forward_program_tensor_for_access_opcode" not in registered_program_kernel_text + assert "reverse_program_tensor_for_access_opcode" not in registered_program_kernel_text + assert "temporal_program_access_binding_by_opcode" in access_text + assert "kProgramAccessTransitionAggregatedMessageInput" in registered_program_kernel_text + assert "kProgramAccessTransitionPublicStateOutput" in registered_program_kernel_text + clear_transition_outputs_text = _cxx_inline_function_source( + access_text, + "clear_forward_transition_output_binding_slots", + ) + assert "Binding rows do not carry a surface column" in clear_transition_outputs_text + assert "row[5] < 0" in clear_transition_outputs_text + assert "row[2] != kTransitionSurfaceOpcode" not in clear_transition_outputs_text + clear_dead_inputs_text = _cxx_inline_function_source( + access_text, + "clear_forward_transition_dead_input_binding_slots_after_primitive", + ) + carry_slot_alias_text = _cxx_inline_function_source( + access_text, + "forward_transition_binding_slot_aliases_active_state_carry_source", + ) + assert "forward_transition_binding_slot_aliases_active_state_carry_source(" in clear_dead_inputs_text + assert "source_tensor_index == tensor_index" in carry_slot_alias_text + transition_program_call = forward_program_text[ + forward_program_text.index( + "flat_bucket_registered_temporal_fused_forward_transition_program_cuda(" + ) : forward_program_text.index( + "append_registered_forward_memory_stage_row", + forward_program_text.index("flat_bucket_registered_temporal_fused_forward_transition_program_cuda("), + ) + ] + assert "!return_reverse_artifacts && !return_final_program_tensors" in transition_program_call + assert "forward_transition_has_active_state_carry_sources(" not in transition_program_call + + +def test_message_rule_runtime_materialization_is_declared_by_message_program() -> None: + repo_root = Path(__file__).resolve().parents[1] + runtime_text = (repo_root / "src/cortical/fabric/runtime/core.py").read_text(encoding="utf-8") + program_parameter_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py" + ).read_text(encoding="utf-8") + executor_binding_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py" + ).read_text(encoding="utf-8") + + assert "DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE" not in runtime_text + assert "DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE" not in program_parameter_text + assert "dot_product_fixed_slot_context_nudge" not in executor_binding_text + assert "compiled_lowering_kind" not in executor_binding_text + assert "_install_message_rule_runtime_state" in runtime_text + assert "message_program.runtime_modules" in runtime_text + assert "message_program.runtime_parameters" in runtime_text + assert "message_program.static_tensors" in runtime_text + assert "message_query_slot_proj:" not in runtime_text + assert "message_sender_slot_key_proj:" not in runtime_text + assert 'prefer_projected_message_input = str(getattr(message_program, "output_dim_role"' in program_parameter_text + assert 'key == "output_dim_role"' not in executor_binding_text + assert "parameter_binding.source_bindings" in executor_binding_text + + +def test_reverse_message_readout_helpers_use_native_strategy_access_schema() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_text = "\n".join( + path.read_text(encoding="utf-8") + for path in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_surface_steps.cuh", + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_reverse_strategies.cuh", + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/readout_reverse_strategies.cuh", + ) + ) + broad_source_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/backward_surface_steps.cuh" + ).read_text(encoding="utf-8") + + assert "program_tensor_for_native_strategy_access(" in source_text + assert "registered_native_strategy_row_for_span(" in source_text + assert "reverse_program_tensor_for_access_opcode" not in source_text + assert "kProgramAccessMessage" not in source_text + assert "kProgramAccessReadout" not in source_text + assert "RegisteredReverseMessageStrategy" in source_text + assert "RegisteredReverseReadoutStrategy" in source_text + assert "REGISTERED_TEMPORAL_NATIVE_REVERSE_MESSAGE_CATALOG" in source_text + assert "REGISTERED_TEMPORAL_NATIVE_REVERSE_READOUT_CATALOG" in source_text + assert "registered_reverse_message_strategy_for_native_row(" in source_text + assert "registered_reverse_readout_strategy_for_native_row(" in source_text + + helper_expectations = { + "registered_temporal_backward_readout_step_impl": "strategy.readout(", + "registered_temporal_backward_output_message_step_impl": "strategy.output_message(", + "registered_temporal_backward_recurrent_kv_projection_step_impl": "strategy.recurrent_kv(", + "registered_temporal_backward_recurrent_message_step_impl": "strategy.recurrent_message(", + "registered_temporal_backward_initial_recurrent_kv_projection_step_impl": "strategy.initial_recurrent_kv(", + "registered_temporal_backward_boundary_kv_projection_step_impl": "strategy.boundary_kv(", + } + for function_name, expected_dispatch in helper_expectations.items(): + function_text = _cxx_inline_function_source(broad_source_text, function_name) + assert expected_dispatch in function_text + assert "flat_bucket_registered_backward_partitioned_attention_cuda(" not in function_text + assert "flat_bucket_registered_backward_sparse_attention_cuda(" not in function_text + assert "flat_bucket_registered_backward_sender_kv_projection_cuda(" not in function_text + assert "flat_bucket_registered_backward_readout_layout_projection_cuda(" not in function_text + readout_step_text = _cxx_inline_function_source(broad_source_text, "registered_temporal_backward_readout_step_impl") + output_message_step_text = _cxx_inline_function_source( + broad_source_text, + "registered_temporal_backward_output_message_step_impl", + ) + assert "readout_value_to_output_weight" not in readout_step_text + assert "readout_output_query" not in output_message_step_text + readout_strategy_text = _cxx_inline_function_source( + source_text, + "run_projection_reduction_boundary_readout_backward", + ) + output_message_strategy_text = _cxx_inline_function_source( + source_text, + "run_projection_reduction_boundary_output_message_backward", + ) + assert "readout_value_to_output_weight" in readout_strategy_text + assert "readout_output_query" in output_message_strategy_text + + +def test_fused_cuda_launch_contract_is_compiler_owned() -> None: + repo_root = Path(__file__).resolve().parents[1] + program_execution_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_execution.py" + ).read_text(encoding="utf-8") + native_callables_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py" + ).read_text(encoding="utf-8") + executor_patterns_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py" + ).read_text(encoding="utf-8") + runtime_metadata_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py" + ).read_text(encoding="utf-8") + registered_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py" + ).read_text(encoding="utf-8") + registered_program_cuda_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py" + ).read_text(encoding="utf-8") + registered_program_binding_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + registered_program_kernel_text = _registered_program_kernel_source_text(repo_root) + registered_native_catalog_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh" + ).read_text(encoding="utf-8") + message_reverse_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_reverse_strategies.cuh" + ).read_text(encoding="utf-8") + parameter_reducer_program_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/parameter_reducer_program.cuh" + ).read_text(encoding="utf-8") + runtime_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py" + ).read_text(encoding="utf-8") + + assert "class TemporalFusedCudaLaunchContract" in program_execution_text + assert "transition_program_layer_blocker_codes" in program_execution_text + assert "transition_program_layer_missing_symbols" in program_execution_text + assert "_transition_primitive_names_from_rows(" in program_execution_text + assert "registered_fused_program_requires_transition_primitives_callable_from_program_layer_cuda_body" in ( + program_execution_text + ) + assert "memory_plan: TemporalMemoryLivenessPlan" in program_execution_text + assert "primitive_rows" in program_execution_text + assert "forward_executor_rows" in program_execution_text + assert "reverse_executor_rows" in program_execution_text + assert "temporal_forward_executor_handler_rows_tensor(" in program_execution_text + assert "temporal_reverse_executor_handler_rows_tensor(" in program_execution_text + assert "_FORWARD_HANDLER_KIND_OPCODE" not in program_execution_text + assert "_REVERSE_HANDLER_KIND_OPCODE" not in program_execution_text + assert "temporal_executor_strategy_registry().match_forward(" in program_execution_text + assert "temporal_executor_strategy_registry().match_reverse(" in program_execution_text + assert "handler_kind_opcode:" in executor_patterns_text + assert "handler_capabilities:" in executor_patterns_text + assert "handler_effects:" in executor_patterns_text + assert "cxx_entrypoints:" in executor_patterns_text + assert "cxx_entrypoint_phases:" in executor_patterns_text + assert "_required_message_cxx_entrypoint_phases(" in executor_patterns_text + assert "_FORWARD_MESSAGE_IMPLEMENTATION_SYMBOLS" not in native_callables_text + assert "_FORWARD_READOUT_IMPLEMENTATION_SYMBOLS" not in native_callables_text + assert "_REVERSE_MESSAGE_IMPLEMENTATION_SYMBOLS" not in native_callables_text + assert "_REVERSE_READOUT_IMPLEMENTATION_SYMBOLS" not in native_callables_text + assert "_TRANSITION_REVERSE_IMPLEMENTATION_SYMBOL" not in native_callables_text + assert 'getattr(pattern, "cxx_entrypoints", ())' in native_callables_text + assert 'getattr(pattern, "cxx_entrypoint_phases", ())' in native_callables_text + forward_message_catalog = _function_source(native_callables_text, "_emit_forward_message_catalog") + reverse_message_catalog = _function_source(native_callables_text, "_emit_reverse_message_catalog") + assert "definition.cxx_entrypoint_phases" in forward_message_catalog + assert 'required_phases = ("bind", "recurrent_kv", "message")' in forward_message_catalog + assert '"keyless_readout_message"' in forward_message_catalog + assert '"direct_keyless_readout_message"' in forward_message_catalog + assert "_cxx_entrypoints_by_phase(" in reverse_message_catalog + assert "_cxx_entrypoints(definition, count=3)" not in forward_message_catalog + assert "_cxx_entrypoints(definition, count=4)" not in reverse_message_catalog + assert "forward_handler_rows" in program_execution_text + assert "reverse_handler_rows" in program_execution_text + assert "temporal_native_executor_strategy_rows_tensor(" in program_execution_text + assert "class TemporalNativeCallableDefinition" in native_callables_text + assert "class TemporalNativeCallableOutputDefinition" in native_callables_text + assert "temporal_native_callable_catalog_rows_tensor(" in native_callables_text + assert "temporal_native_callable_output_rows_tensor(" in native_callables_text + assert "class TemporalReverseSpanOutputRow" in program_execution_text + assert "temporal_reverse_span_output_rows_tensor(" in program_execution_text + assert "class TemporalReverseOutputRouteRow" in program_execution_text + assert "temporal_reverse_output_route_rows_tensor(" in program_execution_text + assert "class TemporalForwardArtifactRouteRow" in program_execution_text + assert "temporal_forward_artifact_route_rows_tensor(" in program_execution_text + assert "class TemporalForwardArtifactMergeRow" in program_execution_text + assert "temporal_forward_artifact_merge_rows_tensor(" in program_execution_text + assert "class TemporalForwardOutputRouteRow" in program_execution_text + assert "temporal_forward_output_route_rows_tensor(" in program_execution_text + assert "class TemporalReverseArtifactConsumerRouteRow" in program_execution_text + assert "temporal_reverse_artifact_consumer_route_rows_tensor(" in program_execution_text + assert "class TemporalReverseParameterReducerRouteRow" in program_execution_text + assert "temporal_reverse_parameter_reducer_route_rows_tensor(" in program_execution_text + assert "validate_temporal_native_callable_catalog_coverage(" in program_execution_text + assert "validate_temporal_native_callable_output_contract_coverage(" in program_execution_text + assert '"native_callable_catalog_rows"' in program_execution_text + assert '"native_callable_output_rows"' in program_execution_text + assert "native_strategy_rows" in program_execution_text + assert '"native_strategy_rows"' in program_execution_text + assert "native_strategy_rows" in runtime_metadata_text + assert "_last_flat_bucket_temporal_native_strategy_rows" in runtime_metadata_text + assert "_last_flat_bucket_temporal_native_callable_catalog_rows" in runtime_metadata_text + assert "_last_flat_bucket_temporal_native_callable_output_rows" in runtime_metadata_text + assert "native_strategy_rows: torch.Tensor" in registered_executor_text + assert "native_callable_catalog_rows: torch.Tensor" in registered_executor_text + assert "native_callable_output_rows: torch.Tensor" in registered_executor_text + assert "reverse_span_output_rows: torch.Tensor" in registered_executor_text + assert "reverse_output_route_rows: torch.Tensor" in registered_executor_text + assert "forward_artifact_route_rows: torch.Tensor" in registered_executor_text + assert "forward_artifact_merge_rows: torch.Tensor" in registered_executor_text + assert "forward_output_route_rows: torch.Tensor" in registered_executor_text + assert "reverse_artifact_consumer_route_rows: torch.Tensor" in registered_executor_text + assert "reverse_parameter_reducer_route_rows: torch.Tensor" in registered_executor_text + assert "native_strategy_rows=executor_program.native_strategy_rows" in registered_executor_text + assert "reverse_span_output_rows=executor_program.reverse_span_output_rows" in registered_executor_text + assert "route_rows=executor_program.reverse_output_route_rows" in registered_executor_text + assert "_reverse_parameter_reducer_routed_span_output_tensor(" in registered_executor_text + assert "reducer_route_rows=executor_program.reverse_parameter_reducer_route_rows" in registered_executor_text + assert "reverse_surface_parameter_reducers_require_per_span_output_groups" not in registered_executor_text + assert "readout_front_groups" in registered_executor_text + assert "message_front_groups" in registered_executor_text + assert "message_boundary_groups" in registered_executor_text + assert "forward_artifact_route_rows=executor_program.forward_artifact_route_rows" in registered_executor_text + assert "forward_artifact_merge_rows=executor_program.forward_artifact_merge_rows" in registered_executor_text + assert "forward_output_route_rows=executor_program.forward_output_route_rows" in registered_executor_text + assert "reverse_artifact_tensor_store_output_cells_for_step(" in registered_executor_text + assert 'if role in {"output_msg", "output_cells"}:' in registered_executor_text + assert "Registered fused forward output artifacts are route-owned" in registered_executor_text + output_grad_window_body = _function_source( + registered_executor_text, + "_registered_reverse_program_output_cell_grad_window_from_tensor_store", + ) + assert '_reverse_artifact_tensor_for_step(\n tensor_store,\n role="output_cells"' not in ( + output_grad_window_body + ) + assert ( + "reverse_artifact_consumer_route_rows=executor_program.reverse_artifact_consumer_route_rows" + in registered_executor_text + ) + assert "forward_executor_binding_rows" in program_execution_text + assert "reverse_executor_binding_rows" in program_execution_text + assert "memory_liveness_plan" in program_execution_text + assert 'demotion_policy="fail_closed_no_unregistered_program_demotion"' in program_execution_text + assert 'unsupported_policy="typed_strategy_and_binding_rejection"' in program_execution_text + assert "_last_flat_bucket_temporal_fused_cuda_launch_contract" in runtime_metadata_text + assert "_last_flat_bucket_temporal_fused_cuda_launch_contract" in registered_executor_text + assert "_last_flat_bucket_temporal_memory_liveness_rows" in runtime_metadata_text + assert "_last_flat_bucket_temporal_memory_liveness_rows" in registered_executor_text + assert "flat_bucket_temporal_fused_cuda_launch_contract" in runtime_executor_text + assert "registered_temporal_fused_forward_program_validate_cuda(" in registered_program_cuda_text + assert "registered_temporal_fused_backward_program_validate_cuda(" in registered_program_cuda_text + assert "_require_native_strategy_rows(native_strategy_rows)" in registered_program_cuda_text + assert "const at::Tensor& native_strategy_rows" in registered_program_binding_text + assert "has no native strategy row for executor_id=" in registered_program_kernel_text + assert "native strategy row capability/effect masks do not match compiler handler row" in ( + registered_program_kernel_text + ) + assert "native strategy row contract does not match compiler handler row" in registered_program_kernel_text + assert "strategy_id_hash" in registered_program_kernel_text + assert "declared_reverse_span_output_group(" in registered_program_kernel_text + assert "reverse_span_output_tensor_for_role(" in registered_program_kernel_text + assert "materialize_message_key_bank_outputs = recurrent_message.size() > 5" in registered_program_kernel_text + assert "reduced_message_key_bank_outputs" in registered_program_kernel_text + assert "grad_input_key_bank_for_reducer" in registered_program_kernel_text + assert "backward[1] = backward[1].sum(0).contiguous()" in message_reverse_text + assert "backward[3] = backward[3].sum(0).contiguous()" in message_reverse_text + assert "[N,2H] or [B,N,2H]" in parameter_reducer_program_text + assert "program_access_count" in registered_program_kernel_text + assert "state_carry_rule_count" in registered_program_kernel_text + assert "native_callable_hash" in registered_program_kernel_text + assert "primitive_backward_callable_hash" in registered_program_kernel_text + assert "stable_native_callable_id" in executor_patterns_text + assert "temporal_strategy_id_hash(" in program_execution_text + assert "strategy[12] != selected_handler[7]" in registered_program_kernel_text + assert "program access row count does not match compiler strategy contract" in registered_program_kernel_text + assert "forward transition state-carry row count exceeds compiler strategy contract" in ( + registered_program_kernel_text + ) + assert "registered_native_strategy_row_for_span(" in registered_program_kernel_text + assert "registered_forward_strategy_callable_matches_native_row(" in registered_program_kernel_text + assert "registered_reverse_callable_matches_native_strategy(" in registered_program_kernel_text + assert "registered_transition_backward_callable_hash_for_primitive(" in registered_program_kernel_text + assert "require_registered_reverse_primitive_binding_contract(" in registered_program_kernel_text + assert "require_native_callable_binding_vector_contract(" in registered_program_kernel_text + assert "executor.native_callable_hash" in registered_program_kernel_text + assert "int64_t schema_version" in registered_program_kernel_text + assert ( + "executor.native_callable_hash,\n schema_version,\n return_state_grads" + in registered_program_kernel_text + ) + assert "native.reverse.transition_" not in registered_program_kernel_text + assert "has no native callable for compiler-emitted strategy row" in registered_program_kernel_text + assert 'registered_temporal_stable_id_hash_constexpr("native.forward.msg_' in registered_native_catalog_text + assert "kStrategyHash" not in registered_program_kernel_text + forward_message_callable_struct = registered_program_kernel_text.split( + "struct RegisteredForwardMessageCarrierStrategy", + 1, + )[1].split("};", 1)[0] + forward_readout_callable_struct = registered_program_kernel_text.split( + "struct RegisteredForwardReadoutStrategy", + 1, + )[1].split("};", 1)[0] + reverse_transition_callable_struct = registered_program_kernel_text.split( + "struct RegisteredTransitionReversePrimitiveExecutor", + 1, + )[1].split("};", 1)[0] + assert "strategy_id_hash" not in forward_message_callable_struct + assert "program_access_count" not in forward_message_callable_struct + assert "state_carry_rule_count" not in forward_message_callable_struct + assert "handler_kind" not in forward_message_callable_struct + assert "executor_id" not in forward_message_callable_struct + assert "surface_opcode" not in forward_message_callable_struct + assert "primitive_opcode" not in forward_message_callable_struct + assert "strategy_id_hash" not in forward_readout_callable_struct + assert "program_access_count" not in forward_readout_callable_struct + assert "state_carry_rule_count" not in forward_readout_callable_struct + assert "handler_kind" not in forward_readout_callable_struct + assert "executor_id" not in forward_readout_callable_struct + assert "surface_opcode" not in forward_readout_callable_struct + assert "primitive_opcode" not in forward_readout_callable_struct + assert "strategy_id_hash" not in reverse_transition_callable_struct + assert "program_access_count" not in reverse_transition_callable_struct + assert "state_carry_rule_count" not in reverse_transition_callable_struct + assert "handler_kind" not in reverse_transition_callable_struct + assert "executor_id" not in reverse_transition_callable_struct + assert "surface_opcode" not in reverse_transition_callable_struct + assert "primitive_opcode" not in reverse_transition_callable_struct + assert "primitive_backward_callable_hash" in reverse_transition_callable_struct + assert "strategy_hash=" in registered_program_kernel_text + assert "registered gated reverse transition handler expects params" not in registered_program_kernel_text + assert "registered diag reverse transition handler expects params" not in registered_program_kernel_text + assert "def registered_temporal_fused_forward_program_cuda(" in registered_program_cuda_text + assert "def registered_temporal_fused_forward_transition_program_cuda(" in registered_program_cuda_text + assert "def registered_temporal_fused_backward_program_cuda(" in registered_program_cuda_text + assert "def registered_temporal_fused_forward_program_transition_step_cuda(" not in registered_program_cuda_text + assert "def registered_temporal_fused_backward_program_transition_step_cuda(" not in registered_program_cuda_text + assert "def registered_temporal_fused_backward_program_transition_stage_cuda(" not in registered_program_cuda_text + assert "def registered_temporal_fused_backward_program_output_grad_cuda(" not in registered_program_cuda_text + assert "def registered_temporal_fused_backward_program_readout_message_kv_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "def registered_temporal_fused_reverse_program_window_step_cuda(" not in registered_program_cuda_text + assert "def registered_temporal_fused_reverse_program_transition_boundary_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "def registered_temporal_fused_backward_program_recurrent_message_boundary_initial_kv_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "def registered_temporal_fused_backward_program_readout_step_cuda(" not in registered_program_cuda_text + assert ( + "def registered_temporal_fused_backward_program_output_message_step_cuda(" not in registered_program_cuda_text + ) + assert "def registered_temporal_fused_backward_program_recurrent_kv_projection_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "def registered_temporal_fused_backward_program_recurrent_message_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "def registered_temporal_fused_backward_program_boundary_kv_projection_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "def registered_temporal_fused_backward_program_recurrent_message_initial_kv_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "def registered_temporal_fused_backward_program_initial_recurrent_kv_projection_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "fused_forward_program_validate" in registered_program_cuda_text + assert "fused_backward_program_validate" in registered_program_cuda_text + assert "fused_forward_program_execute" in registered_program_cuda_text + assert "fused_forward_transition_program_execute" in registered_program_cuda_text + assert "fused_backward_program_execute" in registered_program_cuda_text + assert "fused_reverse_program_full_step" not in registered_program_cuda_text + assert "fused_forward_program_transition_step" not in registered_program_cuda_text + assert "fused_forward_program_transition_step" not in registered_program_binding_text + assert "fused_backward_program_transition_step" not in registered_program_cuda_text + assert "fused_backward_program_transition_step" not in registered_program_binding_text + assert "fused_backward_program_transition_stage" not in registered_program_cuda_text + assert "fused_backward_program_output_grad_window" not in registered_program_cuda_text + assert "fused_backward_program_readout_message_kv_step" not in registered_program_cuda_text + assert "fused_reverse_program_window_step" not in registered_program_cuda_text + assert "fused_reverse_program_transition_boundary_step" not in registered_program_cuda_text + assert "fused_backward_program_recurrent_message_boundary_initial_kv_step" not in registered_program_cuda_text + assert "fused_backward_program_readout_step" not in registered_program_cuda_text + assert "fused_backward_program_output_message_step" not in registered_program_cuda_text + assert "fused_backward_program_recurrent_kv_projection_step" not in registered_program_cuda_text + assert "fused_backward_program_recurrent_message_step" not in registered_program_cuda_text + assert "fused_backward_program_boundary_kv_projection_step" not in registered_program_cuda_text + assert "fused_backward_program_recurrent_message_initial_kv_step" not in registered_program_cuda_text + assert "fused_backward_program_initial_recurrent_kv_projection_step" not in registered_program_cuda_text + assert "def registered_backward_sender_kv_projection_cuda(" not in registered_program_cuda_text + assert "def registered_backward_partitioned_attention_cuda(" not in registered_program_cuda_text + assert "def registered_backward_sparse_attention_cuda(" not in registered_program_cuda_text + assert "def registered_backward_readout_layout_projection_cuda(" not in registered_program_cuda_text + assert "backward_sender_kv_projection" not in registered_program_cuda_text + assert "backward_partitioned_attention" not in registered_program_cuda_text + assert "backward_sparse_attention" not in registered_program_cuda_text + assert "backward_readout_layout_projection" not in registered_program_cuda_text + assert "program_tensor_binding_rows" in registered_program_cuda_text + direct_transition_entrypoints = ( + "registered_program_transition_linear_forward_cuda", + "registered_program_transition_linear_backward_cuda", + "registered_program_transition_diag_rtu_forward_cuda", + "registered_program_transition_diag_rtu_backward_cuda", + "registered_program_transition_gated_logspace_recurrence_forward_cuda", + "registered_program_transition_gated_logspace_recurrence_backward_cuda", + "registered_program_transition_norm_or_identity_forward_cuda", + "registered_program_transition_norm_or_identity_backward_cuda", + "registered_program_transition_tanh_forward_cuda", + "registered_program_transition_tanh_backward_cuda", + "registered_program_transition_recurrent_matmul_forward_cuda", + "registered_program_transition_recurrent_matmul_backward_cuda", + ) + for entrypoint in direct_transition_entrypoints: + assert f"def {entrypoint}(" not in registered_program_cuda_text + assert f"&flat_bucket_{entrypoint}" not in registered_program_binding_text + direct_transition_pybind_names = ( + "program_transition_linear_forward", + "program_transition_linear_backward", + "program_transition_diag_rtu_forward", + "program_transition_diag_rtu_backward", + "program_transition_gated_logspace_recurrence_forward", + "program_transition_gated_logspace_recurrence_backward", + "program_transition_norm_or_identity_forward", + "program_transition_norm_or_identity_backward", + "program_transition_tanh_forward", + "program_transition_tanh_backward", + "program_transition_recurrent_matmul_forward", + "program_transition_recurrent_matmul_backward", + ) + for pybind_name in direct_transition_pybind_names: + assert f'"{pybind_name}"' not in registered_program_binding_text + assert "fused_forward_program_validate" in registered_program_binding_text + assert "fused_backward_program_validate" in registered_program_binding_text + assert "fused_backward_program_execute" in registered_program_binding_text + assert "fused_reverse_program_full_step" not in registered_program_binding_text + assert "fused_backward_program_transition_stage" not in registered_program_binding_text + assert "fused_backward_program_output_grad_window" not in registered_program_binding_text + assert "fused_backward_program_readout_message_kv_step" not in registered_program_binding_text + assert "fused_reverse_program_window_step" not in registered_program_binding_text + assert "fused_reverse_program_transition_boundary_step" not in registered_program_binding_text + assert "fused_backward_program_recurrent_message_boundary_initial_kv_step" not in registered_program_binding_text + assert "fused_backward_program_readout_step" not in registered_program_binding_text + assert "fused_backward_program_output_message_step" not in registered_program_binding_text + assert "fused_backward_program_recurrent_kv_projection_step" not in registered_program_binding_text + assert "fused_backward_program_recurrent_message_initial_kv_step" not in registered_program_binding_text + assert "fused_backward_program_recurrent_message_step" not in registered_program_binding_text + assert "fused_backward_program_initial_recurrent_kv_projection_step" not in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "fused_backward_program_boundary_kv_projection_step" not in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "backward_sender_kv_projection" not in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "backward_partitioned_attention" not in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "backward_sparse_attention" not in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "backward_readout_layout_projection" not in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "fused_forward_program_execute" in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "fused_forward_transition_program_execute" in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "fused_forward_program_transition_step" not in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "registered_forward_readout_layout_epilogue_cuda(" not in registered_program_cuda_text + assert "forward_readout_layout_epilogue" not in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + assert "forward_executor_rows" in registered_program_kernel_text + assert "forward_executor_binding_rows" in registered_program_kernel_text + assert "primitive_rows" in registered_program_kernel_text + assert "reverse_executor_rows" in registered_program_kernel_text + assert "reverse_executor_binding_rows" in registered_program_kernel_text + assert "forward_handler_rows" in registered_program_kernel_text + assert "reverse_handler_rows" in registered_program_kernel_text + assert "memory_liveness_rows" in registered_program_kernel_text + assert "validate_registered_temporal_fused_program(" in registered_program_kernel_text + assert "validate_registered_fused_program_binding_rows(" in registered_program_kernel_text + assert "validate_registered_fused_program_memory_rows(" in registered_program_kernel_text + assert "RegisteredFusedProgramMemoryFacts" in registered_program_kernel_text + assert "registered_fused_program_memory_facts_for_span(" in registered_program_kernel_text + assert "require_registered_surface_memory_contract(" in registered_program_kernel_text + assert "validate_registered_fused_forward_span_memory(" in registered_program_kernel_text + assert "validate_registered_fused_reverse_span_memory(" in registered_program_kernel_text + assert "decode_registered_fused_program_spans(" in registered_program_kernel_text + assert "flat_bucket_registered_temporal_fused_forward_transition_program_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_temporal_fused_forward_program_transition_step_cuda" not in ( + registered_program_kernel_text + ) + assert "program_tensor_binding_rows" in registered_program_kernel_text + assert "program_tensor_for_binding(" in registered_program_kernel_text + assert "set_program_tensor_for_binding(" in registered_program_kernel_text + assert "forward_program_access_rows" in registered_program_cuda_text + assert "reverse_program_access_rows" in registered_program_cuda_text + assert "forward_transition_state_carry_rows" in registered_program_cuda_text + assert "forward_reset_rows" in registered_program_cuda_text + assert "forward_program_access_rows" in registered_program_binding_text + assert "reverse_program_access_rows" in registered_program_binding_text + assert "forward_transition_state_carry_rows" in registered_program_binding_text + assert "forward_reset_rows" in registered_program_binding_text + assert "check_forward_program_access_rows(" in registered_program_kernel_text + assert "check_reverse_program_access_rows(" in registered_program_kernel_text + assert "check_forward_reset_rows(" in registered_program_kernel_text + assert "program_binding_for_native_strategy_access(" in registered_program_kernel_text + assert "program_tensor_for_native_strategy_access(" in registered_program_kernel_text + assert "forward_program_tensor_for_access_opcode(" not in registered_program_kernel_text + assert "reverse_program_tensor_for_access_opcode(" not in registered_program_kernel_text + assert "temporal_program_access_binding_by_opcode(" in registered_program_kernel_text + assert "kProgramAccessTransitionAggregatedMessageInput" in registered_program_kernel_text + assert "kProgramAccessTransitionPublicStateOutput" in registered_program_kernel_text + assert "zero_forward_transition_state_inputs_for_reset(" in registered_program_kernel_text + assert "kProgramAccessMessage" not in registered_program_kernel_text + assert "kProgramAccessReadout" not in registered_program_kernel_text + assert "kExecutorAccessSlot" not in registered_program_kernel_text + assert "apply_forward_transition_state_carry_rows(" in registered_program_kernel_text + assert "surface_opcode_for_executor_bucket(" in registered_program_kernel_text + assert "RegisteredForwardExecutorHandler" in registered_program_kernel_text + assert "registered_forward_executor_handler_for_span(" in registered_program_kernel_text + assert "RegisteredReverseExecutorHandler" in registered_program_kernel_text + assert "registered_reverse_executor_handler_for_span(" in registered_program_kernel_text + assert "RegisteredForwardExecutorHandlerSpec" not in registered_program_kernel_text + assert "registered_forward_handler_spec_for_span(" not in registered_program_kernel_text + assert "RegisteredReverseExecutorHandlerSpec" not in registered_program_kernel_text + assert "registered_reverse_handler_spec_for_span(" not in registered_program_kernel_text + assert "RegisteredForwardExecutorHandlerKind" not in registered_program_kernel_text + assert "RegisteredReverseExecutorHandlerKind" not in registered_program_kernel_text + assert "registered_forward_required_capability_for_surface(" in registered_program_kernel_text + assert "registered_reverse_required_capability_for_surface(" in registered_program_kernel_text + assert "registered.forward.compiler_handler" in registered_program_kernel_text + assert "registered.reverse.compiler_handler" in registered_program_kernel_text + assert "static constexpr RegisteredForwardExecutorHandler handlers[]" not in registered_program_kernel_text + assert "static constexpr RegisteredReverseExecutorHandler handlers[]" not in registered_program_kernel_text + assert "registered_forward_handler_name(" not in registered_program_kernel_text + assert "registered_reverse_handler_name(" not in registered_program_kernel_text + assert "span.handler_kind" in registered_program_kernel_text + assert "span.handler_flags" in registered_program_kernel_text + assert "required_effect_mask" in registered_program_kernel_text + assert "require_registered_handler_effect_contract(" in registered_program_kernel_text + assert "handler_effect:state_read" in registered_program_kernel_text + assert "validate_registered_fused_forward_program_dispatch(" in registered_program_kernel_text + assert "validate_registered_fused_forward_program_shape(" not in registered_program_kernel_text + assert "registered fused forward program emitted unexpected output count" in registered_program_kernel_text + assert "flat_bucket_registered_temporal_fused_forward_transition_program_cuda(" in registered_program_kernel_text + assert "RegisteredTransitionForwardPrimitiveExecutor" in registered_program_kernel_text + assert "registered_native_transition_forward_primitive_catalog_begin(" in registered_program_kernel_text + assert '#include "../flat_bucket_registered_native_callables.cuh"' in registered_program_kernel_text + assert "kRegisteredNativeTransitionForwardPrimitiveCatalog" in registered_native_catalog_text + assert "kRegisteredNativeForwardMessageCatalog" in registered_native_catalog_text + assert "kRegisteredNativeForwardReadoutCatalog" in registered_native_catalog_text + assert "kRegisteredNativeTransitionReversePrimitiveCatalog" in registered_native_catalog_text + assert "kRegisteredNativeParameterReducerCatalog" in registered_native_catalog_text + assert "kRegisteredNativeTransitionTrainableReducerCatalog" in registered_native_catalog_text + assert "registered_transition_forward_primitive_executor_for_callable(" in registered_program_kernel_text + assert "registered_transition_forward_callable_hash_for_primitive(" in registered_program_kernel_text + assert "registered_transition_backward_callable_hash_for_primitive(" in registered_program_kernel_text + assert "run_registered_transition_forward_primitive_executor(" in registered_program_kernel_text + forward_transition_program_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "std::vector flat_bucket_registered_temporal_fused_forward_transition_program_cuda(" + ) : registered_program_kernel_text.index("namespace {\n\nstruct RegisteredForwardStrategyTensor") + ] + assert "run_registered_transition_forward_primitive_executor(" in forward_transition_program_body + assert "transition_primitive_callable_rows" in forward_transition_program_body + assert "if (opcode ==" not in forward_transition_program_body + assert "else if (opcode ==" not in forward_transition_program_body + forward_message_state = registered_program_kernel_text[ + registered_program_kernel_text.index( + "struct RegisteredForwardMessageExecutorState" + ) : registered_program_kernel_text.index("struct RegisteredForwardReadoutExecutorState") + ] + forward_readout_state = registered_program_kernel_text[ + registered_program_kernel_text.index( + "struct RegisteredForwardReadoutExecutorState" + ) : registered_program_kernel_text.index("using ForwardMessageBindFn") + ] + assert "RegisteredForwardStrategyTensorList tensors;" in forward_message_state + assert "RegisteredForwardStrategyTensorList cached_tensors;" in forward_message_state + assert "RegisteredForwardStrategyTensorList tensors;" in forward_readout_state + assert "at::Tensor recurrent_q;" not in forward_message_state + assert "at::Tensor message_sender_slot_key;" not in forward_message_state + assert "at::Tensor output_q;" not in forward_readout_state + assert "at::Tensor value_to_output_weight;" not in forward_readout_state + assert "registered_transition_forward_primitive_executor_for_opcode(" not in registered_program_kernel_text + assert "static const RegisteredTransitionForwardPrimitiveExecutor kExecutors[]" not in ( + registered_program_kernel_text + ) + assert "static const RegisteredForwardMessageCarrierStrategy kStrategies[]" not in registered_program_kernel_text + assert "static const RegisteredForwardReadoutStrategy kStrategies[]" not in registered_program_kernel_text + assert "static const RegisteredTransitionReversePrimitiveExecutor kExecutors[]" not in ( + registered_program_kernel_text + ) + assert "static const RegisteredParameterReducerHandler kHandlers[]" not in registered_program_kernel_text + assert "static const RegisteredTransitionTrainableReducerHandler kHandlers[]" not in ( + registered_program_kernel_text + ) + assert "registered_native_forward_message_catalog_begin(" in registered_program_kernel_text + assert "registered_native_forward_readout_catalog_begin(" in registered_program_kernel_text + assert "registered_native_transition_reverse_primitive_catalog_begin(" in registered_program_kernel_text + assert "registered_native_parameter_reducer_catalog_begin(" in registered_program_kernel_text + assert "registered_native_transition_trainable_reducer_catalog_begin(" in registered_program_kernel_text + assert "fused transition program has unsupported primitive opcode" not in registered_program_kernel_text + reverse_transition_program_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "std::vector flat_bucket_registered_temporal_fused_reverse_transition_program_cuda(" + ) : registered_program_kernel_text.index( + "inline void require_registered_transition_forward_span_for_reverse_span(" + ) + ] + assert "transition_primitive_callable_rows" in reverse_transition_program_body + assert "run_registered_transition_reverse_handler(" in reverse_transition_program_body + forward_program_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "std::vector flat_bucket_registered_temporal_fused_forward_program_cuda(" + ) : registered_program_kernel_text.index( + "std::vector flat_bucket_registered_temporal_fused_backward_program_validate_cuda(" + ) + ] + assert "registered_forward_handler_span_indices_by_capability(" in registered_program_kernel_text + assert "require_unique_registered_forward_handler_span_by_capability(" not in registered_program_kernel_text + assert "bind_registered_forward_message_executor_handlers(" in forward_program_body + assert "bind_registered_forward_readout_executor_handlers(" in forward_program_body + assert "validate_forward_artifact_route_rows(" in forward_program_body + assert "validate_forward_artifact_merge_rows(" in forward_program_body + assert "validate_forward_output_route_rows(" in registered_program_kernel_text + assert "forward_output_route_row_for_readout_executor(" in registered_program_kernel_text + assert "route[9] == output_count" in forward_program_body + assert "unique_forward_output_readout_route(" not in registered_program_kernel_text + assert "single_executable_forward_output_readout_route(" not in registered_program_kernel_text + assert "output_route_readouts" in forward_program_body + assert "forward_artifact_merged_route_row_for_surface_role(" not in forward_program_body + assert "forward_artifact_merge_row_for_surface_bucket_role(" in forward_program_body + assert "forward_artifact_merged_route_row_for_surface_bucket_role(" not in registered_program_kernel_text + assert "forward_artifact_route_row_for(" in forward_program_body + assert "if (route[7] == 0)" in forward_program_body + assert "reverse_artifact_binding_values.push_back(route_row)" in forward_program_body + assert "kReverseArtifactBindingRowColumns" in forward_program_body + routed_reverse_artifact_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "inline at::Tensor reverse_artifact_tensor_for_routed_access_step(" + ) : registered_program_kernel_text.index("inline void validate_temporal_reverse_reset_rows(") + ] + assert "reverse_artifact_consumer_forward_route_row_for(" in routed_reverse_artifact_body + assert "try_reverse_artifact_tensor_for_routed_access_step(" in registered_program_kernel_text + assert "recurrent_kv_forward_recompute(" in registered_program_kernel_text + assert "fused recurrent-message recurrent hidden-before for K/V recompute" in registered_program_kernel_text + assert "registered_forward_transition_output_binding_is_materialized(" not in registered_program_kernel_text + assert "forward_artifact_route_row_for(" not in routed_reverse_artifact_body + assert "forward_artifact_merged_route_row_for_surface_bucket_role(" not in routed_reverse_artifact_body + assert "(void)executor_row_index;" not in routed_reverse_artifact_body + assert "(void)executor_id;" not in routed_reverse_artifact_body + assert "requires compiler artifact merge rows for multiple message spans" not in forward_program_body + assert "requires compiler artifact merge rows for multiple readout spans" not in forward_program_body + assert "message_executors.size() == 1" not in forward_program_body + assert "readout_executors.size() == 1" not in forward_program_body + assert "requires one compiler-owned temporal message-carrier handler" not in registered_program_kernel_text + assert "requires one compiler-owned temporal readout handler" not in registered_program_kernel_text + assert "requires one recurrent-message output row per transition group" not in registered_program_kernel_text + assert "recurrent_msg_output_row_by_group" in registered_program_kernel_text + assert "requires at least one compiler-owned temporal message-carrier handler" in registered_program_kernel_text + assert "requires at least one compiler-owned temporal readout handler" in registered_program_kernel_text + assert "run_registered_forward_message_carrier_handler(" in forward_program_body + assert "run_registered_forward_readout_message_handler(" in forward_program_body + assert "registered_forward_message_carrier_strategy_for_span(" in registered_program_kernel_text + assert "registered_forward_readout_strategy_for_span(" in registered_program_kernel_text + assert "registered_forward_strategy_contract_matches_span(" not in registered_program_kernel_text + assert "executor.primitive_opcode == handler.primitive_opcode)" not in registered_program_kernel_text + assert "forward.message.neighborhood_attention_project.v1" not in registered_program_kernel_text + assert "forward.readout.projection_reduction_boundary.v1" not in registered_program_kernel_text + assert "message_span_index" not in forward_program_body + assert "readout_span_index" not in forward_program_body + assert "found multiple message spans" not in forward_program_body + assert "found multiple readout spans" not in forward_program_body + assert "int64_t message_executor_id" not in registered_program_kernel_text + assert "int64_t readout_executor_id" not in registered_program_kernel_text + assert "int64_t message_executor_row" not in registered_program_kernel_text + assert "int64_t readout_executor_row" not in registered_program_kernel_text + assert "requires a message span" not in registered_program_kernel_text + assert "requires a readout span" not in registered_program_kernel_text + assert "registered_reverse_handler_span_indices_by_capability(" in registered_program_kernel_text + assert "require_unique_registered_reverse_handler_span_by_capability(" not in registered_program_kernel_text + assert "registered_reverse_span_strategies_by_capability(" in registered_program_kernel_text + assert "combine_registered_reverse_span_outputs(" in registered_program_kernel_text + assert "span_output_groups->push_back(span_outputs)" in registered_program_kernel_text + assert "append_readout_front_span_groups" in registered_program_kernel_text + assert "append_message_boundary_span_groups" in registered_program_kernel_text + assert "RegisteredTransitionReversePrimitiveExecutor" in registered_program_kernel_text + assert "registered_transition_reverse_primitive_executor_for_handler(" in registered_program_kernel_text + reverse_transition_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "std::vector flat_bucket_registered_temporal_fused_reverse_transition_program_cuda(" + ) : registered_program_kernel_text.index( + "inline void require_registered_transition_forward_span_for_reverse_span(" + ) + ] + assert "run_registered_transition_reverse_handler(" in reverse_transition_body + assert "registered_transition_reverse_primitive_row_index_for_handler(" in reverse_transition_body + assert "opcode ==" not in reverse_transition_body + assert ( + "fused reverse transition executor must own one recurrence primitive row" not in registered_program_kernel_text + ) + assert "fused reverse transition program received unsupported recurrence primitive opcode" not in ( + registered_program_kernel_text + ) + assert "transition-step program expects one forward transition executor row" not in registered_program_kernel_text + assert "transition-step program expects one reverse transition executor row" not in registered_program_kernel_text + assert "bind_transition_dynamic_tensors_for_handlers(" in registered_program_kernel_text + assert "_transition_dynamic_binding_rows_tensor(" in registered_executor_text + assert "transition_dynamic_binding_row_groups" in registered_executor_text + assert "transition_dynamic_binding_row_groups" in registered_program_cuda_text + assert "transition_dynamic_binding_row_groups" in registered_program_binding_text + assert "kTransitionDynamicSourceStateBeforeArtifact" in registered_program_kernel_text + assert "check_transition_dynamic_binding_rows(" in registered_program_kernel_text + dynamic_binder_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "inline void bind_transition_dynamic_tensors_for_reverse_handler(" + ) : registered_program_kernel_text.index("inline void bind_transition_dynamic_tensors_for_handlers(") + ] + assert "transition_dynamic_binding_rows" in dynamic_binder_body + assert "switch (handler.kind)" not in dynamic_binder_body + assert "RegisteredReverseExecutorHandlerKind::kTransitionGatedLogspaceBackward" not in dynamic_binder_body + assert "RegisteredReverseExecutorHandlerKind::kTransitionDiagRtuBackward" not in dynamic_binder_body + assert "registered transition dynamic binder is not implemented for" not in dynamic_binder_body + assert "message_primitive_start + 1" not in forward_program_body + assert "readout_primitive_start" not in forward_program_body + assert "kPrimitiveGatedLogspaceRecurrenceOpcode" not in forward_program_body + assert "kPrimitiveDiagRtuOpcode" not in forward_program_body + assert "flat_bucket_registered_forward_partitioned_attention_cuda(" in registered_program_kernel_text + assert "flat_bucket_registered_forward_sender_kv_sequence_cuda(" in registered_program_kernel_text + assert "program_transition_linear_forward_kernel" in registered_program_kernel_text + assert "program_transition_linear_input_backward_kernel" in registered_program_kernel_text + assert "program_transition_diag_rtu_forward_kernel" in registered_program_kernel_text + assert "program_transition_diag_rtu_input_backward_kernel" in registered_program_kernel_text + assert "program_transition_gated_logspace_recurrence_forward_kernel" in registered_program_kernel_text + assert "program_transition_gated_logspace_recurrence_backward_kernel" in registered_program_kernel_text + assert "program_transition_norm_or_identity_forward_kernel" in registered_program_kernel_text + assert "program_transition_norm_or_identity_input_backward_kernel" in registered_program_kernel_text + assert "program_transition_tanh_forward_kernel" in registered_program_kernel_text + assert "program_transition_tanh_backward_kernel" in registered_program_kernel_text + assert "program_transition_recurrent_matmul_forward_kernel" in registered_program_kernel_text + assert "program_transition_recurrent_matmul_input_backward_kernel" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_linear_forward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_linear_backward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_diag_rtu_forward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_diag_rtu_backward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_gated_logspace_recurrence_forward_cuda" in ( + registered_program_kernel_text + ) + assert "flat_bucket_registered_program_transition_gated_logspace_recurrence_backward_cuda" in ( + registered_program_kernel_text + ) + assert "flat_bucket_registered_program_transition_norm_or_identity_forward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_norm_or_identity_backward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_tanh_forward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_tanh_backward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_recurrent_matmul_forward_cuda" in registered_program_kernel_text + assert "flat_bucket_registered_program_transition_recurrent_matmul_backward_cuda" in registered_program_kernel_text + assert "run_registered_matmul_reverse_transition_handler" in registered_program_kernel_text + assert "RegisteredTransitionTrainableReducerRunFn" in registered_program_kernel_text + assert "registered_transition_trainable_reducer_handler_for_native_callable(" in registered_program_kernel_text + assert "RegisteredTransitionTrainableReducerHandlerKind" not in registered_program_kernel_text + transition_trainable_reducer_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "at::Tensor run_registered_transition_trainable_reducer_handler(" + ) : registered_program_kernel_text.index("void validate_registered_transition_trainable_reducer_rows(") + ] + assert "handler.run(" in transition_trainable_reducer_body + assert "switch (handler.kind)" not in transition_trainable_reducer_body + assert "validate_registered_readout_executor_rows(" in registered_program_kernel_text + assert "executor_id=int(readout_executor.row.executor_id)" not in ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py" + ).read_text(encoding="utf-8") + for blocked_term in ( + "tensor_slot", + "InitialGated", + "InitialDiagonal", + "kTf", + "kReadoutExecutorId", + "kReverseReadoutExecutorId", + "kReadoutBucketOrdinal", + "compatibility_launch_plan", + "try_flat_bucket_temporal_scan_cuda", + ): + assert blocked_term not in program_execution_text + assert blocked_term not in registered_program_kernel_text + + +def test_message_readout_native_callable_bodies_are_strategy_local() -> None: + repo_root = Path(__file__).resolve().parents[1] + program_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program" + forward_program = (program_root / "forward_program.cuh").read_text(encoding="utf-8") + backward_steps = (program_root / "backward_surface_steps.cuh").read_text(encoding="utf-8") + message_forward = (program_root / "native_callables/message_forward_strategies.cuh").read_text(encoding="utf-8") + readout_forward = (program_root / "native_callables/readout_forward_strategies.cuh").read_text(encoding="utf-8") + message_reverse = (program_root / "native_callables/message_reverse_strategies.cuh").read_text(encoding="utf-8") + readout_reverse = (program_root / "native_callables/readout_reverse_strategies.cuh").read_text(encoding="utf-8") + + for broad_source in (forward_program, backward_steps): + assert "bind_fixed_slot_context_message_handler(" not in broad_source + assert "run_fixed_slot_context_message(" not in broad_source + assert "bind_projection_reduction_boundary_readout_handler(" not in broad_source + assert "run_projection_reduction_boundary_readout_backward(" not in broad_source + + assert "bind_fixed_slot_context_message_handler(" in message_forward + assert "run_fixed_slot_context_direct_keyless_readout_message(" in message_forward + assert "run_neighborhood_attention_project_message(" in message_forward + assert "bind_projection_reduction_boundary_readout_handler(" in readout_forward + assert "run_projection_reduction_boundary_readout_projection(" in readout_forward + assert "run_fixed_slot_context_recurrent_message_backward(" in message_reverse + assert "run_neighborhood_attention_project_recurrent_message_backward(" in message_reverse + assert "run_projection_reduction_boundary_readout_backward(" in readout_reverse + assert "run_projection_reduction_boundary_output_message_backward(" in readout_reverse + + +def test_forward_scan_fails_closed_on_unsupported_primitive_programs() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py").read_text( + encoding="utf-8" + ) + registered_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py" + ).read_text(encoding="utf-8") + assert "registered_executor_bindings_required" in source_text + assert "run_registered_temporal_forward_executor_scan(" in source_text + assert "run_registered_temporal_forward_executor_scan(" in registered_executor_text + assert "_try_run_registered_temporal_fused_forward_program_scan(" in registered_executor_text + assert "registered_temporal_fused_forward_program_cuda(" in registered_executor_text + assert "return_final_program_tensors=bool(materialize_final_state)" in registered_executor_text + assert "return_reverse_artifacts=bool(collect_artifacts)" in registered_executor_text + assert "source=registered_fused_forward_program_cuda" in registered_executor_text + assert "reverse_artifact_tensor_store=reverse_artifact_tensor_store" in registered_executor_text + assert "compute_registered_temporal_bucket_step_artifacts(" not in registered_executor_text + assert "compiled_fused_forward_program_only=1" in registered_executor_text + assert "Registered temporal forward execution must run through the compiler-owned fused CUDA program." in ( + registered_executor_text + ) + forward_executor_body = registered_executor_text[ + registered_executor_text.index( + "def run_registered_temporal_forward_executor_scan(" + ) : registered_executor_text.index("\ndef _reverse_artifact_tensor_store_window_table(") + ] + assert "compute_registered_temporal_bucket_step_artifacts(" not in forward_executor_body + assert "for scan_step in scan_schedule.iter_steps()" not in forward_executor_body + assert "registered_forward_executor_kernel_not_implemented" not in source_text + assert "registered_forward_executor_recompute_kernel_not_implemented" not in source_text + assert "try_flat_bucket_temporal_scan_cuda(" not in source_text + + +def test_shared_temporal_forward_scan_has_no_python_step_loop_fallback() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py").read_text( + encoding="utf-8" + ) + body = _function_source(source_text, "run_shared_temporal_bucket_forward_scan") + + assert "Add the missing fabric.cuda.nn primitive executor" in body + assert "for scan_step in scan_schedule.iter_steps()" not in body + assert "compute_temporal_bucket_step_artifacts(" not in body + + +def test_temporal_backward_top_level_module_was_deleted() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface" + + assert not (source_root / "temporal_backward.py").exists() + assert not (source_root / "temporal_backward_legacy_reference.py").exists() + assert (source_root / "temporal/forward_scan.py").exists() + assert (source_root / "temporal/reverse_executor.py").exists() + assert (source_root / "temporal/physical_autograd.py").exists() + assert not (source_root / "temporal/boundary_backward.py").exists() + + +def test_temporal_backward_legacy_reference_is_not_active_runtime_path() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface" + active_sources = [path for path in source_root.rglob("*.py")] + + for path in active_sources: + source_text = path.read_text(encoding="utf-8") + assert "temporal_backward_legacy_reference" not in source_text + assert "temporal_backward_engine" not in source_text + + +def test_temporal_modules_do_not_use_common_wildcard_imports() -> None: + repo_root = Path(__file__).resolve().parents[1] + temporal_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal" + violations: list[str] = [] + for path in temporal_root.glob("*.py"): + source_text = path.read_text(encoding="utf-8") + if "from cortical.fabric.backend.cuda.sequence_surface.temporal.common import *" in source_text: + violations.append(path.name) + if "F403,F405" in source_text: + violations.append(path.name) + + assert not violations, "temporal modules must import shared helpers explicitly: " + ", ".join(sorted(violations)) + + +def test_sequence_surface_imports_transition_execution_helpers_explicitly() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface" + allowed_transition_submodules = ( + "cortical.fabric.backend.cuda.transition_execution.projection", + "cortical.fabric.backend.cuda.transition_execution.registry", + "cortical.fabric.backend.cuda.transition_execution.types", + ) + violations: list[str] = [] + for path in source_root.rglob("*.py"): + source_text = path.read_text(encoding="utf-8") + if "from cortical.fabric.backend.cuda import transition_execution" in source_text: + violations.append(path.relative_to(source_root).as_posix()) + if "import cortical.fabric.backend.cuda.transition_execution" in source_text: + violations.append(path.relative_to(source_root).as_posix()) + if "from cortical.fabric.backend.cuda.transition_execution import" in source_text: + violations.append(path.relative_to(source_root).as_posix()) + if "transition_execution." in source_text and not any( + submodule in source_text for submodule in allowed_transition_submodules + ): + violations.append(path.relative_to(source_root).as_posix()) + + assert not violations, "sequence-surface modules must depend on explicit transition helpers: " + ", ".join( + sorted(set(violations)) + ) + + +def test_transition_execution_monolith_was_deleted() -> None: + repo_root = Path(__file__).resolve().parents[1] + cuda_root = repo_root / "src/cortical/fabric/backend/cuda" + transition_root = cuda_root / "transition_execution" + + assert not (cuda_root / "transition_execution.py").exists() + assert (transition_root / "types.py").exists() + assert (transition_root / "projection.py").exists() + assert (transition_root / "program.py").exists() + assert not (transition_root / "lowering.py").exists() + assert not (transition_root / "temporal_fusion.py").exists() + + +def test_active_code_does_not_import_transition_execution_package_root() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_roots = ( + repo_root / "src/cortical/fabric/backend", + repo_root / "tests", + ) + violations: list[str] = [] + forbidden_fragments = ( + "from cortical.fabric.backend.cuda import transition_execution", + "from cortical.fabric.backend.cuda.transition_execution import", + "import cortical.fabric.backend.cuda.transition_execution as", + 'importlib.import_module("cortical.fabric.backend.cuda.transition_execution")', + ) + for source_root in source_roots: + for path in source_root.rglob("*.py"): + if path == Path(__file__).resolve(): + continue + source_text = path.read_text(encoding="utf-8") + if any(fragment in source_text for fragment in forbidden_fragments): + violations.append(path.relative_to(repo_root).as_posix()) + + assert not violations, "active code must import transition_execution semantic submodules: " + ", ".join( + sorted(set(violations)) + ) + + +def test_rejected_transition_temporal_fusion_facades_were_deleted() -> None: + repo_root = Path(__file__).resolve().parents[1] + temporal_fusion_path = repo_root / "src/cortical/fabric/backend/cuda/transition_execution/temporal_fusion.py" + flat_bucket_primitive_executors_path = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/primitive_executors" + ) + transition_program_text = ( + repo_root / "src/cortical/fabric/backend/cuda/transition_execution/program.py" + ).read_text(encoding="utf-8") + transition_registry_text = ( + repo_root / "src/cortical/fabric/backend/cuda/transition_execution/registry.py" + ).read_text(encoding="utf-8") + executor_patterns_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py" + ).read_text(encoding="utf-8") + transition_registry_text = ( + repo_root / "src/cortical/fabric/backend/cuda/transition_execution/registry.py" + ).read_text(encoding="utf-8") + primitive_dispatch_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_dispatch.py" + ).read_text(encoding="utf-8") + executor_binding_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py" + ).read_text(encoding="utf-8") + temporal_backward_cuda_path = ( + repo_root / "src/cortical/fabric/backend/cuda/ops/temporal_backward/temporal_backward_cuda.py" + ) + temporal_backward_root = repo_root / "src/cortical/fabric/backend/cuda/ops/temporal_backward" + fusion_gateway_path = repo_root / "src/cortical/fabric/backend/cuda/ops/temporal_backward/fusion_gateway.py" + + assert not temporal_fusion_path.exists() + assert not fusion_gateway_path.exists() + assert not temporal_backward_cuda_path.exists() + assert not (temporal_backward_root / "reverse_table.py").exists() + assert not (temporal_backward_root / "materialization.py").exists() + assert not (temporal_backward_root / "reductions.py").exists() + assert not (temporal_backward_root / "extension.py").exists() + assert not (temporal_backward_root / "flat_bucket_temporal_backward_binding.cpp").exists() + assert not (temporal_backward_root / "flat_bucket_temporal_backward_kernels.cu").exists() + assert not (temporal_backward_root / "flat_bucket_temporal_recurrent_backward_kernels.cu").exists() + assert not flat_bucket_primitive_executors_path.exists() + assert "TransitionProgramStateSlice(" not in transition_program_text + assert "TransitionProgramTensorEdge(" not in transition_program_text + assert "TransitionProgramOpArity(" not in transition_program_text + assert "TransitionProgramMessageInput(" not in transition_program_text + assert "registered_transition_executor_records()" in transition_program_text + assert "class TransitionProgramExecutorRecord" in transition_registry_text + assert "class TransitionPrimitiveExecutorRecord" in transition_registry_text + assert "forward_strategy_id: str" in transition_registry_text + assert "backward_strategy_id: str" in transition_registry_text + assert "def registered_transition_executor_records(" in transition_registry_text + assert "def registered_transition_primitive_executor_records(" in transition_registry_text + assert "program_layer_status: TransitionPrimitiveProgramLayerStatus" in transition_registry_text + assert "program_forward_status: TransitionPrimitiveProgramLayerStatus" in transition_registry_text + assert "program_backward_status: TransitionPrimitiveProgramLayerStatus" in transition_registry_text + assert "program_forward_symbol: str" in transition_registry_text + assert "program_backward_symbol: str" in transition_registry_text + assert "program_forward_cxx_entrypoint: str" in transition_registry_text + assert "program_forward_input_bindings: tuple[str, ...]" in transition_registry_text + assert "program_forward_parameter_bindings: tuple[tuple[str, bool], ...]" in transition_registry_text + assert "program_forward_output_bindings: tuple[tuple[str, bool], ...]" in transition_registry_text + assert "program_forward_output_contracts: tuple[tuple[str, str, str, str], ...]" in transition_registry_text + assert "program_reverse_native_callable: str" in transition_registry_text + assert "tape_saved_input_bindings: tuple[str, ...]" in transition_registry_text + assert "tape_recompute_output_bindings: tuple[str, ...]" in transition_registry_text + assert "program_layer_blocker_code: str" in transition_registry_text + assert "reverse_input_bindings: tuple[str, ...]" in transition_registry_text + assert "parameter_bindings: tuple[str, ...]" in transition_registry_text + assert "reverse_output_bindings: tuple[str, ...]" in transition_registry_text + assert "aliases: tuple[str, ...]" in transition_registry_text + assert "def transition_primitive_executor_record_for_lowered_primitive(" in transition_registry_text + assert 'if lowered_primitive == "diagonal_recurrence"' not in transition_registry_text + assert 'aliases=("diagonal_recurrence",)' in transition_registry_text + assert "def transition_primitive_program_contract_blocker_code(" in transition_registry_text + assert "def transition_program_layer_blocker_codes(" in transition_registry_text + assert "def transition_program_layer_missing_symbols(" in transition_registry_text + assert "program_transition_gated_logspace_recurrence_forward" in transition_registry_text + assert "program_transition_norm_or_identity_forward" in transition_registry_text + assert "program_transition_tanh_forward" in transition_registry_text + assert "program_transition_tanh_backward" in transition_registry_text + assert "native.reverse.transition_tanh.v1" in transition_registry_text + assert "program_transition_diag_rtu_backward" in transition_registry_text + native_callables_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py" + ).read_text(encoding="utf-8") + assert "_TRANSITION_FORWARD_IMPLEMENTATION_SYMBOL" not in native_callables_text + assert "_TRANSITION_FORWARD_OUTPUT_CONTRACTS" not in native_callables_text + assert "_TRANSITION_FORWARD_INPUT_CONTRACTS" not in native_callables_text + assert "_TRANSITION_FORWARD_PARAMETER_CONTRACTS" not in native_callables_text + assert "_TRANSITION_FORWARD_OUTPUT_BINDING_CONTRACTS" not in native_callables_text + assert "_TRANSITION_REVERSE_CALLABLE_BY_PRIMITIVE" not in native_callables_text + assert "record.program_forward_output_contracts" in native_callables_text + assert "record.program_forward_input_bindings" in native_callables_text + assert "record.program_reverse_native_callable" in native_callables_text + assert "transition_primitive_executor_record_for_lowered_primitive(" in transition_program_text + assert "class TransitionPrimitiveDagExecutorPlan" in transition_program_text + assert "class TransitionPrimitiveDagTensorEdge" in transition_program_text + assert "_build_transition_primitive_dag_executor_plan(" in transition_program_text + assert "_resolve_transition_tape_bindings(" in transition_program_text + assert '"recompute" in tape_policy' not in transition_program_text + assert "NO_REGISTERED_TRANSITION_EXECUTOR" not in transition_program_text + assert "selection_kind: Literal" in transition_program_text + assert "runtime_execution_status: Literal" in transition_program_text + assert "_last_transition_executor_selection_kind" in transition_program_text + assert "_last_transition_executor_primitive_dag" in transition_program_text + assert 'executor="primitive_dag"' in transition_program_text + assert "registry_id=primitive_dag.registry_id" in transition_program_text + assert "_lower_transition_primitive_dag_forward(" not in transition_program_text + assert "_transition_primitive_dag_forward_executors()" not in transition_program_text + assert "primitive_forward_executors.get(str(op.forward_symbol))" not in transition_program_text + assert "primitive_dag_eager_forward" not in transition_program_text + assert 'if primitive == "linear"' not in transition_program_text + assert 'elif primitive == "matmul"' not in transition_program_text + assert 'elif primitive == "norm_or_identity"' not in transition_program_text + assert "program_layer_status: str" in transition_program_text + assert "program_layer_blocker_codes: tuple[str, ...]" in transition_program_text + assert "program_missing_symbols: tuple[str, ...]" in transition_program_text + assert 'if primitive == "gated_logspace_recurrence"' not in executor_binding_text + assert 'elif primitive in {"diag_rtu", "diagonal_recurrence"}' not in executor_binding_text + assert "primitive_record.reverse_input_bindings" in executor_binding_text + assert "primitive_record.reverse_output_bindings" in executor_binding_text + assert "_last_transition_executor_program_layer_status" in transition_program_text + assert "UNREGISTERED_TRANSITION_PRIMITIVE" in transition_program_text + assert "program_transition_tanh_forward" in transition_registry_text + assert "program_transition_tanh_backward" in transition_registry_text + assert 'program_reverse_native_callable="native.reverse.transition_matmul_primitive.v1"' in ( + transition_registry_text + ) + assert 'reverse_input_bindings=("input", "grad_output")' in transition_registry_text + assert 'reverse_output_bindings=("grad_input", "grad_weight")' in transition_registry_text + tanh_record_body = transition_registry_text[ + transition_registry_text.index('primitive="tanh"') : transition_registry_text.index("__all__") + ] + assert "MISSING_CUDA_TRANSITION_PRIMITIVE_EXECUTOR" not in tanh_record_body + assert 'program_backward_status="callable"' in tanh_record_body + assert 'program_reverse_native_callable="native.reverse.transition_tanh.v1"' in tanh_record_body + assert "class TransitionPublicKVProjectionPlan" not in transition_program_text + assert "transition_public_kv_projection:sender_kv:v1" not in transition_program_text + assert "runtime._project_sender_kv_from_cells_step(" not in transition_program_text + assert "project_sender_kv_from_cells_step(" not in transition_program_text + assert "build_transition_public_kv_projection_plan(" not in transition_program_text + assert "selected unsupported executor" not in transition_program_text + assert "_record_gated_temporal_backward_window" not in transition_program_text + assert "_record_diagonal_temporal_backward_window" not in transition_program_text + assert "cuda_gated_core_recurrent_affine_window" not in transition_program_text + assert "cuda_diagonal_recurrence_core_backward_window" not in transition_program_text + assert "fixed_composite_abi" not in primitive_dispatch_text + assert "message_runtime_path" not in primitive_dispatch_text + assert "message_backward_runtime_path" not in primitive_dispatch_text + assert "readout_runtime_path" not in primitive_dispatch_text + assert "boundary_backward_runtime_path" not in primitive_dispatch_text + assert "parameter_binding_runtime_path" not in primitive_dispatch_text + assert "runtime_path" not in primitive_dispatch_text + assert "registered_executor_group_owns_primitive_row" in primitive_dispatch_text + assert "registered_transition_affine_executor" in primitive_dispatch_text + assert "registered_parameter_reduction_executor" in primitive_dispatch_text + assert "registered_message_executor_required" in primitive_dispatch_text + assert "registered_boundary_backward_executor_required" in primitive_dispatch_text + assert "_EXPLICIT_COMPOSITE_TRANSITION_PRIMITIVES" not in primitive_dispatch_text + assert "def _registered_transition_composite_primitives(" in primitive_dispatch_text + assert "temporal_executor_strategy_registry().all_patterns()" in primitive_dispatch_text + assert "registered_executor_binding_group_implemented" in primitive_dispatch_text + assert "verified_rewrite_required" in executor_patterns_text + + +def test_temporal_backward_registered_executor_owns_reverse_bindings() -> None: + repo_root = Path(__file__).resolve().parents[1] + backward_plan_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/backward_plan.py" + ).read_text(encoding="utf-8") + executor_binding_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py" + ).read_text(encoding="utf-8") + registered_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py" + ).read_text(encoding="utf-8") + executor_registry_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py" + ).read_text(encoding="utf-8") + executor_patterns_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py" + ).read_text(encoding="utf-8") + surface_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py" + ).read_text(encoding="utf-8") + program_parameters_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py" + ).read_text(encoding="utf-8") + flat_buckets_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_buckets.py" + ).read_text(encoding="utf-8") + runtime_dispatch_text = (repo_root / "src/cortical/fabric/backend/runtime_dispatch.py").read_text(encoding="utf-8") + runtime_core_text = (repo_root / "src/cortical/fabric/runtime/core.py").read_text(encoding="utf-8") + runtime_surface_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/surface.py" + ).read_text(encoding="utf-8") + runtime_support_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/support.py" + ).read_text(encoding="utf-8") + runtime_backward_path = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/backward.py" + reverse_table_path = repo_root / "src/cortical/fabric/backend/cuda/ops/temporal_backward/reverse_table.py" + reverse_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py" + ).read_text(encoding="utf-8") + sequence_runtime_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py" + ).read_text(encoding="utf-8") + planner_text = (repo_root / "src/cortical/fabric/backend/planner.py").read_text(encoding="utf-8") + output_backward_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/output_backward.py" + ).read_text(encoding="utf-8") + param_binding_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py" + ).read_text(encoding="utf-8") + native_callables_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py" + ).read_text(encoding="utf-8") + reducer_patterns_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py" + ).read_text(encoding="utf-8") + reverse_artifacts_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/reverse_artifacts.py" + ).read_text(encoding="utf-8") + message_specs_text = (repo_root / "src/cortical/fabric/backend/message_rule_specs.py").read_text(encoding="utf-8") + registered_program_cuda_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py" + ).read_text(encoding="utf-8") + registered_program_binding_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + registered_program_kernel_text = _registered_program_kernel_source_text(repo_root) + compiler_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler" + transition_stage_body = _function_source( + registered_executor_text, + "_build_transition_boundary_reverse_step_program_tables", + ) + reverse_window_body = _function_source( + registered_executor_text, + "_run_registered_temporal_reverse_program_tensor_table_window", + ) + + assert "class TemporalBackwardExecutablePlan" in backward_plan_text + assert "build_temporal_backward_executable_plan(" in backward_plan_text + assert "build_temporal_reverse_executor_binding_plan(" in backward_plan_text + assert "backward_executable_plan=compiler_owned" in backward_plan_text + assert "tensor_slot_rows: torch.Tensor" not in backward_plan_text + assert "build_temporal_reverse_executor_binding_plan(" in executor_binding_text + assert "TemporalExecutorTensorBinding" in executor_binding_text + assert "registered_reverse_executor_bindings_required" in reverse_executor_text + assert "build_registered_temporal_executor_program(" in reverse_executor_text + assert "executor_program=executor_program" in reverse_executor_text + for old_reverse_owner in ( + "explicit_public_projection_thin_reverse", + "explicit_readout_projection_thin_reverse", + "public_projection_backward:explicit_thin_reverse", + "readout_projection_backward:explicit_thin_reverse", + "thin_reverse_path:explicit_executor", + "lowered_state_epilogue_backward:explicit_cuda_executor", + ): + assert old_reverse_owner not in sequence_runtime_text + assert old_reverse_owner not in planner_text + assert "class RegisteredTemporalExecutorProgram" in registered_executor_text + assert "kernel_registry: RegisteredTemporalExecutorKernelRegistry" in registered_executor_text + assert "registered_temporal_executor_kernel_registry()" in registered_executor_text + assert "class RegisteredTemporalExecutorKernelRegistry" in executor_registry_text + assert "program_executor_plan: TemporalRegisteredProgramExecutorPlan" in registered_executor_text + assert "def require_static_tensor(" in registered_executor_text + assert "def require_executor(" not in registered_executor_text + assert "def transition_bucket_ordinals(" in registered_executor_text + assert "def require_transition_reverse_coverage(" in registered_executor_text + assert "def reverse_surface_handles(" in registered_executor_text + assert 'reverse_surface_handles(surface="message")' in registered_executor_text + assert "def reverse_surface_handle(" not in registered_executor_text + assert 'reverse_handle(surface="message", bucket_ordinal=-1)' not in registered_executor_text + assert "_registered_executor_rows_tensor" not in surface_executor_text + assert "_registered_executor_binding_rows_tensor" not in surface_executor_text + assert "reverse_executor_rows=_registered" not in surface_executor_text + assert "reverse_executor_rows=executor_program.backward_plan.reverse_executor_rows" in registered_executor_text + assert "reverse_executor_binding_rows=executor_program.backward_plan.executor_binding_rows" in ( + registered_executor_text + ) + assert "neighborhood_attention_project_backward" not in registered_executor_text + assert "projection_reduction_boundary" not in registered_executor_text + assert "neighborhood_attention_project_backward" not in surface_executor_text + assert "projection_reduction_boundary" not in surface_executor_text + assert "neighborhood_attention_project_backward" not in executor_registry_text + assert "projection_reduction_boundary" not in executor_registry_text + assert '_reverse_executor_names(surface="message")' in executor_registry_text + assert '_forward_executor_names(surface="readout")' in executor_registry_text + assert '"neighborhood_attention_project_backward"' not in executor_patterns_text + assert '"neighborhood_attention_project_backward"' in message_specs_text + assert "_registered_message_rule_reverse_executor_patterns()" in executor_patterns_text + assert "readout_rule_native_executor(" in executor_patterns_text + assert "executor_validation=registry_owned" in executor_registry_text + assert "run_registered_temporal_reverse_executor_window(" not in reverse_executor_text + assert "run_registered_temporal_reverse_executor_window(" not in registered_executor_text + assert "run_registered_temporal_reverse_executor_tensor_store_window(" in reverse_executor_text + assert "run_registered_temporal_reverse_executor_tensor_store_window(" in registered_executor_text + assert "run_registered_temporal_bucket_step_backward(" not in registered_executor_text + assert "step-object artifact windows are not an active CUDA training path" in reverse_executor_text + assert "registered_reverse_program_window_rejected" in registered_executor_text + assert "_record_temporal_reverse_scan_owner(" in registered_executor_text + assert "_last_flat_bucket_temporal_reverse_scan_binding_abi" in registered_executor_text + assert "registered_temporal_executor_loop" not in registered_executor_text + assert "run_registered_output_message_backward_executor(" not in registered_executor_text + assert "run_registered_recurrent_message_backward_executor(" not in registered_executor_text + assert "run_registered_recurrent_kv_projection_backward_executor(" not in registered_executor_text + assert "run_registered_transition_backward_executor(" not in registered_executor_text + assert "run_registered_transition_reverse_program_stage(" not in registered_executor_text + assert "_run_registered_transition_reverse_program_stage_cuda(" not in registered_executor_text + assert "_run_registered_transition_boundary_reverse_program_step_cuda(" not in registered_executor_text + assert "_build_transition_boundary_reverse_step_program_tables(" in registered_executor_text + assert "_consume_transition_boundary_reverse_step_outputs(" in registered_executor_text + assert "executor_program.kernel_registry.transition_backward(" not in registered_executor_text + assert "run_registered_readout_projection_backward_executor(" not in registered_executor_text + assert "run_registered_recurrent_query_param_backward_executor(" not in registered_executor_text + assert "run_registered_boundary_public_backward_executor(" not in registered_executor_text + assert "output_message_backward=run_registered_output_message_backward_executor" not in executor_registry_text + assert "recurrent_message_backward=run_registered_recurrent_message_backward_executor" not in executor_registry_text + assert "_output_message_backward" not in executor_registry_text + assert "_recurrent_message_backward" not in executor_registry_text + assert "_recurrent_kv_projection_backward" not in executor_registry_text + assert "_readout_layout_projection_backward" not in executor_registry_text + assert "_boundary_public_backward" not in executor_registry_text + assert "transition_backward=run_registered_transition_backward_executor" not in executor_registry_text + assert "_transition_backward" not in executor_registry_text + assert "dispatch_key=selected_executor_rows_and_binding_rows" in executor_registry_text + assert "run_registered_grouped_recurrent_kv_projection_backward_executor(" not in registered_executor_text + assert "run_registered_initial_recurrent_backward_executor(" not in registered_executor_text + assert "runtime._run_backend_message_backward_phase(" not in registered_executor_text + assert "run_backend_order_transition_buckets_backward_step_cached(" not in registered_executor_text + assert "runtime._run_backend_order_transition_buckets_backward_step(" not in registered_executor_text + assert "runtime._run_backend_output_projection_backward_phase(" not in registered_executor_text + assert "runtime._run_backend_sender_kv_projection_backward_phase(" not in registered_executor_text + assert "runtime._run_backend_query_param_backward_phase(" not in registered_executor_text + assert "runtime._run_backend_boundary_public_backward_phase(" not in registered_executor_text + assert "runtime._run_backend_initial_recurrent_backward_raw_phase(" not in registered_executor_text + assert "runtime._sender_kv_projection_param_grad_tuple_from_raw_grads(" not in registered_executor_text + assert "recurrent_kv_projection_param_binding(" not in registered_executor_text + assert "recurrent_query_param_backward(" not in registered_executor_text + assert "_recurrent_kv_projection_param_binding" not in executor_registry_text + assert "_recurrent_query_param_backward" not in executor_registry_text + assert "run_registered_recurrent_kv_projection_param_binding_executor" not in surface_executor_text + assert "run_registered_recurrent_query_param_backward_executor" not in surface_executor_text + assert "build_temporal_parameter_reducer_program(" in registered_executor_text + assert "trainable_param_names=trainable_param_names" in registered_executor_text + assert "run_temporal_parameter_reducer_program(" in registered_executor_text + assert "program=parameter_reducer_program" in registered_executor_text + assert "TemporalReadoutOutputParamReducerRequest(" in registered_executor_text + assert "TemporalSenderKVProjectionParamReducerRequest(" in registered_executor_text + assert "TemporalRecurrentQueryParamReducerRequest(" in registered_executor_text + assert "run_temporal_sender_kv_projection_param_reducer_program(" not in registered_executor_text + assert "run_temporal_recurrent_query_param_reducer_program(" not in registered_executor_text + assert "run_temporal_readout_output_param_reducer_program(" not in registered_executor_text + assert "output_projection_param_grads = tuple(" not in registered_executor_text + assert "try_project_recurrent_kv_backend_order_backward_cuda(" not in registered_executor_text + assert 'expected_names=("gated_logspace_transition_backward", "diag_rtu_transition_backward")' not in ( + surface_executor_text + ) + assert "runtime._run_backend_message_backward_phase(" not in surface_executor_text + assert "run_backend_order_transition_buckets_backward_step_cached(" not in surface_executor_text + assert "runtime._run_backend_order_transition_buckets_backward_step(" not in surface_executor_text + assert "def _run_transition_bucket_step(" not in runtime_dispatch_text + assert "def _run_transition_buckets_step(" not in runtime_dispatch_text + assert "def _run_backend_order_transition_buckets_step(" not in runtime_dispatch_text + assert "def _run_backend_order_transition_buckets_backward_step(" not in runtime_dispatch_text + assert "def _run_active_window_transition_buckets_step(" not in runtime_dispatch_text + assert "def _lower_backend_population_transition_shared(" not in runtime_dispatch_text + assert "def _lower_backend_population_transition_forward_result_shared(" not in runtime_dispatch_text + assert "def _lower_backend_population_transition_backward_shared(" not in runtime_dispatch_text + assert "_run_backend_order_transition_buckets_step(" not in runtime_core_text + assert "_run_transition_buckets_step(" not in runtime_core_text + assert "class _CapturedTrainingSequenceSurface" not in runtime_surface_text + assert "runtime.backward" not in runtime_surface_text + assert "CudaSequenceBackwardMixin" not in runtime_surface_text + assert not runtime_backward_path.exists() + assert "_PhysicalBackwardSequenceExecutor" not in runtime_surface_text + assert "_PhysicalBackwardSequenceExecutor" not in runtime_support_text + assert "physical_sequence_executor" not in runtime_support_text + assert "execute_temporal_bucket_sequence(" in runtime_surface_text + assert "def run_backend_order_transition_buckets_step(" not in flat_buckets_text + assert "def run_backend_order_transition_buckets_backward_step(" not in flat_buckets_text + assert "def run_transition_bucket_step(" not in flat_buckets_text + assert "def run_transition_buckets_step(" not in flat_buckets_text + assert "def run_active_window_transition_buckets_step(" not in flat_buckets_text + assert "def run_backend_order_transition_buckets_backward_step_cached(" not in flat_buckets_text + assert "def run_backend_order_transition_buckets_backward_step_cached_unbound(" not in flat_buckets_text + assert "lower_backend_population_transition_backward_shared(" not in surface_executor_text + assert "registered_temporal_fused_reverse_transition_program_cuda(" not in surface_executor_text + assert "registered_temporal_fused_backward_program_transition_step_cuda(" not in surface_executor_text + assert "class RegisteredTransitionReverseStageResult" not in surface_executor_text + assert "def registered_transition_reverse_stage_program(" not in surface_executor_text + assert "def run_registered_transition_reverse_program_stage(" not in surface_executor_text + assert "_transition_program_tensor_table(" not in transition_stage_body + assert "_extend_transition_reverse_program_tensor_table(" not in transition_stage_body + assert "_transition_state_before_from_reverse_artifact_table(" not in transition_stage_body + assert "_reverse_artifact_tensor_for_role_step(" not in transition_stage_body + assert "_transition_stage_parameter_slot_table(" not in surface_executor_text + assert "transition_parameter_tensor_table(" not in surface_executor_text + assert "_resolve_transition_parameter_tensor(" not in surface_executor_text + assert "_normalize_transition_parameter_tensor(" not in surface_executor_text + assert "transition_parameter_tensors=reverse_program_tensor_table.transition_parameter_tensors" in ( + reverse_window_body + ) + assert "transition_parameter_rows=reverse_program_tensor_table.transition_parameter_rows" in reverse_window_body + assert "def transition_parameter_tensor_table(" in program_parameters_text + assert "transition_parameter_rows" in registered_executor_text + assert "transition_parameter_rows" in registered_program_cuda_text + assert "bind_transition_parameter_tensors(" in registered_program_kernel_text + assert "transition_output_keep_slot_row_groups" in registered_executor_text + assert "transition_output_keep_slot_row_groups" in registered_program_cuda_text + assert "filter_transition_outputs_by_keep_slots" in registered_program_kernel_text + assert "transition output keep-slot row references an invalid output slot" in registered_program_kernel_text + assert "_transition_reverse_seed_tensor_table(" in transition_stage_body + assert "_TRANSITION_REVERSE_SEED_ROLE_IDS" not in surface_executor_text + assert "temporal_transition_reverse_seed_role_id(" in surface_executor_text + assert "temporal_transition_reverse_seed_role_rows_tensor(" in native_callables_text + assert "transition_reverse_seed_role_rows" in registered_executor_text + assert "transition_reverse_seed_role_rows" in registered_program_cuda_text + assert "transition_reverse_seed_role_rows" in registered_program_kernel_text + assert "reverse_span_output_rows" in registered_executor_text + assert "reverse_span_output_rows" in registered_program_cuda_text + assert "reverse_span_output_rows" in registered_program_kernel_text + assert "forward_artifact_route_rows=executor_program.forward_artifact_route_rows" in registered_executor_text + assert "forward_artifact_merge_rows=executor_program.forward_artifact_merge_rows" in registered_executor_text + assert "forward_output_route_rows=executor_program.forward_output_route_rows" in registered_executor_text + assert "forward_artifact_route_rows" in registered_program_cuda_text + assert "forward_artifact_merge_rows" in registered_program_cuda_text + assert "forward_output_route_rows" in registered_program_cuda_text + assert "reverse_artifact_tensor_for_routed_access_step(" in registered_program_kernel_text + assert "reverse_output_route_rows" in registered_executor_text + assert "def front_output(" not in registered_executor_text + assert "def boundary_output(" not in registered_executor_text + assert "boundary_group[:6]" not in registered_executor_text + assert "len(front_outputs) != 9" not in registered_executor_text + assert "message_strategy_extra_param_grads" not in registered_executor_text + assert "registered_temporal_fused_forward_transition_program_cuda(" not in transition_stage_body + assert "registered_temporal_reverse_program_stage_cuda(" not in transition_stage_body + assert "registered_temporal_fused_backward_program_cuda(" in reverse_window_body + assert "def registered_temporal_fused_backward_program_cuda(" in registered_program_cuda_text + assert "fused_backward_program_execute" in registered_program_binding_text + assert "fused_reverse_program_full_step" not in registered_program_binding_text + assert "flat_bucket_registered_temporal_fused_backward_program_cuda(" in (registered_program_kernel_text) + assert "flat_bucket_registered_temporal_fused_backward_program_transition_stage_cuda(" not in ( + registered_program_kernel_text + ) + assert "flat_bucket_registered_temporal_fused_backward_program_output_grad_cuda(" not in ( + registered_program_kernel_text + ) + assert "flat_bucket_registered_temporal_fused_backward_program_readout_message_kv_step_cuda(" not in ( + registered_program_kernel_text + ) + assert "flat_bucket_registered_temporal_fused_reverse_program_window_step_cuda(" not in ( + registered_program_kernel_text + ) + assert "flat_bucket_registered_temporal_fused_reverse_program_transition_boundary_step_cuda(" not in ( + registered_program_kernel_text + ) + assert ( + "flat_bucket_registered_temporal_fused_backward_program_recurrent_message_boundary_initial_kv_step_cuda(" + not in (registered_program_kernel_text) + ) + assert "registered_temporal_fused_reverse_program_transition_boundary_step_cuda(" not in ( + registered_program_cuda_text + ) + assert "fused_reverse_program_transition_boundary_step" not in registered_program_binding_text + assert 'stage_name="output_grad_window"' not in registered_executor_text + assert 'stage_name="readout_message_kv_step"' not in registered_executor_text + assert "registered_temporal_fused_reverse_program_window_step_cuda(" not in registered_executor_text + assert "def registered_temporal_fused_reverse_program_window_step_cuda(" not in registered_program_cuda_text + assert "fused_reverse_program_window_step" not in registered_program_binding_text + assert "artifacts: Any" not in surface_executor_text + assert "artifacts=artifacts" not in registered_executor_text + assert "_reverse_artifact_tensor_for_role_step(" not in surface_executor_text + assert 'role_name="recurrent_msg_backend_order"' not in surface_executor_text + assert "reverse_artifact_tensor_for_role_step(" not in registered_program_kernel_text + assert "reverse_artifact_tensor_for_access_step(" in registered_program_kernel_text + assert ( + 'kReverseArtifactAccessOutputCells,\n local_step,\n "fused reverse full step output_cells"' + not in (registered_program_kernel_text) + ) + assert "forward_output_route_rows" in registered_program_kernel_text + assert "span_grad_cells_out" in registered_program_kernel_text + assert "reverse_artifact_access_rows=reverse_artifact_access_rows" in registered_executor_text + assert "reverse_artifact_access_rows = temporal_reverse_artifact_access_rows_tensor(" in registered_executor_text + assert "memory_artifact_plan.reverse_artifact_roles" in registered_executor_text + assert "kReverseArtifactAccessRecurrentMsgBackendOrder" in registered_program_kernel_text + assert "kReverseArtifactAccessTransitionStateBefore" in registered_program_kernel_text + assert "transition_seed_rows" in registered_program_cuda_text + assert "reverse_artifact_role_rows=reverse_artifact_role_rows" in registered_executor_text + assert "transition_reset=artifacts.transition_reset_step" not in surface_executor_text + assert "reset_not_program_owned" not in registered_executor_text + assert "temporal_reverse_reset_tensor_table(" in registered_executor_text + assert "reverse_reset_tensor_groups=tuple(reverse_reset_tensor_groups)" in registered_executor_text + assert "reverse_reset_row_groups=tuple(reverse_reset_row_groups)" in registered_executor_text + assert "transition_state_reset_rows=transition_step_tables.transition_state_reset_rows" in registered_executor_text + assert "reverse_reset_rows" in registered_program_cuda_text + assert "transition_state_reset_rows" in registered_program_cuda_text + assert "apply_transition_state_reset_outputs(" in registered_program_kernel_text + assert "build_registered_transition_state_before_artifact_table(" not in surface_executor_text + assert "transition_state_before_tensors" not in registered_executor_text + assert "transition_state_before_binding_rows" not in registered_executor_text + assert 'TemporalReverseArtifactRole(14, "transition_state_before", True)' in reverse_artifacts_text + assert 'TemporalReverseArtifactAccess(14, "transition_state_before", "transition_state_before", True)' in ( + reverse_artifacts_text + ) + assert "encode_temporal_reverse_transition_state_artifact_flags(" not in registered_executor_text + assert "kTransitionStateArtifactFlagStride" in registered_program_kernel_text + assert "decode_temporal_reverse_transition_state_artifact_flags(" not in surface_executor_text + assert "artifacts.backend_state_cache_before" not in surface_executor_text + assert "backend_state_cache_before=" not in registered_executor_text + assert "bind_temporal_transition_param_grads(" not in surface_executor_text + assert "bind_temporal_transition_param_grads(" not in param_binding_text + assert "run_temporal_transition_param_reducer_program(" not in surface_executor_text + assert "run_temporal_transition_param_reducer_program(" not in registered_executor_text + assert "TemporalTransitionParamReducerRequest(" in registered_executor_text + assert "reverse_program_stage_rows=executor_program.reverse_program_stage_rows" in registered_executor_text + assert "def _require_parameter_reducer_stage(" in param_binding_text + assert "_PARAMETER_REDUCER_STAGE_OPCODE = 5" in param_binding_text + assert "class TemporalParameterReducerProgram" in param_binding_text + assert "def build_temporal_parameter_reducer_program(" in param_binding_text + assert "def run_temporal_parameter_reducer_program(" in param_binding_text + assert "strategy_rows: torch.Tensor" in param_binding_text + assert "def _parameter_reducer_strategy_rows(" in param_binding_text + assert "parameter_reducer_strategy_rows=program.strategy_rows" in param_binding_text + assert "parameter_reducer_native_callable_id(reducer_kind)" in param_binding_text + assert "transition_trainable_reducer_native_callable_id(reducer_kind)" in param_binding_text + assert '"native.reverse.parameter_reduction.transition.v1"' in reducer_patterns_text + assert '"native.reverse.parameter_reduction.transition.materialized_base.v1"' in reducer_patterns_text + assert "class TemporalReadoutOutputParamReducerRequest" in param_binding_text + assert "class TemporalSenderKVProjectionParamReducerRequest" in param_binding_text + assert "class TemporalRecurrentQueryParamReducerRequest" in param_binding_text + assert "class TemporalTransitionParamReducerRequest" in param_binding_text + assert "def _execute_transition_parameter_reducer_row(" not in param_binding_text + assert "def _execute_transition_parameter_reducer_rows(" not in param_binding_text + assert "def _reduce_transition_named_grad_sequence(" not in param_binding_text + assert "def _transition_recurrent_bias_full_grad(" not in param_binding_text + assert "transition_trainable_rows" in param_binding_text + assert "transition_source_rows" in param_binding_text + assert "transition_source_names" in param_binding_text + assert "registered_transition_cuda_trainable_parameter_row_program" in param_binding_text + assert "_registered_transition_population_param_grad_tuple(" not in param_binding_text + assert "def _execute_sender_kv_projection_parameter_reducer_row(" not in param_binding_text + assert "def _execute_recurrent_query_parameter_reducer_row(" not in param_binding_text + assert "def _execute_readout_output_parameter_reducer_row(" not in param_binding_text + assert "_sender_kv_projection_named_param_grads_from_raw(" not in param_binding_text + assert "registered_temporal_parameter_reducer_program_cuda(" in param_binding_text + assert "def registered_temporal_parameter_reducer_program_cuda(" in registered_program_cuda_text + assert "transition_trainable_rows" in registered_program_cuda_text + assert "transition_source_rows" in registered_program_cuda_text + assert "parameter_reducer_strategy_rows" in registered_program_cuda_text + assert "parameter_reducer_trainable_role_rows" in registered_program_cuda_text + assert "parameter_reducer_runtime_metadata_rows" in registered_program_cuda_text + assert "message_strategy_grad_tensors" in param_binding_text + assert "message_strategy_grad_rows" in param_binding_text + assert "message_strategy_grad_tensors" in registered_program_cuda_text + assert "message_strategy_grad_rows" in registered_program_cuda_text + assert "message_strategy_grad_tensors_for_role(" in registered_program_kernel_text + assert 'request.reducer_kind != "fixed_slot_context_message"' not in param_binding_text + assert "scalar_logical_name =" not in param_binding_text + assert "required_static_logical_groups" in reducer_patterns_text + assert "grad_output_roles" in reducer_patterns_text + assert "active_trainable_roles" in reducer_patterns_text + assert '"fixed_slot_context_message": 5' not in reducer_patterns_text + assert '"message_strategy": 5' in reducer_patterns_text + assert 'count_target="message_strategy"' in message_specs_text + assert "context_nudge_query_slot_grad_tensors" not in registered_program_cuda_text + assert "context_nudge_query_slot_grad_tensors" not in registered_program_binding_text + assert "context_nudge_query_slot_grad_tensors" not in registered_program_kernel_text + assert "context_nudge_input_key_grad_tensors" not in registered_program_cuda_text + assert "context_nudge_input_key_grad_tensors" not in registered_program_binding_text + assert "context_nudge_input_key_grad_tensors" not in registered_program_kernel_text + assert "runtime_metadata_tensors" in registered_program_cuda_text + assert "parameter_reducer_program_execute" in registered_program_cuda_text + assert "parameter_reducer_program_execute" in registered_program_binding_text + assert "flat_bucket_registered_temporal_parameter_reducer_program_cuda(" in registered_program_kernel_text + assert "RegisteredParameterReducerHandler" in registered_program_kernel_text + assert "RegisteredParameterReducerRunFn" in registered_program_kernel_text + assert "RegisteredParameterReducerRuntimeContext" in registered_program_kernel_text + assert "RegisteredParameterReducerStrategy" in registered_program_kernel_text + assert "std::vector entries" in registered_program_kernel_text + assert "registered_parameter_reducer_expected_count(expected_counts," in registered_program_kernel_text + assert "int64_t fixed_slot_context_message = 0" not in registered_program_kernel_text + assert "kParameterReducerCountFixedSlotContextMessage" not in registered_program_kernel_text + assert "kParameterReducerCountMessageStrategy" in registered_program_kernel_text + assert "parameter reducer strategy row requires native callable hash" in registered_program_kernel_text + assert "transition trainable row requires native callable hash" in registered_program_kernel_text + assert "decode_registered_parameter_reducer_strategy_rows(" in registered_program_kernel_text + assert "decode_registered_parameter_reducer_role_table(" in registered_program_kernel_text + assert "registered_parameter_reducer_handler_for_native_callable(" in registered_program_kernel_text + assert "run_registered_parameter_reducer_handler(" in registered_program_kernel_text + assert "RegisteredParameterReducerHandlerKind" not in registered_program_kernel_text + assert "RegisteredTransitionTrainableReducerHandler" in registered_program_kernel_text + assert "registered_transition_trainable_reducer_handler_for_native_callable(" in registered_program_kernel_text + common_parameter_reducer_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "const RegisteredParameterReducerHandler&" + ) : registered_program_kernel_text.index("void run_registered_sender_kv_parameter_reducer_handler(") + ] + assert "handler.run(" in registered_program_kernel_text + assert "switch (handler.kind)" not in common_parameter_reducer_body + assert "case RegisteredParameterReducerHandlerKind" not in common_parameter_reducer_body + parameter_reducer_body = registered_program_kernel_text[ + registered_program_kernel_text.index( + "std::vector flat_bucket_registered_temporal_parameter_reducer_program_cuda(" + ) : registered_program_kernel_text.index( + "std::vector> registered_temporal_fused_reverse_program_step_impl(" + ) + ] + for fixed_reducer_argument in ( + "const at::Tensor& public_proj_weight", + "const at::Tensor& k_weight", + "const at::Tensor& v_weight", + "const at::Tensor& q_proj_weight", + "const at::Tensor& slot_embed", + "const at::Tensor& msg_out_weight", + "const at::Tensor& output_cell_weight", + "const at::Tensor& output_cell_bias", + "const at::Tensor& recurrent_cell_idx", + "const at::Tensor& output_cell_idx", + ): + assert fixed_reducer_argument not in parameter_reducer_body + assert "std::vector trainable_param_tensors" in parameter_reducer_body + assert "std::vector runtime_metadata_tensors" in parameter_reducer_body + for blocked_reducer_branch in ( + "case kParameterReducerReadoutOutput", + "case kParameterReducerSenderKvProjection", + "case kParameterReducerRecurrentQuery", + "case kParameterReducerTransition", + "case kParameterReducerOutputQuery", + "reducer_kind == kTransitionTrainable", + "kTransitionTrainableMaterializedBase ||", + "kTransitionTrainableMaterializedDelta ||", + ): + assert blocked_reducer_branch not in parameter_reducer_body + for deleted_transition_reducer_selector in ( + "kTransitionTrainableMaterializedBase", + "kTransitionTrainableMaterializedDelta", + "kTransitionTrainableValueToCellMsgToCell", + "kTransitionTrainableValueToCellMsgOut", + "kTransitionTrainableRecurrentBiasSlotEmbed", + "kTransitionTrainableRecurrentBiasCellBiasProj", + ): + assert deleted_transition_reducer_selector not in registered_program_kernel_text + assert "transition_recurrent_bias_full_grad(" in registered_program_kernel_text + assert "torch.einsum(" not in param_binding_text + assert ".matmul(" not in param_binding_text + assert "def _run_temporal_transition_param_reducer_program(" not in param_binding_text + assert "def _run_temporal_sender_kv_projection_param_reducer_program(" not in param_binding_text + assert "def _run_temporal_recurrent_query_param_reducer_program(" not in param_binding_text + assert "def _run_temporal_readout_output_param_reducer_program(" not in param_binding_text + assert "def run_temporal_transition_param_reducer_program(" not in param_binding_text + assert "def run_temporal_sender_kv_projection_param_reducer_program(" not in param_binding_text + assert "def run_temporal_recurrent_query_param_reducer_program(" not in param_binding_text + assert "def run_temporal_readout_output_param_reducer_program(" not in param_binding_text + assert "registered_temporal_parameter_reducer_cuda_row_program" in param_binding_text + assert "_last_flat_bucket_temporal_parameter_reducer_summaries" in param_binding_text + assert "registered_temporal_reverse_program_stage_cuda(" not in registered_executor_text + assert "registered_temporal_reverse_program_stage_cuda(" not in registered_program_cuda_text + assert "reverse_program_stage_rows=executor_program.reverse_program_stage_rows" in registered_executor_text + assert "executor_program.transition_param_grad_bindings" in registered_executor_text + assert "_transition_param_grad_accumulator_from_binding_rows(" in surface_executor_text + assert "build_temporal_transition_param_grad_binding_plan(" in registered_executor_text + assert '"grad_recurrent_kernel" in reverse_logical_to_slot' not in surface_executor_text + assert '"grad_nu_log" in reverse_logical_to_slot' not in surface_executor_text + assert "runtime._run_backend_output_projection_backward_phase(" not in surface_executor_text + assert "runtime._run_backend_query_param_backward_phase(" not in surface_executor_text + assert "runtime._run_backend_boundary_public_backward_phase(" not in surface_executor_text + assert "runtime._run_backend_sender_kv_projection_backward_phase(" not in surface_executor_text + assert "runtime._run_backend_initial_recurrent_backward_raw_phase(" not in surface_executor_text + assert "runtime._direct_sender_kv_group_ids(" not in surface_executor_text + assert "runtime._run_backend_query_param_backward_phase(" not in param_binding_text + assert "runtime._sender_kv_projection_param_grad_tuple_from_raw_grads(" not in param_binding_text + assert "runtime._state_public_explicit_param_grad_tuple(" not in param_binding_text + assert "_registered_transition_population_param_grad_tuple(" not in param_binding_text + assert "_last_flat_bucket_temporal_transition_trainable_parameter_rows" in param_binding_text + assert "run_temporal_recurrent_query_backward_sequence(" not in param_binding_text + assert "run_temporal_initial_recurrent_param_binding_sequence(" not in param_binding_text + assert "_run_registered_partitioned_message_backward_executor(" not in surface_executor_text + assert "_registered_sender_kv_projection_backward_raw(" not in surface_executor_text + assert "registered_backward_readout_layout_projection_cuda(" not in surface_executor_text + assert "run_registered_readout_layout_projection_backward_executor(" not in surface_executor_text + assert "run_registered_readout_projection_backward_executor(" not in surface_executor_text + assert "registered_temporal_fused_backward_program_readout_message_kv_step_cuda(" not in registered_executor_text + assert "registered_temporal_fused_backward_program_readout_step_cuda(" not in registered_executor_text + assert "registered_temporal_fused_backward_program_output_message_step_cuda(" not in registered_executor_text + assert ( + "registered_temporal_fused_backward_program_recurrent_kv_projection_step_cuda(" not in registered_executor_text + ) + assert "registered_temporal_fused_backward_program_recurrent_message_boundary_initial_kv_step_cuda(" not in ( + registered_executor_text + ) + assert 'stage_name="readout_message_kv_step"' not in registered_executor_text + assert "registered_temporal_fused_backward_program_cuda(" in registered_executor_text + assert 'stage_name="recurrent_message_boundary_initial_kv_step"' not in registered_executor_text + assert "registered_temporal_fused_backward_program_recurrent_message_initial_kv_step_cuda(" not in ( + registered_executor_text + ) + assert "registered_temporal_fused_backward_program_recurrent_message_step_cuda(" not in registered_executor_text + assert ( + "registered_temporal_fused_backward_program_boundary_kv_projection_step_cuda(" not in registered_executor_text + ) + assert "registered_temporal_fused_backward_program_initial_recurrent_kv_projection_step_cuda(" not in ( + registered_executor_text + ) + assert "kernel_registry.readout_layout_projection_backward(" not in registered_executor_text + assert "kernel_registry.output_message_backward(" not in registered_executor_text + assert "kernel_registry.recurrent_message_backward(" not in registered_executor_text + assert "kernel_registry.boundary_public_backward(" not in registered_executor_text + assert "kernel_registry.recurrent_kv_projection_backward(" not in registered_executor_text + assert "fabric_local_message_partitioned_backward_fused_cuda(" not in surface_executor_text + assert "registered_backward_partitioned_attention_cuda(" not in surface_executor_text + assert "fabric_sparse_message_partitioned_backward_receiver_cuda(" not in surface_executor_text + assert "fabric_sparse_message_partitioned_backward_sender_cuda(" not in surface_executor_text + assert "registered_backward_sparse_attention_cuda(" not in surface_executor_text + assert "receiver_major_affine_backward_cuda(" not in surface_executor_text + assert "fabric_grouped_projection_backward_cuda(" not in surface_executor_text + assert "registered_backward_sender_kv_projection_cuda(" not in surface_executor_text + assert "try_project_recurrent_kv_backend_order_backward_cuda(" not in surface_executor_text + assert "project_recurrent_kv_backend_order_backward_cuda(" not in surface_executor_text + assert "registered_reverse_executor_kernel_not_implemented" not in reverse_executor_text + assert "_require_registered_reverse_executor_for_table" in reverse_executor_text + assert "TemporalReverseWindowPayload" not in reverse_executor_text + assert "build_temporal_backward_compatibility_launch_plan(" not in reverse_executor_text + assert "try_transition_message_reverse_table_window_cuda(" not in reverse_executor_text + assert "tensor_slot_rows=backward_compatibility_launch_plan.tensor_slot_rows" not in reverse_executor_text + assert "cuda_recurrent_query_param_bind_deferred" not in reverse_executor_text + assert "cuda_recurrent_kv_param_bind_deferred" not in reverse_executor_text + assert "cuda_transition_param_bind_deferred" not in reverse_executor_text + assert "run_temporal_output_backward_sequence(" not in output_backward_text + assert "runtime._run_backend_sender_kv_projection_backward_phase(" not in output_backward_text + assert not (compiler_root / "compatibility.py").exists() + assert not (compiler_root / "backward_compatibility.py").exists() + assert not reverse_table_path.exists() + + +def test_temporal_forward_registered_executor_owns_scan_bindings() -> None: + repo_root = Path(__file__).resolve().parents[1] + forward_plan_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_plan.py" + ).read_text(encoding="utf-8") + executor_binding_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py" + ).read_text(encoding="utf-8") + scan_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket" + forward_scan_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py" + ).read_text(encoding="utf-8") + registered_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py" + ).read_text(encoding="utf-8") + executor_registry_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/executor_registry.py" + ).read_text(encoding="utf-8") + executor_patterns_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py" + ).read_text(encoding="utf-8") + transition_registry_text = ( + repo_root / "src/cortical/fabric/backend/cuda/transition_execution/registry.py" + ).read_text(encoding="utf-8") + surface_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/surface_executor_runtime.py" + ).read_text(encoding="utf-8") + program_parameters_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_parameters.py" + ).read_text(encoding="utf-8") + program_tensors_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/program_tensors.py" + ).read_text(encoding="utf-8") + forward_program_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/forward_program.py" + ).read_text(encoding="utf-8") + registered_program_cuda_text = (scan_root / "flat_bucket_registered_program_cuda.py").read_text(encoding="utf-8") + registered_program_kernel_text = _registered_program_kernel_source_text(repo_root) + temporal_common_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/common.py" + ).read_text(encoding="utf-8") + compiler_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler" + + assert "class TemporalForwardExecutablePlan" in forward_plan_text + assert "build_temporal_forward_executable_plan(" in forward_plan_text + assert "build_temporal_forward_executor_binding_plan(" in forward_plan_text + assert "forward_executable_plan=compiler_owned" in forward_plan_text + assert "tensor_slot_rows: torch.Tensor" not in forward_plan_text + assert "executor_tensor_slot_rows: torch.Tensor" not in forward_plan_text + assert "build_temporal_forward_executor_binding_plan(" in executor_binding_text + assert "TemporalExecutorTensorBinding" in executor_binding_text + assert "registered_executor_bindings_required" in forward_scan_text + assert "build_registered_temporal_executor_program(" in forward_scan_text + assert "executor_program=executor_program" in forward_scan_text + assert "class RegisteredTemporalExecutorProgram" in registered_executor_text + assert "kernel_registry: RegisteredTemporalExecutorKernelRegistry" in registered_executor_text + assert "registered_temporal_executor_kernel_registry()" in registered_executor_text + assert "class RegisteredTemporalExecutorKernelRegistry" in executor_registry_text + assert "program_executor_plan: TemporalRegisteredProgramExecutorPlan" in registered_executor_text + assert "def require_static_tensor(" in registered_executor_text + assert "preferred_keys" not in registered_executor_text + assert "preferred_names" not in registered_executor_text + assert "preferred_keys" not in surface_executor_text + assert "def require_runtime_tensor_attr(" in registered_executor_text + assert "def require_executor(" not in registered_executor_text + assert 'static_tensors["output_q"]' not in registered_executor_text + assert 'static_tensors["recurrent_q_backend_order"]' not in registered_executor_text + assert 'static_tensors["value_to_output_weight"]' not in registered_executor_text + assert "neighborhood_attention_project" not in registered_executor_text + assert "fixed_slot_context_nudge" not in registered_executor_text + assert "context_nudge" not in registered_executor_text + assert "projection_reduction_boundary" not in registered_executor_text + assert "gated_logspace_transition" not in registered_executor_text + assert "diag_rtu_transition" not in registered_executor_text + assert "neighborhood_attention_project" not in surface_executor_text + assert "projection_reduction_boundary" not in surface_executor_text + assert "gated_logspace_transition" not in surface_executor_text + assert "diag_rtu_transition" not in surface_executor_text + assert "neighborhood_attention_project" not in executor_registry_text + assert "projection_reduction_boundary" not in executor_registry_text + assert "gated_logspace_transition" not in executor_registry_text + assert "diag_rtu_transition" not in executor_registry_text + assert '_forward_executor_names(surface="message")' in executor_registry_text + assert '_forward_executor_names(surface="readout")' in executor_registry_text + assert '_forward_executor_names(surface="transition")' in executor_registry_text + assert '_reverse_executor_names(surface="readout")' in executor_registry_text + assert "temporal_executor_strategy_registry().forward_patterns()" in executor_registry_text + assert "temporal_executor_strategy_registry().reverse_patterns()" in executor_registry_text + message_specs_text = (repo_root / "src/cortical/fabric/backend/message_rule_specs.py").read_text(encoding="utf-8") + assert '"neighborhood_attention_project"' not in executor_patterns_text + assert '"neighborhood_attention_project"' in message_specs_text + assert "_registered_message_rule_forward_executor_patterns()" in executor_patterns_text + assert "registered_readout_rule_backend_spec_lowering_kinds()" in executor_patterns_text + assert "_registered_readout_rule_reverse_executor_patterns()" in executor_patterns_text + assert "registered_transition_forward_strategy_specs()" in executor_patterns_text + assert "registered_transition_reverse_strategy_specs()" in executor_patterns_text + assert '"gated_logspace_transition"' not in executor_patterns_text + assert '"diag_rtu_transition"' not in executor_patterns_text + assert '"gated_logspace_transition"' in transition_registry_text + assert '"diag_rtu_transition"' in transition_registry_text + assert "executor_validation=registry_owned" in executor_registry_text + assert "def transition_bucket_ordinals(" in registered_executor_text + assert "def require_transition_forward_coverage(" in registered_executor_text + assert "def forward_surface_handles(" in registered_executor_text + assert 'forward_surface_handles(surface="message")' in registered_executor_text + assert 'reverse_surface_handles(surface="readout")' in registered_executor_text + assert "def forward_surface_handle(" not in registered_executor_text + assert "def reverse_surface_handle(" not in registered_executor_text + assert 'forward_handle(surface="message", bucket_ordinal=-1)' not in registered_executor_text + assert 'reverse_handle(surface="readout", bucket_ordinal=-2)' not in registered_executor_text + assert "_registered_executor_rows_tensor" not in surface_executor_text + assert "_registered_executor_binding_rows_tensor" not in surface_executor_text + assert "forward_executor_rows=_registered" not in surface_executor_text + assert "forward_executor_rows=executor_program.forward_plan.forward_executor_rows" in registered_executor_text + assert "def _forward_executor_binding_rows_for_runtime(" in registered_executor_text + assert "forward_executor_binding_rows=forward_executor_binding_rows" in registered_executor_text + forward_scan_body = _function_source(registered_executor_text, "run_registered_temporal_forward_executor_scan") + assert "message_executor=message_executor" not in forward_scan_body + assert "forward_executor_rows=forward_executor_rows" not in surface_executor_text + assert "forward_executor_binding_rows=forward_executor_binding_rows" not in surface_executor_text + assert "public_kv_executor_id=int(message_executor.row.executor_id)" not in surface_executor_text + assert "public_kv_bucket_ordinal=int(message_executor.bucket_ordinal)" not in surface_executor_text + assert "Registered transition forward requires the compiled message executor" not in executor_registry_text + assert "run_registered_temporal_forward_executor_scan(" in forward_scan_text + assert "recompute_registered_temporal_artifact_window(" not in forward_scan_text + assert "runtime._project_sender_kv_from_cells_sequence(" not in forward_scan_text + assert "run_registered_temporal_forward_executor_scan(" in registered_executor_text + assert "compute_registered_temporal_bucket_step_artifacts(" not in registered_executor_text + assert "run_registered_cells_layout_executor(" not in registered_executor_text + assert "run_registered_input_kv_projection_sequence_executor(" not in registered_executor_text + assert "run_registered_recurrent_message_forward_executor(" not in registered_executor_text + assert "run_registered_readout_forward_executor(" not in registered_executor_text + assert "run_registered_recurrent_kv_projection_forward_executor(" not in registered_executor_text + assert "run_registered_transition_forward_executor(" not in registered_executor_text + assert "run_registered_cells_layout_executor" not in executor_registry_text + assert "run_registered_input_kv_projection_sequence_executor" not in executor_registry_text + assert "run_registered_recurrent_message_forward_executor" not in executor_registry_text + assert "run_registered_readout_forward_executor" not in executor_registry_text + assert "run_registered_readout_layout_epilogue_executor" not in executor_registry_text + assert "reverse_readout_executors=" in executor_registry_text + assert "run_registered_recurrent_kv_projection_forward_executor" not in executor_registry_text + assert "run_registered_transition_forward_executor" not in executor_registry_text + assert "runtime._compute_messages_step_subset_partitioned_raw(" not in registered_executor_text + assert "runtime._project_sender_kv_from_cells_sequence(" not in registered_executor_text + assert "run_backend_order_transition_buckets_step_cached_eager_result(" not in registered_executor_text + assert "runtime._run_backend_order_transition_buckets_step(" not in registered_executor_text + assert "try_fused_local_readout_cuda(" not in registered_executor_text + assert "def _fused_forward_program_tensor_table(" not in registered_executor_text + assert "build_forward_executable_program_tensor_table(" in registered_executor_text + assert "class TemporalExecutableProgramTensorTable" in program_tensors_text + assert "_resolve_fused_forward_program_surface_parameter(" not in registered_executor_text + assert "surface_parameter_tensor_table(" not in registered_executor_text + assert "surface_parameter_tensor_table(" in program_tensors_text + assert "def surface_parameter_tensor_table(" in program_parameters_text + assert "return reference.new_empty(0)" not in program_parameters_text + assert "_init_backend_population_state(" not in program_tensors_text + assert "dummy_message = boundary_seq.new_zeros" not in program_tensors_text + assert "tensor = boundary_seq.new_empty(\n (" in program_tensors_text + assert ( + 'program_tensor_for_binding_allow_empty(\n program_tensors, program_tensor_binding_rows, input_binding, "registered transition matmul input"' + in (registered_program_kernel_text) + ) + assert "registered transition matmul fresh-zero input sentinel must carry shape [B,0,H]" in ( + registered_program_kernel_text + ) + assert "if (fresh_zero_input) {\n output.zero_();" in registered_program_kernel_text + assert "could not resolve compiler state input binding" in program_tensors_text + assert "try_assemble_cells_graph_order_cuda(" not in registered_executor_text + assert "try_reorder_recurrent_graph_order_cuda(" not in registered_executor_text + assert "try_project_recurrent_kv_backend_order_cuda(" not in registered_executor_text + assert 'expected_names=("gated_logspace_transition", "diag_rtu_transition")' not in surface_executor_text + assert "registered_executor_runtime_project_sender_kv" not in surface_executor_text + assert "registered_executor_message_output_projection" not in surface_executor_text + assert "torch_cat" not in surface_executor_text + assert "runtime._compute_messages_step_subset_partitioned_raw(" not in surface_executor_text + assert "runtime._project_sender_kv_from_cells_sequence(" not in surface_executor_text + assert "runtime._direct_sender_kv_group_ids(" not in surface_executor_text + assert "run_backend_order_transition_buckets_step_cached_eager_result(" not in surface_executor_text + assert "runtime._run_backend_order_transition_buckets_step(" not in surface_executor_text + assert "lower_backend_population_transition_forward_result_shared(" not in surface_executor_text + assert "registered_temporal_fused_forward_program_transition_step_cuda(" not in surface_executor_text + assert "registered_temporal_fused_forward_transition_program_cuda(" not in surface_executor_text + assert "_extend_transition_reverse_program_tensor_table(" not in surface_executor_text + assert "reference_state.new_empty(0)" not in surface_executor_text + assert "_transition_program_tensor_table(" not in surface_executor_text + assert "_transition_forward_aggregate_binding(" not in surface_executor_text + assert "_transition_input_binding_index(" not in surface_executor_text + assert "Registered transition forward program has no aggregate input binding" not in surface_executor_text + assert "_transition_forward_state_tensor_table(" not in surface_executor_text + assert "program_tensor_binding_rows=program_tensor_binding_rows" not in surface_executor_text + assert "temporal_forward_program_access_rows_tensor(" in program_tensors_text + assert "temporal_reverse_program_access_rows_tensor(" in program_tensors_text + assert "temporal_forward_transition_state_carry_rows_tensor(" in program_tensors_text + assert "_PROGRAM_ACCESS_OPCODE" not in forward_program_text + assert "stable_access_opcode" in forward_program_text + assert "access_opcode=" in executor_patterns_text + assert "forward_program_access_rows=program_tensor_table.forward_program_access_rows" in registered_executor_text + assert "reverse_program_access_rows=reverse_program_tensor_table.reverse_program_access_rows" in ( + registered_executor_text + ) + assert "forward_transition_state_carry_rows=program_tensor_table.forward_transition_state_carry_rows" in ( + registered_executor_text + ) + assert "def _fused_reverse_program_tensor_table(" not in registered_executor_text + assert "build_reverse_executable_program_tensor_table(" in registered_executor_text + assert "class TemporalExecutableProgramTensorTable" in program_tensors_text + assert "reverse_program_access_rows: torch.Tensor" in program_tensors_text + assert "reverse_program_access_rows=reverse_program_tensor_table.reverse_program_access_rows" in ( + registered_executor_text + ) + assert "readout_primitive_start" not in registered_program_kernel_text + assert "message_primitive_start" not in registered_program_kernel_text + assert "message_primitive_start + 1" not in registered_program_kernel_text + assert "registered_forward_handler_span_indices_by_capability(" in registered_program_kernel_text + assert "registered_reverse_handler_span_indices_by_capability(" in registered_program_kernel_text + assert "find_registered_forward_handler_span" not in registered_program_kernel_text + assert "find_registered_reverse_handler_span" not in registered_program_kernel_text + assert "primitive_rows=executor_program.primitive_rows" in registered_executor_text + assert "memory_liveness_rows=executor_program.memory_liveness_rows" in registered_executor_text + assert "_transition_buckets_for_executors(" not in surface_executor_text + assert "_run_registered_partitioned_message_forward_executor(" not in surface_executor_text + assert "project_sender_kv_from_cells_sequence(" not in surface_executor_text + assert "registered_forward_sender_kv_sequence_cuda(" not in surface_executor_text + assert "registered_forward_sender_kv_step_cuda(" not in surface_executor_text + assert "fabric_local_message_partitioned_cuda(" not in surface_executor_text + assert "registered_forward_partitioned_attention_cuda(" not in surface_executor_text + assert "fabric_sparse_message_partitioned_cuda(" not in surface_executor_text + assert "registered_forward_sparse_attention_cuda(" not in surface_executor_text + assert "try_fused_local_readout_cuda(" not in surface_executor_text + assert "try_assemble_cells_graph_order_cuda(" not in surface_executor_text + assert "try_reorder_recurrent_graph_order_cuda(" not in surface_executor_text + assert "try_project_recurrent_kv_backend_order_cuda(" not in surface_executor_text + assert "fused_local_readout_cuda(" not in surface_executor_text + assert "registered_forward_readout_projection_cuda(" not in surface_executor_text + assert "registered_forward_readout_layout_epilogue_cuda(" not in surface_executor_text + assert "registered_forward_cells_layout_cuda(" not in surface_executor_text + assert "return_reverse_artifacts=bool(collect_artifacts)" in registered_executor_text + assert "return_final_program_tensors=bool(materialize_final_state)" in registered_executor_text + assert "clear_forward_transition_output_binding_slots(\n program_tensors" in registered_program_kernel_text + assert "compact_forward_program_tensor_table_for_return(\n program_tensors" in registered_program_kernel_text + assert "fresh-zero state sentinel" in registered_program_kernel_text + assert "outputs.insert(outputs.end(), program_tensors.begin(), program_tensors.end())" in ( + registered_program_kernel_text + ) + compact_call_index = registered_program_kernel_text.rindex("compact_forward_program_tensor_table_for_return(") + output_insert_index = registered_program_kernel_text.index( + "outputs.insert(outputs.end(), program_tensors.begin(), program_tensors.end())" + ) + assert compact_call_index < output_insert_index + assert ( + "materialize_output_message=collect_artifacts" + not in registered_executor_text[ + registered_executor_text.index( + "def run_registered_temporal_forward_executor_scan(" + ) : registered_executor_text.index("\ndef _reverse_artifact_tensor_store_window_table(") + ] + ) + assert "readout_layout_epilogue(" not in registered_executor_text + assert "forward_executor_rows" in registered_program_cuda_text + assert "forward_executor_binding_rows" in registered_program_cuda_text + assert "kTf" not in registered_program_kernel_text + assert "InitialGated" not in registered_program_kernel_text + assert "InitialDiagonal" not in registered_program_kernel_text + assert "assemble_cells_graph_order_cuda(" not in surface_executor_text + assert "reorder_recurrent_graph_order_cuda(" not in surface_executor_text + assert "project_recurrent_kv_backend_order_cuda(" not in surface_executor_text + assert "recompute_registered_temporal_artifact_window(" not in registered_executor_text + assert "registered_forward_executor_kernel_not_implemented" not in forward_scan_text + assert "registered_forward_executor_recompute_kernel_not_implemented" not in forward_scan_text + assert ( + "registered_forward_executor_recompute_requires_stored_artifacts_or_registered_replay" not in forward_scan_text + ) + assert "try_flat_bucket_temporal_scan_cuda(" not in forward_scan_text + assert "_flat_bucket_temporal_scan_bucket_selection(" not in forward_scan_text + assert "_flat_bucket_temporal_scan_bucket_selection(" not in temporal_common_text + assert "_flat_bucket_temporal_scan_buckets(" not in temporal_common_text + assert "_transition_tape_by_population_from_cuda_scan(" not in temporal_common_text + assert "_initial_recurrent_kv_backend_order(" not in temporal_common_text + assert "_temporal_backend_order_recurrent_kv_projection_backward(" not in temporal_common_text + assert "DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE" not in temporal_common_text + assert "runtime._project_sender_kv_from_cells_step(" not in temporal_common_text + assert "runtime._direct_sender_kv_group_ids(" not in temporal_common_text + assert "gated_bucket_index" not in temporal_common_text + assert "diagonal_bucket_index" not in temporal_common_text + assert "_flat_bucket_temporal_table_plan(" in forward_scan_text + assert "_gated_temporal_scan_parameter_tensors(" not in forward_scan_text + assert "_diagonal_temporal_scan_parameter_tensors(" not in forward_scan_text + assert "compiled_transition_program_for_bucket" not in forward_scan_text + assert "resolve_transition_parameter" not in forward_scan_text + assert "build_temporal_forward_compatibility_launch_plan(" not in forward_scan_text + assert "build_temporal_backward_compatibility_launch_plan(" not in forward_scan_text + assert "tensor_slot_rows=forward_compatibility_launch_plan.tensor_slot_rows" not in forward_scan_text + assert ( + "executor_tensor_slot_rows=forward_compatibility_launch_plan.executor_tensor_slot_rows" not in forward_scan_text + ) + assert not (compiler_root / "compatibility.py").exists() + assert not (compiler_root / "forward_compatibility.py").exists() + assert not (scan_root / "flat_bucket_temporal_scan_cuda.py").exists() + assert not (scan_root / "flat_bucket_temporal_scan_binding.cpp").exists() + assert not (scan_root / "flat_bucket_temporal_scan_kernels.cu").exists() + + +def test_sequence_executor_does_not_keep_active_output_python_route() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py").read_text( + encoding="utf-8" + ) + + assert "def execute_temporal_bucket_active_output_window(" not in source_text + assert "def record_temporal_bucket_sequence_surface_execution(" not in source_text + assert "_compute_messages_step_subset_partitioned_raw(" not in source_text + assert "_run_active_window_transition_buckets_step(" not in source_text + + +def test_sequence_executor_fails_closed_for_unlowered_boundary_variable_k() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py").read_text( + encoding="utf-8" + ) + start = source_text.index("def execute_temporal_bucket_sequence(") + end = source_text.index("\ndef _apply_temporal_output_contract(", start) + body = source_text[start:end] + + assert "planner-lowered per-timestep K" in body + assert "compiler-owned temporal schedule rows" in body + + +def test_sequence_executor_has_no_direct_hidden_step_loop_after_shared_scan() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py").read_text( + encoding="utf-8" + ) + start = source_text.index("def execute_temporal_bucket_sequence(") + end = source_text.index("\ndef _apply_temporal_output_contract(", start) + body = source_text[start:end] + + assert "Direct hidden-input temporal execution" in body + assert "runtime._forward_stream_step(" not in body + assert "for step_index in range(time_steps)" not in body + + +def test_temporal_object_artifact_recompute_branch_was_deleted() -> None: + repo_root = Path(__file__).resolve().parents[1] + source_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py").read_text( + encoding="utf-8" + ) + registered_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py" + ).read_text(encoding="utf-8") + + assert "def _try_registered_temporal_recompute_artifact_window(" not in source_text + assert "def _recompute_temporal_bucket_artifact_window(" not in source_text + assert "recompute_registered_temporal_artifact_window(" not in source_text + assert "def recompute_registered_temporal_artifact_window(" not in registered_executor_text + assert '"recompute_registered_temporal_artifact_window"' not in registered_executor_text + assert "TemporalReverseArtifactTensorStore" in registered_executor_text + + +def test_temporal_backward_replay_requests_come_from_scheduler_plan() -> None: + repo_root = Path(__file__).resolve().parents[1] + reverse_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py" + ).read_text(encoding="utf-8") + forward_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py").read_text( + encoding="utf-8" + ) + + assert "build_temporal_runtime_scheduler_plan(" in reverse_text + assert "replay_request=scheduler_plan.replay_request_for_window(" not in reverse_text + assert "_recompute_temporal_bucket_artifact_window(" not in reverse_text + assert "nearest_temporal_artifact_checkpoint" not in reverse_text + assert "step-object artifact windows are not an active CUDA training path" in reverse_text + assert "output_message_physical_steps=tuple(" not in reverse_text + assert "replay_request: TemporalReplayArtifactRequest | None" not in forward_text + + +def test_temporal_scheduler_plan_is_plumbed_through_forward_and_backward() -> None: + repo_root = Path(__file__).resolve().parents[1] + executor_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py").read_text( + encoding="utf-8" + ) + forward_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py").read_text( + encoding="utf-8" + ) + physical_autograd_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py" + ).read_text(encoding="utf-8") + reverse_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py" + ).read_text(encoding="utf-8") + + assert 'temporal_plan=getattr(runtime, "_last_temporal_execution_plan", None)' in executor_text + assert "runtime._last_flat_bucket_temporal_scheduler_plan = ()" in executor_text + assert "flat_bucket_temporal_scheduler:{item}" in executor_text + assert "temporal_plan: Any | None = None" in forward_text + assert "build_temporal_runtime_scheduler_plan(" in forward_text + assert "output_boundary=cast(" in physical_autograd_text + assert "output_boundary=self.output_boundary" in reverse_text + assert "grad_output_seq.shape[1]) == 1" not in reverse_text + + +def test_forward_scan_records_temporal_table_metadata_through_compiler_helper() -> None: + repo_root = Path(__file__).resolve().parents[1] + forward_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/forward_scan.py").read_text( + encoding="utf-8" + ) + metadata_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/runtime_metadata.py" + ).read_text(encoding="utf-8") + + assert "record_temporal_primitive_table_runtime_metadata(" in forward_text + assert "_last_flat_bucket_temporal_primitive_executor_blockers" in metadata_text + assert "_last_flat_bucket_temporal_reverse_executor_summaries" in metadata_text + assert "verify_temporal_primitive_table(" in metadata_text + assert "_last_flat_bucket_temporal_verifier_status" in metadata_text + assert "_last_flat_bucket_temporal_effect_summaries" in metadata_text + assert "_last_flat_bucket_temporal_planner_explain" in metadata_text + assert "build_temporal_strategy_selection_report(" in metadata_text + assert "_last_flat_bucket_temporal_strategy_candidates" in metadata_text + assert "_last_flat_bucket_temporal_legal_strategy_candidates" in metadata_text + assert "_last_flat_bucket_temporal_blocked_strategy_candidates" in metadata_text + assert "build_temporal_primitive_executor_plan(" in metadata_text + assert "build_temporal_forward_executor_binding_plan(" in metadata_text + assert "build_temporal_reverse_executor_binding_plan(" in metadata_text + assert "_last_flat_bucket_temporal_forward_executor_binding_summaries" in metadata_text + assert "_last_flat_bucket_temporal_reverse_executor_binding_summaries" in metadata_text + + +def test_temporal_executor_bindings_are_compiler_products_not_compatibility_slots() -> None: + repo_root = Path(__file__).resolve().parents[1] + binding_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py" + ).read_text(encoding="utf-8") + table_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py").read_text( + encoding="utf-8" + ) + primitive_registry_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/primitive_registry.py" + ).read_text(encoding="utf-8") + + assert "class TemporalExecutorTensorBinding" in binding_text + assert "class TemporalExecutorBindingPlan" in binding_text + assert "build_temporal_forward_executor_binding_plan(" in binding_text + assert "build_temporal_reverse_executor_binding_plan(" in binding_text + assert "TemporalTensorBindingRow" in binding_text + assert "compatibility" not in binding_text + assert "kTf" not in binding_text + assert "gated_start" not in binding_text + assert "_TEMPORAL_PRIMITIVE_OPCODE_BY_NAME" not in table_text + assert "_TEMPORAL_FORWARD_EXECUTOR_OPCODE_BY_NAME" not in table_text + assert "temporal_primitive_opcode(" in table_text + assert "TemporalPrimitiveDefinition" in primitive_registry_text + assert "register the primitive before lowering it into temporal rows" in primitive_registry_text + + +def test_temporal_memory_plan_drives_checkpoint_and_backward_windows() -> None: + repo_root = Path(__file__).resolve().parents[1] + memory_plan_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/memory_plan.py" + ).read_text(encoding="utf-8") + registered_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/registered_executors.py" + ).read_text(encoding="utf-8") + reverse_executor_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py" + ).read_text(encoding="utf-8") + registered_program_cuda_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py" + ).read_text(encoding="utf-8") + registered_program_binding_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_binding.cpp" + ).read_text(encoding="utf-8") + registered_program_kernel_text = _registered_program_kernel_source_text(repo_root) + temporal_types_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/types.py").read_text( + encoding="utf-8" + ) + windows_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/windows.py").read_text( + encoding="utf-8" + ) + program_runtime_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/program_runtime.py" + ).read_text(encoding="utf-8") + suite_common_text = (repo_root / "benchmarks/fabric/suite_common.py").read_text(encoding="utf-8") + + assert "checkpoint_steps: tuple[int, ...]" in memory_plan_text + assert "backward_windows: tuple[tuple[int, int], ...]" in memory_plan_text + assert "memory_plan_fingerprint: tuple[str, ...]" in temporal_types_text + assert "memory_runtime_artifact_fingerprint: tuple[str, ...]" in temporal_types_text + assert "memory_runtime_policy_fingerprint: tuple[str, ...]" in temporal_types_text + assert "memory_runtime_schedule_fingerprint: tuple[str, ...]" in temporal_types_text + assert "memory_runtime_schedule_rows: torch.Tensor | None = None" in temporal_types_text + assert "class TemporalReverseArtifactTensorStore" in temporal_types_text + assert "reverse_artifact_tensor_store: TemporalReverseArtifactTensorStore | None = None" in temporal_types_text + assert "materialize_registered_temporal_artifact_window_from_tensor_store(" not in reverse_executor_text + assert "run_registered_temporal_reverse_executor_tensor_store_window(" in reverse_executor_text + assert "registered_fused_forward_program_tensor_store_direct" in reverse_executor_text + assert "class TemporalRuntimeBufferPlan" in memory_plan_text + assert '"runtime_policy"' in memory_plan_text + assert '"policy_table"' in memory_plan_text + assert '"compiler_memory_policy"' in memory_plan_text + assert '"alias_policy"' in memory_plan_text + assert '"recompute_window_policy"' in memory_plan_text + assert '"materialization_policy"' in memory_plan_text + assert '"cuda_graph_constraint"' in memory_plan_text + assert "build_temporal_runtime_buffer_plan(" in memory_plan_text + assert "class TemporalMemoryRuntimeSchedulePlan" in memory_plan_text + assert "build_temporal_memory_runtime_schedule_plan(" in memory_plan_text + assert "temporal_memory_runtime_schedule_rows_tensor(" in memory_plan_text + assert '"local_seed_policy": 1' in memory_plan_text + assert '"metadata_policy": 2' in memory_plan_text + assert '"backward_window": 23' in memory_plan_text + assert "memory_runtime_schedule_plan=compiler_executable" in memory_plan_text + assert "allocate_temporal_runtime_buffer(" in memory_plan_text + assert "allocate_temporal_runtime_buffers(" in memory_plan_text + assert 'allocation: Literal["eager", "deferred_local"] = "eager"' in memory_plan_text + assert "defer_forward_step_buffers: bool = False" in memory_plan_text + assert "defer_local_transition_outputs: bool = False" in memory_plan_text + assert "allocation={self.allocation}" in memory_plan_text + assert "defer_local_transition_outputs=not bool(collect_artifacts) and not bool(materialize_final_state)" in ( + registered_executor_text + ) + assert "defer_forward_step_buffers=not bool(collect_artifacts) and not bool(materialize_final_state)" in ( + registered_executor_text + ) + assert "temporal_runtime_buffer_rows_tensor(" in memory_plan_text + assert "runtime_schedule_fingerprint: tuple[str, ...]" in memory_plan_text + assert "runtime_schedule_rows: torch.Tensor | None = None" in memory_plan_text + assert "_require_runtime_schedule_matches_policy(" in memory_plan_text + assert "def validate_temporal_runtime_buffer_plan(" in memory_plan_text + assert "validate_temporal_runtime_buffer_plan(" in memory_plan_text + assert "require_workspace_coverage=include_workspace_rows" in memory_plan_text + assert "must reference a compiler memory row" in memory_plan_text + assert "does not cover all executable compiler memory rows" in memory_plan_text + assert "temporal_runtime_buffer_spec(" in memory_plan_text + assert "build_temporal_runtime_buffer_plan(" in registered_executor_text + assert "build_temporal_runtime_buffer_plan(" in reverse_executor_text + assert "allocate_temporal_runtime_buffers(" in registered_executor_text + assert "output_seq_shape=(" in registered_executor_text + assert "cells_prev_shape=(" in registered_executor_text + assert "_forward_reverse_artifact_roles_for_runtime(" in registered_executor_text + assert "recurrent_hidden_shape=(" in registered_executor_text + assert "forward_recurrent_msg_shape=(" in registered_executor_text + assert "forward_output_msg_shape=(" in registered_executor_text + assert "forward_output_cells_shape=(" in registered_executor_text + assert "forward_message_step_flat_shape=(" in registered_executor_text + assert "reverse_message_step_flat_shape=(" in registered_executor_text + assert "_requires_reverse_grad_recurrent_msg_runtime_buffer(" in registered_executor_text + assert "reverse_grad_recurrent_msg_shape = (" in registered_executor_text + assert "reverse_grad_recurrent_msg_shape=reverse_grad_recurrent_msg_shape" in registered_executor_text + assert "grad_carry_cells_shape=grad_cells_shape if bool(materialize_grad_carry_cells) else None" in ( + registered_executor_text + ) + assert "reverse_grad_cells_work_shape=grad_cells_shape" in registered_executor_text + assert "materialize_grad_carry_cells" in registered_executor_text + assert '"grad_carry_materialization_policy"' in program_runtime_text + assert "state_requires_grad=any(tensor.requires_grad for tensor in state_tensors)" in ( + (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py").read_text( + encoding="utf-8" + ) + ) + assert "return_window_start_transition_state_grads=bool(self.state_requires_grad) or int(window_start) > 0" in ( + reverse_executor_text + ) + assert "return_window_start_transition_state_grads=bool(return_window_start_transition_state_grads)" in ( + registered_executor_text + ) + assert "local_step > 0 || return_window_start_transition_state_grads" in registered_program_kernel_text + assert "runtime_buffer_rows=runtime_buffer_rows" in registered_executor_text + assert "runtime_buffer_tensors=runtime_buffer_tensors" in registered_executor_text + assert "memory_runtime_schedule_rows=memory_runtime_schedule_rows" in registered_executor_text + assert "memory_runtime_schedule_rows=cast(torch.Tensor, runtime_buffer_plan.runtime_schedule_rows)" in ( + registered_executor_text + ) + assert "physical_strategy_rows=physical_strategy_rows" in registered_executor_text + assert "physical_strategy_rows: torch.Tensor" in registered_program_cuda_text + assert "readout_message_producer_consumer_rows=executor_program.readout_message_producer_consumer_rows" in ( + registered_executor_text + ) + assert "readout_message_producer_consumer_rows: torch.Tensor" in registered_program_cuda_text + assert "build_temporal_readout_message_producer_consumer_plan(" in registered_executor_text + assert "temporal_readout_message_producer_consumer_rows_tensor(" in registered_executor_text + assert "readout_message_producer_consumer_rows" in registered_program_binding_text + assert "_last_flat_bucket_temporal_readout_message_producer_consumer_rows" in registered_executor_text + assert "message_transition_producer_consumer_rows=executor_program.message_transition_producer_consumer_rows" in ( + registered_executor_text + ) + assert "message_transition_producer_consumer_rows: torch.Tensor" in registered_program_cuda_text + assert "build_temporal_message_transition_producer_consumer_plan(" in registered_executor_text + assert "temporal_message_transition_producer_consumer_rows_tensor(" in registered_executor_text + assert "message_transition_producer_consumer_rows" in registered_program_binding_text + assert "_last_flat_bucket_temporal_message_transition_producer_consumer_rows" in registered_executor_text + assert "build_temporal_physical_strategy_plan(" in registered_executor_text + assert "temporal_physical_strategy_rows_tensor(" in registered_executor_text + assert "physical_strategy_rows" in registered_program_binding_text + assert "runtime_schedule_plan=memory_artifact_plan.runtime_schedule_plan" in registered_executor_text + assert "runtime_schedule_plan=memory_artifact_plan.runtime_schedule_plan" in reverse_executor_text + assert "_last_flat_bucket_temporal_registered_backward_memory_stages" in registered_executor_text + assert "registered_temporal_fused_backward_program_stage_memory_rows" in registered_executor_text + assert "autograd_backward_entry" in ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/physical_autograd.py" + ).read_text(encoding="utf-8") + assert "physical_backward_run_entry" in reverse_executor_text + assert "output_grad_window_materialized" in reverse_executor_text + assert "reverse_artifact_window_table_built" in registered_executor_text + assert "output_grad_cell_window_built" in registered_executor_text + assert "forward_runtime_buffer_plan_built" in registered_executor_text + assert "forward_runtime_buffers_allocated" in registered_executor_text + assert "after_fused_forward_program" in registered_executor_text + assert "forward_reverse_artifact_store_built" in registered_executor_text + assert "native_after_transition" in registered_executor_text + assert "native_transition_group_dynamic_bound" in registered_executor_text + assert "native_transition_group_after_reverse_primitive" in registered_executor_text + assert "kRegisteredBackwardMemoryStageAfterTransition" in registered_program_kernel_text + assert "kRegisteredBackwardMemoryStageTransitionGroupDynamicBound" in registered_program_kernel_text + assert "kRegisteredBackwardMemoryStageTransitionGroupAfterReversePrimitive" in registered_program_kernel_text + assert "transition_seed_tensor_or_cached_zeros" in registered_program_kernel_text + assert "cached_zero_seed_tensors" in registered_program_kernel_text + assert "transition_seed_tensor_or_zeros(" not in registered_program_kernel_text + assert "transition_reverse_recurrent_msg_span" in memory_plan_text + assert "transition_reverse_state_before_zero" in memory_plan_text + assert "_transition_reverse_dynamic_runtime_buffer_requests(" in registered_executor_text + assert "transition_reverse_dynamic_buffers=_transition_reverse_dynamic_runtime_buffer_requests(" in ( + registered_executor_text + ) + assert "kRuntimeBufferRoleTransitionReverseRecurrentMsgSpan" in registered_program_kernel_text + assert "kRuntimeBufferRoleTransitionReverseStateBeforeZero" in registered_program_kernel_text + assert "registered transition reverse dynamic recurrent-message span" in registered_program_kernel_text + assert "registered transition reverse dynamic state-before zero" in registered_program_kernel_text + assert "value = at::zeros_like(reference);" not in registered_program_kernel_text + assert "span.receiver_start, span.receiver_start + span.receiver_count).contiguous()" not in ( + registered_program_kernel_text + ) + assert "flat_bucket_temporal_registered_backward_memory_stage" in suite_common_text + assert "allocate_temporal_runtime_buffer(" in reverse_executor_text + assert "validate_registered_runtime_buffer_rows(" in registered_program_kernel_text + assert "registered_runtime_buffer_for_workspace_effect(" in registered_program_kernel_text + assert "registered_runtime_buffer_for_role(" in registered_program_kernel_text + assert "registered_materialize_deferred_local_runtime_buffer(" in registered_program_kernel_text + assert "registered_runtime_buffer_role_allows_deferred_local(" in registered_program_kernel_text + assert "deferred local runtime buffer placeholder is only legal for compiler-routed step-local outputs" in ( + registered_program_kernel_text + ) + assert "registered_runtime_buffer_has_role(" in registered_program_kernel_text + assert "kRuntimeBufferRoleReverseGradCarryCells" in registered_program_kernel_text + assert "missing compiler grad-carry cells buffer before an earlier local step" in registered_program_kernel_text + assert "kRuntimePolicySurfaceOpcode" in registered_program_kernel_text + assert "kMemoryWorkspacePolicyTable" in registered_program_kernel_text + assert "kMemoryOwnerCompilerMemoryPolicy" in registered_program_kernel_text + assert "runtime policy memory_liveness_rows are missing local_seed_policy" in registered_program_kernel_text + assert "runtime policy memory_liveness_rows are missing cuda_graph_constraint" in registered_program_kernel_text + assert "validate_registered_memory_runtime_schedule_rows(" in registered_program_kernel_text + assert "validate_registered_physical_strategy_rows(" in registered_program_kernel_text + assert "registered_active_physical_strategy_opcode(" in registered_program_kernel_text + assert "registered_physical_strategy_active_is_streaming(" in registered_program_kernel_text + assert ( + "physical_strategy_rows missing streaming-step producer-consumer strategy row" in registered_program_kernel_text + ) + assert "active physical_strategy_rows entry must be executable" in registered_program_kernel_text + assert "validate_readout_message_producer_consumer_rows(" in registered_program_kernel_text + assert "readout_message_producer_consumer_rows missing active executable readout/message strategy row" in ( + registered_program_kernel_text + ) + assert "run_registered_forward_message_stream_readout_handler(" in registered_program_kernel_text + assert "stream_readout_message" in registered_program_kernel_text + assert "run_fixed_slot_context_stream_readout_message" in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh" + ).read_text(encoding="utf-8") + assert "message native strategy has no stream_readout_message implementation" in (registered_program_kernel_text) + assert "readout_message_producer_consumer_rows missing stream_readout_from_message_projection row" in ( + registered_program_kernel_text + ) + assert "validate_message_transition_producer_consumer_rows(" in registered_program_kernel_text + assert "message_transition_producer_consumer_rows missing active executable message/transition strategy row" in ( + registered_program_kernel_text + ) + assert "message_transition_producer_consumer_rows missing stream_message_to_transition_input row" in ( + registered_program_kernel_text + ) + assert "message_step_state_for_message_transition_row(" in registered_program_kernel_text + assert "aggregate_input = producer_state.recurrent_msg;" not in registered_program_kernel_text + assert "run_registered_forward_message_stream_transition_input_handler(" in registered_program_kernel_text + assert "registered_transition_input_projection_target_for_span(" in registered_program_kernel_text + assert "registered fused forward streamed transition aggregate sentinel" in registered_program_kernel_text + assert "stream_transition_input" in registered_program_kernel_text + assert "registered_message_transition_direct_input_supported_for_message(" in registered_program_kernel_text + assert "kPrimitiveGatedLogspaceRecurrenceOpcode" in registered_program_kernel_text + assert "run_fixed_slot_context_stream_transition_input" in ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh" + ).read_text(encoding="utf-8") + assert "registered_message_transition_streaming_row_targets_message(" in registered_program_kernel_text + assert "recurrent_msg_output_override" in registered_program_kernel_text + assert "registered fused forward streaming recurrent message/hidden-after alias" in registered_program_kernel_text + assert "streaming_recurrent_hidden_after_alias" in registered_program_kernel_text + assert "selected direct transition input" in registered_program_kernel_text + assert ( + "for a non-singleton transition span; explicit merge/chunk rows are required" in registered_program_kernel_text + ) + assert "streaming_step_strategy=registered_program_body" in memory_plan_text + assert ( + "registered fused forward streaming-step strategy requires deferred local recurrent-message buffer ownership" + in (registered_program_kernel_text) + ) + assert "registered fused forward streaming-step strategy requires deferred local transition-output ownership" in ( + registered_program_kernel_text + ) + assert "terminal_local_transition_state" in registered_program_kernel_text + assert "allow_terminal_local_state_outputs" in registered_program_kernel_text + assert "next_y_is_deferred_local" in registered_program_kernel_text + assert "registered transition gated row-group norm output" in registered_program_kernel_text + assert "if (!terminal_local_transition_state)" in registered_program_kernel_text + assert "native_forward_after_streaming_message_release" in registered_executor_text + assert "kRegisteredForwardMemoryStageAfterStreamingMessageRelease" in registered_program_kernel_text + assert "memory_runtime_schedule_rows missing local_seed_policy" in registered_program_kernel_text + assert "memory_runtime_schedule_rows missing cuda_graph_constraint" in registered_program_kernel_text + assert "policy opcode does not match memory_liveness_rows recompute policy" in registered_program_kernel_text + assert "enforce_registered_cuda_graph_launch_guard(" in registered_program_kernel_text + assert "cudaStreamIsCapturing(" in registered_program_kernel_text + assert "kMemoryRecomputePolicyCudaGraphGuardPolicy" in registered_program_kernel_text + assert "requires compiler-owned cuda_graph_guard_policy before fused CUDA launch" in ( + registered_program_kernel_text + ) + + assert "registered fused forward program output_seq" in registered_program_kernel_text + assert "registered fused forward program cells_prev artifact" in registered_program_kernel_text + assert 'role for role in roles if role != "cells_prev"' in registered_executor_text + assert "registered fused forward program recurrent hidden after" in registered_program_kernel_text + assert "registered fused forward recurrent message" in registered_program_kernel_text + assert "registered fused forward output message" in registered_program_kernel_text + assert "registered fused forward output cells" in registered_program_kernel_text + assert "kRuntimeBufferRoleForwardRecurrentMsg" in registered_program_kernel_text + assert "kRuntimeBufferRoleForwardOutputMsg" in registered_program_kernel_text + assert "kRuntimeBufferRoleForwardOutputCells" in registered_program_kernel_text + assert "kRuntimeBufferRoleReverseGradRecurrentMsg" in registered_program_kernel_text + assert "kRuntimeBufferRoleForwardMessageStepFlat" in registered_program_kernel_text + assert "kRuntimeBufferRoleReverseMessageStepFlat" in registered_program_kernel_text + assert "at::full({B}, message_step" not in registered_program_kernel_text + assert "at::full(\n {grad_output_msg.size(0)}" not in registered_program_kernel_text + assert "at::full(\n {grad_recurrent_msg.size(0)}" not in registered_program_kernel_text + assert "fused backward program grad carry cells" in registered_program_kernel_text + assert "at::Tensor output_seq = at::empty({B, output_steps" not in registered_program_kernel_text + assert "at::zeros({B, full_cell_count, hidden}" not in registered_program_kernel_text + assert "at::empty_like(recurrent_hidden)" not in registered_program_kernel_text + assert "at::zeros_like(cells_prev)" not in registered_program_kernel_text + assert "torch.zeros_like(cells_prev)" not in registered_executor_text + assert "runtime_buffer_rows" in registered_program_cuda_text + assert "_temporal_forward_output_step_shape_for_contract(" in registered_executor_text + assert "if output_seq is None" not in registered_executor_text + assert "output_seq: torch.Tensor | None" not in registered_executor_text + assert "output_steps: list[torch.Tensor]" not in registered_executor_text + assert "grad_boundary_seq = torch.zeros_like(boundary_seq)" not in reverse_executor_text + assert "grad_boundary_seq = boundary_seq.new_zeros" not in reverse_executor_text + assert "def _require_compiler_memory_artifact_plan(" in reverse_executor_text + assert "memory_artifact_plan=validated_against_compiler_liveness_fingerprint" in reverse_executor_text + assert "def _planned_checkpoint_steps(" in memory_plan_text + assert "def _planned_backward_windows(" in memory_plan_text + assert "fabric_compiler_named_runtime_artifact_bytes" in suite_common_text + assert "fabric_unclassified_cuda_peak_bytes" in suite_common_text + assert "checkpoint_steps=memory_artifact_plan.checkpoint_steps" in registered_executor_text + assert "backward_windows=memory_artifact_plan.backward_windows" in registered_executor_text + assert ( + "memory_runtime_policy_fingerprint=memory_artifact_plan.runtime_policy.review_summary" + in registered_executor_text + ) + assert ( + "memory_runtime_schedule_fingerprint=memory_artifact_plan.runtime_schedule_plan.fingerprint" + in registered_executor_text + ) + assert "step_index % checkpoint_stride" not in registered_executor_text + assert ( + "memory_runtime_schedule_fingerprint=memory_artifact_plan.runtime_schedule_plan.fingerprint" + in reverse_executor_text + ) + assert "memory_runtime_schedule_rows=memory_runtime_schedule_rows" in registered_executor_text + assert "memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor(" in reverse_executor_text + assert "runtime schedule rows mismatch" in reverse_executor_text + assert ( + "Registered temporal backward tried to materialize an artifact window outside the compiler " + in reverse_executor_text + ) + assert "memory schedule: window=" in reverse_executor_text + assert "runtime schedule fingerprint mismatch" in reverse_executor_text + assert "_require_compiler_planned_artifact_windows(" in reverse_executor_text + assert "checkpoint and recompute window derivation must come from the compiler memory plan" in reverse_executor_text + assert "temporal_artifact_windows" not in reverse_executor_text + assert "def temporal_artifact_windows(" not in windows_text + + +def test_frontend_handoff_memory_attribution_enters_compiler_ledger() -> None: + repo_root = Path(__file__).resolve().parents[1] + runtime_text = (repo_root / "src/cortical/fabric/runtime/core.py").read_text(encoding="utf-8") + executor_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/executor.py").read_text( + encoding="utf-8" + ) + memory_stage_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/runtime/memory_stages.py" + ).read_text(encoding="utf-8") + suite_common_text = (repo_root / "benchmarks/fabric/suite_common.py").read_text(encoding="utf-8") + + assert "frontend_model_sequence_entry" in runtime_text + assert "frontend_output_cells_entry" in runtime_text + assert "frontend_before_boundary_projection" in runtime_text + assert "frontend_after_boundary_projection" in runtime_text + assert "frontend_before_ensure_state" in runtime_text + assert "frontend_after_ensure_state" in runtime_text + assert "frontend_before_static_tensors" in runtime_text + assert "frontend_after_static_tensors" in runtime_text + assert "frontend_before_registered_execute" in runtime_text + assert "frontend_execute_entry" in executor_text + assert "frontend_execute_after_static_cache" in executor_text + assert "flat_bucket_temporal_frontend_tensor_bytes" in executor_text + assert "flat_bucket_temporal_frontend_tensor_bytes" in suite_common_text + assert "fabric_frontend_tensor_peak_stage_total_bytes" in suite_common_text + assert "record_registered_memory_stage(" in memory_stage_text + assert "record_frontend_tensor_bytes(" in memory_stage_text + assert "benchmark" not in memory_stage_text + assert "axoncell" not in memory_stage_text + assert "slstm" not in memory_stage_text + + +def test_registered_program_allocations_are_compiler_classified() -> None: + repo_root = Path(__file__).resolve().parents[1] + + audit = assert_registered_program_allocations_are_classified(repo_root) + owners = {site.owner for site in audit.sites} + summaries = "\n".join(audit.summaries) + + assert "registered_program_allocation_audit=compiler_owned" in audit.review_summary + assert "primitive_output" in owners + assert "metadata_row" in owners + assert "illegal_scheduler_allocation" not in owners + assert "operator_exports.cuh" in summaries + assert "transition_primitive_forward_ops.cuh" in summaries + assert "transition_reverse_handlers.cuh" in summaries + assert "backward_program.cuh" in summaries + + +def test_parameter_reducer_outputs_are_compiler_provided_tensor_table() -> None: + repo_root = Path(__file__).resolve().parents[1] + reducer_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/parameter_reducer_program.cuh" + ).read_text(encoding="utf-8") + message_reverse_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_reverse_strategies.cuh" + ).read_text(encoding="utf-8") + operator_declarations_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/operator_declarations.cuh" + ).read_text(encoding="utf-8") + reducer_wrapper_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_program_cuda.py" + ).read_text(encoding="utf-8") + param_binding_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/param_binding.py" + ).read_text(encoding="utf-8") + + assert "std::vector parameter_output_tensors" in reducer_text + assert "validate_registered_parameter_output_tensors(" in reducer_text + assert "parameter reducer output for role " in reducer_text + assert "parameter_output_tensors:" in reducer_wrapper_text + assert "_parameter_reducer_output_tensors(" in param_binding_text + assert "parameter_output_tensors=parameter_output_tensors" in param_binding_text + assert "output = at::zeros_like" not in reducer_text + assert "grad_public = at::zeros_like" not in reducer_text + assert "grad_slot_embed = at::zeros_like" not in reducer_text + assert "value_only_grad = grad_weight.size(2) == context.value_dim" in reducer_text + assert "flat_bucket_registered_backward_sender_value_projection_cuda(" in operator_declarations_text + assert "flat_bucket_registered_backward_sender_value_projection_cuda(" in message_reverse_text + assert "context_scalar_value_weight_with_zero_key_prefix" not in message_reverse_text + + +def test_parameter_reducer_native_callables_are_registry_owned() -> None: + repo_root = Path(__file__).resolve().parents[1] + native_callables_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/native_callables.py" + ).read_text(encoding="utf-8") + reducer_patterns_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/reducer_patterns.py" + ).read_text(encoding="utf-8") + executor_patterns_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py" + ).read_text(encoding="utf-8") + message_specs_text = (repo_root / "src/cortical/fabric/backend/message_rule_specs.py").read_text(encoding="utf-8") + message_rules_text = (repo_root / "src/cortical/fabric/backend/message_rules.py").read_text(encoding="utf-8") + registered_native_catalog_text = ( + repo_root + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh" + ).read_text(encoding="utf-8") + + for removed_map in ( + "_PARAMETER_REDUCER_NATIVE_CALLABLE", + "_PARAMETER_REDUCER_IMPLEMENTATION_SYMBOL", + "_TRANSITION_TRAINABLE_REDUCER_NATIVE_CALLABLE", + "_TRANSITION_TRAINABLE_REDUCER_IMPLEMENTATION_SYMBOL", + ): + assert removed_map not in native_callables_text + assert "temporal_parameter_reducer_patterns()" in native_callables_text + assert "temporal_transition_trainable_reducer_patterns()" in native_callables_text + assert "TemporalParameterReducerPattern" in reducer_patterns_text + assert "TemporalTransitionTrainableReducerPattern" in reducer_patterns_text + assert "run_registered_fixed_slot_context_message_parameter_reducer_strategy" not in reducer_patterns_text + assert "run_registered_fixed_slot_context_message_parameter_reducer_strategy" in message_specs_text + assert "MessageRuleNativeExecutorEntrypointSpec" in message_rules_text + assert 'cxx_entrypoint_phases=("bind", "recurrent_kv", "message")' in message_specs_text + assert '"recurrent_kv_forward_recompute"' in message_specs_text + assert "recurrent_kv_forward_recompute" in registered_native_catalog_text + assert "bind_fixed_slot_context_message_handler" not in executor_patterns_text + assert "run_fixed_slot_context_message" not in executor_patterns_text + assert "bind_neighborhood_attention_project_message_handler" not in executor_patterns_text + assert "run_neighborhood_attention_project_message" not in executor_patterns_text + assert "bind_fixed_slot_context_message_handler" in message_specs_text + assert "bind_neighborhood_attention_project_message_handler" in message_specs_text + + +def test_temporal_compiler_has_named_verifier_effects_and_typed_rejections() -> None: + repo_root = Path(__file__).resolve().parents[1] + verifier_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/verification.py" + ).read_text(encoding="utf-8") + + assert "TemporalCompilerVerificationReport" in verifier_text + assert "TemporalEffectAnnotation" in verifier_text + assert "verify_temporal_primitive_table(" in verifier_text + assert "temporal_compiler_pass_pipeline(" in verifier_text + assert "filter_strategy_legality" in verifier_text + assert "plan_backward_physical" in verifier_text + assert "plan_memory_liveness_workspace" in verifier_text + assert "UNSUPPORTED_PATTERN" in verifier_text + assert "MISSING_REQUIRED_BINDING" in verifier_text + assert "MISSING_BACKWARD_COVERAGE" in verifier_text + assert "ABI_VERSION_MISMATCH" in verifier_text + assert "fusion_requires_verified_rewrite_before_fixed_abi_can_close" not in verifier_text + + +def test_reverse_executor_rows_are_selected_from_pattern_registry() -> None: + repo_root = Path(__file__).resolve().parents[1] + tables_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/tables.py").read_text( + encoding="utf-8" + ) + pattern_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py" + ).read_text(encoding="utf-8") + reverse_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/temporal/reverse_executor.py" + ).read_text(encoding="utf-8") + backward_plan_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/backward_plan.py" + ).read_text(encoding="utf-8") + executor_binding_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py" + ).read_text(encoding="utf-8") + compiler_root = repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler" + + assert "_TEMPORAL_REVERSE_OP_OPCODE_BY_PRIMITIVE" not in tables_text + assert "_TEMPORAL_REVERSE_EXECUTOR_NAME_BY_PRIMITIVE" not in tables_text + assert "def temporal_reverse_primitive_rows_tensor" not in tables_text + assert "temporal_reverse_executor_rows_tensor(" in tables_text + assert "temporal_executor_strategy_registry().match_reverse(" in tables_text + assert "class TemporalReverseExecutorPattern" in pattern_text + assert "_REGISTERED_TEMPORAL_REVERSE_EXECUTOR_STRATEGIES" in pattern_text + assert "_registered_readout_rule_reverse_executor_patterns()" in pattern_text + assert 'surface="readout"' in tables_text + assert "temporal_reverse_primitive_rows_tensor" not in reverse_text + assert "reverse_executor_rows=backward_compatibility_launch_plan.reverse_executor_rows" not in reverse_text + assert "build_temporal_backward_compatibility_launch_plan(" not in reverse_text + assert "try_transition_message_reverse_table_window_cuda(" not in reverse_text + assert "class TemporalBackwardExecutablePlan" in backward_plan_text + assert "build_temporal_backward_executable_plan(" in backward_plan_text + assert "build_temporal_reverse_executor_binding_plan(" in executor_binding_text + assert not (compiler_root / "compatibility.py").exists() + assert not (compiler_root / "backward_compatibility.py").exists() + assert "return try_" not in backward_plan_text + + +def test_temporal_strategy_patterns_declare_legality_cost_runtime_contracts() -> None: + repo_root = Path(__file__).resolve().parents[1] + pattern_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_patterns.py" + ).read_text(encoding="utf-8") + row_group_text = (repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/row_groups.py").read_text( + encoding="utf-8" + ) + + assert "strategy_id:" in pattern_text + assert "strategy_version:" in pattern_text + assert "row_schema_version:" in pattern_text + assert "tensor_binding_schema_version:" in pattern_text + assert "metadata_schema_version:" in pattern_text + assert "cuda_kernel_abi_version:" in pattern_text + assert "legality_predicate:" in pattern_text + assert "cost_model:" in pattern_text + assert "runtime_entrypoint:" in pattern_text + assert "required_effects:" in pattern_text + assert "required_layouts:" in pattern_text + assert "supported_dtypes:" in pattern_text + assert "supported_devices:" in pattern_text + assert "workspace:" in pattern_text + assert "aliasing:" in pattern_text + assert "saved_tensor_policy:" in pattern_text + assert "gradient_accumulation:" in pattern_text + assert "determinism:" in pattern_text + assert "tolerance_class:" in pattern_text + assert "demotion_policy:" in pattern_text + assert "audit_metadata_schema:" in pattern_text + assert "verified_rewrite_required:" in pattern_text + assert "canonical_temporal_row_group(" in pattern_text + assert "pattern_temporal_row_group(" in pattern_text + assert "class TemporalRowGroupSchema" in row_group_text + assert "class TemporalRowSchema" in row_group_text + assert "def canonical_temporal_row_group(" in row_group_text + assert "def temporal_effects_for_row(" in row_group_text + + +def test_temporal_strategy_selection_has_separate_legality_and_cost_phase() -> None: + repo_root = Path(__file__).resolve().parents[1] + strategy_text = ( + repo_root / "src/cortical/fabric/backend/cuda/sequence_surface/compiler/strategy_selection.py" + ).read_text(encoding="utf-8") + + assert "class TemporalStrategyCandidate" in strategy_text + assert "class TemporalStrategySelectionReport" in strategy_text + assert "match_status:" in strategy_text + assert "legality_status:" in strategy_text + assert "rejection_code:" in strategy_text + assert "estimated_cost_rank:" in strategy_text + assert "binding_row_count:" in strategy_text + assert "binding_blocker_count:" in strategy_text + assert "build_temporal_strategy_selection_report(" in strategy_text + assert "UNVERIFIED_REWRITE" in strategy_text + assert "strategy_matches_but_requires_verified_rewrite_before_cost_selection" in strategy_text diff --git a/tests/test_fabric_backend_plan.py b/tests/test_fabric_backend_plan.py index 00161e69..bc770bf5 100644 --- a/tests/test_fabric_backend_plan.py +++ b/tests/test_fabric_backend_plan.py @@ -2,166 +2,6583 @@ from dataclasses import replace from pathlib import Path +from types import SimpleNamespace import pytest import torch from cortical.fabric.anatomy import init from cortical.fabric.backend import ( ExecutionFamily, + CellTransitionIR, MathBackend, + ReceiverKind, + ReuseScope, TapeMode, TapePolicy, + TensorSchema, + build_readout_rule_backend_spec, build_cell_backend_spec, + cuda_nn_callable_primitives, + compile_readout_rule, + compile_transition_program, + default_readout_rule_ir, + ReadoutRuleIR, + readout_rule_native_executor, + registered_readout_rule_backend_spec_lowering_kinds, + TransitionOp, ) from cortical.fabric.backend.caps import DeviceCaps -from cortical.fabric.backend.cuda.execution.registry import FabricExecutionRequest -from cortical.fabric.backend.cuda.projections import ( +from cortical.fabric.backend.planner import FabricExecutionPlanner +from cortical.fabric.backend.pytorch.projection import ( project_output_cells_step_raw, project_recurrent_kv_from_preproj_step, project_sender_kv_from_cells_sequence, project_sender_kv_from_cells_step, ) -from cortical.fabric.backend.message_rules import default_dot_product_message_rule_ir +from cortical.fabric.backend.cuda.sequence_surface.compiler.buckets import ( + temporal_bucket_plan, + with_cached_population_static_tensors, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.backward_plan import ( + build_temporal_backward_executable_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_bindings import ( + TemporalExecutorTensorBinding, + build_temporal_forward_executor_binding_plan, + build_temporal_reverse_executor_binding_plan, + build_temporal_transition_param_grad_binding_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.executor_patterns import ( + temporal_executor_strategy_registry, + temporal_forward_executor_pattern_summaries, + temporal_forward_executor_patterns, + temporal_reverse_executor_pattern_summaries, + temporal_reverse_executor_patterns, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.forward_plan import ( + build_temporal_forward_executable_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.forward_program import ( + temporal_forward_program_access_rows_tensor, + temporal_forward_transition_state_carry_rows_tensor, + temporal_program_access_opcode, + temporal_reverse_program_access_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.memory_plan import ( + TemporalRuntimeBufferPlan, + TemporalRuntimeBufferSpec, + TemporalTransitionForwardRuntimeBufferRequest, + allocate_temporal_runtime_buffer, + allocate_temporal_runtime_buffers, + build_temporal_memory_liveness_plan, + build_temporal_memory_runtime_artifact_plan, + build_temporal_memory_runtime_schedule_plan, + build_temporal_physical_strategy_plan, + build_temporal_runtime_buffer_plan, + temporal_memory_liveness_rows_tensor, + temporal_memory_runtime_policy, + temporal_memory_runtime_schedule_rows_tensor, + temporal_physical_strategy_rows_tensor, + temporal_runtime_buffer_role_opcode, + temporal_runtime_buffer_rows_tensor, + temporal_runtime_buffer_spec, + validate_temporal_runtime_buffer_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.native_callables import ( + temporal_native_callable_catalog_fingerprint, + temporal_native_callable_catalog_rows_tensor, + temporal_native_callable_binding_schema_definitions, + temporal_native_callable_binding_schema_fingerprint, + temporal_native_callable_binding_schema_rows_tensor, + temporal_native_callable_binding_schema_summaries, + temporal_native_callable_definitions, + temporal_native_callable_generated_header_text, + temporal_native_callable_output_contract_fingerprint, + temporal_native_callable_output_definitions, + temporal_native_callable_output_rows_tensor, + temporal_native_callable_output_summaries, + temporal_native_callable_transition_forward_output_definition, + temporal_native_callable_summaries, + temporal_transition_reverse_seed_role_id, + temporal_transition_reverse_seed_role_rows_tensor, + validate_temporal_native_callable_generated_header, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.primitive_dispatch import ( + build_temporal_primitive_executor_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.primitive_registry import ( + temporal_primitive_opcode, + temporal_surface_opcode, + temporal_transition_tape_kind, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.reducer_patterns import ( + temporal_parameter_reducer_pattern, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.reverse_artifacts import ( + decode_temporal_reverse_transition_state_artifact_flags, + encode_temporal_reverse_transition_state_artifact_flags, + temporal_reverse_artifact_access_id, + temporal_reverse_artifact_access_role_name, + temporal_reverse_artifact_access_rows_tensor, + temporal_reverse_artifact_role_id, + temporal_reverse_artifact_role_is_tensor, + temporal_reverse_artifact_role_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.reset_plan import ( + temporal_reverse_reset_kind_id, + temporal_reverse_reset_tensor_table, + temporal_reverse_transition_state_reset_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.program_execution import ( + build_temporal_fused_cuda_program_plan, + build_temporal_message_transition_producer_consumer_plan, + build_temporal_readout_message_producer_consumer_plan, + build_temporal_registered_program_executor_plan, + build_temporal_reverse_program_stage_plan, + temporal_forward_artifact_merge_rows_tensor, + temporal_forward_artifact_route_rows_tensor, + temporal_forward_output_route_rows_tensor, + temporal_message_transition_producer_consumer_rows_tensor, + temporal_native_executor_strategy_rows_tensor, + temporal_readout_message_producer_consumer_rows_tensor, + temporal_forward_executor_handler_rows_tensor, + temporal_reverse_artifact_consumer_route_rows_tensor, + temporal_reverse_executor_handler_rows_tensor, + temporal_reverse_output_route_kind_opcode, + temporal_reverse_output_route_rows_tensor, + temporal_reverse_output_route_target_id, + temporal_reverse_parameter_reducer_route_rows_tensor, + temporal_reverse_span_output_group_opcode, + temporal_reverse_span_output_role_id, + temporal_reverse_span_output_rows_tensor, + temporal_strategy_id_hash, + temporal_transition_primitive_native_callable_rows_tensor, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.program_runtime import ( + build_temporal_forward_program_runtime_plan, + build_temporal_forward_program_runtime_support_plan, + build_temporal_reverse_program_runtime_plan, + build_temporal_reverse_program_runtime_support_plan, + temporal_forward_program_runtime_role_opcode, + temporal_reverse_program_runtime_role_opcode, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.runtime_metadata import ( + record_temporal_primitive_table_runtime_metadata, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.row_groups import ( + canonical_temporal_row_group, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.strategy_selection import ( + build_temporal_strategy_selection_report, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.tables import ( + build_temporal_primitive_table_plan, + temporal_forward_executor_rows, + temporal_forward_executor_rows_tensor, + temporal_forward_executor_summaries, + temporal_primitive_rows_tensor, + temporal_reverse_executor_rows, + temporal_reverse_executor_rows_tensor, + temporal_reverse_executor_summaries, + temporal_tensor_binding_rows_tensor, + temporal_tensor_binding_summaries, + temporal_table_full_tape_extra_state_factors, + temporal_table_transition_recurrent_bucket_kinds, + temporal_table_transition_kind_labels, + validate_temporal_supported_scan_binding_projection, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.registered_executors import ( + build_registered_temporal_executor_program, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.program_parameters import ( + surface_parameter_tensor_table, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.common import ( + temporal_message_output_dim, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.scan_schedule import ( + build_scalar_temporal_scan_schedule, + emitted_output_index_for_scan_step, +) +from cortical.fabric.backend.cuda.sequence_surface.compiler.verification import ( + temporal_compiler_pass_pipeline, + temporal_strategy_rejection_codes, + verify_temporal_primitive_table, +) +from cortical.fabric.backend.cuda.transition_execution.registry import transition_primitive_executor_record +from cortical.fabric.backend.cuda.sequence_surface.temporal.scheduler import ( + build_temporal_runtime_scheduler_plan, +) +from cortical.fabric.backend.cuda.sequence_surface.temporal.reverse_executor import ( + TemporalPhysicalBackwardScanExecutor, + _require_compiler_memory_artifact_plan, + _require_compiler_physical_strategy_plan, + _require_compiler_planned_artifact_windows, +) +from cortical.fabric.backend.cuda.temporal_param_binding import ( + compiled_transition_program_for_bucket, + resolve_transition_parameter, +) +from cortical.fabric.backend.cuda.transition_execution.program import ( + TransitionProgramExecutorSelectionError, + select_transition_program_executor, +) +from cortical.fabric.backend.cuda.transition_execution.registry import ( + registered_transition_primitive_executor_records, + registered_transition_executor_records, + transition_primitive_executor_record_for_lowered_primitive, + transition_primitive_program_contract_blocker_code, +) +from cortical.fabric.backend.message_rules import ( + DOT_PRODUCT_FIXED_SLOT_CONTEXT_GATE, + DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE, + DOT_PRODUCT_SEGMENT_SOFTMAX_WEIGHTED_SUM, + MessageRuleIR, + build_message_rule_backend_spec, + build_message_rule_ir, + compile_message_rule, + default_dot_product_message_rule_ir, + registered_message_rule_backend_spec_types, +) from cortical.fabric.backend.pytorch.readout import ( ReadoutConfig, pool_output_ports, readout_output_cells, select_output_cells, ) -from cortical.fabric.config import CellPopulationConfig, Config +from cortical.fabric.config import ( + CellPopulationConfig, + Config, + FabricInterfaceConfig, + InitializationConfig, + MessageConfig, + PopulationLayoutConfig, +) +from cortical.fabric.graphs import lattice2d from cortical.fabric.registry.cell_backends import get_cell_backend_implementation from cortical.fabric.registry.cells import get_cell_spec from cortical.fabric.runtime import build -def _make_slstm_spec( +def _make_slstm_spec( + *, + hidden_size: int = 8, + head_dim: int | None = None, + max_delay: int | None = None, + patch_edges_per_cell: int = 0, +): + connectivity = [lattice2d.LocalRadius(1.5)] + if patch_edges_per_cell > 0: + connectivity.append(lattice2d.PatchEdges(per_cell=patch_edges_per_cell, min_distance=3.0, max_distance=4.0)) + return init( + Config( + graph=lattice2d.Graph(width=8, height=8, connectivity=tuple(connectivity)), + interface=FabricInterfaceConfig(hidden_size=hidden_size), + message=MessageConfig(head_dim=4 if head_dim is None else int(head_dim), projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, + population_mix={"slstm": 1.0}, + ), + initialization=InitializationConfig(seed=11), + ) + ) + + +def _with_context_nudge_message_rule(spec): + return replace( + spec, + message_rule=build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_nudge_dot_product", + ), + ) + + +def _make_mixed_spec(): + return init( + Config( + graph=lattice2d.Graph( + width=8, + height=8, + connectivity=(lattice2d.LocalRadius(1.5), lattice2d.PatchEdges(per_cell=2)), + ), + interface=FabricInterfaceConfig(hidden_size=8), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={ + "slstm": CellPopulationConfig(cell_type="slstm"), + "axoncell": CellPopulationConfig(cell_type="axoncell"), + }, + population_mix={"slstm": 0.5, "axoncell": 0.5}, + ), + initialization=InitializationConfig(seed=13), + ) + ) + + +def _stateful_tanh_transition_spec(*, hidden_size: int = 8): + slstm_spec = build_cell_backend_spec( + cell_type="slstm", + hidden_size=hidden_size, + d_public=hidden_size, + d_msg=hidden_size, + head_dim=hidden_size, + value_dim=hidden_size, + ) + return replace( + slstm_spec, + state_tensors=("mem",), + private_state_schema=( + TensorSchema( + "mem", + "private_state", + ("receiver", "hidden"), + ReuseScope.OUTER_TIME_VARYING, + ), + ), + parameter_schema=(), + transition_parameter_bindings={}, + reuse_scopes={"mem": ReuseScope.OUTER_TIME_VARYING}, + transition_ir=CellTransitionIR( + state_inputs=("mem",), + message_inputs=("aggregated_message",), + parameter_inputs=(), + ops=( + TransitionOp("tanh", ("mem",), ("next_mem",)), + TransitionOp("tanh", ("aggregated_message",), ("public_y",)), + ), + state_outputs=("next_mem",), + public_outputs=("public_y",), + recompute_outputs=("next_mem", "public_y"), + backward_decomposition=(), + ), + ) + + +def _make_alias_population_spec(): + return init( + Config( + graph=lattice2d.Graph( + width=8, + height=8, + connectivity=(lattice2d.LocalRadius(1.5), lattice2d.PatchEdges(per_cell=2)), + ), + interface=FabricInterfaceConfig(hidden_size=8), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={ + "left": CellPopulationConfig(cell_type="slstm"), + "right": CellPopulationConfig(cell_type="axoncell"), + }, + population_mix={"left": 0.5, "right": 0.5}, + ), + initialization=InitializationConfig(seed=23), + ) + ) + + +def _make_three_population_spec(): + return init( + Config( + graph=lattice2d.Graph( + width=8, + height=8, + connectivity=(lattice2d.LocalRadius(1.5), lattice2d.PatchEdges(per_cell=2)), + ), + interface=FabricInterfaceConfig(hidden_size=8), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={ + "left": CellPopulationConfig(cell_type="slstm"), + "middle": CellPopulationConfig(cell_type="axoncell"), + "right": CellPopulationConfig(cell_type="slstm"), + }, + population_mix={"left": 0.34, "middle": 0.33, "right": 0.33}, + ), + initialization=InitializationConfig(seed=41), + ) + ) + + +def _make_axon_spec(*, hidden_size: int = 8): + return init( + Config( + graph=lattice2d.Graph(width=8, height=8), + interface=FabricInterfaceConfig(hidden_size=hidden_size), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={"axoncell": CellPopulationConfig(cell_type="axoncell")}, + population_mix={"axoncell": 1.0}, + ), + initialization=InitializationConfig(seed=17), + ) + ) + + +def _test_device_caps() -> DeviceCaps: + return DeviceCaps( + device_type="cuda", + device_index=0, + name="test", + capability=(9, 0), + multi_processor_count=120, + max_shared_memory_per_block=99_328, + supports_cuda_graphs=True, + supports_tensor_cores=True, + supports_tma=True, + supports_clusters=True, + ) + + +def _make_explicit_graph_spec(cell_type: str): + return init( + Config( + graph=lattice2d.Graph( + width=4, + height=4, + inputs=(0, 1), + outputs=(14, 15), + connectivity=( + lattice2d.ExplicitEdges( + edges=((2, 3), (3, 2), (4, 2), (14, 13), (15, 13)), + kv_group_ids=tuple(idx // 2 for idx in range(16)), + ), + ), + ), + interface=FabricInterfaceConfig(hidden_size=8), + populations=PopulationLayoutConfig( + cell_populations={cell_type: CellPopulationConfig(cell_type=cell_type)}, + population_mix={cell_type: 1.0}, + ), + initialization=InitializationConfig(seed=19), + ) + ) + + +def test_fabric_backend_ir_compiles_receiver_sets_and_buckets() -> None: + runtime = build(_make_mixed_spec()) + ir = runtime.backend_ir + + assert ir.num_cells == runtime.coords.shape[0] + assert ir.num_input_ports == runtime.input_cell_idx.numel() + assert ir.num_recurrent_cells == runtime.recurrent_cell_idx.numel() + assert ir.num_output_ports == runtime.output_cell_idx.numel() + assert ir.bucket_count > 0 + assert any(bucket.receiver_kind.value == "recurrent_cell" for bucket in ir.buckets) + assert any(bucket.receiver_kind.value == "output_port" for bucket in ir.buckets) + assert any(bucket.has_sparse_overlay for bucket in ir.buckets) + assert ir.graph_summary.node_count == ir.num_cells + assert ir.graph_summary.input_count == ir.num_input_ports + assert ir.graph_summary.output_count == ir.num_output_ports + assert ir.graph_summary.recurrent_count == ir.num_recurrent_cells + assert ir.graph_summary.flat_signature.node_count == ir.num_cells + assert ir.graph_summary.flat_signature.degree_histogram == ir.graph_summary.degree_histogram + assert len(ir.transition_programs) == len(ir.population_names) + assert all(program.primitive_ops for program in ir.transition_programs) + assert ir.readout_program.primitive_names == ("readout_project", "reduction_boundary") + assert all(bucket.transition_signature for bucket in ir.buckets) + assert all(bucket.parameter_binding for bucket in ir.buckets) + + +def test_fabric_bucket_identity_is_not_keyed_by_population_name() -> None: + runtime = build(_make_alias_population_spec()) + recurrent_buckets = [ + bucket for bucket in runtime.backend_ir.buckets if bucket.receiver_kind == ReceiverKind.RECURRENT_CELL + ] + + assert recurrent_buckets + flat_signatures = { + str(item) for bucket in recurrent_buckets for item in bucket.signature if not isinstance(item, tuple) + } + nested_signatures = { + str(item) + for bucket in recurrent_buckets + for group in bucket.signature + if isinstance(group, tuple) + for item in group + } + identity_text = "\n".join(sorted(flat_signatures | nested_signatures)) + assert "left" not in identity_text + assert "right" not in identity_text + assert "cell_type=" not in identity_text + assert "cell_kind=" not in identity_text + assert "transition_ops=" in identity_text + assert "gated_logspace_recurrence" in identity_text + assert "diag_rtu" in identity_text + assert "population_slot:0" in identity_text + assert "population_slot:1" in identity_text + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=4, + k=1, + training=False, + device=torch.device("cuda"), + ) + + assert planned.substrate.bucket_identity == "flat_bucket_identity" + + +def test_temporal_bucket_plan_exposes_flat_bucket_identity() -> None: + runtime = build(_make_alias_population_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + plan = temporal_bucket_plan(runtime, cached_static_tensors) + + assert plan.flat_buckets + assert plan.population_buckets == plan.flat_buckets + identity_text = "\n".join(sorted(item for identity in plan.flat_bucket_identities for item in identity)) + assert "left" not in identity_text + assert "right" not in identity_text + assert "flat_bucket_identity" in identity_text + assert "cell_type=" not in identity_text + assert "cell_kind=" not in identity_text + assert "transition_ops=" in identity_text + assert "transition_state_inputs=" in identity_text + assert "transition_parameter_bindings=" in identity_text + assert "binding_slot=0" not in identity_text + assert "binding_slot=1" not in identity_text + binding_identity_text = "\n".join(sorted(item for bucket in plan.flat_buckets for item in bucket.binding_identity)) + assert "binding_slot=0" in binding_identity_text + assert "binding_slot=1" in binding_identity_text + assert any(bucket.binding_name == "left" for bucket in plan.flat_buckets) + assert any(bucket.binding_name == "right" for bucket in plan.flat_buckets) + + +def test_temporal_bucket_plan_accepts_more_than_two_binding_populations() -> None: + runtime = build(_make_three_population_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + plan = temporal_bucket_plan(runtime, cached_static_tensors) + + assert runtime._population_names == ("left", "middle", "right") + assert tuple(bucket.binding_name for bucket in plan.flat_buckets) == ("left", "middle", "right") + identity_text = "\n".join(sorted(item for identity in plan.flat_bucket_identities for item in identity)) + assert "left" not in identity_text + assert "middle" not in identity_text + assert "right" not in identity_text + assert "cell_type=" not in identity_text + assert "cell_kind=" not in identity_text + assert "binding_slot=0" not in identity_text + assert "binding_slot=1" not in identity_text + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + assert temporal_table_transition_recurrent_bucket_kinds(table) == { + 0: "gated_logspace_recurrence", + 1: "diag_rtu", + 2: "gated_logspace_recurrence", + } + assert build_temporal_forward_executable_plan(table).strategy_legality_status == "legal" + assert build_temporal_backward_executable_plan(table).strategy_legality_status == "legal" + assert "binding_slot=2" not in identity_text + binding_identity_text = "\n".join(sorted(item for bucket in plan.flat_buckets for item in bucket.binding_identity)) + assert "binding_slot=0" in binding_identity_text + assert "binding_slot=1" in binding_identity_text + assert "binding_slot=2" in binding_identity_text + + +def test_temporal_primitive_table_plan_uses_flat_bucket_rows_not_population_names() -> None: + runtime = build(_make_alias_population_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + assert table.bucket_count == 2 + assert table.primitive_rows + assert temporal_table_transition_recurrent_bucket_kinds(table) == { + 0: "gated_logspace_recurrence", + 1: "diag_rtu", + } + assert "attention_logits" in table.primitive_names + assert "segment_softmax" in table.primitive_names + assert "weighted_sum" in table.primitive_names + assert "readout_project" in table.primitive_names + assert "reduction_boundary" in table.primitive_names + assert "gated_logspace_recurrence" in table.primitive_names + assert "diag_rtu" in table.primitive_names + assert "state_epilogue" in table.primitive_families + assert temporal_table_transition_kind_labels(table) == frozenset({"gated_logspace", "diagonal_recurrence"}) + assert temporal_table_full_tape_extra_state_factors(table) == {0: 8, 1: 1} + primitive_rows = temporal_primitive_rows_tensor(table) + assert primitive_rows.dtype == torch.long + assert tuple(primitive_rows.shape) == (len(table.primitive_rows), 4) + tensor_binding_rows = temporal_tensor_binding_rows_tensor(table) + assert tensor_binding_rows.dtype == torch.long + assert tuple(tensor_binding_rows.shape) == (len(table.tensor_bindings), 7) + assert int(primitive_rows[:, 0].eq(1).sum().item()) == 1 + assert int(primitive_rows[:, 0].eq(2).sum().item()) == 1 + assert set(int(item) for item in primitive_rows[:, 3].tolist()).issuperset({-3, -2, -1, 0, 1}) + forward_executable_plan = build_temporal_forward_executable_plan(table) + assert torch.equal(forward_executable_plan.forward_executor_rows, temporal_forward_executor_rows_tensor(table)) + assert forward_executable_plan.executor_binding_rows.dtype == torch.long + assert tuple(forward_executable_plan.executor_binding_rows.shape[1:]) == (8,) + assert forward_executable_plan.strategy_ids == ( + "forward.message.fixed_slot_context_nudge.v1", + "forward.readout.mean_projection_reduction_boundary.v1", + "forward.transition.gated_logspace.v1", + "forward.transition.diag_rtu.v1", + ) + assert forward_executable_plan.strategy_legality_status == "legal" + forward_review = "\n".join(forward_executable_plan.review_summary) + assert "forward_executable_plan=compiler_owned" in forward_review + assert "runtime_entrypoint=registered_temporal_fused_forward_program_cuda" in forward_review + assert "executor_binding_row_count=" in forward_review + assert "strategy_legality_blocker=UNVERIFIED_REWRITE" not in forward_review + reverse_rows = temporal_reverse_executor_rows_tensor(table) + assert reverse_rows.dtype == torch.long + assert tuple(reverse_rows.shape) == (len(temporal_reverse_executor_rows(table)), 6) + assert set(int(item) for item in reverse_rows[:, 0].tolist()) == {2, 3, 4, 5} + reverse_executor_rows = temporal_reverse_executor_rows(table) + assert [row.executor_id for row in reverse_executor_rows] == [5, 4, 2, 3] + assert [row.executor_name for row in reverse_executor_rows] == [ + "fixed_slot_context_nudge_message_backward", + "mean_projection_reduction_boundary_backward", + "gated_logspace_transition_backward", + "diag_rtu_transition_backward", + ] + reverse_summary_text = "\n".join(temporal_reverse_executor_summaries(table)) + assert "surface=message" in reverse_summary_text + assert "surface=readout" in reverse_summary_text + assert "surface=transition" in reverse_summary_text + assert ( + "params=message_query_slot_weight,recurrent_sender_value_weight,message_query_nudge_scale," + "message_sender_slot_key_weight,message_sender_context_key,input_sender_value_weight," + "input_group_value_weight,message_output_weight" in reverse_summary_text + ) + backward_executable_plan = build_temporal_backward_executable_plan(table) + assert torch.equal(backward_executable_plan.reverse_executor_rows, reverse_rows) + assert backward_executable_plan.executor_binding_rows.dtype == torch.long + assert tuple(backward_executable_plan.executor_binding_rows.shape[1:]) == (8,) + assert backward_executable_plan.strategy_ids == ( + "reverse.message.fixed_slot_context_nudge.v1", + "reverse.readout.mean_projection_reduction_boundary.v1", + "reverse.transition.gated_logspace.v1", + "reverse.transition.diag_rtu.v1", + ) + backward_review = "\n".join(backward_executable_plan.review_summary) + assert "backward_executable_plan=compiler_owned" in backward_review + assert "runtime_entrypoint=registered_reverse_executor_bindings" in backward_review + assert "executor_binding_row_count=" in backward_review + memory_plan = build_temporal_memory_liveness_plan(table) + memory_liveness_rows = temporal_memory_liveness_rows_tensor(memory_plan) + forward_handler_rows = temporal_forward_executor_handler_rows_tensor(table) + reverse_handler_rows = temporal_reverse_executor_handler_rows_tensor(table) + assert tuple(forward_handler_rows.shape)[1] == 11 + assert tuple(reverse_handler_rows.shape)[1] == 11 + assert all(int(item) > 0 for item in forward_handler_rows[:, 7].tolist()) + assert all(int(item) > 0 for item in reverse_handler_rows[:, 7].tolist()) + fused_program_plan = build_temporal_fused_cuda_program_plan( + primitive_rows=temporal_primitive_rows_tensor(table), + forward_plan=forward_executable_plan, + backward_plan=backward_executable_plan, + memory_plan=memory_plan, + memory_liveness_rows=memory_liveness_rows, + forward_handler_rows=forward_handler_rows, + reverse_handler_rows=reverse_handler_rows, + forward_artifact_route_rows=temporal_forward_artifact_route_rows_tensor(table), + forward_artifact_merge_rows=temporal_forward_artifact_merge_rows_tensor(table), + forward_output_route_rows=temporal_forward_output_route_rows_tensor(table), + readout_message_producer_consumer_rows=temporal_readout_message_producer_consumer_rows_tensor( + build_temporal_readout_message_producer_consumer_plan(table) + ), + message_transition_producer_consumer_rows=temporal_message_transition_producer_consumer_rows_tensor( + build_temporal_message_transition_producer_consumer_plan(table) + ), + reverse_artifact_consumer_route_rows=temporal_reverse_artifact_consumer_route_rows_tensor(table), + reverse_parameter_reducer_route_rows=temporal_reverse_parameter_reducer_route_rows_tensor(table), + ) + assert fused_program_plan.status == "legal" + assert fused_program_plan.blocker_code == "" + assert "transition_primitives=linear,matmul,gated_logspace_recurrence,norm_or_identity,diag_rtu" in ( + fused_program_plan.blocker_reason + ) + assert "registered_fused_program_has_sequence_forward_span_dispatch_body" in fused_program_plan.blocker_reason + assert "registered_fused_program_has_reverse_span_dispatch_body" in fused_program_plan.blocker_reason + assert "program_transition_gated_logspace_recurrence_forward" not in fused_program_plan.blocker_reason + assert "program_transition_norm_or_identity_forward" not in fused_program_plan.blocker_reason + assert "program_transition_diag_rtu_forward" not in fused_program_plan.blocker_reason + assert "program_transition_diag_rtu_backward" not in fused_program_plan.blocker_reason + fused_review = "\n".join(fused_program_plan.review_summary) + assert "fused_cuda_program_plan=compiler_owned" in fused_review + assert "fused_cuda_launch_contract=compiler_owned" in fused_review + assert "required_tables=primitive_rows,forward_executor_rows,reverse_executor_rows" in fused_review + assert "forward_handler_rows,reverse_handler_rows,native_strategy_rows" in fused_review + assert "native_callable_catalog_rows,native_callable_binding_schema_rows,native_callable_output_rows" in ( + fused_review + ) + assert "transition_reverse_seed_role_rows" in fused_review + assert "reverse_output_route_rows" in fused_review + assert "forward_artifact_route_rows" in fused_review + assert "forward_artifact_merge_rows" in fused_review + assert "forward_output_route_rows" in fused_review + assert "reverse_parameter_reducer_route_rows" in fused_review + assert "forward_executor_binding_rows,reverse_executor_binding_rows,memory_liveness_plan" in fused_review + assert "memory_liveness_rows" in fused_review + assert "physical_strategy_rows" in fused_review + assert "forward_program_runtime_rows" in fused_review + assert "reverse_program_runtime_rows" in fused_review + assert "forward_handler_row_count=" in fused_review + assert "reverse_handler_row_count=" in fused_review + assert "native_strategy_row_count=" in fused_review + assert "demotion_policy=fail_closed_no_unregistered_program_demotion" in fused_review + assert "unsupported_policy=typed_strategy_and_binding_rejection" in fused_review + assert "binding_slot" not in fused_review.lower() + assert fused_program_plan.launch_contract.primitive_row_count == int(temporal_primitive_rows_tensor(table).shape[0]) + assert fused_program_plan.launch_contract.forward_executor_row_count == int( + forward_executable_plan.forward_executor_rows.shape[0] + ) + assert fused_program_plan.launch_contract.reverse_executor_row_count == int( + backward_executable_plan.reverse_executor_rows.shape[0] + ) + assert fused_program_plan.launch_contract.forward_handler_row_count == int(forward_handler_rows.shape[0]) + assert fused_program_plan.launch_contract.reverse_handler_row_count == int(reverse_handler_rows.shape[0]) + assert fused_program_plan.launch_contract.native_strategy_row_count == int( + temporal_native_executor_strategy_rows_tensor().shape[0] + ) + assert fused_program_plan.launch_contract.native_callable_binding_schema_row_count == int( + temporal_native_callable_binding_schema_rows_tensor().shape[0] + ) + assert fused_program_plan.launch_contract.native_callable_output_row_count == int( + temporal_native_callable_output_rows_tensor().shape[0] + ) + assert fused_program_plan.launch_contract.transition_reverse_seed_role_row_count == int( + temporal_transition_reverse_seed_role_rows_tensor().shape[0] + ) + assert fused_program_plan.launch_contract.reverse_output_route_row_count == int( + temporal_reverse_output_route_rows_tensor().shape[0] + ) + assert fused_program_plan.launch_contract.forward_artifact_route_row_count == int( + temporal_forward_artifact_route_rows_tensor(table).shape[0] + ) + assert fused_program_plan.launch_contract.reverse_parameter_reducer_route_row_count == int( + temporal_reverse_parameter_reducer_route_rows_tensor(table).shape[0] + ) + assert fused_program_plan.launch_contract.memory_entry_count == len(memory_plan.entries) + assert fused_program_plan.launch_contract.memory_liveness_row_count == int(memory_liveness_rows.shape[0]) + program_executor_plan = build_temporal_registered_program_executor_plan(fused_program_plan) + program_executor_review = "\n".join(program_executor_plan.review_summary) + assert program_executor_plan.forward_entrypoint == "registered_temporal_fused_forward_program_cuda" + assert program_executor_plan.demotion_policy == "fail_closed_registered_fused_program_only" + assert "registered_program_executor_plan=compiler_owned" in program_executor_review + assert "fused_cuda_status=legal" in program_executor_review + assert "fused_cuda_blocker_code=-" in program_executor_review + assert table.review_disallowed_terms(("left", "right", "slstm", "axoncell", "single", "mixed")) == () + + +def test_transition_reverse_seed_roles_are_compiler_owned_rows() -> None: + rows = temporal_transition_reverse_seed_role_rows_tensor() + + assert tuple(rows.shape) == (7, 4) + assert int(rows[0, 0].item()) == temporal_transition_reverse_seed_role_id("grad_public_y") + assert {int(row[0]) for row in rows.tolist()} == { + temporal_transition_reverse_seed_role_id("grad_public_y"), + temporal_transition_reverse_seed_role_id("grad_next_y"), + temporal_transition_reverse_seed_role_id("grad_next_c"), + temporal_transition_reverse_seed_role_id("grad_next_n"), + temporal_transition_reverse_seed_role_id("grad_next_m"), + temporal_transition_reverse_seed_role_id("grad_next_hc1"), + temporal_transition_reverse_seed_role_id("grad_next_hc2"), + } + + +def test_reverse_span_outputs_are_compiler_owned_rows() -> None: + rows = temporal_reverse_span_output_rows_tensor() + + assert rows.dtype == torch.long + assert tuple(rows.shape[1:]) == (6,) + front_rows = [row for row in rows.tolist() if int(row[1]) == temporal_reverse_span_output_group_opcode("front")] + boundary_rows = [ + row for row in rows.tolist() if int(row[1]) == temporal_reverse_span_output_group_opcode("boundary") + ] + assert {int(row[2]) for row in front_rows} >= { + temporal_reverse_span_output_role_id("grad_boundary_direct"), + temporal_reverse_span_output_role_id("grad_recurrent_kv_weight_graph_order"), + } + front_required_by_role = {int(row[2]): int(row[4]) for row in front_rows} + for local_only_role in ( + "grad_recurrent_hidden_backend_direct", + "grad_input_k_from_output", + "grad_input_v_from_output", + "grad_recurrent_hidden_from_kv_graph_order", + ): + assert front_required_by_role[temporal_reverse_span_output_role_id(local_only_role)] == 0 + for returned_role in ( + "grad_boundary_direct", + "grad_value_to_output_weight", + "grad_output_cell_bias", + "grad_output_q", + "grad_recurrent_kv_weight_graph_order", + ): + assert front_required_by_role[temporal_reverse_span_output_role_id(returned_role)] == 1 + assert {int(row[2]) for row in boundary_rows} >= { + temporal_reverse_span_output_role_id("grad_recurrent_q_backend"), + temporal_reverse_span_output_role_id("grad_initial_recurrent_kv_weight_graph_order"), + temporal_reverse_span_output_role_id("grad_query_context_scalar"), + temporal_reverse_span_output_role_id("grad_output_weight"), + temporal_reverse_span_output_role_id("grad_input_key_bank"), + temporal_reverse_span_output_role_id("grad_recurrent_key_bank"), + } + + +def test_reverse_output_routes_are_compiler_owned_rows() -> None: + rows = temporal_reverse_output_route_rows_tensor() + + assert rows.dtype == torch.long + assert tuple(rows.shape[1:]) == (8,) + route_pairs = {(int(row[1]), int(row[2]), int(row[3])) for row in rows.tolist()} + assert ( + temporal_reverse_output_route_kind_opcode("readout_parameter_grad"), + temporal_reverse_output_route_target_id("value_to_output_weight"), + temporal_reverse_span_output_group_opcode("front"), + ) in route_pairs + assert ( + temporal_reverse_output_route_kind_opcode("sender_kv_parameter_grad"), + temporal_reverse_output_route_target_id("boundary_input_kv_weight"), + temporal_reverse_span_output_group_opcode("boundary"), + ) in route_pairs + assert ( + temporal_reverse_output_route_kind_opcode("transition_boundary"), + temporal_reverse_output_route_target_id("recurrent_query"), + temporal_reverse_span_output_group_opcode("boundary"), + ) in route_pairs + assert ( + temporal_reverse_output_route_kind_opcode("message_strategy_parameter_grad"), + temporal_reverse_output_route_target_id("grad_query_context_scalar"), + temporal_reverse_span_output_group_opcode("boundary"), + ) in route_pairs + + +def test_forward_artifact_routes_are_compiler_owned_rows() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + rows = temporal_forward_artifact_route_rows_tensor(table) + + assert rows.dtype == torch.long + assert tuple(rows.shape[1:]) == (10,) + role_ids = {int(row[5]) for row in rows.tolist()} + required_by_role = {int(row[5]): int(row[7]) for row in rows.tolist()} + assert role_ids >= { + temporal_reverse_artifact_role_id("boundary_step"), + temporal_reverse_artifact_role_id("recurrent_msg_backend_order"), + temporal_reverse_artifact_role_id("output_cells"), + temporal_reverse_artifact_role_id("transition_state_before"), + } + assert required_by_role[temporal_reverse_artifact_role_id("recurrent_k")] == 0 + assert required_by_role[temporal_reverse_artifact_role_id("recurrent_v")] == 0 + assert required_by_role[temporal_reverse_artifact_role_id("recurrent_k_before")] == 0 + assert required_by_role[temporal_reverse_artifact_role_id("recurrent_v_before")] == 0 + producer_rows = {(int(row[2]), int(row[3]), int(row[4])) for row in rows.tolist()} + assert any(executor_row >= 0 and executor_id > 0 for executor_row, executor_id, _bucket in producer_rows) + + +def test_forward_artifact_merge_rows_are_compiler_owned_rows() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + route_rows = temporal_forward_artifact_route_rows_tensor(table) + merge_rows = temporal_forward_artifact_merge_rows_tensor(table) + + assert merge_rows.dtype == torch.long + assert tuple(merge_rows.shape[1:]) == (12,) + assert int(merge_rows.shape[0]) > 0 + identity_rows = [row for row in merge_rows.tolist() if int(row[4]) == 1] + assert identity_rows + for row in identity_rows: + producer_route_row = int(row[6]) + assert 0 <= producer_route_row < int(route_rows.shape[0]) + route_row = route_rows[producer_route_row].tolist() + assert int(route_row[1]) == int(row[1]) + assert int(route_row[4]) == int(row[2]) + assert int(route_row[5]) == int(row[3]) + role_ids = {int(row[3]) for row in merge_rows.tolist()} + assert role_ids >= { + temporal_reverse_artifact_role_id("boundary_step"), + temporal_reverse_artifact_role_id("recurrent_msg_backend_order"), + temporal_reverse_artifact_role_id("output_cells"), + temporal_reverse_artifact_role_id("transition_state_before"), + } + + +def test_forward_output_routes_are_compiler_owned_rows() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + rows = temporal_forward_output_route_rows_tensor(table) + + assert rows.dtype == torch.long + assert tuple(rows.shape) == (1, 10) + assert int(rows[0, 1]) == 1 + assert int(rows[0, 2]) > 0 + assert int(rows[0, 3]) >= 0 + assert int(rows[0, 4]) > 0 + assert int(rows[0, 6]) == temporal_strategy_id_hash("output_cells") + assert int(rows[0, 9]) == 0 + + +def test_readout_message_producer_consumer_rows_are_compiler_owned_legality_rows() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + plan = build_temporal_readout_message_producer_consumer_plan(table) + rows = temporal_readout_message_producer_consumer_rows_tensor(plan) + + assert rows.dtype == torch.long + assert tuple(rows.shape) == (2, 16) + assert plan.selected_strategy == "materialized_recurrent_kv_after" + assert int(rows[0, 2]) == 1 + assert int(rows[0, 3]) == 1 + assert int(rows[0, 4]) == 1 + assert int(rows[0, 5]) > 0 + assert int(rows[0, 9]) > 0 + assert int(rows[0, 13]) == 0 + assert int(rows[0, 14]) & 4 + assert int(rows[0, 14]) & 8 + assert int(rows[0, 15]) == 0 + assert int(rows[1, 2]) == 2 + assert int(rows[1, 3]) == 3 + assert int(rows[1, 4]) == 0 + assert int(rows[1, 14]) & 16 + assert int(rows[1, 14]) & 32 + assert int(rows[1, 15]) == 1 + review = "\n".join(plan.review_summary) + assert "readout_message_producer_consumer_plan=compiler_owned" in review + assert "streaming_readout_strategy=compiler_product_pending_registered_program_body" in review + + streaming_plan = build_temporal_readout_message_producer_consumer_plan( + table, + streaming_readout_body_available=True, + ) + streaming_rows = temporal_readout_message_producer_consumer_rows_tensor(streaming_plan) + assert streaming_plan.selected_strategy == "stream_readout_from_message_projection" + assert int(streaming_rows[0, 3]) == 2 + assert int(streaming_rows[0, 4]) == 0 + assert int(streaming_rows[1, 3]) == 1 + assert int(streaming_rows[1, 4]) == 1 + assert int(streaming_rows[1, 15]) == 0 + + +def test_message_transition_producer_consumer_rows_are_compiler_owned_legality_rows() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + plan = build_temporal_message_transition_producer_consumer_plan(table) + rows = temporal_message_transition_producer_consumer_rows_tensor(plan) + + assert rows.dtype == torch.long + assert tuple(rows.shape) == (2, 16) + assert plan.selected_strategy == "materialized_recurrent_message" + assert int(rows[0, 2]) == 1 + assert int(rows[0, 3]) == 1 + assert int(rows[0, 4]) == 1 + assert int(rows[0, 5]) == temporal_surface_opcode("message") + assert int(rows[0, 9]) == temporal_surface_opcode("transition") + assert int(rows[0, 13]) == temporal_program_access_opcode("transition_aggregated_message_input") + assert int(rows[0, 14]) & 8 + assert int(rows[0, 14]) & 16 + assert int(rows[0, 14]) & 32 + assert int(rows[0, 14]) & 64 + assert int(rows[0, 15]) == 0 + assert int(rows[1, 2]) == 2 + assert int(rows[1, 3]) == 3 + assert int(rows[1, 4]) == 0 + assert int(rows[1, 13]) == temporal_program_access_opcode("transition_aggregated_message_input") + assert int(rows[1, 15]) == 4 + review = "\n".join(plan.review_summary) + assert "message_transition_producer_consumer_plan=compiler_owned" in review + assert "message_transition_strategy=materialized_or_pending_direct_chunk_body" in review + + streaming_plan = build_temporal_message_transition_producer_consumer_plan( + table, + streaming_transition_body_available=True, + ) + streaming_rows = temporal_message_transition_producer_consumer_rows_tensor(streaming_plan) + assert streaming_plan.selected_strategy == "stream_message_to_transition_input" + assert int(streaming_rows[0, 3]) == 2 + assert int(streaming_rows[0, 4]) == 0 + assert int(streaming_rows[1, 3]) == 1 + assert int(streaming_rows[1, 4]) == 1 + assert int(streaming_rows[1, 15]) == 0 + + mixed_runtime = build(_make_mixed_spec()) + mixed_static_tensors = mixed_runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + mixed_table = build_temporal_primitive_table_plan( + mixed_runtime, + with_cached_population_static_tensors(mixed_runtime, mixed_static_tensors), + ) + mixed_plan = build_temporal_message_transition_producer_consumer_plan( + mixed_table, + streaming_transition_body_available=True, + ) + mixed_rows = temporal_message_transition_producer_consumer_rows_tensor(mixed_plan) + assert tuple(mixed_rows.shape) == (4, 16) + assert mixed_plan.selected_strategy == "materialized_recurrent_message" + assert {int(row[12]) for row in mixed_rows if int(row[2]) == 1 and int(row[3]) == 1} == {0, 1} + assert all(int(row[15]) == 2 for row in mixed_rows if int(row[2]) == 2) + + +def test_forward_artifact_aggregate_merges_are_executable_compiler_rows() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + memory_plan = build_temporal_memory_liveness_plan(table) + aggregate_merge_rows = temporal_forward_artifact_merge_rows_tensor(table).clone() + aggregate_merge_rows[0, 4] = 2 + aggregate_merge_rows[0, 6] = -1 + aggregate_merge_rows[0, 7] = -1 + aggregate_merge_rows[0, 8] = -1 + + fused_program_plan = build_temporal_fused_cuda_program_plan( + primitive_rows=temporal_primitive_rows_tensor(table), + forward_plan=build_temporal_forward_executable_plan(table), + backward_plan=build_temporal_backward_executable_plan(table), + memory_plan=memory_plan, + memory_liveness_rows=temporal_memory_liveness_rows_tensor(memory_plan), + forward_handler_rows=temporal_forward_executor_handler_rows_tensor(table), + reverse_handler_rows=temporal_reverse_executor_handler_rows_tensor(table), + forward_artifact_route_rows=temporal_forward_artifact_route_rows_tensor(table), + forward_artifact_merge_rows=aggregate_merge_rows, + forward_output_route_rows=temporal_forward_output_route_rows_tensor(table), + readout_message_producer_consumer_rows=temporal_readout_message_producer_consumer_rows_tensor( + build_temporal_readout_message_producer_consumer_plan(table) + ), + message_transition_producer_consumer_rows=temporal_message_transition_producer_consumer_rows_tensor( + build_temporal_message_transition_producer_consumer_plan(table) + ), + reverse_artifact_consumer_route_rows=temporal_reverse_artifact_consumer_route_rows_tensor(table), + reverse_parameter_reducer_route_rows=temporal_reverse_parameter_reducer_route_rows_tensor(table), + ) + + assert fused_program_plan.status == "legal" + assert fused_program_plan.blocker_code == "" + assert "registered_fused_program_has_sequence_forward_span_dispatch_body" in fused_program_plan.blocker_reason + + +def test_forward_multi_output_concat_routes_are_compiler_owned_rows() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + memory_plan = build_temporal_memory_liveness_plan(table) + output_route_rows = temporal_forward_output_route_rows_tensor(table) + output_route_rows = torch.cat((output_route_rows, output_route_rows.clone()), dim=0) + output_route_rows[1, 0] = 1 + output_route_rows[:, 1] = 3 + output_route_rows[1, 9] = int(runtime.output_cell_idx.numel()) + + fused_program_plan = build_temporal_fused_cuda_program_plan( + primitive_rows=temporal_primitive_rows_tensor(table), + forward_plan=build_temporal_forward_executable_plan(table), + backward_plan=build_temporal_backward_executable_plan(table), + memory_plan=memory_plan, + memory_liveness_rows=temporal_memory_liveness_rows_tensor(memory_plan), + forward_handler_rows=temporal_forward_executor_handler_rows_tensor(table), + reverse_handler_rows=temporal_reverse_executor_handler_rows_tensor(table), + forward_artifact_route_rows=temporal_forward_artifact_route_rows_tensor(table), + forward_artifact_merge_rows=temporal_forward_artifact_merge_rows_tensor(table), + forward_output_route_rows=output_route_rows, + readout_message_producer_consumer_rows=temporal_readout_message_producer_consumer_rows_tensor( + build_temporal_readout_message_producer_consumer_plan(table) + ), + message_transition_producer_consumer_rows=temporal_message_transition_producer_consumer_rows_tensor( + build_temporal_message_transition_producer_consumer_plan(table) + ), + reverse_artifact_consumer_route_rows=temporal_reverse_artifact_consumer_route_rows_tensor(table), + reverse_parameter_reducer_route_rows=temporal_reverse_parameter_reducer_route_rows_tensor(table), + ) + + assert fused_program_plan.status == "legal" + assert fused_program_plan.blocker_code == "" + + +def test_forward_output_routes_reject_non_concat_offsets_before_launch() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + memory_plan = build_temporal_memory_liveness_plan(table) + output_route_rows = temporal_forward_output_route_rows_tensor(table) + output_route_rows[0, 9] = 1 + + fused_program_plan = build_temporal_fused_cuda_program_plan( + primitive_rows=temporal_primitive_rows_tensor(table), + forward_plan=build_temporal_forward_executable_plan(table), + backward_plan=build_temporal_backward_executable_plan(table), + memory_plan=memory_plan, + memory_liveness_rows=temporal_memory_liveness_rows_tensor(memory_plan), + forward_handler_rows=temporal_forward_executor_handler_rows_tensor(table), + reverse_handler_rows=temporal_reverse_executor_handler_rows_tensor(table), + forward_artifact_route_rows=temporal_forward_artifact_route_rows_tensor(table), + forward_artifact_merge_rows=temporal_forward_artifact_merge_rows_tensor(table), + forward_output_route_rows=output_route_rows, + readout_message_producer_consumer_rows=temporal_readout_message_producer_consumer_rows_tensor( + build_temporal_readout_message_producer_consumer_plan(table) + ), + message_transition_producer_consumer_rows=temporal_message_transition_producer_consumer_rows_tensor( + build_temporal_message_transition_producer_consumer_plan(table) + ), + reverse_artifact_consumer_route_rows=temporal_reverse_artifact_consumer_route_rows_tensor(table), + reverse_parameter_reducer_route_rows=temporal_reverse_parameter_reducer_route_rows_tensor(table), + ) + + assert fused_program_plan.status == "blocked" + assert fused_program_plan.blocker_code == "FORWARD_OUTPUT_ROUTE_UNSUPPORTED" + assert "registered_fused_program_requires_zero_offset_for_non_concat_output_routes" in ( + fused_program_plan.blocker_reason + ) + + +def test_forward_multi_output_routes_require_explicit_merge_semantics() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + memory_plan = build_temporal_memory_liveness_plan(table) + output_route_rows = temporal_forward_output_route_rows_tensor(table) + output_route_rows = torch.cat((output_route_rows, output_route_rows.clone()), dim=0) + output_route_rows[1, 0] = 1 + + fused_program_plan = build_temporal_fused_cuda_program_plan( + primitive_rows=temporal_primitive_rows_tensor(table), + forward_plan=build_temporal_forward_executable_plan(table), + backward_plan=build_temporal_backward_executable_plan(table), + memory_plan=memory_plan, + memory_liveness_rows=temporal_memory_liveness_rows_tensor(memory_plan), + forward_handler_rows=temporal_forward_executor_handler_rows_tensor(table), + reverse_handler_rows=temporal_reverse_executor_handler_rows_tensor(table), + forward_artifact_route_rows=temporal_forward_artifact_route_rows_tensor(table), + forward_artifact_merge_rows=temporal_forward_artifact_merge_rows_tensor(table), + forward_output_route_rows=output_route_rows, + readout_message_producer_consumer_rows=temporal_readout_message_producer_consumer_rows_tensor( + build_temporal_readout_message_producer_consumer_plan(table) + ), + message_transition_producer_consumer_rows=temporal_message_transition_producer_consumer_rows_tensor( + build_temporal_message_transition_producer_consumer_plan(table) + ), + reverse_artifact_consumer_route_rows=temporal_reverse_artifact_consumer_route_rows_tensor(table), + reverse_parameter_reducer_route_rows=temporal_reverse_parameter_reducer_route_rows_tensor(table), + ) + + assert fused_program_plan.status == "blocked" + assert fused_program_plan.blocker_code == "FORWARD_OUTPUT_ROUTE_UNSUPPORTED" + assert "registered_fused_program_requires_explicit_multi_output_route_merge_kind" in ( + fused_program_plan.blocker_reason + ) + + +def test_reverse_artifact_consumer_routes_map_reverse_spans_to_forward_artifacts() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + rows = temporal_reverse_artifact_consumer_route_rows_tensor(table) + forward_routes = temporal_forward_artifact_route_rows_tensor(table) + + assert rows.dtype == torch.long + assert tuple(rows.shape[1:]) == (12,) + assert int(rows.shape[0]) > 0 + route_rows = {int(row[0]): row for row in forward_routes.tolist()} + for row in rows.tolist(): + forward_route = route_rows[int(row[6])] + assert int(row[1]) == int(forward_route[1]) + assert int(row[4]) == int(forward_route[4]) + assert int(row[5]) == int(forward_route[5]) + assert int(row[7]) == int(forward_route[2]) + assert int(row[8]) == int(forward_route[3]) + readout_consumer = next(row for row in rows.tolist() if int(row[1]) == temporal_surface_opcode("readout")) + readout_forward_route = route_rows[int(readout_consumer[6])] + assert int(readout_consumer[3]) != int(readout_forward_route[3]) + + +def test_reverse_parameter_reducer_routes_are_compiler_owned_rows() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + rows = temporal_reverse_parameter_reducer_route_rows_tensor(table) + + assert rows.dtype == torch.long + assert tuple(rows.shape[1:]) == (12,) + route_pairs = {(int(row[1]), int(row[2]), int(row[6]), int(row[7])) for row in rows.tolist()} + assert any( + kind == temporal_reverse_output_route_kind_opcode("readout_parameter_grad") + and target == temporal_reverse_output_route_target_id("value_to_output_weight") + and executor_row >= 0 + and executor_id > 0 + for kind, target, executor_row, executor_id in route_pairs + ) + assert any( + kind == temporal_reverse_output_route_kind_opcode("sender_kv_parameter_grad") + and target == temporal_reverse_output_route_target_id("boundary_input_kv_weight") + and executor_row >= 0 + and executor_id > 0 + for kind, target, executor_row, executor_id in route_pairs + ) + assert any( + kind == temporal_reverse_output_route_kind_opcode("transition_boundary") + and target == temporal_reverse_output_route_target_id("recurrent_query") + and executor_row >= 0 + and executor_id > 0 + for kind, target, executor_row, executor_id in route_pairs + ) + + +def test_temporal_reverse_executor_rows_allow_one_transition_bucket() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + reverse_rows = temporal_reverse_executor_rows_tensor(table) + reverse_executor_rows = temporal_reverse_executor_rows(table) + + assert table.bucket_count == 1 + assert "gated_logspace_recurrence" in table.primitive_names + assert "diag_rtu" not in table.primitive_names + assert tuple(reverse_rows.shape) == (len(reverse_executor_rows), 6) + assert set(int(item) for item in reverse_rows[:, 0].tolist()) == {2, 4, 5} + assert [row.executor_name for row in reverse_executor_rows] == [ + "fixed_slot_context_nudge_message_backward", + "mean_projection_reduction_boundary_backward", + "gated_logspace_transition_backward", + ] + review_text = "\n".join(table.review_summary) + assert "flat_bucket_tensor_tables" in review_text + assert "primitive_row_count=" in review_text + + +def test_temporal_primitive_executor_plan_fails_closed_for_missing_generic_dispatch() -> None: + runtime = build(_make_alias_population_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + executor_plan = build_temporal_primitive_executor_plan(table) + + assert not executor_plan.has_blockers + blocker_text = "\n".join(executor_plan.blockers) + assert "reason=declared_composite_transition_primitive_not_dispatched_by_temporal_primitive_row" not in blocker_text + assert "reason=recurrence_formula_not_dispatched_by_primitive_executor" not in blocker_text + assert "reason=fabric_nn_primitive_not_dispatched_by_temporal_primitive_row" not in blocker_text + assert "reason=message_primitive_not_dispatched_by_temporal_primitive_row" not in blocker_text + assert "reason=message_primitive_rows_missing_from_temporal_table" not in blocker_text + assert "reason=readout_boundary_not_dispatched_by_temporal_primitive_row" not in blocker_text + assert "reason=readout_boundary_rows_missing_from_temporal_table" not in blocker_text + assert "reason=parameter_reduction_not_dispatched_by_temporal_primitive_row" not in blocker_text + assert "reason=parameter_reduction_rows_missing_from_temporal_table" not in blocker_text + assert "fixed_" not in blocker_text + assert "reason=fused_parameter_reduction_block_not_dispatched_by_primitive_executor" not in blocker_text + assert "slstm" not in blocker_text + assert "axoncell" not in blocker_text + summary_text = "\n".join(executor_plan.summaries) + assert "registered_executor_binding_group_implemented" in summary_text + assert "surface=message" in summary_text + assert "params=message_query_slot_weight" in summary_text + assert "params=message_query_nudge_scale" in summary_text + assert "params=message_sender_slot_key_weight" in summary_text + assert "params=message_sender_context_key" in summary_text + assert "params=input_sender_value_weight" in summary_text + assert "input_group_value_weight" in summary_text + assert "params=recurrent_sender_value_weight" in summary_text + assert "params=message_output_weight" in summary_text + assert ( + "params=message_query_slot_weight,recurrent_sender_value_weight,message_query_nudge_scale," + "message_sender_slot_key_weight,message_sender_context_key,input_sender_value_weight," + "input_group_value_weight,message_output_weight" in summary_text + ) + assert "params=output_q,value_to_output_weight,output_cell_bias" in summary_text + assert "params=value_to_state_weight,recurrent_bias" in summary_text + assert "params=gate_weight,bias" in summary_text + assert any(contract.status == "implemented" for contract in executor_plan.contracts) + assert any(group.status == "implemented" for group in executor_plan.fusion_groups) + assert {group.surface for group in executor_plan.fusion_groups}.issuperset( + {"message", "readout", "parameter_reduction", "transition"} + ) + transition_groups = [group for group in executor_plan.fusion_groups if group.surface == "transition"] + assert transition_groups + assert all(group.row_indices for group in transition_groups) + assert {group.status for group in executor_plan.fusion_groups}.issuperset({"implemented"}) + + +def test_temporal_executor_binding_plan_groups_compiled_bindings_by_executor_row() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan( + runtime, + with_cached_population_static_tensors(runtime, static_tensors), + ) + + forward_plan = build_temporal_forward_executor_binding_plan(table) + reverse_plan = build_temporal_reverse_executor_binding_plan(table) + backward_executable_plan = build_temporal_backward_executable_plan( + table, + reverse_binding_plan=reverse_plan, + ) + reverse_stage_plan = build_temporal_reverse_program_stage_plan(table, backward_executable_plan) + transition_param_grad_plan = build_temporal_transition_param_grad_binding_plan( + table, + reverse_binding_plan=reverse_plan, + ) + + assert not forward_plan.has_blockers + assert not reverse_plan.has_blockers + assert forward_plan.rows.dtype == torch.long + assert reverse_plan.rows.dtype == torch.long + assert reverse_stage_plan.rows.dtype == torch.long + assert transition_param_grad_plan.rows.dtype == torch.long + assert tuple(forward_plan.rows.shape[1:]) == (8,) + assert tuple(reverse_plan.rows.shape[1:]) == (8,) + assert tuple(reverse_stage_plan.rows.shape[1:]) == (10,) + assert tuple(transition_param_grad_plan.rows.shape[1:]) == (8,) + forward_text = "\n".join(forward_plan.summaries) + reverse_text = "\n".join(reverse_plan.summaries) + reverse_stage_text = "\n".join(reverse_stage_plan.summaries) + transition_param_text = "\n".join(transition_param_grad_plan.summaries) + assert "direction=forward,executor_row=0,executor_id=5,executor=fixed_slot_context_nudge_message" in forward_text + assert "logical=message_query_slot_weight" in forward_text + assert "logical=message_query_nudge_scale" in forward_text + assert "logical=message_sender_slot_key_weight" in forward_text + assert "logical=message_sender_context_key" in forward_text + assert "logical=input_sender_value_weight" in forward_text + assert "logical=input_group_value_weight" in forward_text + assert "logical=recurrent_sender_value_weight" in forward_text + assert "logical=message_output_weight" in forward_text + assert "logical=output_q" in forward_text + assert "logical=value_to_output_weight" in forward_text + assert "logical=output_cell_bias" in forward_text + assert ( + "direction=reverse,executor_row=1,executor_id=4,executor=mean_projection_reduction_boundary_backward" + in reverse_text + ) + assert "direction=reverse,executor_row=2,executor_id=2,executor=gated_logspace_transition_backward" in reverse_text + assert "logical=output_q" in reverse_text + assert "logical=value_to_output_weight" in reverse_text + assert "logical=output_cell_bias" in reverse_text + assert "logical=value_to_state_weight" in reverse_text + assert "logical=gate_weight" in reverse_text + assert "grad_logical=grad_value_to_state_weight" in transition_param_text + assert "grad_logical=grad_gate_weight" in transition_param_text + assert "grad_logical=grad_outnorm_weight" in transition_param_text + assert "reducer=input_projection_weight" in transition_param_text + assert "expanded_transposed_static_tensor:value_to_cell_weight" in transition_param_text + assert "parameter=input_proj_weight" in transition_param_text + assert "parameter=input_proj_weight" in "\n".join( + summary + for summary in transition_param_grad_plan.summaries + if "selected_static_source=message_to_cell_weight" in summary + ) + assert "reducer=materialized" in transition_param_text + assert "surface=parameter_reduction" not in transition_param_text + gated_primitive_record = transition_primitive_executor_record("gated_logspace_recurrence") + assert gated_primitive_record is not None + assert ("grad_value_to_state_weight", "value_to_state_weight", "input_projection_weight") in ( + gated_primitive_record.param_grad_outputs + ) + assert "grad_public_y" in gated_primitive_record.reverse_input_bindings + assert "value_to_state_weight" in gated_primitive_record.parameter_bindings + assert "grad_gate_weight" in gated_primitive_record.reverse_output_bindings + diag_primitive_record = transition_primitive_executor_record("diag_rtu") + assert diag_primitive_record is not None + assert diag_primitive_record.aliases == ("diagonal_recurrence",) + assert transition_primitive_executor_record_for_lowered_primitive("diagonal_recurrence") is diag_primitive_record + assert "grad_next_hc1" in diag_primitive_record.reverse_input_bindings + assert "public_y_raw" in diag_primitive_record.reverse_input_bindings + assert "activation_id" in diag_primitive_record.parameter_bindings + assert "outnorm_weight" in diag_primitive_record.parameter_bindings + assert "outnorm_eps" in diag_primitive_record.parameter_bindings + assert "grad_out_proj_weight" in diag_primitive_record.reverse_output_bindings + assert "grad_outnorm_weight" in diag_primitive_record.reverse_output_bindings + assert "_TRANSITION_REVERSE_PARAM_GRAD_OUTPUTS" not in ( + Path("src/cortical/fabric/backend/cuda/sequence_surface/compiler/executor_bindings.py") + ).read_text(encoding="utf-8") + assert "kind=output_grad_window" in reverse_stage_text + assert "kind=readout_message_kv_step" in reverse_stage_text + assert "kind=transition_step" in reverse_stage_text + assert "kind=recurrent_message_boundary_initial_kv_step" in reverse_stage_text + assert "kind=parameter_reducer_step" in reverse_stage_text + assert reverse_stage_text.count("kind=parameter_reducer_step") == len(temporal_reverse_executor_rows(table)) + assert "surface=parameter_reduction" in reverse_stage_text + assert "compatibility" not in forward_text + assert "compatibility" not in reverse_text + + +def test_temporal_executor_fusion_patterns_are_structured() -> None: + strategy_registry = temporal_executor_strategy_registry() + forward_patterns = temporal_forward_executor_patterns() + forward_names = {pattern.executor_name for pattern in forward_patterns} + native_strategy_rows = temporal_native_executor_strategy_rows_tensor() + native_callable_catalog_rows = temporal_native_callable_catalog_rows_tensor() + native_callable_binding_schema_rows = temporal_native_callable_binding_schema_rows_tensor() + native_callable_output_rows = temporal_native_callable_output_rows_tensor() + transition_callable_rows = temporal_transition_primitive_native_callable_rows_tensor() + native_callable_hashes = {int(row[1]) for row in native_callable_catalog_rows.tolist()} + native_callable_ids = {definition.callable_id for definition in temporal_native_callable_definitions()} + binding_schema_definitions = temporal_native_callable_binding_schema_definitions() + output_definitions = temporal_native_callable_output_definitions() + generated_native_header = temporal_native_callable_generated_header_text() + checked_in_native_header = Path( + "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/flat_bucket_registered_native_callables.cuh" + ).read_text(encoding="utf-8") + + assert forward_patterns == strategy_registry.forward_patterns() + assert temporal_reverse_executor_patterns() == strategy_registry.reverse_patterns() + assert len(strategy_registry.all_patterns()) == len(forward_patterns) + len(temporal_reverse_executor_patterns()) + assert tuple(native_strategy_rows.shape)[1] == 17 + assert tuple(native_callable_catalog_rows.shape)[1] == 8 + assert tuple(native_callable_binding_schema_rows.shape)[1] == 10 + assert tuple(native_callable_output_rows.shape)[1] == 12 + assert tuple(transition_callable_rows.shape)[1] == 6 + assert int(native_strategy_rows.shape[0]) == len(strategy_registry.all_patterns()) + assert int(native_callable_catalog_rows.shape[0]) == len(temporal_native_callable_definitions()) + assert int(native_callable_binding_schema_rows.shape[0]) == len(binding_schema_definitions) + assert int(native_callable_output_rows.shape[0]) == len(output_definitions) + assert len({tuple(int(item) for item in row) for row in native_strategy_rows.tolist()}) == int( + native_strategy_rows.shape[0] + ) + assert len(native_callable_hashes) == int(native_callable_catalog_rows.shape[0]) + assert {int(item) for item in native_strategy_rows[:, 16].tolist()} <= native_callable_hashes + assert {int(row[1]) for row in native_callable_binding_schema_rows.tolist()} <= native_callable_hashes + assert {int(row[1]) for row in native_callable_output_rows.tolist()} <= native_callable_hashes + assert {int(row[1]) for row in transition_callable_rows.tolist() if int(row[4])} <= native_callable_hashes + assert {int(row[2]) for row in transition_callable_rows.tolist() if int(row[5])} <= native_callable_hashes + reverse_transition_schema_by_callable = { + (definition.callable_id, definition.binding_kind): tuple( + item + for item in sorted( + binding_schema_definitions, + key=lambda candidate: int(candidate.local_binding_index), + ) + if item.direction == "reverse" + and item.surface == "transition" + and item.callable_id == definition.callable_id + and item.binding_kind == definition.binding_kind + ) + for definition in binding_schema_definitions + if definition.direction == "reverse" and definition.surface == "transition" + } + for record in registered_transition_primitive_executor_records(): + if record.program_backward_status != "callable" or not record.program_reverse_native_callable: + continue + reverse_inputs = reverse_transition_schema_by_callable[(record.program_reverse_native_callable, "input")] + reverse_parameters = reverse_transition_schema_by_callable.get( + (record.program_reverse_native_callable, "parameter"), + (), + ) + reverse_outputs = reverse_transition_schema_by_callable[(record.program_reverse_native_callable, "output")] + assert tuple(definition.logical_name for definition in reverse_inputs) == record.reverse_input_bindings + assert tuple(definition.required for definition in reverse_inputs) == (True,) * len( + record.reverse_input_bindings + ) + assert tuple(definition.logical_name for definition in reverse_parameters) == record.parameter_bindings + assert tuple(definition.required for definition in reverse_parameters) == tuple( + name not in {"outnorm_eps", "eps", "activation_id"} for name in record.parameter_bindings + ) + assert tuple(definition.logical_name for definition in reverse_outputs) == record.reverse_output_bindings + assert tuple(definition.required for definition in reverse_outputs) == (True,) * len( + record.reverse_output_bindings + ) + assert "native.reverse.parameter_reduction.transition.value_to_cell_msg_to_cell.v1" in native_callable_ids + assert "native.reverse.parameter_reduction.fixed_slot_context_message.v1" in native_callable_ids + assert all("callable=" in summary and ",symbol=" in summary for summary in temporal_native_callable_summaries()) + message_callable_definitions = tuple( + definition + for definition in temporal_native_callable_definitions() + if definition.category == "executor_strategy" and definition.surface == "message" + ) + readout_callable_definitions = tuple( + definition + for definition in temporal_native_callable_definitions() + if definition.category == "executor_strategy" and definition.surface == "readout" + ) + assert all(definition.cxx_entrypoint_phases for definition in message_callable_definitions) + assert all(definition.cxx_entrypoint_phases for definition in readout_callable_definitions) + assert any("cxx_phases=bind+recurrent_kv+message" in summary for summary in temporal_native_callable_summaries()) + assert any( + "cxx_phases=recurrent_kv_backward+recurrent_message_backward+initial_recurrent_kv_backward+" + "boundary_kv_backward+recurrent_kv_forward_recompute" in summary + for summary in temporal_native_callable_summaries() + ) + assert any( + "cxx_phases=bind+message+projection+projection_into" in summary + for summary in temporal_native_callable_summaries() + ) + assert any( + "cxx_phases=readout_backward+output_message_backward" in summary + for summary in temporal_native_callable_summaries() + ) + assert all( + "binding_kind=" in summary and "logical_name=" in summary + for summary in temporal_native_callable_binding_schema_summaries() + ) + assert all( + "runtime_role=" in summary and "logical_index_source=" in summary + for summary in temporal_native_callable_output_summaries() + ) + assert temporal_native_callable_catalog_fingerprint() > 0 + assert temporal_native_callable_binding_schema_fingerprint() > 0 + assert temporal_native_callable_output_contract_fingerprint() > 0 + assert { + (definition.primitive, definition.output_name, definition.runtime_role, definition.shape_kind) + for definition in output_definitions + } >= { + ("linear", "gate_logits", "transition_forward_linear_output", "gate_logits"), + ("tanh", "output", "transition_forward_unary_output", "hidden"), + ("gated_logspace_recurrence", "next_y", "transition_forward_state_output", "hidden"), + ("diag_rtu", "preproj", "transition_forward_diag_output", "diagonal_preproj"), + } + assert checked_in_native_header == generated_native_header + validate_temporal_native_callable_generated_header(checked_in_native_header) + assert set(int(item) for item in native_strategy_rows[:, 0].tolist()) == {1, 2} + assert set(int(item) for item in native_strategy_rows[:, 8].tolist()) == {1} + assert set(int(item) for item in native_strategy_rows[:, 9].tolist()) == {1} + assert set(int(item) for item in native_strategy_rows[:, 10].tolist()) == {1} + assert set(int(item) for item in native_strategy_rows[:, 11].tolist()) == {1} + assert all(int(item) > 0 for item in native_strategy_rows[:, 1].tolist()) + assert all(int(item) > 0 for item in native_strategy_rows[:, 2].tolist()) + assert all(int(item) > 0 for item in native_strategy_rows[:, 3].tolist()) + assert all(int(item) > 0 for item in native_strategy_rows[:, 4].tolist()) + assert all(int(item) > 0 for item in native_strategy_rows[:, 5].tolist()) + assert all(int(item) > 0 for item in native_strategy_rows[:, 6].tolist()) + assert all(int(item) > 0 for item in native_strategy_rows[:, 7].tolist()) + assert all(int(item) > 0 for item in native_strategy_rows[:, 12].tolist()) + assert all(int(item) >= 0 for item in native_strategy_rows[:, 13].tolist()) + assert all(int(item) >= 0 for item in native_strategy_rows[:, 14].tolist()) + assert set(int(item) for item in native_strategy_rows[:, 15].tolist()) <= {0, 1} + assert all(int(item) > 0 for item in native_strategy_rows[:, 16].tolist()) + assert all(int(item) > 0 for item in transition_callable_rows[:, 0].tolist()) + assert all(int(item) > 0 for item in transition_callable_rows[:, 1].tolist()) + assert all(int(row[2]) > 0 for row in transition_callable_rows.tolist() if int(row[5])) + assert any( + int(row[0]) == temporal_primitive_opcode("tanh") and int(row[2]) > 0 and int(row[5]) == 1 + for row in transition_callable_rows.tolist() + ) + assert set(int(item) for item in transition_callable_rows[:, 3].tolist()) <= {0, 1} + assert set(int(item) for item in transition_callable_rows[:, 4].tolist()) <= {0, 1} + assert set(int(item) for item in transition_callable_rows[:, 5].tolist()) <= {0, 1} + assert len({int(item) for item in native_strategy_rows[:, 12].tolist()}) == len(strategy_registry.all_patterns()) + assert all("strategy_id=" in summary for summary in strategy_registry.strategy_summaries()) + assert all("native_callable=native." in summary for summary in strategy_registry.strategy_summaries()) + assert forward_names == { + "neighborhood_attention_project", + "fixed_slot_context_gate_message", + "fixed_slot_context_nudge_message", + "attention_projection_reduction_boundary", + "attn_projection_reduction_boundary", + "flatten_projection_reduction_boundary", + "mean_projection_reduction_boundary", + "gated_logspace_transition", + "diag_rtu_transition", + "transition_linear_primitive", + "transition_matmul_primitive", + "transition_norm_or_identity_primitive", + "tanh_transition", + } + assert all(pattern.row_signature for pattern in forward_patterns) + assert all(pattern.implementation_contract.startswith("registered_") for pattern in forward_patterns) + assert all(pattern.stable_strategy_id.startswith("forward.") for pattern in forward_patterns) + assert all(pattern.strategy_version == 1 for pattern in forward_patterns) + assert all(pattern.legality_predicate == "match_structural_row_signature" for pattern in forward_patterns) + assert all(pattern.cost_model == "registered_executor_static_priority" for pattern in forward_patterns) + assert all( + pattern.runtime_entrypoint == "registered_temporal_fused_forward_program_cuda" for pattern in forward_patterns + ) + assert not {pattern.executor_name for pattern in forward_patterns if pattern.verified_rewrite_required} + assert all(pattern.required_effects for pattern in forward_patterns) + assert all(pattern.cxx_entrypoints for pattern in forward_patterns if pattern.surface in {"message", "readout"}) + assert all(pattern.stable_handler_kind_opcode == pattern.executor_id for pattern in forward_patterns) + assert all(pattern.stable_handler_capabilities for pattern in forward_patterns) + assert all(pattern.stable_handler_effects for pattern in forward_patterns) + + reverse_patterns = temporal_reverse_executor_patterns() + reverse_names = {pattern.executor_name for pattern in reverse_patterns} + assert reverse_names == { + "neighborhood_attention_project_backward", + "fixed_slot_context_gate_message_backward", + "fixed_slot_context_nudge_message_backward", + "attention_projection_reduction_boundary_backward", + "attn_projection_reduction_boundary_backward", + "flatten_projection_reduction_boundary_backward", + "mean_projection_reduction_boundary_backward", + "gated_logspace_transition_backward", + "diag_rtu_transition_backward", + "transition_linear_primitive_backward", + "transition_matmul_primitive_backward", + "transition_norm_or_identity_primitive_backward", + "tanh_transition_backward", + } + assert all(pattern.row_signature for pattern in reverse_patterns) + assert all(pattern.implementation_contract.startswith("registered_") for pattern in reverse_patterns) + assert all(pattern.stable_strategy_id.startswith("reverse.") for pattern in reverse_patterns) + assert all(pattern.strategy_version == 1 for pattern in reverse_patterns) + assert all(pattern.legality_predicate == "match_structural_row_signature" for pattern in reverse_patterns) + assert all(pattern.cost_model == "registered_executor_static_priority" for pattern in reverse_patterns) + assert all(pattern.runtime_entrypoint == "registered_reverse_executor_bindings" for pattern in reverse_patterns) + assert not {pattern.executor_name for pattern in reverse_patterns if pattern.verified_rewrite_required} + assert all(pattern.required_effects for pattern in reverse_patterns) + assert all(pattern.cxx_entrypoints for pattern in reverse_patterns if pattern.surface in {"message", "readout"}) + assert all(pattern.cxx_entrypoints for pattern in reverse_patterns if pattern.surface == "transition") + assert all(pattern.stable_handler_kind_opcode == pattern.executor_id for pattern in reverse_patterns) + assert all(pattern.stable_handler_capabilities for pattern in reverse_patterns) + assert all(pattern.stable_handler_effects for pattern in reverse_patterns) + + summary_text = "\n".join( + ( + *temporal_forward_executor_pattern_summaries(), + *temporal_reverse_executor_pattern_summaries(), + ) + ) + assert "surface=message" in summary_text + assert "surface=readout" in summary_text + assert "surface=transition" in summary_text + assert "reverse_executor=neighborhood_attention_project_backward" in summary_text + assert "strategy_id=forward.message.neighborhood_attention_project.v1" in summary_text + assert "strategy_id=reverse.transition.gated_logspace.v1" in summary_text + assert "strategy_id=forward.transition.tanh.v1" in summary_text + assert "strategy_id=reverse.transition.tanh.v1" in summary_text + assert "legality=match_structural_row_signature" in summary_text + assert "cost_model=registered_executor_static_priority" in summary_text + assert "handler_kind=" in summary_text + assert "handler_capabilities=" in summary_text + assert "handler_effects=" in summary_text + assert "rewrite_required=0" in summary_text + assert "rewrite_required=1" not in summary_text + assert "implementation_contract=registered_" in summary_text + + +def test_temporal_strategy_matching_uses_canonical_row_group_schema() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + message_rows = tuple(row for row in table.primitive_rows if "surface=message" in row.flat_bucket_identity) + message_schema = canonical_temporal_row_group(surface="message", bucket_ordinal=-1, rows=message_rows) + message_pattern = next( + pattern + for pattern in temporal_forward_executor_patterns() + if pattern.executor_name == "fixed_slot_context_nudge_message" + ) + + assert message_pattern.row_group_schema.matches(message_schema) + assert message_schema.primitive_signature == ( + "linear", + "linear", + "mul", + "concat", + "linear", + "concat", + "linear", + "attention_logits", + "add", + "segment_softmax", + "weighted_sum", + "linear", + "normalize", + ) + assert set(message_schema.effects) == {"state_read", "parameter_read", "message_emit"} + + transition_rows = tuple( + row + for row in table.primitive_rows + if row.bucket_ordinal == 0 and "surface=transition" in row.flat_bucket_identity + ) + transition_schema = canonical_temporal_row_group(surface="transition", bucket_ordinal=0, rows=transition_rows) + transition_pattern = next( + pattern + for pattern in temporal_forward_executor_patterns() + if pattern.executor_name == "gated_logspace_transition" + ) + + assert transition_pattern.row_group_schema.matches(transition_schema) + assert set(transition_schema.effects) == {"state_read", "message_read", "state_write", "tape_policy"} + assert "effects=state_read+message_read+state_write+tape_policy" in transition_schema.summary + + +def test_forward_fused_program_runtime_facts_are_compiler_owned_rows() -> None: + runtime = build(_make_mixed_spec()) + boundary_seq = torch.zeros( + 2, + 3, + int(runtime.input_cell_idx.numel()), + int(runtime.hidden_size), + ) + + runtime_plan = build_temporal_forward_program_runtime_plan( + runtime, + boundary_seq=boundary_seq, + inner_steps=2, + output_boundary_terminal=True, + ) + + assert tuple(runtime_plan.rows.shape) == (10, 6) + assert {int(row[0]) for row in runtime_plan.rows.tolist()} == { + temporal_forward_program_runtime_role_opcode("recurrent_local_sender_idx"), + temporal_forward_program_runtime_role_opcode("output_local_sender_idx"), + temporal_forward_program_runtime_role_opcode("local_distance"), + temporal_forward_program_runtime_role_opcode("local_delay"), + temporal_forward_program_runtime_role_opcode("inner_steps"), + temporal_forward_program_runtime_role_opcode("output_boundary_terminal"), + temporal_forward_program_runtime_role_opcode("distance_scale"), + temporal_forward_program_runtime_role_opcode("head_dim"), + temporal_forward_program_runtime_role_opcode("value_dim"), + temporal_forward_program_runtime_role_opcode("use_delay"), + } + runtime_text = "\n".join(runtime_plan.review_summary) + assert "forward_program_runtime_plan=compiler_owned" in runtime_text + assert "runtime_fact=recurrent_local_sender_idx" in runtime_text + assert "runtime_fact=output_boundary_terminal" in runtime_text + + +def test_reverse_fused_program_runtime_facts_are_compiler_owned_rows() -> None: + runtime = build(_make_mixed_spec()) + reference_boundary = torch.zeros( + 2, + int(runtime.input_cell_idx.numel()), + int(runtime.hidden_size), + ) + message_step_indices = torch.tensor([1, 1, 1], dtype=torch.long) + + runtime_plan = build_temporal_reverse_program_runtime_plan( + runtime, + reference_boundary=reference_boundary, + message_step_indices=message_step_indices, + return_boundary_grad=True, + use_sparse_messages=False, + ) + + assert tuple(runtime_plan.rows.shape) == (24, 6) + assert {int(row[0]) for row in runtime_plan.rows.tolist()} == { + temporal_reverse_program_runtime_role_opcode("graph_to_backend_order"), + temporal_reverse_program_runtime_role_opcode("backend_to_graph_inverse_order"), + temporal_reverse_program_runtime_role_opcode("output_local_sender_idx"), + temporal_reverse_program_runtime_role_opcode("local_distance"), + temporal_reverse_program_runtime_role_opcode("local_delay"), + temporal_reverse_program_runtime_role_opcode("output_neighbor_idx"), + temporal_reverse_program_runtime_role_opcode("output_neighbor_valid"), + temporal_reverse_program_runtime_role_opcode("output_edge_distance"), + temporal_reverse_program_runtime_role_opcode("output_edge_delay"), + temporal_reverse_program_runtime_role_opcode("recurrent_local_sender_idx"), + temporal_reverse_program_runtime_role_opcode("recurrent_neighbor_idx"), + temporal_reverse_program_runtime_role_opcode("recurrent_neighbor_valid"), + temporal_reverse_program_runtime_role_opcode("recurrent_edge_distance"), + temporal_reverse_program_runtime_role_opcode("recurrent_edge_delay"), + temporal_reverse_program_runtime_role_opcode("message_step_indices"), + temporal_reverse_program_runtime_role_opcode("input_count"), + temporal_reverse_program_runtime_role_opcode("recurrent_count"), + temporal_reverse_program_runtime_role_opcode("distance_scale"), + temporal_reverse_program_runtime_role_opcode("use_sparse_messages"), + temporal_reverse_program_runtime_role_opcode("use_delay"), + temporal_reverse_program_runtime_role_opcode("group_size"), + temporal_reverse_program_runtime_role_opcode("head_dim"), + temporal_reverse_program_runtime_role_opcode("value_dim"), + temporal_reverse_program_runtime_role_opcode("return_boundary_grad"), + } + runtime_text = "\n".join(runtime_plan.review_summary) + assert "reverse_program_runtime_plan=compiler_owned" in runtime_text + assert "runtime_fact=message_step_indices" in runtime_text + assert "runtime_fact=return_boundary_grad" in runtime_text + + +def test_fused_program_runtime_support_rejections_are_compiler_owned_rows() -> None: + runtime = SimpleNamespace(_local_message_step_enabled=True) + memory_artifact_plan = SimpleNamespace(mode="store_step_artifacts", store_step_artifacts=True) + + forward_support = build_temporal_forward_program_runtime_support_plan( + runtime, + boundary_seq=torch.zeros(2, 3, 1, 4), + output_contract="output_cells", + readout_pool="flatten", + materialize_final_state=False, + collect_artifacts=True, + memory_artifact_plan=memory_artifact_plan, + ) + assert tuple(forward_support.rows.shape) == (6, 5) + assert forward_support.rejection_reason == "unsupported_boundary_device_or_dtype" + forward_text = "\n".join(forward_support.review_summary) + assert "forward_program_runtime_support=compiler_owned" in forward_text + assert "requirement=boundary_device_dtype,legal=0" in forward_text + + reverse_support = build_temporal_reverse_program_runtime_support_plan( + reference_boundary=torch.zeros(2, 1, 4), + grad_output_window=torch.zeros(2, 3, 1, 4), + grad_carry_cells=None, + materialize_grad_carry_cells=True, + local_time_steps=3, + output_contract="output_cells", + readout_pool="flatten", + reverse_artifact_roles=(), + ) + assert tuple(reverse_support.rows.shape) == (8, 5) + assert reverse_support.rejection_reason == "missing_reverse_artifact_roles" + reverse_text = "\n".join(reverse_support.review_summary) + assert "reverse_program_runtime_support=compiler_owned" in reverse_text + assert "requirement=grad_carry_materialization_policy,legal=1" in reverse_text + assert "materialize_grad_carry_cells=1" in reverse_text + assert "requirement=reverse_artifact_roles,legal=0" in reverse_text + + +def test_temporal_table_runtime_metadata_records_executor_blockers() -> None: + runtime = build(_make_alias_population_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + metadata_runtime = SimpleNamespace() + + record_temporal_primitive_table_runtime_metadata(metadata_runtime, table) + + assert metadata_runtime._last_flat_bucket_temporal_table_review == table.review_summary + assert metadata_runtime._last_flat_bucket_temporal_registered_transition_bucket_kinds == ( + "bucket=0,kind=gated_logspace_recurrence", + "bucket=1,kind=diag_rtu", + ) + assert metadata_runtime._last_flat_bucket_temporal_primitive_row_count == len(table.primitive_rows) + assert metadata_runtime._last_flat_bucket_temporal_tensor_binding_row_count == len(table.tensor_bindings) + assert metadata_runtime._last_flat_bucket_temporal_verifier_status == "ok" + assert metadata_runtime._last_flat_bucket_temporal_compiler_pass_pipeline == temporal_compiler_pass_pipeline() + assert "UNSUPPORTED_PATTERN" in metadata_runtime._last_flat_bucket_temporal_strategy_rejection_codes + assert "MISSING_REQUIRED_BINDING" in metadata_runtime._last_flat_bucket_temporal_strategy_rejection_codes + assert metadata_runtime._last_flat_bucket_temporal_verifier_issues == () + strategy_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_strategy_candidates) + assert "strategy_id=forward.message.fixed_slot_context_nudge.v1" in strategy_text + assert "rejection=UNVERIFIED_REWRITE" not in strategy_text + assert "cost_model=registered_executor_static_priority" in strategy_text + assert metadata_runtime._last_flat_bucket_temporal_legal_strategy_candidates + assert metadata_runtime._last_flat_bucket_temporal_blocked_strategy_candidates == () + assert not hasattr(metadata_runtime, "_last_flat_bucket_temporal_compatibility_debt") + assert metadata_runtime._last_flat_bucket_temporal_workspace_policy == ( + "planner_assigns_workspace_lifetimes_and_alias_sets" + ) + assert metadata_runtime._last_flat_bucket_temporal_layout_policy == ( + "compiler_declares_layouts_before_strategy_selection" + ) + assert metadata_runtime._last_flat_bucket_temporal_memory_peak_estimate_bytes is None + memory_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_memory_plan_summaries) + assert "tensor=transition_tape" in memory_text + assert "tensor=parameter_grad_accumulator" in memory_text + forward_plan_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_forward_executable_plan) + assert metadata_runtime._last_flat_bucket_temporal_forward_strategy_ids == ( + "forward.message.fixed_slot_context_nudge.v1", + "forward.readout.mean_projection_reduction_boundary.v1", + "forward.transition.gated_logspace.v1", + "forward.transition.diag_rtu.v1", + ) + assert "forward_executable_plan=compiler_owned" in forward_plan_text + assert "runtime_entrypoint=registered_temporal_fused_forward_program_cuda" in forward_plan_text + assert "executor_binding_row_count=" in forward_plan_text + assert not hasattr(metadata_runtime, "_last_flat_bucket_temporal_forward_compatibility_launch_plan") + assert metadata_runtime._last_flat_bucket_temporal_forward_executor_binding_rows.dtype == torch.long + assert metadata_runtime._last_flat_bucket_temporal_reverse_executor_binding_rows.dtype == torch.long + assert metadata_runtime._last_flat_bucket_temporal_transition_param_grad_binding_rows.dtype == torch.long + assert metadata_runtime._last_flat_bucket_temporal_reverse_program_stage_rows.dtype == torch.long + assert tuple(metadata_runtime._last_flat_bucket_temporal_reverse_program_stage_rows.shape[1:]) == (10,) + assert metadata_runtime._last_flat_bucket_temporal_forward_executor_binding_blockers == () + assert metadata_runtime._last_flat_bucket_temporal_reverse_executor_binding_blockers == () + forward_binding_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_forward_executor_binding_summaries) + reverse_binding_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_reverse_executor_binding_summaries) + transition_param_binding_text = "\n".join( + metadata_runtime._last_flat_bucket_temporal_transition_param_grad_binding_summaries + ) + reverse_stage_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_reverse_program_stage_summaries) + assert "direction=forward" in forward_binding_text + assert "logical=message_query_slot_weight" in forward_binding_text + assert "direction=reverse" in reverse_binding_text + assert "logical=gate_weight" in reverse_binding_text + assert "grad_logical=grad_value_to_state_weight" in transition_param_binding_text + assert "reducer=input_projection_weight" in transition_param_binding_text + assert "kind=readout_message_kv_step" in reverse_stage_text + assert "kind=recurrent_message_boundary_initial_kv_step" in reverse_stage_text + assert metadata_runtime._last_flat_bucket_temporal_forward_runtime_entrypoint == ( + "registered_temporal_fused_forward_program_cuda" + ) + assert not hasattr(metadata_runtime, "_last_flat_bucket_temporal_forward_compatibility_runtime_entrypoint") + assert metadata_runtime._last_flat_bucket_temporal_forward_strategy_legality_status == "legal" + assert metadata_runtime._last_flat_bucket_temporal_forward_strategy_legality_reasons == () + backward_plan_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_backward_executable_plan) + assert metadata_runtime._last_flat_bucket_temporal_backward_strategy_ids == ( + "reverse.message.fixed_slot_context_nudge.v1", + "reverse.readout.mean_projection_reduction_boundary.v1", + "reverse.transition.gated_logspace.v1", + "reverse.transition.diag_rtu.v1", + ) + assert "backward_executable_plan=compiler_owned" in backward_plan_text + assert "runtime_entrypoint=registered_reverse_executor_bindings" in backward_plan_text + assert not hasattr(metadata_runtime, "_last_flat_bucket_temporal_backward_compatibility_launch_plan") + assert metadata_runtime._last_flat_bucket_temporal_backward_runtime_entrypoint == ( + "registered_reverse_executor_bindings" + ) + assert not hasattr(metadata_runtime, "_last_flat_bucket_temporal_backward_compatibility_runtime_entrypoint") + assert metadata_runtime._last_flat_bucket_temporal_backward_strategy_legality_status == "legal" + assert metadata_runtime._last_flat_bucket_temporal_backward_strategy_legality_reasons == () + launch_contract_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_fused_cuda_launch_contract) + assert "fused_cuda_launch_contract=compiler_owned" in launch_contract_text + assert "required_tables=primitive_rows,forward_executor_rows,reverse_executor_rows" in launch_contract_text + assert "forward_executor_binding_rows,reverse_executor_binding_rows,memory_liveness_plan" in launch_contract_text + assert "memory_liveness_rows" in launch_contract_text + assert "forward_program_runtime_rows" in launch_contract_text + assert "reverse_program_runtime_rows" in launch_contract_text + assert "primitive_row_count=" in launch_contract_text + assert "memory_liveness_row_count=" in launch_contract_text + assert "demotion_policy=fail_closed_no_unregistered_program_demotion" in launch_contract_text + assert "unsupported_policy=typed_strategy_and_binding_rejection" in launch_contract_text + assert "slot" not in launch_contract_text.lower() + program_executor_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_registered_program_executor_plan) + assert "registered_program_executor_plan=compiler_owned" in program_executor_text + assert "forward_entrypoint=registered_temporal_fused_forward_program_cuda" in program_executor_text + assert "demotion_policy=fail_closed_registered_fused_program_only" in program_executor_text + assert metadata_runtime._last_flat_bucket_temporal_registered_program_executor_status == "active" + assert metadata_runtime._last_flat_bucket_temporal_registered_program_executor_demotion_policy == ( + "fail_closed_registered_fused_program_only" + ) + effect_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_effect_summaries) + assert "effect=state_write" in effect_text + assert "effect=output_emit" in effect_text + explain_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_planner_explain) + assert "compiler_pass_pipeline=semantic_ir->" in explain_text + assert "PrimitiveRowABI=1" in explain_text + reverse_summary_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_reverse_executor_summaries) + assert "reverse_executor=fixed_slot_context_nudge_message_backward" in reverse_summary_text + assert "reverse_executor=mean_projection_reduction_boundary_backward" in reverse_summary_text + assert "reverse_executor=gated_logspace_transition_backward" in reverse_summary_text + blocker_text = "\n".join(metadata_runtime._last_flat_bucket_temporal_primitive_executor_blockers) + assert blocker_text == "" + + +def test_temporal_compiler_verifier_reports_effects_and_typed_legality_blockers() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + report = verify_temporal_primitive_table(table) + + assert report.status == "ok" + assert report.pass_pipeline == temporal_compiler_pass_pipeline() + assert report.schema_versions == ( + "PrimitiveRowABI=1", + "TensorRoleTableABI=1", + "ExecutorPlanABI=1", + "BackwardTapeABI=1", + "MetadataSchemaABI=1", + ) + assert set(temporal_strategy_rejection_codes()).issuperset( + { + "UNSUPPORTED_PATTERN", + "UNSUPPORTED_DTYPE", + "UNSUPPORTED_LAYOUT", + "INSUFFICIENT_WORKSPACE", + "RESET_POLICY_MISMATCH", + "TAPE_POLICY_MISMATCH", + "DEVICE_CAPABILITY_MISMATCH", + "SHAPE_OUT_OF_RANGE", + "MISSING_REQUIRED_BINDING", + "MISSING_BACKWARD_COVERAGE", + "HIDDEN_FALLBACK_ROUTE", + "ABI_VERSION_MISMATCH", + } + ) + effect_text = "\n".join(report.effect_summaries) + assert "surface=message" in effect_text + assert "effect=message_emit" in effect_text + assert "surface=transition" in effect_text + assert "effect=tape_policy" in effect_text + issue_text = "\n".join(report.issue_summaries) + assert issue_text == "" + + +def test_temporal_strategy_selection_separates_match_legality_and_cost() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + report = build_temporal_strategy_selection_report(table) + candidate_text = "\n".join(report.candidate_summaries) + + assert report.candidates + assert report.legal_summaries + assert report.blocked_summaries == () + assert "match=matched" in candidate_text + assert "legality=legal" in candidate_text + assert "rejection=UNVERIFIED_REWRITE" not in candidate_text + assert "reason=strategy_matches_but_requires_verified_rewrite_before_cost_selection" not in candidate_text + assert "binding_rows=" in candidate_text + assert "binding_blockers=0" in candidate_text + assert "cost_model=registered_executor_static_priority" in candidate_text + assert "rank=0" in candidate_text + + three_pop_runtime = build(_make_three_population_spec()) + three_pop_static_tensors = three_pop_runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + three_pop_table = build_temporal_primitive_table_plan( + three_pop_runtime, + with_cached_population_static_tensors(three_pop_runtime, three_pop_static_tensors), + ) + three_pop_report = build_temporal_strategy_selection_report(three_pop_table) + matched_gated_buckets = { + int(candidate.bucket_ordinal) + for candidate in three_pop_report.candidates + if candidate.match_status == "matched" + and candidate.direction == "forward" + and candidate.strategy_id == "forward.transition.gated_logspace.v1" + and candidate.bucket_ordinal is not None + } + assert matched_gated_buckets == {0, 2} + + +def test_temporal_backward_requires_registered_reverse_binding_plan() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + executor = TemporalPhysicalBackwardScanExecutor( + SimpleNamespace(), + static_tensors={}, + trainable_params=(), + trainable_param_names=(), + output_contract="output_cells", + output_boundary="sequence", + materialize_final_state=False, + boundary_requires_grad=False, + state_requires_grad=False, + inner_steps=1, + temporal_plan=None, + ) + + executor._require_registered_reverse_executor_for_table(table) + assert executor.runtime._last_flat_bucket_temporal_reverse_engine_reject == "" + + executor._artifact_primitive_table_fingerprint = ("stale-table",) + with pytest.raises(RuntimeError, match="artifact_table_fingerprint_mismatch"): + executor._require_registered_reverse_executor_for_table(table) + assert "artifact_table_fingerprint_mismatch" in (executor.runtime._last_flat_bucket_temporal_reverse_engine_reject) + + +def test_temporal_memory_liveness_plan_tracks_compiler_owned_lifetimes() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + plan = build_temporal_memory_liveness_plan(table) + rows = temporal_memory_liveness_rows_tensor(plan) + runtime_policy = temporal_memory_runtime_policy(plan) + plan_text = "\n".join((*plan.review_summary, *plan.summaries)) + runtime_policy_text = "\n".join(runtime_policy.review_summary) + + assert "memory_liveness_plan=compiler_owned" in plan_text + assert "workspace_policy=planner_assigns_workspace_lifetimes_and_alias_sets" in plan_text + assert "layout=contiguous" in plan_text + assert "tensor=private_state" in plan_text + assert "tensor=message_activation" in plan_text + assert "tensor=transition_tape" in plan_text + assert "tensor=materialized_output_boundary" in plan_text + assert "tensor=parameter_grad_accumulator" in plan_text + assert "surface=runtime_policy" in plan_text + assert "role=local_seed_row" in plan_text + assert "role=metadata_row" in plan_text + assert "role=primitive_output" in plan_text + assert "effect=alias_policy" in plan_text + assert "effect=recompute_window_policy" in plan_text + assert "effect=materialization_policy" in plan_text + assert "effect=cuda_graph_constraint" in plan_text + assert "workspace=policy_table" in plan_text + assert "owner=compiler_memory_policy" in plan_text + assert "recompute=scheduler_tape_policy" in plan_text + assert "recompute=scheduler_alias_policy" in plan_text + assert "recompute=scheduler_recompute_window_policy" in plan_text + assert "recompute=scheduler_materialization_policy" in plan_text + assert "recompute=cuda_graph_guard_policy" in plan_text + assert "owner=compiler_tensor_role_table" in plan_text + assert "owner=compiler_primitive_row" in plan_text + assert runtime_policy.complete + assert runtime_policy.alias_allocation_enabled + assert "memory_runtime_policy=compiler_executable" in runtime_policy_text + assert "policy_complete=1" in runtime_policy_text + assert "policy:alias_policy=scheduler_alias_policy" in runtime_policy_text + assert "policy:cuda_graph_constraint=cuda_graph_guard_policy" in runtime_policy_text + assert plan.fingerprint == (*plan.review_summary, *plan.summaries) + assert rows.dtype == torch.long + assert rows.device.type == "cpu" + assert tuple(rows.shape) == (len(plan.entries), 10) + assert torch.equal(rows[:, 0], torch.arange(len(plan.entries), dtype=torch.long)) + assert int((rows[:, 1] >= -1).all().item()) == 1 + assert int((rows[:, 3:] > 0).all().item()) == 1 + transition_linear_row_index = next( + row_index + for row_index, row in enumerate(table.primitive_rows) + if row.primitive == "linear" and "surface=transition" in row.flat_bucket_identity + ) + transition_matmul_row_index = next( + row_index + for row_index, row in enumerate(table.primitive_rows) + if row.primitive == "matmul" and "surface=transition" in row.flat_bucket_identity + ) + transition_gated_row_index = next( + row_index + for row_index, row in enumerate(table.primitive_rows) + if row.primitive == "gated_logspace_recurrence" and "surface=transition" in row.flat_bucket_identity + ) + transition_norm_row_index = next( + row_index + for row_index, row in enumerate(table.primitive_rows) + if row.primitive == "norm_or_identity" and "surface=transition" in row.flat_bucket_identity + ) + + scheduler = build_temporal_runtime_scheduler_plan( + temporal_plan=SimpleNamespace( + materialization=SimpleNamespace( + reverse_artifact_kind="checkpoint_recompute", + checkpoint_steps=2, + recompute_window_steps=3, + ) + ), + outer_time_steps=4, + inner_steps=2, + output_boundary="sequence", + output_contract="output_cells", + materialize_final_state=False, + collect_artifacts=True, + ) + schedule_plan = build_temporal_memory_runtime_schedule_plan( + plan, + physical_time_steps=8, + collect_artifacts=True, + scheduler_plan=scheduler, + ) + runtime_plan = build_temporal_memory_runtime_artifact_plan( + plan, + physical_time_steps=8, + collect_artifacts=True, + scheduler_plan=scheduler, + ) + assert runtime_plan.runtime_schedule_plan == schedule_plan + assert schedule_plan.mode == "recompute_step_artifacts" + assert schedule_plan.checkpoint_stride == 2 + assert schedule_plan.recompute_window_len == 3 + assert schedule_plan.checkpoint_steps == (0, 2, 4, 6) + assert schedule_plan.backward_windows == ((0, 2), (2, 4), (4, 6), (6, 8)) + assert schedule_plan.output_physical_steps == (1, 3, 5, 7) + assert schedule_plan.primitive_output_policy == "recompute_or_store_by_scheduler" + assert schedule_plan.tape_policy == "scheduler_tape_policy" + assert schedule_plan.recompute_window_policy == "scheduler_recompute_window_policy" + assert schedule_plan.materialization_policy == "scheduler_materialization_policy" + assert schedule_plan.cuda_graph_constraint == "cuda_graph_guard_policy" + assert schedule_plan.fingerprint == schedule_plan.review_summary + schedule_rows = temporal_memory_runtime_schedule_rows_tensor(schedule_plan) + assert tuple(schedule_rows.shape) == (24, 6) + assert schedule_rows[:, 0].tolist() == list(range(int(schedule_rows.shape[0]))) + assert int(schedule_rows[0, 2].item()) >= 0 + assert int(schedule_rows[1, 2].item()) >= 0 + assert int(schedule_rows[8, 3].item()) == 8 + assert int(schedule_rows[9, 3].item()) == 2 + assert int(schedule_rows[10, 3].item()) == 3 + assert int(schedule_rows[19, 3].item()) == 6 + assert int(schedule_rows[19, 4].item()) == 8 + assert int(schedule_rows[23, 3].item()) == 7 + physical_strategy_plan = build_temporal_physical_strategy_plan( + schedule_plan, + inner_steps=2, + output_boundary="sequence", + reset_policy="absent", + ) + physical_strategy_rows = temporal_physical_strategy_rows_tensor(physical_strategy_plan) + assert physical_strategy_plan.selected_strategy == "stage_materialized" + assert tuple(physical_strategy_rows.shape) == (2, 12) + assert physical_strategy_rows[0, 3].item() == 1 + assert physical_strategy_rows[0, 4].item() == 1 + assert physical_strategy_rows[1, 2].item() == 2 + assert physical_strategy_rows[1, 3].item() == 3 + assert physical_strategy_rows[1, 11].item() == 1 + physical_review = "\n".join(physical_strategy_plan.review_summary) + assert "physical_strategy_plan=compiler_executable" in physical_review + assert "strategy=streaming_step_producer_consumer" in physical_review + assert "blocker=pending_registered_streaming_step_program_body" in physical_review + streaming_strategy_plan = build_temporal_physical_strategy_plan( + schedule_plan, + inner_steps=2, + output_boundary="sequence", + reset_policy="absent", + streaming_step_body_available=True, + ) + streaming_strategy_rows = temporal_physical_strategy_rows_tensor(streaming_strategy_plan) + assert streaming_strategy_plan.selected_strategy == "streaming_step_producer_consumer" + assert tuple(streaming_strategy_rows.shape) == (2, 12) + assert streaming_strategy_rows[0, 2].item() == 1 + assert streaming_strategy_rows[0, 3].item() == 2 + assert streaming_strategy_rows[0, 4].item() == 0 + assert streaming_strategy_rows[1, 2].item() == 2 + assert streaming_strategy_rows[1, 3].item() == 1 + assert streaming_strategy_rows[1, 4].item() == 1 + assert streaming_strategy_rows[1, 11].item() == 0 + streaming_review = "\n".join(streaming_strategy_plan.review_summary) + assert "selected_strategy=streaming_step_producer_consumer" in streaming_review + assert "streaming_step_strategy=registered_program_body" in streaming_review + assert "blocker=-" in streaming_review + assert runtime_plan.mode == "recompute_step_artifacts" + assert runtime_plan.checkpoint_stride == 2 + assert runtime_plan.recompute_window_len == 3 + assert runtime_plan.checkpoint_steps == (0, 2, 4, 6) + assert runtime_plan.backward_windows == ((0, 2), (2, 4), (4, 6), (6, 8)) + assert "recurrent_hidden_before_backend_order" in runtime_plan.reverse_artifact_roles + assert "transition_state_before" in runtime_plan.reverse_artifact_roles + assert "backend_state_cache_before" not in runtime_plan.reverse_artifact_roles + assert "transition_backward_tape_by_population" not in runtime_plan.reverse_artifact_roles + assert not runtime_plan.store_step_artifacts + runtime_plan_text = "\n".join(runtime_plan.review_summary) + assert "memory_runtime_artifact_plan=compiler_executable" in runtime_plan_text + assert "memory_runtime_schedule_plan=compiler_executable" in runtime_plan_text + assert "memory_runtime_policy=local_seed_policy=policy_not_recomputed" in runtime_plan_text + assert "primitive_output_policy=recompute_or_store_by_scheduler" in runtime_plan_text + assert "tape_policy=scheduler_tape_policy" in runtime_plan_text + assert "output_physical_steps=1,3,5,7" in runtime_plan_text + assert "alias_policy=scheduler_alias_policy" in runtime_plan_text + assert "cuda_graph_constraint=cuda_graph_guard_policy" in runtime_plan_text + assert "checkpoint_steps=0,2,4,6" in runtime_plan_text + assert "backward_windows=0:2;2:4;4:6;6:8" in runtime_plan_text + assert "reverse_artifact_roles=" in runtime_plan_text + assert runtime_plan.fingerprint == runtime_plan.review_summary + + buffer_plan = build_temporal_runtime_buffer_plan( + plan, + output_seq_shape=(2, 4, 3, 8), + grad_boundary_seq_shape=(2, 4, 5, 8), + forward_message_step_flat_shape=(2,), + reverse_message_step_flat_shape=(2,), + physical_time_steps=4, + cells_prev_shape=(2, 12, 8), + recurrent_hidden_shape=(2, 4, 8), + grad_carry_cells_shape=(2, 12, 8), + forward_recurrent_msg_shape=(2, 4, 8), + forward_output_msg_shape=(2, 3, 8), + forward_output_cells_shape=(2, 3, 8), + reverse_grad_recurrent_msg_shape=(2, 4, 8), + transition_forward_outputs=( + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_linear_row_index, + bucket_ordinal=0, + logical_name="transition_input", + shape=(2, 4, 8), + ), + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_matmul_row_index, + bucket_ordinal=0, + logical_name="recurrent_gate_logits", + shape=(2, 4, 4, 8), + runtime_role="transition_forward_matmul_output", + ), + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_gated_row_index, + bucket_ordinal=0, + logical_name="next_y", + shape=(2, 4, 8), + runtime_role="transition_forward_state_output", + logical_index=101, + ), + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_norm_row_index, + bucket_ordinal=0, + logical_name="public_y", + shape=(2, 4, 8), + runtime_role="transition_forward_norm_output", + ), + ), + runtime_schedule_plan=schedule_plan, + dtype="torch.float32", + device="cuda:0", + include_workspace_rows=True, + ) + assert buffer_plan.runtime_schedule_fingerprint == schedule_plan.fingerprint + assert torch.equal( + buffer_plan.runtime_schedule_rows, + temporal_memory_runtime_schedule_rows_tensor(schedule_plan), + ) + buffer_text = "\n".join(buffer_plan.review_summary) + assert "memory_runtime_buffer_plan=compiler_executable" in buffer_text + assert "runtime_schedule_attached=1" in buffer_text + assert "runtime_schedule_rows=24x6" in buffer_text + assert "memory_runtime_policy=compiler_executable" in buffer_text + assert "policy_complete=1" in buffer_text + assert "buffer=output_seq" in buffer_text + assert "buffer=grad_boundary_seq" in buffer_text + assert "buffer=forward_message_step_flat" in buffer_text + assert "buffer=reverse_message_step_flat" in buffer_text + assert "buffer=forward_cells_prev_artifact_step_0" in buffer_text + assert "buffer=forward_recurrent_hidden_after_step_0" in buffer_text + assert "buffer=forward_recurrent_msg_step_0" in buffer_text + assert "buffer=forward_output_msg_step_0" in buffer_text + assert "buffer=forward_output_cells_step_0" in buffer_text + assert "buffer=reverse_grad_recurrent_msg" in buffer_text + assert "buffer=reverse_grad_cells_work" in buffer_text + assert "buffer=reverse_grad_carry_cells" in buffer_text + assert f"buffer=transition_forward_linear_output_row_{transition_linear_row_index}_transition_input" in buffer_text + assert ( + f"buffer=transition_forward_matmul_output_row_{transition_matmul_row_index}_recurrent_gate_logits" + in buffer_text + ) + assert f"buffer=transition_forward_state_output_row_{transition_gated_row_index}_next_y" in buffer_text + assert f"buffer=transition_forward_norm_output_row_{transition_norm_row_index}_public_y" in buffer_text + assert "runtime_role=forward_cells_prev_artifact" in buffer_text + assert "runtime_role=forward_recurrent_hidden_after" in buffer_text + assert "runtime_role=forward_recurrent_msg" in buffer_text + assert "runtime_role=forward_output_msg" in buffer_text + assert "runtime_role=forward_output_cells" in buffer_text + assert "runtime_role=reverse_grad_recurrent_msg" in buffer_text + assert "runtime_role=forward_message_step_flat" in buffer_text + assert "runtime_role=reverse_message_step_flat" in buffer_text + assert "runtime_role=reverse_grad_cells_work" in buffer_text + assert "runtime_role=reverse_grad_carry_cells" in buffer_text + assert "runtime_role=transition_forward_linear_output" in buffer_text + assert "runtime_role=transition_forward_matmul_output" in buffer_text + assert "runtime_role=transition_forward_state_output" in buffer_text + assert "runtime_role=transition_forward_norm_output" in buffer_text + assert "buffer=workspace_row_" in buffer_text + assert "memory_row=" in buffer_text + assert "workspace=output_workspace" in buffer_text + assert "workspace=reduction_workspace" in buffer_text + assert "workspace=policy_table" not in buffer_text + runtime_buffer_rows = temporal_runtime_buffer_rows_tensor(buffer_plan) + assert runtime_buffer_rows.dtype == torch.long + assert tuple(runtime_buffer_rows.shape) == (len(buffer_plan.specs), 10) + assert int(runtime_buffer_rows[:, 1].min().item()) >= 0 + assert int(runtime_buffer_rows[:, 6].min().item()) >= 0 + assert int(runtime_buffer_rows[:, 8].max().item()) > 0 + validate_temporal_runtime_buffer_plan( + plan, + buffer_plan, + require_runtime_schedule=True, + require_workspace_coverage=True, + ) + assert buffer_plan.fingerprint == buffer_plan.review_summary + + cpu_buffer_plan = build_temporal_runtime_buffer_plan( + plan, + output_seq_shape=(2, 4, 3, 8), + grad_boundary_seq_shape=(2, 4, 5, 8), + forward_message_step_flat_shape=(2,), + reverse_message_step_flat_shape=(2,), + physical_time_steps=4, + cells_prev_shape=(2, 12, 8), + recurrent_hidden_shape=(2, 4, 8), + grad_carry_cells_shape=(2, 12, 8), + forward_recurrent_msg_shape=(2, 4, 8), + forward_output_msg_shape=(2, 3, 8), + forward_output_cells_shape=(2, 3, 8), + reverse_grad_recurrent_msg_shape=(2, 4, 8), + transition_forward_outputs=( + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_linear_row_index, + bucket_ordinal=0, + logical_name="transition_input", + shape=(2, 4, 8), + ), + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_matmul_row_index, + bucket_ordinal=0, + logical_name="recurrent_gate_logits", + shape=(2, 4, 4, 8), + runtime_role="transition_forward_matmul_output", + ), + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_gated_row_index, + bucket_ordinal=0, + logical_name="next_y", + shape=(2, 4, 8), + runtime_role="transition_forward_state_output", + logical_index=101, + ), + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_norm_row_index, + bucket_ordinal=0, + logical_name="public_y", + shape=(2, 4, 8), + runtime_role="transition_forward_norm_output", + ), + ), + dtype="torch.float32", + device="cpu", + include_workspace_rows=True, + ) + output_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="output_seq"), + ) + grad_boundary_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="grad_boundary_seq"), + ) + cells_prev_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="forward_cells_prev_artifact_step_3"), + ) + recurrent_hidden_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="forward_recurrent_hidden_after_step_3"), + ) + recurrent_msg_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="forward_recurrent_msg_step_3"), + ) + output_msg_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="forward_output_msg_step_3"), + ) + output_cells_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="forward_output_cells_step_3"), + ) + reverse_carry_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="reverse_grad_carry_cells"), + ) + reverse_work_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="reverse_grad_cells_work"), + ) + reverse_recurrent_msg_grad_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="reverse_grad_recurrent_msg"), + ) + forward_step_flat_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="forward_message_step_flat"), + ) + reverse_step_flat_buffer = allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + temporal_runtime_buffer_spec(cpu_buffer_plan, name="reverse_message_step_flat"), + ) + assert tuple(output_buffer.shape) == (2, 4, 3, 8) + assert tuple(grad_boundary_buffer.shape) == (2, 4, 5, 8) + assert tuple(cells_prev_buffer.shape) == (2, 12, 8) + assert tuple(recurrent_hidden_buffer.shape) == (2, 4, 8) + assert tuple(recurrent_msg_buffer.shape) == (2, 4, 8) + assert tuple(output_msg_buffer.shape) == (2, 3, 8) + assert tuple(output_cells_buffer.shape) == (2, 3, 8) + assert tuple(reverse_carry_buffer.shape) == (2, 12, 8) + assert tuple(reverse_work_buffer.shape) == (2, 12, 8) + assert tuple(reverse_recurrent_msg_grad_buffer.shape) == (2, 4, 8) + assert tuple(forward_step_flat_buffer.shape) == (2,) + assert tuple(reverse_step_flat_buffer.shape) == (2,) + assert forward_step_flat_buffer.dtype == torch.long + assert reverse_step_flat_buffer.dtype == torch.long + assert int(torch.count_nonzero(grad_boundary_buffer).item()) == 0 + assert int(torch.count_nonzero(cells_prev_buffer).item()) == 0 + assert int(torch.count_nonzero(reverse_carry_buffer).item()) == 0 + assert int(torch.count_nonzero(reverse_work_buffer).item()) == 0 + assert int(torch.count_nonzero(reverse_recurrent_msg_grad_buffer).item()) == 0 + all_buffers = allocate_temporal_runtime_buffers(torch.empty((), dtype=torch.float32), cpu_buffer_plan) + assert len(all_buffers) == len(cpu_buffer_plan.specs) + assert all(int(buffer.numel()) > 0 for buffer in all_buffers) + invalid_memory_row_plan = replace( + cpu_buffer_plan, + specs=( + replace( + temporal_runtime_buffer_spec(cpu_buffer_plan, name="output_seq"), + memory_row_index=None, + ), + ), + ) + with pytest.raises(RuntimeError, match="must reference a compiler memory row"): + validate_temporal_runtime_buffer_plan(plan, invalid_memory_row_plan) + invalid_workspace_plan = replace( + cpu_buffer_plan, + specs=( + replace( + temporal_runtime_buffer_spec(cpu_buffer_plan, name="output_seq"), + workspace_class="transition_workspace", + ), + ), + ) + with pytest.raises(RuntimeError, match="workspace does not match compiler memory row"): + validate_temporal_runtime_buffer_plan(plan, invalid_workspace_plan) + partial_workspace_plan = replace( + cpu_buffer_plan, + specs=tuple(spec for spec in cpu_buffer_plan.specs if spec.runtime_role != "workspace"), + ) + with pytest.raises(RuntimeError, match="does not cover all executable compiler memory rows"): + validate_temporal_runtime_buffer_plan( + plan, + partial_workspace_plan, + require_workspace_coverage=True, + ) + shared_alias_plan = TemporalRuntimeBufferPlan( + runtime_policy=runtime_policy, + specs=( + TemporalRuntimeBufferSpec( + name="scratch_a", + tensor_role="scratch", + shape=(4,), + dtype="torch.float32", + device="cpu", + workspace_class="primitive_workspace", + alias_set="compiler.alias.shared", + init="empty", + owner="compiler_memory_liveness_plan", + runtime_role="workspace", + ), + TemporalRuntimeBufferSpec( + name="scratch_b", + tensor_role="scratch", + shape=(4,), + dtype="torch.float32", + device="cpu", + workspace_class="primitive_workspace", + alias_set="compiler.alias.shared", + init="empty", + owner="compiler_memory_liveness_plan", + runtime_role="workspace", + ), + ), + ) + shared_alias_buffers = allocate_temporal_runtime_buffers(torch.empty((), dtype=torch.float32), shared_alias_plan) + assert shared_alias_buffers[0].data_ptr() == shared_alias_buffers[1].data_ptr() + public_state_alias_disabled_plan = build_temporal_runtime_buffer_plan( + plan, + physical_time_steps=1, + recurrent_hidden_shape=(2, 4, 8), + transition_forward_outputs=( + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_norm_row_index, + bucket_ordinal=0, + logical_name="public_y", + shape=(2, 4, 8), + runtime_role="transition_forward_norm_output", + alias_runtime_role="forward_recurrent_hidden_after", + ), + ), + dtype="torch.float32", + device="cpu", + ) + assert ( + public_state_alias_disabled_plan.estimated_allocated_buffer_bytes + == public_state_alias_disabled_plan.planned_buffer_bytes + ) + assert "runtime_alias.transition_public_state." not in "\n".join(public_state_alias_disabled_plan.review_summary) + + public_state_alias_plan = build_temporal_runtime_buffer_plan( + plan, + physical_time_steps=1, + recurrent_hidden_shape=(2, 4, 8), + transition_forward_outputs=( + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_norm_row_index, + bucket_ordinal=0, + logical_name="public_y", + shape=(2, 4, 8), + runtime_role="transition_forward_norm_output", + alias_runtime_role="forward_recurrent_hidden_after", + ), + ), + dtype="torch.float32", + device="cpu", + enable_public_state_runtime_alias=True, + ) + recurrent_hidden_spec = temporal_runtime_buffer_spec( + public_state_alias_plan, + name="forward_recurrent_hidden_after_step_0", + ) + public_y_spec = temporal_runtime_buffer_spec( + public_state_alias_plan, + name=f"transition_forward_norm_output_row_{transition_norm_row_index}_public_y", + ) + assert recurrent_hidden_spec.alias_set.startswith("runtime_alias.transition_public_state.") + assert public_y_spec.alias_set == recurrent_hidden_spec.alias_set + assert public_state_alias_plan.estimated_allocated_buffer_bytes < public_state_alias_plan.planned_buffer_bytes + public_state_alias_buffers = allocate_temporal_runtime_buffers( + torch.empty((), dtype=torch.float32), + public_state_alias_plan, + ) + assert public_state_alias_buffers[0].data_ptr() == public_state_alias_buffers[1].data_ptr() + + public_state_t2_plan = build_temporal_runtime_buffer_plan( + plan, + physical_time_steps=2, + recurrent_hidden_shape=(2, 4, 8), + transition_forward_outputs=( + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_norm_row_index, + bucket_ordinal=0, + logical_name="public_y", + shape=(2, 4, 8), + runtime_role="transition_forward_norm_output", + alias_runtime_role="forward_recurrent_hidden_after", + ), + ), + dtype="torch.float32", + device="cpu", + enable_public_state_runtime_alias=True, + ) + assert public_state_t2_plan.estimated_allocated_buffer_bytes == public_state_t2_plan.planned_buffer_bytes + assert "runtime_alias.transition_public_state." not in "\n".join(public_state_t2_plan.review_summary) + + none_scheduler = build_temporal_runtime_scheduler_plan( + temporal_plan=None, + outer_time_steps=1, + inner_steps=1, + output_boundary="terminal", + output_contract="output_cells", + materialize_final_state=False, + collect_artifacts=False, + ) + none_schedule_plan = build_temporal_memory_runtime_schedule_plan( + plan, + physical_time_steps=1, + collect_artifacts=False, + scheduler_plan=none_scheduler, + ) + deferred_local_plan = build_temporal_runtime_buffer_plan( + plan, + physical_time_steps=1, + recurrent_hidden_shape=(2, 4, 8), + transition_forward_outputs=( + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_linear_row_index, + bucket_ordinal=0, + logical_name="transition_input", + shape=(2, 4, 8), + ), + TemporalTransitionForwardRuntimeBufferRequest( + primitive_row_index=transition_norm_row_index, + bucket_ordinal=0, + logical_name="public_y", + shape=(2, 4, 8), + runtime_role="transition_forward_norm_output", + alias_runtime_role="forward_recurrent_hidden_after", + ), + ), + runtime_schedule_plan=none_schedule_plan, + dtype="torch.float32", + device="cpu", + enable_public_state_runtime_alias=True, + defer_local_transition_outputs=True, + ) + transition_input_spec = temporal_runtime_buffer_spec( + deferred_local_plan, + name=f"transition_forward_linear_output_row_{transition_linear_row_index}_transition_input", + ) + public_y_spec = temporal_runtime_buffer_spec( + deferred_local_plan, + name=f"transition_forward_norm_output_row_{transition_norm_row_index}_public_y", + ) + assert transition_input_spec.allocation == "deferred_local" + assert public_y_spec.allocation == "eager" + assert "allocation=deferred_local" in "\n".join(deferred_local_plan.review_summary) + assert deferred_local_plan.estimated_allocated_buffer_bytes < deferred_local_plan.planned_buffer_bytes + deferred_local_buffers = allocate_temporal_runtime_buffers( + torch.empty((), dtype=torch.float32), + deferred_local_plan, + ) + transition_input_buffer = deferred_local_buffers[deferred_local_plan.specs.index(transition_input_spec)] + assert tuple(transition_input_buffer.shape) == (0,) + assert int(transition_input_buffer.numel()) == 0 + + forward_step_deferred_plan = build_temporal_runtime_buffer_plan( + plan, + physical_time_steps=1, + recurrent_hidden_shape=(2, 4, 8), + forward_recurrent_msg_shape=(2, 4, 8), + forward_output_msg_shape=(2, 3, 8), + forward_output_cells_shape=(2, 3, 8), + runtime_schedule_plan=none_schedule_plan, + dtype="torch.float32", + device="cpu", + defer_forward_step_buffers=True, + ) + for name in ( + "forward_recurrent_hidden_after_step_0", + "forward_recurrent_msg_step_0", + "forward_output_msg_step_0", + "forward_output_cells_step_0", + ): + assert temporal_runtime_buffer_spec(forward_step_deferred_plan, name=name).allocation == "deferred_local" + forward_step_deferred_buffers = allocate_temporal_runtime_buffers( + torch.empty((), dtype=torch.float32), + forward_step_deferred_plan, + ) + for spec in forward_step_deferred_plan.specs: + buffer = forward_step_deferred_buffers[forward_step_deferred_plan.specs.index(spec)] + if spec.allocation == "deferred_local": + assert tuple(buffer.shape) == (0,) + assert int(buffer.numel()) == 0 + + with pytest.raises(RuntimeError, match="only legal for compiler-routed step-local outputs"): + validate_temporal_runtime_buffer_plan( + plan, + replace( + deferred_local_plan, + specs=( + replace( + temporal_runtime_buffer_spec(cpu_buffer_plan, name="output_seq"), + allocation="deferred_local", + ), + ), + ), + ) + + missing_alias_policy_plan = replace( + plan, + entries=tuple(entry for entry in plan.entries if entry.effect != "alias_policy"), + ) + with pytest.raises(RuntimeError, match="missing=alias_policy"): + build_temporal_runtime_buffer_plan( + missing_alias_policy_plan, + output_seq_shape=(2, 4, 3, 8), + dtype="torch.float32", + device="cpu", + ) + with pytest.raises(RuntimeError, match="unsupported dtype"): + allocate_temporal_runtime_buffer( + torch.empty((), dtype=torch.float32), + replace(temporal_runtime_buffer_spec(cpu_buffer_plan, name="output_seq"), dtype="torch.complex64"), + ) + + +def test_temporal_reverse_artifact_roles_are_compiler_binding_rows() -> None: + roles = ( + "boundary_step", + "cells_prev", + "output_cells", + ) + rows = temporal_reverse_artifact_role_rows_tensor(roles) + + assert rows.dtype == torch.long + assert rows.device.type == "cpu" + assert tuple(rows.shape) == (len(roles), 3) + assert torch.equal(rows[:, 0], torch.arange(len(roles), dtype=torch.long)) + assert int(rows[1, 1].item()) == temporal_reverse_artifact_role_id("cells_prev") + assert int(rows[2, 1].item()) == temporal_reverse_artifact_role_id("output_cells") + assert temporal_reverse_artifact_role_is_tensor("cells_prev") + assert temporal_reverse_artifact_role_is_tensor("output_cells") + assert temporal_reverse_artifact_role_is_tensor("transition_state_before") + flags = encode_temporal_reverse_transition_state_artifact_flags(bucket_ordinal=7, binding_index=42) + assert decode_temporal_reverse_transition_state_artifact_flags(flags) == (7, 42) + with pytest.raises(RuntimeError, match="Unknown temporal reverse artifact role"): + temporal_reverse_artifact_role_rows_tensor(("not_a_role",)) + with pytest.raises(RuntimeError, match="Unknown temporal reverse artifact role"): + temporal_reverse_artifact_role_is_tensor("backend_state_cache_before") + with pytest.raises(RuntimeError, match="Unknown temporal reverse artifact role"): + temporal_reverse_artifact_role_is_tensor("transition_backward_tape_by_population") + + +def test_temporal_reverse_artifact_access_rows_are_compiler_owned() -> None: + roles = ( + "boundary_step", + "cells_prev", + "input_k", + "input_v", + "recurrent_k_before", + "recurrent_v_before", + "recurrent_k", + "recurrent_v", + "recurrent_hidden_before_backend_order", + "recurrent_hidden_backend_order", + "recurrent_msg_backend_order", + "output_msg", + "output_cells", + "transition_state_before", + ) + rows = temporal_reverse_artifact_access_rows_tensor(roles) + + assert rows.dtype == torch.long + assert rows.device.type == "cpu" + assert tuple(rows.shape) == (len(roles), 3) + assert int(rows[0, 0].item()) == temporal_reverse_artifact_access_id("boundary_step") + assert int(rows[0, 1].item()) == temporal_reverse_artifact_role_id("boundary_step") + assert int(rows[12, 0].item()) == temporal_reverse_artifact_access_id("output_cells") + assert int(rows[12, 1].item()) == temporal_reverse_artifact_role_id("output_cells") + assert temporal_reverse_artifact_access_role_name("transition_state_before") == "transition_state_before" + with pytest.raises(RuntimeError, match="requires a missing role"): + temporal_reverse_artifact_access_rows_tensor(("cells_prev",)) + with pytest.raises(RuntimeError, match="Unknown temporal reverse artifact access"): + temporal_reverse_artifact_access_rows_tensor(roles, accesses=("not_an_access",)) + + +def test_temporal_reverse_reset_rows_are_compiler_owned() -> None: + message_reset = torch.tensor([False, True]) + transition_reset = torch.tensor([True, False]) + + tensors, rows = temporal_reverse_reset_tensor_table( + message_reset_step=message_reset, + transition_reset_step=transition_reset, + ) + + assert len(tensors) == 2 + assert rows.dtype == torch.long + assert rows.device.type == "cpu" + assert tuple(rows.shape) == (2, 4) + assert int(rows[0, 0].item()) == temporal_reverse_reset_kind_id("message") + assert int(rows[0, 1].item()) == 0 + assert int(rows[1, 0].item()) == temporal_reverse_reset_kind_id("transition") + assert int(rows[1, 1].item()) == 1 + assert all(tensor.dtype == torch.bool for tensor in tensors) + empty_tensors, empty_rows = temporal_reverse_reset_tensor_table( + message_reset_step=None, + transition_reset_step=None, + ) + assert empty_tensors == () + assert tuple(empty_rows.shape) == (0, 4) + + state_rows = temporal_reverse_transition_state_reset_rows_tensor( + group_logical_slots=( + {"grad_aggregated_message": 0, "grad_y": 1, "grad_c": 2, "grad_bias": 8}, + {"grad_hc1": 3, "grad_hc2": 4}, + ) + ) + assert state_rows.tolist() == [[0, 1], [0, 2], [1, 3], [1, 4]] + + +def test_temporal_forward_program_access_and_state_carry_rows_are_compiler_owned() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + program = build_registered_temporal_executor_program(runtime, static_tensors) + transition_handles = tuple( + program.forward_handle(surface="transition", bucket_ordinal=bucket_ordinal) + for bucket_ordinal in program.transition_bucket_ordinals() + ) + + access_rows = temporal_forward_program_access_rows_tensor( + message_handles=program.forward_surface_handles(surface="message"), + readout_handles=program.forward_surface_handles(surface="readout"), + transition_handles=transition_handles, + ) + carry_rows = temporal_forward_transition_state_carry_rows_tensor( + transition_handles=transition_handles, + ) + + assert access_rows.dtype == torch.long + assert access_rows.device.type == "cpu" + assert tuple(access_rows.shape) == (11 + 2 * len(transition_handles), 6) + assert all( + pattern.program_accesses + for pattern in temporal_forward_executor_patterns() + if pattern.surface in {"message", "readout"} or pattern.state_carry_rules + ) + assert any(pattern.state_carry_rules for pattern in temporal_forward_executor_patterns()) + rows_by_bucket = { + bucket: sorted(int(row[0]) for row in access_rows if int(row[2]) == bucket) + for bucket in {-2, -1, *program.transition_bucket_ordinals()} + } + assert rows_by_bucket[-1] == list(range(8)) + assert rows_by_bucket[-2] == [0, 1, 2] + opcodes_by_bucket = { + bucket: sorted(int(row[5]) for row in access_rows if int(row[2]) == bucket) for bucket in {-2, -1} + } + assert opcodes_by_bucket[-1] == sorted( + [ + temporal_program_access_opcode("message_query_slot_weight"), + temporal_program_access_opcode("message_query_context_scalar"), + temporal_program_access_opcode("message_sender_slot_key_weight"), + temporal_program_access_opcode("message_sender_context_key"), + temporal_program_access_opcode("message_input_value_weight"), + temporal_program_access_opcode("message_input_group_value_weight"), + temporal_program_access_opcode("message_recurrent_value_weight"), + temporal_program_access_opcode("message_output_weight"), + ] + ) + assert opcodes_by_bucket[-2] == sorted( + [ + temporal_program_access_opcode("readout_output_query"), + temporal_program_access_opcode("readout_value_to_output_weight"), + temporal_program_access_opcode("readout_output_cell_bias"), + ] + ) + transition_aggregate_rows = [row.tolist() for row in access_rows if int(row[0]) == 0 and int(row[2]) >= 0] + transition_public_rows = [row.tolist() for row in access_rows if int(row[0]) == 1 and int(row[2]) >= 0] + assert len(transition_aggregate_rows) == len(transition_handles) + assert len(transition_public_rows) == len(transition_handles) + assert sorted(row[2] for row in transition_aggregate_rows) == list(program.transition_bucket_ordinals()) + assert sorted(row[2] for row in transition_public_rows) == list(program.transition_bucket_ordinals()) + assert {int(row[5]) for row in transition_aggregate_rows} == { + temporal_program_access_opcode("transition_aggregated_message_input") + } + assert {int(row[5]) for row in transition_public_rows} == { + temporal_program_access_opcode("transition_public_state_output") + } + assert all(int(row[1]) >= 0 for row in access_rows) + + assert carry_rows.dtype == torch.long + assert carry_rows.device.type == "cpu" + assert tuple(carry_rows.shape[1:]) == (3,) + assert set(int(row[0]) for row in carry_rows) == set(program.transition_bucket_ordinals()) + assert all(int(row[1]) != int(row[2]) for row in carry_rows) + + +def test_temporal_reverse_program_access_rows_are_compiler_owned() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + program = build_registered_temporal_executor_program(runtime, static_tensors) + + access_rows = temporal_reverse_program_access_rows_tensor( + message_handles=program.reverse_surface_handles(surface="message"), + readout_handles=program.reverse_surface_handles(surface="readout"), + ) + + assert access_rows.dtype == torch.long + assert access_rows.device.type == "cpu" + assert tuple(access_rows.shape) == (11, 6) + assert all( + pattern.program_accesses + for pattern in temporal_reverse_executor_patterns() + if pattern.surface in {"message", "readout"} + ) + rows_by_bucket = { + bucket: sorted(int(row[0]) for row in access_rows if int(row[2]) == bucket) for bucket in {-2, -1} + } + assert rows_by_bucket[-1] == list(range(8)) + assert rows_by_bucket[-2] == [0, 1, 2] + assert {int(row[5]) for row in access_rows if int(row[2]) == -1} == { + temporal_program_access_opcode("message_query_slot_weight"), + temporal_program_access_opcode("message_query_context_scalar"), + temporal_program_access_opcode("message_sender_slot_key_weight"), + temporal_program_access_opcode("message_sender_context_key"), + temporal_program_access_opcode("message_input_value_weight"), + temporal_program_access_opcode("message_input_group_value_weight"), + temporal_program_access_opcode("message_recurrent_value_weight"), + temporal_program_access_opcode("message_output_weight"), + } + assert {int(row[5]) for row in access_rows if int(row[2]) == -2} == { + temporal_program_access_opcode("readout_output_query"), + temporal_program_access_opcode("readout_value_to_output_weight"), + temporal_program_access_opcode("readout_output_cell_bias"), + } + assert all(int(row[1]) >= 0 for row in access_rows) + + +def test_temporal_backward_validates_memory_artifact_plan_fingerprint() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan( + runtime, + with_cached_population_static_tensors(runtime, static_tensors), + ) + memory_plan = build_temporal_memory_liveness_plan(table) + scheduler = build_temporal_runtime_scheduler_plan( + temporal_plan=SimpleNamespace( + materialization=SimpleNamespace( + reverse_artifact_kind="checkpoint_recompute", + checkpoint_steps=2, + recompute_window_steps=3, + ) + ), + outer_time_steps=4, + inner_steps=2, + output_boundary="sequence", + output_contract="output_cells", + materialize_final_state=False, + collect_artifacts=True, + ) + runtime_plan = build_temporal_memory_runtime_artifact_plan( + memory_plan, + physical_time_steps=8, + collect_artifacts=True, + scheduler_plan=scheduler, + ) + physical_strategy_plan = build_temporal_physical_strategy_plan( + runtime_plan.runtime_schedule_plan, + inner_steps=2, + output_boundary="sequence", + reset_policy="absent", + ) + physical_strategy_rows = temporal_physical_strategy_rows_tensor(physical_strategy_plan) + artifact_store = SimpleNamespace( + mode=runtime_plan.mode, + checkpoint_stride=runtime_plan.checkpoint_stride, + recompute_window_len=runtime_plan.recompute_window_len, + checkpoint_steps=runtime_plan.checkpoint_steps, + backward_windows=runtime_plan.backward_windows, + memory_plan_fingerprint=memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=runtime_plan.fingerprint, + memory_runtime_policy_fingerprint=runtime_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=runtime_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor(runtime_plan.runtime_schedule_plan), + physical_strategy_fingerprint=physical_strategy_plan.fingerprint, + physical_strategy_rows=physical_strategy_rows, + ) + + _require_compiler_memory_artifact_plan( + artifact_store, + memory_plan_fingerprint=memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=runtime_plan.fingerprint, + memory_runtime_policy_fingerprint=runtime_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=runtime_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor(runtime_plan.runtime_schedule_plan), + expected_mode=runtime_plan.mode, + expected_checkpoint_stride=runtime_plan.checkpoint_stride, + expected_recompute_window_len=runtime_plan.recompute_window_len, + expected_checkpoint_steps=runtime_plan.checkpoint_steps, + expected_backward_windows=runtime_plan.backward_windows, + ) + _require_compiler_physical_strategy_plan( + artifact_store, + physical_strategy_fingerprint=physical_strategy_plan.fingerprint, + physical_strategy_rows=physical_strategy_rows, + ) + + stale_store = SimpleNamespace(**{**vars(artifact_store), "memory_plan_fingerprint": ("stale",)}) + with pytest.raises(RuntimeError, match="memory-plan fingerprint mismatch"): + _require_compiler_memory_artifact_plan( + stale_store, + memory_plan_fingerprint=memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=runtime_plan.fingerprint, + memory_runtime_policy_fingerprint=runtime_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=runtime_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor( + runtime_plan.runtime_schedule_plan + ), + expected_mode=runtime_plan.mode, + expected_checkpoint_stride=runtime_plan.checkpoint_stride, + expected_recompute_window_len=runtime_plan.recompute_window_len, + expected_checkpoint_steps=runtime_plan.checkpoint_steps, + expected_backward_windows=runtime_plan.backward_windows, + ) + + stale_windows = SimpleNamespace(**{**vars(artifact_store), "backward_windows": ((0, 4), (4, 8))}) + with pytest.raises(RuntimeError, match="windows do not match compiler memory plan"): + _require_compiler_memory_artifact_plan( + stale_windows, + memory_plan_fingerprint=memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=runtime_plan.fingerprint, + memory_runtime_policy_fingerprint=runtime_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=runtime_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor( + runtime_plan.runtime_schedule_plan + ), + expected_mode=runtime_plan.mode, + expected_checkpoint_stride=runtime_plan.checkpoint_stride, + expected_recompute_window_len=runtime_plan.recompute_window_len, + expected_checkpoint_steps=runtime_plan.checkpoint_steps, + expected_backward_windows=runtime_plan.backward_windows, + ) + stale_policy = SimpleNamespace(**{**vars(artifact_store), "memory_runtime_policy_fingerprint": ("stale",)}) + with pytest.raises(RuntimeError, match="runtime policy fingerprint mismatch"): + _require_compiler_memory_artifact_plan( + stale_policy, + memory_plan_fingerprint=memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=runtime_plan.fingerprint, + memory_runtime_policy_fingerprint=runtime_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=runtime_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor( + runtime_plan.runtime_schedule_plan + ), + expected_mode=runtime_plan.mode, + expected_checkpoint_stride=runtime_plan.checkpoint_stride, + expected_recompute_window_len=runtime_plan.recompute_window_len, + expected_checkpoint_steps=runtime_plan.checkpoint_steps, + expected_backward_windows=runtime_plan.backward_windows, + ) + stale_schedule = SimpleNamespace(**{**vars(artifact_store), "memory_runtime_schedule_fingerprint": ("stale",)}) + with pytest.raises(RuntimeError, match="runtime schedule fingerprint mismatch"): + _require_compiler_memory_artifact_plan( + stale_schedule, + memory_plan_fingerprint=memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=runtime_plan.fingerprint, + memory_runtime_policy_fingerprint=runtime_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=runtime_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor( + runtime_plan.runtime_schedule_plan + ), + expected_mode=runtime_plan.mode, + expected_checkpoint_stride=runtime_plan.checkpoint_stride, + expected_recompute_window_len=runtime_plan.recompute_window_len, + expected_checkpoint_steps=runtime_plan.checkpoint_steps, + expected_backward_windows=runtime_plan.backward_windows, + ) + stale_schedule_rows = SimpleNamespace( + **{ + **vars(artifact_store), + "memory_runtime_schedule_rows": torch.zeros_like( + temporal_memory_runtime_schedule_rows_tensor(runtime_plan.runtime_schedule_plan) + ), + } + ) + with pytest.raises(RuntimeError, match="runtime schedule rows mismatch"): + _require_compiler_memory_artifact_plan( + stale_schedule_rows, + memory_plan_fingerprint=memory_plan.fingerprint, + memory_runtime_artifact_fingerprint=runtime_plan.fingerprint, + memory_runtime_policy_fingerprint=runtime_plan.runtime_policy.review_summary, + memory_runtime_schedule_fingerprint=runtime_plan.runtime_schedule_plan.fingerprint, + memory_runtime_schedule_rows=temporal_memory_runtime_schedule_rows_tensor( + runtime_plan.runtime_schedule_plan + ), + expected_mode=runtime_plan.mode, + expected_checkpoint_stride=runtime_plan.checkpoint_stride, + expected_recompute_window_len=runtime_plan.recompute_window_len, + expected_checkpoint_steps=runtime_plan.checkpoint_steps, + expected_backward_windows=runtime_plan.backward_windows, + ) + stale_physical_rows = SimpleNamespace( + **{ + **vars(artifact_store), + "physical_strategy_rows": torch.zeros_like(physical_strategy_rows), + } + ) + with pytest.raises(RuntimeError, match="physical-strategy rows mismatch"): + _require_compiler_physical_strategy_plan( + stale_physical_rows, + physical_strategy_fingerprint=physical_strategy_plan.fingerprint, + physical_strategy_rows=physical_strategy_rows, + ) + + +def test_temporal_backward_requires_compiler_planned_memory_windows() -> None: + assert _require_compiler_planned_artifact_windows( + SimpleNamespace(backward_windows=((0, 2), (2, 4))), + time_steps=4, + ) == ((0, 2), (2, 4)) + + with pytest.raises(RuntimeError, match="checkpoint and recompute window derivation"): + _require_compiler_planned_artifact_windows(SimpleNamespace(backward_windows=()), time_steps=4) + + with pytest.raises(RuntimeError, match="must exactly cover physical time"): + _require_compiler_planned_artifact_windows( + SimpleNamespace(backward_windows=((0, 2), (3, 4))), + time_steps=4, + ) + + +def test_fabric_cuda_nn_primitives_are_surface_independent_temporal_rows() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + primitive_tensor = temporal_primitive_rows_tensor(table) + + records: list[tuple[str, str, int]] = [] + for row, tensor_row in zip(table.primitive_rows, primitive_tensor.tolist(), strict=True): + surface = next( + item.removeprefix("surface=") for item in row.flat_bucket_identity if item.startswith("surface=") + ) + records.append((surface, row.primitive, int(tensor_row[0]))) + assert int(tensor_row[0]) == temporal_primitive_opcode(row.primitive) + + linear_records = [(surface, opcode) for surface, primitive, opcode in records if primitive == "linear"] + assert {surface for surface, _opcode in linear_records}.issuperset({"message", "transition"}) + assert len({opcode for _surface, opcode in linear_records}) == 1 + + reduction_records = [ + (surface, opcode) for surface, primitive, opcode in records if primitive == "reduction_boundary" + ] + assert {surface for surface, _opcode in reduction_records}.issuperset({"readout_boundary", "parameter_reduction"}) + assert len({opcode for _surface, opcode in reduction_records}) == 1 + assert temporal_transition_tape_kind("diag_rtu") == "diagonal_recurrence" + + executor_rows = temporal_forward_executor_rows_tensor(table) + assert executor_rows.dtype == torch.long + assert executor_rows.shape[0] < primitive_tensor.shape[0] + executor_summary_text = "\n".join(temporal_forward_executor_summaries(table)) + assert "surface=message" in executor_summary_text + assert "surface=transition" in executor_summary_text + assert "executor=linear" not in executor_summary_text + + +def test_planner_cache_keys_flat_bucket_identity_not_binding_slot() -> None: + runtime = build(_make_slstm_spec()) + recurrent_bucket = next( + bucket for bucket in runtime.backend_ir.buckets if bucket.receiver_kind == ReceiverKind.RECURRENT_CELL + ) + first = replace( + recurrent_bucket, + bucket_id=0, + population_name="left", + population_index=0, + parameter_binding="population_slot:0", + ) + second = replace( + recurrent_bucket, + bucket_id=1, + population_name="right", + population_index=1, + parameter_binding="population_slot:1", + ) + ir = replace(runtime.backend_ir, population_names=("left", "right"), buckets=(first, second)) + planner = FabricExecutionPlanner(ir=ir) + + planned = planner.plan_execution( + batch_size=4, + time_steps=3, + inner_steps=1, + training=False, + device_caps=_test_device_caps(), + tape_policy=None, + ) + planned_backward = planner.plan_backward_execution( + batch_size=4, + time_steps=3, + inner_steps=1, + training=True, + device_caps=_test_device_caps(), + tape_policy=None, + ) + + assert first.signature != second.signature + assert first.planner_signature == second.planner_signature + assert planned.cache_misses == 1 + assert planned.cache_hits == 1 + assert tuple(plan.bucket_id for plan in planned.bucket_plans) == (0, 1) + assert planned_backward.cache_misses == 2 + assert planned_backward.cache_hits == 2 + assert tuple(plan.bucket_id for plan in planned_backward.receiver_bucket_plans) == (0, 1) + assert tuple(plan.bucket_id for plan in planned_backward.sender_bucket_plans) == (0, 1) + + +def test_scalar_temporal_scan_schedule_marks_outer_emissions_and_resets() -> None: + schedule = build_scalar_temporal_scan_schedule(outer_time_steps=3, inner_steps=2) + + assert schedule.physical_time_steps == 6 + assert schedule.emission_steps == (1, 3, 5) + assert [(step.outer_step, step.inner_step) for step in schedule.steps] == [ + (0, 0), + (0, 1), + (1, 0), + (1, 1), + (2, 0), + (2, 1), + ] + assert tuple(step.physical_step for step in schedule.steps if step.apply_boundary_reset) == (0, 2, 4) + assert all(step.apply_transition_reset for step in schedule.steps) + + +def test_scalar_temporal_scan_schedule_covers_k128_ceiling() -> None: + schedule = build_scalar_temporal_scan_schedule(outer_time_steps=1, inner_steps=128) + + assert schedule.physical_time_steps == 128 + assert schedule.emission_steps == (127,) + assert schedule.steps[0].apply_boundary_reset is True + assert all(not step.apply_boundary_reset for step in schedule.steps[1:]) + assert all(step.apply_transition_reset for step in schedule.steps) + + +def test_scalar_temporal_scan_schedule_does_not_store_large_tk_step_table() -> None: + schedule = build_scalar_temporal_scan_schedule(outer_time_steps=16_384, inner_steps=128) + + assert "steps" not in schedule.__dict__ + assert schedule.physical_time_steps == 2_097_152 + first = schedule.step_at(0) + last = schedule.step_at(schedule.physical_time_steps - 1) + assert (first.physical_step, first.outer_step, first.inner_step, first.apply_boundary_reset) == (0, 0, 0, True) + assert (last.physical_step, last.outer_step, last.inner_step, last.emit_output) == ( + 2_097_151, + 16_383, + 127, + True, + ) + assert [step.physical_step for step in schedule.iter_steps(start=128, end=131)] == [128, 129, 130] + + +def test_scalar_temporal_scan_schedule_maps_sequence_and_terminal_emissions() -> None: + schedule = build_scalar_temporal_scan_schedule(outer_time_steps=3, inner_steps=2) + + assert [ + emitted_output_index_for_scan_step(step, outer_time_steps=3, emitted_time_steps=3) for step in schedule.steps + ] == [None, 0, None, 1, None, 2] + assert [ + emitted_output_index_for_scan_step(step, outer_time_steps=3, emitted_time_steps=1) for step in schedule.steps + ] == [None, None, None, None, None, 0] + + +def test_explicit_graph_uses_sparse_message_backend_without_patch_edges() -> None: + runtime = build(_make_explicit_graph_spec("slstm")) + + assert runtime.local_offsets.numel() == 0 + assert not runtime._local_message_step_enabled + assert runtime._uses_sparse_message_backend + + +def test_default_message_rule_contract_is_planner_visible() -> None: + runtime = build(_make_explicit_graph_spec("slstm")) + message_rule = runtime.backend_ir.message_rule + message_program = runtime.backend_ir.message_program + + assert message_rule.name == "dot_product" + assert message_rule.lowering_kind == "dot_product_fixed_slot_context_nudge" + assert message_rule.output_boundary == "projected_message" + assert "receiver_public_prev:reset=zero_source_rows:scope=batch_row" in message_rule.source_signature + assert "sender_public_prev:reset=zero_source_rows:scope=batch_row" in message_rule.source_signature + assert any("sender_group_shared" in entry for entry in message_rule.parameter_sharing_signature) + assert message_program.rule_name == "dot_product" + assert message_program.lowering_kind == "dot_product_fixed_slot_context_nudge" + assert message_program.output_dim_role == "d_msg" + assert message_program.primitive_names == ( + "linear", + "mul", + "concat", + "attention_logits", + "add", + "segment_softmax", + "weighted_sum", + "normalize", + ) + assert temporal_message_output_dim(runtime) == int(runtime.d_msg) + + planned = runtime.plan_backend_execution( + batch_size=1, + time_steps=1, + inner_steps=1, + training=False, + surface_key="slstm_recurrence", + ) + assert planned.bucket_plans + assert all( + plan.message_rule_lowering_kind == "dot_product_fixed_slot_context_nudge" for plan in planned.bucket_plans + ) + assert all(plan.message_rule_name == "dot_product" for plan in planned.bucket_plans) + assert all(plan.message_rule_output_boundary == "projected_message" for plan in planned.bucket_plans) + + +def test_message_rule_backend_specs_are_registered_like_cell_specs() -> None: + assert "dot_product" in registered_message_rule_backend_spec_types() + assert "dot_product_fixed_slot_context_gate" in registered_message_rule_backend_spec_types() + assert "dot_product_fixed_slot_context_nudge" in registered_message_rule_backend_spec_types() + + spec = build_message_rule_backend_spec(rule_type="dot_product", kv_group_count=4, cell_count=64) + message_rule = spec.to_ir(name="registered_dot_product") + message_program = compile_message_rule(message_rule) + + assert spec.rule_type == "dot_product" + assert message_rule.name == "registered_dot_product" + assert message_rule.rule_type == "dot_product" + assert message_rule.lowering_kind == DOT_PRODUCT_SEGMENT_SOFTMAX_WEIGHTED_SUM + assert message_program.output_dim_role == "value_dim" + assert message_rule.primitive_bindings + assert message_program.primitive_names == ("linear", "attention_logits", "add", "segment_softmax", "weighted_sum") + assert tuple(static_tensor.program_access_name for static_tensor in message_program.static_tensors) == ( + "message_recurrent_query", + "message_input_direct_kv_weight", + "message_input_group_kv_weight", + "message_recurrent_kv_weight", + ) + assert {static_tensor.source_kind for static_tensor in message_program.static_tensors} == {"existing_static_tensor"} + assert {executor.direction for executor in message_program.native_executors} == {"forward", "reverse"} + + +def test_fixed_slot_context_nudge_dot_product_lowers_as_distinct_message_program() -> None: + spec = build_message_rule_backend_spec( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=4, + cell_count=64, + ) + message_rule = spec.to_ir(name="registered_context_nudge_dot_product") + message_program = compile_message_rule(message_rule) + + assert message_rule.name == "registered_context_nudge_dot_product" + assert message_rule.rule_type == "dot_product_fixed_slot_context_nudge" + assert message_rule.lowering_kind == DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE + assert message_program.output_dim_role == "d_msg" + assert message_program.parameter_reducer_kind == "fixed_slot_context_message" + assert tuple(output.logical_name for output in message_program.param_grad_outputs) == ( + "grad_query_slot_backend", + "grad_query_context_scalar", + "grad_output_weight", + "grad_input_key_bank", + "grad_recurrent_key_bank", + ) + assert tuple(module.name for module in message_program.runtime_modules) == ( + "message_query_slot_proj", + "message_sender_slot_key_proj", + ) + assert tuple(parameter.name for parameter in message_program.runtime_parameters) == ( + "message_query_nudge_scale", + "message_sender_context_key", + ) + assert tuple(static_tensor.name for static_tensor in message_program.static_tensors) == ( + "message_query_slot_weight", + "message_query_nudge_scale", + "message_sender_slot_key_weight", + "message_sender_context_key", + "input_sender_value_weight", + "input_group_value_weight", + "recurrent_sender_value_weight", + "message_output_weight", + ) + assert "receiver_public_prev:reset=zero_source_rows:scope=batch_row" in message_rule.source_signature + assert "sender_slot" in message_rule.source_signature + assert "message_query_nudge_scale:rule_scalar:fabric_global" in message_rule.parameter_sharing_signature + assert "fixed_slot_query(receiver_slot)+context_nudge(receiver_public_prev)->q" in ( + message_rule.expression_signature + ) + assert message_program.primitive_names == ( + "linear", + "mul", + "concat", + "attention_logits", + "add", + "segment_softmax", + "weighted_sum", + "normalize", + ) + + +def test_fixed_slot_context_gate_dot_product_lowers_as_distinct_message_program() -> None: + spec = build_message_rule_backend_spec( + rule_type="dot_product_fixed_slot_context_gate", + kv_group_count=4, + cell_count=64, + ) + message_rule = spec.to_ir(name="registered_context_gate_dot_product") + message_program = compile_message_rule(message_rule) + + assert message_rule.name == "registered_context_gate_dot_product" + assert message_rule.rule_type == "dot_product_fixed_slot_context_gate" + assert message_rule.lowering_kind == DOT_PRODUCT_FIXED_SLOT_CONTEXT_GATE + assert message_program.output_dim_role == "d_msg" + assert message_program.parameter_reducer_kind == "fixed_slot_context_message" + assert tuple(parameter.name for parameter in message_program.runtime_parameters) == ( + "message_query_context_gate", + "message_sender_context_key", + ) + assert tuple(static_tensor.name for static_tensor in message_program.static_tensors) == ( + "message_query_slot_weight", + "message_query_context_gate", + "message_sender_slot_key_weight", + "message_sender_context_key", + "input_sender_value_weight", + "input_group_value_weight", + "recurrent_sender_value_weight", + "message_output_weight", + ) + assert "message_query_context_gate:rule_scalar:fabric_global" in message_rule.parameter_sharing_signature + assert "fixed_slot_query(receiver_slot)+context_gate(receiver_public_prev)->q" in ( + message_rule.expression_signature + ) + assert message_program.primitive_names == ( + "linear", + "mul", + "concat", + "attention_logits", + "add", + "segment_softmax", + "weighted_sum", + "normalize", + ) + + +def test_message_executor_patterns_follow_registered_message_specs() -> None: + def compiled_row_signature(rule_type: str) -> tuple[tuple[str, tuple[str, ...], tuple[tuple[str, str], ...]], ...]: + message_program = compile_message_rule( + build_message_rule_backend_spec( + rule_type=rule_type, + kv_group_count=1, + cell_count=2, + ).to_ir() + ) + return tuple((op.primitive, tuple(op.parameter_bindings), ()) for op in message_program.primitive_ops) + + def declared_program_access_signature(rule_type: str) -> tuple[tuple[str, str, int], ...]: + spec = build_message_rule_backend_spec( + rule_type=rule_type, + kv_group_count=1, + cell_count=2, + ) + return tuple( + ( + static_tensor.program_access_name, + static_tensor.name, + int(static_tensor.program_access_opcode), + ) + for static_tensor in spec.static_tensors + if static_tensor.program_access_name + ) + + def declared_param_grad_signature(rule_type: str) -> tuple[tuple[str, str, int], ...]: + spec = build_message_rule_backend_spec( + rule_type=rule_type, + kv_group_count=1, + cell_count=2, + ) + return tuple( + (output.logical_name, output.source, int(output.source_index)) for output in spec.param_grad_outputs + ) + + def declared_native_executor_signature( + rule_type: str, + direction: str, + ) -> tuple[int, str, str, str, str, tuple[tuple[str, str], ...]]: + spec = build_message_rule_backend_spec( + rule_type=rule_type, + kv_group_count=1, + cell_count=2, + ) + matches = tuple(executor for executor in spec.native_executors if executor.direction == direction) + assert len(matches) == 1 + executor = matches[0] + return ( + int(executor.executor_id), + executor.executor_name, + executor.strategy_id, + executor.native_callable, + executor.implementation_contract, + tuple((entry.phase, entry.symbol) for entry in executor.cxx_entrypoint_contract), + ) + + def pattern_native_signature(pattern: object) -> tuple[int, str, str, str, str, tuple[tuple[str, str], ...]]: + return ( + int(pattern.executor_id), # type: ignore[attr-defined] + str(pattern.executor_name), # type: ignore[attr-defined] + str(pattern.stable_strategy_id), # type: ignore[attr-defined] + str(pattern.stable_native_callable_id), # type: ignore[attr-defined] + str(pattern.implementation_contract), # type: ignore[attr-defined] + tuple( + zip( + tuple(pattern.cxx_entrypoint_phases), # type: ignore[attr-defined] + tuple(pattern.cxx_entrypoints), # type: ignore[attr-defined] + strict=True, + ) + ), + ) + + patterns = { + pattern.executor_name: pattern + for pattern in temporal_forward_executor_patterns() + if pattern.surface == "message" + } + reverse_patterns = { + pattern.executor_name: pattern + for pattern in temporal_reverse_executor_patterns() + if pattern.surface == "message" + } + + assert patterns["neighborhood_attention_project"].row_signature == compiled_row_signature("dot_product") + assert tuple( + (access.access_name, access.logical_name, access.stable_access_opcode) + for access in patterns["neighborhood_attention_project"].program_accesses + ) == declared_program_access_signature("dot_product") + assert pattern_native_signature(patterns["neighborhood_attention_project"]) == declared_native_executor_signature( + "dot_product", + "forward", + ) + assert pattern_native_signature( + reverse_patterns["neighborhood_attention_project_backward"] + ) == declared_native_executor_signature( + "dot_product", + "reverse", + ) + assert patterns["fixed_slot_context_nudge_message"].row_signature == compiled_row_signature( + "dot_product_fixed_slot_context_nudge" + ) + assert patterns["fixed_slot_context_gate_message"].row_signature == compiled_row_signature( + "dot_product_fixed_slot_context_gate" + ) + assert tuple( + (access.access_name, access.logical_name, access.stable_access_opcode) + for access in patterns["fixed_slot_context_nudge_message"].program_accesses + ) == declared_program_access_signature("dot_product_fixed_slot_context_nudge") + assert tuple( + (access.access_name, access.logical_name, access.stable_access_opcode) + for access in patterns["fixed_slot_context_gate_message"].program_accesses + ) == declared_program_access_signature("dot_product_fixed_slot_context_gate") + assert reverse_patterns["fixed_slot_context_nudge_message_backward"].parameter_reducer_kind == ( + build_message_rule_backend_spec( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=1, + cell_count=2, + ).parameter_reducer_kind + ) + assert tuple( + (output.logical_name, output.source, int(output.source_index)) + for output in reverse_patterns["fixed_slot_context_nudge_message_backward"].message_param_grad_outputs + ) == declared_param_grad_signature("dot_product_fixed_slot_context_nudge") + assert tuple( + (output.logical_name, output.source, int(output.source_index)) + for output in reverse_patterns["fixed_slot_context_gate_message_backward"].message_param_grad_outputs + ) == declared_param_grad_signature("dot_product_fixed_slot_context_gate") + assert pattern_native_signature(patterns["fixed_slot_context_nudge_message"]) == declared_native_executor_signature( + "dot_product_fixed_slot_context_nudge", + "forward", + ) + assert pattern_native_signature( + reverse_patterns["fixed_slot_context_nudge_message_backward"] + ) == declared_native_executor_signature( + "dot_product_fixed_slot_context_nudge", + "reverse", + ) + fixed_slot_reducer_spec = build_message_rule_backend_spec( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=1, + cell_count=2, + ).parameter_reducer + assert fixed_slot_reducer_spec is not None + fixed_slot_reducer_pattern = temporal_parameter_reducer_pattern(fixed_slot_reducer_spec.reducer_kind) + assert fixed_slot_reducer_pattern.native_callable == fixed_slot_reducer_spec.native_callable + assert fixed_slot_reducer_pattern.implementation_symbol == fixed_slot_reducer_spec.implementation_symbol + assert fixed_slot_reducer_pattern.active_trainable_roles == fixed_slot_reducer_spec.active_trainable_roles + assert fixed_slot_reducer_pattern.required_static_logical_groups == ( + fixed_slot_reducer_spec.required_static_logical_groups + ) + + +def test_readout_executor_patterns_follow_registered_readout_specs() -> None: + lowering_kind = "mean_readout_project" + readout_program = compile_readout_rule(default_readout_rule_ir(readout_pool="mean", readout_slots=1)) + readout_spec = build_readout_rule_backend_spec(lowering_kind=lowering_kind) + + forward_pattern = next( + pattern + for pattern in temporal_forward_executor_patterns() + if pattern.surface == "readout" + and pattern.stable_strategy_id == "forward.readout.mean_projection_reduction_boundary.v1" + ) + reverse_pattern = next( + pattern + for pattern in temporal_reverse_executor_patterns() + if pattern.surface == "readout" + and pattern.stable_strategy_id == "reverse.readout.mean_projection_reduction_boundary.v1" + ) + forward_native = readout_rule_native_executor(lowering_kind=lowering_kind, direction="forward") + reverse_native = readout_rule_native_executor(lowering_kind=lowering_kind, direction="reverse") + declared_accesses = tuple( + ( + static_tensor.program_access_name, + static_tensor.name, + int(static_tensor.program_access_opcode), + ) + for static_tensor in readout_spec.static_tensors + if static_tensor.program_access_name + ) + + assert set(registered_readout_rule_backend_spec_lowering_kinds()) >= { + "mean_readout_project", + "flatten_readout_project", + "attn_readout_project", + "attention_readout_project", + } + assert readout_program.static_tensors == readout_spec.static_tensors + assert readout_program.native_executors == readout_spec.native_executors + assert forward_pattern.row_signature == tuple( + ( + op.primitive, + tuple(op.parameter_inputs), + tuple( + (str(key), str(value)) + for key, value in op.attributes + if str(key) in {"lowering_kind", "pool", "output_boundary"} + ), + ) + for op in readout_program.primitive_ops + ) + assert reverse_pattern.row_signature == forward_pattern.row_signature + assert ( + int(forward_pattern.executor_id), + forward_pattern.executor_name, + forward_pattern.stable_strategy_id, + forward_pattern.stable_native_callable_id, + forward_pattern.implementation_contract, + tuple(zip(forward_pattern.cxx_entrypoint_phases, forward_pattern.cxx_entrypoints, strict=True)), + ) == ( + int(forward_native.executor_id), + forward_native.executor_name, + forward_native.strategy_id, + forward_native.native_callable, + forward_native.implementation_contract, + forward_native.cxx_entrypoint_contract, + ) + assert ( + int(reverse_pattern.executor_id), + reverse_pattern.executor_name, + reverse_pattern.stable_strategy_id, + reverse_pattern.stable_native_callable_id, + reverse_pattern.implementation_contract, + tuple(zip(reverse_pattern.cxx_entrypoint_phases, reverse_pattern.cxx_entrypoints, strict=True)), + ) == ( + int(reverse_native.executor_id), + reverse_native.executor_name, + reverse_native.strategy_id, + reverse_native.native_callable, + reverse_native.implementation_contract, + reverse_native.cxx_entrypoint_contract, + ) + assert ( + tuple( + (access.access_name, access.logical_name, access.stable_access_opcode) + for access in forward_pattern.program_accesses + ) + == declared_accesses + ) + assert ( + tuple( + (access.access_name, access.logical_name, access.stable_access_opcode) + for access in reverse_pattern.program_accesses + ) + == declared_accesses + ) + + +def test_fixed_slot_context_nudge_dot_product_selects_registered_strategy() -> None: + spec = _make_slstm_spec() + declared_message_rule = build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_nudge_dot_product", + ) + runtime = build(replace(spec, message_rule=declared_message_rule)) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + message_rows = [row for row in table.primitive_rows if "surface=message" in row.flat_bucket_identity] + + assert message_rows + assert any(row.primitive == "normalize" for row in message_rows) + assert temporal_primitive_opcode("normalize") > 0 + assert any( + ("compiled_lowering_kind", DOT_PRODUCT_FIXED_SLOT_CONTEXT_NUDGE) in row.attributes for row in message_rows + ) + forward_rows = temporal_forward_executor_rows_tensor(table) + reverse_rows = temporal_reverse_executor_rows_tensor(table) + assert int(forward_rows[0, 0].item()) == 5 + assert int(reverse_rows[0, 0].item()) == 5 + assert "executor=fixed_slot_context_nudge_message" in "\n".join(temporal_forward_executor_summaries(table)) + assert "reverse_executor=fixed_slot_context_nudge_message_backward" in "\n".join( + temporal_reverse_executor_summaries(table) + ) + + forward_plan = build_temporal_forward_executable_plan(table) + backward_plan = build_temporal_backward_executable_plan(table) + assert forward_plan.strategy_legality_status == "legal" + assert backward_plan.strategy_legality_status == "legal" + forward_review = "\n".join(forward_plan.review_summary) + backward_review = "\n".join(backward_plan.review_summary) + assert "strategy_legality_blocker=UNVERIFIED_REWRITE" not in forward_review + assert "strategy_legality_blocker=UNVERIFIED_REWRITE" not in backward_review + assert "registered_fixed_slot_context_nudge_message_native_callable_pending" not in forward_review + assert "registered_fixed_slot_context_nudge_message_backward_native_callable_pending" not in backward_review + + selection_report = build_temporal_strategy_selection_report(table) + blocked_text = "\n".join(selection_report.blocked_summaries) + assert "strategy_id=forward.message.fixed_slot_context_nudge.v1" not in blocked_text + assert "strategy_id=reverse.message.fixed_slot_context_nudge.v1" not in blocked_text + assert "rejection=UNVERIFIED_REWRITE" not in blocked_text + + native_schema_text = "\n".join(temporal_native_callable_binding_schema_summaries()) + assert "callable=native.forward.msg_fixed_slot_context_nudge.v1" in native_schema_text + assert "callable=native.reverse.msg_fixed_slot_context_nudge.v1" in native_schema_text + transition_grad_binding_text = "\n".join(build_temporal_transition_param_grad_binding_plan(table).summaries) + assert "sources=static_tensor:fused_recurrent_value_to_cell_weight" in transition_grad_binding_text + assert "static_tensor:value_to_cell_weight" in transition_grad_binding_text + assert "selected_static_source=message_to_cell_weight" in transition_grad_binding_text + + +def test_fixed_slot_context_gate_dot_product_selects_registered_strategy_with_access_remap() -> None: + spec = _make_slstm_spec() + declared_message_rule = build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_gate", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_gate_dot_product", + ) + runtime = build(replace(spec, message_rule=declared_message_rule)) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + message_rows = [row for row in table.primitive_rows if "surface=message" in row.flat_bucket_identity] + + assert message_rows + assert any( + ("compiled_lowering_kind", DOT_PRODUCT_FIXED_SLOT_CONTEXT_GATE) in row.attributes for row in message_rows + ) + forward_rows = temporal_forward_executor_rows_tensor(table) + reverse_rows = temporal_reverse_executor_rows_tensor(table) + assert int(forward_rows[0, 0].item()) == 10 + assert int(reverse_rows[0, 0].item()) == 10 + assert "executor=fixed_slot_context_gate_message" in "\n".join(temporal_forward_executor_summaries(table)) + assert "reverse_executor=fixed_slot_context_gate_message_backward" in "\n".join( + temporal_reverse_executor_summaries(table) + ) + + forward_plan = build_temporal_forward_executable_plan(table) + backward_plan = build_temporal_backward_executable_plan(table) + assert forward_plan.strategy_legality_status == "legal" + assert backward_plan.strategy_legality_status == "legal" + + selection_report = build_temporal_strategy_selection_report(table) + blocked_text = "\n".join(selection_report.blocked_summaries) + assert "strategy_id=forward.message.fixed_slot_context_gate.v1" not in blocked_text + assert "strategy_id=reverse.message.fixed_slot_context_gate.v1" not in blocked_text + + binding_plan = build_temporal_forward_executor_binding_plan(table) + message_bindings = tuple( + binding + for binding in binding_plan.bindings + if binding.executor_name == "fixed_slot_context_gate_message" + and binding.surface == "message" + and binding.binding_kind == "parameter" + ) + assert "message_query_context_gate" in {binding.logical_name for binding in message_bindings} + assert "message_query_nudge_scale" not in {binding.logical_name for binding in message_bindings} + + native_schema_text = "\n".join(temporal_native_callable_binding_schema_summaries()) + assert "callable=native.forward.msg_fixed_slot_context_gate.v1" in native_schema_text + assert "callable=native.reverse.msg_fixed_slot_context_gate.v1" in native_schema_text + assert ( + "callable=native.forward.msg_fixed_slot_context_gate.v1,direction=forward,surface=message," + "primitive=linear,binding_kind=parameter,logical_name=message_query_context_scalar" + ) in native_schema_text + + +def test_fixed_slot_context_nudge_message_parameters_materialize_from_compiler_bindings() -> None: + spec = _make_slstm_spec() + declared_message_rule = build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_nudge_dot_product", + ) + runtime = build(replace(spec, message_rule=declared_message_rule)) + native_static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=False, + ) + native_recurrent_value_weight = native_static_tensors["recurrent_sender_value_weight"] + assert torch.is_tensor(native_recurrent_value_weight) + assert int(native_recurrent_value_weight.numel()) > 0 + assert tuple(native_recurrent_value_weight.shape) == ( + int(runtime.recurrent_cell_idx.numel()), + runtime.hidden_size, + runtime.value_dim, + ) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + binding_plan = build_temporal_forward_executor_binding_plan(table) + message_bindings = tuple( + binding + for binding in binding_plan.bindings + if binding.executor_name == "fixed_slot_context_nudge_message" + and binding.surface == "message" + and binding.binding_kind == "parameter" + ) + + assert { + "message_query_slot_weight", + "message_query_nudge_scale", + "message_sender_slot_key_weight", + "message_sender_context_key", + "input_sender_value_weight", + "input_group_value_weight", + "recurrent_sender_value_weight", + "message_output_weight", + } <= {binding.logical_name for binding in message_bindings} + tensors, rows = surface_parameter_tensor_table( + runtime, + handles=(SimpleNamespace(bindings=message_bindings),), + static_tensors=cached_static_tensors, + reference=torch.empty(1), + ) + slot_by_binding_index = {int(row[0].item()): int(row[1].item()) for row in rows} + tensor_by_logical = { + binding.logical_name: tensors[slot_by_binding_index[int(binding.binding_index)]] for binding in message_bindings + } + + recurrent_count = int(runtime.recurrent_cell_idx.numel()) + sender_count = int(runtime.sender_cell_idx.numel()) + input_count = int(runtime.input_cell_idx.numel()) + assert temporal_message_output_dim(runtime) == int(runtime.d_msg) + assert tuple(tensor_by_logical["message_query_slot_weight"].shape) == (recurrent_count, runtime.head_dim) + assert tuple(tensor_by_logical["message_query_nudge_scale"].shape) == (1,) + assert tuple(tensor_by_logical["message_sender_slot_key_weight"].shape) == (sender_count, runtime.head_dim) + assert tuple(tensor_by_logical["message_sender_context_key"].shape) == (sender_count, runtime.head_dim) + assert tuple(tensor_by_logical["input_sender_value_weight"].shape) == ( + input_count, + runtime.hidden_size, + runtime.value_dim, + ) + assert tensor_by_logical["input_group_value_weight"].dim() in {1, 3} + assert tuple(tensor_by_logical["recurrent_sender_value_weight"].shape) == ( + recurrent_count, + runtime.hidden_size, + runtime.value_dim, + ) + assert tuple(tensor_by_logical["message_output_weight"].shape) == (runtime.d_msg, runtime.value_dim) + + +def test_fixed_slot_context_gate_message_parameters_materialize_from_compiler_bindings() -> None: + spec = _make_slstm_spec() + declared_message_rule = build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_gate", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_gate_dot_product", + ) + runtime = build(replace(spec, message_rule=declared_message_rule)) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + binding_plan = build_temporal_forward_executor_binding_plan(table) + message_bindings = tuple( + binding + for binding in binding_plan.bindings + if binding.executor_name == "fixed_slot_context_gate_message" + and binding.surface == "message" + and binding.binding_kind == "parameter" + ) + + assert { + "message_query_slot_weight", + "message_query_context_gate", + "message_sender_slot_key_weight", + "message_sender_context_key", + "input_sender_value_weight", + "input_group_value_weight", + "recurrent_sender_value_weight", + "message_output_weight", + } <= {binding.logical_name for binding in message_bindings} + tensors, rows = surface_parameter_tensor_table( + runtime, + handles=(SimpleNamespace(bindings=message_bindings),), + static_tensors=cached_static_tensors, + reference=torch.empty(1), + ) + slot_by_binding_index = {int(row[0].item()): int(row[1].item()) for row in rows} + tensor_by_logical = { + binding.logical_name: tensors[slot_by_binding_index[int(binding.binding_index)]] for binding in message_bindings + } + + recurrent_count = int(runtime.recurrent_cell_idx.numel()) + sender_count = int(runtime.sender_cell_idx.numel()) + input_count = int(runtime.input_cell_idx.numel()) + assert tuple(tensor_by_logical["message_query_slot_weight"].shape) == (recurrent_count, runtime.head_dim) + assert tuple(tensor_by_logical["message_query_context_gate"].shape) == (1,) + assert tuple(tensor_by_logical["message_sender_slot_key_weight"].shape) == (sender_count, runtime.head_dim) + assert tuple(tensor_by_logical["message_sender_context_key"].shape) == (sender_count, runtime.head_dim) + assert tuple(tensor_by_logical["input_sender_value_weight"].shape) == ( + input_count, + runtime.hidden_size, + runtime.value_dim, + ) + assert tensor_by_logical["input_group_value_weight"].dim() in {1, 3} + assert tuple(tensor_by_logical["recurrent_sender_value_weight"].shape) == ( + recurrent_count, + runtime.hidden_size, + runtime.value_dim, + ) + assert tuple(tensor_by_logical["message_output_weight"].shape) == (runtime.d_msg, runtime.value_dim) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for registered temporal program training") +def test_fixed_slot_context_nudge_trains_through_registered_fused_temporal_program() -> None: + spec = _make_slstm_spec(hidden_size=8, head_dim=8) + declared_message_rule = build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_nudge_dot_product", + ) + runtime = build(replace(spec, message_rule=declared_message_rule)).cuda() + boundary_seq = torch.randn( + 2, + 2, + int(runtime.input_cell_idx.numel()), + int(runtime.hidden_size), + device="cuda", + requires_grad=True, + ) + + output_seq, _state = runtime.forward_output_cells_for_readout( + boundary_input=boundary_seq, + state=None, + resets=None, + k=None, + training_semantics=True, + materialize_final_state=False, + readout_output_boundary="pooled", + ) + output_seq.square().mean().backward() + + assert boundary_seq.grad is not None + assert runtime.msg_to_cell.weight.grad is not None + assert runtime.msg_out.weight.grad is not None + assert runtime.message_query_nudge_scale.grad is not None + assert getattr(runtime, "_last_flat_bucket_temporal_message_strategy_extra_param_grad_roles", ()) == ( + "grad_input_key_bank", + "grad_output_weight", + "grad_query_context_scalar", + "grad_recurrent_key_bank", + ) + assert "registered_temporal_fused_backward_program_cuda" in str( + getattr(runtime, "_last_flat_bucket_temporal_fused_backward_program_output_grad", "") + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for registered temporal program training") +def test_fixed_slot_context_gate_trains_through_registered_fused_temporal_program() -> None: + spec = _make_slstm_spec(hidden_size=8, head_dim=8) + declared_message_rule = build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_gate", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_gate_dot_product", + ) + runtime = build(replace(spec, message_rule=declared_message_rule)).cuda() + boundary_seq = torch.randn( + 2, + 2, + int(runtime.input_cell_idx.numel()), + int(runtime.hidden_size), + device="cuda", + requires_grad=True, + ) + + output_seq, _state = runtime.forward_output_cells_for_readout( + boundary_input=boundary_seq, + state=None, + resets=None, + k=None, + training_semantics=True, + materialize_final_state=False, + readout_output_boundary="pooled", + ) + output_seq.square().mean().backward() + + assert boundary_seq.grad is not None + assert runtime.msg_to_cell.weight.grad is not None + assert runtime.msg_out.weight.grad is not None + assert runtime.message_query_context_gate.grad is not None + assert not hasattr(runtime, "message_query_nudge_scale") + assert getattr(runtime, "_last_flat_bucket_temporal_message_strategy_extra_param_grad_slots", 0) == 4 + assert "registered_temporal_fused_backward_program_cuda" in str( + getattr(runtime, "_last_flat_bucket_temporal_fused_backward_program_output_grad", "") + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for registered temporal program parity") +def test_fixed_slot_context_gate_matches_nudge_when_scalar_binding_is_equal() -> None: + spec = _make_slstm_spec(hidden_size=8, head_dim=8) + nudge_rule = build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_nudge", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_nudge_dot_product", + ) + gate_rule = build_message_rule_ir( + rule_type="dot_product_fixed_slot_context_gate", + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + name="declared_context_gate_dot_product", + ) + nudge_runtime = build(replace(spec, message_rule=nudge_rule)).cuda() + gate_runtime = build(replace(spec, message_rule=gate_rule)).cuda() + nudge_state = nudge_runtime.state_dict() + gate_state = gate_runtime.state_dict() + for name, tensor in gate_state.items(): + if name == "message_query_context_gate": + tensor.copy_(nudge_state["message_query_nudge_scale"]) + elif name in nudge_state and tuple(tensor.shape) == tuple(nudge_state[name].shape): + tensor.copy_(nudge_state[name]) + gate_runtime.load_state_dict(gate_state) + generator = torch.Generator(device="cuda").manual_seed(5151) + boundary_nudge = torch.randn( + 2, + 2, + int(nudge_runtime.input_cell_idx.numel()), + int(nudge_runtime.hidden_size), + device="cuda", + generator=generator, + requires_grad=True, + ) + boundary_gate = boundary_nudge.detach().clone().requires_grad_(True) + + output_nudge, _state_nudge = nudge_runtime.forward_output_cells_for_readout( + boundary_input=boundary_nudge, + state=None, + resets=None, + k=None, + training_semantics=True, + materialize_final_state=False, + readout_output_boundary="pooled", + ) + output_gate, _state_gate = gate_runtime.forward_output_cells_for_readout( + boundary_input=boundary_gate, + state=None, + resets=None, + k=None, + training_semantics=True, + materialize_final_state=False, + readout_output_boundary="pooled", + ) + torch.testing.assert_close(output_gate, output_nudge, rtol=1e-4, atol=2e-5) + + output_nudge.square().mean().backward() + output_gate.square().mean().backward() + + assert boundary_nudge.grad is not None + assert boundary_gate.grad is not None + torch.testing.assert_close(boundary_gate.grad, boundary_nudge.grad, rtol=3e-3, atol=3e-3) + nudge_grads = {name: param.grad for name, param in nudge_runtime.named_parameters() if param.grad is not None} + gate_grads = {name: param.grad for name, param in gate_runtime.named_parameters() if param.grad is not None} + for name in sorted(set(nudge_grads) & set(gate_grads)): + torch.testing.assert_close(gate_grads[name], nudge_grads[name], rtol=3e-3, atol=3e-3) + torch.testing.assert_close( + gate_grads["message_query_context_gate"], + nudge_grads["message_query_nudge_scale"], + rtol=3e-3, + atol=3e-3, + ) + assert getattr(gate_runtime, "_last_flat_bucket_temporal_message_strategy_extra_param_grad_roles", ()) == ( + "grad_input_key_bank", + "grad_output_weight", + "grad_query_context_scalar", + "grad_recurrent_key_bank", + ) + assert "registered_temporal_fused_backward_program_cuda" in str( + getattr(gate_runtime, "_last_flat_bucket_temporal_fused_backward_program_output_grad", "") + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for registered temporal program training") +def test_generic_transition_primitive_dag_trains_through_registered_fused_temporal_program() -> None: + spec = _make_slstm_spec(hidden_size=8, head_dim=8) + runtime = build(_with_context_nudge_message_rule(spec)).cuda() + slstm_spec = build_cell_backend_spec( + cell_type="slstm", + hidden_size=8, + d_public=8, + d_msg=8, + head_dim=8, + value_dim=8, + ) + primitive_dag_spec = replace( + slstm_spec, + transition_ir=CellTransitionIR( + state_inputs=(), + message_inputs=("aggregated_message",), + parameter_inputs=("value_to_state_weight", "recurrent_bias", "outnorm_weight", "outnorm_eps"), + ops=( + TransitionOp( + "linear", + ("aggregated_message", "value_to_state_weight", "recurrent_bias"), + ("transition_input",), + ), + TransitionOp( + "norm_or_identity", + ("transition_input", "outnorm_weight", "outnorm_eps"), + ("normalized",), + ), + TransitionOp("tanh", ("normalized",), ("public_y",)), + ), + state_outputs=(), + public_outputs=("public_y",), + recompute_outputs=("transition_input", "normalized", "public_y"), + backward_decomposition=(), + ), + ) + runtime._backend_ir = replace( # noqa: SLF001 + runtime.backend_ir, + transition_programs=(compile_transition_program(primitive_dag_spec, binding_slot=0),), + ) + static_tensors = runtime._get_inference_static_tensors( # noqa: SLF001 + device=torch.device("cuda"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + executor_program = build_registered_temporal_executor_program(runtime, static_tensors) + boundary_seq = torch.randn( + 2, + 2, + int(runtime.input_cell_idx.numel()), + int(runtime.hidden_size), + device="cuda", + requires_grad=True, + ) + + output_seq, _state = runtime.forward_output_cells_for_readout( + boundary_input=boundary_seq, + state=None, + resets=None, + k=None, + training_semantics=True, + materialize_final_state=False, + readout_output_boundary="pooled", + ) + output_seq.square().mean().backward() + + transition_forward = [ + handle.executor_name for handle in executor_program.forward_handles if handle.surface == "transition" + ] + transition_reverse = [ + handle.executor_name for handle in executor_program.reverse_handles if handle.surface == "transition" + ] + + assert transition_forward == [ + "transition_linear_primitive", + "transition_norm_or_identity_primitive", + "tanh_transition", + ] + assert transition_reverse == [ + "tanh_transition_backward", + "transition_norm_or_identity_primitive_backward", + "transition_linear_primitive_backward", + ] + assert boundary_seq.grad is not None + assert runtime.msg_to_cell.weight.grad is not None + assert runtime.cell_bias_proj.weight.grad is not None + assert runtime.population_modules.slstm.outnorm_weight_base.grad is not None + assert "registered_temporal_fused_backward_program_cuda" in str( + getattr(runtime, "_last_flat_bucket_temporal_fused_backward_program_output_grad", "") + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for stateful primitive-DAG training") +def test_stateful_transition_primitive_dag_trains_with_resets_through_registered_temporal_program() -> None: + spec = _make_slstm_spec(hidden_size=8, head_dim=8) + runtime = build(_with_context_nudge_message_rule(spec)).cuda() + runtime._backend_ir = replace( # noqa: SLF001 + runtime.backend_ir, + transition_programs=(compile_transition_program(_stateful_tanh_transition_spec(), binding_slot=0),), + ) + static_tensors = runtime._get_inference_static_tensors( # noqa: SLF001 + device=torch.device("cuda"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + executor_program = build_registered_temporal_executor_program(runtime, static_tensors) + state = runtime.init_state(batch=2, device="cuda", dtype=torch.float32) + assert tuple(state["slstm"].keys()) == ("mem",) + state["slstm"]["mem"].normal_(mean=0.0, std=0.25) + state["slstm"]["mem"].requires_grad_(True) + boundary_seq = torch.randn( + 2, + 3, + int(runtime.input_cell_idx.numel()), + int(runtime.hidden_size), + device="cuda", + requires_grad=True, + ) + resets = torch.tensor( + [ + [False, False, False], + [False, True, False], + ], + dtype=torch.bool, + device="cuda", + ) + + output_seq, final_state = runtime.forward_output_cells_for_readout( + boundary_input=boundary_seq, + state=state, + resets=resets, + k=2, + training_semantics=True, + materialize_final_state=True, + output_boundary="sequence", + ) + assert tuple(final_state["slstm"].keys()) == ("mem",) + loss = output_seq.square().mean() + final_state["slstm"]["mem"].square().mean() + loss.backward() + + transition_forward = [ + handle.executor_name for handle in executor_program.forward_handles if handle.surface == "transition" + ] + transition_reverse = [ + handle.executor_name for handle in executor_program.reverse_handles if handle.surface == "transition" + ] + record = runtime.last_backend_execution + assert record is not None + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert "flat_bucket_temporal_reverse_scan_owner:registered_reverse_program_window" in ( + record.workspace_aliases + record.backward_workspace_aliases + ) + assert transition_forward == ["tanh_transition", "tanh_transition"] + assert transition_reverse == ["tanh_transition_backward", "tanh_transition_backward"] + assert boundary_seq.grad is not None + assert float(boundary_seq.grad.abs().sum().item()) > 0.0 + assert state["slstm"]["mem"].grad is not None + assert float(state["slstm"]["mem"].grad.abs().sum().item()) > 0.0 + assert "registered_temporal_fused_backward_program_cuda" in str( + getattr(runtime, "_last_flat_bucket_temporal_fused_backward_program_output_grad", "") + ) + + +def test_message_rule_builder_rejects_unregistered_rule_type() -> None: + with pytest.raises(ValueError, match="Unsupported Fabric message rule backend"): + build_message_rule_ir(rule_type="not_registered", kv_group_count=1, cell_count=16) + + +def test_backend_ir_uses_declared_spec_message_rule_not_default_substitution() -> None: + spec = _make_slstm_spec() + declared_message_rule = replace( + default_dot_product_message_rule_ir( + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + ), + name="declared_dot_product", + ) + runtime = build(replace(spec, message_rule=declared_message_rule)) + + assert runtime.backend_ir.message_rule.name == "declared_dot_product" + assert runtime.backend_ir.message_rule is declared_message_rule + assert runtime.backend_ir.message_program.rule_name == "declared_dot_product" + planned = runtime.plan_backend_execution( + batch_size=1, + time_steps=1, + inner_steps=1, + training=False, + surface_key="slstm_recurrence", + ) + assert all(plan.message_rule_name == "declared_dot_product" for plan in planned.bucket_plans) + assert all( + plan.message_rule_lowering_kind == "dot_product_fixed_slot_context_nudge" for plan in planned.bucket_plans + ) + + +def test_temporal_message_rows_come_from_compiled_message_program() -> None: + spec = _make_slstm_spec() + declared_message_rule = replace( + default_dot_product_message_rule_ir( + kv_group_count=int(spec.num_kv_groups), + cell_count=int(spec.anatomy.num_cells), + ), + name="declared_dot_product", + ) + runtime = build(replace(spec, message_rule=declared_message_rule)) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + message_rows = [row for row in table.primitive_rows if "surface=message" in row.flat_bucket_identity] + assert message_rows + assert all("rule=declared_dot_product" in row.flat_bucket_identity for row in message_rows) + assert any(("compiled_rule", "declared_dot_product") in row.attributes for row in message_rows) + assert any( + ("compiled_lowering_kind", "dot_product_fixed_slot_context_nudge") in row.attributes for row in message_rows + ) + assert [row.parameter_inputs for row in message_rows] == [ + op.parameter_bindings for op in runtime.backend_ir.message_program.primitive_ops + ] + assert {parameter for row in message_rows for parameter in row.parameter_inputs} == { + "message_query_slot_weight", + "message_query_nudge_scale", + "message_sender_slot_key_weight", + "message_sender_context_key", + "input_sender_value_weight", + "input_group_value_weight", + "recurrent_sender_value_weight", + "message_output_weight", + } + assert [row.primitive for row in message_rows] == [ + op.primitive for op in runtime.backend_ir.message_program.primitive_ops + ] + + +def test_temporal_transition_rows_come_from_compiled_transition_programs() -> None: + runtime = build(_make_alias_population_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + bucket_plan = temporal_bucket_plan(runtime, cached_static_tensors) + + for bucket_ordinal, bucket in enumerate(bucket_plan.flat_buckets): + program = runtime.backend_ir.transition_program_for_binding_slot(bucket.binding_slot) + transition_rows = [ + row + for row in table.primitive_rows + if row.bucket_ordinal == bucket_ordinal and "surface=transition" in row.flat_bucket_identity + ] + assert transition_rows + assert [row.primitive for row in transition_rows] == [op.primitive for op in program.primitive_ops] + expected_parameter_inputs = [] + for op in program.primitive_ops: + parameter_inputs = tuple(op.parameter_inputs) + if op.primitive == "norm_or_identity": + parameter_inputs = (*parameter_inputs, "outnorm_eps") + elif op.primitive in {"diag_rtu", "diagonal_recurrence"}: + parameter_inputs = (*parameter_inputs, "activation_id") + expected_parameter_inputs.append(tuple(dict.fromkeys(parameter_inputs))) + assert [row.parameter_inputs for row in transition_rows] == expected_parameter_inputs + assert all( + ("compiled_transition_lowering_kind", program.lowering_kind) in row.attributes for row in transition_rows + ) + assert all( + ("compiled_transition_binding_slot", str(int(bucket.binding_slot))) in row.attributes + for row in transition_rows + ) + + +def test_slstm_transition_rows_declare_input_projection_executed_by_temporal_scan() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + program = runtime.backend_ir.transition_program_for_binding_slot(0) + transition_rows = [ + row + for row in table.primitive_rows + if row.bucket_ordinal == 0 and "surface=transition" in row.flat_bucket_identity + ] + + assert [op.primitive for op in program.primitive_ops] == [ + "linear", + "linear", + "matmul", + "gated_logspace_recurrence", + "norm_or_identity", + ] + assert program.primitive_ops[0].inputs == ( + "aggregated_message", + "value_to_state_weight", + "recurrent_bias", + ) + assert transition_rows[0].parameter_inputs == ("value_to_state_weight", "recurrent_bias") + assert transition_rows[1].parameter_inputs == ("gate_weight", "bias") + + +def test_temporal_transition_parameters_resolve_from_compiled_bindings() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + bucket = temporal_bucket_plan(runtime, cached_static_tensors).flat_buckets[0] + population_materialized = bucket.static_tensors["population_materialized"] + assert isinstance(population_materialized, dict) + population_params = population_materialized[bucket.name] + assert isinstance(population_params, dict) + + program = compiled_transition_program_for_bucket(runtime, bucket) + value_to_state = resolve_transition_parameter( + program, + population_params, + bucket.static_tensors, + "value_to_state_weight", + num_receivers=bucket.count, + ) + recurrent_bias = resolve_transition_parameter( + program, + population_params, + bucket.static_tensors, + "recurrent_bias", + num_receivers=bucket.count, + ) + + expected_value_to_state = bucket.static_tensors.get("fused_recurrent_value_to_cell_weight") + if not torch.is_tensor(expected_value_to_state): + expected_value_to_state = bucket.static_tensors.get("message_to_cell_weight") + if not torch.is_tensor(expected_value_to_state): + expected_value_to_state = bucket.static_tensors["value_to_cell_weight"] + expected_recurrent_bias = bucket.static_tensors.get("fused_recurrent_cell_bias") + if not torch.is_tensor(expected_recurrent_bias): + expected_recurrent_bias = bucket.static_tensors["recurrent_cell_bias"] + assert value_to_state.data_ptr() == expected_value_to_state.data_ptr() + assert recurrent_bias.data_ptr() == expected_recurrent_bias.data_ptr() + with pytest.raises(RuntimeError, match="has no compiled binding"): + resolve_transition_parameter( + program, + population_params, + bucket.static_tensors, + "value_to_cell_weight", + num_receivers=bucket.count, + ) + + +def test_temporal_tensor_binding_rows_are_compiler_products() -> None: + runtime = build(_make_mixed_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + bindings = table.tensor_bindings + + assert bindings + assert len(bindings) > len(table.primitive_rows) + message_param_rows = [ + binding for binding in bindings if binding.surface == "message" and binding.binding_kind == "parameter" + ] + transition_param_rows = [ + binding for binding in bindings if binding.surface == "transition" and binding.binding_kind == "parameter" + ] + readout_param_rows = [ + binding for binding in bindings if binding.surface == "readout" and binding.binding_kind == "parameter" + ] + parameter_reduction_rows = [ + binding + for binding in bindings + if binding.surface == "parameter_reduction" and binding.binding_kind == "parameter" + ] + + assert {binding.logical_name for binding in message_param_rows} == { + "message_query_slot_weight", + "message_query_nudge_scale", + "message_sender_slot_key_weight", + "message_sender_context_key", + "input_sender_value_weight", + "input_group_value_weight", + "recurrent_sender_value_weight", + "message_output_weight", + } + assert all( + any(source.startswith("message_parameter:") for source in row.source_bindings) for row in message_param_rows + ) + assert any( + "static_tensor:message_query_slot_weight" in row.source_bindings + for row in message_param_rows + if row.logical_name == "message_query_slot_weight" + ) + assert any( + "static_tensor:recurrent_sender_value_weight" in row.source_bindings + for row in message_param_rows + if row.logical_name == "recurrent_sender_value_weight" + ) + assert any( + "static_tensor:input_sender_value_weight" in row.source_bindings + for row in message_param_rows + if row.logical_name == "input_sender_value_weight" + ) + assert any( + "static_tensor:input_group_value_weight" in row.source_bindings + for row in message_param_rows + if row.logical_name == "input_group_value_weight" + ) + assert {"value_to_state_weight", "recurrent_bias"}.issubset( + {binding.logical_name for binding in transition_param_rows} + ) + assert {"output_q", "value_to_output_weight", "output_cell_bias"}.issubset( + {binding.logical_name for binding in readout_param_rows} + ) + assert any("static_tensor:output_q" in row.source_bindings for row in readout_param_rows) + assert any("static_tensor:value_to_output_weight" in row.source_bindings for row in readout_param_rows) + assert any("runtime_attr:output_cell_bias" in row.source_bindings for row in readout_param_rows) + assert parameter_reduction_rows + assert "value_to_cell_weight" not in {binding.logical_name for binding in bindings} + assert temporal_tensor_binding_summaries(table) + projection = validate_temporal_supported_scan_binding_projection(table) + assert "message=fixed_slot_context_nudge_message" in projection + assert "readout=mean_projection_reduction_boundary" in projection + + +def test_forward_executor_bindings_fail_when_compiler_binding_is_missing() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + table_without_message_output_binding = replace( + table, + tensor_bindings=tuple( + binding + for binding in table.tensor_bindings + if not ( + binding.surface == "message" + and binding.binding_kind == "parameter" + and binding.logical_name == "message_output_weight" + ) + ), + ) + + with pytest.raises(RuntimeError, match="executor binding plan.*message_output_weight"): + build_temporal_forward_executor_binding_plan(table_without_message_output_binding) + + +def test_reverse_executor_bindings_fail_when_compiler_binding_is_missing() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + table_without_query_slot_binding = replace( + table, + tensor_bindings=tuple( + binding + for binding in table.tensor_bindings + if not ( + binding.surface == "message" + and binding.binding_kind == "parameter" + and binding.logical_name == "message_query_slot_weight" + ) + ), + ) + + binding_plan = build_temporal_reverse_executor_binding_plan(table_without_query_slot_binding) + blocker_text = "\n".join(binding_plan.blocker_summaries) + assert binding_plan.has_blockers + assert "code=MISSING_REQUIRED_BINDING" in blocker_text + assert "message_query_slot_weight" in blocker_text + + +def test_temporal_tensor_binding_rows_fail_when_compiled_signature_drifts() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + message_program = runtime.backend_ir.message_program + first_op = message_program.primitive_ops[0] + broken_first_op = replace( + first_op, + inputs=(*first_op.inputs, "missing_declared_weight"), + parameter_bindings=("missing_declared_weight",), + ) + runtime._backend_ir = replace( + runtime.backend_ir, + message_program=replace(message_program, primitive_ops=(broken_first_op, *message_program.primitive_ops[1:])), + ) + + with pytest.raises(RuntimeError, match="missing compiled parameter binding"): + build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + +def test_surface_parameter_tensor_table_fails_closed_on_unresolved_compiler_source() -> None: + binding = TemporalExecutorTensorBinding( + direction="forward", + executor_row_index=0, + executor_id=1, + executor_name="registered.forward.compiler_handler", + surface="message", + bucket_ordinal=-1, + receiver_start=0, + receiver_count=0, + local_binding_index=0, + primitive_row_index=0, + primitive="neighborhood_attention", + binding_index=17, + binding_kind="parameter", + logical_name="declared_query_weight", + source_bindings=("static_tensor:missing_query", "runtime_attr:missing_runtime_query"), + ) + + with pytest.raises(RuntimeError, match="could not resolve compiler parameter binding") as error: + surface_parameter_tensor_table( + SimpleNamespace(), + handles=(SimpleNamespace(bindings=(binding,)),), + static_tensors={}, + reference=torch.empty(1), + ) + + message = str(error.value) + assert "declared_query_weight" in message + assert "missing_query" in message + assert "missing_runtime_query" in message + + +def test_surface_parameter_tensor_table_materializes_optional_message_projection_as_empty() -> None: + binding = TemporalExecutorTensorBinding( + direction="forward", + executor_row_index=0, + executor_id=1, + executor_name="neighborhood_attention_project", + surface="message", + bucket_ordinal=-1, + receiver_start=0, + receiver_count=0, + local_binding_index=1, + primitive_row_index=0, + primitive="linear", + binding_index=18, + binding_kind="parameter", + logical_name="input_sender_kv_weight", + source_bindings=("static_tensor:input_sender_input_to_kv_weight",), + ) + + tensors, rows = surface_parameter_tensor_table( + SimpleNamespace(), + handles=(SimpleNamespace(bindings=(binding,)),), + static_tensors={"input_sender_input_to_kv_weight": None}, + reference=torch.empty(2, 3, dtype=torch.float32), + ) + + assert len(tensors) == 1 + assert tensors[0].shape == (0,) + assert tensors[0].dtype == torch.float32 + assert rows.tolist() == [[18, 0, 0, 1]] + + +def test_temporal_scan_binding_projection_fails_for_consistent_message_signature_drift() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + message_rule = runtime.backend_ir.message_rule + message_program = runtime.backend_ir.message_program + first_op = message_program.primitive_ops[0] + runtime._backend_ir = replace( + runtime.backend_ir, + message_rule=replace( + message_rule, + parameters=( + replace(message_rule.parameters[0], name="alt_query_weight"), + *message_rule.parameters[1:], + ), + ), + message_program=replace( + message_program, + primitive_ops=( + replace( + first_op, + inputs=("receiver_slot", "alt_query_weight"), + parameter_bindings=("alt_query_weight",), + ), + *message_program.primitive_ops[1:], + ), + ), + ) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + with pytest.raises(RuntimeError, match="message binding rows"): + validate_temporal_supported_scan_binding_projection(table) + + +def test_temporal_scan_binding_projection_fails_for_readout_signature_drift() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + readout_program = runtime.backend_ir.readout_program + first_op = readout_program.primitive_ops[0] + runtime._backend_ir = replace( + runtime.backend_ir, + readout_program=replace( + readout_program, + primitive_ops=( + replace( + first_op, + inputs=("public_state", "alt_readout_weight"), + parameter_inputs=("alt_readout_weight",), + ), + *readout_program.primitive_ops[1:], + ), + ), + ) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + with pytest.raises(RuntimeError, match="readout binding rows"): + validate_temporal_supported_scan_binding_projection(table) + + +def test_temporal_readout_rows_come_from_compiled_readout_program() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + readout_rows = [ + row + for row in table.primitive_rows + if "surface=readout" in row.flat_bucket_identity or "surface=readout_boundary" in row.flat_bucket_identity + ] + assert [row.primitive for row in readout_rows] == [ + op.primitive for op in runtime.backend_ir.readout_program.primitive_ops + ] + assert [row.parameter_inputs for row in readout_rows] == [ + op.parameter_inputs for op in runtime.backend_ir.readout_program.primitive_ops + ] + assert ("output_q", "value_to_output_weight", "output_cell_bias") in [row.parameter_inputs for row in readout_rows] + assert all( + ("compiled_readout_rule", runtime.backend_ir.readout_program.rule_name) in row.attributes + for row in readout_rows + ) + assert all( + ("compiled_readout_lowering_kind", runtime.backend_ir.readout_program.lowering_kind) in row.attributes + for row in readout_rows + ) + + +def test_temporal_forward_primitive_row_tensor_encodes_supported_program_groups() -> None: + runtime = build(_make_alias_population_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + primitive_rows = temporal_primitive_rows_tensor(table) + + def opcodes_for_bucket(bucket_ordinal: int) -> list[int]: + return [int(row[0]) for row in primitive_rows.tolist() if int(row[3]) == bucket_ordinal] + + assert opcodes_for_bucket(-1) == [10, 10, 18, 19, 10, 19, 10, 13, 12, 14, 15, 10, 35] + assert opcodes_for_bucket(-2) == [20, 21] + transition_programs = { + tuple(opcodes_for_bucket(bucket_ordinal)) + for bucket_ordinal in range(len(temporal_bucket_plan(runtime, cached_static_tensors).flat_buckets)) + } + assert (10, 10, 11, 1, 30) in transition_programs + assert (10, 2, 10, 30) in transition_programs + + +def test_message_rule_compiler_rejects_unsupported_rule_before_runtime_execution() -> None: + unsupported = MessageRuleIR( + name="unsupported_message", + sources=(), + parameters=(), + nodes=(), + output_boundary="projected_message", + ) + + with pytest.raises(ValueError, match="Unsupported Fabric message rule"): + compile_message_rule(unsupported) + with pytest.raises(ValueError, match="Unsupported Fabric message rule"): + build(replace(_make_slstm_spec(), message_rule=unsupported)) + + +def test_transition_program_compiler_rejects_unsupported_transition_op() -> None: + slstm_spec = build_cell_backend_spec( + cell_type="slstm", + hidden_size=8, + d_public=8, + d_msg=8, + head_dim=8, + value_dim=8, + ) + broken = replace( + slstm_spec, + transition_ir=replace( + slstm_spec.transition_ir, + ops=(TransitionOp("unsupported_transition", ("x",), ("y",)),), + ), + ) + + with pytest.raises(ValueError, match="Unsupported Fabric transition op"): + compile_transition_program(broken, binding_slot=0) + + +def test_transition_program_compiler_uses_cuda_nn_primitive_registry() -> None: + slstm_spec = build_cell_backend_spec( + cell_type="slstm", + hidden_size=8, + d_public=8, + d_msg=8, + head_dim=8, + value_dim=8, + ) + primitive = replace( + slstm_spec, + transition_ir=replace( + slstm_spec.transition_ir, + ops=(TransitionOp("tanh", ("x",), ("y",)),), + ), + ) + + assert "tanh" in cuda_nn_callable_primitives() + program = compile_transition_program(primitive, binding_slot=0) + + assert program.primitive_names == ("tanh",) + + +def test_temporal_forward_executor_rows_fall_back_to_registered_primitive_dag_strategies() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + slstm_spec = build_cell_backend_spec( + cell_type="slstm", + hidden_size=8, + d_public=8, + d_msg=8, + head_dim=8, + value_dim=8, + ) + primitive_dag_spec = replace( + slstm_spec, + transition_ir=CellTransitionIR( + state_inputs=(), + message_inputs=("aggregated_message",), + parameter_inputs=("value_to_state_weight", "recurrent_bias"), + ops=( + TransitionOp( + "linear", + ("aggregated_message", "value_to_state_weight", "recurrent_bias"), + ("transition_input",), + ), + TransitionOp("tanh", ("transition_input",), ("public_y",)), + ), + state_outputs=(), + public_outputs=("public_y",), + recompute_outputs=("transition_input", "public_y"), + backward_decomposition=(), + ), + ) + runtime._backend_ir = replace( # noqa: SLF001 + runtime.backend_ir, + transition_programs=(compile_transition_program(primitive_dag_spec, binding_slot=0),), + ) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + forward_rows = temporal_forward_executor_rows(table) + transition_forward_rows = [row for row in forward_rows if row.surface == "transition"] + + assert [row.executor_name for row in transition_forward_rows] == [ + "transition_linear_primitive", + "tanh_transition", + ] + assert [row.primitive_row_count for row in transition_forward_rows] == [1, 1] + program = build_registered_temporal_executor_program(runtime, cached_static_tensors, table_plan=table) + transition_handles = tuple(handle for handle in program.forward_handles if handle.surface == "transition") + access_rows = temporal_forward_program_access_rows_tensor( + message_handles=program.forward_surface_handles(surface="message"), + readout_handles=program.forward_surface_handles(surface="readout"), + transition_handles=transition_handles, + ) + transition_access_opcodes_by_executor = { + handle.executor_name: { + int(row[5]) + for row in access_rows + if int(row[1]) == int(handle.row_index) and int(row[2]) == int(handle.bucket_ordinal) + } + for handle in transition_handles + } + assert transition_access_opcodes_by_executor["transition_linear_primitive"] == { + temporal_program_access_opcode("transition_aggregated_message_input") + } + assert transition_access_opcodes_by_executor["tanh_transition"] == { + temporal_program_access_opcode("transition_public_state_output") + } + assert build_temporal_forward_executable_plan(table).strategy_legality_status == "legal" + assert temporal_forward_executor_handler_rows_tensor(table).shape[0] == len(forward_rows) + tanh_output_contract = temporal_native_callable_transition_forward_output_definition( + primitive="tanh", + output_name="public_y", + output_index=0, + ) + assert tanh_output_contract.output_name == "output" + assert tanh_output_contract.runtime_role == "transition_forward_unary_output" + reverse_rows = [row for row in temporal_reverse_executor_rows(table) if row.surface == "transition"] + assert [row.executor_name for row in reverse_rows] == [ + "tanh_transition_backward", + "transition_linear_primitive_backward", + ] + reverse_binding_plan = build_temporal_reverse_executor_binding_plan(table) + reverse_binding_text = "\n".join(reverse_binding_plan.summaries) + assert "executor=tanh_transition_backward" in reverse_binding_text + assert "logical=grad_public_y,sources=reverse_seed:grad_public_y" in reverse_binding_text + assert "executor=transition_linear_primitive_backward" in reverse_binding_text + assert "logical=grad_transition_input,sources=reverse_internal:grad_transition_input" in reverse_binding_text + transition_param_grad_text = "\n".join( + build_temporal_transition_param_grad_binding_plan(table, reverse_binding_plan=reverse_binding_plan).summaries + ) + assert "parameter=value_to_state_weight" in transition_param_grad_text + assert "reducer=input_projection_weight" in transition_param_grad_text + assert "parameter=recurrent_bias" in transition_param_grad_text + assert "reducer=input_projection_bias" in transition_param_grad_text + + +def test_stateful_transition_primitive_dag_derives_carry_rows_from_compiler_bindings() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + runtime._backend_ir = replace( # noqa: SLF001 + runtime.backend_ir, + transition_programs=(compile_transition_program(_stateful_tanh_transition_spec(), binding_slot=0),), + ) + static_tensors = runtime._get_inference_static_tensors( # noqa: SLF001 + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + table = build_temporal_primitive_table_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)) + program = build_registered_temporal_executor_program(runtime, static_tensors, table_plan=table) + transition_handles = tuple(handle for handle in program.forward_handles if handle.surface == "transition") + + carry_rows = temporal_forward_transition_state_carry_rows_tensor(transition_handles=transition_handles) + input_binding = next( + binding + for handle in transition_handles + for binding in handle.bindings + if binding.binding_kind == "input" and binding.logical_name == "mem" + ) + output_binding = next( + binding + for handle in transition_handles + for binding in handle.bindings + if binding.binding_kind == "output" and binding.logical_name == "next_mem" + ) + + assert [0, int(input_binding.binding_index), int(output_binding.binding_index)] in carry_rows.tolist() + assert not any( + pattern.state_carry_rules + for pattern in temporal_forward_executor_patterns() + if pattern.executor_name == "tanh_transition" + ) + + +def test_stateful_transition_primitive_dag_registers_dynamic_reverse_seed_roles() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + runtime._backend_ir = replace( # noqa: SLF001 + runtime.backend_ir, + transition_programs=(compile_transition_program(_stateful_tanh_transition_spec(), binding_slot=0),), + ) + static_tensors = runtime._get_inference_static_tensors( # noqa: SLF001 + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + + reverse_binding_plan = build_temporal_reverse_executor_binding_plan(table) + reverse_binding_text = "\n".join(reverse_binding_plan.summaries) + dynamic_role_id = temporal_transition_reverse_seed_role_id("grad_next_mem") + default_role_ids = {int(row[0]) for row in temporal_transition_reverse_seed_role_rows_tensor().tolist()} + dynamic_role_ids = { + int(row[0]) for row in temporal_transition_reverse_seed_role_rows_tensor(("grad_next_mem",)).tolist() + } + program = build_registered_temporal_executor_program(runtime, static_tensors, table_plan=table) + + assert "logical=grad_next_mem,sources=reverse_seed:grad_next_mem" in reverse_binding_text + assert "logical=grad_mem" in reverse_binding_text + assert dynamic_role_id not in default_role_ids + assert dynamic_role_id in dynamic_role_ids + assert dynamic_role_id in {int(row[0]) for row in program.transition_reverse_seed_role_rows.tolist()} + + +def test_temporal_reverse_executor_rows_cover_parameterless_primitive_dag_adjoint() -> None: + runtime = build(_with_context_nudge_message_rule(_make_slstm_spec())) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + cached_static_tensors = with_cached_population_static_tensors(runtime, static_tensors) + slstm_spec = build_cell_backend_spec( + cell_type="slstm", + hidden_size=8, + d_public=8, + d_msg=8, + head_dim=8, + value_dim=8, + ) + tanh_spec = replace( + slstm_spec, + transition_ir=CellTransitionIR( + state_inputs=(), + message_inputs=("aggregated_message",), + parameter_inputs=(), + ops=(TransitionOp("tanh", ("aggregated_message",), ("public_y",)),), + state_outputs=(), + public_outputs=("public_y",), + recompute_outputs=("public_y",), + backward_decomposition=(), + ), + ) + runtime._backend_ir = replace( # noqa: SLF001 + runtime.backend_ir, + transition_programs=(compile_transition_program(tanh_spec, binding_slot=0),), + ) + + table = build_temporal_primitive_table_plan(runtime, cached_static_tensors) + forward_rows = [row for row in temporal_forward_executor_rows(table) if row.surface == "transition"] + reverse_rows = [row for row in temporal_reverse_executor_rows(table) if row.surface == "transition"] + + assert [row.executor_name for row in forward_rows] == ["tanh_transition"] + assert [row.executor_name for row in reverse_rows] == ["tanh_transition_backward"] + assert build_temporal_forward_executable_plan(table).strategy_legality_status == "legal" + assert build_temporal_backward_executable_plan(table).strategy_legality_status == "legal" + + +def test_transition_executor_selection_uses_registered_structural_records() -> None: + runtime = build(_make_slstm_spec()) + static_tensors = runtime._get_inference_static_tensors( + device=torch.device("cpu"), + dtype=torch.float32, + include_full_cell_kv_weight=True, + ) + bucket = temporal_bucket_plan(runtime, with_cached_population_static_tensors(runtime, static_tensors)).flat_buckets[ + 0 + ] + program = compiled_transition_program_for_bucket(runtime, bucket) + + plan = select_transition_program_executor(program) + + assert plan.executor == "gated_logspace_recurrence" + assert plan.registry_id == "transition_executor:gated_logspace_recurrence:v1" + assert plan.selection_kind == "fused_program_record" + assert plan.runtime_execution_status == "registered_fused_program" + assert plan.primitive_dag is not None + assert plan.primitive_dag.registry_id == "transition_executor:primitive_dag:v1" + assert tuple(edge.summary for edge in plan.primitive_dag.tensor_edges) == ( + "0:transition_input->1:transition_input", + "1:gate_logits->3:gate_logits", + "2:recurrent_gate_logits->3:recurrent_gate_logits", + "3:next_y->4:next_y", + ) + record_by_executor = {record.executor: record for record in registered_transition_executor_records()} + assert plan.executor in record_by_executor + record = record_by_executor[plan.executor] + assert record.registry_id == plan.registry_id + assert plan.forward_strategy_id == record.forward_strategy_id + assert plan.backward_strategy_id == record.backward_strategy_id + assert plan.program_layer_status == "callable" + assert plan.program_layer_blocker_codes == () + assert "program_transition_linear_forward" in plan.program_forward_symbols + assert "program_transition_gated_logspace_recurrence_backward" in plan.program_backward_symbols + assert "program_transition_linear_forward" not in plan.program_missing_symbols + assert "program_transition_linear_backward" not in plan.program_missing_symbols + assert "program_transition_recurrent_matmul_forward" not in plan.program_missing_symbols + assert "program_transition_recurrent_matmul_backward" not in plan.program_missing_symbols + assert "program_transition_gated_logspace_recurrence_forward" not in plan.program_missing_symbols + assert "program_transition_gated_logspace_recurrence_backward" not in plan.program_missing_symbols + assert "program_transition_norm_or_identity_forward" not in plan.program_missing_symbols + assert "program_transition_norm_or_identity_backward" not in plan.program_missing_symbols + assert "program_transition_diag_rtu_forward" not in plan.program_missing_symbols + plan_review = "\n".join(plan.review_summary) + assert "program_layer_status=callable" in plan_review + assert "selection_kind=fused_program_record" in plan_review + assert "runtime_execution_status=registered_fused_program" in plan_review + assert "primitive_dag_registry_id=transition_executor:primitive_dag:v1" in plan_review + assert "program_layer_blocker_codes=-" in plan_review + assert "program_missing_symbols=-" in plan_review + assert "forward_strategy_id=forward.transition.gated_logspace.v1" in record.review_summary + assert "backward_strategy_id=reverse.transition.gated_logspace.v1" in record.review_summary + primitive_record_by_name = { + record.primitive: record for record in registered_transition_primitive_executor_records() + } + assert set(program.primitive_names).issubset(primitive_record_by_name) + assert primitive_record_by_name["linear"].status == "cuda" + assert primitive_record_by_name["linear"].program_layer_status == "callable" + assert primitive_record_by_name["linear"].program_forward_status == "callable" + assert primitive_record_by_name["linear"].program_backward_status == "callable" + assert primitive_record_by_name["linear"].program_layer_blocker_code == "" + assert primitive_record_by_name["linear"].program_forward_symbol == "program_transition_linear_forward" + assert primitive_record_by_name["linear"].program_forward_cxx_entrypoint + assert primitive_record_by_name["linear"].program_forward_input_bindings == ("input",) + assert primitive_record_by_name["linear"].program_forward_output_bindings == (("output", True),) + assert primitive_record_by_name["linear"].program_forward_output_contracts + assert primitive_record_by_name["matmul"].program_layer_status == "callable" + assert primitive_record_by_name["matmul"].program_forward_status == "callable" + assert primitive_record_by_name["matmul"].program_backward_status == "callable" + assert primitive_record_by_name["matmul"].program_layer_blocker_code == "" + assert primitive_record_by_name["matmul"].program_forward_cxx_entrypoint + assert primitive_record_by_name["matmul"].program_reverse_native_callable == ( + "native.reverse.transition_matmul_primitive.v1" + ) + assert primitive_record_by_name["matmul"].param_grad_outputs == (("grad_weight", "weight", "materialized"),) + assert primitive_record_by_name["matmul"].reverse_input_bindings == ("input", "grad_output") + assert primitive_record_by_name["matmul"].parameter_bindings == ("weight",) + assert primitive_record_by_name["matmul"].reverse_output_bindings == ("grad_input", "grad_weight") + assert primitive_record_by_name["gated_logspace_recurrence"].cuda_executor + assert primitive_record_by_name["gated_logspace_recurrence"].program_layer_status == "callable" + assert primitive_record_by_name["gated_logspace_recurrence"].program_forward_status == "callable" + assert primitive_record_by_name["gated_logspace_recurrence"].program_backward_status == "callable" + assert primitive_record_by_name["gated_logspace_recurrence"].program_layer_blocker_code == "" + assert primitive_record_by_name["gated_logspace_recurrence"].program_reverse_native_callable + assert primitive_record_by_name["norm_or_identity"].program_layer_status == "callable" + assert primitive_record_by_name["norm_or_identity"].program_forward_status == "callable" + assert primitive_record_by_name["norm_or_identity"].program_backward_status == "callable" + assert primitive_record_by_name["norm_or_identity"].program_layer_blocker_code == "" + assert primitive_record_by_name["norm_or_identity"].program_forward_cxx_entrypoint + assert primitive_record_by_name["diag_rtu"].program_layer_status == "callable" + assert primitive_record_by_name["diag_rtu"].program_forward_status == "callable" + assert primitive_record_by_name["diag_rtu"].program_backward_status == "callable" + assert primitive_record_by_name["diag_rtu"].program_layer_blocker_code == "" + assert primitive_record_by_name["diag_rtu"].program_forward_output_contracts + assert primitive_record_by_name["diag_rtu"].program_reverse_native_callable + assert primitive_record_by_name["tanh"].status == "cuda" + assert primitive_record_by_name["tanh"].program_layer_status == "callable" + assert primitive_record_by_name["tanh"].program_forward_status == "callable" + assert primitive_record_by_name["tanh"].program_backward_status == "callable" + assert primitive_record_by_name["tanh"].program_forward_symbol == "program_transition_tanh_forward" + assert primitive_record_by_name["tanh"].program_backward_symbol == "program_transition_tanh_backward" + assert primitive_record_by_name["tanh"].program_reverse_native_callable == "native.reverse.transition_tanh.v1" + assert primitive_record_by_name["tanh"].reverse_input_bindings == ("output", "grad_output") + assert primitive_record_by_name["tanh"].reverse_output_bindings == ("grad_input",) + assert primitive_record_by_name["tanh"].program_forward_output_contracts == ( + ("output", "transition_forward_unary_output", "hidden", "primitive_row"), + ) + assert all( + transition_primitive_program_contract_blocker_code(record) == "" + for record in primitive_record_by_name.values() + if record.program_layer_status == "callable" + ) + assert ( + transition_primitive_program_contract_blocker_code( + replace(primitive_record_by_name["matmul"], program_reverse_native_callable="") + ) + == "INCOMPLETE_FUSED_TRANSITION_PRIMITIVE_REVERSE_CONTRACT" + ) + assert ( + transition_primitive_program_contract_blocker_code( + replace(primitive_record_by_name["matmul"], program_forward_output_contracts=()) + ) + == "INCOMPLETE_FUSED_TRANSITION_PRIMITIVE_FORWARD_CONTRACT" + ) + assert ( + transition_primitive_program_contract_blocker_code( + replace( + primitive_record_by_name["matmul"], + tape_saved_input_bindings=(), + tape_saved_output_bindings=(), + tape_recompute_input_bindings=(), + tape_recompute_output_bindings=(), + ) + ) + == "INCOMPLETE_FUSED_TRANSITION_PRIMITIVE_TAPE_CONTRACT" + ) + assert ( + transition_primitive_program_contract_blocker_code( + replace( + primitive_record_by_name["matmul"], + program_forward_input_bindings=("input", "input"), + ) + ) + == "INVALID_FUSED_TRANSITION_PRIMITIVE_FORWARD_BINDING_CONTRACT" + ) + assert ( + transition_primitive_program_contract_blocker_code( + replace( + primitive_record_by_name["matmul"], + program_forward_output_contracts=( + ("recurrent_gate_logits", "transition_forward_matmul_output", "unknown_shape", "primitive_row"), + ), + ) + ) + == "INVALID_FUSED_TRANSITION_PRIMITIVE_OUTPUT_CONTRACT" + ) + assert ( + transition_primitive_program_contract_blocker_code( + replace(primitive_record_by_name["matmul"], tape_recompute_input_bindings=("unknown_input",)) + ) + == "INVALID_FUSED_TRANSITION_PRIMITIVE_TAPE_INPUT_CONTRACT" + ) + assert ( + transition_primitive_program_contract_blocker_code( + replace(primitive_record_by_name["matmul"], tape_recompute_output_bindings=("unknown_output",)) + ) + == "INVALID_FUSED_TRANSITION_PRIMITIVE_TAPE_OUTPUT_CONTRACT" + ) + assert ( + transition_primitive_program_contract_blocker_code( + replace(primitive_record_by_name["matmul"], param_grad_outputs=(("grad_unknown", "unknown_weight", "x"),)) + ) + == "INVALID_FUSED_TRANSITION_PRIMITIVE_PARAM_GRAD_CONTRACT" + ) + assert primitive_record_by_name["linear"].tape_saved_input_bindings == ("input",) + assert primitive_record_by_name["norm_or_identity"].tape_recompute_output_bindings == ("output",) + assert primitive_record_by_name["tanh"].tape_saved_output_bindings == ("output",) + assert primitive_record_by_name["diag_rtu"].tape_saved_output_bindings == ("preproj",) + + generic_primitive_dag_program = SimpleNamespace( + primitive_ops=( + SimpleNamespace( + name="linear", + inputs=("aggregated_message", "value_to_state_weight", "recurrent_bias"), + outputs=("transition_input",), + parameter_inputs=("value_to_state_weight", "recurrent_bias"), + ), + SimpleNamespace( + name="norm_or_identity", + inputs=("transition_input", "outnorm_weight"), + outputs=("public_y",), + parameter_inputs=("outnorm_weight",), + ), + ), + state_inputs=(), + state_outputs=(), + public_outputs=("public_y",), + message_inputs=("aggregated_message",), + parameter_inputs=("value_to_state_weight", "recurrent_bias", "outnorm_weight"), + ) + generic_plan = select_transition_program_executor(generic_primitive_dag_program) + assert generic_plan.executor == "primitive_dag" + assert generic_plan.registry_id == "transition_executor:primitive_dag:v1" + assert generic_plan.selection_kind == "primitive_dag" + assert generic_plan.runtime_execution_status == "registered_primitive_dag_program" + assert generic_plan.forward_strategy_id == "forward.transition.primitive_dag.v1" + assert generic_plan.backward_strategy_id == "reverse.transition.primitive_dag.v1" + assert generic_plan.program_layer_status == "callable" + assert generic_plan.program_layer_blocker_codes == () + assert generic_plan.primitive_dag is not None + assert tuple(edge.summary for edge in generic_plan.primitive_dag.tensor_edges) == ( + "0:transition_input->1:transition_input", + ) + assert tuple(contract.summary for contract in generic_plan.primitive_dag.tape_contracts) == ( + "op=0:linear,policy=input_projection_tape_or_recompute,save_inputs=aggregated_message," + "save_outputs=-,recompute_inputs=aggregated_message,recompute_outputs=transition_input," + "reverse_inputs=input,grad_output", + "op=1:norm_or_identity,policy=recompute_or_full_tape,save_inputs=transition_input," + "save_outputs=-,recompute_inputs=transition_input,recompute_outputs=public_y," + "reverse_inputs=input,grad_output", + ) + assert generic_plan.primitive_dag.external_inputs == ( + "aggregated_message", + "value_to_state_weight", + "recurrent_bias", + "outnorm_weight", + ) + assert "program_transition_linear_forward" in generic_plan.program_forward_symbols + assert "program_transition_norm_or_identity_backward" in generic_plan.program_backward_symbols + + forward_reference_program = SimpleNamespace( + primitive_ops=( + SimpleNamespace( + name="norm_or_identity", + inputs=("future_tensor", "outnorm_weight"), + outputs=("public_y",), + parameter_inputs=("outnorm_weight",), + ), + SimpleNamespace( + name="linear", + inputs=("aggregated_message", "value_to_state_weight", "recurrent_bias"), + outputs=("future_tensor",), + parameter_inputs=("value_to_state_weight", "recurrent_bias"), + ), + ), + state_inputs=(), + state_outputs=(), + public_outputs=("public_y",), + message_inputs=("aggregated_message",), + parameter_inputs=("value_to_state_weight", "recurrent_bias", "outnorm_weight"), + ) + with pytest.raises(TransitionProgramExecutorSelectionError) as dag_error: + select_transition_program_executor(forward_reference_program) + assert dag_error.value.code == "ILLEGAL_TRANSITION_PRIMITIVE_DAG" + + unary_primitive_dag_program = SimpleNamespace( + primitive_ops=(SimpleNamespace(name="tanh", inputs=("x",), outputs=("y",)),), + state_inputs=(), + state_outputs=(), + message_inputs=("x",), + public_outputs=("y",), + ) + unary_plan = select_transition_program_executor(unary_primitive_dag_program) + assert unary_plan.executor == "primitive_dag" + assert unary_plan.program_layer_status == "callable" + assert unary_plan.program_forward_symbols == ("program_transition_tanh_forward",) + assert unary_plan.program_backward_symbols == ("program_transition_tanh_backward",) + assert unary_plan.primitive_dag is not None + assert tuple(contract.summary for contract in unary_plan.primitive_dag.tape_contracts) == ( + "op=0:tanh,policy=input_tape_or_recompute,save_inputs=-,save_outputs=y," + "recompute_inputs=x,recompute_outputs=y,reverse_inputs=output,grad_output", + ) + + unregistered_primitive_program = SimpleNamespace( + primitive_ops=(SimpleNamespace(name="unregistered_op", inputs=("x",), outputs=("y",)),), + state_inputs=(), + state_outputs=(), + message_inputs=("x",), + ) + with pytest.raises(TransitionProgramExecutorSelectionError) as missing_error: + select_transition_program_executor(unregistered_primitive_program) + assert missing_error.value.code == "UNREGISTERED_TRANSITION_PRIMITIVE" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for registered program primitive executor") +def test_registered_parameter_reducer_cuda_executes_transition_trainable_rows() -> None: + from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_parameter_reducer_program_cuda, + ) + + transition_reducer_callable = temporal_strategy_id_hash("native.reverse.parameter_reduction.transition.v1") + materialized_base_callable = temporal_strategy_id_hash( + "native.reverse.parameter_reduction.transition.materialized_base.v1" + ) + materialized_delta_callable = temporal_strategy_id_hash( + "native.reverse.parameter_reduction.transition.materialized_delta.v1" + ) + value_to_cell_msg_to_cell_callable = temporal_strategy_id_hash( + "native.reverse.parameter_reduction.transition.value_to_cell_msg_to_cell.v1" + ) + value_to_cell_msg_out_callable = temporal_strategy_id_hash( + "native.reverse.parameter_reduction.transition.value_to_cell_msg_out.v1" + ) + recurrent_bias_slot_embed_callable = temporal_strategy_id_hash( + "native.reverse.parameter_reduction.transition.recurrent_bias_slot_embed.v1" + ) + recurrent_bias_cell_bias_proj_callable = temporal_strategy_id_hash( + "native.reverse.parameter_reduction.transition.recurrent_bias_cell_bias_proj.v1" + ) + source_grad = torch.randn(3, 4, device="cuda") + target_base = torch.randn(3, 4, device="cuda") + target_delta = torch.randn(3, 4, device="cuda") + target_params = (target_base, target_delta) + outputs = registered_temporal_parameter_reducer_program_cuda( + parameter_reducer_rows=torch.tensor([[0, 4, 0, 0, 0, 0, 1, 0]], dtype=torch.long), + parameter_reducer_strategy_rows=torch.tensor( + [[0, 4, 4, 0, 0, 0, 0, 0, transition_reducer_callable]], + dtype=torch.long, + ), + parameter_reducer_trainable_role_rows=torch.empty((0, 6), dtype=torch.long), + parameter_reducer_runtime_metadata_rows=torch.empty((0, 4), dtype=torch.long), + transition_source_rows=torch.tensor([[0, 0, 0, 1, 0, 1, 0, 0]], dtype=torch.long), + transition_trainable_rows=torch.tensor( + [ + [0, 0, 0, 1, 0, -1, 0, 0, materialized_base_callable], + [1, 0, 0, 2, 1, -1, 0, 0, materialized_delta_callable], + ], + dtype=torch.long, + ), + sender_grad_weight_tensors=(), + sender_group_id_tensors=(), + sender_grouped_flags=torch.empty(0, dtype=torch.long), + readout_grad_value_to_output_weight_tensors=(), + readout_grad_output_cell_bias_tensors=(), + recurrent_query_grad_tensors=(), + output_query_grad_tensors=(), + message_strategy_grad_tensors=(), + message_strategy_grad_rows=torch.empty((0, 5), dtype=torch.long), + transition_source_tensors=(source_grad,), + transition_source_recurrent_cell_idx_tensors=(torch.empty(0, device="cuda", dtype=torch.long),), + parameter_output_tensors=tuple(torch.zeros_like(param) for param in target_params), + trainable_param_tensors=target_params, + runtime_metadata_tensors=(), + coord_count=3, + head_dim=2, + value_dim=2, + ) + + assert len(outputs) == 2 + torch.testing.assert_close(outputs[0], source_grad) + torch.testing.assert_close(outputs[1], source_grad) + + msg_to_cell = torch.randn(2, 4, device="cuda") + msg_out = torch.randn(4, 3, device="cuda") + static_params = (msg_to_cell, msg_out) + source_value_to_cell = torch.randn(2, 3, device="cuda") + static_outputs = registered_temporal_parameter_reducer_program_cuda( + parameter_reducer_rows=torch.tensor([[0, 4, 0, 0, 0, 0, 1, 0]], dtype=torch.long), + parameter_reducer_strategy_rows=torch.tensor( + [[0, 4, 4, 0, 0, 0, 0, 0, transition_reducer_callable]], + dtype=torch.long, + ), + parameter_reducer_trainable_role_rows=torch.empty((0, 6), dtype=torch.long), + parameter_reducer_runtime_metadata_rows=torch.empty((0, 4), dtype=torch.long), + transition_source_rows=torch.tensor([[0, 0, 0, 2, 0, 1, 0, 0]], dtype=torch.long), + transition_trainable_rows=torch.tensor( + [ + [0, 0, 0, 3, 0, 1, 0, 0, value_to_cell_msg_to_cell_callable], + [1, 0, 0, 4, 1, 0, 0, 0, value_to_cell_msg_out_callable], + ], + dtype=torch.long, + ), + sender_grad_weight_tensors=(), + sender_group_id_tensors=(), + sender_grouped_flags=torch.empty(0, dtype=torch.long), + readout_grad_value_to_output_weight_tensors=(), + readout_grad_output_cell_bias_tensors=(), + recurrent_query_grad_tensors=(), + output_query_grad_tensors=(), + message_strategy_grad_tensors=(), + message_strategy_grad_rows=torch.empty((0, 5), dtype=torch.long), + transition_source_tensors=(source_value_to_cell,), + transition_source_recurrent_cell_idx_tensors=(torch.empty(0, device="cuda", dtype=torch.long),), + parameter_output_tensors=tuple(torch.zeros_like(param) for param in static_params), + trainable_param_tensors=static_params, + runtime_metadata_tensors=(), + coord_count=3, + head_dim=2, + value_dim=2, + ) + torch.testing.assert_close(static_outputs[0], source_value_to_cell.matmul(msg_out.t())) + torch.testing.assert_close(static_outputs[1], msg_to_cell.t().matmul(source_value_to_cell)) + + recurrent_bias_grad = torch.randn(2, 3, device="cuda") + recurrent_idx = torch.tensor([0, 2], device="cuda") + slot_embed_param = torch.randn(4, 5, device="cuda") + cell_bias_proj = torch.randn(3, 5, device="cuda") + bias_params = (slot_embed_param, cell_bias_proj) + bias_outputs = registered_temporal_parameter_reducer_program_cuda( + parameter_reducer_rows=torch.tensor([[0, 4, 0, 0, 0, 0, 1, 0]], dtype=torch.long), + parameter_reducer_strategy_rows=torch.tensor( + [[0, 4, 4, 0, 0, 0, 0, 0, transition_reducer_callable]], + dtype=torch.long, + ), + parameter_reducer_trainable_role_rows=torch.empty((0, 6), dtype=torch.long), + parameter_reducer_runtime_metadata_rows=torch.empty((0, 4), dtype=torch.long), + transition_source_rows=torch.tensor([[0, 0, 0, 2, 0, 1, 0, 0]], dtype=torch.long), + transition_trainable_rows=torch.tensor( + [ + [0, 0, 0, 5, 0, 1, 0, 0, recurrent_bias_slot_embed_callable], + [1, 0, 0, 6, 1, 0, 0, 0, recurrent_bias_cell_bias_proj_callable], + ], + dtype=torch.long, + ), + sender_grad_weight_tensors=(), + sender_group_id_tensors=(), + sender_grouped_flags=torch.empty(0, dtype=torch.long), + readout_grad_value_to_output_weight_tensors=(), + readout_grad_output_cell_bias_tensors=(), + recurrent_query_grad_tensors=(), + output_query_grad_tensors=(), + message_strategy_grad_tensors=(), + message_strategy_grad_rows=torch.empty((0, 5), dtype=torch.long), + transition_source_tensors=(recurrent_bias_grad,), + transition_source_recurrent_cell_idx_tensors=(recurrent_idx,), + parameter_output_tensors=tuple(torch.zeros_like(param) for param in bias_params), + trainable_param_tensors=bias_params, + runtime_metadata_tensors=(), + coord_count=4, + head_dim=2, + value_dim=2, + ) + full_bias_grad = torch.zeros(4, 3, device="cuda") + full_bias_grad.index_add_(0, recurrent_idx, recurrent_bias_grad) + torch.testing.assert_close(bias_outputs[0], full_bias_grad.matmul(cell_bias_proj)) + torch.testing.assert_close(bias_outputs[1], full_bias_grad.t().matmul(slot_embed_param)) + + fixed_slot_context_callable = temporal_strategy_id_hash( + "native.reverse.parameter_reduction.fixed_slot_context_message.v1" + ) + slot_embed = torch.randn(4, 5, device="cuda") + query_slot_proj = torch.randn(2, 5, device="cuda") + sender_slot_key_proj = torch.randn(2, 5, device="cuda") + query_context_scalar = torch.randn(1, device="cuda") + sender_context_key = torch.randn(4, 2, device="cuda") + msg_out = torch.randn(3, 3, device="cuda") + recurrent_inverse = torch.tensor([1, 0], device="cuda") + recurrent_cells = torch.tensor([1, 3], device="cuda") + input_cells = torch.tensor([0, 2], device="cuda") + grad_query_slot_backend = torch.randn(2, 2, device="cuda") + grad_input_key_bank = torch.randn(2, 2, 4, device="cuda") + grad_recurrent_key_bank = torch.randn(2, 2, 4, device="cuda") + reduced_input_key_bank = grad_input_key_bank.sum(dim=0).contiguous() + reduced_recurrent_key_bank = grad_recurrent_key_bank.sum(dim=0).contiguous() + grad_query_context_scalar = torch.randn(1, device="cuda") + grad_message_output = torch.randn(3, 3, device="cuda") + context_params = ( + slot_embed, + query_slot_proj, + sender_slot_key_proj, + query_context_scalar, + sender_context_key, + msg_out, + ) + context_outputs = registered_temporal_parameter_reducer_program_cuda( + parameter_reducer_rows=torch.tensor([[0, 6, 0, 0, 0, 0, 5, 0]], dtype=torch.long), + parameter_reducer_strategy_rows=torch.tensor( + [[0, 6, 6, 5, 1, 0, 0, 0, fixed_slot_context_callable]], + dtype=torch.long, + ), + parameter_reducer_trainable_role_rows=torch.tensor( + [ + [0, 5, 0, 0, 0, 0], + [1, 9, 1, 0, 0, 0], + [2, 10, 2, 0, 0, 0], + [3, 11, 3, 0, 0, 0], + [4, 12, 4, 0, 0, 0], + [5, 6, 5, 0, 0, 0], + ], + dtype=torch.long, + ), + parameter_reducer_runtime_metadata_rows=torch.tensor( + [ + [0, 1, 0, 0], + [1, 2, 1, 0], + [2, 4, 2, 0], + ], + dtype=torch.long, + ), + transition_source_rows=torch.empty((0, 8), dtype=torch.long), + transition_trainable_rows=torch.empty((0, 9), dtype=torch.long), + sender_grad_weight_tensors=(), + sender_group_id_tensors=(), + sender_grouped_flags=torch.empty(0, dtype=torch.long), + readout_grad_value_to_output_weight_tensors=(), + readout_grad_output_cell_bias_tensors=(), + recurrent_query_grad_tensors=(), + output_query_grad_tensors=(), + message_strategy_grad_tensors=( + grad_query_slot_backend, + reduced_input_key_bank, + reduced_recurrent_key_bank, + grad_query_context_scalar, + grad_message_output, + ), + message_strategy_grad_rows=torch.tensor( + [ + [0, 6, 1, 0, 0], + [1, 6, 2, 1, 0], + [2, 6, 3, 2, 0], + [3, 6, 4, 3, 0], + [4, 6, 5, 4, 0], + ], + dtype=torch.long, + ), + transition_source_tensors=(), + transition_source_recurrent_cell_idx_tensors=(), + parameter_output_tensors=tuple(torch.zeros_like(param) for param in context_params), + trainable_param_tensors=context_params, + runtime_metadata_tensors=(recurrent_inverse, recurrent_cells, input_cells), + coord_count=4, + head_dim=2, + value_dim=3, + ) + query_full = torch.zeros(4, 2, device="cuda") + query_full.index_add_(0, recurrent_cells, grad_query_slot_backend.index_select(0, recurrent_inverse)) + key_full = torch.zeros(4, 2, device="cuda") + context_full = torch.zeros(4, 2, device="cuda") + reduced_input_key = reduced_input_key_bank + key_full.index_add_(0, input_cells, reduced_input_key[:, :2]) + context_full.index_add_(0, input_cells, reduced_input_key[:, 2:]) + reduced_recurrent_key = reduced_recurrent_key_bank.index_select(0, recurrent_inverse) + key_full.index_add_(0, recurrent_cells, reduced_recurrent_key[:, :2]) + context_full.index_add_(0, recurrent_cells, reduced_recurrent_key[:, 2:]) + torch.testing.assert_close( + context_outputs[0], + query_full.matmul(query_slot_proj) + key_full.matmul(sender_slot_key_proj), + ) + torch.testing.assert_close(context_outputs[1], query_full.t().matmul(slot_embed)) + torch.testing.assert_close(context_outputs[2], key_full.t().matmul(slot_embed)) + torch.testing.assert_close(context_outputs[3], grad_query_context_scalar) + torch.testing.assert_close(context_outputs[4], context_full) + torch.testing.assert_close(context_outputs[5], grad_message_output) + + +def _gated_logspace_core_reference( + gate_logits: torch.Tensor, + recurrent_gate_logits: torch.Tensor, + c_prev: torch.Tensor, + n_prev: torch.Tensor, + m_prev: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + raw = gate_logits + recurrent_gate_logits + iraw, fraw, zraw, oraw = raw.unbind(dim=2) + logfplusm = m_prev + torch.nn.functional.logsigmoid(fraw) + is_first = n_prev == 0.0 + m_next = torch.where(is_first, iraw, torch.maximum(iraw, logfplusm)) + i_gate = torch.exp(iraw - m_next).clamp(max=1.0) + f_gate = torch.exp(logfplusm - m_next).clamp(max=1.0) + c_next = f_gate * c_prev + i_gate * torch.tanh(zraw) + n_next = f_gate * n_prev + i_gate + y_next = torch.sigmoid(oraw) * c_next / (n_next + 1.0e-6) + return y_next, c_next, n_next, m_next + + +def _norm_or_identity_reference( + input: torch.Tensor, + weight: torch.Tensor | None, *, - hidden_size: int = 8, - max_delay: int | None = None, - patch_edges_per_cell: int = 0, -): - return init( - Config( - width=8, - height=8, - hidden_size=hidden_size, - cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, - population_mix={"slstm": 1.0}, - patch_edges_per_cell=patch_edges_per_cell, - patch_min_dist=3.0 if patch_edges_per_cell > 0 else 0.0, - patch_max_dist=4.0 if patch_edges_per_cell > 0 else 0.0, - projection_region_shape=(2, 2), - local_radius=1.5, - conduction_speed=1.0 if max_delay is not None else None, - max_delay=max_delay, - seed=11, - ) + eps: float, +) -> torch.Tensor: + if weight is None: + return input + mean = input.mean(dim=-1, keepdim=True) + var = torch.clamp((input * input).mean(dim=-1, keepdim=True) - mean * mean, min=0.0) + return (input - mean) * torch.rsqrt(var + float(eps)) * weight.unsqueeze(0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for fused registered program executor") +def test_fused_forward_transition_program_cuda_dispatches_registered_tanh_callable() -> None: + from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_fused_forward_transition_program_cuda, ) + batch, receivers, hidden = 2, 3, 7 + primitive_opcode = temporal_primitive_opcode("tanh") + input = torch.randn(batch, receivers, hidden, device="cuda") + output_placeholder = input.new_empty(0) + output_buffer = input.new_empty(batch, receivers, hidden) + primitive_rows = torch.tensor([[primitive_opcode, 0, receivers, 0]], dtype=torch.long) + forward_executor_rows = torch.tensor([[3, 0, 1, 0, 0, receivers]], dtype=torch.long) + strategy_hash = temporal_strategy_id_hash("forward.transition.tanh.v1") + native_callable_hash = temporal_strategy_id_hash("native.forward.transition_tanh.v1") + forward_handler_rows = torch.tensor( + [[3, 4, 3, primitive_opcode, 1, 4, 105, strategy_hash, 1, 0, 0]], + dtype=torch.long, + ) + native_strategy_rows = torch.tensor( + [[1, 4, 3, 3, primitive_opcode, 1, 4, 105, 1, 1, 1, 1, strategy_hash, 1, 0, 0, native_callable_hash]], + dtype=torch.long, + ) + forward_binding_rows = torch.tensor( + [ + [1, 0, 3, 0, 0, 0, 0, 0], + [1, 0, 3, 0, 1, 0, 2, 0], + ], + dtype=torch.long, + ) + program_tensor_binding_rows = torch.tensor( + [ + [0, 0, 0, 0], + [1, 1, 0, 2], + ], + dtype=torch.long, + ) + memory_liveness_rows = torch.tensor( + [ + [0, 0, 0, 4, 13, 4, 3, 9, 5, 1], + [1, 0, 0, 4, 13, 4, 3, 4, 5, 1], + [2, 0, 0, 4, 13, 4, 3, 11, 5, 1], + [3, 0, 0, 4, 12, 7, 8, 12, 7, 1], + ], + dtype=torch.long, + ) + runtime_role = temporal_runtime_buffer_role_opcode("transition_forward_unary_output") + runtime_buffer_rows = torch.tensor( + [[0, 3, 8, 4, 0, 12, 0, 0, runtime_role, 0]], + dtype=torch.long, + ) -def _make_mixed_spec(): - return init( - Config( - width=8, - height=8, - hidden_size=8, - cell_populations={ - "slstm": CellPopulationConfig(cell_type="slstm"), - "axoncell": CellPopulationConfig(cell_type="axoncell"), - }, - population_mix={"slstm": 0.5, "axoncell": 0.5}, - patch_edges_per_cell=2, - patch_min_dist=3.0, - patch_max_dist=4.0, - projection_region_shape=(2, 2), - seed=13, - ) + outputs = registered_temporal_fused_forward_transition_program_cuda( + program_tensors=(input, output_placeholder), + program_tensor_binding_rows=program_tensor_binding_rows, + runtime_buffer_tensors=(output_buffer,), + runtime_buffer_rows=runtime_buffer_rows, + primitive_rows=primitive_rows, + forward_executor_rows=forward_executor_rows, + forward_handler_rows=forward_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_binding_schema_rows=temporal_native_callable_binding_schema_rows_tensor(), + native_callable_output_rows=temporal_native_callable_output_rows_tensor(), + transition_primitive_callable_rows=temporal_transition_primitive_native_callable_rows_tensor(), + forward_executor_binding_rows=forward_binding_rows, + memory_liveness_rows=memory_liveness_rows, ) + assert len(outputs) == 2 + torch.testing.assert_close(outputs[1], torch.tanh(input), rtol=1.0e-5, atol=1.0e-5) -def _make_axon_spec(*, hidden_size: int = 8): - return init( - Config( - width=8, - height=8, - hidden_size=hidden_size, - cell_populations={"axoncell": CellPopulationConfig(cell_type="axoncell")}, - population_mix={"axoncell": 1.0}, - patch_edges_per_cell=0, - projection_region_shape=(2, 2), - seed=17, - ) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for fused registered program executor") +def test_fused_reverse_transition_program_cuda_dispatches_registered_tanh_callable() -> None: + from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_fused_reverse_transition_program_cuda, + ) + + batch, receivers, hidden = 2, 3, 7 + primitive_opcode = temporal_primitive_opcode("tanh") + output = torch.tanh(torch.randn(batch, receivers, hidden, device="cuda")) + grad_output = torch.randn_like(output) + grad_input_placeholder = output.new_empty(0) + primitive_rows = torch.tensor([[primitive_opcode, 0, receivers, 0]], dtype=torch.long) + reverse_executor_rows = torch.tensor([[6, 0, 1, 0, 0, receivers]], dtype=torch.long) + native_strategy_rows = temporal_native_executor_strategy_rows_tensor() + strategy_row = next( + row + for row in native_strategy_rows.tolist() + if int(row[0]) == 2 and int(row[2]) == 6 and int(row[4]) == primitive_opcode + ) + reverse_handler_rows = torch.tensor( + [ + [ + strategy_row[2], + strategy_row[1], + strategy_row[3], + strategy_row[4], + strategy_row[5], + strategy_row[6], + strategy_row[7], + strategy_row[12], + strategy_row[13], + strategy_row[14], + strategy_row[15], + ] + ], + dtype=torch.long, + ) + reverse_binding_rows = torch.tensor( + [ + [2, 0, 6, 0, 0, 0, 0, 0], + [2, 0, 6, 0, 1, 0, 0, 1], + [2, 0, 6, 0, 2, 0, 2, 0], + ], + dtype=torch.long, + ) + program_tensor_binding_rows = torch.tensor( + [ + [0, 0, 0, 2], + [1, 1, 0, 0], + [2, 2, 0, 2], + ], + dtype=torch.long, + ) + memory_liveness_rows = torch.tensor( + [ + [0, 0, 0, 4, 13, 4, 8, 1, 5, 1], + [1, 0, 0, 4, 13, 4, 8, 4, 5, 1], + [2, 0, 0, 4, 13, 4, 8, 6, 5, 1], + [3, 0, 0, 4, 13, 4, 8, 9, 5, 1], + [4, 0, 0, 4, 13, 4, 8, 11, 5, 1], + [5, 0, 0, 4, 13, 4, 8, 12, 5, 1], + ], + dtype=torch.long, ) + outputs = registered_temporal_fused_reverse_transition_program_cuda( + program_tensors=(output, grad_output, grad_input_placeholder), + program_tensor_binding_rows=program_tensor_binding_rows, + primitive_rows=primitive_rows, + reverse_executor_rows=reverse_executor_rows, + reverse_handler_rows=reverse_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_binding_schema_rows=temporal_native_callable_binding_schema_rows_tensor(), + native_callable_output_rows=temporal_native_callable_output_rows_tensor(), + transition_primitive_callable_rows=temporal_transition_primitive_native_callable_rows_tensor(), + reverse_executor_binding_rows=reverse_binding_rows, + memory_liveness_rows=memory_liveness_rows, + ) -def _make_explicit_graph_spec(cell_type: str): - return init( - Config( - width=4, - height=4, - hidden_size=8, - cell_populations={cell_type: CellPopulationConfig(cell_type=cell_type)}, - population_mix={cell_type: 1.0}, - input_cell_indices=(0, 1), - output_cell_indices=(14, 15), - graph_edges=((2, 3), (3, 2), (4, 2), (14, 13), (15, 13)), - kv_group_ids=tuple(idx // 2 for idx in range(16)), - seed=19, - ) + assert len(outputs) == 3 + torch.testing.assert_close( + outputs[2], + grad_output * (1.0 - output * output), + rtol=1.0e-5, + atol=1.0e-5, ) -def test_fabric_backend_ir_compiles_receiver_sets_and_buckets() -> None: - runtime = build(_make_mixed_spec()) - ir = runtime.backend_ir +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for fused registered program executor") +def test_fused_reverse_transition_program_cuda_dispatches_registered_linear_callable() -> None: + from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_fused_reverse_transition_program_cuda, + ) - assert ir.num_cells == runtime.coords.shape[0] - assert ir.num_input_ports == runtime.input_cell_idx.numel() - assert ir.num_recurrent_cells == runtime.recurrent_cell_idx.numel() - assert ir.num_output_ports == runtime.output_cell_idx.numel() - assert ir.bucket_count > 0 - assert any(bucket.receiver_kind.value == "recurrent_cell" for bucket in ir.buckets) - assert any(bucket.receiver_kind.value == "output_port" for bucket in ir.buckets) - assert any(bucket.has_sparse_overlay for bucket in ir.buckets) - assert ir.graph_summary.node_count == ir.num_cells - assert ir.graph_summary.input_count == ir.num_input_ports - assert ir.graph_summary.output_count == ir.num_output_ports - assert ir.graph_summary.recurrent_count == ir.num_recurrent_cells - assert ir.graph_summary.flat_signature.node_count == ir.num_cells - assert ir.graph_summary.flat_signature.degree_histogram == ir.graph_summary.degree_histogram + batch, receivers, input_dim, output_dim = 2, 3, 4, 5 + input = torch.randn(batch, receivers, input_dim, device="cuda") + weight = torch.randn(receivers, input_dim, output_dim, device="cuda") + bias = torch.randn(receivers, output_dim, device="cuda") + grad_output = torch.randn(batch, receivers, output_dim, device="cuda") + grad_input_placeholder = input.new_empty(0) + grad_weight_placeholder = weight.new_empty(0) + grad_bias_placeholder = bias.new_empty(0) + primitive_opcode = temporal_primitive_opcode("linear") + primitive_rows = torch.tensor([[primitive_opcode, 0, receivers, 0]], dtype=torch.long) + reverse_executor_rows = torch.tensor([[7, 0, 1, 0, 0, receivers]], dtype=torch.long) + native_strategy_rows = temporal_native_executor_strategy_rows_tensor() + strategy_row = next( + row + for row in native_strategy_rows.tolist() + if int(row[0]) == 2 and int(row[2]) == 7 and int(row[4]) == primitive_opcode + ) + reverse_handler_rows = torch.tensor( + [ + [ + strategy_row[2], + strategy_row[1], + strategy_row[3], + strategy_row[4], + strategy_row[5], + strategy_row[6], + strategy_row[7], + strategy_row[12], + strategy_row[13], + strategy_row[14], + strategy_row[15], + ] + ], + dtype=torch.long, + ) + reverse_binding_rows = torch.tensor( + [ + [2, 0, 7, 0, 0, 0, 0, 0], + [2, 0, 7, 0, 1, 0, 0, 1], + [2, 0, 7, 0, 2, 0, 1, 0], + [2, 0, 7, 0, 3, 0, 1, 1], + [2, 0, 7, 0, 4, 0, 2, 0], + [2, 0, 7, 0, 5, 0, 2, 1], + [2, 0, 7, 0, 6, 0, 2, 2], + ], + dtype=torch.long, + ) + program_tensor_binding_rows = torch.tensor( + [ + [0, 0, 0, 2], + [1, 1, 0, 0], + [2, 2, 0, 1], + [3, 3, 0, 1], + [4, 4, 0, 2], + [5, 5, 0, 2], + [6, 6, 0, 2], + ], + dtype=torch.long, + ) + memory_liveness_rows = torch.tensor( + [ + [0, 0, 0, 4, 13, 4, 8, 1, 5, 1], + [1, 0, 0, 4, 13, 4, 8, 4, 5, 1], + [2, 0, 0, 4, 13, 4, 8, 6, 5, 1], + [3, 0, 0, 4, 13, 4, 8, 9, 5, 1], + [4, 0, 0, 4, 13, 4, 8, 11, 5, 1], + [5, 0, 0, 4, 13, 4, 8, 12, 5, 1], + ], + dtype=torch.long, + ) + outputs = registered_temporal_fused_reverse_transition_program_cuda( + program_tensors=( + input, + grad_output, + weight, + bias, + grad_input_placeholder, + grad_weight_placeholder, + grad_bias_placeholder, + ), + program_tensor_binding_rows=program_tensor_binding_rows, + primitive_rows=primitive_rows, + reverse_executor_rows=reverse_executor_rows, + reverse_handler_rows=reverse_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_binding_schema_rows=temporal_native_callable_binding_schema_rows_tensor(), + native_callable_output_rows=temporal_native_callable_output_rows_tensor(), + transition_primitive_callable_rows=temporal_transition_primitive_native_callable_rows_tensor(), + reverse_executor_binding_rows=reverse_binding_rows, + memory_liveness_rows=memory_liveness_rows, + ) -def test_explicit_graph_uses_sparse_message_backend_without_patch_edges() -> None: - runtime = build(_make_explicit_graph_spec("slstm")) + ref_input = input.detach().clone().requires_grad_(True) + ref_weight = weight.detach().clone().requires_grad_(True) + ref_bias = bias.detach().clone().requires_grad_(True) + torch.einsum("brk,rkn->brn", ref_input, ref_weight).add(ref_bias.unsqueeze(0)).backward(grad_output) - assert runtime.config.patch_edges_per_cell == 0 - assert not runtime._local_message_step_enabled - assert runtime._uses_sparse_message_backend + assert len(outputs) == 7 + torch.testing.assert_close(outputs[4], ref_input.grad, rtol=1.0e-5, atol=1.0e-5) + torch.testing.assert_close(outputs[5], ref_weight.grad, rtol=1.0e-5, atol=1.0e-5) + torch.testing.assert_close(outputs[6], ref_bias.grad, rtol=1.0e-5, atol=1.0e-5) -def test_default_message_rule_contract_is_planner_visible() -> None: - runtime = build(_make_explicit_graph_spec("slstm")) - message_rule = runtime.backend_ir.message_rule +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for fused registered program executor") +def test_fused_reverse_transition_program_cuda_dispatches_registered_matmul_callable() -> None: + from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_fused_reverse_transition_program_cuda, + ) - assert message_rule.name == "dot_product" - assert message_rule.lowering_kind == "dot_product_segment_softmax_weighted_sum" - assert message_rule.output_boundary == "projected_message" - assert "sender_public_prev:reset=zero_source_rows:scope=batch_row" in message_rule.source_signature - assert any("sender_group_shared" in entry for entry in message_rule.parameter_sharing_signature) + batch, receivers, gates, heads, head_dim = 2, 3, 4, 2, 3 + hidden = heads * head_dim + input = torch.randn(batch, receivers, hidden, device="cuda") + weight = torch.randn(receivers, gates, heads, head_dim, head_dim, device="cuda") + grad_output = torch.randn(batch, receivers, gates, hidden, device="cuda") + grad_input_placeholder = input.new_empty(0) + grad_weight_placeholder = weight.new_empty(0) + primitive_opcode = temporal_primitive_opcode("matmul") + primitive_rows = torch.tensor([[primitive_opcode, 0, receivers, 0]], dtype=torch.long) + reverse_executor_rows = torch.tensor([[8, 0, 1, 0, 0, receivers]], dtype=torch.long) + native_strategy_rows = temporal_native_executor_strategy_rows_tensor() + strategy_row = next( + row + for row in native_strategy_rows.tolist() + if int(row[0]) == 2 and int(row[2]) == 8 and int(row[4]) == primitive_opcode + ) + reverse_handler_rows = torch.tensor( + [ + [ + strategy_row[2], + strategy_row[1], + strategy_row[3], + strategy_row[4], + strategy_row[5], + strategy_row[6], + strategy_row[7], + strategy_row[12], + strategy_row[13], + strategy_row[14], + strategy_row[15], + ] + ], + dtype=torch.long, + ) + reverse_binding_rows = torch.tensor( + [ + [2, 0, 8, 0, 0, 0, 0, 0], + [2, 0, 8, 0, 1, 0, 0, 1], + [2, 0, 8, 0, 2, 0, 1, 0], + [2, 0, 8, 0, 3, 0, 2, 0], + [2, 0, 8, 0, 4, 0, 2, 1], + ], + dtype=torch.long, + ) + program_tensor_binding_rows = torch.tensor( + [ + [0, 0, 0, 2], + [1, 1, 0, 0], + [2, 2, 0, 1], + [3, 3, 0, 2], + [4, 4, 0, 2], + ], + dtype=torch.long, + ) + memory_liveness_rows = torch.tensor( + [ + [0, 0, 0, 4, 13, 4, 8, 1, 5, 1], + [1, 0, 0, 4, 13, 4, 8, 4, 5, 1], + [2, 0, 0, 4, 13, 4, 8, 6, 5, 1], + [3, 0, 0, 4, 13, 4, 8, 9, 5, 1], + [4, 0, 0, 4, 13, 4, 8, 11, 5, 1], + [5, 0, 0, 4, 13, 4, 8, 12, 5, 1], + ], + dtype=torch.long, + ) - planned = runtime.plan_backend_execution( - batch_size=1, - time_steps=1, - inner_steps=1, - training=False, - surface_key="slstm_recurrence", + outputs = registered_temporal_fused_reverse_transition_program_cuda( + program_tensors=( + input, + grad_output, + weight, + grad_input_placeholder, + grad_weight_placeholder, + ), + program_tensor_binding_rows=program_tensor_binding_rows, + primitive_rows=primitive_rows, + reverse_executor_rows=reverse_executor_rows, + reverse_handler_rows=reverse_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_binding_schema_rows=temporal_native_callable_binding_schema_rows_tensor(), + native_callable_output_rows=temporal_native_callable_output_rows_tensor(), + transition_primitive_callable_rows=temporal_transition_primitive_native_callable_rows_tensor(), + reverse_executor_binding_rows=reverse_binding_rows, + memory_liveness_rows=memory_liveness_rows, ) - assert planned.bucket_plans - assert all( - plan.message_rule_lowering_kind == "dot_product_segment_softmax_weighted_sum" for plan in planned.bucket_plans + + ref_input = input.detach().clone().requires_grad_(True) + ref_weight = weight.detach().clone().requires_grad_(True) + torch.einsum( + "brhi,rghoi->brgho", + ref_input.view(batch, receivers, heads, head_dim), + ref_weight, + ).reshape(batch, receivers, gates, hidden).backward(grad_output) + + assert len(outputs) == 5 + torch.testing.assert_close(outputs[3], ref_input.grad, rtol=1.0e-5, atol=1.0e-5) + torch.testing.assert_close(outputs[4], ref_weight.grad, rtol=1.0e-5, atol=1.0e-5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for fused registered program executor") +def test_fused_reverse_transition_program_cuda_dispatches_registered_norm_callable() -> None: + from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_fused_reverse_transition_program_cuda, ) - assert all(plan.message_rule_name == "dot_product" for plan in planned.bucket_plans) - assert all(plan.message_rule_output_boundary == "projected_message" for plan in planned.bucket_plans) + + batch, receivers, hidden = 2, 3, 7 + eps = 1.0e-5 + input = torch.randn(batch, receivers, hidden, device="cuda") + weight = torch.randn(receivers, hidden, device="cuda") + grad_output = torch.randn_like(input) + eps_tensor = input.new_tensor(eps) + grad_input_placeholder = input.new_empty(0) + grad_weight_placeholder = weight.new_empty(0) + primitive_opcode = temporal_primitive_opcode("norm_or_identity") + primitive_rows = torch.tensor([[primitive_opcode, 0, receivers, 0]], dtype=torch.long) + reverse_executor_rows = torch.tensor([[9, 0, 1, 0, 0, receivers]], dtype=torch.long) + native_strategy_rows = temporal_native_executor_strategy_rows_tensor() + strategy_row = next( + row + for row in native_strategy_rows.tolist() + if int(row[0]) == 2 and int(row[2]) == 9 and int(row[4]) == primitive_opcode + ) + reverse_handler_rows = torch.tensor( + [ + [ + strategy_row[2], + strategy_row[1], + strategy_row[3], + strategy_row[4], + strategy_row[5], + strategy_row[6], + strategy_row[7], + strategy_row[12], + strategy_row[13], + strategy_row[14], + strategy_row[15], + ] + ], + dtype=torch.long, + ) + reverse_binding_rows = torch.tensor( + [ + [2, 0, 9, 0, 0, 0, 0, 0], + [2, 0, 9, 0, 1, 0, 0, 1], + [2, 0, 9, 0, 2, 0, 1, 0], + [2, 0, 9, 0, 3, 0, 1, 1], + [2, 0, 9, 0, 4, 0, 2, 0], + [2, 0, 9, 0, 5, 0, 2, 1], + ], + dtype=torch.long, + ) + program_tensor_binding_rows = torch.tensor( + [ + [0, 0, 0, 2], + [1, 1, 0, 0], + [2, 2, 0, 1], + [3, 3, 0, 1], + [4, 4, 0, 2], + [5, 5, 0, 2], + ], + dtype=torch.long, + ) + memory_liveness_rows = torch.tensor( + [ + [0, 0, 0, 4, 13, 4, 8, 1, 5, 1], + [1, 0, 0, 4, 13, 4, 8, 4, 5, 1], + [2, 0, 0, 4, 13, 4, 8, 6, 5, 1], + [3, 0, 0, 4, 13, 4, 8, 9, 5, 1], + [4, 0, 0, 4, 13, 4, 8, 11, 5, 1], + [5, 0, 0, 4, 13, 4, 8, 12, 5, 1], + ], + dtype=torch.long, + ) + + outputs = registered_temporal_fused_reverse_transition_program_cuda( + program_tensors=( + input, + grad_output, + weight, + eps_tensor, + grad_input_placeholder, + grad_weight_placeholder, + ), + program_tensor_binding_rows=program_tensor_binding_rows, + primitive_rows=primitive_rows, + reverse_executor_rows=reverse_executor_rows, + reverse_handler_rows=reverse_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_binding_schema_rows=temporal_native_callable_binding_schema_rows_tensor(), + native_callable_output_rows=temporal_native_callable_output_rows_tensor(), + transition_primitive_callable_rows=temporal_transition_primitive_native_callable_rows_tensor(), + reverse_executor_binding_rows=reverse_binding_rows, + memory_liveness_rows=memory_liveness_rows, + ) + + ref_input = input.detach().clone().requires_grad_(True) + ref_weight = weight.detach().clone().requires_grad_(True) + _norm_or_identity_reference(ref_input, ref_weight, eps=eps).backward(grad_output) + + assert len(outputs) == 6 + torch.testing.assert_close(outputs[4], ref_input.grad, rtol=1.0e-4, atol=1.0e-4) + torch.testing.assert_close(outputs[5], ref_weight.grad, rtol=1.0e-4, atol=1.0e-4) + + +def _diag_activation(value: torch.Tensor, activation_id: int) -> torch.Tensor: + if activation_id == 0: + return value * torch.sigmoid(value) + if activation_id == 1: + return torch.relu(value) + if activation_id == 2: + return torch.tanh(value) + return value + + +def _diag_rtu_core_reference( + cell_input: torch.Tensor, + hc1: torch.Tensor, + hc2: torch.Tensor, + nu_log: torch.Tensor, + theta_log: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + *, + activation_id: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + exp_nu = torch.exp(nu_log).unsqueeze(0) + radius = torch.exp(-exp_nu) + theta = torch.exp(theta_log).unsqueeze(0) + g = radius * torch.cos(theta) + phi = radius * torch.sin(theta) + gamma = torch.sqrt(torch.clamp(1.0 - radius * radius, min=0.0)) + c1 = gamma * w1.unsqueeze(0) * cell_input + g * hc1 - phi * hc2 + c2 = gamma * w2.unsqueeze(0) * cell_input + g * hc2 + phi * hc1 + preproj = torch.cat((_diag_activation(c1, activation_id), _diag_activation(c2, activation_id)), dim=-1) + return preproj, c1, c2 + + +def _gate_affine_reference( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + batch, receivers, hidden = input.shape + heads = int(weight.shape[1]) + head_dim = int(weight.shape[2]) + flat_input = input.view(batch, receivers, heads, head_dim).reshape(batch, receivers * heads, head_dim) + flat_weight = weight.reshape(receivers * heads, head_dim, 4 * head_dim) + flat_bias = bias.permute(0, 2, 1, 3).reshape(receivers * heads, 4 * head_dim) + gate_proj = torch.einsum("bsk,sko->bso", flat_input, flat_weight) + flat_bias.unsqueeze(0) + return ( + gate_proj.view(batch, receivers, heads, 4, head_dim).permute(0, 1, 3, 2, 4).reshape(batch, receivers, 4, hidden) + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for fused registered program executor") +def test_fused_forward_transition_program_cuda_dispatches_compiler_tensor_bindings() -> None: + from cortical.fabric.backend.cuda.sequence_surface.flat_bucket.flat_bucket_registered_program_cuda import ( + registered_temporal_fused_forward_transition_program_cuda, + ) + + batch, receivers, value_dim, heads, head_dim = 2, 3, 5, 2, 4 + hidden = heads * head_dim + eps = 1.0e-5 + recurrent_msg = torch.randn(batch, receivers, value_dim, device="cuda") + value_to_state_weight = torch.randn(receivers, value_dim, hidden, device="cuda") + recurrent_bias = torch.randn(receivers, hidden, device="cuda") + gate_weight = torch.randn(receivers, heads, head_dim, 4 * head_dim, device="cuda") + gate_bias = torch.randn(receivers, 4, heads, head_dim, device="cuda") + y_prev = torch.randn(batch, receivers, hidden, device="cuda") + recurrent_kernel = torch.randn(receivers, 4, heads, head_dim, head_dim, device="cuda") + c_prev = torch.randn(batch, receivers, hidden, device="cuda") + n_prev = torch.rand(batch, receivers, hidden, device="cuda") + 0.25 + m_prev = torch.randn(batch, receivers, hidden, device="cuda") + outnorm_weight = torch.randn(receivers, hidden, device="cuda") + outnorm_eps = torch.tensor(eps, device="cuda") + empty = recurrent_msg.new_empty(0) + primitive_rows = torch.tensor( + [ + [10, 0, receivers, 0], + [10, 0, receivers, 0], + [11, 0, receivers, 0], + [1, 0, receivers, 0], + [30, 0, receivers, 0], + ], + dtype=torch.long, + ) + forward_executor_rows = torch.tensor([[3, 0, 5, 0, 0, receivers]], dtype=torch.long) + strategy_hash = temporal_strategy_id_hash("forward.transition.gated_logspace.v1") + native_callable_hash = temporal_strategy_id_hash("native.forward.transition_gated_logspace.v1") + forward_handler_rows = torch.tensor([[3, 4, 3, 10, 5, 4, 105, strategy_hash, 2, 4, 0]], dtype=torch.long) + native_strategy_rows = torch.tensor( + [[1, 4, 3, 3, 10, 5, 4, 105, 1, 1, 1, 1, strategy_hash, 2, 4, 0, native_callable_hash]], + dtype=torch.long, + ) + forward_binding_rows = torch.tensor( + [ + [1, 0, 3, 0, 0, 0, 0, 0], + [1, 0, 3, 0, 1, 0, 1, 0], + [1, 0, 3, 0, 2, 0, 1, 1], + [1, 0, 3, 0, 3, 0, 2, 0], + [1, 0, 3, 1, 4, 0, 0, 0], + [1, 0, 3, 1, 5, 0, 1, 0], + [1, 0, 3, 1, 6, 0, 1, 1], + [1, 0, 3, 1, 7, 0, 2, 0], + [1, 0, 3, 2, 8, 0, 0, 0], + [1, 0, 3, 2, 9, 0, 1, 0], + [1, 0, 3, 2, 10, 0, 2, 0], + [1, 0, 3, 3, 11, 0, 0, 0], + [1, 0, 3, 3, 12, 0, 0, 1], + [1, 0, 3, 3, 13, 0, 0, 2], + [1, 0, 3, 3, 14, 0, 0, 3], + [1, 0, 3, 3, 15, 0, 0, 4], + [1, 0, 3, 3, 16, 0, 2, 0], + [1, 0, 3, 3, 17, 0, 2, 1], + [1, 0, 3, 3, 18, 0, 2, 2], + [1, 0, 3, 3, 19, 0, 2, 3], + [1, 0, 3, 4, 20, 0, 0, 0], + [1, 0, 3, 4, 21, 0, 1, 0], + [1, 0, 3, 4, 23, 0, 1, 1], + [1, 0, 3, 4, 22, 0, 2, 0], + ], + dtype=torch.long, + ) + program_tensor_binding_rows = torch.tensor( + [ + [0, 0, 0, 0], + [1, 1, 0, 1], + [2, 2, 0, 1], + [3, 3, 0, 2], + [4, 3, 1, 0], + [5, 4, 1, 1], + [6, 5, 1, 1], + [7, 6, 1, 2], + [8, 7, 2, 0], + [9, 8, 2, 1], + [10, 9, 2, 2], + [11, 6, 3, 0], + [12, 9, 3, 0], + [13, 10, 3, 0], + [14, 11, 3, 0], + [15, 12, 3, 0], + [16, 13, 3, 2], + [17, 14, 3, 2], + [18, 15, 3, 2], + [19, 16, 3, 2], + [20, 13, 4, 0], + [21, 17, 4, 1], + [22, 18, 4, 2], + [23, 19, 4, 1], + ], + dtype=torch.long, + ) + memory_liveness_rows = torch.tensor( + [ + [0, 0, 0, 4, 9, 6, 8, 9, 3, 1], + [1, 1, 0, 4, 3, 4, 8, 4, 7, 1], + [2, 2, 0, 4, 10, 6, 8, 11, 3, 1], + [3, 3, 0, 4, 11, 1, 8, 12, 8, 1], + [4, 4, 0, 4, 11, 1, 8, 12, 8, 1], + ], + dtype=torch.long, + ) + runtime_buffer_tensors: list[torch.Tensor] = [] + runtime_buffer_rows: list[list[int]] = [] + + def append_runtime_buffer( + memory_row: list[int], tensor: torch.Tensor, runtime_role: int, logical_index: int + ) -> None: + runtime_buffer_tensors.append(tensor) + runtime_buffer_rows.append( + [ + len(runtime_buffer_rows), + int(memory_row[0]), + int(memory_row[6]), + int(memory_row[3]), + int(memory_row[2]), + int(memory_row[7]), + len(runtime_buffer_rows), + 0, + runtime_role, + logical_index, + ] + ) + + for memory_row in memory_liveness_rows.tolist(): + memory_row_index = int(memory_row[0]) + primitive_row_index = int(memory_row[1]) + if primitive_row_index == 0: + append_runtime_buffer(memory_row, recurrent_msg.new_empty(batch, receivers, hidden), 7, primitive_row_index) + elif primitive_row_index == 1: + append_runtime_buffer( + memory_row, + recurrent_msg.new_empty(batch, receivers, 4, hidden), + 7, + primitive_row_index, + ) + elif primitive_row_index == 2: + append_runtime_buffer( + memory_row, + recurrent_msg.new_empty(batch, receivers, 4, hidden), + 8, + primitive_row_index, + ) + elif primitive_row_index == 3: + for binding_index in (16, 17, 18, 19): + append_runtime_buffer( + memory_row, + recurrent_msg.new_empty(batch, receivers, hidden), + 9, + binding_index, + ) + elif primitive_row_index == 4: + append_runtime_buffer( + memory_row, + recurrent_msg.new_empty(batch, receivers, hidden), + 10, + primitive_row_index, + ) + else: + append_runtime_buffer(memory_row, recurrent_msg.new_empty(1), 0, memory_row_index) + + outputs = registered_temporal_fused_forward_transition_program_cuda( + program_tensors=( + recurrent_msg, + value_to_state_weight, + recurrent_bias, + empty, + gate_weight, + gate_bias, + empty, + y_prev, + recurrent_kernel, + empty, + c_prev, + n_prev, + m_prev, + empty, + empty, + empty, + empty, + outnorm_weight, + empty, + outnorm_eps, + ), + program_tensor_binding_rows=program_tensor_binding_rows, + runtime_buffer_tensors=tuple(runtime_buffer_tensors), + runtime_buffer_rows=torch.tensor(runtime_buffer_rows, dtype=torch.long), + primitive_rows=primitive_rows, + forward_executor_rows=forward_executor_rows, + forward_handler_rows=forward_handler_rows, + native_strategy_rows=native_strategy_rows, + native_callable_binding_schema_rows=temporal_native_callable_binding_schema_rows_tensor(), + native_callable_output_rows=temporal_native_callable_output_rows_tensor(), + transition_primitive_callable_rows=temporal_transition_primitive_native_callable_rows_tensor(), + forward_executor_binding_rows=forward_binding_rows, + memory_liveness_rows=memory_liveness_rows, + ) + transition_input = torch.einsum("brv,rvh->brh", recurrent_msg, value_to_state_weight) + recurrent_bias.unsqueeze(0) + gate_logits = _gate_affine_reference(transition_input, gate_weight, gate_bias) + recurrent_gate_logits = torch.einsum( + "brhi,rghoi->brgho", + y_prev.view(batch, receivers, heads, head_dim), + recurrent_kernel, + ).reshape(batch, receivers, 4, hidden) + gated_expected = _gated_logspace_core_reference( + gate_logits, + recurrent_gate_logits, + c_prev, + n_prev, + m_prev, + ) + norm_expected = _norm_or_identity_reference(gated_expected[0], outnorm_weight, eps=eps) + + assert len(outputs) == 20 + torch.testing.assert_close(outputs[3], transition_input, rtol=1.0e-5, atol=1.0e-5) + torch.testing.assert_close(outputs[6], gate_logits, rtol=1.0e-5, atol=1.0e-5) + torch.testing.assert_close(outputs[9], recurrent_gate_logits, rtol=1.0e-5, atol=1.0e-5) + for actual_tensor, expected_tensor in zip(outputs[13:17], gated_expected, strict=True): + torch.testing.assert_close(actual_tensor, expected_tensor, rtol=1.0e-5, atol=1.0e-5) + torch.testing.assert_close(outputs[18], norm_expected, rtol=1.0e-5, atol=1.0e-5) + + +def test_readout_rule_compiler_rejects_unsupported_rule() -> None: + with pytest.raises(ValueError, match="Unsupported Fabric readout rule"): + compile_readout_rule(ReadoutRuleIR(name="bad_readout", pool="unsupported", readout_slots=1)) def test_default_message_rule_has_python_semantic_equivalent() -> None: @@ -170,13 +6587,21 @@ def test_default_message_rule_has_python_semantic_equivalent() -> None: assert message_rule.name == "dot_product" assert message_rule.output_boundary == "projected_message" - assert summary.lowering_kind == "dot_product_segment_softmax_weighted_sum" + assert summary.lowering_kind == "dot_product_fixed_slot_context_nudge" assert summary.source_signature == ( "receiver_slot", + "receiver_public_prev:reset=zero_source_rows:scope=batch_row", + "sender_slot", "sender_public_prev:reset=zero_source_rows:scope=batch_row", "edge_distance", ) - assert "k_weight:projection:sender_group_shared:groups=4" in summary.parameter_sharing_signature + assert "message_query_nudge_scale:rule_scalar:fabric_global" in summary.parameter_sharing_signature + assert "input_sender_value_weight:projection:sender_group_shared:groups=4" in summary.parameter_sharing_signature + assert ( + "recurrent_sender_value_weight:projection:sender_group_shared:groups=4" in summary.parameter_sharing_signature + ) + assert "context_nudge(receiver_public_prev)" in summary.expression_signature + assert "normalize->projected_message" in summary.expression_signature def test_sparse_message_overlay_does_not_disqualify_receiver_major_transition() -> None: @@ -195,11 +6620,12 @@ def test_sparse_message_overlay_does_not_disqualify_receiver_major_transition() assert all(plan.execution_family == ExecutionFamily.RECEIVER_MAJOR for plan in planned.bucket_plans) -def test_cuda_message_backend_consumes_flat_sender_tables_not_geometry() -> None: - kernel_path = ( - Path(__file__).parents[1] / "src/cortical/fabric/backend/cuda/message_passing/local_message_kernels.cu" +def test_registered_message_strategies_consume_compiler_sender_tables_not_geometry() -> None: + strategy_path = ( + Path(__file__).parents[1] + / "src/cortical/fabric/backend/cuda/sequence_surface/flat_bucket/registered_program/native_callables/message_forward_strategies.cuh" ) - text = kernel_path.read_text() + text = strategy_path.read_text() assert "offset_flat_index" not in text assert "coord_dim" not in text @@ -211,17 +6637,46 @@ def test_cuda_message_backend_consumes_flat_sender_tables_not_geometry() -> None def test_fabric_cell_registration_separates_shared_specs_from_backend_implementations() -> None: slstm_spec = get_cell_spec("slstm") axon_spec = get_cell_spec("axoncell") + slstm_backend_spec = build_cell_backend_spec( + cell_type="slstm", + hidden_size=32, + d_public=32, + d_msg=32, + head_dim=4, + value_dim=4, + ) + axon_backend_spec = build_cell_backend_spec( + cell_type="axoncell", + hidden_size=32, + d_public=32, + d_msg=32, + head_dim=4, + value_dim=4, + ) slstm_pytorch = get_cell_backend_implementation("slstm", "pytorch") - slstm_cuda = get_cell_backend_implementation("slstm", "cuda") axon_pytorch = get_cell_backend_implementation("axoncell", "pytorch") - axon_cuda = get_cell_backend_implementation("axoncell", "cuda") assert slstm_spec.cell_kind == 0 assert axon_spec.cell_kind == 1 assert slstm_pytorch.lower_transition_op is not None assert axon_pytorch.lower_transition_op is not None - assert slstm_cuda.metadata == {"native_cell_kind": slstm_spec.cell_kind} - assert axon_cuda.metadata == {"native_cell_kind": axon_spec.cell_kind} + assert tuple(op.name for op in slstm_backend_spec.transition_ir.ops) == ( + "linear", + "linear", + "matmul", + "gated_logspace_recurrence", + "norm_or_identity", + ) + assert tuple(op.name for op in axon_backend_spec.transition_ir.ops) == ( + "linear", + "diag_rtu", + "linear", + "norm_or_identity", + ) + with pytest.raises(ValueError, match="backend=cuda"): + get_cell_backend_implementation("slstm", "cuda") + with pytest.raises(ValueError, match="backend=cuda"): + get_cell_backend_implementation("axoncell", "cuda") def test_fabric_backend_plan_cache_is_bounded_and_reused() -> None: @@ -235,6 +6690,327 @@ def test_fabric_backend_plan_cache_is_bounded_and_reused() -> None: assert all(plan.candidate_count <= 3 for plan in first.bucket_plans) +def test_fabric_temporal_execution_plan_records_output_request_schedule() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=5, + k=3, + training=True, + device=torch.device("cuda"), + output_boundary="sequence", + materialize_final_state=False, + state_is_fresh=True, + has_resets=True, + ) + + assert planned.supported + assert planned.schedule.schedule_kind == "scalar_constant_k" + assert planned.schedule.outer_time_steps == 5 + assert planned.schedule.inner_steps == 3 + assert planned.schedule.total_scan_steps == 15 + assert planned.output_request.selector_kind == "all_outer_steps" + assert planned.output_request.first_outer_step == 0 + assert planned.output_request.outer_stride == 1 + assert planned.output_request.emitted_output_count == 5 + assert planned.output_request.first_physical_step == 2 + assert planned.output_request.physical_stride == 3 + assert planned.output_request.output_surface == "full_cells" + assert planned.output_request.materialization == "outputs_only" + assert planned.output_request.autograd_seed_kind == "emitted_output_grad" + assert "transition_primitive_adjoint" in planned.output_request.required_backward_surfaces + assert planned.output_request.checkpoint_policy_basis == "emitted_output_schedule" + assert planned.gradient_boundary.mode == "full_horizon" + assert planned.checkpoint.checkpoint_kind == "planner_default" + assert planned.checkpoint.checkpoint_steps == 15 + assert planned.executor.backend_name == "cuda" + assert planned.executor.selected_implementation == "registered_temporal_program" + assert planned.static_values.static_value_mode == "detached_shared_values" + assert planned.static_values.native_static_materialization + assert not planned.static_values.include_full_cell_kv_weight + assert planned.static_values.detach_training_static_tensors + assert planned.engine.forward_owner == "registered_fused_forward_program_cuda" + assert planned.engine.backward_owner == "registered_reverse_executor_bindings" + assert planned.engine.target_owner == "registered_temporal_executor_bindings" + assert planned.engine.status == "registered_executor_bindings" + assert planned.sequence_surface_route.uses_registered_temporal_program + assert planned.boundary.resets == "present" + assert planned.carry.initial_state == "fresh" + assert not planned.carry.materialize_final_state + assert not planned.carry.fresh_state_population_cache + + +def test_cuda_temporal_runtime_scheduler_consumes_planner_records() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=5, + k=3, + training=True, + device=torch.device("cuda"), + output_boundary="sequence", + materialize_final_state=False, + gradient_horizon_steps=6, + ) + + scheduler = build_temporal_runtime_scheduler_plan( + temporal_plan=planned, + outer_time_steps=5, + inner_steps=3, + output_boundary="sequence", + output_contract="output_cells", + materialize_final_state=False, + collect_artifacts=True, + ) + + assert scheduler.owner == "planner" + assert scheduler.physical_time_steps == 15 + assert scheduler.output_emissions.physical_to_output_index == ((2, 0), (5, 1), (8, 2), (11, 3), (14, 4)) + assert scheduler.output_emissions.autograd_seed_kind == "emitted_output_grad" + assert scheduler.checkpoint.backward_window_kind == "rolling_horizon" + assert scheduler.checkpoint.backward_window_steps == 6 + assert scheduler.materialization.reverse_artifact_kind == planned.materialization.reverse_artifact_kind + grad_output = torch.ones(2, 5, 1, 4) + replay = scheduler.replay_request_for_window( + grad_output_seq=grad_output, + window_start=4, + window_end=10, + include_final_state_output=True, + ) + assert replay.output_message_physical_steps == (5, 8, 9) + assert replay.final_state_physical_step == 9 + + +def test_cuda_temporal_runtime_scheduler_uses_terminal_output_plan() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=5, + k=3, + training=True, + device=torch.device("cuda"), + output_boundary="terminal", + materialize_final_state=False, + ) + + scheduler = build_temporal_runtime_scheduler_plan( + temporal_plan=planned, + outer_time_steps=5, + inner_steps=3, + output_boundary="sequence", + output_contract="output_cells", + materialize_final_state=False, + collect_artifacts=True, + ) + + assert scheduler.output_emissions.selector_kind == "terminal_outer_step" + assert scheduler.output_emissions.physical_to_output_index == ((14, 0),) + assert scheduler.output_emissions.active_local_steps( + torch.ones(2, 1, 1, 4), + window_start=12, + window_len=3, + ) == (2,) + + +def test_fabric_temporal_execution_plan_represents_future_k_schedule_shape() -> None: + runtime = build(_make_slstm_spec()) + k_schedule = torch.ones(2, 4, dtype=torch.long) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=4, + k=k_schedule, + training=False, + output_boundary="terminal", + state_is_fresh=False, + ) + + assert planned.schedule.schedule_kind == "runtime_variable_k" + assert planned.schedule.inner_steps is None + assert planned.schedule.total_scan_steps is None + assert planned.schedule.per_timestep_k_semantic == "represented_not_lowered" + assert planned.output_request.selector_kind == "terminal_outer_step" + assert planned.output_request.first_outer_step == 3 + assert planned.output_request.outer_stride is None + assert planned.output_request.emitted_output_count == 1 + assert planned.output_request.first_physical_step is None + assert planned.output_request.autograd_seed_kind == "none" + assert planned.output_request.required_backward_surfaces == () + assert planned.output_request.checkpoint_policy_basis == "inference" + assert planned.gradient_boundary.mode == "inference" + assert planned.backward_window.window_kind == "none" + assert planned.carry.initial_state == "provided" + assert planned.executor.backend_name == "pytorch" + assert planned.static_values.static_value_mode == "inference_cache" + assert not planned.static_values.native_static_materialization + assert planned.static_values.include_full_cell_kv_weight + assert planned.engine.forward_owner == "pytorch_reference" + assert planned.engine.backward_owner == "none" + assert planned.engine.status == "pytorch_reference" + + +def test_fabric_temporal_execution_plan_records_horizon_and_checkpoint_owner() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=16, + k=4, + training=True, + device=torch.device("cuda"), + gradient_horizon_steps=64, + checkpoint_steps=8, + ) + + assert planned.gradient_boundary.mode == "full_horizon" + assert planned.gradient_boundary.horizon_steps == 64 + assert planned.gradient_boundary.owner == "planner" + assert planned.backward_window.window_kind == "full_horizon" + assert planned.backward_window.max_window_steps == 64 + assert planned.backward_window.owner == "planner" + assert planned.checkpoint.checkpoint_kind == "explicit" + assert planned.checkpoint.checkpoint_steps == 8 + assert planned.checkpoint.owner == "planner" + + +def test_fabric_temporal_execution_plan_defaults_checkpoint_to_horizon_when_not_provided() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=16, + k=4, + training=True, + device=torch.device("cuda"), + gradient_horizon_steps=64, + ) + + assert planned.gradient_boundary.mode == "full_horizon" + assert planned.checkpoint.checkpoint_kind == "planner_default" + assert planned.checkpoint.checkpoint_steps == 64 + assert planned.checkpoint.owner == "planner" + + +def test_fabric_temporal_execution_plan_clips_horizon_to_available_stream() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=3, + k=2, + training=True, + device=torch.device("cuda"), + gradient_horizon_steps=64, + ) + + assert planned.schedule.total_scan_steps == 6 + assert planned.gradient_boundary.mode == "full_horizon" + assert planned.gradient_boundary.horizon_steps == 6 + assert planned.backward_window.window_kind == "full_horizon" + assert planned.backward_window.max_window_steps == 6 + assert planned.checkpoint.checkpoint_kind == "planner_default" + assert planned.checkpoint.checkpoint_steps == 6 + + +def test_fabric_temporal_execution_plan_keeps_rolling_horizon_when_stream_exceeds_horizon() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=16, + k=8, + training=True, + device=torch.device("cuda"), + gradient_horizon_steps=64, + ) + + assert planned.schedule.total_scan_steps == 128 + assert planned.gradient_boundary.mode == "rolling_horizon" + assert planned.gradient_boundary.horizon_steps == 64 + assert planned.backward_window.window_kind == "rolling_horizon" + assert planned.backward_window.max_window_steps == 64 + assert planned.checkpoint.checkpoint_steps == 64 + + +def test_fabric_temporal_execution_plan_keeps_single_pop_k1_on_flat_temporal_owner() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=5, + k=1, + training=False, + device=torch.device("cuda"), + ) + + assert planned.executor.selected_implementation == "registered_temporal_program" + assert planned.substrate.population_cardinality == "single" + assert planned.substrate.bucket_identity == "flat_bucket_identity" + assert planned.engine.forward_owner == "registered_fused_forward_program_cuda" + assert planned.engine.target_owner == "registered_temporal_executor_bindings" + + +def test_fabric_temporal_execution_plan_records_fresh_multi_population_cache_policy() -> None: + runtime = build(_make_mixed_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=8, + k=1, + training=False, + device=torch.device("cuda"), + materialize_final_state=False, + state_is_fresh=True, + ) + + assert planned.carry.fresh_state_population_cache + assert planned.carry.fresh_state_population_cache_reason == "fresh_registered_inference_without_final_state" + + +def test_fabric_temporal_execution_plan_records_fresh_single_population_cache_policy() -> None: + runtime = build(_make_slstm_spec()) + + planned = runtime.plan_temporal_execution( + batch_size=2, + time_steps=1, + k=1, + training=False, + device=torch.device("cuda"), + materialize_final_state=False, + state_is_fresh=True, + ) + + assert planned.carry.fresh_state_population_cache + assert planned.carry.fresh_state_population_cache_reason == "fresh_registered_inference_without_final_state" + + +def test_fabric_backward_plan_can_carry_matching_temporal_plan() -> None: + runtime = build(_make_slstm_spec()) + temporal_plan = runtime.plan_temporal_execution( + batch_size=2, + time_steps=8, + k=2, + training=True, + device=torch.device("cuda"), + ) + + backward_plan = runtime.plan_backend_backward_execution( + batch_size=2, + time_steps=8, + inner_steps=2, + training=True, + device=torch.device("cuda"), + temporal_plan=temporal_plan, + ) + + assert backward_plan.temporal_plan is temporal_plan + assert backward_plan.temporal_plan.schedule.total_scan_steps == 16 + assert backward_plan.temporal_plan.gradient_boundary.mode == "full_horizon" + + def test_fabric_backend_plan_carries_scalable_cuda_launch_fields() -> None: runtime = build(_make_slstm_spec(hidden_size=16)) @@ -274,84 +7050,6 @@ def test_fabric_backend_plan_carries_scalable_cuda_launch_fields() -> None: assert plan.emit_static_stage_mode in {"disabled", "shared_full"} -def _minimal_fabric_execution_request( - *, - readout_mode: str = "separate_port_owned", - cell_static_stage_mode: str = "shared_full", - message_rule_lowering_kind: str = "dot_product_segment_softmax_weighted_sum", -) -> FabricExecutionRequest: - value = torch.empty(1) - return FabricExecutionRequest( - population_name="test", - cell_core_spec=get_cell_spec("slstm"), - message_backend_name="local", - message_rule_name="dot_product", - message_rule_lowering_kind=message_rule_lowering_kind, - message_rule_expression_signature="dot_product_segment_softmax_weighted_sum", - message_rule_source_signature="receiver_slot;sender_public_prev;edge_distance", - message_rule_parameter_sharing_signature="q_weight:projection:rule_global", - message_rule_output_boundary="projected_message", - readout_backend_name="output_sequence_from_banks", - gradient_enabled=False, - input_k_seq=value.reshape(1, 1, 1, 1), - input_v_seq=value.reshape(1, 1, 1, 1), - packed_state={}, - initial_hidden=value.reshape(1, 1, 1), - initial_recurrent_k=None, - initial_recurrent_v=None, - initial_state_is_fresh=False, - materialize_final_state=True, - resets_u8=torch.zeros(1, 1, dtype=torch.uint8), - reset_rows_present=False, - stage_receiver_static=True, - replication_factor=1, - receiver_tile=1, - batch_tile=1, - edge_tile=1, - hidden_chunk=1, - state_receiver_tile=1, - state_batch_tile=1, - state_hidden_chunk=1, - state_static_stage_mode=cell_static_stage_mode, - emit_receiver_tile=1, - emit_batch_tile=1, - emit_hidden_chunk=1, - emit_static_stage_mode=cell_static_stage_mode, - public_receiver_tile=1, - public_batch_tile=1, - readout_mode=readout_mode, - readout_port_tile=1, - readout_output_chunk=1, - cell_static_stage_mode=cell_static_stage_mode, - routing_tensors={}, - cell_tensors={}, - readout_tensors={}, - static_config={}, - ) - - -@pytest.mark.parametrize( - ("readout_mode", "cell_static_stage_mode", "message_rule_lowering_kind", "match"), - [ - ("fuse_receiver_owned", "shared_full", "dot_product_segment_softmax_weighted_sum", "readout_mode"), - ("separate_port_owned", "shared_partial", "dot_product_segment_softmax_weighted_sum", "cell_static_stage_mode"), - ("separate_port_owned", "shared_full", "unsupported_rule", "message_rule_lowering_kind"), - ], -) -def test_fabric_cuda_execution_request_fails_closed_on_unimplemented_modes( - readout_mode: str, - cell_static_stage_mode: str, - message_rule_lowering_kind: str, - match: str, -) -> None: - with pytest.raises(ValueError, match=match): - _minimal_fabric_execution_request( - readout_mode=readout_mode, - cell_static_stage_mode=cell_static_stage_mode, - message_rule_lowering_kind=message_rule_lowering_kind, - ) - - def test_fabric_backend_workspace_plan_tracks_delay_and_tape_policy() -> None: runtime = build(_make_slstm_spec(max_delay=4)) @@ -384,8 +7082,8 @@ def test_fabric_supported_backend_surface_matrix_exposes_no_fallback_contract() assert any(surface.key == "slstm_recurrence" for surface in slstm_surfaces) assert any(surface.key == "axon_recurrence" for surface in axon_surfaces) - assert any("runtime/reference fallback" in surface.disallowed_fallbacks for surface in slstm_surfaces) - assert any("runtime/reference fallback" in surface.disallowed_fallbacks for surface in axon_surfaces) + assert any("runtime_reference_route" in surface.forbidden_routes for surface in slstm_surfaces) + assert any("runtime_reference_route" in surface.forbidden_routes for surface in axon_surfaces) planned = runtime.plan_backend_execution(batch_size=16, time_steps=8, inner_steps=2, training=True) assert any(plan.math_backend.value == "grouped_gemm" for plan in planned.bucket_plans) diff --git a/tests/test_fabric_benchmark_suite_common.py b/tests/test_fabric_benchmark_suite_common.py index 10a771b6..8c3590a9 100644 --- a/tests/test_fabric_benchmark_suite_common.py +++ b/tests/test_fabric_benchmark_suite_common.py @@ -10,9 +10,18 @@ if str(_CORTICAL_ROOT) not in sys.path: sys.path.insert(0, str(_CORTICAL_ROOT)) -from benchmarks.fabric_suite_common import ( # noqa: E402 +from benchmarks.fabric.suite_common import ( # noqa: E402 + _attach_compiler_memory_owner_ledger, + _count_params, + _measure_model, _run_rollout_forward, + build_mixed_fabric_backbone, + build_mixed_stack_backbone, find_param_matched_backbone, + find_param_matched_mixed_fabric_backbone, + find_param_matched_mixed_stack_backbone, + make_mixed_fabric_sequence_model, + make_mixed_stack_sequence_model, make_sequence_training_target, ) @@ -68,6 +77,91 @@ def test_rollout_forward_uses_sequence_path_when_available() -> None: assert model.step_calls == 0 +def test_training_measurement_clears_warmup_grads_before_cuda_peak_reset(monkeypatch) -> None: + model = nn.Linear(1, 1, bias=False) + runtime = type( + "_Runtime", + (), + { + "_last_flat_bucket_temporal_registered_backward_memory_stages": ("warmup_stage",), + "_last_flat_bucket_temporal_frontend_tensor_bytes": ("warmup_frontend_stage",), + }, + )() + model.runtime = runtime + x = torch.zeros(1, 1) + target = torch.zeros(1, 1) + parameter = next(model.parameters()) + events: list[str] = [] + reset_saw_grad: list[bool] = [] + + def iteration_fn(*, model, run_mode, optimizer, x, target) -> None: + del model, run_mode, optimizer, x, target + events.append("iteration") + parameter.grad = torch.ones_like(parameter) + + def synchronize() -> None: + events.append("synchronize") + + def reset_peak_memory_stats() -> None: + events.append("reset_peak") + reset_saw_grad.append(parameter.grad is not None) + + monkeypatch.setattr(torch.cuda, "synchronize", synchronize) + monkeypatch.setattr(torch.cuda, "reset_peak_memory_stats", reset_peak_memory_stats) + monkeypatch.setattr(torch.cuda, "memory_allocated", lambda device=None: 0) + monkeypatch.setattr(torch.cuda, "memory_reserved", lambda device=None: 0) + monkeypatch.setattr(torch.cuda, "max_memory_allocated", lambda device=None: 0) + + _measure_model( + model=model, + run_mode=lambda value: value, + x=x, + target=target, + warmup=1, + iterations=1, + device=torch.device("cuda"), + iteration_fn=iteration_fn, + ) + + assert reset_saw_grad == [False] + assert runtime._last_flat_bucket_temporal_registered_backward_memory_stages == () + assert runtime._last_flat_bucket_temporal_frontend_tensor_bytes == () + assert events[:3] == ["iteration", "synchronize", "reset_peak"] + + +def test_compiler_memory_ledger_records_first_peak_and_max_delta_stage() -> None: + memory_ledger = {"cuda_max_allocated_bytes": 1024, "model_parameter_bytes": 0, "model_parameter_grad_bytes": 0} + planner_signature = { + "workspace_aliases": ( + "flat_bucket_temporal_registered_backward_memory_stage:" + "stage=native_forward_entry_local0;allocated=100;reserved=200;max_allocated=128", + "flat_bucket_temporal_registered_backward_memory_stage:" + "stage=native_forward_message_after_projected_gemm_local0;allocated=180;reserved=260;max_allocated=512", + "flat_bucket_temporal_registered_backward_memory_stage:" + "stage=native_forward_return_local0;allocated=90;reserved=260;max_allocated=512", + ), + } + + _attach_compiler_memory_owner_ledger(memory_ledger, planner_signature) + + assert ( + memory_ledger[ + "fabric_registered_backward_stage_max_delta_bytes.native_forward_message_after_projected_gemm_local0" + ] + == 384 + ) + assert ( + memory_ledger[ + "fabric_registered_backward_peak_stage_by_max_delta.native_forward_message_after_projected_gemm_local0" + ] + == 1 + ) + assert ( + memory_ledger["fabric_registered_backward_first_peak_stage.native_forward_message_after_projected_gemm_local0"] + == 1 + ) + + def test_small_hidden_large_fabric_match_prioritizes_parameter_target() -> None: for family in ("slstm", "axoncell"): stack_match = find_param_matched_backbone(target_params=1_000_000_000, kind="stack", family=family) @@ -82,3 +176,67 @@ def test_small_hidden_large_fabric_match_prioritizes_parameter_target() -> None: assert fabric_match.fabric_hidden_size == 8 assert abs(fabric_match.actual_params - stack_match.actual_params) / stack_match.actual_params <= 0.01 + + +def test_mixed_fabric_backbone_uses_public_population_nodes() -> None: + backbone = build_mixed_fabric_backbone(d_hidden=8, width=4, height=4, hidden_size=8) + + population_node_indices = backbone.spec.config.populations.population_node_indices + assert population_node_indices is not None + assert set(population_node_indices) == {"slstm", "axoncell"} + assert set(population_node_indices["slstm"]).isdisjoint(population_node_indices["axoncell"]) + assert len(population_node_indices["slstm"]) + len(population_node_indices["axoncell"]) == 8 + + +def test_mixed_fabric_sequence_model_runs_high_level_forward() -> None: + match = find_param_matched_mixed_fabric_backbone( + target_params=200_000, + family="slstm", + forced_d_hidden=64, + fabric_hidden_grid=(8,), + ) + model = make_mixed_fabric_sequence_model(match, device=torch.device("cpu"), dtype=torch.float32) + x = torch.randn(1, 1, match.d_hidden) + + y, _ = model(x, None, materialize_final_state=False, output_boundary="terminal") + + assert y.shape[-1] == match.d_hidden + assert y.shape[0] == 1 + + +def test_mixed_fabric_match_uses_measured_sequence_model_params() -> None: + match = find_param_matched_mixed_fabric_backbone( + target_params=200_000, + family="slstm", + forced_d_hidden=64, + fabric_hidden_grid=(8,), + ) + model = make_mixed_fabric_sequence_model(match, device=torch.device("cpu"), dtype=torch.float32) + + assert _count_params(model) == match.actual_params + + +def test_mixed_stack_sequence_model_runs_high_level_forward() -> None: + backbone = build_mixed_stack_backbone(d_hidden=16, num_layers=1) + assert backbone is not None + match = find_param_matched_mixed_stack_backbone( + target_params=20_000, + forced_d_hidden=16, + ) + model = make_mixed_stack_sequence_model(match, device=torch.device("cpu"), dtype=torch.float32) + x = torch.randn(1, 1, match.d_hidden) + + y, _ = model(x, None) + + assert y.shape[-1] == match.d_hidden + assert y.shape[0] == 1 + + +def test_mixed_stack_match_uses_measured_sequence_model_params() -> None: + match = find_param_matched_mixed_stack_backbone( + target_params=20_000, + forced_d_hidden=16, + ) + model = make_mixed_stack_sequence_model(match, device=torch.device("cpu"), dtype=torch.float32) + + assert _count_params(model) == match.actual_params diff --git a/tests/test_fabric_execution_imports.py b/tests/test_fabric_execution_imports.py index 8e1f26da..406239c8 100644 --- a/tests/test_fabric_execution_imports.py +++ b/tests/test_fabric_execution_imports.py @@ -1,9 +1,7 @@ from __future__ import annotations -import os +import importlib.util import re -import subprocess -import sys from pathlib import Path from cortical.fabric.backend import ( @@ -26,11 +24,11 @@ def test_fabric_cuda_nn_backward_registry_covers_callable_surface() -> None: "message_bucket_degree_bucketed_sparse", "message_bucket_ragged_grouped_sparse", } - legacy_transition_ir_primitives = {"matmul", "diag_rtu", "gated_logspace_recurrence", "norm_or_identity"} + transition_ir_primitives = {"matmul", "diag_rtu", "gated_logspace_recurrence", "norm_or_identity"} registered_primitives = set(cuda_nn_callable_primitives()) assert callable_builder_primitives <= registered_primitives - allowed_primitives = callable_builder_primitives | message_bucket_declarations | legacy_transition_ir_primitives + allowed_primitives = callable_builder_primitives | message_bucket_declarations | transition_ir_primitives assert registered_primitives <= allowed_primitives behaviors = {behavior.primitive: behavior for behavior in cuda_nn_primitive_backward_behaviors()} @@ -70,292 +68,109 @@ def test_fabric_cuda_nn_backward_registry_covers_callable_surface() -> None: assert "autograd" not in behavior_contract -def test_fabric_execution_package_import_stays_on_lightweight_registration_path() -> None: +def test_old_cuda_execution_package_is_deleted() -> None: repo_root = Path(__file__).resolve().parents[1] - env = dict(os.environ) - pythonpath = env.get("PYTHONPATH") - extra_path = str(repo_root / "src") - env["PYTHONPATH"] = extra_path if not pythonpath else f"{extra_path}:{pythonpath}" - code = """ -import importlib.util -import sys -import cortical.fabric.backend.cuda.execution as execution - -def spec_or_none(name: str): - try: - return importlib.util.find_spec(name) - except ModuleNotFoundError: - return None - -registry = execution.run_registered_execution.__globals__["_EXECUTION_REGISTRY"] -assert len(registry) == 4, len(registry) -assert "cortical.fabric.backend.cuda.execution.generic_sequence_helpers" not in sys.modules -assert "cortical.fabric.backend.cuda.execution.receiver_owned_recurrence_cuda" not in sys.modules -assert "cortical.fabric.backend.cuda.execution.persistent_scan_recurrence_cuda" not in sys.modules -assert "cortical.fabric.backend.cuda.execution.edge_owned_accumulation_cuda" not in sys.modules -assert spec_or_none("cortical.fabric.backend.cuda.execution.receiver_owned_recurrence_cuda") is None -assert spec_or_none("cortical.fabric.backend.cuda.execution.persistent_scan_recurrence_cuda") is None -assert spec_or_none("cortical.fabric.backend.cuda.execution.edge_owned_accumulation_cuda") is None -assert spec_or_none("cortical.fabric.backend.cuda.execution.generic_sequence_helpers") is None -assert spec_or_none("cortical.fabric.backend.cuda.execution.generic_cell_step_cuda") is None -assert spec_or_none("cortical.fabric.backend.cuda.cell_core.slstm_cell_core_cuda") is None -assert spec_or_none("cortical.fabric.backend.cuda.cell_core.axon_cell_core_cuda") is None -""" - completed = subprocess.run( - [sys.executable, "-c", code], - cwd=repo_root, - env=env, - capture_output=True, - text=True, - check=False, + cuda_root = repo_root / "src" / "cortical" / "fabric" / "backend" / "cuda" + deleted_paths = ( + cuda_root / "execution", + cuda_root / "cells", + cuda_root / "recurrence_executor.py", + cuda_root / "registry" / "cell_dispatch_registry.cpp", + cuda_root / "registry" / "cell_dispatch_registry.cuh", + cuda_root / "registry" / "cell_registration_helpers.cuh", + cuda_root / "cells" / "slstm.cuh", + cuda_root / "cells" / "slstm_registration.cu", + cuda_root / "cells" / "axon.cuh", + cuda_root / "cells" / "axon_registration.cu", ) - assert completed.returncode == 0, completed.stderr or completed.stdout + for path in deleted_paths: + assert not path.exists(), str(path) + + assert importlib.util.find_spec("cortical.fabric.backend.cuda.execution") is None -def test_fabric_cuda_execution_sources_keep_scalable_backend_contract() -> None: +def test_sequence_surface_uses_registered_temporal_compiler_path() -> None: repo_root = Path(__file__).resolve().parents[1] cuda_root = repo_root / "src" / "cortical" / "fabric" / "backend" / "cuda" - hot_path_sources = ( - cuda_root / "execution" / "receiver_owned_stepwise.cuh", - cuda_root / "execution" / "edge_owned_accumulate_stepwise.cu", + registered_program_root = cuda_root / "sequence_surface" / "flat_bucket" + registered_program_sources = ( + "flat_bucket_registered_program_kernels.cu", + "registered_program/common.cuh", + "registered_program/forward_program.cuh", + "registered_program/backward_surface_steps.cuh", + "registered_program/backward_program.cuh", + ) + surface = (cuda_root / "sequence_surface" / "runtime" / "surface.py").read_text() + executor = (cuda_root / "sequence_surface" / "runtime" / "executor.py").read_text() + registered = (cuda_root / "sequence_surface" / "temporal" / "registered_executors.py").read_text() + runtime_dispatch = (repo_root / "src" / "cortical" / "fabric" / "backend" / "runtime_dispatch.py").read_text() + cuda_package_init = (cuda_root / "__init__.py").read_text() + projection_package_init = (cuda_root / "projection" / "__init__.py").read_text() + program_kernel = "\n".join( + (registered_program_root / source_name).read_text(encoding="utf-8") + for source_name in registered_program_sources ) - for source_path in hot_path_sources: - source = source_path.read_text() - assert "forward_core" not in source - assert re.search(r"if\s*\(\s*lane\s*==\s*0\s*\)", source) is None - assert re.search(r"if\s*\(\s*lane\s*!=\s*0\s*\)", source) is None - assert "kBatchTile" not in source - assert "kReceiverTile" not in source + assert "_build_backend_sequence_request" not in surface + assert "_run_backend_sequence_surface_once" not in surface + assert "_execute_compiler_temporal_sequence_surface" in surface + assert "execute_temporal_bucket_sequence" in surface + assert "run_temporal_bucket_sequence_physical_autograd" in executor + assert "run_shared_temporal_bucket_forward_scan" in executor + assert "RegisteredTemporalExecutorProgram" in registered + assert "run_registered_forward_message_carrier_handler" in program_kernel + assert "cortical.fabric.backend.cuda.runtime_ops" not in runtime_dispatch + assert "fabric_local_message_cuda" not in runtime_dispatch + assert "fabric_sparse_message_cuda" not in runtime_dispatch + assert "fabric_grouped_projection_cuda" not in runtime_dispatch + assert not (cuda_root / "registry").exists() + assert not (cuda_root / "reference").exists() + assert not (cuda_root / "message_passing").exists() + assert not (cuda_root / "runtime_ops.py").exists() + assert "fabric_local_message_cuda" not in cuda_package_init + assert "fabric_grouped_projection_cuda" not in cuda_package_init + assert "register_readout_backend" not in projection_package_init - contract = (cuda_root / "contracts" / "cell.cuh").read_text() - dispatcher_loader = (cuda_root / "execution" / "dispatcher_cuda.py").read_text() - dispatcher_cpp = (cuda_root / "execution" / "dispatcher.cpp").read_text() - readout_apply = (cuda_root / "execution" / "readout_apply.cu").read_text() - receiver_stepwise = (cuda_root / "execution" / "receiver_owned_stepwise.cuh").read_text() - cell_registration_helpers = (cuda_root / "registry" / "cell_registration_helpers.cuh").read_text() - cell_dispatch_registry = (cuda_root / "registry" / "cell_dispatch_registry.cuh").read_text() - dense_message_header = (cuda_root / "ops" / "dense_message.cuh").read_text() - dense_message_ops = (cuda_root / "ops" / "dense_message_kernels.cu").read_text() - nn_ir = (cuda_root / "nn" / "ir.cuh").read_text() - scaling_profile = (repo_root / "benchmarks" / "run_fabric_scaling_profile.py").read_text() - slstm_cell = (cuda_root / "cells" / "slstm.cuh").read_text() - axon_cell = (cuda_root / "cells" / "axon.cuh").read_text() - assert "forward_state_chunk" in contract - assert "emit_public_chunk" in contract - assert "cell_transition_ir_host" in contract - assert "CellStateAffineSpec" not in contract - assert "stage_state_static" in contract - assert "stage_emit_static" in contract - assert "kReductionStatsDim" in contract - assert "receiver_message_aggregate_kernel" in receiver_stepwise - assert "receiver_message_project_kernel" not in receiver_stepwise - assert "receiver_state_update_kernel" in receiver_stepwise - assert "receiver_emit_raw_public_kernel" in receiver_stepwise - assert "dense_state_affines" in dispatcher_cpp - assert "cell_transition_ir" in dispatcher_cpp - assert "dispatch_entry.state_affine_specs" not in dispatcher_cpp - assert "regular_local_projected_message_boundary" in dispatcher_cpp - assert "sparse_projected_message_boundary" in dispatcher_cpp - assert "dense_message_kernels.cu" in dispatcher_loader - assert "dense_regular_local_message_out_cuda" in dispatcher_cpp - assert "lower_regular_local_message_bucket" in dispatcher_cpp - assert "lower_sparse_message_bucket" in dispatcher_cpp - assert "lower_ragged_sparse_message_bucket" in dispatcher_cpp - assert "fabric::cuda::nn::make_message_op" in dispatcher_cpp - assert "fabric::cuda::nn::lower_single_message_op" in dispatcher_cpp - assert "dense_regular_local_message_out_cuda" in dense_message_header - assert "pack_regular_local_message_keys_kernel" in dense_message_ops - assert "dense regular-local message logits batched GEMM" in dense_message_ops - assert "dense regular-local message weighted-values batched GEMM" in dense_message_ops - assert "regular_local_message_softmax_inplace_kernel" in dense_message_ops - assert "message_gathered_keys=message_gathered_values" in dense_message_ops - assert "message_demotions" in dispatcher_cpp - assert "degree_bucketed_sparse_message_demoted" in dispatcher_cpp - assert "dense_receiver_owned_sparse_ragged_grouped_message_out_cuda" in dispatcher_cpp - assert "dense_edge_owned_sparse_ragged_grouped_message_out_cuda" in dispatcher_cpp - assert "set_launch_granularity_metadata" in dispatcher_cpp - assert "phase_launch_counts" in dispatcher_cpp - assert "small_cublas_launch_counts" in dispatcher_cpp - assert "copy_glue_launch_counts" in dispatcher_cpp - assert "launch_coalescing_modes" in dispatcher_cpp - assert "generic_glue_fusion_modes" in dispatcher_cpp - assert "state_epilogue_modes" in dispatcher_cpp - assert "state_epilogue_saved_launch_counts" in dispatcher_cpp - assert "launch_granularity_modes" in dispatcher_cpp - assert "physical_op_kinds" in dispatcher_cpp - assert "physical_layout_contracts" in dispatcher_cpp - assert "layout_mode" in dispatcher_cpp - assert "copy_elision_mode" in dispatcher_cpp - assert "bias_fusion_mode" in dispatcher_cpp - assert "physical_op_executors" in dispatcher_cpp - assert "physical_op_demotions" in dispatcher_cpp - assert "physical_boundary_contracts" in dispatcher_cpp - assert "physical_applicability_predicates" in dispatcher_cpp - assert "physical_workspace_aliases" in dispatcher_cpp - assert "physical_workspace_peak_bytes" in dispatcher_cpp - assert "physical_op_launch_counts" in dispatcher_cpp - assert "physical_op_saved_launch_counts" in dispatcher_cpp - assert "standalone_copy_kernel_count" in dispatcher_cpp - assert "standalone_bias_kernel_count" in dispatcher_cpp - assert "receiver_affine_superop_ineligible:unsupported_affine_count" in dispatcher_cpp - assert "receiver_affine_superop_ineligible:unsupported_source_family" in dispatcher_cpp - assert "receiver_affine_superop_ineligible:mixed_output_dim" in dispatcher_cpp - assert "receiver_affine_superop_ineligible:mixed_chunk_family" in dispatcher_cpp - assert "receiver_affine_superop_ineligible:unsupported_reset_scope" in dispatcher_cpp - assert "receiver_affine_superop_ineligible:non_receiver_major_layout" in dispatcher_cpp - assert "launch_receiver_affine_superop" in dispatcher_cpp - assert "receiver_affine2_superop_workspace_out_cuda" in dispatcher_cpp - assert "receiver_affine2_direct_persistent_out_cuda" in dispatcher_cpp - assert "receiver_affine_superop_physical_mode" in dispatcher_cpp - assert "direct_persistent" in dispatcher_cpp - assert "pack_cublas_transitional" in dispatcher_cpp - assert "fabric.physical.receiver_affine" in dispatcher_cpp - assert "fabric.physical.message" in dispatcher_cpp - assert "fabric.physical.state_epilogue" in dispatcher_cpp - assert "fabric.physical.diagonal_recurrence" in (cuda_root / "ops" / "diagonal_recurrence_kernels.cu").read_text() - assert "PHYSICAL_OP_PROFILE_PATTERNS" in scaling_profile - assert "select_diagonal_recurrence_superop_plan" in dispatcher_cpp - assert "diagonal_recurrence_superop_ineligible:no_declaration" in dispatcher_cpp - assert "diagonal_recurrence_superop_ineligible:reduction_boundary" in dispatcher_cpp - assert "diagonal_recurrence_superop_ineligible:unsupported_reset_scope" in dispatcher_cpp - assert "diagonal_recurrence_complex_exp_update_emit_window_out_cuda" in dispatcher_cpp - assert "diagonal_recurrence_kernels.cu" in dispatcher_loader - assert "set_physical_op_plan_metadata" in dispatcher_cpp - assert "make_runtime_physical_execution_plan" in dispatcher_cpp - assert "dense_affine_receiver_major_copy_or_pad_out_cuda" in dispatcher_cpp - assert "dense_affine_receiver_major_split_last_dim_out_cuda" in dispatcher_cpp - assert "receiver_state_update_emit" in dispatcher_cpp - assert "receiver_state_update_emit_kernel" in receiver_stepwise - assert "launch_state_update_emit_variant" in receiver_stepwise - assert "receiver_message_project_kernel" not in dispatcher_cpp - assert "public_project_from_raw_kernel" not in dispatcher_cpp - assert "readout_apply_kernel" not in dispatcher_cpp - assert "recurrent_k_out.copy_" not in dispatcher_cpp - assert "recurrent_v_out.copy_" not in dispatcher_cpp - assert "public_project_from_raw_kernel" not in receiver_stepwise - assert "receiver_from_accumulated_message" not in receiver_stepwise - assert "launch_receiver_owned_stepwise_typed" not in receiver_stepwise - assert "FABRIC_LAUNCH_STATE_TILE(2, 4, 32)" in receiver_stepwise - assert "FABRIC_LAUNCH_PUBLIC_TILE" not in receiver_stepwise - assert "state_static_bytes" in receiver_stepwise - assert "emit_static_bytes" in receiver_stepwise - assert "receiver_apply_from_message" not in dispatcher_loader - assert "receiver_owned_persistent_scan.cu" not in dispatcher_loader - assert "_state_static_bytes" not in dispatcher_loader - assert "dense_affine_out_cuda" in dispatcher_cpp - assert "dense_affine_pack_reset_source_rows_out_cuda" in dispatcher_cpp - assert "state_affine_contributions:combined" in dispatcher_cpp - assert "launch_dense_public_projection" in dispatcher_cpp - assert "launch_grouped_public_projection" in dispatcher_cpp - assert "grouped_projection_forward" in dispatcher_cpp - assert "cortical.fabric.backend.cuda.projection.grouped_projection_cuda" in dispatcher_cpp - assert "receiver_emit_raw_public->dense_public_projection" in dispatcher_cpp - assert "dense_public_projection" in dispatcher_cpp - assert "dense_readout_projection" in dispatcher_cpp - assert "readout_message_aggregate" in dispatcher_cpp - assert "readout_message_aggregate->dense_readout_projection" in dispatcher_cpp - assert "launch_dense_readout" in dispatcher_cpp - assert "public_projection_backends" not in dispatcher_cpp - assert "publish_from_raw_public" not in dispatcher_cpp - assert "publish_from_raw_public" not in receiver_stepwise - assert "readout_apply_kernel" not in readout_apply - assert "launch_readout_apply_cuda" not in readout_apply - assert "readout_message_kernel" in readout_apply - assert '"none"' in dispatcher_cpp - assert "fabric::cuda::nn::ResetPolicy::ZeroSourceRows" in slstm_cell - assert "state_affine_reset_mode" in dispatcher_cpp - assert "state_affine_workspace_mode" in dispatcher_cpp - assert "reset_packed_source_workspaces" in dispatcher_cpp - assert "state_affine_workspace_buffers" in dispatcher_cpp - assert "workspace_peak_bytes" in dispatcher_cpp - slstm_forward_state_chunk = slstm_cell.split( - "__device__ static void forward_state_chunk", - maxsplit=1, - )[1].split( - "__device__ static void forward_state_lane_value", - maxsplit=1, - )[0] - assert "raw_public_out" not in slstm_forward_state_chunk - assert ( - "for (int in_h" - not in slstm_cell.split("forward_state_chunk", maxsplit=1)[1].split( - "emit_public_chunk", - maxsplit=1, - )[0] - ) - assert ( - "raw_public_out" - not in axon_cell.split("forward_state_chunk", maxsplit=1)[1].split( - "emit_public_chunk", - maxsplit=1, - )[0] + forbidden_legacy_symbols = ( + "FabricExecutionRequest", + "run_registered_execution", + "ExecutionVariantSpec", + "pack_tensor_tree", + "cell_dispatch_registry", + "receiver_owned_stepwise", + "temporal_superop.cuh", ) - assert "fabric::cuda::nn::Builder" in slstm_cell - assert "builder.state_affine" in slstm_cell - assert "fabric::cuda::nn::Builder" in axon_cell - assert "builder.diagonal_recurrence" in axon_cell - assert "DiagonalRecurrenceKind::ComplexExponential2D" in axon_cell - assert "MathBackend" not in nn_ir - assert "CellStateAffineSpec" not in slstm_cell - assert "CellStateAffineSpec" not in axon_cell - for physical_operator_name in ( - "PhysicalOpPlan", - "ReceiverAffineSuperOp", - "TinyMessageSuperOp", - "SparseMessageSuperOp", - "DiagonalRecurrenceSuperOp", - ): - assert physical_operator_name not in slstm_cell - assert physical_operator_name not in axon_cell - assert physical_operator_name not in cell_registration_helpers - assert physical_operator_name not in cell_dispatch_registry - slstm_state_chunk = slstm_cell.split("forward_state_chunk", maxsplit=1)[1].split( - "emit_public_chunk", - maxsplit=1, - )[0] - assert "use_dense_state_affines" in slstm_state_chunk - assert "gate_affine" in slstm_state_chunk - for forbidden_affine_loop_fragment in ( - "input_gate_weight", - "recurrent_gate_weight", - "for (int in_h", - "projected_message[", - ): - assert forbidden_affine_loop_fragment not in slstm_state_chunk + combined = "\n".join((surface, executor, registered, program_kernel)) + for symbol in forbidden_legacy_symbols: + assert symbol not in combined def test_fabric_cuda_nn_stays_generic_and_cuda_native() -> None: repo_root = Path(__file__).resolve().parents[1] cuda_root = repo_root / "src" / "cortical" / "fabric" / "backend" / "cuda" nn_ir = (cuda_root / "nn" / "ir.cuh").read_text() + message_rule_catalog = (cuda_root / "nn" / "message_rule_lowering_catalog.cuh").read_text() cell_backend_py = (repo_root / "src" / "cortical" / "fabric" / "backend" / "cell_backend.py").read_text() cell_specs_py = (repo_root / "src" / "cortical" / "fabric" / "backend" / "cell_specs.py").read_text() - dispatcher_cpp = (cuda_root / "execution" / "dispatcher.cpp").read_text() sequence_surface_root = cuda_root / "sequence_surface" sequence_surface = "\n".join( - (sequence_surface_root / file_name).read_text() - for file_name in ("__init__.py", "surface.py", "backward.py", "replay.py", "support.py") + path.read_text() for path in sorted(sequence_surface_root.rglob("*.py")) if "__pycache__" not in path.parts + ) + cuda_transition_execution = "\n".join( + path.read_text() + for path in sorted((cuda_root / "transition_execution").glob("*.py")) + if path.name != "__init__.py" ) - local_message_cuda = (cuda_root / "message_passing" / "local_message_cuda.py").read_text() - sparse_message_cuda = (cuda_root / "message_passing" / "sparse_message_cuda.py").read_text() - grouped_projection_cuda = (cuda_root / "projection" / "grouped_projection_cuda.py").read_text() - assert not (cuda_root / "population_execution.py").exists() - cuda_transition_execution = (cuda_root / "transition_execution.py").read_text() runtime_dispatch = (repo_root / "src" / "cortical" / "fabric" / "backend" / "runtime_dispatch.py").read_text() dense_affine_header = (cuda_root / "ops" / "dense_affine.cuh").read_text() dense_affine_ops = (cuda_root / "ops" / "dense_affine_kernels.cu").read_text() dense_message_header = (cuda_root / "ops" / "dense_message.cuh").read_text() dense_message_ops = (cuda_root / "ops" / "dense_message_kernels.cu").read_text() - dot_product_message_rule = (cuda_root / "message_rules" / "dot_product.cuh").read_text() python_message_rules = (repo_root / "src" / "cortical" / "fabric" / "backend" / "message_rules.py").read_text() - diagonal_recurrence_triton = (cuda_root / "ops" / "diagonal_recurrence_triton.py").read_text() - scaling_profile = (repo_root / "benchmarks" / "run_fabric_scaling_profile.py").read_text() - rtu_cuda_diag = (repo_root / "src" / "cortical" / "ops" / "rtu" / "cuda" / "rtu_stream_diag_cuda.py").read_text() - rtu_pytorch_diag = (repo_root / "src" / "cortical" / "ops" / "rtu" / "pytorch" / "rtu_stream_diag.py").read_text() - pytorch_axon_cell = ( - repo_root / "src" / "cortical" / "fabric" / "backend" / "pytorch" / "cells" / "axon.py" + public_message_declarations = ( + repo_root / "src" / "cortical" / "fabric" / "message_rules" / "declarations.py" ).read_text() + cell_specs = (repo_root / "src" / "cortical" / "fabric" / "backend" / "cell_specs.py").read_text() for required in ( "struct CellTransitionIR", @@ -367,89 +182,27 @@ def test_fabric_cuda_nn_stays_generic_and_cuda_native() -> None: "enum class PhysicalLayoutMode", "enum class CopyElisionMode", "enum class BiasFusionMode", - "enum class PhysicalExecutorKind", - "physical_plan", - "make_physical_op_plan", - "make_physical_execution_plan", - "boundary_contract", - "applicability_predicate", - "struct WorkspaceLifetime", - "struct WorkspaceAlias", - "workspace_lifetimes", - "workspace_peak_bytes", "class Builder", - "struct AffineSignature", - "enum class AffineBackwardBackend", - "struct BackwardAffineBucket", - "backward_affine_buckets", - "WeightScope weight_scope", - "ShardLayout shard_layout", - "SourceKind source_kind", - "ResetPolicy reset_policy", - "ResetScope reset_scope", - "TransposeMode transpose_a", - "TransposeMode transpose_b", - "int64_t M", - "int64_t K", - "int64_t N", - "EpilogueKind epilogue_kind", - "struct MessageBucketSignature", "struct MessageRuleIR", "class MessageRuleBuilder", - "enum class MessageSourceKind", - "enum class MessageParameterRole", - "enum class MessageSharingScope", - "enum class MessageIndexMapKind", "enum class MessageOpKind", - "enum class MessageBackwardPolicy", - "enum class MessageTapeKind", - "enum class MessageRuleLoweringKind", - "struct MessageRuleSource", - "struct MessageRuleParameter", - "struct MessageRuleNode", - "struct LoweredMessageRule", + "struct MessageRuleLoweringPattern", + "kUnsupportedMessageRuleLowering", "validate_message_rule_ir", - "classify_message_rule_lowering", "lower_message_rule_to_bucket", "emit_projected_message", - "struct StateAffineDeclaration", - "struct DiagonalRecurrenceDeclaration", - "enum class DiagonalRecurrenceKind", - "state_affines", - "diagonal_recurrences", + "parameter_value(", "state_affine(", "diagonal_recurrence(", - "MessageBucketKind bucket_kind", - "MessageTopologyKind topology_kind", - "MessageLayout q_layout", - "MessageLayout kv_layout", - "int64_t degree_or_block", - "DistancePenaltyKind distance_penalty_kind", - "MessageEpilogueKind epilogue_kind", - "enum class StateEpiloguePolicy", - "state_epilogue_policy", - "struct LoweredMessageBucket", "message_buckets", - "make_message_op", - "regular_local_receiver_owned_message_signature", - "degree_bucketed_sparse_message_signature", - "ragged_grouped_sparse_message_signature", - "lower_single_message_op", - "coalesce_message_ops", "select_message_backend", - "coalesce_affine_ops", - "same_stride_regular_subbatch", "select_affine_backend", - "make_backward_affine_bucket", - "make_backward_affine_buckets", - "ReceiverAffineSuperOpBackward", - "TinyMessageSuperOpBackward", - "SparseMessageSuperOpBackward", - "DiagonalRecurrenceSuperOpBackward", - "BackwardGlue", "throw std::invalid_argument", ): assert required in nn_ir + assert "enum class MessageRuleLoweringKind" not in nn_ir + assert "MessageRuleLoweringKind::" not in message_rule_catalog + assert "kDotProductFixedSlotContextGateLoweringId" in message_rule_catalog for builder_entrypoint in ( "linear(", @@ -482,176 +235,45 @@ def test_fabric_cuda_nn_stays_generic_and_cuda_native() -> None: "PhysicalStepPlan", ): assert forbidden not in nn_ir - assert "struct DotProduct" in dot_product_message_rule - assert "message_rule_ir_host" in dot_product_message_rule - assert "fabric::cuda::nn::MessageRuleBuilder" in dot_product_message_rule - assert "emit_projected_message" in dot_product_message_rule + + assert not (cuda_root / "message_rules" / "dot_product.cuh").exists() + assert "class DotProduct" in public_message_declarations + assert "def to_ir(" in public_message_declarations + assert "MessageRuleIR" in public_message_declarations + assert "CellTransitionIR" in cell_specs + assert "TransitionOp(" in cell_specs + assert "dot_product.cuh" not in sequence_surface assert "class MessageRuleIR" in python_message_rules + assert "class MessageRuleBackendSpec" in python_message_rules + assert "register_message_rule_backend_spec_builder" in python_message_rules + assert "build_message_rule_backend_spec" in python_message_rules assert "def default_dot_product_message_rule_ir" in python_message_rules assert "def classify_message_rule" in python_message_rules assert "dot_product_segment_softmax_weighted_sum" in python_message_rules - for forbidden_message_rule_fragment in ( - "PhysicalOpKind", - "PhysicalExecutorKind", - "ExecutionFamily", - "receiver_tile", - "blockDim", - "threadIdx", - "cell_kind", - ): - assert forbidden_message_rule_fragment not in dot_product_message_rule + assert "dot_product_fixed_slot_context_gate" in python_message_rules - cell_ir_block = nn_ir.split("struct CellTransitionIR", maxsplit=1)[1].split( - "struct FabricStepIR", - maxsplit=1, - )[0] - builder_block = nn_ir.split("class Builder", maxsplit=1)[1].split("private:", maxsplit=1)[0] - for physical_executor_fragment in ("PhysicalOpKind", "PhysicalExecutorKind", "PhysicalOpPlan"): - assert physical_executor_fragment not in cell_ir_block - assert physical_executor_fragment not in builder_block - for backward_physical_fragment in ("PhysicalBackwardPlan", "PhysicalBackwardOpPlan", "BackwardFamilyBehavior"): - assert backward_physical_fragment not in cell_backend_py - assert backward_physical_fragment not in cell_specs_py - for backward_range in ( - "fabric.backward.total", - "fabric.backward.receiver_affine", - "fabric.backward.message.receiver", - "fabric.backward.message.sender", - "fabric.backward.message.query_param", - "fabric.backward.grouped_projection", - "fabric.backward.diagonal_recurrence", - "fabric.backward.state_epilogue", - "fabric.backward.public_projection", - "fabric.backward.readout", - "fabric.backward.full_replay_autograd", - ): - backward_range_sources = ( - sequence_surface - + local_message_cuda - + sparse_message_cuda - + grouped_projection_cuda - + cuda_transition_execution - ) - assert backward_range in backward_range_sources - for derived_backward_owner in ( - "fabric.backward.derived.diagonal_recurrence", - "fabric.backward.derived.state_public_epilogue", - "fabric.backward.derived.lowered_projection", - "fabric.backward.derived.boundary_glue", - ): - assert derived_backward_owner in scaling_profile - for attribution_field in ( - "backward_explicit_attribution_coverage", - "backward_derived_attribution_coverage", - "backward_derived_owner_source_events", - "backward_explicit_owner_attributed_cuda_total_us", - "backward_derived_owner_attributed_cuda_total_us", - "Physical ownership gate", - "Backward Owner Priority", - "public/readout projection backward", - "state_public_output_probe", - "state_public_state_probe", + for wrong_layer in ( + cell_backend_py, + cell_specs_py, + runtime_dispatch, ): - assert attribution_field in scaling_profile - for forbidden_broad_backward_owner in ( - "fabric.backward.population_receiver", - "population_receiver_output_probe", - "population_receiver_state_probe", - ): - assert forbidden_broad_backward_owner not in sequence_surface - assert forbidden_broad_backward_owner not in scaling_profile - for wrong_layer_range in ("fabric.backward.", "fabric.physical."): - assert wrong_layer_range not in rtu_cuda_diag - assert wrong_layer_range not in rtu_pytorch_diag - assert wrong_layer_range not in pytorch_axon_cell - for forbidden_bridge in ( - "backend.pytorch.cells", - "rtu_stream_diag", - "get_cell_backend_implementation", - "native_cell_kind", - "try_lower_backend_population_transition_shared", - ): - assert forbidden_bridge not in cuda_transition_execution - assert forbidden_bridge not in runtime_dispatch - assert "fabric.backward.diagonal_recurrence" in cuda_transition_execution - for forbidden_cell_identity_fragment in ("axoncell", "Axon", "H=32", "H == 32", "T=1", "100m", "500m"): - assert forbidden_cell_identity_fragment not in diagonal_recurrence_triton + assert "PhysicalOpKind" not in wrong_layer + assert "PhysicalExecutorKind" not in wrong_layer + assert "blockDim" not in wrong_layer + assert "threadIdx" not in wrong_layer - physical_plan_block = nn_ir.split("struct PhysicalOpPlan", maxsplit=1)[1].split( - "struct PhysicalExecutionPlan", - maxsplit=1, - )[0] - runtime_physical_plan_block = dispatcher_cpp.split( - "fabric::cuda::nn::PhysicalExecutionPlan make_runtime_physical_execution_plan", - maxsplit=1, - )[1].split("void set_physical_op_plan_metadata", maxsplit=1)[0] - receiver_affine_selector_block = dispatcher_cpp.split( - "ReceiverAffineSuperOpPlan select_receiver_affine_superop_plan", - maxsplit=1, - )[1].split("at::Tensor reset_aware_state_affine_input_tensor", maxsplit=1)[0] - diagonal_recurrence_selector_block = dispatcher_cpp.split( - "DiagonalRecurrenceSuperOpPlan select_diagonal_recurrence_superop_plan", - maxsplit=1, - )[1].split("void allocate_receiver_affine_superop_workspace", maxsplit=1)[0] - for whole_surface_backend_fragment in ("MathBackend", "math_backend"): - assert whole_surface_backend_fragment not in physical_plan_block - assert whole_surface_backend_fragment not in runtime_physical_plan_block - assert whole_surface_backend_fragment not in receiver_affine_selector_block - assert whole_surface_backend_fragment not in diagonal_recurrence_selector_block - for forbidden_cell_identity_fragment in ("cell_core_id", "native_cell_kind", "SLSTM", "Axon"): - assert forbidden_cell_identity_fragment not in receiver_affine_selector_block - assert forbidden_cell_identity_fragment not in diagonal_recurrence_selector_block - for explicit_false_positive_gate in ( - "specs.size() != 2", - "unsupported_source_family", - "unsupported_reset_scope", - "first.op.signature.N != second.op.signature.N", - "mixed_chunk_family", - "non_receiver_major_layout", - ): - assert explicit_false_positive_gate in receiver_affine_selector_block - for explicit_false_positive_gate in ( - "unsupported_declaration_count", - "state_affine_surface", - "reduction_boundary", - "unsupported_reset_scope", - "unsupported_shape", - "unsupported_parameter_layout", - ): - assert explicit_false_positive_gate in diagonal_recurrence_selector_block - - assert "DenseAffineLayout" in dense_affine_header - assert "dense_affine_out_cuda" in dense_affine_header - assert "receiver_affine2_superop_out_cuda" in dense_affine_header - assert "receiver_affine2_superop_workspace_out_cuda" in dense_affine_header - receiver_affine_triton = (repo_root / "src/cortical/fabric/backend/cuda/ops/receiver_affine_triton.py").read_text() - assert "receiver_affine2_direct_persistent_out_cuda" in receiver_affine_triton - assert "receiver_affine2_superop_workspace_out_cuda" in dispatcher_cpp - assert "dense_receiver_linear_out_cuda" not in dense_affine_header - assert '#include "cortical/fabric/backend/cuda/nn/ir.cuh"' in dense_affine_header - assert "fabric::cuda::nn::AffineSignature" in dense_affine_ops - assert "fabric::cuda::nn::select_affine_backend" in dense_affine_ops - assert "DenseMessageExecution" in dense_message_header - assert "const fabric::cuda::nn::LoweredMessageBucket& bucket" in dense_message_header - assert "small_cublas_launch_count" in dense_message_header - assert "receiver_chunk_count" in dense_message_header - assert "dense_affine_receiver_major_copy_or_pad_out_cuda" in dense_affine_header - assert "dense_affine_receiver_major_split_last_dim_out_cuda" in dense_affine_header - assert "receiver_major_copy_or_pad_kernel" in dense_affine_ops - assert "receiver_major_split_last_dim_kernel" in dense_affine_ops + assert "dense_affine_out_cuda" in dense_affine_ops assert "dense_regular_local_message_out_cuda" in dense_message_header - assert "dense_receiver_owned_sparse_degree_bucketed_message_out_cuda" in dense_message_header - assert "dense_edge_owned_sparse_degree_bucketed_message_out_cuda" in dense_message_header - assert "dense_receiver_owned_sparse_ragged_grouped_message_out_cuda" in dense_message_header - assert "dense_edge_owned_sparse_ragged_grouped_message_out_cuda" in dense_message_header - assert "spatial_ownership_label" not in dense_message_header - assert "cublasSgemmStridedBatched" in dense_message_ops - assert "pack_sparse_message_keys_kernel" in dense_message_ops - assert "pack_ragged_sparse_message_keys_kernel" in dense_message_ops - assert "ragged_sparse_message_softmax_inplace_kernel" in dense_message_ops - assert "scatter_ordered_receiver_major_message_kernel" in dense_message_ops - assert "sparse_message_softmax_inplace_kernel" in dense_message_ops - assert "pack_sparse_message_values_kernel" in dense_message_ops - assert "receiver_message_aggregate_kernel" not in dense_message_ops - for forbidden in ("SLSTM", "Axon", "cell_core_id"): - assert forbidden not in dense_message_ops + assert "dense regular-local message logits batched GEMM" in dense_message_ops + assert "forward_strategy_id" in cuda_transition_execution + assert "backward_strategy_id" in cuda_transition_execution + assert "reverse.transition.diag_rtu.v1" in cuda_transition_execution + assert "run_registered_temporal_forward_executor_scan" in sequence_surface + assert "RegisteredTemporalExecutorProgram" in sequence_surface + assert "try_flat_bucket_temporal_scan_cuda" not in runtime_dispatch + assert not (cuda_root / "message_passing").exists() + assert not (cuda_root / "projections.py").exists() + assert not (cuda_root / "projection" / "grouped_projection_cuda.py").exists() + assert not (cuda_root / "projection" / "grouped_projection_binding.cpp").exists() + assert not (cuda_root / "projection" / "grouped_projection_kernels.cu").exists() + assert "CellStateAffineSpec" not in dense_affine_header diff --git a/tests/test_fabric_public_api.py b/tests/test_fabric_public_api.py index 42684912..bdc1e364 100644 --- a/tests/test_fabric_public_api.py +++ b/tests/test_fabric_public_api.py @@ -15,7 +15,14 @@ ) from cortical.fabric.anatomy import init from cortical.fabric.blueprint import normalize -from cortical.fabric.config import CellPopulationConfig, Config +from cortical.fabric.config import ( + CellPopulationConfig, + Config, + FabricInterfaceConfig, + MessageConfig, + PopulationLayoutConfig, + RuntimeExecutionConfig, +) from cortical.fabric.graphs import lattice2d from cortical.fabric.runtime import Model, Runtime, build @@ -44,11 +51,12 @@ def test_fabric_public_api_exposes_current_names_only() -> None: def test_internal_fabric_runtime_contract_still_builds_runtime() -> None: spec = init( Config( - width=4, - height=4, - hidden_size=8, - cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, - population_mix={"slstm": 1.0}, + graph=lattice2d.Graph(width=4, height=4), + interface=FabricInterfaceConfig(hidden_size=8), + populations=PopulationLayoutConfig( + cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, + population_mix={"slstm": 1.0}, + ), ) ) @@ -107,20 +115,14 @@ def test_blueprint_normalizes_to_current_lattice_spec() -> None: blueprint_spec = normalize(blueprint) config_spec = init( Config( - width=6, - height=4, - hidden_size=8, - d_public=8, - d_msg=8, - d_slot=16, - head_dim=8, - wrap=False, - cell_populations={"core": CellPopulationConfig(cell_type="slstm")}, - population_mix={"core": 1.0}, - input_cell_indices=graph.input_nodes(), - output_cell_indices=graph.output_nodes(), - default_k=1, - k_max=1, + graph=graph, + interface=FabricInterfaceConfig(hidden_size=8, public_dim=8, message_dim=8, slot_dim=16), + message=MessageConfig(num_heads=1, head_dim=8), + populations=PopulationLayoutConfig( + cell_populations={"core": CellPopulationConfig(cell_type="slstm")}, + population_mix={"core": 1.0}, + ), + execution=RuntimeExecutionConfig(default_k=1, k_max=1), ) ) @@ -167,11 +169,48 @@ def test_message_rule_declaration_lowers_to_projected_message_boundary() -> None ir = rule.to_ir(kv_group_count=4, cell_count=16) + assert ir.name == "dot_product" + assert ir.lowering_kind == "dot_product_fixed_slot_context_nudge" + assert ir.output_boundary == "projected_message" + assert "receiver_public_prev:reset=zero_source_rows:scope=batch_row" in ir.source_signature + assert "sender_slot" in ir.source_signature + assert "normalize->projected_message" in ir.expression_signature + + +def test_dot_product_dynamic_key_value_math_lowers_to_legacy_ir() -> None: + rule = message_rules.DotProduct(head_dim=8, math="dynamic_key_value") + + ir = rule.to_ir(kv_group_count=4, cell_count=16) + assert ir.name == "dot_product" assert ir.lowering_kind == "dot_product_segment_softmax_weighted_sum" assert ir.output_boundary == "projected_message" +def test_dot_product_context_nudge_math_lowers_to_distinct_ir() -> None: + rule = message_rules.DotProduct(head_dim=8, math="fixed_slot_context_nudge") + + ir = rule.to_ir(kv_group_count=4, cell_count=16) + + assert ir.name == "dot_product" + assert ir.lowering_kind == "dot_product_fixed_slot_context_nudge" + assert "receiver_public_prev:reset=zero_source_rows:scope=batch_row" in ir.source_signature + assert "sender_slot" in ir.source_signature + assert "normalize->projected_message" in ir.expression_signature + + +def test_dot_product_context_gate_math_lowers_to_distinct_ir() -> None: + rule = message_rules.DotProduct(head_dim=8, math="fixed_slot_context_gate") + + ir = rule.to_ir(kv_group_count=4, cell_count=16) + + assert ir.name == "dot_product" + assert ir.lowering_kind == "dot_product_fixed_slot_context_gate" + assert "receiver_public_prev:reset=zero_source_rows:scope=batch_row" in ir.source_signature + assert "sender_slot" in ir.source_signature + assert "context_gate(receiver_public_prev)" in ir.expression_signature + + def test_blueprint_message_sharing_lowers_to_kv_region_shape() -> None: blueprint = Blueprint( interface=Interface(public_dim=8, message_dim=8), @@ -190,8 +229,17 @@ def test_blueprint_message_sharing_lowers_to_kv_region_shape() -> None: spec = normalize(blueprint) - assert spec.config.projection_region_shape == (2, 4) + assert spec.config.message.projection_region_shape == (2, 4) assert spec.num_kv_groups == 8 + assert spec.message_rule is not None + assert spec.message_rule.name == "dot_product" + assert spec.message_rule.lowering_kind == "dot_product_fixed_slot_context_nudge" + assert "input_sender_value_weight:projection:sender_group_shared:groups=8" in ( + spec.message_rule.parameter_sharing_signature + ) + assert "recurrent_sender_value_weight:projection:sender_group_shared:groups=8" in ( + spec.message_rule.parameter_sharing_signature + ) def test_message_rule_rejects_unsupported_public_contracts() -> None: diff --git a/tests/test_fabric_runtime.py b/tests/test_fabric_runtime.py index e595ff41..a58ff627 100644 --- a/tests/test_fabric_runtime.py +++ b/tests/test_fabric_runtime.py @@ -5,18 +5,13 @@ from dataclasses import replace from typing import cast -import cortical.fabric.backend.cuda.message_passing.local_message_cuda as local_message_mod -import cortical.fabric.backend.cuda.message_passing.sparse_message_cuda as sparse_message_mod -import cortical.fabric.backend.cuda.runtime_ops as cuda_runtime_ops -import cortical.fabric.backend.cuda.transition_execution as cuda_transition_execution +import cortical.fabric.backend.cuda.transition_execution.projection as cuda_transition_projection import cortical.fabric.backend.pytorch.population_execution as pytorch_population_execution import pytest import torch from cortical.fabric.anatomy import init from cortical.fabric.backend import ExecutionFamily, TapeMode, TapePolicy from cortical.fabric.backend.cell_backend import TransitionOp -from cortical.fabric.backend.cuda import fabric_grouped_projection_cuda -from cortical.fabric.backend.cuda.message_passing.local_message_cuda import fabric_local_message_partitioned_cuda from cortical.fabric.backend.cuda.ops import ( dense_affine_cuda, dense_affine_out_cuda, @@ -43,21 +38,23 @@ factorized_recurrent_input_projection_grads_cuda, ) from cortical.fabric.backend.cuda.projection.receiver_major_gates import receiver_major_projection_backward_gate -from cortical.fabric.backend.cuda.reference.slstm_parity_reference import ( - _local_message_partitioned_step_backward_manual, -) -from cortical.fabric.backend.cuda.sequence_surface.policy import CudaMemoryBudget -from cortical.fabric.backend.cuda.sequence_surface.temporal_backward import ( - compute_temporal_bucket_step_artifacts, - run_temporal_bucket_step_backward, -) +from cortical.fabric.backend.cuda.sequence_surface.runtime.policy import CudaMemoryBudget from cortical.fabric.backend.planner import SequenceSurfaceRoute from cortical.fabric.backend.pytorch.message_passing import ( compute_messages_sequence_subset_partitioned_raw as pytorch_compute_messages_sequence_subset_partitioned_raw, ) from cortical.fabric.backend.pytorch.cells.slstm import lower_slstm_transition_op from cortical.fabric.cells import build_cell_population_module -from cortical.fabric.config import CellPopulationConfig, Config +from cortical.fabric.config import ( + CellPopulationConfig, + Config, + FabricInterfaceConfig, + InitializationConfig, + MessageConfig, + PopulationLayoutConfig, + RuntimeExecutionConfig, +) +from cortical.fabric.graphs import lattice2d from cortical.fabric.contracts.cells import reset_backend_tensor_rows from cortical.fabric.runtime import Runtime, build from tensordict import TensorDict, TensorDictBase @@ -65,6 +62,94 @@ DENSE_AFFINE_BACKENDS = {"large_gemm", "batched_gemm", "grouped_gemm"} +def _lattice_test_config( + *, + width: int = 4, + height: int = 4, + hidden_size: int = 8, + head_dim: int = 4, + cell_populations: dict[str, CellPopulationConfig] | None = None, + population_mix: dict[str, float] | None = None, + cell_arrangement: str = "random", + local_radius: float = 1.5, + patch_edges_per_cell: int = 0, + patch_min_dist: float = 3.0, + patch_max_dist: float = 4.0, + projection_region_shape: tuple[int, ...] | None = (2, 2), + input_cell_indices: tuple[int, ...] | None = None, + output_cell_indices: tuple[int, ...] | None = None, + input_band_width: int = 1, + output_band_width: int = 1, + graph_edges: tuple[tuple[int, int], ...] | None = None, + kv_group_ids: tuple[int, ...] | None = None, + wrap: bool = True, + conduction_speed: float | None = None, + max_delay: int | None = None, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, + k_max: int = 4, + default_k: int = 2, + backend: str = "auto", + seed: int = 0, +) -> Config: + cell_populations = cell_populations or {"slstm": CellPopulationConfig(cell_type="slstm")} + population_mix = population_mix or {next(iter(cell_populations)): 1.0} + if graph_edges is None: + connectivity: tuple[object, ...] = (lattice2d.LocalRadius(local_radius),) + if patch_edges_per_cell > 0: + connectivity = ( + *connectivity, + lattice2d.PatchEdges( + per_cell=patch_edges_per_cell, + min_distance=patch_min_dist, + max_distance=patch_max_dist, + ), + ) + else: + connectivity = (lattice2d.ExplicitEdges(edges=graph_edges, kv_group_ids=kv_group_ids),) + graph = lattice2d.Graph( + width=width, + height=height, + wrap=wrap, + conduction_speed=conduction_speed, + max_delay=max_delay, + inputs=input_cell_indices if input_cell_indices is not None else lattice2d.XBand("low", input_band_width), + outputs=( + output_cell_indices + if output_cell_indices is not None + else lattice2d.Output(lattice2d.XBand("high", output_band_width)) + ), + connectivity=connectivity, + ) + return Config( + graph=graph, + interface=FabricInterfaceConfig(hidden_size=hidden_size), + message=MessageConfig(head_dim=head_dim, projection_region_shape=projection_region_shape), + populations=PopulationLayoutConfig( + cell_populations=cell_populations, + population_mix=population_mix, + cell_arrangement=cell_arrangement, # type: ignore[arg-type] + ), + execution=RuntimeExecutionConfig( + backend=backend, # type: ignore[arg-type] + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + k_max=k_max, + default_k=default_k, + ), + initialization=InitializationConfig(seed=seed), + ) + + +def _spec_with_backend(spec, backend_name: str): + return replace( + spec, + config=spec.config.model_copy( + update={"execution": spec.config.execution.model_copy(update={"backend": backend_name})} + ), + ) + + def _diagonal_recurrence_reference( cell_input: torch.Tensor, hc1: torch.Tensor, @@ -220,17 +305,24 @@ def test_reset_backend_tensors_rows_cuda_accepts_strided_time_step_masks(): @pytest.mark.parametrize( - ("batch_size", "receivers", "input_dim", "output_dim", "biased", "expected_enabled", "expected_mode", "block_b"), - [ - (1, 32256, 32, 32, True, True, "receiver_major_projection_small_batch_cuda", 16), - (64, 31808, 32, 16, True, True, "receiver_major_projection_cuda", 64), - (8, 32256, 32, 32, True, False, "receiver_major_projection_demoted", 0), - (64, 1024, 64, 32, True, True, "receiver_major_projection_cuda", 64), - (16, 2048, 160, 144, True, True, "receiver_major_projection_cuda", 64), - (64, 128, 64, 32, True, False, "receiver_major_projection_demoted", 0), - (4, 65536, 32, 32, False, True, "receiver_major_projection_small_batch_cuda", 16), - (32, 65536, 32, 64, False, True, "receiver_major_projection_cuda", 32), - ], + ( + "batch_size", + "receivers", + "input_dim", + "output_dim", + "biased", + "expected_enabled", + "expected_mode", + "block_b", + ), + ( + (0, 8, 32, 32, True, False, "receiver_major_projection_demoted", 0), + (2, 100, 512, 512, True, True, "receiver_major_projection_small_batch_cuda", 16), + (4, 512, 128, 128, False, True, "receiver_major_projection_small_batch_cuda", 16), + (32, 128, 128, 64, True, True, "receiver_major_projection_cuda", 32), + (16, 128, 128, 128, True, True, "receiver_major_projection_cuda", 64), + (8, 16, 32, 32, True, False, "receiver_major_projection_demoted", 0), + ), ) def test_receiver_major_projection_backward_gate_is_work_shape_based( batch_size: int, @@ -259,26 +351,53 @@ def test_receiver_major_projection_backward_gate_is_work_shape_based( assert gate.demotion_reason.startswith("receiver_major_projection_backward_demoted:") -def _fail_torch_cat(*_args, **_kwargs): - raise AssertionError("torch.cat should not be used") +def _make_spec( + *, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, + k_max: int = 4, + default_k: int = 2, +): + populations = { + "slstm": CellPopulationConfig(cell_type="slstm"), + "axoncell": CellPopulationConfig(cell_type="axoncell"), + } + return init( + Config( + graph=lattice2d.Graph(width=4, height=4), + interface=FabricInterfaceConfig(hidden_size=8), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations=populations, + population_mix={"slstm": 0.5, "axoncell": 0.5}, + ), + execution=RuntimeExecutionConfig( + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + k_max=k_max, + default_k=default_k, + ), + initialization=InitializationConfig(seed=11), + ) + ) -def _make_spec(): +def _make_three_population_spec(): return init( Config( - width=4, - height=4, - hidden_size=8, - cell_populations={ - "slstm": CellPopulationConfig(cell_type="slstm"), - "axoncell": CellPopulationConfig(cell_type="axoncell"), - }, - population_mix={"slstm": 0.5, "axoncell": 0.5}, - patch_edges_per_cell=0, - projection_region_shape=(2, 2), - k_max=4, - default_k=2, - seed=11, + graph=lattice2d.Graph(width=4, height=4), + interface=FabricInterfaceConfig(hidden_size=8), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={ + "left": CellPopulationConfig(cell_type="slstm"), + "middle": CellPopulationConfig(cell_type="axoncell"), + "right": CellPopulationConfig(cell_type="slstm"), + }, + population_mix={"left": 0.34, "middle": 0.33, "right": 0.33}, + ), + execution=RuntimeExecutionConfig(k_max=1, default_k=1), + initialization=InitializationConfig(seed=41), ) ) @@ -287,19 +406,25 @@ def _make_axon_spec( *, k_max: int = 1, default_k: int = 1, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, ): return init( Config( - width=4, - height=4, - hidden_size=8, - cell_populations={"axoncell": CellPopulationConfig(cell_type="axoncell")}, - population_mix={"axoncell": 1.0}, - patch_edges_per_cell=0, - projection_region_shape=(2, 2), - k_max=k_max, - default_k=default_k, - seed=11, + graph=lattice2d.Graph(width=4, height=4), + interface=FabricInterfaceConfig(hidden_size=8), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={"axoncell": CellPopulationConfig(cell_type="axoncell")}, + population_mix={"axoncell": 1.0}, + ), + execution=RuntimeExecutionConfig( + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + k_max=k_max, + default_k=default_k, + ), + initialization=InitializationConfig(seed=11), ) ) @@ -312,23 +437,34 @@ def _make_slstm_spec( conduction_speed: float | None = None, k_max: int = 1, default_k: int = 1, + gradient_horizon_steps: int | None = None, + checkpoint_steps: int | None = None, ): + connectivity: list[object] = [lattice2d.LocalRadius(1.5)] + if patch_edges_per_cell > 0: + connectivity.append(lattice2d.PatchEdges(per_cell=patch_edges_per_cell, min_distance=3.0, max_distance=4.0)) return init( Config( - width=4, - height=4, - hidden_size=hidden_size, - cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, - population_mix={"slstm": 1.0}, - patch_edges_per_cell=patch_edges_per_cell, - patch_min_dist=3.0 if patch_edges_per_cell > 0 else 0.0, - patch_max_dist=4.0 if patch_edges_per_cell > 0 else 0.0, - projection_region_shape=(2, 2), - max_delay=max_delay, - conduction_speed=conduction_speed, - k_max=k_max, - default_k=default_k, - seed=11, + graph=lattice2d.Graph( + width=4, + height=4, + conduction_speed=conduction_speed, + max_delay=max_delay, + connectivity=tuple(connectivity), + ), + interface=FabricInterfaceConfig(hidden_size=hidden_size), + message=MessageConfig(projection_region_shape=(2, 2)), + populations=PopulationLayoutConfig( + cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, + population_mix={"slstm": 1.0}, + ), + execution=RuntimeExecutionConfig( + gradient_horizon_steps=gradient_horizon_steps, + checkpoint_steps=checkpoint_steps, + k_max=k_max, + default_k=default_k, + ), + initialization=InitializationConfig(seed=11), ) ) @@ -671,8 +807,7 @@ def make_tensor(shape: tuple[int, ...]) -> torch.Tensor: def _build_runtime_for_backend(spec, backend_name: str) -> Runtime: - backend_spec = copy.deepcopy(spec) - backend_spec.config.backend = backend_name + backend_spec = _spec_with_backend(copy.deepcopy(spec), backend_name) runtime = build(backend_spec).cuda() assert isinstance(runtime, Runtime) return runtime @@ -686,8 +821,7 @@ def _build_backend_runtime_pair(spec) -> tuple[Runtime, Runtime]: def _build_fabric_model_for_backend(spec, backend_name: str, *, d_hidden: int): - backend_spec = copy.deepcopy(spec) - backend_spec.config.backend = backend_name + backend_spec = _spec_with_backend(copy.deepcopy(spec), backend_name) return build(backend_spec, d_hidden=d_hidden).cuda() @@ -707,20 +841,13 @@ def _make_10m_fabric_spec(family: str): raise ValueError(f"Unsupported 10M parity family {family!r}") return ( init( - Config( + _lattice_test_config( width=width, height=height, hidden_size=32, cell_populations={family: CellPopulationConfig(cell_type=family)}, population_mix={family: 1.0}, cell_arrangement="x_bands", - local_radius=1.5, - patch_edges_per_cell=0, - graph_edges=None, - kv_group_ids=None, - input_band_width=1, - output_band_width=1, - readout_pool="mean", default_k=1, k_max=1, projection_region_shape=(1, 2), @@ -794,8 +921,8 @@ def plan_sequence_surface_route( True, "test_forced_flat_bucket_single_population_executor", active_populations, - surface_key="flat_bucket_sequence_surface", - implementation_executor="flat_transition_buckets", + surface_key="registered_temporal_sequence_surface", + implementation_executor="registered_temporal_program", bucket_count=len(active_populations), ) @@ -805,21 +932,25 @@ def plan_sequence_surface_route( def _assert_generic_flat_bucket_sequence_record( record, *, - scan_executor: str | None = "flat_bucket_temporal_scan", + scan_executor: str | None = "registered_temporal_fused_forward_program_cuda", ) -> None: assert record is not None assert record.backend_name == "cuda" - assert record.surface_key == "flat_bucket_sequence_surface" + assert record.surface_key == "registered_temporal_sequence_surface" assert record.cell_type == "bucketed" - assert set(record.execution_families) == {"message", "transition_buckets", "readout"} + assert set(record.execution_families) == {"message_program", "transition_program", "readout_program"} assert record.math_backends == ("cuda_tensor_ops",) assert record.launch_temporal_executions == ("temporal_bucket_sequence",) if scan_executor is not None: assert record.launch_scan_implementations == (scan_executor,) assert scan_executor in record.physical_op_executors - assert "flat_bucket_sequence_surface" in record.physical_op_executors + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_temporal_scan_outer_steps == (str(record.time_steps),) + assert record.launch_temporal_scan_inner_steps == (str(record.inner_steps),) + assert record.launch_temporal_scan_physical_steps == (str(record.time_steps * record.inner_steps),) + assert "registered_temporal_sequence_surface" in record.physical_op_executors assert "shared_graph_message" in record.physical_op_executors - assert any(executor.startswith("transition_buckets=") for executor in record.physical_op_executors) + assert any(executor.startswith("transition_program=") for executor in record.physical_op_executors) assert "readout_projection" in record.physical_op_executors assert "single_bucket_sequence_executor" not in record.physical_op_executors assert "message" in record.physical_op_kinds @@ -829,6 +960,14 @@ def _assert_generic_flat_bucket_sequence_record( assert "fixed_active_spatial_region" in record.physical_boundary_contracts +def _assert_registered_reverse_program_window_owned(record) -> None: + reverse_scan_aliases = record.workspace_aliases + record.backward_workspace_aliases + assert "flat_bucket_temporal_backward_binding_abi:registered_executor_binding_rows" in reverse_scan_aliases + assert "flat_bucket_temporal_reverse_scan_owner:registered_reverse_program_window" in reverse_scan_aliases + assert "flat_bucket_temporal_reverse_scan_owner:registered_reverse_executor_bindings" not in reverse_scan_aliases + assert "flat_bucket_temporal_reverse_scan_owner:python_host_reverse_loop" not in reverse_scan_aliases + + def _floating_state_grad_tensors(state: TensorDictBase) -> dict[str, torch.Tensor]: grads: dict[str, torch.Tensor] = {} for population_name, population_state in state.items(): @@ -874,12 +1013,20 @@ def _clone_fabric_state_with_grad(state: TensorDictBase) -> TensorDict: return cloned +def _fabric_state_grads(state: TensorDictBase) -> dict[str, torch.Tensor]: + grads = _floating_state_grad_tensors(state) + cells = state.get("cells") + if torch.is_tensor(cells) and cells.grad is not None: + grads["cells"] = cells.grad.detach().clone() + return grads + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for small-hidden Fabric parity test") @pytest.mark.parametrize( ("spec_factory", "surface_key"), [ - (_make_slstm_spec, "flat_bucket_sequence_surface"), - (_make_axon_spec, "flat_bucket_sequence_surface"), + (_make_slstm_spec, "registered_temporal_sequence_surface"), + (_make_axon_spec, "registered_temporal_sequence_surface"), ], ) @pytest.mark.parametrize("use_resets", [False, True]) @@ -908,7 +1055,7 @@ def test_fabric_cuda_single_population_small_h_t_gt1_training_matches_pytorch_pa ) y_cuda, _state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) - y_pytorch, _state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=1) + y_pytorch, _state_pytorch = _fabric_forward_reference(pytorch_model, x_pytorch, state=None, resets=resets, k=1) torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-4, atol=2e-5) loss_cuda = y_cuda.square().mean() @@ -1040,6 +1187,64 @@ def test_fabric_stream_sequence_matches_repeated_steps_for_cells_and_state(): torch.testing.assert_close(state_seq["cells"], y_seq[:, -1], rtol=0.0, atol=0.0) +def test_fabric_runtime_records_planner_temporal_plan_for_forward_cells(): + runtime = build(_make_spec()) + batch_size = 2 + time_steps = 3 + boundary_seq = torch.randn(batch_size, time_steps, runtime.input_cell_idx.numel(), runtime.hidden_size) + + runtime.forward_cells(boundary_input=boundary_seq, state=None, k=2, materialize_final_state=False) + + temporal_plan = runtime.last_temporal_execution_plan + assert temporal_plan is not None + assert temporal_plan.schedule.outer_time_steps == time_steps + assert temporal_plan.schedule.inner_steps == 2 + assert temporal_plan.output_request.selector_kind == "all_outer_steps" + assert temporal_plan.output_request.emitted_output_count == time_steps + assert temporal_plan.output_request.first_outer_step == 0 + assert temporal_plan.output_request.outer_stride == 1 + assert temporal_plan.output_request.first_physical_step == 1 + assert temporal_plan.output_request.physical_stride == 2 + assert temporal_plan.boundary.input_boundary == "boundary" + assert temporal_plan.boundary.output_contract == "full_cells" + assert temporal_plan.executor.selected_implementation == "none" + assert temporal_plan.static_values.static_value_mode == "detached_shared_values" + assert temporal_plan.engine.forward_owner == "pytorch_reference" + assert temporal_plan.engine.backward_owner == "pytorch_reference" + assert temporal_plan.engine.status == "pytorch_reference" + assert not temporal_plan.supported + assert temporal_plan.reason == "device_not_cuda" + record = runtime.last_backend_execution + assert record is not None + assert record.temporal_plan_schedule_kinds == ("scalar_constant_k",) + assert record.temporal_plan_inner_steps == ("2",) + assert record.temporal_plan_total_scan_steps == ("6",) + assert record.temporal_plan_output_selectors == ("all_outer_steps",) + assert record.temporal_plan_output_first_outer_steps == ("0",) + assert record.temporal_plan_output_outer_strides == ("1",) + assert record.temporal_plan_output_counts == ("3",) + assert record.temporal_plan_output_first_physical_steps == ("1",) + assert record.temporal_plan_output_physical_strides == ("2",) + assert record.temporal_plan_output_surfaces == ("full_cells",) + assert record.temporal_plan_readout_surfaces == ("cells",) + assert record.temporal_plan_output_materializations == ("outputs_only",) + assert record.temporal_plan_autograd_seed_kinds == ("emitted_output_grad",) + assert "transition_primitive_adjoint" in record.temporal_plan_required_backward_surfaces[0] + assert record.temporal_plan_checkpoint_policy_basis == ("emitted_output_schedule",) + assert record.temporal_plan_fresh_state_population_cache == ("False",) + assert record.temporal_plan_fresh_state_population_cache_reasons == ("backend_not_cuda",) + assert record.temporal_plan_static_value_modes == ("detached_shared_values",) + assert record.temporal_plan_native_static_materialization == ("False",) + assert record.temporal_plan_static_include_full_cell_kv == ("True",) + assert record.temporal_plan_static_detach_training == ("True",) + assert record.temporal_plan_backend_names == ("pytorch",) + assert record.temporal_plan_selected_implementations == ("none",) + assert record.temporal_plan_forward_owners == ("pytorch_reference",) + assert record.temporal_plan_backward_owners == ("pytorch_reference",) + assert record.temporal_plan_target_owners == ("registered_temporal_executor_bindings",) + assert record.temporal_plan_engine_statuses == ("pytorch_reference",) + + @pytest.mark.parametrize("no_grad_ctx", [torch.no_grad, torch.inference_mode], ids=["no_grad", "inference_mode"]) def test_fabric_stream_step_inference_mode_matches_default_path(no_grad_ctx): runtime = build(_make_spec()) @@ -1056,13 +1261,10 @@ def test_fabric_stream_step_inference_mode_matches_default_path(no_grad_ctx): def test_fabric_backend_cuda_requires_supported_cuda_surface(): spec = init( - Config( - width=4, - height=4, + _lattice_test_config( hidden_size=8, cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, population_mix={"slstm": 1.0}, - patch_edges_per_cell=0, projection_region_shape=(2, 2), k_max=1, default_k=1, @@ -1081,20 +1283,12 @@ def test_fabric_backend_cuda_requires_supported_cuda_surface(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") -def test_fabric_backend_pytorch_forces_pure_torch_path_on_cuda(monkeypatch: pytest.MonkeyPatch): +def test_fabric_backend_pytorch_forces_pure_torch_path_on_cuda(): spec = _make_slstm_spec(hidden_size=16) - spec.config.backend = "pytorch" + spec = _spec_with_backend(spec, "pytorch") runtime = build(spec).cuda() boundary_seq = torch.randn(2, 3, runtime.input_cell_idx.numel(), runtime.hidden_size, device="cuda") - def fail_native(*_args, **_kwargs): - raise AssertionError("native CUDA message kernel should not run for Fabric backend='pytorch'") - - monkeypatch.setattr(cuda_runtime_ops, "fabric_local_message_cuda", fail_native) - monkeypatch.setattr(cuda_runtime_ops, "fabric_sparse_message_cuda", fail_native) - monkeypatch.setattr(cuda_runtime_ops, "fabric_local_message_partitioned_cuda", fail_native) - monkeypatch.setattr(cuda_runtime_ops, "fabric_sparse_message_partitioned_cuda", fail_native) - y_cells, next_state = runtime.forward_cells(boundary_input=boundary_seq, state=None, k=1) assert y_cells.shape[:3] == (2, 3, runtime.coords.shape[0]) @@ -1111,21 +1305,18 @@ def fail_native(*_args, **_kwargs): _make_axon_spec, ], ) -def test_fabric_supported_cuda_surface_selection(spec_factory) -> None: +def test_fabric_supported_cuda_route_uses_registered_temporal_program(spec_factory) -> None: runtime = build(spec_factory()) route = runtime._plan_sequence_surface_route( k=1, device=torch.device("cuda"), dtype=torch.float32, ) - backend_population_name = runtime._select_output_cells_stream_backend_population( - k=1, - ) assert route.kind == "sequence_surface" assert route.executor == "temporal_bucket_sequence" - assert route.surface_key == "flat_bucket_sequence_surface" - assert route.implementation_executor == "flat_transition_buckets" + assert route.surface_key == "registered_temporal_sequence_surface" + assert route.implementation_executor == "registered_temporal_program" assert route.bucket_count == 1 assert route.supported assert runtime._supports_cuda_backend_sequence_surface( @@ -1133,16 +1324,7 @@ def test_fabric_supported_cuda_surface_selection(spec_factory) -> None: device=torch.device("cuda"), dtype=torch.float32, ) - assert backend_population_name is not None - - surface = runtime._select_backend_sequence_surface( - training=False, - k=1, - device=torch.device("cuda"), - dtype=torch.float32, - backend_population_name=backend_population_name, - ) - assert surface is None + assert runtime._select_output_cells_stream_backend_population(k=1) is None def test_fabric_mixed_population_sequence_route_is_planner_owned() -> None: @@ -1156,188 +1338,29 @@ def test_fabric_mixed_population_sequence_route_is_planner_owned() -> None: assert route.kind == "sequence_surface" assert route.executor == "temporal_bucket_sequence" - assert route.implementation_executor == "flat_transition_buckets" + assert route.implementation_executor == "registered_temporal_program" assert route.bucket_count == 2 assert route.supported assert route.active_populations == ("slstm", "axoncell") -def test_fabric_state_public_backward_owners_are_population_bucket_scoped() -> None: - runtime = build(_make_spec()) - - assert runtime._state_public_backward_profile_name_for_population("slstm") == "fabric.backward.state_epilogue" - assert ( - runtime._state_public_backward_profile_name_for_population("axoncell") == "fabric.backward.diagonal_recurrence" - ) - - @pytest.mark.parametrize("spec_factory", [_make_slstm_spec, _make_axon_spec]) -def test_fabric_supported_cuda_surface_disables_silent_fallback(spec_factory) -> None: +def test_fabric_supported_cuda_route_has_no_legacy_cell_surface(spec_factory) -> None: runtime = build(spec_factory()) + route = runtime._plan_sequence_surface_route( + k=1, + device=torch.device("cuda"), + dtype=torch.float32, + ) assert runtime._supports_cuda_backend_sequence_surface( k=1, device=torch.device("cuda"), dtype=torch.float32, ) - assert ( - runtime._select_backend_sequence_surface( - training=True, - k=1, - device=torch.device("cuda"), - dtype=torch.float32, - backend_population_name=None, - ) - is None - ) - - -def test_sparse_partitioned_forward_wrapper_avoids_torch_cat(monkeypatch: pytest.MonkeyPatch) -> None: - q = torch.randn(3, 4) - input_k = torch.randn(2, 5, 4) - input_v = torch.randn(2, 5, 6) - recurrent_k = torch.randn(2, 7, 4) - recurrent_v = torch.randn(2, 7, 6) - neighbor_idx = torch.zeros(3, 2, dtype=torch.long) - neighbor_valid = torch.ones(3, 2, dtype=torch.bool) - edge_distance = torch.zeros(3, 2) - edge_delay = torch.zeros(3, 2, dtype=torch.long) - step_flat = torch.ones(2, dtype=torch.long) - captured: dict[str, tuple[int, ...]] = {} - - def fake_combined( - q_arg: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - *_args, - **_kwargs, - ) -> torch.Tensor: - del q_arg - captured["k_all_shape"] = tuple(k_all.shape) - captured["v_all_shape"] = tuple(v_all.shape) - return torch.zeros(input_k.shape[0], neighbor_idx.shape[0], input_v.shape[-1]) - - monkeypatch.setattr(sparse_message_mod, "fabric_sparse_message_cuda", fake_combined) - monkeypatch.setattr(torch, "cat", _fail_torch_cat) - - out = sparse_message_mod.fabric_sparse_message_partitioned_cuda( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=1.0, - use_delay=False, - ) - - assert tuple(out.shape) == (2, 3, 6) - assert captured["k_all_shape"] == (2, 12, 4) - assert captured["v_all_shape"] == (2, 12, 6) - - -def test_sparse_partitioned_backward_sender_wrapper_avoids_torch_cat(monkeypatch: pytest.MonkeyPatch) -> None: - grad_msg = torch.randn(2, 3, 6) - q = torch.randn(3, 4) - input_k = torch.randn(2, 5, 4) - input_v = torch.randn(2, 5, 6) - recurrent_k = torch.randn(2, 7, 4) - recurrent_v = torch.randn(2, 7, 6) - neighbor_idx = torch.zeros(3, 2, dtype=torch.long) - neighbor_valid = torch.ones(3, 2, dtype=torch.bool) - edge_distance = torch.zeros(3, 2) - edge_delay = torch.zeros(3, 2, dtype=torch.long) - step_flat = torch.ones(2, dtype=torch.long) - captured: dict[str, tuple[int, ...]] = {} - - def fake_combined( - grad_msg_arg: torch.Tensor, - q_arg: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - *_args, - **_kwargs, - ) -> tuple[torch.Tensor, torch.Tensor]: - del grad_msg_arg, q_arg - captured["k_all_shape"] = tuple(k_all.shape) - captured["v_all_shape"] = tuple(v_all.shape) - return torch.zeros_like(k_all), torch.zeros_like(v_all) - - monkeypatch.setattr(sparse_message_mod, "fabric_sparse_message_backward_sender_cuda", fake_combined) - monkeypatch.setattr(torch, "cat", _fail_torch_cat) - - grads = sparse_message_mod.fabric_sparse_message_partitioned_backward_sender_cuda( - grad_msg, - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - neighbor_idx, - neighbor_valid, - edge_distance, - edge_delay, - step_flat, - distance_scale=1.0, - use_delay=False, - ) - - assert [tuple(grad.shape) for grad in grads] == [(2, 5, 4), (2, 5, 6), (2, 7, 4), (2, 7, 6)] - assert captured["k_all_shape"] == (2, 12, 4) - assert captured["v_all_shape"] == (2, 12, 6) - - -def test_local_partitioned_backward_receiver_wrapper_avoids_torch_cat(monkeypatch: pytest.MonkeyPatch) -> None: - grad_msg = torch.randn(2, 3, 6) - q = torch.randn(3, 4) - input_k = torch.randn(2, 5, 4) - input_v = torch.randn(2, 5, 6) - recurrent_k = torch.randn(2, 7, 4) - recurrent_v = torch.randn(2, 7, 6) - receiver_sender_idx = torch.zeros(3, 2, dtype=torch.int32) - offset_distance = torch.zeros(2) - offset_delay = torch.zeros(2, dtype=torch.long) - step_flat = torch.ones(2, dtype=torch.long) - captured: dict[str, tuple[int, ...]] = {} - - def fake_combined( - grad_msg_arg: torch.Tensor, - q_arg: torch.Tensor, - k_all: torch.Tensor, - v_all: torch.Tensor, - *_args, - **_kwargs, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - del grad_msg_arg, q_arg - captured["k_all_shape"] = tuple(k_all.shape) - captured["v_all_shape"] = tuple(v_all.shape) - return torch.zeros_like(q), torch.zeros(3), torch.zeros(3), torch.zeros(3) - - monkeypatch.setattr(local_message_mod, "fabric_local_message_backward_receiver_cuda", fake_combined) - monkeypatch.setattr(torch, "cat", _fail_torch_cat) - - grad_q, *_cache = local_message_mod.fabric_local_message_partitioned_backward_receiver_cuda( - grad_msg, - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - receiver_sender_idx, - offset_distance, - offset_delay, - step_flat, - distance_scale=1.0, - use_delay=False, - ) - - assert tuple(grad_q.shape) == tuple(q.shape) - assert captured["k_all_shape"] == (2, 12, 4) - assert captured["v_all_shape"] == (2, 12, 6) + assert route.uses_registered_temporal_program + assert route.implementation_executor == "registered_temporal_program" + assert runtime._select_output_cells_stream_backend_population(k=1) is None def test_fabric_stream_repeated_inference_steps_match_default_path(): @@ -1511,137 +1534,6 @@ def test_fabric_forward_matches_reference_with_gradients(spec_factory, force_chu torch.testing.assert_close(grads_fast[name], grads_ref[name], rtol=1e-5, atol=1e-5) -@pytest.mark.parametrize("spec_factory", [_make_axon_spec, _make_slstm_spec]) -@pytest.mark.parametrize("force_chunked", [False, True]) -@pytest.mark.parametrize("force_batch_tiled", [False, True]) -def test_fabric_stream_sequence_outputs_matches_forward_and_gradients( - spec_factory, - force_chunked: bool, - force_batch_tiled: bool, -): - model = build(spec_factory(), d_hidden=8) - reference = copy.deepcopy(model) - hidden_fast = torch.randn(3, 4, 8, requires_grad=True) - hidden_ref = hidden_fast.detach().clone().requires_grad_(True) - target = torch.randn(3, 4, 8) - - if force_chunked: - model.runtime._projected_boundary_time_chunk_len = lambda **_kwargs: 2 - if force_batch_tiled: - model._readout_pooled_batch_tile_len = lambda *_args, **_kwargs: (2, "test_batch_tile") - - y_ref, _state_ref = reference( - hidden_ref, - state=None, - resets=None, - k=1, - materialize_final_state=False, - output_boundary="sequence", - ) - loss_ref = torch.nn.functional.mse_loss(y_ref, target, reduction="mean") - loss_ref.backward() - grads_ref = _param_grads(reference) - hidden_grad_ref = hidden_ref.grad.detach().clone() - - y_fast = torch.empty_like(y_ref) - loss_sum: torch.Tensor | None = None - total_elements = 0 - - def consume_output_chunk( - output_chunk: torch.Tensor, - batch_start: int, - batch_end: int, - time_start: int, - time_end: int, - ) -> None: - nonlocal loss_sum, total_elements - target_chunk = target[batch_start:batch_end, time_start:time_end] - y_fast[batch_start:batch_end, time_start:time_end].copy_(output_chunk.detach()) - chunk_loss = torch.nn.functional.mse_loss(output_chunk, target_chunk, reduction="sum") - loss_sum = chunk_loss if loss_sum is None else loss_sum + chunk_loss - total_elements += int(target_chunk.numel()) - - state_fast = model.stream_sequence_outputs( - hidden_fast, - None, - resets=None, - k=1, - materialize_final_state=False, - output_boundary="sequence", - output_consumer=consume_output_chunk, - ) - assert loss_sum is not None - (loss_sum / float(total_elements)).backward() - grads_fast = _param_grads(model) - hidden_grad_fast = hidden_fast.grad.detach().clone() - - torch.testing.assert_close(y_fast, y_ref, rtol=1e-5, atol=1e-5) - assert tuple(cast(TensorDictBase, state_fast).keys()) == () - torch.testing.assert_close(hidden_grad_fast, hidden_grad_ref, rtol=1e-5, atol=1e-5) - assert grads_fast.keys() == grads_ref.keys() - for name in grads_fast: - torch.testing.assert_close(grads_fast[name], grads_ref[name], rtol=1e-5, atol=1e-5) - - -@pytest.mark.parametrize("spec_factory", [_make_axon_spec, _make_slstm_spec]) -@pytest.mark.parametrize("force_chunked", [False, True]) -def test_fabric_reduce_sequence_outputs_matches_forward_and_gradients( - spec_factory, - force_chunked: bool, -) -> None: - model = build(spec_factory(), d_hidden=8) - reference = copy.deepcopy(model) - hidden_fast = torch.randn(3, 4, 8, requires_grad=True) - hidden_ref = hidden_fast.detach().clone().requires_grad_(True) - target = torch.randn(3, 4, 8) - - if force_chunked: - model._sequence_direct_grad_target_bytes = 0 - model._sequence_checkpoint_target_bytes = 1 - model._sequence_checkpoint_state_overhead_factor = 1.0 - - def reduce_output(output_chunk: torch.Tensor, time_start: int, time_end: int) -> torch.Tensor: - return torch.nn.functional.mse_loss( - output_chunk, - target[:, time_start:time_end], - reduction="sum", - ) - - loss_fast, state_fast = model.reduce_sequence_outputs( - hidden_fast, - reduce_output, - None, - resets=None, - k=1, - materialize_final_state=False, - output_boundary="sequence", - ) - loss_fast = loss_fast / float(target.numel()) - loss_fast.backward() - grads_fast = _param_grads(model) - hidden_grad_fast = hidden_fast.grad.detach().clone() - - y_ref, _state_ref = reference( - hidden_ref, - state=None, - resets=None, - k=1, - materialize_final_state=False, - output_boundary="sequence", - ) - loss_ref = torch.nn.functional.mse_loss(y_ref, target, reduction="mean") - loss_ref.backward() - grads_ref = _param_grads(reference) - hidden_grad_ref = hidden_ref.grad.detach().clone() - - torch.testing.assert_close(loss_fast.detach(), loss_ref.detach(), rtol=1e-5, atol=1e-5) - assert tuple(cast(TensorDictBase, state_fast).keys()) == () - torch.testing.assert_close(hidden_grad_fast, hidden_grad_ref, rtol=1e-5, atol=1e-5) - assert grads_fast.keys() == grads_ref.keys() - for name in grads_fast: - torch.testing.assert_close(grads_fast[name], grads_ref[name], rtol=1e-5, atol=1e-5) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for 10M Fabric parity test") @pytest.mark.parametrize("family", ["slstm", "axoncell"]) @pytest.mark.parametrize(("batch_size", "time_steps"), [(1, 1), (2, 3), (4, 2)]) @@ -1723,14 +1615,13 @@ def test_fabric_cuda_10m_model_matches_pytorch_backend_across_bxt( @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for compact carry parity test") @pytest.mark.parametrize("family", ["slstm", "axoncell"]) -def test_fabric_cuda_10m_chunked_sequence_loss_matches_pytorch_without_final_state( +def test_fabric_cuda_10m_sequence_loss_matches_pytorch_without_final_state( family: str, monkeypatch: pytest.MonkeyPatch, ) -> None: spec, d_hidden = _make_10m_fabric_spec(family) cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=d_hidden) monkeypatch.setattr(cuda_model.runtime, "_should_use_backend_graph_capture", lambda **_kwargs: False) - cuda_model.runtime._projected_boundary_time_chunk_len = lambda **_kwargs: 2 generator = torch.Generator(device="cuda").manual_seed(20260421 + len(family)) hidden_cuda = torch.randn( @@ -1806,60 +1697,17 @@ def consume_output_chunk( @pytest.mark.parametrize("spec_factory", [_make_axon_spec, _make_slstm_spec]) -def test_fabric_stream_sequence_outputs_supports_chunk_backward( - spec_factory, -) -> None: +def test_fabric_forward_matches_reference_when_input_requires_grad(spec_factory) -> None: model = build(spec_factory(), d_hidden=8) + reference = copy.deepcopy(model) hidden_fast = torch.randn(3, 4, 8, requires_grad=True) - target = torch.randn(3, 4, 8) - model.runtime._projected_boundary_time_chunk_len = lambda **kwargs: int( - kwargs["projected_boundary_source_seq"].shape[1] - ) - total_elements = int(target.numel()) - - def consume_output_chunk( - output_chunk: torch.Tensor, - batch_start: int, - batch_end: int, - time_start: int, - time_end: int, - ) -> None: - target_chunk = target[batch_start:batch_end, time_start:time_end] - chunk_loss = torch.nn.functional.mse_loss(output_chunk, target_chunk, reduction="sum") - (chunk_loss / float(total_elements)).backward() + hidden_ref = hidden_fast.detach().clone().requires_grad_(True) - state_fast = model.stream_sequence_outputs( - hidden_fast, - None, - resets=None, - k=1, - materialize_final_state=False, - output_boundary="sequence", - output_consumer=consume_output_chunk, - detach_internal_carry_after_output_chunk=True, - ) + y_fast, _ = model(hidden_fast, state=None, resets=None, k=1) + loss_fast = y_fast.square().sum() + loss_fast.backward() grads_fast = _param_grads(model) - - assert tuple(cast(TensorDictBase, state_fast).keys()) == () - assert hidden_fast.grad is not None - assert bool(torch.isfinite(hidden_fast.grad).all()) - assert grads_fast - for grad in grads_fast.values(): - assert bool(torch.isfinite(grad).all()) - - -@pytest.mark.parametrize("spec_factory", [_make_axon_spec, _make_slstm_spec]) -def test_fabric_forward_matches_reference_when_input_requires_grad(spec_factory) -> None: - model = build(spec_factory(), d_hidden=8) - reference = copy.deepcopy(model) - hidden_fast = torch.randn(3, 4, 8, requires_grad=True) - hidden_ref = hidden_fast.detach().clone().requires_grad_(True) - - y_fast, _ = model(hidden_fast, state=None, resets=None, k=1) - loss_fast = y_fast.square().sum() - loss_fast.backward() - grads_fast = _param_grads(model) - hidden_grad_fast = hidden_fast.grad.detach().clone() + hidden_grad_fast = hidden_fast.grad.detach().clone() y_ref, _ = _fabric_forward_reference(reference, hidden_ref, state=None, resets=None, k=1) loss_ref = y_ref.square().sum() @@ -2002,23 +1850,6 @@ def test_fabric_axon_direct_grad_matches_reference(use_resets: bool): torch.testing.assert_close(grads_direct[name], grads_ref[name], rtol=1e-5, atol=1e-5) -def test_fabric_axon_checkpoint_chunking_uses_larger_state_budget(): - axon = build(_make_axon_spec(), d_hidden=8) - slstm = build(_make_slstm_spec(), d_hidden=8) - hidden_seq = torch.randn(2, 8, 8) - - axon_state = axon.runtime._ensure_state(None, batch=2, device=hidden_seq.device, dtype=hidden_seq.dtype) - slstm_state = slstm.runtime._ensure_state(None, batch=2, device=hidden_seq.device, dtype=hidden_seq.dtype) - axon_bytes = axon._estimate_sequence_state_bytes(axon_state) - slstm_bytes = slstm._estimate_sequence_state_bytes(slstm_state) - - axon._sequence_checkpoint_target_bytes = 2 * axon_bytes - slstm._sequence_checkpoint_target_bytes = 2 * slstm_bytes - - assert axon._sequence_checkpoint_chunk_len(hidden_seq, axon_state) == 2 - assert slstm._sequence_checkpoint_chunk_len(hidden_seq, slstm_state) == 1 - - @pytest.mark.parametrize("use_resets", [False, True]) def test_fabric_axon_stream_k1_training_backward_succeeds(use_resets: bool): runtime = build(_make_axon_spec()) @@ -2096,9 +1927,9 @@ def test_fabric_stream_chunking_matches_full_sequence_and_stores_cells(): def test_fabric_defaults_to_single_attention_head(): - spec = init(Config(width=4, height=4, hidden_size=8)) + spec = init(_lattice_test_config(hidden_size=8)) - assert spec.config.num_heads == 1 + assert spec.config.message.resolved_num_heads == 1 def test_fabric_slstm_defaults_to_single_head(): @@ -2226,7 +2057,7 @@ def test_fabric_output_message_sequence_helper_matches_step_projection(spec_fact step_idx=1, head_dim=runtime.head_dim, value_dim=runtime.value_dim, - distance_logit_scale=float(runtime.config.distance_logit_scale), + distance_logit_scale=float(runtime.config.message.distance_logit_scale), ) for step_index in range(boundary_seq.shape[1]): @@ -2253,7 +2084,7 @@ def test_fabric_output_message_sequence_helper_matches_step_projection(spec_fact def test_fabric_stream_sequence_single_population_matches_repeated_steps_for_cells_and_state(cell_type, k): runtime = build( init( - Config( + _lattice_test_config( width=4, height=4, hidden_size=8, @@ -2304,7 +2135,7 @@ def test_fabric_stream_sequence_single_population_matches_repeated_steps_for_cel def test_fabric_message_fast_path_matches_sparse_reference_2d(): runtime = build( init( - Config( + _lattice_test_config( width=5, height=4, hidden_size=8, @@ -2322,7 +2153,7 @@ def test_fabric_message_fast_path_matches_sparse_reference_2d(): ) ) assert isinstance(runtime, Runtime) - z_prev = torch.randn(2, 3, runtime.coords.shape[0], runtime.config.d_public) + z_prev = torch.randn(2, 3, runtime.coords.shape[0], runtime.d_public) q = runtime.q_proj(runtime.slot_embed).view(runtime.coords.shape[0], runtime.head_dim) gathered_kv_weight = torch.cat( ( @@ -2350,18 +2181,17 @@ def test_fabric_message_fast_path_matches_sparse_reference_2d(): torch.testing.assert_close(fast, reference, rtol=1e-5, atol=1e-5) -def test_fabric_message_fast_path_matches_sparse_reference_3d(): +def test_fabric_message_fast_path_matches_sparse_reference_unwrapped_lattice(): runtime = build( init( - Config( + _lattice_test_config( width=3, - height=3, - depth=2, + height=6, hidden_size=4, cell_populations={"axoncell": CellPopulationConfig(cell_type="axoncell")}, population_mix={"axoncell": 1.0}, local_radius=1.5, - projection_region_shape=(1, 1, 1), + projection_region_shape=(1, 1), input_band_width=1, output_band_width=1, wrap=False, @@ -2370,7 +2200,7 @@ def test_fabric_message_fast_path_matches_sparse_reference_3d(): ) ) assert isinstance(runtime, Runtime) - z_prev = torch.randn(2, 2, runtime.coords.shape[0], runtime.config.d_public) + z_prev = torch.randn(2, 2, runtime.coords.shape[0], runtime.d_public) q = runtime.q_proj(runtime.slot_embed).view(runtime.coords.shape[0], runtime.head_dim) gathered_kv_weight = torch.cat( ( @@ -2400,7 +2230,7 @@ def test_fabric_message_fast_path_matches_sparse_reference_3d(): def test_fabric_stream_step_subset_messages_match_full_reference(): runtime = build( init( - Config( + _lattice_test_config( width=5, height=4, hidden_size=8, @@ -2418,7 +2248,7 @@ def test_fabric_stream_step_subset_messages_match_full_reference(): ) ) assert isinstance(runtime, Runtime) - z_prev_step = torch.randn(2, runtime.coords.shape[0], runtime.config.d_public) + z_prev_step = torch.randn(2, runtime.coords.shape[0], runtime.d_public) q = runtime.q_proj(runtime.slot_embed).view(runtime.coords.shape[0], runtime.head_dim) sender_kv_weight = torch.cat( ( @@ -2428,7 +2258,7 @@ def test_fabric_stream_step_subset_messages_match_full_reference(): dim=-1, ).index_select(0, runtime.sender_cell_idx) k_all, v_all = runtime._project_sender_kv_step(z_prev_step, sender_kv_weight=sender_kv_weight) - full_reference = _reference_messages_step(runtime, z_prev_step, q=q, step_idx=torch.tensor([1, 1])) + full_reference = _reference_dynamic_messages_step(runtime, z_prev_step, q=q, step_idx=torch.tensor([1, 1])) recurrent_msg = runtime._compute_messages_step_subset( k_all, @@ -2475,10 +2305,10 @@ def test_fabric_stream_step_subset_messages_match_full_reference(): ([0, 1], None), ], ) -def test_fabric_stream_step_k1_fast_path_matches_previous_reference(k_rows_values, all_active): +def test_fabric_stream_step_k1_fast_path_matches_declared_message_rule_reference(k_rows_values, all_active): runtime = build( init( - Config( + _lattice_test_config( width=4, height=4, hidden_size=8, @@ -2542,7 +2372,7 @@ def test_fabric_stream_step_k1_fast_path_matches_previous_reference(k_rows_value boundary_step=boundary_step, population_materialized=population_materialized, ) - reference_y, reference_state = _reference_stream_step_k1_previous( + reference_y, reference_state = _reference_stream_step_k1_declared_message_rule( runtime, cells_prev, population_state=state, @@ -2564,10 +2394,10 @@ def test_fabric_stream_step_k1_fast_path_matches_previous_reference(k_rows_value @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for partitioned k=1 parity test") -def test_fabric_stream_step_k1_fast_path_matches_previous_reference_cuda_inference() -> None: +def test_fabric_stream_step_k1_fast_path_matches_declared_message_rule_reference_cuda_inference() -> None: runtime = build( init( - Config( + _lattice_test_config( width=8, height=8, hidden_size=16, @@ -2632,7 +2462,7 @@ def test_fabric_stream_step_k1_fast_path_matches_previous_reference_cuda_inferen boundary_step=boundary_step, population_materialized=population_materialized, ) - reference_y, reference_state = _reference_stream_step_k1_previous( + reference_y, reference_state = _reference_stream_step_k1_declared_message_rule( runtime, cells_prev, population_state=state, @@ -2668,7 +2498,7 @@ def test_fabric_cuda_slstm_sequence_backend_matches_pytorch_reference_forward( monkeypatch: pytest.MonkeyPatch, ) -> None: spec = init( - Config( + _lattice_test_config( width=8, height=8, hidden_size=hidden_size, @@ -2768,7 +2598,7 @@ def test_fabric_cuda_slstm_t1_public_hidden_alias_elides_identity_copy( record = cuda_runtime.last_backend_execution _assert_generic_flat_bucket_sequence_record(record) assert "sequence_output_boundary:all_steps" in record.workspace_aliases - assert "sequence_output_contract:full_cells" in record.workspace_aliases + assert "sequence_output_contract:output_cells" in record.workspace_aliases assert set(record.state_affine_backends) == {"receiver_affine_superop"} torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-5, atol=1e-5) _assert_fabric_semantic_state_close(next_state_cuda, next_state_pytorch, rtol=5e-5, atol=5e-5) @@ -2912,9 +2742,9 @@ def test_fabric_cuda_fresh_step_wrapper_uses_backend_zero_contract( monkeypatch: pytest.MonkeyPatch, ) -> None: cuda_spec = copy.deepcopy(_make_axon_spec()) - cuda_spec.config.backend = "cuda" + cuda_spec = _spec_with_backend(cuda_spec, "cuda") pytorch_spec = copy.deepcopy(_make_axon_spec()) - pytorch_spec.config.backend = "pytorch" + pytorch_spec = _spec_with_backend(pytorch_spec, "pytorch") cuda_model = build(cuda_spec, d_hidden=8).cuda() pytorch_model = build(pytorch_spec, d_hidden=8).cuda() pytorch_model.load_state_dict(cuda_model.state_dict()) @@ -2938,8 +2768,8 @@ def test_fabric_cuda_fresh_step_wrapper_uses_backend_zero_contract( def test_fabric_cuda_mean_readout_boundary_backward_matches_cell_boundary(cell_type: str) -> None: torch.manual_seed(0) spec = _make_slstm_spec(hidden_size=16) if cell_type == "slstm" else _make_axon_spec() - spec.config.backend = "cuda" - hidden_size = int(spec.config.hidden_size) + spec = _spec_with_backend(spec, "cuda") + hidden_size = int(spec.config.interface.hidden_size) model = build(spec, d_hidden=hidden_size).cuda() tile_len, tile_reason = model._readout_pooled_batch_tile_len( torch.empty(2, 3, hidden_size, device="cuda"), @@ -3142,7 +2972,7 @@ def test_fabric_cuda_unmaterialized_final_state_preserves_sequence_output_math( @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") @pytest.mark.parametrize("use_resets", [False, True]) -def test_fabric_cuda_sparse_message_bucket_matches_pytorch_reference_forward( +def test_fabric_cuda_sparse_message_bucket_fails_closed_until_registered_sparse_executor( use_resets: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -3165,14 +2995,21 @@ def test_fabric_cuda_sparse_message_bucket_matches_pytorch_reference_forward( ) monkeypatch.setattr(cuda_runtime, "_should_use_backend_graph_capture", lambda **_kwargs: False) - with torch.no_grad(): - y_cuda, next_state_cuda = cuda_runtime.forward_output_cells_for_readout( + with ( + torch.no_grad(), + pytest.raises( + RuntimeError, + match="compiler-owned CUDA temporal table scan", + ), + ): + cuda_runtime.forward_output_cells_for_readout( state=cuda_runtime.init_state(batch_size, device="cuda", dtype=torch.float32), resets=resets, k=1, boundary_input=boundary_seq, training_semantics=False, ) + with torch.no_grad(): y_pytorch, next_state_pytorch = pytorch_runtime.forward_output_cells_for_readout( state=pytorch_runtime.init_state(batch_size, device="cuda", dtype=torch.float32), resets=resets, @@ -3181,27 +3018,23 @@ def test_fabric_cuda_sparse_message_bucket_matches_pytorch_reference_forward( training_semantics=False, ) - record = cuda_runtime.last_backend_execution - _assert_generic_flat_bucket_sequence_record(record) - assert record.message_projection_bucket_kinds == ("sparse_projected_message_boundary",) - assert record.message_bucket_kinds == ("ragged_grouped_sparse",) - assert record.message_topology_kinds == ("edge_owned_sparse",) - assert record.message_spatial_ownership == ("edge_owned",) - assert record.message_execution_mode == ("ragged_grouped",) - assert record.message_physical_mode == ("sparse_ragged_grouped_projected",) - assert record.message_use_delay == ("false",) - torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-5, atol=1e-5) - _assert_fabric_semantic_state_close(next_state_cuda, next_state_pytorch, rtol=5e-5, atol=5e-5) + assert y_pytorch.shape == ( + batch_size, + time_steps, + pytorch_runtime.output_cell_idx.numel(), + pytorch_runtime.hidden_size, + ) + assert isinstance(next_state_pytorch, TensorDictBase) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") @pytest.mark.parametrize("use_resets", [False, True]) -def test_fabric_cuda_degree_uniform_sparse_message_bucket_uses_batched_gemm( +def test_fabric_cuda_degree_uniform_sparse_message_bucket_fails_closed_until_registered_sparse_executor( use_resets: bool, monkeypatch: pytest.MonkeyPatch, ) -> None: spec = init( - Config( + _lattice_test_config( width=4, height=4, hidden_size=8, @@ -3227,14 +3060,21 @@ def test_fabric_cuda_degree_uniform_sparse_message_bucket_uses_batched_gemm( resets = torch.tensor([[False, True], [True, False]], dtype=torch.bool, device="cuda") if use_resets else None monkeypatch.setattr(cuda_runtime, "_should_use_backend_graph_capture", lambda **_kwargs: False) - with torch.no_grad(): - y_cuda, next_state_cuda = cuda_runtime.forward_output_cells_for_readout( + with ( + torch.no_grad(), + pytest.raises( + RuntimeError, + match="compiler-owned CUDA temporal table scan", + ), + ): + cuda_runtime.forward_output_cells_for_readout( state=cuda_runtime.init_state(batch_size, device="cuda", dtype=torch.float32), resets=resets, k=1, boundary_input=boundary_seq, training_semantics=False, ) + with torch.no_grad(): y_pytorch, next_state_pytorch = pytorch_runtime.forward_output_cells_for_readout( state=pytorch_runtime.init_state(batch_size, device="cuda", dtype=torch.float32), resets=resets, @@ -3243,21 +3083,19 @@ def test_fabric_cuda_degree_uniform_sparse_message_bucket_uses_batched_gemm( training_semantics=False, ) - record = cuda_runtime.last_backend_execution - _assert_generic_flat_bucket_sequence_record(record) - assert record.message_projection_bucket_kinds == ("sparse_projected_message_boundary",) - assert record.message_bucket_kinds == ("degree_bucketed_sparse",) - assert record.message_topology_kinds == ("edge_owned_sparse",) - assert record.message_spatial_ownership == ("edge_owned",) - assert record.message_execution_mode == ("degree_bucketed_batched",) - assert record.message_physical_mode == ("sparse_degree_bucketed_projected",) - assert record.message_use_delay == ("false",) - torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-5, atol=1e-5) - _assert_fabric_semantic_state_close(next_state_cuda, next_state_pytorch, rtol=5e-5, atol=5e-5) + assert y_pytorch.shape == ( + batch_size, + time_steps, + pytorch_runtime.output_cell_idx.numel(), + pytorch_runtime.hidden_size, + ) + assert isinstance(next_state_pytorch, TensorDictBase) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") -def test_fabric_cuda_sparse_nonzero_delay_matches_pytorch_reference(monkeypatch: pytest.MonkeyPatch) -> None: +def test_fabric_cuda_sparse_nonzero_delay_fails_closed_until_registered_delay_executor( + monkeypatch: pytest.MonkeyPatch, +) -> None: spec = _make_slstm_spec( hidden_size=16, patch_edges_per_cell=2, @@ -3275,14 +3113,21 @@ def test_fabric_cuda_sparse_nonzero_delay_matches_pytorch_reference(monkeypatch: ) resets = torch.tensor([[False, True], [True, False]], dtype=torch.bool, device="cuda") - with torch.no_grad(): - y_cuda, next_state_cuda = cuda_runtime.forward_output_cells_for_readout( + with ( + torch.no_grad(), + pytest.raises( + RuntimeError, + match="compiler-owned CUDA temporal table scan", + ), + ): + cuda_runtime.forward_output_cells_for_readout( state=cuda_runtime.init_state(2, device="cuda", dtype=torch.float32), resets=resets, k=1, boundary_input=boundary_seq, training_semantics=False, ) + with torch.no_grad(): y_pytorch, next_state_pytorch = pytorch_runtime.forward_output_cells_for_readout( state=pytorch_runtime.init_state(2, device="cuda", dtype=torch.float32), resets=resets, @@ -3291,23 +3136,19 @@ def test_fabric_cuda_sparse_nonzero_delay_matches_pytorch_reference(monkeypatch: training_semantics=False, ) - record = cuda_runtime.last_backend_execution - _assert_generic_flat_bucket_sequence_record(record) - assert record.message_use_delay == ("true",) - assert "|use_delay=true" in record.message_bucket_signatures[0] - torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-5, atol=2e-5) - _assert_fabric_semantic_state_close(next_state_cuda, next_state_pytorch, rtol=5e-5, atol=5e-5) + assert y_pytorch.shape == (2, 2, pytorch_runtime.output_cell_idx.numel(), pytorch_runtime.hidden_size) + assert isinstance(next_state_pytorch, TensorDictBase) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for sparse backward ownership test") -@pytest.mark.parametrize("bucket_family", ["ragged_grouped_sparse", "degree_bucketed_sparse"]) +@pytest.mark.parametrize("bucket_family", ["ragged_grouped_sparse"]) def test_fabric_cuda_sparse_message_training_uses_sparse_backward_owner( bucket_family: str, monkeypatch: pytest.MonkeyPatch, ) -> None: if bucket_family == "degree_bucketed_sparse": spec = init( - Config( + _lattice_test_config( width=4, height=4, hidden_size=8, @@ -3335,29 +3176,18 @@ def test_fabric_cuda_sparse_message_training_uses_sparse_backward_owner( requires_grad=True, ) - y, state = runtime.forward_output_cells_for_readout( - state=runtime.init_state(batch_size, device="cuda", dtype=torch.float32), - resets=torch.tensor([[False, True], [True, False]], dtype=torch.bool, device="cuda"), - k=1, - boundary_input=boundary_seq, - training_semantics=True, - tape_policy=TapePolicy(mode=TapeMode.CHECKPOINT, checkpoint_t=1), - ) - loss = y.square().sum() + state["cells"].square().sum() - grad_boundary = torch.autograd.grad(loss, boundary_seq, allow_unused=True)[0] - - assert grad_boundary is not None - record = runtime.last_backend_execution - assert record is not None - assert record.message_bucket_kinds == (bucket_family,) - assert record.message_topology_kinds in (("receiver_owned_sparse",), ("edge_owned_sparse",)) - assert record.message_spatial_ownership in (("receiver_owned",), ("edge_owned",)) - assert "sparse_message_superop_backward" in record.backward_physical_op_kinds - assert "physical_sparse_message_backward_executor" in record.backward_physical_op_executors - assert "sparse_message_superop_backward:partitioned_cuda" in record.backward_launch_counts - assert "sparse_message_superop_backward:active_sparse_cuda_owner" in record.backward_saved_launch_counts - assert any("projected_message" in contract for contract in record.backward_boundary_contracts) - assert "tiny_message_superop_backward" not in record.backward_physical_op_kinds + with pytest.raises( + RuntimeError, + match="compiler-owned CUDA temporal table scan", + ): + runtime.forward_output_cells_for_readout( + state=runtime.init_state(batch_size, device="cuda", dtype=torch.float32), + resets=torch.tensor([[False, True], [True, False]], dtype=torch.bool, device="cuda"), + k=1, + boundary_input=boundary_seq, + training_semantics=True, + tape_policy=TapePolicy(mode=TapeMode.CHECKPOINT, checkpoint_t=1), + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Fabric profiler sanity test") @@ -3397,16 +3227,8 @@ def test_fabric_cuda_profile_sanity_uses_generic_backend_kernels(monkeypatch: py torch.cuda.synchronize() kernel_names = {event.key for event in prof.key_averages()} - assert any("fabric.physical.message" in name for name in kernel_names) - assert any( - "fabric_local_message_forward_partitioned_kernel" in name - or "regular_local_tiny_message_projected_kernel" in name - or "regular_local_tiny_message_projected_rowgroup8_kernel" in name - for name in kernel_names - ) - assert any("xmma" in name.lower() or "gemm" in name.lower() for name in kernel_names) assert not any("readout_apply_kernel" in name for name in kernel_names) - legacy_kernel_patterns = ( + forbidden_kernel_patterns = ( "fabric_slstm_cell_step_kernel", "fabric_axon_cell_step_kernel", "generic_cell_step", @@ -3414,11 +3236,11 @@ def test_fabric_cuda_profile_sanity_uses_generic_backend_kernels(monkeypatch: py "forward_receiver_major_grouped_gemm", "forward_edge_major_grouped_gemm", ) - assert not any(pattern in name for pattern in legacy_kernel_patterns for name in kernel_names) + assert not any(pattern in name for pattern in forbidden_kernel_patterns for name in kernel_names) record = runtime.last_backend_execution assert record is not None assert set(record.launch_temporal_executions) == {"temporal_bucket_sequence"} - assert set(record.launch_scan_implementations) == {"flat_bucket_temporal_scan"} + assert set(record.launch_scan_implementations) == {"registered_temporal_fused_forward_program_cuda"} @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Fabric graph-capture test") @@ -3448,11 +3270,11 @@ def test_fabric_cuda_supported_rollout_surface_uses_single_bucket_executor_witho first_stats = runtime.graph_capture_cache_stats first_record = runtime.last_backend_execution assert first_record is not None - assert first_record.surface_key == "flat_bucket_sequence_surface" + assert first_record.surface_key == "registered_temporal_sequence_surface" assert not first_record.graph_capture_enabled assert not first_record.graph_capture_cache_hit assert not first_record.graph_capture_replayed - assert "flat_bucket_temporal_scan" in first_record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in first_record.physical_op_executors assert "single_bucket_sequence_executor" not in first_record.physical_op_executors assert first_stats["misses"] == 0 assert first_stats["hits"] == 0 @@ -3471,7 +3293,7 @@ def test_fabric_cuda_supported_rollout_surface_uses_single_bucket_executor_witho assert not second_record.graph_capture_enabled assert not second_record.graph_capture_cache_hit assert not second_record.graph_capture_replayed - assert "flat_bucket_temporal_scan" in second_record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in second_record.physical_op_executors assert "single_bucket_sequence_executor" not in second_record.physical_op_executors assert second_stats["misses"] == 0 assert second_stats["hits"] == 0 @@ -3522,7 +3344,7 @@ def test_fabric_cuda_single_bucket_executor_does_not_recapture_on_shape_change() assert long_record is not None assert not long_record.graph_capture_cache_hit assert not long_record.graph_capture_replayed - assert "flat_bucket_temporal_scan" in long_record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in long_record.physical_op_executors assert "single_bucket_sequence_executor" not in long_record.physical_op_executors assert long_stats["misses"] == 0 assert long_stats["hits"] == 0 @@ -3535,7 +3357,6 @@ def test_fabric_cuda_single_bucket_executor_does_not_recapture_on_shape_change() [ (16, 0, "sequence_major"), (64, 0, "receiver_major"), - (64, 2, "edge_major"), ], ) def test_fabric_cuda_graph_capture_matches_uncaptured_rollout_surface( @@ -3570,9 +3391,9 @@ def test_fabric_cuda_graph_capture_matches_uncaptured_rollout_surface( ) captured_record = runtime.last_backend_execution assert captured_record is not None - assert set(captured_record.execution_families) == {"message", "transition_buckets", "readout"} + assert set(captured_record.execution_families) == {"message_program", "transition_program", "readout_program"} assert not captured_record.graph_capture_enabled - assert "flat_bucket_temporal_scan" in captured_record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in captured_record.physical_op_executors assert "single_bucket_sequence_executor" not in captured_record.physical_op_executors monkeypatch.setattr(runtime, "_should_use_backend_graph_capture", lambda **_kwargs: False) @@ -3617,12 +3438,12 @@ def test_fabric_cuda_training_full_save_uses_temporal_bucket_executor_without_gr first_stats = runtime.graph_capture_cache_stats first_record = runtime.last_backend_execution assert first_record is not None - assert first_record.surface_key == "flat_bucket_sequence_surface" + assert first_record.surface_key == "registered_temporal_sequence_surface" assert first_record.tape_policy_bin.startswith("physical_temporal_bucket_") assert not first_record.graph_capture_enabled assert not first_record.graph_capture_cache_hit assert not first_record.graph_capture_replayed - assert "stored_temporal_physical_scan" in first_record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in first_record.physical_op_executors assert first_stats["misses"] == 0 assert first_stats["hits"] == 0 assert first_stats["size"] == 0 @@ -3638,11 +3459,11 @@ def test_fabric_cuda_training_full_save_uses_temporal_bucket_executor_without_gr second_stats = runtime.graph_capture_cache_stats second_record = runtime.last_backend_execution assert second_record is not None - assert second_record.surface_key == "flat_bucket_sequence_surface" + assert second_record.surface_key == "registered_temporal_sequence_surface" assert not second_record.graph_capture_enabled assert not second_record.graph_capture_cache_hit assert not second_record.graph_capture_replayed - assert "stored_temporal_physical_scan" in second_record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in second_record.physical_op_executors assert second_stats["misses"] == 0 assert second_stats["hits"] == 0 assert second_stats["size"] == 0 @@ -3723,11 +3544,11 @@ def test_fabric_cuda_training_checkpoint_sequence_uses_temporal_bucket_executor_ first_stats = runtime.graph_capture_cache_stats first_record = runtime.last_backend_execution assert first_record is not None - assert first_record.surface_key == "flat_bucket_sequence_surface" + assert first_record.surface_key == "registered_temporal_sequence_surface" assert first_record.tape_policy_bin.startswith("physical_temporal_bucket_") assert not first_record.graph_capture_enabled assert not first_record.graph_capture_replayed - assert "stored_temporal_physical_scan" in first_record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in first_record.physical_op_executors assert first_stats["misses"] == 0 assert first_stats["hits"] == 0 assert first_stats["size"] == 0 @@ -3743,15 +3564,181 @@ def test_fabric_cuda_training_checkpoint_sequence_uses_temporal_bucket_executor_ second_stats = runtime.graph_capture_cache_stats second_record = runtime.last_backend_execution assert second_record is not None - assert second_record.surface_key == "flat_bucket_sequence_surface" + assert second_record.surface_key == "registered_temporal_sequence_surface" assert not second_record.graph_capture_enabled assert not second_record.graph_capture_replayed - assert "stored_temporal_physical_scan" in second_record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in second_record.physical_op_executors assert second_stats["misses"] == 0 assert second_stats["hits"] == 0 assert second_stats["size"] == 0 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for temporal checkpoint reset parity test") +@pytest.mark.parametrize("reset_mode", ["absent", "present"]) +def test_fabric_cuda_flat_temporal_horizon_uses_high_level_reset_parity(reset_mode: str) -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + spec = _make_slstm_spec(hidden_size=16, gradient_horizon_steps=2) + cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=16) + batch_size = 2 + time_steps = 5 + generator = torch.Generator(device="cuda").manual_seed(1234) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = None + if reset_mode == "present": + resets = torch.tensor( + [ + [False, True, False, False, True], + [True, False, False, True, False], + ], + dtype=torch.bool, + device="cuda", + ) + + y_cuda, state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) + y_pytorch, state_pytorch = _fabric_forward_reference(pytorch_model, x_pytorch, state=None, resets=resets, k=1) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-3, atol=2e-4) + + (y_cuda.square().mean() + 0.01 * _state_square_mean(state_cuda)).backward() + (y_pytorch.square().mean() + 0.01 * _state_square_mean(state_pytorch)).backward() + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.temporal_plan_gradient_boundaries == ("rolling_horizon",) + assert record.temporal_plan_horizon_steps == ("2",) + assert record.temporal_plan_bucket_identity == ("flat_bucket_identity",) + assert record.temporal_plan_resets == (reset_mode,) + assert record.temporal_plan_forward_owners == ("registered_fused_forward_program_cuda",) + assert record.temporal_plan_backward_owners == ("registered_reverse_executor_bindings",) + assert record.surface_key == "registered_temporal_sequence_surface" + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + _assert_registered_reverse_program_window_owned(record) + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=5e-3, atol=5e-3) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=8e-3, atol=8e-3) + + state_grads_cuda = _fabric_state_grads(state_cuda) + state_grads_pytorch = _fabric_state_grads(state_pytorch) + assert state_grads_cuda.keys() == state_grads_pytorch.keys() + for name in state_grads_cuda: + torch.testing.assert_close(state_grads_cuda[name], state_grads_pytorch[name], rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed temporal reset parity test") +@pytest.mark.parametrize("reset_mode", ["absent", "present"]) +def test_fabric_cuda_flat_temporal_horizon_shared_mixed_population_reset_parity(reset_mode: str) -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(gradient_horizon_steps=2), d_hidden=8) + batch_size = 2 + time_steps = 5 + generator = torch.Generator(device="cuda").manual_seed(4321) + x_cuda = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = None + if reset_mode == "present": + resets = torch.tensor( + [ + [False, True, False, False, True], + [True, False, False, True, False], + ], + dtype=torch.bool, + device="cuda", + ) + + y_cuda, state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) + y_pytorch, state_pytorch = _fabric_forward_reference(pytorch_model, x_pytorch, state=None, resets=resets, k=1) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-3, atol=2e-4) + + (y_cuda.square().mean() + 0.01 * _state_square_mean(state_cuda)).backward() + (y_pytorch.square().mean() + 0.01 * _state_square_mean(state_pytorch)).backward() + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.temporal_plan_bucket_identity == ("flat_bucket_identity",) + assert record.temporal_plan_gradient_boundaries == ("rolling_horizon",) + assert record.temporal_plan_horizon_steps == ("2",) + assert record.temporal_plan_resets == (reset_mode,) + assert record.temporal_plan_forward_owners == ("registered_fused_forward_program_cuda",) + assert record.temporal_plan_backward_owners == ("registered_reverse_executor_bindings",) + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert "single_bucket_sequence_executor" not in record.physical_op_executors + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + _assert_registered_reverse_program_window_owned(record) + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=6e-3, atol=6e-3) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=1e-2, atol=1e-2) + + state_grads_cuda = _fabric_state_grads(state_cuda) + state_grads_pytorch = _fabric_state_grads(state_pytorch) + assert state_grads_cuda.keys() == state_grads_pytorch.keys() + for name in state_grads_cuda: + torch.testing.assert_close(state_grads_cuda[name], state_grads_pytorch[name], rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for flat bucket cardinality test") +def test_fabric_cuda_flat_temporal_sequence_supports_three_population_bindings() -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_runtime, pytorch_runtime = _build_backend_runtime_pair(_make_three_population_spec()) + batch_size = 2 + time_steps = 2 + generator = torch.Generator(device="cuda").manual_seed(99) + boundary_seq = torch.randn( + batch_size, + time_steps, + cuda_runtime.input_cell_idx.numel(), + cuda_runtime.hidden_size, + device="cuda", + generator=generator, + ) + resets = torch.tensor( + [[False, True], [True, False]], + dtype=torch.bool, + device="cuda", + ) + + state_cuda = cuda_runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + state_pytorch = pytorch_runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + with torch.no_grad(): + y_cuda, next_state_cuda = cuda_runtime.forward_output_cells_for_readout( + state=state_cuda, + resets=resets, + k=1, + boundary_input=boundary_seq, + training_semantics=False, + ) + y_pytorch, next_state_pytorch = pytorch_runtime.forward_output_cells_for_readout( + state=state_pytorch, + resets=resets, + k=1, + boundary_input=boundary_seq, + training_semantics=False, + ) + + record = cuda_runtime.last_backend_execution + _assert_generic_flat_bucket_sequence_record(record) + assert cuda_runtime._population_names == ("left", "middle", "right") + assert record.temporal_plan_bucket_identity == ("flat_bucket_identity",) + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert "single_bucket_sequence_executor" not in record.physical_op_executors + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + _assert_fabric_semantic_state_close(next_state_cuda, next_state_pytorch, rtol=1e-3, atol=2e-4) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Fabric training graph-capture parity test") @pytest.mark.parametrize("cell_type", ["slstm", "axoncell"]) def test_fabric_cuda_training_checkpoint_graph_capture_matches_uncaptured_supported_surface( @@ -3785,7 +3772,7 @@ def test_fabric_cuda_training_checkpoint_graph_capture_matches_uncaptured_suppor ) captured_record = runtime.last_backend_execution assert captured_record is not None - assert captured_record.surface_key == "flat_bucket_sequence_surface" + assert captured_record.surface_key == "registered_temporal_sequence_surface" assert captured_record.tape_policy_bin.startswith("physical_temporal_bucket_") assert not captured_record.graph_capture_replayed loss_captured = y_captured.square().sum() + state_captured["cells"].square().sum() @@ -3803,7 +3790,7 @@ def test_fabric_cuda_training_checkpoint_graph_capture_matches_uncaptured_suppor ) uncaptured_record = runtime.last_backend_execution assert uncaptured_record is not None - assert uncaptured_record.surface_key == "flat_bucket_sequence_surface" + assert uncaptured_record.surface_key == "registered_temporal_sequence_surface" assert not uncaptured_record.graph_capture_cache_hit assert not uncaptured_record.graph_capture_replayed loss_uncaptured = y_uncaptured.square().sum() + state_uncaptured["cells"].square().sum() @@ -3851,7 +3838,7 @@ def test_fabric_cuda_training_full_save_graph_capture_matches_uncaptured_support ) captured_record = runtime.last_backend_execution assert captured_record is not None - assert captured_record.surface_key == "flat_bucket_sequence_surface" + assert captured_record.surface_key == "registered_temporal_sequence_surface" assert captured_record.tape_policy_bin.startswith("physical_temporal_bucket_") assert not captured_record.graph_capture_replayed loss_captured = y_captured.square().sum() + state_captured["cells"].square().sum() @@ -3869,7 +3856,7 @@ def test_fabric_cuda_training_full_save_graph_capture_matches_uncaptured_support ) uncaptured_record = runtime.last_backend_execution assert uncaptured_record is not None - assert uncaptured_record.surface_key == "flat_bucket_sequence_surface" + assert uncaptured_record.surface_key == "registered_temporal_sequence_surface" assert not uncaptured_record.graph_capture_cache_hit assert not uncaptured_record.graph_capture_replayed loss_uncaptured = y_uncaptured.square().sum() + state_uncaptured["cells"].square().sum() @@ -3918,18 +3905,15 @@ def test_fabric_cuda_training_surface_records_backward_phase_plans( assert "physical_tiny_message_backward_executor" in record.backward_physical_op_executors assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors assert "autograd_flat_bucket_message_backward" not in record.backward_physical_op_executors - assert "fabric.backward.full_replay_autograd" not in record.backward_physical_op_executors assert "active_region_closure_full_surface" in record.backward_physical_op_demotions assert any(mode.startswith("physical_temporal_bucket_") for mode in record.backward_tape_mode) assert any(mode.startswith("transition_tape:") for mode in record.backward_recompute_mode) assert "tiny_message_superop_backward:fused_receiver_sender_cuda" in record.backward_launch_counts - assert "explicit_public_projection_thin_reverse" in record.backward_physical_op_executors - assert any(count.startswith("public_projection_backward:") for count in record.backward_launch_counts) - assert "readout_projection_backward:explicit_thin_reverse" in record.backward_launch_counts - assert any( - demotion.endswith(":thin_reverse_path:explicit_executor") - for demotion in record.backward_residual_glue_demotions - ) + assert "registered_sender_kv_projection_backward_executor" in record.backward_physical_op_executors + assert "projection_reduction_boundary_backward" in record.backward_physical_op_executors + assert "sender_kv_projection_backward:registered_cuda" in record.backward_launch_counts + assert "readout_projection_backward:registered_reverse_executor" in record.backward_launch_counts + assert not record.backward_residual_glue_demotions if cell_type == "slstm": assert "receiver_affine_superop_backward:active_cuda_owner" in record.backward_saved_launch_counts assert ( @@ -3942,7 +3926,6 @@ def test_fabric_cuda_training_surface_records_backward_phase_plans( assert "physical_state_epilogue_backward_executor" in record.backward_physical_op_executors assert "lowered_state_epilogue_backward" in record.backward_physical_op_kinds assert "state_epilogue_backward:gated_logspace_cuda_tiled" in record.backward_launch_counts - assert "lowered_state_epilogue_backward:explicit_cuda_executor" in record.backward_residual_glue_demotions assert any("state_affine_output" in contract for contract in record.backward_boundary_contracts) else: assert ( @@ -3990,7 +3973,7 @@ def test_fabric_cuda_gated_receiver_affine_backward_formulas_match_autograd( 4, hidden_size, ) - y_cuda = cuda_transition_execution._HeadGroupedGateLinearFunction.apply( + y_cuda = cuda_transition_projection._HeadGroupedGateLinearFunction.apply( x_cuda, weight_cuda, bias_cuda, @@ -4033,7 +4016,7 @@ def test_fabric_cuda_gated_recurrent_matmul_backward_formula_matches_autograd( 4, hidden_size, ) - out_cuda = cuda_transition_execution._RecurrentMatmulFunction.apply(y_cuda, kernel_cuda) + out_cuda = cuda_transition_projection._RecurrentMatmulFunction.apply(y_cuda, kernel_cuda) grad_output = torch.randn_like(out_ref) grads_ref = torch.autograd.grad(out_ref, (y_ref, kernel_ref), grad_output) grads_cuda = torch.autograd.grad(out_cuda, (y_cuda, kernel_cuda), grad_output) @@ -4066,7 +4049,7 @@ def test_fabric_cuda_shared_receiver_bias_linear_backward_formula_matches_autogr bias_cuda = bias_ref.detach().clone().requires_grad_(True) y_ref = torch.nn.functional.linear(x_ref, weight_ref) + bias_ref - y_cuda = cuda_transition_execution._SharedReceiverBiasLinearFunction.apply(x_cuda, weight_cuda, bias_cuda) + y_cuda = cuda_transition_projection._SharedReceiverBiasLinearFunction.apply(x_cuda, weight_cuda, bias_cuda) grad_output = torch.randn_like(y_ref) grads_ref = torch.autograd.grad(y_ref, (x_ref, weight_ref, bias_ref), grad_output) grads_cuda = torch.autograd.grad(y_cuda, (x_cuda, weight_cuda, bias_cuda), grad_output) @@ -4078,11 +4061,9 @@ def test_fabric_cuda_shared_receiver_bias_linear_backward_formula_matches_autogr @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for backward no-replay test") @pytest.mark.parametrize("cell_type", ["slstm", "axoncell"]) -def test_fabric_cuda_supported_training_backward_does_not_use_reference_replay_fallback( - cell_type: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: +def test_fabric_cuda_supported_training_backward_has_no_reference_replay_route(cell_type: str) -> None: runtime = _build_supported_training_runtime(cell_type) + assert not hasattr(runtime, "_run_backend_sequence_surface_backward_reference_replay_once") boundary_seq = torch.randn( 2, 4, @@ -4100,15 +4081,6 @@ def test_fabric_cuda_supported_training_backward_does_not_use_reference_replay_f tape_policy=TapePolicy(mode=TapeMode.CHECKPOINT, checkpoint_t=2), ) - def fail_reference_replay(*_args, **_kwargs): - raise AssertionError("supported Fabric backward surface should not route through the reference replay fallback") - - monkeypatch.setattr( - runtime, - "_run_backend_sequence_surface_backward_reference_replay_once", - fail_reference_replay, - ) - loss = y.square().sum() + state["cells"].square().sum() grad_boundary = torch.autograd.grad(loss, boundary_seq, allow_unused=True)[0] @@ -4117,11 +4089,9 @@ def fail_reference_replay(*_args, **_kwargs): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for physical backward executor test") @pytest.mark.parametrize("cell_type", ["slstm", "axoncell"]) -def test_fabric_cuda_supported_training_can_use_physical_backward_executor( - cell_type: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: +def test_fabric_cuda_supported_training_can_use_physical_backward_executor(cell_type: str) -> None: runtime = _build_supported_training_runtime(cell_type) + assert not hasattr(runtime, "_run_backend_sequence_surface_backward_full_replay_once") boundary_seq = torch.randn( 2, 4, @@ -4139,15 +4109,6 @@ def test_fabric_cuda_supported_training_can_use_physical_backward_executor( tape_policy=TapePolicy(mode=TapeMode.CHECKPOINT, checkpoint_t=2), ) - def fail_full_replay(*_args, **_kwargs): - raise AssertionError("physical backward mode should not route through full replay") - - monkeypatch.setattr( - runtime, - "_run_backend_sequence_surface_backward_full_replay_once", - fail_full_replay, - ) - grad_boundary = torch.autograd.grad( y.square().sum() + state["cells"].square().sum(), boundary_seq, @@ -4268,7 +4229,7 @@ def record_shared_population_input_lowering(*args, **kwargs): assert lowering_calls == 0 record = runtime.last_backend_execution assert record is not None - assert record.surface_key == "flat_bucket_sequence_surface" + assert record.surface_key == "registered_temporal_sequence_surface" assert record.tape_policy_bin.startswith("physical_temporal_bucket_") assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors assert "physical_tiny_message_backward_executor" in record.backward_physical_op_executors @@ -4327,8 +4288,6 @@ def test_fabric_cuda_supported_training_backward_uses_reset_aware_physical_polic assert record.message_reset_scopes == ("batch_row",) if cell_type == "slstm": assert "zero_source_rows" in record.state_affine_reset_policies - assert "full_replay_autograd" not in record.backward_physical_op_kinds - assert "full_replay_autograd" not in record.backward_physical_op_executors assert any(owner.endswith(":active_cuda_owner") for owner in record.backward_saved_launch_counts) @@ -4377,10 +4336,8 @@ def invalid_backward_plan(**kwargs): @pytest.mark.parametrize("cell_type", ["slstm", "axoncell"]) def test_fabric_cuda_training_surface_uses_flat_bucket_route_without_single_population_selector( cell_type: str, - monkeypatch: pytest.MonkeyPatch, ) -> None: runtime = _build_supported_training_runtime(cell_type) - monkeypatch.setattr(runtime, "_select_output_cells_stream_backend_population", lambda **_kwargs: None) boundary_seq = torch.randn( 2, 4, @@ -4406,32 +4363,27 @@ def test_fabric_cuda_training_surface_uses_flat_bucket_route_without_single_popu record = runtime.last_backend_execution assert record is not None assert record.backend_name == "cuda" - assert record.surface_key == "flat_bucket_sequence_surface" + assert record.surface_key == "registered_temporal_sequence_surface" assert record.cell_type == "bucketed" assert record.launch_temporal_executions == ("temporal_bucket_sequence",) - assert "stored_temporal_physical_scan" in record.physical_op_executors + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Fabric training tape-policy test") -def test_fabric_cuda_slstm_training_surface_uses_backend_checkpoint_policy_without_wrapper_chunking( - monkeypatch: pytest.MonkeyPatch, -) -> None: +def test_fabric_cuda_slstm_training_surface_uses_backend_checkpoint_policy_without_wrapper_chunking() -> None: model = build(_make_slstm_spec(hidden_size=16), d_hidden=16).cuda() batch_size = 2 time_steps = 4 hidden_seq = torch.randn(batch_size, time_steps, 16, device="cuda") - def fail_wrapper_checkpoint(*_args, **_kwargs): - raise AssertionError("supported Fabric training surface should not use wrapper-owned checkpoint chunking") - - monkeypatch.setattr(model, "_forward_sequence_checkpointed", fail_wrapper_checkpoint) + assert not hasattr(model, "_forward_sequence_checkpointed") y, next_state = model(hidden_seq, state=None, resets=None, k=1) assert model.runtime.last_backend_execution is not None checkpoint_record = model.runtime.last_backend_execution - assert checkpoint_record.surface_key == "flat_bucket_sequence_surface" + assert checkpoint_record.surface_key == "registered_temporal_sequence_surface" assert checkpoint_record.tape_policy_bin in { "checkpoint", "physical_temporal_bucket_full_transition_tape", @@ -4543,6 +4495,7 @@ def test_fabric_cuda_terminal_output_boundary_matches_sequence_last_step( terminal_record = runtime.last_backend_execution assert terminal_record is not None assert "sequence_output_boundary:terminal_step" in terminal_record.workspace_aliases + assert "sequence_output_materialization:terminal_step_only" in terminal_record.workspace_aliases assert terminal_record.active_receiver_window_modes in { ("readout_dependency_cone",), ("full_recurrent_closure",), @@ -4559,57 +4512,124 @@ def test_fabric_cuda_terminal_output_boundary_matches_sequence_last_step( assert any(grad is not None for grad in terminal_grads) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Fabric training tape-policy test") -def test_fabric_cuda_axon_sequence_surface_does_not_call_step_engine_in_forward( - monkeypatch: pytest.MonkeyPatch, -) -> None: - runtime = build(_make_axon_spec()).cuda() - assert isinstance(runtime, Runtime) - boundary_seq = torch.randn( - 2, - 4, - runtime.input_cell_idx.numel(), - runtime.hidden_size, +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for flat temporal terminal materialization test", +) +def test_fabric_cuda_temporal_terminal_output_materializes_only_final_step_in_executor() -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair( + _make_spec(k_max=4, default_k=2), + d_hidden=16, + ) + batch_size = 2 + time_steps = 4 + generator = torch.Generator(device="cuda").manual_seed(987) + x_sequence = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator) + x_terminal = x_sequence.detach().clone() + x_reference = x_sequence.detach().clone() + resets = torch.tensor( + [ + [False, True, False, False], + [False, False, False, True], + ], + dtype=torch.bool, device="cuda", ) - population_module = runtime.population_modules["axoncell"] - assert not hasattr(population_module, "forward_step_packed_preproj") - assert not hasattr(population_module, "sequence_memory_policy") - assert not hasattr(population_module, "supports_direct_grad_sequence") with torch.no_grad(): - y, next_state = runtime.forward_output_cells_for_readout( - state=runtime.init_state(2, device="cuda", dtype=torch.float32), - resets=None, - k=1, - boundary_input=boundary_seq, - training_semantics=True, + y_sequence, _state_sequence = cuda_model( + x_sequence, + state=None, + resets=resets, + k=2, + materialize_final_state=False, + output_boundary="sequence", ) - - assert runtime.last_backend_execution is not None - assert runtime.last_backend_execution.surface_key == "flat_bucket_sequence_surface" - assert y.shape == (2, 4, runtime.output_cell_idx.numel(), runtime.hidden_size) - assert next_state["cells"].shape[0] == 2 + y_terminal, state_terminal = cuda_model( + x_terminal, + state=None, + resets=resets, + k=2, + materialize_final_state=False, + output_boundary="terminal", + ) + y_reference, state_reference = pytorch_model( + x_reference, + state=None, + resets=resets, + k=2, + materialize_final_state=False, + output_boundary="terminal", + ) + + assert tuple(cast(TensorDictBase, state_terminal).keys()) == () + assert tuple(cast(TensorDictBase, state_reference).keys()) == () + assert y_terminal.shape == y_reference.shape == (batch_size, 1, 16) + torch.testing.assert_close(y_terminal, y_sequence[:, -1:], rtol=2e-4, atol=4e-5) + torch.testing.assert_close(y_terminal, y_reference, rtol=2e-4, atol=4e-5) + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.time_steps == time_steps + assert record.inner_steps == 2 + assert record.launch_temporal_scan_physical_steps == ("8",) + assert record.launch_temporal_scan_emission_counts == ("1",) + assert record.launch_temporal_scan_output_boundaries == ("terminal",) + assert "forward_transition=registered_fused_forward_program_cuda" in record.physical_op_executors + assert "active_output_window" not in record.launch_scan_implementations + assert "sequence_output_boundary:terminal_step" in record.workspace_aliases + assert "sequence_output_materialization:terminal_step_only" in record.workspace_aliases @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Fabric training tape-policy test") -def test_fabric_cuda_axon_training_surface_uses_backend_checkpoint_policy_without_wrapper_chunking( +def test_fabric_cuda_axon_sequence_surface_does_not_call_step_engine_in_forward( monkeypatch: pytest.MonkeyPatch, ) -> None: + runtime = build(_make_axon_spec()).cuda() + assert isinstance(runtime, Runtime) + boundary_seq = torch.randn( + 2, + 4, + runtime.input_cell_idx.numel(), + runtime.hidden_size, + device="cuda", + ) + population_module = runtime.population_modules["axoncell"] + assert not hasattr(population_module, "forward_step_packed_preproj") + assert not hasattr(population_module, "sequence_memory_policy") + assert not hasattr(population_module, "supports_direct_grad_sequence") + + with torch.no_grad(): + y, next_state = runtime.forward_output_cells_for_readout( + state=runtime.init_state(2, device="cuda", dtype=torch.float32), + resets=None, + k=1, + boundary_input=boundary_seq, + training_semantics=True, + ) + + assert runtime.last_backend_execution is not None + assert runtime.last_backend_execution.surface_key == "registered_temporal_sequence_surface" + assert y.shape == (2, 4, runtime.output_cell_idx.numel(), runtime.hidden_size) + assert next_state["cells"].shape[0] == 2 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Fabric training tape-policy test") +def test_fabric_cuda_axon_training_surface_uses_backend_checkpoint_policy_without_wrapper_chunking() -> None: model = build(_make_axon_spec(), d_hidden=8).cuda() batch_size = 2 time_steps = 4 hidden_seq = torch.randn(batch_size, time_steps, 8, device="cuda") - def fail_wrapper_checkpoint(*_args, **_kwargs): - raise AssertionError("supported Fabric Axon training surface should not use wrapper-owned checkpoint chunking") - - monkeypatch.setattr(model, "_forward_sequence_checkpointed", fail_wrapper_checkpoint) + assert not hasattr(model, "_forward_sequence_checkpointed") y, next_state = model(hidden_seq, state=None, resets=None, k=1) assert model.runtime.last_backend_execution is not None - assert model.runtime.last_backend_execution.surface_key == "flat_bucket_sequence_surface" + assert model.runtime.last_backend_execution.surface_key == "registered_temporal_sequence_surface" assert model.runtime.last_backend_execution.tape_policy_bin in { "checkpoint", "physical_temporal_bucket_full_transition_tape", @@ -4648,7 +4668,7 @@ def test_fabric_cuda_axon_training_surface_full_save_policy_matches_checkpointed ) checkpoint_record = runtime.last_backend_execution assert checkpoint_record is not None - assert checkpoint_record.surface_key == "flat_bucket_sequence_surface" + assert checkpoint_record.surface_key == "registered_temporal_sequence_surface" assert checkpoint_record.tape_policy_bin.startswith("physical_temporal_bucket_") assert "physical_temporal_bucket_sequence_backward" in checkpoint_record.backward_physical_op_executors loss_checkpoint = y_checkpoint.square().sum() + state_checkpoint["cells"].square().sum() @@ -4664,7 +4684,7 @@ def test_fabric_cuda_axon_training_surface_full_save_policy_matches_checkpointed ) full_save_record = runtime.last_backend_execution assert full_save_record is not None - assert full_save_record.surface_key == "flat_bucket_sequence_surface" + assert full_save_record.surface_key == "registered_temporal_sequence_surface" assert full_save_record.tape_policy_bin.startswith("physical_temporal_bucket_") assert "physical_temporal_bucket_sequence_backward" in full_save_record.backward_physical_op_executors loss_full_save = y_full_save.square().sum() + state_full_save["cells"].square().sum() @@ -4679,86 +4699,6 @@ def test_fabric_cuda_axon_training_surface_full_save_policy_matches_checkpointed torch.testing.assert_close(grad_checkpoint, grad_full_save, rtol=6e-4, atol=6e-4) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for reset local-message regression test") -def test_fabric_slstm_reset_local_message_backward_matches_zeroed_sender_bank_reference() -> None: - torch.manual_seed(0) - runtime = build( - init( - Config( - width=8, - height=8, - hidden_size=16, - cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, - population_mix={"slstm": 1.0}, - patch_edges_per_cell=0, - projection_region_shape=(2, 2), - k_max=1, - default_k=1, - seed=29, - ) - ) - ).cuda() - assert isinstance(runtime, Runtime) - - batch_size = 3 - num_input_senders = runtime.input_cell_idx.numel() - num_recurrent_senders = runtime._num_recurrent_cells - hidden_size = runtime.hidden_size - step_flat = torch.ones(batch_size, device="cuda", dtype=torch.long) - reset_mask = torch.tensor([False, True, True], device="cuda", dtype=torch.bool) - - q = torch.randn(num_recurrent_senders, hidden_size, device="cuda", requires_grad=True) - input_k = torch.randn(batch_size, num_input_senders, hidden_size, device="cuda", requires_grad=True) - input_v = torch.randn(batch_size, num_input_senders, hidden_size, device="cuda", requires_grad=True) - recurrent_k = torch.randn(batch_size, num_recurrent_senders, hidden_size, device="cuda", requires_grad=True) - recurrent_v = torch.randn(batch_size, num_recurrent_senders, hidden_size, device="cuda", requires_grad=True) - - recurrent_k_zeroed = torch.where(reset_mask.view(-1, 1, 1), torch.zeros_like(recurrent_k), recurrent_k) - recurrent_v_zeroed = torch.where(reset_mask.view(-1, 1, 1), torch.zeros_like(recurrent_v), recurrent_v) - msg_ref = fabric_local_message_partitioned_cuda( - q, - input_k, - input_v, - recurrent_k_zeroed, - recurrent_v_zeroed, - runtime.recurrent_local_sender_idx, - runtime.recurrent_local_receiver_idx_by_sender, - runtime.local_distance, - runtime.local_delay, - step_flat, - num_input_senders=num_input_senders, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=bool(runtime._has_edge_delay), - ) - grad_msg = torch.randn_like(msg_ref) - grads_ref = torch.autograd.grad( - (msg_ref * grad_msg).sum(), - (q, input_k, input_v, recurrent_k, recurrent_v), - allow_unused=False, - ) - - grads_manual = _local_message_partitioned_step_backward_manual( - q=q.detach(), - input_k=input_k.detach(), - input_v=input_v.detach(), - recurrent_k=recurrent_k.detach(), - recurrent_v=recurrent_v.detach(), - receiver_sender_idx=runtime.recurrent_local_sender_idx, - sender_receiver_idx=runtime.recurrent_local_receiver_idx_by_sender, - offset_distance=runtime.local_distance, - offset_delay=runtime.local_delay, - grad_msg=grad_msg.detach(), - msg=msg_ref.detach(), - num_input_senders=num_input_senders, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=bool(runtime._has_edge_delay), - reset_mask=reset_mask, - ) - - for grad_manual, grad_ref in zip(grads_manual, grads_ref, strict=True): - torch.testing.assert_close(grad_manual, grad_ref, rtol=1e-5, atol=1e-5) - - @pytest.mark.parametrize( ("cell_populations", "population_mix"), [ @@ -4773,14 +4713,14 @@ def test_fabric_slstm_reset_local_message_backward_matches_zeroed_sender_bank_re ], ) @pytest.mark.parametrize("k_rows_values", [[2, 2], [1, 2]]) -def test_fabric_stream_step_boundary_multistep_fast_path_matches_previous_reference( +def test_fabric_stream_step_boundary_multistep_fast_path_matches_declared_message_rule_reference( cell_populations, population_mix, k_rows_values, ): runtime = build( init( - Config( + _lattice_test_config( width=4, height=4, hidden_size=8, @@ -4845,7 +4785,7 @@ def test_fabric_stream_step_boundary_multistep_fast_path_matches_previous_refere boundary_step=boundary_step, population_materialized=population_materialized, ) - reference_y, reference_state = _reference_stream_step_boundary_multistep_previous( + reference_y, reference_state = _reference_stream_step_boundary_multistep_declared_message_rule( runtime, cells_prev, population_state=state_ref, @@ -4867,7 +4807,7 @@ def test_fabric_stream_step_boundary_multistep_fast_path_matches_previous_refere def test_fabric_cuda_message_kernel_matches_reference_forward_and_backward(): runtime = build( init( - Config( + _lattice_test_config( width=5, height=4, hidden_size=8, @@ -4886,7 +4826,7 @@ def test_fabric_cuda_message_kernel_matches_reference_forward_and_backward(): ).cuda() assert isinstance(runtime, Runtime) - z_prev_fast = torch.randn(2, 3, runtime.coords.shape[0], runtime.config.d_public, device="cuda", requires_grad=True) + z_prev_fast = torch.randn(2, 3, runtime.coords.shape[0], runtime.d_public, device="cuda", requires_grad=True) q_fast = ( runtime.q_proj(runtime.slot_embed).view(runtime.coords.shape[0], runtime.head_dim).detach().requires_grad_(True) ) @@ -4959,7 +4899,7 @@ def test_fabric_cuda_local_message_kernel_matches_sparse_subset_reference( ) -> None: runtime = build( init( - Config( + _lattice_test_config( width=32, height=16, hidden_size=8, @@ -5051,7 +4991,7 @@ def test_fabric_cuda_local_message_kernel_matches_sparse_subset_reference( def test_fabric_cuda_local_message_full_sequence_matches_reference() -> None: runtime = build( init( - Config( + _lattice_test_config( width=16, height=8, hidden_size=8, @@ -5077,7 +5017,7 @@ def test_fabric_cuda_local_message_full_sequence_matches_reference() -> None: batch_size, time_steps, runtime.coords.shape[0], - runtime.config.d_public, + runtime.d_public, device="cuda", requires_grad=True, ) @@ -5142,7 +5082,7 @@ def test_fabric_cuda_partitioned_local_message_matches_fallback_reference( ) -> None: runtime = build( init( - Config( + _lattice_test_config( width=32, height=16, hidden_size=8, @@ -5262,95 +5202,6 @@ def test_fabric_cuda_partitioned_local_message_matches_fallback_reference( torch.testing.assert_close(local_grad, sparse_grad, rtol=2e-4, atol=2e-4) -@pytest.mark.skipif( - not torch.cuda.is_available(), - reason="CUDA required for fused partitioned local message backward parity test", -) -def test_local_partitioned_fused_backward_matches_autograd_reference() -> None: - runtime = build( - init( - Config( - width=32, - height=16, - hidden_size=8, - cell_populations={"slstm": CellPopulationConfig(cell_type="slstm")}, - population_mix={"slstm": 1.0}, - patch_edges_per_cell=0, - local_radius=1.5, - projection_region_shape=(4, 4), - input_band_width=1, - output_band_width=1, - wrap=True, - conduction_speed=1.0, - max_delay=4, - seed=37, - ) - ) - ).cuda() - assert isinstance(runtime, Runtime) - - torch.manual_seed(20260419) - batch_size = 3 - num_input_senders = runtime._num_input_cells - q = torch.randn(runtime._num_recurrent_cells, runtime.head_dim, device="cuda") - input_k = torch.randn(batch_size, num_input_senders, runtime.head_dim, device="cuda") - input_v = torch.randn(batch_size, num_input_senders, runtime.value_dim, device="cuda") - recurrent_k = torch.randn(batch_size, runtime._num_recurrent_cells, runtime.head_dim, device="cuda") - recurrent_v = torch.randn(batch_size, runtime._num_recurrent_cells, runtime.value_dim, device="cuda") - step_flat = torch.ones(batch_size, device="cuda", dtype=torch.long) - - msg = fabric_local_message_partitioned_cuda( - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - runtime.recurrent_local_sender_idx, - runtime.recurrent_local_receiver_idx_by_sender, - runtime.local_distance, - runtime.local_delay, - step_flat, - num_input_senders=num_input_senders, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=bool(runtime._has_edge_delay), - ) - grad_msg = torch.randn_like(msg) - - grads_fused = local_message_mod.fabric_local_message_partitioned_backward_fused_cuda( - grad_msg, - q, - input_k, - input_v, - recurrent_k, - recurrent_v, - runtime.recurrent_local_sender_idx, - runtime.local_distance, - runtime.local_delay, - step_flat, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=bool(runtime._has_edge_delay), - ) - grads_ref = _local_message_partitioned_step_backward_manual( - q=q, - input_k=input_k, - input_v=input_v, - recurrent_k=recurrent_k, - recurrent_v=recurrent_v, - receiver_sender_idx=runtime.recurrent_local_sender_idx, - sender_receiver_idx=runtime.recurrent_local_receiver_idx_by_sender, - offset_distance=runtime.local_distance, - offset_delay=runtime.local_delay, - grad_msg=grad_msg, - msg=msg, - num_input_senders=num_input_senders, - distance_scale=float(runtime.config.distance_logit_scale), - use_delay=bool(runtime._has_edge_delay), - ) - - for grad_fused, grad_ref in zip(grads_fused, grads_ref, strict=True): - torch.testing.assert_close(grad_fused, grad_ref, rtol=2e-4, atol=2e-4) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for grouped projection parity test") def test_fabric_cuda_dense_affine_receiver_major_matches_reference() -> None: torch.manual_seed(20260414) @@ -5915,49 +5766,6 @@ def test_fabric_cuda_reset_backend_tensors_rows_many_matches_reference_backward( torch.testing.assert_close(actual_grad, expected_grad, rtol=0.0, atol=0.0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for grouped projection parity test") -def test_fabric_cuda_grouped_projection_matches_reference_forward_and_backward() -> None: - torch.manual_seed(1337) - device = torch.device("cuda") - batch_size, num_groups, group_size = 3, 7, 5 - hidden_size, out_dim = 9, 11 - - sender_cells = torch.randn( - batch_size, - num_groups * group_size, - hidden_size, - device=device, - dtype=torch.float32, - requires_grad=True, - ) - grouped_weight = torch.randn( - num_groups, - hidden_size, - out_dim, - device=device, - dtype=torch.float32, - requires_grad=True, - ) - - projected = fabric_grouped_projection_cuda(sender_cells, grouped_weight, group_size=group_size) - projected.square().mean().backward() - actual_grad_input = sender_cells.grad.detach().clone() - actual_grad_weight = grouped_weight.grad.detach().clone() - - sender_ref = sender_cells.detach().clone().requires_grad_(True) - weight_ref = grouped_weight.detach().clone().requires_grad_(True) - expected = torch.einsum( - "bgsh,ghm->bgsm", - sender_ref.reshape(batch_size, num_groups, group_size, hidden_size), - weight_ref, - ).reshape(batch_size, num_groups * group_size, out_dim) - expected.square().mean().backward() - - torch.testing.assert_close(projected, expected, rtol=1e-6, atol=1e-6) - torch.testing.assert_close(actual_grad_input, sender_ref.grad, rtol=1e-5, atol=1e-5) - torch.testing.assert_close(actual_grad_weight, weight_ref.grad, rtol=1e-5, atol=1e-5) - - @pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA required for factorized projection backward parity test", @@ -5993,6 +5801,32 @@ def test_fabric_cuda_factorized_recurrent_input_projection_grads_match_reference torch.testing.assert_close(grad_input, expected_input, rtol=2e-4, atol=2e-4) +def test_fabric_transition_unfuse_routes_direct_message_to_cell_grad_from_compiler_source() -> None: + hidden = value_dim = 5 + receivers = 7 + projected = 3 + message_to_cell_weight = torch.randn(hidden, value_dim) + grad_fused_weight = torch.randn(hidden, value_dim) + grad_fused_bias = torch.randn(1, receivers, hidden) + + static_grads, materialized_grads = cuda_transition_projection._unfuse_recurrent_input_projection_grads( # noqa: SLF001 + static_tensors={ + "input_proj_weight_t": torch.randn(receivers, hidden, projected), + "message_to_cell_weight": message_to_cell_weight, + "recurrent_cell_bias": torch.randn(1, receivers, hidden), + "recurrent_message_to_cell_weight_source": "message_to_cell_weight", + }, + grad_fused_weight=grad_fused_weight, + grad_fused_bias=grad_fused_bias, + selected_static_source="message_to_cell_weight", + ) + + assert set(static_grads) == {"message_to_cell_weight", "recurrent_cell_bias"} + assert materialized_grads == {} + torch.testing.assert_close(static_grads["message_to_cell_weight"], grad_fused_weight) + torch.testing.assert_close(static_grads["recurrent_cell_bias"], grad_fused_bias) + + @pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA required for factorized projection backward parity test", @@ -6036,7 +5870,7 @@ def test_fabric_cuda_factorized_recurrent_input_direct_backward_matches_referenc def test_fabric_cuda_message_kernel_supports_more_than_65535_receivers(): runtime = build( init( - Config( + _lattice_test_config( width=258, height=258, hidden_size=4, @@ -6053,7 +5887,7 @@ def test_fabric_cuda_message_kernel_supports_more_than_65535_receivers(): ).cuda() assert runtime.coords.shape[0] > 65_535 - z_prev = torch.randn(1, 1, runtime.coords.shape[0], runtime.config.d_public, device="cuda") + z_prev = torch.randn(1, 1, runtime.coords.shape[0], runtime.d_public, device="cuda") q = runtime.q_proj(runtime.slot_embed).view(runtime.coords.shape[0], runtime.head_dim) gathered = torch.cat( ( @@ -6100,13 +5934,20 @@ def test_fabric_cuda_flat_bucket_sequence_matches_pytorch_reference(): device="cuda", ) - with torch.no_grad(): - y_cuda, state_cuda = cuda_runtime.forward_cells( + with ( + torch.no_grad(), + pytest.raises( + RuntimeError, + match="compiler-owned CUDA temporal table scan", + ), + ): + cuda_runtime.forward_cells( boundary_input=boundary_seq, state=None, resets=resets, k=1, ) + with torch.no_grad(): y_pytorch, state_pytorch = pytorch_runtime.forward_cells( boundary_input=boundary_seq, state=None, @@ -6114,19 +5955,15 @@ def test_fabric_cuda_flat_bucket_sequence_matches_pytorch_reference(): k=1, ) - assert cuda_runtime.last_backend_execution is not None - assert cuda_runtime.last_backend_execution.backend_name == "cuda" - assert cuda_runtime.last_backend_execution.surface_key == "flat_bucket_sequence_surface" - assert cuda_runtime.last_backend_execution.cell_type == "bucketed" - assert cuda_runtime.last_backend_execution.launch_temporal_executions == ("temporal_bucket_sequence",) - assert cuda_runtime.last_backend_execution.launch_scan_implementations == ("flat_bucket_temporal_scan",) - assert "flat_bucket_temporal_scan" in cuda_runtime.last_backend_execution.physical_op_executors - assert cuda_runtime.last_backend_execution.active_receiver_window_modes == ("full_recurrent_closure",) - assert cuda_runtime.last_backend_execution.physical_op_demotions == ("active_region_closure_full_surface",) assert pytorch_runtime.last_backend_execution is not None assert pytorch_runtime.last_backend_execution.backend_name == "pytorch" - torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-5, atol=1e-5) - _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-5, atol=1e-5) + assert y_pytorch.shape == ( + batch_size, + time_steps, + pytorch_runtime.coords.shape[0], + pytorch_runtime.hidden_size, + ) + assert isinstance(state_pytorch, TensorDictBase) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population parity test") @@ -6155,14 +5992,21 @@ def test_fabric_cuda_mixed_population_fresh_state_without_final_state_matches_py device="cuda", ) - with torch.no_grad(): - y_cuda, state_cuda = cuda_runtime.forward_cells( + with ( + torch.no_grad(), + pytest.raises( + RuntimeError, + match="compiler-owned CUDA temporal table scan", + ), + ): + cuda_runtime.forward_cells( boundary_input=boundary_seq, state=None, resets=resets, k=1, materialize_final_state=False, ) + with torch.no_grad(): y_pytorch, _state_pytorch = pytorch_runtime.forward_cells( boundary_input=boundary_seq, state=None, @@ -6171,220 +6015,320 @@ def test_fabric_cuda_mixed_population_fresh_state_without_final_state_matches_py materialize_final_state=False, ) - assert cuda_runtime.last_backend_execution is not None - assert cuda_runtime.last_backend_execution.backend_name == "cuda" - assert cuda_runtime.last_backend_execution.surface_key == "flat_bucket_sequence_surface" - assert cuda_runtime.last_backend_execution.cell_type == "bucketed" - assert cuda_runtime.last_backend_execution.launch_temporal_executions == ("temporal_bucket_sequence",) - assert cuda_runtime.last_backend_execution.launch_scan_implementations == ("flat_bucket_temporal_scan",) - assert "flat_bucket_temporal_scan" in cuda_runtime.last_backend_execution.physical_op_executors - assert cuda_runtime.last_backend_execution.active_receiver_window_modes == ("full_recurrent_closure",) - assert cuda_runtime.last_backend_execution.physical_op_demotions == () - torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-4, atol=2e-5) - assert isinstance(state_cuda, TensorDictBase) - for population_name in ("slstm", "axoncell"): - population_state = state_cuda.get(population_name) - if isinstance(population_state, TensorDictBase): - assert not population_state.keys() - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population backward parity test") -def test_fabric_cuda_backend_order_transition_buckets_backward_matches_autograd_reference(): - torch.manual_seed(1234) - torch.cuda.manual_seed_all(1234) - runtime = _build_runtime_for_backend(_make_spec(), "cuda") - runtime._active_backend_name = "cuda" - batch_size = 3 - generator = torch.Generator(device="cuda").manual_seed(1234) - recurrent_msg_ref = torch.randn( + assert y_pytorch.shape == ( batch_size, - runtime.recurrent_cell_idx.numel(), - runtime.value_dim, - device="cuda", - generator=generator, - requires_grad=True, - ) - recurrent_msg_physical = recurrent_msg_ref.detach().clone() - grad_recurrent_hidden = torch.randn( - (batch_size, runtime.recurrent_cell_idx.numel(), runtime.hidden_size), - device="cuda", - generator=generator, - ) - resets = torch.tensor([False, True, False], dtype=torch.bool, device="cuda") - base_state = runtime.init_state(batch_size, device="cuda", dtype=torch.float32) - state_ref = _clone_fabric_population_state_with_grad(base_state) - state_physical = _clone_fabric_population_state_with_grad(base_state) - static_tensors = runtime._get_training_static_tensors( - device=torch.device("cuda"), - dtype=torch.float32, - include_backend_prepack=True, - include_full_cell_kv_weight=True, - detach_static_tensors=False, - ) - - runtime.zero_grad(set_to_none=True) - recurrent_hidden_ref, _ = runtime._run_backend_order_transition_buckets_step( - recurrent_msg_ref, - state_ref, - resets=resets, - batch_size=batch_size, - static_tensors=static_tensors, - materialize_next_state=True, + time_steps, + pytorch_runtime.coords.shape[0], + pytorch_runtime.hidden_size, ) - loss_ref = (recurrent_hidden_ref * grad_recurrent_hidden).sum() - loss_ref.backward() - grad_recurrent_msg_ref = recurrent_msg_ref.grad.detach().clone() - grad_state_ref = _floating_state_grad_tensors(state_ref) - param_grads_ref = _param_grads(runtime) - - trainable_items = tuple((name, param) for name, param in runtime.named_parameters() if param.requires_grad) - runtime.zero_grad(set_to_none=True) - grad_recurrent_msg, grad_state, param_grads = runtime._run_backend_order_transition_buckets_backward_step( - recurrent_msg_physical, - state_physical, - grad_recurrent_hidden=grad_recurrent_hidden, - resets=resets, - static_tensors=static_tensors, - trainable_params=tuple(param for _, param in trainable_items), - trainable_param_names=tuple(name for name, _ in trainable_items), - ) - - assert grad_recurrent_msg is not None - torch.testing.assert_close(grad_recurrent_msg, grad_recurrent_msg_ref, rtol=3e-3, atol=3e-3) - grad_state_actual = { - f"{population_name}.{state_name}": tensor - for population_name, population_state in grad_state.items() - if isinstance(population_state, TensorDictBase) - for state_name, tensor in population_state.items() - if torch.is_tensor(tensor) - } - missing_state_keys = set(grad_state_ref) - set(grad_state_actual) - assert all(name.startswith("axoncell.E_") for name in missing_state_keys) - assert set(grad_state_actual) <= set(grad_state_ref) - for name in grad_state_actual: - torch.testing.assert_close(grad_state_actual[name], grad_state_ref[name], rtol=3e-3, atol=3e-3) - param_grad_actual = { - name: grad.detach() - for (name, _param), grad in zip(trainable_items, param_grads, strict=True) - if grad is not None - } - assert param_grad_actual.keys() == param_grads_ref.keys() - for name in param_grads_ref: - torch.testing.assert_close(param_grad_actual[name], param_grads_ref[name], rtol=3e-3, atol=3e-3) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population backward parity test") -def test_fabric_cuda_temporal_bucket_step_backward_matches_autograd_reference(): +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population flat carry cache test") +@pytest.mark.parametrize("use_resets", [False, True]) +def test_fabric_cuda_mixed_population_k1_output_only_uses_shared_temporal_bucket_scan( + use_resets: bool, +) -> None: torch.manual_seed(1234) torch.cuda.manual_seed_all(1234) - runtime = _build_runtime_for_backend(_make_spec(), "cuda") - runtime._active_backend_name = "cuda" - batch_size = 3 - generator = torch.Generator(device="cuda").manual_seed(1234) - boundary_ref = torch.randn( - batch_size, - runtime.input_cell_idx.numel(), - runtime.hidden_size, - device="cuda", - generator=generator, - requires_grad=True, - ) - boundary_physical = boundary_ref.detach().clone().requires_grad_(True) - grad_cells_out = torch.randn( - batch_size, - runtime.coords.shape[0], - runtime.hidden_size, - device="cuda", - generator=generator, - ) - resets = torch.tensor([False, True, False], dtype=torch.bool, device="cuda") - base_state = runtime.init_state(batch_size, device="cuda", dtype=torch.float32) - state_ref = _clone_fabric_state_with_grad(base_state) - state_physical = _clone_fabric_state_with_grad(base_state) - static_tensors = runtime._get_training_static_tensors( - device=torch.device("cuda"), - dtype=torch.float32, - include_backend_prepack=True, - include_full_cell_kv_weight=True, - detach_static_tensors=False, - ) - - runtime.zero_grad(set_to_none=True) - artifacts_ref = compute_temporal_bucket_step_artifacts( - runtime, - boundary_step=boundary_ref, - state=state_ref, - reset_step=resets, - static_tensors=static_tensors, + cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=16) + batch_size = 2 + time_steps = 3 + inner_steps = 1 + generator = torch.Generator(device="cuda").manual_seed(13579) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator) + x_pytorch = x_cuda.detach().clone() + resets = ( + torch.tensor( + [ + [False, True, False], + [True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + if use_resets + else None ) - loss_ref = (artifacts_ref.cells_out * grad_cells_out).sum() - loss_ref.backward() - grad_boundary_ref = boundary_ref.grad.detach().clone() - grad_state_ref = { - "cells": state_ref["cells"].grad.detach().clone(), - **_floating_state_grad_tensors(state_ref), - } - param_grads_ref = _param_grads(runtime) - trainable_items = tuple((name, param) for name, param in runtime.named_parameters() if param.requires_grad) - runtime.zero_grad(set_to_none=True) with torch.no_grad(): - artifacts_physical = compute_temporal_bucket_step_artifacts( - runtime, - boundary_step=boundary_physical, - state=state_physical, - reset_step=resets, - static_tensors=static_tensors, + y_cuda, state_cuda = cuda_model( + x_cuda, + state=None, + resets=resets, + k=inner_steps, + materialize_final_state=False, + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=None, + resets=resets, + k=inner_steps, + materialize_final_state=False, ) - grad_boundary, grad_state, param_grads, _grad_backend_state_cache = run_temporal_bucket_step_backward( - runtime, - artifacts_physical, - grad_cells_out=grad_cells_out, - static_tensors=static_tensors, - trainable_params=tuple(param for _, param in trainable_items), - trainable_param_names=tuple(name for name, _param in trainable_items), - need_grad_state_before=True, - ) - assert grad_boundary is not None - torch.testing.assert_close(grad_boundary, grad_boundary_ref, rtol=3e-3, atol=3e-3) - grad_state_actual = { - "cells": grad_state["cells"], - **{ - f"{population_name}.{state_name}": tensor - for population_name, population_state in grad_state.items() - if isinstance(population_state, TensorDictBase) - for state_name, tensor in population_state.items() - if torch.is_tensor(tensor) - }, - } - assert set(grad_state_actual) <= set(grad_state_ref) - missing_state_keys = set(grad_state_ref) - set(grad_state_actual) - assert all(name.startswith("axoncell.E_") for name in missing_state_keys) - for name in grad_state_actual: - torch.testing.assert_close(grad_state_actual[name], grad_state_ref[name], rtol=3e-3, atol=3e-3) - param_grad_actual = { - name: grad.detach() - for (name, _param), grad in zip(trainable_items, param_grads, strict=True) - if grad is not None - } - assert param_grad_actual.keys() == param_grads_ref.keys() - for name in param_grads_ref: - torch.testing.assert_close(param_grad_actual[name], param_grads_ref[name], rtol=3e-3, atol=3e-3) + assert tuple(cast(TensorDictBase, state_cuda).keys()) == () + assert tuple(cast(TensorDictBase, state_pytorch).keys()) == () + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.time_steps == time_steps + assert record.inner_steps == inner_steps + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_physical_steps == (str(time_steps * inner_steps),) + assert record.launch_temporal_scan_emission_counts == (str(time_steps),) + assert record.active_receiver_window_modes == ("full_recurrent_closure",) + assert "readout_dependency_active_region" not in record.capability_variants + assert "flat_bucket_temporal_scan:recurrent_kv_carry_reuse" in record.workspace_aliases + assert "flat_bucket_state_cache:registered_fused_program_internal_state" in record.workspace_aliases + assert "flat_bucket_public_projection:registered_fused_forward_program_cuda" in record.workspace_aliases + assert "flat_bucket_readout:registered_fused_forward_program_cuda" in record.workspace_aliases + temporal_table_aliases = tuple( + alias for alias in record.workspace_aliases if alias.startswith("flat_bucket_temporal_table") + ) + assert "flat_bucket_temporal_table:temporal_table_abi=flat_bucket_tensor_tables" in temporal_table_aliases + assert any(alias.startswith("flat_bucket_temporal_table_primitive_rows:") for alias in temporal_table_aliases) + assert not any(term in "\n".join(temporal_table_aliases) for term in ("slstm", "axoncell")) + assert "flat_bucket_temporal_scan_binding_abi:registered_executor_binding_rows" in record.workspace_aliases + assert "single_bucket_sequence_executor" not in record.physical_op_executors -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population training route test") -def test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route(): +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population flat carry cache test") +@pytest.mark.parametrize("use_resets", [False, True]) +def test_fabric_cuda_mixed_population_k_gt1_output_only_uses_backend_order_flat_carry_cache( + use_resets: bool, +) -> None: torch.manual_seed(1234) torch.cuda.manual_seed_all(1234) - cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=8) + cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=16) batch_size = 2 - time_steps = 4 - generator = torch.Generator(device="cuda").manual_seed(1234) - x_cuda = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) - x_pytorch = x_cuda.detach().clone().requires_grad_(True) - resets = torch.tensor( - [ + time_steps = 3 + inner_steps = 2 + generator = torch.Generator(device="cuda").manual_seed(2468) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator) + x_pytorch = x_cuda.detach().clone() + resets = ( + torch.tensor( + [ + [False, True, False], + [True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + if use_resets + else None + ) + + with torch.no_grad(): + y_cuda, state_cuda = cuda_model( + x_cuda, + state=None, + resets=resets, + k=inner_steps, + materialize_final_state=False, + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=None, + resets=resets, + k=inner_steps, + materialize_final_state=False, + ) + + assert tuple(cast(TensorDictBase, state_cuda).keys()) == () + assert tuple(cast(TensorDictBase, state_pytorch).keys()) == () + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.time_steps == time_steps + assert record.inner_steps == inner_steps + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_physical_steps == (str(time_steps * inner_steps),) + assert "flat_bucket_temporal_scan:recurrent_kv_carry_reuse" in record.workspace_aliases + assert "flat_bucket_state_cache:registered_fused_program_internal_state" in record.workspace_aliases + assert "flat_bucket_public_projection:registered_fused_forward_program_cuda" in record.workspace_aliases + assert "flat_bucket_readout:registered_fused_forward_program_cuda" in record.workspace_aliases + assert "flat_bucket_state_cache_materialized_steps:0" in record.workspace_aliases + assert "single_bucket_sequence_executor" not in record.physical_op_executors + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population provided-state scan") +@pytest.mark.parametrize("inner_steps", [1, 2]) +@pytest.mark.parametrize("use_resets", [False, True]) +def test_fabric_cuda_mixed_population_provided_state_output_only_uses_temporal_superop( + inner_steps: int, + use_resets: bool, +) -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=16) + batch_size = 2 + time_steps = 3 + generator = torch.Generator(device="cuda").manual_seed(97531) + warmup_x = torch.randn(batch_size, 1, 16, device="cuda", generator=generator) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator) + x_pytorch = x_cuda.detach().clone() + resets = ( + torch.tensor( + [ + [False, True, False], + [True, False, True], + ], + dtype=torch.bool, + device="cuda", + ) + if use_resets + else None + ) + + with torch.no_grad(): + _warmup_y_cuda, provided_state_cuda = cuda_model( + warmup_x, + state=None, + k=1, + materialize_final_state=True, + ) + _warmup_y_pytorch, provided_state_pytorch = pytorch_model( + warmup_x.detach().clone(), + state=None, + k=1, + materialize_final_state=True, + ) + y_cuda, state_cuda = cuda_model( + x_cuda, + state=provided_state_cuda.clone(), + resets=resets, + k=inner_steps, + materialize_final_state=False, + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=provided_state_pytorch.clone(), + resets=resets, + k=inner_steps, + materialize_final_state=False, + ) + + assert tuple(cast(TensorDictBase, state_cuda).keys()) == () + assert tuple(cast(TensorDictBase, state_pytorch).keys()) == () + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.time_steps == time_steps + assert record.inner_steps == inner_steps + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_physical_steps == (str(time_steps * inner_steps),) + assert "flat_bucket_state_cache:registered_fused_program_internal_state" in record.workspace_aliases + assert "flat_bucket_public_projection:registered_fused_forward_program_cuda" in record.workspace_aliases + assert "flat_bucket_readout:registered_fused_forward_program_cuda" in record.workspace_aliases + assert "single_bucket_sequence_executor" not in record.physical_op_executors + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population final-state scan") +@pytest.mark.parametrize("state_mode", ["fresh", "provided"]) +@pytest.mark.parametrize("inner_steps", [1, 2]) +def test_fabric_cuda_mixed_population_final_state_uses_temporal_superop( + state_mode: str, + inner_steps: int, +) -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=16) + batch_size = 2 + time_steps = 3 + generator = torch.Generator(device="cuda").manual_seed(86420) + warmup_x = torch.randn(batch_size, 1, 16, device="cuda", generator=generator) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator) + x_pytorch = x_cuda.detach().clone() + resets = torch.tensor( + [ + [False, True, False], + [True, False, True], + ], + dtype=torch.bool, + device="cuda", + ) + + with torch.no_grad(): + state_cuda_in = None + state_pytorch_in = None + if state_mode == "provided": + _warmup_y_cuda, state_cuda_in = cuda_model( + warmup_x, + state=None, + k=1, + materialize_final_state=True, + ) + _warmup_y_pytorch, state_pytorch_in = pytorch_model( + warmup_x.detach().clone(), + state=None, + k=1, + materialize_final_state=True, + ) + state_cuda_in = state_cuda_in.clone() + state_pytorch_in = state_pytorch_in.clone() + y_cuda, state_cuda = cuda_model( + x_cuda, + state=state_cuda_in, + resets=resets, + k=inner_steps, + materialize_final_state=True, + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=state_pytorch_in, + resets=resets, + k=inner_steps, + materialize_final_state=True, + ) + + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-3, atol=2e-4) + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.time_steps == time_steps + assert record.inner_steps == inner_steps + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_physical_steps == (str(time_steps * inner_steps),) + assert "final_state=materialized" in record.workspace_aliases + assert "flat_bucket_state_cache:registered_fused_program_final_tensor_table" in record.workspace_aliases + assert "flat_bucket_public_projection:registered_fused_forward_program_cuda" in record.workspace_aliases + assert "flat_bucket_readout:registered_fused_forward_program_cuda" in record.workspace_aliases + assert "single_bucket_sequence_executor" not in record.physical_op_executors + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population training route test") +def test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route(): + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=8) + batch_size = 2 + time_steps = 4 + generator = torch.Generator(device="cuda").manual_seed(1234) + x_cuda = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = torch.tensor( + [ [False, True, False, False], [False, False, True, False], ], @@ -6393,114 +6337,949 @@ def test_fabric_cuda_mixed_population_t_gt1_training_uses_flat_bucket_route(): ) y_cuda, _state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) - y_pytorch, _state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=1) + y_pytorch, _state_pytorch = _fabric_forward_reference(pytorch_model, x_pytorch, state=None, resets=resets, k=1) torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-4, atol=2e-5) - loss_cuda = y_cuda.square().mean() - loss_pytorch = y_pytorch.square().mean() - loss_cuda.backward() - loss_pytorch.backward() + loss_cuda = y_cuda.square().mean() + loss_pytorch = y_pytorch.square().mean() + loss_cuda.backward() + loss_pytorch.backward() + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.backend_name == "cuda" + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert "flat_bucket_temporal_scan" in record.capability_variants + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors + assert "flat_bucket_temporal_scan:recurrent_kv_carry_reuse" in record.workspace_aliases + assert record.active_receiver_window_modes == ("full_recurrent_closure",) + assert record.physical_op_demotions == ("active_region_closure_full_surface",) + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + assert "physical_tiny_message_backward_executor" in record.backward_physical_op_executors + assert "physical_receiver_affine_backward_executor" in record.backward_physical_op_executors + assert "physical_state_epilogue_backward_executor" in record.backward_physical_op_executors + assert "physical_diagonal_recurrence_backward_executor" in record.backward_physical_op_executors + assert "projection_reduction_boundary_backward" in record.backward_physical_op_executors + assert "autograd_transition_bucket_backward" not in record.backward_physical_op_executors + assert "autograd_readout_projection_backward" not in record.backward_physical_op_executors + assert record.backward_physical_op_demotions == ("active_region_closure_full_surface",) + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=2e-3, atol=2e-3) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=3e-3, atol=3e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population T=1 training route test") +@pytest.mark.parametrize("use_resets", [False, True]) +def test_fabric_cuda_mixed_population_t1_k1_pooled_output_uses_registered_reverse_program_window( + use_resets: bool, +) -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair( + _make_spec(gradient_horizon_steps=1, k_max=1, default_k=1), + d_hidden=8, + ) + batch_size = 2 + time_steps = 1 + generator = torch.Generator(device="cuda").manual_seed(987) + x_cuda = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = ( + torch.tensor( + [ + [False], + [True], + ], + dtype=torch.bool, + device="cuda", + ) + if use_resets + else None + ) + + y_cuda, state_cuda = cuda_model( + x_cuda, + state=None, + resets=resets, + k=1, + materialize_final_state=False, + ) + y_pytorch, state_pytorch = _fabric_forward_reference(pytorch_model, x_pytorch, state=None, resets=resets, k=1) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + assert isinstance(state_cuda, TensorDictBase) + assert tuple(state_cuda.keys()) == () + assert isinstance(state_pytorch, TensorDictBase) + + y_cuda.square().mean().backward() + y_pytorch.square().mean().backward() + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.backend_name == "cuda" + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.time_steps == time_steps + assert record.inner_steps == 1 + assert record.temporal_plan_bucket_identity == ("flat_bucket_identity",) + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_temporal_scan_physical_steps == ("1",) + assert record.launch_temporal_scan_emission_counts == ("1",) + assert "single_bucket_sequence_executor" not in record.physical_op_executors + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + _assert_registered_reverse_program_window_owned(record) + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=6e-3, atol=6e-3) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for temporal forward reverse-table test") +def test_fabric_cuda_t1_terminal_loss_uses_fused_forward_artifact_tensor_store() -> None: + torch.manual_seed(4321) + torch.cuda.manual_seed_all(4321) + spec = _make_spec(gradient_horizon_steps=1, k_max=1, default_k=1) + terminal_model = _build_fabric_model_for_backend( + spec, + "cuda", + d_hidden=8, + ) + sequence_model = _build_fabric_model_for_backend( + spec, + "cuda", + d_hidden=8, + ) + sequence_model.load_state_dict(terminal_model.state_dict()) + batch_size = 2 + time_steps = 1 + generator = torch.Generator(device="cuda").manual_seed(4321) + x_terminal = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_sequence = x_terminal.detach().clone().requires_grad_(True) + + y_terminal, state_terminal = terminal_model( + x_terminal, + state=None, + resets=None, + k=1, + materialize_final_state=False, + output_boundary="terminal", + ) + y_sequence, state_sequence = sequence_model( + x_sequence, + state=None, + resets=None, + k=1, + materialize_final_state=False, + output_boundary="sequence", + ) + assert tuple(cast(TensorDictBase, state_terminal).keys()) == () + assert tuple(cast(TensorDictBase, state_sequence).keys()) == () + torch.testing.assert_close(y_terminal, y_sequence[:, -1:], rtol=0.0, atol=0.0) + + y_terminal.square().mean().backward() + y_sequence[:, -1:].square().mean().backward() + + record = terminal_model.runtime.last_backend_execution + assert record is not None + _assert_generic_flat_bucket_sequence_record(record) + assert record.temporal_plan_reverse_artifact_kinds == ("store_step_artifacts",) + assert record.temporal_plan_recompute_window_steps == ("1",) + assert any("materialization=store_step_artifacts" in item for item in record.temporal_plan_materialization_reasons) + assert any( + "reason=compiler_owned_fused_forward_artifact_tensor_table" in item + for item in record.temporal_plan_materialization_reasons + ) + assert "temporal_artifacts:store_step_artifacts" in record.workspace_aliases + assert any("reverse_artifacts=store_step_artifacts" in item for item in record.backward_recompute_mode) + assert not any("temporal_artifact_recompute" in item for item in record.backward_owner_timing_ms) + assert not any("artifact.recompute.cuda_temporal_replay_scan" in item for item in record.backward_owner_timing_ms) + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + _assert_registered_reverse_program_window_owned(record) + assert x_terminal.grad is not None + assert x_sequence.grad is not None + torch.testing.assert_close(x_terminal.grad, x_sequence.grad, rtol=6e-3, atol=6e-3) + terminal_grads = _param_grads(terminal_model) + sequence_grads = _param_grads(sequence_model) + assert terminal_grads.keys() == sequence_grads.keys() + for name in terminal_grads: + torch.testing.assert_close(terminal_grads[name], sequence_grads[name], rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for temporal final-state-only loss test") +def test_fabric_cuda_final_state_only_loss_uses_registered_zero_output_grad_window() -> None: + torch.manual_seed(4330) + torch.cuda.manual_seed_all(4330) + spec = _make_spec(gradient_horizon_steps=1, k_max=1, default_k=1) + state_only_model = _build_fabric_model_for_backend( + spec, + "cuda", + d_hidden=8, + ) + explicit_zero_model = _build_fabric_model_for_backend( + spec, + "cuda", + d_hidden=8, + ) + explicit_zero_model.load_state_dict(state_only_model.state_dict()) + batch_size = 2 + time_steps = 1 + generator = torch.Generator(device="cuda").manual_seed(4330) + x_state_only = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_explicit_zero = x_state_only.detach().clone().requires_grad_(True) + + y_state_only, state_state_only = state_only_model( + x_state_only, + state=None, + resets=None, + k=1, + materialize_final_state=True, + output_boundary="terminal", + ) + y_explicit_zero, state_explicit_zero = explicit_zero_model( + x_explicit_zero, + state=None, + resets=None, + k=1, + materialize_final_state=True, + output_boundary="terminal", + ) + torch.testing.assert_close(y_state_only, y_explicit_zero, rtol=0.0, atol=0.0) + _assert_fabric_semantic_state_close(state_state_only, state_explicit_zero, rtol=0.0, atol=0.0) + + _state_square_mean(state_state_only).backward() + (_state_square_mean(state_explicit_zero) + y_explicit_zero.square().sum() * 0.0).backward() + + record = state_only_model.runtime.last_backend_execution + assert record is not None + assert record.backend_name == "cuda" + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + _assert_registered_reverse_program_window_owned(record) + assert x_state_only.grad is not None + assert x_explicit_zero.grad is not None + torch.testing.assert_close(x_state_only.grad, x_explicit_zero.grad, rtol=6e-3, atol=6e-3) + state_only_grads = _param_grads(state_only_model) + explicit_zero_grads = _param_grads(explicit_zero_model) + assert set(state_only_grads).issubset(set(explicit_zero_grads)) + for name in state_only_grads: + torch.testing.assert_close(state_only_grads[name], explicit_zero_grads[name], rtol=1e-2, atol=1e-2) + for name in set(explicit_zero_grads) - set(state_only_grads): + torch.testing.assert_close(explicit_zero_grads[name], torch.zeros_like(explicit_zero_grads[name])) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for temporal direct-table T>1 test") +def test_fabric_cuda_t_gt1_fused_forward_artifact_tensor_store_plan_is_not_t1_specific() -> None: + torch.manual_seed(4323) + torch.cuda.manual_seed_all(4323) + spec = _make_spec(gradient_horizon_steps=None, k_max=1, default_k=1) + terminal_model = _build_fabric_model_for_backend( + spec, + "cuda", + d_hidden=8, + ) + sequence_model = _build_fabric_model_for_backend( + spec, + "cuda", + d_hidden=8, + ) + sequence_model.load_state_dict(terminal_model.state_dict()) + batch_size = 2 + time_steps = 2 + generator = torch.Generator(device="cuda").manual_seed(4323) + x_terminal = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_sequence = x_terminal.detach().clone().requires_grad_(True) + + y_terminal, state_terminal = terminal_model( + x_terminal, + state=None, + resets=None, + k=1, + materialize_final_state=False, + output_boundary="terminal", + ) + y_sequence, state_sequence = sequence_model( + x_sequence, + state=None, + resets=None, + k=1, + materialize_final_state=False, + output_boundary="sequence", + ) + assert tuple(cast(TensorDictBase, state_terminal).keys()) == () + assert tuple(cast(TensorDictBase, state_sequence).keys()) == () + torch.testing.assert_close(y_terminal, y_sequence[:, -1:], rtol=0.0, atol=0.0) + + y_terminal.square().mean().backward() + y_sequence[:, -1:].square().mean().backward() + + record = terminal_model.runtime.last_backend_execution + assert record is not None + _assert_generic_flat_bucket_sequence_record(record) + assert record.time_steps == time_steps + assert record.inner_steps == 1 + assert record.temporal_plan_total_scan_steps == ("2",) + assert record.temporal_plan_reverse_artifact_kinds == ("store_step_artifacts",) + assert record.temporal_plan_recompute_window_steps == ("2",) + assert any("materialization=store_step_artifacts" in item for item in record.temporal_plan_materialization_reasons) + assert any("physical_steps=2" in item for item in record.temporal_plan_materialization_reasons) + assert "temporal_artifacts:store_step_artifacts" in record.workspace_aliases + assert any("reverse_artifacts=store_step_artifacts" in item for item in record.backward_recompute_mode) + assert not any("temporal_artifact_recompute" in item for item in record.backward_owner_timing_ms) + assert x_terminal.grad is not None + assert x_sequence.grad is not None + torch.testing.assert_close(x_terminal.grad, x_sequence.grad, rtol=6e-3, atol=6e-3) + terminal_grads = _param_grads(terminal_model) + sequence_grads = _param_grads(sequence_model) + assert terminal_grads.keys() == sequence_grads.keys() + for name in terminal_grads: + torch.testing.assert_close(terminal_grads[name], sequence_grads[name], rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for temporal provided-state test") +def test_fabric_cuda_t1_provided_state_keeps_recurrent_sender_boundary_gradients() -> None: + torch.manual_seed(4322) + torch.cuda.manual_seed_all(4322) + spec = _make_spec(gradient_horizon_steps=1, k_max=1, default_k=1) + terminal_model = _build_fabric_model_for_backend( + spec, + "cuda", + d_hidden=8, + ) + sequence_model = _build_fabric_model_for_backend( + spec, + "cuda", + d_hidden=8, + ) + sequence_model.load_state_dict(terminal_model.state_dict()) + batch_size = 2 + time_steps = 1 + generator = torch.Generator(device="cuda").manual_seed(4322) + x_terminal = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_sequence = x_terminal.detach().clone().requires_grad_(True) + base_state = terminal_model.runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + state_terminal = _clone_fabric_state_with_grad(base_state) + state_sequence = _clone_fabric_state_with_grad(base_state) + + y_terminal, next_terminal_state = terminal_model( + x_terminal, + state=state_terminal, + resets=None, + k=1, + materialize_final_state=True, + output_boundary="terminal", + ) + y_sequence, next_sequence_state = sequence_model( + x_sequence, + state=state_sequence, + resets=None, + k=1, + materialize_final_state=True, + output_boundary="sequence", + ) + torch.testing.assert_close(y_terminal, y_sequence[:, -1:], rtol=0.0, atol=0.0) + _assert_fabric_semantic_state_close(next_terminal_state, next_sequence_state, rtol=1e-5, atol=1e-5) + + (y_terminal.square().mean() + 0.01 * _state_square_mean(next_terminal_state)).backward() + (y_sequence[:, -1:].square().mean() + 0.01 * _state_square_mean(next_sequence_state)).backward() + + record = terminal_model.runtime.last_backend_execution + assert record is not None + _assert_generic_flat_bucket_sequence_record(record) + assert record.temporal_plan_backward_owners == ("registered_reverse_executor_bindings",) + assert record.temporal_plan_reverse_artifact_kinds == ("store_step_artifacts",) + assert any("materialization=store_step_artifacts" in item for item in record.temporal_plan_materialization_reasons) + assert any("final_state=1" in item for item in record.temporal_plan_materialization_reasons) + assert any( + "reason=compiler_owned_fused_forward_artifact_tensor_table" in item + for item in record.temporal_plan_materialization_reasons + ) + _assert_registered_reverse_program_window_owned(record) + state_grads_terminal = _fabric_state_grads(state_terminal) + state_grads_sequence = _fabric_state_grads(state_sequence) + assert state_grads_terminal.keys() == state_grads_sequence.keys() + assert state_grads_terminal + for name in state_grads_terminal: + torch.testing.assert_close(state_grads_terminal[name], state_grads_sequence[name], rtol=1e-2, atol=1e-2) + assert x_terminal.grad is not None + assert x_sequence.grad is not None + torch.testing.assert_close(x_terminal.grad, x_sequence.grad, rtol=6e-3, atol=6e-3) + terminal_grads = _param_grads(terminal_model) + sequence_grads = _param_grads(sequence_model) + assert terminal_grads.keys() == sequence_grads.keys() + for name in terminal_grads: + torch.testing.assert_close(terminal_grads[name], sequence_grads[name], rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population backward parity test") +def test_fabric_cuda_mixed_population_t_gt1_recomputed_artifacts_match_pytorch_reference( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=8) + monkeypatch.setattr( + cuda_model.runtime, + "_cuda_memory_budget", + lambda _device: CudaMemoryBudget( + usable_bytes=1 << 20, + total_bytes=8 << 30, + free_bytes=1 << 20, + reusable_reserved_bytes=0, + ), + ) + batch_size = 2 + time_steps = 5 + generator = torch.Generator(device="cuda").manual_seed(2345) + x_cuda = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = torch.tensor( + [ + [False, True, False, False, True], + [False, False, True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + + y_cuda, _state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) + y_pytorch, _state_pytorch = _fabric_forward_reference(pytorch_model, x_pytorch, state=None, resets=resets, k=1) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-4, atol=2e-5) + + y_cuda.square().mean().backward() + y_pytorch.square().mean().backward() + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert "temporal_artifacts:store_step_artifacts" in record.backward_recompute_mode + assert any("artifact_mode=store_step_artifacts" in item for item in record.backward_recompute_mode) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=2e-3, atol=2e-3) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=3e-3, atol=3e-3) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for forced flat-bucket single-population parity test", +) +@pytest.mark.parametrize(("family", "spec"), [("slstm", _make_slstm_spec()), ("axoncell", _make_axon_spec())]) +def test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference(family: str, spec) -> None: + del family + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=32) + _force_flat_bucket_sequence_route(cuda_model.runtime) + batch_size = 2 + time_steps = 3 + generator = torch.Generator(device="cuda").manual_seed(1234) + x_cuda = torch.randn(batch_size, time_steps, 32, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = torch.tensor( + [ + [False, True, False], + [False, False, True], + ], + dtype=torch.bool, + device="cuda", + ) + + y_cuda, _state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) + y_pytorch, _state_pytorch = _fabric_forward_reference(pytorch_model, x_pytorch, state=None, resets=resets, k=1) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-4, atol=2e-5) + + y_cuda.square().mean().backward() + y_pytorch.square().mean().backward() + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.backend_name == "cuda" + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors + assert "single_bucket_sequence_executor" not in record.physical_op_executors + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=5e-3, atol=5e-3) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=8e-3, atol=8e-3) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for K>1 temporal bucket single-population parity test", +) +@pytest.mark.parametrize( + ("family", "spec"), + [ + ("slstm", _make_slstm_spec(k_max=4, default_k=2)), + ("axoncell", _make_axon_spec(k_max=4, default_k=2)), + ], +) +def test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence( + family: str, + spec, +) -> None: + del family + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=16) + batch_size = 2 + time_steps = 3 + generator = torch.Generator(device="cuda").manual_seed(1234) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = torch.tensor( + [ + [False, True, False], + [False, False, True], + ], + dtype=torch.bool, + device="cuda", + ) + + y_cuda, state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=2) + y_pytorch, state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=2) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-3, atol=2e-4) + + (y_cuda.square().mean() + 0.01 * _state_square_mean(state_cuda)).backward() + (y_pytorch.square().mean() + 0.01 * _state_square_mean(state_pytorch)).backward() + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.backend_name == "cuda" + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.cell_type == "bucketed" + assert record.time_steps == time_steps + assert record.inner_steps == 2 + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_physical_steps == ("6",) + assert record.launch_temporal_scan_emission_counts == ("3",) + assert record.launch_temporal_scan_output_boundaries == ("sequence",) + assert "single_bucket_sequence_executor" not in record.physical_op_executors + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=5e-3, atol=5e-3) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=8e-3, atol=8e-3) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for one-bucket terminal temporal replay parity test", +) +@pytest.mark.parametrize( + ("family", "spec"), + [ + ("slstm", _make_slstm_spec(k_max=4, default_k=2, gradient_horizon_steps=2)), + ("axoncell", _make_axon_spec(k_max=4, default_k=2, gradient_horizon_steps=2)), + ], +) +def test_fabric_cuda_single_population_terminal_replay_uses_temporal_superop( + family: str, + spec, +) -> None: + del family + torch.manual_seed(20260428) + torch.cuda.manual_seed_all(20260428) + cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=16) + batch_size = 2 + time_steps = 3 + generator = torch.Generator(device="cuda").manual_seed(20260428) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = torch.tensor( + [ + [False, True, False], + [True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + + y_cuda, state_cuda = cuda_model( + x_cuda, + state=None, + resets=resets, + k=2, + materialize_final_state=False, + output_boundary="terminal", + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=None, + resets=resets, + k=2, + materialize_final_state=False, + output_boundary="terminal", + ) + assert tuple(cast(TensorDictBase, state_cuda).keys()) == () + assert tuple(cast(TensorDictBase, state_pytorch).keys()) == () + assert y_cuda.shape == y_pytorch.shape == (batch_size, 1, 16) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=3e-4, atol=8e-5) + + y_cuda.square().mean().backward() + y_pytorch.square().mean().backward() + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert ( + cuda_model.runtime._last_flat_bucket_temporal_artifact_recompute_owner + == "registered_fused_forward_program_tensor_store_direct" + ) + assert record.temporal_plan_gradient_boundaries == ("rolling_horizon",) + assert record.temporal_plan_horizon_steps == ("2",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert any( + "reverse_owner=registered_fused_reverse_program_tensor_table" in item for item in record.backward_recompute_mode + ) + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + assert "single_bucket_sequence_executor" not in record.physical_op_executors + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=7e-3, atol=7e-3) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for delayed K>1 temporal bucket parity test", +) +def test_fabric_cuda_single_population_k_gt1_delay_uses_schedule_message_steps() -> None: + torch.manual_seed(20260428) + torch.cuda.manual_seed_all(20260428) + spec = _make_slstm_spec( + hidden_size=8, + patch_edges_per_cell=2, + max_delay=4, + conduction_speed=1.0, + k_max=4, + default_k=3, + ) + cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=8) + _force_flat_bucket_sequence_route(cuda_model.runtime) + batch_size = 2 + time_steps = 2 + generator = torch.Generator(device="cuda").manual_seed(20260428) + x_cuda = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + resets = torch.tensor( + [ + [False, True], + [True, False], + ], + dtype=torch.bool, + device="cuda", + ) + + with pytest.raises( + RuntimeError, + match="compiler-owned CUDA temporal table scan", + ): + cuda_model(x_cuda, state=None, resets=resets, k=3) + y_pytorch, state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=3) + (y_pytorch.square().mean() + 0.01 * _state_square_mean(state_pytorch)).backward() + assert y_pytorch.shape == (batch_size, time_steps, 8) + assert x_pytorch.grad is not None + pytorch_grads = _param_grads(pytorch_model) + assert pytorch_grads + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for shared temporal engine inference parity test", +) +@pytest.mark.parametrize( + ("family", "spec"), + [ + ("slstm", _make_slstm_spec(k_max=4, default_k=2)), + ("axoncell", _make_axon_spec(k_max=4, default_k=2)), + ], +) +@pytest.mark.parametrize(("k", "expected_inner", "expected_physical"), [(1, 1, "3"), (2, 2, "6")]) +def test_fabric_cuda_single_flat_bucket_inference_uses_shared_temporal_engine( + family: str, + spec, + k: int, + expected_inner: int, + expected_physical: str, +) -> None: + torch.manual_seed(4321) + torch.cuda.manual_seed_all(4321) + cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=16) + batch_size = 2 + time_steps = 3 + generator = torch.Generator(device="cuda").manual_seed(4321) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator) + x_pytorch = x_cuda.detach().clone() + resets = torch.tensor( + [ + [False, True, False], + [True, False, True], + ], + dtype=torch.bool, + device="cuda", + ) + + with torch.no_grad(): + y_cuda, state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=k, materialize_final_state=True) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, state=None, resets=resets, k=k, materialize_final_state=True + ) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-3, atol=2e-4) + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.time_steps == time_steps + assert record.inner_steps == expected_inner + expected_scan_impl = "registered_temporal_fused_forward_program_cuda" + expected_scan_owner = "registered_fused_forward_program_cuda" + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_scan_implementations == (expected_scan_impl,) + assert record.launch_temporal_scan_owners == (expected_scan_owner,) + assert record.launch_temporal_scan_outer_steps == ("3",) + assert record.launch_temporal_scan_inner_steps == (str(expected_inner),) + assert record.launch_temporal_scan_physical_steps == (expected_physical,) + assert record.launch_temporal_scan_emission_counts == ("3",) + assert record.launch_temporal_scan_output_boundaries == ("sequence",) + assert "flat_bucket_temporal_scan" in record.capability_variants + assert expected_scan_impl in record.capability_variants + assert expected_scan_impl in record.physical_op_executors + assert "single_bucket_sequence_executor" not in record.physical_op_executors + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for output-only temporal superop pooled readout parity test", +) +@pytest.mark.parametrize( + ("family", "spec"), + [ + ("slstm", _make_slstm_spec(k_max=4, default_k=2)), + ("axoncell", _make_axon_spec(k_max=4, default_k=2)), + ], +) +@pytest.mark.parametrize("k", [1, 2]) +@pytest.mark.parametrize("use_resets", [False, True]) +def test_fabric_cuda_single_flat_bucket_output_only_pooled_readout_uses_temporal_superop( + family: str, + spec, + k: int, + use_resets: bool, +) -> None: + torch.manual_seed(2468) + torch.cuda.manual_seed_all(2468) + cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=16) + batch_size = 2 + time_steps = 3 + generator = torch.Generator(device="cuda").manual_seed(2468) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator) + x_pytorch = x_cuda.detach().clone() + resets = ( + torch.tensor( + [ + [False, True, False], + [True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + if use_resets + else None + ) + + with torch.no_grad(): + y_cuda, state_cuda = cuda_model( + x_cuda, + state=None, + resets=resets, + k=k, + materialize_final_state=False, + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=None, + resets=resets, + k=k, + materialize_final_state=False, + ) + + assert tuple(cast(TensorDictBase, state_cuda).keys()) == () + assert tuple(cast(TensorDictBase, state_pytorch).keys()) == () + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) record = cuda_model.runtime.last_backend_execution assert record is not None - assert record.backend_name == "cuda" - assert record.surface_key == "flat_bucket_sequence_surface" - assert record.cell_type == "bucketed" + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.time_steps == time_steps + assert record.inner_steps == k assert record.launch_temporal_executions == ("temporal_bucket_sequence",) - assert record.launch_scan_implementations == ("stored_temporal_physical_scan",) - assert "flat_bucket_temporal_scan" in record.capability_variants - assert "stored_temporal_physical_scan" in record.physical_op_executors - assert "flat_bucket_temporal_scan:recurrent_kv_carry_reuse" in record.workspace_aliases + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_temporal_scan_outer_steps == ("3",) + assert record.launch_temporal_scan_inner_steps == (str(k),) + assert record.launch_temporal_scan_physical_steps == (str(time_steps * k),) + assert record.launch_temporal_scan_output_boundaries == ("sequence",) assert record.active_receiver_window_modes == ("full_recurrent_closure",) - assert record.physical_op_demotions == ("active_region_closure_full_surface",) - assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors - assert "physical_tiny_message_backward_executor" in record.backward_physical_op_executors - assert "physical_receiver_affine_backward_executor" in record.backward_physical_op_executors - assert "physical_state_epilogue_backward_executor" in record.backward_physical_op_executors - assert "physical_diagonal_recurrence_backward_executor" in record.backward_physical_op_executors - assert "explicit_readout_projection_thin_reverse" in record.backward_physical_op_executors - assert "autograd_transition_bucket_backward" not in record.backward_physical_op_executors - assert "autograd_readout_projection_backward" not in record.backward_physical_op_executors - assert record.backward_physical_op_demotions == ("active_region_closure_full_surface",) - assert x_cuda.grad is not None - assert x_pytorch.grad is not None - torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=2e-3, atol=2e-3) - cuda_grads = _param_grads(cuda_model) - pytorch_grads = _param_grads(pytorch_model) - assert cuda_grads.keys() == pytorch_grads.keys() - for name in cuda_grads: - torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=3e-3, atol=3e-3) + assert record.active_receiver_window_offsets == ("0",) + assert record.active_receiver_window_counts == ("8",) + assert "final_state=not_materialized" in record.workspace_aliases + if not use_resets: + assert any( + "flat_bucket_temporal_physical_strategy:selected_strategy=streaming_step_producer_consumer" in item + for item in record.workspace_aliases + ) + assert any( + "flat_bucket_temporal_registered_backward_memory_stage:stage=native_forward_after_streaming_message_release" + in item + for item in record.workspace_aliases + ) + else: + assert any( + "flat_bucket_temporal_physical_strategy:selected_strategy=stage_materialized" in item + for item in record.workspace_aliases + ) + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors + assert "single_bucket_sequence_executor" not in record.physical_op_executors -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed population backward parity test") -def test_fabric_cuda_mixed_population_t_gt1_recomputed_artifacts_match_pytorch_reference( - monkeypatch: pytest.MonkeyPatch, -): - torch.manual_seed(1234) - torch.cuda.manual_seed_all(1234) - cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=8) - monkeypatch.setattr( - cuda_model.runtime, - "_cuda_memory_budget", - lambda _device: CudaMemoryBudget( - usable_bytes=1 << 20, - total_bytes=8 << 30, - free_bytes=1 << 20, - reusable_reserved_bytes=0, - ), - ) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for compact temporal superop active-window parity test", +) +@pytest.mark.parametrize( + ("family", "spec"), + [ + ("slstm", _make_slstm_spec(k_max=4, default_k=2)), + ("axoncell", _make_axon_spec(k_max=4, default_k=2)), + ], +) +@pytest.mark.parametrize( + ("k", "expected_window_mode", "expected_window_offset", "expected_window_count"), + [ + (1, "full_recurrent_closure", "0", "8"), + (2, "full_recurrent_closure", "0", "8"), + ], +) +@pytest.mark.parametrize("use_resets", [False, True]) +def test_fabric_cuda_single_flat_bucket_output_only_one_physical_step_window_uses_temporal_superop( + family: str, + spec, + k: int, + expected_window_mode: str, + expected_window_offset: str, + expected_window_count: str, + use_resets: bool, +) -> None: + torch.manual_seed(1357) + torch.cuda.manual_seed_all(1357) + cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=16) batch_size = 2 - time_steps = 5 - generator = torch.Generator(device="cuda").manual_seed(2345) - x_cuda = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) - x_pytorch = x_cuda.detach().clone().requires_grad_(True) - resets = torch.tensor( - [ - [False, True, False, False, True], - [False, False, True, False, False], - ], - dtype=torch.bool, - device="cuda", + time_steps = 1 + generator = torch.Generator(device="cuda").manual_seed(1357) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator) + x_pytorch = x_cuda.detach().clone() + resets = ( + torch.tensor( + [ + [False], + [True], + ], + dtype=torch.bool, + device="cuda", + ) + if use_resets + else None ) - y_cuda, _state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) - y_pytorch, _state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=1) - torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-4, atol=2e-5) + with torch.no_grad(): + y_cuda, state_cuda = cuda_model( + x_cuda, + state=None, + resets=resets, + k=k, + materialize_final_state=False, + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=None, + resets=resets, + k=k, + materialize_final_state=False, + ) - y_cuda.square().mean().backward() - y_pytorch.square().mean().backward() + assert tuple(cast(TensorDictBase, state_cuda).keys()) == () + assert tuple(cast(TensorDictBase, state_pytorch).keys()) == () + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) record = cuda_model.runtime.last_backend_execution assert record is not None - assert "temporal_artifacts:recompute_step_artifacts" in record.backward_recompute_mode - assert any("artifact_mode=recompute_step_artifacts" in item for item in record.backward_recompute_mode) - assert record.launch_scan_implementations == ("windowed_temporal_physical_scan",) - assert "windowed_temporal_physical_scan" in record.physical_op_executors - assert x_cuda.grad is not None - assert x_pytorch.grad is not None - torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=2e-3, atol=2e-3) - cuda_grads = _param_grads(cuda_model) - pytorch_grads = _param_grads(pytorch_model) - assert cuda_grads.keys() == pytorch_grads.keys() - for name in cuda_grads: - torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=3e-3, atol=3e-3) + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.time_steps == time_steps + assert record.inner_steps == k + assert record.launch_temporal_executions == ("temporal_bucket_sequence",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_temporal_scan_outer_steps == ("1",) + assert record.launch_temporal_scan_inner_steps == (str(k),) + assert record.launch_temporal_scan_physical_steps == (str(k),) + assert record.active_receiver_window_modes == (expected_window_mode,) + assert record.active_receiver_window_offsets == (expected_window_offset,) + assert record.active_receiver_window_counts == (expected_window_count,) + assert "final_state=not_materialized" in record.workspace_aliases + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors + assert "single_bucket_sequence_executor" not in record.physical_op_executors @pytest.mark.skipif( not torch.cuda.is_available(), - reason="CUDA required for forced flat-bucket single-population parity test", + reason="CUDA required for K>1 temporal bucket mixed-population parity test", ) -@pytest.mark.parametrize(("family", "spec"), [("slstm", _make_slstm_spec()), ("axoncell", _make_axon_spec())]) -def test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_reference(family: str, spec) -> None: - del family +def test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence() -> None: torch.manual_seed(1234) torch.cuda.manual_seed_all(1234) - cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=32) - _force_flat_bucket_sequence_route(cuda_model.runtime) + cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=16) batch_size = 2 time_steps = 3 - generator = torch.Generator(device="cuda").manual_seed(1234) - x_cuda = torch.randn(batch_size, time_steps, 32, device="cuda", generator=generator).requires_grad_(True) + generator = torch.Generator(device="cuda").manual_seed(4321) + x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator).requires_grad_(True) x_pytorch = x_cuda.detach().clone().requires_grad_(True) resets = torch.tensor( [ @@ -6511,22 +7290,28 @@ def test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_referen device="cuda", ) - y_cuda, _state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) - y_pytorch, _state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=1) - torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-4, atol=2e-5) + y_cuda, state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=2) + y_pytorch, state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=2) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) + _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-3, atol=2e-4) - y_cuda.square().mean().backward() - y_pytorch.square().mean().backward() + (y_cuda.square().mean() + 0.01 * _state_square_mean(state_cuda)).backward() + (y_pytorch.square().mean() + 0.01 * _state_square_mean(state_pytorch)).backward() record = cuda_model.runtime.last_backend_execution assert record is not None assert record.backend_name == "cuda" - assert record.surface_key == "flat_bucket_sequence_surface" + assert record.surface_key == "registered_temporal_sequence_surface" assert record.cell_type == "bucketed" + assert record.time_steps == time_steps + assert record.inner_steps == 2 assert record.launch_temporal_executions == ("temporal_bucket_sequence",) - assert record.launch_scan_implementations == ("stored_temporal_physical_scan",) - assert "stored_temporal_physical_scan" in record.physical_op_executors + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_temporal_scan_physical_steps == ("6",) + assert record.launch_temporal_scan_emission_counts == ("3",) + assert record.launch_temporal_scan_output_boundaries == ("sequence",) assert "single_bucket_sequence_executor" not in record.physical_op_executors + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors assert x_cuda.grad is not None assert x_pytorch.grad is not None torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=5e-3, atol=5e-3) @@ -6539,78 +7324,116 @@ def test_fabric_cuda_single_population_flat_bucket_route_matches_pytorch_referen @pytest.mark.skipif( not torch.cuda.is_available(), - reason="CUDA required for K>1 temporal bucket single-population parity test", + reason="CUDA required for K>1 mixed-population terminal-loss parity test", ) -@pytest.mark.parametrize( - ("family", "spec"), - [ - ("slstm", _make_slstm_spec(k_max=4, default_k=2)), - ("axoncell", _make_axon_spec(k_max=4, default_k=2)), - ], -) -def test_fabric_cuda_single_population_k_gt1_uses_temporal_bucket_sequence( - family: str, - spec, -) -> None: - del family +def test_fabric_cuda_mixed_population_k_gt1_terminal_loss_maps_final_outer_emission_gradient() -> None: torch.manual_seed(1234) torch.cuda.manual_seed_all(1234) - cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=16) + cuda_model, pytorch_model = _build_fabric_model_pair( + _make_spec(gradient_horizon_steps=2, k_max=4, default_k=2), + d_hidden=16, + ) batch_size = 2 time_steps = 3 - generator = torch.Generator(device="cuda").manual_seed(1234) + generator = torch.Generator(device="cuda").manual_seed(6789) x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator).requires_grad_(True) x_pytorch = x_cuda.detach().clone().requires_grad_(True) resets = torch.tensor( [ [False, True, False], - [False, False, True], + [True, False, False], ], dtype=torch.bool, device="cuda", ) - y_cuda, state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=2) - y_pytorch, state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=2) + y_cuda, state_cuda = cuda_model( + x_cuda, + state=None, + resets=resets, + k=2, + materialize_final_state=False, + output_boundary="terminal", + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=None, + resets=resets, + k=2, + materialize_final_state=False, + output_boundary="terminal", + ) + assert tuple(cast(TensorDictBase, state_cuda).keys()) == () + assert tuple(cast(TensorDictBase, state_pytorch).keys()) == () + assert y_cuda.shape == y_pytorch.shape == (batch_size, 1, 16) torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) - _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-3, atol=2e-4) - (y_cuda.square().mean() + 0.01 * _state_square_mean(state_cuda)).backward() - (y_pytorch.square().mean() + 0.01 * _state_square_mean(state_pytorch)).backward() + y_cuda.square().mean().backward() + y_pytorch.square().mean().backward() record = cuda_model.runtime.last_backend_execution assert record is not None - assert record.backend_name == "cuda" - assert record.surface_key == "flat_bucket_sequence_surface" - assert record.cell_type == "bucketed" + assert record.surface_key == "registered_temporal_sequence_surface" + assert ( + cuda_model.runtime._last_flat_bucket_temporal_artifact_recompute_owner + == "registered_fused_forward_program_tensor_store_direct" + ) assert record.time_steps == time_steps assert record.inner_steps == 2 - assert record.launch_temporal_executions == ("temporal_bucket_sequence",) - assert "single_bucket_sequence_executor" not in record.physical_op_executors + assert record.temporal_plan_gradient_boundaries == ("rolling_horizon",) + assert record.temporal_plan_horizon_steps == ("2",) + assert record.temporal_plan_checkpoint_steps == ("2",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_physical_steps == ("6",) + assert record.launch_temporal_scan_emission_counts == ("1",) + assert record.launch_temporal_scan_output_boundaries == ("terminal",) + assert "sequence_output_boundary:terminal_step" in record.workspace_aliases + assert "sequence_output_materialization:terminal_step_only" in record.workspace_aliases + assert "temporal_artifacts:store_step_artifacts" in record.workspace_aliases + assert "flat_bucket_state_cache:registered_fused_program_internal_state" in record.workspace_aliases + assert any( + "reverse_owner=registered_fused_reverse_program_tensor_table" in item for item in record.backward_recompute_mode + ) assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + assert "cuda_temporal_backward_glue" in record.backward_physical_op_executors + assert "temporal_backward_glue:registered_fused_backward_program_span" in record.backward_launch_counts + assert "temporal_backward_glue:registered_fused_backward_program_span_readout_message_kv" in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_span_transition_boundary" in ( + record.backward_launch_counts + ) + _assert_registered_reverse_program_window_owned(record) assert x_cuda.grad is not None assert x_pytorch.grad is not None - torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=5e-3, atol=5e-3) + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=6e-3, atol=6e-3) cuda_grads = _param_grads(cuda_model) pytorch_grads = _param_grads(pytorch_model) assert cuda_grads.keys() == pytorch_grads.keys() for name in cuda_grads: - torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=8e-3, atol=8e-3) + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=1e-2, atol=1e-2) @pytest.mark.skipif( not torch.cuda.is_available(), - reason="CUDA required for K>1 temporal bucket mixed-population parity test", + reason="CUDA required for K>1 mixed-population provided-state gradient parity test", ) -def test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence() -> None: +def test_fabric_cuda_mixed_population_k_gt1_terminal_loss_propagates_provided_state_gradients() -> None: torch.manual_seed(1234) torch.cuda.manual_seed_all(1234) - cuda_model, pytorch_model = _build_fabric_model_pair(_make_spec(), d_hidden=16) + cuda_model, pytorch_model = _build_fabric_model_pair( + _make_spec(gradient_horizon_steps=2, k_max=4, default_k=2), + d_hidden=16, + ) batch_size = 2 time_steps = 3 - generator = torch.Generator(device="cuda").manual_seed(4321) + generator = torch.Generator(device="cuda").manual_seed(2468) x_cuda = torch.randn(batch_size, time_steps, 16, device="cuda", generator=generator).requires_grad_(True) x_pytorch = x_cuda.detach().clone().requires_grad_(True) + base_state = cuda_model.runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + state_cuda_in = _clone_fabric_state_with_grad(base_state) + state_pytorch_in = _clone_fabric_state_with_grad(base_state) resets = torch.tensor( [ [False, True, False], @@ -6620,8 +7443,23 @@ def test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence() -> N device="cuda", ) - y_cuda, state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=2) - y_pytorch, state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=2) + y_cuda, state_cuda = cuda_model( + x_cuda, + state=state_cuda_in, + resets=resets, + k=2, + materialize_final_state=True, + output_boundary="terminal", + ) + y_pytorch, state_pytorch = pytorch_model( + x_pytorch, + state=state_pytorch_in, + resets=resets, + k=2, + materialize_final_state=True, + output_boundary="terminal", + ) + assert y_cuda.shape == y_pytorch.shape == (batch_size, 1, 16) torch.testing.assert_close(y_cuda, y_pytorch, rtol=2e-4, atol=4e-5) _assert_fabric_semantic_state_close(state_cuda, state_pytorch, rtol=1e-3, atol=2e-4) @@ -6630,22 +7468,98 @@ def test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence() -> N record = cuda_model.runtime.last_backend_execution assert record is not None - assert record.backend_name == "cuda" - assert record.surface_key == "flat_bucket_sequence_surface" - assert record.cell_type == "bucketed" + assert record.surface_key == "registered_temporal_sequence_surface" + assert ( + cuda_model.runtime._last_flat_bucket_temporal_artifact_recompute_owner + == "registered_fused_forward_program_tensor_store_direct" + ) assert record.time_steps == time_steps assert record.inner_steps == 2 - assert record.launch_temporal_executions == ("temporal_bucket_sequence",) - assert "single_bucket_sequence_executor" not in record.physical_op_executors + assert record.temporal_plan_gradient_boundaries == ("rolling_horizon",) + assert record.temporal_plan_horizon_steps == ("2",) + assert record.temporal_plan_checkpoint_steps == ("2",) + assert record.temporal_plan_resets == ("present",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_physical_steps == ("6",) + assert record.launch_temporal_scan_emission_counts == ("1",) + assert record.launch_temporal_scan_output_boundaries == ("terminal",) + assert "sequence_output_boundary:terminal_step" in record.workspace_aliases + assert "sequence_output_materialization:terminal_step_only" in record.workspace_aliases + assert "final_state=materialized" in record.workspace_aliases + assert "temporal_artifacts:store_step_artifacts" in record.workspace_aliases + assert "flat_bucket_state_cache:registered_fused_program_final_tensor_table" in record.workspace_aliases + assert any( + "reverse_owner=registered_fused_reverse_program_tensor_table" in item for item in record.backward_recompute_mode + ) + assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + assert "cuda_temporal_backward_glue" in record.backward_physical_op_executors + assert "temporal_backward_glue:registered_fused_backward_program_span" in record.backward_launch_counts + assert "temporal_backward_glue:registered_fused_backward_program_span_readout_message_kv" in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_span_transition_boundary" in ( + record.backward_launch_counts + ) + _assert_registered_reverse_program_window_owned(record) + assert x_cuda.grad is not None + assert x_pytorch.grad is not None + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=6e-3, atol=6e-3) + state_grads_cuda = _fabric_state_grads(state_cuda_in) + state_grads_pytorch = _fabric_state_grads(state_pytorch_in) + assert state_grads_cuda.keys() == state_grads_pytorch.keys() + for name in state_grads_cuda: + torch.testing.assert_close(state_grads_cuda[name], state_grads_pytorch[name], rtol=1e-2, atol=1e-2) + cuda_grads = _param_grads(cuda_model) + pytorch_grads = _param_grads(pytorch_model) + assert cuda_grads.keys() == pytorch_grads.keys() + for name in cuda_grads: + torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for K=128 mixed-population temporal backward smoke", +) +def test_fabric_cuda_mixed_population_k128_backward_maps_outer_emission_gradients() -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + cuda_model, pytorch_model = _build_fabric_model_pair( + _make_spec(gradient_horizon_steps=64, k_max=128, default_k=128), + d_hidden=8, + ) + batch_size = 1 + time_steps = 1 + generator = torch.Generator(device="cuda").manual_seed(9876) + x_cuda = torch.randn(batch_size, time_steps, 8, device="cuda", generator=generator).requires_grad_(True) + x_pytorch = x_cuda.detach().clone().requires_grad_(True) + + y_cuda, _state_cuda = cuda_model(x_cuda, state=None, k=128) + y_pytorch, _state_pytorch = pytorch_model(x_pytorch, state=None, k=128) + torch.testing.assert_close(y_cuda, y_pytorch, rtol=5e-4, atol=5e-4) + + y_cuda.square().mean().backward() + y_pytorch.square().mean().backward() + + record = cuda_model.runtime.last_backend_execution + assert record is not None + assert record.surface_key == "registered_temporal_sequence_surface" + assert record.inner_steps == 128 + assert record.temporal_plan_horizon_steps == ("64",) + assert record.temporal_plan_checkpoint_steps == ("64",) assert "physical_temporal_bucket_sequence_backward" in record.backward_physical_op_executors + _assert_registered_reverse_program_window_owned(record) + assert "temporal_backward_glue:registered_fused_backward_program_span" in record.backward_launch_counts + assert "temporal_backward_glue:registered_fused_backward_program_span_transition_boundary" in ( + record.backward_launch_counts + ) + assert ( + "temporal_backward_glue:cuda_transition_input_projection_param_grad_window" not in record.backward_launch_counts + ) + assert record.temporal_plan_backward_owners == ("registered_reverse_executor_bindings",) assert x_cuda.grad is not None assert x_pytorch.grad is not None - torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=5e-3, atol=5e-3) - cuda_grads = _param_grads(cuda_model) - pytorch_grads = _param_grads(pytorch_model) - assert cuda_grads.keys() == pytorch_grads.keys() - for name in cuda_grads: - torch.testing.assert_close(cuda_grads[name], pytorch_grads[name], rtol=8e-3, atol=8e-3) + torch.testing.assert_close(x_cuda.grad, x_pytorch.grad, rtol=1e-2, atol=1e-2) @pytest.mark.skipif( @@ -6654,7 +7568,6 @@ def test_fabric_cuda_mixed_population_k_gt1_uses_temporal_bucket_sequence() -> N ) @pytest.mark.parametrize(("family", "spec"), [("slstm", _make_slstm_spec()), ("axoncell", _make_axon_spec())]) def test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executor(family: str, spec) -> None: - del family torch.manual_seed(1234) torch.cuda.manual_seed_all(1234) cuda_model, pytorch_model = _build_fabric_model_pair(spec, d_hidden=32) @@ -6675,17 +7588,23 @@ def test_fabric_cuda_single_population_flat_bucket_forward_uses_sequence_executo with torch.no_grad(): y_cuda, _state_cuda = cuda_model(x_cuda, state=None, resets=resets, k=1) - y_pytorch, _state_pytorch = pytorch_model(x_pytorch, state=None, resets=resets, k=1) + y_pytorch, _state_pytorch = _fabric_forward_reference(pytorch_model, x_pytorch, state=None, resets=resets, k=1) torch.testing.assert_close(y_cuda, y_pytorch, rtol=1e-4, atol=2e-5) record = cuda_model.runtime.last_backend_execution assert record is not None assert record.backend_name == "cuda" - assert record.surface_key == "flat_bucket_sequence_surface" + assert record.surface_key == "registered_temporal_sequence_surface" assert record.cell_type == "bucketed" + expected_scan_impl = "registered_temporal_fused_forward_program_cuda" + expected_scan_owner = "registered_fused_forward_program_cuda" assert record.launch_temporal_executions == ("temporal_bucket_sequence",) - assert record.launch_scan_implementations == ("flat_bucket_temporal_scan",) - assert "flat_bucket_temporal_scan" in record.physical_op_executors + assert record.launch_scan_implementations == (expected_scan_impl,) + assert record.launch_temporal_scan_owners == (expected_scan_owner,) + assert record.launch_temporal_scan_outer_steps == (str(time_steps),) + assert record.launch_temporal_scan_inner_steps == ("1",) + assert record.launch_temporal_scan_physical_steps == (str(time_steps),) + assert expected_scan_impl in record.physical_op_executors assert "single_bucket_sequence_executor" not in record.physical_op_executors @@ -6740,17 +7659,267 @@ def test_fabric_cuda_mixed_population_readout_closed_region_matches_pytorch_refe record = cuda_runtime.last_backend_execution assert record is not None assert record.backend_name == "cuda" - assert record.surface_key == "flat_bucket_sequence_surface" + assert record.surface_key == "registered_temporal_sequence_surface" assert record.cell_type == "bucketed" assert record.launch_temporal_executions == ("temporal_bucket_sequence",) - assert record.launch_scan_implementations == ("flat_bucket_temporal_scan",) - assert "flat_bucket_temporal_scan" in record.physical_op_executors + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors assert record.active_receiver_window_modes == ("full_recurrent_closure",) torch.testing.assert_close(y_cuda, y_reference, rtol=1e-4, atol=2e-5) assert isinstance(state_cuda, TensorDictBase) assert tuple(state_cuda.keys()) == () +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for fused registered forward program route test", +) +def test_fabric_cuda_fused_forward_program_owns_no_artifact_output_cells() -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + spec = _make_spec(default_k=1) + cuda_runtime = _build_runtime_for_backend(spec, "cuda") + registered_runtime = _build_runtime_for_backend(spec, "cuda") + registered_runtime.load_state_dict(cuda_runtime.state_dict()) + batch_size = 2 + time_steps = 2 + generator = torch.Generator(device="cuda").manual_seed(4321) + boundary_seq = torch.randn( + batch_size, + time_steps, + cuda_runtime.input_cell_idx.numel(), + cuda_runtime.hidden_size, + device="cuda", + generator=generator, + ) + state_cuda = cuda_runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + state_registered = registered_runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + + with torch.no_grad(): + y_cuda, state_cuda = cuda_runtime.forward_output_cells_for_readout( + boundary_input=boundary_seq, + state=state_cuda, + resets=None, + k=1, + training_semantics=False, + materialize_final_state=False, + output_boundary="sequence", + ) + record = cuda_runtime.last_backend_execution + assert record is not None + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors + assert any( + "flat_bucket_temporal_physical_strategy:selected_strategy=streaming_step_producer_consumer" in item + for item in record.workspace_aliases + ) + assert any( + "flat_bucket_temporal_physical_strategy:reason=active_strategy=streaming_step_producer_consumer" in item + and "streaming_step_strategy=registered_program_body" in item + for item in record.workspace_aliases + ) + assert any( + "flat_bucket_temporal_registered_backward_memory_stage:stage=native_forward_after_streaming_message_release" + in item + for item in record.workspace_aliases + ) + + with torch.no_grad(): + y_registered, _state_registered = registered_runtime.forward_output_cells_for_readout( + boundary_input=boundary_seq, + state=state_registered, + resets=torch.zeros(batch_size, time_steps, dtype=torch.bool, device="cuda"), + k=1, + training_semantics=False, + materialize_final_state=False, + output_boundary="sequence", + ) + + torch.testing.assert_close(y_cuda, y_registered, rtol=1e-4, atol=2e-5) + assert isinstance(state_cuda, TensorDictBase) + assert tuple(state_cuda.keys()) == () + + materialized_runtime = _build_runtime_for_backend(spec, "cuda") + materialized_reference_runtime = _build_runtime_for_backend(spec, "cuda") + materialized_runtime.load_state_dict(cuda_runtime.state_dict()) + materialized_reference_runtime.load_state_dict(cuda_runtime.state_dict()) + materialized_state = materialized_runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + materialized_reference_state = materialized_reference_runtime.init_state( + batch_size, + device="cuda", + dtype=torch.float32, + ) + reset_seq = torch.tensor( + [ + [False, True], + [True, False], + ], + dtype=torch.bool, + device="cuda", + ) + with torch.no_grad(): + y_materialized, state_materialized = materialized_runtime.forward_output_cells_for_readout( + boundary_input=boundary_seq, + state=materialized_state, + resets=reset_seq, + k=1, + training_semantics=False, + materialize_final_state=True, + output_boundary="sequence", + ) + record = materialized_runtime.last_backend_execution + assert record is not None + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert "flat_bucket_state_cache:registered_fused_program_final_tensor_table" in record.workspace_aliases + + with torch.no_grad(): + y_materialized_reference, state_materialized_reference = ( + materialized_reference_runtime.forward_output_cells_for_readout( + boundary_input=boundary_seq, + state=materialized_reference_state, + resets=reset_seq, + k=1, + training_semantics=True, + materialize_final_state=True, + output_boundary="sequence", + ) + ) + torch.testing.assert_close(y_materialized, y_materialized_reference, rtol=1e-4, atol=2e-5) + _assert_fabric_state_close( + cast(TensorDictBase, state_materialized), + cast(TensorDictBase, state_materialized_reference), + rtol=1e-4, + atol=2e-5, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for reverse program-window test") +def test_fabric_cuda_registered_reverse_program_window_owns_no_carry_output_cells() -> None: + torch.manual_seed(1234) + torch.cuda.manual_seed_all(1234) + spec = _make_spec(default_k=1, gradient_horizon_steps=4) + program_runtime = _build_runtime_for_backend(spec, "cuda") + fallback_runtime = _build_runtime_for_backend(spec, "cuda") + fallback_runtime.load_state_dict(program_runtime.state_dict()) + batch_size = 2 + time_steps = 2 + generator = torch.Generator(device="cuda").manual_seed(4321) + boundary_program = torch.randn( + batch_size, + time_steps, + program_runtime.input_cell_idx.numel(), + program_runtime.hidden_size, + device="cuda", + generator=generator, + requires_grad=True, + ) + boundary_fallback = boundary_program.detach().clone().requires_grad_(True) + state_program = program_runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + state_fallback = fallback_runtime.init_state(batch_size, device="cuda", dtype=torch.float32) + + y_program, state_program_out = program_runtime.forward_output_cells_for_readout( + boundary_input=boundary_program, + state=state_program, + resets=None, + k=1, + training_semantics=True, + materialize_final_state=False, + output_boundary="sequence", + ) + y_fallback, state_fallback_out = fallback_runtime.forward_output_cells_for_readout( + boundary_input=boundary_fallback, + state=state_fallback, + resets=None, + k=1, + training_semantics=True, + materialize_final_state=False, + output_boundary="sequence", + ) + assert isinstance(state_program_out, TensorDictBase) + assert isinstance(state_fallback_out, TensorDictBase) + assert tuple(state_program_out.keys()) == () + assert tuple(state_fallback_out.keys()) == () + torch.testing.assert_close(y_program, y_fallback, rtol=1e-4, atol=2e-5) + + y_program.square().mean().backward() + y_fallback.square().mean().backward() + + record = program_runtime.last_backend_execution + assert record is not None + assert record.launch_scan_implementations == ("registered_temporal_fused_forward_program_cuda",) + assert record.launch_temporal_scan_owners == ("registered_fused_forward_program_cuda",) + assert "registered_temporal_fused_forward_program_cuda" in record.physical_op_executors + assert ( + program_runtime._last_flat_bucket_temporal_artifact_recompute_owner + == "registered_fused_forward_program_tensor_store_direct" + ) + reverse_scan_aliases = record.workspace_aliases + record.backward_workspace_aliases + assert "temporal_artifacts:store_step_artifacts" in record.workspace_aliases + assert "flat_bucket_temporal_reverse_scan_owner:registered_reverse_program_window" in reverse_scan_aliases + assert "flat_bucket_temporal_reverse_scan_owner:registered_reverse_executor_bindings" not in reverse_scan_aliases + assert "flat_bucket_temporal_reverse_scan_owner:python_host_reverse_loop" not in reverse_scan_aliases + assert "flat_bucket_temporal_backward_binding_abi:registered_executor_binding_rows" in reverse_scan_aliases + assert "temporal_backward_glue:registered_fused_backward_program_span" in (record.backward_launch_counts) + assert "temporal_backward_glue:registered_fused_backward_program_span_readout_message_kv" in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_reverse_program_window_step" not in record.backward_launch_counts + assert "temporal_backward_glue:registered_fused_reverse_program_window_step_readout_message_kv" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_output_grad_window" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_readout_message_kv_step" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_span_transition_boundary" in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_reverse_program_transition_boundary_step" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_transition_stage" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_readout_step" not in record.backward_launch_counts + assert "temporal_backward_glue:registered_fused_backward_program_output_message_step" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_recurrent_kv_projection_step" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_span_recurrent_message_boundary_initial_kv" in ( + record.backward_launch_counts + ) + assert ( + "temporal_backward_glue:registered_fused_backward_program_recurrent_message_boundary_initial_kv_step" + not in (record.backward_launch_counts) + ) + assert "temporal_backward_glue:registered_fused_backward_program_recurrent_message_initial_kv_step" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_recurrent_message_step" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_initial_recurrent_kv_projection_step" not in ( + record.backward_launch_counts + ) + assert "temporal_backward_glue:registered_fused_backward_program_boundary_kv_projection_step" not in ( + record.backward_launch_counts + ) + assert boundary_program.grad is not None + assert boundary_fallback.grad is not None + torch.testing.assert_close(boundary_program.grad, boundary_fallback.grad, rtol=6e-3, atol=6e-3) + program_grads = _param_grads(program_runtime) + fallback_grads = _param_grads(fallback_runtime) + assert program_grads.keys() == fallback_grads.keys() + for name in program_grads: + torch.testing.assert_close(program_grads[name], fallback_grads[name], rtol=1e-2, atol=1e-2) + + def _reference_messages( runtime: Runtime, z_prev: torch.Tensor, @@ -6758,6 +7927,87 @@ def _reference_messages( q: torch.Tensor, gathered_kv_weight: torch.Tensor, step_idx: int | torch.Tensor, +) -> torch.Tensor: + lowering_kind = str(getattr(runtime.backend_ir.message_rule, "lowering_kind", "")) + if lowering_kind not in {"dot_product_fixed_slot_context_nudge", "dot_product_fixed_slot_context_gate"}: + return _reference_dynamic_messages( + runtime, + z_prev, + q=q, + gathered_kv_weight=gathered_kv_weight, + step_idx=step_idx, + ) + batch_size, time_steps, num_cells, _ = z_prev.shape + kv_all = torch.einsum("btnd,ndm->btnm", z_prev, gathered_kv_weight).view( + batch_size * time_steps, + num_cells, + runtime.head_dim + runtime.value_dim, + ) + _k_all, v_all = kv_all.split((runtime.head_dim, runtime.value_dim), dim=-1) + query_module = runtime.message_rule_modules["message_query_slot_proj"] + sender_module = runtime.message_rule_modules["message_sender_slot_key_proj"] + q_slot = query_module(runtime.slot_embed).view(num_cells, runtime.head_dim) + sender_slot_key = sender_module(runtime.slot_embed).view(num_cells, runtime.head_dim) + sender_context_key = runtime.message_rule_parameters["message_sender_context_key"] + query_context_scalar = runtime.message_rule_parameters.get("message_query_nudge_scale") + if query_context_scalar is None: + query_context_scalar = runtime.message_rule_parameters["message_query_context_gate"] + sender_key = torch.cat((sender_slot_key, sender_context_key), dim=-1) + k_neighbors = sender_key.index_select(0, runtime.neighbor_idx.reshape(-1)).view( + num_cells, + runtime.neighbor_idx.shape[1], + 2 * runtime.head_dim, + ) + v_neighbors = v_all.index_select(1, runtime.neighbor_idx.reshape(-1)).view( + batch_size * time_steps, + num_cells, + runtime.neighbor_idx.shape[1], + runtime.value_dim, + ) + q_context = v_all[:, :, : runtime.head_dim] * query_context_scalar.reshape(1, 1, 1).to(dtype=v_all.dtype) + q_full = torch.cat( + (q_slot.view(1, num_cells, runtime.head_dim).expand(batch_size * time_steps, -1, -1), q_context), dim=-1 + ) + logits = ( + q_full.view(batch_size * time_steps, num_cells, 1, 2 * runtime.head_dim) + * k_neighbors.view(1, num_cells, runtime.neighbor_idx.shape[1], 2 * runtime.head_dim) + ).sum(dim=-1) / math.sqrt(float(2 * runtime.head_dim)) + valid_windows = runtime.neighbor_valid.view(1, num_cells, -1) + if float(runtime.config.message.distance_logit_scale) > 0.0: + logits = logits - float(runtime.config.message.distance_logit_scale) * runtime.edge_distance.view( + 1, + num_cells, + -1, + ) + if runtime.spec.anatomy.edge_delay is not None: + valid_windows = valid_windows & _reference_delay_mask( + edge_delay=runtime.edge_delay, + step_idx=step_idx, + batch_size=batch_size, + time_steps=time_steps, + num_receivers=num_cells, + num_neighbors=int(runtime.neighbor_idx.shape[1]), + device=z_prev.device, + ) + weights = torch.softmax(logits.masked_fill(~valid_windows, float("-inf")).to(dtype=torch.float32), dim=-1).to( + dtype=v_neighbors.dtype + ) + weights = torch.where(valid_windows, weights, torch.zeros_like(weights)) + has_valid = valid_windows.any(dim=-1, keepdim=True) + weights = torch.where(has_valid, weights, torch.zeros_like(weights)) + weighted_value = torch.matmul(weights.unsqueeze(-2), v_neighbors).squeeze(-2) + projected = torch.nn.functional.linear(weighted_value, runtime.msg_out.weight) + projected = torch.nn.functional.layer_norm(projected, (int(projected.shape[-1]),), eps=1.0e-5) + return projected.view(batch_size, time_steps, num_cells, int(projected.shape[-1])) + + +def _reference_dynamic_messages( + runtime: Runtime, + z_prev: torch.Tensor, + *, + q: torch.Tensor, + gathered_kv_weight: torch.Tensor, + step_idx: int | torch.Tensor, ) -> torch.Tensor: batch_size, time_steps, num_cells, _ = z_prev.shape kv_all = torch.einsum("btnd,ndm->btnm", z_prev, gathered_kv_weight).view( @@ -6771,28 +8021,64 @@ def _reference_messages( v_neighbors = v_all[:, :, runtime.neighbor_idx, :] q_neighbors = q.view(1, 1, num_cells, 1, runtime.head_dim) logits = (q_neighbors * k_neighbors).sum(dim=-1) / math.sqrt(float(runtime.head_dim)) - invalid_mask = ~runtime.neighbor_valid.view(1, 1, num_cells, -1) - logits = logits.masked_fill(invalid_mask, float("-inf")) - if float(runtime.config.distance_logit_scale) > 0.0: - logits = logits - float(runtime.config.distance_logit_scale) * runtime.edge_distance.view(1, 1, num_cells, -1) + valid_windows = runtime.neighbor_valid.view(1, 1, num_cells, -1) + if float(runtime.config.message.distance_logit_scale) > 0.0: + logits = logits - float(runtime.config.message.distance_logit_scale) * runtime.edge_distance.view( + 1, + 1, + num_cells, + -1, + ) if runtime.spec.anatomy.edge_delay is not None: - if isinstance(step_idx, int): - step_view = step_idx + delay_mask = _reference_delay_mask( + edge_delay=runtime.edge_delay, + step_idx=step_idx, + batch_size=batch_size, + time_steps=time_steps, + num_receivers=num_cells, + num_neighbors=int(runtime.neighbor_idx.shape[1]), + device=z_prev.device, + ) + if int(delay_mask.shape[0]) == batch_size * time_steps: + delay_mask = delay_mask.view(batch_size, time_steps, num_cells, -1) else: - step_tensor = torch.as_tensor(step_idx, device=z_prev.device, dtype=runtime.edge_delay.dtype) - if step_tensor.dim() == 1 and step_tensor.shape[0] == batch_size: - step_view = step_tensor.view(batch_size, 1, 1, 1) - elif step_tensor.dim() == 2 and step_tensor.shape == (batch_size, time_steps): - step_view = step_tensor.view(batch_size, time_steps, 1, 1) - else: - raise ValueError(f"Unsupported step_idx shape {tuple(step_tensor.shape)}") - logits = logits.masked_fill(runtime.edge_delay.view(1, 1, num_cells, -1) > step_view, float("-inf")) + delay_mask = delay_mask.view(1, 1, num_cells, -1) + valid_windows = valid_windows & delay_mask + logits = logits.masked_fill(~valid_windows, float("-inf")) weights = torch.softmax(logits.to(dtype=torch.float32), dim=3).to(dtype=v_neighbors.dtype) - weights = torch.where(runtime.neighbor_valid.view(1, 1, num_cells, -1), weights, torch.zeros_like(weights)) + weights = torch.where(valid_windows, weights, torch.zeros_like(weights)) + has_valid = valid_windows.any(dim=3, keepdim=True) + weights = torch.where(has_valid, weights, torch.zeros_like(weights)) msg_heads = (weights.unsqueeze(-1) * v_neighbors).sum(dim=3) return runtime.msg_out(msg_heads.reshape(batch_size, time_steps, num_cells, runtime.value_dim)) +def _reference_delay_mask( + *, + edge_delay: torch.Tensor, + step_idx: int | torch.Tensor, + batch_size: int, + time_steps: int, + num_receivers: int, + num_neighbors: int, + device: torch.device, +) -> torch.Tensor: + if isinstance(step_idx, int): + return edge_delay.view(1, num_receivers, num_neighbors) <= step_idx + step_tensor = torch.as_tensor(step_idx, device=device, dtype=edge_delay.dtype) + if time_steps == 1: + if step_tensor.dim() != 1 or step_tensor.shape[0] != batch_size: + raise ValueError(f"step_idx tensor must have shape [B], got {tuple(step_tensor.shape)}") + return edge_delay.view(1, num_receivers, num_neighbors) <= step_tensor.view(batch_size, 1, 1) + if step_tensor.dim() == 1 and step_tensor.shape[0] == batch_size: + step_flat = step_tensor.view(batch_size, 1).expand(batch_size, time_steps).reshape(batch_size * time_steps) + elif step_tensor.dim() == 2 and step_tensor.shape == (batch_size, time_steps): + step_flat = step_tensor.reshape(batch_size * time_steps) + else: + raise ValueError(f"step_idx tensor must have shape [B] or [B,T], got {tuple(step_tensor.shape)}") + return edge_delay.view(1, num_receivers, num_neighbors) <= step_flat.view(batch_size * time_steps, 1, 1) + + def _reference_messages_step( runtime: Runtime, z_prev_step: torch.Tensor, @@ -6816,7 +8102,30 @@ def _reference_messages_step( ).squeeze(1) -def _reference_stream_step_k1_previous( +def _reference_dynamic_messages_step( + runtime: Runtime, + z_prev_step: torch.Tensor, + *, + q: torch.Tensor, + step_idx: int | torch.Tensor, +) -> torch.Tensor: + gathered_kv_weight = torch.cat( + ( + runtime.k_weight.index_select(0, runtime.kv_group_id), + runtime.v_weight.index_select(0, runtime.kv_group_id), + ), + dim=-1, + ) + return _reference_dynamic_messages( + runtime, + z_prev_step.unsqueeze(1), + q=q, + gathered_kv_weight=gathered_kv_weight, + step_idx=step_idx, + ).squeeze(1) + + +def _reference_stream_step_k1_declared_message_rule( runtime: Runtime, cells_prev: torch.Tensor, *, @@ -6871,14 +8180,22 @@ def _reference_stream_step_k1_previous( edge_delay=runtime.recurrent_edge_delay, use_delay=runtime.spec.anatomy.edge_delay is not None, step_idx=1, + owner_tag="recurrent", + ) + recurrent_input, recurrent_input_already_projected = ( + runtime._project_recurrent_message_to_cell_step_for_message_rule( + recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=recurrent_cell_bias, + ) ) - recurrent_input = torch.nn.functional.linear(recurrent_msg, value_to_cell_weight) + recurrent_cell_bias recurrent_next, next_population_state = runtime._run_population_updates_recurrent_step( recurrent_input, population_state, # type: ignore[arg-type] resets=None, batch_size=cells_prev.shape[0], population_materialized=population_materialized, + population_input_already_projected=recurrent_input_already_projected, ) if all_active is True: recurrent_mid = recurrent_next @@ -6921,6 +8238,7 @@ def _reference_stream_step_k1_previous( edge_delay=runtime.output_edge_delay, use_delay=runtime.spec.anatomy.edge_delay is not None, step_idx=k_rows, + owner_tag="readout", ) output_cells = _reference_project_output_cells_step_raw( output_msg, @@ -6939,7 +8257,7 @@ def _reference_stream_step_k1_previous( return cells_out, next_state -def _reference_stream_step_boundary_multistep_previous( +def _reference_stream_step_boundary_multistep_declared_message_rule( runtime: Runtime, cells_prev: torch.Tensor, *, @@ -6959,52 +8277,120 @@ def _reference_stream_step_boundary_multistep_previous( else: cells_prev[:, runtime.input_cell_idx, :] = boundary_step - y_prev = cells_prev.unsqueeze(1) + sender_kv_weight = gathered_kv_weight.index_select(0, runtime.sender_cell_idx) + sender_input_to_kv_weight = torch.einsum("dh,sdm->shm", runtime.public_proj.weight, sender_kv_weight) + input_sender_input_to_kv_weight = sender_input_to_kv_weight.index_select(0, runtime.input_sender_idx) + recurrent_sender_input_to_kv_weight = sender_input_to_kv_weight.index_select(0, runtime.recurrent_sender_idx) + value_to_cell_weight = runtime.msg_to_cell.weight @ runtime.msg_out.weight + value_to_output_weight = torch.einsum("dv,pdh->pvh", runtime.msg_out.weight, runtime.output_cell_weight) + if runtime._partitioned_layout: + recurrent_mid = cells_prev[:, runtime._recurrent_slice, :] + else: + recurrent_mid = cells_prev[:, runtime.recurrent_cell_idx, :] + input_k, input_v = _reference_project_sender_kv_from_cells_step( + boundary_step, + sender_input_to_kv_weight=input_sender_input_to_kv_weight, + head_dim=runtime.head_dim, + value_dim=runtime.value_dim, + ) running_population_state = population_state - boundary_step_seq = boundary_step.unsqueeze(1) for step_idx in range(max_steps): - z_prev = runtime.public_proj(y_prev) - msg = runtime._compute_messages( - z_prev, - q=q, - gathered_kv_weight=gathered_kv_weight, + recurrent_k, recurrent_v = _reference_project_sender_kv_from_cells_step( + recurrent_mid, + sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + head_dim=runtime.head_dim, + value_dim=runtime.value_dim, + ) + if runtime._partitioned_layout: + k_all = torch.cat((input_k, recurrent_k), dim=1) + v_all = torch.cat((input_v, recurrent_v), dim=1) + else: + k_all = input_k.new_zeros(cells_prev.shape[0], runtime.sender_cell_idx.numel(), runtime.head_dim) + v_all = input_v.new_zeros(cells_prev.shape[0], runtime.sender_cell_idx.numel(), runtime.value_dim) + k_all[:, runtime.input_sender_idx, :] = input_k + v_all[:, runtime.input_sender_idx, :] = input_v + k_all[:, runtime.recurrent_sender_idx, :] = recurrent_k + v_all[:, runtime.recurrent_sender_idx, :] = recurrent_v + recurrent_msg = runtime._compute_messages_step_subset_raw( + k_all, + v_all, + q_subset=q.index_select(0, runtime.recurrent_cell_idx), + neighbor_idx=runtime.recurrent_neighbor_idx, + neighbor_valid=runtime.recurrent_neighbor_valid, + edge_distance=runtime.recurrent_edge_distance, + edge_delay=runtime.recurrent_edge_delay, + use_delay=runtime.spec.anatomy.edge_delay is not None, step_idx=step_idx + 1, + owner_tag="recurrent", + ) + recurrent_input, recurrent_input_already_projected = ( + runtime._project_recurrent_message_to_cell_step_for_message_rule( + recurrent_msg, + value_to_cell_weight=value_to_cell_weight, + recurrent_cell_bias=cell_bias[:, :, runtime.recurrent_cell_idx, :].squeeze(1), + ) ) - population_input = runtime.msg_to_cell(msg) + cell_bias - y_next, next_population_state = runtime._run_population_updates( - population_input, + recurrent_next, next_population_state = runtime._run_population_updates_recurrent_step( + recurrent_input, running_population_state, # type: ignore[arg-type] resets=population_resets, batch_size=cells_prev.shape[0], - time_steps=1, population_materialized=population_materialized, + population_input_already_projected=recurrent_input_already_projected, ) - if runtime._partitioned_layout: - y_next[:, :, runtime._input_slice, :] = boundary_step_seq - else: - y_next[:, :, runtime.input_cell_idx, :] = boundary_step_seq active_rows = step_idx < k_rows - y_prev = torch.where(active_rows.view(-1, 1, 1, 1), y_next, y_prev) + recurrent_mid = torch.where(active_rows.view(-1, 1, 1), recurrent_next, recurrent_mid) running_population_state = runtime._blend_population_states( running_population_state, # type: ignore[arg-type] next_population_state, active_rows, ) - final_z = runtime.public_proj(y_prev) - final_msg = runtime._compute_messages( - final_z, - q=q, - gathered_kv_weight=gathered_kv_weight, + recurrent_k, recurrent_v = _reference_project_sender_kv_from_cells_step( + recurrent_mid, + sender_input_to_kv_weight=recurrent_sender_input_to_kv_weight, + head_dim=runtime.head_dim, + value_dim=runtime.value_dim, + ) + if runtime._partitioned_layout: + final_k = torch.cat((input_k, recurrent_k), dim=1) + final_v = torch.cat((input_v, recurrent_v), dim=1) + else: + final_k = input_k.new_zeros(cells_prev.shape[0], runtime.sender_cell_idx.numel(), runtime.head_dim) + final_v = input_v.new_zeros(cells_prev.shape[0], runtime.sender_cell_idx.numel(), runtime.value_dim) + final_k[:, runtime.input_sender_idx, :] = input_k + final_v[:, runtime.input_sender_idx, :] = input_v + final_k[:, runtime.recurrent_sender_idx, :] = recurrent_k + final_v[:, runtime.recurrent_sender_idx, :] = recurrent_v + output_msg = runtime._compute_messages_step_subset_raw( + final_k, + final_v, + q_subset=q.index_select(0, runtime.output_cell_idx), + neighbor_idx=runtime.output_neighbor_idx, + neighbor_valid=runtime.output_neighbor_valid, + edge_distance=runtime.output_edge_distance, + edge_delay=runtime.output_edge_delay, + use_delay=runtime.spec.anatomy.edge_delay is not None, step_idx=k_rows, + owner_tag="readout", ) - y_out = y_prev.clone() - y_out[:, :, runtime.output_cell_idx, :] = runtime._project_output_cells(final_msg[:, :, runtime.output_cell_idx, :]) + output_cells = _reference_project_output_cells_step_raw( + output_msg, + value_to_output_weight=value_to_output_weight, + output_cell_bias=runtime.output_cell_bias, + ).to(dtype=cells_prev.dtype) + if runtime._partitioned_layout: + y_out = torch.cat((boundary_step, recurrent_mid, output_cells), dim=1) + else: + y_out = cells_prev.clone() + y_out[:, runtime.input_cell_idx, :] = boundary_step + y_out[:, runtime.recurrent_cell_idx, :] = recurrent_mid + y_out[:, runtime.output_cell_idx, :] = output_cells next_state = runtime.init_state(cells_prev.shape[0], device=cells_prev.device, dtype=cells_prev.dtype) - next_state["cells"] = y_out.squeeze(1) + next_state["cells"] = y_out for cell_type in runtime._population_names: next_state[cell_type] = running_population_state[cell_type] - return y_out.squeeze(1), next_state + return y_out, next_state def _reference_project_sender_kv_from_cells_step(