From 20aea98e366484881bbbf8679da64d18627a4e82 Mon Sep 17 00:00:00 2001 From: Jake Harmon Date: Wed, 4 Dec 2024 15:46:25 -0800 Subject: [PATCH] Update references to JAX's GitHub repo JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax PiperOrigin-RevId: 702886997 --- README.md | 2 +- jaxite/jaxite_lib/matrix_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index db8df36..bd8eddf 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Jaxite Jaxite is a fully homomorphic encryption backend targeting TPUs and GPUs, -written in [JAX](https://github.com/google/jax). +written in [JAX](https://github.com/jax-ml/jax). It implements the [CGGI cryptosystem](https://eprint.iacr.org/2018/421) with some optimizations, and is a supported backend for [Google's FHE diff --git a/jaxite/jaxite_lib/matrix_utils.py b/jaxite/jaxite_lib/matrix_utils.py index 100cd10..b30b0b4 100644 --- a/jaxite/jaxite_lib/matrix_utils.py +++ b/jaxite/jaxite_lib/matrix_utils.py @@ -181,7 +181,7 @@ def hpmatmul_conv_adapt_conv(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: lhs: jax.Array = jax.lax.bitcast_convert_type(x, new_dtype=jnp.uint8) # bnmp rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) # nk1q - # https://github.com/google/jax/issues/11483 + # https://github.com/jax-ml/jax/issues/11483 rhs = jax.lax.rev(rhs, [2]) # rhs = jlax.rev(rhs, dimensions=[3])