diff --git a/BUILD b/BUILD index 6bd7589..9eda046 100644 --- a/BUILD +++ b/BUILD @@ -178,6 +178,22 @@ tpu_test( ], ) +tpu_test( + name = "add_test", + size = "large", + timeout = "eternal", + srcs = ["jaxite_word/add_test.py"], + shard_count = 3, + deps = [ + ":jaxite", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@jaxite_deps_jax//:pkg", + "@jaxite_deps_jaxlib//:pkg", + "@jaxite_deps_numpy//:pkg", + ], +) + cpu_gpu_tpu_test( name = "decomposition_test", size = "small", diff --git a/jaxite_word/add.py b/jaxite_word/add.py new file mode 100644 index 0000000..a07bcd0 --- /dev/null +++ b/jaxite_word/add.py @@ -0,0 +1,61 @@ +"""TPU kernels for Evaluation of the CKKS algorithm.""" + +import jax +import jax.numpy as jnp + + +def jax_add(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array): + """This function processes all degree of the two input polynomials in parallel using multi-trheading. + + Assuming the input data type is jax array. + + Args: + value_a: the first operand of the addition. + value_b: the second operand of the addition. + modulus_list: the list of moduli for each degree. + + Returns: + The result of the addition. + """ + num_elements, _, degree = value_a.shape + modulus_broadcast = jnp.tile( + modulus_list[None, :, None], (num_elements, 1, degree) + ) + result = value_a + value_b + return jnp.where( + result > modulus_broadcast, result - modulus_broadcast, result + ) # jnp.mod(value_a + value_b, modulus_broadcast) + + +def vmap_add(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array): + """This function processes all degree of the two input polynomials in SIMD using jax.vmap. + + Assumes the input data type is a 3-dimensional jax Array as (num_elements, num_towers, degree) + where: + * num_elements: number of polynomials + * num_towers: number of RNS limbs + * degree: degree of the polynomials + + This vmap_add can later be extended to batch ciphertexts. + + Args: + value_a: the first operand of the addition. + value_b: the second operand of the addition. + modulus_list: the list of moduli for each degree. + + Returns: + The result of the addition. + """ + num_elements, num_towers, degree = value_a.shape + modulus_broadcast = jnp.tile( + modulus_list[None, :, None], (num_elements, 1, degree) + ) + + def chunk_wise_add(value_a, value_b): + return value_a + value_b + + def chunk_wise_subtract(value_a, value_b): + return jnp.where(value_a > value_b, value_a - value_b, value_a) + + result = jax.vmap(chunk_wise_add)(value_a, value_b) + return jax.vmap(chunk_wise_subtract)(result, modulus_broadcast) diff --git a/jaxite_word/add_test.py b/jaxite_word/add_test.py new file mode 100644 index 0000000..e49b831 --- /dev/null +++ b/jaxite_word/add_test.py @@ -0,0 +1,108 @@ +"""A module for operations on test CKKS evaluation kernels including. + +- ModAdd +- HEAdd +- HESub +- HEMul +- HERotate +""" + +from concurrent import futures +from typing import Any, Callable + +import jax +import jax.numpy as jnp +from jaxite.jaxite_word import add + +from absl.testing import absltest +from absl.testing import parameterized + + +ProcessPoolExecutor = futures.ProcessPoolExecutor + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_traceback_filtering", "off") + + +class CKKSEvalKernelsTest(parameterized.TestCase): + """A base class for running bootstrap tests.""" + + def __init__(self, *args, **kwargs): + super(CKKSEvalKernelsTest, self).__init__(*args, **kwargs) + self.debug = False # dsiable it from printing the test input values + self.modulus_element_0_tower_0 = 1152921504606748673 + self.modulus_element_0_tower_1 = 268664833 + self.modulus_element_0_tower_2 = 557057 + self.random_key = jax.random.key(0) + + def random(self, shape, modulus_list, dtype=jnp.int32): + assert len(modulus_list) == shape[1] + + return jnp.concatenate( + [ + jax.random.randint( + self.random_key, + shape=(shape[0], 1, shape[2]), + minval=0, + maxval=bound, + dtype=dtype, + ) + for bound in modulus_list + ], + axis=1, + ) + + @parameterized.named_parameters( + dict( + testcase_name="jax_add", + test_target=add.jax_add, + modulus_list=[1152921504606748673, 268664833, 557057], + shape=(2, 3, 16384), # number of elements, number of towers, degree + ), + dict( + testcase_name="vmap_add", + test_target=add.vmap_add, + modulus_list=[1152921504606748673, 268664833, 557057], + shape=(2, 3, 16384), # number of elements, number of towers, degree + ), + ) + def test_add( + self, + test_target: Callable[[Any, Any, Any], Any], + modulus_list=jax.Array, + shape=tuple[int, int, int], + ): + """This function tests the add function using Python native integer data type with arbitrary precision. + + This test finishes in 1.05 second. + + Args: + test_target: The function to test. + modulus_list: A jax.Array of integers. + shape: A tuple of integers representing the shape of the input arrays. + """ + # Only test a single element to save comparison time, + # Correctness-wise, it's sufficient for add. + value_a = self.random(shape, modulus_list, dtype=jnp.uint64) + value_b = self.random(shape, modulus_list, dtype=jnp.uint64) + assert value_a.shape == shape + assert value_b.shape == shape + result_a_plus_b = [] + for element_id in range(value_a.shape[0]): + result_a_plus_b_one_element = [] + for tower_id in range(value_a.shape[1]): + add_res = int(value_b[element_id, tower_id, 0]) + int( + value_a[element_id, tower_id, 0] + ) + if add_res > modulus_list[tower_id]: + add_res = add_res - modulus_list[tower_id] + result_a_plus_b_one_element.append(add_res) + result_a_plus_b.append(result_a_plus_b_one_element) + result_a_plus_b = jnp.array(result_a_plus_b, dtype=jnp.uint64) + modulus_list = jnp.array(modulus_list, dtype=jnp.uint64) + result = test_target(value_a, value_b, modulus_list) + self.assertEqual(result[:, :, 0].all(), result_a_plus_b.all()) + + +if __name__ == "__main__": + absltest.main()