Skip to content

Commit

Permalink
batched ECIP: py & cairo
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed Dec 5, 2024
1 parent 80cde40 commit c81f436
Show file tree
Hide file tree
Showing 14 changed files with 9,417 additions and 22,426 deletions.
26 changes: 26 additions & 0 deletions hydra/garaga/hints/ecip.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,32 @@ def print_ff(ff: FF):
return string


def n_points_from_n_coeffs(n_coeffs: int, batched: bool) -> int:
if batched:
extra = 4 * 2
else:
extra = 0

# n_coeffs = 10 + 4n_points => 4n_points = n_coeffs - 10
assert n_coeffs >= 10 + extra
assert (n_coeffs - 10 - extra) % 4 == 0
return (n_coeffs - 10 - extra) // 4


def n_coeffs_from_n_points(n_points: int, batched: bool) -> tuple[int, int, int, int]:
if batched:
extra = 2
else:
extra = 0

return (
1 + n_points + extra,
1 + n_points + 1 + extra,
1 + n_points + 1 + extra,
1 + n_points + 4 + extra,
)


if __name__ == "__main__":
import random

Expand Down
7 changes: 5 additions & 2 deletions hydra/garaga/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,15 @@ class CircuitID(Enum):
},
CircuitID.EVAL_FUNCTION_CHALLENGE_DUPL: {
"class": EvalFunctionChallengeDuplCircuit,
"params": [{"n_points": k} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],
"params": [
{"n_points": k, "batched": True} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
]
+ [{"n_points": k, "batched": False} for k in [1, 2]],
"filename": "ec",
},
CircuitID.INIT_FUNCTION_CHALLENGE_DUPL: {
"class": InitFunctionChallengeDuplCircuit,
"params": [{"n_points": k} for k in [11]],
"params": [{"n_points": k, "batched": True} for k in [11]],
"filename": "ec",
},
CircuitID.ACC_FUNCTION_CHALLENGE_DUPL: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import garaga.modulo_circuit_structs as structs
from garaga.definitions import CURVES, CurveID, G1Point, G2Point
from garaga.hints import neg_3
from garaga.hints.ecip import slope_intercept
from garaga.hints.ecip import (
n_coeffs_from_n_points,
n_points_from_n_coeffs,
slope_intercept,
)
from garaga.modulo_circuit import WriteOps
from garaga.modulo_circuit_structs import G1PointCircuit, G2PointCircuit, u384
from garaga.precompiled_circuits.compilable_circuits.base import (
Expand Down Expand Up @@ -398,26 +402,19 @@ def __init__(
n_points: int = 1,
auto_run: bool = True,
compilation_mode: int = 0,
batched: bool = False,
generic_circuit: bool = True,
) -> None:
self.n_points = n_points
self.batched = batched
self.generic_circuit = generic_circuit
super().__init__(
name=f"eval_fn_challenge_dupl_{n_points}P",
name=f"eval_fn_challenge_dupl_{n_points}P" + ("_rlc" if batched else ""),
curve_id=curve_id,
auto_run=auto_run,
compilation_mode=compilation_mode,
)

@staticmethod
def _n_coeffs_from_n_points(n_points: int) -> tuple[int, int, int, int]:
return (1 + n_points, 1 + n_points + 1, 1 + n_points + 1, 1 + n_points + 4)

@staticmethod
def _n_points_from_n_coeffs(n_coeffs: int) -> int:
# n_coeffs = 10 + 4n_points => 4n_points = n_coeffs - 10
assert n_coeffs >= 10
assert (n_coeffs - 10) % 4 == 0
return (n_coeffs - 10) // 4

def build_input(self) -> list[PyFelt]:
input = []
circuit = SlopeInterceptSamePointCircuit(self.curve_id, auto_run=False)
Expand All @@ -426,14 +423,17 @@ def build_input(self) -> list[PyFelt]:
[xA, _yA, _A]
).output
input.extend([xA0.felt, _yA.felt, xA2.felt, yA2.felt, coeff0.felt, coeff2.felt])
n_coeffs = self._n_coeffs_from_n_points(self.n_points)
n_coeffs = n_coeffs_from_n_points(self.n_points, self.batched)
for _ in range(sum(n_coeffs)):
input.append(self.field(randint(0, CURVES[self.curve_id].p - 1)))
return input

def _run_circuit_inner(self, input: list[PyFelt]) -> ModuloCircuit:
circuit = ECIPCircuits(
self.name, self.curve_id, compilation_mode=self.compilation_mode
self.name,
self.curve_id,
compilation_mode=self.compilation_mode,
generic_circuit=self.generic_circuit,
)

xA0, yA0 = circuit.write_struct(
Expand All @@ -454,14 +454,14 @@ def split_list(input_list, lengths):
start_idx += length
return result

n_points = self._n_points_from_n_coeffs(len(all_coeffs))
n_points = n_points_from_n_coeffs(len(all_coeffs), self.batched)
_log_div_a_num, _log_div_a_den, _log_div_b_num, _log_div_b_den = split_list(
all_coeffs, self._n_coeffs_from_n_points(n_points)
all_coeffs, n_coeffs_from_n_points(n_points, self.batched)
)
log_div_a_num, log_div_a_den, log_div_b_num, log_div_b_den = (
circuit.write_struct(
structs.FunctionFeltCircuit(
name="SumDlogDiv",
name="SumDlogDiv" + ("Batched" if self.batched else ""),
elmts=[
structs.u384Span("log_div_a_num", _log_div_a_num),
structs.u384Span("log_div_a_den", _log_div_a_den),
Expand Down Expand Up @@ -494,31 +494,22 @@ def __init__(
curve_id: int,
n_points: int = 1,
auto_run: bool = True,
batched: bool = False,
compilation_mode: int = 0,
) -> None:
self.n_points = n_points
self.batched = batched
super().__init__(
name=f"init_fn_challenge_dupl_{n_points}P",
name=f"init_fn_challenge_dupl_{n_points}P" + ("_rlc" if batched else ""),
curve_id=curve_id,
auto_run=auto_run,
compilation_mode=compilation_mode,
)

@staticmethod
def _n_coeffs_from_n_points(n_points: int) -> tuple[int, int, int, int]:
return (1 + n_points, 1 + n_points + 1, 1 + n_points + 1, 1 + n_points + 4)

@staticmethod
def _n_points_from_n_coeffs(n_coeffs: int) -> int:
# n_coeffs = 10 + 4n_points => 4n_points = n_coeffs - 10
assert n_coeffs >= 10
assert (n_coeffs - 10) % 4 == 0
return (n_coeffs - 10) // 4

def build_input(self) -> list[PyFelt]:
input = []
input.extend([self.field.random(), self.field.random()]) # xA0, xA2
n_coeffs = self._n_coeffs_from_n_points(self.n_points)
n_coeffs = n_coeffs_from_n_points(self.n_points, self.batched)
for _ in range(sum(n_coeffs)):
input.append(self.field(randint(0, CURVES[self.curve_id].p - 1)))
return input
Expand All @@ -539,9 +530,9 @@ def split_list(input_list, lengths):
start_idx += length
return result

n_points = self._n_points_from_n_coeffs(len(all_coeffs))
n_points = n_points_from_n_coeffs(len(all_coeffs), self.batched)
_log_div_a_num, _log_div_a_den, _log_div_b_num, _log_div_b_den = split_list(
all_coeffs, self._n_coeffs_from_n_points(n_points)
all_coeffs, n_coeffs_from_n_points(n_points, self.batched)
)

log_div_a_num, log_div_a_den, log_div_b_num, log_div_b_den = (
Expand Down
10 changes: 8 additions & 2 deletions hydra/garaga/precompiled_circuits/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,17 @@ def _derive_point_from_x(


class ECIPCircuits(ModuloCircuit):
def __init__(self, name: str, curve_id: int, compilation_mode: int = 0):
def __init__(
self,
name: str,
curve_id: int,
compilation_mode: int = 0,
generic_circuit: bool = True,
):
super().__init__(
name=name,
curve_id=curve_id,
generic_circuit=True,
generic_circuit=generic_circuit,
compilation_mode=compilation_mode,
)
self.curve = CURVES[curve_id]
Expand Down
5 changes: 2 additions & 3 deletions hydra/garaga/starknet/groth16_contract_generator/calldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def groth16_calldata_from_vk_and_proof(
vk: Groth16VerifyingKey, proof: Groth16Proof, use_rust: bool = True
vk: Groth16VerifyingKey, proof: Groth16Proof, use_rust: bool = False
) -> list[int]:
if use_rust:
return _groth16_calldata_from_vk_and_proof_rust(vk, proof)
Expand Down Expand Up @@ -45,13 +45,13 @@ def groth16_calldata_from_vk_and_proof(
curve_id=vk.curve_id,
points=[vk.ic[3], vk.ic[4]],
scalars=[proof.public_inputs[2], proof.public_inputs[3]],
risc0_mode=True,
)
calldata.extend(
msm.serialize_to_calldata(
include_digits_decomposition=True,
include_points_and_scalars=False,
serialize_as_pure_felt252_array=True,
risc0_mode=True,
)
)
else:
Expand All @@ -66,7 +66,6 @@ def groth16_calldata_from_vk_and_proof(
include_digits_decomposition=True,
include_points_and_scalars=False,
serialize_as_pure_felt252_array=True,
risc0_mode=False,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from garaga.starknet.cli.utils import create_directory, get_package_version
from garaga.starknet.groth16_contract_generator.parsing_utils import Groth16VerifyingKey

ECIP_OPS_CLASS_HASH = 0x70C1D1C709C75E3CF51D79D19CF7C84A0D4521F3A2B8BF7BFF5CB45EE0DD289
ECIP_OPS_CLASS_HASH = 0x223A0051C2E31EDE1FD33DB4F01BC979901FD80F3429017710176CCE6AADA3B


def precompute_lines_from_vk(vk: Groth16VerifyingKey) -> StructArray:
Expand Down
Loading

0 comments on commit c81f436

Please sign in to comment.