Skip to content

Commit

Permalink
Merge pull request #275 from balena/pull-out-random
Browse files Browse the repository at this point in the history
Exposed random sources
  • Loading branch information
ZhAnGeek authored Jan 12, 2024
2 parents f67a429 + 6c233c6 commit 87f7e12
Show file tree
Hide file tree
Showing 49 changed files with 369 additions and 306 deletions.
7 changes: 4 additions & 3 deletions common/hash_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package common_test

import (
"crypto/rand"
"math/big"
"reflect"
"testing"
Expand All @@ -15,12 +16,12 @@ import (
)

func TestRejectionSample(t *testing.T) {
curveQ := common.GetRandomPrimeInt(256)
randomQ := common.MustGetRandomInt(64)
curveQ := common.GetRandomPrimeInt(rand.Reader, 256)
randomQ := common.MustGetRandomInt(rand.Reader, 64)
hash := common.SHA512_256iOne(big.NewInt(123))
rs1 := common.RejectionSample(curveQ, hash)
rs2 := common.RejectionSample(randomQ, hash)
rs3 := common.RejectionSample(common.MustGetRandomInt(64), hash)
rs3 := common.RejectionSample(common.MustGetRandomInt(rand.Reader, 64), hash)
type args struct {
q *big.Int
eHash *big.Int
Expand Down
33 changes: 17 additions & 16 deletions common/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
package common

import (
"crypto/rand"
cryptorand "crypto/rand"
"fmt"
"io"
"math/big"

"github.com/pkg/errors"
Expand All @@ -18,8 +19,8 @@ const (
mustGetRandomIntMaxBits = 5000
)

// MustGetRandomInt panics if it is unable to gather entropy from `rand.Reader` or when `bits` is <= 0
func MustGetRandomInt(bits int) *big.Int {
// MustGetRandomInt panics if it is unable to gather entropy from `io.Reader` or when `bits` is <= 0
func MustGetRandomInt(rand io.Reader, bits int) *big.Int {
if bits <= 0 || mustGetRandomIntMaxBits < bits {
panic(fmt.Errorf("MustGetRandomInt: bits should be positive, non-zero and less than %d", mustGetRandomIntMaxBits))
}
Expand All @@ -28,37 +29,37 @@ func MustGetRandomInt(bits int) *big.Int {
max = max.Exp(two, big.NewInt(int64(bits)), nil).Sub(max, one)

// Generate cryptographically strong pseudo-random int between 0 - max
n, err := rand.Int(rand.Reader, max)
n, err := cryptorand.Int(rand, max)
if err != nil {
panic(errors.Wrap(err, "rand.Int failure in MustGetRandomInt!"))
}
return n
}

func GetRandomPositiveInt(lessThan *big.Int) *big.Int {
func GetRandomPositiveInt(rand io.Reader, lessThan *big.Int) *big.Int {
if lessThan == nil || zero.Cmp(lessThan) != -1 {
return nil
}
var try *big.Int
for {
try = MustGetRandomInt(lessThan.BitLen())
try = MustGetRandomInt(rand, lessThan.BitLen())
if try.Cmp(lessThan) < 0 {
break
}
}
return try
}

func GetRandomPrimeInt(bits int) *big.Int {
func GetRandomPrimeInt(rand io.Reader, bits int) *big.Int {
if bits <= 0 {
return nil
}
try, err := rand.Prime(rand.Reader, bits)
try, err := cryptorand.Prime(rand, bits)
if err != nil ||
try.Cmp(zero) == 0 {
// fallback to older method
for {
try = MustGetRandomInt(bits)
try = MustGetRandomInt(rand, bits)
if probablyPrime(try) {
break
}
Expand All @@ -69,13 +70,13 @@ func GetRandomPrimeInt(bits int) *big.Int {

// Generate a random element in the group of all the elements in Z/nZ that
// has a multiplicative inverse.
func GetRandomPositiveRelativelyPrimeInt(n *big.Int) *big.Int {
func GetRandomPositiveRelativelyPrimeInt(rand io.Reader, n *big.Int) *big.Int {
if n == nil || zero.Cmp(n) != -1 {
return nil
}
var try *big.Int
for {
try = MustGetRandomInt(n.BitLen())
try = MustGetRandomInt(rand, n.BitLen())
if IsNumberInMultiplicativeGroup(n, try) {
break
}
Expand All @@ -96,24 +97,24 @@ func IsNumberInMultiplicativeGroup(n, v *big.Int) bool {
// THIS METHOD ONLY WORKS IF N IS THE PRODUCT OF TWO SAFE PRIMES!
//
// https://github.com/didiercrunch/paillier/blob/d03e8850a8e4c53d04e8016a2ce8762af3278b71/utils.go#L39
func GetRandomGeneratorOfTheQuadraticResidue(n *big.Int) *big.Int {
f := GetRandomPositiveRelativelyPrimeInt(n)
func GetRandomGeneratorOfTheQuadraticResidue(rand io.Reader, n *big.Int) *big.Int {
f := GetRandomPositiveRelativelyPrimeInt(rand, n)
fSq := new(big.Int).Mul(f, f)
return fSq.Mod(fSq, n)
}

// GetRandomQuadraticNonResidue returns a quadratic non residue of odd n.
func GetRandomQuadraticNonResidue(n *big.Int) *big.Int {
func GetRandomQuadraticNonResidue(rand io.Reader, n *big.Int) *big.Int {
for {
w := GetRandomPositiveInt(n)
w := GetRandomPositiveInt(rand, n)
if big.Jacobi(w, n) == -1 {
return w
}
}
}

// GetRandomBytes returns random bytes of length.
func GetRandomBytes(length int) ([]byte, error) {
func GetRandomBytes(rand io.Reader, length int) ([]byte, error) {
// Per [BIP32], the seed must be in range [MinSeedBytes, MaxSeedBytes].
if length <= 0 {
return nil, errors.New("invalid length")
Expand Down
13 changes: 7 additions & 6 deletions common/random_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package common_test

import (
"crypto/rand"
"math/big"
"testing"

Expand All @@ -20,28 +21,28 @@ const (
)

func TestGetRandomInt(t *testing.T) {
rnd := common.MustGetRandomInt(randomIntBitLen)
rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen)
assert.NotZero(t, rnd, "rand int should not be zero")
}

func TestGetRandomPositiveInt(t *testing.T) {
rnd := common.MustGetRandomInt(randomIntBitLen)
rndPos := common.GetRandomPositiveInt(rnd)
rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen)
rndPos := common.GetRandomPositiveInt(rand.Reader, rnd)
assert.NotZero(t, rndPos, "rand int should not be zero")
assert.True(t, rndPos.Cmp(big.NewInt(0)) == 1, "rand int should be positive")
}

func TestGetRandomPositiveRelativelyPrimeInt(t *testing.T) {
rnd := common.MustGetRandomInt(randomIntBitLen)
rndPosRP := common.GetRandomPositiveRelativelyPrimeInt(rnd)
rnd := common.MustGetRandomInt(rand.Reader, randomIntBitLen)
rndPosRP := common.GetRandomPositiveRelativelyPrimeInt(rand.Reader, rnd)
assert.NotZero(t, rndPosRP, "rand int should not be zero")
assert.True(t, common.IsNumberInMultiplicativeGroup(rnd, rndPosRP))
assert.True(t, rndPosRP.Cmp(big.NewInt(0)) == 1, "rand int should be positive")
// TODO test for relative primeness
}

func TestGetRandomPrimeInt(t *testing.T) {
prime := common.GetRandomPrimeInt(randomIntBitLen)
prime := common.GetRandomPrimeInt(rand.Reader, randomIntBitLen)
assert.NotZero(t, prime, "rand prime should not be zero")
assert.True(t, prime.ProbablyPrime(50), "rand prime should be prime")
}
63 changes: 31 additions & 32 deletions common/safe_prime.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package common

import (
"context"
"crypto/rand"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -125,7 +124,7 @@ var ErrGeneratorCancelled = fmt.Errorf("generator work cancelled")
// This function generates safe primes of at least 6 `bitLen`. For every
// generated safe prime, the two most significant bits are always set to `1`
// - we don't want the generated number to be too small.
func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, concurrency int) ([]*GermainSafePrime, error) {
func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, concurrency int, rand io.Reader) ([]*GermainSafePrime, error) {
if bitLen < 6 {
return nil, errors.New("safe prime size must be at least 6 bits")
}
Expand All @@ -149,7 +148,7 @@ func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, c
for i := 0; i < concurrency; i++ {
waitGroup.Add(1)
runGenPrimeRoutine(
generatorCtx, primeCh, errCh, waitGroup, rand.Reader, bitLen,
generatorCtx, primeCh, errCh, waitGroup, rand, bitLen,
)
}

Expand All @@ -175,35 +174,35 @@ func GetRandomSafePrimesConcurrent(ctx context.Context, bitLen, numPrimes int, c
// a bit length equal to `pBitLen-1`.
//
// The algorithm is as follows:
// 1. Generate a random odd number `q` of length `pBitLen-1` with two the most
// significant bits set to `1`.
// 2. Execute preliminary primality test on `q` checking whether it is coprime
// to all the elements of `smallPrimes`. It allows to eliminate trivial
// cases quickly, when `q` is obviously no prime, without running an
// expensive final primality tests.
// If `q` is coprime to all of the `smallPrimes`, then go to the point 3.
// If not, add `2` and try again. Do it at most 10 times.
// 3. Check the potentially prime `q`, whether `q = 1 (mod 3)`. This will
// happen for 50% of cases.
// If it is, then `p = 2q+1` will be a multiple of 3, so it will be obviously
// not a prime number. In this case, add `2` and try again. Do it at most 10
// times. If `q != 1 (mod 3)`, go to the point 4.
// 4. Now we know `q` is potentially prime and `p = 2q+1` is not a multiple of
// 3. We execute a preliminary primality test on `p`, checking whether
// it is coprime to all the elements of `smallPrimes` just like we did for
// `q` in point 2. If `p` is not coprime to at least one element of the
// `smallPrimes`, then go back to point 1.
// If `p` is coprime to all the elements of `smallPrimes`, go to point 5.
// 5. At this point, we know `q` is potentially prime, and `p=2q+1` is also
// potentially prime. We need to execute a final primality test for `q`.
// We apply Miller-Rabin and Baillie-PSW tests. If they succeed, it means
// that `q` is prime with a very high probability. Knowing `q` is prime,
// we use Pocklington's criterion to prove the primality of `p=2q+1`, that
// is, we execute Fermat primality test to base 2 checking whether
// `2^{p-1} = 1 (mod p)`. It's significantly faster than running full
// Miller-Rabin and Baillie-PSW for `p`.
// If `q` and `p` are found to be prime, return them as a result. If not, go
// back to the point 1.
// 1. Generate a random odd number `q` of length `pBitLen-1` with two the most
// significant bits set to `1`.
// 2. Execute preliminary primality test on `q` checking whether it is coprime
// to all the elements of `smallPrimes`. It allows to eliminate trivial
// cases quickly, when `q` is obviously no prime, without running an
// expensive final primality tests.
// If `q` is coprime to all of the `smallPrimes`, then go to the point 3.
// If not, add `2` and try again. Do it at most 10 times.
// 3. Check the potentially prime `q`, whether `q = 1 (mod 3)`. This will
// happen for 50% of cases.
// If it is, then `p = 2q+1` will be a multiple of 3, so it will be obviously
// not a prime number. In this case, add `2` and try again. Do it at most 10
// times. If `q != 1 (mod 3)`, go to the point 4.
// 4. Now we know `q` is potentially prime and `p = 2q+1` is not a multiple of
// 3. We execute a preliminary primality test on `p`, checking whether
// it is coprime to all the elements of `smallPrimes` just like we did for
// `q` in point 2. If `p` is not coprime to at least one element of the
// `smallPrimes`, then go back to point 1.
// If `p` is coprime to all the elements of `smallPrimes`, go to point 5.
// 5. At this point, we know `q` is potentially prime, and `p=2q+1` is also
// potentially prime. We need to execute a final primality test for `q`.
// We apply Miller-Rabin and Baillie-PSW tests. If they succeed, it means
// that `q` is prime with a very high probability. Knowing `q` is prime,
// we use Pocklington's criterion to prove the primality of `p=2q+1`, that
// is, we execute Fermat primality test to base 2 checking whether
// `2^{p-1} = 1 (mod p)`. It's significantly faster than running full
// Miller-Rabin and Baillie-PSW for `p`.
// If `q` and `p` are found to be prime, return them as a result. If not, go
// back to the point 1.
func runGenPrimeRoutine(
ctx context.Context,
primeCh chan<- *GermainSafePrime,
Expand Down
3 changes: 2 additions & 1 deletion common/safe_prime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package common

import (
"context"
"crypto/rand"
"math/big"
"runtime"
"testing"
Expand Down Expand Up @@ -45,7 +46,7 @@ func Test_Validate_Bad(t *testing.T) {
func TestGetRandomGermainPrimeConcurrent(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
defer cancel()
sgps, err := GetRandomSafePrimesConcurrent(ctx, 1024, 2, runtime.NumCPU())
sgps, err := GetRandomSafePrimesConcurrent(ctx, 1024, 2, runtime.NumCPU(), rand.Reader)
assert.NoError(t, err)
assert.Equal(t, 2, len(sgps))
for _, sgp := range sgps {
Expand Down
5 changes: 3 additions & 2 deletions crypto/commitments/commitment.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package commitments

import (
"io"
"math/big"

"github.com/bnb-chain/tss-lib/v2/common"
Expand Down Expand Up @@ -43,8 +44,8 @@ func NewHashCommitmentWithRandomness(r *big.Int, secrets ...*big.Int) *HashCommi
return cmt
}

func NewHashCommitment(secrets ...*big.Int) *HashCommitDecommit {
r := common.MustGetRandomInt(HashLength) // r
func NewHashCommitment(rand io.Reader, secrets ...*big.Int) *HashCommitDecommit {
r := common.MustGetRandomInt(rand, HashLength) // r
return NewHashCommitmentWithRandomness(r, secrets...)
}

Expand Down
5 changes: 3 additions & 2 deletions crypto/commitments/commitment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package commitments_test

import (
"crypto/rand"
"math/big"
"testing"

Expand All @@ -19,7 +20,7 @@ func TestCreateVerify(t *testing.T) {
one := big.NewInt(1)
zero := big.NewInt(0)

commitment := NewHashCommitment(zero, one)
commitment := NewHashCommitment(rand.Reader, zero, one)
pass := commitment.Verify()

assert.True(t, pass, "must pass")
Expand All @@ -29,7 +30,7 @@ func TestDeCommit(t *testing.T) {
one := big.NewInt(1)
zero := big.NewInt(0)

commitment := NewHashCommitment(zero, one)
commitment := NewHashCommitment(rand.Reader, zero, one)
pass, secrets := commitment.DeCommit()

assert.True(t, pass, "must pass")
Expand Down
9 changes: 4 additions & 5 deletions crypto/dlnproof/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package dlnproof

import (
"fmt"
"io"
"math/big"

"github.com/bnb-chain/tss-lib/v2/common"
Expand All @@ -28,17 +29,15 @@ type (
}
)

var (
one = big.NewInt(1)
)
var one = big.NewInt(1)

func NewDLNProof(h1, h2, x, p, q, N *big.Int) *Proof {
func NewDLNProof(h1, h2, x, p, q, N *big.Int, rand io.Reader) *Proof {
pMulQ := new(big.Int).Mul(p, q)
modN, modPQ := common.ModInt(N), common.ModInt(pMulQ)
a := make([]*big.Int, Iterations)
alpha := [Iterations]*big.Int{}
for i := range alpha {
a[i] = common.GetRandomPositiveInt(pMulQ)
a[i] = common.GetRandomPositiveInt(rand, pMulQ)
alpha[i] = modN.Exp(h1, a[i])
}
msg := append([]*big.Int{h1, h2, N}, alpha[:]...)
Expand Down
Loading

0 comments on commit 87f7e12

Please sign in to comment.