From 492f455184454c9cec5cc95199ee8f5650f97bfe Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 27 May 2026 09:58:53 -0700 Subject: [PATCH] Add one more 30-bit prime barrett test PiperOrigin-RevId: 922194596 --- jaxite/jaxite_ckks/barrett_test.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/jaxite/jaxite_ckks/barrett_test.py b/jaxite/jaxite_ckks/barrett_test.py index 8554957..c5c9c8f 100644 --- a/jaxite/jaxite_ckks/barrett_test.py +++ b/jaxite/jaxite_ckks/barrett_test.py @@ -42,14 +42,28 @@ class BarrettTest(parameterized.TestCase): @parameterized.parameters( (1073753729, [0, 1, 1073753728, 1073753729, 2000000000]), (65537, [0, 1, 65536, 65537, 1000000]), + (1073741441, [193942197421246]), ) def test_modular_reduction_basic(self, modulus, inputs): expected = [x % modulus for x in inputs] constants = barrett.precompute_barrett_constants(modulus) - unreduced = jnp.array(inputs, dtype=jnp.uint64) + unreduced = jnp.array(inputs, dtype=jnp.uint64).reshape(1, 1, -1) actual = barrett.modular_reduction(unreduced, constants) + actual = np.array(actual).flatten() np.testing.assert_array_equal(actual, expected) + def test_modular_reduction_simulation_data(self): + moduli = [1073741441, 1073740609] + inputs = jnp.array([193942197421246, 49184246388363], dtype=jnp.uint64) + expected = jnp.array([870864944, 484052509], dtype=jnp.uint32) + + constants = barrett.precompute_barrett_constants(moduli) + + unreduced = inputs.reshape(1, 1, -1) + actual = barrett.modular_reduction(unreduced, constants) + + np.testing.assert_array_equal(actual.flatten(), expected) + @hypothesis.settings(deadline=None, max_examples=50) @hypothesis.given(moduli_and_z()) def test_modular_reduction_hypothesis(self, moduli_and_input):