diff --git a/.github/workflows/mutation-testing.yml b/.github/workflows/mutation-testing.yml new file mode 100644 index 0000000..6e68df0 --- /dev/null +++ b/.github/workflows/mutation-testing.yml @@ -0,0 +1,94 @@ +name: Mutation Testing + +# Minimal permissions for security +permissions: + contents: read + +on: + # Run weekly on Sundays at 2 AM UTC + schedule: + - cron: "0 2 * * 0" + # Allow manual triggering + workflow_dispatch: + inputs: + package: + description: "Package to test (select 'all' for all packages)" + required: false + default: "all" + type: choice + options: + - all + - ml-kem + - module-lattice + - dhkem + - frodo-kem + - x-wing + +env: + CARGO_INCREMENTAL: 0 + CARGO_TERM_COLOR: always + +# Only run one mutation test at a time +concurrency: + group: mutation-testing + cancel-in-progress: false + +jobs: + mutants: + name: ${{ matrix.package }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + package: + - ml-kem + - module-lattice + - dhkem + - frodo-kem + - x-wing + # Only filter if a specific package was requested via workflow_dispatch + # For push/pull_request events, inputs.package is undefined, so default to 'all' + exclude: + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'ml-kem' && 'ml-kem' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'module-lattice' && 'module-lattice' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'dhkem' && 'dhkem' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'frodo-kem' && 'frodo-kem' || 'NONE' }} + - package: ${{ (github.event.inputs.package || 'all') != 'all' && (github.event.inputs.package || 'all') != 'x-wing' && 'x-wing' || 'NONE' }} + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false + submodules: recursive + + - uses: dtolnay/rust-toolchain@f7ccc83f9ed1e5b9c81d8a67d7ad1a747e22a561 # stable + with: + toolchain: stable + + - uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8 + with: + prefix-key: mutants-${{ matrix.package }} + + - name: Install cargo-mutants + run: cargo install cargo-mutants@26.1.2 --locked + + - name: Run mutation testing + run: | + cargo mutants --package "${{ matrix.package }}" --all-features -j 2 --timeout 300 2>&1 | tee mutants-output.txt + # Extract summary for job summary + { + echo "## Mutation Testing Results: ${{ matrix.package }}" + echo "" + echo '```' + tail -5 mutants-output.txt + echo '```' + } >> "$GITHUB_STEP_SUMMARY" + + - name: Upload mutation results + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + if: always() + with: + name: mutants-${{ matrix.package }} + path: | + mutants.out/ + mutants-output.txt + retention-days: 30 diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index 79b568d..b1729cb 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -471,6 +471,11 @@ mod test { assert_eq!((&v1 * &v2), const_ntt(6)); assert_eq!((&v1 * &v3), const_ntt(9)); assert_eq!((&v2 * &v3), const_ntt(18)); + + // Verify inequality (catches PartialEq mutation that returns true unconditionally) + assert_ne!(v1, v2); + assert_ne!(v1, v3); + assert_ne!(v2, v3); } #[test] diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs index 2644d92..4399082 100644 --- a/ml-kem/src/kem.rs +++ b/ml-kem/src/kem.rs @@ -440,4 +440,29 @@ mod test { seed_test::(); seed_test::(); } + + fn key_inequality_test

() + where + P: KemParams, + { + let mut rng = UnwrapErr(SysRng); + + // Generate two different keys + let dk1 = DecapsulationKey::

::generate_from_rng(&mut rng); + let dk2 = DecapsulationKey::

::generate_from_rng(&mut rng); + + let ek1 = dk1.encapsulation_key(); + let ek2 = dk2.encapsulation_key(); + + // Verify inequality (catches PartialEq mutation that returns true unconditionally) + assert_ne!(dk1, dk2); + assert_ne!(ek1, ek2); + } + + #[test] + fn key_inequality() { + key_inequality_test::(); + key_inequality_test::(); + key_inequality_test::(); + } } diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index 88007f2..a849b55 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -264,4 +264,26 @@ mod test { let invalid_key = [0xFF; 1184]; assert!(EncryptionKey::::from_bytes(&invalid_key.into()).is_err()); } + + fn key_inequality_test

() + where + P: PkeParams, + { + let mut rng = UnwrapErr(SysRng); + let d1 = B32::generate_from_rng(&mut rng); + let d2 = B32::generate_from_rng(&mut rng); + + let (dk1, _) = DecryptionKey::

::generate(&d1); + let (dk2, _) = DecryptionKey::

::generate(&d2); + + // Verify inequality (catches PartialEq mutation that returns true unconditionally) + assert_ne!(dk1, dk2); + } + + #[test] + fn key_inequality() { + key_inequality_test::(); + key_inequality_test::(); + key_inequality_test::(); + } } diff --git a/module-lattice/tests/algebra.rs b/module-lattice/tests/algebra.rs index 8cfa1bc..2c0ce3f 100644 --- a/module-lattice/tests/algebra.rs +++ b/module-lattice/tests/algebra.rs @@ -1,6 +1,9 @@ //! Tests for the `algebra` module. -use module_lattice::algebra::Field; +use hybrid_array::typenum::U2; +use module_lattice::algebra::{ + Elem, Field, NttMatrix, NttPolynomial, NttVector, Polynomial, Vector, +}; // Field used by ML-KEM. module_lattice::define_field!(KyberField, u16, u32, u64, 3329); @@ -16,3 +19,528 @@ fn small_reduce() { assert_eq!(DilithiumField::small_reduce(8_380_416), 8_380_416); assert_eq!(DilithiumField::small_reduce(8_380_417), 0); } + +#[test] +fn barrett_reduce() { + // Test Barrett reduction produces values in correct range + assert_eq!(KyberField::barrett_reduce(0), 0); + assert_eq!(KyberField::barrett_reduce(3329), 0); + assert_eq!(KyberField::barrett_reduce(3328), 3328); + assert_eq!(KyberField::barrett_reduce(6658), 0); // 2 * 3329 + + // Large product that requires Barrett reduction + let product: u32 = 3000 * 3000; // 9_000_000 + let reduced = KyberField::barrett_reduce(product); + assert!(reduced < 3329); + assert_eq!(reduced, (product % 3329) as u16); + + // Test with Dilithium field + assert_eq!(DilithiumField::barrett_reduce(0), 0); + assert_eq!(DilithiumField::barrett_reduce(8_380_417), 0); +} + +// ======================================== +// Elem arithmetic tests +// ======================================== + +#[test] +fn elem_negation() { + let a: Elem = Elem::new(100); + let neg_a = -a; + // -100 mod 3329 = 3229 + assert_eq!(neg_a.0, 3229); + + // Double negation returns original + assert_eq!((-neg_a).0, 100); + + // Negation of zero is zero + let zero: Elem = Elem::new(0); + assert_eq!((-zero).0, 0); +} + +#[test] +fn elem_addition() { + let a: Elem = Elem::new(100); + let b: Elem = Elem::new(200); + let sum = a + b; + assert_eq!(sum.0, 300); + + // Test wraparound + let c: Elem = Elem::new(3300); + let d: Elem = Elem::new(100); + let wrapped = c + d; + assert_eq!(wrapped.0, 71); // (3300 + 100) % 3329 = 71 + + // Adding zero is identity + let zero: Elem = Elem::new(0); + assert_eq!((a + zero).0, 100); +} + +#[test] +fn elem_subtraction() { + let a: Elem = Elem::new(300); + let b: Elem = Elem::new(100); + let diff = a - b; + assert_eq!(diff.0, 200); + + // Test negative result wraps correctly + let c: Elem = Elem::new(100); + let d: Elem = Elem::new(300); + let wrapped = c - d; + // 100 - 300 = -200 mod 3329 = 3129 + assert_eq!(wrapped.0, 3129); + + // Subtracting zero is identity + let zero: Elem = Elem::new(0); + assert_eq!((a - zero).0, 300); + + // Subtracting self gives zero + assert_eq!((a - a).0, 0); +} + +#[test] +fn elem_multiplication() { + let a: Elem = Elem::new(100); + let b: Elem = Elem::new(200); + let prod = a * b; + assert_eq!(prod.0, (100 * 200) % 3329); + + // Multiply by one is identity + let one: Elem = Elem::new(1); + assert_eq!((a * one).0, 100); + + // Multiply by zero is zero + let zero: Elem = Elem::new(0); + assert_eq!((a * zero).0, 0); + + // Test large product requiring Barrett reduction + let c: Elem = Elem::new(3000); + let d: Elem = Elem::new(3000); + let large_prod = c * d; + assert_eq!(large_prod.0, ((3000u32 * 3000u32) % 3329) as u16); +} + +#[test] +fn elem_arithmetic_consistency() { + // Test: a + b - b = a + let a: Elem = Elem::new(1234); + let b: Elem = Elem::new(5678 % 3329); + assert_eq!((a + b - b).0, a.0); + + // Test: a - b + b = a + assert_eq!((a - b + b).0, a.0); + + // Test: a + (-a) = 0 + assert_eq!((a + (-a)).0, 0); +} + +// ======================================== +// Polynomial arithmetic tests +// ======================================== + +fn make_test_polynomial(base: F::Int) -> Polynomial +where + F::Int: From, +{ + let mut coeffs = [Elem::new(F::Int::from(0u8)); 256]; + for (i, c) in coeffs.iter_mut().enumerate().take(10) { + *c = Elem::new(base + F::Int::from(i as u8)); + } + Polynomial::new(coeffs.into()) +} + +#[test] +fn polynomial_addition() { + let p1 = make_test_polynomial::(100); + let p2 = make_test_polynomial::(200); + let sum = &p1 + &p2; + + // Check first few coefficients + assert_eq!(sum.0[0].0, 300); // 100 + 200 + assert_eq!(sum.0[1].0, 302); // 101 + 201 + assert_eq!(sum.0[9].0, 318); // 109 + 209 + + // Remaining coefficients should be 0 + assert_eq!(sum.0[10].0, 0); +} + +#[test] +fn polynomial_subtraction() { + let p1 = make_test_polynomial::(300); + let p2 = make_test_polynomial::(100); + let diff = &p1 - &p2; + + // Check first few coefficients + assert_eq!(diff.0[0].0, 200); // 300 - 100 + assert_eq!(diff.0[1].0, 200); // 301 - 101 +} + +#[test] +fn polynomial_negation() { + let p = make_test_polynomial::(100); + let neg_p = -&p; + + // Check negation: -100 mod 3329 = 3229 + assert_eq!(neg_p.0[0].0, 3229); + // -101 mod 3329 = 3228 + assert_eq!(neg_p.0[1].0, 3228); + + // Double negation returns original + let double_neg = -&neg_p; + assert_eq!(double_neg.0[0].0, p.0[0].0); +} + +#[test] +fn polynomial_scalar_multiplication() { + let p = make_test_polynomial::(100); + let scalar: Elem = Elem::new(3); + let scaled = scalar * &p; + + assert_eq!(scaled.0[0].0, 300); // 3 * 100 + assert_eq!(scaled.0[1].0, 303); // 3 * 101 +} + +// ======================================== +// Vector arithmetic tests +// ======================================== + +fn make_test_vector(base: F::Int) -> Vector +where + F::Int: From, +{ + let p1 = make_test_polynomial::(base); + let p2 = make_test_polynomial::(base + F::Int::from(50u8)); + Vector::new([p1, p2].into()) +} + +#[test] +fn vector_addition() { + let v1 = make_test_vector::(100); + let v2 = make_test_vector::(200); + let sum = &v1 + &v2; + + // First polynomial: 100+200=300, second: 150+250=400 + assert_eq!(sum.0[0].0[0].0, 300); + assert_eq!(sum.0[1].0[0].0, 400); +} + +#[test] +fn vector_addition_owned() { + let v1 = make_test_vector::(100); + let v2 = make_test_vector::(200); + let sum = v1 + v2; + + assert_eq!(sum.0[0].0[0].0, 300); + assert_eq!(sum.0[1].0[0].0, 400); +} + +#[test] +fn vector_subtraction() { + let v1 = make_test_vector::(300); + let v2 = make_test_vector::(100); + let diff = &v1 - &v2; + + // 300 - 100 = 200 + assert_eq!(diff.0[0].0[0].0, 200); + // 350 - 150 = 200 + assert_eq!(diff.0[1].0[0].0, 200); +} + +#[test] +fn vector_negation() { + let v = make_test_vector::(100); + let neg_v = -&v; + + // -100 mod 3329 = 3229 + assert_eq!(neg_v.0[0].0[0].0, 3229); +} + +#[test] +fn vector_scalar_multiplication() { + let v = make_test_vector::(100); + let scalar: Elem = Elem::new(2); + let scaled = scalar * &v; + + assert_eq!(scaled.0[0].0[0].0, 200); // 2 * 100 + assert_eq!(scaled.0[1].0[0].0, 300); // 2 * 150 +} + +// ======================================== +// NttPolynomial arithmetic tests +// ======================================== + +fn make_test_ntt_polynomial(base: F::Int) -> NttPolynomial +where + F::Int: From, +{ + let mut coeffs = [Elem::new(F::Int::from(0u8)); 256]; + for (i, c) in coeffs.iter_mut().enumerate().take(10) { + *c = Elem::new(base + F::Int::from(i as u8)); + } + NttPolynomial::new(coeffs.into()) +} + +#[test] +fn ntt_polynomial_addition() { + let p1 = make_test_ntt_polynomial::(100); + let p2 = make_test_ntt_polynomial::(200); + let sum = &p1 + &p2; + + assert_eq!(sum.0[0].0, 300); + assert_eq!(sum.0[1].0, 302); +} + +#[test] +fn ntt_polynomial_subtraction() { + let p1 = make_test_ntt_polynomial::(300); + let p2 = make_test_ntt_polynomial::(100); + let diff = &p1 - &p2; + + assert_eq!(diff.0[0].0, 200); +} + +#[test] +fn ntt_polynomial_negation() { + let p = make_test_ntt_polynomial::(100); + let neg_p = -&p; + + assert_eq!(neg_p.0[0].0, 3229); // -100 mod 3329 +} + +#[test] +fn ntt_polynomial_scalar_multiplication() { + let p = make_test_ntt_polynomial::(100); + let scalar: Elem = Elem::new(3); + let scaled = scalar * &p; + + assert_eq!(scaled.0[0].0, 300); +} + +#[test] +fn ntt_polynomial_from_array() { + use hybrid_array::Array; + + let coeffs: Array, hybrid_array::typenum::U256> = + core::array::from_fn(|i| Elem::new((i % 3329) as u16)).into(); + let p: NttPolynomial = coeffs.into(); + + assert_eq!(p.0[0].0, 0); + assert_eq!(p.0[1].0, 1); + + // Convert back + let arr: Array, hybrid_array::typenum::U256> = p.into(); + assert_eq!(arr[0].0, coeffs[0].0); +} + +// ======================================== +// NttVector arithmetic tests +// ======================================== + +fn make_test_ntt_vector(base: F::Int) -> NttVector +where + F::Int: From, +{ + let p1 = make_test_ntt_polynomial::(base); + let p2 = make_test_ntt_polynomial::(base + F::Int::from(50u8)); + NttVector::new([p1, p2].into()) +} + +#[test] +fn ntt_vector_addition() { + let v1 = make_test_ntt_vector::(100); + let v2 = make_test_ntt_vector::(200); + let sum = &v1 + &v2; + + assert_eq!(sum.0[0].0[0].0, 300); + assert_eq!(sum.0[1].0[0].0, 400); +} + +#[test] +fn ntt_vector_subtraction() { + let v1 = make_test_ntt_vector::(300); + let v2 = make_test_ntt_vector::(100); + let diff = &v1 - &v2; + + assert_eq!(diff.0[0].0[0].0, 200); + assert_eq!(diff.0[1].0[0].0, 200); +} + +// ======================================== +// PartialEq tests (to catch == vs != mutations) +// ======================================== + +#[test] +fn elem_equality() { + let a: Elem = Elem::new(100); + let b: Elem = Elem::new(100); + let c: Elem = Elem::new(200); + + assert_eq!(a, b); + assert_ne!(a, c); +} + +#[test] +fn polynomial_equality() { + let p1 = make_test_polynomial::(100); + let p2 = make_test_polynomial::(100); + let p3 = make_test_polynomial::(200); + + assert_eq!(p1, p2); + assert_ne!(p1, p3); +} + +#[test] +fn vector_equality() { + let v1 = make_test_vector::(100); + let v2 = make_test_vector::(100); + let v3 = make_test_vector::(200); + + assert_eq!(v1, v2); + assert_ne!(v1, v3); +} + +#[test] +fn ntt_polynomial_equality() { + let p1 = make_test_ntt_polynomial::(100); + let p2 = make_test_ntt_polynomial::(100); + let p3 = make_test_ntt_polynomial::(200); + + assert_eq!(p1, p2); + assert_ne!(p1, p3); +} + +#[test] +fn ntt_vector_equality() { + let v1 = make_test_ntt_vector::(100); + let v2 = make_test_ntt_vector::(100); + let v3 = make_test_ntt_vector::(200); + + assert_eq!(v1, v2); + assert_ne!(v1, v3); +} + +#[test] +fn ntt_matrix_equality() { + let v1 = make_test_ntt_vector::(100); + let v2 = make_test_ntt_vector::(150); + let m1: NttMatrix = NttMatrix::new([v1.clone(), v2.clone()].into()); + let m2: NttMatrix = NttMatrix::new([v1.clone(), v2.clone()].into()); + + let v3 = make_test_ntt_vector::(200); + let m3: NttMatrix = NttMatrix::new([v1, v3].into()); + + assert_eq!(m1, m2); + assert_ne!(m1, m3); +} + +#[test] +fn ntt_polynomial_into_array() { + use hybrid_array::Array; + use hybrid_array::typenum::U256; + + let p = make_test_ntt_polynomial::(100); + + // Convert to array and verify contents match + let arr: Array, U256> = p.clone().into(); + assert_eq!(arr[0].0, 100); + assert_eq!(arr[1].0, 101); + assert_eq!(arr[9].0, 109); + assert_eq!(arr[10].0, 0); + + // Verify conversion preserves all data + for i in 0..256 { + assert_eq!(arr[i].0, p.0[i].0); + } +} + +// ======================================== +// Zeroize tests (require zeroize feature) +// ======================================== + +#[cfg(feature = "zeroize")] +mod zeroize_tests { + use super::*; + use zeroize::Zeroize; + + #[test] + fn elem_zeroize() { + let mut a: Elem = Elem::new(1234); + assert_ne!(a.0, 0); + a.zeroize(); + assert_eq!(a.0, 0); + } + + #[test] + fn polynomial_zeroize() { + let mut p = make_test_polynomial::(100); + assert_ne!(p.0[0].0, 0); + p.zeroize(); + for i in 0..256 { + assert_eq!(p.0[i].0, 0, "Coefficient {} not zeroed", i); + } + } + + #[test] + fn vector_zeroize() { + let mut v = make_test_vector::(100); + assert_ne!(v.0[0].0[0].0, 0); + v.zeroize(); + for i in 0..2 { + for j in 0..256 { + assert_eq!(v.0[i].0[j].0, 0, "Element [{i}][{j}] not zeroed"); + } + } + } + + #[test] + fn ntt_polynomial_zeroize() { + let mut p = make_test_ntt_polynomial::(100); + assert_ne!(p.0[0].0, 0); + p.zeroize(); + for i in 0..256 { + assert_eq!(p.0[i].0, 0, "Coefficient {} not zeroed", i); + } + } + + #[test] + fn ntt_vector_zeroize() { + let mut v = make_test_ntt_vector::(100); + assert_ne!(v.0[0].0[0].0, 0); + v.zeroize(); + for i in 0..2 { + for j in 0..256 { + assert_eq!(v.0[i].0[j].0, 0, "Element [{i}][{j}] not zeroed"); + } + } + } +} + +// ======================================== +// ConstantTimeEq tests (require subtle feature) +// ======================================== + +#[cfg(feature = "subtle")] +mod subtle_tests { + use super::*; + use subtle::ConstantTimeEq; + + #[test] + fn elem_ct_eq() { + let a: Elem = Elem::new(100); + let b: Elem = Elem::new(100); + let c: Elem = Elem::new(200); + + assert!(bool::from(a.ct_eq(&b))); + assert!(!bool::from(a.ct_eq(&c))); + } + + #[test] + fn ntt_polynomial_ct_eq() { + let p1 = make_test_ntt_polynomial::(100); + let p2 = make_test_ntt_polynomial::(100); + let p3 = make_test_ntt_polynomial::(200); + + assert!(bool::from(p1.ct_eq(&p2))); + assert!(!bool::from(p1.ct_eq(&p3))); + } +} diff --git a/module-lattice/tests/encode.rs b/module-lattice/tests/encode.rs new file mode 100644 index 0000000..31c30a4 --- /dev/null +++ b/module-lattice/tests/encode.rs @@ -0,0 +1,270 @@ +//! Tests for the `encode` module. + +use hybrid_array::typenum::{U1, U4, U10, U12}; +use module_lattice::algebra::{Elem, NttPolynomial, NttVector, Polynomial, Vector}; +use module_lattice::encode::{Encode, byte_decode, byte_encode}; + +// Field used by ML-KEM. +module_lattice::define_field!(KyberField, u16, u32, u64, 3329); + +// ======================================== +// byte_encode / byte_decode round-trip tests +// ======================================== + +#[test] +fn byte_encode_decode_d1_roundtrip() { + // D=1: Single bit encoding + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new((i % 2) as u16)); + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_decode_d4_roundtrip() { + // D=4: 4-bit encoding + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new((i % 16) as u16)); + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_decode_d10_roundtrip() { + // D=10: 10-bit encoding + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new((i % 1024) as u16)); + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_decode_d12_roundtrip() { + // D=12: 12-bit encoding (special case with modular reduction) + // Values up to q-1 (3328) + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329)); + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_decode_d12_modular_reduction() { + // Test that D=12 properly reduces values >= Q + // Fill with values near and above Q + let vals: [Elem; 256] = + core::array::from_fn(|i| Elem::new(3329 + (i as u16) % 100)); + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + // After decode, values should be reduced mod Q + for (i, dec) in decoded.iter().enumerate() { + assert!(dec.0 < 3329, "Value at {i} not reduced: {}", dec.0); + } +} + +#[test] +fn byte_encode_zero_values() { + let vals = [Elem::::new(0); 256]; + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for dec in &decoded { + assert_eq!(dec.0, 0); + } +} + +#[test] +fn byte_encode_max_values() { + // D=4: max value is 15 + let vals = [Elem::::new(15); 256]; + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for dec in &decoded { + assert_eq!(dec.0, 15); + } +} + +// ======================================== +// Polynomial encoding tests +// ======================================== + +#[test] +fn polynomial_encode_decode_roundtrip() { + let coeffs: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 7) as u16 % 16)); + let p = Polynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(p, decoded); +} + +#[test] +fn polynomial_encode_decode_d12() { + let coeffs: [Elem; 256] = + core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329)); + let p = Polynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(p, decoded); +} + +// ======================================== +// Vector encoding tests +// ======================================== + +#[test] +fn vector_encode_decode_roundtrip() { + use hybrid_array::typenum::U2; + + let coeffs1: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 3) as u16 % 16)); + let coeffs2: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 5) as u16 % 16)); + + let p1 = Polynomial::::new(coeffs1.into()); + let p2 = Polynomial::::new(coeffs2.into()); + let v: Vector = Vector::new([p1, p2].into()); + + let encoded = as Encode>::encode(&v); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(v, decoded); +} + +// ======================================== +// NttPolynomial encoding tests +// ======================================== + +#[test] +fn ntt_polynomial_encode_decode_roundtrip() { + let coeffs: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 7) as u16 % 16)); + let p = NttPolynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(p, decoded); +} + +#[test] +fn ntt_polynomial_encode_decode_d12() { + let coeffs: [Elem; 256] = + core::array::from_fn(|i| Elem::new((i * 13) as u16 % 3329)); + let p = NttPolynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(p, decoded); +} + +// ======================================== +// NttVector encoding tests +// ======================================== + +#[test] +fn ntt_vector_encode_decode_roundtrip() { + use hybrid_array::typenum::U2; + + let coeffs1: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 3) as u16 % 16)); + let coeffs2: [Elem; 256] = core::array::from_fn(|i| Elem::new((i * 5) as u16 % 16)); + + let p1 = NttPolynomial::::new(coeffs1.into()); + let p2 = NttPolynomial::::new(coeffs2.into()); + let v: NttVector = NttVector::new([p1, p2].into()); + + let encoded = as Encode>::encode(&v); + let decoded = as Encode>::decode(&encoded); + + assert_eq!(v, decoded); +} + +// ======================================== +// Encoding size verification +// ======================================== + +#[test] +fn encoded_polynomial_size_d4() { + // D=4 means 4 bits per coefficient, 256 coefficients = 1024 bits = 128 bytes + let coeffs = [Elem::::new(0); 256]; + let p = Polynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + assert_eq!(encoded.len(), 128); +} + +#[test] +fn encoded_polynomial_size_d12() { + // D=12 means 12 bits per coefficient, 256 coefficients = 3072 bits = 384 bytes + let coeffs = [Elem::::new(0); 256]; + let p = Polynomial::::new(coeffs.into()); + + let encoded = as Encode>::encode(&p); + assert_eq!(encoded.len(), 384); +} + +#[test] +fn encoded_vector_size() { + use hybrid_array::typenum::U3; + + // D=4, K=3: 128 bytes per polynomial * 3 = 384 bytes + let coeffs = [Elem::::new(0); 256]; + let p = Polynomial::::new(coeffs.into()); + let v: Vector = Vector::new([p, p, p].into()); + + let encoded = as Encode>::encode(&v); + assert_eq!(encoded.len(), 384); +} + +// ======================================== +// Edge cases and boundary tests +// ======================================== + +#[test] +fn byte_encode_alternating_bits() { + // Test alternating patterns to catch bit manipulation issues + let vals: [Elem; 256] = + core::array::from_fn(|i| Elem::new(if i % 2 == 0 { 0b0101 } else { 0b1010 })); + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); + } +} + +#[test] +fn byte_encode_sequential_values() { + // Sequential values to catch ordering issues + let vals: [Elem; 256] = core::array::from_fn(|i| Elem::new(i as u16 % 16)); + + let encoded = byte_encode::(&vals.into()); + let decoded = byte_decode::(&encoded); + + for (i, (dec, val)) in decoded.iter().zip(vals.iter()).enumerate() { + assert_eq!(dec.0, val.0, "Mismatch at index {i}"); + } +}