Skip to content

Commit

Permalink
Small cleanups / simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
jannotti committed Dec 20, 2022
1 parent 68aab14 commit 0503c64
Showing 1 changed file with 45 additions and 61 deletions.
106 changes: 45 additions & 61 deletions data/transactions/logic/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ import (
"math/big"

"github.com/consensys/gnark-crypto/ecc"
BLS12381fr "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/gnark-crypto/ecc/bn254"
BN254fp "github.com/consensys/gnark-crypto/ecc/bn254/fp"
BN254fr "github.com/consensys/gnark-crypto/ecc/bn254/fr"
bn254fp "github.com/consensys/gnark-crypto/ecc/bn254/fp"
bn254fr "github.com/consensys/gnark-crypto/ecc/bn254/fr"

bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381"
BLS12381fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp"
bls12381fp "github.com/consensys/gnark-crypto/ecc/bls12-381/fp"
bls12381fr "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
)

/*Remaining questions
->What conditions should cause pairing to error vs put false on stack vs ignore point?
->Empty inputs (currently pairing and multiexp panic on empty inputs)
->Is subgroup check necessary for multiexp? Precompile does not seem to think so, but should ask Fabris
->Is subgroup check necessary for multiexp? Precompile does not seem to think so, but should ask Fabrice
->Confirm with gnark whether or not IsInSubgroup() also checks if point on curve. If not, they have a problem
->For now our code is written as if IsInSubgroup() does not check if point is on curve but is set up to be easily changed
*/
Expand All @@ -53,12 +53,12 @@ const (
scalarSize = 32
)

func bytesToBLS12381Field(b []byte) (BLS12381fp.Element, error) {
func bytesToBLS12381Field(b []byte) (bls12381fp.Element, error) {
intRepresentation := new(big.Int).SetBytes(b)
if intRepresentation.Cmp(BLS12381fp.Modulus()) >= 0 {
return BLS12381fp.Element{}, errors.New("Field element larger than modulus")
if intRepresentation.Cmp(bls12381fp.Modulus()) >= 0 {
return bls12381fp.Element{}, errors.New("Field element larger than modulus")
}
return *new(BLS12381fp.Element).SetBigInt(intRepresentation), nil
return *new(bls12381fp.Element).SetBigInt(intRepresentation), nil
}

func bytesToBLS12381G1(b []byte, checkCurve bool) (bls12381.G1Affine, error) {
Expand Down Expand Up @@ -189,14 +189,10 @@ func opBLS12381G1Add(cx *EvalContext) error {
if err != nil {
return err
}
// Would be slightly more efficient to use global variable instead of constantly creating new points
// But would mess with parallelization
res := new(bls12381.G1Affine).Add(&a, &b)
// It's possible it's more efficient to only check if the sum is on the curve as opposed to the summands,
// but I doubt that's safe
resBytes := bls12381G1ToBytes(res)
cx.stack = cx.stack[:last]
cx.stack[prev].Bytes = resBytes
cx.stack[prev].Bytes = bls12381G1ToBytes(a.Add(&a, &b))
return nil
}

Expand All @@ -213,10 +209,9 @@ func opBLS12381G2Add(cx *EvalContext) error {
if err != nil {
return err
}
res := new(bls12381.G2Affine).Add(&a, &b)
resBytes := bls12381G2ToBytes(res)

cx.stack = cx.stack[:last]
cx.stack[prev].Bytes = resBytes
cx.stack[prev].Bytes = bls12381G2ToBytes(a.Add(&a, &b))
return nil
}

Expand All @@ -233,10 +228,9 @@ func opBLS12381G1ScalarMul(cx *EvalContext) error {
return err
}
k := new(big.Int).SetBytes(cx.stack[last].Bytes)
res := new(bls12381.G1Affine).ScalarMultiplication(&a, k)
resBytes := bls12381G1ToBytes(res)
a.ScalarMultiplication(&a, k)
cx.stack = cx.stack[:last]
cx.stack[prev].Bytes = resBytes
cx.stack[prev].Bytes = bls12381G1ToBytes(&a)
return nil
}

Expand All @@ -249,10 +243,8 @@ func opBLS12381G2ScalarMul(cx *EvalContext) error {
return err
}
k := new(big.Int).SetBytes(cx.stack[last].Bytes)
res := new(bls12381.G2Affine).ScalarMultiplication(&a, k)
resBytes := bls12381G2ToBytes(res)
cx.stack = cx.stack[:last]
cx.stack[prev].Bytes = resBytes
cx.stack[prev].Bytes = bls12381G2ToBytes(a.ScalarMultiplication(&a, k))
return nil
}

Expand All @@ -272,11 +264,12 @@ func opBLS12381Pairing(cx *EvalContext) error {
return err
}
ok, err := bls12381.PairingCheck(g1, g2)
if err != nil {
return err
}
cx.stack = cx.stack[:last]
cx.stack[prev].Uint = boolToUint(ok)
cx.stack[prev].Bytes = nil
// I'm assuming it's significantly more likely that err is nil than not
return err
cx.stack[prev] = boolToSV(ok)
return nil
}

// Input: Top of stack is slice of k scalars, second to top is slice of k G1 points as uncompressed bytes
Expand All @@ -286,15 +279,15 @@ func opBLS12381G1MultiExponentiation(cx *EvalContext) error {
prev := last - 1
g1Bytes := cx.stack[prev].Bytes
scalarBytes := cx.stack[last].Bytes
// Precompile does not list subgroup check as mandatory for multiexponentiation, but should ask Fabris about this
// Precompile does not list subgroup check as mandatory for multiexponentiation, but should ask Fabrice about this
g1Points, err := bytesToBLS12381G1s(g1Bytes, false)
if err != nil {
return err
}
if len(scalarBytes)%scalarSize != 0 || len(scalarBytes)/scalarSize != len(g1Points) {
return errors.New("Bad input")
}
scalars := make([]BLS12381fr.Element, len(g1Points))
scalars := make([]bls12381fr.Element, len(g1Points))
for i := 0; i < len(g1Points); i++ {
scalars[i].SetBytes(scalarBytes[i*scalarSize : (i+1)*scalarSize])
}
Expand All @@ -316,7 +309,7 @@ func opBLS12381G2MultiExponentiation(cx *EvalContext) error {
if len(scalarBytes)%scalarSize != 0 || len(scalarBytes)/scalarSize != len(g2Points) {
return errors.New("Bad input")
}
scalars := make([]BLS12381fr.Element, len(g2Points))
scalars := make([]bls12381fr.Element, len(g2Points))
for i := 0; i < len(g2Points); i++ {
scalars[i].SetBytes(scalarBytes[i*scalarSize : (i+1)*scalarSize])
}
Expand Down Expand Up @@ -374,9 +367,8 @@ func opBLS12381G1SubgroupCheck(cx *EvalContext) error {
if err != nil {
return err
}
cx.stack[last].Uint = boolToUint(point.IsInSubGroup())
cx.stack[last].Bytes = nil
return err
cx.stack[last] = boolToSV(point.IsInSubGroup())
return nil
}

func opBLS12381G2SubgroupCheck(cx *EvalContext) error {
Expand All @@ -386,17 +378,16 @@ func opBLS12381G2SubgroupCheck(cx *EvalContext) error {
if err != nil {
return err
}
cx.stack[last].Uint = boolToUint(point.IsInSubGroup())
cx.stack[last].Bytes = nil
return err
cx.stack[last] = boolToSV(point.IsInSubGroup())
return nil
}

func bytesToBN254Field(b []byte) (BN254fp.Element, error) {
func bytesToBN254Field(b []byte) (bn254fp.Element, error) {
intRepresentation := new(big.Int).SetBytes(b)
if intRepresentation.Cmp(BN254fp.Modulus()) >= 0 {
return BN254fp.Element{}, errors.New("Field element larger than modulus")
if intRepresentation.Cmp(bn254fp.Modulus()) >= 0 {
return bn254fp.Element{}, errors.New("Field element larger than modulus")
}
return *new(BN254fp.Element).SetBigInt(intRepresentation), nil
return *new(bn254fp.Element).SetBigInt(intRepresentation), nil
}

func bytesToBN254G1(b []byte, checkCurve bool) (bn254.G1Affine, error) {
Expand Down Expand Up @@ -523,10 +514,8 @@ func opBN254G1Add(cx *EvalContext) error {
if err != nil {
return err
}
res := new(bn254.G1Affine).Add(&a, &b)
resBytes := bn254G1ToBytes(res)
cx.stack = cx.stack[:last]
cx.stack[prev].Bytes = resBytes
cx.stack[prev].Bytes = bn254G1ToBytes(a.Add(&a, &b))
return nil
}

Expand All @@ -543,10 +532,8 @@ func opBN254G2Add(cx *EvalContext) error {
if err != nil {
return err
}
res := new(bn254.G2Affine).Add(&a, &b)
resBytes := bn254G2ToBytes(res)
cx.stack = cx.stack[:last]
cx.stack[prev].Bytes = resBytes
cx.stack[prev].Bytes = bn254G2ToBytes(a.Add(&a, &b))
return nil
}

Expand All @@ -559,10 +546,8 @@ func opBN254G1ScalarMul(cx *EvalContext) error {
return err
}
k := new(big.Int).SetBytes(cx.stack[last].Bytes)
res := new(bn254.G1Affine).ScalarMultiplication(&a, k)
resBytes := bn254G1ToBytes(res)
cx.stack = cx.stack[:last]
cx.stack[prev].Bytes = resBytes
cx.stack[prev].Bytes = bn254G1ToBytes(a.ScalarMultiplication(&a, k))
return nil
}

Expand All @@ -575,10 +560,8 @@ func opBN254G2ScalarMul(cx *EvalContext) error {
return err
}
k := new(big.Int).SetBytes(cx.stack[last].Bytes)
res := new(bn254.G2Affine).ScalarMultiplication(&a, k)
resBytes := bn254G2ToBytes(res)
cx.stack = cx.stack[:last]
cx.stack[prev].Bytes = resBytes
cx.stack[prev].Bytes = bn254G2ToBytes(a.ScalarMultiplication(&a, k))
return nil
}

Expand All @@ -596,9 +579,12 @@ func opBN254Pairing(cx *EvalContext) error {
return err
}
ok, err := bn254.PairingCheck(g1, g2)
if err != nil {
return err
}
cx.stack[prev] = boolToSV(ok)
cx.stack = cx.stack[:last]
return err
return nil
}

func opBN254G1MultiExponentiation(cx *EvalContext) error {
Expand All @@ -613,7 +599,7 @@ func opBN254G1MultiExponentiation(cx *EvalContext) error {
if len(scalarBytes)%scalarSize != 0 || len(scalarBytes)/scalarSize != len(g1Points) {
return errors.New("Bad input")
}
scalars := make([]BN254fr.Element, len(g1Points))
scalars := make([]bn254fr.Element, len(g1Points))
for i := 0; i < len(g1Points); i++ {
scalars[i].SetBytes(scalarBytes[i*scalarSize : (i+1)*scalarSize])
}
Expand All @@ -635,7 +621,7 @@ func opBN254G2MultiExponentiation(cx *EvalContext) error {
if len(scalarBytes)%scalarSize != 0 || len(scalarBytes)/scalarSize != len(g2Points) {
return errors.New("Bad input")
}
scalars := make([]BN254fr.Element, len(g2Points))
scalars := make([]bn254fr.Element, len(g2Points))
for i := 0; i < len(g2Points); i++ {
scalars[i].SetBytes(scalarBytes[i*scalarSize : (i+1)*scalarSize])
}
Expand Down Expand Up @@ -689,9 +675,8 @@ func opBN254G1SubgroupCheck(cx *EvalContext) error {
if err != nil {
return err
}
cx.stack[last].Uint = boolToUint(point.IsInSubGroup())
cx.stack[last].Bytes = nil
return err
cx.stack[last] = boolToSV(point.IsInSubGroup())
return nil
}

func opBN254G2SubgroupCheck(cx *EvalContext) error {
Expand All @@ -701,7 +686,6 @@ func opBN254G2SubgroupCheck(cx *EvalContext) error {
if err != nil {
return err
}
cx.stack[last].Uint = boolToUint(point.IsInSubGroup())
cx.stack[last].Bytes = nil
return err
cx.stack[last] = boolToSV(point.IsInSubGroup())
return nil
}

0 comments on commit 0503c64

Please sign in to comment.