Skip to content

Commit c6b4ae7

Browse files
committed
refactor: optimize core kernel - remove intermediate tensor and replace with fma
1 parent 6e29aea commit c6b4ae7

2 files changed

Lines changed: 126 additions & 26 deletions

File tree

core/algebra.py

Lines changed: 122 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def __init__(self, p: int, q: int = 0, r: int = 0, device='cuda'):
7171
self.rc_action,
7272
) = CliffordAlgebra._CACHED_TABLES[cache_key]
7373

74+
self.grade_masks_float = [m.float() for m in self.grade_masks]
75+
if self.n >= 2:
76+
self._bv_indices = self.grade_masks[2].nonzero(as_tuple=False).squeeze(-1)
77+
else:
78+
self._bv_indices = torch.zeros(0, dtype=torch.long, device=self.device)
79+
7480
@property
7581
def num_grades(self) -> int:
7682
"""Counts the number of grades (n + 1).
@@ -304,8 +310,7 @@ def geometric_product(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
304310
B_gathered = B[..., idx] # [..., D, D]
305311

306312
# result[..., k] = sum_i A[..., i] * B[..., cayley[i,k]] * signs[i,k]
307-
# = sum_i (A[..., i, None] * B_gathered[..., i, :] * signs[i, :]) summed over i
308-
return (A.unsqueeze(-1) * B_gathered * self.gp_signs).sum(dim=-2)
313+
return torch.matmul(A.unsqueeze(-2), B_gathered * self.gp_signs).squeeze(-2)
309314

310315
def ensure_device(self, device) -> None:
311316
"""Move cached tables to the given device if not already there.
@@ -319,13 +324,20 @@ def ensure_device(self, device) -> None:
319324
self.cayley_signs = self.cayley_signs.to(device)
320325
self.gp_signs = self.gp_signs.to(device)
321326
self.grade_masks = [m.to(device) for m in self.grade_masks]
327+
self.grade_masks_float = [m.to(device) for m in self.grade_masks_float]
328+
self._bv_indices = self._bv_indices.to(device)
322329
self.rev_signs = self.rev_signs.to(device)
323330
self.bv_sq_scalar = self.bv_sq_scalar.to(device)
324331
self.wedge_gp_signs = self.wedge_gp_signs.to(device)
325332
self.inner_gp_signs = self.inner_gp_signs.to(device)
326333
self.grade_index = self.grade_index.to(device)
327334
self.rc_action = self.rc_action.to(device)
328335

336+
# Clear lazy caches (will be recomputed on next use)
337+
self._left_sign_T = None
338+
self._ps_source = None
339+
self._ps_signs = None
340+
329341
cache_key = (self.p, self.q, self.r, str(self.device))
330342
CliffordAlgebra._CACHED_TABLES[cache_key] = (
331343
self.cayley_indices, self.cayley_signs, self.gp_signs,
@@ -337,19 +349,20 @@ def ensure_device(self, device) -> None:
337349
def grade_projection(self, mv: torch.Tensor, grade: int) -> torch.Tensor:
338350
"""Isolates a specific grade.
339351
352+
Uses multiplicative masking (mv * float_mask) instead of boolean
353+
indexing to avoid ``nonzero`` calls that break ``torch.compile``.
354+
340355
Args:
341356
mv (torch.Tensor): Multivector [..., Dim].
342357
grade (int): Target grade.
343358
344359
Returns:
345360
torch.Tensor: Projected multivector [..., Dim].
346361
"""
347-
mask = self.grade_masks[grade]
348-
if mask.device != mv.device:
349-
mask = mask.to(mv.device)
350-
result = torch.zeros_like(mv)
351-
result[..., mask] = mv[..., mask]
352-
return result
362+
mask = self.grade_masks_float[grade]
363+
if mask.device != mv.device or mask.dtype != mv.dtype:
364+
mask = mask.to(device=mv.device, dtype=mv.dtype)
365+
return mv * mask
353366

354367
def reverse(self, mv: torch.Tensor) -> torch.Tensor:
355368
"""Computes the reversion. The Clifford conjugate.
@@ -386,7 +399,7 @@ def wedge(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
386399
self.ensure_device(A.device)
387400
idx = self.cayley_indices
388401
B_gathered = B[..., idx]
389-
return (A.unsqueeze(-1) * B_gathered * self.wedge_gp_signs).sum(dim=-2)
402+
return torch.matmul(A.unsqueeze(-2), B_gathered * self.wedge_gp_signs).squeeze(-2)
390403

391404
def right_contraction(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
392405
"""Computes the right contraction: A _| B.
@@ -405,14 +418,20 @@ def right_contraction(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
405418
Returns:
406419
torch.Tensor: Right contraction A _| B [..., dim].
407420
"""
408-
bv_mask = self.grade_masks[2]
409-
if bv_mask.device != A.device:
421+
bv_idx = self._bv_indices
422+
if bv_idx.device != A.device:
410423
self.ensure_device(A.device)
411-
bv_mask = self.grade_masks[2]
424+
bv_idx = self._bv_indices
425+
426+
# Use gather instead of boolean indexing (compile-friendly)
427+
bv_idx_exp = bv_idx.expand(*A.shape[:-1], -1)
428+
bv_coeffs = torch.gather(A, -1, bv_idx_exp) # [..., num_bv]
412429

413-
bv_coeffs = A[..., bv_mask] # [..., num_bv]
414-
g1_idx = self.grade_masks[1].nonzero(as_tuple=False).squeeze(-1)
415-
v_coeffs = B[..., g1_idx] # [..., n]
430+
# Grade-1 indices: powers of 2 for basis vectors
431+
g1_idx = torch.arange(self.n, device=A.device)
432+
g1_idx = (1 << g1_idx).long() # [n]
433+
g1_idx_exp = g1_idx.expand(*B.shape[:-1], -1)
434+
v_coeffs = torch.gather(B, -1, g1_idx_exp) # [..., n]
416435

417436
rc = self.rc_action
418437
if rc.device != A.device or rc.dtype != A.dtype:
@@ -423,7 +442,7 @@ def right_contraction(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
423442
result_v = torch.matmul(M, v_coeffs.unsqueeze(-1)).squeeze(-1) # [..., n]
424443

425444
result = torch.zeros_like(A)
426-
result[..., g1_idx] = result_v
445+
result.scatter_(-1, g1_idx_exp, result_v)
427446
return result
428447

429448
def inner_product(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
@@ -447,7 +466,7 @@ def inner_product(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
447466
self.ensure_device(A.device)
448467
idx = self.cayley_indices
449468
B_gathered = B[..., idx]
450-
return (A.unsqueeze(-1) * B_gathered * self.inner_gp_signs).sum(dim=-2)
469+
return torch.matmul(A.unsqueeze(-2), B_gathered * self.inner_gp_signs).squeeze(-2)
451470

452471
def blade_inverse(self, blade: torch.Tensor) -> torch.Tensor:
453472
"""Compute the inverse of a blade: B^{-1} = B_rev / <B * B_rev>_0.
@@ -466,6 +485,83 @@ def blade_inverse(self, blade: torch.Tensor) -> torch.Tensor:
466485
scalar = blade_sq[..., 0:1].clamp(min=1e-12)
467486
return blade_rev / scalar
468487

488+
def sandwich_product(self, R: torch.Tensor, x: torch.Tensor,
489+
R_rev: torch.Tensor = None) -> torch.Tensor:
490+
"""Optimized sandwich product R x R~ via action matrix.
491+
492+
Builds a [N, D, D] sandwich action matrix from the rotor, then applies
493+
it to all C channels via a single batched matmul. This is much faster
494+
than two separate ``geometric_product`` calls when x has extra channel
495+
dimensions that R does not.
496+
497+
Memory: O(N*D*D) where N = batch (without channels).
498+
Compare to naive: O(N*C*D*D) — a factor of C improvement.
499+
500+
Args:
501+
R: Rotors [N, D] (2-D, batch-flattened).
502+
x: Multivectors [N, C, D] (3-D, C channels per rotor).
503+
R_rev: Optional precomputed reverse of R [N, D].
504+
505+
Returns:
506+
Sandwiched result [N, C, D].
507+
"""
508+
D = self.dim
509+
self.ensure_device(R.device)
510+
511+
if R_rev is None:
512+
R_rev = self.reverse(R)
513+
514+
ci = self.cayley_indices # [D, D], ci[i, j] = i ^ j
515+
516+
# Precompute left-sign table (lazy, cached per device)
517+
if not hasattr(self, '_left_sign_T') or self._left_sign_T is None \
518+
or self._left_sign_T.device != R.device:
519+
k_range = torch.arange(D, device=R.device)
520+
# ls[j, k] = gp_signs[j^k, k]
521+
ls = self.gp_signs[ci, k_range.unsqueeze(0).expand(D, D)]
522+
self._left_sign_T = ls.T.contiguous() # [D(k), D(j)]
523+
524+
# Left-multiplication matrix L_R: L_R[n, k, j] = R[n, j^k] * gp_signs[j^k, k]
525+
R_gathered = R[:, ci] # [N, D(j), D(k)]
526+
L_R = R_gathered.permute(0, 2, 1) * self._left_sign_T.unsqueeze(0)
527+
528+
# Right-multiplication matrix R_{R~}: R_Rr[n, k, i] = R~[n, i^k] * gp_signs[i, k]
529+
gp_T = self.gp_signs.T
530+
Rr_gathered = R_rev[:, ci] # [N, D(i), D(k)]
531+
R_Rr = Rr_gathered.permute(0, 2, 1) * gp_T.unsqueeze(0)
532+
533+
# Sandwich matrix: M = R_Rr @ L_R → (R x R~)[k] = sum_j M[k, j] * x[j]
534+
M = torch.bmm(R_Rr, L_R) # [N, D, D]
535+
536+
# Apply to all channels: result[n, c, k] = sum_j M[n, k, j] * x[n, c, j]
537+
return torch.matmul(x, M.transpose(-2, -1))
538+
539+
def pseudoscalar_product(self, x: torch.Tensor) -> torch.Tensor:
540+
"""Multiply by the unit pseudoscalar: x * I.
541+
542+
Maps grade-k to grade-(n-k) (Hodge dual). Computed as a simple
543+
permutation with sign flips — no geometric product needed.
544+
545+
Args:
546+
x: Multivector [..., D].
547+
548+
Returns:
549+
Result [..., D].
550+
"""
551+
D = self.dim
552+
if not hasattr(self, '_ps_source') or self._ps_source is None \
553+
or self._ps_source.device != x.device:
554+
self.ensure_device(x.device)
555+
ps_src = torch.arange(D, device=x.device) ^ (D - 1)
556+
self._ps_source = ps_src
557+
self._ps_signs = self.gp_signs[ps_src, torch.arange(D, device=x.device)]
558+
559+
ps_signs = self._ps_signs
560+
if ps_signs.dtype != x.dtype:
561+
ps_signs = ps_signs.to(dtype=x.dtype)
562+
563+
return x[..., self._ps_source] * ps_signs
564+
469565
def blade_project(self, mv: torch.Tensor, blade: torch.Tensor) -> torch.Tensor:
470566
"""Project multivector onto blade subspace: (mv . B) B^{-1}.
471567
@@ -553,10 +649,11 @@ def _exp_bivector_closed(self, B: torch.Tensor) -> torch.Tensor:
553649
Returns:
554650
torch.Tensor: Rotor exp(B) [..., dim].
555651
"""
556-
bv_mask = self.grade_masks[2]
557-
if bv_mask.device != B.device:
558-
bv_mask = bv_mask.to(B.device)
559-
bv_coeffs = B[..., bv_mask] # [..., num_bivectors]
652+
bv_idx = self._bv_indices
653+
if bv_idx.device != B.device:
654+
bv_idx = bv_idx.to(B.device)
655+
idx_expanded = bv_idx.expand(*B.shape[:-1], -1)
656+
bv_coeffs = torch.gather(B, -1, idx_expanded) # [..., num_bivectors]
560657

561658
bv_sq = self.bv_sq_scalar
562659
if bv_sq.device != B.device:
@@ -599,10 +696,10 @@ def _exp_bivector_closed(self, B: torch.Tensor) -> torch.Tensor:
599696
torch.where(is_hyperbolic, sinhc_theta, torch.ones_like(theta))
600697
)
601698

602-
result = coeff_part * B
603-
result[..., 0] = scalar_part.squeeze(-1)
604-
605-
return result
699+
g0_mask = self.grade_masks_float[0]
700+
if g0_mask.device != B.device or g0_mask.dtype != B.dtype:
701+
g0_mask = g0_mask.to(device=B.device, dtype=B.dtype)
702+
return scalar_part * g0_mask + coeff_part * B
606703

607704
def _exp_taylor(self, mv: torch.Tensor, order: int = 8) -> torch.Tensor:
608705
"""Taylor series exponential with scaling-and-squaring (fallback).

core/device.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ def __post_init__(self) -> None:
7474
# Public helpers
7575

7676
def apply_backend_settings(self) -> None:
77-
"""Apply ``cudnn.benchmark`` (and future backend knobs)."""
77+
"""Apply ``cudnn.benchmark``, TF32 matmul precision, etc."""
7878
if torch.backends.cudnn.is_available():
7979
torch.backends.cudnn.benchmark = self.cudnn_benchmark
80+
# Enable TF32 tensor cores on Ampere+ GPUs (RTX 30xx, 40xx, Ada)
81+
if self.device.startswith("cuda"):
82+
torch.set_float32_matmul_precision("high")
8083

8184
def maybe_compile(self, model: nn.Module) -> nn.Module:
8285
"""Optionally wrap *model* with :func:`torch.compile`."""

0 commit comments

Comments
 (0)