We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 270c330 commit f309c6bCopy full SHA for f309c6b
1 file changed
layers/primitives/projection.py
@@ -143,13 +143,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
144
# 3. Perform Projection
145
# Solve: cur_cov_bb * W = cur_cov_bs => W = inv(cur_cov_bb) * cur_cov_bs
146
- # Use pseudo-inverse for stability
147
- if cur_cov_bb.device.type == 'mps':
148
- inv_bb = torch.linalg.pinv(cur_cov_bb.cpu()).to(cur_cov_bb.device)
149
- else:
150
- inv_bb = torch.linalg.pinv(cur_cov_bb)
151
-
152
- weights = torch.matmul(inv_bb, cur_cov_bs)
+ reg = 1e-6 * torch.eye(
+ self.d2, device=cur_cov_bb.device, dtype=cur_cov_bb.dtype
+ ).unsqueeze(0)
+ weights = torch.linalg.solve(cur_cov_bb + reg, cur_cov_bs)
153
154
# Center based on current means
155
b_centered = bivec - cur_mean_b.unsqueeze(0)
0 commit comments