Skip to content

Commit

Permalink
refactoring random generation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Eduardo committed Dec 20, 2019
1 parent 5071687 commit a0a4911
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 177 deletions.
9 changes: 5 additions & 4 deletions key.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package tcrsa

import (
"crypto/rand"
"crypto/rsa"
"fmt"
"math/big"
Expand Down Expand Up @@ -97,7 +98,7 @@ func NewKey(bitSize int, k, l uint16, args *KeyMetaArgs) (shares KeyShareList, m
p.Set(args.P)
pr.Sub(p, big.NewInt(1)).Div(pr, big.NewInt(2))
} else {
if p, pr, err = generateSafePrimes(pPrimeSize, randomDev); err != nil {
if p, pr, err = generateSafePrimes(pPrimeSize, rand.Reader); err != nil {
return
}
}
Expand All @@ -110,7 +111,7 @@ func NewKey(bitSize int, k, l uint16, args *KeyMetaArgs) (shares KeyShareList, m
q.Set(args.Q)
qr.Sub(q, big.NewInt(1)).Div(qr, big.NewInt(2))
} else {
if q, qr, err = generateSafePrimes(qPrimeSize, randomDev); err != nil {
if q, qr, err = generateSafePrimes(qPrimeSize, rand.Reader); err != nil {
return
}
}
Expand Down Expand Up @@ -143,7 +144,7 @@ func NewKey(bitSize int, k, l uint16, args *KeyMetaArgs) (shares KeyShareList, m
// generate v
if args.R == nil {
for divisor.Cmp(big.NewInt(1)) != 0 {
r, err = randomDev(n.BitLen())
r, err = randInt(n.BitLen())
if err != nil {
return
}
Expand All @@ -165,7 +166,7 @@ func NewKey(bitSize int, k, l uint16, args *KeyMetaArgs) (shares KeyShareList, m
// generate u
if args.U == nil {
for cond := true; cond; cond = big.Jacobi(vku, n) != -1 {
vku, err = randomDev(n.BitLen())
vku, err = randInt(n.BitLen())
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion key_share.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (keyShare KeyShare) Sign(doc []byte, hashType crypto.Hash, info *KeyMeta) (
xi2.Exp(xi, big.NewInt(2), n)

// r = abs(random(bytes_len))
r, err := randomDev(n.BitLen() + 2*hashType.Size()*8)
r, err := randInt(n.BitLen() + 2*hashType.Size()*8)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion polynomial.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func createRandomPolynomial(d int, x0, m *big.Int) (polynomial, error) {
poly[0].Set(x0)

for i := 1; i < len(poly); i++ {
rand, err := randomDev(bitLen)
rand, err := randInt(bitLen)
if err != nil {
return polynomial{}, err
}
Expand Down
122 changes: 15 additions & 107 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ package tcrsa

import (
"crypto/rand"
mathRand "math/rand"
"fmt"
"io"
"math/big"
)

// Number of Miller-Rabin tests
const c = 25
const c = 20

// randomDev is a function which generates a random big number, using crypto/rand
// randInt is a function which generates a random big number, using crypto/rand
// crypto-secure Golang library.
func randomDev(bitLen int) (randNum *big.Int, err error) {
func randInt(bitLen int) (randNum *big.Int, err error) {
randNum = big.NewInt(0)
if bitLen <= 0 {
err = fmt.Errorf("bitlen should be greater than 0, but it is %d", bitLen)
Expand Down Expand Up @@ -43,117 +43,25 @@ func randomDev(bitLen int) (randNum *big.Int, err error) {
return
}

// randomFixed returns a seeded pseudorandom function that returns a random number of bitLen bits.
func randomFixed(seed int64) func(int) (*big.Int, error) {
seededRand := mathRand.New(mathRand.NewSource(seed))
return func(bitLen int) (randNum *big.Int, err error) {
randNum = big.NewInt(0)
if bitLen <= 0 {
err = fmt.Errorf("bitlen should be greater than 0, but it is %d", bitLen)
return
}
byteLen := bitLen / 8
if bitLen % 8 != 0 {
byteLen++
}
rawRand := make([]byte, byteLen)
for randNum.BitLen() == 0 || randNum.BitLen() > bitLen {
_, err = seededRand.Read(rawRand)
if err != nil {
return
}
randNum.SetBytes(rawRand)
// set MSBs to 0 to get a bitLen equal to bitLen param.
for bit := bitLen; bit < randNum.BitLen(); bit++ {
randNum.SetBit(randNum, bit, 0)
}
}

if randNum.BitLen() == 0 || randNum.BitLen() > bitLen {
err = fmt.Errorf("random number returned should have length at most %d, but its length is %d", bitLen, randNum.BitLen())
return
}
return
}
}

// randomPrime returns a random prime of length bitLen, using a given random function randFn.
func randomPrime(bitLen int, randFn func(int) (*big.Int, error)) (randPrime *big.Int, err error) {
randPrime = new(big.Int)

if randFn == nil {
err = fmt.Errorf("random function cannot be nil")
return
}
if bitLen <= 0 {
err = fmt.Errorf("bit length must be positive")
return
}

for randPrime.BitLen() == 0 || randPrime.BitLen() > bitLen {
randPrime, err = randFn(bitLen)
if err != nil {
return
}
setAsNextPrime(randPrime, c)
}

if randPrime.BitLen() == 0 || randPrime.BitLen() > bitLen {
err = fmt.Errorf("random number returned should have length at most %d, but its length is %d", bitLen, randPrime.BitLen())
return
}

if !randPrime.ProbablyPrime(c) {
err = fmt.Errorf("random number returned is not prime")
return
}
return
}

// setAsNextPrime edits the number as the next prime number from it, checking for its prime condition
// using ProbablyPrime function.
func setAsNextPrime(num *big.Int, n int) {
// Possible prime should be odd
num.SetBit(num, 0, 1)
two := big.NewInt(2)
for !num.ProbablyPrime(n) {
// I add two to the number to obtain another odd number
num.Add(num, two)
}
}

// generateSafePrimes generates two primes p and q, in a way that q
// is equal to (p-1)/2. The greatest prime bit length is at least bitLen bits.
func generateSafePrimes(bitLen int, randFn func(int) (*big.Int, error)) (*big.Int, *big.Int, error) {
if randFn == nil {
return big.NewInt(0), big.NewInt(0), fmt.Errorf("random function cannot be nil")
func generateSafePrimes(bitLen int, randSource io.Reader) (*big.Int, *big.Int, error) {
if randSource == nil {
return big.NewInt(0), big.NewInt(0), fmt.Errorf("random source cannot be nil")
}

q := new(big.Int)
r := new(big.Int)
p := new(big.Int)

for {
p, err := randomPrime(bitLen, randFn)
q, err := rand.Prime(randSource, bitLen-1)
if err != nil {
return big.NewInt(0), big.NewInt(0), err
}

// If the number will be odd after right shift
if p.Bit(1) == 1 {
// q = (p - 1) / 2
q.Rsh(p, 1)
if q.ProbablyPrime(c) {
return p, q, nil
}
}

if p.BitLen() < bitLen {
// r = 2p + 1
r.Lsh(p, 1)
r.SetBit(r,0,1)
if r.ProbablyPrime(c) {
return r, p, nil
}
// p = 2q + 1
p.Lsh(q, 1)
p.SetBit(p,0,1)
if p.ProbablyPrime(c) {
return p, q, nil
}
}
}

81 changes: 17 additions & 64 deletions utils_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package tcrsa

import (
"crypto/rand"
"fmt"
"math/big"
"testing"
"time"
)

const utilsTestBitlen = 256
Expand All @@ -13,11 +16,11 @@ const utilsTestC = 25
// Tests that two consecutive outputs from random dev are different.
// TODO: Test how much different are the numbers generated
func TestRandomDev_different(t *testing.T) {
rand1, err := randomDev(utilsTestBitlen)
rand1, err := randInt(utilsTestBitlen)
if err != nil {
t.Errorf("first random number generation failed: %v", err)
}
rand2, err := randomDev(utilsTestBitlen)
rand2, err := randInt(utilsTestBitlen)
if err != nil {
t.Errorf("second random number generation failed: %v", err)
}
Expand All @@ -28,7 +31,7 @@ func TestRandomDev_different(t *testing.T) {

// Tests that the bit size of the output of a random dev function is the desired.
func TestRandomDev_bitSize(t *testing.T) {
rand1, err := randomDev(utilsTestBitlen)
rand1, err := randInt(utilsTestBitlen)
if err != nil {
t.Errorf("first random number generation failed: %v", err)
}
Expand All @@ -37,61 +40,11 @@ func TestRandomDev_bitSize(t *testing.T) {
}
}

// Tests that two consecutive random primes are different.
// TODO: Test how much different are the numbers generated
func TestRandomPrimes_different(t *testing.T) {
rand1, err := randomPrime(utilsTestBitlen, randomDev)
if err != nil {
t.Errorf("first random prime number generation failed: %v", err)
}
rand2, err := randomPrime(utilsTestBitlen, randomDev)
if err != nil {
t.Errorf("second random prime number generation failed: %v", err)
}
if rand1.Cmp(rand2) == 0 {
t.Errorf("both random numbers are equal!")
}
}

// Tests that the output size of a random prime function is the desired.
func TestRandomPrimes_bitSize(t *testing.T) {
rand1, err := randomPrime(utilsTestBitlen, randomDev)
if err != nil {
t.Errorf("first random prime number generation failed: %v", err)
}
if rand1.BitLen() > utilsTestBitlen {
t.Errorf("random number bit length should have been at most %d, but it was %d", rand1.BitLen(), utilsTestBitlen)
}
}

// Tests that the output of RandomPrimes is a prime.
func TestRandomPrimes_isPrime(t *testing.T) {
rand1, err := randomPrime(utilsTestBitlen, randomDev)
if err != nil {
t.Errorf("first random prime number generation failed: %v", err)
}
if !rand1.ProbablyPrime(utilsTestC) {
t.Errorf("random number is not prime")
}
}

// Tests that NextPrime returns the next prime of a number greater than 2.
func TestNextPrime(t *testing.T) {
number := big.NewInt(4)
firstNumber := big.NewInt(0)
firstNumber.Set(number)
expected := big.NewInt(5)
setAsNextPrime(number, utilsTestC)
if number.Cmp(expected) != 0 {
t.Errorf("expecting %s as next prime of %s, but obtained %s", expected, firstNumber, number)
}
}

func TestGenerateSafePrimes(t *testing.T) {

pExpected := new(big.Int)

p, pr, err := generateSafePrimes(utilsTestBitlen, randomDev)
p, pr, err := generateSafePrimes(utilsTestBitlen, rand.Reader)
if err != nil {
t.Errorf("safe prime generation failed: %v", err)
}
Expand All @@ -113,12 +66,12 @@ func TestGenerateSafePrimes_keyGeneration(t *testing.T) {
d := new(big.Int)
r := new(big.Int)

_, pr, err := generateSafePrimes(utilsTestBitlen, randomDev)
_, pr, err := generateSafePrimes(utilsTestBitlen, rand.Reader)
if err != nil {
t.Errorf("safe prime generation failed: %v", err)
}

_, qr, err := generateSafePrimes(utilsTestBitlen, randomDev)
_, qr, err := generateSafePrimes(utilsTestBitlen, rand.Reader)
if err != nil {
t.Errorf("safe prime generation failed: %v", err)
}
Expand All @@ -135,14 +88,14 @@ func TestGenerateSafePrimes_keyGeneration(t *testing.T) {

}


func BenchmarkSetAsPrime(b *testing.B) {
randFn := randomFixed(12345)
for i := 0; i < b.N; i++ {
randPrime := big.NewInt(0)
for randPrime.BitLen() == 0 || randPrime.BitLen() > utilsTestBitlen {
randPrime, _ = randFn(utilsTestBitlen)
setAsNextPrime(randPrime, c)
func TestGenerateSafePrimes_Time(t *testing.T) {
for i := 4; i <11; i++ {
keyLength := 1 << uint(i)
start := time.Now()
_, _, err := generateSafePrimes(keyLength, rand.Reader)
if err != nil {
t.Errorf("error generating safe primes: %d", err)
}
fmt.Printf("- %d byte safe prime pair obtained in %f seconds\n", keyLength, time.Since(start).Seconds())
}
}

0 comments on commit a0a4911

Please sign in to comment.