Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions jaxite/jaxite_lib/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,15 @@ def jit_external_product(
decomposed_rlwe = decomposition.decompose_rlwe_ciphertext(
rlwe_ct, decomposition_params
)
return polymul_kernel.negacyclic_vector_matrix_polymul(
decomposed_rlwe, rgsw_ct
)
use_bat = True
if use_bat:
return polymul_kernel.negacyclic_vector_matrix_polymul_bat(
decomposed_rlwe, rgsw_ct
)
else:
return polymul_kernel.negacyclic_vector_matrix_polymul(
decomposed_rlwe, rgsw_ct
)


def cmux(
Expand Down
126 changes: 126 additions & 0 deletions jaxite/jaxite_lib/polymul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,29 @@ def _i32_matmul_unreduced_CGGI(lhs, rhs):
return acc


@jax.jit
def bat_matmul(lhs: jax.Array, y: jax.Array):
"""Perform matrix multiplication between u8(m, n, 4, 4)@u32(n, k).

Args:
lhs: Input u8(m, n, 4, 4) Left Matrix
y: Input u32(n, k) Right Matrix, will be converted to u8(n, k, 4)

Returns:
Output u32(m, k)
"""

rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8)
i8_products = jnp.einsum(
"cmnpq,cnkq->cmkp",
lhs,
rhs,
preferred_element_type=jnp.uint32,
)
shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32)
return jnp.sum(i8_products << shift_factors, axis=(0, 3,))


def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray):
# b is the product of the RLWE dimension (e.g., 3) and the number of
# decomposition levels in the decomposition parameters (e.g., 6).
Expand Down Expand Up @@ -211,6 +234,109 @@ def negacyclic_vector_matrix_polymul(
return fallback_vector_matrix_polymul(vec, matrix)


@jax.named_call
@jax.jit
def negacyclic_vector_matrix_polymul_bat(
poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray,
) -> jnp.ndarray:
"""Perform matrix multiplication between u8(c, m, n, 4, 4)@u32(c, n, k).

Args:
poly_vec1: Input u8(c, m, n, 4, 4) Left Matrix
poly_mat2: Input u32(c, n, k) Right Matrix -> converted to u8(c, n, k, 4)

Returns:
Output u32(m, k)
"""
# b is the product of the RLWE dimension (e.g., 3) and the number of
# decomposition levels in the decomposition parameters (e.g., 6).
# n is the degree of the RLWE polynomials.
assert poly_vec1.dtype == jnp.uint32
assert poly_mat2.dtype == jnp.uint32
b, n = poly_vec1.shape
# m is the number of polynomials in the RLWE dimension (e.g., 3)
b2, m, n2 = poly_mat2.shape
assert b == b2 and n == n2

n_matrix = poly_mat2.shape[-1]
n_vec = poly_vec1.shape[-1]
if n_matrix != n_vec:
raise ValueError(
"Expected polynomial degree of the inputs to match, "
f"but found {n_vec} != {n_matrix}"
)

tpu_version = jax_helpers.get_tpu_version()
if n_vec % 128 == 0 and tpu_version >= 5:
def _toeplitz_chunk(poly_vec1, vec_toeplitz):
n = poly_vec1.shape[2]
chunk = jnp.broadcast_to(poly_vec1[...][0], (128, n))
chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0)
chunk_row_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=0
)
chunk_col_indices = jax.lax.broadcasted_iota(
dtype=jnp.int32, shape=(128, n), dimension=1
)
toeplitz_chunks = []
for _ in range(0, n, 128):
toeplitz_chunks.append(
jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk)
)
# Because the vector registers are aligned to size 128, this
# roll operation lowers to telling the TPU to refer to a different
# register, rather than actually applying any rolling operation.
# Hence, the op produces no hardware instructions.
chunk = pltpu.roll(chunk, 128, 1)
chunk_row_indices = chunk_row_indices + 128
vec_toeplitz[...] = jax.lax.concatenate(
toeplitz_chunks, dimension=0
).reshape(poly_vec1.shape[0], n, n)

vec_toeplitz_list = pl.pallas_call(
_toeplitz_chunk,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=(pl.BlockSpec((1, 1, n), lambda b: (b, 0, 0)),),
out_specs=pl.BlockSpec((1, n, n), lambda b: (b, 0, 0)),
grid=(18,),
),
out_shape=jax.ShapeDtypeStruct((b, n, n), jnp.int32),
compiler_params=pltpu.TPUCompilerParams(
dimension_semantics=("parallel",)
),
)(poly_vec1[:, None].astype(jnp.int32))

def bat_offline_compile_cggi(mat_a):
"""Convert the input matrix with 32 bit elements into u8(*matrix.shape,4,4).

i.e. replace each element in the original matrix by a p*q matrix (p=q=4).

Args:
mat_a: The input matrix.

Returns:
The converted matrix.
"""
mat_a_u8 = jax.lax.bitcast_convert_type(mat_a, new_dtype=jnp.uint8).reshape(
*mat_a.shape, 4, 1
)
mat_a_u8_r1 = jnp.roll(mat_a_u8, 1, axis=-2)
mat_a_u8_r2 = jnp.roll(mat_a_u8, 2, axis=-2)
mat_a_u8_r3 = jnp.roll(mat_a_u8, 3, axis=-2)
mat_a_u8_array = jnp.concatenate(
[mat_a_u8, mat_a_u8_r1, mat_a_u8_r2, mat_a_u8_r3], axis=-1
)
return jnp.tril(mat_a_u8_array)

poly_mat2 = bat_offline_compile_cggi(poly_mat2)

return bat_matmul(poly_mat2, vec_toeplitz_list)

else:
return fallback_vector_matrix_polymul(poly_vec1, poly_mat2)


def i32_matmul_unreduced(lhs, rhs, out):
"""A helper to isolate the matmul part of the kernel to test in isolation."""
out[...] = _i32_matmul_unreduced(lhs[...], rhs[...])
Expand Down