From 12030ae502701fae69e90bcf599cbedf4f6ab6fb Mon Sep 17 00:00:00 2001 From: Concode0 Date: Wed, 18 Mar 2026 23:01:02 +0900 Subject: [PATCH 01/16] feat: add Geometric Turing Machine v4 (PGA motor, rule memory, cross-grade attention) for ARC-AGI --- conf/task/gtm.yaml | 76 +++++++ datalib/__init__.py | 4 + datalib/arc.py | 412 ++++++++++++++++++++++++++++++++++ datalib/gtm.py | 17 ++ main.py | 2 + models/gtm/__init__.py | 35 +++ models/gtm/adaptive_halt.py | 95 ++++++++ models/gtm/analysis.py | 433 ++++++++++++++++++++++++++++++++++++ models/gtm/control_plane.py | 122 ++++++++++ models/gtm/cpu.py | 174 +++++++++++++++ models/gtm/grid_codec.py | 136 +++++++++++ models/gtm/gtm_net.py | 247 ++++++++++++++++++++ models/gtm/heads.py | 48 ++++ models/gtm/rule_memory.py | 88 ++++++++ models/gtm/superposition.py | 127 +++++++++++ models/gtm/turing_step.py | 225 +++++++++++++++++++ models/gtm/turing_vm.py | 161 ++++++++++++++ models/vm/__init__.py | 19 ++ models/vm/attention.py | 75 +++++++ models/vm/bridge.py | 46 ++++ models/vm/projections.py | 80 +++++++ pyproject.toml | 7 +- scripts/analyze_gtm.py | 97 ++++++++ tasks/__init__.py | 4 + tasks/gtm.py | 345 ++++++++++++++++++++++++++++ uv.lock | 17 +- 26 files changed, 3087 insertions(+), 5 deletions(-) create mode 100644 conf/task/gtm.yaml create mode 100644 datalib/arc.py create mode 100644 datalib/gtm.py create mode 100644 models/gtm/__init__.py create mode 100644 models/gtm/adaptive_halt.py create mode 100644 models/gtm/analysis.py create mode 100644 models/gtm/control_plane.py create mode 100644 models/gtm/cpu.py create mode 100644 models/gtm/grid_codec.py create mode 100644 models/gtm/gtm_net.py create mode 100644 models/gtm/heads.py create mode 100644 models/gtm/rule_memory.py create mode 100644 models/gtm/superposition.py create mode 100644 models/gtm/turing_step.py create mode 100644 models/gtm/turing_vm.py create mode 100644 models/vm/__init__.py create mode 100644 models/vm/attention.py create mode 100644 models/vm/bridge.py create mode 100644 models/vm/projections.py create mode 100644 scripts/analyze_gtm.py create mode 100644 tasks/gtm.py diff --git a/conf/task/gtm.yaml b/conf/task/gtm.yaml new file mode 100644 index 0000000..a7dd780 --- /dev/null +++ b/conf/task/gtm.yaml @@ -0,0 +1,76 @@ +# @package _global_ +name: gtm + +# ── RTX Pro 4500 (32 GB VRAM, Ada Lovelace) tuning notes ────────────── +# +# VRAM budget breakdown (fp16 activations via AMP): +# Phase 1 (demo): B=24, K=3, grid=30×30 → N_demo = 5400 cells +# Attention [B,H,N,N]: 24×4×5400×5400×2B ≈ 5.6 GB +# CPU state × 12 steps: 24×5400×16×2B × 12 ≈ 50 MB +# Phase 2 (test): N_test = 900 cells → attention < 150 MB +# Model params + optimizer: < 1 GB +# Total: ~8–10 GB in fp16, safely within 32 GB +# +# Key CUDA flags: +# amp: true — bf16 forward/backward (Ada Lovelace tensor cores) +# compile: true — torch.compile for kernel fusion +# cudnn_benchmark: true +# pin_memory: true — async CPU→GPU transfer +# num_workers: 4 — parallel data loading + +algebra: + p: 3 + q: 0 + r: 1 + device: cuda + +model: + channels: 32 + num_steps: 12 + max_steps: 24 + num_hypotheses: 8 + top_k: 1 + coord_scale: 1.0 + head_hidden: 128 + gumbel_temperature: 1.0 + num_rule_slots: 8 + act: + enabled: true + lambda_p: 0.5 + color_unit: + K_color: 4 + attention: + num_heads: 4 + head_dim: 8 + +dataset: + data_dir: data/arc + include_toy: true + toy_n_examples: 20000 + toy_max_grid_size: 15 + num_demos: 3 + epoch_samples: 0 # 0 = full dataset shuffle; set >0 for capped-epoch sampling + +training: + epochs: 150 + lr: 0.0005 + batch_size: 24 + optimizer_type: riemannian_adam + max_bivector_norm: 10.0 + + # CUDA acceleration + num_workers: 4 + pin_memory: true + amp: true + compile: true + cudnn_benchmark: true + + # Three-phase schedule (scaled for 150 epochs) + warmup_epochs: 8 + trim_epochs: 72 + act_epochs: 70 + act_weight: 0.01 + act_ramp_epochs: 20 + gate_entropy_weight: 0.001 + grad_clip: 1.0 + eval_every: 5 diff --git a/datalib/__init__.py b/datalib/__init__.py index 35f75ba..ced95a2 100644 --- a/datalib/__init__.py +++ b/datalib/__init__.py @@ -11,6 +11,7 @@ from .md17 import get_md17_loaders from .deap import DEAPDataset, get_deap_loaders, get_group_sizes from .lqa import CLUTRRDataset, HANSDataset, BoolQNegDataset, get_lqa_loaders +from .arc import ToyARCDataset, ARCDataset, get_arc_loaders __all__ = [ "SRDataset", @@ -27,4 +28,7 @@ "HANSDataset", "BoolQNegDataset", "get_lqa_loaders", + "ToyARCDataset", + "ARCDataset", + "get_arc_loaders", ] diff --git a/datalib/arc.py b/datalib/arc.py new file mode 100644 index 0000000..fc51661 --- /dev/null +++ b/datalib/arc.py @@ -0,0 +1,412 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""ARC dataset loaders for the Geometric Turing Machine. + +Key design choices: + - Few-shot format: each example = (demo_pairs, test_input, test_output). + The model sees K demo (input,output) pairs to infer the rule, then + applies it to a test input to produce the test output. + - 2D grid preservation: grids are padded to (H_max, W_max) and kept as + 2D tensors so GridCodec can directly read row/col without re-parsing. + +Provides: + - ToyARCDataset: procedurally generated ARC-like tasks (few-shot) + - ARCDataset: original ARC-AGI JSON tasks (few-shot) + - collate_arc: custom collation for variable-size few-shot ARC examples + - get_arc_loaders: factory function for train/val DataLoaders +""" + +import json +import os +import random + +import torch +from torch.utils.data import Dataset, DataLoader, RandomSampler + + +class ToyARCDataset(Dataset): + """Procedurally generated ARC-like tasks in few-shot format. + + Each item is a task: K demo pairs + 1 test pair, all sharing + the same transformation rule. + + Task types: + 0 - color_swap: swap two colors + 1 - rotate_90: rotate grid 90° clockwise + 2 - reflect_h: reflect horizontally + 3 - fill_rect: fill a rectangular region with a color + 4 - pattern_tile: tile a small pattern across the grid + 5 - border: add a colored border + """ + + TASK_TYPES = ['color_swap', 'rotate_90', 'reflect_h', 'fill_rect', + 'pattern_tile', 'border'] + + def __init__(self, n_examples: int = 5000, max_grid_size: int = 10, + min_grid_size: int = 3, task_types: list = None, + num_demos: int = 3, seed: int = 42): + super().__init__() + self.n_examples = n_examples + self.max_grid_size = max_grid_size + self.min_grid_size = min_grid_size + self.task_types = task_types or list(range(len(self.TASK_TYPES))) + self.num_demos = num_demos + self.seed = seed + self.examples = self._generate_all() + + def _generate_all(self): + rng = random.Random(self.seed) + examples = [] + for _ in range(self.n_examples): + task_type = rng.choice(self.task_types) + # All pairs in one task share the same rule parameters + rule_params = self._sample_rule_params(task_type, rng) + + demo_pairs = [] + for _ in range(self.num_demos): + h = rng.randint(self.min_grid_size, self.max_grid_size) + w = rng.randint(self.min_grid_size, self.max_grid_size) + inp, out = self._generate_one(task_type, h, w, rng, rule_params) + demo_pairs.append({'input': inp, 'output': out}) + + # Test pair (same rule, different grid) + h = rng.randint(self.min_grid_size, self.max_grid_size) + w = rng.randint(self.min_grid_size, self.max_grid_size) + test_in, test_out = self._generate_one(task_type, h, w, rng, rule_params) + + examples.append({ + 'demo_pairs': demo_pairs, + 'test_input': test_in, + 'test_output': test_out, + 'task_type': task_type, + }) + return examples + + def _sample_rule_params(self, task_type, rng): + """Sample rule-specific parameters (shared across all pairs in a task).""" + if task_type == 0: # color_swap + c1, c2 = rng.sample(range(10), 2) + return {'c1': c1, 'c2': c2} + elif task_type == 3: # fill_rect — color is shared, position varies + return {'color': rng.randint(0, 9)} + elif task_type == 4: # pattern_tile — pattern is shared across demos + ph = rng.randint(1, 3) + pw = rng.randint(1, 3) + pattern = [[rng.randint(0, 9) for _ in range(pw)] for _ in range(ph)] + return {'pattern': pattern} + elif task_type == 5: # border + return {'color': rng.randint(1, 9)} + return {} + + def _generate_one(self, task_type, h, w, rng, rule_params): + """Generate a single (input, output) grid pair using shared rule params.""" + grid = [[rng.randint(0, 9) for _ in range(w)] for _ in range(h)] + inp = torch.tensor(grid, dtype=torch.long) + + if task_type == 0: # color_swap + c1, c2 = rule_params['c1'], rule_params['c2'] + out = inp.clone() + out[inp == c1] = c2 + out[inp == c2] = c1 + elif task_type == 1: # rotate_90 clockwise + out = inp.rot90(-1, [0, 1]) + elif task_type == 2: # reflect_h + out = inp.flip(1) + elif task_type == 3: # fill_rect + out = inp.clone() + r1 = rng.randint(0, max(0, h - 2)) + r2 = rng.randint(r1 + 1, h) + c1_r = rng.randint(0, max(0, w - 2)) + c2_r = rng.randint(c1_r + 1, w) + out[r1:r2, c1_r:c2_r] = rule_params['color'] + elif task_type == 4: # pattern_tile (shared pattern from rule_params) + pattern_data = rule_params['pattern'] + ph = min(len(pattern_data), h) + pw = min(len(pattern_data[0]), w) + pattern = torch.tensor( + [row[:pw] for row in pattern_data[:ph]], + dtype=torch.long, + ) + out = pattern.repeat( + (h + ph - 1) // ph, (w + pw - 1) // pw + )[:h, :w] + inp = out.clone() + n_corrupt = max(1, h * w // 5) + for _ in range(n_corrupt): + ri, ci = rng.randint(0, h - 1), rng.randint(0, w - 1) + inp[ri, ci] = rng.randint(0, 9) + elif task_type == 5: # border + color = rule_params['color'] + out = inp.clone() + out[0, :] = color + out[-1, :] = color + out[:, 0] = color + out[:, -1] = color + else: + out = inp.clone() + + return inp, out + + def __len__(self): + return self.n_examples + + def __getitem__(self, idx): + return self.examples[idx] + + +class ARCDataset(Dataset): + """Original ARC-AGI JSON dataset in few-shot format. + + Each item is one ARC task: the 'train' pairs are demos, and each + 'test' pair becomes a separate example (with all train pairs as demos). + + Expects directory structure: + data_dir/training/*.json + data_dir/evaluation/*.json + """ + + def __init__(self, data_dir: str, split: str = 'training'): + super().__init__() + self.data_dir = data_dir + self.split = split + self.examples = self._load_all() + + def _load_all(self): + examples = [] + task_dir = os.path.join(self.data_dir, self.split) + if not os.path.isdir(task_dir): + return examples + + for fname in sorted(os.listdir(task_dir)): + if not fname.endswith('.json'): + continue + task_id = fname.replace('.json', '') + fpath = os.path.join(task_dir, fname) + with open(fpath, 'r') as f: + task_data = json.load(f) + + # Demo pairs from 'train' + demo_pairs = [] + for pair in task_data.get('train', []): + demo_pairs.append({ + 'input': torch.tensor(pair['input'], dtype=torch.long), + 'output': torch.tensor(pair['output'], dtype=torch.long), + }) + + # Each test pair becomes a separate example with shared demos + for pair in task_data.get('test', []): + examples.append({ + 'demo_pairs': demo_pairs, + 'test_input': torch.tensor(pair['input'], dtype=torch.long), + 'test_output': torch.tensor(pair['output'], dtype=torch.long), + 'task_id': task_id, + }) + + # If no test pairs, use last train pair as test + if not task_data.get('test') and demo_pairs: + last = demo_pairs[-1] + examples.append({ + 'demo_pairs': demo_pairs[:-1] if len(demo_pairs) > 1 else demo_pairs, + 'test_input': last['input'], + 'test_output': last['output'], + 'task_id': task_id, + }) + + return examples + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx): + return self.examples[idx] + + +def _pad_grid_2d(grid, H_max, W_max, pad_value=0): + """Pad a 2D grid [H, W] to [H_max, W_max].""" + H, W = grid.shape + padded = torch.full((H_max, W_max), pad_value, dtype=grid.dtype) + padded[:H, :W] = grid + return padded + + +def collate_arc(batch): + """Custom collation for few-shot ARC examples. + + Each batch item has: + demo_pairs: list of K dicts with 'input' [H,W] and 'output' [H,W] + test_input: [H,W] + test_output: [H,W] + + Returns dict with: + 'demo_inputs': [B, K, H_max, W_max] padded demo inputs + 'demo_outputs': [B, K, H_max, W_max] padded demo outputs + 'demo_masks': [B, K, H_max, W_max] bool (True=valid input cell) + 'demo_output_masks': [B, K, H_max, W_max] bool (True=valid output cell) + 'test_inputs': [B, H_max, W_max] padded test input + 'test_outputs': [B, H_max, W_max] padded test output (-1 = pad) + 'test_masks': [B, H_max, W_max] bool (True=valid cell) + 'num_demos': [B] int — actual number of demo pairs per example + 'test_sizes': list of (H, W) tuples for test outputs + 'input_sizes': list of (H, W) tuples for test inputs + """ + B = len(batch) + + # Find max K (demo pairs), max H, max W across all grids + max_K = max(len(item['demo_pairs']) for item in batch) + all_grids = [] + for item in batch: + for dp in item['demo_pairs']: + all_grids.append(dp['input']) + all_grids.append(dp['output']) + all_grids.append(item['test_input']) + all_grids.append(item['test_output']) + + H_max = max(g.shape[0] for g in all_grids) + W_max = max(g.shape[1] for g in all_grids) + + # Allocate tensors + demo_inputs = torch.zeros(B, max_K, H_max, W_max, dtype=torch.long) + demo_outputs = torch.full((B, max_K, H_max, W_max), -1, dtype=torch.long) + demo_masks = torch.zeros(B, max_K, H_max, W_max, dtype=torch.bool) + demo_output_masks = torch.zeros(B, max_K, H_max, W_max, dtype=torch.bool) + test_inputs = torch.zeros(B, H_max, W_max, dtype=torch.long) + test_outputs = torch.full((B, H_max, W_max), -1, dtype=torch.long) + test_masks = torch.zeros(B, H_max, W_max, dtype=torch.bool) + num_demos = torch.zeros(B, dtype=torch.long) + test_sizes = [] + input_sizes = [] + + for i, item in enumerate(batch): + # Demo pairs (input and output may have different dimensions) + K = len(item['demo_pairs']) + num_demos[i] = K + for k, dp in enumerate(item['demo_pairs']): + di = dp['input'] + do = dp['output'] + dH, dW = di.shape + demo_inputs[i, k, :dH, :dW] = di + demo_masks[i, k, :dH, :dW] = True + doH, doW = do.shape + demo_outputs[i, k, :doH, :doW] = do + demo_output_masks[i, k, :doH, :doW] = True + + # Test pair + ti = item['test_input'] + to = item['test_output'] + tH, tW = ti.shape + toH, toW = to.shape + test_inputs[i, :tH, :tW] = ti + test_outputs[i, :toH, :toW] = to + test_masks[i, :tH, :tW] = True + test_sizes.append((toH, toW)) + input_sizes.append((tH, tW)) + + return { + 'demo_inputs': demo_inputs, + 'demo_outputs': demo_outputs, + 'demo_masks': demo_masks, + 'demo_output_masks': demo_output_masks, + 'test_inputs': test_inputs, + 'test_outputs': test_outputs, + 'test_masks': test_masks, + 'num_demos': num_demos, + 'test_sizes': test_sizes, + 'input_sizes': input_sizes, + } + + +def get_arc_loaders(data_dir: str = 'data/arc', batch_size: int = 8, + include_toy: bool = True, toy_n_examples: int = 5000, + toy_max_grid_size: int = 10, num_workers: int = 0, + num_demos: int = 3, seed: int = 42, + pin_memory: bool = False, + epoch_samples: int = 0): + """Create ARC train/val DataLoaders. + + Args: + data_dir: Path to ARC JSON directory. + batch_size: Batch size. + include_toy: If True, augment training with ToyARC examples. + toy_n_examples: Number of synthetic examples. + toy_max_grid_size: Max grid dimension for synthetic examples. + num_workers: DataLoader workers. + num_demos: Number of demo pairs per task (for ToyARC). + seed: Random seed. + pin_memory: Pin memory for CUDA async transfers. + epoch_samples: Samples per epoch (0 = use full dataset with shuffle). + + Returns: + dict with 'train', 'val' DataLoaders and 'num_colors' (10). + """ + datasets = [] + + # Original ARC training data + arc_train = ARCDataset(data_dir, split='training') + if len(arc_train) > 0: + datasets.append(arc_train) + + # Synthetic ToyARC data + if include_toy: + toy = ToyARCDataset( + n_examples=toy_n_examples, + max_grid_size=toy_max_grid_size, + num_demos=num_demos, + seed=seed, + ) + datasets.append(toy) + + if datasets: + train_dataset = torch.utils.data.ConcatDataset(datasets) + else: + train_dataset = ToyARCDataset( + n_examples=toy_n_examples, + max_grid_size=toy_max_grid_size, + num_demos=num_demos, + seed=seed, + ) + + # Validation: ARC evaluation set, fallback to small ToyARC + val_dataset = ARCDataset(data_dir, split='evaluation') + if len(val_dataset) == 0: + val_dataset = ToyARCDataset( + n_examples=min(500, toy_n_examples // 10), + max_grid_size=toy_max_grid_size, + num_demos=num_demos, + seed=seed + 1, + ) + + persistent = num_workers > 0 + + if epoch_samples > 0: + sampler = RandomSampler( + train_dataset, + replacement=True, + num_samples=epoch_samples, + ) + train_loader = DataLoader( + train_dataset, batch_size=batch_size, sampler=sampler, + collate_fn=collate_arc, num_workers=num_workers, drop_last=True, + pin_memory=pin_memory, persistent_workers=persistent, + ) + else: + train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + collate_fn=collate_arc, num_workers=num_workers, drop_last=True, + pin_memory=pin_memory, persistent_workers=persistent, + ) + val_loader = DataLoader( + val_dataset, batch_size=batch_size, shuffle=False, + collate_fn=collate_arc, num_workers=num_workers, + pin_memory=pin_memory, persistent_workers=persistent, + ) + + return { + 'train': train_loader, + 'val': val_loader, + 'num_colors': 10, + } diff --git a/datalib/gtm.py b/datalib/gtm.py new file mode 100644 index 0000000..3a4ab6f --- /dev/null +++ b/datalib/gtm.py @@ -0,0 +1,17 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""GTM data — re-exports ARC dataset loaders.""" + +from .arc import ToyARCDataset, ARCDataset, collate_arc, get_arc_loaders + +__all__ = [ + "ToyARCDataset", + "ARCDataset", + "collate_arc", + "get_arc_loaders", +] diff --git a/main.py b/main.py index 182d3bd..0e587fd 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ from tasks.symbolic_regression import SRTask from tasks.lqa import LQATask from tasks.deap_eeg import DEAPEEGTask +from tasks.gtm import GTMTask EXAMPLE_TASKS = {'manifold', 'hyperbolic', 'sanity'} @@ -33,6 +34,7 @@ def main(cfg: DictConfig): 'sr': SRTask, 'lqa': LQATask, 'deap_eeg': DEAPEEGTask, + 'gtm': GTMTask, } if task_name in EXAMPLE_TASKS: diff --git a/models/gtm/__init__.py b/models/gtm/__init__.py new file mode 100644 index 0000000..7f420e7 --- /dev/null +++ b/models/gtm/__init__.py @@ -0,0 +1,35 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Geometric Turing Machine (GTM) package — ARC-AGI v4.""" + +from .grid_codec import GridCodec +from .cpu import GeometricCPU, ColorUnit +from .control_plane import ControlPlane +from .superposition import GeometricSuperpositionSearch +from .turing_step import TuringStep +from .adaptive_halt import AdaptiveHalt +from .turing_vm import TuringVM +from .heads import GridReconstructionHead +from .rule_memory import RuleAggregator +from .gtm_net import GTMNet +from .analysis import GTMAnalyzer + +__all__ = [ + "GridCodec", + "GeometricCPU", + "ColorUnit", + "ControlPlane", + "GeometricSuperpositionSearch", + "TuringStep", + "AdaptiveHalt", + "TuringVM", + "GridReconstructionHead", + "RuleAggregator", + "GTMNet", + "GTMAnalyzer", +] diff --git a/models/gtm/adaptive_halt.py b/models/gtm/adaptive_halt.py new file mode 100644 index 0000000..be73401 --- /dev/null +++ b/models/gtm/adaptive_halt.py @@ -0,0 +1,95 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""PonderNet-style adaptive computation controller. + +Takes per-step halting probabilities and produces: +- Mixing weights for per-step outputs (geometric distribution) +- KL divergence against a geometric prior for regularization +- Expected number of computation steps per example +""" + +import torch +import torch.nn as nn + + +class AdaptiveHalt(nn.Module): + """PonderNet adaptive computation time controller. + + Computes mixing weights from per-step halt probabilities using a + geometric distribution: p(halt at t) = lambda_t * prod_{s dict: + """Compute mixing weights and KL loss from per-step halt probabilities. + + Args: + halt_probs: List of T tensors, each [B] (mean halt prob per example). + + Returns: + dict with: + 'weights': [B, T] mixing weights for per-step outputs + 'expected_steps': [B] expected computation depth + 'kl_loss': scalar KL divergence against geometric prior + """ + T = len(halt_probs) + B = halt_probs[0].shape[0] + device = halt_probs[0].device + eps = self.eps + + # Stack halt probs: [T, B] + lambdas = torch.stack(halt_probs, dim=0) # [T, B] + lambdas = lambdas.clamp(eps, 1.0 - eps) + + # Compute geometric distribution weights + # p(halt at t) = lambda_t * prod_{s +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""GTM Explainability Analysis — post-training inspection tools (v4 PGA). + +Usage (from checkpoint): + analyzer = GTMAnalyzer.from_checkpoint( + 'gtm_arc_best.pt', device='cuda' + ) + + # Static: what did each instruction template learn? + instr = analyzer.analyze_instructions() + + # Dynamic: full analysis on a batch + report = analyzer.analyze(batch) + +Usage (from existing model): + analyzer = GTMAnalyzer(model, device='cuda') + report = analyzer.analyze(batch) + +Standalone script: + uv run python scripts/analyze_gtm.py --checkpoint gtm_arc_best.pt +""" + +import math +import torch +import torch.nn as nn +from core.algebra import CliffordAlgebra + + +class GTMAnalyzer: + """Post-training analysis for pretrained GTM v4 models. + + Provides: + - Instruction template decomposition (rotation + translation motors) + - Color remapping table inspection + - Cursor trajectory through both phases + - Hypothesis selection analysis (scores, weights, temperature) + - Write gate analysis (per-cell acceptance/rejection) + - Rule memory analysis + - Per-cell prediction vs target comparison + """ + + def __init__(self, model: nn.Module, device: str = 'cpu'): + self.model = model.to(device).eval() + self.device = device + + @staticmethod + def from_checkpoint(path: str, device: str = 'cpu') -> 'GTMAnalyzer': + """Load GTMAnalyzer from a BaseTask checkpoint. + + Args: + path: Path to checkpoint saved by BaseTask.save_checkpoint(). + device: Target device. + + Returns: + GTMAnalyzer instance with loaded model. + """ + from models.gtm import GTMNet + + checkpoint = torch.load(path, map_location=device, weights_only=False) + cfg = checkpoint['config'] + mcfg = cfg.model + act_cfg = mcfg.get('act', {}) + color_cfg = mcfg.get('color_unit', {}) + attn_cfg = mcfg.get('attention', {}) + + algebra_cpu = CliffordAlgebra(3, 0, 1, device=device) + algebra_ctrl = CliffordAlgebra(1, 1, 0, device=device) + + model = GTMNet( + algebra_cpu=algebra_cpu, + algebra_ctrl=algebra_ctrl, + channels=mcfg.get('channels', 16), + num_steps=mcfg.get('num_steps', 8), + max_steps=mcfg.get('max_steps', 20), + num_hypotheses=mcfg.get('num_hypotheses', 4), + top_k=mcfg.get('top_k', 1), + head_hidden=mcfg.get('head_hidden', 64), + temperature_init=mcfg.get('gumbel_temperature', 1.0), + use_act=act_cfg.get('enabled', True), + lambda_p=act_cfg.get('lambda_p', 0.5), + coord_scale=mcfg.get('coord_scale', 1.0), + K_color=color_cfg.get('K_color', 4), + num_attn_heads=attn_cfg.get('num_heads', 4), + attn_head_dim=attn_cfg.get('head_dim', 8), + num_rule_slots=mcfg.get('num_rule_slots', 8), + ) + model.load_state_dict(checkpoint['model_state_dict']) + return GTMAnalyzer(model, device) + + # ------------------------------------------------------------------ + # Static analysis (no data required) + # ------------------------------------------------------------------ + + def analyze_instructions(self) -> dict: + """Decompose instruction templates into geometric components. + + For each of the K trainable instruction templates in Cl(3,0,1): + - Rotation bivectors (e01, e02, e12) -> rotation angle and plane + - Translation bivectors (e03, e13, e23) -> translation vector + - Scalar (grade-0) and pseudoscalar (grade-4) -> color control signals + + Returns: + dict with keys per template index: + 'templates_raw': [K, 16] raw parameter values + 'rotation_angles': [K] angle in radians (= 2 * ||B_rot||) + 'rotation_planes': [K, 3] unit bivector (e01, e02, e12) + 'rotation_degrees': [K] angle in degrees + 'translation_vectors': [K, 3] translation (e03, e13, e23) magnitudes + 'translation_norms': [K] translation magnitude + 'color_control': [K, 2] (grade-0, grade-4) values + 'near_identity': [K] bool — True if template ~ no-op + """ + templates = self._get_templates() # [K, 16] + K = templates.shape[0] + + # Rotation bivectors: e01(idx3), e02(idx5), e12(idx6) + bv_rot = templates[:, [3, 5, 6]] # [K, 3] + bv_rot_norm = bv_rot.norm(dim=-1) # [K] + rotation_angles = 2.0 * bv_rot_norm + + safe_norm = bv_rot_norm.clamp(min=1e-8).unsqueeze(-1) + rotation_planes = bv_rot / safe_norm + + # Translation bivectors: e03(idx9), e13(idx10), e23(idx12) + bv_trans = templates[:, [9, 10, 12]] # [K, 3] + trans_norms = bv_trans.norm(dim=-1) # [K] + + # Color control signals + color_control = templates[:, [0, 15]] # [K, 2] (grade-0, grade-4) + + # Near-identity: small rotation + small translation + small color signal + near_identity = ( + (rotation_angles < 0.1) & + (trans_norms < 0.05) & + (color_control.abs().max(dim=-1).values < 0.05) + ) + + return { + 'templates_raw': templates, + 'rotation_angles': rotation_angles, + 'rotation_planes': rotation_planes, + 'rotation_degrees': rotation_angles * (180.0 / math.pi), + 'translation_vectors': bv_trans, + 'translation_norms': trans_norms, + 'color_control': color_control, + 'near_identity': near_identity, + } + + def analyze_color_unit(self) -> dict: + """Inspect ColorUnit remapping tables. + + Returns: + dict with: + 'remap_tables': [K_color, 10, 10] learned tables + 'table_diag_dominance': [K_color] how close to identity each table is + """ + # Get color unit from first step's search module + color_unit = self.model.vm.steps[0].search.pga_cpu.color_unit + tables = color_unit.remap_tables.detach() # [K_color, 10, 10] + + # Diagonal dominance: fraction of mass on diagonal + diags = torch.diagonal(tables, dim1=-2, dim2=-1) # [K_color, 10] + row_sums = tables.abs().sum(dim=-1) # [K_color, 10] + diag_dominance = (diags.abs() / row_sums.clamp(min=1e-8)).mean(dim=-1) + + return { + 'remap_tables': tables, + 'table_diag_dominance': diag_dominance, + } + + def analyze_temperature(self) -> dict: + """Analyze Gumbel-Softmax temperature across all steps. + + Returns: + dict with: + 'temperatures': [num_steps] current temperature per step + 'is_sharp': [num_steps] bool — True if tau < 0.5 (near-discrete) + """ + temps = [] + for step in self.model.vm.steps: + tau = step.search.log_temperature.exp().clamp(0.1, 5.0) + temps.append(tau.item()) + + temps_t = torch.tensor(temps) + return { + 'temperatures': temps_t, + 'is_sharp': temps_t < 0.5, + } + + # ------------------------------------------------------------------ + # Dynamic analysis (requires a batch) + # ------------------------------------------------------------------ + + def analyze(self, batch: dict) -> dict: + """Full analysis of one batch through both phases. + + Args: + batch: Collated ARC batch from collate_arc. + + Returns: + dict with: + 'instructions': instruction decomposition (static) + 'color_unit': color remapping analysis (static) + 'phase1': {cursors, search_scores, search_weights, + gate_values, halt_probs} + 'phase2': same structure as phase1 + 'cursor_after_phase1': [B, 4] + 'cursor_after_phase2': [B, 4] + 'predictions': [B, N_test] predicted colors + 'targets': [B, N_test] ground truth + 'cell_accuracy': float + 'grid_correct': [B] bool per example + 'test_masks': [B, N_test] validity mask + """ + num_steps = self.model.vm.num_steps + + # Run full forward with trace + with torch.no_grad(): + result = self._run_forward(batch) + + logits = result['logits'] + preds = logits.argmax(dim=-1) + trace = result['trace'] + + # Split trace into Phase 1 and Phase 2 + phase1_trace = {k: v[:num_steps] for k, v in trace.items()} + phase2_trace = {k: v[num_steps:] for k, v in trace.items()} + + # Targets + test_outputs = batch['test_outputs'].to(self.device) + test_masks = batch['test_masks'].to(self.device) + B, H_max, W_max = test_outputs.shape + targets = test_outputs.reshape(B, H_max * W_max) + valid = test_masks.reshape(B, H_max * W_max) + + # Metrics + matches = (preds == targets) & valid + cell_acc = matches.sum().item() / max(valid.sum().item(), 1) + + grid_correct = torch.zeros(B, dtype=torch.bool) + test_sizes = batch['test_sizes'] + for i in range(B): + toH, toW = test_sizes[i] + N = toH * toW + grid_correct[i] = (preds[i, :N] == targets[i, :N]).all() + + return { + 'instructions': self.analyze_instructions(), + 'color_unit': self.analyze_color_unit(), + 'phase1': phase1_trace, + 'phase2': phase2_trace, + 'cursor_after_phase1': phase1_trace['cursors'][-1] if phase1_trace['cursors'] else None, + 'cursor_after_phase2': phase2_trace['cursors'][-1] if phase2_trace['cursors'] else None, + 'predictions': preds, + 'targets': targets, + 'cell_accuracy': cell_acc, + 'grid_correct': grid_correct, + 'test_masks': valid, + } + + def predict(self, batch: dict) -> dict: + """Lightweight prediction — just logits and accuracy. + + Args: + batch: Collated ARC batch. + + Returns: + dict with 'predictions', 'targets', 'cell_accuracy', 'grid_correct'. + """ + with torch.no_grad(): + result = self._run_forward(batch) + + logits = result['logits'] + preds = logits.argmax(dim=-1) + + test_outputs = batch['test_outputs'].to(self.device) + test_masks = batch['test_masks'].to(self.device) + B, H_max, W_max = test_outputs.shape + targets = test_outputs.reshape(B, H_max * W_max) + valid = test_masks.reshape(B, H_max * W_max) + + matches = (preds == targets) & valid + cell_acc = matches.sum().item() / max(valid.sum().item(), 1) + + grid_correct = torch.zeros(B, dtype=torch.bool) + test_sizes = batch['test_sizes'] + for i in range(B): + toH, toW = test_sizes[i] + N = toH * toW + grid_correct[i] = (preds[i, :N] == targets[i, :N]).all() + + return { + 'predictions': preds, + 'targets': targets, + 'cell_accuracy': cell_acc, + 'grid_correct': grid_correct, + } + + # ------------------------------------------------------------------ + # Report formatting + # ------------------------------------------------------------------ + + def format_instruction_report(self) -> str: + """Human-readable instruction template summary.""" + info = self.analyze_instructions() + K = info['templates_raw'].shape[0] + lines = ['=== Instruction Template Analysis (PGA) ===', ''] + + for k in range(K): + angle_deg = info['rotation_degrees'][k].item() + plane = info['rotation_planes'][k] + trans = info['translation_vectors'][k] + trans_norm = info['translation_norms'][k].item() + ctrl = info['color_control'][k] + identity = info['near_identity'][k].item() + + lines.append(f'Template {k}:') + lines.append(f' Rotation: {angle_deg:6.1f}deg ' + f'plane=({plane[0]:.2f}*e01 + {plane[1]:.2f}*e02 + {plane[2]:.2f}*e12)') + lines.append(f' Translation: |t|={trans_norm:.3f} ' + f'({trans[0]:.3f}*e03 + {trans[1]:.3f}*e13 + {trans[2]:.3f}*e23)') + lines.append(f' Color ctrl: grade0={ctrl[0]:.3f} pseudoscalar={ctrl[1]:.3f}') + if identity: + lines.append(f' ** NEAR IDENTITY (no-op) **') + lines.append('') + + return '\n'.join(lines) + + def format_cursor_report(self, report: dict) -> str: + """Human-readable cursor trajectory summary.""" + lines = ['=== Cursor Trajectory ===', ''] + + # Cl(1,1) components: {1, e3, e4, e34} + labels = ['scalar(confidence)', 'e3(hypothesis)', 'e4(depth)', 'e34(phase)'] + + for phase_name, phase_key in [('Phase 1 (Rule Inference)', 'phase1'), + ('Phase 2 (Rule Application)', 'phase2')]: + cursors = report[phase_key]['cursors'] + if not cursors: + continue + lines.append(f'{phase_name}:') + for t, cursor in enumerate(cursors): + vals = cursor[0] # first batch element + components = ' '.join(f'{labels[j]}={vals[j]:+.4f}' for j in range(4)) + lines.append(f' Step {t}: {components}') + lines.append('') + + return '\n'.join(lines) + + def format_search_report(self, report: dict) -> str: + """Human-readable hypothesis selection summary.""" + lines = ['=== Hypothesis Selection ===', ''] + + for phase_name, phase_key in [('Phase 1', 'phase1'), ('Phase 2', 'phase2')]: + weights_list = report[phase_key]['search_weights'] + if not weights_list: + continue + lines.append(f'{phase_name}:') + for t, w in enumerate(weights_list): + w0 = w[0] # first batch element + dominant = w0.argmax().item() + w_str = ' '.join(f'H{k}={w0[k]:.3f}' for k in range(w0.shape[0])) + lines.append(f' Step {t}: [{w_str}] dominant=H{dominant}') + lines.append('') + + return '\n'.join(lines) + + def format_gate_report(self, report: dict) -> str: + """Human-readable write gate summary.""" + lines = ['=== Write Gate Analysis ===', ''] + + for phase_name, phase_key in [('Phase 1', 'phase1'), ('Phase 2', 'phase2')]: + gates = report[phase_key]['gate_values'] + if not gates: + continue + lines.append(f'{phase_name}:') + for t, g in enumerate(gates): + g0 = g[0].squeeze(-1) # [N] for first batch element + lines.append( + f' Step {t}: mean={g0.mean():.3f} ' + f'min={g0.min():.3f} max={g0.max():.3f} ' + f'accept(>0.5)={(g0 > 0.5).float().mean():.1%}' + ) + lines.append('') + + return '\n'.join(lines) + + def full_report(self, batch: dict) -> str: + """Generate complete human-readable analysis report.""" + report = self.analyze(batch) + + sections = [ + self.format_instruction_report(), + self.format_cursor_report(report), + self.format_search_report(report), + self.format_gate_report(report), + '', + '=== Prediction Summary ===', + f' Cell accuracy: {report["cell_accuracy"]:.4f}', + f' Grid correct: {report["grid_correct"].sum().item()}/{report["grid_correct"].shape[0]}', + ] + return '\n'.join(sections) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_templates(self) -> torch.Tensor: + """Get instruction templates from the first step (shared across steps).""" + return self.model.vm.steps[0].search.instruction_templates.detach() + + def _run_forward(self, batch: dict) -> dict: + """Run model forward with trace, handling device transfer.""" + demo_inputs = batch['demo_inputs'].to(self.device) + demo_outputs = batch['demo_outputs'].to(self.device) + demo_masks = batch['demo_masks'].to(self.device) + demo_output_masks = batch.get('demo_output_masks', demo_masks).to(self.device) + test_inputs = batch['test_inputs'].to(self.device) + test_masks = batch['test_masks'].to(self.device) + num_demos = batch['num_demos'].to(self.device) + + return self.model( + demo_inputs, demo_outputs, demo_masks, + test_inputs, test_masks, num_demos, + demo_output_masks=demo_output_masks, + return_trace=True, + ) diff --git a/models/gtm/control_plane.py b/models/gtm/control_plane.py new file mode 100644 index 0000000..b6c8701 --- /dev/null +++ b/models/gtm/control_plane.py @@ -0,0 +1,122 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Cl(1,1) learnable search controller with rook movement. + +4D state: cursor = s*1 + h*e3 + d*e4 + p*e34 + s (scalar): confidence + h (e3): hypothesis index + d (e4): depth + p (e34): phase (bivector) + +Rook movement: only horizontal OR vertical per step. + Horizontal boost: R_h = exp(alpha * e34) — boosts e3 (explore hypotheses) + Vertical boost: R_v = exp(-beta * e34) — boosts e4 (go deeper) + Direction gate (sigmoid MLP): selects pos' = gate*R_h(pos) + (1-gate)*R_v(pos) + +e34 is hyperbolic (sq=+1 in Cl(1,1)), so boosts use cosh/sinh. +""" + +import torch +import torch.nn as nn +from core.algebra import CliffordAlgebra +from layers.primitives.base import CliffordModule + + +class ControlPlane(CliffordModule): + """Cl(1,1) learnable search controller.""" + + def __init__(self, algebra_ctrl: CliffordAlgebra, channels: int, + max_hypotheses: int = 4): + assert algebra_ctrl.p == 1 and algebra_ctrl.q == 1, \ + f"ControlPlane requires Cl(1,1), got Cl({algebra_ctrl.p},{algebra_ctrl.q})" + super().__init__(algebra_ctrl) + self.channels = channels + self.max_hypotheses = max_hypotheses + # Cl(1,1) dim = 4: {1, e3, e4, e34} mapped to indices {0, 1, 2, 3} + + # Boost parameters (learnable) + self.alpha_mlp = nn.Sequential( + nn.Linear(channels + 4, 32), + nn.ReLU(), + nn.Linear(32, 1), + ) + self.beta_mlp = nn.Sequential( + nn.Linear(channels + 4, 32), + nn.ReLU(), + nn.Linear(32, 1), + ) + + # Direction gate: horizontal vs vertical + self.direction_gate = nn.Sequential( + nn.Linear(channels + 4, 32), + nn.ReLU(), + nn.Linear(32, 1), + ) + + # Halt signal from cursor + self.halt_mlp = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 1), + ) + + def step(self, cursor: torch.Tensor, + cpu_context: torch.Tensor) -> tuple: + """Advance the control cursor one step. + + Args: + cursor: Current cursor [B, 4] in Cl(1,1). + cpu_context: Summary of CPU state [B, channels] (e.g., mean-pooled grade norms). + + Returns: + Tuple of (new_cursor [B, 4], direction_logit [B, 1], halt_prob [B]). + """ + B = cursor.shape[0] + device = cursor.device + self.algebra.ensure_device(device) + + # Combine cursor with CPU context for MLPs + combined = torch.cat([cursor, cpu_context], dim=-1) # [B, 4 + channels] + + # Compute boost magnitudes + alpha = self.alpha_mlp(combined).squeeze(-1) # [B] + beta = self.beta_mlp(combined).squeeze(-1) # [B] + + # Build bivector for horizontal boost: alpha * e34 + bv_h = torch.zeros(B, 4, device=device, dtype=cursor.dtype) + bv_h[:, 3] = alpha # e34 component + + # Build bivector for vertical boost: -beta * e34 + bv_v = torch.zeros(B, 4, device=device, dtype=cursor.dtype) + bv_v[:, 3] = -beta # e34 component + + # Exponentiate boosts + R_h = self.algebra.exp(-0.5 * bv_h) # [B, 4] + R_v = self.algebra.exp(-0.5 * bv_v) # [B, 4] + + # Apply boosts to cursor via sandwich product + # For Cl(1,1), we can use geometric_product directly (1D batch) + R_h_rev = self.algebra.reverse(R_h) + R_v_rev = self.algebra.reverse(R_v) + + cursor_h = self.algebra.geometric_product( + self.algebra.geometric_product(R_h, cursor), R_h_rev + ) + cursor_v = self.algebra.geometric_product( + self.algebra.geometric_product(R_v, cursor), R_v_rev + ) + + # Direction gate + direction_logit = self.direction_gate(combined) # [B, 1] + gate = torch.sigmoid(direction_logit) # [B, 1] + new_cursor = gate * cursor_h + (1.0 - gate) * cursor_v # [B, 4] + + # Halt probability from grade-0 of cursor + halt_prob = torch.sigmoid(self.halt_mlp(new_cursor)).squeeze(-1) # [B] + + return new_cursor, direction_logit, halt_prob diff --git a/models/gtm/cpu.py b/models/gtm/cpu.py new file mode 100644 index 0000000..be489a7 --- /dev/null +++ b/models/gtm/cpu.py @@ -0,0 +1,174 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""PGA Motor CPU + ColorUnit — Cl(3,0,1) computation engine. + +Core operations (three-part transform): + Part A — Motor transform: M = exp(-grade_2(instr)/2), X' = MXM~ + The 6 bivectors split into 3 rotation (e01,e02,e12) and + 3 translation (e03,e13,e23) components. The parabolic exp branch + in core/algebra.py handles null bivectors: exp(t*e03) = 1 + t*e03. + Part B — ColorUnit: discrete color remapping conditioned on instruction + K_color learnable tables [K_color, 10, 10], selected by grade-0 + grade-4. + Part C — Merge: spatial from motor, color from ColorUnit. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.algebra import CliffordAlgebra + + +class ColorUnit(nn.Module): + """Discrete color remapping conditioned on instruction. + + K_color learnable remapping tables [K_color, 10, 10]. + Instruction's grade-0 and grade-4 select and blend tables. + """ + + def __init__(self, K_color: int = 4): + super().__init__() + self.K_color = K_color + # Initialize as near-identity: eye(10) + small noise per table + self.remap_tables = nn.Parameter( + torch.eye(10).unsqueeze(0).expand(K_color, -1, -1).clone() + + torch.randn(K_color, 10, 10) * 0.01 + ) + # Selector: grade-0 (idx 0) + grade-4 (idx 15) → table weights + self.selector = nn.Linear(2, K_color) + + def forward(self, state: torch.Tensor, + instruction: torch.Tensor) -> torch.Tensor: + """Apply color remapping to grade-0 and update occupancy. + + Args: + state: [L, N, 16] PGA multivectors after motor transform. + instruction: [L, 16] instruction multivectors. + + Returns: + [L, N, 16] state with grade-0 (color) and grade-4 (occupancy) updated. + """ + L, N, D = state.shape + + # Extract selector features from instruction + sel_input = torch.stack([instruction[:, 0], instruction[:, 15]], dim=-1) # [L, 2] + table_weights = F.softmax(self.selector(sel_input), dim=-1) # [L, K_color] + + # Blend remapping tables: [L, 10, 10] + # table_weights: [L, K] @ remap_tables: [K, 10, 10] -> [L, 10, 10] + blended = torch.einsum('lk,kij->lij', table_weights, self.remap_tables) + + # Extract current color: grade-0 → soft 10-class + raw_color = state[:, :, 0] * 9.0 # [L, N] in [0, 9] range + # Create soft one-hot via distance to each integer class + centers = torch.arange(10, device=state.device, dtype=state.dtype) # [10] + # Soft assignment: exp(-4 * (color - center)^2) + diffs = raw_color.unsqueeze(-1) - centers # [L, N, 10] + soft_color = F.softmax(-4.0 * diffs.pow(2), dim=-1) # [L, N, 10] + + # Apply remapping: [L, N, 10] @ [L, 10, 10] -> [L, N, 10] + remapped = torch.bmm( + soft_color.reshape(L, N, 10), + blended + ) # [L, N, 10] + + # Convert back to scalar: expected value / 9.0 + new_color = torch.einsum('lni,i->ln', remapped, centers) / 9.0 # [L, N] + + # Update occupancy flag (grade-4 pseudoscalar idx 15) + new_occupancy = 1.0 - remapped[:, :, 0] # prob of NOT being color 0 + + # Construct output: only modify grade-0 and grade-4 + out = state.clone() + out[:, :, 0] = new_color + out[:, :, 15] = new_occupancy + + return out + + +class GeometricCPU(nn.Module): + """PGA Cl(3,0,1) computation engine with Motor + ColorUnit. + + The motor transform handles both rotation (e01,e02,e12 bivectors) + and translation (e03,e13,e23 null bivectors) via a single sandwich product. + The ColorUnit handles discrete color remapping. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, K_color: int = 4): + super().__init__() + assert algebra_cpu.p == 3 and algebra_cpu.r == 1, \ + f"GeometricCPU requires Cl(3,0,1), got Cl({algebra_cpu.p},{algebra_cpu.q},{algebra_cpu.r})" + self.algebra = algebra_cpu + self.color_unit = ColorUnit(K_color) + + def _transform(self, state: torch.Tensor, instruction: torch.Tensor) -> torch.Tensor: + """Core transform: PGA motor sandwich + color remapping. + + Args: + state: [L, N, 16] — L can be B (single) or B*K (batched). + instruction: [L, 16]. + + Returns: + [L, N, 16] transformed state. + """ + L, N, D = state.shape + + # Part A: Motor Transform (rotation + translation via PGA sandwich) + bv = self.algebra.grade_projection(instruction, 2) # [L, 16] + M = self.algebra.exp(-0.5 * bv) # [L, 16] — motor (rotation + translation) + M_rev = self.algebra.reverse(M) # [L, 16] + + M_exp = M.unsqueeze(1).expand(L, N, D).reshape(L * N, D) + M_rev_exp = M_rev.unsqueeze(1).expand(L, N, D).reshape(L * N, D) + state_flat = state.reshape(L * N, 1, D) + + spatial_out = self.algebra.sandwich_product( + M_exp, state_flat, M_rev_exp + ).reshape(L, N, D) + + # Part B: Color Remapping (grade-0 and grade-4 only) + color_out = self.color_unit(spatial_out, instruction) + + return color_out + + def execute(self, state: torch.Tensor, instruction: torch.Tensor) -> torch.Tensor: + """Apply PGA Motor + ColorUnit to state. + + Args: + state: CPU state [B, N, 16] — per-cell multivectors. + instruction: Instruction multivector [B, 16]. + + Returns: + New state [B, N, 16]. + """ + self.algebra.ensure_device(state.device) + return self._transform(state, instruction) + + def execute_all(self, state: torch.Tensor, + instructions: torch.Tensor) -> torch.Tensor: + """Execute K instructions in a single batched call. + + Reshapes [B, N, 16] x [B, K, 16] into [B*K, N, 16] x [B*K, 16], + runs one _transform call, then reshapes back to [B, K, N, 16]. + + Args: + state: CPU state [B, N, 16]. + instructions: K instruction multivectors [B, K, 16]. + + Returns: + Tensor [B, K, N, 16] — all K outcomes stacked. + """ + B, N, D = state.shape + K = instructions.shape[1] + self.algebra.ensure_device(state.device) + + # Expand state for all K instructions: [B, K, N, D] -> [B*K, N, D] + state_exp = state.unsqueeze(1).expand(B, K, N, D).reshape(B * K, N, D) + instr_flat = instructions.reshape(B * K, D) + + result = self._transform(state_exp, instr_flat) # [B*K, N, D] + return result.reshape(B, K, N, D) diff --git a/models/gtm/grid_codec.py b/models/gtm/grid_codec.py new file mode 100644 index 0000000..06c7348 --- /dev/null +++ b/models/gtm/grid_codec.py @@ -0,0 +1,136 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Deterministic ARC grid <-> Cl(3,0,1) PGA multivector codec. + +Grids are kept as 2D tensors [H, W] (or padded [B, H_max, W_max]). +Row and column are directly read from 2D indices — no flattening required. + +PGA encoding (each cell -> 1 multivector in Cl(3,0,1), dim=16): + Grade-0 (scalar, idx 0): color / 9.0 (invariant under all sandwich products) + Grade-1 (vectors): + idx 1 (e0): row (integer) — spatial position + idx 2 (e1): col (integer) — spatial position + idx 4 (e2): 0.0 — reserved for role embed / auxiliary + idx 8 (e3): 1.0 — homogeneous coord (enables PGA translation) + Grade-2 (bivectors): + idx 3 (e01): row * col — spatial correlation + Grade-4 (pseudoscalar): + idx 15 (e0123): 1.0 if non-background (color!=0), else 0.0 + +Integer coordinates: no max_grid_size normalization. CliffordLayerNorm +in the VM handles normalization across steps. +""" + +import torch +from core.algebra import CliffordAlgebra + + +class GridCodec: + """Deterministic encoder/decoder for ARC grids. No learnable parameters. + + Operates on 2D grids [H, W] or batched [B, H_max, W_max] with masks. + Uses PGA Cl(3,0,1) with dim=16. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, coord_scale: float = 1.0): + assert algebra_cpu.p == 3 and algebra_cpu.r == 1, \ + f"GridCodec requires Cl(3,0,1), got Cl({algebra_cpu.p},{algebra_cpu.q},{algebra_cpu.r})" + self.algebra = algebra_cpu + self.coord_scale = coord_scale + + def encode_grid(self, grid: torch.Tensor) -> torch.Tensor: + """Encode a single 2D grid into multivectors. + + Args: + grid: Integer grid [H, W] with values in [0, 9]. + + Returns: + Multivectors [H, W, 16] in Cl(3,0,1). + """ + H, W = grid.shape + device = grid.device + cs = self.coord_scale + + mv = torch.zeros(H, W, 16, device=device, dtype=torch.float32) + colors = grid.float() + + # Row and col coordinate grids (integer, no normalization) + rows = torch.arange(H, device=device).float().unsqueeze(1).expand(H, W) + cols = torch.arange(W, device=device).float().unsqueeze(0).expand(H, W) + + # Grade-0 (idx 0): normalized color + mv[:, :, 0] = colors / 9.0 + + # Grade-1 (idx 1=e0, 2=e1, 4=e2, 8=e3): spatial position + homogeneous + mv[:, :, 1] = rows * cs + mv[:, :, 2] = cols * cs + # idx 4 (e2) left zero — reserved for auxiliary features / role embed + mv[:, :, 8] = 1.0 # e3 homogeneous coord (enables PGA translations) + + # Grade-2 (idx 3=e01): spatial correlation + mv[:, :, 3] = rows * cols * (cs * cs) + + # Grade-4 pseudoscalar (idx 15=e0123): occupancy flag + mv[:, :, 15] = (colors > 0).float() + + return mv + + def encode_batch(self, grids: torch.Tensor, + masks: torch.Tensor) -> tuple: + """Encode a batch of padded 2D grids into flat multivector sequences. + + Args: + grids: Padded grids [B, H_max, W_max] (long). + masks: Validity masks [B, H_max, W_max] (bool). + + Returns: + Tuple of: + mv: [B, N_max, 16] flattened multivectors (N_max = H_max * W_max) + flat_masks: [B, N_max] bool + """ + B, H_max, W_max = grids.shape + N_max = H_max * W_max + device = grids.device + cs = self.coord_scale + + colors = grids.float() + rows = torch.arange(H_max, device=device).float().view(1, H_max, 1).expand(B, H_max, W_max) + cols = torch.arange(W_max, device=device).float().view(1, 1, W_max).expand(B, H_max, W_max) + + mv = torch.zeros(B, H_max, W_max, 16, device=device, dtype=torch.float32) + mv[:, :, :, 0] = colors / 9.0 + mv[:, :, :, 1] = rows * cs + mv[:, :, :, 2] = cols * cs + mv[:, :, :, 8] = 1.0 # e3 homogeneous coord + mv[:, :, :, 3] = rows * cols * (cs * cs) + mv[:, :, :, 15] = (colors > 0).float() + + # Zero out padding cells + mv = mv * masks.unsqueeze(-1).float() + + # Flatten spatial dims: [B, H_max, W_max, 16] -> [B, N_max, 16] + mv = mv.reshape(B, N_max, 16) + flat_masks = masks.reshape(B, N_max) + + return mv, flat_masks + + def decode(self, mv: torch.Tensor, H: int, W: int) -> torch.Tensor: + """Decode multivectors back to a 2D grid. + + Args: + mv: Multivectors [H*W, 16] or [H, W, 16]. + H: Grid height. + W: Grid width. + + Returns: + Integer grid [H, W] with values in [0, 9]. + """ + flat = mv.reshape(-1, 16) + colors = flat[:H * W, 0] * 9.0 + colors = colors.round().long().clamp(0, 9) + return colors.reshape(H, W) diff --git a/models/gtm/gtm_net.py b/models/gtm/gtm_net.py new file mode 100644 index 0000000..4415f93 --- /dev/null +++ b/models/gtm/gtm_net.py @@ -0,0 +1,247 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""GTMNet: Grid-native Geometric Turing Machine for ARC-AGI v4. + +Two-phase few-shot pipeline with Rule Memory Bank: + Phase 1 — Rule Inference: + 1. Encode demo (input,output) pairs -> PGA multivectors + 2. TuringVM processes demo cells -> cpu_state encodes patterns + 3. RuleAggregator compresses demo cpu_state into M rule slots + Phase 2 — Rule Application: + 4. Encode test input -> PGA multivectors + 5. TuringVM processes test cells, using ctrl_cursor + rule_memory + 6. GridReconstructionHead predicts color logits + +Information bridge: M rule slots * 16 dims = 128 floats (vs 4 floats before) +plus the 4D ctrl_cursor for halt control / step navigation. +""" + +import torch +import torch.nn as nn +from core.algebra import CliffordAlgebra +from .grid_codec import GridCodec +from .turing_vm import TuringVM +from .heads import GridReconstructionHead +from .rule_memory import RuleAggregator + + +class GTMNet(nn.Module): + """Grid-native Geometric Turing Machine Network. + + Two sub-algebras (Mother algebra removed): + CPU Cl(3,0,1): PGA computation engine (motor + color) + Control Cl(1,1): learnable search controller + """ + + def __init__( + self, + algebra_cpu: CliffordAlgebra, + algebra_ctrl: CliffordAlgebra, + channels: int = 16, + num_steps: int = 8, + max_steps: int = 20, + num_hypotheses: int = 4, + top_k: int = 1, + head_hidden: int = 64, + temperature_init: float = 1.0, + use_act: bool = False, + lambda_p: float = 0.5, + coord_scale: float = 1.0, + K_color: int = 4, + num_attn_heads: int = 4, + attn_head_dim: int = 8, + num_rule_slots: int = 8, + ): + super().__init__() + self.algebra_cpu = algebra_cpu + self.algebra_ctrl = algebra_ctrl + self.channels = channels + + D_cpu = algebra_cpu.dim # 16 + + # Grid codec (deterministic, no params) + self.codec = GridCodec(algebra_cpu, coord_scale) + + # Learnable initial control cursor [4] in Cl(1,1) + self.init_cursor = nn.Parameter(torch.randn(4) * 0.01) + + # Learnable role markers injected into geometrically reserved slots: + # idx 4 (e2): reserved auxiliary vector — never used by GridCodec + # idx 15 (pseudoscalar e0123): occupancy/role flag + # Shape: [3, 2] for (e2_value, pseudoscalar_value) per role + # role 0 = demo input, role 1 = demo output, role 2 = test input + self.role_embed = nn.Parameter(torch.randn(3, 2) * 0.01) + + # Rule Memory Aggregator + self.rule_aggregator = RuleAggregator( + d_cpu=D_cpu, num_slots=num_rule_slots, num_heads=num_attn_heads, + ) + + # Turing VM + self.vm = TuringVM( + algebra_cpu, algebra_ctrl, + channels, num_steps, max_steps, + num_hypotheses, top_k, temperature_init, + use_act, lambda_p, + num_attn_heads, attn_head_dim, + K_color, num_rule_slots, + ) + + # Reconstruction head + self.head = GridReconstructionHead(algebra_cpu, head_hidden) + + def forward(self, demo_inputs: torch.Tensor, demo_outputs: torch.Tensor, + demo_masks: torch.Tensor, + test_inputs: torch.Tensor, + test_masks: torch.Tensor, num_demos: torch.Tensor, + demo_output_masks: torch.Tensor = None, + input_sizes: list = None, + return_trace: bool = False) -> dict: + """Two-phase forward pass: Rule Inference -> Rule Application. + + Phase 1 processes demo pairs through the VM to encode transformation + patterns. RuleAggregator compresses these into rule_memory slots. + Phase 2 processes test input using ctrl_cursor + rule_memory. + + Args: + demo_inputs: [B, K, H_max, W_max] padded demo input grids. + demo_outputs: [B, K, H_max, W_max] padded demo output grids. + demo_masks: [B, K, H_max, W_max] bool (True=valid input cell). + test_inputs: [B, H_max, W_max] padded test input. + test_masks: [B, H_max, W_max] bool (True=valid). + num_demos: [B] int — actual demo count per example. + demo_output_masks: [B, K, H_max, W_max] bool (True=valid output cell). + If None, falls back to demo_masks (same dims assumed). + input_sizes: Optional list of (H, W) for test inputs. + return_trace: Collect per-step diagnostics. + + Returns: + dict with: + 'logits': [B, N_test, 10] color logits for test cells + 'test_flat_masks': [B, N_test] bool + optionally 'act_info', 'trace' + """ + B, K, H_max, W_max = demo_inputs.shape + N_grid = H_max * W_max + device = demo_inputs.device + D_cpu = self.algebra_cpu.dim # 16 + + if demo_output_masks is None: + demo_output_masks = demo_masks + + # --- Encode demo pairs --- + di_flat = demo_inputs.reshape(B * K, H_max, W_max) + do_flat = demo_outputs.clamp(min=0).reshape(B * K, H_max, W_max) + dim_flat = demo_masks.reshape(B * K, H_max, W_max) + dom_flat = demo_output_masks.reshape(B * K, H_max, W_max) + + di_mv, di_fm = self.codec.encode_batch(di_flat, dim_flat) # [B*K, N_grid, 16] + do_mv, do_fm = self.codec.encode_batch(do_flat, dom_flat) # [B*K, N_grid, 16] + + # Add role markers into reserved slots + di_mv[:, :, 4] = di_mv[:, :, 4] + self.role_embed[0, 0] + di_mv[:, :, 15] = di_mv[:, :, 15] + self.role_embed[0, 1] + do_mv[:, :, 4] = do_mv[:, :, 4] + self.role_embed[1, 0] + do_mv[:, :, 15] = do_mv[:, :, 15] + self.role_embed[1, 1] + + # Interleave demo input + output: [B*K, 2*N_grid, 16] + demo_mv = torch.cat([di_mv, do_mv], dim=1) + demo_fm = torch.cat([di_fm, do_fm], dim=1) + + # Reshape: [B, K * 2 * N_grid, 16] + N_demo_per_pair = 2 * N_grid + demo_mv = demo_mv.reshape(B, K * N_demo_per_pair, D_cpu) + demo_fm = demo_fm.reshape(B, K * N_demo_per_pair) + + # Mask out unused demo pairs — vectorized, no .item() calls + total_demo_len = K * N_demo_per_pair + pos_idx = torch.arange(total_demo_len, device=device).unsqueeze(0) # [1, L] + limit = (num_demos * N_demo_per_pair).unsqueeze(1) # [B, 1] + valid_demo = pos_idx < limit # [B, L] + demo_mv = demo_mv * valid_demo.unsqueeze(-1).float() + demo_fm = demo_fm & valid_demo + + # --- Encode test input --- + test_mv, test_fm = self.codec.encode_batch(test_inputs, test_masks) + test_mv[:, :, 4] = test_mv[:, :, 4] + self.role_embed[2, 0] + test_mv[:, :, 15] = test_mv[:, :, 15] + self.role_embed[2, 1] + + # --- Init control cursor --- + ctrl_cursor = self.init_cursor.unsqueeze(0).expand(B, -1).clone() + + # === Phase 1: Rule Inference (demo only) === + # VM processes demo pairs -> cpu_state encodes patterns, ctrl_cursor updated + demo_state, ctrl_cursor, act_info_demo, trace_demo = self.vm( + demo_mv, ctrl_cursor, demo_fm, return_trace, + ) + + # Compress demo state into rule memory slots + rule_memory = self.rule_aggregator(demo_state, demo_fm) # [B, M, 16] + + # === Phase 2: Rule Application (test only) === + # VM processes test input using ctrl_cursor + rule_memory from Phase 1 + test_state, ctrl_cursor, act_info_test, trace_test = self.vm( + test_mv, ctrl_cursor, test_fm, return_trace, + rule_memory=rule_memory, + ) + + # --- Decode --- + logits = self.head(test_state, test_fm) # [B, N_grid, 10] + + result = { + 'logits': logits, + 'test_flat_masks': test_fm, + } + + # ACT info: combine KL loss from both phases + if act_info_test is not None: + result['act_info'] = { + 'kl_loss': act_info_test['kl_loss'] + act_info_demo['kl_loss'], + 'expected_steps': act_info_test['expected_steps'], + 'weights': act_info_test['weights'], + } + + # Merge traces from both phases + if return_trace: + trace_keys = ['search_scores', 'search_weights', 'halt_probs', + 'cursors', 'gate_values'] + trace = {k: [] for k in trace_keys} + for t in (trace_demo, trace_test): + if t is not None: + for k in trace_keys: + trace[k].extend(t.get(k, [])) + result['trace'] = trace + + return result + + def freeze_vm(self): + """Freeze all VM parameters (Phase 1: warmup).""" + for param in self.vm.parameters(): + param.requires_grad = False + + def unfreeze_vm(self): + """Unfreeze all VM parameters (Phase 2+).""" + for param in self.vm.parameters(): + param.requires_grad = True + + def enable_act(self): + """Enable adaptive computation time.""" + if self.vm.adaptive_halt is not None: + self.vm.use_act = True + + def disable_act(self): + """Disable adaptive computation time.""" + self.vm.use_act = False + + def trainable_parameters(self): + for param in self.parameters(): + if param.requires_grad: + yield param + + def count_trainable_params(self) -> int: + return sum(p.numel() for p in self.parameters() if p.requires_grad) diff --git a/models/gtm/heads.py b/models/gtm/heads.py new file mode 100644 index 0000000..a85839f --- /dev/null +++ b/models/gtm/heads.py @@ -0,0 +1,48 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Grid reconstruction head for ARC-AGI v4. + +Per-cell color classification from final CPU state multivectors. +""" + +import torch +import torch.nn as nn +from core.algebra import CliffordAlgebra + + +class GridReconstructionHead(nn.Module): + """Per-cell color classification from Cl(3,0,1) multivectors. + + Maps each cell's 16-component multivector to 10-class color logits. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, hidden_dim: int = 64): + super().__init__() + self.algebra = algebra_cpu + self.mlp = nn.Sequential( + nn.Linear(algebra_cpu.dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 10), # 10 ARC colors + ) + + def forward(self, cpu_state: torch.Tensor, + mask: torch.Tensor = None) -> torch.Tensor: + """Predict per-cell color logits. + + Args: + cpu_state: Final CPU state [B, N, 16]. + mask: Optional validity mask [B, N] (True=valid). Not used in + forward (handled by loss function ignore_index), but kept + for interface compatibility. + + Returns: + Logits [B, N, 10]. + """ + return self.mlp(cpu_state) diff --git a/models/gtm/rule_memory.py b/models/gtm/rule_memory.py new file mode 100644 index 0000000..f9e28cc --- /dev/null +++ b/models/gtm/rule_memory.py @@ -0,0 +1,88 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Rule Memory Bank — cross-attention aggregator for demo→test information flow. + +Compresses Phase 1 (demo) CPU state into M learnable rule slots via +cross-attention. This replaces the 4-float ctrl_cursor bottleneck as the +primary information bridge between demo and test phases. + +Information capacity: M=8 slots * 16 dims = 128 floats (vs 4 floats before). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RuleAggregator(nn.Module): + """Cross-attention from M learnable queries to demo cpu_state. + + Compresses Phase 1 output into M rule slots that encode the + transformation rule learned from demo pairs. + """ + + def __init__(self, d_cpu: int = 16, num_slots: int = 8, num_heads: int = 4): + super().__init__() + self.d_cpu = d_cpu + self.num_slots = num_slots + self.num_heads = num_heads + self.head_dim = d_cpu // num_heads + assert d_cpu % num_heads == 0, f"d_cpu={d_cpu} must be divisible by num_heads={num_heads}" + + self.scale = self.head_dim ** -0.5 + + # Learnable query templates + self.query_templates = nn.Parameter(torch.randn(num_slots, d_cpu) * 0.02) + + # Projections for cross-attention + self.q_proj = nn.Linear(d_cpu, d_cpu) + self.k_proj = nn.Linear(d_cpu, d_cpu) + # V = raw demo state (no projection — preserves geometric structure) + + def forward(self, demo_cpu_state: torch.Tensor, + demo_mask: torch.Tensor) -> torch.Tensor: + """Aggregate demo state into rule memory slots. + + Args: + demo_cpu_state: [B, N_demo, 16] CPU state after Phase 1. + demo_mask: [B, N_demo] bool (True=valid). + + Returns: + rule_memory: [B, M, 16] compressed rule representation. + """ + B, N_demo, D = demo_cpu_state.shape + M = self.num_slots + H = self.num_heads + hd = self.head_dim + + # Query from learnable templates: [M, D] -> [B, M, D] + Q = self.q_proj(self.query_templates).unsqueeze(0).expand(B, -1, -1) + # Key from demo state: [B, N_demo, D] + K = self.k_proj(demo_cpu_state) + + # Multi-head reshape + Q = Q.reshape(B, M, H, hd).transpose(1, 2) # [B, H, M, hd] + K = K.reshape(B, N_demo, H, hd).transpose(1, 2) # [B, H, N_demo, hd] + + # Attention scores + scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [B, H, M, N_demo] + + # Mask invalid demo cells + if demo_mask is not None: + pad_mask = ~demo_mask # [B, N_demo] + scores = scores.masked_fill( + pad_mask.unsqueeze(1).unsqueeze(2), float('-inf') + ) + + attn = F.softmax(scores, dim=-1) # [B, H, M, N_demo] + attn_avg = attn.mean(dim=1) # [B, M, N_demo] + + # Values: raw demo state (preserves geometric structure) + rule_memory = torch.bmm(attn_avg, demo_cpu_state) # [B, M, 16] + + return rule_memory diff --git a/models/gtm/superposition.py b/models/gtm/superposition.py new file mode 100644 index 0000000..dba3470 --- /dev/null +++ b/models/gtm/superposition.py @@ -0,0 +1,127 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Geometric Superposition Search — simplified scoring via CPU grade norms. + +Scores K instruction hypotheses using CPU state grade norms + ctrl_cursor, +dispatches trainable instruction templates (optionally modulated by rule memory) +to the PGA CPU, executes K outcomes in parallel, and selects via Gumbel-Softmax. + +Mother algebra is no longer needed — scoring uses CPU grade norms directly. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.algebra import CliffordAlgebra +from .cpu import GeometricCPU + + +class GeometricSuperpositionSearch(nn.Module): + """Geometric Superposition Search over CPU Cl(3,0,1). + + Trainable parameters: + instruction_templates: [K, 16] full Cl(3,0,1) multivectors + score_mlp: CPU grade norms + ctrl_cursor -> K scores + rule_proj: rule_memory -> per-template modulation (if rule_memory provided) + log_temperature: Gumbel-Softmax temperature (learnable) + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, + algebra_ctrl: CliffordAlgebra, + channels: int, + num_hypotheses: int = 4, + top_k: int = 1, + temperature_init: float = 1.0, + K_color: int = 4, + num_rule_slots: int = 8): + super().__init__() + self.algebra_cpu = algebra_cpu + self.algebra_ctrl = algebra_ctrl + self.channels = channels + self.num_hypotheses = num_hypotheses + self.top_k = top_k + + D_cpu = algebra_cpu.dim # 16 + + # CPU engine (has ColorUnit params) + self.pga_cpu = GeometricCPU(algebra_cpu, K_color) + + # Trainable instruction templates — full Cl(3,0,1) multivectors + self.instruction_templates = nn.Parameter( + torch.randn(num_hypotheses, D_cpu) * 0.01 + ) + + # Scoring MLP: CPU grade norms + ctrl_cursor -> K scores + cpu_grades = algebra_cpu.num_grades # 5 for Cl(3,0,1) + self.score_mlp = nn.Sequential( + nn.Linear(cpu_grades + algebra_ctrl.dim, 64), + nn.ReLU(), + nn.Linear(64, num_hypotheses), + ) + + # Rule-conditioned instruction modulation + self.rule_proj = nn.Linear(D_cpu, num_hypotheses * D_cpu) + + # Gumbel temperature (learnable) + self.log_temperature = nn.Parameter( + torch.tensor(float(torch.tensor(temperature_init).log())) + ) + + def step(self, cpu_state: torch.Tensor, + ctrl_cursor: torch.Tensor, + rule_memory: torch.Tensor = None) -> tuple: + """One superposition search step. + + Args: + cpu_state: [B, N, 16] CPU state in Cl(3,0,1). + ctrl_cursor: [B, 4] control cursor in Cl(1,1). + rule_memory: Optional [B, M, 16] rule slots from RuleAggregator. + + Returns: + Tuple of (new_cpu_state [B, N, 16], search_info dict). + """ + B, N, D_cpu = cpu_state.shape + device = cpu_state.device + K = self.num_hypotheses + + # STEP 1 — SCORE: CPU grade norms + ctrl_cursor + cpu_summary = cpu_state.mean(dim=1) # [B, 16] + self.algebra_cpu.ensure_device(device) + cpu_grade_norms = self.algebra_cpu.get_grade_norms(cpu_summary) # [B, 5] + score_input = torch.cat([cpu_grade_norms, ctrl_cursor], dim=-1) # [B, 9] + scores = self.score_mlp(score_input) # [B, K] + + # STEP 2 — DISPATCH: templates optionally modulated by rule memory + templates = self.instruction_templates.unsqueeze(0).expand(B, -1, -1) # [B, K, 16] + + if rule_memory is not None: + rule_summary = rule_memory.mean(dim=1) # [B, 16] + rule_features = self.rule_proj(rule_summary) # [B, K * 16] + rule_modulation = rule_features.view(B, K, D_cpu) # [B, K, 16] + templates = templates + rule_modulation + + # Score-dependent modulation + instructions = scores.unsqueeze(-1) * templates # [B, K, 16] + + # STEP 3 — EXECUTE: CPU applies PGA Motor + ColorUnit, K× batched + outcomes = self.pga_cpu.execute_all(cpu_state, instructions) # [B, K, N, 16] + + # STEP 4 — SELECT: Gumbel-Softmax, differentiable discrete selection + tau = self.log_temperature.exp().clamp(0.1, 5.0) + weights = F.gumbel_softmax(scores, tau=tau, hard=False) # [B, K] + + # Weighted sum via einsum (no Python loop) + new_cpu_state = torch.einsum('bk,bknd->bnd', weights, outcomes) + + search_info = { + 'scores': scores, + 'weights': weights, + 'temperature': tau.detach(), + } + + return new_cpu_state, search_info diff --git a/models/gtm/turing_step.py b/models/gtm/turing_step.py new file mode 100644 index 0000000..f810da1 --- /dev/null +++ b/models/gtm/turing_step.py @@ -0,0 +1,225 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Single GTM step: SuperpositionSearch + Cross-Grade Attention + ControlPlane. + +Key design choices: + - NO additive residual (addition destroys geometric structure after rotations). + Instead, the instruction can learn B~0 to approximate identity. + - Cross-grade dense Q/K attention: allows learning diagonal, distance, + and other 2D spatial relationships. + - Values remain raw multivectors with per-grade gain (preserves geometry). + - Geometric gating (rotor interpolation) instead of additive skip connections. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.algebra import CliffordAlgebra +from layers.primitives.normalization import CliffordLayerNorm +from .superposition import GeometricSuperpositionSearch +from .control_plane import ControlPlane + + +# Grade-to-index mapping for Cl(3,0,1), dim=16 +# Grade 0: [0] +# Grade 1: [1, 2, 4, 8] +# Grade 2: [3, 5, 6, 9, 10, 12] +# Grade 3: [7, 11, 13, 14] +# Grade 4: [15] +_GRADE_MAP_16 = torch.zeros(16, dtype=torch.long) +_GRADE_MAP_16[0] = 0 +_GRADE_MAP_16[[1, 2, 4, 8]] = 1 +_GRADE_MAP_16[[3, 5, 6, 9, 10, 12]] = 2 +_GRADE_MAP_16[[7, 11, 13, 14]] = 3 +_GRADE_MAP_16[15] = 4 + + +class CellAttention(nn.Module): + """Cross-grade self-attention over grid cells in Cl(3,0,1). + + Dense Q/K projections allow learning cross-grade features like + diagonals (e0+e1), distances, and 2D spatial relationships. + Values remain raw multivectors with per-grade gain to preserve + geometric structure in the convex combination. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, num_heads: int = 4, + head_dim: int = 8, dropout: float = 0.0): + super().__init__() + D = algebra_cpu.dim # 16 + attn_dim = num_heads * head_dim + + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + # Dense Q, K projections: allow cross-grade mixing for scoring + self.q_proj = nn.Linear(D, attn_dim) + self.k_proj = nn.Linear(D, attn_dim) + + # Per-grade gain on values (preserves geometric structure) + self.v_gain = nn.ParameterDict({ + f'g{k}': nn.Parameter(torch.ones(1)) for k in range(5) + }) + + self.dropout = nn.Dropout(dropout) + + # Grade map buffer for applying per-grade gains + self.register_buffer('grade_map', _GRADE_MAP_16.clone()) + + def _apply_grade_gains(self, x: torch.Tensor) -> torch.Tensor: + """Apply per-grade isotropic gains to multivector components.""" + gains = torch.ones(16, device=x.device, dtype=x.dtype) + for k in range(5): + mask = self.grade_map == k + gains[mask] = self.v_gain[f'g{k}'] + return x * gains + + def forward(self, x: torch.Tensor, + mask: torch.Tensor = None) -> torch.Tensor: + """Cross-grade self-attention over cells. + + Args: + x: [B, N, 16] multivectors in Cl(3,0,1). + mask: [B, N] bool, True=valid. + + Returns: + [B, N, 16] attended multivectors. + """ + B, N, D = x.shape + + # Dense Q, K projections + Q = self.q_proj(x) # [B, N, attn_dim] + K = self.k_proj(x) # [B, N, attn_dim] + + # Multi-head reshape + Q = Q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N, hd] + K = K.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N, hd] + + # Scaled dot-product attention + scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [B, H, N, N] + + if mask is not None: + pad_mask = ~mask + scores = scores.masked_fill( + pad_mask.unsqueeze(1).unsqueeze(2), float('-inf') + ) + + attn = F.softmax(scores, dim=-1) # [B, H, N, N] + attn = self.dropout(attn) + attn_avg = attn.mean(dim=1) # [B, N, N] + + # Values: raw multivector, weighted average (convex combination) + attended = torch.bmm(attn_avg, x) # [B, N, 16] + + # Per-grade gain on output + return self._apply_grade_gains(attended) + + +class TuringStep(nn.Module): + """One step of the Geometric Turing Machine. + + Composes: + 1. Cell attention (cross-cell communication with cross-grade features) + 2. Superposition search (per-cell transformation via PGA motor) + 3. Geometric write gate (interpolates via scalar gating, no additive residual) + 4. CliffordLayerNorm + 5. Control plane step + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, + algebra_ctrl: CliffordAlgebra, + channels: int, + num_hypotheses: int = 4, + top_k: int = 1, + temperature_init: float = 1.0, + num_attn_heads: int = 4, + attn_head_dim: int = 8, + attn_dropout: float = 0.0, + K_color: int = 4, + num_rule_slots: int = 8): + super().__init__() + self.channels = channels + D_cpu = algebra_cpu.dim # 16 + + # Cell-to-cell attention (cross-grade features) + self.cell_attn = CellAttention( + algebra_cpu, num_attn_heads, attn_head_dim, attn_dropout, + ) + + # Superposition search module (no mother algebra) + self.search = GeometricSuperpositionSearch( + algebra_cpu, algebra_ctrl, + channels, num_hypotheses, top_k, temperature_init, + K_color, num_rule_slots, + ) + + # Control plane + self.control = ControlPlane(algebra_ctrl, channels) + + # CPU state normalization + self.norm = CliffordLayerNorm(algebra_cpu, 1) # per-cell norm (C=1) + + # Context projection: cpu summary -> ctrl context + self.context_proj = nn.Linear(D_cpu, channels) + + # Geometric write gate: scalar gate per cell + self.write_gate = nn.Sequential( + nn.Linear(D_cpu * 2, 64), # concat(old, new) -> 64 + nn.ReLU(), + nn.Linear(64, 1), + ) + + def forward(self, cpu_state: torch.Tensor, + ctrl_cursor: torch.Tensor, + mask: torch.Tensor = None, + rule_memory: torch.Tensor = None) -> dict: + """Execute one GTM step. + + Args: + cpu_state: [B, N, 16] CPU state in Cl(3,0,1). + ctrl_cursor: [B, 4] control cursor in Cl(1,1). + mask: Optional [B, N] validity mask (True=valid). + rule_memory: Optional [B, M, 16] rule slots from RuleAggregator. + + Returns: + dict with 'cpu_state', 'ctrl_cursor', 'halt_prob', 'search_info'. + """ + old_state = cpu_state + + # 1. Cell attention (cross-cell communication) + attended = self.cell_attn(cpu_state, mask) + + # 2. Superposition search (per-cell transformation via PGA motor) + new_cpu, search_info = self.search.step(attended, ctrl_cursor, rule_memory) + + # 3. Geometric write gate (NO additive residual) + gate_input = torch.cat([old_state, new_cpu], dim=-1) # [B, N, 32] + gate = torch.sigmoid(self.write_gate(gate_input)) # [B, N, 1] + new_cpu = gate * new_cpu + (1.0 - gate) * old_state + + # 4. CliffordLayerNorm (per-cell) + B, N, D = new_cpu.shape + new_cpu_flat = new_cpu.reshape(B * N, 1, D) + new_cpu_flat = self.norm(new_cpu_flat) + new_cpu = new_cpu_flat.reshape(B, N, D) + + # 5. Control plane step + cpu_summary = new_cpu.mean(dim=1) # [B, 16] + cpu_context = self.context_proj(cpu_summary) # [B, channels] + new_cursor, direction_logit, halt_prob = self.control.step( + ctrl_cursor, cpu_context + ) + + return { + 'cpu_state': new_cpu, + 'ctrl_cursor': new_cursor, + 'halt_prob': halt_prob, + 'search_info': search_info, + 'gate_values': gate, + } diff --git a/models/gtm/turing_vm.py b/models/gtm/turing_vm.py new file mode 100644 index 0000000..715887d --- /dev/null +++ b/models/gtm/turing_vm.py @@ -0,0 +1,161 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Geometric Turing Machine execution engine — ARC-AGI v4. + +Chains TuringSteps with dual-state (cpu_state + ctrl_cursor) threading. +Supports both fixed-step and adaptive computation (PonderNet) modes. +Optionally threads rule_memory from Phase 1 to each step. +""" + +import torch +import torch.nn as nn +from core.algebra import CliffordAlgebra +from layers.primitives.normalization import CliffordLayerNorm +from .turing_step import TuringStep +from .adaptive_halt import AdaptiveHalt + + +class TuringVM(nn.Module): + """Geometric Turing Machine execution engine. + + Chains N TuringSteps with dual-state (cpu_state + ctrl_cursor). + Supports both fixed-step and adaptive computation modes. + Threads rule_memory to each step when provided. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, + algebra_ctrl: CliffordAlgebra, + channels: int, + num_steps: int = 8, + max_steps: int = 20, + num_hypotheses: int = 4, + top_k: int = 1, + temperature_init: float = 1.0, + use_act: bool = False, + lambda_p: float = 0.5, + num_attn_heads: int = 4, + attn_head_dim: int = 8, + K_color: int = 4, + num_rule_slots: int = 8): + super().__init__() + self.channels = channels + self.num_steps = num_steps + self.max_steps = max_steps + self.use_act = use_act + + # Create steps up to max_steps (ACT) or num_steps (fixed) + effective_steps = max_steps if use_act else num_steps + self.steps = nn.ModuleList([ + TuringStep( + algebra_cpu, algebra_ctrl, + channels, num_hypotheses, top_k, temperature_init, + num_attn_heads, attn_head_dim, 0.0, + K_color, num_rule_slots, + ) + for _ in range(effective_steps) + ]) + + # Adaptive halt controller + self.adaptive_halt = AdaptiveHalt(lambda_p, max_steps) if use_act else None + + # Final normalization on CPU state + self.final_norm = CliffordLayerNorm(algebra_cpu, 1) + + def forward(self, cpu_state: torch.Tensor, ctrl_cursor: torch.Tensor, + mask: torch.Tensor = None, + return_trace: bool = False, + rule_memory: torch.Tensor = None) -> tuple: + """Execute the GTM program. + + Args: + cpu_state: Initial CPU state [B, N, 16]. + ctrl_cursor: Initial control cursor [B, 4]. + mask: Optional validity mask [B, N] (True=valid). + return_trace: If True, collect per-step diagnostics. + rule_memory: Optional [B, M, 16] rule slots from RuleAggregator. + + Returns: + Tuple of (cpu_state, ctrl_cursor, act_info or None, trace or None). + """ + trace = { + 'search_scores': [], + 'search_weights': [], + 'halt_probs': [], + 'cursors': [], + 'gate_values': [], + } if return_trace else None + + if self.use_act: + return self._forward_act(cpu_state, ctrl_cursor, mask, trace, rule_memory) + else: + return self._forward_fixed(cpu_state, ctrl_cursor, mask, trace, rule_memory) + + def _forward_fixed(self, cpu_state, ctrl_cursor, mask, trace, rule_memory): + """Fixed-step execution.""" + for i in range(self.num_steps): + result = self.steps[i](cpu_state, ctrl_cursor, mask, rule_memory) + cpu_state = result['cpu_state'] + ctrl_cursor = result['ctrl_cursor'] + + if trace is not None: + trace['search_scores'].append(result['search_info']['scores'].detach()) + trace['search_weights'].append(result['search_info']['weights'].detach()) + trace['halt_probs'].append(result['halt_prob'].detach()) + trace['cursors'].append(ctrl_cursor.detach()) + trace['gate_values'].append(result['gate_values'].detach()) + + # Final norm + B, N, D = cpu_state.shape + cpu_state = self.final_norm( + cpu_state.reshape(B * N, 1, D) + ).reshape(B, N, D) + + return cpu_state, ctrl_cursor, None, trace + + def _forward_act(self, cpu_state, ctrl_cursor, mask, trace, rule_memory): + """Adaptive computation with PonderNet halting.""" + per_step_outputs = [] + halt_probs = [] + + for i, step in enumerate(self.steps): + result = step(cpu_state, ctrl_cursor, mask, rule_memory) + cpu_state = result['cpu_state'] + ctrl_cursor = result['ctrl_cursor'] + + per_step_outputs.append(cpu_state) + halt_probs.append(result['halt_prob']) + + if trace is not None: + trace['search_scores'].append(result['search_info']['scores'].detach()) + trace['search_weights'].append(result['search_info']['weights'].detach()) + trace['halt_probs'].append(result['halt_prob'].detach()) + trace['cursors'].append(ctrl_cursor.detach()) + trace['gate_values'].append(result['gate_values'].detach()) + + # Compute ACT mixing weights + act_result = self.adaptive_halt(halt_probs) + weights = act_result['weights'] # [B, T] + + # Weighted sum of per-step CPU states via einsum (no Python loop) + stacked = torch.stack(per_step_outputs, dim=1) # [B, T, N, D] + output = torch.einsum('bt,btnd->bnd', weights, stacked) + + # Final norm + B, N, D = output.shape + output = self.final_norm( + output.reshape(B * N, 1, D) + ).reshape(B, N, D) + + act_info = { + 'kl_loss': act_result['kl_loss'], + 'expected_steps': act_result['expected_steps'], + 'weights': act_result['weights'], + } + + # ctrl_cursor is from last step (not mixed — control is sequential) + return output, ctrl_cursor, act_info, trace diff --git a/models/vm/__init__.py b/models/vm/__init__.py new file mode 100644 index 0000000..9cd7a49 --- /dev/null +++ b/models/vm/__init__.py @@ -0,0 +1,19 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Virtual Machine components: LLM bridge, grade-aware projections, attention.""" + +from .bridge import LLMBridge +from .projections import GradeAwareProjectionIn, GradeWeightedProjectionOut +from .attention import GradeMaskedAttention + +__all__ = [ + "LLMBridge", + "GradeAwareProjectionIn", + "GradeWeightedProjectionOut", + "GradeMaskedAttention", +] diff --git a/models/vm/attention.py b/models/vm/attention.py new file mode 100644 index 0000000..b61ebc9 --- /dev/null +++ b/models/vm/attention.py @@ -0,0 +1,75 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Grade-masked multihead attention for the Geometric Turing Machine. + +Operates on grade-1 + grade-2 components (the "heap") of Clifford multivectors, +leaving other grades untouched. This focuses attention on the most computation- +relevant subspaces: vectors (data bus) and bivectors (instructions). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GradeMaskedAttention(nn.Module): + """Grade-masked multihead attention on Clifford multivectors. + + Extracts grade-1 + grade-2 basis elements, applies standard MHA, + then scatters the result back into the full multivector. + """ + + def __init__(self, algebra, channels, num_heads=4, dropout=0.0): + super().__init__() + # Select grade-1 + grade-2 basis elements + g1_mask = algebra.grade_masks[1] + g2_mask = algebra.grade_masks[2] + heap_mask = g1_mask | g2_mask + heap_idx = heap_mask.nonzero(as_tuple=False).squeeze(-1) + self.register_buffer('heap_idx', heap_idx) + self.heap_dim = len(heap_idx) + self.channels = channels + self.num_heads = num_heads + self.algebra_dim = algebra.dim + + proj_dim = channels * self.heap_dim + self.q_proj = nn.Linear(proj_dim, proj_dim) + self.k_proj = nn.Linear(proj_dim, proj_dim) + self.v_proj = nn.Linear(proj_dim, proj_dim) + self.out_proj = nn.Linear(proj_dim, proj_dim) + self.dropout = nn.Dropout(dropout) + self.head_dim = proj_dim // num_heads + + def forward(self, mv, key_padding_mask=None): + B, L, C, D = mv.shape + # Extract heap components + heap = mv[..., self.heap_idx] # [B, L, C, heap_dim] + heap_flat = heap.reshape(B, L, -1) # [B, L, C*heap_dim] + + Q = self.q_proj(heap_flat).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) + K = self.k_proj(heap_flat).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) + V = self.v_proj(heap_flat).reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2) + + scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5) + if key_padding_mask is not None: + scores = scores.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf') + ) + attn = F.softmax(scores, dim=-1) + attn = self.dropout(attn) + + context = torch.matmul(attn, V) + context = context.transpose(1, 2).reshape(B, L, -1) + out_flat = self.out_proj(context) + out_heap = out_flat.reshape(B, L, C, self.heap_dim) + + # Write back to full multivector + result = mv.clone() + heap_idx_exp = self.heap_idx.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(B, L, C, -1) + result.scatter_(3, heap_idx_exp, out_heap) + return result diff --git a/models/vm/bridge.py b/models/vm/bridge.py new file mode 100644 index 0000000..b6fd97c --- /dev/null +++ b/models/vm/bridge.py @@ -0,0 +1,46 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Frozen LLM prefix bridge for the Geometric Turing Machine. + +Runs the first `prefix_layers` transformer blocks of a pretrained GPT-2 model +and returns the intermediate hidden states (no ln_f). All parameters are frozen. +""" + +import torch +import torch.nn as nn + + +class LLMBridge(nn.Module): + """Frozen GPT-2 prefix bridge. + + Keeps the full GPT-2 model for state_dict compatibility, but only + runs the first `prefix_layers` transformer blocks during forward. + """ + + def __init__(self, model_name='gpt2', prefix_layers=4): + super().__init__() + from transformers import GPT2Model + self.model = GPT2Model.from_pretrained(model_name) + self.hidden_dim = self.model.config.hidden_size + self.prefix_layers = prefix_layers + # Freeze all + for p in self.model.parameters(): + p.requires_grad = False + + def forward(self, input_ids, attention_mask=None): + with torch.no_grad(): + out = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + # Use hidden state after prefix_layers (index = prefix_layers + # because index 0 is embedding output). + # No ln_f — this is an intermediate representation, not final. + hidden = out.hidden_states[self.prefix_layers] + return hidden diff --git a/models/vm/projections.py b/models/vm/projections.py new file mode 100644 index 0000000..f16eb79 --- /dev/null +++ b/models/vm/projections.py @@ -0,0 +1,80 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Grade-aware projections between LLM hidden states and Clifford multivectors.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from layers.primitives.normalization import CliffordLayerNorm + + +class GradeAwareProjectionIn(nn.Module): + """Project LLM hidden states into grade-1 multivectors. + + Linear(llm_dim -> channels * g1_dim) then scatter into the grade-1 + subspace of the full multivector, followed by CliffordLayerNorm. + """ + + def __init__(self, algebra, llm_dim, channels): + super().__init__() + g1_mask = algebra.grade_masks[1] + g1_idx = g1_mask.nonzero(as_tuple=False).squeeze(-1) + self.register_buffer('g1_idx', g1_idx) + self.g1_dim = len(g1_idx) + self.channels = channels + self.algebra_dim = algebra.dim + self.linear = nn.Linear(llm_dim, channels * self.g1_dim) + self.norm = CliffordLayerNorm(algebra, channels) + + def forward(self, x): + B, L, _ = x.shape + proj = self.linear(x) # [B, L, C*g1_dim] + proj = proj.reshape(B, L, self.channels, self.g1_dim) + mv = torch.zeros(B, L, self.channels, self.algebra_dim, + device=x.device, dtype=x.dtype) + g1_idx = self.g1_idx.unsqueeze(0).unsqueeze(0).unsqueeze(0) + g1_idx = g1_idx.expand(B, L, self.channels, -1) + mv.scatter_(3, g1_idx, proj) + mv_flat = mv.reshape(B * L, self.channels, self.algebra_dim) + mv_flat = self.norm(mv_flat) + return mv_flat.reshape(B, L, self.channels, self.algebra_dim) + + +class GradeWeightedProjectionOut(nn.Module): + """Project multivectors back to LLM dim with learned grade weights. + + Per-grade linear projections weighted by a learnable softmax over grades. + """ + + def __init__(self, algebra, channels, llm_dim, task_type='generic'): + super().__init__() + self.channels = channels + self.llm_dim = llm_dim + self.num_grades = algebra.num_grades + self.grade_weights = nn.Parameter(torch.zeros(algebra.num_grades)) + self.grade_projections = nn.ModuleList() + for g in range(algebra.num_grades): + g_mask = algebra.grade_masks[g] + g_dim = int(g_mask.sum().item()) + g_idx = g_mask.nonzero(as_tuple=False).squeeze(-1) + self.register_buffer(f'_grade_idx_{g}', g_idx) + self.grade_projections.append( + nn.Linear(channels * g_dim, llm_dim) + ) + self.layer_norm = nn.LayerNorm(llm_dim) + + def forward(self, mv): + B, L, C, D = mv.shape + weights = F.softmax(self.grade_weights, dim=0) + out = torch.zeros(B, L, self.llm_dim, device=mv.device, dtype=mv.dtype) + for g in range(self.num_grades): + idx = getattr(self, f'_grade_idx_{g}') + g_vals = mv[..., idx] # [B, L, C, g_dim] + g_flat = g_vals.reshape(B, L, -1) + out = out + weights[g] * self.grade_projections[g](g_flat) + return self.layer_norm(out) diff --git a/pyproject.toml b/pyproject.toml index 71e9478..7857951 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,9 @@ lqa = [ "sentence-transformers>=5.1.2", "datasets>=2.16.0,<4.0", ] +gtm = [ + "datasets>=2.16.0,<4.0", +] viz = [ "matplotlib>=3.8.0", "seaborn>=0.13.0", @@ -66,10 +69,10 @@ dev = [ "pytest-xdist>=3.8.0", ] all_tasks = [ - "versor[sr,md17,lqa]", + "versor[sr,md17,lqa,gtm]", ] all = [ - "versor[sr,md17,lqa,viz,demo,dev]", + "versor[sr,md17,lqa,gtm,viz,demo,dev]", ] [tool.setuptools.packages.find] diff --git a/scripts/analyze_gtm.py b/scripts/analyze_gtm.py new file mode 100644 index 0000000..904ba19 --- /dev/null +++ b/scripts/analyze_gtm.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Standalone GTM analysis script — run on a pretrained checkpoint. + +Usage: + uv run python scripts/analyze_gtm.py --checkpoint gtm_arc_best.pt + uv run python scripts/analyze_gtm.py --checkpoint gtm_arc_best.pt --device cuda + uv run python scripts/analyze_gtm.py --checkpoint gtm_arc_best.pt --n-batches 5 +""" + +import argparse +import sys +import torch + +sys.path.insert(0, '.') + +from models.gtm.analysis import GTMAnalyzer +from datalib.arc import get_arc_loaders + + +def main(): + parser = argparse.ArgumentParser(description='GTM Explainability Analysis') + parser.add_argument('--checkpoint', required=True, help='Path to checkpoint') + parser.add_argument('--device', default='cpu', help='Device (cpu/cuda)') + parser.add_argument('--data-dir', default='data/arc', help='ARC data directory') + parser.add_argument('--n-batches', type=int, default=1, + help='Number of batches to analyze') + parser.add_argument('--batch-size', type=int, default=4) + parser.add_argument('--output', default=None, help='Save report to file') + args = parser.parse_args() + + # Load model + print(f'Loading checkpoint: {args.checkpoint}') + analyzer = GTMAnalyzer.from_checkpoint(args.checkpoint, device=args.device) + + # Print static analysis (no data needed) + print() + print(analyzer.format_instruction_report()) + temp_info = analyzer.analyze_temperature() + print('=== Gumbel Temperature ===') + for i, t in enumerate(temp_info['temperatures']): + sharp = '*' if temp_info['is_sharp'][i] else '' + print(f' Step {i}: tau={t:.4f} {sharp}') + print() + + # Load validation data + loaders = get_arc_loaders( + data_dir=args.data_dir, + batch_size=args.batch_size, + include_toy=True, + toy_n_examples=500, + num_demos=3, + seed=123, + ) + val_loader = loaders['val'] + + # Dynamic analysis + full_text = [] + total_cell_acc = 0 + total_grid_correct = 0 + total_grid_count = 0 + + for batch_idx, batch in enumerate(val_loader): + if batch_idx >= args.n_batches: + break + + print(f'--- Batch {batch_idx} ---') + report_text = analyzer.full_report(batch) + print(report_text) + full_text.append(f'=== Batch {batch_idx} ===\n{report_text}') + + report = analyzer.analyze(batch) + total_cell_acc += report['cell_accuracy'] + total_grid_correct += report['grid_correct'].sum().item() + total_grid_count += report['grid_correct'].shape[0] + + # Summary + n = min(args.n_batches, len(val_loader)) + if n > 0: + print(f'\n=== Overall ({n} batches) ===') + print(f' Avg cell accuracy: {total_cell_acc / n:.4f}') + print(f' Grid correct: {total_grid_correct}/{total_grid_count}') + + if args.output: + with open(args.output, 'w') as f: + f.write('\n\n'.join(full_text)) + print(f'\nReport saved to {args.output}') + + +if __name__ == '__main__': + main() diff --git a/tasks/__init__.py b/tasks/__init__.py index b01e12a..94aa2cc 100644 --- a/tasks/__init__.py +++ b/tasks/__init__.py @@ -12,6 +12,7 @@ "MD17Task", "LQATask", "DEAPEEGTask", + "GTMTask", ] @@ -28,4 +29,7 @@ def __getattr__(name): if name == "DEAPEEGTask": from .deap_eeg import DEAPEEGTask return DEAPEEGTask + if name == "GTMTask": + from .gtm import GTMTask + return GTMTask raise AttributeError(f"module 'tasks' has no attribute {name!r}") diff --git a/tasks/gtm.py b/tasks/gtm.py new file mode 100644 index 0000000..6914f3f --- /dev/null +++ b/tasks/gtm.py @@ -0,0 +1,345 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""Geometric Turing Machine Task — ARC-AGI v4. + +Few-shot format: each training example = (demo_pairs, test_input, test_output). +The model sees K demo (input,output) pairs to infer the rule, then applies +it to a test input to produce the test output. + +Three-phase training (anti-lazy-optimization): +1. Warmup: freeze VM, train head + init_cursor + role_embed +2. Circuit Search: unfreeze VM, fixed steps, gate entropy loss +3. ACT: enable adaptive computation, KL ramp-up + +Two algebras (Mother algebra removed): + CPU Cl(3,0,1): PGA computation engine (motor + color) + Control Cl(1,1): learnable search +""" + +import torch +import torch.nn as nn +from tqdm import tqdm +from core.algebra import CliffordAlgebra +from tasks.base import BaseTask +from models.gtm import GTMNet +from log import get_logger + +logger = get_logger(__name__) + + +def _gate_entropy_loss(scores: torch.Tensor) -> torch.Tensor: + """Entropy of search scores — minimizing this encourages instruction specialization.""" + eps = 1e-8 + probs = torch.softmax(scores, dim=-1) + entropy = -(probs * torch.log(probs + eps)).sum(dim=-1) + return entropy.mean() + + +class GTMTask(BaseTask): + """Geometric Turing Machine task for ARC-AGI v4.""" + + def __init__(self, cfg): + # Training phase config + self.warmup_epochs = cfg.training.get('warmup_epochs', 5) + self.trim_epochs = cfg.training.get('trim_epochs', 50) + self.act_epochs = cfg.training.get('act_epochs', 45) + self.act_weight = cfg.training.get('act_weight', 0.01) + self.act_ramp_epochs = cfg.training.get('act_ramp_epochs', 15) + self.gate_entropy_weight = cfg.training.get('gate_entropy_weight', 0.001) + self.grad_clip = cfg.training.get('grad_clip', 1.0) + self.eval_every = cfg.training.get('eval_every', 5) + + super().__init__(cfg) + + def setup_algebra(self): + """Initialize CPU and Control algebras. Returns CPU algebra for BaseTask.""" + self.algebra_cpu = CliffordAlgebra(3, 0, 1, device=self.device) + self.algebra_ctrl = CliffordAlgebra(1, 1, 0, device=self.device) + return self.algebra_cpu + + def setup_model(self): + mcfg = self.cfg.model + act_cfg = mcfg.get('act', {}) + color_cfg = mcfg.get('color_unit', {}) + attn_cfg = mcfg.get('attention', {}) + + return GTMNet( + algebra_cpu=self.algebra_cpu, + algebra_ctrl=self.algebra_ctrl, + channels=mcfg.get('channels', 16), + num_steps=mcfg.get('num_steps', 8), + max_steps=mcfg.get('max_steps', 20), + num_hypotheses=mcfg.get('num_hypotheses', 4), + top_k=mcfg.get('top_k', 1), + head_hidden=mcfg.get('head_hidden', 64), + temperature_init=mcfg.get('gumbel_temperature', 1.0), + use_act=act_cfg.get('enabled', True), + lambda_p=act_cfg.get('lambda_p', 0.5), + coord_scale=mcfg.get('coord_scale', 1.0), + K_color=color_cfg.get('K_color', 4), + num_attn_heads=attn_cfg.get('num_heads', 4), + attn_head_dim=attn_cfg.get('head_dim', 8), + num_rule_slots=mcfg.get('num_rule_slots', 8), + ) + + def _setup_optimizer(self): + """Override to use only trainable parameters.""" + opt_type = self.cfg.training.get('optimizer_type', 'riemannian_adam') + lr = self.cfg.training.lr + trainable_params = [p for p in self.model.parameters() if p.requires_grad] + + if not trainable_params: + return torch.optim.SGD([torch.zeros(1)], lr=lr) + + if opt_type == 'riemannian_adam': + from optimizers.riemannian import RiemannianAdam + return RiemannianAdam( + trainable_params, lr=lr, + betas=self.cfg.training.get('betas', (0.9, 0.999)), + algebra=self.algebra, + max_bivector_norm=self.cfg.training.get('max_bivector_norm', 10.0), + ) + elif opt_type == 'exponential_sgd': + from optimizers.riemannian import ExponentialSGD + return ExponentialSGD( + trainable_params, lr=lr, + momentum=self.cfg.training.get('momentum', 0.9), + algebra=self.algebra, + max_bivector_norm=self.cfg.training.get('max_bivector_norm', 10.0), + ) + else: + return torch.optim.AdamW(trainable_params, lr=lr) + + def setup_criterion(self): + return nn.CrossEntropyLoss(ignore_index=-1) + + def get_data(self): + from datalib.arc import get_arc_loaders + dcfg = self.cfg.dataset + loaders = get_arc_loaders( + data_dir=dcfg.get('data_dir', 'data/arc'), + batch_size=self.cfg.training.batch_size, + include_toy=dcfg.get('include_toy', True), + toy_n_examples=dcfg.get('toy_n_examples', 5000), + toy_max_grid_size=dcfg.get('toy_max_grid_size', 10), + num_workers=self.cfg.training.get('num_workers', 0), + num_demos=dcfg.get('num_demos', 3), + pin_memory=self.device_config.pin_memory, + epoch_samples=dcfg.get('epoch_samples', 0), + ) + return loaders['train'], loaders['val'] + + def _run_model(self, batch, return_trace=False): + """Run model on a few-shot batch.""" + demo_inputs = batch['demo_inputs'].to(self.device) + demo_outputs = batch['demo_outputs'].to(self.device) + demo_masks = batch['demo_masks'].to(self.device) + demo_output_masks = batch['demo_output_masks'].to(self.device) + test_inputs = batch['test_inputs'].to(self.device) + test_masks = batch['test_masks'].to(self.device) + num_demos = batch['num_demos'].to(self.device) + + return self.model( + demo_inputs, demo_outputs, demo_masks, + test_inputs, test_masks, num_demos, + demo_output_masks=demo_output_masks, + input_sizes=batch.get('input_sizes'), + return_trace=return_trace, + ) + + def train_step(self, batch): + self.optimizer.zero_grad(set_to_none=True) + + need_trace = (self._phase >= 2 and self.gate_entropy_weight > 0) + result = self._run_model(batch, return_trace=need_trace) + + logits = result['logits'] # [B, N_grid, 10] + + # Target: test output grid flattened + test_outputs = batch['test_outputs'].to(self.device) # [B, H_max, W_max] + B, H_max, W_max = test_outputs.shape + targets = test_outputs.reshape(B, H_max * W_max) # [B, N_grid] + + loss = self.criterion( + logits.reshape(-1, 10), + targets.reshape(-1), + ) + + # ACT KL loss (Phase 3 only) + act_kl = torch.tensor(0.0, device=self.device) + if 'act_info' in result and result['act_info'] is not None: + act_kl = result['act_info']['kl_loss'] + loss = loss + self._current_act_weight * act_kl + + # Gate entropy loss (Phases 2-3) + gate_ent = torch.tensor(0.0, device=self.device) + if need_trace and 'trace' in result and result['trace'] is not None: + trace = result['trace'] + if trace['search_scores']: + ent_sum = sum(_gate_entropy_loss(s) for s in trace['search_scores']) + gate_ent = ent_sum / len(trace['search_scores']) + loss = loss + self.gate_entropy_weight * gate_ent + + self._backward(loss) + + if self.grad_clip > 0: + trainable = [p for p in self.model.parameters() if p.requires_grad] + if trainable: + torch.nn.utils.clip_grad_norm_(trainable, self.grad_clip) + + self._optimizer_step() + + logs = {'Loss': loss.item()} + if act_kl.item() > 0: + logs['ACT_KL'] = act_kl.item() + if gate_ent.item() != 0: + logs['GateEnt'] = gate_ent.item() + return loss.item(), logs + + def evaluate(self, val_loader): + self.model.eval() + cell_correct = 0 + cell_total = 0 + grid_correct = 0 + grid_total = 0 + + with torch.no_grad(): + for batch in val_loader: + result = self._run_model(batch) + logits = result['logits'] # [B, N_grid, 10] + preds = logits.argmax(dim=-1) # [B, N_grid] + + test_outputs = batch['test_outputs'].to(self.device) + test_masks = batch['test_masks'].to(self.device) + B, H_max, W_max = test_outputs.shape + targets = test_outputs.reshape(B, H_max * W_max) + valid = test_masks.reshape(B, H_max * W_max) + + # Cell accuracy (non-padded cells only) + matches = (preds == targets) & valid + cell_correct += matches.sum().item() + cell_total += valid.sum().item() + + # Grid accuracy (entire grid must match) + test_sizes = batch['test_sizes'] + for i in range(B): + toH, toW = test_sizes[i] + N = toH * toW + if (preds[i, :N] == targets[i, :N]).all(): + grid_correct += 1 + grid_total += 1 + + cell_acc = cell_correct / max(cell_total, 1) + grid_acc = grid_correct / max(grid_total, 1) + logger.info("Cell accuracy: %.4f | Grid accuracy: %.4f (%d/%d)", + cell_acc, grid_acc, grid_correct, grid_total) + return {'cell_accuracy': cell_acc, 'grid_accuracy': grid_acc} + + def visualize(self, val_loader): + pass + + def run(self): + """Three-phase training loop with ACT ramp-up.""" + logger.info("Starting GTM ARC-AGI v4 Task") + train_loader, val_loader = self.get_data() + + total_epochs = self.warmup_epochs + self.trim_epochs + self.act_epochs + self.epochs = total_epochs + + self._phase = 0 + self._current_act_weight = 0.0 + best_val_metric = 0.0 + metric_key = 'cell_accuracy' + + pbar = tqdm(range(total_epochs)) + + for epoch in pbar: + if epoch < self.warmup_epochs: + phase = 1 + elif epoch < self.warmup_epochs + self.trim_epochs: + phase = 2 + else: + phase = 3 + + if phase != self._phase: + self._phase = phase + if phase == 1: + logger.info("Phase 1: Warmup (VM frozen, train head + init_cursor)") + self.model.freeze_vm() + self.model.disable_act() + elif phase == 2: + logger.info("Phase 2: Circuit Search (fixed steps)") + self.model.unfreeze_vm() + self.model.disable_act() + elif phase == 3: + act_cfg = self.cfg.model.get('act', {}) + if act_cfg.get('enabled', True): + logger.info("Phase 3: ACT activation (adaptive computation)") + self.model.enable_act() + else: + logger.info("Phase 3: Extended training (ACT disabled)") + self.optimizer = self._setup_optimizer() + self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, mode='min', factor=0.5, patience=10) + + # ACT weight ramp + if phase == 3: + act_epoch = epoch - (self.warmup_epochs + self.trim_epochs) + ramp = min(1.0, act_epoch / self.act_ramp_epochs) if self.act_ramp_epochs > 0 else 1.0 + self._current_act_weight = self.act_weight * ramp + else: + self._current_act_weight = 0.0 + + # Training + self.model.train() + total_loss = 0 + n_batches = 0 + for batch in train_loader: + loss, _ = self.train_step(batch) + total_loss += loss + n_batches += 1 + + avg_loss = total_loss / max(n_batches, 1) + self.scheduler.step(avg_loss) + + # Validation + do_eval = ( + epoch < self.warmup_epochs or + phase == 3 or + (epoch + 1) % self.eval_every == 0 or + epoch == total_epochs - 1 + ) + + if do_eval: + val_metrics = self.evaluate(val_loader) + val_metric = val_metrics.get(metric_key, 0.0) + if val_metric > best_val_metric: + best_val_metric = val_metric + self.save_checkpoint("gtm_arc_best.pt") + else: + val_metric = best_val_metric + + display = { + 'P': phase, 'Loss': avg_loss, + metric_key: val_metric, + 'LR': self.optimizer.param_groups[0]['lr'], + } + if self._current_act_weight > 0: + display['ACT_w'] = self._current_act_weight + desc = " | ".join( + f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" + for k, v in display.items() + ) + pbar.set_description(desc) + + logger.info("Training complete. Best %s: %.4f", metric_key, best_val_metric) + + self.load_checkpoint("gtm_arc_best.pt") + final_metrics = self.evaluate(val_loader) + logger.info("Final metrics: %s", final_metrics) + return final_metrics diff --git a/uv.lock b/uv.lock index fa465a4..f128a96 100644 --- a/uv.lock +++ b/uv.lock @@ -5096,6 +5096,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/16/ee/efbd56687be60ef9af0c9c0ebe106964c07400eade5b0af8902a1d8cd58c/torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321", size = 915510070, upload-time = "2026-03-11T14:16:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/f4/39/590742415c3030551944edc2ddc273ea1fdfe8ffb2780992e824f1ebee98/torch-2.10.0-3-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b1d5e2aba4eb7f8e87fbe04f86442887f9167a35f092afe4c237dfcaaef6e328", size = 915632474, upload-time = "2026-03-11T14:15:13.666Z" }, + { url = "https://files.pythonhosted.org/packages/b6/8e/34949484f764dde5b222b7fe3fede43e4a6f0da9d7f8c370bb617d629ee2/torch-2.10.0-3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:0228d20b06701c05a8f978357f657817a4a63984b0c90745def81c18aedfa591", size = 915523882, upload-time = "2026-03-11T14:14:46.311Z" }, { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" }, @@ -5416,6 +5423,9 @@ dev = [ { name = "pytest", version = "9.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "pytest-xdist" }, ] +gtm = [ + { name = "datasets" }, +] lqa = [ { name = "datasets" }, { name = "sentence-transformers", version = "5.1.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, @@ -5449,6 +5459,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "datasets", marker = "extra == 'gtm'", specifier = ">=2.16.0,<4.0" }, { name = "datasets", marker = "extra == 'lqa'", specifier = ">=2.16.0,<4.0" }, { name = "hydra-core", specifier = ">=1.3.2" }, { name = "matplotlib", marker = "extra == 'viz'", specifier = ">=3.8.0" }, @@ -5465,10 +5476,10 @@ requires-dist = [ { name = "torch", specifier = ">=2.0.0" }, { name = "torch-geometric", marker = "extra == 'md17'", specifier = ">=2.6.1" }, { name = "tqdm", specifier = ">=4.67.3" }, - { name = "versor", extras = ["sr", "md17", "lqa"], marker = "extra == 'all-tasks'" }, - { name = "versor", extras = ["sr", "md17", "lqa", "viz", "demo", "dev"], marker = "extra == 'all'" }, + { name = "versor", extras = ["sr", "md17", "lqa", "gtm"], marker = "extra == 'all-tasks'" }, + { name = "versor", extras = ["sr", "md17", "lqa", "gtm", "viz", "demo", "dev"], marker = "extra == 'all'" }, ] -provides-extras = ["sr", "md17", "lqa", "viz", "demo", "dev", "all-tasks", "all"] +provides-extras = ["sr", "md17", "lqa", "gtm", "viz", "demo", "dev", "all-tasks", "all"] [package.metadata.requires-dev] dev = [ From e9de721fc82faa559e7787738edf823f2fccf030 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Thu, 19 Mar 2026 18:27:41 +0900 Subject: [PATCH 02/16] fix: the color can't changed by geometric algebra cause grade-0 invariant and energy imbalance issue it cause model make color parts gradient weak and indirect --- conf/task/gtm.yaml | 29 +--- datalib/benchmarks.py | 316 +++++++++++++++++++++++++++++++++++ experiments/navier_stokes.py | 162 ++++++++---------- experiments/yang_mills.py | 230 +++++++++++++------------ models/gtm/analysis.py | 39 +++-- models/gtm/control_plane.py | 25 +++ models/gtm/cpu.py | 143 ++++++---------- models/gtm/grid_codec.py | 113 +++++-------- models/gtm/gtm_net.py | 4 + models/gtm/rule_memory.py | 60 ++----- models/gtm/superposition.py | 106 +++++------- models/gtm/turing_step.py | 129 ++++---------- models/gtm/turing_vm.py | 22 +-- tasks/gtm.py | 27 ++- 14 files changed, 789 insertions(+), 616 deletions(-) create mode 100644 datalib/benchmarks.py diff --git a/conf/task/gtm.yaml b/conf/task/gtm.yaml index a7dd780..7a8d51e 100644 --- a/conf/task/gtm.yaml +++ b/conf/task/gtm.yaml @@ -1,23 +1,6 @@ # @package _global_ name: gtm -# ── RTX Pro 4500 (32 GB VRAM, Ada Lovelace) tuning notes ────────────── -# -# VRAM budget breakdown (fp16 activations via AMP): -# Phase 1 (demo): B=24, K=3, grid=30×30 → N_demo = 5400 cells -# Attention [B,H,N,N]: 24×4×5400×5400×2B ≈ 5.6 GB -# CPU state × 12 steps: 24×5400×16×2B × 12 ≈ 50 MB -# Phase 2 (test): N_test = 900 cells → attention < 150 MB -# Model params + optimizer: < 1 GB -# Total: ~8–10 GB in fp16, safely within 32 GB -# -# Key CUDA flags: -# amp: true — bf16 forward/backward (Ada Lovelace tensor cores) -# compile: true — torch.compile for kernel fusion -# cudnn_benchmark: true -# pin_memory: true — async CPU→GPU transfer -# num_workers: 4 — parallel data loading - algebra: p: 3 q: 0 @@ -49,12 +32,12 @@ dataset: toy_n_examples: 20000 toy_max_grid_size: 15 num_demos: 3 - epoch_samples: 0 # 0 = full dataset shuffle; set >0 for capped-epoch sampling + epoch_samples: 4000 # 0 = full dataset shuffle; set >0 for capped-epoch sampling training: epochs: 150 lr: 0.0005 - batch_size: 24 + batch_size: 16 optimizer_type: riemannian_adam max_bivector_norm: 10.0 @@ -62,7 +45,7 @@ training: num_workers: 4 pin_memory: true amp: true - compile: true + compile: false cudnn_benchmark: true # Three-phase schedule (scaled for 150 epochs) @@ -71,6 +54,10 @@ training: act_epochs: 70 act_weight: 0.01 act_ramp_epochs: 20 - gate_entropy_weight: 0.001 + gate_entropy_weight: 0.01 grad_clip: 1.0 eval_every: 5 + tau_start: 1.0 + tau_act_restart: 0.7 + tau_end: 0.1 + diff --git a/datalib/benchmarks.py b/datalib/benchmarks.py new file mode 100644 index 0000000..9952f88 --- /dev/null +++ b/datalib/benchmarks.py @@ -0,0 +1,316 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""BIG-Bench Hard (BBH) data loading with curriculum learning support.""" + +import re +import torch +from torch.utils.data import Dataset, DataLoader, Sampler + + +# --------------------------------------------------------------------------- +# Task difficulty tiers for curriculum learning +# --------------------------------------------------------------------------- + +TASK_TIERS = { + 1: [ # Binary (2 choices) — basic pattern matching + 'boolean_expressions', + 'navigate', + 'sports_understanding', + 'web_of_lies', + 'causal_judgement', + 'formal_fallacies', + ], + 2: [ # Simple MC (2-4 choices) — moderate reasoning + 'disambiguation_qa', + 'hyperbaton', + 'snarks', + 'ruin_names', + 'logical_deduction_three_objects', + 'tracking_shuffled_objects_three_objects', + 'temporal_sequences', + ], + 3: [ # Complex MC (5+ choices) — multi-step reasoning + 'date_understanding', + 'movie_recommendation', + 'penguins_in_a_table', + 'salient_translation_error_detection', + 'logical_deduction_five_objects', + 'tracking_shuffled_objects_five_objects', + 'reasoning_about_colored_objects', + 'geometric_shapes', + ], +} + +ALL_CURRICULUM_TASKS = [t for tier in sorted(TASK_TIERS) for t in TASK_TIERS[tier]] + + +def get_tier_for_task(task_name: str) -> int: + for tier, tasks in TASK_TIERS.items(): + if task_name in tasks: + return tier + return 3 + + +# --------------------------------------------------------------------------- +# Answer parsing +# --------------------------------------------------------------------------- + +BINARY_ANSWERS = { + 'True': 1, 'False': 0, + 'true': 1, 'false': 0, + 'Yes': 1, 'No': 0, + 'yes': 1, 'no': 0, + 'Valid': 1, 'Invalid': 0, + 'valid': 1, 'invalid': 0, +} + +MC_PATTERN = re.compile(r'^\(([A-Z])\)$') + + +def _parse_answer(target: str, task_name: str) -> tuple: + """Parse a BBH target string into (label_index, num_choices). + + Returns: + (label_index, num_choices) tuple, or (None, None) if unparseable. + """ + target = target.strip() + + if target in BINARY_ANSWERS: + return BINARY_ANSWERS[target], 2 + + mc_match = MC_PATTERN.match(target) + if mc_match: + letter = mc_match.group(1) + idx = ord(letter) - ord('A') + return idx, None # num_choices determined by scanning all examples + + # Unparseable (free-text answer) — skip gracefully + return None, None + + +# --------------------------------------------------------------------------- +# Single-task dataset +# --------------------------------------------------------------------------- + +class BBHDataset(Dataset): + """BIG-Bench Hard dataset for a single task. + + Loads from the lukaemon/bbh HuggingFace dataset, tokenizes with a + provided tokenizer, and maps answers to class indices. + Examples with unparseable answers are silently skipped. + """ + + def __init__(self, task_name: str, tokenizer, max_len: int = 512, + split: str = 'test', num_choices: int = None): + from datasets import load_dataset + ds = load_dataset("lukaemon/bbh", task_name, trust_remote_code=True) + + if split in ds: + raw = ds[split] + else: + raw = ds[list(ds.keys())[0]] + + # Parse answers, skip unparseable + parsed = [] + texts = [] + max_choice = 0 + for example in raw: + label, nc = _parse_answer(example['target'], task_name) + if label is None: + continue + if nc is not None: + max_choice = max(max_choice, nc) + else: + max_choice = max(max_choice, label + 1) + parsed.append(label) + texts.append(example['input']) + + self.num_choices = num_choices or max_choice + self.labels = parsed + + encodings = tokenizer( + texts, + max_length=max_len, + padding='max_length', + truncation=True, + return_tensors='pt', + ) + self.input_ids = encodings['input_ids'] + self.attention_mask = encodings['attention_mask'] + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + return { + 'input_ids': self.input_ids[idx], + 'attention_mask': self.attention_mask[idx], + 'labels': torch.tensor(self.labels[idx], dtype=torch.long), + } + + +# --------------------------------------------------------------------------- +# Multi-task curriculum dataset +# --------------------------------------------------------------------------- + +class BBHCurriculumDataset(Dataset): + """Multi-task BBH dataset with per-example curriculum metadata. + + Each example carries its task_id, tier, and num_valid_choices so the + training loop can mask invalid logits and the curriculum sampler can + select examples by difficulty tier. + """ + + def __init__(self, task_names, tokenizer, max_len: int = 512): + all_input_ids = [] + all_attention_masks = [] + all_labels = [] + all_num_valid = [] + all_task_ids = [] + all_tiers = [] + + self.task_names = [] + self.task_num_choices = {} + max_choices = 0 + + for task_id, task_name in enumerate(task_names): + try: + ds = BBHDataset(task_name, tokenizer, max_len) + except Exception as e: + print(f" Warning: skipping task {task_name}: {e}") + continue + + if len(ds) == 0: + print(f" Warning: no parseable examples for {task_name}") + continue + + nc = ds.num_choices + self.task_names.append(task_name) + self.task_num_choices[task_name] = nc + max_choices = max(max_choices, nc) + tier = get_tier_for_task(task_name) + + for i in range(len(ds)): + all_input_ids.append(ds.input_ids[i]) + all_attention_masks.append(ds.attention_mask[i]) + all_labels.append(ds.labels[i]) + all_num_valid.append(nc) + all_task_ids.append(task_id) + all_tiers.append(tier) + + self.input_ids = torch.stack(all_input_ids) + self.attention_mask = torch.stack(all_attention_masks) + self.labels = all_labels + self.num_valid_choices = all_num_valid + self.task_ids = all_task_ids + self.tiers = all_tiers + self.max_choices = max_choices + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + return { + 'input_ids': self.input_ids[idx], + 'attention_mask': self.attention_mask[idx], + 'labels': torch.tensor(self.labels[idx], dtype=torch.long), + 'num_valid_choices': torch.tensor(self.num_valid_choices[idx], + dtype=torch.long), + 'task_id': torch.tensor(self.task_ids[idx], dtype=torch.long), + } + + +# --------------------------------------------------------------------------- +# Curriculum sampler +# --------------------------------------------------------------------------- + +class CurriculumSampler(Sampler): + """Samples only from examples whose tier is in the active set.""" + + def __init__(self, tiers: list, active_tier_ids: set): + active = set(active_tier_ids) + self.indices = [i for i, t in enumerate(tiers) if t in active] + + def __iter__(self): + perm = torch.randperm(len(self.indices)) + return iter([self.indices[i] for i in perm]) + + def __len__(self): + return len(self.indices) + + +# --------------------------------------------------------------------------- +# Loader factories +# --------------------------------------------------------------------------- + +def get_bbh_loaders( + task_name: str, + tokenizer, + batch_size: int = 16, + max_len: int = 512, + train_ratio: float = 0.8, + num_workers: int = 0, + num_choices: int = None, +) -> dict: + """Create train/val DataLoaders for a single BBH task.""" + dataset = BBHDataset(task_name, tokenizer, max_len, num_choices=num_choices) + + n = len(dataset) + n_train = int(n * train_ratio) + n_val = n - n_train + + generator = torch.Generator().manual_seed(42) + train_ds, val_ds = torch.utils.data.random_split( + dataset, [n_train, n_val], generator=generator, + ) + + return { + 'train': DataLoader(train_ds, batch_size=batch_size, shuffle=True, + num_workers=num_workers), + 'val': DataLoader(val_ds, batch_size=batch_size, shuffle=False, + num_workers=num_workers), + 'num_choices': dataset.num_choices, + } + + +def get_curriculum_loaders( + task_names: list, + tokenizer, + max_len: int = 512, + train_ratio: float = 0.8, +) -> dict: + """Load all tasks into a single curriculum dataset with train/val split. + + Returns a dict with dataset objects and tier metadata. The experiment + script builds DataLoaders on the fly with CurriculumSampler. + """ + dataset = BBHCurriculumDataset(task_names, tokenizer, max_len) + + n = len(dataset) + n_train = int(n * train_ratio) + n_val = n - n_train + + generator = torch.Generator().manual_seed(42) + train_ds, val_ds = torch.utils.data.random_split( + dataset, [n_train, n_val], generator=generator, + ) + + # Map tiers through the subset indices + train_tiers = [dataset.tiers[i] for i in train_ds.indices] + val_tiers = [dataset.tiers[i] for i in val_ds.indices] + + return { + 'full_dataset': dataset, + 'train_dataset': train_ds, + 'val_dataset': val_ds, + 'train_tiers': train_tiers, + 'val_tiers': val_tiers, + 'max_choices': dataset.max_choices, + 'task_names': dataset.task_names, + 'task_num_choices': dataset.task_num_choices, + } diff --git a/experiments/navier_stokes.py b/experiments/navier_stokes.py index 7f17264..ef6192d 100644 --- a/experiments/navier_stokes.py +++ b/experiments/navier_stokes.py @@ -279,8 +279,8 @@ class GaugeFluidNet(nn.Module): Returns full multivector encoding pressure, velocity, vorticity, helicity. """ - def __init__(self, algebra, hidden_dim: int = 64, num_layers: int = 6, - num_spatial_freqs: int = 8, num_temporal_freqs: int = 16): + def __init__(self, algebra, hidden_dim: int = 32, num_layers: int = 4, + num_spatial_freqs: int = 4, num_temporal_freqs: int = 8): super().__init__() self.algebra = algebra self.hidden_dim = hidden_dim @@ -386,101 +386,81 @@ def compute_ns_residual(model: GaugeFluidNet, coords_raw: torch.Tensor, algebra, nu: float) -> Dict[str, torch.Tensor]: """Compute Navier-Stokes PDE residuals via autograd. + Optimised: uses a single [B,4] leaf tensor for all (x,y,z,t) instead of + four separate leaf scalars, reducing first-derivative calls from 15 → 4 + (one per output variable) and total autograd calls from 25 → 14. + Args: model: The gauge fluid network. - coords_raw: [B, 5] — (x, y, z, t, log_re). Spatial/temporal coords - are re-created as leaf tensors with requires_grad=True. + coords_raw: [B, 5] — (x, y, z, t, log_re). algebra: CliffordAlgebra(3, 0). nu: Kinematic viscosity. Returns: - Dict of loss terms: ns_residual, div_residual, lagrangian, vorticity_consistency. + Dict: ns_residual, div_residual, lagrangian, vorticity_consistency, mv. """ - # Create leaf tensors for autograd differentiation - x = coords_raw[:, 0:1].detach().requires_grad_(True) - y = coords_raw[:, 1:2].detach().requires_grad_(True) - z = coords_raw[:, 2:3].detach().requires_grad_(True) - t = coords_raw[:, 3:4].detach().requires_grad_(True) + # Single leaf tensor for all differentiable coordinates [B, 4] + xyzt = coords_raw[:, :4].detach().requires_grad_(True) log_re = coords_raw[:, 4:5].detach() + coords = torch.cat([xyzt, log_re], dim=-1) # [B, 5] - # Reconstruct coords from leaf tensors - coords = torch.cat([x, y, z, t, log_re], dim=-1) - - # Forward pass mv, _ = model(coords) - # Extract fields from multivector - p = mv[:, 0] # pressure (grade-0) - u1 = mv[:, 1] # velocity e₁ - u2 = mv[:, 2] # velocity e₂ - u3 = mv[:, 4] # velocity e₃ - w3_pred = mv[:, 3] # vorticity e₁₂ (ω₃) - w2_pred = -mv[:, 5] # vorticity e₁₃: stored as -ω₂, negate to recover ω₂ - w1_pred = mv[:, 6] # vorticity e₂₃ (ω₁) - - # --- First derivatives --- - # ∂u_i/∂x_j and ∂u_i/∂t - du1_dx = torch.autograd.grad(u1.sum(), x, create_graph=True)[0].squeeze(-1) - du1_dy = torch.autograd.grad(u1.sum(), y, create_graph=True)[0].squeeze(-1) - du1_dz = torch.autograd.grad(u1.sum(), z, create_graph=True)[0].squeeze(-1) - du1_dt = torch.autograd.grad(u1.sum(), t, create_graph=True)[0].squeeze(-1) - - du2_dx = torch.autograd.grad(u2.sum(), x, create_graph=True)[0].squeeze(-1) - du2_dy = torch.autograd.grad(u2.sum(), y, create_graph=True)[0].squeeze(-1) - du2_dz = torch.autograd.grad(u2.sum(), z, create_graph=True)[0].squeeze(-1) - du2_dt = torch.autograd.grad(u2.sum(), t, create_graph=True)[0].squeeze(-1) - - du3_dx = torch.autograd.grad(u3.sum(), x, create_graph=True)[0].squeeze(-1) - du3_dy = torch.autograd.grad(u3.sum(), y, create_graph=True)[0].squeeze(-1) - du3_dz = torch.autograd.grad(u3.sum(), z, create_graph=True)[0].squeeze(-1) - du3_dt = torch.autograd.grad(u3.sum(), t, create_graph=True)[0].squeeze(-1) - - # Pressure gradients - dp_dx = torch.autograd.grad(p.sum(), x, create_graph=True)[0].squeeze(-1) - dp_dy = torch.autograd.grad(p.sum(), y, create_graph=True)[0].squeeze(-1) - dp_dz = torch.autograd.grad(p.sum(), z, create_graph=True)[0].squeeze(-1) - - # --- Second derivatives (Laplacian) --- - d2u1_dx2 = torch.autograd.grad(du1_dx.sum(), x, create_graph=True)[0].squeeze(-1) - d2u1_dy2 = torch.autograd.grad(du1_dy.sum(), y, create_graph=True)[0].squeeze(-1) - d2u1_dz2 = torch.autograd.grad(du1_dz.sum(), z, create_graph=True)[0].squeeze(-1) - - d2u2_dx2 = torch.autograd.grad(du2_dx.sum(), x, create_graph=True)[0].squeeze(-1) - d2u2_dy2 = torch.autograd.grad(du2_dy.sum(), y, create_graph=True)[0].squeeze(-1) - d2u2_dz2 = torch.autograd.grad(du2_dz.sum(), z, create_graph=True)[0].squeeze(-1) - - d2u3_dx2 = torch.autograd.grad(du3_dx.sum(), x, create_graph=True)[0].squeeze(-1) - d2u3_dy2 = torch.autograd.grad(du3_dy.sum(), y, create_graph=True)[0].squeeze(-1) - d2u3_dz2 = torch.autograd.grad(du3_dz.sum(), z, create_graph=True)[0].squeeze(-1) + p = mv[:, 0] + u1 = mv[:, 1] + u2 = mv[:, 2] + u3 = mv[:, 4] + w3_pred = mv[:, 3] + w2_pred = -mv[:, 5] + w1_pred = mv[:, 6] - # --- NS momentum residual --- - # R_i = ∂u_i/∂t + u_j·∂u_i/∂x_j + ∂p/∂x_i - ν∇²u_i - R1 = du1_dt + u1 * du1_dx + u2 * du1_dy + u3 * du1_dz + dp_dx - nu * (d2u1_dx2 + d2u1_dy2 + d2u1_dz2) - R2 = du2_dt + u1 * du2_dx + u2 * du2_dy + u3 * du2_dz + dp_dy - nu * (d2u2_dx2 + d2u2_dy2 + d2u2_dz2) - R3 = du3_dt + u1 * du3_dx + u2 * du3_dy + u3 * du3_dz + dp_dz - nu * (d2u3_dx2 + d2u3_dy2 + d2u3_dz2) + _grad = torch.autograd.grad - ns_residual = (R1 ** 2 + R2 ** 2 + R3 ** 2).mean() + # --- First derivatives — 4 backward passes (was 15) --- + gu1 = _grad(u1.sum(), xyzt, create_graph=True, retain_graph=True)[0] # [B,4] + gu2 = _grad(u2.sum(), xyzt, create_graph=True, retain_graph=True)[0] + gu3 = _grad(u3.sum(), xyzt, create_graph=True, retain_graph=True)[0] + gp = _grad(p.sum(), xyzt, create_graph=True, retain_graph=True)[0] - # --- Incompressibility (divergence-free) --- - div_u = du1_dx + du2_dy + du3_dz - div_residual = (div_u ** 2).mean() + du1_dx, du1_dy, du1_dz, du1_dt = gu1[:,0], gu1[:,1], gu1[:,2], gu1[:,3] + du2_dx, du2_dy, du2_dz, du2_dt = gu2[:,0], gu2[:,1], gu2[:,2], gu2[:,3] + du3_dx, du3_dy, du3_dz, du3_dt = gu3[:,0], gu3[:,1], gu3[:,2], gu3[:,3] + dp_dx, dp_dy, dp_dz = gp[:,0], gp[:,1], gp[:,2] + + # --- Second derivatives (Laplacian) — 9 backward passes (unchanged) --- + d2u1_dx2 = _grad(gu1[:,0].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,0] + d2u1_dy2 = _grad(gu1[:,1].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,1] + d2u1_dz2 = _grad(gu1[:,2].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,2] + + d2u2_dx2 = _grad(gu2[:,0].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,0] + d2u2_dy2 = _grad(gu2[:,1].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,1] + d2u2_dz2 = _grad(gu2[:,2].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,2] - # --- Vorticity consistency: ω_pred ≈ ∇×u --- - # curl(u) components: - curl_x = du3_dy - du2_dz # ω₁ - curl_y = du1_dz - du3_dx # ω₂ - curl_z = du2_dx - du1_dy # ω₃ + d2u3_dx2 = _grad(gu3[:,0].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,0] + d2u3_dy2 = _grad(gu3[:,1].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,1] + d2u3_dz2 = _grad(gu3[:,2].sum(), xyzt, create_graph=True, retain_graph=True)[0][:,2] - vort_consistency = ((w1_pred - curl_x) ** 2 + - (w2_pred - curl_y) ** 2 + - (w3_pred - curl_z) ** 2).mean() + # --- NS momentum residual --- + R1 = du1_dt + u1*du1_dx + u2*du1_dy + u3*du1_dz + dp_dx - nu*(d2u1_dx2+d2u1_dy2+d2u1_dz2) + R2 = du2_dt + u1*du2_dx + u2*du2_dy + u3*du2_dz + dp_dy - nu*(d2u2_dx2+d2u2_dy2+d2u2_dz2) + R3 = du3_dt + u1*du3_dx + u2*du3_dy + u3*du3_dz + dp_dz - nu*(d2u3_dx2+d2u3_dy2+d2u3_dz2) + ns_residual = (R1**2 + R2**2 + R3**2).mean() + + # --- Incompressibility --- + div_u = du1_dx + du2_dy + du3_dz + div_residual = (div_u**2).mean() - # --- Lagrangian energy balance: dE/dt + 2ν·Ω = 0 --- - # E = ½(u1² + u2² + u3²), Ω = ½(ω1² + ω2² + ω3²) (enstrophy) - E = 0.5 * (u1 ** 2 + u2 ** 2 + u3 ** 2) - dE_dt = torch.autograd.grad(E.sum(), t, create_graph=True)[0].squeeze(-1) - enstrophy = 0.5 * (w1_pred ** 2 + w2_pred ** 2 + w3_pred ** 2) - lagrangian = ((dE_dt + 2.0 * nu * enstrophy) ** 2).mean() + # --- Vorticity consistency --- + curl_x = du3_dy - du2_dz + curl_y = du1_dz - du3_dx + curl_z = du2_dx - du1_dy + vort_consistency = ((w1_pred-curl_x)**2 + (w2_pred-curl_y)**2 + (w3_pred-curl_z)**2).mean() + + # --- Lagrangian energy balance — 1 backward pass --- + E = 0.5 * (u1**2 + u2**2 + u3**2) + dE_dt = _grad(E.sum(), xyzt, create_graph=True)[0][:,3] + enstrophy = 0.5 * (w1_pred**2 + w2_pred**2 + w3_pred**2) + lagrangian = ((dE_dt + 2.0*nu*enstrophy)**2).mean() return { 'ns_residual': ns_residual, @@ -1117,11 +1097,15 @@ def train(args): mv_ic, intermediates = model(coords_grad[:1]) ic_loss = torch.tensor(0.0, device=device) - # Gauge covariance loss - mv_all, intermediates = model(coords_grad) - gauge_loss = compute_gauge_covariance_loss(mv_all, algebra) + # Gauge covariance loss — reuse mv from residual pass (collocation points) + # Avoids a third full-batch forward pass; gauge invariance sampled on + # collocation subset which is representative. + if colloc_mask.any(): + gauge_loss = compute_gauge_covariance_loss(residuals['mv'], algebra) + else: + gauge_loss = torch.tensor(0.0, device=device) - # Orthogonality loss + # Orthogonality loss (uses intermediates from IC pass above) ortho_loss = torch.tensor(0.0, device=device) if args.strict_ortho and intermediates: eff_weight = ortho.anneal_weight(epoch, @@ -1260,19 +1244,19 @@ def parse_args() -> argparse.Namespace: help='Maximum Reynolds number (curriculum target)') p.add_argument('--t-max', type=float, default=1.0, help='Maximum time') - p.add_argument('--num-collocation', type=int, default=3000, + p.add_argument('--num-collocation', type=int, default=2000, help='Number of collocation points') - p.add_argument('--num-ic', type=int, default=1000, + p.add_argument('--num-ic', type=int, default=500, help='Number of initial condition points') # Model - p.add_argument('--hidden-dim', type=int, default=64) - p.add_argument('--num-layers', type=int, default=6) + p.add_argument('--hidden-dim', type=int, default=32) + p.add_argument('--num-layers', type=int, default=4) # Training p.add_argument('--epochs', type=int, default=300) p.add_argument('--lr', type=float, default=0.001) - p.add_argument('--batch-size', type=int, default=256) + p.add_argument('--batch-size', type=int, default=128) p.add_argument('--seed', type=int, default=42) p.add_argument('--device', type=str, default='mps') diff --git a/experiments/yang_mills.py b/experiments/yang_mills.py index b0f82e9..0461a94 100644 --- a/experiments/yang_mills.py +++ b/experiments/yang_mills.py @@ -61,6 +61,41 @@ from functional.orthogonality import StrictOrthogonality, OrthogonalitySettings +# ============================================================================ # +# Vectorised Jacobian helper +# ============================================================================ # + +def _batch_jacobian( + output: torch.Tensor, + inputs: torch.Tensor, + create_graph: bool = True, + retain_graph: bool = True, +) -> torch.Tensor: + """Per-sample Jacobian via K VJPs. + + Args: + output: [B, K] — each row output[b,:] depends only on inputs[b,:]. + inputs: [B, D] leaf tensor with requires_grad=True. + create_graph: retain meta-graph for higher-order differentiation. + retain_graph: keep the original graph between successive calls. + + Returns: + jac: [B, K, D] where jac[b,k,d] = ∂output[b,k] / ∂inputs[b,d]. + """ + B, K = output.shape + D = inputs.shape[1] + jac = output.new_zeros(B, K, D) + for k in range(K): + (g,) = torch.autograd.grad( + output[:, k].sum(), + inputs, + create_graph=create_graph, + retain_graph=retain_graph or (k < K - 1), + ) + jac[:, k, :] = g + return jac + + # ============================================================================ # # 't Hooft Symbols and BPST Instanton # ============================================================================ # @@ -254,8 +289,8 @@ class YangMillsNet(nn.Module): Returns A_mu [B, 4, 8] and intermediates. """ - def __init__(self, algebra, hidden_dim: int = 64, num_layers: int = 6, - num_freqs: int = 32): + def __init__(self, algebra, hidden_dim: int = 32, num_layers: int = 4, + num_freqs: int = 16): super().__init__() self.algebra = algebra self.hidden_dim = hidden_dim @@ -335,9 +370,15 @@ def forward(self, coords: torch.Tensor): # Field Strength Computation # ============================================================================ # -def compute_field_strength(algebra, A_mu: torch.Tensor, - coords: torch.Tensor) -> Dict[Tuple[int, int], torch.Tensor]: - """Compute Yang-Mills field strength F_μν = ∂_μA_ν - ∂_νA_μ + [A_μ, A_ν]. +def compute_field_strength( + algebra, + A_mu: torch.Tensor, + coords: torch.Tensor, +) -> Tuple[Dict[Tuple[int, int], torch.Tensor], torch.Tensor]: + """Compute F_μν = ∂_μA_ν - ∂_νA_μ + [A_μ, A_ν] and the A-Jacobian. + + Uses a single batch of 32 VJPs (one per A_mu output component) instead of + 96 separate per-component grad calls, giving ~3× fewer autograd operations. Args: algebra: CliffordAlgebra(3, 0). @@ -345,43 +386,27 @@ def compute_field_strength(algebra, A_mu: torch.Tensor, coords: [B, 4] with requires_grad=True. Returns: - Dict mapping (mu, nu) pairs (mu < nu) to [B, 8] field strength bivectors. + F_dict: (mu, nu) → [B, 8] field-strength tensors. + jac_A: [B, 4, 8, 4] — jac_A[:, chan, comp, x_dim]. """ B = A_mu.shape[0] - F_dict = {} + # 32 VJPs to get the full Jacobian ∂A/∂x (was 96 per-component calls) + A_flat = A_mu.reshape(B, 32) # [B, 32] + jac_flat = _batch_jacobian(A_flat, coords, + create_graph=True, retain_graph=True) # [B, 32, 4] + jac_A = jac_flat.reshape(B, 4, 8, 4) # [B, chan, comp, x_dim] + + F_dict: Dict[Tuple[int, int], torch.Tensor] = {} for mu in range(4): for nu in range(mu + 1, 4): - # Abelian part: ∂_μA_ν - ∂_νA_μ - # Compute ∂A_ν/∂x_μ for each of the 8 components - dAnu_dxmu = torch.zeros(B, 8, device=A_mu.device, dtype=A_mu.dtype) - dAmu_dxnu = torch.zeros(B, 8, device=A_mu.device, dtype=A_mu.dtype) - - for comp in range(8): - # ∂A_ν[comp]/∂x_μ - grad_nu = torch.autograd.grad( - A_mu[:, nu, comp].sum(), coords, - create_graph=True, retain_graph=True - )[0] # [B, 4] - dAnu_dxmu[:, comp] = grad_nu[:, mu] - - # ∂A_μ[comp]/∂x_ν - grad_mu = torch.autograd.grad( - A_mu[:, mu, comp].sum(), coords, - create_graph=True, retain_graph=True - )[0] # [B, 4] - dAmu_dxnu[:, comp] = grad_mu[:, nu] - - abelian = dAnu_dxmu - dAmu_dxnu # [B, 8] - - # Non-abelian part: [A_μ, A_ν] = A_μ A_ν - A_ν A_μ + dAnu_dxmu = jac_A[:, nu, :, mu] # [B, 8] + dAmu_dxnu = jac_A[:, mu, :, nu] # [B, 8] AB = algebra.geometric_product(A_mu[:, mu], A_mu[:, nu]) BA = algebra.geometric_product(A_mu[:, nu], A_mu[:, mu]) - commutator = AB - BA # [B, 8] + F_dict[(mu, nu)] = (dAnu_dxmu - dAmu_dxnu) + (AB - BA) - F_dict[(mu, nu)] = abelian + commutator - - return F_dict + return F_dict, jac_A def hodge_dual_4d(F_dict: Dict[Tuple[int, int], torch.Tensor]) -> Dict[Tuple[int, int], torch.Tensor]: @@ -425,6 +450,11 @@ def compute_ym_losses(algebra, model: YangMillsNet, coords: torch.Tensor, rho: float) -> Dict[str, torch.Tensor]: """Compute all Yang-Mills loss terms. + Optimised autograd schedule: + • Field strength : 32 VJPs (was 96 per-component calls) + • F Jacobians : 48 VJPs (was 224 per-component calls for YM eq + Bianchi) + Total : 80 VJPs (was 320+) + Args: algebra: CliffordAlgebra(3, 0). model: YangMillsNet. @@ -441,8 +471,8 @@ def compute_ym_losses(algebra, model: YangMillsNet, coords: torch.Tensor, # --- Supervised loss on A --- supervised_loss = nn.functional.mse_loss(A_pred, A_exact) - # --- Field strength --- - F_dict = compute_field_strength(algebra, A_pred, coords) + # --- Field strength (32 VJPs via _batch_jacobian) --- + F_dict, _ = compute_field_strength(algebra, A_pred, coords) # --- Self-duality: F = *F --- F_dual = hodge_dual_4d(F_dict) @@ -469,83 +499,77 @@ def compute_ym_losses(algebra, model: YangMillsNet, coords: torch.Tensor, purity_loss = purity_loss + (residual ** 2).mean() purity_loss = purity_loss / 4.0 - # --- Gauge covariance: _H = _H --- + # --- Gauge covariance --- gauge_loss = _gauge_covariance_loss(algebra, F_dict) - # --- YM equation residual: D_μ F^μν = 0 --- - # Simplified: check that ∂_μ F^μν + [A_μ, F^μν] ≈ 0 for each ν + # --- Precompute F Jacobians once: 48 VJPs for all 6 pairs --- + # jac_F[(mu,nu)][b, comp, d] = ∂F_{μν}[b,comp] / ∂coords[b,d] + # retain_graph=True on all calls: the same forward graph (A_pred) is also + # needed for supervised_loss and commutator terms in loss.backward(). + # The graph is freed by loss.backward() at the end of the training step. + pairs = list(F_dict.keys()) # 6 pairs + jac_F: Dict[Tuple[int, int], torch.Tensor] = {} + for key in pairs: + jac_F[key] = _batch_jacobian( + F_dict[key], coords, + create_graph=True, + retain_graph=True, + ) # [B, 8, 4] + + _zeros8 = torch.zeros(coords.shape[0], 8, device=coords.device) + + # --- YM equation: D_μ F^μν = ∂_μ F^μν + [A_μ, F^μν] = 0 --- ym_loss = torch.tensor(0.0, device=coords.device) for nu_target in range(4): - residual_nu = torch.zeros(coords.shape[0], 8, device=coords.device) + residual_nu = torch.zeros_like(_zeros8) for mu in range(4): if mu == nu_target: continue - # Get F^μν (with sign convention) if mu < nu_target: - F_mn = F_dict.get((mu, nu_target), torch.zeros(coords.shape[0], 8, device=coords.device)) + key = (mu, nu_target) + sign = 1.0 else: - F_mn = -F_dict.get((nu_target, mu), torch.zeros(coords.shape[0], 8, device=coords.device)) - - # ∂F^μν/∂x_μ - for comp in range(8): - grad = torch.autograd.grad( - F_mn[:, comp].sum(), coords, - create_graph=True, retain_graph=True - )[0] # [B, 4] - residual_nu[:, comp] = residual_nu[:, comp] + grad[:, mu] - - # [A_μ, F^μν] - comm = algebra.geometric_product(A_pred[:, mu], F_mn) - \ - algebra.geometric_product(F_mn, A_pred[:, mu]) + key = (nu_target, mu) + sign = -1.0 + F_mn = F_dict.get(key, _zeros8) + jac = jac_F.get(key) + # ∂_μ F^μν: column mu of the Jacobian + dF_dxmu = sign * jac[:, :, mu] if jac is not None else _zeros8 + residual_nu = residual_nu + dF_dxmu + comm = (algebra.geometric_product(A_pred[:, mu], sign * F_mn) - + algebra.geometric_product(sign * F_mn, A_pred[:, mu])) residual_nu = residual_nu + comm - ym_loss = ym_loss + (residual_nu ** 2).mean() ym_loss = ym_loss / 4.0 - # --- Bianchi identity: D_{[μ} F_{νρ]} = 0 --- - # Check for one cyclic triple: (0,1,2) + # --- Bianchi identity: D_{[μ}F_{νρ]} = 0 --- bianchi_loss = torch.tensor(0.0, device=coords.device) triples = [(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)] for mu, nu, rho_idx in triples: - F_nu_rho = F_dict.get((nu, rho_idx), torch.zeros(coords.shape[0], 8, device=coords.device)) - F_rho_mu = F_dict.get((min(rho_idx, mu), max(rho_idx, mu)), - torch.zeros(coords.shape[0], 8, device=coords.device)) - if rho_idx > mu: - pass # already correct sign - else: - F_rho_mu = -F_rho_mu - F_mu_nu = F_dict.get((mu, nu), torch.zeros(coords.shape[0], 8, device=coords.device)) - - # D_μ F_{νρ} = ∂_μ F_{νρ} + [A_μ, F_{νρ}] - # Compute ∂_μ F_{νρ} via autograd - dF_nu_rho = torch.zeros_like(F_nu_rho) - dF_rho_mu = torch.zeros_like(F_rho_mu) - dF_mu_nu = torch.zeros_like(F_mu_nu) - for comp in range(F_nu_rho.shape[1]): - g1 = torch.autograd.grad( - F_nu_rho[:, comp].sum(), coords, - create_graph=True, retain_graph=True - )[0] - dF_nu_rho[:, comp] = g1[:, mu] - - g2 = torch.autograd.grad( - F_rho_mu[:, comp].sum(), coords, - create_graph=True, retain_graph=True - )[0] - dF_rho_mu[:, comp] = g2[:, nu] - - g3 = torch.autograd.grad( - F_mu_nu[:, comp].sum(), coords, - create_graph=True, retain_graph=True - )[0] - dF_mu_nu[:, comp] = g3[:, rho_idx] - - comm1 = algebra.geometric_product(A_pred[:, mu], F_nu_rho) - \ - algebra.geometric_product(F_nu_rho, A_pred[:, mu]) - comm2 = algebra.geometric_product(A_pred[:, nu], F_rho_mu) - \ - algebra.geometric_product(F_rho_mu, A_pred[:, nu]) - comm3 = algebra.geometric_product(A_pred[:, rho_idx], F_mu_nu) - \ - algebra.geometric_product(F_mu_nu, A_pred[:, rho_idx]) + F_nu_rho = F_dict.get((nu, rho_idx), _zeros8) + j_nu_rho = jac_F.get((nu, rho_idx)) + + key_rmu = (min(rho_idx, mu), max(rho_idx, mu)) + sign_rmu = 1.0 if rho_idx < mu else -1.0 + F_rho_mu_raw = F_dict.get(key_rmu, _zeros8) + j_rho_mu_raw = jac_F.get(key_rmu) + F_rho_mu = sign_rmu * F_rho_mu_raw + + F_mu_nu = F_dict.get((mu, nu), _zeros8) + j_mu_nu = jac_F.get((mu, nu)) + + # ∂_μ F_{νρ}, ∂_ν F_{ρμ}, ∂_ρ F_{μν} — reuse precomputed Jacobians + dF_nu_rho = j_nu_rho[:, :, mu] if j_nu_rho is not None else _zeros8 + dF_rho_mu = (sign_rmu * j_rho_mu_raw[:, :, nu] + if j_rho_mu_raw is not None else _zeros8) + dF_mu_nu = j_mu_nu[:, :, rho_idx] if j_mu_nu is not None else _zeros8 + + comm1 = (algebra.geometric_product(A_pred[:, mu], F_nu_rho) - + algebra.geometric_product(F_nu_rho, A_pred[:, mu])) + comm2 = (algebra.geometric_product(A_pred[:, nu], F_rho_mu) - + algebra.geometric_product(F_rho_mu, A_pred[:, nu])) + comm3 = (algebra.geometric_product(A_pred[:, rho_idx], F_mu_nu) - + algebra.geometric_product(F_mu_nu, A_pred[:, rho_idx])) bianchi_residual = (dF_nu_rho + comm1) + (dF_rho_mu + comm2) + (dF_mu_nu + comm3) bianchi_loss = bianchi_loss + (bianchi_residual ** 2).mean() @@ -1272,18 +1296,18 @@ def parse_args() -> argparse.Namespace: help='Instanton size parameter') p.add_argument('--sampling-radius', type=float, default=5.0, help='Maximum sampling radius around instanton core') - p.add_argument('--num-train', type=int, default=3000) - p.add_argument('--num-test', type=int, default=500) + p.add_argument('--num-train', type=int, default=2000) + p.add_argument('--num-test', type=int, default=300) # Model - p.add_argument('--hidden-dim', type=int, default=64) - p.add_argument('--num-layers', type=int, default=6) - p.add_argument('--num-freqs', type=int, default=32) + p.add_argument('--hidden-dim', type=int, default=32) + p.add_argument('--num-layers', type=int, default=4) + p.add_argument('--num-freqs', type=int, default=16) # Training p.add_argument('--epochs', type=int, default=300) p.add_argument('--lr', type=float, default=0.001) - p.add_argument('--batch-size', type=int, default=128) + p.add_argument('--batch-size', type=int, default=64) p.add_argument('--seed', type=int, default=42) p.add_argument('--device', type=str, default='cpu') diff --git a/models/gtm/analysis.py b/models/gtm/analysis.py index 02328b6..3f17a4f 100644 --- a/models/gtm/analysis.py +++ b/models/gtm/analysis.py @@ -90,7 +90,7 @@ def from_checkpoint(path: str, device: str = 'cpu') -> 'GTMAnalyzer': attn_head_dim=attn_cfg.get('head_dim', 8), num_rule_slots=mcfg.get('num_rule_slots', 8), ) - model.load_state_dict(checkpoint['model_state_dict']) + model.load_state_dict(checkpoint['model_state_dict'], strict=False) return GTMAnalyzer(model, device) # ------------------------------------------------------------------ @@ -184,7 +184,7 @@ def analyze_temperature(self) -> dict: """ temps = [] for step in self.model.vm.steps: - tau = step.search.log_temperature.exp().clamp(0.1, 5.0) + tau = step.search._temperature.clamp(0.1, 5.0) temps.append(tau.item()) temps_t = torch.tensor(temps) @@ -218,8 +218,6 @@ def analyze(self, batch: dict) -> dict: 'grid_correct': [B] bool per example 'test_masks': [B, N_test] validity mask """ - num_steps = self.model.vm.num_steps - # Run full forward with trace with torch.no_grad(): result = self._run_forward(batch) @@ -228,9 +226,13 @@ def analyze(self, batch: dict) -> dict: preds = logits.argmax(dim=-1) trace = result['trace'] - # Split trace into Phase 1 and Phase 2 - phase1_trace = {k: v[:num_steps] for k, v in trace.items()} - phase2_trace = {k: v[num_steps:] for k, v in trace.items()} + # Split trace into Phase 1 (demo) and Phase 2 (test). + # When ACT is enabled, each VM call produces max_steps entries; + # when disabled, num_steps entries. Both phases use the same mode. + vm = self.model.vm + steps_per_phase = vm.max_steps if vm.use_act else vm.num_steps + phase1_trace = {k: v[:steps_per_phase] for k, v in trace.items()} + phase2_trace = {k: v[steps_per_phase:] for k, v in trace.items()} # Targets test_outputs = batch['test_outputs'].to(self.device) @@ -354,7 +356,11 @@ def format_cursor_report(self, report: dict) -> str: return '\n'.join(lines) def format_search_report(self, report: dict) -> str: - """Human-readable hypothesis selection summary.""" + """Human-readable hypothesis selection summary. + + Handles both per-cell weights [B, N, K] (v4.1+) and legacy + global weights [B, K] via dimension check. + """ lines = ['=== Hypothesis Selection ===', ''] for phase_name, phase_key in [('Phase 1', 'phase1'), ('Phase 2', 'phase2')]: @@ -364,9 +370,20 @@ def format_search_report(self, report: dict) -> str: lines.append(f'{phase_name}:') for t, w in enumerate(weights_list): w0 = w[0] # first batch element - dominant = w0.argmax().item() - w_str = ' '.join(f'H{k}={w0[k]:.3f}' for k in range(w0.shape[0])) - lines.append(f' Step {t}: [{w_str}] dominant=H{dominant}') + if w0.dim() == 2: + # Per-cell weights: [N, K] + K = w0.shape[-1] + mean_w = w0.mean(dim=0) # [K] + dominant_per_cell = w0.argmax(dim=-1) # [N] + hist = torch.bincount(dominant_per_cell, minlength=K) + mean_str = ' '.join(f'H{k}={mean_w[k]:.3f}' for k in range(K)) + hist_str = ' '.join(f'H{k}:{hist[k].item()}' for k in range(K)) + lines.append(f' Step {t}: mean=[{mean_str}] cells=[{hist_str}]') + else: + # Legacy global weights: [K] + dominant = w0.argmax().item() + w_str = ' '.join(f'H{k}={w0[k]:.3f}' for k in range(w0.shape[0])) + lines.append(f' Step {t}: [{w_str}] dominant=H{dominant}') lines.append('') return '\n'.join(lines) diff --git a/models/gtm/control_plane.py b/models/gtm/control_plane.py index b6c8701..d827149 100644 --- a/models/gtm/control_plane.py +++ b/models/gtm/control_plane.py @@ -58,6 +58,21 @@ def __init__(self, algebra_ctrl: CliffordAlgebra, channels: int, nn.Linear(32, 1), ) + # Residual correction for boost-invariant components (scalar, e34) + # Sandwich product with e34 bivector only boosts grade-1 (e3, e4); + # scalar and pseudoscalar are algebraically invariant. This MLP + # provides a learned additive update so those components can evolve. + # Outputs 2 values: (delta_scalar, delta_e34), NOT all 4 components, + # to avoid double-counting with the boost on e3/e4. + self.cursor_residual = nn.Sequential( + nn.Linear(channels + 4, 32), + nn.Tanh(), + nn.Linear(32, 2), + ) + # Initialize near-zero so early training is dominated by the boost + nn.init.zeros_(self.cursor_residual[-1].weight) + nn.init.zeros_(self.cursor_residual[-1].bias) + # Halt signal from cursor self.halt_mlp = nn.Sequential( nn.Linear(4, 16), @@ -116,6 +131,16 @@ def step(self, cursor: torch.Tensor, gate = torch.sigmoid(direction_logit) # [B, 1] new_cursor = gate * cursor_h + (1.0 - gate) * cursor_v # [B, 4] + # Residual correction: only update boost-invariant components + delta = self.cursor_residual(combined) # [B, 2] -> (delta_scalar, delta_e34) + new_cursor = new_cursor.clone() + new_cursor[:, 0] = new_cursor[:, 0] + delta[:, 0] # scalar + new_cursor[:, 3] = new_cursor[:, 3] + delta[:, 1] # e34 + + # Symmlog normalization: prevents unbounded drift across steps + # while preserving gradient (grad = 1/(1+|x|), never zero) + new_cursor = torch.sign(new_cursor) * torch.log1p(new_cursor.abs()) + # Halt probability from grade-0 of cursor halt_prob = torch.sigmoid(self.halt_mlp(new_cursor)).squeeze(-1) # [B] diff --git a/models/gtm/cpu.py b/models/gtm/cpu.py index be489a7..2182e02 100644 --- a/models/gtm/cpu.py +++ b/models/gtm/cpu.py @@ -5,17 +5,7 @@ # you may not use this file except in compliance with the License. # -"""PGA Motor CPU + ColorUnit — Cl(3,0,1) computation engine. - -Core operations (three-part transform): - Part A — Motor transform: M = exp(-grade_2(instr)/2), X' = MXM~ - The 6 bivectors split into 3 rotation (e01,e02,e12) and - 3 translation (e03,e13,e23) components. The parabolic exp branch - in core/algebra.py handles null bivectors: exp(t*e03) = 1 + t*e03. - Part B — ColorUnit: discrete color remapping conditioned on instruction - K_color learnable tables [K_color, 10, 10], selected by grade-0 + grade-4. - Part C — Merge: spatial from motor, color from ColorUnit. -""" +"""PGA Cl(3,0,1) computation engine: motor sandwich + color remapping.""" import torch import torch.nn as nn @@ -24,78 +14,81 @@ class ColorUnit(nn.Module): - """Discrete color remapping conditioned on instruction. + """Position-conditioned discrete color remapping via K blended [10, 10] tables. - K_color learnable remapping tables [K_color, 10, 10]. - Instruction's grade-0 and grade-4 select and blend tables. + Table selection conditioned on per-cell spatial features (grade-1 post-motor + position) plus instruction grade-0/grade-4. This breaks the fundamental + GA bottleneck where the motor sandwich product leaves grade-0 (color) + invariant — by conditioning the remap on post-motor position, each cell + can receive a different color transformation. + + Spatial indices used: + idx 1 (e0): row, idx 2 (e1): col, idx 8 (e3): homogeneous coord """ + # Grade-1 spatial component indices in Cl(3,0,1) + _SPATIAL_IDX = [1, 2, 8] # e0(row), e1(col), e3(homo) + def __init__(self, K_color: int = 4): super().__init__() self.K_color = K_color - # Initialize as near-identity: eye(10) + small noise per table self.remap_tables = nn.Parameter( torch.eye(10).unsqueeze(0).expand(K_color, -1, -1).clone() + torch.randn(K_color, 10, 10) * 0.01 ) - # Selector: grade-0 (idx 0) + grade-4 (idx 15) → table weights - self.selector = nn.Linear(2, K_color) + # 2 (instruction g0 + g4) + 3 (cell spatial) = 5 inputs + self.selector = nn.Linear(5, K_color) def forward(self, state: torch.Tensor, instruction: torch.Tensor) -> torch.Tensor: - """Apply color remapping to grade-0 and update occupancy. + """Apply position-conditioned color remapping. Args: - state: [L, N, 16] PGA multivectors after motor transform. - instruction: [L, 16] instruction multivectors. - - Returns: - [L, N, 16] state with grade-0 (color) and grade-4 (occupancy) updated. + state: [L, N, 16] after motor transform. + instruction: [L, 16]. """ L, N, D = state.shape - # Extract selector features from instruction - sel_input = torch.stack([instruction[:, 0], instruction[:, 15]], dim=-1) # [L, 2] - table_weights = F.softmax(self.selector(sel_input), dim=-1) # [L, K_color] + # Per-cell spatial features from post-motor grade-1 components + cell_spatial = state[:, :, self._SPATIAL_IDX] # [L, N, 3] - # Blend remapping tables: [L, 10, 10] - # table_weights: [L, K] @ remap_tables: [K, 10, 10] -> [L, 10, 10] - blended = torch.einsum('lk,kij->lij', table_weights, self.remap_tables) + # Instruction features broadcast to every cell + instr_feat = torch.stack( + [instruction[:, 0], instruction[:, 15]], dim=-1, + ).unsqueeze(1).expand(L, N, 2) # [L, N, 2] - # Extract current color: grade-0 → soft 10-class - raw_color = state[:, :, 0] * 9.0 # [L, N] in [0, 9] range - # Create soft one-hot via distance to each integer class - centers = torch.arange(10, device=state.device, dtype=state.dtype) # [10] - # Soft assignment: exp(-4 * (color - center)^2) - diffs = raw_color.unsqueeze(-1) - centers # [L, N, 10] - soft_color = F.softmax(-4.0 * diffs.pow(2), dim=-1) # [L, N, 10] + sel_input = torch.cat([instr_feat, cell_spatial], dim=-1) # [L, N, 5] + table_weights = F.softmax(self.selector(sel_input), dim=-1) # [L, N, K] + + # Per-cell blended remap table + blended = torch.einsum( + 'lnk,kij->lnij', table_weights, self.remap_tables, + ) # [L, N, 10, 10] - # Apply remapping: [L, N, 10] @ [L, 10, 10] -> [L, N, 10] - remapped = torch.bmm( - soft_color.reshape(L, N, 10), - blended - ) # [L, N, 10] + raw_color = state[:, :, 0] * 9.0 + centers = torch.arange(10, device=state.device, dtype=state.dtype) + diffs = raw_color.unsqueeze(-1) - centers + soft_color = F.softmax(-4.0 * diffs.pow(2), dim=-1) # [L, N, 10] - # Convert back to scalar: expected value / 9.0 - new_color = torch.einsum('lni,i->ln', remapped, centers) / 9.0 # [L, N] + # Per-cell remap: [L, N, 1, 10] @ [L, N, 10, 10] -> [L, N, 1, 10] + remapped = torch.matmul( + soft_color.unsqueeze(2), blended, + ).squeeze(2) # [L, N, 10] - # Update occupancy flag (grade-4 pseudoscalar idx 15) - new_occupancy = 1.0 - remapped[:, :, 0] # prob of NOT being color 0 + new_color = torch.einsum('lni,i->ln', remapped, centers) / 9.0 + new_occupancy = 1.0 - remapped[:, :, 0] - # Construct output: only modify grade-0 and grade-4 out = state.clone() out[:, :, 0] = new_color out[:, :, 15] = new_occupancy - return out class GeometricCPU(nn.Module): - """PGA Cl(3,0,1) computation engine with Motor + ColorUnit. + """PGA Cl(3,0,1) computation engine. - The motor transform handles both rotation (e01,e02,e12 bivectors) - and translation (e03,e13,e23 null bivectors) via a single sandwich product. - The ColorUnit handles discrete color remapping. + Bivectors e01/e02/e12 produce rotations; null bivectors e03/e13/e23 + produce translations. Both composed into a single motor via exp map. """ def __init__(self, algebra_cpu: CliffordAlgebra, K_color: int = 4): @@ -106,21 +99,12 @@ def __init__(self, algebra_cpu: CliffordAlgebra, K_color: int = 4): self.color_unit = ColorUnit(K_color) def _transform(self, state: torch.Tensor, instruction: torch.Tensor) -> torch.Tensor: - """Core transform: PGA motor sandwich + color remapping. - - Args: - state: [L, N, 16] — L can be B (single) or B*K (batched). - instruction: [L, 16]. - - Returns: - [L, N, 16] transformed state. - """ + """Motor sandwich + color remapping. [L, N, 16] -> [L, N, 16].""" L, N, D = state.shape - # Part A: Motor Transform (rotation + translation via PGA sandwich) - bv = self.algebra.grade_projection(instruction, 2) # [L, 16] - M = self.algebra.exp(-0.5 * bv) # [L, 16] — motor (rotation + translation) - M_rev = self.algebra.reverse(M) # [L, 16] + bv = self.algebra.grade_projection(instruction, 2) + M = self.algebra.exp(-0.5 * bv) + M_rev = self.algebra.reverse(M) M_exp = M.unsqueeze(1).expand(L, N, D).reshape(L * N, D) M_rev_exp = M_rev.unsqueeze(1).expand(L, N, D).reshape(L * N, D) @@ -130,45 +114,22 @@ def _transform(self, state: torch.Tensor, instruction: torch.Tensor) -> torch.Te M_exp, state_flat, M_rev_exp ).reshape(L, N, D) - # Part B: Color Remapping (grade-0 and grade-4 only) - color_out = self.color_unit(spatial_out, instruction) - - return color_out + return self.color_unit(spatial_out, instruction) def execute(self, state: torch.Tensor, instruction: torch.Tensor) -> torch.Tensor: - """Apply PGA Motor + ColorUnit to state. - - Args: - state: CPU state [B, N, 16] — per-cell multivectors. - instruction: Instruction multivector [B, 16]. - - Returns: - New state [B, N, 16]. - """ + """Apply transform to [B, N, 16] state with [B, 16] instruction.""" self.algebra.ensure_device(state.device) return self._transform(state, instruction) def execute_all(self, state: torch.Tensor, instructions: torch.Tensor) -> torch.Tensor: - """Execute K instructions in a single batched call. - - Reshapes [B, N, 16] x [B, K, 16] into [B*K, N, 16] x [B*K, 16], - runs one _transform call, then reshapes back to [B, K, N, 16]. - - Args: - state: CPU state [B, N, 16]. - instructions: K instruction multivectors [B, K, 16]. - - Returns: - Tensor [B, K, N, 16] — all K outcomes stacked. - """ + """Execute K instructions batched. [B,N,16] x [B,K,16] -> [B,K,N,16].""" B, N, D = state.shape K = instructions.shape[1] self.algebra.ensure_device(state.device) - # Expand state for all K instructions: [B, K, N, D] -> [B*K, N, D] state_exp = state.unsqueeze(1).expand(B, K, N, D).reshape(B * K, N, D) instr_flat = instructions.reshape(B * K, D) - result = self._transform(state_exp, instr_flat) # [B*K, N, D] + result = self._transform(state_exp, instr_flat) return result.reshape(B, K, N, D) diff --git a/models/gtm/grid_codec.py b/models/gtm/grid_codec.py index 06c7348..aa7e72d 100644 --- a/models/gtm/grid_codec.py +++ b/models/gtm/grid_codec.py @@ -5,36 +5,34 @@ # you may not use this file except in compliance with the License. # -"""Deterministic ARC grid <-> Cl(3,0,1) PGA multivector codec. - -Grids are kept as 2D tensors [H, W] (or padded [B, H_max, W_max]). -Row and column are directly read from 2D indices — no flattening required. - -PGA encoding (each cell -> 1 multivector in Cl(3,0,1), dim=16): - Grade-0 (scalar, idx 0): color / 9.0 (invariant under all sandwich products) - Grade-1 (vectors): - idx 1 (e0): row (integer) — spatial position - idx 2 (e1): col (integer) — spatial position - idx 4 (e2): 0.0 — reserved for role embed / auxiliary - idx 8 (e3): 1.0 — homogeneous coord (enables PGA translation) - Grade-2 (bivectors): - idx 3 (e01): row * col — spatial correlation - Grade-4 (pseudoscalar): - idx 15 (e0123): 1.0 if non-background (color!=0), else 0.0 - -Integer coordinates: no max_grid_size normalization. CliffordLayerNorm -in the VM handles normalization across steps. -""" +"""Deterministic ARC grid <-> Cl(3,0,1) PGA multivector codec.""" import torch from core.algebra import CliffordAlgebra class GridCodec: - """Deterministic encoder/decoder for ARC grids. No learnable parameters. - - Operates on 2D grids [H, W] or batched [B, H_max, W_max] with masks. - Uses PGA Cl(3,0,1) with dim=16. + """Deterministic encoder/decoder for ARC grids in PGA Cl(3,0,1). + + Proper PGA point encoding — each cell is a grade-1 element plus + scalar color and pseudoscalar occupancy: + + idx 0 (1): color / 9.0 (grade 0 — motor-invariant) + idx 1 (e0): row / (H-1) (grade 1 — motor-transformable) + idx 2 (e1): col / (W-1) (grade 1 — motor-transformable) + idx 4 (e2): reserved (role embed) + idx 8 (e3): 1.0 (grade 1 — homogeneous coord) + idx 15 (e0123): occupancy flag (grade 4) + + Coordinates are normalized to [0, 1] by grid dimensions, then scaled + by coord_scale. This balances energy with color (also [0, 1]) and + makes the motor operate in a grid-size-invariant coordinate system. + + Note: the old encoding placed row*col in idx 3 (e01, grade 2). + This was removed because (a) it's not a proper PGA point component + — points are grade-1 in PGA, (b) it dominated 89% of encoding energy, + drowning out the color signal, (c) it confused the motor which treats + grade-2 as bivectors (lines/planes). """ def __init__(self, algebra_cpu: CliffordAlgebra, coord_scale: float = 1.0): @@ -44,55 +42,31 @@ def __init__(self, algebra_cpu: CliffordAlgebra, coord_scale: float = 1.0): self.coord_scale = coord_scale def encode_grid(self, grid: torch.Tensor) -> torch.Tensor: - """Encode a single 2D grid into multivectors. - - Args: - grid: Integer grid [H, W] with values in [0, 9]. - - Returns: - Multivectors [H, W, 16] in Cl(3,0,1). - """ + """Encode a single [H, W] grid into [H, W, 16] multivectors.""" H, W = grid.shape device = grid.device cs = self.coord_scale mv = torch.zeros(H, W, 16, device=device, dtype=torch.float32) colors = grid.float() - - # Row and col coordinate grids (integer, no normalization) rows = torch.arange(H, device=device).float().unsqueeze(1).expand(H, W) cols = torch.arange(W, device=device).float().unsqueeze(0).expand(H, W) - # Grade-0 (idx 0): normalized color - mv[:, :, 0] = colors / 9.0 - - # Grade-1 (idx 1=e0, 2=e1, 4=e2, 8=e3): spatial position + homogeneous - mv[:, :, 1] = rows * cs - mv[:, :, 2] = cols * cs - # idx 4 (e2) left zero — reserved for auxiliary features / role embed - mv[:, :, 8] = 1.0 # e3 homogeneous coord (enables PGA translations) - - # Grade-2 (idx 3=e01): spatial correlation - mv[:, :, 3] = rows * cols * (cs * cs) + # Normalize coords to [0, 1] by grid dims for energy balance + row_norm = rows / max(H - 1, 1) + col_norm = cols / max(W - 1, 1) - # Grade-4 pseudoscalar (idx 15=e0123): occupancy flag + mv[:, :, 0] = colors / 9.0 + mv[:, :, 1] = row_norm * cs + mv[:, :, 2] = col_norm * cs + mv[:, :, 8] = 1.0 mv[:, :, 15] = (colors > 0).float() return mv def encode_batch(self, grids: torch.Tensor, masks: torch.Tensor) -> tuple: - """Encode a batch of padded 2D grids into flat multivector sequences. - - Args: - grids: Padded grids [B, H_max, W_max] (long). - masks: Validity masks [B, H_max, W_max] (bool). - - Returns: - Tuple of: - mv: [B, N_max, 16] flattened multivectors (N_max = H_max * W_max) - flat_masks: [B, N_max] bool - """ + """Encode padded [B, H_max, W_max] grids into [B, N_max, 16] multivectors.""" B, H_max, W_max = grids.shape N_max = H_max * W_max device = grids.device @@ -102,34 +76,25 @@ def encode_batch(self, grids: torch.Tensor, rows = torch.arange(H_max, device=device).float().view(1, H_max, 1).expand(B, H_max, W_max) cols = torch.arange(W_max, device=device).float().view(1, 1, W_max).expand(B, H_max, W_max) + # Normalize coords to [0, 1] by grid dims for energy balance + row_norm = rows / max(H_max - 1, 1) + col_norm = cols / max(W_max - 1, 1) + mv = torch.zeros(B, H_max, W_max, 16, device=device, dtype=torch.float32) mv[:, :, :, 0] = colors / 9.0 - mv[:, :, :, 1] = rows * cs - mv[:, :, :, 2] = cols * cs - mv[:, :, :, 8] = 1.0 # e3 homogeneous coord - mv[:, :, :, 3] = rows * cols * (cs * cs) + mv[:, :, :, 1] = row_norm * cs + mv[:, :, :, 2] = col_norm * cs + mv[:, :, :, 8] = 1.0 mv[:, :, :, 15] = (colors > 0).float() - # Zero out padding cells mv = mv * masks.unsqueeze(-1).float() - - # Flatten spatial dims: [B, H_max, W_max, 16] -> [B, N_max, 16] mv = mv.reshape(B, N_max, 16) flat_masks = masks.reshape(B, N_max) return mv, flat_masks def decode(self, mv: torch.Tensor, H: int, W: int) -> torch.Tensor: - """Decode multivectors back to a 2D grid. - - Args: - mv: Multivectors [H*W, 16] or [H, W, 16]. - H: Grid height. - W: Grid width. - - Returns: - Integer grid [H, W] with values in [0, 9]. - """ + """Decode [H*W, 16] multivectors back to [H, W] integer grid.""" flat = mv.reshape(-1, 16) colors = flat[:H * W, 0] * 9.0 colors = colors.round().long().clamp(0, 9) diff --git a/models/gtm/gtm_net.py b/models/gtm/gtm_net.py index 4415f93..a8aba98 100644 --- a/models/gtm/gtm_net.py +++ b/models/gtm/gtm_net.py @@ -219,6 +219,10 @@ def forward(self, demo_inputs: torch.Tensor, demo_outputs: torch.Tensor, return result + def set_temperature(self, tau: float): + """Set Gumbel-Softmax temperature for all VM steps.""" + self.vm.set_temperature(tau) + def freeze_vm(self): """Freeze all VM parameters (Phase 1: warmup).""" for param in self.vm.parameters(): diff --git a/models/gtm/rule_memory.py b/models/gtm/rule_memory.py index f9e28cc..604f799 100644 --- a/models/gtm/rule_memory.py +++ b/models/gtm/rule_memory.py @@ -5,14 +5,7 @@ # you may not use this file except in compliance with the License. # -"""Rule Memory Bank — cross-attention aggregator for demo→test information flow. - -Compresses Phase 1 (demo) CPU state into M learnable rule slots via -cross-attention. This replaces the 4-float ctrl_cursor bottleneck as the -primary information bridge between demo and test phases. - -Information capacity: M=8 slots * 16 dims = 128 floats (vs 4 floats before). -""" +"""Rule memory bank: cross-attention aggregator for demo-to-test information flow.""" import torch import torch.nn as nn @@ -20,69 +13,42 @@ class RuleAggregator(nn.Module): - """Cross-attention from M learnable queries to demo cpu_state. + """Compresses demo CPU state into M rule slots via cross-attention. - Compresses Phase 1 output into M rule slots that encode the - transformation rule learned from demo pairs. + Learnable query templates attend over demo cells; values are raw + multivectors to preserve geometric structure. """ def __init__(self, d_cpu: int = 16, num_slots: int = 8, num_heads: int = 4): super().__init__() - self.d_cpu = d_cpu self.num_slots = num_slots self.num_heads = num_heads self.head_dim = d_cpu // num_heads - assert d_cpu % num_heads == 0, f"d_cpu={d_cpu} must be divisible by num_heads={num_heads}" + assert d_cpu % num_heads == 0 self.scale = self.head_dim ** -0.5 - - # Learnable query templates self.query_templates = nn.Parameter(torch.randn(num_slots, d_cpu) * 0.02) - - # Projections for cross-attention self.q_proj = nn.Linear(d_cpu, d_cpu) self.k_proj = nn.Linear(d_cpu, d_cpu) - # V = raw demo state (no projection — preserves geometric structure) def forward(self, demo_cpu_state: torch.Tensor, demo_mask: torch.Tensor) -> torch.Tensor: - """Aggregate demo state into rule memory slots. - - Args: - demo_cpu_state: [B, N_demo, 16] CPU state after Phase 1. - demo_mask: [B, N_demo] bool (True=valid). - - Returns: - rule_memory: [B, M, 16] compressed rule representation. - """ + """[B, N_demo, 16] -> [B, M, 16] rule memory slots.""" B, N_demo, D = demo_cpu_state.shape - M = self.num_slots - H = self.num_heads - hd = self.head_dim + M, H, hd = self.num_slots, self.num_heads, self.head_dim - # Query from learnable templates: [M, D] -> [B, M, D] Q = self.q_proj(self.query_templates).unsqueeze(0).expand(B, -1, -1) - # Key from demo state: [B, N_demo, D] K = self.k_proj(demo_cpu_state) - # Multi-head reshape - Q = Q.reshape(B, M, H, hd).transpose(1, 2) # [B, H, M, hd] - K = K.reshape(B, N_demo, H, hd).transpose(1, 2) # [B, H, N_demo, hd] + Q = Q.reshape(B, M, H, hd).transpose(1, 2) + K = K.reshape(B, N_demo, H, hd).transpose(1, 2) - # Attention scores - scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [B, H, M, N_demo] + scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale - # Mask invalid demo cells if demo_mask is not None: - pad_mask = ~demo_mask # [B, N_demo] scores = scores.masked_fill( - pad_mask.unsqueeze(1).unsqueeze(2), float('-inf') + (~demo_mask).unsqueeze(1).unsqueeze(2), float('-inf') ) - attn = F.softmax(scores, dim=-1) # [B, H, M, N_demo] - attn_avg = attn.mean(dim=1) # [B, M, N_demo] - - # Values: raw demo state (preserves geometric structure) - rule_memory = torch.bmm(attn_avg, demo_cpu_state) # [B, M, 16] - - return rule_memory + attn = F.softmax(scores, dim=-1).mean(dim=1) + return torch.bmm(attn, demo_cpu_state) diff --git a/models/gtm/superposition.py b/models/gtm/superposition.py index dba3470..a927800 100644 --- a/models/gtm/superposition.py +++ b/models/gtm/superposition.py @@ -5,14 +5,7 @@ # you may not use this file except in compliance with the License. # -"""Geometric Superposition Search — simplified scoring via CPU grade norms. - -Scores K instruction hypotheses using CPU state grade norms + ctrl_cursor, -dispatches trainable instruction templates (optionally modulated by rule memory) -to the PGA CPU, executes K outcomes in parallel, and selects via Gumbel-Softmax. - -Mother algebra is no longer needed — scoring uses CPU grade norms directly. -""" +"""Geometric Superposition Search: score, dispatch, execute, select.""" import torch import torch.nn as nn @@ -22,13 +15,9 @@ class GeometricSuperpositionSearch(nn.Module): - """Geometric Superposition Search over CPU Cl(3,0,1). - - Trainable parameters: - instruction_templates: [K, 16] full Cl(3,0,1) multivectors - score_mlp: CPU grade norms + ctrl_cursor -> K scores - rule_proj: rule_memory -> per-template modulation (if rule_memory provided) - log_temperature: Gumbel-Softmax temperature (learnable) + """Scores K hypotheses via CPU grade norms, executes PGA motor transforms + in parallel, and selects via Gumbel-Softmax. Instruction templates are + optionally modulated by rule memory. """ def __init__(self, algebra_cpu: CliffordAlgebra, @@ -46,82 +35,73 @@ def __init__(self, algebra_cpu: CliffordAlgebra, self.num_hypotheses = num_hypotheses self.top_k = top_k - D_cpu = algebra_cpu.dim # 16 + D_cpu = algebra_cpu.dim - # CPU engine (has ColorUnit params) self.pga_cpu = GeometricCPU(algebra_cpu, K_color) - - # Trainable instruction templates — full Cl(3,0,1) multivectors self.instruction_templates = nn.Parameter( - torch.randn(num_hypotheses, D_cpu) * 0.01 + torch.randn(num_hypotheses, D_cpu) * 0.1 ) - - # Scoring MLP: CPU grade norms + ctrl_cursor -> K scores - cpu_grades = algebra_cpu.num_grades # 5 for Cl(3,0,1) self.score_mlp = nn.Sequential( - nn.Linear(cpu_grades + algebra_ctrl.dim, 64), + nn.Linear(algebra_cpu.num_grades + algebra_ctrl.dim, 64), nn.ReLU(), nn.Linear(64, num_hypotheses), ) + # Per-cell routing: each cell scores hypotheses independently + self.cell_router = nn.Linear(D_cpu, num_hypotheses) + # Rule memory bias on hypothesis scores + self.rule_score_proj = nn.Linear(D_cpu, num_hypotheses) + + # Small-weight init so initial behavior ≈ old global-only scoring + nn.init.normal_(self.cell_router.weight, std=0.01) + nn.init.zeros_(self.cell_router.bias) + nn.init.normal_(self.rule_score_proj.weight, std=0.01) + nn.init.zeros_(self.rule_score_proj.bias) - # Rule-conditioned instruction modulation self.rule_proj = nn.Linear(D_cpu, num_hypotheses * D_cpu) + self.register_buffer('_temperature', torch.tensor(float(temperature_init))) - # Gumbel temperature (learnable) - self.log_temperature = nn.Parameter( - torch.tensor(float(torch.tensor(temperature_init).log())) - ) + def set_temperature(self, tau: float): + """Set Gumbel-Softmax temperature (called by external annealing schedule).""" + self._temperature.fill_(tau) def step(self, cpu_state: torch.Tensor, ctrl_cursor: torch.Tensor, rule_memory: torch.Tensor = None) -> tuple: - """One superposition search step. - - Args: - cpu_state: [B, N, 16] CPU state in Cl(3,0,1). - ctrl_cursor: [B, 4] control cursor in Cl(1,1). - rule_memory: Optional [B, M, 16] rule slots from RuleAggregator. - - Returns: - Tuple of (new_cpu_state [B, N, 16], search_info dict). - """ + """One search step. Returns (new_cpu_state, search_info).""" B, N, D_cpu = cpu_state.shape device = cpu_state.device K = self.num_hypotheses - # STEP 1 — SCORE: CPU grade norms + ctrl_cursor - cpu_summary = cpu_state.mean(dim=1) # [B, 16] + cpu_summary = cpu_state.mean(dim=1) self.algebra_cpu.ensure_device(device) - cpu_grade_norms = self.algebra_cpu.get_grade_norms(cpu_summary) # [B, 5] - score_input = torch.cat([cpu_grade_norms, ctrl_cursor], dim=-1) # [B, 9] - scores = self.score_mlp(score_input) # [B, K] + grade_norms = self.algebra_cpu.get_grade_norms(cpu_summary) - # STEP 2 — DISPATCH: templates optionally modulated by rule memory - templates = self.instruction_templates.unsqueeze(0).expand(B, -1, -1) # [B, K, 16] + # Per-cell logits + global bias from cursor/grade norms + cell_logits = self.cell_router(cpu_state) # [B, N, K] + global_bias = self.score_mlp( + torch.cat([grade_norms, ctrl_cursor], dim=-1) + ) # [B, K] + scores = cell_logits + global_bias.unsqueeze(1) # [B, N, K] + templates = self.instruction_templates.unsqueeze(0).expand(B, -1, -1) if rule_memory is not None: - rule_summary = rule_memory.mean(dim=1) # [B, 16] - rule_features = self.rule_proj(rule_summary) # [B, K * 16] - rule_modulation = rule_features.view(B, K, D_cpu) # [B, K, 16] + rule_summary = rule_memory.mean(dim=1) + rule_modulation = self.rule_proj(rule_summary).view(B, K, D_cpu) templates = templates + rule_modulation + # Rule memory biases scoring (which instructions cells prefer) + rule_score_bias = self.rule_score_proj(rule_summary) # [B, K] + scores = scores + rule_score_bias.unsqueeze(1) # [B, N, K] - # Score-dependent modulation - instructions = scores.unsqueeze(-1) * templates # [B, K, 16] + outcomes = self.pga_cpu.execute_all(cpu_state, templates) # [B, K, N, D] - # STEP 3 — EXECUTE: CPU applies PGA Motor + ColorUnit, K× batched - outcomes = self.pga_cpu.execute_all(cpu_state, instructions) # [B, K, N, 16] + tau = self._temperature.clamp(0.1, 5.0) + weights = F.gumbel_softmax( + scores.reshape(B * N, K), tau=tau, hard=False + ).reshape(B, N, K) # [B, N, K] + new_cpu_state = torch.einsum('bnk,bknd->bnd', weights, outcomes) - # STEP 4 — SELECT: Gumbel-Softmax, differentiable discrete selection - tau = self.log_temperature.exp().clamp(0.1, 5.0) - weights = F.gumbel_softmax(scores, tau=tau, hard=False) # [B, K] - - # Weighted sum via einsum (no Python loop) - new_cpu_state = torch.einsum('bk,bknd->bnd', weights, outcomes) - - search_info = { + return new_cpu_state, { 'scores': scores, 'weights': weights, 'temperature': tau.detach(), } - - return new_cpu_state, search_info diff --git a/models/gtm/turing_step.py b/models/gtm/turing_step.py index f810da1..e1892c8 100644 --- a/models/gtm/turing_step.py +++ b/models/gtm/turing_step.py @@ -5,16 +5,7 @@ # you may not use this file except in compliance with the License. # -"""Single GTM step: SuperpositionSearch + Cross-Grade Attention + ControlPlane. - -Key design choices: - - NO additive residual (addition destroys geometric structure after rotations). - Instead, the instruction can learn B~0 to approximate identity. - - Cross-grade dense Q/K attention: allows learning diagonal, distance, - and other 2D spatial relationships. - - Values remain raw multivectors with per-grade gain (preserves geometry). - - Geometric gating (rotor interpolation) instead of additive skip connections. -""" +"""Single GTM step: SuperpositionSearch + Cross-Grade Attention + ControlPlane.""" import torch import torch.nn as nn @@ -25,12 +16,6 @@ from .control_plane import ControlPlane -# Grade-to-index mapping for Cl(3,0,1), dim=16 -# Grade 0: [0] -# Grade 1: [1, 2, 4, 8] -# Grade 2: [3, 5, 6, 9, 10, 12] -# Grade 3: [7, 11, 13, 14] -# Grade 4: [15] _GRADE_MAP_16 = torch.zeros(16, dtype=torch.long) _GRADE_MAP_16[0] = 0 _GRADE_MAP_16[[1, 2, 4, 8]] = 1 @@ -40,36 +25,24 @@ class CellAttention(nn.Module): - """Cross-grade self-attention over grid cells in Cl(3,0,1). - - Dense Q/K projections allow learning cross-grade features like - diagonals (e0+e1), distances, and 2D spatial relationships. - Values remain raw multivectors with per-grade gain to preserve - geometric structure in the convex combination. - """ + """Cross-grade self-attention over grid cells in Cl(3,0,1).""" def __init__(self, algebra_cpu: CliffordAlgebra, num_heads: int = 4, head_dim: int = 8, dropout: float = 0.0): super().__init__() - D = algebra_cpu.dim # 16 + D = algebra_cpu.dim attn_dim = num_heads * head_dim self.num_heads = num_heads self.head_dim = head_dim self.scale = head_dim ** -0.5 - # Dense Q, K projections: allow cross-grade mixing for scoring self.q_proj = nn.Linear(D, attn_dim) self.k_proj = nn.Linear(D, attn_dim) - - # Per-grade gain on values (preserves geometric structure) self.v_gain = nn.ParameterDict({ f'g{k}': nn.Parameter(torch.ones(1)) for k in range(5) }) - self.dropout = nn.Dropout(dropout) - - # Grade map buffer for applying per-grade gains self.register_buffer('grade_map', _GRADE_MAP_16.clone()) def _apply_grade_gains(self, x: torch.Tensor) -> torch.Tensor: @@ -82,55 +55,32 @@ def _apply_grade_gains(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: - """Cross-grade self-attention over cells. - - Args: - x: [B, N, 16] multivectors in Cl(3,0,1). - mask: [B, N] bool, True=valid. - - Returns: - [B, N, 16] attended multivectors. - """ + """[B, N, 16] -> [B, N, 16] with optional mask [B, N].""" B, N, D = x.shape - # Dense Q, K projections - Q = self.q_proj(x) # [B, N, attn_dim] - K = self.k_proj(x) # [B, N, attn_dim] + Q = self.q_proj(x) + K = self.k_proj(x) - # Multi-head reshape - Q = Q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N, hd] - K = K.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) # [B, H, N, hd] + Q = Q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + K = K.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) - # Scaled dot-product attention - scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # [B, H, N, N] + scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale if mask is not None: - pad_mask = ~mask scores = scores.masked_fill( - pad_mask.unsqueeze(1).unsqueeze(2), float('-inf') + (~mask).unsqueeze(1).unsqueeze(2), float('-inf') ) - attn = F.softmax(scores, dim=-1) # [B, H, N, N] + attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) - attn_avg = attn.mean(dim=1) # [B, N, N] + attn_avg = attn.mean(dim=1) - # Values: raw multivector, weighted average (convex combination) - attended = torch.bmm(attn_avg, x) # [B, N, 16] - - # Per-grade gain on output + attended = torch.bmm(attn_avg, x) return self._apply_grade_gains(attended) class TuringStep(nn.Module): - """One step of the Geometric Turing Machine. - - Composes: - 1. Cell attention (cross-cell communication with cross-grade features) - 2. Superposition search (per-cell transformation via PGA motor) - 3. Geometric write gate (interpolates via scalar gating, no additive residual) - 4. CliffordLayerNorm - 5. Control plane step - """ + """One step of the Geometric Turing Machine.""" def __init__(self, algebra_cpu: CliffordAlgebra, algebra_ctrl: CliffordAlgebra, @@ -145,73 +95,52 @@ def __init__(self, algebra_cpu: CliffordAlgebra, num_rule_slots: int = 8): super().__init__() self.channels = channels - D_cpu = algebra_cpu.dim # 16 + D_cpu = algebra_cpu.dim - # Cell-to-cell attention (cross-grade features) self.cell_attn = CellAttention( algebra_cpu, num_attn_heads, attn_head_dim, attn_dropout, ) - - # Superposition search module (no mother algebra) self.search = GeometricSuperpositionSearch( algebra_cpu, algebra_ctrl, channels, num_hypotheses, top_k, temperature_init, K_color, num_rule_slots, ) - - # Control plane self.control = ControlPlane(algebra_ctrl, channels) - - # CPU state normalization - self.norm = CliffordLayerNorm(algebra_cpu, 1) # per-cell norm (C=1) - - # Context projection: cpu summary -> ctrl context + self.norm = CliffordLayerNorm(algebra_cpu, 1) self.context_proj = nn.Linear(D_cpu, channels) - - # Geometric write gate: scalar gate per cell + # Per-component gate: enables cross-grade mixing. + # With scalar gate (old), color (g0) and position (g1) always move + # together. Per-component gate lets the model selectively update + # color based on position context and vice versa. self.write_gate = nn.Sequential( - nn.Linear(D_cpu * 2, 64), # concat(old, new) -> 64 + nn.Linear(D_cpu * 2, 64), nn.ReLU(), - nn.Linear(64, 1), + nn.Linear(64, D_cpu), ) + def set_temperature(self, tau: float): + self.search.set_temperature(tau) + def forward(self, cpu_state: torch.Tensor, ctrl_cursor: torch.Tensor, mask: torch.Tensor = None, rule_memory: torch.Tensor = None) -> dict: - """Execute one GTM step. - - Args: - cpu_state: [B, N, 16] CPU state in Cl(3,0,1). - ctrl_cursor: [B, 4] control cursor in Cl(1,1). - mask: Optional [B, N] validity mask (True=valid). - rule_memory: Optional [B, M, 16] rule slots from RuleAggregator. - - Returns: - dict with 'cpu_state', 'ctrl_cursor', 'halt_prob', 'search_info'. - """ old_state = cpu_state - # 1. Cell attention (cross-cell communication) attended = self.cell_attn(cpu_state, mask) - - # 2. Superposition search (per-cell transformation via PGA motor) new_cpu, search_info = self.search.step(attended, ctrl_cursor, rule_memory) - # 3. Geometric write gate (NO additive residual) - gate_input = torch.cat([old_state, new_cpu], dim=-1) # [B, N, 32] - gate = torch.sigmoid(self.write_gate(gate_input)) # [B, N, 1] + gate_input = torch.cat([old_state, new_cpu], dim=-1) + gate = torch.sigmoid(self.write_gate(gate_input)) new_cpu = gate * new_cpu + (1.0 - gate) * old_state - # 4. CliffordLayerNorm (per-cell) B, N, D = new_cpu.shape new_cpu_flat = new_cpu.reshape(B * N, 1, D) new_cpu_flat = self.norm(new_cpu_flat) new_cpu = new_cpu_flat.reshape(B, N, D) - # 5. Control plane step - cpu_summary = new_cpu.mean(dim=1) # [B, 16] - cpu_context = self.context_proj(cpu_summary) # [B, channels] + cpu_summary = new_cpu.mean(dim=1) + cpu_context = self.context_proj(cpu_summary) new_cursor, direction_logit, halt_prob = self.control.step( ctrl_cursor, cpu_context ) diff --git a/models/gtm/turing_vm.py b/models/gtm/turing_vm.py index 715887d..791a4fb 100644 --- a/models/gtm/turing_vm.py +++ b/models/gtm/turing_vm.py @@ -5,12 +5,7 @@ # you may not use this file except in compliance with the License. # -"""Geometric Turing Machine execution engine — ARC-AGI v4. - -Chains TuringSteps with dual-state (cpu_state + ctrl_cursor) threading. -Supports both fixed-step and adaptive computation (PonderNet) modes. -Optionally threads rule_memory from Phase 1 to each step. -""" +"""Geometric Turing Machine execution engine.""" import torch import torch.nn as nn @@ -21,12 +16,7 @@ class TuringVM(nn.Module): - """Geometric Turing Machine execution engine. - - Chains N TuringSteps with dual-state (cpu_state + ctrl_cursor). - Supports both fixed-step and adaptive computation modes. - Threads rule_memory to each step when provided. - """ + """Chains TuringSteps with dual-state threading and optional PonderNet halting.""" def __init__(self, algebra_cpu: CliffordAlgebra, algebra_ctrl: CliffordAlgebra, @@ -48,7 +38,6 @@ def __init__(self, algebra_cpu: CliffordAlgebra, self.max_steps = max_steps self.use_act = use_act - # Create steps up to max_steps (ACT) or num_steps (fixed) effective_steps = max_steps if use_act else num_steps self.steps = nn.ModuleList([ TuringStep( @@ -60,12 +49,13 @@ def __init__(self, algebra_cpu: CliffordAlgebra, for _ in range(effective_steps) ]) - # Adaptive halt controller self.adaptive_halt = AdaptiveHalt(lambda_p, max_steps) if use_act else None - - # Final normalization on CPU state self.final_norm = CliffordLayerNorm(algebra_cpu, 1) + def set_temperature(self, tau: float): + for step in self.steps: + step.set_temperature(tau) + def forward(self, cpu_state: torch.Tensor, ctrl_cursor: torch.Tensor, mask: torch.Tensor = None, return_trace: bool = False, diff --git a/tasks/gtm.py b/tasks/gtm.py index 6914f3f..569af79 100644 --- a/tasks/gtm.py +++ b/tasks/gtm.py @@ -50,10 +50,17 @@ def __init__(self, cfg): self.act_epochs = cfg.training.get('act_epochs', 45) self.act_weight = cfg.training.get('act_weight', 0.01) self.act_ramp_epochs = cfg.training.get('act_ramp_epochs', 15) - self.gate_entropy_weight = cfg.training.get('gate_entropy_weight', 0.001) + self.gate_entropy_weight = cfg.training.get('gate_entropy_weight', 0.01) self.grad_clip = cfg.training.get('grad_clip', 1.0) self.eval_every = cfg.training.get('eval_every', 5) + # Gumbel temperature annealing schedule + self.tau_start = cfg.training.get('tau_start', 1.0) + self.tau_end = cfg.training.get('tau_end', 0.1) + # Warm restart at Phase 3: steps[num_steps:max_steps] are untrained, + # need high tau for exploration before annealing down + self.tau_act_restart = cfg.training.get('tau_act_restart', 0.7) + super().__init__(cfg) def setup_algebra(self): @@ -295,6 +302,23 @@ def run(self): else: self._current_act_weight = 0.0 + # Gumbel temperature annealing: + # Phase 1 (warmup): hold at tau_start + # Phase 2 (circuit): anneal tau_start -> tau_act_restart + # Phase 3 (ACT): warm restart at tau_act_restart, anneal -> tau_end + # Warm restart needed because ACT activates steps[num_steps:max_steps] + # which have untrained weights and need exploration room. + if phase == 1: + tau = self.tau_start + elif phase == 2: + progress = min(1.0, (epoch - self.warmup_epochs) / max(self.trim_epochs, 1)) + tau = self.tau_start + (self.tau_act_restart - self.tau_start) * progress + else: # phase 3 + act_epoch = epoch - (self.warmup_epochs + self.trim_epochs) + progress = min(1.0, act_epoch / max(self.act_epochs, 1)) + tau = self.tau_act_restart + (self.tau_end - self.tau_act_restart) * progress + self.model.set_temperature(tau) + # Training self.model.train() total_loss = 0 @@ -328,6 +352,7 @@ def run(self): 'P': phase, 'Loss': avg_loss, metric_key: val_metric, 'LR': self.optimizer.param_groups[0]['lr'], + 'tau': tau, } if self._current_act_weight > 0: display['ACT_w'] = self._current_act_weight From 402febb4dd4d546e5eec990da3a29e91c393ced1 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Thu, 19 Mar 2026 20:25:20 +0900 Subject: [PATCH 03/16] chore: remove useless files --- experiments/svp_lattice_breaker.py | 1206 ---------------------------- 1 file changed, 1206 deletions(-) delete mode 100644 experiments/svp_lattice_breaker.py diff --git a/experiments/svp_lattice_breaker.py b/experiments/svp_lattice_breaker.py deleted file mode 100644 index e8b8c17..0000000 --- a/experiments/svp_lattice_breaker.py +++ /dev/null @@ -1,1206 +0,0 @@ -# Versor: Universal Geometric Algebra Neural Network -# Copyright (C) 2026 Eunkyum Kim -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# - -"""Kyber MLWE Lattice Solver via Kannan's Embedding. - -Solves Module-LWE instances from the Kyber KEM (NIST FIPS 203) using -Geometric Algebra: - 1. KyberInstance — MLWE generation + Kannan embedding (CVP→SVP) - 2. GALatticeReducer — LLL/BKZ reduction with proper GSO Lovász condition - 3. BladeEnumerator — Tree enumeration with blade rejection pruning - 4. RotorSearchLayer — Neural SVP with GeometricNeutralizer + PhysicalNormBreaker - 5. KyberSolver — Orchestrator: reduce → search → extract secret - -Key features: - - Rigorous Kyber-512 params: eta1=3 (secret), eta2=2 (error) per FIPS 203 - - NO lattice truncation: full Kannan embedding (2m+1) preserves target - - Correct negacyclic matrix: M[i,j] = p[i-j] if i>=j, else -p[n+i-j] - - Proper LLL with GSO-based Lovász condition - - Diet loss: 3 soft terms + 2 hard constraints (last coeff ±1, integrality via snap) - - Physical norm breaker: clamps model output to prevent large-vector false positives - - Safe LAPACK: pre-scaled slogdet/QR avoids DLASCL overflow at high dimensions - - Phase 4: checks shortest basis vectors + pairwise combos before neural candidates - -Performance: - - LLL uses incremental GSO updates (Cohen 2.6.3): O(n) per step vs O(n²d) - - Vectorized Gram-Schmidt and negacyclic matrix construction - - Supports CUDA for all phases (float64 throughout) - -All arithmetic in torch float64. No mpmath. -""" - -import os -import sys -import math -import time -import argparse -import torch -import torch.nn as nn - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) - -from core.algebra import CliffordAlgebra -from core.metric import induced_norm -from layers.primitives.multi_rotor import MultiRotorLayer -from layers.primitives.normalization import CliffordLayerNorm -from layers.primitives.projection import GeometricNeutralizer - -try: - from fpylll import IntegerMatrix, LLL, BKZ, GSO, FPLLL - HAS_FPYLLL = True -except ImportError: - HAS_FPYLLL = False - - -def resolve_device(requested: str) -> str: - """Resolve device string, with auto-detection for 'auto'.""" - if requested == 'auto': - if torch.cuda.is_available(): - return 'cuda' - return 'cpu' - return requested - - -# --------------------------------------------------------------------------- -# 1. KyberInstance — MLWE Instance + Kannan Embedding (no truncation) -# --------------------------------------------------------------------------- - -class KyberInstance: - """Generate a Kyber-style MLWE instance and construct Kannan embedding. - - Polynomial ring R_q = Z_q[X]/(X^n + 1) represented via anti-circulant - (negacyclic convolution) matrices. Secret s drawn from CBD(eta1), - error e drawn from CBD(eta2). - - Standard Kyber-512 (NIST FIPS 203): n=256, k=2, q=3329, eta1=3, eta2=2. - - The full Kannan embedding has dimension 2m+1 where m = k*n. No truncation - is applied — truncation destroys the target vector (e, -s, M). - - Note: Kyber uses compress/decompress for ciphertext components (du, dv). - This affects ciphertext size but NOT the lattice structure. The Kannan - embedding operates on the algebraic relation b = As + e (mod q), which - is pre-compression. - """ - - def __init__(self, n: int = 256, k: int = 2, q: int = 3329, - eta1: int = 3, eta2: int = 2, eta: int = None, - seed: int = 42, device: str = 'cpu'): - if eta is not None: # backward compat: single eta overrides both - eta1, eta2 = eta, eta - self.n, self.k, self.q, self.eta1, self.eta2 = n, k, q, eta1, eta2 - self.device = device - self.m = k * n - self.full_dim = 2 * self.m + 1 - - torch.manual_seed(seed) - - # Secret via CBD(eta1), error via CBD(eta2) — per NIST FIPS 203 - self.s = self._cbd(self.m, self.eta1) - self.e = self._cbd(self.m, self.eta2) - - # Public matrix A (block-negacyclic) - self.A = self._random_module_matrix() - - # b = A s + e (mod q) - self.b = (self.A @ self.s + self.e) % self.q - - # Adaptive M: expected error norm from CBD(eta2) over m dimensions - # Var(CBD(eta)) = eta/2, so E[||e||] = sqrt(m * eta2 / 2) - self.M_embed = max(math.sqrt(self.m * self.eta2 / 2.0), 1.0) - - # Full Kannan embedding basis - self.basis = self._kannan_embedding() - - # Precompute the correction vector k = floor((As + e) / q) for verification - As_plus_e = self.A.long() @ self.s.long() + self.e.long() - self.k_correction = torch.div(As_plus_e, self.q, rounding_mode='floor') - - # Target vector norm: ||(e, -s, M)|| - self.target_norm = math.sqrt( - torch.norm(self.e.double()).item()**2 + - torch.norm(self.s.double()).item()**2 + - self.M_embed**2 - ) - - print(f" KyberInstance: n={n}, k={k}, q={q}, eta1={eta1}, eta2={eta2}") - print(f" Ring dim m={self.m}, full lattice dim={self.full_dim}") - print(f" ||s||={torch.norm(self.s.double()).item():.2f}, " - f"||e||={torch.norm(self.e.double()).item():.2f}, " - f"M={self.M_embed:.2f}") - print(f" Expected target vector norm: {self.target_norm:.2f}") - - @staticmethod - def _cbd(length: int, eta: int) -> torch.Tensor: - """Centered binomial distribution CBD(eta).""" - a = torch.randint(0, 2, (length, eta), dtype=torch.long) - b = torch.randint(0, 2, (length, eta), dtype=torch.long) - return a.sum(dim=1) - b.sum(dim=1) - - def _negacyclic_matrix(self, poly: torch.Tensor) -> torch.Tensor: - """Anti-circulant matrix for R_q = Z_q[X]/(X^n + 1). - - M[i,j] = poly[i - j] if i >= j (no wraparound) - M[i,j] = -poly[n + i - j] if i < j (X^n = -1 wraparound) - - Verified against NIST FIPS 203 (Kyber spec): negacyclic NTT matrix - M[i,j] = f[(i-j) mod n] * (-1)^{floor((i-j)/n)} matches X^n + 1 quotient. - - Vectorized: (row - col) % n gives the correct index for both cases. - """ - n = len(poly) - rows = torch.arange(n, dtype=torch.long).unsqueeze(1) - cols = torch.arange(n, dtype=torch.long).unsqueeze(0) - idx = (rows - cols) % n - lower = rows >= cols - vals = poly[idx] - M = torch.where(lower, vals, -vals) % self.q - return M - - def _random_module_matrix(self) -> torch.Tensor: - """Build k×k block matrix of negacyclic polynomials → m×m matrix.""" - A = torch.zeros(self.m, self.m, dtype=torch.long) - for bi in range(self.k): - for bj in range(self.k): - poly = torch.randint(0, self.q, (self.n,), dtype=torch.long) - block = self._negacyclic_matrix(poly) - A[bi*self.n:(bi+1)*self.n, bj*self.n:(bj+1)*self.n] = block - return A - - def _kannan_embedding(self) -> torch.Tensor: - """Construct full Kannan embedding lattice basis. - - Layout (2m+1) × (2m+1): - [[q*I_m, 0, 0 ], rows 0..m-1 - [ A^T, I_m, 0 ], rows m..2m-1 - [ b^T, 0, M ]] row 2m - - The short vector in this lattice is (e, -s, M) with coefficients: - c[0:m] = -k (mod-q correction) - c[m:2m] = -s (secret) - c[2m] = 1 (Kannan embedding coefficient) - - Lattice vector: c @ B gives: - cols 0..m-1: q*(-k) + A^T*(-s) + b = e (since b = As + e + qk) - cols m..2m-1: -s - col 2m: M - """ - m, dim = self.m, self.full_dim - dev = self.device - - B = torch.zeros(dim, dim, dtype=torch.float64, device=dev) - - # Top-left: q * I_m - B[:m, :m] = self.q * torch.eye(m, dtype=torch.float64, device=dev) - - # Middle block: A^T (transpose gives correct product A @ c_secret) - B[m:2*m, :m] = self.A.to(dtype=torch.float64, device=dev).T - - # Middle diagonal: I_m - B[m:2*m, m:2*m] = torch.eye(m, dtype=torch.float64, device=dev) - - # Bottom row: (b, 0, M) - B[2*m, :m] = self.b.to(dtype=torch.float64, device=dev) - B[2*m, 2*m] = self.M_embed - - return B - - def verify_solution(self, s_recovered: torch.Tensor, - e_recovered: torch.Tensor) -> bool: - """Check As + e ≡ b (mod q).""" - b_check = (self.A.long() @ s_recovered.long() + - e_recovered.long()) % self.q - b_ref = self.b.long() % self.q - return torch.all(b_check == b_ref).item() - - -# --------------------------------------------------------------------------- -# 2. GALatticeReducer — LLL/BKZ with proper GSO Lovász condition -# --------------------------------------------------------------------------- - -class GALatticeReducer: - """Lattice basis reduction: LLL (proper GSO) + BKZ enumeration. - - Size reduction via GSO mu coefficients. Swap via Lovász condition on - GSO norms. BKZ blocks use BladeEnumerator for GA-pruned enumeration. - """ - - def __init__(self, block_dim: int = 8, device: str = 'cpu', use_fpylll: bool = True): - self.block_dim = block_dim - self.device = device - self.use_fpylll = use_fpylll and HAS_FPYLLL - - def _torch_to_imatrix(self, basis: torch.Tensor): - """Convert float64 torch basis to fpylll IntegerMatrix (rounds to nearest int).""" - return IntegerMatrix.from_matrix(basis.round().long().tolist()) - - def _imatrix_to_torch(self, A, device: str) -> torch.Tensor: - """Convert fpylll IntegerMatrix back to float64 torch tensor.""" - n, d = A.nrows, A.ncols - rows = [[A[i, j] for j in range(d)] for i in range(n)] - return torch.tensor(rows, dtype=torch.float64, device=device) - - def _fpylll_reduce(self, basis: torch.Tensor, rounds: int) -> torch.Tensor: - """LLL + progressive BKZ reduction via fpylll backend. - - Uses MPFR extended precision for large lattices (dim > 160) to - maintain GSO numerical stability. Progressive BKZ warms up from - block_size 4 to the target, giving each stage a better starting - point. - """ - A = self._torch_to_imatrix(basis) - n = A.nrows - - # Select float type: MPFR for large lattices - ft = "double" - if n > 160: - try: - FPLLL.set_precision(max(150, n)) - ft = "mpfr" - except Exception: - ft = "double" - - M = GSO.Mat(A, float_type=ft) - M.update_gso() - - # LLL reduction - lll_obj = LLL.Reduction(M) - lll_obj() - - # Progressive BKZ: warm up from small block sizes - if self.block_dim >= 4: - for bs in range(4, self.block_dim, 2): - params = BKZ.Param(block_size=bs, max_loops=max(1, rounds // 2)) - BKZ.Reduction(M, lll_obj, params)() - - # Final pass at target block size - if self.block_dim >= 2: - params = BKZ.Param(block_size=self.block_dim, max_loops=rounds) - BKZ.Reduction(M, lll_obj, params)() - - return self._imatrix_to_torch(A, self.device) - - def _compute_gso(self, basis: torch.Tensor): - """Gram-Schmidt orthogonalization with mu coefficients. - - Vectorized: inner loop replaced with batched dot products and - a single matrix-vector multiply per row. - - Returns: - gso: Orthogonal vectors [n, d]. - mu: Projection coefficients [n, n] (lower triangular). - B_sq: Squared norms of GSO vectors [n]. - """ - n = basis.shape[0] - gso = basis.clone() - mu = torch.zeros(n, n, dtype=basis.dtype, device=basis.device) - B_sq = torch.zeros(n, dtype=basis.dtype, device=basis.device) - - B_sq[0] = (gso[0] ** 2).sum() - for i in range(1, n): - # Batched dot products: basis[i] against all previous GSO vectors - dots = torch.mv(gso[:i], basis[i]) # [i] - valid = B_sq[:i] > 1e-30 - mu[i, :i] = torch.where(valid, dots / B_sq[:i].clamp(min=1e-30), - torch.zeros_like(dots)) - # Subtract all projections at once - gso[i] = basis[i] - torch.mv(gso[:i].T, mu[i, :i]) - B_sq[i] = (gso[i] ** 2).sum() - - return gso, mu, B_sq - - def _lll_reduce(self, basis: torch.Tensor, delta: float = 0.99) -> torch.Tensor: - """LLL reduction with incremental GSO updates (Cohen Algorithm 2.6.3). - - Key optimization: GSO is computed once upfront. Size reduction updates - only mu coefficients (O(k) per step instead of O(n²d) full recompute). - Swaps use the standard incremental GSO update formulas. - """ - n = basis.shape[0] - basis = basis.clone() - _, mu, B_sq = self._compute_gso(basis) # Single full GSO computation - k = 1 - iterations = 0 - max_iter = n * n * 10 # Safety bound - - while k < n and iterations < max_iter: - iterations += 1 - - # Size reduce b_k against b_{k-1}, ..., b_0 - for j in range(k - 1, -1, -1): - if abs(mu[k, j].item()) > 0.5: - r = torch.round(mu[k, j]) - basis[k] = basis[k] - r * basis[j] - # Incremental mu update: O(j) instead of O(n²d) GSO recompute - # basis[j] = gso[j] + sum_{i/B_sq[m] decreases by r*mu[j,m] for m/B_sq[j] = 1). - if j > 0: - mu[k, :j] -= r * mu[j, :j] - mu[k, j] -= r - - # Lovász condition: ||b*_k||^2 >= (delta - mu[k,k-1]^2) * ||b*_{k-1}||^2 - if B_sq[k-1] > 1e-30: - lovasz_ok = B_sq[k] >= (delta - mu[k, k-1]**2) * B_sq[k-1] - else: - lovasz_ok = True - - if lovasz_ok: - k += 1 - else: - # Swap b_k and b_{k-1} - basis[[k, k-1]] = basis[[k-1, k]] - - # Incremental GSO update after swap (standard LLL formulas) - mu_bar = mu[k, k-1].clone() - B = B_sq[k] + mu_bar ** 2 * B_sq[k-1] - - if B > 1e-30: - old_Bk = B_sq[k].clone() - mu[k, k-1] = mu_bar * B_sq[k-1] / B - B_sq[k] = B_sq[k-1] * old_Bk / B - B_sq[k-1] = B - - # Swap mu rows for j < k-1 - if k >= 2: - temp = mu[k-1, :k-1].clone() - mu[k-1, :k-1] = mu[k, :k-1] - mu[k, :k-1] = temp - - # Update mu for all rows i > k (vectorized) - if k + 1 < n: - t = mu[k+1:, k].clone() - mu[k+1:, k] = mu[k+1:, k-1] - mu_bar * t - mu[k+1:, k-1] = t + mu[k, k-1] * mu[k+1:, k] - - k = max(k - 1, 1) - - return basis - - def reduce(self, basis: torch.Tensor, rounds: int = 5) -> torch.Tensor: - """LLL + BKZ reduction. - - Args: - basis: Lattice basis [dim, dim], float64. - rounds: Number of BKZ tours after initial LLL. - - Returns: - Reduced basis [dim, dim]. - """ - n = basis.shape[0] - basis = basis.clone().to(dtype=torch.float64, device=self.device) - - if self.use_fpylll: - print(f" Using fpylll backend (LLL + BKZ-{self.block_dim})") - basis = self._fpylll_reduce(basis, rounds) - metrics = self._compute_metrics(basis) - print(f" fpylll done: shortest={metrics['shortest']:.4f}, " - f"log_defect={metrics['log_defect']:.2f}, " - f"rhf={metrics['rhf']:.6f}") - return basis - - # Phase 1: Full LLL reduction - basis = self._lll_reduce(basis) - metrics = self._compute_metrics(basis) - print(f" LLL done: shortest={metrics['shortest']:.4f}, " - f"log_defect={metrics['log_defect']:.2f}, " - f"rhf={metrics['rhf']:.6f}") - - # Phase 2: BKZ tours with block enumeration - for rnd in range(rounds): - improved = False - - for start in range(0, n - 1, max(1, self.block_dim // 2)): - end = min(start + self.block_dim, n) - block_size = end - start - if block_size < 2: - continue - - # Extract block and enumerate - block = basis[start:end, :].clone() - enumerator = BladeEnumerator(block, device=self.device) - short_vec, short_norm = enumerator.enumerate() - - if short_vec is not None: - current_norm = torch.norm(basis[start]).item() - if short_norm < current_norm * 0.999: - basis[start] = short_vec - # Local LLL re-reduction around insertion point - lo = max(0, start - 2) - hi = min(n, end + 2) - local = basis[lo:hi].clone() - local = self._lll_reduce(local) - basis[lo:hi] = local - improved = True - - metrics = self._compute_metrics(basis) - print(f" BKZ round {rnd+1}/{rounds}: " - f"shortest={metrics['shortest']:.4f}, " - f"log_defect={metrics['log_defect']:.2f}, " - f"rhf={metrics['rhf']:.6f}" - f"{' (improved)' if improved else ''}") - - if not improved: - print(f" BKZ converged at round {rnd+1}.") - break - - return basis - - def _compute_metrics(self, basis: torch.Tensor) -> dict: - """Compute reduction quality metrics with numerically safe LAPACK calls. - - Uses QR decomposition (Householder reflections) instead of slogdet - (LU) to avoid DLASCL errors on ill-conditioned lattices at high - dimensions. log|det(B)| = sum(log|R_ii|) + sum(log(row_norms)). - """ - norms = torch.norm(basis, dim=1) - shortest = norms.min().item() - log_norms_sum = torch.log(norms.clamp(min=1e-100)).sum().item() - - # QR-based log|det|: det(B) = det(B/norms) * prod(norms) - # QR of B_scaled gives |det(B_scaled)| = prod|R_ii| - try: - norms_safe = norms.clamp(min=1e-100) - basis_scaled = basis / norms_safe.unsqueeze(1) - _Q, R = torch.linalg.qr(basis_scaled) - diag_abs = torch.abs(torch.diag(R)).clamp(min=1e-100) - log_det_scaled = torch.log(diag_abs).sum().item() - log_det = log_det_scaled + torch.log(norms_safe).sum().item() - except Exception: - log_det = log_norms_sum # Fallback: assume orthogonal (defect ≈ 0) - - log_defect = log_norms_sum - log_det - n = basis.shape[0] - det_root = math.exp(log_det / n) if log_det > -1e30 else 1e-30 - rhf = (shortest / max(det_root, 1e-100)) ** (1.0 / n) - return {'log_defect': log_defect, 'rhf': rhf, 'shortest': shortest, - 'log_det': log_det} - - -# --------------------------------------------------------------------------- -# 3. BladeEnumerator — GA-Enhanced Enumeration -# --------------------------------------------------------------------------- - -class BladeEnumerator: - """Tree-based enumeration with blade rejection pruning. - - Precomputes blade hierarchy B_k = b_1 ∧ ... ∧ b_k and uses - blade_reject norm for pruning (GA Schnorr-Euchner analogue). - """ - - def __init__(self, block_basis: torch.Tensor, search_range: int = 3, - device: str = 'cpu'): - self.device = device - self.search_range = search_range - self.block_dim = block_basis.shape[0] - self.vec_dim = block_basis.shape[1] - self.block_basis = block_basis.to(dtype=torch.float64, device=device) - - # Cap algebra dimension for tractability (Cl(p,0) with p <= 10) - self.alg_dim = min(self.vec_dim, 10) - self.algebra = CliffordAlgebra(p=self.alg_dim, q=0, device=device) - - # Embed basis vectors as multivectors (using first alg_dim components) - self.mv_basis = [] - for i in range(self.block_dim): - v = self.block_basis[i, :self.alg_dim] - self.mv_basis.append( - self.algebra.embed_vector(v.unsqueeze(0)).squeeze(0) - ) - - # Blade hierarchy - self.blades = self._precompute_blades() - - # Gaussian heuristic bound (QR-based log|det| to avoid DLASCL errors) - sq = min(self.block_dim, self.vec_dim) - sub = self.block_basis[:, :sq] - try: - row_norms = torch.norm(sub, dim=1).clamp(min=1e-100) - sub_scaled = sub / row_norms.unsqueeze(1) - _Q, R = torch.linalg.qr(sub_scaled) - diag_abs = torch.abs(torch.diag(R)).clamp(min=1e-100) - log_det_scaled = torch.log(diag_abs).sum().item() - log_det = log_det_scaled + torch.log(row_norms).sum().item() - det_root = math.exp(log_det / sq) - self.bound = math.sqrt(sq / (2 * math.pi * math.e)) * det_root * 1.05 - except Exception: - self.bound = float('inf') - - def _precompute_blades(self) -> dict: - blades = {} - if self.block_dim == 0: - return blades - current = self.mv_basis[0] - blades[1] = current - for k in range(2, min(self.alg_dim, self.block_dim) + 1): - current = self.algebra.wedge( - current.unsqueeze(0), self.mv_basis[k-1].unsqueeze(0) - ).squeeze(0) - if induced_norm(self.algebra, current.unsqueeze(0)).item() < 1e-12: - break - blades[k] = current - return blades - - def enumerate(self) -> tuple: - """Enumerate with blade rejection pruning. - - Returns: - (best_vector, best_norm) or (None, inf). - """ - best_vec = None - best_norm = float('inf') - sr = self.search_range - - def search(depth, current_vec, current_mv): - nonlocal best_vec, best_norm - - # Blade rejection pruning - k = self.block_dim - depth - if 0 < k <= len(self.blades) and k in self.blades: - B_k = self.blades[k] - q_reject = self.algebra.blade_reject( - current_mv.unsqueeze(0), B_k.unsqueeze(0) - ) - if induced_norm(self.algebra, q_reject).item() > min(self.bound, best_norm): - return - - if depth == 0: - norm = torch.norm(current_vec).item() - if 0 < norm < best_norm: - best_norm = norm - best_vec = current_vec.clone() - return - - idx = depth - 1 - for z in range(-sr, sr + 1): - next_vec = current_vec + z * self.block_basis[idx] - next_mv = (current_mv + z * self.mv_basis[idx] - if idx < len(self.mv_basis) else current_mv) - search(depth - 1, next_vec, next_mv) - - zero_vec = torch.zeros(self.vec_dim, dtype=torch.float64, device=self.device) - zero_mv = torch.zeros(self.algebra.dim, dtype=torch.float64, device=self.device) - search(self.block_dim, zero_vec, zero_mv) - return best_vec, best_norm - - -# --------------------------------------------------------------------------- -# 4. RotorSearchLayer — Neural SVP with safe Neutralizer -# --------------------------------------------------------------------------- - -class RotorSearchLayer(nn.Module): - """Neural short vector search: MultiRotor + GeometricNeutralizer. - - embed_vector → expand channels → CliffordLayerNorm → - MultiRotorLayer(K=4) → GeometricNeutralizer (safe) → extract grade-1 - → PhysicalNormBreaker (clamp oversized outputs) - - ~126 parameters. Pre-normalizes multivectors before neutralizer - to prevent singular covariance (DLASCL crash). Physical norm breaker - prevents large vectors from being mistaken for correct short vectors. - """ - - def __init__(self, block_dim: int = 8, channels: int = 2, - num_rotors: int = 4, target_norm: float = None, - norm_multiplier: float = 2.0, device: str = 'cpu'): - super().__init__() - self.block_dim = block_dim - self.channels = channels - self.device = device - self.target_norm = target_norm - self.norm_multiplier = norm_multiplier - - alg_dim = min(block_dim, 10) - self.algebra = CliffordAlgebra(p=alg_dim, q=0, device=device) - - self.norm_layer = CliffordLayerNorm( - self.algebra, channels=channels, recover=False) - self.multi_rotor = MultiRotorLayer( - self.algebra, channels=channels, num_rotors=num_rotors) - self.neutralizer = GeometricNeutralizer( - self.algebra, channels=channels, momentum=0.3) - - def forward(self, blocks: torch.Tensor): - """Process blocks through GA pipeline. - - Args: - blocks: [B, L, block_dim]. - - Returns: - guided_blocks: [B, L, block_dim]. - """ - B, L, D = blocks.shape - self.algebra.ensure_device(blocks.device) - - d = min(D, self.algebra.n) - x_mv = self.algebra.embed_vector( - blocks[..., :d].reshape(-1, d) - ) # [B*L, 2^n] - - # Expand to channels - x_mv = x_mv.unsqueeze(1).expand(-1, self.channels, -1).contiguous() - - # CliffordLayerNorm → MultiRotor - x_mv = self.norm_layer(x_mv) - x_mv = self.multi_rotor(x_mv) - - # Safe Neutralizer: pre-normalize to unit multivectors, then restore scale. - # This prevents ill-conditioned covariance in the EMA statistics. - mv_norms = x_mv.norm(dim=-1, keepdim=True).clamp(min=1e-8) - x_normed = x_mv / mv_norms - try: - x_normed = self.neutralizer(x_normed) - except Exception: - pass # Skip neutralization if covariance is still singular - x_mv = x_normed * mv_norms - - # Extract grade-1 - g1_idx = [1 << i for i in range(self.algebra.n)] - guided = x_mv[..., g1_idx].mean(dim=1) # [B*L, n] - - if d < D: - pad = torch.zeros(guided.shape[0], D - d, - dtype=guided.dtype, device=guided.device) - guided = torch.cat([guided, pad], dim=-1) - - # Physical Norm Breaker: clamp outputs exceeding norm_multiplier * target_norm. - # Pattern from optimizers/riemannian.py: scale = clamp(norm/max, min=1.0) - if self.target_norm is not None: - max_allowed = self.target_norm * self.norm_multiplier - out_norms = torch.norm(guided, dim=-1, keepdim=True).clamp(min=1e-12) - scale = torch.clamp(out_norms / max_allowed, min=1.0) - guided = guided / scale - - return guided.view(B, L, D) - - -# --------------------------------------------------------------------------- -# 5. KyberSolver — Orchestrator with hard constraints + diet loss -# --------------------------------------------------------------------------- - -class KyberSolver: - """End-to-end Kyber MLWE solver. - - Phase 1: Generate MLWE instance + full Kannan embedding - Phase 2: LLL + BKZ reduction - Phase 3: Neural search (hard constraint: last coeff ±1, diet loss) - Phase 4: Extract (s, e) from short vector → verify - - Supports CPU and CUDA. All arithmetic in float64. - """ - - def __init__(self, n: int = 256, k: int = 2, q: int = 3329, - eta1: int = 3, eta2: int = 2, eta: int = None, - block_dim: int = 8, bkz_rounds: int = 5, - search_steps: int = 300, hunts: int = 10, - seed: int = 42, device: str = 'cpu', - use_fpylll: bool = True): - if eta is not None: - eta1, eta2 = eta, eta - self.n, self.k, self.q, self.eta1, self.eta2 = n, k, q, eta1, eta2 - self.block_dim = block_dim - self.bkz_rounds = bkz_rounds - self.search_steps = search_steps - self.hunts = hunts - self.seed = seed - self.device = device - self.use_fpylll = use_fpylll - self.batch_size = 2 - self.stride = block_dim - - def solve(self): - print(f"\n{'='*60}") - print(f" Kyber MLWE Solver via Kannan Embedding + GA") - print(f" n={self.n}, k={self.k}, q={self.q}, eta1={self.eta1}, eta2={self.eta2}") - print(f" device={self.device}") - print(f"{'='*60}") - start_total = time.time() - - # Phase 1: Generate instance (full embedding, no truncation) - print(f"\n--- Phase 1: MLWE Instance Generation ---") - instance = KyberInstance( - n=self.n, k=self.k, q=self.q, eta1=self.eta1, eta2=self.eta2, - seed=self.seed, device=self.device - ) - basis = instance.basis.clone() - wd = basis.shape[0] # = 2*m + 1 - m = instance.m - - # Phase 2: LLL + BKZ reduction - print(f"\n--- Phase 2: Lattice Reduction ---") - reducer = GALatticeReducer(block_dim=self.block_dim, device=self.device, - use_fpylll=self.use_fpylll) - basis = reducer.reduce(basis, rounds=self.bkz_rounds) - metrics = reducer._compute_metrics(basis) - print(f" Post-reduction: shortest={metrics['shortest']:.4f}, " - f"rhf={metrics['rhf']:.6f}") - - # Phase 3: Neural search - print(f"\n--- Phase 3: Neural SVP Search ---") - - # Compute physical norm bound: min of instance target norm and Gaussian heuristic - gh_norm = (math.sqrt(wd / (2 * math.pi * math.e)) * - math.exp(metrics['log_det'] / wd)) if metrics['log_det'] > -1e30 else float('inf') - target_norm = min(instance.target_norm, gh_norm) if gh_norm > 0 else instance.target_norm - print(f" Physical norm bound: target={target_norm:.4f} " - f"(instance={instance.target_norm:.4f}, GH={gh_norm:.4f})") - - model = RotorSearchLayer( - block_dim=self.block_dim, channels=2, num_rotors=4, - target_norm=target_norm, norm_multiplier=2.0, - device=self.device - ).to(device=self.device, dtype=torch.float64) - num_params = sum(p.numel() for p in model.parameters()) - print(f" RotorSearchLayer: {num_params} params") - - num_blocks = max(1, (wd - self.block_dim) // self.stride + 1) - found_vectors = [] - existing_blades = [None] * num_blocks - - def _make_model(): - return RotorSearchLayer( - block_dim=self.block_dim, channels=2, num_rotors=4, - target_norm=target_norm, norm_multiplier=2.0, - device=self.device - ).to(device=self.device, dtype=torch.float64) - - for hunt in range(self.hunts): - print(f"\n Hunt {hunt+1}/{self.hunts}") - t0 = time.time() - - # Reinitialize model if NaN-corrupted from previous hunt - if any(torch.isnan(p).any() for p in model.parameters()): - model = _make_model() - - # Try both signs for the Kannan embedding coefficient - best_result = None - for sign in [1.0, -1.0]: - result = self._neural_search( - basis, model, num_blocks, existing_blades, - attempt=hunt, embed_sign=sign - ) - if result is not None: - _, _, norm = result - if best_result is None or norm < best_result[2]: - best_result = result - - if best_result is not None: - coeffs, vec, norm = best_result - - # Unimodular basis update - new_basis = self._unimodular_update(basis, coeffs, vec) - if new_basis is not None: - old_m = reducer._compute_metrics(basis) - new_m = reducer._compute_metrics(new_basis) - if new_m['log_defect'] < old_m['log_defect'] - 1e-4: - basis = new_basis - basis = reducer.reduce(basis, rounds=1) - print(f" Basis updated: log_defect " - f"{old_m['log_defect']:.2f} → {new_m['log_defect']:.2f}") - - found_vectors.append({'norm': norm, 'vector': vec, 'coeffs': coeffs}) - print(f" Found: norm={norm:.4f} (time={time.time()-t0:.1f}s)") - else: - print(f" No improvement.") - - # Phase 4: Solution Extraction - print(f"\n--- Phase 4: Solution Extraction ---") - solution_found = False - - # Sub-phase 4a: Check reduced basis rows directly (cheapest check first) - basis_norms = torch.norm(basis, dim=1) - sorted_basis_idx = torch.argsort(basis_norms) - n_check = min(20, wd) - print(f" Scanning {n_check} shortest basis vectors...") - - for rank, idx in enumerate(sorted_basis_idx[:n_check]): - bvec = basis[idx.item()] - bnorm = basis_norms[idx.item()].item() - - e_cand = bvec[:m].round().long() - s_cand = -bvec[m:2*m].round().long() - - if instance.verify_solution(s_cand, e_cand): - print(f" *** SOLUTION FROM BASIS ROW {idx.item()} " - f"(norm={bnorm:.4f}) ***") - print(f" ||e||={torch.norm(e_cand.double()).item():.2f}, " - f"||s||={torch.norm(s_cand.double()).item():.2f}") - solution_found = True - break - - s_cand2 = bvec[m:2*m].round().long() - if instance.verify_solution(s_cand2, e_cand): - print(f" *** SOLUTION FROM BASIS ROW {idx.item()} " - f"(alt sign, norm={bnorm:.4f}) ***") - solution_found = True - break - - # Sub-phase 4b: Check pairwise combinations of shortest basis vectors - if not solution_found: - top_k = min(10, wd) - top_idx = sorted_basis_idx[:top_k] - print(f" Scanning pairwise combinations of shortest {top_k} vectors...") - - for i in range(top_k): - if solution_found: - break - for j in range(i + 1, top_k): - if solution_found: - break - for ci in [-1, 1]: - if solution_found: - break - for cj in [-1, 1]: - combo = (ci * basis[top_idx[i].item()] + - cj * basis[top_idx[j].item()]) - - e_cand = combo[:m].round().long() - s_cand = -combo[m:2*m].round().long() - - if instance.verify_solution(s_cand, e_cand): - print(f" *** SOLUTION FROM COMBINATION " - f"({ci}*row{top_idx[i].item()}" - f"+{cj}*row{top_idx[j].item()}) " - f"norm={torch.norm(combo).item():.4f} ***") - solution_found = True - break - - s_cand2 = combo[m:2*m].round().long() - if instance.verify_solution(s_cand2, e_cand): - print(f" *** SOLUTION FROM COMBINATION " - f"(alt sign) ***") - solution_found = True - break - - # Sub-phase 4c: Check neural search found_vectors - if not solution_found: - found_vectors.sort(key=lambda x: x['norm']) - - for i, fv in enumerate(found_vectors[:10]): - vec = fv['vector'] - print(f" Candidate {i+1}: norm={fv['norm']:.4f}") - - e_cand = vec[:m].round().long() - s_cand = -vec[m:2*m].round().long() - - if instance.verify_solution(s_cand, e_cand): - print(f" *** SOLUTION VERIFIED: As + e ≡ b (mod q) ***") - print(f" ||e||={torch.norm(e_cand.double()).item():.2f}, " - f"||s||={torch.norm(s_cand.double()).item():.2f}") - solution_found = True - break - - s_cand2 = vec[m:2*m].round().long() - if instance.verify_solution(s_cand2, e_cand): - print(f" *** SOLUTION VERIFIED (alt sign): " - f"As + e ≡ b (mod q) ***") - solution_found = True - break - - if not solution_found: - print(f" Solution not extracted from found vectors.") - if found_vectors: - sv = found_vectors[0]['vector'] - res = self._lattice_membership_check(instance.basis, sv) - print(f" Lattice membership residual: {res:.2e}") - - total_time = time.time() - start_total - print(f"\nTotal time: {total_time:.1f}s") - final = reducer._compute_metrics(basis) - print(f"Final: shortest={final['shortest']:.4f}, " - f"log_defect={final['log_defect']:.2f}, rhf={final['rhf']:.6f}") - - def _slice_blocks(self, v: torch.Tensor, num_blocks: int) -> torch.Tensor: - """Slice vector into non-overlapping blocks.""" - blocks = [] - for i in range(num_blocks): - start = i * self.stride - end = start + self.block_dim - if end <= v.shape[1]: - blocks.append(v[:, start:end]) - if not blocks: - return v[:, :self.block_dim].unsqueeze(1) - return torch.stack(blocks, dim=1) - - def _neural_search(self, basis: torch.Tensor, model: nn.Module, - num_blocks: int, existing_blades: list, - attempt: int = 0, embed_sign: float = 1.0) -> tuple: - """Neural search with hard constraints + diet loss. - - Hard constraints: - 1. Last coefficient (Kannan row) fixed to ±1 - 2. Periodic snap: near-integer coefficients rounded in-place - 3. Element-wise coefficient clamp (overflow safety) - - Diet loss (3 terms only): - L = norm_sq + 20·guide_loss + 10·ortho_penalty - Guide loss uses cosine dissimilarity (scale-invariant). - """ - wd = basis.shape[0] - dev = self.device - - # Stabilize basis - norms = torch.norm(basis, dim=1, keepdim=True).clamp(min=1e-12) - basis_stab = basis / norms - norms_flat = norms.squeeze(1) - - # Learnable free coefficients (all except last) - c_free = nn.Parameter(torch.zeros( - self.batch_size, wd - 1, dtype=torch.float64, device=dev - )) - # Last coefficient: hard constraint ±1 - c_embed = embed_sign * torch.ones( - self.batch_size, 1, dtype=torch.float64, device=dev - ) - - # Seed initialization - with torch.no_grad(): - sorted_idx = torch.argsort(norms_flat[:-1]) - for b in range(self.batch_size): - seed_idx = sorted_idx[(attempt * self.batch_size + b) % (wd - 1)].item() - c_free[b, seed_idx] = 1.0 - c_free.add_(torch.randn_like(c_free) * 0.001) - - optimizer = torch.optim.Adam(list(model.parameters()) + [c_free], lr=0.002) - alg = model.algebra - - best_norm = float('inf') - best_coeffs = None - best_vec = None - patience_count = 0 - patience = max(self.search_steps // 2, 100) - - for step in range(self.search_steps): - optimizer.zero_grad() - - # Assemble full coefficient vector with hard constraint - c = torch.cat([c_free, c_embed], dim=1) # [B, wd] - - # Reconstruct lattice vector - v = (c * norms_flat.unsqueeze(0)) @ basis_stab # [B, wd] - - # Block processing through GA model - blocks = self._slice_blocks(v, num_blocks) - guided = model(blocks) - - # Reconstruct guided vector - guided_v = torch.zeros_like(v) - counts = torch.zeros_like(v) - for i in range(guided.shape[1]): - start = i * self.stride - end = start + self.block_dim - if end <= wd: - guided_v[:, start:end] += guided[:, i] - counts[:, start:end] += 1 - guided_v = guided_v / counts.clamp(min=1) - - # === Diet loss: 3 soft terms only === - norm_sq = (v ** 2).sum(dim=1, keepdim=True) - - # Cosine-based guide loss (scale-invariant): the model suggests - # a direction via guided_v; v should align with it regardless - # of their relative magnitudes. - v_n = torch.norm(v, dim=1, keepdim=True).clamp(min=1e-8) - gv_n = torch.norm(guided_v, dim=1, keepdim=True).clamp(min=1e-8) - guide_loss = 1.0 - (v * guided_v).sum(dim=1, keepdim=True) / (v_n * gv_n) - - # Ortho penalty via blade rejection - ortho_penalty = torch.zeros( - self.batch_size, 1, dtype=torch.float64, device=dev) - for i in range(min(guided.shape[1], len(existing_blades))): - if existing_blades[i] is not None: - start = i * self.stride - end = start + self.block_dim - if end <= wd: - d = min(self.block_dim, alg.n) - mv_v = alg.embed_vector(v[:, start:start+d]) - blade = existing_blades[i].to( - dtype=torch.float64, device=dev - ).expand(self.batch_size, -1) - w = alg.wedge(blade, mv_v) - ortho_penalty += 1.0 / (induced_norm(alg, w) + 1e-4) - - loss = (norm_sq + 20.0 * guide_loss + 10.0 * ortho_penalty).mean() - - if torch.isnan(loss) or torch.isinf(loss): - break - - loss.backward() - torch.nn.utils.clip_grad_norm_( - list(model.parameters()) + [c_free], max_norm=1.0) - optimizer.step() - - # Clip model bivector parameters (prevents exp() overflow, - # mirrors RiemannianAdam's max_bivector_norm=10.0) - with torch.no_grad(): - for p in model.parameters(): - p_norm = torch.norm(p) - if p_norm > 10.0: - p.mul_(10.0 / p_norm) - - # === Hard constraint 3: element-wise coefficient clamp === - # Prevents numerical overflow while preserving the Kannan - # structure (c_embed stays fixed at ±1, c_free keeps its - # per-element direction unlike uniform norm scaling). - with torch.no_grad(): - c_free.data.clamp_(-100, 100) - - # === Hard constraint 2: periodic snap === - if step > 0 and step % 50 == 0: - with torch.no_grad(): - residual = (c_free - c_free.round()).abs() - snap_mask = residual < 0.15 - c_free.data[snap_mask] = c_free.data[snap_mask].round() - - # Track best rounded vector - improved = False - with torch.no_grad(): - c_full = torch.cat([c_free.round(), c_embed], dim=1) - for b in range(self.batch_size): - if torch.all(c_full[b] == 0): - continue - v_round = (c_full[b] * norms_flat) @ basis_stab - n_round = torch.norm(v_round).item() - if n_round < best_norm * 0.99999 and n_round > 0: - best_norm = n_round - best_coeffs = c_full[b].clone() - best_vec = v_round.clone() - improved = True - - if improved: - patience_count = 0 - else: - patience_count += 1 - if patience_count >= patience: - break - - if step % 100 == 0: - print(f" Step {step:3d} | Loss: {loss.item():.2e} " - f"| Best: {best_norm:.4f}") - - # Update existing blades for independence tracking - if best_vec is not None: - blks = self._slice_blocks(best_vec.unsqueeze(0), num_blocks).squeeze(0) - for i in range(min(blks.shape[0], len(existing_blades))): - d = min(blks.shape[1], alg.n) - mv_v = alg.embed_vector(blks[i, :d].unsqueeze(0)) - if existing_blades[i] is None: - existing_blades[i] = mv_v.squeeze(0) - else: - nb = alg.wedge(existing_blades[i].unsqueeze(0), mv_v).squeeze(0) - if induced_norm(alg, nb.unsqueeze(0)).item() > 1e-4: - existing_blades[i] = nb - - if best_coeffs is not None: - return best_coeffs, best_vec, best_norm - return None - - def _unimodular_update(self, basis: torch.Tensor, coeffs: torch.Tensor, - new_vec: torch.Tensor) -> torch.Tensor: - """Replace pivot basis row with new short vector + size reduce.""" - c = coeffs.round().long() - if torch.all(c == 0): - return None - - pivot_idx = torch.argmax(torch.abs(c)).item() - if c[pivot_idx].abs() == 0: - return None - - new_basis = basis.clone() - new_basis[pivot_idx] = new_vec - - # Size-reduce other rows against the new vector - n = new_basis.shape[0] - bi_sq = (new_basis[pivot_idx] ** 2).sum() - if bi_sq > 1e-30: - for j in range(n): - if j == pivot_idx: - continue - mu = (new_basis[j] * new_basis[pivot_idx]).sum() / bi_sq - r = torch.round(mu) - if r.abs() > 0: - new_basis[j] = new_basis[j] - r * new_basis[pivot_idx] - - return new_basis - - def _lattice_membership_check(self, original_basis: torch.Tensor, - vec: torch.Tensor) -> float: - """Check approximate lattice membership using QR decomposition. - - Uses pre-conditioned QR solve instead of lstsq to avoid DLASCL - errors on ill-conditioned lattice bases. - """ - try: - B_T = original_basis.T - col_norms = torch.norm(B_T, dim=0).clamp(min=1e-100) - B_scaled = B_T / col_norms.unsqueeze(0) - Q, R = torch.linalg.qr(B_scaled) - rhs = Q.T @ vec - coeffs_scaled = torch.linalg.solve_triangular( - R, rhs.unsqueeze(1), upper=True).squeeze(1) - coeffs = coeffs_scaled / col_norms - residual = torch.norm(B_T @ coeffs - vec).item() - int_residual = torch.norm(coeffs - coeffs.round()).item() - return residual + int_residual - except Exception: - return float('inf') - - -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- - -def main(): - parser = argparse.ArgumentParser( - description="Kyber MLWE Lattice Solver via Kannan Embedding + GA" - ) - parser.add_argument('--n', type=int, default=256, - help='Polynomial degree (default: 256)') - parser.add_argument('--k', type=int, default=2, - help='Module rank (default: 2)') - parser.add_argument('--q', type=int, default=3329, - help='Modulus (default: 3329)') - parser.add_argument('--eta1', type=int, default=3, - help='CBD parameter for secret (Kyber-512: 3)') - parser.add_argument('--eta2', type=int, default=2, - help='CBD parameter for error (Kyber-512: 2)') - parser.add_argument('--eta', type=int, default=None, - help='Set both eta1 and eta2 (overrides individual values)') - parser.add_argument('--block-dim', type=int, default=8, - help='BKZ block dimension (default: 8)') - parser.add_argument('--bkz-rounds', type=int, default=5, - help='BKZ reduction rounds (default: 5)') - parser.add_argument('--search-steps', type=int, default=300, - help='Neural search GD steps (default: 300)') - parser.add_argument('--hunts', type=int, default=10, - help='Number of search hunts (default: 10)') - parser.add_argument('--seed', type=int, default=42, - help='Random seed (default: 42)') - parser.add_argument('--device', type=str, default='cpu', - help='Device: cpu, cuda, or auto (default: cpu)') - parser.add_argument('--no-fpylll', action='store_true', - help='Disable fpylll backend, use PyTorch LLL/BKZ') - args = parser.parse_args() - - args.device = resolve_device(args.device) - print(f"Using device: {args.device}") - - eta1 = args.eta if args.eta is not None else args.eta1 - eta2 = args.eta if args.eta is not None else args.eta2 - - solver = KyberSolver( - n=args.n, k=args.k, q=args.q, eta1=eta1, eta2=eta2, - block_dim=args.block_dim, bkz_rounds=args.bkz_rounds, - search_steps=args.search_steps, hunts=args.hunts, - seed=args.seed, device=args.device, - use_fpylll=not args.no_fpylll - ) - solver.solve() - - -if __name__ == '__main__': - main() From bb60ab4e38617e0e6a1d8922d5dcba596fb9a644 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Thu, 19 Mar 2026 20:26:54 +0900 Subject: [PATCH 04/16] fix: remove chore coments --- scripts/analyze_gtm.py | 15 ++++++++++++++- tasks/gtm.py | 6 ------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/scripts/analyze_gtm.py b/scripts/analyze_gtm.py index 904ba19..59b75b8 100644 --- a/scripts/analyze_gtm.py +++ b/scripts/analyze_gtm.py @@ -42,11 +42,24 @@ def main(): # Print static analysis (no data needed) print() print(analyzer.format_instruction_report()) + + # Temperature from checkpoint reflects the buffer state at save time. + # During training, temperature is set externally by the annealing schedule + # (tau_start -> tau_end over phases 2+3), so the saved value shows the + # temperature at the epoch when the best checkpoint was saved. temp_info = analyzer.analyze_temperature() - print('=== Gumbel Temperature ===') + print('=== Gumbel Temperature (at checkpoint save) ===') for i, t in enumerate(temp_info['temperatures']): sharp = '*' if temp_info['is_sharp'][i] else '' print(f' Step {i}: tau={t:.4f} {sharp}') + + # Show annealing config if available + checkpoint = torch.load(args.checkpoint, map_location='cpu', weights_only=False) + cfg = checkpoint.get('config', {}) + tau_start = cfg.get('training', {}).get('tau_start', None) + tau_end = cfg.get('training', {}).get('tau_end', None) + if tau_start is not None: + print(f' Schedule: {tau_start} -> {tau_end} (linear over phases 2+3)') print() # Load validation data diff --git a/tasks/gtm.py b/tasks/gtm.py index 569af79..bf47f9f 100644 --- a/tasks/gtm.py +++ b/tasks/gtm.py @@ -302,12 +302,6 @@ def run(self): else: self._current_act_weight = 0.0 - # Gumbel temperature annealing: - # Phase 1 (warmup): hold at tau_start - # Phase 2 (circuit): anneal tau_start -> tau_act_restart - # Phase 3 (ACT): warm restart at tau_act_restart, anneal -> tau_end - # Warm restart needed because ACT activates steps[num_steps:max_steps] - # which have untrained weights and need exploration room. if phase == 1: tau = self.tau_start elif phase == 2: From adba6b859aee5bf3f24533512147ebf039a14b86 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 11:15:13 +0900 Subject: [PATCH 05/16] fix: core nan fix caused by sqrt --- core/metric.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/metric.py b/core/metric.py index 8d187db..d4353e0 100644 --- a/core/metric.py +++ b/core/metric.py @@ -94,8 +94,7 @@ def induced_norm(algebra: CliffordAlgebra, A: torch.Tensor) -> torch.Tensor: sq_norm = inner_product(algebra, A, A_rev) # In mixed signatures, sq_norm can be negative. - # We return sqrt(|sq_norm|) - return torch.sqrt(torch.abs(sq_norm)) + return torch.sqrt(torch.abs(sq_norm).clamp(min=1e-12)) def geometric_distance(algebra: CliffordAlgebra, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: """Computes geometric distance. @@ -238,7 +237,8 @@ def hermitian_norm(algebra: CliffordAlgebra, A: torch.Tensor) -> torch.Tensor: Norm [..., 1]. Always >= 0. """ sq = hermitian_inner_product(algebra, A, A) - return torch.sqrt(torch.abs(sq)) + # Clamp before sqrt to avoid inf gradient when sq ≈ 0 (e.g. null multivectors in PGA). + return torch.sqrt(torch.abs(sq).clamp(min=1e-12)) def hermitian_distance(algebra: CliffordAlgebra, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: @@ -277,7 +277,7 @@ def hermitian_angle(algebra: CliffordAlgebra, A: torch.Tensor, B: torch.Tensor) sq_b = (signs * B * B).sum(dim=-1, keepdim=True) # Use sqrt(sq_a * sq_b) instead of sqrt(sq_a)*sqrt(sq_b) to avoid # float32 precision loss from two separate sqrt operations. - denom = torch.sqrt(torch.abs(sq_a) * torch.abs(sq_b)).clamp(min=1e-6) + denom = torch.sqrt((torch.abs(sq_a) * torch.abs(sq_b)).clamp(min=1e-12)).clamp(min=1e-6) cos_theta = ip / denom cos_theta = torch.clamp(cos_theta, -1.0, 1.0) return torch.acos(cos_theta) From e87e3eba7304a8812ee56065d67a14b4a21b1438 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 11:40:46 +0900 Subject: [PATCH 06/16] feat: fundamental change about gtm, previous version is too rigid and rigor cause focus on the turing machine --- models/gtm/__init__.py | 32 +-- models/gtm/action_engine.py | 202 +++++++++++++++++ models/gtm/adaptive_halt.py | 92 +++----- models/gtm/analysis.py | 278 +++++------------------ models/gtm/control_plane.py | 147 ------------- models/gtm/cpu.py | 135 ------------ models/gtm/gtm_net.py | 196 ++++++++--------- models/gtm/heads.py | 14 +- models/gtm/info_geometry.py | 111 ++++++++++ models/gtm/log_manifold.py | 71 ++++++ models/gtm/search_plane.py | 247 +++++++++++++++++++++ models/gtm/superposition.py | 107 --------- models/gtm/turing_step.py | 154 ------------- models/gtm/turing_vm.py | 151 ------------- models/gtm/world_model.py | 426 ++++++++++++++++++++++++++++++++++++ 15 files changed, 1260 insertions(+), 1103 deletions(-) create mode 100644 models/gtm/action_engine.py delete mode 100644 models/gtm/control_plane.py delete mode 100644 models/gtm/cpu.py create mode 100644 models/gtm/info_geometry.py create mode 100644 models/gtm/log_manifold.py create mode 100644 models/gtm/search_plane.py delete mode 100644 models/gtm/superposition.py delete mode 100644 models/gtm/turing_step.py delete mode 100644 models/gtm/turing_vm.py create mode 100644 models/gtm/world_model.py diff --git a/models/gtm/__init__.py b/models/gtm/__init__.py index 7f420e7..826ea56 100644 --- a/models/gtm/__init__.py +++ b/models/gtm/__init__.py @@ -5,15 +5,15 @@ # you may not use this file except in compliance with the License. # -"""Geometric Turing Machine (GTM) package — ARC-AGI v4.""" +"""Geometric Turing Machine (GTM) package — v5 World Model + Search Plane.""" from .grid_codec import GridCodec -from .cpu import GeometricCPU, ColorUnit -from .control_plane import ControlPlane -from .superposition import GeometricSuperpositionSearch -from .turing_step import TuringStep -from .adaptive_halt import AdaptiveHalt -from .turing_vm import TuringVM +from .action_engine import ActionEngine, DiscreteActionHead +from .log_manifold import LogManifoldProjector +from .info_geometry import FIMEvaluator +from .search_plane import AlgebraicProjection, AlgebraicLift, SearchPlane +from .adaptive_halt import FIMAdaptiveHalt +from .world_model import CellAttention, WorldModelStep, WorldModel from .heads import GridReconstructionHead from .rule_memory import RuleAggregator from .gtm_net import GTMNet @@ -21,13 +21,17 @@ __all__ = [ "GridCodec", - "GeometricCPU", - "ColorUnit", - "ControlPlane", - "GeometricSuperpositionSearch", - "TuringStep", - "AdaptiveHalt", - "TuringVM", + "ActionEngine", + "DiscreteActionHead", + "LogManifoldProjector", + "FIMEvaluator", + "AlgebraicProjection", + "AlgebraicLift", + "SearchPlane", + "FIMAdaptiveHalt", + "CellAttention", + "WorldModelStep", + "WorldModel", "GridReconstructionHead", "RuleAggregator", "GTMNet", diff --git a/models/gtm/action_engine.py b/models/gtm/action_engine.py new file mode 100644 index 0000000..59abe23 --- /dev/null +++ b/models/gtm/action_engine.py @@ -0,0 +1,202 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""ActionEngine: generalized PGA Cl(3,0,1) action proposer. + +Combines continuous motor transforms (rotation + translation via sandwich +product) with discrete color operations (DiscreteActionHead). A learnable +per-component gate controls the blend between paths. Proposes K candidate +states for all hypotheses in parallel. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.algebra import CliffordAlgebra + + +class DiscreteActionHead(nn.Module): + """Discrete color update conditioned on spatial context. + + Predicts a 10-class color distribution and converts to grade-0 via + differentiable soft-argmax. This gives much richer color control than + a single additive delta, since the model can target any of 10 colors + in one shot. + + Operates on grade-0 (motor-invariant) using grade-1 spatial features + as context. Bypasses the motor sandwich invariance limitation by + directly setting the scalar component. + """ + + # Grade-1 spatial indices in Cl(3,0,1): e0(1), e1(2), e2(4), e3(8) + _SPATIAL_IDX = [1, 2, 4, 8] + _NUM_COLORS = 10 + + def __init__(self): + super().__init__() + self.spatial_proj = nn.Linear(4, 32) + self.color_mlp = nn.Sequential( + nn.Linear(32 + 16, 64), + nn.ReLU(), + nn.Linear(64, self._NUM_COLORS), + ) + # Anchors: 10 evenly-spaced grade-0 values for each color + self.register_buffer( + '_color_anchors', + torch.arange(self._NUM_COLORS, dtype=torch.float32) / (self._NUM_COLORS - 1), + ) + + def forward(self, state: torch.Tensor, + instr: torch.Tensor) -> torch.Tensor: + """Apply discrete color update via soft color selection. + + Args: + state: Cell states [L, N, 16]. + instr: Instructions [L, 16]. + + Returns: + Updated state [L, N, 16] with modified grade-0. + """ + L, N, D = state.shape + spatial = state[:, :, self._SPATIAL_IDX] # [L, N, 4] + feat = F.relu(self.spatial_proj(spatial)) # [L, N, 32] + ctx = torch.cat([feat, instr.unsqueeze(1).expand(-1, N, -1)], dim=-1) + color_logits = self.color_mlp(ctx) # [L, N, 10] + + # Soft color selection: differentiable weighted sum of anchor values + color_probs = F.softmax(color_logits, dim=-1) # [L, N, 10] + new_color = (color_probs * self._color_anchors).sum(dim=-1) # [L, N] + + out = state.clone() + out[:, :, 0] = new_color + return out + + +class ActionEngine(nn.Module): + """Generalized action proposer combining continuous motors and discrete updates. + + For each of K hypotheses, modulates instruction templates with rule memory, + applies both continuous motor transforms and discrete color updates, then + blends via a learnable per-component gate. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, + num_hypotheses: int = 8, + gate_init: float = 0.0): + super().__init__() + assert algebra_cpu.p == 3 and algebra_cpu.r == 1, \ + f"ActionEngine requires Cl(3,0,1), got Cl({algebra_cpu.p},{algebra_cpu.q},{algebra_cpu.r})" + + self.algebra = algebra_cpu + D = algebra_cpu.dim # 16 + K = num_hypotheses + self.num_hypotheses = K + + # Instruction templates + self.instruction_templates = nn.Parameter(torch.randn(K, D) * 0.1) + + # Discrete action head + self.discrete_head = DiscreteActionHead() + + # Per-component gate: sigmoid(gate_init=0) = 0.5 balanced start + # Grade-0 biased toward discrete (motor can't change scalars) + gate_vals = torch.full((D,), gate_init) + gate_vals[0] = -2.0 # sigmoid(-2) ≈ 0.12 → mostly discrete for color + self.action_gate = nn.Parameter(gate_vals) + + # Rule memory modulation + self.rule_proj = nn.Linear(D, K * D) + + def _motor_transform(self, state: torch.Tensor, + instruction: torch.Tensor) -> torch.Tensor: + """Apply motor sandwich product: R x R~ where R = exp(-B/2). + + Pure geometric transform — no color remapping. The discrete head + handles color updates separately. + + Args: + state: [L, N, D] + instruction: [L, D] + + Returns: + Transformed state [L, N, D]. + """ + L, N, D = state.shape + self.algebra.ensure_device(state.device) + + bv = self.algebra.grade_projection(instruction, 2) + M = self.algebra.exp(-0.5 * bv) + M_rev = self.algebra.reverse(M) + + M_exp = M.unsqueeze(1).expand(L, N, D).reshape(L * N, D) + M_rev_exp = M_rev.unsqueeze(1).expand(L, N, D).reshape(L * N, D) + state_flat = state.reshape(L * N, 1, D) + + out = self.algebra.sandwich_product( + M_exp, state_flat, M_rev_exp, + ).reshape(L, N, D) + + return out + + def propose_all(self, state: torch.Tensor, + hypotheses: torch.Tensor, + rule_memory: torch.Tensor = None) -> torch.Tensor: + """Propose K candidate states for all hypotheses. + + Args: + state: Attended cell states [B, N, D]. + hypotheses: Current hypothesis states [B, K, 4] (reserved for + future hypothesis-conditioned actions). + rule_memory: Optional rule slots [B, M, D]. + + Returns: + Candidate states [B, K, N, D]. + """ + B, N, D = state.shape + K = self.num_hypotheses + self.algebra.ensure_device(state.device) + + templates = self.instruction_templates.unsqueeze(0).expand(B, -1, -1) # [B, K, D] + if rule_memory is not None: + rule_mod = self.rule_proj(rule_memory.mean(dim=1)).view(B, K, D) + templates = templates + rule_mod + + # Batch all K hypotheses: [B*K, N, D] + state_exp = state.unsqueeze(1).expand(B, K, N, D).reshape(B * K, N, D) + instr_flat = templates.reshape(B * K, D) + + # Continuous motor transform (no ColorUnit — pure geometric) + continuous = self._motor_transform(state_exp, instr_flat) # [B*K, N, D] + + # Discrete color update + discrete = self.discrete_head(state_exp, instr_flat) # [B*K, N, D] + + # Blend via per-component gate + gate = torch.sigmoid(self.action_gate) # [D] + result = gate * continuous + (1.0 - gate) * discrete + + return result.reshape(B, K, N, D) + + def get_combined_rotor(self, weights: torch.Tensor) -> torch.Tensor: + """Compute weighted combination of rotors from instruction templates. + + Uses weighted bivector averaging in Lie algebra (log-space) then + a single exp map, which is more numerically stable than weighting + post-exp rotors. + + Args: + weights: Hypothesis attention weights [B, K]. + + Returns: + Combined rotor [B, D]. + """ + self.algebra.ensure_device(weights.device) + # Weighted sum of bivectors (Lie algebra averaging) + bv = self.algebra.grade_projection(self.instruction_templates, 2) # [K, D] + combined_bv = torch.einsum('bk,kd->bd', weights, bv) # [B, D] + # Single exp map from the averaged bivector + return self.algebra.exp(-0.5 * combined_bv) diff --git a/models/gtm/adaptive_halt.py b/models/gtm/adaptive_halt.py index be73401..a21e47c 100644 --- a/models/gtm/adaptive_halt.py +++ b/models/gtm/adaptive_halt.py @@ -5,91 +5,59 @@ # you may not use this file except in compliance with the License. # -"""PonderNet-style adaptive computation controller. +"""FIM-based adaptive computation halt. -Takes per-step halting probabilities and produces: -- Mixing weights for per-step outputs (geometric distribution) -- KL divergence against a geometric prior for regularization -- Expected number of computation steps per example +Steps that produce more information gain get higher mixing weight. +At inference, halts when weighted information gain drops below a threshold. """ import torch import torch.nn as nn -class AdaptiveHalt(nn.Module): - """PonderNet adaptive computation time controller. +class FIMAdaptiveHalt(nn.Module): + """FIM-based adaptive computation time controller. - Computes mixing weights from per-step halt probabilities using a - geometric distribution: p(halt at t) = lambda_t * prod_{s dict: - """Compute mixing weights and KL loss from per-step halt probabilities. + def forward(self, delta_infos: list, weights_list: list) -> dict: + """Compute mixing weights from per-step FIM information gains. Args: - halt_probs: List of T tensors, each [B] (mean halt prob per example). + delta_infos: List of T tensors, each [B, K] (per-hypothesis info gain). + weights_list: List of T tensors, each [B, K] (hypothesis attention weights). Returns: dict with: - 'weights': [B, T] mixing weights for per-step outputs + 'mixing_weights': [B, T] mixing weights for per-step outputs 'expected_steps': [B] expected computation depth - 'kl_loss': scalar KL divergence against geometric prior """ - T = len(halt_probs) - B = halt_probs[0].shape[0] - device = halt_probs[0].device - eps = self.eps - - # Stack halt probs: [T, B] - lambdas = torch.stack(halt_probs, dim=0) # [T, B] - lambdas = lambdas.clamp(eps, 1.0 - eps) - - # Compute geometric distribution weights - # p(halt at t) = lambda_t * prod_{s 'GTMAnalyzer': - """Load GTMAnalyzer from a BaseTask checkpoint. - - Args: - path: Path to checkpoint saved by BaseTask.save_checkpoint(). - device: Target device. - - Returns: - GTMAnalyzer instance with loaded model. - """ + """Load GTMAnalyzer from a BaseTask checkpoint.""" from models.gtm import GTMNet checkpoint = torch.load(path, map_location=device, weights_only=False) cfg = checkpoint['config'] mcfg = cfg.model - act_cfg = mcfg.get('act', {}) - color_cfg = mcfg.get('color_unit', {}) attn_cfg = mcfg.get('attention', {}) + sp_cfg = mcfg.get('search_plane', {}) + lm_cfg = mcfg.get('log_manifold', {}) + ig_cfg = mcfg.get('info_geometry', {}) + ae_cfg = mcfg.get('action_engine', {}) algebra_cpu = CliffordAlgebra(3, 0, 1, device=device) algebra_ctrl = CliffordAlgebra(1, 1, 0, device=device) @@ -75,66 +66,50 @@ def from_checkpoint(path: str, device: str = 'cpu') -> 'GTMAnalyzer': model = GTMNet( algebra_cpu=algebra_cpu, algebra_ctrl=algebra_ctrl, - channels=mcfg.get('channels', 16), - num_steps=mcfg.get('num_steps', 8), - max_steps=mcfg.get('max_steps', 20), - num_hypotheses=mcfg.get('num_hypotheses', 4), - top_k=mcfg.get('top_k', 1), - head_hidden=mcfg.get('head_hidden', 64), - temperature_init=mcfg.get('gumbel_temperature', 1.0), - use_act=act_cfg.get('enabled', True), - lambda_p=act_cfg.get('lambda_p', 0.5), + channels=mcfg.get('channels', 32), + num_steps=mcfg.get('num_steps', 12), + max_steps=mcfg.get('max_steps', 24), + num_hypotheses=mcfg.get('num_hypotheses', 8), + head_hidden=mcfg.get('head_hidden', 128), coord_scale=mcfg.get('coord_scale', 1.0), - K_color=color_cfg.get('K_color', 4), num_attn_heads=attn_cfg.get('num_heads', 4), attn_head_dim=attn_cfg.get('head_dim', 8), num_rule_slots=mcfg.get('num_rule_slots', 8), + num_memory_channels=mcfg.get('num_memory_channels', 4), + weight_share_steps=mcfg.get('weight_share_steps', False), + log_manifold_gate_init=lm_cfg.get('gate_init', -5.0), + evolve_hidden=sp_cfg.get('evolve_hidden', 64), + halt_eps=ig_cfg.get('halt_eps', 0.01), + use_supervised_fim=ig_cfg.get('use_supervised_fim', True), + action_gate_init=ae_cfg.get('gate_init', 0.0), ) model.load_state_dict(checkpoint['model_state_dict'], strict=False) return GTMAnalyzer(model, device) - # ------------------------------------------------------------------ - # Static analysis (no data required) - # ------------------------------------------------------------------ - def analyze_instructions(self) -> dict: """Decompose instruction templates into geometric components. - For each of the K trainable instruction templates in Cl(3,0,1): - - Rotation bivectors (e01, e02, e12) -> rotation angle and plane - - Translation bivectors (e03, e13, e23) -> translation vector - - Scalar (grade-0) and pseudoscalar (grade-4) -> color control signals - Returns: - dict with keys per template index: - 'templates_raw': [K, 16] raw parameter values - 'rotation_angles': [K] angle in radians (= 2 * ||B_rot||) - 'rotation_planes': [K, 3] unit bivector (e01, e02, e12) - 'rotation_degrees': [K] angle in degrees - 'translation_vectors': [K, 3] translation (e03, e13, e23) magnitudes - 'translation_norms': [K] translation magnitude - 'color_control': [K, 2] (grade-0, grade-4) values - 'near_identity': [K] bool — True if template ~ no-op + dict with rotation angles/planes, translation vectors, color control. """ templates = self._get_templates() # [K, 16] K = templates.shape[0] # Rotation bivectors: e01(idx3), e02(idx5), e12(idx6) - bv_rot = templates[:, [3, 5, 6]] # [K, 3] - bv_rot_norm = bv_rot.norm(dim=-1) # [K] + bv_rot = templates[:, [3, 5, 6]] + bv_rot_norm = bv_rot.norm(dim=-1) rotation_angles = 2.0 * bv_rot_norm safe_norm = bv_rot_norm.clamp(min=1e-8).unsqueeze(-1) rotation_planes = bv_rot / safe_norm # Translation bivectors: e03(idx9), e13(idx10), e23(idx12) - bv_trans = templates[:, [9, 10, 12]] # [K, 3] - trans_norms = bv_trans.norm(dim=-1) # [K] + bv_trans = templates[:, [9, 10, 12]] + trans_norms = bv_trans.norm(dim=-1) # Color control signals - color_control = templates[:, [0, 15]] # [K, 2] (grade-0, grade-4) + color_control = templates[:, [0, 15]] - # Near-identity: small rotation + small translation + small color signal near_identity = ( (rotation_angles < 0.1) & (trans_norms < 0.05) & @@ -152,96 +127,40 @@ def analyze_instructions(self) -> dict: 'near_identity': near_identity, } - def analyze_color_unit(self) -> dict: - """Inspect ColorUnit remapping tables. - - Returns: - dict with: - 'remap_tables': [K_color, 10, 10] learned tables - 'table_diag_dominance': [K_color] how close to identity each table is - """ - # Get color unit from first step's search module - color_unit = self.model.vm.steps[0].search.pga_cpu.color_unit - tables = color_unit.remap_tables.detach() # [K_color, 10, 10] - - # Diagonal dominance: fraction of mass on diagonal - diags = torch.diagonal(tables, dim1=-2, dim2=-1) # [K_color, 10] - row_sums = tables.abs().sum(dim=-1) # [K_color, 10] - diag_dominance = (diags.abs() / row_sums.clamp(min=1e-8)).mean(dim=-1) - + def analyze_action_gate(self) -> dict: + """Inspect per-component continuous/discrete blend.""" + # All steps share the same ActionEngine parameters if weight-shared + step0 = self.model.world_model.steps[0] + gate = torch.sigmoid(step0.action_engine.action_gate).detach() return { - 'remap_tables': tables, - 'table_diag_dominance': diag_dominance, + 'gate_values': gate, + 'continuous_dominant': (gate > 0.5).sum().item(), + 'discrete_dominant': (gate <= 0.5).sum().item(), } - def analyze_temperature(self) -> dict: - """Analyze Gumbel-Softmax temperature across all steps. - - Returns: - dict with: - 'temperatures': [num_steps] current temperature per step - 'is_sharp': [num_steps] bool — True if tau < 0.5 (near-discrete) - """ - temps = [] - for step in self.model.vm.steps: - tau = step.search._temperature.clamp(0.1, 5.0) - temps.append(tau.item()) - - temps_t = torch.tensor(temps) + def analyze_hypothesis_init(self) -> dict: + """Inspect initial hypothesis positions in Cl(1,1).""" + h_init = self.model.world_model.hypothesis_init.detach() + labels = ['scalar', 'e+', 'e-', 'e+e-'] return { - 'temperatures': temps_t, - 'is_sharp': temps_t < 0.5, + 'hypothesis_init': h_init, + 'component_labels': labels, } - # ------------------------------------------------------------------ - # Dynamic analysis (requires a batch) - # ------------------------------------------------------------------ - def analyze(self, batch: dict) -> dict: - """Full analysis of one batch through both phases. - - Args: - batch: Collated ARC batch from collate_arc. - - Returns: - dict with: - 'instructions': instruction decomposition (static) - 'color_unit': color remapping analysis (static) - 'phase1': {cursors, search_scores, search_weights, - gate_values, halt_probs} - 'phase2': same structure as phase1 - 'cursor_after_phase1': [B, 4] - 'cursor_after_phase2': [B, 4] - 'predictions': [B, N_test] predicted colors - 'targets': [B, N_test] ground truth - 'cell_accuracy': float - 'grid_correct': [B] bool per example - 'test_masks': [B, N_test] validity mask - """ - # Run full forward with trace + """Full analysis of one batch through both phases.""" with torch.no_grad(): result = self._run_forward(batch) logits = result['logits'] preds = logits.argmax(dim=-1) - trace = result['trace'] - - # Split trace into Phase 1 (demo) and Phase 2 (test). - # When ACT is enabled, each VM call produces max_steps entries; - # when disabled, num_steps entries. Both phases use the same mode. - vm = self.model.vm - steps_per_phase = vm.max_steps if vm.use_act else vm.num_steps - phase1_trace = {k: v[:steps_per_phase] for k, v in trace.items()} - phase2_trace = {k: v[steps_per_phase:] for k, v in trace.items()} - # Targets test_outputs = batch['test_outputs'].to(self.device) test_masks = batch['test_masks'].to(self.device) B, H_max, W_max = test_outputs.shape targets = test_outputs.reshape(B, H_max * W_max) valid = test_masks.reshape(B, H_max * W_max) - # Metrics matches = (preds == targets) & valid cell_acc = matches.sum().item() / max(valid.sum().item(), 1) @@ -254,11 +173,10 @@ def analyze(self, batch: dict) -> dict: return { 'instructions': self.analyze_instructions(), - 'color_unit': self.analyze_color_unit(), - 'phase1': phase1_trace, - 'phase2': phase2_trace, - 'cursor_after_phase1': phase1_trace['cursors'][-1] if phase1_trace['cursors'] else None, - 'cursor_after_phase2': phase2_trace['cursors'][-1] if phase2_trace['cursors'] else None, + 'action_gate': self.analyze_action_gate(), + 'hypothesis_init': self.analyze_hypothesis_init(), + 'trace': result.get('trace'), + 'world_model_info': result.get('world_model_info'), 'predictions': preds, 'targets': targets, 'cell_accuracy': cell_acc, @@ -267,14 +185,7 @@ def analyze(self, batch: dict) -> dict: } def predict(self, batch: dict) -> dict: - """Lightweight prediction — just logits and accuracy. - - Args: - batch: Collated ARC batch. - - Returns: - dict with 'predictions', 'targets', 'cell_accuracy', 'grid_correct'. - """ + """Lightweight prediction — just logits and accuracy.""" with torch.no_grad(): result = self._run_forward(batch) @@ -304,10 +215,6 @@ def predict(self, batch: dict) -> dict: 'grid_correct': grid_correct, } - # ------------------------------------------------------------------ - # Report formatting - # ------------------------------------------------------------------ - def format_instruction_report(self) -> str: """Human-readable instruction template summary.""" info = self.analyze_instructions() @@ -334,89 +241,16 @@ def format_instruction_report(self) -> str: return '\n'.join(lines) - def format_cursor_report(self, report: dict) -> str: - """Human-readable cursor trajectory summary.""" - lines = ['=== Cursor Trajectory ===', ''] - - # Cl(1,1) components: {1, e3, e4, e34} - labels = ['scalar(confidence)', 'e3(hypothesis)', 'e4(depth)', 'e34(phase)'] - - for phase_name, phase_key in [('Phase 1 (Rule Inference)', 'phase1'), - ('Phase 2 (Rule Application)', 'phase2')]: - cursors = report[phase_key]['cursors'] - if not cursors: - continue - lines.append(f'{phase_name}:') - for t, cursor in enumerate(cursors): - vals = cursor[0] # first batch element - components = ' '.join(f'{labels[j]}={vals[j]:+.4f}' for j in range(4)) - lines.append(f' Step {t}: {components}') - lines.append('') - - return '\n'.join(lines) - - def format_search_report(self, report: dict) -> str: - """Human-readable hypothesis selection summary. - - Handles both per-cell weights [B, N, K] (v4.1+) and legacy - global weights [B, K] via dimension check. - """ - lines = ['=== Hypothesis Selection ===', ''] - - for phase_name, phase_key in [('Phase 1', 'phase1'), ('Phase 2', 'phase2')]: - weights_list = report[phase_key]['search_weights'] - if not weights_list: - continue - lines.append(f'{phase_name}:') - for t, w in enumerate(weights_list): - w0 = w[0] # first batch element - if w0.dim() == 2: - # Per-cell weights: [N, K] - K = w0.shape[-1] - mean_w = w0.mean(dim=0) # [K] - dominant_per_cell = w0.argmax(dim=-1) # [N] - hist = torch.bincount(dominant_per_cell, minlength=K) - mean_str = ' '.join(f'H{k}={mean_w[k]:.3f}' for k in range(K)) - hist_str = ' '.join(f'H{k}:{hist[k].item()}' for k in range(K)) - lines.append(f' Step {t}: mean=[{mean_str}] cells=[{hist_str}]') - else: - # Legacy global weights: [K] - dominant = w0.argmax().item() - w_str = ' '.join(f'H{k}={w0[k]:.3f}' for k in range(w0.shape[0])) - lines.append(f' Step {t}: [{w_str}] dominant=H{dominant}') - lines.append('') - - return '\n'.join(lines) - - def format_gate_report(self, report: dict) -> str: - """Human-readable write gate summary.""" - lines = ['=== Write Gate Analysis ===', ''] - - for phase_name, phase_key in [('Phase 1', 'phase1'), ('Phase 2', 'phase2')]: - gates = report[phase_key]['gate_values'] - if not gates: - continue - lines.append(f'{phase_name}:') - for t, g in enumerate(gates): - g0 = g[0].squeeze(-1) # [N] for first batch element - lines.append( - f' Step {t}: mean={g0.mean():.3f} ' - f'min={g0.min():.3f} max={g0.max():.3f} ' - f'accept(>0.5)={(g0 > 0.5).float().mean():.1%}' - ) - lines.append('') - - return '\n'.join(lines) - def full_report(self, batch: dict) -> str: """Generate complete human-readable analysis report.""" report = self.analyze(batch) sections = [ self.format_instruction_report(), - self.format_cursor_report(report), - self.format_search_report(report), - self.format_gate_report(report), + '', + '=== Action Gate ===', + f' Continuous-dominant components: {report["action_gate"]["continuous_dominant"]}/16', + f' Discrete-dominant components: {report["action_gate"]["discrete_dominant"]}/16', '', '=== Prediction Summary ===', f' Cell accuracy: {report["cell_accuracy"]:.4f}', @@ -424,13 +258,9 @@ def full_report(self, batch: dict) -> str: ] return '\n'.join(sections) - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - def _get_templates(self) -> torch.Tensor: - """Get instruction templates from the first step (shared across steps).""" - return self.model.vm.steps[0].search.instruction_templates.detach() + """Get instruction templates from the first WorldModel step.""" + return self.model.world_model.steps[0].action_engine.instruction_templates.detach() def _run_forward(self, batch: dict) -> dict: """Run model forward with trace, handling device transfer.""" diff --git a/models/gtm/control_plane.py b/models/gtm/control_plane.py deleted file mode 100644 index d827149..0000000 --- a/models/gtm/control_plane.py +++ /dev/null @@ -1,147 +0,0 @@ -# Versor: Universal Geometric Algebra Neural Network -# Copyright (C) 2026 Eunkyum Kim -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# - -"""Cl(1,1) learnable search controller with rook movement. - -4D state: cursor = s*1 + h*e3 + d*e4 + p*e34 - s (scalar): confidence - h (e3): hypothesis index - d (e4): depth - p (e34): phase (bivector) - -Rook movement: only horizontal OR vertical per step. - Horizontal boost: R_h = exp(alpha * e34) — boosts e3 (explore hypotheses) - Vertical boost: R_v = exp(-beta * e34) — boosts e4 (go deeper) - Direction gate (sigmoid MLP): selects pos' = gate*R_h(pos) + (1-gate)*R_v(pos) - -e34 is hyperbolic (sq=+1 in Cl(1,1)), so boosts use cosh/sinh. -""" - -import torch -import torch.nn as nn -from core.algebra import CliffordAlgebra -from layers.primitives.base import CliffordModule - - -class ControlPlane(CliffordModule): - """Cl(1,1) learnable search controller.""" - - def __init__(self, algebra_ctrl: CliffordAlgebra, channels: int, - max_hypotheses: int = 4): - assert algebra_ctrl.p == 1 and algebra_ctrl.q == 1, \ - f"ControlPlane requires Cl(1,1), got Cl({algebra_ctrl.p},{algebra_ctrl.q})" - super().__init__(algebra_ctrl) - self.channels = channels - self.max_hypotheses = max_hypotheses - # Cl(1,1) dim = 4: {1, e3, e4, e34} mapped to indices {0, 1, 2, 3} - - # Boost parameters (learnable) - self.alpha_mlp = nn.Sequential( - nn.Linear(channels + 4, 32), - nn.ReLU(), - nn.Linear(32, 1), - ) - self.beta_mlp = nn.Sequential( - nn.Linear(channels + 4, 32), - nn.ReLU(), - nn.Linear(32, 1), - ) - - # Direction gate: horizontal vs vertical - self.direction_gate = nn.Sequential( - nn.Linear(channels + 4, 32), - nn.ReLU(), - nn.Linear(32, 1), - ) - - # Residual correction for boost-invariant components (scalar, e34) - # Sandwich product with e34 bivector only boosts grade-1 (e3, e4); - # scalar and pseudoscalar are algebraically invariant. This MLP - # provides a learned additive update so those components can evolve. - # Outputs 2 values: (delta_scalar, delta_e34), NOT all 4 components, - # to avoid double-counting with the boost on e3/e4. - self.cursor_residual = nn.Sequential( - nn.Linear(channels + 4, 32), - nn.Tanh(), - nn.Linear(32, 2), - ) - # Initialize near-zero so early training is dominated by the boost - nn.init.zeros_(self.cursor_residual[-1].weight) - nn.init.zeros_(self.cursor_residual[-1].bias) - - # Halt signal from cursor - self.halt_mlp = nn.Sequential( - nn.Linear(4, 16), - nn.ReLU(), - nn.Linear(16, 1), - ) - - def step(self, cursor: torch.Tensor, - cpu_context: torch.Tensor) -> tuple: - """Advance the control cursor one step. - - Args: - cursor: Current cursor [B, 4] in Cl(1,1). - cpu_context: Summary of CPU state [B, channels] (e.g., mean-pooled grade norms). - - Returns: - Tuple of (new_cursor [B, 4], direction_logit [B, 1], halt_prob [B]). - """ - B = cursor.shape[0] - device = cursor.device - self.algebra.ensure_device(device) - - # Combine cursor with CPU context for MLPs - combined = torch.cat([cursor, cpu_context], dim=-1) # [B, 4 + channels] - - # Compute boost magnitudes - alpha = self.alpha_mlp(combined).squeeze(-1) # [B] - beta = self.beta_mlp(combined).squeeze(-1) # [B] - - # Build bivector for horizontal boost: alpha * e34 - bv_h = torch.zeros(B, 4, device=device, dtype=cursor.dtype) - bv_h[:, 3] = alpha # e34 component - - # Build bivector for vertical boost: -beta * e34 - bv_v = torch.zeros(B, 4, device=device, dtype=cursor.dtype) - bv_v[:, 3] = -beta # e34 component - - # Exponentiate boosts - R_h = self.algebra.exp(-0.5 * bv_h) # [B, 4] - R_v = self.algebra.exp(-0.5 * bv_v) # [B, 4] - - # Apply boosts to cursor via sandwich product - # For Cl(1,1), we can use geometric_product directly (1D batch) - R_h_rev = self.algebra.reverse(R_h) - R_v_rev = self.algebra.reverse(R_v) - - cursor_h = self.algebra.geometric_product( - self.algebra.geometric_product(R_h, cursor), R_h_rev - ) - cursor_v = self.algebra.geometric_product( - self.algebra.geometric_product(R_v, cursor), R_v_rev - ) - - # Direction gate - direction_logit = self.direction_gate(combined) # [B, 1] - gate = torch.sigmoid(direction_logit) # [B, 1] - new_cursor = gate * cursor_h + (1.0 - gate) * cursor_v # [B, 4] - - # Residual correction: only update boost-invariant components - delta = self.cursor_residual(combined) # [B, 2] -> (delta_scalar, delta_e34) - new_cursor = new_cursor.clone() - new_cursor[:, 0] = new_cursor[:, 0] + delta[:, 0] # scalar - new_cursor[:, 3] = new_cursor[:, 3] + delta[:, 1] # e34 - - # Symmlog normalization: prevents unbounded drift across steps - # while preserving gradient (grad = 1/(1+|x|), never zero) - new_cursor = torch.sign(new_cursor) * torch.log1p(new_cursor.abs()) - - # Halt probability from grade-0 of cursor - halt_prob = torch.sigmoid(self.halt_mlp(new_cursor)).squeeze(-1) # [B] - - return new_cursor, direction_logit, halt_prob diff --git a/models/gtm/cpu.py b/models/gtm/cpu.py deleted file mode 100644 index 2182e02..0000000 --- a/models/gtm/cpu.py +++ /dev/null @@ -1,135 +0,0 @@ -# Versor: Universal Geometric Algebra Neural Network -# Copyright (C) 2026 Eunkyum Kim -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# - -"""PGA Cl(3,0,1) computation engine: motor sandwich + color remapping.""" - -import torch -import torch.nn as nn -import torch.nn.functional as F -from core.algebra import CliffordAlgebra - - -class ColorUnit(nn.Module): - """Position-conditioned discrete color remapping via K blended [10, 10] tables. - - Table selection conditioned on per-cell spatial features (grade-1 post-motor - position) plus instruction grade-0/grade-4. This breaks the fundamental - GA bottleneck where the motor sandwich product leaves grade-0 (color) - invariant — by conditioning the remap on post-motor position, each cell - can receive a different color transformation. - - Spatial indices used: - idx 1 (e0): row, idx 2 (e1): col, idx 8 (e3): homogeneous coord - """ - - # Grade-1 spatial component indices in Cl(3,0,1) - _SPATIAL_IDX = [1, 2, 8] # e0(row), e1(col), e3(homo) - - def __init__(self, K_color: int = 4): - super().__init__() - self.K_color = K_color - self.remap_tables = nn.Parameter( - torch.eye(10).unsqueeze(0).expand(K_color, -1, -1).clone() - + torch.randn(K_color, 10, 10) * 0.01 - ) - # 2 (instruction g0 + g4) + 3 (cell spatial) = 5 inputs - self.selector = nn.Linear(5, K_color) - - def forward(self, state: torch.Tensor, - instruction: torch.Tensor) -> torch.Tensor: - """Apply position-conditioned color remapping. - - Args: - state: [L, N, 16] after motor transform. - instruction: [L, 16]. - """ - L, N, D = state.shape - - # Per-cell spatial features from post-motor grade-1 components - cell_spatial = state[:, :, self._SPATIAL_IDX] # [L, N, 3] - - # Instruction features broadcast to every cell - instr_feat = torch.stack( - [instruction[:, 0], instruction[:, 15]], dim=-1, - ).unsqueeze(1).expand(L, N, 2) # [L, N, 2] - - sel_input = torch.cat([instr_feat, cell_spatial], dim=-1) # [L, N, 5] - table_weights = F.softmax(self.selector(sel_input), dim=-1) # [L, N, K] - - # Per-cell blended remap table - blended = torch.einsum( - 'lnk,kij->lnij', table_weights, self.remap_tables, - ) # [L, N, 10, 10] - - raw_color = state[:, :, 0] * 9.0 - centers = torch.arange(10, device=state.device, dtype=state.dtype) - diffs = raw_color.unsqueeze(-1) - centers - soft_color = F.softmax(-4.0 * diffs.pow(2), dim=-1) # [L, N, 10] - - # Per-cell remap: [L, N, 1, 10] @ [L, N, 10, 10] -> [L, N, 1, 10] - remapped = torch.matmul( - soft_color.unsqueeze(2), blended, - ).squeeze(2) # [L, N, 10] - - new_color = torch.einsum('lni,i->ln', remapped, centers) / 9.0 - new_occupancy = 1.0 - remapped[:, :, 0] - - out = state.clone() - out[:, :, 0] = new_color - out[:, :, 15] = new_occupancy - return out - - -class GeometricCPU(nn.Module): - """PGA Cl(3,0,1) computation engine. - - Bivectors e01/e02/e12 produce rotations; null bivectors e03/e13/e23 - produce translations. Both composed into a single motor via exp map. - """ - - def __init__(self, algebra_cpu: CliffordAlgebra, K_color: int = 4): - super().__init__() - assert algebra_cpu.p == 3 and algebra_cpu.r == 1, \ - f"GeometricCPU requires Cl(3,0,1), got Cl({algebra_cpu.p},{algebra_cpu.q},{algebra_cpu.r})" - self.algebra = algebra_cpu - self.color_unit = ColorUnit(K_color) - - def _transform(self, state: torch.Tensor, instruction: torch.Tensor) -> torch.Tensor: - """Motor sandwich + color remapping. [L, N, 16] -> [L, N, 16].""" - L, N, D = state.shape - - bv = self.algebra.grade_projection(instruction, 2) - M = self.algebra.exp(-0.5 * bv) - M_rev = self.algebra.reverse(M) - - M_exp = M.unsqueeze(1).expand(L, N, D).reshape(L * N, D) - M_rev_exp = M_rev.unsqueeze(1).expand(L, N, D).reshape(L * N, D) - state_flat = state.reshape(L * N, 1, D) - - spatial_out = self.algebra.sandwich_product( - M_exp, state_flat, M_rev_exp - ).reshape(L, N, D) - - return self.color_unit(spatial_out, instruction) - - def execute(self, state: torch.Tensor, instruction: torch.Tensor) -> torch.Tensor: - """Apply transform to [B, N, 16] state with [B, 16] instruction.""" - self.algebra.ensure_device(state.device) - return self._transform(state, instruction) - - def execute_all(self, state: torch.Tensor, - instructions: torch.Tensor) -> torch.Tensor: - """Execute K instructions batched. [B,N,16] x [B,K,16] -> [B,K,N,16].""" - B, N, D = state.shape - K = instructions.shape[1] - self.algebra.ensure_device(state.device) - - state_exp = state.unsqueeze(1).expand(B, K, N, D).reshape(B * K, N, D) - instr_flat = instructions.reshape(B * K, D) - - result = self._transform(state_exp, instr_flat) - return result.reshape(B, K, N, D) diff --git a/models/gtm/gtm_net.py b/models/gtm/gtm_net.py index a8aba98..18647b3 100644 --- a/models/gtm/gtm_net.py +++ b/models/gtm/gtm_net.py @@ -5,27 +5,19 @@ # you may not use this file except in compliance with the License. # -"""GTMNet: Grid-native Geometric Turing Machine for ARC-AGI v4. - -Two-phase few-shot pipeline with Rule Memory Bank: - Phase 1 — Rule Inference: - 1. Encode demo (input,output) pairs -> PGA multivectors - 2. TuringVM processes demo cells -> cpu_state encodes patterns - 3. RuleAggregator compresses demo cpu_state into M rule slots - Phase 2 — Rule Application: - 4. Encode test input -> PGA multivectors - 5. TuringVM processes test cells, using ctrl_cursor + rule_memory - 6. GridReconstructionHead predicts color logits - -Information bridge: M rule slots * 16 dims = 128 floats (vs 4 floats before) -plus the 4D ctrl_cursor for halt control / step navigation. +"""GTMNet: World Model + Search Plane architecture. + +Two-phase few-shot pipeline with Rule Memory Bank. Phase 1 (Rule Inference) +encodes demo pairs, processes them through the WorldModel, and compresses +into rule memory. Phase 2 (Rule Application) encodes the test input, +processes with rule memory, and predicts color logits. """ import torch import torch.nn as nn from core.algebra import CliffordAlgebra from .grid_codec import GridCodec -from .turing_vm import TuringVM +from .world_model import WorldModel from .heads import GridReconstructionHead from .rule_memory import RuleAggregator @@ -33,45 +25,45 @@ class GTMNet(nn.Module): """Grid-native Geometric Turing Machine Network. - Two sub-algebras (Mother algebra removed): + Two sub-algebras: CPU Cl(3,0,1): PGA computation engine (motor + color) - Control Cl(1,1): learnable search controller + Control Cl(1,1): hypothesis search plane """ def __init__( self, algebra_cpu: CliffordAlgebra, algebra_ctrl: CliffordAlgebra, - channels: int = 16, - num_steps: int = 8, - max_steps: int = 20, - num_hypotheses: int = 4, - top_k: int = 1, - head_hidden: int = 64, - temperature_init: float = 1.0, - use_act: bool = False, - lambda_p: float = 0.5, + channels: int = 32, + num_steps: int = 12, + max_steps: int = 24, + num_hypotheses: int = 8, + head_hidden: int = 128, coord_scale: float = 1.0, - K_color: int = 4, num_attn_heads: int = 4, attn_head_dim: int = 8, num_rule_slots: int = 8, + num_memory_channels: int = 4, + weight_share_steps: bool = False, + log_manifold_gate_init: float = -5.0, + evolve_hidden: int = 64, + halt_eps: float = 0.01, + use_supervised_fim: bool = True, + action_gate_init: float = 0.0, ): super().__init__() self.algebra_cpu = algebra_cpu self.algebra_ctrl = algebra_ctrl self.channels = channels + self.num_memory_channels = num_memory_channels D_cpu = algebra_cpu.dim # 16 # Grid codec (deterministic, no params) self.codec = GridCodec(algebra_cpu, coord_scale) - # Learnable initial control cursor [4] in Cl(1,1) - self.init_cursor = nn.Parameter(torch.randn(4) * 0.01) - # Learnable role markers injected into geometrically reserved slots: - # idx 4 (e2): reserved auxiliary vector — never used by GridCodec + # idx 4 (e2): reserved auxiliary vector # idx 15 (pseudoscalar e0123): occupancy/role flag # Shape: [3, 2] for (e2_value, pseudoscalar_value) per role # role 0 = demo input, role 1 = demo output, role 2 = test input @@ -82,32 +74,37 @@ def __init__( d_cpu=D_cpu, num_slots=num_rule_slots, num_heads=num_attn_heads, ) - # Turing VM - self.vm = TuringVM( + self.world_model = WorldModel( algebra_cpu, algebra_ctrl, - channels, num_steps, max_steps, - num_hypotheses, top_k, temperature_init, - use_act, lambda_p, - num_attn_heads, attn_head_dim, - K_color, num_rule_slots, + num_steps=num_steps, + max_steps=max_steps, + num_hypotheses=num_hypotheses, + num_attn_heads=num_attn_heads, + attn_head_dim=attn_head_dim, + num_rule_slots=num_rule_slots, + evolve_hidden=evolve_hidden, + gate_init=action_gate_init, + log_manifold_gate_init=log_manifold_gate_init, + halt_eps=halt_eps, + use_supervised_fim=use_supervised_fim, + weight_share_steps=weight_share_steps, ) # Reconstruction head - self.head = GridReconstructionHead(algebra_cpu, head_hidden) + self.head = GridReconstructionHead( + algebra_cpu, head_hidden, num_memory_channels, + ) def forward(self, demo_inputs: torch.Tensor, demo_outputs: torch.Tensor, demo_masks: torch.Tensor, test_inputs: torch.Tensor, test_masks: torch.Tensor, num_demos: torch.Tensor, demo_output_masks: torch.Tensor = None, + test_targets: torch.Tensor = None, input_sizes: list = None, return_trace: bool = False) -> dict: """Two-phase forward pass: Rule Inference -> Rule Application. - Phase 1 processes demo pairs through the VM to encode transformation - patterns. RuleAggregator compresses these into rule_memory slots. - Phase 2 processes test input using ctrl_cursor + rule_memory. - Args: demo_inputs: [B, K, H_max, W_max] padded demo input grids. demo_outputs: [B, K, H_max, W_max] padded demo output grids. @@ -116,15 +113,12 @@ def forward(self, demo_inputs: torch.Tensor, demo_outputs: torch.Tensor, test_masks: [B, H_max, W_max] bool (True=valid). num_demos: [B] int — actual demo count per example. demo_output_masks: [B, K, H_max, W_max] bool (True=valid output cell). - If None, falls back to demo_masks (same dims assumed). + test_targets: [B, H_max, W_max] optional targets for supervised FIM. input_sizes: Optional list of (H, W) for test inputs. return_trace: Collect per-step diagnostics. Returns: - dict with: - 'logits': [B, N_test, 10] color logits for test cells - 'test_flat_masks': [B, N_test] bool - optionally 'act_info', 'trace' + dict with 'logits', 'test_flat_masks', 'world_model_info', optionally 'trace'. """ B, K, H_max, W_max = demo_inputs.shape N_grid = H_max * W_max @@ -134,113 +128,105 @@ def forward(self, demo_inputs: torch.Tensor, demo_outputs: torch.Tensor, if demo_output_masks is None: demo_output_masks = demo_masks - # --- Encode demo pairs --- di_flat = demo_inputs.reshape(B * K, H_max, W_max) do_flat = demo_outputs.clamp(min=0).reshape(B * K, H_max, W_max) dim_flat = demo_masks.reshape(B * K, H_max, W_max) dom_flat = demo_output_masks.reshape(B * K, H_max, W_max) - di_mv, di_fm = self.codec.encode_batch(di_flat, dim_flat) # [B*K, N_grid, 16] - do_mv, do_fm = self.codec.encode_batch(do_flat, dom_flat) # [B*K, N_grid, 16] + di_mv, di_fm = self.codec.encode_batch(di_flat, dim_flat) + do_mv, do_fm = self.codec.encode_batch(do_flat, dom_flat) - # Add role markers into reserved slots + # Add role markers di_mv[:, :, 4] = di_mv[:, :, 4] + self.role_embed[0, 0] di_mv[:, :, 15] = di_mv[:, :, 15] + self.role_embed[0, 1] do_mv[:, :, 4] = do_mv[:, :, 4] + self.role_embed[1, 0] do_mv[:, :, 15] = do_mv[:, :, 15] + self.role_embed[1, 1] - # Interleave demo input + output: [B*K, 2*N_grid, 16] + # Interleave: [B*K, 2*N_grid, 16] -> [B, K*2*N_grid, 16] demo_mv = torch.cat([di_mv, do_mv], dim=1) demo_fm = torch.cat([di_fm, do_fm], dim=1) - # Reshape: [B, K * 2 * N_grid, 16] N_demo_per_pair = 2 * N_grid demo_mv = demo_mv.reshape(B, K * N_demo_per_pair, D_cpu) demo_fm = demo_fm.reshape(B, K * N_demo_per_pair) - # Mask out unused demo pairs — vectorized, no .item() calls + # Mask out unused demo pairs total_demo_len = K * N_demo_per_pair - pos_idx = torch.arange(total_demo_len, device=device).unsqueeze(0) # [1, L] - limit = (num_demos * N_demo_per_pair).unsqueeze(1) # [B, 1] - valid_demo = pos_idx < limit # [B, L] + pos_idx = torch.arange(total_demo_len, device=device).unsqueeze(0) + limit = (num_demos * N_demo_per_pair).unsqueeze(1) + valid_demo = pos_idx < limit demo_mv = demo_mv * valid_demo.unsqueeze(-1).float() demo_fm = demo_fm & valid_demo - # --- Encode test input --- test_mv, test_fm = self.codec.encode_batch(test_inputs, test_masks) test_mv[:, :, 4] = test_mv[:, :, 4] + self.role_embed[2, 0] test_mv[:, :, 15] = test_mv[:, :, 15] + self.role_embed[2, 1] - # --- Init control cursor --- - ctrl_cursor = self.init_cursor.unsqueeze(0).expand(B, -1).clone() + # Flatten targets for supervised FIM + flat_targets = None + if test_targets is not None: + flat_targets = test_targets.reshape(B, N_grid) - # === Phase 1: Rule Inference (demo only) === - # VM processes demo pairs -> cpu_state encodes patterns, ctrl_cursor updated - demo_state, ctrl_cursor, act_info_demo, trace_demo = self.vm( - demo_mv, ctrl_cursor, demo_fm, return_trace, + demo_result = self.world_model( + demo_mv, demo_fm, return_trace=return_trace, ) + demo_state = demo_result['output'] - # Compress demo state into rule memory slots - rule_memory = self.rule_aggregator(demo_state, demo_fm) # [B, M, 16] + # Compress demo state into rule memory + rule_memory = self.rule_aggregator(demo_state, demo_fm) - # === Phase 2: Rule Application (test only) === - # VM processes test input using ctrl_cursor + rule_memory from Phase 1 - test_state, ctrl_cursor, act_info_test, trace_test = self.vm( - test_mv, ctrl_cursor, test_fm, return_trace, + test_result = self.world_model( + test_mv, test_fm, rule_memory=rule_memory, + targets=flat_targets, + return_trace=return_trace, ) + test_state = test_result['output'] - # --- Decode --- - logits = self.head(test_state, test_fm) # [B, N_grid, 10] + logits = self.head(test_state, test_fm) result = { 'logits': logits, 'test_flat_masks': test_fm, + 'world_model_info': { + 'step_deltas': test_result['step_deltas'], + 'step_weights': test_result['step_weights'], + 'mixing_weights': test_result['mixing_weights'], + 'hypotheses': test_result['hypotheses'], + 'R_accum': test_result['R_accum'], + }, } - # ACT info: combine KL loss from both phases - if act_info_test is not None: - result['act_info'] = { - 'kl_loss': act_info_test['kl_loss'] + act_info_demo['kl_loss'], - 'expected_steps': act_info_test['expected_steps'], - 'weights': act_info_test['weights'], - } - - # Merge traces from both phases + # Merge traces if return_trace: - trace_keys = ['search_scores', 'search_weights', 'halt_probs', - 'cursors', 'gate_values'] - trace = {k: [] for k in trace_keys} - for t in (trace_demo, trace_test): - if t is not None: - for k in trace_keys: - trace[k].extend(t.get(k, [])) - result['trace'] = trace + result['trace'] = { + 'demo': demo_result.get('trace'), + 'test': test_result.get('trace'), + } return result def set_temperature(self, tau: float): - """Set Gumbel-Softmax temperature for all VM steps.""" - self.vm.set_temperature(tau) + """Set softmax temperature for all WorldModel steps.""" + self.world_model.set_temperature(tau) - def freeze_vm(self): - """Freeze all VM parameters (Phase 1: warmup).""" - for param in self.vm.parameters(): + def freeze_world_model(self): + """Freeze all WorldModel parameters for warmup.""" + for param in self.world_model.parameters(): param.requires_grad = False - def unfreeze_vm(self): - """Unfreeze all VM parameters (Phase 2+).""" - for param in self.vm.parameters(): + def unfreeze_world_model(self): + """Unfreeze all WorldModel parameters.""" + for param in self.world_model.parameters(): param.requires_grad = True - def enable_act(self): - """Enable adaptive computation time.""" - if self.vm.adaptive_halt is not None: - self.vm.use_act = True + def enable_fim_halt(self): + """Enable FIM-based adaptive halt.""" + self.world_model.use_fim_halt = True - def disable_act(self): - """Disable adaptive computation time.""" - self.vm.use_act = False + def disable_fim_halt(self): + """Disable FIM-based adaptive halt.""" + self.world_model.use_fim_halt = False def trainable_parameters(self): for param in self.parameters(): diff --git a/models/gtm/heads.py b/models/gtm/heads.py index a85839f..876a2ef 100644 --- a/models/gtm/heads.py +++ b/models/gtm/heads.py @@ -5,7 +5,7 @@ # you may not use this file except in compliance with the License. # -"""Grid reconstruction head for ARC-AGI v4. +"""Grid reconstruction head for ARC-AGI. Per-cell color classification from final CPU state multivectors. """ @@ -19,13 +19,17 @@ class GridReconstructionHead(nn.Module): """Per-cell color classification from Cl(3,0,1) multivectors. Maps each cell's 16-component multivector to 10-class color logits. + Optionally masks out memory channels before decoding. """ - def __init__(self, algebra_cpu: CliffordAlgebra, hidden_dim: int = 64): + def __init__(self, algebra_cpu: CliffordAlgebra, hidden_dim: int = 64, + num_memory_channels: int = 0): super().__init__() self.algebra = algebra_cpu + self.num_memory_channels = num_memory_channels + input_dim = algebra_cpu.dim - num_memory_channels self.mlp = nn.Sequential( - nn.Linear(algebra_cpu.dim, hidden_dim), + nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), @@ -37,7 +41,7 @@ def forward(self, cpu_state: torch.Tensor, """Predict per-cell color logits. Args: - cpu_state: Final CPU state [B, N, 16]. + cpu_state: Final CPU state [B, N, D]. mask: Optional validity mask [B, N] (True=valid). Not used in forward (handled by loss function ignore_index), but kept for interface compatibility. @@ -45,4 +49,6 @@ def forward(self, cpu_state: torch.Tensor, Returns: Logits [B, N, 10]. """ + if self.num_memory_channels > 0: + cpu_state = cpu_state[..., :-self.num_memory_channels] return self.mlp(cpu_state) diff --git a/models/gtm/info_geometry.py b/models/gtm/info_geometry.py new file mode 100644 index 0000000..7506fce --- /dev/null +++ b/models/gtm/info_geometry.py @@ -0,0 +1,111 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""FIMEvaluator: target-free Fisher Information proxy via grade-wise variance. + +Scores hypotheses by measuring how much each deviates from the mean +across hypotheses, weighted by learnable per-grade importance. Higher +FIM proxy = more informative hypothesis (more distinctive structure). +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.algebra import CliffordAlgebra +from core.metric import hermitian_norm +from layers.primitives.base import CliffordModule + + +class FIMEvaluator(CliffordModule): + """Fisher Information Matrix proxy using grade-wise variance across hypotheses. + + For each hypothesis k, measures how its grade-wise structure deviates from + the mean across all hypotheses. Learnable grade_weights control per-grade + importance. Optionally provides supervised FIM using target comparison. + """ + + def __init__(self, algebra: CliffordAlgebra): + super().__init__(algebra) + self.grade_weights = nn.Parameter(torch.zeros(algebra.num_grades)) + + def fim_proxy(self, states: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """Unsupervised FIM proxy: grade-wise variance across hypotheses. + + Args: + states: Candidate states [B, K, N, D]. + mask: Optional validity mask [B, N] (True=valid). + + Returns: + FIM proxy values [B, K]. + """ + B, K, N, D = states.shape + + # Mean across hypotheses per position + mean_state = states.mean(dim=1, keepdim=True) # [B, 1, N, D] + + # Per-grade deviation for each hypothesis + diff = states - mean_state # [B, K, N, D] + weights = F.softplus(self.grade_weights) # [num_grades], positive + + # Compute per-grade squared norm of deviation + fim = torch.zeros(B, K, device=states.device, dtype=states.dtype) + for g in range(self.algebra.num_grades): + diff_g = self.algebra.grade_projection( + diff.reshape(B * K * N, D), g + ).reshape(B, K, N, D) + # Hermitian norm squared per position + sq = (diff_g ** 2).sum(dim=-1) # [B, K, N] + if mask is not None: + sq = sq * mask.unsqueeze(1).float() + fim = fim + weights[g] * sq.mean(dim=-1) # [B, K] + + return fim + + def information_gain(self, fim_cur: torch.Tensor, + fim_prev: torch.Tensor) -> torch.Tensor: + """Compute information gain between successive FIM evaluations. + + Args: + fim_cur: Current FIM values [B, K]. + fim_prev: Previous FIM values [B, K]. + + Returns: + Information gain [B, K]. + """ + return fim_cur - fim_prev + + def supervised_fim(self, states: torch.Tensor, + targets: torch.Tensor, + mask: torch.Tensor = None) -> torch.Tensor: + """Training-only: compare grade-0 of each candidate to target encoding. + + Provides stronger signal than the unsupervised proxy by directly + measuring how well each hypothesis's scalar component matches targets. + + Args: + states: Candidate states [B, K, N, D]. + targets: Target color indices [B, N] (long, 0-9). + mask: Optional validity mask [B, N]. + + Returns: + Supervised FIM values [B, K] (higher = better match). + """ + B, K, N, D = states.shape + # Grade-0 of candidates = predicted color signal + pred_color = states[:, :, :, 0] # [B, K, N] + + # Target as normalized float + target_color = targets.float() / 9.0 # [B, N] + + # Negative squared error (higher = better) + sq_err = -(pred_color - target_color.unsqueeze(1)) ** 2 # [B, K, N] + if mask is not None: + sq_err = sq_err * mask.unsqueeze(1).float() + denom = mask.float().sum(dim=-1).clamp(min=1.0) # [B] + return sq_err.sum(dim=-1) / denom.unsqueeze(1) # [B, K] + + return sq_err.mean(dim=-1) # [B, K] diff --git a/models/gtm/log_manifold.py b/models/gtm/log_manifold.py new file mode 100644 index 0000000..b16f8e7 --- /dev/null +++ b/models/gtm/log_manifold.py @@ -0,0 +1,71 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""LogManifoldProjector: mantissa/exponent split with cross-modulation. + +Splits x = mantissa * exp(exponent) for high-depth stability. +The mantissa lives on the unit manifold; the exponent captures log-scale. +Cross-modulation gates (initialized near zero) gradually couple the two +as training progresses. +""" + +import torch +import torch.nn as nn +from core.algebra import CliffordAlgebra +from core.metric import hermitian_norm +from layers.primitives.base import CliffordModule + + +class LogManifoldProjector(CliffordModule): + """Split/merge multivectors into mantissa (unit) and exponent (log-scale). + + Split: x -> (mantissa, exponent) where mantissa = x / ||x||_H, exponent = log(||x||_H) + Merge: (mantissa, exponent) -> mantissa * exp(exponent) with cross-modulation + + Cross-modulation gates start near zero (logit=-5 -> sigmoid~0.007) so the + projector initially behaves as a pure split/merge. As training progresses, + the gates open to enable information flow between scale and direction. + """ + + def __init__(self, algebra: CliffordAlgebra, gate_init: float = -5.0): + super().__init__(algebra) + # Cross-modulation gates (initialized near zero) + self.gate_e = nn.Parameter(torch.tensor(gate_init)) # exponent <- mantissa feedback + self.gate_m = nn.Parameter(torch.tensor(gate_init)) # mantissa <- exponent feedback + + def split(self, x: torch.Tensor) -> tuple: + """Split multivector into unit-norm mantissa and log-scale exponent. + + Args: + x: Multivector [B, N, D]. + + Returns: + (mantissa [B, N, D], exponent [B, N, 1]). + """ + norm = hermitian_norm(self.algebra, x).clamp(min=1e-8) # [B, N, 1] + mantissa = x / norm + exponent = torch.log(norm) + return mantissa, exponent + + def merge(self, mantissa: torch.Tensor, exponent: torch.Tensor) -> torch.Tensor: + """Merge mantissa and exponent back into full multivector. + + Cross-modulation allows information flow between scale and direction: + - gate_e: exponent adjusted by mantissa's scalar component + - gate_m: mantissa scaled by original exponent + + Args: + mantissa: Unit-norm multivector [B, N, D]. + exponent: Log-scale [B, N, 1]. + + Returns: + Reconstructed multivector [B, N, D]. + """ + # Cross-modulation + e_mod = exponent + torch.sigmoid(self.gate_e) * mantissa[..., 0:1] + m_mod = mantissa * (1.0 + torch.sigmoid(self.gate_m) * torch.tanh(exponent)) + return m_mod * torch.exp(e_mod.clamp(-10.0, 10.0)) diff --git a/models/gtm/search_plane.py b/models/gtm/search_plane.py new file mode 100644 index 0000000..14bd994 --- /dev/null +++ b/models/gtm/search_plane.py @@ -0,0 +1,247 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""SearchPlane + Algebraic Projection/Lift for hypothesis management in Cl(1,1). + +Three classes: +- AlgebraicProjection (phi): Cl(3,0,1) -> Cl(1,1) via principled grade-norm decomposition +- AlgebraicLift (psi): Cl(1,1) -> grade-wise multiplicative modulation of Cl(3,0,1) +- SearchPlane: Active hypothesis management via Cl(1,1) hyperbolic rotors + FIM scoring +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.algebra import CliffordAlgebra +from core.metric import _hermitian_signs +from layers.primitives.base import CliffordModule + + +class AlgebraicProjection(nn.Module): + """phi: Cl(3,0,1) world summary -> Cl(1,1) search state. + + Compresses 16D multivector into 4D via principled decomposition: + - Scalar component preserved directly + - Grade norms -> e+ and e- components (energy distribution) + - Bivector coefficients -> e+e- component (relational phase) + """ + + def __init__(self, algebra_cpu: CliffordAlgebra): + super().__init__() + num_grades = algebra_cpu.num_grades # 5 for Cl(3,0,1) + num_bivectors = len(algebra_cpu._bv_indices) # 6 for Cl(3,0,1) + + self.f_plus = nn.Linear(num_grades, 1) + self.f_minus = nn.Linear(num_grades, 1) + self.f_phase = nn.Linear(num_bivectors, 1) + + self._algebra_cpu = algebra_cpu + + def forward(self, world_summary: torch.Tensor) -> torch.Tensor: + """Project world summary to search plane. + + Args: + world_summary: Mean-pooled CPU state [B, 16]. + + Returns: + Search state [B, 4] in Cl(1,1). + """ + self._algebra_cpu.ensure_device(world_summary.device) + grade_norms = self._algebra_cpu.get_grade_norms(world_summary) # [B, 5] + bv_idx = self._algebra_cpu._bv_indices + bv_coeffs = world_summary[:, bv_idx] # [B, 6] + + return torch.stack([ + world_summary[:, 0], # scalar preserved + self.f_plus(grade_norms).squeeze(-1), # e+: positive energy + self.f_minus(grade_norms).squeeze(-1), # e-: negative/null energy + self.f_phase(bv_coeffs).squeeze(-1), # e+e-: relational phase + ], dim=-1) + + +class AlgebraicLift(nn.Module): + """psi: Cl(1,1) hypotheses -> grade-wise modulation of Cl(3,0,1). + + Converts weighted hypothesis mean into per-grade multiplicative scales. + Centered at 1.0 (no effect initially) via 1 + tanh(...), range [0, 2]. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra): + super().__init__() + num_grades = algebra_cpu.num_grades # 5 + self.lift_mlp = nn.Sequential( + nn.Linear(4, 32), + nn.ReLU(), + nn.Linear(32, num_grades), + ) + # Precompute grade masks as float for broadcasting + self._grade_masks = [m.float() for m in algebra_cpu.grade_masks] + self._num_grades = num_grades + self._dim = algebra_cpu.dim + + def forward(self, hypotheses: torch.Tensor, + weights: torch.Tensor) -> torch.Tensor: + """Compute grade-wise modulation from weighted hypotheses. + + Args: + hypotheses: Hypothesis states [B, K, 4]. + weights: Attention weights [B, K]. + + Returns: + Modulation vector [B, D] (multiplicative, centered at 1.0). + """ + # Conviction-weighted mean hypothesis + weighted = torch.einsum('bk,bkd->bd', weights, hypotheses) # [B, 4] + grade_scales = 1.0 + torch.tanh(self.lift_mlp(weighted)) # [B, num_grades] + + B = hypotheses.shape[0] + device = hypotheses.device + modulation = torch.ones(B, self._dim, device=device, dtype=hypotheses.dtype) + + for g in range(self._num_grades): + mask = self._grade_masks[g].to(device=device, dtype=hypotheses.dtype) + modulation = modulation + (grade_scales[:, g:g+1] - 1.0) * mask.unsqueeze(0) + + return modulation + + +class SearchPlane(CliffordModule): + """Active hypothesis management in Cl(1,1). + + Evolves K hypotheses via hyperbolic rotors, scores them using FIM, + and computes soft attention weights. All operations are differentiable + (no hard selection). Temperature controls exploration vs exploitation. + """ + + def __init__(self, algebra_ctrl: CliffordAlgebra, + num_hypotheses: int = 8, + evolve_hidden: int = 64): + assert algebra_ctrl.p == 1 and algebra_ctrl.q == 1, \ + f"SearchPlane requires Cl(1,1), got Cl({algebra_ctrl.p},{algebra_ctrl.q})" + super().__init__(algebra_ctrl) + + K = num_hypotheses + self.num_hypotheses = K + + # Initial hypothesis states in Cl(1,1) + self.hypothesis_init = nn.Parameter(torch.randn(K, 4) * 0.1) + + # Evolution network: context -> boost magnitude + # Input: hypothesis (4) + world_summary projected to 16D -> concatenated + self.evolve_net = nn.Sequential( + nn.Linear(4 + 16, evolve_hidden), + nn.ReLU(), + nn.Linear(evolve_hidden, 1), + ) + + # Temperature buffer (annealed externally) + self.register_buffer('_temperature', torch.tensor(1.0)) + + # Precompute hermitian signs for Gram matrix + self._ctrl_signs = None + + def set_temperature(self, tau: float): + """Set softmax temperature for attention weights.""" + self._temperature.fill_(tau) + + def _get_signs(self, device: torch.device) -> torch.Tensor: + """Get cached hermitian signs for Cl(1,1).""" + if self._ctrl_signs is None or self._ctrl_signs.device != device: + self._ctrl_signs = _hermitian_signs(self.algebra).to(device) + return self._ctrl_signs + + def forward(self, hypotheses: torch.Tensor, + world_summary: torch.Tensor, + fim_values: torch.Tensor, + fim_prev: torch.Tensor = None) -> dict: + """Evolve hypotheses and compute attention weights. + + Args: + hypotheses: Current hypothesis states [B, K, 4]. + world_summary: Mean-pooled CPU state [B, 16]. + fim_values: FIM scores for candidates [B, K]. + fim_prev: Previous FIM values [B, K] or None. + + Returns: + dict with hypotheses, weights, fim_values, delta_info, gram, conviction. + """ + B, K, _ = hypotheses.shape + device = hypotheses.device + self.algebra.ensure_device(device) + + world_exp = world_summary.unsqueeze(1).expand(B, K, -1).reshape(B * K, -1) + h_flat = hypotheses.reshape(B * K, 4) + ctx = torch.cat([h_flat, world_exp], dim=-1) # [B*K, 20] + raw_theta = self.evolve_net(ctx).reshape(B, K) + # Smooth bounding via tanh: always has gradient, range [-3, 3] + theta = torch.tanh(raw_theta) * 3.0 + + # Build e+e- bivector (index 3 in Cl(1,1)) + bv = torch.zeros(B * K, 4, device=device, dtype=hypotheses.dtype) + bv[:, 3] = theta.reshape(B * K) + + # Exponentiate and sandwich + R = self.algebra.exp(-0.5 * bv) # [B*K, 4] + R_rev = self.algebra.reverse(R) + evolved = self.algebra.geometric_product( + self.algebra.geometric_product(R, h_flat), R_rev + ) + # Symmlog: prevents unbounded drift, gradient = 1/(1+|x|), never zero + evolved = torch.sign(evolved) * torch.log1p(evolved.abs()) + hypotheses = evolved.reshape(B, K, 4) + + delta_info = fim_values - fim_prev if fim_prev is not None else fim_values + + tau = self._temperature.clamp(min=0.01) + weights = F.softmax(fim_values / tau, dim=-1) # [B, K] + + gram = self.hermitian_gram(hypotheses) # [B, K, K] + + conviction = weights.max(dim=-1).values # [B] + + return { + 'hypotheses': hypotheses, + 'weights': weights, + 'fim_values': fim_values, + 'delta_info': delta_info, + 'gram': gram, + 'conviction': conviction, + } + + def hermitian_gram(self, hypotheses: torch.Tensor) -> torch.Tensor: + """Compute Hermitian Gram matrix for hypothesis orthogonality. + + Args: + hypotheses: [B, K, 4]. + + Returns: + Gram matrix [B, K, K]. + """ + signs = self._get_signs(hypotheses.device).to(dtype=hypotheses.dtype) + h_signed = hypotheses * signs # [B, K, 4] + return torch.einsum('bkd,bld->bkl', h_signed, hypotheses) + + @staticmethod + def orthogonality_loss(gram: torch.Tensor) -> torch.Tensor: + """Compute orthogonality loss from Gram matrix. + + L_ortho = ||G_normalized - I||_F^2 + + Args: + gram: Gram matrix [B, K, K]. + + Returns: + Scalar loss. + """ + K = gram.shape[1] + device = gram.device + # Normalize: G_norm[i,j] = G[i,j] / sqrt(|G[i,i]| * |G[j,j]|) + diag = gram.diagonal(dim1=-2, dim2=-1).abs().clamp(min=1e-8) # [B, K] + norm_factor = torch.sqrt(diag.unsqueeze(-1) * diag.unsqueeze(-2)) # [B, K, K] + gram_norm = gram / norm_factor + eye = torch.eye(K, device=device, dtype=gram.dtype).unsqueeze(0) + return ((gram_norm - eye) ** 2).mean() diff --git a/models/gtm/superposition.py b/models/gtm/superposition.py deleted file mode 100644 index a927800..0000000 --- a/models/gtm/superposition.py +++ /dev/null @@ -1,107 +0,0 @@ -# Versor: Universal Geometric Algebra Neural Network -# Copyright (C) 2026 Eunkyum Kim -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# - -"""Geometric Superposition Search: score, dispatch, execute, select.""" - -import torch -import torch.nn as nn -import torch.nn.functional as F -from core.algebra import CliffordAlgebra -from .cpu import GeometricCPU - - -class GeometricSuperpositionSearch(nn.Module): - """Scores K hypotheses via CPU grade norms, executes PGA motor transforms - in parallel, and selects via Gumbel-Softmax. Instruction templates are - optionally modulated by rule memory. - """ - - def __init__(self, algebra_cpu: CliffordAlgebra, - algebra_ctrl: CliffordAlgebra, - channels: int, - num_hypotheses: int = 4, - top_k: int = 1, - temperature_init: float = 1.0, - K_color: int = 4, - num_rule_slots: int = 8): - super().__init__() - self.algebra_cpu = algebra_cpu - self.algebra_ctrl = algebra_ctrl - self.channels = channels - self.num_hypotheses = num_hypotheses - self.top_k = top_k - - D_cpu = algebra_cpu.dim - - self.pga_cpu = GeometricCPU(algebra_cpu, K_color) - self.instruction_templates = nn.Parameter( - torch.randn(num_hypotheses, D_cpu) * 0.1 - ) - self.score_mlp = nn.Sequential( - nn.Linear(algebra_cpu.num_grades + algebra_ctrl.dim, 64), - nn.ReLU(), - nn.Linear(64, num_hypotheses), - ) - # Per-cell routing: each cell scores hypotheses independently - self.cell_router = nn.Linear(D_cpu, num_hypotheses) - # Rule memory bias on hypothesis scores - self.rule_score_proj = nn.Linear(D_cpu, num_hypotheses) - - # Small-weight init so initial behavior ≈ old global-only scoring - nn.init.normal_(self.cell_router.weight, std=0.01) - nn.init.zeros_(self.cell_router.bias) - nn.init.normal_(self.rule_score_proj.weight, std=0.01) - nn.init.zeros_(self.rule_score_proj.bias) - - self.rule_proj = nn.Linear(D_cpu, num_hypotheses * D_cpu) - self.register_buffer('_temperature', torch.tensor(float(temperature_init))) - - def set_temperature(self, tau: float): - """Set Gumbel-Softmax temperature (called by external annealing schedule).""" - self._temperature.fill_(tau) - - def step(self, cpu_state: torch.Tensor, - ctrl_cursor: torch.Tensor, - rule_memory: torch.Tensor = None) -> tuple: - """One search step. Returns (new_cpu_state, search_info).""" - B, N, D_cpu = cpu_state.shape - device = cpu_state.device - K = self.num_hypotheses - - cpu_summary = cpu_state.mean(dim=1) - self.algebra_cpu.ensure_device(device) - grade_norms = self.algebra_cpu.get_grade_norms(cpu_summary) - - # Per-cell logits + global bias from cursor/grade norms - cell_logits = self.cell_router(cpu_state) # [B, N, K] - global_bias = self.score_mlp( - torch.cat([grade_norms, ctrl_cursor], dim=-1) - ) # [B, K] - scores = cell_logits + global_bias.unsqueeze(1) # [B, N, K] - - templates = self.instruction_templates.unsqueeze(0).expand(B, -1, -1) - if rule_memory is not None: - rule_summary = rule_memory.mean(dim=1) - rule_modulation = self.rule_proj(rule_summary).view(B, K, D_cpu) - templates = templates + rule_modulation - # Rule memory biases scoring (which instructions cells prefer) - rule_score_bias = self.rule_score_proj(rule_summary) # [B, K] - scores = scores + rule_score_bias.unsqueeze(1) # [B, N, K] - - outcomes = self.pga_cpu.execute_all(cpu_state, templates) # [B, K, N, D] - - tau = self._temperature.clamp(0.1, 5.0) - weights = F.gumbel_softmax( - scores.reshape(B * N, K), tau=tau, hard=False - ).reshape(B, N, K) # [B, N, K] - new_cpu_state = torch.einsum('bnk,bknd->bnd', weights, outcomes) - - return new_cpu_state, { - 'scores': scores, - 'weights': weights, - 'temperature': tau.detach(), - } diff --git a/models/gtm/turing_step.py b/models/gtm/turing_step.py deleted file mode 100644 index e1892c8..0000000 --- a/models/gtm/turing_step.py +++ /dev/null @@ -1,154 +0,0 @@ -# Versor: Universal Geometric Algebra Neural Network -# Copyright (C) 2026 Eunkyum Kim -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# - -"""Single GTM step: SuperpositionSearch + Cross-Grade Attention + ControlPlane.""" - -import torch -import torch.nn as nn -import torch.nn.functional as F -from core.algebra import CliffordAlgebra -from layers.primitives.normalization import CliffordLayerNorm -from .superposition import GeometricSuperpositionSearch -from .control_plane import ControlPlane - - -_GRADE_MAP_16 = torch.zeros(16, dtype=torch.long) -_GRADE_MAP_16[0] = 0 -_GRADE_MAP_16[[1, 2, 4, 8]] = 1 -_GRADE_MAP_16[[3, 5, 6, 9, 10, 12]] = 2 -_GRADE_MAP_16[[7, 11, 13, 14]] = 3 -_GRADE_MAP_16[15] = 4 - - -class CellAttention(nn.Module): - """Cross-grade self-attention over grid cells in Cl(3,0,1).""" - - def __init__(self, algebra_cpu: CliffordAlgebra, num_heads: int = 4, - head_dim: int = 8, dropout: float = 0.0): - super().__init__() - D = algebra_cpu.dim - attn_dim = num_heads * head_dim - - self.num_heads = num_heads - self.head_dim = head_dim - self.scale = head_dim ** -0.5 - - self.q_proj = nn.Linear(D, attn_dim) - self.k_proj = nn.Linear(D, attn_dim) - self.v_gain = nn.ParameterDict({ - f'g{k}': nn.Parameter(torch.ones(1)) for k in range(5) - }) - self.dropout = nn.Dropout(dropout) - self.register_buffer('grade_map', _GRADE_MAP_16.clone()) - - def _apply_grade_gains(self, x: torch.Tensor) -> torch.Tensor: - """Apply per-grade isotropic gains to multivector components.""" - gains = torch.ones(16, device=x.device, dtype=x.dtype) - for k in range(5): - mask = self.grade_map == k - gains[mask] = self.v_gain[f'g{k}'] - return x * gains - - def forward(self, x: torch.Tensor, - mask: torch.Tensor = None) -> torch.Tensor: - """[B, N, 16] -> [B, N, 16] with optional mask [B, N].""" - B, N, D = x.shape - - Q = self.q_proj(x) - K = self.k_proj(x) - - Q = Q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) - K = K.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) - - scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale - - if mask is not None: - scores = scores.masked_fill( - (~mask).unsqueeze(1).unsqueeze(2), float('-inf') - ) - - attn = F.softmax(scores, dim=-1) - attn = self.dropout(attn) - attn_avg = attn.mean(dim=1) - - attended = torch.bmm(attn_avg, x) - return self._apply_grade_gains(attended) - - -class TuringStep(nn.Module): - """One step of the Geometric Turing Machine.""" - - def __init__(self, algebra_cpu: CliffordAlgebra, - algebra_ctrl: CliffordAlgebra, - channels: int, - num_hypotheses: int = 4, - top_k: int = 1, - temperature_init: float = 1.0, - num_attn_heads: int = 4, - attn_head_dim: int = 8, - attn_dropout: float = 0.0, - K_color: int = 4, - num_rule_slots: int = 8): - super().__init__() - self.channels = channels - D_cpu = algebra_cpu.dim - - self.cell_attn = CellAttention( - algebra_cpu, num_attn_heads, attn_head_dim, attn_dropout, - ) - self.search = GeometricSuperpositionSearch( - algebra_cpu, algebra_ctrl, - channels, num_hypotheses, top_k, temperature_init, - K_color, num_rule_slots, - ) - self.control = ControlPlane(algebra_ctrl, channels) - self.norm = CliffordLayerNorm(algebra_cpu, 1) - self.context_proj = nn.Linear(D_cpu, channels) - # Per-component gate: enables cross-grade mixing. - # With scalar gate (old), color (g0) and position (g1) always move - # together. Per-component gate lets the model selectively update - # color based on position context and vice versa. - self.write_gate = nn.Sequential( - nn.Linear(D_cpu * 2, 64), - nn.ReLU(), - nn.Linear(64, D_cpu), - ) - - def set_temperature(self, tau: float): - self.search.set_temperature(tau) - - def forward(self, cpu_state: torch.Tensor, - ctrl_cursor: torch.Tensor, - mask: torch.Tensor = None, - rule_memory: torch.Tensor = None) -> dict: - old_state = cpu_state - - attended = self.cell_attn(cpu_state, mask) - new_cpu, search_info = self.search.step(attended, ctrl_cursor, rule_memory) - - gate_input = torch.cat([old_state, new_cpu], dim=-1) - gate = torch.sigmoid(self.write_gate(gate_input)) - new_cpu = gate * new_cpu + (1.0 - gate) * old_state - - B, N, D = new_cpu.shape - new_cpu_flat = new_cpu.reshape(B * N, 1, D) - new_cpu_flat = self.norm(new_cpu_flat) - new_cpu = new_cpu_flat.reshape(B, N, D) - - cpu_summary = new_cpu.mean(dim=1) - cpu_context = self.context_proj(cpu_summary) - new_cursor, direction_logit, halt_prob = self.control.step( - ctrl_cursor, cpu_context - ) - - return { - 'cpu_state': new_cpu, - 'ctrl_cursor': new_cursor, - 'halt_prob': halt_prob, - 'search_info': search_info, - 'gate_values': gate, - } diff --git a/models/gtm/turing_vm.py b/models/gtm/turing_vm.py deleted file mode 100644 index 791a4fb..0000000 --- a/models/gtm/turing_vm.py +++ /dev/null @@ -1,151 +0,0 @@ -# Versor: Universal Geometric Algebra Neural Network -# Copyright (C) 2026 Eunkyum Kim -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# - -"""Geometric Turing Machine execution engine.""" - -import torch -import torch.nn as nn -from core.algebra import CliffordAlgebra -from layers.primitives.normalization import CliffordLayerNorm -from .turing_step import TuringStep -from .adaptive_halt import AdaptiveHalt - - -class TuringVM(nn.Module): - """Chains TuringSteps with dual-state threading and optional PonderNet halting.""" - - def __init__(self, algebra_cpu: CliffordAlgebra, - algebra_ctrl: CliffordAlgebra, - channels: int, - num_steps: int = 8, - max_steps: int = 20, - num_hypotheses: int = 4, - top_k: int = 1, - temperature_init: float = 1.0, - use_act: bool = False, - lambda_p: float = 0.5, - num_attn_heads: int = 4, - attn_head_dim: int = 8, - K_color: int = 4, - num_rule_slots: int = 8): - super().__init__() - self.channels = channels - self.num_steps = num_steps - self.max_steps = max_steps - self.use_act = use_act - - effective_steps = max_steps if use_act else num_steps - self.steps = nn.ModuleList([ - TuringStep( - algebra_cpu, algebra_ctrl, - channels, num_hypotheses, top_k, temperature_init, - num_attn_heads, attn_head_dim, 0.0, - K_color, num_rule_slots, - ) - for _ in range(effective_steps) - ]) - - self.adaptive_halt = AdaptiveHalt(lambda_p, max_steps) if use_act else None - self.final_norm = CliffordLayerNorm(algebra_cpu, 1) - - def set_temperature(self, tau: float): - for step in self.steps: - step.set_temperature(tau) - - def forward(self, cpu_state: torch.Tensor, ctrl_cursor: torch.Tensor, - mask: torch.Tensor = None, - return_trace: bool = False, - rule_memory: torch.Tensor = None) -> tuple: - """Execute the GTM program. - - Args: - cpu_state: Initial CPU state [B, N, 16]. - ctrl_cursor: Initial control cursor [B, 4]. - mask: Optional validity mask [B, N] (True=valid). - return_trace: If True, collect per-step diagnostics. - rule_memory: Optional [B, M, 16] rule slots from RuleAggregator. - - Returns: - Tuple of (cpu_state, ctrl_cursor, act_info or None, trace or None). - """ - trace = { - 'search_scores': [], - 'search_weights': [], - 'halt_probs': [], - 'cursors': [], - 'gate_values': [], - } if return_trace else None - - if self.use_act: - return self._forward_act(cpu_state, ctrl_cursor, mask, trace, rule_memory) - else: - return self._forward_fixed(cpu_state, ctrl_cursor, mask, trace, rule_memory) - - def _forward_fixed(self, cpu_state, ctrl_cursor, mask, trace, rule_memory): - """Fixed-step execution.""" - for i in range(self.num_steps): - result = self.steps[i](cpu_state, ctrl_cursor, mask, rule_memory) - cpu_state = result['cpu_state'] - ctrl_cursor = result['ctrl_cursor'] - - if trace is not None: - trace['search_scores'].append(result['search_info']['scores'].detach()) - trace['search_weights'].append(result['search_info']['weights'].detach()) - trace['halt_probs'].append(result['halt_prob'].detach()) - trace['cursors'].append(ctrl_cursor.detach()) - trace['gate_values'].append(result['gate_values'].detach()) - - # Final norm - B, N, D = cpu_state.shape - cpu_state = self.final_norm( - cpu_state.reshape(B * N, 1, D) - ).reshape(B, N, D) - - return cpu_state, ctrl_cursor, None, trace - - def _forward_act(self, cpu_state, ctrl_cursor, mask, trace, rule_memory): - """Adaptive computation with PonderNet halting.""" - per_step_outputs = [] - halt_probs = [] - - for i, step in enumerate(self.steps): - result = step(cpu_state, ctrl_cursor, mask, rule_memory) - cpu_state = result['cpu_state'] - ctrl_cursor = result['ctrl_cursor'] - - per_step_outputs.append(cpu_state) - halt_probs.append(result['halt_prob']) - - if trace is not None: - trace['search_scores'].append(result['search_info']['scores'].detach()) - trace['search_weights'].append(result['search_info']['weights'].detach()) - trace['halt_probs'].append(result['halt_prob'].detach()) - trace['cursors'].append(ctrl_cursor.detach()) - trace['gate_values'].append(result['gate_values'].detach()) - - # Compute ACT mixing weights - act_result = self.adaptive_halt(halt_probs) - weights = act_result['weights'] # [B, T] - - # Weighted sum of per-step CPU states via einsum (no Python loop) - stacked = torch.stack(per_step_outputs, dim=1) # [B, T, N, D] - output = torch.einsum('bt,btnd->bnd', weights, stacked) - - # Final norm - B, N, D = output.shape - output = self.final_norm( - output.reshape(B * N, 1, D) - ).reshape(B, N, D) - - act_info = { - 'kl_loss': act_result['kl_loss'], - 'expected_steps': act_result['expected_steps'], - 'weights': act_result['weights'], - } - - # ctrl_cursor is from last step (not mixed — control is sequential) - return output, ctrl_cursor, act_info, trace diff --git a/models/gtm/world_model.py b/models/gtm/world_model.py new file mode 100644 index 0000000..056972c --- /dev/null +++ b/models/gtm/world_model.py @@ -0,0 +1,426 @@ +# Versor: Universal Geometric Algebra Neural Network +# Copyright (C) 2026 Eunkyum Kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# + +"""WorldModel + WorldModelStep + CellAttention: the core computation loop. + +CellAttention provides spatial context via cross-grade self-attention. +WorldModelStep chains attention, action proposals, search, modulation, +gating, and rotor accumulation. WorldModel wraps the step loop with +log-manifold stability and FIM-based adaptive halt. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.algebra import CliffordAlgebra +from layers.primitives.normalization import CliffordLayerNorm +from .search_plane import AlgebraicProjection, AlgebraicLift, SearchPlane +from .action_engine import ActionEngine +from .info_geometry import FIMEvaluator +from .log_manifold import LogManifoldProjector +from .adaptive_halt import FIMAdaptiveHalt + + +_GRADE_MAP_16 = torch.zeros(16, dtype=torch.long) +_GRADE_MAP_16[0] = 0 +_GRADE_MAP_16[[1, 2, 4, 8]] = 1 +_GRADE_MAP_16[[3, 5, 6, 9, 10, 12]] = 2 +_GRADE_MAP_16[[7, 11, 13, 14]] = 3 +_GRADE_MAP_16[15] = 4 + + +class CellAttention(nn.Module): + """Cross-grade self-attention over grid cells in Cl(3,0,1). + + Multi-head attention where Q/K are projected from the full multivector + and values are the raw multivector with per-grade learnable gains. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, num_heads: int = 4, + head_dim: int = 8, dropout: float = 0.0): + super().__init__() + D = algebra_cpu.dim + attn_dim = num_heads * head_dim + + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + self.q_proj = nn.Linear(D, attn_dim) + self.k_proj = nn.Linear(D, attn_dim) + self.v_gain = nn.ParameterDict({ + f'g{k}': nn.Parameter(torch.ones(1)) for k in range(5) + }) + self.dropout = nn.Dropout(dropout) + self.register_buffer('grade_map', _GRADE_MAP_16.clone()) + + def _apply_grade_gains(self, x: torch.Tensor) -> torch.Tensor: + """Apply per-grade isotropic gains to multivector components.""" + gains = torch.ones(16, device=x.device, dtype=x.dtype) + for k in range(5): + mask = self.grade_map == k + gains[mask] = self.v_gain[f'g{k}'] + return x * gains + + def forward(self, x: torch.Tensor, + mask: torch.Tensor = None) -> torch.Tensor: + """[B, N, 16] -> [B, N, 16] with optional mask [B, N].""" + B, N, D = x.shape + + Q = self.q_proj(x) + K = self.k_proj(x) + + Q = Q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + K = K.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2) + + scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale + + if mask is not None: + scores = scores.masked_fill( + (~mask).unsqueeze(1).unsqueeze(2), float('-inf'), + ) + + attn = F.softmax(scores, dim=-1) + attn = self.dropout(attn) + attn_avg = attn.mean(dim=1) + + attended = torch.bmm(attn_avg, x) + return self._apply_grade_gains(attended) + + +def _identity_rotor(batch_size: int, dim: int, device: torch.device, + dtype: torch.dtype) -> torch.Tensor: + """Create identity rotor (scalar=1, rest=0).""" + R = torch.zeros(batch_size, dim, device=device, dtype=dtype) + R[:, 0] = 1.0 + return R + + +def _normalize_rotor(algebra: CliffordAlgebra, R: torch.Tensor) -> torch.Tensor: + """Normalize rotor to stay on the Spin group: R / sqrt(|R R~|_0).""" + R_rev = algebra.reverse(R) + sq = algebra.geometric_product(R, R_rev)[..., 0:1].abs().clamp(min=1e-6) + return R / sq.sqrt() + + +class WorldModelStep(nn.Module): + """One step of the World Model computation loop. + + Chains cell attention, action proposals, FIM scoring, search-plane + evolution, weighted candidate selection, grade-wise modulation, + gated residual update, and rotor accumulation. + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, + algebra_ctrl: CliffordAlgebra, + num_hypotheses: int = 8, + num_attn_heads: int = 4, + attn_head_dim: int = 8, + num_rule_slots: int = 8, + evolve_hidden: int = 64, + gate_init: float = 0.0, + use_supervised_fim: bool = True): + super().__init__() + D = algebra_cpu.dim # 16 + self._algebra_cpu = algebra_cpu + + # Attention + self.cell_attn = CellAttention( + algebra_cpu, num_attn_heads, attn_head_dim, + ) + + # Algebraic projection/lift + self.phi = AlgebraicProjection(algebra_cpu) + self.psi = AlgebraicLift(algebra_cpu) + + # Search plane + self.search_plane = SearchPlane( + algebra_ctrl, num_hypotheses, evolve_hidden, + ) + + # Action engine + self.action_engine = ActionEngine( + algebra_cpu, num_hypotheses, gate_init, + ) + + # FIM evaluator + self.fim_evaluator = FIMEvaluator(algebra_cpu) + self.use_supervised_fim = use_supervised_fim + + # Write gate — spatial components (grade 1+) + self.write_gate = nn.Sequential( + nn.Linear(D * 2, 64), + nn.ReLU(), + nn.Linear(64, D), + ) + # Separate color write gate for grade-0, so spatial gate can't suppress + # color flow. Shared MLP would learn to block grade-0 because motor + # transforms can't change scalars, making the blended grade-0 noisy. + self.color_write_gate = nn.Sequential( + nn.Linear(2, 16), + nn.ReLU(), + nn.Linear(16, 1), + ) + + # Normalization + self.norm = CliffordLayerNorm(algebra_cpu, 1) + + def set_temperature(self, tau: float): + self.search_plane.set_temperature(tau) + + def forward(self, state: torch.Tensor, hypotheses: torch.Tensor, + R_accum: torch.Tensor, mask: torch.Tensor = None, + rule_memory: torch.Tensor = None, + fim_prev: torch.Tensor = None, + targets: torch.Tensor = None) -> dict: + """Execute one world model step. + + Args: + state: Current state (mantissa) [B, N, D]. + hypotheses: Current hypotheses [B, K, 4]. + R_accum: Accumulated rotor [B, D]. + mask: Validity mask [B, N]. + rule_memory: Optional rule slots [B, M, D]. + fim_prev: Previous FIM values [B, K] or None. + targets: Optional target colors [B, N] for supervised FIM. + + Returns: + dict with world_state, hypotheses, R_accum, fim_values, search_info, gate. + """ + B, N, D = state.shape + old = state + + attended = self.cell_attn(state, mask) # [B, N, D] + + if mask is not None: + mask_f = mask.float().unsqueeze(-1) # [B, N, 1] + world_summary = (attended * mask_f).sum(dim=1) / mask_f.sum(dim=1).clamp(min=1.0) + else: + world_summary = attended.mean(dim=1) # [B, D] + + candidates = self.action_engine.propose_all( + attended, hypotheses, rule_memory, + ) # [B, K, N, D] + + if self.training and self.use_supervised_fim and targets is not None: + fim_values = self.fim_evaluator.supervised_fim(candidates, targets, mask) + else: + fim_values = self.fim_evaluator.fim_proxy(candidates, mask) + + search_result = self.search_plane( + hypotheses, world_summary, fim_values, fim_prev, + ) + weights = search_result['weights'] # [B, K] + hypotheses = search_result['hypotheses'] + + new_state = torch.einsum('bk,bknd->bnd', weights, candidates) # [B, N, D] + + modulation = self.psi(hypotheses, weights) # [B, D] + new_state = new_state * modulation.unsqueeze(1) + + gate_input = torch.cat([old, new_state], dim=-1) + gate = torch.sigmoid(self.write_gate(gate_input)) # [B, N, D] + + # Separate color gate for grade-0 — avoids clone() by computing + # the residual blend per-component before combining + color_gate_in = torch.stack([old[:, :, 0], new_state[:, :, 0]], dim=-1) + color_gate = torch.sigmoid(self.color_write_gate(color_gate_in)) # [B, N, 1] + + new_state_spatial = gate[:, :, 1:] * new_state[:, :, 1:] + (1.0 - gate[:, :, 1:]) * old[:, :, 1:] + new_state_color = color_gate * new_state[:, :, 0:1] + (1.0 - color_gate) * old[:, :, 0:1] + new_state = torch.cat([new_state_color, new_state_spatial], dim=-1) + + new_state = self.norm( + new_state.reshape(B * N, 1, D), + ).reshape(B, N, D) + + # Detach: chained geometric products create exponentially deep + # graphs that amplify gradients. R_accum is memory output only. + self._algebra_cpu.ensure_device(state.device) + R_t = self.action_engine.get_combined_rotor(weights) # [B, D] + R_accum = self._algebra_cpu.geometric_product(R_t, R_accum.detach()) + R_accum = _normalize_rotor(self._algebra_cpu, R_accum) + + return { + 'world_state': new_state, + 'hypotheses': hypotheses, + 'R_accum': R_accum, + 'fim_values': search_result['fim_values'], + 'search_info': search_result, + 'gate': gate, + } + + +class WorldModel(nn.Module): + """Main World Model: chains WorldModelSteps with log-manifold stability. + + Splits input into mantissa/exponent, processes mantissa through T steps, + then merges back. Supports FIM-based adaptive halt at inference and + FIM-weighted mixing at training (Phase 3). + """ + + def __init__(self, algebra_cpu: CliffordAlgebra, + algebra_ctrl: CliffordAlgebra, + num_steps: int = 12, + max_steps: int = 24, + num_hypotheses: int = 8, + num_attn_heads: int = 4, + attn_head_dim: int = 8, + num_rule_slots: int = 8, + evolve_hidden: int = 64, + gate_init: float = 0.0, + log_manifold_gate_init: float = -5.0, + halt_eps: float = 0.01, + use_supervised_fim: bool = True, + weight_share_steps: bool = False): + super().__init__() + self.num_steps = num_steps + self.max_steps = max_steps + self.num_hypotheses = num_hypotheses + self._algebra_cpu = algebra_cpu + + D = algebra_cpu.dim + + # Log-manifold projector + self.log_projector = LogManifoldProjector(algebra_cpu, log_manifold_gate_init) + + # World model steps + if weight_share_steps: + shared_step = WorldModelStep( + algebra_cpu, algebra_ctrl, + num_hypotheses, num_attn_heads, attn_head_dim, + num_rule_slots, evolve_hidden, gate_init, + use_supervised_fim, + ) + self.steps = nn.ModuleList([shared_step] * num_steps) + else: + self.steps = nn.ModuleList([ + WorldModelStep( + algebra_cpu, algebra_ctrl, + num_hypotheses, num_attn_heads, attn_head_dim, + num_rule_slots, evolve_hidden, gate_init, + use_supervised_fim, + ) + for _ in range(num_steps) + ]) + + # Initial hypotheses + self.hypothesis_init = nn.Parameter(torch.randn(num_hypotheses, 4) * 0.1) + + # FIM adaptive halt + self.fim_halt = FIMAdaptiveHalt(halt_eps) + self.use_fim_halt = False + + # Final norm + self.final_norm = CliffordLayerNorm(algebra_cpu, 1) + + def set_temperature(self, tau: float): + for step in self.steps: + step.set_temperature(tau) + + def forward(self, cpu_state: torch.Tensor, + mask: torch.Tensor = None, + rule_memory: torch.Tensor = None, + targets: torch.Tensor = None, + return_trace: bool = False) -> dict: + """Execute the world model loop. + + Args: + cpu_state: Initial CPU state [B, N, D] from GridCodec. + mask: Validity mask [B, N]. + rule_memory: Optional rule slots [B, M, D]. + targets: Optional target colors [B, N] for supervised FIM. + return_trace: Collect per-step diagnostics. + + Returns: + dict with output, hypotheses, R_accum, step_outputs, etc. + """ + B, N, D = cpu_state.shape + device = cpu_state.device + dtype = cpu_state.dtype + + # Split into mantissa + exponent + mantissa, exponent = self.log_projector.split(cpu_state) + + # Initialize + hypotheses = self.hypothesis_init.unsqueeze(0).expand(B, -1, -1).clone() + R_accum = _identity_rotor(B, D, device, dtype) + fim_prev = None + + step_outputs = [] + step_deltas = [] + step_weights = [] + trace = { + 'search_info': [], + 'gate_values': [], + 'fim_values': [], + } if return_trace else None + + for t, step in enumerate(self.steps): + result = step( + mantissa, hypotheses, R_accum, mask, + rule_memory, fim_prev, targets, + ) + mantissa = result['world_state'] + hypotheses = result['hypotheses'] + R_accum = result['R_accum'] + fim_prev = result['fim_values'] + + step_outputs.append(mantissa) + + search_info = result['search_info'] + step_deltas.append(search_info['delta_info']) + step_weights.append(search_info['weights']) + + if trace is not None: + trace['search_info'].append({ + k: v.detach() if torch.is_tensor(v) else v + for k, v in search_info.items() + }) + trace['gate_values'].append(result['gate'].detach()) + trace['fim_values'].append(result['fim_values'].detach()) + + # FIM-based halt (inference only) + if not self.training and t > 0: + delta = search_info['delta_info'] + w = search_info['weights'] + weighted_delta = (delta * w).sum(dim=-1).mean() + if weighted_delta < self.fim_halt.halt_eps: + break + + # Merge mantissa with exponent + output = self.log_projector.merge(mantissa, exponent) + + # Final norm + output = self.final_norm( + output.reshape(B * N, 1, D), + ).reshape(B, N, D) + + # FIM-weighted mixing during training + mixing_weights = None + if self.training and self.use_fim_halt and len(step_deltas) > 1: + halt_result = self.fim_halt(step_deltas, step_weights) + mixing_weights = halt_result['mixing_weights'] # [B, T] + + # Weighted sum of per-step mantissa outputs + stacked = torch.stack(step_outputs, dim=1) # [B, T, N, D] + mixed_mantissa = torch.einsum('bt,btnd->bnd', mixing_weights, stacked) + output = self.log_projector.merge(mixed_mantissa, exponent) + output = self.final_norm( + output.reshape(B * N, 1, D), + ).reshape(B, N, D) + + return { + 'output': output, + 'hypotheses': hypotheses, + 'R_accum': R_accum, + 'step_outputs': step_outputs, + 'step_deltas': step_deltas, + 'step_weights': step_weights, + 'mixing_weights': mixing_weights, + 'trace': trace, + } From 663a5fb52dc7516dbd8f67b9a2e4de8744ec8ccc Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 11:41:03 +0900 Subject: [PATCH 07/16] chore: update config files need test --- conf/task/gtm.yaml | 45 +++++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/conf/task/gtm.yaml b/conf/task/gtm.yaml index 7a8d51e..7c38e4b 100644 --- a/conf/task/gtm.yaml +++ b/conf/task/gtm.yaml @@ -12,16 +12,26 @@ model: num_steps: 12 max_steps: 24 num_hypotheses: 8 - top_k: 1 coord_scale: 1.0 head_hidden: 128 - gumbel_temperature: 1.0 num_rule_slots: 8 - act: - enabled: true - lambda_p: 0.5 - color_unit: - K_color: 4 + num_memory_channels: 4 + weight_share_steps: false + + log_manifold: + gate_init: -5.0 + + search_plane: + conviction_threshold: 0.9 + evolve_hidden: 64 + + info_geometry: + halt_eps: 0.01 + use_supervised_fim: true + + action_engine: + gate_init: 0.0 + attention: num_heads: 4 head_dim: 8 @@ -32,7 +42,7 @@ dataset: toy_n_examples: 20000 toy_max_grid_size: 15 num_demos: 3 - epoch_samples: 4000 # 0 = full dataset shuffle; set >0 for capped-epoch sampling + epoch_samples: 4000 training: epochs: 150 @@ -40,6 +50,7 @@ training: batch_size: 16 optimizer_type: riemannian_adam max_bivector_norm: 10.0 + grad_clip: 1.0 # CUDA acceleration num_workers: 4 @@ -48,16 +59,18 @@ training: compile: false cudnn_benchmark: true - # Three-phase schedule (scaled for 150 epochs) + # Three-phase schedule warmup_epochs: 8 trim_epochs: 72 act_epochs: 70 - act_weight: 0.01 - act_ramp_epochs: 20 - gate_entropy_weight: 0.01 - grad_clip: 1.0 - eval_every: 5 + + # Temperature schedule tau_start: 1.0 - tau_act_restart: 0.7 - tau_end: 0.1 + tau_mid: 0.5 + tau_end: 0.05 + # Loss weights + ortho_weight: 0.005 + gate_entropy_weight: 0.01 + info_gain_weight: 0.01 + eval_every: 5 From 5997260dd6a9ec43ef906ab9e70e63b41eb5c758 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 11:41:35 +0900 Subject: [PATCH 08/16] feat: change task structure for using updated gtm --- tasks/gtm.py | 179 +++++++++++++++++++++++++++------------------------ 1 file changed, 95 insertions(+), 84 deletions(-) diff --git a/tasks/gtm.py b/tasks/gtm.py index bf47f9f..856f691 100644 --- a/tasks/gtm.py +++ b/tasks/gtm.py @@ -5,20 +5,14 @@ # you may not use this file except in compliance with the License. # -"""Geometric Turing Machine Task — ARC-AGI v4. +"""Geometric Turing Machine Task. Few-shot format: each training example = (demo_pairs, test_input, test_output). The model sees K demo (input,output) pairs to infer the rule, then applies it to a test input to produce the test output. -Three-phase training (anti-lazy-optimization): -1. Warmup: freeze VM, train head + init_cursor + role_embed -2. Circuit Search: unfreeze VM, fixed steps, gate entropy loss -3. ACT: enable adaptive computation, KL ramp-up - -Two algebras (Mother algebra removed): - CPU Cl(3,0,1): PGA computation engine (motor + color) - Control Cl(1,1): learnable search +Three-phase training: warmup (WorldModel frozen), world model training +(ortho + gate entropy losses), FIM halt + conviction collapse. """ import torch @@ -27,71 +21,76 @@ from core.algebra import CliffordAlgebra from tasks.base import BaseTask from models.gtm import GTMNet +from models.gtm.search_plane import SearchPlane from log import get_logger logger = get_logger(__name__) -def _gate_entropy_loss(scores: torch.Tensor) -> torch.Tensor: - """Entropy of search scores — minimizing this encourages instruction specialization.""" +def _gate_entropy_loss(gate_values: torch.Tensor) -> torch.Tensor: + """Entropy of write gate values — encourages decisive gating.""" eps = 1e-8 - probs = torch.softmax(scores, dim=-1) - entropy = -(probs * torch.log(probs + eps)).sum(dim=-1) + p = gate_values.mean(dim=(0, 1)) # [D] average gate per component + entropy = -(p * torch.log(p + eps) + (1 - p) * torch.log(1 - p + eps)) return entropy.mean() class GTMTask(BaseTask): - """Geometric Turing Machine task for ARC-AGI v4.""" + """Geometric Turing Machine task for ARC-AGI.""" def __init__(self, cfg): # Training phase config - self.warmup_epochs = cfg.training.get('warmup_epochs', 5) - self.trim_epochs = cfg.training.get('trim_epochs', 50) - self.act_epochs = cfg.training.get('act_epochs', 45) - self.act_weight = cfg.training.get('act_weight', 0.01) - self.act_ramp_epochs = cfg.training.get('act_ramp_epochs', 15) - self.gate_entropy_weight = cfg.training.get('gate_entropy_weight', 0.01) + self.warmup_epochs = cfg.training.get('warmup_epochs', 8) + self.trim_epochs = cfg.training.get('trim_epochs', 72) + self.act_epochs = cfg.training.get('act_epochs', 70) self.grad_clip = cfg.training.get('grad_clip', 1.0) self.eval_every = cfg.training.get('eval_every', 5) - # Gumbel temperature annealing schedule + # Temperature schedule self.tau_start = cfg.training.get('tau_start', 1.0) - self.tau_end = cfg.training.get('tau_end', 0.1) - # Warm restart at Phase 3: steps[num_steps:max_steps] are untrained, - # need high tau for exploration before annealing down - self.tau_act_restart = cfg.training.get('tau_act_restart', 0.7) + self.tau_mid = cfg.training.get('tau_mid', 0.5) + self.tau_end = cfg.training.get('tau_end', 0.05) + + # Loss weights + self.ortho_weight = cfg.training.get('ortho_weight', 0.005) + self.gate_entropy_weight = cfg.training.get('gate_entropy_weight', 0.01) + self.info_gain_weight = cfg.training.get('info_gain_weight', 0.01) super().__init__(cfg) def setup_algebra(self): - """Initialize CPU and Control algebras. Returns CPU algebra for BaseTask.""" + """Initialize CPU and Control algebras.""" self.algebra_cpu = CliffordAlgebra(3, 0, 1, device=self.device) self.algebra_ctrl = CliffordAlgebra(1, 1, 0, device=self.device) return self.algebra_cpu def setup_model(self): mcfg = self.cfg.model - act_cfg = mcfg.get('act', {}) - color_cfg = mcfg.get('color_unit', {}) attn_cfg = mcfg.get('attention', {}) + sp_cfg = mcfg.get('search_plane', {}) + lm_cfg = mcfg.get('log_manifold', {}) + ig_cfg = mcfg.get('info_geometry', {}) + ae_cfg = mcfg.get('action_engine', {}) return GTMNet( algebra_cpu=self.algebra_cpu, algebra_ctrl=self.algebra_ctrl, - channels=mcfg.get('channels', 16), - num_steps=mcfg.get('num_steps', 8), - max_steps=mcfg.get('max_steps', 20), - num_hypotheses=mcfg.get('num_hypotheses', 4), - top_k=mcfg.get('top_k', 1), - head_hidden=mcfg.get('head_hidden', 64), - temperature_init=mcfg.get('gumbel_temperature', 1.0), - use_act=act_cfg.get('enabled', True), - lambda_p=act_cfg.get('lambda_p', 0.5), + channels=mcfg.get('channels', 32), + num_steps=mcfg.get('num_steps', 12), + max_steps=mcfg.get('max_steps', 24), + num_hypotheses=mcfg.get('num_hypotheses', 8), + head_hidden=mcfg.get('head_hidden', 128), coord_scale=mcfg.get('coord_scale', 1.0), - K_color=color_cfg.get('K_color', 4), num_attn_heads=attn_cfg.get('num_heads', 4), attn_head_dim=attn_cfg.get('head_dim', 8), num_rule_slots=mcfg.get('num_rule_slots', 8), + num_memory_channels=mcfg.get('num_memory_channels', 4), + weight_share_steps=mcfg.get('weight_share_steps', False), + log_manifold_gate_init=lm_cfg.get('gate_init', -5.0), + evolve_hidden=sp_cfg.get('evolve_hidden', 64), + halt_eps=ig_cfg.get('halt_eps', 0.01), + use_supervised_fim=ig_cfg.get('use_supervised_fim', True), + action_gate_init=ae_cfg.get('gate_init', 0.0), ) def _setup_optimizer(self): @@ -151,10 +150,16 @@ def _run_model(self, batch, return_trace=False): test_masks = batch['test_masks'].to(self.device) num_demos = batch['num_demos'].to(self.device) + # Pass targets for supervised FIM during training + test_targets = None + if self.model.training and 'test_outputs' in batch: + test_targets = batch['test_outputs'].to(self.device) + return self.model( demo_inputs, demo_outputs, demo_masks, test_inputs, test_masks, num_demos, demo_output_masks=demo_output_masks, + test_targets=test_targets, input_sizes=batch.get('input_sizes'), return_trace=return_trace, ) @@ -167,31 +172,50 @@ def train_step(self, batch): logits = result['logits'] # [B, N_grid, 10] - # Target: test output grid flattened - test_outputs = batch['test_outputs'].to(self.device) # [B, H_max, W_max] + # Target + test_outputs = batch['test_outputs'].to(self.device) B, H_max, W_max = test_outputs.shape - targets = test_outputs.reshape(B, H_max * W_max) # [B, N_grid] + targets = test_outputs.reshape(B, H_max * W_max) loss = self.criterion( logits.reshape(-1, 10), targets.reshape(-1), ) - # ACT KL loss (Phase 3 only) - act_kl = torch.tensor(0.0, device=self.device) - if 'act_info' in result and result['act_info'] is not None: - act_kl = result['act_info']['kl_loss'] - loss = loss + self._current_act_weight * act_kl + wm_info = result.get('world_model_info', {}) + + ortho_loss = torch.tensor(0.0, device=self.device) + if self._phase >= 2 and self.ortho_weight > 0: + # Get Gram from last step's search info via trace + if 'trace' in result and result['trace'] is not None: + test_trace = result['trace'].get('test') + if test_trace and test_trace['search_info']: + last_search = test_trace['search_info'][-1] + gram = last_search.get('gram') + if gram is not None: + ortho_loss = SearchPlane.orthogonality_loss(gram) + loss = loss + self.ortho_weight * ortho_loss - # Gate entropy loss (Phases 2-3) gate_ent = torch.tensor(0.0, device=self.device) if need_trace and 'trace' in result and result['trace'] is not None: - trace = result['trace'] - if trace['search_scores']: - ent_sum = sum(_gate_entropy_loss(s) for s in trace['search_scores']) - gate_ent = ent_sum / len(trace['search_scores']) + test_trace = result['trace'].get('test') + if test_trace and test_trace['gate_values']: + ent_sum = sum(_gate_entropy_loss(g) for g in test_trace['gate_values']) + gate_ent = ent_sum / len(test_trace['gate_values']) loss = loss + self.gate_entropy_weight * gate_ent + # Penalize negative info gain (monotonic progress) + info_loss = torch.tensor(0.0, device=self.device) + if self._phase >= 3 and self.info_gain_weight > 0: + step_deltas = wm_info.get('step_deltas', []) + if step_deltas: + all_deltas = torch.stack( + [(d).mean() for d in step_deltas] + ) + # Penalize negative information gain (want monotonic progress) + info_loss = torch.relu(-all_deltas).mean() + loss = loss + self.info_gain_weight * info_loss + self._backward(loss) if self.grad_clip > 0: @@ -202,10 +226,12 @@ def train_step(self, batch): self._optimizer_step() logs = {'Loss': loss.item()} - if act_kl.item() > 0: - logs['ACT_KL'] = act_kl.item() + if ortho_loss.item() > 0: + logs['Ortho'] = ortho_loss.item() if gate_ent.item() != 0: logs['GateEnt'] = gate_ent.item() + if info_loss.item() > 0: + logs['InfoGain'] = info_loss.item() return loss.item(), logs def evaluate(self, val_loader): @@ -218,8 +244,8 @@ def evaluate(self, val_loader): with torch.no_grad(): for batch in val_loader: result = self._run_model(batch) - logits = result['logits'] # [B, N_grid, 10] - preds = logits.argmax(dim=-1) # [B, N_grid] + logits = result['logits'] + preds = logits.argmax(dim=-1) test_outputs = batch['test_outputs'].to(self.device) test_masks = batch['test_masks'].to(self.device) @@ -227,12 +253,10 @@ def evaluate(self, val_loader): targets = test_outputs.reshape(B, H_max * W_max) valid = test_masks.reshape(B, H_max * W_max) - # Cell accuracy (non-padded cells only) matches = (preds == targets) & valid cell_correct += matches.sum().item() cell_total += valid.sum().item() - # Grid accuracy (entire grid must match) test_sizes = batch['test_sizes'] for i in range(B): toH, toW = test_sizes[i] @@ -251,15 +275,14 @@ def visualize(self, val_loader): pass def run(self): - """Three-phase training loop with ACT ramp-up.""" - logger.info("Starting GTM ARC-AGI v4 Task") + """Three-phase training loop with FIM-based computation budget.""" + logger.info("Starting GTM training") train_loader, val_loader = self.get_data() total_epochs = self.warmup_epochs + self.trim_epochs + self.act_epochs self.epochs = total_epochs self._phase = 0 - self._current_act_weight = 0.0 best_val_metric = 0.0 metric_key = 'cell_accuracy' @@ -276,41 +299,31 @@ def run(self): if phase != self._phase: self._phase = phase if phase == 1: - logger.info("Phase 1: Warmup (VM frozen, train head + init_cursor)") - self.model.freeze_vm() - self.model.disable_act() + logger.info("Phase 1: Warmup (WorldModel frozen)") + self.model.freeze_world_model() + self.model.disable_fim_halt() elif phase == 2: - logger.info("Phase 2: Circuit Search (fixed steps)") - self.model.unfreeze_vm() - self.model.disable_act() + logger.info("Phase 2: World Model Training") + self.model.unfreeze_world_model() + self.model.disable_fim_halt() elif phase == 3: - act_cfg = self.cfg.model.get('act', {}) - if act_cfg.get('enabled', True): - logger.info("Phase 3: ACT activation (adaptive computation)") - self.model.enable_act() - else: - logger.info("Phase 3: Extended training (ACT disabled)") + logger.info("Phase 3: FIM Halt + Conviction Collapse") + self.model.enable_fim_halt() + # Rebuild optimizer for new trainable params self.optimizer = self._setup_optimizer() self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', factor=0.5, patience=10) - # ACT weight ramp - if phase == 3: - act_epoch = epoch - (self.warmup_epochs + self.trim_epochs) - ramp = min(1.0, act_epoch / self.act_ramp_epochs) if self.act_ramp_epochs > 0 else 1.0 - self._current_act_weight = self.act_weight * ramp - else: - self._current_act_weight = 0.0 - + # Temperature schedule if phase == 1: tau = self.tau_start elif phase == 2: progress = min(1.0, (epoch - self.warmup_epochs) / max(self.trim_epochs, 1)) - tau = self.tau_start + (self.tau_act_restart - self.tau_start) * progress + tau = self.tau_start + (self.tau_mid - self.tau_start) * progress else: # phase 3 act_epoch = epoch - (self.warmup_epochs + self.trim_epochs) progress = min(1.0, act_epoch / max(self.act_epochs, 1)) - tau = self.tau_act_restart + (self.tau_end - self.tau_act_restart) * progress + tau = self.tau_mid + (self.tau_end - self.tau_mid) * progress self.model.set_temperature(tau) # Training @@ -348,8 +361,6 @@ def run(self): 'LR': self.optimizer.param_groups[0]['lr'], 'tau': tau, } - if self._current_act_weight > 0: - display['ACT_w'] = self._current_act_weight desc = " | ".join( f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display.items() From efc3449a57dd8390d4718192b23d292f3e9c6d5e Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 11:41:58 +0900 Subject: [PATCH 09/16] fix: relax the test case about hermitianNorm --- tests/test_hermitian_metrics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_hermitian_metrics.py b/tests/test_hermitian_metrics.py index 49a0d0f..2689474 100644 --- a/tests/test_hermitian_metrics.py +++ b/tests/test_hermitian_metrics.py @@ -145,7 +145,8 @@ def test_non_negative(self, algebra_minkowski): def test_zero_for_zero(self, algebra_3d): mv = torch.zeros(algebra_3d.dim) n = hermitian_norm(algebra_3d, mv) - assert torch.allclose(n, torch.tensor([0.0])) + # hermitian_norm clamps at 1e-12 before sqrt for gradient safety + assert n < 1e-5 def test_positive_for_nonzero(self, algebra_3d): mv = torch.randn(algebra_3d.dim) From 3976012e121f2b70681f41305672ff73cec10d66 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 19:49:45 +0900 Subject: [PATCH 10/16] fix: increase tau end into 0.2 to for prevent wrong craystalization and gss early collapse --- conf/task/gtm.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conf/task/gtm.yaml b/conf/task/gtm.yaml index 7c38e4b..9f77016 100644 --- a/conf/task/gtm.yaml +++ b/conf/task/gtm.yaml @@ -67,7 +67,7 @@ training: # Temperature schedule tau_start: 1.0 tau_mid: 0.5 - tau_end: 0.05 + tau_end: 0.2 # Loss weights ortho_weight: 0.005 From b3d9b3e9b5f33dc45446c6db754cf07ac82778cd Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 19:50:41 +0900 Subject: [PATCH 11/16] fix: adopt blend FIM info for gradual changing --- models/gtm/world_model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/models/gtm/world_model.py b/models/gtm/world_model.py index 056972c..eb9bc78 100644 --- a/models/gtm/world_model.py +++ b/models/gtm/world_model.py @@ -314,6 +314,8 @@ def __init__(self, algebra_cpu: CliffordAlgebra, # FIM adaptive halt self.fim_halt = FIMAdaptiveHalt(halt_eps) self.use_fim_halt = False + # Ramp for gradual FIM mixing blend (0=last-step only, 1=full FIM mix) + self.fim_mix_ramp = 0.0 # Final norm self.final_norm = CliffordLayerNorm(algebra_cpu, 1) @@ -400,16 +402,21 @@ def forward(self, cpu_state: torch.Tensor, output.reshape(B * N, 1, D), ).reshape(B, N, D) - # FIM-weighted mixing during training + # FIM-weighted mixing during training: gradually blend between + # last-step output and FIM-weighted mix using fim_mix_ramp (0->1) mixing_weights = None if self.training and self.use_fim_halt and len(step_deltas) > 1: halt_result = self.fim_halt(step_deltas, step_weights) mixing_weights = halt_result['mixing_weights'] # [B, T] - # Weighted sum of per-step mantissa outputs stacked = torch.stack(step_outputs, dim=1) # [B, T, N, D] mixed_mantissa = torch.einsum('bt,btnd->bnd', mixing_weights, stacked) - output = self.log_projector.merge(mixed_mantissa, exponent) + + # Blend: ramp=0 uses last-step, ramp=1 uses full FIM mix + ramp = self.fim_mix_ramp + blended_mantissa = (1.0 - ramp) * mantissa + ramp * mixed_mantissa + + output = self.log_projector.merge(blended_mantissa, exponent) output = self.final_norm( output.reshape(B * N, 1, D), ).reshape(B, N, D) From cf29c770ead9909874226d5846581bfac70586a9 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 19:51:02 +0900 Subject: [PATCH 12/16] fix: change preserve lr in phase 2->3 --- tasks/gtm.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tasks/gtm.py b/tasks/gtm.py index 856f691..bb90058 100644 --- a/tasks/gtm.py +++ b/tasks/gtm.py @@ -297,6 +297,7 @@ def run(self): phase = 3 if phase != self._phase: + prev_phase = self._phase self._phase = phase if phase == 1: logger.info("Phase 1: Warmup (WorldModel frozen)") @@ -309,8 +310,15 @@ def run(self): elif phase == 3: logger.info("Phase 3: FIM Halt + Conviction Collapse") self.model.enable_fim_halt() - # Rebuild optimizer for new trainable params + + # Preserve LR from previous phase when transitioning 2->3 + # to avoid destabilizing learned representations + prev_lr = (self.optimizer.param_groups[0]['lr'] + if prev_phase > 0 else self.cfg.training.lr) self.optimizer = self._setup_optimizer() + if prev_phase > 0: + for pg in self.optimizer.param_groups: + pg['lr'] = prev_lr self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode='min', factor=0.5, patience=10) @@ -326,6 +334,14 @@ def run(self): tau = self.tau_mid + (self.tau_end - self.tau_mid) * progress self.model.set_temperature(tau) + # FIM mixing ramp: gradually blend in FIM-weighted output over Phase 3 + if phase == 3: + act_epoch = epoch - (self.warmup_epochs + self.trim_epochs) + self.model.world_model.fim_mix_ramp = min( + 1.0, act_epoch / max(self.act_epochs * 0.5, 1)) + else: + self.model.world_model.fim_mix_ramp = 0.0 + # Training self.model.train() total_loss = 0 From 20613ff43271490510c8e22d8aa7782e216cab58 Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 22:01:38 +0900 Subject: [PATCH 13/16] fix: relax write_gate init bias --- models/gtm/world_model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/models/gtm/world_model.py b/models/gtm/world_model.py index eb9bc78..5d71364 100644 --- a/models/gtm/world_model.py +++ b/models/gtm/world_model.py @@ -151,20 +151,19 @@ def __init__(self, algebra_cpu: CliffordAlgebra, self.fim_evaluator = FIMEvaluator(algebra_cpu) self.use_supervised_fim = use_supervised_fim - # Write gate — spatial components (grade 1+) self.write_gate = nn.Sequential( nn.Linear(D * 2, 64), nn.ReLU(), nn.Linear(64, D), ) - # Separate color write gate for grade-0, so spatial gate can't suppress - # color flow. Shared MLP would learn to block grade-0 because motor - # transforms can't change scalars, making the blended grade-0 noisy. + nn.init.constant_(self.write_gate[-1].bias, -3.0) + self.color_write_gate = nn.Sequential( nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 1), ) + nn.init.constant_(self.color_write_gate[-1].bias, -3.0) # Normalization self.norm = CliffordLayerNorm(algebra_cpu, 1) From 65d04e094a40e678adac736af41b958282c3e3ea Mon Sep 17 00:00:00 2001 From: Concode0 Date: Sat, 21 Mar 2026 22:02:09 +0900 Subject: [PATCH 14/16] fix: add analyze_temperature functions --- models/gtm/analysis.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/models/gtm/analysis.py b/models/gtm/analysis.py index 139114c..a4fc966 100644 --- a/models/gtm/analysis.py +++ b/models/gtm/analysis.py @@ -138,6 +138,17 @@ def analyze_action_gate(self) -> dict: 'discrete_dominant': (gate <= 0.5).sum().item(), } + def analyze_temperature(self) -> dict: + """Inspect per-step softmax temperature from SearchPlane buffers.""" + temps = [] + for step in self.model.world_model.steps: + tau = step.search_plane._temperature.item() + temps.append(tau) + return { + 'temperatures': temps, + 'is_sharp': [t < 0.1 for t in temps], + } + def analyze_hypothesis_init(self) -> dict: """Inspect initial hypothesis positions in Cl(1,1).""" h_init = self.model.world_model.hypothesis_init.detach() From 8dbdf990f5c6c58abc4b97c5f9af500d7e489a9a Mon Sep 17 00:00:00 2001 From: Concode0 Date: Mon, 23 Mar 2026 16:53:24 +0900 Subject: [PATCH 15/16] feat: change lr schedule system and adopt more constraints for proper crystalization --- conf/task/gtm.yaml | 7 ++++- models/gtm/action_engine.py | 37 +++++++++++++++++++++---- models/gtm/gtm_net.py | 2 ++ models/gtm/search_plane.py | 21 +++++++------- models/gtm/world_model.py | 55 ++++++++++++++++++++++++++++++++----- 5 files changed, 99 insertions(+), 23 deletions(-) diff --git a/conf/task/gtm.yaml b/conf/task/gtm.yaml index 9f77016..748fcce 100644 --- a/conf/task/gtm.yaml +++ b/conf/task/gtm.yaml @@ -17,6 +17,7 @@ model: num_rule_slots: 8 num_memory_channels: 4 weight_share_steps: false + gradient_horizon: 2 log_manifold: gate_init: -5.0 @@ -51,14 +52,18 @@ training: optimizer_type: riemannian_adam max_bivector_norm: 10.0 grad_clip: 1.0 + min_lr: 1.0e-5 # CUDA acceleration num_workers: 4 pin_memory: true - amp: true + amp: false compile: false cudnn_benchmark: true + # Phase 2 LR cap (fraction of base LR) — prevents instability in 12-step model + phase2_lr_scale: 0.5 + # Three-phase schedule warmup_epochs: 8 trim_epochs: 72 diff --git a/models/gtm/action_engine.py b/models/gtm/action_engine.py index 59abe23..6b934b7 100644 --- a/models/gtm/action_engine.py +++ b/models/gtm/action_engine.py @@ -40,7 +40,7 @@ def __init__(self): super().__init__() self.spatial_proj = nn.Linear(4, 32) self.color_mlp = nn.Sequential( - nn.Linear(32 + 16, 64), + nn.Linear(32 + 16 + 4, 64), # +4 for hypothesis state nn.ReLU(), nn.Linear(64, self._NUM_COLORS), ) @@ -51,12 +51,14 @@ def __init__(self): ) def forward(self, state: torch.Tensor, - instr: torch.Tensor) -> torch.Tensor: + instr: torch.Tensor, + hypothesis: torch.Tensor = None) -> torch.Tensor: """Apply discrete color update via soft color selection. Args: state: Cell states [L, N, 16]. instr: Instructions [L, 16]. + hypothesis: Hypothesis state [L, 4] from SearchPlane. Returns: Updated state [L, N, 16] with modified grade-0. @@ -64,7 +66,14 @@ def forward(self, state: torch.Tensor, L, N, D = state.shape spatial = state[:, :, self._SPATIAL_IDX] # [L, N, 4] feat = F.relu(self.spatial_proj(spatial)) # [L, N, 32] - ctx = torch.cat([feat, instr.unsqueeze(1).expand(-1, N, -1)], dim=-1) + instr_exp = instr.unsqueeze(1).expand(-1, N, -1) # [L, N, 16] + if hypothesis is not None: + h_exp = hypothesis.unsqueeze(1).expand(-1, N, -1) # [L, N, 4] + ctx = torch.cat([feat, instr_exp, h_exp], dim=-1) # [L, N, 52] + else: + # Fallback: zero-pad hypothesis dims for compat + h_zeros = torch.zeros(L, N, 4, device=state.device, dtype=state.dtype) + ctx = torch.cat([feat, instr_exp, h_zeros], dim=-1) # [L, N, 52] color_logits = self.color_mlp(ctx) # [L, N, 10] # Soft color selection: differentiable weighted sum of anchor values @@ -108,6 +117,16 @@ def __init__(self, algebra_cpu: CliffordAlgebra, gate_vals[0] = -2.0 # sigmoid(-2) ≈ 0.12 → mostly discrete for color self.action_gate = nn.Parameter(gate_vals) + # FiLM modulation: hypothesis -> instruction template modulation + # Creates the missing gradient path: loss -> candidates -> templates -> hypotheses + self.hypothesis_modulator = nn.Sequential( + nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 2 * D), + ) + # Initialize at identity: scale=1, shift=0 + nn.init.zeros_(self.hypothesis_modulator[-1].weight) + nn.init.zeros_(self.hypothesis_modulator[-1].bias) + self.hypothesis_modulator[-1].bias.data[:D] = 1.0 # scale starts at 1 + # Rule memory modulation self.rule_proj = nn.Linear(D, K * D) @@ -165,15 +184,23 @@ def propose_all(self, state: torch.Tensor, rule_mod = self.rule_proj(rule_memory.mean(dim=1)).view(B, K, D) templates = templates + rule_mod + # FiLM: modulate templates with hypothesis state + # Bounded to prevent runaway amplification: scale ∈ [0, 2], shift ∈ [-0.5, 0.5] + h_mod = self.hypothesis_modulator(hypotheses) # [B, K, 2*D] + scale = 1.0 + torch.tanh(h_mod[..., :D] - 1.0) # init: tanh(0)=0 → scale=1 + shift = torch.tanh(h_mod[..., D:]) * 0.5 # init: tanh(0)=0 → shift=0 + templates = templates * scale + shift # [B, K, D] + # Batch all K hypotheses: [B*K, N, D] state_exp = state.unsqueeze(1).expand(B, K, N, D).reshape(B * K, N, D) instr_flat = templates.reshape(B * K, D) + hyp_flat = hypotheses.reshape(B * K, 4) # Continuous motor transform (no ColorUnit — pure geometric) continuous = self._motor_transform(state_exp, instr_flat) # [B*K, N, D] - # Discrete color update - discrete = self.discrete_head(state_exp, instr_flat) # [B*K, N, D] + # Discrete color update (hypothesis-conditioned) + discrete = self.discrete_head(state_exp, instr_flat, hyp_flat) # [B*K, N, D] # Blend via per-component gate gate = torch.sigmoid(self.action_gate) # [D] diff --git a/models/gtm/gtm_net.py b/models/gtm/gtm_net.py index 18647b3..1b36601 100644 --- a/models/gtm/gtm_net.py +++ b/models/gtm/gtm_net.py @@ -50,6 +50,7 @@ def __init__( halt_eps: float = 0.01, use_supervised_fim: bool = True, action_gate_init: float = 0.0, + gradient_horizon: int = 2, ): super().__init__() self.algebra_cpu = algebra_cpu @@ -88,6 +89,7 @@ def __init__( halt_eps=halt_eps, use_supervised_fim=use_supervised_fim, weight_share_steps=weight_share_steps, + gradient_horizon=gradient_horizon, ) # Reconstruction head diff --git a/models/gtm/search_plane.py b/models/gtm/search_plane.py index 14bd994..7c5287a 100644 --- a/models/gtm/search_plane.py +++ b/models/gtm/search_plane.py @@ -130,14 +130,17 @@ def __init__(self, algebra_ctrl: CliffordAlgebra, # Initial hypothesis states in Cl(1,1) self.hypothesis_init = nn.Parameter(torch.randn(K, 4) * 0.1) - # Evolution network: context -> boost magnitude + # Evolution network: context -> full Cl(1,1) bivector for evolution # Input: hypothesis (4) + world_summary projected to 16D -> concatenated self.evolve_net = nn.Sequential( nn.Linear(4 + 16, evolve_hidden), nn.ReLU(), - nn.Linear(evolve_hidden, 1), + nn.Linear(evolve_hidden, 4), ) + # Learnable scale for RMS-normalized hypotheses + self._evolve_scale = nn.Parameter(torch.tensor(1.0)) + # Temperature buffer (annealed externally) self.register_buffer('_temperature', torch.tensor(1.0)) @@ -176,13 +179,10 @@ def forward(self, hypotheses: torch.Tensor, world_exp = world_summary.unsqueeze(1).expand(B, K, -1).reshape(B * K, -1) h_flat = hypotheses.reshape(B * K, 4) ctx = torch.cat([h_flat, world_exp], dim=-1) # [B*K, 20] - raw_theta = self.evolve_net(ctx).reshape(B, K) - # Smooth bounding via tanh: always has gradient, range [-3, 3] - theta = torch.tanh(raw_theta) * 3.0 - # Build e+e- bivector (index 3 in Cl(1,1)) - bv = torch.zeros(B * K, 4, device=device, dtype=hypotheses.dtype) - bv[:, 3] = theta.reshape(B * K) + # Full 4D bivector evolution (all Cl(1,1) dimensions) + bv = self.evolve_net(ctx) # [B*K, 4] + bv = torch.tanh(bv) * 3.0 # bound all components # Exponentiate and sandwich R = self.algebra.exp(-0.5 * bv) # [B*K, 4] @@ -190,8 +190,9 @@ def forward(self, hypotheses: torch.Tensor, evolved = self.algebra.geometric_product( self.algebra.geometric_product(R, h_flat), R_rev ) - # Symmlog: prevents unbounded drift, gradient = 1/(1+|x|), never zero - evolved = torch.sign(evolved) * torch.log1p(evolved.abs()) + # RMS normalization: O(1) gradient regardless of input magnitude + rms = evolved.pow(2).mean(dim=-1, keepdim=True).add(1e-8).sqrt() + evolved = evolved / rms * self._evolve_scale hypotheses = evolved.reshape(B, K, 4) delta_info = fim_values - fim_prev if fim_prev is not None else fim_values diff --git a/models/gtm/world_model.py b/models/gtm/world_model.py index 5d71364..8afa1ff 100644 --- a/models/gtm/world_model.py +++ b/models/gtm/world_model.py @@ -175,7 +175,9 @@ def forward(self, state: torch.Tensor, hypotheses: torch.Tensor, R_accum: torch.Tensor, mask: torch.Tensor = None, rule_memory: torch.Tensor = None, fim_prev: torch.Tensor = None, - targets: torch.Tensor = None) -> dict: + targets: torch.Tensor = None, + step_idx: int = 0, total_steps: int = 12, + gradient_horizon: int = 2) -> dict: """Execute one world model step. Args: @@ -186,6 +188,9 @@ def forward(self, state: torch.Tensor, hypotheses: torch.Tensor, rule_memory: Optional rule slots [B, M, D]. fim_prev: Previous FIM values [B, K] or None. targets: Optional target colors [B, N] for supervised FIM. + step_idx: Current step index in the loop. + total_steps: Total number of steps. + gradient_horizon: Allow gradient for last N steps. Returns: dict with world_state, hypotheses, R_accum, fim_values, search_info, gate. @@ -237,8 +242,8 @@ def forward(self, state: torch.Tensor, hypotheses: torch.Tensor, new_state.reshape(B * N, 1, D), ).reshape(B, N, D) - # Detach: chained geometric products create exponentially deep - # graphs that amplify gradients. R_accum is memory output only. + # R_accum tracks the accumulated rotor for diagnostics but is not + # in the loss path, so always detach to save memory. self._algebra_cpu.ensure_device(state.device) R_t = self.action_engine.get_combined_rotor(weights) # [B, D] R_accum = self._algebra_cpu.geometric_product(R_t, R_accum.detach()) @@ -275,11 +280,13 @@ def __init__(self, algebra_cpu: CliffordAlgebra, log_manifold_gate_init: float = -5.0, halt_eps: float = 0.01, use_supervised_fim: bool = True, - weight_share_steps: bool = False): + weight_share_steps: bool = False, + gradient_horizon: int = 2): super().__init__() self.num_steps = num_steps self.max_steps = max_steps self.num_hypotheses = num_hypotheses + self.gradient_horizon = gradient_horizon self._algebra_cpu = algebra_cpu D = algebra_cpu.dim @@ -307,15 +314,33 @@ def __init__(self, algebra_cpu: CliffordAlgebra, for _ in range(num_steps) ]) - # Initial hypotheses + # Initial hypotheses (base, shared across problems) self.hypothesis_init = nn.Parameter(torch.randn(num_hypotheses, 4) * 0.1) + # Demo-conditioned hypothesis offset: rule_memory -> per-problem initial hypotheses + # Creates different starting hypotheses for different problems, enabling + # per-problem adaptation instead of fitting one global average. + # Zero-initialized so it starts as identity (no offset at init). + self.hypothesis_projector = nn.Sequential( + nn.Linear(D, 64), nn.ReLU(), nn.Linear(64, num_hypotheses * 4), + ) + nn.init.zeros_(self.hypothesis_projector[-1].weight) + nn.init.zeros_(self.hypothesis_projector[-1].bias) + # FIM adaptive halt self.fim_halt = FIMAdaptiveHalt(halt_eps) self.use_fim_halt = False # Ramp for gradual FIM mixing blend (0=last-step only, 1=full FIM mix) self.fim_mix_ramp = 0.0 + # Gated exponent update: enables magnitude learning during the step loop + self.exponent_update = nn.Sequential( + nn.Linear(D, 32), nn.ReLU(), nn.Linear(32, 1), + ) + nn.init.zeros_(self.exponent_update[-1].weight) + nn.init.zeros_(self.exponent_update[-1].bias) + self.exponent_gate = nn.Parameter(torch.tensor(-5.0)) # sigmoid(-5)≈0.007 + # Final norm self.final_norm = CliffordLayerNorm(algebra_cpu, 1) @@ -347,8 +372,14 @@ def forward(self, cpu_state: torch.Tensor, # Split into mantissa + exponent mantissa, exponent = self.log_projector.split(cpu_state) - # Initialize - hypotheses = self.hypothesis_init.unsqueeze(0).expand(B, -1, -1).clone() + # Initialize hypotheses — conditioned on rule_memory if available + K = self.num_hypotheses + if rule_memory is not None: + rule_ctx = rule_memory.mean(dim=1) # [B, D] + h_offset = self.hypothesis_projector(rule_ctx).view(B, K, 4) + hypotheses = self.hypothesis_init.unsqueeze(0).expand(B, -1, -1) + h_offset + else: + hypotheses = self.hypothesis_init.unsqueeze(0).expand(B, -1, -1).clone() R_accum = _identity_rotor(B, D, device, dtype) fim_prev = None @@ -361,16 +392,26 @@ def forward(self, cpu_state: torch.Tensor, 'fim_values': [], } if return_trace else None + num_active_steps = len(self.steps) + exp_gate = torch.sigmoid(self.exponent_gate) + for t, step in enumerate(self.steps): result = step( mantissa, hypotheses, R_accum, mask, rule_memory, fim_prev, targets, + step_idx=t, total_steps=num_active_steps, + gradient_horizon=self.gradient_horizon, ) mantissa = result['world_state'] hypotheses = result['hypotheses'] R_accum = result['R_accum'] fim_prev = result['fim_values'] + # Gated exponent update: enables magnitude learning per step + # tanh bounds delta to [-1, 1] — max exponent shift per step is ~0.007 + exp_delta = torch.tanh(self.exponent_update(mantissa.mean(dim=1, keepdim=True))) + exponent = exponent + exp_gate * exp_delta + step_outputs.append(mantissa) search_info = result['search_info'] From 1cb57317a3a56d5003e5d54529944785621a267f Mon Sep 17 00:00:00 2001 From: Concode0 Date: Mon, 23 Mar 2026 16:54:26 +0900 Subject: [PATCH 16/16] fix: check that previous gradient is not working properly - cause the model memory blind --- tasks/gtm.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/tasks/gtm.py b/tasks/gtm.py index bb90058..e3b64e5 100644 --- a/tasks/gtm.py +++ b/tasks/gtm.py @@ -91,6 +91,7 @@ def setup_model(self): halt_eps=ig_cfg.get('halt_eps', 0.01), use_supervised_fim=ig_cfg.get('use_supervised_fim', True), action_gate_init=ae_cfg.get('gate_init', 0.0), + gradient_horizon=mcfg.get('gradient_horizon', 2), ) def _setup_optimizer(self): @@ -218,6 +219,11 @@ def train_step(self, batch): self._backward(loss) + # Unscale BEFORE grad clip so clipping operates on real gradient magnitudes. + # Without this, AMP's scale factor (65536) makes the effective clip ~1e-5. + if self._scaler is not None: + self._scaler.unscale_(self.optimizer) + if self.grad_clip > 0: trainable = [p for p in self.model.parameters() if p.requires_grad] if trainable: @@ -311,16 +317,23 @@ def run(self): logger.info("Phase 3: FIM Halt + Conviction Collapse") self.model.enable_fim_halt() - # Preserve LR from previous phase when transitioning 2->3 - # to avoid destabilizing learned representations + min_lr = self.cfg.training.get('min_lr', 1e-5) + + # LR handling per phase transition prev_lr = (self.optimizer.param_groups[0]['lr'] if prev_phase > 0 else self.cfg.training.lr) self.optimizer = self._setup_optimizer() - if prev_phase > 0: + if phase == 3 and prev_phase == 2: + # Reset to a viable LR floor in case scheduler killed it + phase3_lr = max(self.cfg.training.lr * 0.1, min_lr) + for pg in self.optimizer.param_groups: + pg['lr'] = phase3_lr + elif prev_phase > 0: for pg in self.optimizer.param_groups: pg['lr'] = prev_lr self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, mode='min', factor=0.5, patience=10) + self.optimizer, mode='min', factor=0.5, patience=5, + min_lr=min_lr) # Temperature schedule if phase == 1: @@ -334,6 +347,18 @@ def run(self): tau = self.tau_mid + (self.tau_end - self.tau_mid) * progress self.model.set_temperature(tau) + # Phase 2 LR warmup: ramp from 1/10 to capped peak over first 10 epochs. + # Caps at phase2_lr_scale * base_lr to prevent instability in 12-step model. + if phase == 2: + phase2_epoch = epoch - self.warmup_epochs + warmup_len = min(10, max(self.trim_epochs // 4, 1)) + if phase2_epoch < warmup_len: + base_lr = self.cfg.training.lr + phase2_scale = self.cfg.training.get('phase2_lr_scale', 0.5) + peak_lr = base_lr * phase2_scale + for pg in self.optimizer.param_groups: + pg['lr'] = peak_lr * (phase2_epoch + 1) / warmup_len + # FIM mixing ramp: gradually blend in FIM-weighted output over Phase 3 if phase == 3: act_epoch = epoch - (self.warmup_epochs + self.trim_epochs)