From c138f6ba5acbd67ec192073bfd1d43bcd5f78710 Mon Sep 17 00:00:00 2001 From: Scott Arciszewski Date: Thu, 29 Jan 2026 15:18:12 -0500 Subject: [PATCH 1/2] ml-kem: add mutation testing and expand test coverage --- .github/workflows/mutation-testing.yml | 87 ++++ ml-kem/src/algebra.rs | 5 + ml-kem/src/kem.rs | 25 ++ ml-kem/src/pke.rs | 22 + module-lattice/tests/algebra.rs | 530 ++++++++++++++++++++++++- module-lattice/tests/encode.rs | 312 +++++++++++++++ 6 files changed, 980 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/mutation-testing.yml create mode 100644 module-lattice/tests/encode.rs diff --git a/.github/workflows/mutation-testing.yml b/.github/workflows/mutation-testing.yml new file mode 100644 index 0000000..3e68491 --- /dev/null +++ b/.github/workflows/mutation-testing.yml @@ -0,0 +1,87 @@ +name: Mutation Testing + +# Minimal permissions for security +permissions: + contents: read + +on: + # Run weekly on Sundays at 2 AM UTC + schedule: + - cron: "0 2 * * 0" + # Allow manual triggering + workflow_dispatch: + inputs: + package: + description: "Package to test (select 'all' for all packages)" + required: false + default: "all" + type: choice + options: + - all + - ml-kem + - module-lattice + - dhkem + - frodo-kem + - x-wing + +env: + CARGO_INCREMENTAL: 0 + CARGO_TERM_COLOR: always + +# Only run one mutation test at a time +concurrency: + group: mutation-testing + cancel-in-progress: false + +jobs: + mutants: + name: ${{ matrix.package }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + package: + - ml-kem + - module-lattice + # Only filter if a specific package was requested via workflow_dispatch + exclude: + - package: ${{ github.event.inputs.package != 'all' && github.event.inputs.package != 'ml-kem' && 'ml-kem' || 'NONE' }} + - package: ${{ github.event.inputs.package != 'all' && github.event.inputs.package != 'module-lattice' && 'module-lattice' || 'NONE' }} + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + submodules: recursive + + - uses: dtolnay/rust-toolchain@f7ccc83f9ed1e5b9c81d8a67d7ad1a747e22a561 # stable + with: + toolchain: stable + + - uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8 + with: + prefix-key: mutants-${{ matrix.package }} + + - name: Install cargo-mutants + run: cargo install cargo-mutants@26.1.2 --locked + + - name: Run mutation testing + run: | + cargo mutants --package "${{ matrix.package }}" --all-features -j 2 --timeout 300 2>&1 | tee mutants-output.txt + # Extract summary for job summary + { + echo "## Mutation Testing Results: ${{ matrix.package }}" + echo "" + echo '```' + tail -5 mutants-output.txt + echo '```' + } >> "$GITHUB_STEP_SUMMARY" + + - name: Upload mutation results + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: always() + with: + name: mutants-${{ matrix.package }} + path: | + mutants.out/ + mutants-output.txt + retention-days: 30 diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 79b568d..b1729cb 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -471,6 +471,11 @@ mod test { assert_eq!((&v1 * &v2), const_ntt(6)); assert_eq!((&v1 * &v3), const_ntt(9)); assert_eq!((&v2 * &v3), const_ntt(18)); + + // Verify inequality (catches PartialEq mutation that returns true unconditionally) + assert_ne!(v1, v2); + assert_ne!(v1, v3); + assert_ne!(v2, v3); } #[test] diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs index 2644d92..4399082 100644 --- a/ml-kem/src/kem.rs +++ b/ml-kem/src/kem.rs @@ -440,4 +440,29 @@ mod test { seed_test::(); seed_test::(); } + + fn key_inequality_test

() + where + P: KemParams, + { + let mut rng = UnwrapErr(SysRng); + + // Generate two different keys + let dk1 = DecapsulationKey::

::generate_from_rng(&mut rng); + let dk2 = DecapsulationKey::

::generate_from_rng(&mut rng); + + let ek1 = dk1.encapsulation_key(); + let ek2 = dk2.encapsulation_key(); + + // Verify inequality (catches PartialEq mutation that returns true unconditionally) + assert_ne!(dk1, dk2); + assert_ne!(ek1, ek2); + } + + #[test] + fn key_inequality() { + key_inequality_test::(); + key_inequality_test::(); + key_inequality_test::(); + } } diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index 88007f2..a849b55 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -264,4 +264,26 @@ mod test { let invalid_key = [0xFF; 1184]; assert!(EncryptionKey::::from_bytes(&invalid_key.into()).is_err()); } + + fn key_inequality_test

() + where + P: PkeParams, + { + let mut rng = UnwrapErr(SysRng); + let d1 = B32::generate_from_rng(&mut rng); + let d2 = B32::generate_from_rng(&mut rng); + + let (dk1, _) = DecryptionKey::

::generate(&d1); + let (dk2, _) = DecryptionKey::

::generate(&d2); + + // Verify inequality (catches PartialEq mutation that returns true unconditionally) + assert_ne!(dk1, dk2); + } + + #[test] + fn key_inequality() { + key_inequality_test::(); + key_inequality_test::(); + key_inequality_test::(); + } } diff --git a/module-lattice/tests/algebra.rs b/module-lattice/tests/algebra.rs index 8cfa1bc..23d1908 100644 --- a/module-lattice/tests/algebra.rs +++ b/module-lattice/tests/algebra.rs @@ -1,6 +1,9 @@ //! Tests for the `algebra` module. -use module_lattice::algebra::Field; +use hybrid_array::typenum::U2; +use module_lattice::algebra::{ + Elem, Field, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector, +}; // Field used by ML-KEM. module_lattice::define_field!(KyberField, u16, u32, u64, 3329); @@ -16,3 +19,528 @@ fn small_reduce() { assert_eq!(DilithiumField::small_reduce(8_380_416), 8_380_416); assert_eq!(DilithiumField::small_reduce(8_380_417), 0); } + +#[test] +fn barrett_reduce() { + // Test Barrett reduction produces values in correct range + assert_eq!(KyberField::barrett_reduce(0), 0); + assert_eq!(KyberField::barrett_reduce(3329), 0); + assert_eq!(KyberField::barrett_reduce(3328), 3328); + assert_eq!(KyberField::barrett_reduce(6658), 0); // 2 * 3329 + + // Large product that requires Barrett reduction + let product: u32 = 3000 * 3000; // 9_000_000 + let reduced = KyberField::barrett_reduce(product); + assert!(reduced < 3329); + assert_eq!(reduced, (product % 3329) as u16); + + // Test with Dilithium field + assert_eq!(DilithiumField::barrett_reduce(0), 0); + assert_eq!(DilithiumField::barrett_reduce(8_380_417), 0); +} + +// ======================================== +// Elem arithmetic tests +// ======================================== + +#[test] +fn elem_negation() { + let a: Elem = Elem::new(100); + let neg_a = -a; + // -100 mod 3329 = 3229 + assert_eq!(neg_a.0, 3229); + + // Double negation returns original + assert_eq!((-neg_a).0, 100); + + // Negation of zero is zero + let zero: Elem = Elem::new(0); + assert_eq!((-zero).0, 0); +} + +#[test] +fn elem_addition() { + let a: Elem = Elem::new(100); + let b: Elem = Elem::new(200); + let sum = a + b; + assert_eq!(sum.0, 300); + + // Test wraparound + let c: Elem = Elem::new(3300); + let d: Elem = Elem::new(100); + let wrapped = c + d; + assert_eq!(wrapped.0, 71); // (3300 + 100) % 3329 = 71 + + // Adding zero is identity + let zero: Elem = Elem::new(0); + assert_eq!((a + zero).0, 100); +} + +#[test] +fn elem_subtraction() { + let a: Elem = Elem::new(300); + let b: Elem = Elem::new(100); + let diff = a - b; + assert_eq!(diff.0, 200); + + // Test negative result wraps correctly + let c: Elem = Elem::new(100); + let d: Elem = Elem::new(300); + let wrapped = c - d; + // 100 - 300 = -200 mod 3329 = 3129 + assert_eq!(wrapped.0, 3129); + + // Subtracting zero is identity + let zero: Elem = Elem::new(0); + assert_eq!((a - zero).0, 300); + + // Subtracting self gives zero + assert_eq!((a - a).0, 0); +} + +#[test] +fn elem_multiplication() { + let a: Elem = Elem::new(100); + let b: Elem = Elem::new(200); + let prod = a * b; + assert_eq!(prod.0, (100 * 200) % 3329); + + // Multiply by one is identity + let one: Elem = Elem::new(1); + assert_eq!((a * one).0, 100); + + // Multiply by zero is zero + let zero: Elem = Elem::new(0); + assert_eq!((a * zero).0, 0); + + // Test large product requiring Barrett reduction + let c: Elem = Elem::new(3000); + let d: Elem = Elem::new(3000); + let large_prod = c * d; + assert_eq!(large_prod.0, ((3000u32 * 3000u32) % 3329) as u16); +} + +#[test] +fn elem_arithmetic_consistency() { + // Test: a + b - b = a + let a: Elem = Elem::new(1234); + let b: Elem = Elem::new(5678 % 3329); + assert_eq!((a + b - b).0, a.0); + + // Test: a - b + b = a + assert_eq!((a - b + b).0, a.0); + + // Test: a + (-a) = 0 + assert_eq!((a + (-a)).0, 0); +} + +// ======================================== +// Polynomial arithmetic tests +// ======================================== + +fn make_test_polynomial(base: F::Int) -> Polynomial +where + F::Int: From, +{ + let mut coeffs = [Elem::new(F::Int::from(0u8)); 256]; + for (i, c) in coeffs.iter_mut().enumerate().take(10) { + *c = Elem::new(base + F::Int::from(i as u8)); + } + Polynomial::new(coeffs.into()) +} + +#[test] +fn polynomial_addition() { + let p1 = make_test_polynomial::(100); + let p2 = make_test_polynomial::(200); + let sum = &p1 + &p2; + + // Check first few coefficients + assert_eq!(sum.0[0].0, 300); // 100 + 200 + assert_eq!(sum.0[1].0, 302); // 101 + 201 + assert_eq!(sum.0[9].0, 318); // 109 + 209 + + // Remaining coefficients should be 0 + assert_eq!(sum.0[10].0, 0); +} + +#[test] +fn polynomial_subtraction() { + let p1 = make_test_polynomial::(300); + let p2 = make_test_polynomial::(100); + let diff = &p1 - &p2; + + // Check first few coefficients + assert_eq!(diff.0[0].0, 200); // 300 - 100 + assert_eq!(diff.0[1].0, 200); // 301 - 101 +} + +#[test] +fn polynomial_negation() { + let p = make_test_polynomial::(100); + let neg_p = -&p; + + // Check negation: -100 mod 3329 = 3229 + assert_eq!(neg_p.0[0].0, 3229); + // -101 mod 3329 = 3228 + assert_eq!(neg_p.0[1].0, 3228); + + // Double negation returns original + let double_neg = -&neg_p; + assert_eq!(double_neg.0[0].0, p.0[0].0); +} + +#[test] +fn polynomial_scalar_multiplication() { + let p = make_test_polynomial::(100); + let scalar: Elem = Elem::new(3); + let scaled = scalar * &p; + + assert_eq!(scaled.0[0].0, 300); // 3 * 100 + assert_eq!(scaled.0[1].0, 303); // 3 * 101 +} + +// ======================================== +// Vector arithmetic tests +// ======================================== + +fn make_test_vector(base: F::Int) -> Vector +where + F::Int: From, +{ + let p1 = make_test_polynomial::(base); + let p2 = make_test_polynomial::(base + F::Int::from(50u8)); + Vector::new([p1, p2].into()) +} + +#[test] +fn vector_addition() { + let v1 = make_test_vector::(100); + let v2 = make_test_vector::(200); + let sum = &v1 + &v2; + + // First polynomial: 100+200=300, second: 150+250=400 + assert_eq!(sum.0[0].0[0].0, 300); + assert_eq!(sum.0[1].0[0].0, 400); +} + +#[test] +fn vector_addition_owned() { + let v1 = make_test_vector::(100); + let v2 = make_test_vector::(200); + let sum = v1 + v2; + + assert_eq!(sum.0[0].0[0].0, 300); + assert_eq!(sum.0[1].0[0].0, 400); +} + +#[test] +fn vector_subtraction() { + let v1 = make_test_vector::(300); + let v2 = make_test_vector::(100); + let diff = &v1 - &v2; + + // 300 - 100 = 200 + assert_eq!(diff.0[0].0[0].0, 200); + // 350 - 150 = 200 + assert_eq!(diff.0[1].0[0].0, 200); +} + +#[test] +fn vector_negation() { + let v = make_test_vector::(100); + let neg_v = -&v; + + // -100 mod 3329 = 3229 + assert_eq!(neg_v.0[0].0[0].0, 3229); +} + +#[test] +fn vector_scalar_multiplication() { + let v = make_test_vector::(100); + let scalar: Elem = Elem::new(2); + let scaled = scalar * &v; + + assert_eq!(scaled.0[0].0[0].0, 200); // 2 * 100 + assert_eq!(scaled.0[1].0[0].0, 300); // 2 * 150 +} + +// ======================================== +// NttPolynomial arithmetic tests +// ======================================== + +fn make_test_ntt_polynomial(base: F::Int) -> NttPolynomial +where + F::Int: From, +{ + let mut coeffs = [Elem::new(F::Int::from(0u8)); 256]; + for (i, c) in coeffs.iter_mut().enumerate().take(10) { + *c = Elem::new(base + F::Int::from(i as u8)); + } + NttPolynomial::new(coeffs.into()) +} + +#[test] +fn ntt_polynomial_addition() { + let p1 = make_test_ntt_polynomial::(100); + let p2 = make_test_ntt_polynomial::(200); + let sum = &p1 + &p2; + + assert_eq!(sum.0[0].0, 300); + assert_eq!(sum.0[1].0, 302); +} + +#[test] +fn ntt_polynomial_subtraction() { + let p1 = make_test_ntt_polynomial::(300); + let p2 = make_test_ntt_polynomial::(100); + let diff = &p1 - &p2; + + assert_eq!(diff.0[0].0, 200); +} + +#[test] +fn ntt_polynomial_negation() { + let p = make_test_ntt_polynomial::(100); + let neg_p = -&p; + + assert_eq!(neg_p.0[0].0, 3229); // -100 mod 3329 +} + +#[test] +fn ntt_polynomial_scalar_multiplication() { + let p = make_test_ntt_polynomial::(100); + let scalar: Elem = Elem::new(3); + let scaled = scalar * &p; + + assert_eq!(scaled.0[0].0, 300); +} + +#[test] +fn ntt_polynomial_from_array() { + use hybrid_array::Array; + + let coeffs: Array, hybrid_array::typenum::U256> = + core::array::from_fn(|i| Elem::new((i % 3329) as u16)).into(); + let p: NttPolynomial = coeffs.clone().into(); + + assert_eq!(p.0[0].0, 0); + assert_eq!(p.0[1].0, 1); + + // Convert back + let arr: Array, hybrid_array::typenum::U256> = p.into(); + assert_eq!(arr[0].0, coeffs[0].0); +} + +// ======================================== +// NttVector arithmetic tests +// ======================================== + +fn make_test_ntt_vector(base: F::Int) -> NttVector +where + F::Int: From, +{ + let p1 = make_test_ntt_polynomial::(base); + let p2 = make_test_ntt_polynomial::(base + F::Int::from(50u8)); + NttVector::new([p1, p2].into()) +} + +#[test] +fn ntt_vector_addition() { + let v1 = make_test_ntt_vector::(100); + let v2 = make_test_ntt_vector::(200); + let sum = &v1 + &v2; + + assert_eq!(sum.0[0].0[0].0, 300); + assert_eq!(sum.0[1].0[0].0, 400); +} + +#[test] +fn ntt_vector_subtraction() { + let v1 = make_test_ntt_vector::(300); + let v2 = make_test_ntt_vector::(100); + let diff = &v1 - &v2; + + assert_eq!(diff.0[0].0[0].0, 200); + assert_eq!(diff.0[1].0[0].0, 200); +} + +// ======================================== +// PartialEq tests (to catch == vs != mutations) +// ======================================== + +#[test] +fn elem_equality() { + let a: Elem = Elem::new(100); + let b: Elem = Elem::new(100); + let c: Elem = Elem::new(200); + + assert_eq!(a, b); + assert_ne!(a, c); +} + +#[test] +fn polynomial_equality() { + let p1 = make_test_polynomial::(100); + let p2 = make_test_polynomial::(100); + let p3 = make_test_polynomial::(200); + + assert_eq!(p1, p2); + assert_ne!(p1, p3); +} + +#[test] +fn vector_equality() { + let v1 = make_test_vector::(100); + let v2 = make_test_vector::(100); + let v3 = make_test_vector::(200); + + assert_eq!(v1, v2); + assert_ne!(v1, v3); +} + +#[test] +fn ntt_polynomial_equality() { + let p1 = make_test_ntt_polynomial::(100); + let p2 = make_test_ntt_polynomial::(100); + let p3 = make_test_ntt_polynomial::(200); + + assert_eq!(p1, p2); + assert_ne!(p1, p3); +} + +#[test] +fn ntt_vector_equality() { + let v1 = make_test_ntt_vector::(100); + let v2 = make_test_ntt_vector::(100); + let v3 = make_test_ntt_vector::(200); + + assert_eq!(v1, v2); + assert_ne!(v1, v3); +} + +#[test] +fn ntt_matrix_equality() { + let v1 = make_test_ntt_vector::(100); + let v2 = make_test_ntt_vector::(150); + let m1: NttMatrix = NttMatrix::new([v1.clone(), v2.clone()].into()); + let m2: NttMatrix = NttMatrix::new([v1.clone(), v2.clone()].into()); + + let v3 = make_test_ntt_vector::(200); + let m3: NttMatrix = NttMatrix::new([v1, v3].into()); + + assert_eq!(m1, m2); + assert_ne!(m1, m3); +} + +#[test] +fn ntt_polynomial_into_array() { + use hybrid_array::Array; + use hybrid_array::typenum::U256; + + let p = make_test_ntt_polynomial::(100); + + // Convert to array and verify contents match + let arr: Array, U256> = p.clone().into(); + assert_eq!(arr[0].0, 100); + assert_eq!(arr[1].0, 101); + assert_eq!(arr[9].0, 109); + assert_eq!(arr[10].0, 0); + + // Verify conversion preserves all data + for i in 0..256 { + assert_eq!(arr[i].0, p.0[i].0); + } +} + +// ======================================== +// Zeroize tests (require zeroize feature) +// ======================================== + +#[cfg(feature = "zeroize")] +mod zeroize_tests { + use super::*; + use zeroize::Zeroize; + + #[test] + fn elem_zeroize() { + let mut a: Elem = Elem::new(1234); + assert_ne!(a.0, 0); + a.zeroize(); + assert_eq!(a.0, 0); + } + + #[test] + fn polynomial_zeroize() { + let mut p = make_test_polynomial::(100); + assert_ne!(p.0[0].0, 0); + p.zeroize(); + for i in 0..256 { + assert_eq!(p.0[i].0, 0, "Coefficient {} not zeroed", i); + } + } + + #[test] + fn vector_zeroize() { + let mut v = make_test_vector::(100); + assert_ne!(v.0[0].0[0].0, 0); + v.zeroize(); + for i in 0..2 { + for j in 0..256 { + assert_eq!(v.0[i].0[j].0, 0, "Element [{i}][{j}] not zeroed"); + } + } + } + + #[test] + fn ntt_polynomial_zeroize() { + let mut p = make_test_ntt_polynomial::(100); + assert_ne!(p.0[0].0, 0); + p.zeroize(); + for i in 0..256 { + assert_eq!(p.0[i].0, 0, "Coefficient {} not zeroed", i); + } + } + + #[test] + fn ntt_vector_zeroize() { + let mut v = make_test_ntt_vector::(100); + assert_ne!(v.0[0].0[0].0, 0); + v.zeroize(); + for i in 0..2 { + for j in 0..256 { + assert_eq!(v.0[i].0[j].0, 0, "Element [{i}][{j}] not zeroed"); + } + } + } +} + +// ======================================== +// ConstantTimeEq tests (require subtle feature) +// ======================================== + +#[cfg(feature = "subtle")] +mod subtle_tests { + use super::*; + use subtle::ConstantTimeEq; + + #[test] + fn elem_ct_eq() { + let a: Elem = Elem::new(100); + let b: Elem = Elem::new(100); + let c: Elem = Elem::new(200); + + assert!(bool::from(a.ct_eq(&b))); + assert!(!bool::from(a.ct_eq(&c))); + } + + #[test] + fn ntt_polynomial_ct_eq() { + let p1 = make_test_ntt_polynomial::(100); + let p2 = make_test_ntt_polynomial::(100); + let p3 = make_test_ntt_polynomial::(200); + + assert!(bool::from(p1.ct_eq(&p2))); + assert!(!bool::from(p1.ct_eq(&p3))); + } +} diff --git a/module-lattice/tests/encode.rs b/module-lattice/tests/encode.rs new file mode 100644 index 0000000..2777d42 --- /dev/null +++ b/module-lattice/tests/encode.rs @@ -0,0 +1,312 @@ +//! Tests for the `encode` module. + +use hybrid_array::typenum::{U1, U4, U10, U12}; +use module_lattice::algebra::{Elem, NttPolynomial, NttVector, Polynomial, Vector}; +use module_lattice::encode::{byte_decode, byte_encode, Encode}; + +// Field used by ML-KEM. +module_lattice::define_field!(KyberField, u16, u32, u64, 3329); + +// ======================================== +// byte_encode / byte_decode round-trip tests +// ======================================== + +#[test] +fn byte_encode_decode_d1_roundtrip() { + // D=1: Single bit encoding + let mut vals = [Elem::::new(0); 256]; + for i in 0..256 { + vals[i] = Elem::new((i % 2) as u16); + } + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for i in 0..256 { + assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_decode_d4_roundtrip() { + // D=4: 4-bit encoding + let mut vals = [Elem::::new(0); 256]; + for i in 0..256 { + vals[i] = Elem::new((i % 16) as u16); + } + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for i in 0..256 { + assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_decode_d10_roundtrip() { + // D=10: 10-bit encoding + let mut vals = [Elem::::new(0); 256]; + for i in 0..256 { + vals[i] = Elem::new((i % 1024) as u16); + } + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for i in 0..256 { + assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_decode_d12_roundtrip() { + // D=12: 12-bit encoding (special case with modular reduction) + let mut vals = [Elem::::new(0); 256]; + for i in 0..256 { + // Values up to q-1 (3328) + vals[i] = Elem::new((i * 13) as u16 % 3329); + } + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for i in 0..256 { + assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_decode_d12_modular_reduction() { + // Test that D=12 properly reduces values >= Q + let mut vals = [Elem::::new(0); 256]; + + // Fill with values near and above Q + for i in 0..256 { + vals[i] = Elem::new(3329 + (i as u16) % 100); // Values >= Q + } + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + // After decode, values should be reduced mod Q + for i in 0..256 { + assert!( + decoded[i].0 < 3329, + "Value at {i} not reduced: {}", + decoded[i].0 + ); + } +} + +#[test] +fn byte_encode_zero_values() { + let vals = [Elem::::new(0); 256]; + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for i in 0..256 { + assert_eq!(decoded[i].0, 0); + } +} + +#[test] +fn byte_encode_max_values() { + // D=4: max value is 15 + let vals = [Elem::::new(15); 256]; + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for i in 0..256 { + assert_eq!(decoded[i].0, 15); + } +} + +// ======================================== +// Polynomial encoding tests +// ======================================== + +#[test] +fn polynomial_encode_decode_roundtrip() { + let mut coeffs = [Elem::::new(0); 256]; + for i in 0..256 { + coeffs[i] = Elem::new((i * 7) as u16 % 16); + } + let p = Polynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(p, decoded); +} + +#[test] +fn polynomial_encode_decode_d12() { + let mut coeffs = [Elem::::new(0); 256]; + for i in 0..256 { + coeffs[i] = Elem::new((i * 13) as u16 % 3329); + } + let p = Polynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(p, decoded); +} + +// ======================================== +// Vector encoding tests +// ======================================== + +#[test] +fn vector_encode_decode_roundtrip() { + use hybrid_array::typenum::U2; + + let mut coeffs1 = [Elem::::new(0); 256]; + let mut coeffs2 = [Elem::::new(0); 256]; + for i in 0..256 { + coeffs1[i] = Elem::new((i * 3) as u16 % 16); + coeffs2[i] = Elem::new((i * 5) as u16 % 16); + } + + let p1 = Polynomial::::new(coeffs1.into()); + let p2 = Polynomial::::new(coeffs2.into()); + let v: Vector = Vector::new([p1, p2].into()); + + let encoded = as Encode>::encode(&v); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(v, decoded); +} + +// ======================================== +// NttPolynomial encoding tests +// ======================================== + +#[test] +fn ntt_polynomial_encode_decode_roundtrip() { + let mut coeffs = [Elem::::new(0); 256]; + for i in 0..256 { + coeffs[i] = Elem::new((i * 7) as u16 % 16); + } + let p = NttPolynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(p, decoded); +} + +#[test] +fn ntt_polynomial_encode_decode_d12() { + let mut coeffs = [Elem::::new(0); 256]; + for i in 0..256 { + coeffs[i] = Elem::new((i * 13) as u16 % 3329); + } + let p = NttPolynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(p, decoded); +} + +// ======================================== +// NttVector encoding tests +// ======================================== + +#[test] +fn ntt_vector_encode_decode_roundtrip() { + use hybrid_array::typenum::U2; + + let mut coeffs1 = [Elem::::new(0); 256]; + let mut coeffs2 = [Elem::::new(0); 256]; + for i in 0..256 { + coeffs1[i] = Elem::new((i * 3) as u16 % 16); + coeffs2[i] = Elem::new((i * 5) as u16 % 16); + } + + let p1 = NttPolynomial::::new(coeffs1.into()); + let p2 = NttPolynomial::::new(coeffs2.into()); + let v: NttVector = NttVector::new([p1, p2].into()); + + let encoded = as Encode>::encode(&v); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(v, decoded); +} + +// ======================================== +// Encoding size verification +// ======================================== + +#[test] +fn encoded_polynomial_size_d4() { + // D=4 means 4 bits per coefficient, 256 coefficients = 1024 bits = 128 bytes + let coeffs = [Elem::::new(0); 256]; + let p = Polynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + assert_eq!(encoded.len(), 128); +} + +#[test] +fn encoded_polynomial_size_d12() { + // D=12 means 12 bits per coefficient, 256 coefficients = 3072 bits = 384 bytes + let coeffs = [Elem::::new(0); 256]; + let p = Polynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + assert_eq!(encoded.len(), 384); +} + +#[test] +fn encoded_vector_size() { + use hybrid_array::typenum::U3; + + // D=4, K=3: 128 bytes per polynomial * 3 = 384 bytes + let coeffs = [Elem::::new(0); 256]; + let p = Polynomial::::new(coeffs.into()); + let v: Vector = Vector::new([p.clone(), p.clone(), p].into()); + + let encoded = as Encode>::encode(&v); + assert_eq!(encoded.len(), 384); +} + +// ======================================== +// Edge cases and boundary tests +// ======================================== + +#[test] +fn byte_encode_alternating_bits() { + // Test alternating patterns to catch bit manipulation issues + let mut vals = [Elem::::new(0); 256]; + for i in 0..256 { + vals[i] = Elem::new(if i % 2 == 0 { 0b0101 } else { 0b1010 }); + } + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for i in 0..256 { + assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_sequential_values() { + // Sequential values to catch ordering issues + let mut vals = [Elem::::new(0); 256]; + for i in 0..256 { + vals[i] = Elem::new(i as u16 % 16); + } + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for i in 0..256 { + assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + } +} From 69ea52355a6a8b41d0a97d46397a5f3694f2e5f7 Mon Sep 17 00:00:00 2001 From: Scott Arciszewski Date: Thu, 29 Jan 2026 15:25:43 -0500 Subject: [PATCH 2/2] module-lattice: fix ci failures --- .github/workflows/mutation-testing.yml | 11 ++- module-lattice/tests/algebra.rs | 2 +- module-lattice/tests/encode.rs | 122 ++++++++----------------- 3 files changed, 50 insertions(+), 85 deletions(-) diff --git a/.github/workflows/mutation-testing.yml b/.github/workflows/mutation-testing.yml index 3e68491..6e68df0 100644 --- a/.github/workflows/mutation-testing.yml +++ b/.github/workflows/mutation-testing.yml @@ -43,10 +43,17 @@ jobs: package: - ml-kem - module-lattice + - dhkem + - frodo-kem + - x-wing # Only filter if a specific package was requested via workflow_dispatch + # For push/pull_request events, inputs.package is undefined, so default to 'all' exclude: - - package: ${{ github.event.inputs.package != 'all' && github.event.inputs.package != 'ml-kem' && 'ml-kem' || 'NONE' }} - - package: ${{ github.event.inputs.package != 'all' && github.event.inputs.package != 'module-lattice' && 'module-lattice' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'ml-kem' && 'ml-kem' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'module-lattice' && 'module-lattice' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'dhkem' && 'dhkem' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'frodo-kem' && 'frodo-kem' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'x-wing' && 'x-wing' || 'NONE' }} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: diff --git a/module-lattice/tests/algebra.rs b/module-lattice/tests/algebra.rs index 23d1908..2c0ce3f 100644 --- a/module-lattice/tests/algebra.rs +++ b/module-lattice/tests/algebra.rs @@ -322,7 +322,7 @@ fn ntt_polynomial_from_array() { let coeffs: Array, hybrid_array::typenum::U256> = core::array::from_fn(|i| Elem::new((i % 3329) as u16)).into(); - let p: NttPolynomial = coeffs.clone().into(); + let p: NttPolynomial = coeffs.into(); assert_eq!(p.0[0].0, 0); assert_eq!(p.0[1].0, 1); diff --git a/module-lattice/tests/encode.rs b/module-lattice/tests/encode.rs index 2777d42..31c30a4 100644 --- a/module-lattice/tests/encode.rs +++ b/module-lattice/tests/encode.rs @@ -2,7 +2,7 @@ use hybrid_array::typenum::{U1, U4, U10, U12}; use module_lattice::algebra::{Elem, NttPolynomial, NttVector, Polynomial, Vector}; -use module_lattice::encode::{byte_decode, byte_encode, Encode}; +use module_lattice::encode::{Encode, byte_decode, byte_encode}; // Field used by ML-KEM. module_lattice::define_field!(KyberField, u16, u32, u64, 3329); @@ -14,88 +14,69 @@ module_lattice::define_field!(KyberField, u16, u32, u64, 3329); #[test] fn byte_encode_decode_d1_roundtrip() { // D=1: Single bit encoding - let mut vals = [Elem::::new(0); 256]; - for i in 0..256 { - vals[i] = Elem::new((i % 2) as u16); - } + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new((i % 2) as u16)); let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); - for i in 0..256 { - assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); } } #[test] fn byte_encode_decode_d4_roundtrip() { // D=4: 4-bit encoding - let mut vals = [Elem::::new(0); 256]; - for i in 0..256 { - vals[i] = Elem::new((i % 16) as u16); - } + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new((i % 16) as u16)); let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); - for i in 0..256 { - assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); } } #[test] fn byte_encode_decode_d10_roundtrip() { // D=10: 10-bit encoding - let mut vals = [Elem::::new(0); 256]; - for i in 0..256 { - vals[i] = Elem::new((i % 1024) as u16); - } + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new((i % 1024) as u16)); let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); - for i in 0..256 { - assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); } } #[test] fn byte_encode_decode_d12_roundtrip() { // D=12: 12-bit encoding (special case with modular reduction) - let mut vals = [Elem::::new(0); 256]; - for i in 0..256 { - // Values up to q-1 (3328) - vals[i] = Elem::new((i * 13) as u16 % 3329); - } + // Values up to q-1 (3328) + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329)); let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); - for i in 0..256 { - assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); } } #[test] fn byte_encode_decode_d12_modular_reduction() { // Test that D=12 properly reduces values >= Q - let mut vals = [Elem::::new(0); 256]; - // Fill with values near and above Q - for i in 0..256 { - vals[i] = Elem::new(3329 + (i as u16) % 100); // Values >= Q - } + let vals: [Elem; 256] = + core::array::from_fn(|i| Elem::new(3329 + (i as u16) % 100)); let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); // After decode, values should be reduced mod Q - for i in 0..256 { - assert!( - decoded[i].0 < 3329, - "Value at {i} not reduced: {}", - decoded[i].0 - ); + for (i, dec) in decoded.iter().enumerate() { + assert!(dec.0 < 3329, "Value at {i} not reduced: {}", dec.0); } } @@ -106,8 +87,8 @@ fn byte_encode_zero_values() { let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); - for i in 0..256 { - assert_eq!(decoded[i].0, 0); + for dec in &decoded { + assert_eq!(dec.0, 0); } } @@ -119,8 +100,8 @@ fn byte_encode_max_values() { let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); - for i in 0..256 { - assert_eq!(decoded[i].0, 15); + for dec in &decoded { + assert_eq!(dec.0, 15); } } @@ -130,10 +111,7 @@ fn byte_encode_max_values() { #[test] fn polynomial_encode_decode_roundtrip() { - let mut coeffs = [Elem::::new(0); 256]; - for i in 0..256 { - coeffs[i] = Elem::new((i * 7) as u16 % 16); - } + let coeffs: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 7) as u16 % 16)); let p = Polynomial::::new(coeffs.into()); let encoded = as Encode>::encode(&p); @@ -144,10 +122,8 @@ fn polynomial_encode_decode_roundtrip() { #[test] fn polynomial_encode_decode_d12() { - let mut coeffs = [Elem::::new(0); 256]; - for i in 0..256 { - coeffs[i] = Elem::new((i * 13) as u16 % 3329); - } + let coeffs: [Elem; 256] = + core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329)); let p = Polynomial::::new(coeffs.into()); let encoded = as Encode>::encode(&p); @@ -164,12 +140,8 @@ fn polynomial_encode_decode_d12() { fn vector_encode_decode_roundtrip() { use hybrid_array::typenum::U2; - let mut coeffs1 = [Elem::::new(0); 256]; - let mut coeffs2 = [Elem::::new(0); 256]; - for i in 0..256 { - coeffs1[i] = Elem::new((i * 3) as u16 % 16); - coeffs2[i] = Elem::new((i * 5) as u16 % 16); - } + let coeffs1: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 3) as u16 % 16)); + let coeffs2: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 5) as u16 % 16)); let p1 = Polynomial::::new(coeffs1.into()); let p2 = Polynomial::::new(coeffs2.into()); @@ -187,10 +159,7 @@ fn vector_encode_decode_roundtrip() { #[test] fn ntt_polynomial_encode_decode_roundtrip() { - let mut coeffs = [Elem::::new(0); 256]; - for i in 0..256 { - coeffs[i] = Elem::new((i * 7) as u16 % 16); - } + let coeffs: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 7) as u16 % 16)); let p = NttPolynomial::::new(coeffs.into()); let encoded = as Encode>::encode(&p); @@ -201,10 +170,8 @@ fn ntt_polynomial_encode_decode_roundtrip() { #[test] fn ntt_polynomial_encode_decode_d12() { - let mut coeffs = [Elem::::new(0); 256]; - for i in 0..256 { - coeffs[i] = Elem::new((i * 13) as u16 % 3329); - } + let coeffs: [Elem; 256] = + core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329)); let p = NttPolynomial::::new(coeffs.into()); let encoded = as Encode>::encode(&p); @@ -221,12 +188,8 @@ fn ntt_polynomial_encode_decode_d12() { fn ntt_vector_encode_decode_roundtrip() { use hybrid_array::typenum::U2; - let mut coeffs1 = [Elem::::new(0); 256]; - let mut coeffs2 = [Elem::::new(0); 256]; - for i in 0..256 { - coeffs1[i] = Elem::new((i * 3) as u16 % 16); - coeffs2[i] = Elem::new((i * 5) as u16 % 16); - } + let coeffs1: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 3) as u16 % 16)); + let coeffs2: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 5) as u16 % 16)); let p1 = NttPolynomial::::new(coeffs1.into()); let p2 = NttPolynomial::::new(coeffs2.into()); @@ -269,7 +232,7 @@ fn encoded_vector_size() { // D=4, K=3: 128 bytes per polynomial * 3 = 384 bytes let coeffs = [Elem::::new(0); 256]; let p = Polynomial::::new(coeffs.into()); - let v: Vector = Vector::new([p.clone(), p.clone(), p].into()); + let v: Vector = Vector::new([p, p, p].into()); let encoded = as Encode>::encode(&v); assert_eq!(encoded.len(), 384); @@ -282,31 +245,26 @@ fn encoded_vector_size() { #[test] fn byte_encode_alternating_bits() { // Test alternating patterns to catch bit manipulation issues - let mut vals = [Elem::::new(0); 256]; - for i in 0..256 { - vals[i] = Elem::new(if i % 2 == 0 { 0b0101 } else { 0b1010 }); - } + let vals: [Elem; 256] = + core::array::from_fn(|i| Elem::new(if i % 2 == 0 { 0b0101 } else { 0b1010 })); let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); - for i in 0..256 { - assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); } } #[test] fn byte_encode_sequential_values() { // Sequential values to catch ordering issues - let mut vals = [Elem::::new(0); 256]; - for i in 0..256 { - vals[i] = Elem::new(i as u16 % 16); - } + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new(i as u16 % 16)); let encoded = byte_encode::(&vals.into()); let decoded = byte_decode::(&encoded); - for i in 0..256 { - assert_eq!(decoded[i].0, vals[i].0, "Mismatch at index {i}"); + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); } }