Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,22 @@ tpu_test(
],
)

tpu_test(
name = "add_test",
size = "large",
timeout = "eternal",
srcs = ["jaxite_word/add_test.py"],
shard_count = 3,
deps = [
":jaxite",
"@com_google_absl_py//absl/testing:absltest",
"@com_google_absl_py//absl/testing:parameterized",
"@jaxite_deps_jax//:pkg",
"@jaxite_deps_jaxlib//:pkg",
"@jaxite_deps_numpy//:pkg",
],
)

cpu_gpu_tpu_test(
name = "decomposition_test",
size = "small",
Expand Down
61 changes: 61 additions & 0 deletions jaxite_word/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""TPU kernels for Evaluation of the CKKS algorithm."""

import jax
import jax.numpy as jnp


def jax_add(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array):
"""This function processes all degree of the two input polynomials in parallel using multi-trheading.

Assuming the input data type is jax array.

Args:
value_a: the first operand of the addition.
value_b: the second operand of the addition.
modulus_list: the list of moduli for each degree.

Returns:
The result of the addition.
"""
num_elements, _, degree = value_a.shape
modulus_broadcast = jnp.tile(
modulus_list[None, :, None], (num_elements, 1, degree)
)
result = value_a + value_b
return jnp.where(
result > modulus_broadcast, result - modulus_broadcast, result
) # jnp.mod(value_a + value_b, modulus_broadcast)


def vmap_add(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array):
"""This function processes all degree of the two input polynomials in SIMD using jax.vmap.

Assumes the input data type is a 3-dimensional jax Array as (num_elements, num_towers, degree)
where:
* num_elements: number of polynomials
* num_towers: number of RNS limbs
* degree: degree of the polynomials

This vmap_add can later be extended to batch ciphertexts.

Args:
value_a: the first operand of the addition.
value_b: the second operand of the addition.
modulus_list: the list of moduli for each degree.

Returns:
The result of the addition.
"""
num_elements, num_towers, degree = value_a.shape
modulus_broadcast = jnp.tile(
modulus_list[None, :, None], (num_elements, 1, degree)
)

def chunk_wise_add(value_a, value_b):
return value_a + value_b

def chunk_wise_subtract(value_a, value_b):
return jnp.where(value_a > value_b, value_a - value_b, value_a)

result = jax.vmap(chunk_wise_add)(value_a, value_b)
return jax.vmap(chunk_wise_subtract)(result, modulus_broadcast)
108 changes: 108 additions & 0 deletions jaxite_word/add_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""A module for operations on test CKKS evaluation kernels including.

- ModAdd
- HEAdd
- HESub
- HEMul
- HERotate
"""

from concurrent import futures
from typing import Any, Callable

import jax
import jax.numpy as jnp
from jaxite.jaxite_word import add

from absl.testing import absltest
from absl.testing import parameterized


ProcessPoolExecutor = futures.ProcessPoolExecutor

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_traceback_filtering", "off")


class CKKSEvalKernelsTest(parameterized.TestCase):
"""A base class for running bootstrap tests."""

def __init__(self, *args, **kwargs):
super(CKKSEvalKernelsTest, self).__init__(*args, **kwargs)
self.debug = False # dsiable it from printing the test input values
self.modulus_element_0_tower_0 = 1152921504606748673
self.modulus_element_0_tower_1 = 268664833
self.modulus_element_0_tower_2 = 557057
self.random_key = jax.random.key(0)

def random(self, shape, modulus_list, dtype=jnp.int32):
assert len(modulus_list) == shape[1]

return jnp.concatenate(
[
jax.random.randint(
self.random_key,
shape=(shape[0], 1, shape[2]),
minval=0,
maxval=bound,
dtype=dtype,
)
for bound in modulus_list
],
axis=1,
)

@parameterized.named_parameters(
dict(
testcase_name="jax_add",
test_target=add.jax_add,
modulus_list=[1152921504606748673, 268664833, 557057],
shape=(2, 3, 16384), # number of elements, number of towers, degree
),
dict(
testcase_name="vmap_add",
test_target=add.vmap_add,
modulus_list=[1152921504606748673, 268664833, 557057],
shape=(2, 3, 16384), # number of elements, number of towers, degree
),
)
def test_add(
self,
test_target: Callable[[Any, Any, Any], Any],
modulus_list=jax.Array,
shape=tuple[int, int, int],
):
"""This function tests the add function using Python native integer data type with arbitrary precision.

This test finishes in 1.05 second.

Args:
test_target: The function to test.
modulus_list: A jax.Array of integers.
shape: A tuple of integers representing the shape of the input arrays.
"""
# Only test a single element to save comparison time,
# Correctness-wise, it's sufficient for add.
value_a = self.random(shape, modulus_list, dtype=jnp.uint64)
value_b = self.random(shape, modulus_list, dtype=jnp.uint64)
assert value_a.shape == shape
assert value_b.shape == shape
result_a_plus_b = []
for element_id in range(value_a.shape[0]):
result_a_plus_b_one_element = []
for tower_id in range(value_a.shape[1]):
add_res = int(value_b[element_id, tower_id, 0]) + int(
value_a[element_id, tower_id, 0]
)
if add_res > modulus_list[tower_id]:
add_res = add_res - modulus_list[tower_id]
result_a_plus_b_one_element.append(add_res)
result_a_plus_b.append(result_a_plus_b_one_element)
result_a_plus_b = jnp.array(result_a_plus_b, dtype=jnp.uint64)
modulus_list = jnp.array(modulus_list, dtype=jnp.uint64)
result = test_target(value_a, value_b, modulus_list)
self.assertEqual(result[:, :, 0].all(), result_a_plus_b.all())


if __name__ == "__main__":
absltest.main()