@@ -71,6 +71,12 @@ def __init__(self, p: int, q: int = 0, r: int = 0, device='cuda'):
7171 self .rc_action ,
7272 ) = CliffordAlgebra ._CACHED_TABLES [cache_key ]
7373
74+ self .grade_masks_float = [m .float () for m in self .grade_masks ]
75+ if self .n >= 2 :
76+ self ._bv_indices = self .grade_masks [2 ].nonzero (as_tuple = False ).squeeze (- 1 )
77+ else :
78+ self ._bv_indices = torch .zeros (0 , dtype = torch .long , device = self .device )
79+
7480 @property
7581 def num_grades (self ) -> int :
7682 """Counts the number of grades (n + 1).
@@ -304,8 +310,7 @@ def geometric_product(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
304310 B_gathered = B [..., idx ] # [..., D, D]
305311
306312 # result[..., k] = sum_i A[..., i] * B[..., cayley[i,k]] * signs[i,k]
307- # = sum_i (A[..., i, None] * B_gathered[..., i, :] * signs[i, :]) summed over i
308- return (A .unsqueeze (- 1 ) * B_gathered * self .gp_signs ).sum (dim = - 2 )
313+ return torch .matmul (A .unsqueeze (- 2 ), B_gathered * self .gp_signs ).squeeze (- 2 )
309314
310315 def ensure_device (self , device ) -> None :
311316 """Move cached tables to the given device if not already there.
@@ -319,13 +324,20 @@ def ensure_device(self, device) -> None:
319324 self .cayley_signs = self .cayley_signs .to (device )
320325 self .gp_signs = self .gp_signs .to (device )
321326 self .grade_masks = [m .to (device ) for m in self .grade_masks ]
327+ self .grade_masks_float = [m .to (device ) for m in self .grade_masks_float ]
328+ self ._bv_indices = self ._bv_indices .to (device )
322329 self .rev_signs = self .rev_signs .to (device )
323330 self .bv_sq_scalar = self .bv_sq_scalar .to (device )
324331 self .wedge_gp_signs = self .wedge_gp_signs .to (device )
325332 self .inner_gp_signs = self .inner_gp_signs .to (device )
326333 self .grade_index = self .grade_index .to (device )
327334 self .rc_action = self .rc_action .to (device )
328335
336+ # Clear lazy caches (will be recomputed on next use)
337+ self ._left_sign_T = None
338+ self ._ps_source = None
339+ self ._ps_signs = None
340+
329341 cache_key = (self .p , self .q , self .r , str (self .device ))
330342 CliffordAlgebra ._CACHED_TABLES [cache_key ] = (
331343 self .cayley_indices , self .cayley_signs , self .gp_signs ,
@@ -337,19 +349,20 @@ def ensure_device(self, device) -> None:
337349 def grade_projection (self , mv : torch .Tensor , grade : int ) -> torch .Tensor :
338350 """Isolates a specific grade.
339351
352+ Uses multiplicative masking (mv * float_mask) instead of boolean
353+ indexing to avoid ``nonzero`` calls that break ``torch.compile``.
354+
340355 Args:
341356 mv (torch.Tensor): Multivector [..., Dim].
342357 grade (int): Target grade.
343358
344359 Returns:
345360 torch.Tensor: Projected multivector [..., Dim].
346361 """
347- mask = self .grade_masks [grade ]
348- if mask .device != mv .device :
349- mask = mask .to (mv .device )
350- result = torch .zeros_like (mv )
351- result [..., mask ] = mv [..., mask ]
352- return result
362+ mask = self .grade_masks_float [grade ]
363+ if mask .device != mv .device or mask .dtype != mv .dtype :
364+ mask = mask .to (device = mv .device , dtype = mv .dtype )
365+ return mv * mask
353366
354367 def reverse (self , mv : torch .Tensor ) -> torch .Tensor :
355368 """Computes the reversion. The Clifford conjugate.
@@ -386,7 +399,7 @@ def wedge(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
386399 self .ensure_device (A .device )
387400 idx = self .cayley_indices
388401 B_gathered = B [..., idx ]
389- return (A .unsqueeze (- 1 ) * B_gathered * self .wedge_gp_signs ).sum ( dim = - 2 )
402+ return torch . matmul (A .unsqueeze (- 2 ), B_gathered * self .wedge_gp_signs ).squeeze ( - 2 )
390403
391404 def right_contraction (self , A : torch .Tensor , B : torch .Tensor ) -> torch .Tensor :
392405 """Computes the right contraction: A _| B.
@@ -405,14 +418,20 @@ def right_contraction(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
405418 Returns:
406419 torch.Tensor: Right contraction A _| B [..., dim].
407420 """
408- bv_mask = self .grade_masks [ 2 ]
409- if bv_mask .device != A .device :
421+ bv_idx = self ._bv_indices
422+ if bv_idx .device != A .device :
410423 self .ensure_device (A .device )
411- bv_mask = self .grade_masks [2 ]
424+ bv_idx = self ._bv_indices
425+
426+ # Use gather instead of boolean indexing (compile-friendly)
427+ bv_idx_exp = bv_idx .expand (* A .shape [:- 1 ], - 1 )
428+ bv_coeffs = torch .gather (A , - 1 , bv_idx_exp ) # [..., num_bv]
412429
413- bv_coeffs = A [..., bv_mask ] # [..., num_bv]
414- g1_idx = self .grade_masks [1 ].nonzero (as_tuple = False ).squeeze (- 1 )
415- v_coeffs = B [..., g1_idx ] # [..., n]
430+ # Grade-1 indices: powers of 2 for basis vectors
431+ g1_idx = torch .arange (self .n , device = A .device )
432+ g1_idx = (1 << g1_idx ).long () # [n]
433+ g1_idx_exp = g1_idx .expand (* B .shape [:- 1 ], - 1 )
434+ v_coeffs = torch .gather (B , - 1 , g1_idx_exp ) # [..., n]
416435
417436 rc = self .rc_action
418437 if rc .device != A .device or rc .dtype != A .dtype :
@@ -423,7 +442,7 @@ def right_contraction(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
423442 result_v = torch .matmul (M , v_coeffs .unsqueeze (- 1 )).squeeze (- 1 ) # [..., n]
424443
425444 result = torch .zeros_like (A )
426- result [..., g1_idx ] = result_v
445+ result . scatter_ ( - 1 , g1_idx_exp , result_v )
427446 return result
428447
429448 def inner_product (self , A : torch .Tensor , B : torch .Tensor ) -> torch .Tensor :
@@ -447,7 +466,7 @@ def inner_product(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
447466 self .ensure_device (A .device )
448467 idx = self .cayley_indices
449468 B_gathered = B [..., idx ]
450- return (A .unsqueeze (- 1 ) * B_gathered * self .inner_gp_signs ).sum ( dim = - 2 )
469+ return torch . matmul (A .unsqueeze (- 2 ), B_gathered * self .inner_gp_signs ).squeeze ( - 2 )
451470
452471 def blade_inverse (self , blade : torch .Tensor ) -> torch .Tensor :
453472 """Compute the inverse of a blade: B^{-1} = B_rev / <B * B_rev>_0.
@@ -466,6 +485,83 @@ def blade_inverse(self, blade: torch.Tensor) -> torch.Tensor:
466485 scalar = blade_sq [..., 0 :1 ].clamp (min = 1e-12 )
467486 return blade_rev / scalar
468487
488+ def sandwich_product (self , R : torch .Tensor , x : torch .Tensor ,
489+ R_rev : torch .Tensor = None ) -> torch .Tensor :
490+ """Optimized sandwich product R x R~ via action matrix.
491+
492+ Builds a [N, D, D] sandwich action matrix from the rotor, then applies
493+ it to all C channels via a single batched matmul. This is much faster
494+ than two separate ``geometric_product`` calls when x has extra channel
495+ dimensions that R does not.
496+
497+ Memory: O(N*D*D) where N = batch (without channels).
498+ Compare to naive: O(N*C*D*D) — a factor of C improvement.
499+
500+ Args:
501+ R: Rotors [N, D] (2-D, batch-flattened).
502+ x: Multivectors [N, C, D] (3-D, C channels per rotor).
503+ R_rev: Optional precomputed reverse of R [N, D].
504+
505+ Returns:
506+ Sandwiched result [N, C, D].
507+ """
508+ D = self .dim
509+ self .ensure_device (R .device )
510+
511+ if R_rev is None :
512+ R_rev = self .reverse (R )
513+
514+ ci = self .cayley_indices # [D, D], ci[i, j] = i ^ j
515+
516+ # Precompute left-sign table (lazy, cached per device)
517+ if not hasattr (self , '_left_sign_T' ) or self ._left_sign_T is None \
518+ or self ._left_sign_T .device != R .device :
519+ k_range = torch .arange (D , device = R .device )
520+ # ls[j, k] = gp_signs[j^k, k]
521+ ls = self .gp_signs [ci , k_range .unsqueeze (0 ).expand (D , D )]
522+ self ._left_sign_T = ls .T .contiguous () # [D(k), D(j)]
523+
524+ # Left-multiplication matrix L_R: L_R[n, k, j] = R[n, j^k] * gp_signs[j^k, k]
525+ R_gathered = R [:, ci ] # [N, D(j), D(k)]
526+ L_R = R_gathered .permute (0 , 2 , 1 ) * self ._left_sign_T .unsqueeze (0 )
527+
528+ # Right-multiplication matrix R_{R~}: R_Rr[n, k, i] = R~[n, i^k] * gp_signs[i, k]
529+ gp_T = self .gp_signs .T
530+ Rr_gathered = R_rev [:, ci ] # [N, D(i), D(k)]
531+ R_Rr = Rr_gathered .permute (0 , 2 , 1 ) * gp_T .unsqueeze (0 )
532+
533+ # Sandwich matrix: M = R_Rr @ L_R → (R x R~)[k] = sum_j M[k, j] * x[j]
534+ M = torch .bmm (R_Rr , L_R ) # [N, D, D]
535+
536+ # Apply to all channels: result[n, c, k] = sum_j M[n, k, j] * x[n, c, j]
537+ return torch .matmul (x , M .transpose (- 2 , - 1 ))
538+
539+ def pseudoscalar_product (self , x : torch .Tensor ) -> torch .Tensor :
540+ """Multiply by the unit pseudoscalar: x * I.
541+
542+ Maps grade-k to grade-(n-k) (Hodge dual). Computed as a simple
543+ permutation with sign flips — no geometric product needed.
544+
545+ Args:
546+ x: Multivector [..., D].
547+
548+ Returns:
549+ Result [..., D].
550+ """
551+ D = self .dim
552+ if not hasattr (self , '_ps_source' ) or self ._ps_source is None \
553+ or self ._ps_source .device != x .device :
554+ self .ensure_device (x .device )
555+ ps_src = torch .arange (D , device = x .device ) ^ (D - 1 )
556+ self ._ps_source = ps_src
557+ self ._ps_signs = self .gp_signs [ps_src , torch .arange (D , device = x .device )]
558+
559+ ps_signs = self ._ps_signs
560+ if ps_signs .dtype != x .dtype :
561+ ps_signs = ps_signs .to (dtype = x .dtype )
562+
563+ return x [..., self ._ps_source ] * ps_signs
564+
469565 def blade_project (self , mv : torch .Tensor , blade : torch .Tensor ) -> torch .Tensor :
470566 """Project multivector onto blade subspace: (mv . B) B^{-1}.
471567
@@ -553,10 +649,11 @@ def _exp_bivector_closed(self, B: torch.Tensor) -> torch.Tensor:
553649 Returns:
554650 torch.Tensor: Rotor exp(B) [..., dim].
555651 """
556- bv_mask = self .grade_masks [2 ]
557- if bv_mask .device != B .device :
558- bv_mask = bv_mask .to (B .device )
559- bv_coeffs = B [..., bv_mask ] # [..., num_bivectors]
652+ bv_idx = self ._bv_indices
653+ if bv_idx .device != B .device :
654+ bv_idx = bv_idx .to (B .device )
655+ idx_expanded = bv_idx .expand (* B .shape [:- 1 ], - 1 )
656+ bv_coeffs = torch .gather (B , - 1 , idx_expanded ) # [..., num_bivectors]
560657
561658 bv_sq = self .bv_sq_scalar
562659 if bv_sq .device != B .device :
@@ -599,10 +696,10 @@ def _exp_bivector_closed(self, B: torch.Tensor) -> torch.Tensor:
599696 torch .where (is_hyperbolic , sinhc_theta , torch .ones_like (theta ))
600697 )
601698
602- result = coeff_part * B
603- result [..., 0 ] = scalar_part . squeeze ( - 1 )
604-
605- return result
699+ g0_mask = self . grade_masks_float [ 0 ]
700+ if g0_mask . device != B . device or g0_mask . dtype != B . dtype :
701+ g0_mask = g0_mask . to ( device = B . device , dtype = B . dtype )
702+ return scalar_part * g0_mask + coeff_part * B
606703
607704 def _exp_taylor (self , mv : torch .Tensor , order : int = 8 ) -> torch .Tensor :
608705 """Taylor series exponential with scaling-and-squaring (fallback).
0 commit comments