Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "la-stack"
version = "0.4.0"
edition = "2024"
rust-version = "1.94"
rust-version = "1.95"
license = "BSD-3-Clause"
description = "Fast, stack-allocated linear algebra for fixed dimensions"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[toolchain]
# Pin to MSRV as specified in Cargo.toml
channel = "1.94.0"
channel = "1.95.0"

# Essential components for development
components = [
Expand Down
26 changes: 20 additions & 6 deletions src/exact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
//! \[10\] for background on floating-point representation and exact
//! rational reconstruction. Reference numbers refer to `REFERENCES.md`.

use core::hint::cold_path;
use std::array::from_fn;

use num_bigint::{BigInt, Sign};
Expand All @@ -61,6 +62,7 @@ fn validate_finite<const D: usize>(m: &Matrix<D>) -> Result<(), LaError> {
for r in 0..D {
for c in 0..D {
if !m.rows[r][c].is_finite() {
cold_path();
return Err(LaError::NonFinite {
row: Some(r),
col: c,
Expand All @@ -78,6 +80,7 @@ fn validate_finite<const D: usize>(m: &Matrix<D>) -> Result<(), LaError> {
fn validate_finite_vec<const D: usize>(v: &Vector<D>) -> Result<(), LaError> {
for (i, &x) in v.data.iter().enumerate() {
if !x.is_finite() {
cold_path();
return Err(LaError::NonFinite { row: None, col: i });
}
}
Expand Down Expand Up @@ -324,6 +327,7 @@ fn gauss_solve<const D: usize>(m: &Matrix<D>, b: &Vector<D>) -> Result<[BigRatio
mat.swap(k, swap_row);
rhs.swap(k, swap_row);
} else {
cold_path();
return Err(LaError::Singular { pivot_col: k });
}
}
Expand Down Expand Up @@ -419,6 +423,7 @@ impl<const D: usize> Matrix<D> {
if val.is_finite() {
Ok(val)
} else {
cold_path();
Err(LaError::Overflow { index: None })
}
}
Expand Down Expand Up @@ -490,6 +495,7 @@ impl<const D: usize> Matrix<D> {
for (i, val) in exact.iter().enumerate() {
let f = val.to_f64().unwrap_or(f64::INFINITY);
if !f.is_finite() {
cold_path();
return Err(LaError::Overflow { index: Some(i) });
}
result[i] = f;
Expand Down Expand Up @@ -537,23 +543,31 @@ impl<const D: usize> Matrix<D> {
validate_finite(self)?;

// Stage 1: f64 fast filter for D ≤ 4.
if let (Some(det_f64), Some(err)) = (self.det_direct(), self.det_errbound()) {
// When entries are large (e.g. near f64::MAX) the determinant can
// overflow to infinity even though every individual entry is finite.
// In that case the fast filter is inconclusive; fall through to the
// exact Bareiss path.
if det_f64.is_finite() {
//
// When entries are large (e.g. near f64::MAX) the determinant can
// overflow to infinity even though every individual entry is finite.
// In that case the fast filter is inconclusive; fall through to the
// exact Bareiss path.
match self.det_direct() {
Some(det_f64)
if let Some(err) = self.det_errbound()
&& det_f64.is_finite() =>
{
if det_f64 > err {
return Ok(1);
}
if det_f64 < -err {
return Ok(-1);
}
}
_ => {}
}

// Stage 2: integer Bareiss fallback — the 2^(D×e_min) scale factor
// is always positive, so det_int.sign() == det(A).sign().
// This is the cold path: the fast filter resolves the vast majority of
// well-conditioned calls without allocating.
cold_path();
let (det_int, _) = bareiss_det_int(self);
Ok(match det_int.sign() {
Sign::Plus => 1,
Expand Down
26 changes: 26 additions & 0 deletions src/ldlt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//! symmetric positive definite (SPD) and positive semi-definite (PSD) matrices (e.g. Gram
//! matrices) without pivoting.

use core::hint::cold_path;

use crate::LaError;
use crate::matrix::Matrix;
use crate::vector::Vector;
Expand Down Expand Up @@ -39,19 +41,22 @@ impl<const D: usize> Ldlt<D> {
for j in 0..D {
let d = f.rows[j][j];
if !d.is_finite() {
cold_path();
return Err(LaError::NonFinite {
row: Some(j),
col: j,
});
}
if d <= tol {
cold_path();
return Err(LaError::Singular { pivot_col: j });
}

// Compute L multipliers below the diagonal in column j.
for i in (j + 1)..D {
let l = f.rows[i][j] / d;
if !l.is_finite() {
cold_path();
return Err(LaError::NonFinite {
row: Some(i),
col: j,
Expand All @@ -69,6 +74,7 @@ impl<const D: usize> Ldlt<D> {
let l_k = f.rows[k][j];
let new_val = (-l_i_d).mul_add(l_k, f.rows[i][k]);
if !new_val.is_finite() {
cold_path();
return Err(LaError::NonFinite {
row: Some(i),
col: k,
Expand Down Expand Up @@ -141,6 +147,7 @@ impl<const D: usize> Ldlt<D> {
sum = (-row[j]).mul_add(*x_j, sum);
}
if !sum.is_finite() {
cold_path();
return Err(LaError::NonFinite { row: None, col: i });
}
x[i] = sum;
Expand All @@ -150,14 +157,17 @@ impl<const D: usize> Ldlt<D> {
for (i, x_i) in x.iter_mut().enumerate().take(D) {
let diag = self.factors.rows[i][i];
if !diag.is_finite() {
cold_path();
return Err(LaError::NonFinite { row: None, col: i });
}
if diag <= self.tol {
cold_path();
return Err(LaError::Singular { pivot_col: i });
}

let v = *x_i / diag;
if !v.is_finite() {
cold_path();
return Err(LaError::NonFinite { row: None, col: i });
}
*x_i = v;
Expand All @@ -171,6 +181,7 @@ impl<const D: usize> Ldlt<D> {
sum = (-self.factors.rows[j][i]).mul_add(*x_j, sum);
}
if !sum.is_finite() {
cold_path();
return Err(LaError::NonFinite { row: None, col: i });
}
x[i] = sum;
Expand Down Expand Up @@ -407,4 +418,19 @@ mod tests {
let err = ldlt.solve_vec(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}

#[test]
fn nonfinite_solve_vec_diagonal_solve_overflow() {
// Diagonal SPD matrix with a tiny diagonal entry just above the
// singularity tolerance. Forward substitution passes through the
// large RHS unchanged, then the diagonal solve z[1] = y[1] / D[1]
// = 1e300 / 1e-11 = 1e311 overflows f64, exercising the
// `!v.is_finite()` branch of the diagonal solve.
let a = Matrix::<2>::from_rows([[1.0, 0.0], [0.0, 1.0e-11]]);
let ldlt = a.ldlt(DEFAULT_SINGULAR_TOL).unwrap();

let b = Vector::<2>::new([0.0, 1.0e300]);
let err = ldlt.solve_vec(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}
}
27 changes: 27 additions & 0 deletions src/lu.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! LU decomposition and solves.

use core::hint::cold_path;

use crate::LaError;
use crate::matrix::Matrix;
use crate::vector::Vector;
Expand Down Expand Up @@ -31,6 +33,7 @@ impl<const D: usize> Lu<D> {
let mut pivot_row = k;
let mut pivot_abs = lu.rows[k][k].abs();
if !pivot_abs.is_finite() {
cold_path();
return Err(LaError::NonFinite {
row: Some(k),
col: k,
Expand All @@ -40,6 +43,7 @@ impl<const D: usize> Lu<D> {
for r in (k + 1)..D {
let v = lu.rows[r][k].abs();
if !v.is_finite() {
cold_path();
return Err(LaError::NonFinite {
row: Some(r),
col: k,
Expand All @@ -52,6 +56,7 @@ impl<const D: usize> Lu<D> {
}

if pivot_abs <= tol {
cold_path();
return Err(LaError::Singular { pivot_col: k });
}

Expand All @@ -63,6 +68,7 @@ impl<const D: usize> Lu<D> {

let pivot = lu.rows[k][k];
if !pivot.is_finite() {
cold_path();
return Err(LaError::NonFinite {
row: Some(k),
col: k,
Expand All @@ -73,6 +79,7 @@ impl<const D: usize> Lu<D> {
for r in (k + 1)..D {
let mult = lu.rows[r][k] / pivot;
if !mult.is_finite() {
cold_path();
return Err(LaError::NonFinite {
row: Some(r),
col: k,
Expand Down Expand Up @@ -132,6 +139,7 @@ impl<const D: usize> Lu<D> {
sum = (-row[j]).mul_add(*x_j, sum);
}
if !sum.is_finite() {
cold_path();
return Err(LaError::NonFinite { row: None, col: i });
}
x[i] = sum;
Expand All @@ -148,14 +156,17 @@ impl<const D: usize> Lu<D> {

let diag = row[i];
if !diag.is_finite() || !sum.is_finite() {
cold_path();
return Err(LaError::NonFinite { row: None, col: i });
}
if diag.abs() <= self.tol {
cold_path();
return Err(LaError::Singular { pivot_col: i });
}

let q = sum / diag;
if !q.is_finite() {
cold_path();
return Err(LaError::NonFinite { row: None, col: i });
}
x[i] = q;
Expand Down Expand Up @@ -474,4 +485,20 @@ mod tests {
let err = lu.solve_vec(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}

#[test]
fn solve_vec_nonfinite_back_substitution_sum_overflow() {
// Upper-triangular U with a very large off-diagonal in row 1 and a
// very large x[2] produced by the RHS. The back-substitution
// accumulator `sum = (-row[j]).mul_add(x[j], sum)` overflows while
// reducing row 1, so the failure is detected via the `!sum.is_finite()`
// branch of the combined diag/sum check (distinct from the
// `q = sum / diag` overflow path covered above).
let a = Matrix::<3>::from_rows([[1.0, 0.0, 0.0], [0.0, 1.0, 1.0e200], [0.0, 0.0, 1.0]]);
let lu = a.lu(DEFAULT_PIVOT_TOL).unwrap();

let b = Vector::<3>::new([0.0, 0.0, 1.0e200]);
let err = lu.solve_vec(b).unwrap_err();
assert_eq!(err, LaError::NonFinite { row: None, col: 1 });
}
}
22 changes: 21 additions & 1 deletion src/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Fixed-size, stack-allocated square matrices.

use core::hint::cold_path;

use crate::LaError;
use crate::ldlt::Ldlt;
use crate::lu::Lu;
Expand Down Expand Up @@ -266,7 +268,11 @@ impl<const D: usize> Matrix<D> {
(-r[0][1]).mul_add(c01, r[0][2].mul_add(c02, -(r[0][3] * c03))),
))
}
_ => None,
_ => {
// Cold in the common D ≤ 4 case; callers fall back to LU for D ≥ 5.
cold_path();
None
}
}
}

Expand Down Expand Up @@ -296,6 +302,7 @@ impl<const D: usize> Matrix<D> {
return if d.is_finite() {
Ok(d)
} else {
cold_path();
// Scan for the first non-finite entry to preserve coordinates.
for r in 0..D {
for c in 0..D {
Expand Down Expand Up @@ -703,6 +710,19 @@ mod tests {
);
}

#[test]
fn det_returns_nonfinite_error_for_overflow_with_finite_entries() {
// det_direct produces an overflowing f64 (1e300 * 1e300 = ∞) even
// though every matrix entry is finite. The entry scan in `det`
// falls through and returns NonFinite { row: None, col: 0 } to signal
// a computed overflow rather than a NaN/∞ input.
let m = Matrix::<2>::from_rows([[1e300, 0.0], [0.0, 1e300]]);
assert_eq!(
m.det(DEFAULT_PIVOT_TOL),
Err(LaError::NonFinite { row: None, col: 0 })
);
}

#[test]
fn det_direct_is_const_evaluable_d2() {
// Const evaluation proves the function is truly const fn.
Expand Down
Loading