diff --git a/ecc/bls12-377/fr/fft/domain.go b/ecc/bls12-377/fr/fft/domain.go index cca496b23..7577eba2b 100644 --- a/ecc/bls12-377/fr/fft/domain.go +++ b/ecc/bls12-377/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(22) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bls12-377/fr/fft/fft.go b/ecc/bls12-377/fr/fft/fft.go index 20cafffd7..0e181cfc4 100644 --- a/ecc/bls12-377/fr/fft/fft.go +++ b/ecc/bls12-377/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bls12-377/fr/fft/fft_test.go b/ecc/bls12-377/fr/fft/fft_test.go index a443204c1..2fbc579e2 100644 --- a/ecc/bls12-377/fr/fft/fft_test.go +++ b/ecc/bls12-377/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bls12-377/fr/fft/options.go b/ecc/bls12-377/fr/fft/options.go index 02a6000e5..e86234633 100644 --- a/ecc/bls12-377/fr/fft/options.go +++ b/ecc/bls12-377/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bls12-377/fr/iop/ratios.go b/ecc/bls12-377/fr/iop/ratios.go index 3027260bb..1a98136fa 100644 --- a/ecc/bls12-377/fr/iop/ratios.go +++ b/ecc/bls12-377/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/bls12-377/fr/sis/sis.go b/ecc/bls12-377/fr/sis/sis.go index 164dc1a0e..1279c8145 100644 --- a/ecc/bls12-377/fr/sis/sis.go +++ b/ecc/bls12-377/fr/sis/sis.go @@ -119,7 +119,7 @@ func NewRSis(seed int64, logTwoDegree, logTwoBound, maxNbElementsToHash int) (*R LogTwoBound: logTwoBound, capacity: capacity, Degree: degree, - Domain: fft.NewDomain(uint64(degree), shift), + Domain: fft.NewDomain(uint64(degree), fft.WithShift(shift)), A: make([][]fr.Element, n), Ag: make([][]fr.Element, n), bufM: make(fr.Vector, degree*n), @@ -129,7 +129,7 @@ func NewRSis(seed int64, logTwoDegree, logTwoBound, maxNbElementsToHash int) (*R } if r.LogTwoBound == 8 && r.Degree == 64 { // TODO @gbotrel fixme, that's dirty. - r.twiddleCosets = precomputeTwiddlesCoset(r.Domain.Twiddles, r.Domain.FrMultiplicativeGen) + r.twiddleCosets = precomputeTwiddlesCoset(r.Domain.Generator, r.Domain.FrMultiplicativeGen) } // filling A diff --git a/ecc/bls12-377/fr/sis/sis_fft.go b/ecc/bls12-377/fr/sis/sis_fft.go index 891b7e677..ae351d5ec 100644 --- a/ecc/bls12-377/fr/sis/sis_fft.go +++ b/ecc/bls12-377/fr/sis/sis_fft.go @@ -18,6 +18,7 @@ package sis import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math/big" ) // fft64 is generated by gnark-crypto and contains the unrolled code for FFT (DIF) on 64 elements @@ -413,82 +414,154 @@ func fft64(a []fr.Element, twiddlesCoset []fr.Element) { // precomputeTwiddlesCoset precomputes twiddlesCoset from twiddles and coset table // it then return all elements in the correct order for the unrolled FFT. -func precomputeTwiddlesCoset(twiddles [][]fr.Element, shifter fr.Element) []fr.Element { - r := make([][]fr.Element, len(twiddles)) - for i := 0; i < len(twiddles); i++ { - r[i] = make([]fr.Element, len(twiddles[i])) - s := shifter - for k := 0; k < i; k++ { - s.Mul(&s, &s) - } - for j := 0; j < len(twiddles[i]); j++ { - r[i][j].Mul(&twiddles[i][j], &s) - } - } - toReturn := make([]fr.Element, 0, 63) +func precomputeTwiddlesCoset(generator, shifter fr.Element) []fr.Element { + toReturn := make([]fr.Element, 63) + var r, s fr.Element + e := new(big.Int) - toReturn = append(toReturn, r[5][0]) - toReturn = append(toReturn, r[4][0]) - toReturn = append(toReturn, r[4][1]) - toReturn = append(toReturn, r[3][0]) - toReturn = append(toReturn, r[3][2]) - toReturn = append(toReturn, r[3][1]) - toReturn = append(toReturn, r[3][3]) - toReturn = append(toReturn, r[2][0]) - toReturn = append(toReturn, r[2][4]) - toReturn = append(toReturn, r[2][2]) - toReturn = append(toReturn, r[2][6]) - toReturn = append(toReturn, r[2][1]) - toReturn = append(toReturn, r[2][5]) - toReturn = append(toReturn, r[2][3]) - toReturn = append(toReturn, r[2][7]) - toReturn = append(toReturn, r[1][0]) - toReturn = append(toReturn, r[1][8]) - toReturn = append(toReturn, r[1][4]) - toReturn = append(toReturn, r[1][12]) - toReturn = append(toReturn, r[1][2]) - toReturn = append(toReturn, r[1][10]) - toReturn = append(toReturn, r[1][6]) - toReturn = append(toReturn, r[1][14]) - toReturn = append(toReturn, r[1][1]) - toReturn = append(toReturn, r[1][9]) - toReturn = append(toReturn, r[1][5]) - toReturn = append(toReturn, r[1][13]) - toReturn = append(toReturn, r[1][3]) - toReturn = append(toReturn, r[1][11]) - toReturn = append(toReturn, r[1][7]) - toReturn = append(toReturn, r[1][15]) - toReturn = append(toReturn, r[0][0]) - toReturn = append(toReturn, r[0][16]) - toReturn = append(toReturn, r[0][8]) - toReturn = append(toReturn, r[0][24]) - toReturn = append(toReturn, r[0][4]) - toReturn = append(toReturn, r[0][20]) - toReturn = append(toReturn, r[0][12]) - toReturn = append(toReturn, r[0][28]) - toReturn = append(toReturn, r[0][2]) - toReturn = append(toReturn, r[0][18]) - toReturn = append(toReturn, r[0][10]) - toReturn = append(toReturn, r[0][26]) - toReturn = append(toReturn, r[0][6]) - toReturn = append(toReturn, r[0][22]) - toReturn = append(toReturn, r[0][14]) - toReturn = append(toReturn, r[0][30]) - toReturn = append(toReturn, r[0][1]) - toReturn = append(toReturn, r[0][17]) - toReturn = append(toReturn, r[0][9]) - toReturn = append(toReturn, r[0][25]) - toReturn = append(toReturn, r[0][5]) - toReturn = append(toReturn, r[0][21]) - toReturn = append(toReturn, r[0][13]) - toReturn = append(toReturn, r[0][29]) - toReturn = append(toReturn, r[0][3]) - toReturn = append(toReturn, r[0][19]) - toReturn = append(toReturn, r[0][11]) - toReturn = append(toReturn, r[0][27]) - toReturn = append(toReturn, r[0][7]) - toReturn = append(toReturn, r[0][23]) - toReturn = append(toReturn, r[0][15]) - toReturn = append(toReturn, r[0][31]) + s = shifter + for k := 0; k < 5; k++ { + s.Square(&s) + } + toReturn[0] = s + s = shifter + for k := 0; k < 4; k++ { + s.Square(&s) + } + toReturn[1] = s + r.Exp(generator, e.SetUint64(uint64(1<<4*1))) + toReturn[2].Mul(&r, &s) + s = shifter + for k := 0; k < 3; k++ { + s.Square(&s) + } + toReturn[3] = s + r.Exp(generator, e.SetUint64(uint64(1<<3*2))) + toReturn[4].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<3*1))) + toReturn[5].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<3*3))) + toReturn[6].Mul(&r, &s) + s = shifter + for k := 0; k < 2; k++ { + s.Square(&s) + } + toReturn[7] = s + r.Exp(generator, e.SetUint64(uint64(1<<2*4))) + toReturn[8].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*2))) + toReturn[9].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*6))) + toReturn[10].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*1))) + toReturn[11].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*5))) + toReturn[12].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*3))) + toReturn[13].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*7))) + toReturn[14].Mul(&r, &s) + s = shifter + for k := 0; k < 1; k++ { + s.Square(&s) + } + toReturn[15] = s + r.Exp(generator, e.SetUint64(uint64(1<<1*8))) + toReturn[16].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*4))) + toReturn[17].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*12))) + toReturn[18].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*2))) + toReturn[19].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*10))) + toReturn[20].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*6))) + toReturn[21].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*14))) + toReturn[22].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*1))) + toReturn[23].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*9))) + toReturn[24].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*5))) + toReturn[25].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*13))) + toReturn[26].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*3))) + toReturn[27].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*11))) + toReturn[28].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*7))) + toReturn[29].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*15))) + toReturn[30].Mul(&r, &s) + s = shifter + for k := 0; k < 0; k++ { + s.Square(&s) + } + toReturn[31] = s + r.Exp(generator, e.SetUint64(uint64(1<<0*16))) + toReturn[32].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*8))) + toReturn[33].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*24))) + toReturn[34].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*4))) + toReturn[35].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*20))) + toReturn[36].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*12))) + toReturn[37].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*28))) + toReturn[38].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*2))) + toReturn[39].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*18))) + toReturn[40].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*10))) + toReturn[41].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*26))) + toReturn[42].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*6))) + toReturn[43].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*22))) + toReturn[44].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*14))) + toReturn[45].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*30))) + toReturn[46].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*1))) + toReturn[47].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*17))) + toReturn[48].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*9))) + toReturn[49].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*25))) + toReturn[50].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*5))) + toReturn[51].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*21))) + toReturn[52].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*13))) + toReturn[53].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*29))) + toReturn[54].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*3))) + toReturn[55].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*19))) + toReturn[56].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*11))) + toReturn[57].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*27))) + toReturn[58].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*7))) + toReturn[59].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*23))) + toReturn[60].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*15))) + toReturn[61].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*31))) + toReturn[62].Mul(&r, &s) return toReturn } diff --git a/ecc/bls12-377/fr/sis/sis_test.go b/ecc/bls12-377/fr/sis/sis_test.go index ddccf9222..3df4ff66a 100644 --- a/ecc/bls12-377/fr/sis/sis_test.go +++ b/ecc/bls12-377/fr/sis/sis_test.go @@ -16,14 +16,14 @@ package sis import ( "bytes" + "crypto/rand" "encoding/binary" "encoding/json" "fmt" "io" - "io/ioutil" "math/big" "math/bits" - "math/rand" + "os" "testing" "time" @@ -69,7 +69,7 @@ func TestReference(t *testing.T) { // read the test case file var testCases TestCases - data, err := ioutil.ReadFile("test_cases.json") + data, err := os.ReadFile("test_cases.json") assert.NoError(err, "reading test cases failed") err = json.Unmarshal(data, &testCases) assert.NoError(err, "reading test cases failed") @@ -142,7 +142,7 @@ func TestMulMod(t *testing.T) { // and random. var shift fr.Element shift.SetString("19540430494807482326159819597004422086093766032135589407132600596362845576832") - domain := fft.NewDomain(uint64(size), shift) + domain := fft.NewDomain(uint64(size), fft.WithShift(shift)) // mul mod domain.FFT(p, fft.DIF, fft.OnCoset()) @@ -391,8 +391,7 @@ func TestLimbDecompositionFastPath(t *testing.T) { nValues := bitset.New(uint(size)) // Generate a random buffer - rand.Seed(time.Now().UnixNano()) //#nosec G404 weak rng is fine here - _, err := rand.Read(buf) //#nosec G404 weak rng is fine here + _, err := rand.Read(buf) //#nosec G404 weak rng is fine here assert.NoError(err) limbDecomposeBytes8_64(buf, m, mValues) @@ -415,7 +414,7 @@ func TestUnrolledFFT(t *testing.T) { const size = 64 assert := require.New(t) - domain := fft.NewDomain(size, shift) + domain := fft.NewDomain(size, fft.WithShift(shift)) k1 := make([]fr.Element, size) for i := 0; i < size; i++ { @@ -428,7 +427,7 @@ func TestUnrolledFFT(t *testing.T) { domain.FFT(k1, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1)) // unrolled FFT - twiddlesCoset := precomputeTwiddlesCoset(domain.Twiddles, domain.FrMultiplicativeGen) + twiddlesCoset := precomputeTwiddlesCoset(domain.Generator, domain.FrMultiplicativeGen) fft64(k2, twiddlesCoset) // compare results diff --git a/ecc/bls12-378/fr/fft/domain.go b/ecc/bls12-378/fr/fft/domain.go index be38f0602..a9c410b48 100644 --- a/ecc/bls12-378/fr/fft/domain.go +++ b/ecc/bls12-378/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(22) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bls12-378/fr/fft/fft.go b/ecc/bls12-378/fr/fft/fft.go index 9f1527360..6db23a3ac 100644 --- a/ecc/bls12-378/fr/fft/fft.go +++ b/ecc/bls12-378/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bls12-378/fr/fft/fft_test.go b/ecc/bls12-378/fr/fft/fft_test.go index 0478dd270..da1a29035 100644 --- a/ecc/bls12-378/fr/fft/fft_test.go +++ b/ecc/bls12-378/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bls12-378/fr/fft/options.go b/ecc/bls12-378/fr/fft/options.go index 02a6000e5..81316225a 100644 --- a/ecc/bls12-378/fr/fft/options.go +++ b/ecc/bls12-378/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bls12-378/fr/iop/ratios.go b/ecc/bls12-378/fr/iop/ratios.go index 7bc6a607a..1eac2992b 100644 --- a/ecc/bls12-378/fr/iop/ratios.go +++ b/ecc/bls12-378/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/bls12-381/fr/fft/domain.go b/ecc/bls12-381/fr/fft/domain.go index 1c121be0b..30113b466 100644 --- a/ecc/bls12-381/fr/fft/domain.go +++ b/ecc/bls12-381/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(7) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bls12-381/fr/fft/fft.go b/ecc/bls12-381/fr/fft/fft.go index dbc99e444..22b489094 100644 --- a/ecc/bls12-381/fr/fft/fft.go +++ b/ecc/bls12-381/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bls12-381/fr/fft/fft_test.go b/ecc/bls12-381/fr/fft/fft_test.go index 21851d58d..01a3e3215 100644 --- a/ecc/bls12-381/fr/fft/fft_test.go +++ b/ecc/bls12-381/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bls12-381/fr/fft/options.go b/ecc/bls12-381/fr/fft/options.go index 02a6000e5..5ae24b709 100644 --- a/ecc/bls12-381/fr/fft/options.go +++ b/ecc/bls12-381/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bls12-381/fr/iop/ratios.go b/ecc/bls12-381/fr/iop/ratios.go index b6dabd4d6..17cd7b98d 100644 --- a/ecc/bls12-381/fr/iop/ratios.go +++ b/ecc/bls12-381/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/bls24-315/fr/fft/domain.go b/ecc/bls24-315/fr/fft/domain.go index 65c9a5118..de895ab37 100644 --- a/ecc/bls24-315/fr/fft/domain.go +++ b/ecc/bls24-315/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(7) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bls24-315/fr/fft/fft.go b/ecc/bls24-315/fr/fft/fft.go index bd2eda5fc..40ecd312b 100644 --- a/ecc/bls24-315/fr/fft/fft.go +++ b/ecc/bls24-315/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bls24-315/fr/fft/fft_test.go b/ecc/bls24-315/fr/fft/fft_test.go index 04860ec9f..70934867e 100644 --- a/ecc/bls24-315/fr/fft/fft_test.go +++ b/ecc/bls24-315/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bls24-315/fr/fft/options.go b/ecc/bls24-315/fr/fft/options.go index 02a6000e5..cdb7f7f79 100644 --- a/ecc/bls24-315/fr/fft/options.go +++ b/ecc/bls24-315/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bls24-315/fr/iop/ratios.go b/ecc/bls24-315/fr/iop/ratios.go index 889f22d45..1d8832c34 100644 --- a/ecc/bls24-315/fr/iop/ratios.go +++ b/ecc/bls24-315/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/bls24-317/fr/fft/domain.go b/ecc/bls24-317/fr/fft/domain.go index 01b57683e..54d98c291 100644 --- a/ecc/bls24-317/fr/fft/domain.go +++ b/ecc/bls24-317/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(7) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bls24-317/fr/fft/fft.go b/ecc/bls24-317/fr/fft/fft.go index 5a205433b..45f361a9e 100644 --- a/ecc/bls24-317/fr/fft/fft.go +++ b/ecc/bls24-317/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bls24-317/fr/fft/fft_test.go b/ecc/bls24-317/fr/fft/fft_test.go index 6a5e4dfdd..12ccded05 100644 --- a/ecc/bls24-317/fr/fft/fft_test.go +++ b/ecc/bls24-317/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bls24-317/fr/fft/options.go b/ecc/bls24-317/fr/fft/options.go index 02a6000e5..0bb039125 100644 --- a/ecc/bls24-317/fr/fft/options.go +++ b/ecc/bls24-317/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bls24-317/fr/iop/ratios.go b/ecc/bls24-317/fr/iop/ratios.go index d0a4a030d..e71f9547b 100644 --- a/ecc/bls24-317/fr/iop/ratios.go +++ b/ecc/bls24-317/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/bn254/fr/fft/domain.go b/ecc/bn254/fr/fft/domain.go index 191fe1603..6ac79bf0a 100644 --- a/ecc/bn254/fr/fft/domain.go +++ b/ecc/bn254/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(5) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bn254/fr/fft/fft.go b/ecc/bn254/fr/fft/fft.go index a3dfe0d34..bffec1498 100644 --- a/ecc/bn254/fr/fft/fft.go +++ b/ecc/bn254/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bn254/fr/fft/fft_test.go b/ecc/bn254/fr/fft/fft_test.go index ac1318130..207ae537c 100644 --- a/ecc/bn254/fr/fft/fft_test.go +++ b/ecc/bn254/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bn254/fr/fft/options.go b/ecc/bn254/fr/fft/options.go index 02a6000e5..d2b44ec8b 100644 --- a/ecc/bn254/fr/fft/options.go +++ b/ecc/bn254/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bn254/fr/iop/ratios.go b/ecc/bn254/fr/iop/ratios.go index b0a3721d2..7d9000ccd 100644 --- a/ecc/bn254/fr/iop/ratios.go +++ b/ecc/bn254/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/bn254/fr/sis/sis.go b/ecc/bn254/fr/sis/sis.go index db90af87e..dce215b14 100644 --- a/ecc/bn254/fr/sis/sis.go +++ b/ecc/bn254/fr/sis/sis.go @@ -119,7 +119,7 @@ func NewRSis(seed int64, logTwoDegree, logTwoBound, maxNbElementsToHash int) (*R LogTwoBound: logTwoBound, capacity: capacity, Degree: degree, - Domain: fft.NewDomain(uint64(degree), shift), + Domain: fft.NewDomain(uint64(degree), fft.WithShift(shift)), A: make([][]fr.Element, n), Ag: make([][]fr.Element, n), bufM: make(fr.Vector, degree*n), @@ -129,7 +129,7 @@ func NewRSis(seed int64, logTwoDegree, logTwoBound, maxNbElementsToHash int) (*R } if r.LogTwoBound == 8 && r.Degree == 64 { // TODO @gbotrel fixme, that's dirty. - r.twiddleCosets = precomputeTwiddlesCoset(r.Domain.Twiddles, r.Domain.FrMultiplicativeGen) + r.twiddleCosets = precomputeTwiddlesCoset(r.Domain.Generator, r.Domain.FrMultiplicativeGen) } // filling A diff --git a/ecc/bn254/fr/sis/sis_fft.go b/ecc/bn254/fr/sis/sis_fft.go index 336805ebe..70b4d32d9 100644 --- a/ecc/bn254/fr/sis/sis_fft.go +++ b/ecc/bn254/fr/sis/sis_fft.go @@ -18,6 +18,7 @@ package sis import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "math/big" ) // fft64 is generated by gnark-crypto and contains the unrolled code for FFT (DIF) on 64 elements @@ -413,82 +414,154 @@ func fft64(a []fr.Element, twiddlesCoset []fr.Element) { // precomputeTwiddlesCoset precomputes twiddlesCoset from twiddles and coset table // it then return all elements in the correct order for the unrolled FFT. -func precomputeTwiddlesCoset(twiddles [][]fr.Element, shifter fr.Element) []fr.Element { - r := make([][]fr.Element, len(twiddles)) - for i := 0; i < len(twiddles); i++ { - r[i] = make([]fr.Element, len(twiddles[i])) - s := shifter - for k := 0; k < i; k++ { - s.Mul(&s, &s) - } - for j := 0; j < len(twiddles[i]); j++ { - r[i][j].Mul(&twiddles[i][j], &s) - } - } - toReturn := make([]fr.Element, 0, 63) +func precomputeTwiddlesCoset(generator, shifter fr.Element) []fr.Element { + toReturn := make([]fr.Element, 63) + var r, s fr.Element + e := new(big.Int) - toReturn = append(toReturn, r[5][0]) - toReturn = append(toReturn, r[4][0]) - toReturn = append(toReturn, r[4][1]) - toReturn = append(toReturn, r[3][0]) - toReturn = append(toReturn, r[3][2]) - toReturn = append(toReturn, r[3][1]) - toReturn = append(toReturn, r[3][3]) - toReturn = append(toReturn, r[2][0]) - toReturn = append(toReturn, r[2][4]) - toReturn = append(toReturn, r[2][2]) - toReturn = append(toReturn, r[2][6]) - toReturn = append(toReturn, r[2][1]) - toReturn = append(toReturn, r[2][5]) - toReturn = append(toReturn, r[2][3]) - toReturn = append(toReturn, r[2][7]) - toReturn = append(toReturn, r[1][0]) - toReturn = append(toReturn, r[1][8]) - toReturn = append(toReturn, r[1][4]) - toReturn = append(toReturn, r[1][12]) - toReturn = append(toReturn, r[1][2]) - toReturn = append(toReturn, r[1][10]) - toReturn = append(toReturn, r[1][6]) - toReturn = append(toReturn, r[1][14]) - toReturn = append(toReturn, r[1][1]) - toReturn = append(toReturn, r[1][9]) - toReturn = append(toReturn, r[1][5]) - toReturn = append(toReturn, r[1][13]) - toReturn = append(toReturn, r[1][3]) - toReturn = append(toReturn, r[1][11]) - toReturn = append(toReturn, r[1][7]) - toReturn = append(toReturn, r[1][15]) - toReturn = append(toReturn, r[0][0]) - toReturn = append(toReturn, r[0][16]) - toReturn = append(toReturn, r[0][8]) - toReturn = append(toReturn, r[0][24]) - toReturn = append(toReturn, r[0][4]) - toReturn = append(toReturn, r[0][20]) - toReturn = append(toReturn, r[0][12]) - toReturn = append(toReturn, r[0][28]) - toReturn = append(toReturn, r[0][2]) - toReturn = append(toReturn, r[0][18]) - toReturn = append(toReturn, r[0][10]) - toReturn = append(toReturn, r[0][26]) - toReturn = append(toReturn, r[0][6]) - toReturn = append(toReturn, r[0][22]) - toReturn = append(toReturn, r[0][14]) - toReturn = append(toReturn, r[0][30]) - toReturn = append(toReturn, r[0][1]) - toReturn = append(toReturn, r[0][17]) - toReturn = append(toReturn, r[0][9]) - toReturn = append(toReturn, r[0][25]) - toReturn = append(toReturn, r[0][5]) - toReturn = append(toReturn, r[0][21]) - toReturn = append(toReturn, r[0][13]) - toReturn = append(toReturn, r[0][29]) - toReturn = append(toReturn, r[0][3]) - toReturn = append(toReturn, r[0][19]) - toReturn = append(toReturn, r[0][11]) - toReturn = append(toReturn, r[0][27]) - toReturn = append(toReturn, r[0][7]) - toReturn = append(toReturn, r[0][23]) - toReturn = append(toReturn, r[0][15]) - toReturn = append(toReturn, r[0][31]) + s = shifter + for k := 0; k < 5; k++ { + s.Square(&s) + } + toReturn[0] = s + s = shifter + for k := 0; k < 4; k++ { + s.Square(&s) + } + toReturn[1] = s + r.Exp(generator, e.SetUint64(uint64(1<<4*1))) + toReturn[2].Mul(&r, &s) + s = shifter + for k := 0; k < 3; k++ { + s.Square(&s) + } + toReturn[3] = s + r.Exp(generator, e.SetUint64(uint64(1<<3*2))) + toReturn[4].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<3*1))) + toReturn[5].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<3*3))) + toReturn[6].Mul(&r, &s) + s = shifter + for k := 0; k < 2; k++ { + s.Square(&s) + } + toReturn[7] = s + r.Exp(generator, e.SetUint64(uint64(1<<2*4))) + toReturn[8].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*2))) + toReturn[9].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*6))) + toReturn[10].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*1))) + toReturn[11].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*5))) + toReturn[12].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*3))) + toReturn[13].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<2*7))) + toReturn[14].Mul(&r, &s) + s = shifter + for k := 0; k < 1; k++ { + s.Square(&s) + } + toReturn[15] = s + r.Exp(generator, e.SetUint64(uint64(1<<1*8))) + toReturn[16].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*4))) + toReturn[17].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*12))) + toReturn[18].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*2))) + toReturn[19].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*10))) + toReturn[20].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*6))) + toReturn[21].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*14))) + toReturn[22].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*1))) + toReturn[23].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*9))) + toReturn[24].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*5))) + toReturn[25].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*13))) + toReturn[26].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*3))) + toReturn[27].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*11))) + toReturn[28].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*7))) + toReturn[29].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<1*15))) + toReturn[30].Mul(&r, &s) + s = shifter + for k := 0; k < 0; k++ { + s.Square(&s) + } + toReturn[31] = s + r.Exp(generator, e.SetUint64(uint64(1<<0*16))) + toReturn[32].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*8))) + toReturn[33].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*24))) + toReturn[34].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*4))) + toReturn[35].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*20))) + toReturn[36].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*12))) + toReturn[37].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*28))) + toReturn[38].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*2))) + toReturn[39].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*18))) + toReturn[40].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*10))) + toReturn[41].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*26))) + toReturn[42].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*6))) + toReturn[43].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*22))) + toReturn[44].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*14))) + toReturn[45].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*30))) + toReturn[46].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*1))) + toReturn[47].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*17))) + toReturn[48].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*9))) + toReturn[49].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*25))) + toReturn[50].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*5))) + toReturn[51].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*21))) + toReturn[52].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*13))) + toReturn[53].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*29))) + toReturn[54].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*3))) + toReturn[55].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*19))) + toReturn[56].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*11))) + toReturn[57].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*27))) + toReturn[58].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*7))) + toReturn[59].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*23))) + toReturn[60].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*15))) + toReturn[61].Mul(&r, &s) + r.Exp(generator, e.SetUint64(uint64(1<<0*31))) + toReturn[62].Mul(&r, &s) return toReturn } diff --git a/ecc/bn254/fr/sis/sis_test.go b/ecc/bn254/fr/sis/sis_test.go index ec5f234e7..480a11277 100644 --- a/ecc/bn254/fr/sis/sis_test.go +++ b/ecc/bn254/fr/sis/sis_test.go @@ -16,14 +16,14 @@ package sis import ( "bytes" + "crypto/rand" "encoding/binary" "encoding/json" "fmt" "io" - "io/ioutil" "math/big" "math/bits" - "math/rand" + "os" "testing" "time" @@ -68,7 +68,7 @@ func TestReference(t *testing.T) { // read the test case file var testCases TestCases - data, err := ioutil.ReadFile("test_cases.json") + data, err := os.ReadFile("test_cases.json") assert.NoError(err, "reading test cases failed") err = json.Unmarshal(data, &testCases) assert.NoError(err, "reading test cases failed") @@ -140,7 +140,7 @@ func TestMulMod(t *testing.T) { // creation of the domain var shift fr.Element shift.SetString("19540430494807482326159819597004422086093766032135589407132600596362845576832") - domain := fft.NewDomain(uint64(size), shift) + domain := fft.NewDomain(uint64(size), fft.WithShift(shift)) // mul mod domain.FFT(p, fft.DIF, fft.OnCoset()) @@ -390,8 +390,7 @@ func TestLimbDecompositionFastPath(t *testing.T) { nValues := bitset.New(uint(size)) // Generate a random buffer - rand.Seed(time.Now().UnixNano()) //#nosec G404 weak rng is fine here - _, err := rand.Read(buf) //#nosec G404 weak rng is fine here + _, err := rand.Read(buf) assert.NoError(err) limbDecomposeBytes8_64(buf, m, mValues) @@ -414,7 +413,7 @@ func TestUnrolledFFT(t *testing.T) { const size = 64 assert := require.New(t) - domain := fft.NewDomain(size, shift) + domain := fft.NewDomain(size, fft.WithShift(shift)) k1 := make([]fr.Element, size) for i := 0; i < size; i++ { @@ -427,7 +426,7 @@ func TestUnrolledFFT(t *testing.T) { domain.FFT(k1, fft.DIF, fft.OnCoset(), fft.WithNbTasks(1)) // unrolled FFT - twiddlesCoset := precomputeTwiddlesCoset(domain.Twiddles, domain.FrMultiplicativeGen) + twiddlesCoset := precomputeTwiddlesCoset(domain.Generator, domain.FrMultiplicativeGen) fft64(k2, twiddlesCoset) // compare results diff --git a/ecc/bw6-633/fr/fft/domain.go b/ecc/bw6-633/fr/fft/domain.go index 8dd5bcad3..795172208 100644 --- a/ecc/bw6-633/fr/fft/domain.go +++ b/ecc/bw6-633/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(13) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bw6-633/fr/fft/fft.go b/ecc/bw6-633/fr/fft/fft.go index 63ba7a84c..af1560b04 100644 --- a/ecc/bw6-633/fr/fft/fft.go +++ b/ecc/bw6-633/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bw6-633/fr/fft/fft_test.go b/ecc/bw6-633/fr/fft/fft_test.go index 190028756..5e87e5516 100644 --- a/ecc/bw6-633/fr/fft/fft_test.go +++ b/ecc/bw6-633/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bw6-633/fr/fft/options.go b/ecc/bw6-633/fr/fft/options.go index 02a6000e5..c77a1c772 100644 --- a/ecc/bw6-633/fr/fft/options.go +++ b/ecc/bw6-633/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bw6-633/fr/iop/ratios.go b/ecc/bw6-633/fr/iop/ratios.go index 977a10571..17913fd97 100644 --- a/ecc/bw6-633/fr/iop/ratios.go +++ b/ecc/bw6-633/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/bw6-756/fr/fft/domain.go b/ecc/bw6-756/fr/fft/domain.go index d1ae2d2dd..2835b05b8 100644 --- a/ecc/bw6-756/fr/fft/domain.go +++ b/ecc/bw6-756/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(5) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bw6-756/fr/fft/fft.go b/ecc/bw6-756/fr/fft/fft.go index 0adffeb25..20a9da3dd 100644 --- a/ecc/bw6-756/fr/fft/fft.go +++ b/ecc/bw6-756/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bw6-756/fr/fft/fft_test.go b/ecc/bw6-756/fr/fft/fft_test.go index ae23f5bcb..8f8db5c72 100644 --- a/ecc/bw6-756/fr/fft/fft_test.go +++ b/ecc/bw6-756/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bw6-756/fr/fft/options.go b/ecc/bw6-756/fr/fft/options.go index 02a6000e5..9091f81fc 100644 --- a/ecc/bw6-756/fr/fft/options.go +++ b/ecc/bw6-756/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bw6-756/fr/iop/ratios.go b/ecc/bw6-756/fr/iop/ratios.go index 9e1541129..449ae8f53 100644 --- a/ecc/bw6-756/fr/iop/ratios.go +++ b/ecc/bw6-756/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/bw6-761/fr/fft/domain.go b/ecc/bw6-761/fr/fft/domain.go index d40c8ed5a..faac86a33 100644 --- a/ecc/bw6-761/fr/fft/domain.go +++ b/ecc/bw6-761/fr/fft/domain.go @@ -17,6 +17,7 @@ package fft import ( + "errors" "io" "math/big" "math/bits" @@ -41,28 +42,33 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -71,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(15) - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil { + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -85,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -96,54 +105,104 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element, nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -153,6 +212,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -194,7 +254,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -210,7 +270,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -218,34 +278,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/ecc/bw6-761/fr/fft/fft.go b/ecc/bw6-761/fr/fft/fft.go index bc32d2165..bcc5f538f 100644 --- a/ecc/bw6-761/fr/fft/fft.go +++ b/ecc/bw6-761/fr/fft/fft.go @@ -19,6 +19,7 @@ package fft import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" "math/bits" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" @@ -40,41 +41,71 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1<> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -144,29 +206,38 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDIFNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } @@ -177,104 +248,170 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == 256 && stage >= twiddlesStartStage { + kerDITNP_256(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w, w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks/(1<<(stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) - fr.Butterfly(&a[0], &a[2]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[5], &a[7]) - fr.Butterfly(&a[0], &a[4]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - fr.Butterfly(&a[1], &a[5]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - fr.Butterfly(&a[2], &a[6]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[3], &a[7]) +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - - fr.Butterfly(&a[0], &a[4]) - fr.Butterfly(&a[1], &a[5]) - fr.Butterfly(&a[2], &a[6]) - fr.Butterfly(&a[3], &a[7]) - a[5].Mul(&a[5], &twiddles[stage+0][1]) - a[6].Mul(&a[6], &twiddles[stage+0][2]) - a[7].Mul(&a[7], &twiddles[stage+0][3]) - fr.Butterfly(&a[0], &a[2]) - fr.Butterfly(&a[1], &a[3]) - fr.Butterfly(&a[4], &a[6]) - fr.Butterfly(&a[5], &a[7]) - a[3].Mul(&a[3], &twiddles[stage+1][1]) - a[7].Mul(&a[7], &twiddles[stage+1][1]) - fr.Butterfly(&a[0], &a[1]) - fr.Butterfly(&a[2], &a[3]) - fr.Butterfly(&a[4], &a[5]) - fr.Butterfly(&a[6], &a[7]) +func kerDIFNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fr.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddles(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddles(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddles(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddles(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddles(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddles(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddles(a[:256], twiddles[stage+0], 0, 128, 128) } diff --git a/ecc/bw6-761/fr/fft/fft_test.go b/ecc/bw6-761/fr/fft/fft_test.go index 30b67a45b..ff8de863d 100644 --- a/ecc/bw6-761/fr/fft/fft_test.go +++ b/ecc/bw6-761/fr/fft/fft_test.go @@ -33,208 +33,217 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + BitReverse(pol) + domain.FFT(pol, DIT) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true + check := true - for i := 1; i <= nbCosets; i++ { + for i := 1; i <= nbCosets; i++ { - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/ecc/bw6-761/fr/fft/options.go b/ecc/bw6-761/fr/fft/options.go index 02a6000e5..b3372f244 100644 --- a/ecc/bw6-761/fr/fft/options.go +++ b/ecc/bw6-761/fr/fft/options.go @@ -16,7 +16,11 @@ package fft -import "runtime" +import ( + "runtime" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -48,7 +52,7 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ coset: false, @@ -59,3 +63,41 @@ func options(opts ...Option) fftConfig { } return opt } + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt +} diff --git a/ecc/bw6-761/fr/iop/ratios.go b/ecc/bw6-761/fr/iop/ratios.go index 59ec755e3..035d558b9 100644 --- a/ecc/bw6-761/fr/iop/ratios.go +++ b/ecc/bw6-761/fr/iop/ratios.go @@ -336,20 +336,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -362,15 +354,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/ecc/stark-curve/g1_test.go b/ecc/stark-curve/g1_test.go index 31abc377f..0e3373eb9 100644 --- a/ecc/stark-curve/g1_test.go +++ b/ecc/stark-curve/g1_test.go @@ -17,8 +17,8 @@ package starkcurve import ( + "crypto/rand" "math/big" - "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" diff --git a/go.mod b/go.mod index 79d762f08..6ad6dbf6e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/consensys/gnark-crypto -go 1.18 +go 1.19 require ( github.com/bits-and-blooms/bitset v1.7.0 diff --git a/internal/generator/fft/template/domain.go.tmpl b/internal/generator/fft/template/domain.go.tmpl index e816d6db2..b31367738 100644 --- a/internal/generator/fft/template/domain.go.tmpl +++ b/internal/generator/fft/template/domain.go.tmpl @@ -4,6 +4,7 @@ import ( "math/bits" "runtime" "sync" + "errors" {{ template "import_fr" . }} {{ template "import_curve" . }} @@ -22,29 +23,35 @@ type Domain struct { FrMultiplicativeGen fr.Element // generator of Fr* FrMultiplicativeGenInv fr.Element + // this is set with the WithoutPrecompute option; + // if true, the domain does some pre-computation and stores it. + // if false, the FFT will compute the twiddles on the fly (this is less CPU efficient, but uses less memory) + withPrecompute bool + // the following slices are not serialized and are (re)computed through domain.preComputeTwiddles() - // Twiddles factor for the FFT using Generator for each stage of the recursive FFT - Twiddles [][]fr.Element + // twiddles factor for the FFT using Generator for each stage of the recursive FFT + twiddles [][]fr.Element - // Twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT - TwiddlesInv [][]fr.Element + // twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT + twiddlesInv [][]fr.Element // we precompute these mostly to avoid the memory intensive bit reverse permutation in the groth16.Prover - // CosetTable u*<1,g,..,g^(n-1)> - CosetTable []fr.Element + // cosetTable u*<1,g,..,g^(n-1)> + cosetTable []fr.Element - // CosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j - CosetTableInv []fr.Element + // cosetTable[i][j] = domain.Generator(i-th)SqrtInv ^ j + cosetTableInv []fr.Element } + // NewDomain returns a subgroup with a power of 2 cardinality // cardinality >= m // shift: when specified, it's the element by which the set of root of unity is shifted. -func NewDomain(m uint64, shift ...fr.Element) *Domain { - +func NewDomain(m uint64, opts ...DomainOption) *Domain { + opt := domainOptions(opts...) domain := &Domain{} x := ecc.NextPowerOfTwo(m) domain.Cardinality = uint64(x) @@ -70,8 +77,8 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.FrMultiplicativeGen.SetUint64(7) {{end}} - if len(shift) != 0 { - domain.FrMultiplicativeGen.Set(&shift[0]) + if opt.shift != nil{ + domain.FrMultiplicativeGen.Set(opt.shift) } domain.FrMultiplicativeGenInv.Inverse(&domain.FrMultiplicativeGen) @@ -84,7 +91,10 @@ func NewDomain(m uint64, shift ...fr.Element) *Domain { domain.CardinalityInv.SetUint64(uint64(x)).Inverse(&domain.CardinalityInv) // twiddle factors - domain.preComputeTwiddles() + domain.withPrecompute = opt.withPrecompute + if domain.withPrecompute { + domain.preComputeTwiddles() + } return domain } @@ -95,54 +105,106 @@ func Generator(m uint64) (fr.Element, error) { return fr.Generator(m) } +// Twiddles returns the twiddles factor for the FFT using Generator for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) Twiddles() ([][]fr.Element, error) { + if d.twiddles == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddles, nil +} + +// TwiddlesInv returns the twiddles factor for the FFT using GeneratorInv for each stage of the recursive FFT +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) TwiddlesInv() ([][]fr.Element, error) { + if d.twiddlesInv == nil { + return nil, errors.New("twiddles not precomputed") + } + return d.twiddlesInv, nil +} + +// CosetTable returns the cosetTable u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTable() ([]fr.Element, error) { + if d.cosetTable == nil { + return nil, errors.New("cosetTable not precomputed") + } + return d.cosetTable, nil +} + +// CosetTableInv returns the cosetTableInv u*<1,g,..,g^(n-1)> +// or an error if the domain was created with the WithoutPrecompute option +func (d *Domain) CosetTableInv() ([]fr.Element, error) { + if d.cosetTableInv == nil { + return nil, errors.New("cosetTableInv not precomputed") + } + return d.cosetTableInv, nil +} + + + func (d *Domain) preComputeTwiddles() { // nb fft stages nbStages := uint64(bits.TrailingZeros64(d.Cardinality)) - d.Twiddles = make([][]fr.Element, nbStages) - d.TwiddlesInv = make([][]fr.Element, nbStages) - d.CosetTable = make([]fr.Element, d.Cardinality) - d.CosetTableInv = make([]fr.Element, d.Cardinality) + d.twiddles = make([][]fr.Element, nbStages) + d.twiddlesInv = make([][]fr.Element, nbStages) + d.cosetTable = make([]fr.Element, d.Cardinality) + d.cosetTableInv = make([]fr.Element, d.Cardinality) var wg sync.WaitGroup - // for each fft stage, we pre compute the twiddle factors - twiddles := func(t [][]fr.Element, omega fr.Element) { - for i := uint64(0); i < nbStages; i++ { - t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) - var w fr.Element - if i == 0 { - w = omega - } else { - w = t[i-1][2] - } - t[i][0] = fr.One() - t[i][1] = w - for j := 2; j < len(t[i]); j++ { - t[i][j].Mul(&t[i][j-1], &w) - } - } - wg.Done() - } - expTable := func(sqrt fr.Element, t []fr.Element) { - t[0] = fr.One() - precomputeExpTable(sqrt, t) + BuildExpTable(sqrt, t) wg.Done() } wg.Add(4) - go twiddles(d.Twiddles, d.Generator) - go twiddles(d.TwiddlesInv, d.GeneratorInv) - go expTable(d.FrMultiplicativeGen, d.CosetTable) - go expTable(d.FrMultiplicativeGenInv, d.CosetTableInv) + go func() { + buildTwiddles(d.twiddles, d.Generator, nbStages) + wg.Done() + }() + go func() { + buildTwiddles(d.twiddlesInv, d.GeneratorInv, nbStages) + wg.Done() + }() + go expTable(d.FrMultiplicativeGen, d.cosetTable) + go expTable(d.FrMultiplicativeGenInv, d.cosetTableInv) wg.Wait() } -func precomputeExpTable(w fr.Element, table []fr.Element) { +func buildTwiddles(t [][]fr.Element, omega fr.Element,nbStages uint64) { + if nbStages == 0 { + return + } + if len(t) != int(nbStages) { + panic("invalid twiddle table") + } + // we just compute the first stage + t[0] = make([]fr.Element, 1+(1<<(nbStages-1))) + BuildExpTable(omega, t[0]) + + // for the next stages, we just iterate on the first stage with larger stride + for i := uint64(1); i < nbStages; i++ { + t[i] = make([]fr.Element, 1+(1<<(nbStages-i-1))) + k := 0 + for j := 0; j < len(t[i]); j++ { + t[i][j] = t[0][k] + k += 1 << i + } + } + +} + +// BuildExpTable precomputes the first n powers of w in parallel +// table[0] = w^0 +// table[1] = w^1 +// ... +func BuildExpTable(w fr.Element, table []fr.Element) { + table[0].SetOne() n := len(table) // see if it makes sense to parallelize exp tables pre-computation @@ -152,6 +214,7 @@ func precomputeExpTable(w fr.Element, table []fr.Element) { } // this ratio roughly correspond to the number of multiplication one can do in place of a Exp operation + // TODO @gbotrel revisit this; Exps in this context will be by a "small power of 2" so faster than this ref ratio. const ratioExpMul = 6000 / 17 if interval < ratioExpMul { @@ -193,7 +256,7 @@ func (d *Domain) WriteTo(w io.Writer) (int64, error) { enc := curve.NewEncoder(w) - toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toEncode := []interface{}{d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toEncode { if err := enc.Encode(v); err != nil { @@ -209,7 +272,7 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { dec := curve.NewDecoder(r) - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} + toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv, &d.withPrecompute} for _, v := range toDecode { if err := dec.Decode(v); err != nil { @@ -217,35 +280,9 @@ func (d *Domain) ReadFrom(r io.Reader) (int64, error) { } } - - // twiddle factors - d.preComputeTwiddles() - - return dec.BytesRead(), nil -} - -// AsyncReadFrom attempts to decode a domain from Reader. It returns a channel that will be closed -// when the precomputation is done. -func (d *Domain) AsyncReadFrom(r io.Reader) (int64, error, chan struct{}) { - - dec := curve.NewDecoder(r) - - toDecode := []interface{}{&d.Cardinality, &d.CardinalityInv, &d.Generator, &d.GeneratorInv, &d.FrMultiplicativeGen, &d.FrMultiplicativeGenInv} - - for _, v := range toDecode { - if err := dec.Decode(v); err != nil { - return dec.BytesRead(), err, nil - } - } - - chDone := make(chan struct{}) - - go func() { - // twiddle factors + if d.withPrecompute { d.preComputeTwiddles() + } - close(chDone) - }() - - return dec.BytesRead(), nil, chDone + return dec.BytesRead(), nil } diff --git a/internal/generator/fft/template/fft.go.tmpl b/internal/generator/fft/template/fft.go.tmpl index 7cde1697c..caf7e373a 100644 --- a/internal/generator/fft/template/fft.go.tmpl +++ b/internal/generator/fft/template/fft.go.tmpl @@ -2,10 +2,14 @@ import ( "math/bits" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" {{ template "import_fr" . }} - ) +{{- /* these params set the size of the kernel we generate & unroll */}} +{{ $sizeKernelLog2 := 8}} +{{ $sizeKernel := shl 1 $sizeKernelLog2}} + // Decimation is used in the FFT call to select decimation in time or in frequency type Decimation uint8 @@ -22,53 +26,85 @@ const butterflyThreshold = 16 // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order func (domain *Domain) FFT(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } // if coset != 0, scale by coset table if opt.coset { if decimation == DIT { // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTable[irev]) + a[i].Mul(&a[i], &cosetTable[irev]) } }, opt.nbTasks) } else { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTable[i]) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + } } - // find the stage where we should stop spawning go routines in our recursive calls - // (ie when we have as many go routines running as we have available CPUs) - maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) - if opt.nbTasks == 1 { - maxSplits = -1 + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddles = make([][]fr.Element, nbStages - twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage))) + buildTwiddles(twiddles, w, uint64(nbStages - twiddlesStartStage)) } switch decimation { case DIF: - difFFT(a, domain.Twiddles, 0, maxSplits, nil, opt.nbTasks) + difFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) case DIT: - ditFFT(a, domain.Twiddles, 0, maxSplits, nil, opt.nbTasks) + ditFFT(a, domain.Generator, twiddles, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) default: panic("not implemented") } } + + // FFTInverse computes (recursively) the inverse discrete Fourier transform of a and stores the result in a // if decimation == DIT (decimation in time), the input must be in bit-reversed order // if decimation == DIF (decimation in frequency), the output will be in bit-reversed order // coset sets the shift of the fft (0 = no shift, standard fft) // len(a) must be a power of 2, and w must be a len(a)th root of unity in field F. func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ...Option) { - opt := options(opts...) + opt := fftOptions(opts...) // find the stage where we should stop spawning go routines in our recursive calls // (ie when we have as many go routines running as we have available CPUs) @@ -76,11 +112,23 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... if opt.nbTasks == 1 { maxSplits = -1 } + + twiddlesInv := domain.twiddlesInv + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + twiddlesInv = make([][]fr.Element, nbStages - twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1 << twiddlesStartStage))) + buildTwiddles(twiddlesInv, w, uint64(nbStages - twiddlesStartStage)) + } + switch decimation { case DIF: - difFFT(a, domain.TwiddlesInv, 0, maxSplits, nil, opt.nbTasks) + difFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) case DIT: - ditFFT(a, domain.TwiddlesInv, 0, maxSplits, nil, opt.nbTasks) + ditFFT(a, domain.GeneratorInv, twiddlesInv, twiddlesStartStage, 0, maxSplits, nil, opt.nbTasks) default: panic("not implemented") } @@ -97,29 +145,48 @@ func (domain *Domain) FFTInverse(a []fr.Element, decimation Decimation, opts ... if decimation == DIT { - parallel.Execute(len(a), func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &domain.CosetTableInv[i]). - Mul(&a[i], &domain.CardinalityInv) - } - }, opt.nbTasks) + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].Mul(&a[i], &domain.cosetTableInv[i]). + Mul(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGenInv + parallel.Execute(len(a), func(start, end int) { + var at fr.Element + at.Exp(c, big.NewInt(int64(start))) + at.Mul(&at, &domain.CardinalityInv) + for i := start; i < end; i++ { + a[i].Mul(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } return } // decimation == DIF, need to access coset table in bit reversed order. + cosetTableInv := domain.cosetTableInv + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTableInv = make([]fr.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGenInv, cosetTableInv) + } parallel.Execute(len(a), func(start, end int) { n := uint64(len(a)) nn := uint64(64 - bits.TrailingZeros64(n)) for i := start; i < end; i++ { irev := int(bits.Reverse64(uint64(i)) >> nn) - a[i].Mul(&a[i], &domain.CosetTableInv[irev]). + a[i].Mul(&a[i], &cosetTableInv[irev]). Mul(&a[i], &domain.CardinalityInv) } }, opt.nbTasks) } -func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { +func difFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } @@ -127,163 +194,220 @@ func difFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDon n := len(a) if n == 1 { return - } else if n == 8 { - kerDIF8(a, twiddles, stage) + } else if n == {{$sizeKernel}} && stage >= twiddlesStartStage { + kerDIFNP_{{$sizeKernel}}(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for i := start; i < end; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddles(a, at,w, start, end, m) + }, nbTasks / (1 << (stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddles(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) } else { - // i == 0 - fr.Butterfly(&a[0], &a[m]) - for i := 1; i < m; i++ { - fr.Butterfly(&a[i], &a[i+m]) - a[i+m].Mul(&a[i+m], &twiddles[stage][i]) + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks / (1 << (stage))) + } else { + innerDIFWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) } } if m == 1 { return } - + nextStage := stage + 1 if stage < maxSplits { chDone := make(chan struct{}, 1) - go difFFT(a[m:n], twiddles, nextStage, maxSplits, chDone, nbTasks) - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - difFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - difFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) + difFFT(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFT(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } } -func ditFFT(a []fr.Element, twiddles [][]fr.Element, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + +func innerDIFWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fr.Butterfly(&a[i], &a[i+m]) + a[i+m].Mul(&a[i+m], &at) + at.Mul(&at, &w) + } +} + + +func ditFFT(a []fr.Element, w fr.Element, twiddles [][]fr.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { if chDone != nil { defer close(chDone) } n := len(a) if n == 1 { return - } else if n == 8 { - kerDIT8(a, twiddles, stage) + } else if n == {{$sizeKernel}} && stage >= twiddlesStartStage { + kerDITNP_{{$sizeKernel}}(a, twiddles, stage-twiddlesStartStage) return } m := n >> 1 nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) if stage < maxSplits { // that's the only time we fire go routines chDone := make(chan struct{}, 1) - go ditFFT(a[m:], twiddles, nextStage, maxSplits, chDone, nbTasks) - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) + go ditFFT(a[m:],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFT(a[0:m],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) <-chDone } else { - ditFFT(a[0:m], twiddles, nextStage, maxSplits, nil, nbTasks) - ditFFT(a[m:n], twiddles, nextStage, maxSplits, nil, nbTasks) - + ditFFT(a[0:m],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFT(a[m:n],nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) } - // if stage < maxSplits, we parallelize this butterfly - // but we have only numCPU / stage cpus available - if (m > butterflyThreshold) && (stage < maxSplits) { - // 1 << stage == estimated used CPUs - numCPU := nbTasks / (1 << (stage)) - parallel.Execute(m, func(start, end int) { - for k := start; k < end; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } - }, numCPU) + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + var at fr.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddles(a, at,w, start, end, m) + }, nbTasks / (1 << (stage))) // 1 << stage == estimated used CPUs + } else { + innerDITWithoutTwiddles(a, w,w, 0, m, m) + } + return + } + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks / (1 << (stage))) } else { + innerDITWithTwiddles(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } +} + + +func innerDITWithTwiddles(a []fr.Element, twiddles []fr.Element, start, end, m int) { + if start == 0 { fr.Butterfly(&a[0], &a[m]) - for k := 1; k < m; k++ { - a[k+m].Mul(&a[k+m], &twiddles[stage][k]) - fr.Butterfly(&a[k], &a[k+m]) - } + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &twiddles[i]) + fr.Butterfly(&a[i], &a[i+m]) + } +} + +func innerDITWithoutTwiddles(a []fr.Element, at, w fr.Element, start, end, m int) { + if start == 0 { + fr.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].Mul(&a[i+m], &at) + fr.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) } } -// kerDIT8 is a kernel that process a FFT of size 8 -func kerDIT8(a []fr.Element, twiddles [][]fr.Element, stage int) { - {{- /* notes: - this function can be updated with larger n - nbSteps must be updated too such as 1 << nbSteps == n - butterflies and multiplication are separated for size n = 8, must check perf for larger n - */}} - {{ $n := 2}} + +func kerDIFNP_{{$sizeKernel}}(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + {{ $n := shl 1 $sizeKernelLog2}} {{ $m := div $n 2}} - {{ $split := 4}} - {{- range $step := reverse (iterate 0 3)}} + {{ $split := 1}} + {{- range $step := iterate 0 $sizeKernelLog2}} {{- $offset := 0}} - {{- range $s := reverse (iterate 0 $split)}} - {{- range $i := iterate 0 $m}} - {{- $j := add $i $offset}} - {{- $k := add $j $m}} - {{- if ne $i 0}} - a[{{$k}}].Mul(&a[{{$k}}], &twiddles[stage+{{$step}}][{{$i}}]) + + {{- $bound := mul $split $n}} + {{- if eq $bound $n}} + innerDIFWithTwiddles(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- else}} + for offset := 0; offset < {{$bound}}; offset += {{$n}} { + {{- if eq $m 1}} + fr.Butterfly(&a[offset], &a[offset+1]) + {{- else}} + innerDIFWithTwiddles(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) {{- end}} - fr.Butterfly(&a[{{$j}}], &a[{{$k}}]) - {{- end}} - {{- $offset = add $offset $n}} + } {{- end}} - - {{- $n = mul $n 2}} + + {{- $n = div $n 2}} {{- $m = div $n 2}} - {{- $split = div $split 2}} + {{- $split = mul $split 2}} {{- end}} } -// kerDIF8 is a kernel that process a FFT of size 8 -func kerDIF8(a []fr.Element, twiddles [][]fr.Element, stage int) { - {{- /* notes: - this function can be updated with larger n - nbSteps must be updated too such as 1 << nbSteps == n - butterflies and multiplication are separated for size n = 8, must check perf for larger n - */}} - {{ $n := 8}} + +func kerDITNP_{{$sizeKernel}}(a []fr.Element, twiddles [][]fr.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + {{ $n := 2}} {{ $m := div $n 2}} - {{ $split := 1}} - {{- range $step := iterate 0 3}} + {{ $split := div (shl 1 $sizeKernelLog2) 2}} + {{- range $step := reverse (iterate 0 $sizeKernelLog2)}} {{- $offset := 0}} - {{- range $s := iterate 0 $split}} - {{- range $i := iterate 0 $m}} - {{- $j := add $i $offset}} - {{- $k := add $j $m}} - fr.Butterfly(&a[{{$j}}], &a[{{$k}}]) - {{- end}} - {{- $offset = add $offset $n}} - {{- end}} - {{- $offset := 0}} - {{- range $s := iterate 0 $split}} - {{- range $i := iterate 0 $m}} - {{- $j := add $i $offset}} - {{- $k := add $j $m}} - {{- if ne $i 0}} - a[{{$k}}].Mul(&a[{{$k}}], &twiddles[stage+{{$step}}][{{$i}}]) + {{- $bound := mul $split $n}} + {{- if eq $bound $n}} + innerDITWithTwiddles(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- else}} + for offset := 0; offset < {{$bound}}; offset += {{$n}} { + {{- if eq $m 1}} + fr.Butterfly(&a[offset], &a[offset+1]) + {{- else}} + innerDITWithTwiddles(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) {{- end}} - {{- end}} - {{- $offset = add $offset $n}} + } {{- end}} - {{- $n = div $n 2}} + + {{- $n = mul $n 2}} {{- $m = div $n 2}} - {{- $split = mul $split 2}} + {{- $split = div $split 2}} {{- end}} } - diff --git a/internal/generator/fft/template/options.go.tmpl b/internal/generator/fft/template/options.go.tmpl index 450da97bf..1eb8407e0 100644 --- a/internal/generator/fft/template/options.go.tmpl +++ b/internal/generator/fft/template/options.go.tmpl @@ -1,4 +1,7 @@ -import "runtime" +import ( + "runtime" + {{ template "import_fr" . }} +) // Option defines option for altering the behavior of FFT methods. // See the descriptions of functions returning instances of this type for @@ -6,7 +9,7 @@ import "runtime" type Option func(*fftConfig) type fftConfig struct { - coset bool + coset bool nbTasks int } @@ -30,14 +33,52 @@ func WithNbTasks(nbTasks int) Option { } // default options -func options(opts ...Option) fftConfig { +func fftOptions(opts ...Option) fftConfig { // apply options opt := fftConfig{ - coset: false, + coset: false, nbTasks: runtime.NumCPU(), } for _, option := range opts { option(&opt) } return opt +} + +// DomainOption defines option for altering the definition of the FFT domain +// See the descriptions of functions returning instances of this type for +// particular options. +type DomainOption func(*domainConfig) + +type domainConfig struct { + shift *fr.Element + withPrecompute bool +} + +// WithShift sets the FrMultiplicativeGen of the domain. +// Default is generator of the largest 2-adic subgroup. +func WithShift(shift fr.Element) DomainOption { + return func(opt *domainConfig) { + opt.shift = new(fr.Element).Set(&shift) + } +} + +// WithoutPrecompute disables precomputation of twiddles in the domain. +// When this option is set, FFTs will be slower, but will use less memory. +func WithoutPrecompute() DomainOption { + return func(opt *domainConfig) { + opt.withPrecompute = false + } +} + +// default options +func domainOptions(opts ...DomainOption) domainConfig { + // apply options + opt := domainConfig{ + withPrecompute: true, + } + for _, option := range opts { + option(&opt) + } + return opt } \ No newline at end of file diff --git a/internal/generator/fft/template/tests/fft.go.tmpl b/internal/generator/fft/template/tests/fft.go.tmpl index 5fb9742e9..6fb288557 100644 --- a/internal/generator/fft/template/tests/fft.go.tmpl +++ b/internal/generator/fft/template/tests/fft.go.tmpl @@ -16,208 +16,219 @@ func TestFFT(t *testing.T) { nbCosets := 3 domainWithPrecompute := NewDomain(maxSize) + domainWithoutPrecompute := NewDomain(maxSize, WithoutPrecompute()) parameters := gopter.DefaultTestParameters() parameters.MinSuccessfulTests = 5 properties := gopter.NewProperties(parameters) - properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFT(pol, DIF) - BitReverse(pol) + domain.FFT(pol, DIF) + BitReverse(pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - eval := evaluatePolynomial(backupPol, sample) + eval := evaluatePolynomial(backupPol, sample) - return eval.Equal(&pol[ithpower]) + return eval.Equal(&pol[ithpower]) - }, - gen.IntRange(0, maxSize-1), - )) + }, + gen.IntRange(0, maxSize-1), + )) - properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - domainWithPrecompute.FFT(pol, DIF, OnCoset()) - BitReverse(pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))). - Mul(&sample, &domainWithPrecompute.FrMultiplicativeGen) + domain.FFT(pol, DIF, OnCoset()) + BitReverse(pol) - eval := evaluatePolynomial(backupPol, sample) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) - return eval.Equal(&pol[ithpower]) + eval := evaluatePolynomial(backupPol, sample) - }, - gen.IntRange(0, maxSize-1), - )) + return eval.Equal(&pol[ithpower]) - properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + }, + gen.IntRange(0, maxSize-1), + )) - // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result - func(ithpower int) bool { + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - sample := domainWithPrecompute.Generator - sample.Exp(sample, big.NewInt(int64(ithpower))) + BitReverse(pol) + domain.FFT(pol, DIT) - eval := evaluatePolynomial(backupPol, sample) + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) - return eval.Equal(&pol[ithpower]) + eval := evaluatePolynomial(backupPol, sample) - }, - gen.IntRange(0, maxSize-1), - )) + return eval.Equal(&pol[ithpower]) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + }, + gen.IntRange(0, maxSize-1), + )) - func() bool { + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + func() bool { - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT) - domainWithPrecompute.FFTInverse(pol, DIF) - BitReverse(pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - check := true - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) - } - return check - }, - )) + BitReverse(pol) + domain.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + BitReverse(pol) - properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) - func() bool { + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on cosets", prop.ForAll( - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + func() bool { - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - check := true + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - for i := 1; i <= nbCosets; i++ { + check := true - BitReverse(pol) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - BitReverse(pol) + for i := 1; i <= nbCosets; i++ { - for i := 0; i < len(pol); i++ { - check = check && pol[i].Equal(&backupPol[i]) + BitReverse(pol) + domain.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + BitReverse(pol) + + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } } - } - return check - }, - )) + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF) - domainWithPrecompute.FFT(pol, DIT) + domain.FFTInverse(pol, DIF) + domain.FFT(pol, DIT) - check := true - for i := 0; i < len(pol); i++ { - check = check && (pol[i] == backupPol[i]) - } - return check - }, - )) + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) - properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( - func() bool { + func() bool { - pol := make([]fr.Element, maxSize) - backupPol := make([]fr.Element, maxSize) + pol := make([]fr.Element, maxSize) + backupPol := make([]fr.Element, maxSize) - for i := 0; i < maxSize; i++ { - pol[i].SetRandom() - } - copy(backupPol, pol) + for i := 0; i < maxSize; i++ { + pol[i].SetRandom() + } + copy(backupPol, pol) - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset()) - domainWithPrecompute.FFT(pol, DIT, OnCoset()) + domain.FFTInverse(pol, DIF, OnCoset()) + domain.FFT(pol, DIT, OnCoset()) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - // compute with nbTasks == 1 - domainWithPrecompute.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) - domainWithPrecompute.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) + // compute with nbTasks == 1 + domain.FFTInverse(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFT(pol, DIT, OnCoset(), WithNbTasks(1)) - for i := 0; i < len(pol); i++ { - if !(pol[i].Equal(&backupPol[i])) { - return false + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } } - } - return true - }, - )) + return true + }, + )) - properties.TestingRun(t, gopter.ConsoleReporter(false)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } } diff --git a/internal/generator/iop/template/ratios.go.tmpl b/internal/generator/iop/template/ratios.go.tmpl index dcebd51f4..8dab0d14c 100644 --- a/internal/generator/iop/template/ratios.go.tmpl +++ b/internal/generator/iop/template/ratios.go.tmpl @@ -318,20 +318,12 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen res := make([]fr.Element, uint64(nbCopies)*domain.Cardinality) sizePoly := int(domain.Cardinality) - // check if we can reuse the pre-computed twiddles from the domain. - if len(domain.Twiddles) == 0 || len(domain.Twiddles[0]) < sizePoly/2 { - res[0].SetOne() - if len(res) > 1 { - res[1].Set(&domain.Generator) - for i := 2; i < len(res); i++ { - res[i].Mul(&res[i-1], &domain.Generator) - } - } - } else { - // re-use pre-computed twiddles from the domain. - copy(res, domain.Twiddles[0]) - for i := (sizePoly / 2) - 1; i < sizePoly-1; i++ { - res[i+1].Mul(&res[i], &domain.Generator) + // TODO @gbotrel check if we can reuse the pre-computed twiddles from the domain. + res[0].SetOne() + if len(res) > 1 { + res[1].Set(&domain.Generator) + for i := 2; i < len(res); i++ { + res[i].Mul(&res[i-1], &domain.Generator) } } @@ -344,15 +336,7 @@ func getSupportIdentityPermutation(nbCopies int, domain *fft.Domain) []fr.Elemen i := i var coset fr.Element - if i == 1 { - coset = domain.FrMultiplicativeGen - } else { - if len(domain.CosetTable) > i { - coset = domain.CosetTable[i] - } else { - coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) - } - } + coset.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(i))) go func() { parallel.Execute(sizePoly, func(start, end int) { diff --git a/internal/generator/main.go b/internal/generator/main.go index 3c9903b8a..389f96c2e 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -92,7 +92,7 @@ func main() { // generate fft on fr assertNoError(fft.Generate(conf, filepath.Join(curveDir, "fr", "fft"), bgen)) - if conf.Equal(config.BN254) { + if conf.Equal(config.BN254) || conf.Equal(config.BLS12_377) { assertNoError(sis.Generate(conf, filepath.Join(curveDir, "fr", "sis"), bgen)) } diff --git a/internal/generator/sis/template/fft.go.tmpl b/internal/generator/sis/template/fft.go.tmpl index 36cdace6c..91938f81d 100644 --- a/internal/generator/sis/template/fft.go.tmpl +++ b/internal/generator/sis/template/fft.go.tmpl @@ -1,5 +1,6 @@ import ( - "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr" + "math/big" ) // fft64 is generated by gnark-crypto and contains the unrolled code for FFT (DIF) on 64 elements @@ -48,28 +49,31 @@ func fft64(a []fr.Element, twiddlesCoset []fr.Element) { // precomputeTwiddlesCoset precomputes twiddlesCoset from twiddles and coset table // it then return all elements in the correct order for the unrolled FFT. -func precomputeTwiddlesCoset(twiddles [][]fr.Element, shifter fr.Element) []fr.Element { - r := make([][]fr.Element, len(twiddles)) - for i := 0; i < len(twiddles); i++ { - r[i] = make([]fr.Element, len(twiddles[i])) - s := shifter - for k:=0; k