Skip to content

Commit f309c6b

Browse files
committed
refactor: change projection logic not svd using direct solve with Tikhonov
1 parent 270c330 commit f309c6b

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

layers/primitives/projection.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
143143

144144
# 3. Perform Projection
145145
# Solve: cur_cov_bb * W = cur_cov_bs => W = inv(cur_cov_bb) * cur_cov_bs
146-
# Use pseudo-inverse for stability
147-
if cur_cov_bb.device.type == 'mps':
148-
inv_bb = torch.linalg.pinv(cur_cov_bb.cpu()).to(cur_cov_bb.device)
149-
else:
150-
inv_bb = torch.linalg.pinv(cur_cov_bb)
151-
152-
weights = torch.matmul(inv_bb, cur_cov_bs)
146+
reg = 1e-6 * torch.eye(
147+
self.d2, device=cur_cov_bb.device, dtype=cur_cov_bb.dtype
148+
).unsqueeze(0)
149+
weights = torch.linalg.solve(cur_cov_bb + reg, cur_cov_bs)
153150

154151
# Center based on current means
155152
b_centered = bivec - cur_mean_b.unsqueeze(0)

0 commit comments

Comments
 (0)