From 7a9e0e2f1db615a5681f537a9325dcbb153096d6 Mon Sep 17 00:00:00 2001 From: Shruthi Gorantala Date: Fri, 30 Jan 2026 11:43:38 -0800 Subject: [PATCH] Add Montgomery Reduction and Barrett Reduction from CROSS PiperOrigin-RevId: 863329143 --- BUILD | 35 +++ jaxite/jaxite_ckks/finite_field.py | 297 ++++++++++++++++++++++++ jaxite/jaxite_ckks/finite_field_test.py | 65 ++++++ jaxite/jaxite_ckks/util.py | 105 +++++++++ jaxite/jaxite_ckks/util_test.py | 16 ++ 5 files changed, 518 insertions(+) create mode 100644 jaxite/jaxite_ckks/finite_field.py create mode 100644 jaxite/jaxite_ckks/finite_field_test.py create mode 100644 jaxite/jaxite_ckks/util.py create mode 100644 jaxite/jaxite_ckks/util_test.py diff --git a/BUILD b/BUILD index e0bb789..288ea2f 100644 --- a/BUILD +++ b/BUILD @@ -471,3 +471,38 @@ gpu_tpu_test( "@jaxite_deps_parameterized//:pkg", ], ) + +gpu_tpu_test( + name = "finite_field_test", + size = "small", + timeout = "moderate", + srcs = ["jaxite/jaxite_ckks/finite_field_test.py"], + deps = [ + ":jaxite", + ":test_utils", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@jaxite_deps_hypothesis//:pkg", + "@jaxite_deps_jax//:pkg", + "@jaxite_deps_jaxlib//:pkg", + "@jaxite_deps_numpy//:pkg", + "@jaxite_deps_parameterized//:pkg", + ], +) + +py_test( + name = "util_test", + size = "small", + timeout = "moderate", + srcs = ["jaxite/jaxite_ckks/util_test.py"], + deps = [ + ":jaxite", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@jaxite_deps_hypothesis//:pkg", + "@jaxite_deps_jax//:pkg", + "@jaxite_deps_jaxlib//:pkg", + "@jaxite_deps_numpy//:pkg", + "@jaxite_deps_parameterized//:pkg", + ], +) diff --git a/jaxite/jaxite_ckks/finite_field.py b/jaxite/jaxite_ckks/finite_field.py new file mode 100644 index 0000000..273775a --- /dev/null +++ b/jaxite/jaxite_ckks/finite_field.py @@ -0,0 +1,297 @@ +"""Name: JAX Finite Field Context Integration + +Name Template: Context + - : (JAX accelerator backend). + - : [Optional] + - Empty: Standard scalar. + - RNS: Residue Number System. + - DRNS: Digitized RNS. + - RD: Radix Decomposition (Big Integer simulation). + - : [Optional] + - Montgomery: Montgomery reduction. + - Barrett: Barrett reduction. + - Shoup: Shoup reduction. + - : [Optional] + - MultipleModuli: vectorized over moduli. + - Lazy: Lazy reduction. + - Opt/Opt2: Optimization levels or specific variants. + - Context: Class suffix. + - : [Optional] Abstract base class. + +Explanation: This module adapts the generic finite field contexts for use with +JAX. It inherits from the base contexts in `finite_field_context.py` and adds +functionality to precompute and format parameters (such as modular inverses, RNS +matrices, and bit-shifted constants) into JAX-compatible arrays. It serves as +the configuration bridge between the mathematical specifications and the JAX +kernels. +""" + +import math +from typing import List, Union +import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import util + +jax.config.update("jax_enable_x64", True) + + +######################## +# Base Context Class +######################## +class FiniteFieldContextBase: + + def __init__(self, moduli: int): + self.moduli = moduli + + def to_computation_format(self, a: int): + return a + + def to_original_format(self, a: jnp.ndarray): + return a + + def get_jax_parameters(self): + return {} + + def modular_reduction(self, a: jnp.ndarray) -> jnp.ndarray: + raise NotImplementedError("Subclasses must implement this method") + + def drop_last_modulus(self): + raise NotImplementedError("Subclasses must implement this method") + + +######################## +# Montgomery Modulus Reduction Context +######################## +class MontgomeryContext(FiniteFieldContextBase): + + def __init__(self, moduli: Union[List[int], int]): + super().__init__(moduli) + self.moduli = moduli + if isinstance(self.moduli, int): + self.moduli = [self.moduli] + self.w = 32 + self.w_inv = [util.modinv(1 << self.w, m) for m in self.moduli] + self.w_inv_reduction = jnp.array(self.w_inv, jnp.uint64) + + self.moduli_reduction = jnp.array(self.moduli, jnp.uint64) + + self.moduli_inv_32 = [util.modinv(m, 2**32) for m in self.moduli] + self.moduli_low16 = [m & 0xFFFF for m in self.moduli] + self.moduli_high16 = [m >> 16 for m in self.moduli] + + self.q = jnp.array(self.moduli, dtype=jnp.uint32) + self.q_low = jnp.array(self.moduli_low16, dtype=jnp.uint32) + self.q_high = jnp.array(self.moduli_high16, dtype=jnp.uint32) + self.q_inv_32 = jnp.array(self.moduli_inv_32, dtype=jnp.uint32) + + def to_computation_format(self, a: int): + # The algorithm being performed: + # [(a * (1 << self.w)) % m for m in self.moduli] + return ((a << self.w) % self.moduli_reduction).astype(jnp.uint32) + + def to_original_format(self, a: jnp.ndarray): + return (a * self.w_inv_reduction) % self.moduli_reduction + + def get_jax_parameters(self): + return { + "moduli": util.to_tuple(self.moduli), + "moduli_inv_32": util.to_tuple(self.moduli_inv_32), + "moduli_low": util.to_tuple(self.moduli_low16), + "moduli_high": util.to_tuple(self.moduli_high16), + } + + def modular_reduction(self, z: jnp.ndarray) -> jnp.ndarray: + """Montgomery reduction from u64 to u32 optimized version using only 32-bit + operations. + + Args: + z: - is u64 array of shape (B, M) - input + + parameters: + moduli: + - Tuple parameters constants + - is u32 array of shape (M) + - modular or moduli + moduli_low: + - Tuple parameters constants + - is u32 array of shape (M) + - low 16 bits of modular or moduli + moduli_high: + - Tuple parameters constants + - is u32 array of shape (M) + - high 16 bits of modular or moduli + moduli_inv_32: + - Tuple parameters constants + - is u32 array of shape (M) + - modular inverse of q mod 2^32 + Returns: + - is u32 array of shape (B, M) + - output + - reduced value + """ + + # Local constants + MASK32 = 0xFFFFFFFF + MASK16 = 0xFFFF + SHIFT16 = 16 + SHIFT32 = 32 + # Ensure dimensions for broadcasting + q = self.q + q_low = self.q_low + q_high = self.q_high + q_inv_32 = self.q_inv_32 + + # Computation + z_low = z.astype(jnp.uint32) + z_high = (z >> SHIFT32).astype(jnp.uint32) + t = (z_low * q_inv_32) & MASK32 + t_low = t & MASK16 + t_high = (t >> SHIFT16) & MASK16 + + prod_high = t_high * q_high # This contributes directly to upper 32 bits + prod_mid_high = t_high * q_low # Upper 16 bits go to upper 32 bits + prod_mid_low = t_low * q_high # Upper 16 bits go to upper 32 bits + prod_low = t_low * q_low # Upper 16 bits contribute to middle part + mid_low = ( + (prod_mid_high & MASK16) + + (prod_mid_low & MASK16) + + (prod_low >> SHIFT16) + ) + mid_high = ( + (prod_mid_high >> SHIFT16) + + (prod_mid_low >> SHIFT16) + + (mid_low >> SHIFT16) + ) + + # Final upper 32 bits + t_final = prod_high + mid_high + b = z_high + q - t_final + # Ensure strict reduction + # b = jnp.where(b >= q, b - q, b).astype(jnp.uint32) + return b.astype(jnp.uint32) + + def drop_last_modulus(self): + # self.moduli_reduction, self.moduli_inv_32, self.moduli_low16, self.moduli_high16 are not updated here. + # Because they are not used in the reduction. + # self.moduli = self.moduli[:-1] + self.moduli_reduction = self.moduli_reduction[:-1] + self.q = self.q[:-1] + self.q_low = self.q_low[:-1] + self.q_high = self.q_high[:-1] + self.q_inv_32 = self.q_inv_32[:-1] + + +######################## +# Barrett Modulus Reduction Context +######################## +class BarrettContext(FiniteFieldContextBase): + """Context for performing modular reduction using Barrett's algorithm. + + This class precomputes parameters necessary for efficient Barrett reduction + within JAX, supporting both single and multiple moduli. + """ + + def __init__(self, moduli: Union[List[int], int]): + super().__init__(moduli) + self.moduli = moduli + if isinstance(self.moduli, int): + self.moduli = [self.moduli] + + self.barrett_s = [2 * math.ceil(math.log2(m)) for m in self.moduli] + self.barrett_w = [min(s, 32) for s in self.barrett_s] + self.barrett_s_w = [s - w for s, w in zip(self.barrett_s, self.barrett_w)] + self.barrett_m = [ + math.floor(2**s / m) for s, m in zip(self.barrett_s, self.moduli) + ] + # used for run-time reduction + self.m = jnp.array(self.barrett_m, dtype=jnp.uint64) + self.moduli_reduction = jnp.array(self.moduli, dtype=jnp.uint64) + self.w = jnp.array(self.barrett_w, dtype=jnp.uint16) + self.s_w = jnp.array(self.barrett_s_w, dtype=jnp.uint16) + + def to_computation_format(self, a): + return a + + def to_original_format(self, a): + return a + + def get_jax_parameters(self): + return { + "barrett_m": util.to_tuple(self.barrett_m), + "moduli": util.to_tuple(self.moduli), + "barrett_w": util.to_tuple(self.barrett_w), + "barrett_s_w": util.to_tuple(self.barrett_s_w), + } + + def modular_reduction(self, z: jnp.ndarray) -> jnp.ndarray: + """Vectorized implementation of the Barrett reduction. + + Works for modulus `q` less than 31 bits. + + This implementation sets the internal shift width `w` to `min(s, 32)` so it + works with small modulus `moduli < 2^16`. + + Args: + z: The input value. + moduli: The RNS moduli. + s_w: The bit width of moduli. + w: The internal shift width. + m: The precomputed value for Barrett reduction. + + Returns: + The result of the Barrett reduction. + """ + m = self.m + moduli = self.moduli_reduction + w = self.w + s_w = self.s_w + + z1 = z & 0xFFFFFFFF + z2 = z >> w + t = ((z1 * m) >> w) + (z2 * m) + t = t >> s_w + z = z - t * moduli + pred = z >= moduli + return jnp.where(pred, z - moduli, z).astype(jnp.uint32) + + def modular_reduction_single_modulus( + self, z: jnp.ndarray, modulus_index: int + ) -> jnp.ndarray: + """Vectorized implementation of the Barrett reduction. + + Works for modulus `q` less than 31 bits. + + This implementation sets the internal shift width `w` to `min(s, 32)` so it + works with small modulus `moduli < 2^16`. + + Args: + z: The input value. + moduli: The RNS moduli. + s_w: The bit width of moduli. + w: The internal shift width. + m: The precomputed value for Barrett reduction. + + Returns: + The result of the Barrett reduction. + """ + m = self.m[modulus_index] + moduli = self.moduli_reduction[modulus_index] + w = self.w[modulus_index] + s_w = self.s_w[modulus_index] + + z1 = z.astype(jnp.uint32) + z2 = (z >> w).astype(jnp.uint32) + t = ((z1 * m) >> w) + (z2 * m) + t = t >> s_w + z = z - t * moduli + pred = z >= moduli + return jnp.where(pred, z - moduli, z).astype(jnp.uint32) + + def drop_last_modulus(self): + # self.barrett_s, self.barrett_w, self.barrett_s_w, self.barrett_m are not + # updated here. Because they are not used in the reduction. + # self.moduli = self.moduli[:-1] + self.m = self.m[:-1] + self.moduli_reduction = self.moduli_reduction[:-1] + self.w = self.w[:-1] + self.s_w = self.s_w[:-1] diff --git a/jaxite/jaxite_ckks/finite_field_test.py b/jaxite/jaxite_ckks/finite_field_test.py new file mode 100644 index 0000000..3884dad --- /dev/null +++ b/jaxite/jaxite_ckks/finite_field_test.py @@ -0,0 +1,65 @@ +"""Finite Field Test Suite + +Test cases: +- Montgomery Single Modulus Context +- Barrett Single Modulus Context + +Terminology: +- Modulus: Single form of modulus. + +Usage: +- Specify the overall modulus for the context, and corresponding parameter +required for the modular reduction. +- Then feed "modulus" and "parameters" to the context constructor. +- Then context->modular_reduction(input) to get the reduced result for certain +inputs. +""" + +import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import finite_field as ff_context +from jaxite.jaxite_ckks import util +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +testing_params = [{"testcase_name": "0"}] + + +@parameterized.named_parameters(testing_params) +class FiniteFieldTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + # Setup random input data and their modmul reference results. + self.modulus = util.find_moduli_ntt(1, 31, 16)[0] + self.random_key = jax.random.key(0) + self.a = jax.random.randint( + self.random_key, (1,), 0, self.modulus - 1, dtype=jnp.int32 + ) + self.b = jax.random.randint( + self.random_key, (1,), 0, self.modulus - 1, dtype=jnp.int32 + ) + self.ab = self.a.astype(jnp.uint64) * self.b.astype(jnp.uint64) + self.ab_modq = (self.ab % self.modulus).astype(jnp.uint32) + + + def test_montgomery_single_moduli_context(self): + context = ff_context.MontgomeryContext(self.modulus) + a_mont = context.to_computation_format(self.a[0].astype(jnp.uint64)) + b_mont = context.to_computation_format(self.b[0].astype(jnp.uint64)) + ab_mont = a_mont.astype(jnp.uint64) * b_mont.astype(jnp.uint64) + result_mont = context.modular_reduction(ab_mont) + result = context.to_original_format(result_mont.astype(jnp.uint64)) + np.testing.assert_array_equal(result[0], self.ab_modq) + + + def test_barrett_single_moduli_context(self): + context = ff_context.BarrettContext(self.modulus) + ab = self.a.astype(jnp.uint64) * self.b.astype(jnp.uint64) + result = context.modular_reduction(ab) + np.testing.assert_array_equal(result[0], self.ab_modq) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_ckks/util.py b/jaxite/jaxite_ckks/util.py new file mode 100644 index 0000000..59adcae --- /dev/null +++ b/jaxite/jaxite_ckks/util.py @@ -0,0 +1,105 @@ +"""Utils for jaxite_ckks. +""" + + + +def is_prime_deterministic(n): + """Deterministic primality test for n < 2^64. + + Uses Trial Division for speed + Deterministic Miller-Rabin for correctness. + + Args: + n: The number to test for primality. + + Returns: + True if n is prime, False otherwise. + """ + if n < 2: + return False + if n == 2 or n == 3: + return True + if n % 2 == 0: + return False + + # 1. SPEED OPTIMIZATION: Trial Division + # Check divisibility by small primes to fail fast on obvious composites. + # This filters out ~85% of candidates without expensive modular + # exponentiation. + small_primes = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + for p in small_primes: + if n == p: + return True + if n % p == 0: + return False + + # 2. DETERMINISTIC MILLER-RABIN + # For n < 2^64, verifying these specific bases guarantees primality. + # No randomness involved. + d = n - 1 + s = 0 + while d % 2 == 0: + d //= 2 + s += 1 + + # Bases required for deterministic check up to 2^64 + bases = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37] + + for a in bases: + if a >= n: + break + + x = pow(a, d, n) + if x == 1 or x == n - 1: + continue + + for _ in range(s - 1): + x = pow(x, 2, n) + if x == n - 1: + break + else: + return False # Composite + + return True # Prime + + +def find_moduli_ntt(total_number, precision, ntt_length): + """Deterministically finds the largest valid NTT moduli. + + Args: + total_number: Number of moduli to find. + precision: Bit-width (e.g., 60 for < 2^60). + ntt_length: The required N-th root of unity (e.g., 1024). + + Returns: + A list of the largest valid NTT moduli found, up to `total_number`. + """ + overall_moduli = [] + + # Upper bound + limit = 2**precision + + # Start search from the largest possible k + # P = k * ntt_length + 1 + k = (limit - 1) // ntt_length + + while len(overall_moduli) < total_number and k > 0: + candidate_p = k * ntt_length + 1 + # Check candidate + if is_prime_deterministic(candidate_p): + overall_moduli.append(candidate_p) + k -= 1 + + return overall_moduli + + +def to_tuple(a): + """Create to convert numpy array into tuple.""" + try: + return tuple(to_tuple(i) for i in a) + except TypeError: + return a + + +def modinv(x: int, q: int) -> int: + """Returns the inverse of x mod q.""" + return int(pow(x, -1, q)) diff --git a/jaxite/jaxite_ckks/util_test.py b/jaxite/jaxite_ckks/util_test.py new file mode 100644 index 0000000..77953d0 --- /dev/null +++ b/jaxite/jaxite_ckks/util_test.py @@ -0,0 +1,16 @@ +from jaxite.jaxite_ckks import util +from absl.testing import absltest + + +class UtilsTest(absltest.TestCase): + + def test_is_prime_deterministic(self): + self.assertTrue(util.is_prime_deterministic(17)) + self.assertFalse(util.is_prime_deterministic(18)) + + def test_find_moduli_ntt(self): + self.assertEqual(util.find_moduli_ntt(1, 31, 16), [2147483489]) + self.assertEqual(util.find_moduli_ntt(2, 31, 16), [2147483489, 2147483249]) + +if __name__ == "__main__": + absltest.main()