From a39da78b39fd2e1cc119f865abef14d37d800fde Mon Sep 17 00:00:00 2001 From: Jianming Tong Date: Tue, 8 Apr 2025 08:15:31 -0700 Subject: [PATCH] Update new 28-bit RNS PiperOrigin-RevId: 745151531 --- BUILD | 2 +- jaxite_ec/elliptic_curve.py | 229 +++++++++++++++++++-- jaxite_ec/elliptic_curve_test.py | 337 ++++++++++++++++++++++++------- jaxite_ec/finite_field.py | 231 +++++++++++---------- jaxite_ec/finite_field_test.py | 23 ++- jaxite_ec/msm_test.py | 213 +++++++++++++++---- jaxite_ec/pippenger.py | 24 +-- jaxite_ec/pippenger_rns.py | 295 +++++++++++++++++++++++++-- jaxite_ec/util.py | 233 ++++++++++----------- 9 files changed, 1187 insertions(+), 400 deletions(-) diff --git a/BUILD b/BUILD index 5490cd8..2169fb7 100644 --- a/BUILD +++ b/BUILD @@ -145,7 +145,7 @@ tpu_test( "jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv", ], python_version = "PY3", - shard_count = 3, + shard_count = 6, srcs_version = "PY3ONLY", deps = [ ":jaxite", diff --git a/jaxite_ec/elliptic_curve.py b/jaxite_ec/elliptic_curve.py index bf846d2..b4057d9 100644 --- a/jaxite_ec/elliptic_curve.py +++ b/jaxite_ec/elliptic_curve.py @@ -21,6 +21,7 @@ from jaxite.jaxite_ec import util + add_3u16 = finite_field.add_3u16 add_2u16 = finite_field.add_2u16 sub_2u16 = finite_field.sub_2u16 @@ -33,15 +34,12 @@ add_rns_3u16 = finite_field.add_rns_3u16 add_sub_rns_var = finite_field.add_sub_rns_var negate_rns_for_var_add = finite_field.negate_rns_for_var_add -negate_rns_for_var_add_zero_check = ( - finite_field.negate_rns_for_var_add_zero_check -) -rns_constant = finite_field.rns_constant +negate_rns = finite_field.negate_rns # Barrett Reduction Based Functions @jax.named_call -def padd_barret_xyzz( +def padd_barrett_xyzz( x1: jax.Array, y1: jax.Array, zz1: jax.Array, @@ -108,7 +106,7 @@ def padd_barret_xyzz( @jax.named_call -def pdul_barret_xyzz( +def pdul_barrett_xyzz( x1: jax.Array, y1: jax.Array, zz1: jax.Array, zzz1: jax.Array ): """PDUL-BARRET elliptic curve operation with packed arguments. @@ -163,7 +161,7 @@ def pdul_barret_xyzz( @jax.named_call def pdul_barrett_xyzz_pack(x1_y1_zz1_zzz1: jax.Array): - return pdul_barret_xyzz( + return pdul_barrett_xyzz( x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1], x1_y1_zz1_zzz1[2], x1_y1_zz1_zzz1[3] ) @@ -172,7 +170,7 @@ def pdul_barrett_xyzz_pack(x1_y1_zz1_zzz1: jax.Array): def padd_barrett_xyzz_pack( x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array ): - return padd_barret_xyzz( + return padd_barrett_xyzz( x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1], x1_y1_zz1_zzz1[2], @@ -188,7 +186,7 @@ def padd_barrett_xyzz_pack( def pdul_barrett_xyzz_pack_batch_first( x1_y1_zz1_zzz1: jax.Array, transpose=(0, 1, 2) ): - return pdul_barret_xyzz( + return pdul_barrett_xyzz( x1_y1_zz1_zzz1[:, 0], x1_y1_zz1_zzz1[:, 1], x1_y1_zz1_zzz1[:, 2], @@ -200,7 +198,7 @@ def pdul_barrett_xyzz_pack_batch_first( def padd_barrett_xyzz_pack_batch_first( x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array, transpose=(0, 1, 2) ): - return padd_barret_xyzz( + return padd_barrett_xyzz( x1_y1_zz1_zzz1[:, 0], x1_y1_zz1_zzz1[:, 1], x1_y1_zz1_zzz1[:, 2], @@ -212,6 +210,151 @@ def padd_barrett_xyzz_pack_batch_first( ).transpose(transpose[0], transpose[1], transpose[2]) +@jax.named_call +@functools.partial(jax.jit, static_argnames="twisted_d_chunk") +def padd_barrett_twisted( + x1: jax.Array, + y1: jax.Array, + z1: jax.Array, + t1: jax.Array, + x2: jax.Array, + y2: jax.Array, + z2: jax.Array, + t2: jax.Array, + twisted_d_chunk=util.TWIST_D_INT_CHUNK_BARRETT, +): + """PADD-BARRETT elliptic curve operation with packed arguments. + + As for the algorithm, pls refer to + jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::add_general + + This function implements the PADD-LAZY elliptic curve operation with packed + arguments, which is used to compute the elliptic curve points of a given + group. + + Args: + x1: The first generator element. + y1: The second generator element. + z1: The third generator element. + t1: The fourth generator element. + x2: The first generator element. + y2: The second generator element. + z2: The third generator element. + t2: The fourth generator element. + twisted_d_chunk: The twisted d parameter. + + Returns: + A tuple containing the third generator element and the elliptic curve points + of the group. + """ + twisted_d = jnp.asarray(twisted_d_chunk, dtype=jnp.uint16) + twisted_d = jax.lax.broadcast(twisted_d, [x1.shape[0]]) + + a = mod_mul_barrett_2u16(x1, x2) + b = mod_mul_barrett_2u16(y1, y2) + d = mod_mul_barrett_2u16(z1, z2) + c = mod_mul_barrett_2u16(t1, t2) + c = mod_mul_barrett_2u16(c, twisted_d) + + h = add_2u16(a, b) + h = cond_sub_mod_u16(h) + e1 = add_2u16(x1, y1) + e1 = cond_sub_mod_u16(e1) + e2 = add_2u16(x2, y2) + e2 = cond_sub_mod_u16(e2) + e = mod_mul_barrett_2u16(e1, e2) + + e = cond_sub_2u16(e, h) + + f = cond_sub_2u16(d, c) + g = add_2u16(d, c) + g = cond_sub_mod_u16(g) + + x3 = mod_mul_barrett_2u16(e, f) + y3 = mod_mul_barrett_2u16(g, h) + z3 = mod_mul_barrett_2u16(f, g) + t3 = mod_mul_barrett_2u16(e, h) + + return jnp.array([x3, y3, z3, t3]) + + +def pdul_barrett_twisted( + x1: jax.Array, + y1: jax.Array, + z1: jax.Array, + t1: jax.Array, +): + """PDUL-LAZY elliptic curve operation with packed arguments. + + As for the algorithm, pls refer to + jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::double_general + + This function implements the PDUL-LAZY elliptic curve operation with packed + arguments, which is used to compute the elliptic curve points of a given + group. + + Args: + x1: The first generator element. + y1: The second generator element. + z1: The third generator element. + t1: The fourth generator element. + + Returns: + A tuple containing the third generator element and the elliptic curve points + of the group. + """ + modulus_377_int_array = jnp.asarray( + util.MODULUS_377_INT_CHUNK, jnp.uint16 + ) + + a = mod_mul_barrett_2u16(x1, x1) + b = mod_mul_barrett_2u16(y1, y1) + + ct = mod_mul_barrett_2u16(z1, z1) # + ct2 = add_2u16(ct, ct) # + ct2 = cond_sub_mod_u16(ct2) # + + h = add_2u16(a, b) + h = cond_sub_2u16(modulus_377_int_array, h) # + + et = add_2u16(x1, y1) # + et = cond_sub_mod_u16(et) # + e = mod_mul_barrett_2u16(et, et) # + e = add_2u16(e, h) # + e = cond_sub_mod_u16(e) # + + g = cond_sub_2u16(b, a) # + f = cond_sub_2u16(g, ct2) # + x3 = mod_mul_barrett_2u16(e, f) # + y3 = mod_mul_barrett_2u16(g, h) + z3 = mod_mul_barrett_2u16(f, g) + t3 = mod_mul_barrett_2u16(e, h) + return jnp.array([x3, y3, z3, t3]) + + +@jax.named_call +def pdul_barrett_twisted_pack(x1_y1_zz1_zzz1: jax.Array): + return pdul_barrett_twisted( + x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1], x1_y1_zz1_zzz1[2], x1_y1_zz1_zzz1[3] + ) + + +@jax.named_call +def padd_barrett_twisted_pack( + x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array +): + return padd_barrett_twisted( + x1_y1_zz1_zzz1[0], + x1_y1_zz1_zzz1[1], + x1_y1_zz1_zzz1[2], + x1_y1_zz1_zzz1[3], + x2_y2_zz2_zzz2[0], + x2_y2_zz2_zzz2[1], + x2_y2_zz2_zzz2[2], + x2_y2_zz2_zzz2[3], + ) + + # Lazy Reduction Based Functions @jax.named_call def padd_lazy_xyzz( @@ -505,7 +648,7 @@ def pdul_lazy_twisted( a = mod_mul_lazy_2u16(x1, x1) b = mod_mul_lazy_2u16(y1, y1) - ct = mod_mul_lazy_2u16(z1, z1) + ct = mod_mul_lazy_2u16(z1, z1) # ct2 = add_2u16(ct, ct) # ct2 = cond_sub_mod_u16_ext(ct2) # @@ -786,8 +929,14 @@ def padd_rns_twisted_pack( group. Args: - x1_y1_zz1_zzz1: The first point. - x2_y2_zz2_zzz2: The second point. + x1: The first generator element. + y1: The second generator element. + z1: The third generator element. + t1: The third generator element. + x2: The first generator element. + y2: The second generator element. + z2: The third generator element. + t2: The third generator element. rns_mat: The RNS matrix. twist_d: curve parameter. @@ -815,10 +964,10 @@ def padd_rns_twisted_pack( # Issue happens here e = add_sub_rns_var( e3, - negate_rns_for_var_add_zero_check(a), - negate_rns_for_var_add_zero_check(b), + negate_rns_for_var_add(a), + negate_rns_for_var_add(b), ) - f = add_sub_rns_var(d, negate_rns_for_var_add_zero_check(c)) + f = add_sub_rns_var(d, negate_rns_for_var_add(c)) g = add_rns_2u16(d, c) h = add_rns_2u16(a, b) @@ -890,7 +1039,47 @@ def pdul_rns_twisted_pack(x1_y1_zz1_zzz1: jax.Array, rns_mat=util.RNS_MAT): @jax.named_call -def rns_twist_zero(): - return jnp.array( - [rns_constant(0), rns_constant(1), rns_constant(1), rns_constant(0)] - ) +@functools.partial(jax.jit) +def pneg_rns_twisted( + x1: jax.Array, + y1: jax.Array, + z1: jax.Array, + t1: jax.Array, +): + """PNEG-RNS elliptic curve operation with packed arguments. + + As for the algorithm, pls refer to + jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::double_general + + This function implements the PNEG-RNS elliptic curve operation with packed + arguments, which is used to compute the elliptic curve points of a given + group. + + Args: + x1: The first generator element. + y1: The second generator element. + z1: The third generator element. + t1: The fourth generator element. + + Returns: + A tuple containing the third generator element and the elliptic curve points + of the group. + """ + + x2 = negate_rns(x1) + y2 = y1 + z2 = z1 + t2 = negate_rns(t1) + + return jnp.array([x2, y2, z2, t2]) + + +@jax.named_call +@jax.jit +def pneg_rns_twisted_pack(x1_y1_zz1_zzz1: jax.Array): + return pneg_rns_twisted( + x1_y1_zz1_zzz1[0], + x1_y1_zz1_zzz1[1], + x1_y1_zz1_zzz1[2], + x1_y1_zz1_zzz1[3], + ) \ No newline at end of file diff --git a/jaxite_ec/elliptic_curve_test.py b/jaxite_ec/elliptic_curve_test.py index 4b09116..7c9d68e 100644 --- a/jaxite_ec/elliptic_curve_test.py +++ b/jaxite_ec/elliptic_curve_test.py @@ -81,7 +81,7 @@ def test_pdul_barrett_xyzz(self): profile_name = "jit_pdul_barrett_xyzz_pack" # copybara: util.profile_jax_functions(tasks, profile_name) - def test_jit_pdul_barrett_xyzz_pack_two_no_batch(self): + def test_pdul_barrett_xyzz_pack_two_no_batch(self): point_a_jax = util.int_point_batch_to_jax_point_pack( [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] ) @@ -112,7 +112,7 @@ def test_jit_pdul_barrett_xyzz_pack_two_no_batch(self): profile_name = "jit_pdul_barrett_xyzz_pack" # copybara: util.profile_jax_functions(tasks, profile_name) - def test_jit_pdul_barrett_xyzz_pack_two_batch(self): + def test_pdul_barrett_xyzz_pack_two_batch(self): point_a_jax = util.int_point_batch_to_jax_point_pack( [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] ) @@ -143,11 +143,11 @@ def test_jit_pdul_barrett_xyzz_pack_two_batch(self): def test_padd_lazy_xyzz_pack(self): point_a_jax = util.int_point_batch_to_jax_point_pack( [self.point_a + [1] * (self.coordinate_num - len(self.point_a))], - chunk_num=util.U16_EXT_CHUNK_NUM, + array_size=util.U16_EXT_CHUNK_NUM, ) point_b_jax = util.int_point_batch_to_jax_point_pack( [self.point_b + [1] * (self.coordinate_num - len(self.point_b))], - chunk_num=util.U16_EXT_CHUNK_NUM, + array_size=util.U16_EXT_CHUNK_NUM, ) # lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT) jit_padd_lazy_xyzz_pack = jax.jit(jec.padd_lazy_xyzz_pack) @@ -181,7 +181,7 @@ def test_padd_lazy_xyzz_pack(self): def test_pdul_lazy_xyzz_pack(self): point_a_jax = util.int_point_batch_to_jax_point_pack( [self.point_a + [1] * (self.coordinate_num - len(self.point_a))], - chunk_num=util.U16_EXT_CHUNK_NUM, + array_size=util.U16_EXT_CHUNK_NUM, ) # lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT) @@ -213,6 +213,69 @@ def test_pdul_lazy_xyzz_pack(self): profile_name = "jit_pdul_lazy_xyzz_pack" # copybara: util.profile_jax_functions(tasks, profile_name) + def test_padd_barrett_twisted_pack(self): + twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( + config_file.config_BLS12_377_t + ) + twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) + twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) + + point_a_jax = util.int_point_batch_to_jax_point_pack( + [twist_a], array_size=util.U16_CHUNK_NUM + ) + point_b_jax = util.int_point_batch_to_jax_point_pack( + [twist_b], array_size=util.U16_CHUNK_NUM + ) + + jit_padd_barrett_twisted_pack = jax.jit(jec.padd_barrett_twisted_pack) + result_jax = jit_padd_barrett_twisted_pack(point_a_jax, point_b_jax) + result_int = util.jax_point_pack_to_int_point_batch(result_jax) + + result_affine_point = twisted_ec_sys.generate_point( + result_int[0], twist=False + ).convert_to_affine() + + self.assertEqual( + result_affine_point[0].get_value(), + self.true_result_padd_affine[0].get_value(), + ) + self.assertEqual( + result_affine_point[1].get_value(), + self.true_result_padd_affine[1].get_value(), + ) + + def test_pdul_barrett_twisted_pack(self): + twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( + config_file.config_BLS12_377_t + ) + twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) + point_a_jax = util.int_point_batch_to_jax_point_pack( + [twist_a], array_size=util.U16_CHUNK_NUM + ) + + jit_pdul_barrett_twisted_pack = jax.jit(jec.pdul_barrett_twisted_pack) + result_jax = jit_pdul_barrett_twisted_pack(point_a_jax) + result_int = util.jax_point_pack_to_int_point_batch(result_jax) + + result_affine_point = twisted_ec_sys.generate_point( + result_int[0], twist=False + ).convert_to_affine() + self.assertEqual( + result_affine_point[0].get_value(), + self.true_result_pdub_a_affine[0].get_value(), + ) + self.assertEqual( + result_affine_point[1].get_value(), + self.true_result_pdub_a_affine[1].get_value(), + ) + + # performance measurement + tasks = [ + (jit_pdul_barrett_twisted_pack, (point_a_jax,)), + ] + profile_name = "jit_pdul_barrett_twisted_pack" + # copybara: util.profile_jax_functions(tasks, profile_name) + def test_padd_lazy_twisted_pack(self): twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( config_file.config_BLS12_377_t @@ -221,10 +284,10 @@ def test_padd_lazy_twisted_pack(self): twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_a], array_size=util.U16_EXT_CHUNK_NUM ) point_b_jax = util.int_point_batch_to_jax_point_pack( - [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_b], array_size=util.U16_EXT_CHUNK_NUM ) jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack) @@ -244,6 +307,7 @@ def test_padd_lazy_twisted_pack(self): self.true_result_padd_affine[1].get_value(), ) + @absltest.skip("skip current test") def test_padd_lazy_twisted_pack_batch(self): for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( @@ -253,10 +317,10 @@ def test_padd_lazy_twisted_pack_batch(self): twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_a], array_size=util.U16_EXT_CHUNK_NUM ) point_b_jax = util.int_point_batch_to_jax_point_pack( - [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_b], array_size=util.U16_EXT_CHUNK_NUM ) point_a_jax = jnp.broadcast_to( point_a_jax, (point_a_jax.shape[0], batch_size, point_a_jax.shape[-1]) @@ -309,10 +373,10 @@ def test_padd_same_lazy_twisted_pack(self): twist_a2 = twisted_ec_sys.twist_int_coordinates(self.point_a) point_a1_jax = util.int_point_batch_to_jax_point_pack( - [twist_a1], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_a1], array_size=util.U16_EXT_CHUNK_NUM ) point_a2_jax = util.int_point_batch_to_jax_point_pack( - [twist_a2], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_a2], array_size=util.U16_EXT_CHUNK_NUM ) jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack) @@ -345,7 +409,7 @@ def test_pdul_lazy_twisted_pack(self): ) twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_a], array_size=util.U16_EXT_CHUNK_NUM ) jit_pdul_lazy_twisted_pack = jax.jit(jec.pdul_lazy_twisted_pack) @@ -371,7 +435,7 @@ def test_pdul_lazy_twisted_pack(self): profile_name = "jit_pdul_lazy_twisted_pack" # copybara: util.profile_jax_functions(tasks, profile_name) - def test_jit_pneg_lazy_twisted_pack(self): + def test_pneg_lazy_twisted_pack(self): twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( config_file.config_BLS12_377_t ) @@ -379,10 +443,10 @@ def test_jit_pneg_lazy_twisted_pack(self): twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_a], array_size=util.U16_EXT_CHUNK_NUM ) point_b_jax = util.int_point_batch_to_jax_point_pack( - [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_b], array_size=util.U16_EXT_CHUNK_NUM ) jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack) @@ -451,55 +515,6 @@ def test_padd_rns_xyzz(self): profile_name = "jit_padd_rns_xyzz_pack" # copybara: util.profile_jax_functions(tasks, profile_name) - def test_padd_rns_xyzz_batch(self): - for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: - point_a_jax = util.int_point_batch_to_jax_rns_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] - ) - point_b_jax = util.int_point_batch_to_jax_rns_point_pack( - [self.point_b + [1] * (self.coordinate_num - len(self.point_b))] - ) - point_a_jax = jnp.broadcast_to( - point_a_jax, (point_a_jax.shape[0], batch_size, point_a_jax.shape[-1]) - ) - point_b_jax = jnp.broadcast_to( - point_b_jax, (point_b_jax.shape[0], batch_size, point_b_jax.shape[-1]) - ) - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - - jit_padd_rns_xyzz_pack_batch = jax.jit( - jax.named_call( - functools.partial(jec.padd_rns_xyzz_pack, rns_mat=rns_mat), - name="jit_padd_rns_xyzz_pack_batch", - ), - ) - result_jax = jit_padd_rns_xyzz_pack_batch(point_a_jax, point_b_jax) - result_jax = util.jax_rns_point_pack_to_int_point_batch(result_jax) - - self.assertEqual( - result_jax[0][0] % util.MODULUS_377_INT, - self.true_result_padd[0].get_value(), - ) - self.assertEqual( - result_jax[0][1] % util.MODULUS_377_INT, - self.true_result_padd[1].get_value(), - ) - self.assertEqual( - result_jax[0][2] % util.MODULUS_377_INT, - self.true_result_padd[2].get_value(), - ) - self.assertEqual( - result_jax[0][3] % util.MODULUS_377_INT, - self.true_result_padd[3].get_value(), - ) - - # performance measurement - tasks = [ - (jit_padd_rns_xyzz_pack_batch, (point_a_jax, point_b_jax)), - ] - profile_name = f"jit_padd_rns_xyzz_pack_batch_{batch_size}" - # copybara: util.profile_jax_functions(tasks, profile_name) - def test_pdul_rns_xyzz_pack(self): point_a_jax = util.int_point_batch_to_jax_rns_point_pack( [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] @@ -811,10 +826,10 @@ def test_padd_zero_twisted_pack_new_twisted(self): ) twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_a], array_size=util.U16_EXT_CHUNK_NUM ) point_zero_jax = util.int_point_batch_to_jax_point_pack( - [self.zero_twisted], chunk_num=util.U16_EXT_CHUNK_NUM + [self.zero_twisted], array_size=util.U16_EXT_CHUNK_NUM ) jit_padd_lazy_twisted_pack = jax.jit( @@ -885,10 +900,10 @@ def test_padd_rns_a_point_add_zero_correctness(self): twist_a = twisted_ec_sys.twist_int_coordinates(test_in_point) twist_b = [0, 1, 1, 0] point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_a], array_size=util.U16_EXT_CHUNK_NUM ) point_b_jax = util.int_point_batch_to_jax_point_pack( - [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM + [twist_b], array_size=util.U16_EXT_CHUNK_NUM ) point_a_jax_rns = util.int_point_batch_to_jax_rns_point_pack([twist_a]) point_b_jax_rns = util.int_point_batch_to_jax_rns_point_pack([twist_b]) @@ -935,5 +950,191 @@ def test_padd_rns_a_point_add_zero_correctness(self): affine_sum_point_rns[1].get_value() % util.MODULUS_377_INT, ) + def test_padd_barrett_xyzz_pack_batch(self): + point_a_in = util.int_point_batch_to_jax_point_pack( + [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] + ) + point_b_in = util.int_point_batch_to_jax_point_pack( + [self.point_b + [1] * (self.coordinate_num - len(self.point_b))] + ) + for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: + point_a_jax = jnp.broadcast_to( + point_a_in, (point_a_in.shape[0], batch_size, point_a_in.shape[-1]) + ) + point_b_jax = jnp.broadcast_to( + point_b_in, (point_b_in.shape[0], batch_size, point_b_in.shape[-1]) + ) + jit_padd_barrett_xyzz_pack = jax.jit( + jax.named_call(jec.padd_barrett_xyzz_pack, + name=f"jit_padd_barrett_xyzz_pack_{batch_size}") + ) + + result_jax = jit_padd_barrett_xyzz_pack(point_a_jax, point_b_jax) + result_jax = util.jax_point_pack_to_int_point_batch(result_jax) + + self.assertEqual(result_jax[0][0], self.true_result_padd[0].get_value()) + self.assertEqual(result_jax[0][1], self.true_result_padd[1].get_value()) + self.assertEqual(result_jax[0][2], self.true_result_padd[2].get_value()) + self.assertEqual(result_jax[0][3], self.true_result_padd[3].get_value()) + + # performance measurement + tasks = [ + (jit_padd_barrett_xyzz_pack, (point_a_jax, point_b_jax)), + ] + profile_name = f"jit_padd_barrett_xyzz_pack_{batch_size}" + # copybara: util.profile_jax_functions(tasks, profile_name) + + def test_padd_barrett_twisted_pack_batch(self): + twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( + config_file.config_BLS12_377_t + ) + twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) + twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) + + point_a_in = util.int_point_batch_to_jax_point_pack( + [twist_a], array_size=util.U16_CHUNK_NUM + ) + point_b_in = util.int_point_batch_to_jax_point_pack( + [twist_b], array_size=util.U16_CHUNK_NUM + ) + for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: + point_a_jax = jnp.broadcast_to( + point_a_in, (point_a_in.shape[0], batch_size, point_a_in.shape[-1]) + ) + point_b_jax = jnp.broadcast_to( + point_b_in, (point_b_in.shape[0], batch_size, point_b_in.shape[-1]) + ) + jit_padd_barrett_twisted_pack = jax.jit( + jax.named_call( + jec.padd_barrett_twisted_pack, + name=f"jit_padd_barrett_twisted_pack_{batch_size}", + ) + ) + result_jax = jit_padd_barrett_twisted_pack(point_a_jax, point_b_jax) + result_int = util.jax_point_pack_to_int_point_batch(result_jax) + + result_affine_point = twisted_ec_sys.generate_point( + result_int[0], twist=False + ).convert_to_affine() + + self.assertEqual( + result_affine_point[0].get_value(), + self.true_result_padd_affine[0].get_value(), + ) + self.assertEqual( + result_affine_point[1].get_value(), + self.true_result_padd_affine[1].get_value(), + ) + # performance measurement + tasks = [ + (jit_padd_barrett_twisted_pack, (point_a_jax, point_b_jax)), + ] + profile_name = f"jit_padd_barrett_twisted_pack_{batch_size}" + # copybara: util.profile_jax_functions(tasks, profile_name) + + # @absltest.skip("skip current test") + def test_padd_rns_xyzz_batch(self): + for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: + point_a_jax = util.int_point_batch_to_jax_rns_point_pack( + [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] + ) + point_b_jax = util.int_point_batch_to_jax_rns_point_pack( + [self.point_b + [1] * (self.coordinate_num - len(self.point_b))] + ) + point_a_jax = jnp.broadcast_to( + point_a_jax, (point_a_jax.shape[0], batch_size, point_a_jax.shape[-1]) + ) + point_b_jax = jnp.broadcast_to( + point_b_jax, (point_b_jax.shape[0], batch_size, point_b_jax.shape[-1]) + ) + rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) + + jit_padd_rns_xyzz_pack_batch = jax.jit( + jax.named_call( + functools.partial(jec.padd_rns_xyzz_pack, rns_mat=rns_mat), + name=f"jit_padd_rns_xyzz_pack_batch_{batch_size}", + ), + ) + result_jax = jit_padd_rns_xyzz_pack_batch(point_a_jax, point_b_jax) + result_jax = util.jax_rns_point_pack_to_int_point_batch(result_jax) + + self.assertEqual( + result_jax[0][0] % util.MODULUS_377_INT, + self.true_result_padd[0].get_value(), + ) + self.assertEqual( + result_jax[0][1] % util.MODULUS_377_INT, + self.true_result_padd[1].get_value(), + ) + self.assertEqual( + result_jax[0][2] % util.MODULUS_377_INT, + self.true_result_padd[2].get_value(), + ) + self.assertEqual( + result_jax[0][3] % util.MODULUS_377_INT, + self.true_result_padd[3].get_value(), + ) + + # performance measurement + tasks = [ + (jit_padd_rns_xyzz_pack_batch, (point_a_jax, point_b_jax)), + ] + profile_name = f"jit_padd_rns_xyzz_pack_batch_{batch_size}" + # copybara: util.profile_jax_functions(tasks, profile_name) + + # @absltest.skip("Skip for now") + def test_padd_rns_twisted_pack_new_twist_batch(self): + twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( + config_file.config_BLS12_377_t + ) + rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) + twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) + twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) + twist_a_jax = util.int_point_to_jax_rns_point_pack(twist_a).reshape( + util.COORDINATE_NUM, 1, util.NUM_MODULI + ) + twist_b_jax = util.int_point_to_jax_rns_point_pack(twist_b).reshape( + util.COORDINATE_NUM, 1, util.NUM_MODULI + ) + + for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: + point_a_jax = jnp.broadcast_to( + twist_a_jax, (twist_a_jax.shape[0], batch_size, twist_a_jax.shape[-1]) + ) + point_b_jax = jnp.broadcast_to( + twist_b_jax, (twist_b_jax.shape[0], batch_size, twist_b_jax.shape[-1]) + ) + + jit_padd_rns_twisted_pack_batch = jax.jit( + jax.named_call( + functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat), + name=f"jit_padd_rns_twisted_pack_batch_{batch_size}", + ), + ) + result_batch = jit_padd_rns_twisted_pack_batch(point_a_jax, point_b_jax) + project_twist_sum = util.jax_rns_point_pack_to_int_point_batch( + result_batch + ) + project_twist_jax = twisted_ec_sys.generate_point( + project_twist_sum[0], twist=False + ).convert_to_affine() + + self.assertEqual( + project_twist_jax[0].get_value() % util.MODULUS_377_INT, + self.true_result_padd_affine[0].get_value(), + ) + self.assertEqual( + project_twist_jax[1].get_value() % util.MODULUS_377_INT, + self.true_result_padd_affine[1].get_value(), + ) + + # performance measurement + tasks = [ + (jit_padd_rns_twisted_pack_batch, (point_a_jax, point_b_jax)), + ] + profile_name = f"jit_padd_rns_twisted_pack_batch_{batch_size}" + # copybara: util.profile_jax_functions(tasks, profile_name) + + if __name__ == "__main__": absltest.main() diff --git a/jaxite_ec/finite_field.py b/jaxite_ec/finite_field.py index 863cd21..5222092 100644 --- a/jaxite_ec/finite_field.py +++ b/jaxite_ec/finite_field.py @@ -32,7 +32,6 @@ import jax import jax.numpy as jnp from jaxite.jaxite_ec import util -import numpy as np total_modulus = util.total_modulus @@ -320,6 +319,11 @@ def cond_sub_mod_u16( compare_u16, chunk_num_u16=chunk_num_u16 ) sub_2u16_local = functools.partial(sub_2u16, chunk_num_u16=chunk_num_u16) + # if value_a.shape[0] > 1: + # # Input is batch (Vector, Constant) + # compare_u16_local = jax.vmap(compare_u16_local, in_axes=(0, None)) + # sub_2u16_local = jax.vmap(sub_2u16_local, in_axes=(0, None)) + modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16) cond = compare_u16_local(value_a, modulus_377_int_array) @@ -456,7 +460,7 @@ def mul_shift_2u16x2x1( mask: The mask to apply to the value. barrett_shift_u8: The number of bits to shift the value. chunk_num_u16: The number of chunks in the u16 value. - chunk_num_u32: The number of chunks in the u32 value. + chunk_num_u32: The number of chunks in the u16 value. chunk_shift_bits: The number of bits to shift the value. vmap_axes: (0, None) means axis 0 is the mapped access, and The rest is not. @@ -654,77 +658,49 @@ def mod_mul_lazy_2u16( return value_c_u16 -def split_view_32_to_16(a: jnp.ndarray): - # Interpret each 32-bit element as two 16-bit numbers - # and reshape to add an extra dimension of size 2. - v = a.view(jnp.uint16).reshape(a.shape + (2,)) - # Assuming little-endian storage, the lower 16 bits are at index 0 - # and the upper 16 bits are at index 1. - lower = v[..., 0] - upper = v[..., 1] - return upper, lower - - -def split_view_32_to_16_8(a: jnp.ndarray): - # First, reshape the 32-bit integers as groups of 4 bytes. - v8 = a.view(jnp.uint8).reshape(a.shape + (4,)) - # Also, reshape as 16-bit integers (2 per 32-bit element) - v16 = a.view(jnp.uint16).reshape(a.shape + (2,)) - # For each 32-bit integer: - # v16[..., 0] gives the lower 16 bits. - # v8[..., 2] gives the third byte (i.e. the lower 8 bits of the upper 16 bits) - lower = v16[..., 0] - upper8 = v8[..., 2] - return upper8, lower - - -# Reduce via RNS modulus -@jax.named_call @functools.partial( jax.jit, - static_argnames="moduli_t", + static_argnames=("moduli", "s", "m"), ) -def moduli_rns_red_internal_2u16(vals, moduli_t=util.RNS_MODULI_T): - """Reduce via RNS modulus. +def barret_reduction_u32( + z, moduli=util.MODULI, s=util.S_BARRETT, m=util.M_BARRETT +): + """Vectorized implementation of the Barrett reduction. + + Works for modulus `q` less than 31 bits. + + This implementation sets the internal shift width `w` to `min(s, 32)` so it + works with small modulus `moduli < 2^16`. Args: - vals: The values to reduce. - moduli_t: The moduli for the target. + z: The input value. + moduli: The RNS moduli. + s: The bit width of moduli. + m: The precomputed value for Barrett reduction. Returns: - The reduced values. + The result of the Barrett reduction. """ - # See jaxite_ec/advanced_algorithm/rns_red.py for description - moduli_t = jnp.array(moduli_t, dtype=jnp.uint8) - u1, l1 = split_view_32_to_16(vals) - i1 = jnp.add( - l1.astype(jnp.uint32), - jnp.multiply(u1.astype(jnp.uint32), moduli_t), - ) - u2, l2 = split_view_32_to_16_8(i1) - i2 = jnp.add( - l2.astype(jnp.uint32), - jnp.multiply(u2.astype(jnp.uint16), moduli_t).astype( - jnp.uint32 - ), - ) - u3, l3 = split_view_32_to_16_8(i2) - out = jnp.add(l3, jnp.multiply(u3, moduli_t).astype(jnp.uint16)) - return out + m = jnp.array(m, dtype=jnp.uint64) + moduli = jnp.array(moduli, dtype=jnp.uint32) + s = jnp.array(s, dtype=jnp.uint16) + t = ((z.astype(jnp.uint64) * m) >> s).astype(jnp.uint32) + z = z - t * moduli + pred = z >= moduli + return jnp.where(pred, z - moduli, z).astype(jnp.uint16) # Reduce via prime modulus @jax.named_call @functools.partial( jax.jit, - static_argnames=("rns_mat", "moduli_t", "num_moduli", "precision"), + static_argnames=("rns_mat", "num_moduli", "precision"), ) def mod_red_rns_2u16( c_rns_reduced, rns_mat=util.RNS_MAT, - moduli_t=util.RNS_MODULI_T, num_moduli=util.NUM_MODULI, - precision=util.RNS_PRECISION, + precision=util.moduli_precision, ): """Reduce via RNS modulus. @@ -761,35 +737,32 @@ def mod_red_rns_2u16( k, cor_mat, preferred_element_type=jnp.uint32 ) - return moduli_rns_red_internal_2u16(c_corrected, moduli_t) + return barret_reduction_u32(c_corrected) # Multiply, without reducing @jax.named_call @functools.partial( jax.jit, - static_argnames="moduli_t", ) def mul_unreduced_rns_2u16( value_a, value_b, - moduli_t=util.RNS_MODULI_T, ): ab = jnp.multiply(value_a.astype(jnp.uint32), value_b.astype(jnp.uint32)) - return moduli_rns_red_internal_2u16(ab, moduli_t) + return barret_reduction_u32(ab) # Multiply and reduce @jax.named_call @functools.partial( jax.jit, - static_argnames=("rns_mat", "moduli_t"), + static_argnames="rns_mat", ) def mod_mul_rns_2u16( value_a, value_b, rns_mat=util.RNS_MAT, - moduli_t=util.RNS_MODULI_T, ): """Multiply two u16 values with RNS reduction. @@ -802,19 +775,82 @@ def mod_mul_rns_2u16( Returns: The product of the two u16 values. """ - ab = mul_unreduced_rns_2u16(value_a, value_b, moduli_t) - return mod_red_rns_2u16(ab, rns_mat, moduli_t) + ab = mul_unreduced_rns_2u16(value_a, value_b) + return mod_red_rns_2u16(ab, rns_mat) + + +@jax.named_call +@functools.partial( + jax.jit, + static_argnames=("rns_mat", "num_moduli", "precision", "moduli", "s", "m"), +) +def mod_mul_rns_unified_2u16( + value_a, + value_b, + rns_mat=util.RNS_MAT, + num_moduli=util.NUM_MODULI, + precision=util.moduli_precision, + moduli=util.MODULI, + s=util.S_BARRETT, + m=util.M_BARRETT +): + """Multiply two u16 values with RNS reduction. + + Args: + value_a: The first u16 value. + value_b: The second u16 value. + rns_mat: The RNS precompute. + moduli_t: The moduli for the target. + + Returns: + The product of the two u16 values. + """ + m = jnp.array(m, dtype=jnp.uint64) + moduli = jnp.array(moduli, dtype=jnp.uint32) + s = jnp.array(s, dtype=jnp.uint16) + + rns_stacked_mat = jnp.array(rns_mat[0], jnp.uint8) + cor_mat = jnp.array(rns_mat[1], jnp.uint16) + + ab = jnp.multiply(value_a.astype(jnp.uint32), value_b.astype(jnp.uint32)) + t = ((ab.astype(jnp.uint64) * m) >> s).astype(jnp.uint32) + ab = (ab - t * moduli) + pred = ab >= moduli + c_rns_reduced = jnp.where(pred, ab - moduli, ab).astype(jnp.uint16) + + c_target = jnp.matmul( + c_rns_reduced.view(jnp.uint8), + rns_stacked_mat, + preferred_element_type=jnp.uint32, + ) + + mul_res_glb_red_u32 = c_target.reshape(*c_target.shape[:-1], -1, 2) + mul_res_glb_red_u32 = mul_res_glb_red_u32[..., 0] + ( + mul_res_glb_red_u32[..., 1] << 8 + ) + rns_reduce_u32, qe_u32 = jnp.split( + mul_res_glb_red_u32, [num_moduli], axis=1 + ) + + # obtain the high 32 bits from the quotient estimation results qe_u32 + k = (qe_u32 >> precision).astype(jnp.uint16) + c_corrected = rns_reduce_u32 + jnp.matmul( + k, cor_mat, preferred_element_type=jnp.uint32 + ) + + t = ((c_corrected.astype(jnp.uint64) * m) >> s).astype(jnp.uint32) + c_corrected = (c_corrected - t * moduli) + pred2 = c_corrected >= moduli + return jnp.where(pred2, c_corrected - moduli, c_corrected).astype(jnp.uint16) @jax.named_call @functools.partial( jax.jit, - static_argnames="moduli_t", ) def add_rns_2u16( value_a: jax.Array, value_b: jax.Array, - moduli_t=util.RNS_MODULI_T, ): """Add two u16 values with RNS reduction. @@ -826,19 +862,17 @@ def add_rns_2u16( Returns: The sum of the two u16 values. """ - return add_sub_rns_var(value_a, value_b, moduli_t=moduli_t) + return add_sub_rns_var(value_a, value_b) @jax.named_call @functools.partial( jax.jit, - static_argnames="moduli_t", ) def add_rns_3u16( value_a: jax.Array, value_b: jax.Array, value_c: jax.Array, - moduli_t=util.RNS_MODULI_T, ): """Add three u16 values with RNS reduction. @@ -846,12 +880,11 @@ def add_rns_3u16( value_a: The first u16 value. value_b: The second u16 value. value_c: The third u16 value. - moduli_t: The moduli for the target. Returns: The sum of the three u16 values. """ - return add_sub_rns_var(value_a, value_b, value_c, moduli_t=moduli_t) + return add_sub_rns_var(value_a, value_b, value_c) @jax.named_call @@ -862,6 +895,7 @@ def add_rns_3u16( def negate_rns_for_var_add( value_a: jax.Array, moduli_sub=util.MODULI_SUB, + moduli=util.MODULI, ): """Negate a value for use in subtraction. @@ -872,17 +906,19 @@ def negate_rns_for_var_add( value_a: RNS array to negate moduli_sub: Precomputed constants for performing negation, that depend on the target modulus + moduli: RNS moduli Returns: An intermediate representing the negation of values_a in the target modulus in RNS form. - Note: original data precision is 16 bit, using uint32 to avoid overflow + Note: original data precision is 32 bit, using uint64 to avoid overflow """ - moduli_sub = jnp.array(moduli_sub, dtype=jnp.uint32) + moduli_sub = jnp.array(moduli_sub, dtype=jnp.uint16) + moduli = jnp.array(moduli, dtype=jnp.uint32) return jnp.add( - jnp.negative(value_a.astype(jnp.uint16)).astype(jnp.uint32), + jnp.subtract(moduli, value_a.astype(jnp.uint32)), moduli_sub, ) @@ -892,9 +928,10 @@ def negate_rns_for_var_add( jax.jit, static_argnames="moduli_sub", ) -def negate_rns_for_var_add_zero_check( +def negate_rns( value_a: jax.Array, moduli_sub=util.MODULI_SUB, + moduli=util.MODULI, ): """Negate a value for use in subtraction. @@ -905,35 +942,29 @@ def negate_rns_for_var_add_zero_check( value_a: RNS array to negate moduli_sub: Precomputed constants for performing negation, that depend on the target modulus + moduli: RNS moduli Returns: An intermediate representing the negation of values_a in the target modulus in RNS form. - Note: original data precision is 16 bit, using uint32 to avoid overflow + Note: original data precision is 32 bit, using uint64 to avoid overflow """ + moduli_sub = jnp.array(moduli_sub, dtype=jnp.uint16) + moduli = jnp.array(moduli, dtype=jnp.uint32) - moduli_sub = jnp.array(moduli_sub, dtype=jnp.uint32) - a = value_a.astype(jnp.uint16) - - # Compute two's complement negation: for nonzero a, jnp.negative(a) computes - # (2^16 - a). - neg = jnp.negative(a).astype(jnp.uint32) - - # Build a branchless mask: 0 if a==0, 1 otherwise. - mask = (a != 0).astype(jnp.uint32) - - # For nonzero a: (2^16 - a) + moduli_sub; for zero: m + 0 multiplied by 0 - # gives 0. - return (neg + moduli_sub) * mask + negate_a = jnp.add( + jnp.subtract(moduli, value_a.astype(jnp.uint32)), + moduli_sub, + ) + return barret_reduction_u32(negate_a) @jax.named_call @functools.partial( jax.jit, - static_argnames="moduli_t", ) -def add_sub_rns_var(*values, moduli_t=util.RNS_MODULI_T): +def add_sub_rns_var(*values): """Evaluate an static set of additions and subtractions. Subtractions are implemented by calling negate_rns_for_var_add on inputs to @@ -943,7 +974,6 @@ def add_sub_rns_var(*values, moduli_t=util.RNS_MODULI_T): Args: *values: A list of RNS values to accumulate - moduli_t: The moduli for the RNS form. Returns: The RNS form of the evaluation of the expession. @@ -954,23 +984,6 @@ def add_sub_rns_var(*values, moduli_t=util.RNS_MODULI_T): if acc != None: acc = jnp.add(v.astype(jnp.uint32), acc) else: - acc = v.astype(jnp.uint32) + acc = v assert len(values) < 256 - moduli_t = jnp.array(moduli_t, dtype=jnp.uint8) - # u1 < 254 - u1, l1 = split_view_32_to_16_8(acc) - # i1 < 2**16 - 1 + 255t < 2**17 - t for 8 bit t - i1 = jnp.add( - jnp.multiply(u1.astype(np.uint16), moduli_t).astype(jnp.uint32), - l1.astype(jnp.uint32), - ) - # u2 = 0 or 1, but if u2 = 1 then l < 2**16 - t, so 2**16 - t + t < 2**16 - u2, l2 = split_view_32_to_16_8(i1) - return jnp.add(jnp.multiply(u2, moduli_t).astype(jnp.uint16), l2) - - -@functools.partial(jax.jit, static_argnames=("c", "num_moduli")) -def rns_constant(c, num_moduli=util.NUM_MODULI): - assert c >= 0 - assert c < 2**14 # small constants only please - return jnp.repeat(jnp.array([c], dtype=jnp.uint16), num_moduli) + return barret_reduction_u32(acc) diff --git a/jaxite_ec/finite_field_test.py b/jaxite_ec/finite_field_test.py index 208eda0..cb1d8f6 100644 --- a/jaxite_ec/finite_field_test.py +++ b/jaxite_ec/finite_field_test.py @@ -201,6 +201,23 @@ def test_jax_mod_mul_rns_reduction(self): (a_list[i] * b_list[i]) % util.MODULUS_377_INT, ) + def test_jax_mod_mul_rns_unified(self): + """This test case check the jax version (TPU deployment) of the rns reduction based modular multiplication algorithm.""" + batch_size = 16 + a_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)] + b_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)] + + modulus_rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) + a_batch = util.int_list_to_array_rns(a_list) + b_batch = util.int_list_to_array_rns(b_list) + c_batch = ff.mod_mul_rns_unified_2u16(a_batch, b_batch, modulus_rns_mat) + c_list = util.array_rns_to_int_list(c_batch) + for i in range(len(a_list)): + np.testing.assert_equal( + c_list[i] % util.MODULUS_377_INT, + (a_list[i] * b_list[i]) % util.MODULUS_377_INT, + ) + def test_jax_add_rns(self): max_val = [2**16 - 1 for _ in range(util.NUM_MODULI)] max_normal_val = [m - 1 for m in util.MODULI] @@ -210,10 +227,8 @@ def test_jax_add_rns(self): for b in values: jax_a = jnp.array(a, dtype=jnp.uint16).reshape((1, util.NUM_MODULI)) jax_b = jnp.array(b, dtype=jnp.uint16).reshape((1, util.NUM_MODULI)) - jax_sum = ff.add_rns_2u16(jax_a, jax_b, tuple(util.RNS_MODULI_T)) - jax_3sum = ff.add_rns_3u16( - jax_a, jax_b, jax_a, tuple(util.RNS_MODULI_T) - ) + jax_sum = ff.add_rns_2u16(jax_a, jax_b) + jax_3sum = ff.add_rns_3u16(jax_a, jax_b, jax_a) for i in range(util.NUM_MODULI): np.testing.assert_equal( int(jax_sum[0, i]) % util.MODULI[i], diff --git a/jaxite_ec/msm_test.py b/jaxite_ec/msm_test.py index cd78971..12a4b7d 100644 --- a/jaxite_ec/msm_test.py +++ b/jaxite_ec/msm_test.py @@ -17,6 +17,7 @@ script_path = os.path.abspath(sys.argv[0]) script_dir = os.path.dirname(script_path) +config_BLS12_377 = config_file.config_BLS12_377 jax.config.update("jax_traceback_filtering", "off") @@ -26,18 +27,18 @@ ) TEST_PARAMS = [ - ( - "test_4_degree", - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv" - ), - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv" - ), - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv" - ), - ), + # ( + # "test_4_degree", + # os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), + # f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv" + # ), + # os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), + # f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv" + # ), + # os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), + # f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv" + # ), + # ), ( "test_1024_degree", os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), @@ -66,6 +67,7 @@ def twist_coordinates_list(ec_config, coordinates_list): class MSMTest(parameterized.TestCase): + def read_external_file(self, scalar_path, base_path, result_path): scalars = [] with open( @@ -93,9 +95,9 @@ def read_external_file(self, scalar_path, base_path, result_path): result_ref.append(int(row[-1][13:-2], 16)) return scalars, points, result_ref + # @absltest.skip("skip this test for now") @parameterized.named_parameters(*TEST_PARAMS) def test_pippenger_index_selection(self, scalar_path, base_path, result_path): - """Normal version Pippenger.""" scalars, points, result_ref = self.read_external_file( scalar_path, base_path, result_path ) @@ -179,11 +181,13 @@ def test_pippenger_index_selection(self, scalar_path, base_path, result_path): .compile() ) + # HERE msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit) msm_algo.bucket_reduction(bucket_reduction_scan_jit) result = msm_algo.window_merge(window_merge_scan_jit) result = util.jax_point_pack_to_int_point(result) - ec_sys = ec.ECCSWeierstrassXYZZ(config_file.config_BLS12_377) + # TO HERE + ec_sys = ec.ECCSWeierstrassXYZZ(config_BLS12_377) result_affine_point = ec_sys.generate_point(result).convert_to_affine() coordinates = ( result_affine_point[0].get_value(), @@ -191,16 +195,15 @@ def test_pippenger_index_selection(self, scalar_path, base_path, result_path): ) self.assertEqual(coordinates[0], result_ref[0]) self.assertEqual(coordinates[1], result_ref[1]) - - # performance measurement tasks = [ (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), (msm_algo.window_merge, (window_merge_scan_jit,)), ] - profile_name = "normal_pippenger_index_selection" + profile_name = "test_pippenger_index_selection" # copybara: util.profile_jax_functions(tasks, profile_name) + # @absltest.skip("skip this test for now") @parameterized.named_parameters(*TEST_PARAMS) def test_pippenger_index_selection_twisted_edwards( self, scalar_path, base_path, result_path @@ -211,7 +214,7 @@ def test_pippenger_index_selection_twisted_edwards( twisted_points, untwisted_coordinates_indeices = twist_coordinates_list( config_file.config_BLS12_377_t, points ) - assert not untwisted_coordinates_indeices + assert len(untwisted_coordinates_indeices) == 0 slice_length = 4 parallel_num = 4 msm_algo = pippenger.MSMPippengerTwisted(slice_length, parallel_num) @@ -328,16 +331,16 @@ def test_pippenger_index_selection_twisted_edwards( self.assertEqual(coordinates[0], result_ref[0]) self.assertEqual(coordinates[1], result_ref[1]) - # performance measurement tasks = [ (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), (msm_algo.batch_window_summation, (batch_window_summation_jit,)), (msm_algo.window_merge, (window_merge_scan_jit,)), ] - profile_name = "pippenger_index_selection_twisted_edwards" + profile_name = "test_pippenger_index_selection_twisted_edwards" # copybara: util.profile_jax_functions(tasks, profile_name) + # @absltest.skip("skip test for now.") @parameterized.named_parameters(*TEST_PARAMS) def test_pippenger_signed_index_selection_twisted_edwards( self, scalar_path, base_path, result_path @@ -348,7 +351,7 @@ def test_pippenger_signed_index_selection_twisted_edwards( twisted_points, untwisted_coordinates_indeices = twist_coordinates_list( config_file.config_BLS12_377_t, points ) - assert not untwisted_coordinates_indeices + assert len(untwisted_coordinates_indeices) == 0 slice_length = 4 parallel_num = 4 msm_algo = pippenger.MSMPippengerTwistedSigned(slice_length, parallel_num) @@ -467,23 +470,20 @@ def test_pippenger_signed_index_selection_twisted_edwards( ) self.assertEqual(coordinates[0], result_ref[0]) self.assertEqual(coordinates[1], result_ref[1]) - - # performance measurement tasks = [ (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), (msm_algo.batch_window_summation, (batch_window_summation_jit,)), (msm_algo.window_merge, (window_merge_scan_jit,)), ] - profile_name = "pippenger_signed_index_selection_twisted_edwards" + profile_name = "test_pippenger_signed_index_selection_twisted_edwards" # copybara: util.profile_jax_functions(tasks, profile_name) # @absltest.skip("test pass") @parameterized.named_parameters(*TEST_PARAMS) - def test_pippenger_index_rns_selection( + def test_rns_pippenger_index_selection( self, scalar_path, base_path, result_path ): - """RNS version Pippenger - XYZZ.""" scalars, points, result_ref = self.read_external_file( scalar_path, base_path, result_path ) @@ -569,7 +569,7 @@ def test_pippenger_index_rns_selection( result = msm_algo.window_merge(window_merge_scan_jit) result = util.jax_rns_point_pack_to_int_point(result) # TO HERE - ec_sys = ec.ECCSWeierstrassXYZZ(config_file.config_BLS12_377) + ec_sys = ec.ECCSWeierstrassXYZZ(config_BLS12_377) result_affine_point = ec_sys.generate_point(result).convert_to_affine() coordinates = ( result_affine_point[0].get_value(), @@ -577,28 +577,28 @@ def test_pippenger_index_rns_selection( ) self.assertEqual(coordinates[0], result_ref[0]) self.assertEqual(coordinates[1], result_ref[1]) - - # performance measurement tasks = [ (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), (msm_algo.window_merge, (window_merge_scan_jit,)), ] - profile_name = "pippenger_index_rns_selection" + profile_name = "test_rns_pippenger_index_selection" # copybara: util.profile_jax_functions(tasks, profile_name) - # @absltest.skip("has some bug in result") + # @absltest.skip("test pass") @parameterized.named_parameters(*TEST_PARAMS) - def test_pippenger_index_selection_rns_twisted_edwards( + def test_rns_pippenger_index_selection_twisted_edwards( self, scalar_path, base_path, result_path ): scalars, points, result_ref = self.read_external_file( scalar_path, base_path, result_path ) + print(points[0]) + print(points[1]) twisted_points, untwisted_coordinates_indeices = twist_coordinates_list( config_file.config_BLS12_377_t, points ) - assert not untwisted_coordinates_indeices + assert len(untwisted_coordinates_indeices) == 0 slice_length = 4 parallel_num = 4 msm_algo = pippenger_rns.MSMPippengerTwisted(slice_length, parallel_num) @@ -612,7 +612,6 @@ def test_pippenger_index_selection_rns_twisted_edwards( batch_window_num = window_num * parallel_num batch_mem_length = msm_length // parallel_num - bucket_accumulation_index_scan_jit = ( jax.jit( pippenger_rns.bucket_accumulation_index_scan_parallel_algorithm_twisted, @@ -704,6 +703,7 @@ def test_pippenger_index_selection_rns_twisted_edwards( result = msm_algo.window_merge(window_merge_scan_jit) result = util.jax_rns_point_pack_to_int_point(result) # TO HERE + ec_sys = ec.ECCSTwistedEdwardsExtended(config_file.config_BLS12_377_t) result_affine_point = ec_sys.generate_point( result, twist=False @@ -711,18 +711,159 @@ def test_pippenger_index_selection_rns_twisted_edwards( coordinates = ( result_affine_point[0].get_value() % util.MODULUS_377_INT, result_affine_point[1].get_value() % util.MODULUS_377_INT, + result_affine_point[2].get_value() % util.MODULUS_377_INT, + result_affine_point[3].get_value() % util.MODULUS_377_INT, ) self.assertEqual(coordinates[0], result_ref[0]) self.assertEqual(coordinates[1], result_ref[1]) + tasks = [ + (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), + (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), + (msm_algo.batch_window_summation, (batch_window_summation_jit,)), + (msm_algo.window_merge, (window_merge_scan_jit,)), + ] + profile_name = "test_rns_pippenger_index_selection_twisted_edwards" + # copybara: util.profile_jax_functions(tasks, profile_name) + + # @absltest.skip("test pass") + @parameterized.named_parameters(*TEST_PARAMS) + def test_rns_pippenger_signed_index_selection_twisted_edwards( + self, scalar_path, base_path, result_path + ): + scalars, points, result_ref = self.read_external_file( + scalar_path, base_path, result_path + ) + twisted_points, untwisted_coordinates_indeices = twist_coordinates_list( + config_file.config_BLS12_377_t, points + ) + assert len(untwisted_coordinates_indeices) == 0 + slice_length = 4 + parallel_num = 4 + msm_algo = pippenger_rns.MSMPippengerTwistedSigned( + slice_length, parallel_num + ) + msm_algo.initialize(scalars, twisted_points) + + window_num = msm_algo.window_num + bucket_num_per_window = msm_algo.bucket_num_per_window + msm_length = msm_algo.msm_length + coordinate_num = msm_algo.coordinate_num + chunk_num = util.NUM_MODULI + + batch_window_num = window_num * parallel_num + batch_mem_length = msm_length // parallel_num + + bucket_accumulation_index_scan_jit = ( + jax.jit( + pippenger_rns.bucket_accumulation_signed_index_scan_parallel_algorithm_twisted, + static_argnames="msm_length", + ) + .lower( + jax.ShapeDtypeStruct( + ( + coordinate_num, + batch_window_num, + bucket_num_per_window, + chunk_num, + ), + dtype=jnp.uint16, + ), + jax.ShapeDtypeStruct( + (batch_mem_length, coordinate_num, parallel_num, chunk_num), + dtype=jnp.uint16, + ), + jax.ShapeDtypeStruct( + (batch_mem_length, batch_window_num), dtype=jnp.uint16 + ), + jax.ShapeDtypeStruct( + (batch_mem_length, batch_window_num), dtype=jnp.uint8 + ), + batch_mem_length, + ) + .compile() + ) + + bucket_reduction_scan_jit = ( + jax.jit( + pippenger_rns.bucket_reduction_scan_algorithm_twisted, + static_argnames="bucket_num_in_window", + ) + .lower( + jax.ShapeDtypeStruct( + ( + coordinate_num, + batch_window_num, + bucket_num_per_window, + chunk_num, + ), + dtype=jnp.uint16, + ), + jax.ShapeDtypeStruct( + (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 + ), + jax.ShapeDtypeStruct( + (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 + ), + bucket_num_per_window, + ) + .compile() + ) - # performance measurement + batch_window_summation_jit = ( + jax.jit( + pippenger_rns.batch_window_summation_algorithm_twisted, + static_argnames="point_parallel", + ) + .lower( + jax.ShapeDtypeStruct( + (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 + ), + jax.ShapeDtypeStruct( + (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 + ), + parallel_num, + ) + .compile() + ) + + window_merge_scan_jit = ( + jax.jit( + pippenger_rns.window_merge_scan_algorithm_twisted, + static_argnames="slice_length", + ) + .lower( + jax.ShapeDtypeStruct( + (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 + ), + slice_length, + ) + .compile() + ) + + # HERE + msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit) + msm_algo.bucket_reduction(bucket_reduction_scan_jit) + msm_algo.batch_window_summation(batch_window_summation_jit) + result = msm_algo.window_merge(window_merge_scan_jit) + result = util.jax_rns_point_pack_to_int_point(result) + # TO HERE + ec_sys = ec.ECCSTwistedEdwardsExtended(config_file.config_BLS12_377_t) + result_affine_point = ec_sys.generate_point( + result, twist=False + ).convert_to_affine() + coordinates = ( + result_affine_point[0].get_value(), + result_affine_point[1].get_value(), + ) + self.assertEqual(coordinates[0], result_ref[0]) + self.assertEqual(coordinates[1], result_ref[1]) tasks = [ (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), (msm_algo.batch_window_summation, (batch_window_summation_jit,)), (msm_algo.window_merge, (window_merge_scan_jit,)), ] - profile_name = "pippenger_index_selection_rns_twisted_edwards" + profile_name = "test_rns_pippenger_signed_index_selection_twisted_edwards" # copybara: util.profile_jax_functions(tasks, profile_name) diff --git a/jaxite_ec/pippenger.py b/jaxite_ec/pippenger.py index 63d5372..13293b5 100644 --- a/jaxite_ec/pippenger.py +++ b/jaxite_ec/pippenger.py @@ -723,17 +723,6 @@ def padd(partial_sum, single_point): return jec.padd_lazy_twisted_pack(partial_sum, single_point) -def padd_with_pdul_check(partial_sum, single_point): - # coordinate_dim, batch_dim, precision_dim = partial_sum.shape - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_lazy_twisted_pack(partial_sum, single_point) - double_partial_sum = jec.pdul_lazy_twisted_pack(partial_sum) - cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape( - 1, batch_dim, 1 - ) - return jnp.where(cond_equal, double_partial_sum, new_partial_sum) - - def bucket_accumulation_index_scan_parallel_algorithm_twisted( all_buckets: jnp.ndarray, all_points: jnp.ndarray, @@ -788,7 +777,7 @@ def bucket_reduction_scan_algorithm_twisted( def scan_body(temp_and_window_sum_pack, buckets): temp_sum, window_sum = temp_and_window_sum_pack temp_sum = padd(temp_sum, buckets) - window_sum = padd_with_pdul_check(window_sum, temp_sum) + window_sum = padd(window_sum, temp_sum) return (temp_sum, window_sum), None (_, window_sum), _ = jax.lax.scan( @@ -828,10 +817,10 @@ def window_merge_scan_algorithm_twisted( ): """Scan version WM.""" coordinate_dim, window_dim, precision_dim = window_sum.shape - window_sum = window_sum.transpose(1, 0, 2) - result = window_sum[window_dim - 1, :, :].reshape( - (coordinate_dim, 1, precision_dim) + window_sum = window_sum.transpose(1, 0, 2).reshape( + (window_dim, coordinate_dim, 1, precision_dim) ) + result = window_sum[window_dim - 1] def fori_loop_body(_, result): result = jec.pdul_lazy_twisted_pack(result) @@ -839,9 +828,7 @@ def fori_loop_body(_, result): def scan_body(result, window_sum): result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result) - result = jec.padd_lazy_twisted_pack( - result, window_sum.reshape((coordinate_dim, 1, util.U16_EXT_CHUNK_NUM)) - ) + result = jec.padd_lazy_twisted_pack(result, window_sum) return result, None result, _ = jax.lax.scan( @@ -1001,6 +988,7 @@ def bucket_accumulation(self, bucket_accumulation_index_func): self.all_buckets = bucket_accumulation_index_func( self.all_buckets, self.all_points, self.selection_index_list ) + # for i in range(self.coordinate_num): return self.all_buckets def bucket_reduction(self, bucket_reduction_func): diff --git a/jaxite_ec/pippenger_rns.py b/jaxite_ec/pippenger_rns.py index fd27be3..87db8ce 100644 --- a/jaxite_ec/pippenger_rns.py +++ b/jaxite_ec/pippenger_rns.py @@ -715,21 +715,6 @@ def construct_br_zero_states(self, bucket_zero_states): ######################### -def padd(partial_sum, single_point): - return jec.padd_rns_twisted_pack(partial_sum, single_point) - - -def padd_with_pdul_check(partial_sum, single_point): - # coordinate_dim, batch_dim, precision_dim = partial_sum.shape - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_rns_twisted_pack(partial_sum, single_point) - double_partial_sum = jec.pdul_rns_twisted_pack(partial_sum) - cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape( - 1, batch_dim, 1 - ) - return jnp.where(cond_equal, double_partial_sum, new_partial_sum) - - def bucket_accumulation_index_scan_parallel_algorithm_twisted( all_buckets: jnp.ndarray, all_points: jnp.ndarray, @@ -755,7 +740,7 @@ def scan_body(buckets, point_with_cond_pack): selective_buckets = buckets[ :, jnp.arange(batch_window_dim), selection_index, : ] - selective_update = padd(selective_buckets, point) + selective_update = jec.padd_rns_twisted_pack(selective_buckets, point) return ( buckets.at[:, jnp.arange(batch_window_dim), selection_index, :].set( selective_update @@ -779,12 +764,12 @@ def bucket_reduction_scan_algorithm_twisted( bucket_num_in_window: int, ): """Scan version BR.""" - all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0) + all_buckets = all_buckets.transpose(2, 0, 1, 3) def scan_body(temp_and_window_sum_pack, buckets): temp_sum, window_sum = temp_and_window_sum_pack - temp_sum = padd(temp_sum, buckets) - window_sum = padd_with_pdul_check(window_sum, temp_sum) + temp_sum = jec.padd_rns_twisted_pack(temp_sum, buckets) + window_sum = jec.padd_rns_twisted_pack(window_sum, temp_sum) return (temp_sum, window_sum), None (_, window_sum), _ = jax.lax.scan( @@ -792,6 +777,7 @@ def scan_body(temp_and_window_sum_pack, buckets): (temp_sum, window_sum), all_buckets[:bucket_num_in_window], length=bucket_num_in_window, + reverse=True, ) return window_sum @@ -810,7 +796,9 @@ def batch_window_summation_algorithm_twisted( ).transpose(1, 0, 2, 3) def scan_body(batch_window_sum, single_window_sum): - batch_window_sum = padd(batch_window_sum, single_window_sum) + batch_window_sum = jec.padd_rns_twisted_pack( + batch_window_sum, single_window_sum + ) return batch_window_sum, None batch_window_sum, _ = jax.lax.scan( @@ -998,6 +986,7 @@ def bucket_accumulation(self, bucket_accumulation_index_func): self.all_buckets = bucket_accumulation_index_func( self.all_buckets, self.all_points, self.selection_index_list ) + # for q in range(self.coordinate_num): return self.all_buckets def bucket_reduction(self, bucket_reduction_func): @@ -1039,7 +1028,14 @@ def batch_window_summation(self, batch_window_summation_func): return self.window_sum def window_merge(self, window_merge_func): - """Merge the windows to form the final elliptic curve.""" + """Merge the windows to form the final elliptic curve. + + Args: + window_merge_func: The function to merge the windows. + + Returns: + The final elliptic curve. + """ self.result = window_merge_func(self.window_sum) return self.result @@ -1055,3 +1051,260 @@ def construct_ba_selection(self): selection_index.append(slice_index) selection_index_list.append(deepcopy(selection_index)) return selection_index_list + + +######################### +# Functions for Signed bucket + twisted curve +######################### +def padd_with_sign(partial_sum, single_point, sign): + neg_single_point = jec.pneg_rns_twisted_pack(single_point) + _, batch_dim, _ = partial_sum.shape + cond_neg = jnp.equal(sign, 1).reshape(1, batch_dim, 1) + signed_point = jnp.where(cond_neg, neg_single_point, single_point) + result = jec.padd_rns_twisted_pack(partial_sum, signed_point) + return result + + +def bucket_accumulation_signed_index_scan_parallel_algorithm_twisted( + all_buckets: jnp.ndarray, + all_points: jnp.ndarray, + selection_index_list: jnp.ndarray, + selection_sign_list: jnp.ndarray, + msm_length: int, +): + """Scan version BA with index selection.""" + window_selection = jnp.arange(256) + coordinate_dim, batch_window_dim, _, precision_dim = all_buckets.shape + _, _, parallel_dim, _ = ( + all_points.shape + ) # (serial_dim, coordinate_dim, parallel_dim, precision_dim) + single_window_dim = batch_window_dim // parallel_dim + + def scan_body(buckets, point_with_cond_pack): + point, selection_index, selection_sign = point_with_cond_pack + # point = jnp.repeat(point, repeats=single_window_dim, axis=1) + point = jax.lax.broadcast_in_dim( + point, + (coordinate_dim, parallel_dim, single_window_dim, precision_dim), + (0, 1, 3), + ) + point = point.reshape((coordinate_dim, batch_window_dim, precision_dim)) + selective_buckets = buckets[:, window_selection, selection_index, :] + selective_update = padd_with_sign(selective_buckets, point, selection_sign) + return ( + buckets.at[:, window_selection, selection_index, :].set( + selective_update + ), + None, + ) + + all_buckets, _ = jax.lax.scan( + scan_body, + all_buckets, + (all_points, selection_index_list, selection_sign_list), + length=msm_length, + ) + return all_buckets + + +class MSMPippengerTwistedSigned: + """Pippenger algorithm for elliptic curves with twisted and signed points. + + Attributes: + coordinate_num: The number of coordinates in the elliptic curve. + slice_length: The length of each slice in the elliptic curve. + point_parallel: The number of parallel points in the elliptic curve. + window_num: The number of windows in the elliptic curve. + batch_window_num: The number of batch windows in the elliptic curve. + bucket_num_per_window: The number of buckets in each window. + slice_mask: The mask for the slices in the elliptic curve. + blank_point: A JAX array of zeros, used to initialize the buckets. + all_buckets: A JAX array of all the buckets in the elliptic curve. + points: A list of JAX arrays, where each array represents an Orignal point + from the trace. + scalars: A list of integers, where each integer represents an Orignal scalar + from the trace. + all_points: A JAX array of all the points in the elliptic curve. from the + trace. + window_sum: A JAX array of the window sum. + zero_states_list: A JAX array of the zero states for the buckets. + selection_list: A JAX array of the selection states for the buckets. + selection_index_list: A JAX array of the selection index for the buckets. + selection_sign_list: A JAX array of the selection sign for the buckets. + all_points: A JAX array of all the points in the elliptic curve. from the + trace. + msm_length: The length of the MSM trace. + result: The final elliptic curve. + rns_mat: The rns matrix used for padding and doubling. + """ + + def __init__(self, slice_length: int, point_parallel: int): + self.coordinate_num = util.COORDINATE_NUM + + self.slice_length = slice_length + self.point_parallel = point_parallel + self.window_num = int(math.ceil(254 / self.slice_length)) # + self.batch_window_num = self.window_num * self.point_parallel + self.bucket_num_per_window = 2 ** (self.slice_length - 1) + self.slice_mask = 2**self.slice_length - 1 + self.blank_point = ( + util.int_list_to_array_rns([0, 1, 1, 0]) + .reshape(self.coordinate_num, 1, util.NUM_MODULI) + .astype(jnp.uint16) + ) + + self.all_buckets = jnp.broadcast_to( + self.blank_point.reshape( + 1, self.coordinate_num, 1, util.NUM_MODULI + ).transpose(1, 0, 2, 3), + ( + self.coordinate_num, + self.batch_window_num, + self.bucket_num_per_window, + util.NUM_MODULI, + ), + ) + + self.window_sum: jnp.ndarray + + self.msm_length = 0 + + self.zero_states_list: jnp.ndarray + self.selection_list: jnp.ndarray + self.selection_index_list: jnp.ndarray + self.selection_sign_list: jnp.ndarray + self.all_points: jnp.ndarray + + self.scalars: List[int] = [] # Orignal scalar from the trace + # [Points, Points, ..., Points] + self.points: List[jnp.ndarray] = [] # Orignal points from the trace + self.rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) + + self.result = None + + def initialize(self, scalars, points): + """Initialize the Pippenger algorithm. + + Args: + scalars: A list of integers, where each integer represents an Orignal + scalar from the trace. + points: A list of JAX arrays, where each array represents an Orignal point + from the trace. + """ + # Initial internal selection from the scalar + self.scalars = scalars + self.msm_length = len(scalars) + + # Convert high-precision points into a vector of low-precision chunks + self.points = [ + util.int_list_to_array_rns(coordinates) for coordinates in points + ] + self.all_points = jnp.array(self.points).astype(jnp.uint16) + _, coordinate_dim, precision_dim = self.all_points.shape + + # For BA + selection_index_pylist, selection_sign_pylist = ( + self.construct_ba_selection_with_sign() + ) + self.selection_index_list = jnp.asarray(selection_index_pylist).astype( + jnp.uint16 + ) + self.selection_sign_list = jnp.array(selection_sign_pylist, dtype=jnp.uint8) + _, window_dim = self.selection_index_list.shape + + # Batch construction + self.all_points = self.all_points.reshape( + (-1, self.point_parallel, coordinate_dim, precision_dim) + ).transpose(0, 2, 1, 3) + self.selection_index_list = self.selection_index_list.reshape( + (-1, window_dim * self.point_parallel) + ) + self.selection_sign_list = self.selection_sign_list.reshape( + (-1, window_dim * self.point_parallel) + ) + + def bucket_accumulation(self, bucket_accumulation_index_algorithm): + """BA index selection version.""" + self.all_buckets = bucket_accumulation_index_algorithm( + self.all_buckets, + self.all_points, + self.selection_index_list, + self.selection_sign_list, + ) + + return self.all_buckets + + def bucket_reduction(self, bucket_reduction_func): + """Reduce the buckets to a single point for each window.""" + temp_sum = jnp.broadcast_to( + self.blank_point, + ( + self.coordinate_num, + self.batch_window_num, + util.NUM_MODULI, + ), + ) + window_sum = jnp.broadcast_to( + self.blank_point, + ( + self.coordinate_num, + self.batch_window_num, + util.NUM_MODULI, + ), + ) + self.window_sum = bucket_reduction_func( + self.all_buckets, temp_sum, window_sum + ) + return self.window_sum + + def batch_window_summation(self, batch_window_summation_algorithm): + """Sum the batch windows to form the final window sum.""" + batch_window_sum = jnp.broadcast_to( + self.blank_point, + ( + self.coordinate_num, + self.window_num, + util.NUM_MODULI, + ), + ) + self.window_sum = batch_window_summation_algorithm( + batch_window_sum, self.window_sum + ) + return self.window_sum + + def window_merge(self, window_merge_func): + """Merge the windows to form the final elliptic curve.""" + self.result = window_merge_func(self.window_sum) + return self.result + + def construct_ba_selection_with_sign(self): + """Construct the selection index and sign for the bucket accumulation (BA) step. + + Returns: + A tuple of two lists: the selection index for the bucket accumulation, and + the selection sign for the bucket accumulation. + """ + selection_index_list = [] # Used for index selection + selection_sign_list = [] + slice_max = 2**self.slice_length + slice_half = 2 ** (self.slice_length - 1) + for scalar in self.scalars: + # Compute the zero states for each scalar by time dependence + selection_index = [] + selection_sign = [] + carry = 0 + for w in range(self.window_num): + slice_index = (scalar >> (w * self.slice_length)) & self.slice_mask + slice_index = slice_index + carry + if slice_index >= slice_half: + new_slice_index = abs(slice_index - slice_max) + carry = 1 + else: + new_slice_index = slice_index + carry = 0 + selection_index.append(new_slice_index - 1) + selection_sign.append(carry) + assert carry == 0 + selection_index_list.append(deepcopy(selection_index)) + selection_sign_list.append(deepcopy(selection_sign)) + return selection_index_list, selection_sign_list diff --git a/jaxite_ec/util.py b/jaxite_ec/util.py index 17afce2..69827a5 100644 --- a/jaxite_ec/util.py +++ b/jaxite_ec/util.py @@ -4,7 +4,6 @@ jitted. """ -import csv import json import math from typing import Any, Callable, List, Tuple @@ -55,80 +54,6 @@ # Pippenger Logics COORDINATE_NUM = 4 -# RNS Reduction Logics -# Hardware friendly moduli factors are 2**16 - v for v in the following list -RNS_MODULI_T = ( - 0, - 1, - 3, - 5, - 9, - 15, - 17, - 27, - 33, - 39, - 45, - 47, - 57, - 59, - 63, - 77, - 87, - 89, - 99, - 105, - 113, - 117, - 123, - 125, - 129, - 143, - 153, - 155, - 165, - 167, - 173, - 179, - 183, - 189, - 197, - 209, - 213, - 215, - 225, - 227, - 243, - 249, - 14, - 38, - 50, - 54, - 98, - 102, - 110, - 122, -) - -MODULI = tuple([ - 2**16 if i == 0 else 2**16 - int(i) if i % 2 == 1 else 2**15 - (int(i) // 2) - for i in RNS_MODULI_T -]) - - -RNS_PRECISION = 16 -NUM_MODULI = len(RNS_MODULI_T) -# Maximum number of consecutive additions/subtractions -ADDITION_BOUND = 4 - -# Warning: specific to target modulus and addition bound -MODULI_SUB = tuple([ - ((512 * NUM_MODULI * MODULUS_377_INT * ADDITION_BOUND) - 2**16) % m - for m in MODULI -]) -TWIST_D_RNS = tuple([TWIST_D_INT % MODULI[i] for i in range(len(MODULI))]) - - #################################### # Utility Functions #################################### @@ -150,7 +75,7 @@ def array_to_int(jax_array: jax.Array, base) -> int: def int_to_array( - python_int, base=BASE, dtype=jnp.uint16, array_size=U16_CHUNK_NUM + python_int, base=BASE, dtype=BASE_TYPE, array_size=U16_CHUNK_NUM ): """Converts a Python integer to a JAX array.""" mask = (1 << base) - 1 @@ -165,7 +90,7 @@ def int_to_array( assert array_size >= len(elements) elements = elements[:array_size] + [0] * (array_size - len(elements)) - return jnp.array(elements, dtype=dtype) + return jnp.array(elements, dtype=dtype).astype(BASE_TYPE) def array_to_int_list(jax_array, base): @@ -183,16 +108,16 @@ def int_list_to_array(int_list, base=BASE, array_size=U16_CHUNK_NUM): chunked_arrays = [] for int_value in int_list: chunked_arrays.append(int_to_array(int_value, base, array_size=array_size)) - return jnp.array(chunked_arrays) + return jnp.array(chunked_arrays).astype(BASE_TYPE) def int_point_to_jax_point_pack( - coordinates: List[int], base=BASE, chunk_num=U16_CHUNK_NUM + coordinates: List[int], base=BASE, array_size=U16_CHUNK_NUM ): result = [] for i in range(len(coordinates)): - result.append(int_to_array(coordinates[i], base, array_size=chunk_num)) - return jnp.array(result) + result.append(int_to_array(coordinates[i], base, array_size=array_size)) + return jnp.array(result).astype(BASE_TYPE) def jax_point_pack_to_int_point(point: jax.Array): @@ -219,7 +144,7 @@ def int_list_to_array_rns(int_list) -> jnp.ndarray: limbs = [] for int_value in int_list: limbs.append(int_to_array_rns(int_value)) - return jnp.array(limbs) + return jnp.array(limbs).astype(BASE_TYPE) def array_rns_to_int_list(jax_array): @@ -236,7 +161,7 @@ def int_point_to_jax_rns_point_pack(coordinates: List[int]): result = [] for i in range(len(coordinates)): result.append(int_to_array_rns(coordinates[i])) - return jnp.array(result) + return jnp.array(result).astype(BASE_TYPE) def jax_rns_point_pack_to_int_point(point: jax.Array): @@ -249,11 +174,11 @@ def jax_rns_point_pack_to_int_point(point: jax.Array): def int_point_batch_to_jax_point_pack( - points: List[List[int]], base=BASE, chunk_num=U16_CHUNK_NUM + points: List[List[int]], base=BASE, array_size=U16_CHUNK_NUM ): result = [] for i in range(len(points)): - result.append(int_point_to_jax_point_pack(points[i], base, chunk_num)) + result.append(int_point_to_jax_point_pack(points[i], base, array_size)) return jnp.transpose(jnp.array(result), (1, 0, 2)) @@ -324,7 +249,7 @@ def to_tuple(a): # The following function achieves the same function as int_to_array, but it # can be pre-run (Google restriction), and returns a tuple. def int_to_precomputed_array( - python_int, base=BASE, dtype=jnp.uint16, array_size=U16_CHUNK_NUM + python_int, base=BASE, dtype=BASE_TYPE, array_size=U16_CHUNK_NUM ): """Converts a Python integer to a JAX array.""" mask = (1 << base) - 1 @@ -456,7 +381,7 @@ def find_moduli(total_modulus, precision): overall_moduli = [] overall_constant_offset = [] overall_modulus = 1 - for i in range(2 ** (precision >> 1) - 1): + for i in range(2 ** precision - 1): cur_moduli = initial_moduli - i if math.gcd(cur_moduli, overall_modulus) == 1: overall_moduli.append(cur_moduli) @@ -465,17 +390,59 @@ def find_moduli(total_modulus, precision): if overall_modulus > total_modulus: return to_tuple(overall_moduli), to_tuple(overall_constant_offset) - # Find 2**15 - v too - initial_moduli = 2 ** (precision - 1) - if overall_modulus < total_modulus: - for i in range(2 ** (precision >> 1) - 1): - cur_moduli = initial_moduli - i - if math.gcd(cur_moduli, overall_modulus) == 1: - overall_moduli.append(cur_moduli) - overall_constant_offset.append(i << 1) - overall_modulus *= cur_moduli - if overall_modulus > total_modulus: - return to_tuple(overall_moduli), to_tuple(overall_constant_offset) + return to_tuple(overall_moduli), to_tuple(overall_constant_offset) + + +def find_moduli_specified_number(total_number, precision): + """Finds a list of moduli close to the given precision. + + Args: + total_number: The total number of moduli requirement. + precision: The desired precision of the moduli. + + Returns: + A tuple containing two lists: + - overall_moduli: A list of moduli close to the given precision. + - overall_constant_offset: A list of constant offsets for the moduli. + """ + initial_moduli = 2**precision + overall_moduli = [] + overall_modulus = 1 + for i in range(1, 2 ** precision - 1): + cur_moduli = initial_moduli - i + if math.gcd(cur_moduli, overall_modulus) == 1: + overall_moduli.append(cur_moduli) + overall_modulus *= cur_moduli + if len(overall_moduli) >= total_number: + return to_tuple(overall_moduli) + + return to_tuple(overall_moduli) + + +def find_moduli_barrett(total_modulus, precision): + """Finds a list of moduli close to the given precision. + + Args: + total_modulus: The target modulus. + precision: The desired precision of the moduli. + + Returns: + A tuple containing two lists: + - overall_moduli: A list of moduli close to the given precision. + - overall_constant_offset: A list of constant offsets for the moduli. + """ + initial_moduli = 2**precision + overall_moduli = [] + overall_constant_offset = [] + overall_modulus = 1 + for i in range(1, 2 ** (precision >> 1) - 1): + cur_moduli = initial_moduli - i + if math.gcd(cur_moduli, overall_modulus) == 1: + overall_moduli.append(cur_moduli) + overall_constant_offset.append(i) + overall_modulus *= cur_moduli + if overall_modulus > total_modulus: + return to_tuple(overall_moduli), to_tuple(overall_constant_offset) return to_tuple(overall_moduli), to_tuple(overall_constant_offset) @@ -532,9 +499,8 @@ def rns_coefficients_precompute( ] rns_mat = np.array( - icrt_factors_byteshifted_modq_rns, dtype=np.uint16 + icrt_factors_byteshifted_modq_rns, dtype=BASE_TYPE ).reshape(-1, num_residues) - # calculate quotient estimation fix_point = 1 << moduli_precision @@ -544,46 +510,46 @@ def rns_coefficients_precompute( shifted_quotient_estimations.append( [math.ceil((chunk * fix_point) / overall_modulus)] ) - sqe_mat = np.array(shifted_quotient_estimations, dtype=np.uint16) + sqe_mat = np.array(shifted_quotient_estimations, dtype=BASE_TYPE) cor_mat = np.array( - [to_rns(-overall_modulus % q, overall_moduli)], dtype=np.uint16 + [to_rns(-overall_modulus % q, overall_moduli)], dtype=BASE_TYPE ) - # Convert rns_mat and sqe_mat into various bytes. - # Version 1: split precision into different chunks. - # rns_mat_u8 = rns_mat.view(np.uint8).reshape(*rns_mat.shape, num_bytes) - # seq_mat_u8 = sqe_mat.view(np.uint8).reshape(*sqe_mat.shape, num_bytes) - # rns_stack_mat_u8 = np.hstack(( - # rns_mat_u8[..., 0], - # seq_mat_u8[..., 0], - # rns_mat_u8[..., 1], - # seq_mat_u8[..., 1], - # )) - # Version 2: interleave precision -- tested to be faster. rns_stack_mat_u8 = np.hstack( (rns_mat.view(jnp.uint8), sqe_mat.view(jnp.uint8)) ) return to_tuple(rns_stack_mat_u8.tolist()), to_tuple(cor_mat.tolist()) -def get_parts(u16mat): - assert u16mat.dtype == np.uint16 - u16bytes = u16mat.view(np.uint8) - return [u16bytes[:, ::2], u16bytes[:, 1::2]] - - -M = MODULUS_377_INT * MODULUS_377_INT * 256 * 256 * 50 * 50 * 4 * 2 moduli_precision = 16 -num_bytes = moduli_precision // 8 # 2 +num_bytes = math.ceil(moduli_precision / 8) # 2 +num_residues_for_q = ( + int(MODULUS_377_INT).bit_length() + moduli_precision - 1 +) // moduli_precision +NUM_MODULI = 56 +extra_bit_to_avoid_addition_overflow = 4 +minimal_modulus = ( + MODULUS_377_INT + * 256 + * NUM_MODULI + * 4 + * 2 + * extra_bit_to_avoid_addition_overflow +) ** 2 * 2 + +# ((MODULUS_377_INT * 255 * num_bytes * num_residues_for_q)**2) +# * extra_bit_to_avoid_addition_overflow * 2 +# M = MODULUS_377_INT * MODULUS_377_INT * 256 * 256 * 50 * 50 * 4 * 2 # hardware friendly moduli is 2**precision - t # overall_moduli is the jax.array of "2**precision - t" # overall_constant_offset is the jax.array of "t" -overall_moduli, overall_constant_offset = find_moduli(M, moduli_precision) +overall_moduli = find_moduli_specified_number(NUM_MODULI, moduli_precision) M = 1 for moduli in overall_moduli: M *= moduli M = int(M) +assert M > minimal_modulus assert len(overall_moduli) == ( (M.bit_length() + moduli_precision - 1) // moduli_precision ) @@ -612,7 +578,6 @@ def construct_rns_matrix(q): MODULI = overall_moduli -RNS_MODULI_T = overall_constant_offset RNS_MAT = (RNS_STACK_MAT_NEW, COR_MAT_NEW) MODULUS_377_INT_CHUNK = int_to_precomputed_array( MODULUS_377_INT, base=BASE, array_size=U16_CHUNK_NUM @@ -623,6 +588,28 @@ def construct_rns_matrix(q): TWIST_D_INT_CHUNK = int_to_precomputed_array( TWIST_D_INT, base=BASE, array_size=U16_EXT_CHUNK_NUM ) +TWIST_D_INT_CHUNK_BARRETT = int_to_precomputed_array( + TWIST_D_INT, base=BASE, array_size=U16_CHUNK_NUM +) MODULUS_377_S16_INT_CHUNK = int_to_precomputed_array( MODULUS_377_S16_INT, base=BASE, array_size=U16_EXT_CHUNK_NUM ) +# Maximum number of consecutive additions/subtractions +# Warning: specific to target modulus and addition bound +# SUB_MODULI_CONSTANT = (MODULUS_377_INT << 24) +# print(RNS_MODULI_T) +# print(MODULI) +# print((256*NUM_MODULI *4* 2*MODULUS_377_INT) - 2**RNS_PRECISION) +MODULI_SUB = tuple([ + ((256 * NUM_MODULI * 4 * 2 * MODULUS_377_INT)) % m + # SUB_MODULI_CONSTANT % m + for m in MODULI +]) + +TWIST_D_RNS = tuple([TWIST_D_INT % MODULI[i] for i in range(len(MODULI))]) + +S_BARRETT = to_tuple([2 * math.ceil(math.log2(q)) for q in MODULI]) +M_BARRETT = [math.floor(2**s / q) for (s, q) in zip(S_BARRETT, MODULI)] +W_BARRETT = to_tuple([min(s, 32) for s in S_BARRETT]) +MASK_BARRETT = to_tuple([2**w - 1 for w in W_BARRETT]) +S_W_BARRETT = to_tuple([s - w for (s, w) in zip(S_BARRETT, W_BARRETT)])