From 458d7a1ef44e878545df9fa6991360502ffcc9f1 Mon Sep 17 00:00:00 2001 From: Jianming Tong Date: Fri, 13 Dec 2024 15:32:57 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 706021243 --- jaxite/jaxite_lib/bootstrap.py | 12 ++- jaxite/jaxite_lib/polymul_kernel.py | 126 ++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/jaxite/jaxite_lib/bootstrap.py b/jaxite/jaxite_lib/bootstrap.py index d49a399..fd08cb5 100644 --- a/jaxite/jaxite_lib/bootstrap.py +++ b/jaxite/jaxite_lib/bootstrap.py @@ -355,9 +355,15 @@ def jit_external_product( decomposed_rlwe = decomposition.decompose_rlwe_ciphertext( rlwe_ct, decomposition_params ) - return polymul_kernel.negacyclic_vector_matrix_polymul( - decomposed_rlwe, rgsw_ct - ) + use_bat = True + if use_bat: + return polymul_kernel.negacyclic_vector_matrix_polymul_bat( + decomposed_rlwe, rgsw_ct + ) + else: + return polymul_kernel.negacyclic_vector_matrix_polymul( + decomposed_rlwe, rgsw_ct + ) def cmux( diff --git a/jaxite/jaxite_lib/polymul_kernel.py b/jaxite/jaxite_lib/polymul_kernel.py index 18139ff..e8b699f 100644 --- a/jaxite/jaxite_lib/polymul_kernel.py +++ b/jaxite/jaxite_lib/polymul_kernel.py @@ -92,6 +92,29 @@ def _i32_matmul_unreduced_CGGI(lhs, rhs): return acc +@jax.jit +def bat_matmul(lhs: jax.Array, y: jax.Array): + """Perform matrix multiplication between u8(m, n, 4, 4)@u32(n, k). + + Args: + lhs: Input u8(m, n, 4, 4) Left Matrix + y: Input u32(n, k) Right Matrix, will be converted to u8(n, k, 4) + + Returns: + Output u32(m, k) + """ + + rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) + i8_products = jnp.einsum( + "cmnpq,cnkq->cmkp", + lhs, + rhs, + preferred_element_type=jnp.uint32, + ) + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + return jnp.sum(i8_products << shift_factors, axis=(0, 3,)) + + def _vector_matrix_polymul(poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray): # b is the product of the RLWE dimension (e.g., 3) and the number of # decomposition levels in the decomposition parameters (e.g., 6). @@ -211,6 +234,109 @@ def negacyclic_vector_matrix_polymul( return fallback_vector_matrix_polymul(vec, matrix) +@jax.named_call +@jax.jit +def negacyclic_vector_matrix_polymul_bat( + poly_vec1: jnp.ndarray, poly_mat2: jnp.ndarray, +) -> jnp.ndarray: + """Perform matrix multiplication between u8(c, m, n, 4, 4)@u32(c, n, k). + + Args: + poly_vec1: Input u8(c, m, n, 4, 4) Left Matrix + poly_mat2: Input u32(c, n, k) Right Matrix -> converted to u8(c, n, k, 4) + + Returns: + Output u32(m, k) + """ + # b is the product of the RLWE dimension (e.g., 3) and the number of + # decomposition levels in the decomposition parameters (e.g., 6). + # n is the degree of the RLWE polynomials. + assert poly_vec1.dtype == jnp.uint32 + assert poly_mat2.dtype == jnp.uint32 + b, n = poly_vec1.shape + # m is the number of polynomials in the RLWE dimension (e.g., 3) + b2, m, n2 = poly_mat2.shape + assert b == b2 and n == n2 + + n_matrix = poly_mat2.shape[-1] + n_vec = poly_vec1.shape[-1] + if n_matrix != n_vec: + raise ValueError( + "Expected polynomial degree of the inputs to match, " + f"but found {n_vec} != {n_matrix}" + ) + + tpu_version = jax_helpers.get_tpu_version() + if n_vec % 128 == 0 and tpu_version >= 5: + def _toeplitz_chunk(poly_vec1, vec_toeplitz): + n = poly_vec1.shape[2] + chunk = jnp.broadcast_to(poly_vec1[...][0], (128, n)) + chunk = pltpu.roll(chunk, 0, 1, stride=1, stride_axis=0) + chunk_row_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, n), dimension=0 + ) + chunk_col_indices = jax.lax.broadcasted_iota( + dtype=jnp.int32, shape=(128, n), dimension=1 + ) + toeplitz_chunks = [] + for _ in range(0, n, 128): + toeplitz_chunks.append( + jnp.where(chunk_row_indices > chunk_col_indices, -chunk, chunk) + ) + # Because the vector registers are aligned to size 128, this + # roll operation lowers to telling the TPU to refer to a different + # register, rather than actually applying any rolling operation. + # Hence, the op produces no hardware instructions. + chunk = pltpu.roll(chunk, 128, 1) + chunk_row_indices = chunk_row_indices + 128 + vec_toeplitz[...] = jax.lax.concatenate( + toeplitz_chunks, dimension=0 + ).reshape(poly_vec1.shape[0], n, n) + + vec_toeplitz_list = pl.pallas_call( + _toeplitz_chunk, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=(pl.BlockSpec((1, 1, n), lambda b: (b, 0, 0)),), + out_specs=pl.BlockSpec((1, n, n), lambda b: (b, 0, 0)), + grid=(18,), + ), + out_shape=jax.ShapeDtypeStruct((b, n, n), jnp.int32), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel",) + ), + )(poly_vec1[:, None].astype(jnp.int32)) + + def bat_offline_compile_cggi(mat_a): + """Convert the input matrix with 32 bit elements into u8(*matrix.shape,4,4). + + i.e. replace each element in the original matrix by a p*q matrix (p=q=4). + + Args: + mat_a: The input matrix. + + Returns: + The converted matrix. + """ + mat_a_u8 = jax.lax.bitcast_convert_type(mat_a, new_dtype=jnp.uint8).reshape( + *mat_a.shape, 4, 1 + ) + mat_a_u8_r1 = jnp.roll(mat_a_u8, 1, axis=-2) + mat_a_u8_r2 = jnp.roll(mat_a_u8, 2, axis=-2) + mat_a_u8_r3 = jnp.roll(mat_a_u8, 3, axis=-2) + mat_a_u8_array = jnp.concatenate( + [mat_a_u8, mat_a_u8_r1, mat_a_u8_r2, mat_a_u8_r3], axis=-1 + ) + return jnp.tril(mat_a_u8_array) + + poly_mat2 = bat_offline_compile_cggi(poly_mat2) + + return bat_matmul(poly_mat2, vec_toeplitz_list) + + else: + return fallback_vector_matrix_polymul(poly_vec1, poly_mat2) + + def i32_matmul_unreduced(lhs, rhs, out): """A helper to isolate the matmul part of the kernel to test in isolation.""" out[...] = _i32_matmul_unreduced(lhs[...], rhs[...])