Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ tpu_test(
"jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv",
],
python_version = "PY3",
shard_count = 3,
shard_count = 6,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
Expand Down
229 changes: 209 additions & 20 deletions jaxite_ec/elliptic_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jaxite.jaxite_ec import util



add_3u16 = finite_field.add_3u16
add_2u16 = finite_field.add_2u16
sub_2u16 = finite_field.sub_2u16
Expand All @@ -33,15 +34,12 @@
add_rns_3u16 = finite_field.add_rns_3u16
add_sub_rns_var = finite_field.add_sub_rns_var
negate_rns_for_var_add = finite_field.negate_rns_for_var_add
negate_rns_for_var_add_zero_check = (
finite_field.negate_rns_for_var_add_zero_check
)
rns_constant = finite_field.rns_constant
negate_rns = finite_field.negate_rns


# Barrett Reduction Based Functions
@jax.named_call
def padd_barret_xyzz(
def padd_barrett_xyzz(
x1: jax.Array,
y1: jax.Array,
zz1: jax.Array,
Expand Down Expand Up @@ -108,7 +106,7 @@ def padd_barret_xyzz(


@jax.named_call
def pdul_barret_xyzz(
def pdul_barrett_xyzz(
x1: jax.Array, y1: jax.Array, zz1: jax.Array, zzz1: jax.Array
):
"""PDUL-BARRET elliptic curve operation with packed arguments.
Expand Down Expand Up @@ -163,7 +161,7 @@ def pdul_barret_xyzz(

@jax.named_call
def pdul_barrett_xyzz_pack(x1_y1_zz1_zzz1: jax.Array):
return pdul_barret_xyzz(
return pdul_barrett_xyzz(
x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1], x1_y1_zz1_zzz1[2], x1_y1_zz1_zzz1[3]
)

Expand All @@ -172,7 +170,7 @@ def pdul_barrett_xyzz_pack(x1_y1_zz1_zzz1: jax.Array):
def padd_barrett_xyzz_pack(
x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array
):
return padd_barret_xyzz(
return padd_barrett_xyzz(
x1_y1_zz1_zzz1[0],
x1_y1_zz1_zzz1[1],
x1_y1_zz1_zzz1[2],
Expand All @@ -188,7 +186,7 @@ def padd_barrett_xyzz_pack(
def pdul_barrett_xyzz_pack_batch_first(
x1_y1_zz1_zzz1: jax.Array, transpose=(0, 1, 2)
):
return pdul_barret_xyzz(
return pdul_barrett_xyzz(
x1_y1_zz1_zzz1[:, 0],
x1_y1_zz1_zzz1[:, 1],
x1_y1_zz1_zzz1[:, 2],
Expand All @@ -200,7 +198,7 @@ def pdul_barrett_xyzz_pack_batch_first(
def padd_barrett_xyzz_pack_batch_first(
x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array, transpose=(0, 1, 2)
):
return padd_barret_xyzz(
return padd_barrett_xyzz(
x1_y1_zz1_zzz1[:, 0],
x1_y1_zz1_zzz1[:, 1],
x1_y1_zz1_zzz1[:, 2],
Expand All @@ -212,6 +210,151 @@ def padd_barrett_xyzz_pack_batch_first(
).transpose(transpose[0], transpose[1], transpose[2])


@jax.named_call
@functools.partial(jax.jit, static_argnames="twisted_d_chunk")
def padd_barrett_twisted(
x1: jax.Array,
y1: jax.Array,
z1: jax.Array,
t1: jax.Array,
x2: jax.Array,
y2: jax.Array,
z2: jax.Array,
t2: jax.Array,
twisted_d_chunk=util.TWIST_D_INT_CHUNK_BARRETT,
):
"""PADD-BARRETT elliptic curve operation with packed arguments.

As for the algorithm, pls refer to
jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::add_general

This function implements the PADD-LAZY elliptic curve operation with packed
arguments, which is used to compute the elliptic curve points of a given
group.

Args:
x1: The first generator element.
y1: The second generator element.
z1: The third generator element.
t1: The fourth generator element.
x2: The first generator element.
y2: The second generator element.
z2: The third generator element.
t2: The fourth generator element.
twisted_d_chunk: The twisted d parameter.

Returns:
A tuple containing the third generator element and the elliptic curve points
of the group.
"""
twisted_d = jnp.asarray(twisted_d_chunk, dtype=jnp.uint16)
twisted_d = jax.lax.broadcast(twisted_d, [x1.shape[0]])

a = mod_mul_barrett_2u16(x1, x2)
b = mod_mul_barrett_2u16(y1, y2)
d = mod_mul_barrett_2u16(z1, z2)
c = mod_mul_barrett_2u16(t1, t2)
c = mod_mul_barrett_2u16(c, twisted_d)

h = add_2u16(a, b)
h = cond_sub_mod_u16(h)
e1 = add_2u16(x1, y1)
e1 = cond_sub_mod_u16(e1)
e2 = add_2u16(x2, y2)
e2 = cond_sub_mod_u16(e2)
e = mod_mul_barrett_2u16(e1, e2)

e = cond_sub_2u16(e, h)

f = cond_sub_2u16(d, c)
g = add_2u16(d, c)
g = cond_sub_mod_u16(g)

x3 = mod_mul_barrett_2u16(e, f)
y3 = mod_mul_barrett_2u16(g, h)
z3 = mod_mul_barrett_2u16(f, g)
t3 = mod_mul_barrett_2u16(e, h)

return jnp.array([x3, y3, z3, t3])


def pdul_barrett_twisted(
x1: jax.Array,
y1: jax.Array,
z1: jax.Array,
t1: jax.Array,
):
"""PDUL-LAZY elliptic curve operation with packed arguments.

As for the algorithm, pls refer to
jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::double_general

This function implements the PDUL-LAZY elliptic curve operation with packed
arguments, which is used to compute the elliptic curve points of a given
group.

Args:
x1: The first generator element.
y1: The second generator element.
z1: The third generator element.
t1: The fourth generator element.

Returns:
A tuple containing the third generator element and the elliptic curve points
of the group.
"""
modulus_377_int_array = jnp.asarray(
util.MODULUS_377_INT_CHUNK, jnp.uint16
)

a = mod_mul_barrett_2u16(x1, x1)
b = mod_mul_barrett_2u16(y1, y1)

ct = mod_mul_barrett_2u16(z1, z1) #
ct2 = add_2u16(ct, ct) #
ct2 = cond_sub_mod_u16(ct2) #

h = add_2u16(a, b)
h = cond_sub_2u16(modulus_377_int_array, h) #

et = add_2u16(x1, y1) #
et = cond_sub_mod_u16(et) #
e = mod_mul_barrett_2u16(et, et) #
e = add_2u16(e, h) #
e = cond_sub_mod_u16(e) #

g = cond_sub_2u16(b, a) #
f = cond_sub_2u16(g, ct2) #
x3 = mod_mul_barrett_2u16(e, f) #
y3 = mod_mul_barrett_2u16(g, h)
z3 = mod_mul_barrett_2u16(f, g)
t3 = mod_mul_barrett_2u16(e, h)
return jnp.array([x3, y3, z3, t3])


@jax.named_call
def pdul_barrett_twisted_pack(x1_y1_zz1_zzz1: jax.Array):
return pdul_barrett_twisted(
x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1], x1_y1_zz1_zzz1[2], x1_y1_zz1_zzz1[3]
)


@jax.named_call
def padd_barrett_twisted_pack(
x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array
):
return padd_barrett_twisted(
x1_y1_zz1_zzz1[0],
x1_y1_zz1_zzz1[1],
x1_y1_zz1_zzz1[2],
x1_y1_zz1_zzz1[3],
x2_y2_zz2_zzz2[0],
x2_y2_zz2_zzz2[1],
x2_y2_zz2_zzz2[2],
x2_y2_zz2_zzz2[3],
)


# Lazy Reduction Based Functions
@jax.named_call
def padd_lazy_xyzz(
Expand Down Expand Up @@ -505,7 +648,7 @@ def pdul_lazy_twisted(
a = mod_mul_lazy_2u16(x1, x1)
b = mod_mul_lazy_2u16(y1, y1)

ct = mod_mul_lazy_2u16(z1, z1)
ct = mod_mul_lazy_2u16(z1, z1) #
ct2 = add_2u16(ct, ct) #
ct2 = cond_sub_mod_u16_ext(ct2) #

Expand Down Expand Up @@ -786,8 +929,14 @@ def padd_rns_twisted_pack(
group.

Args:
x1_y1_zz1_zzz1: The first point.
x2_y2_zz2_zzz2: The second point.
x1: The first generator element.
y1: The second generator element.
z1: The third generator element.
t1: The third generator element.
x2: The first generator element.
y2: The second generator element.
z2: The third generator element.
t2: The third generator element.
rns_mat: The RNS matrix.
twist_d: curve parameter.

Expand Down Expand Up @@ -815,10 +964,10 @@ def padd_rns_twisted_pack(
# Issue happens here
e = add_sub_rns_var(
e3,
negate_rns_for_var_add_zero_check(a),
negate_rns_for_var_add_zero_check(b),
negate_rns_for_var_add(a),
negate_rns_for_var_add(b),
)
f = add_sub_rns_var(d, negate_rns_for_var_add_zero_check(c))
f = add_sub_rns_var(d, negate_rns_for_var_add(c))
g = add_rns_2u16(d, c)
h = add_rns_2u16(a, b)

Expand Down Expand Up @@ -890,7 +1039,47 @@ def pdul_rns_twisted_pack(x1_y1_zz1_zzz1: jax.Array, rns_mat=util.RNS_MAT):


@jax.named_call
def rns_twist_zero():
return jnp.array(
[rns_constant(0), rns_constant(1), rns_constant(1), rns_constant(0)]
)
@functools.partial(jax.jit)
def pneg_rns_twisted(
x1: jax.Array,
y1: jax.Array,
z1: jax.Array,
t1: jax.Array,
):
"""PNEG-RNS elliptic curve operation with packed arguments.

As for the algorithm, pls refer to
jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::double_general

This function implements the PNEG-RNS elliptic curve operation with packed
arguments, which is used to compute the elliptic curve points of a given
group.

Args:
x1: The first generator element.
y1: The second generator element.
z1: The third generator element.
t1: The fourth generator element.

Returns:
A tuple containing the third generator element and the elliptic curve points
of the group.
"""

x2 = negate_rns(x1)
y2 = y1
z2 = z1
t2 = negate_rns(t1)

return jnp.array([x2, y2, z2, t2])


@jax.named_call
@jax.jit
def pneg_rns_twisted_pack(x1_y1_zz1_zzz1: jax.Array):
return pneg_rns_twisted(
x1_y1_zz1_zzz1[0],
x1_y1_zz1_zzz1[1],
x1_y1_zz1_zzz1[2],
x1_y1_zz1_zzz1[3],
)
Loading
Loading