Skip to content

Commit

Permalink
Add documentation explaining how intn works
Browse files Browse the repository at this point in the history
  • Loading branch information
tchajed committed Mar 7, 2019
1 parent 61495c0 commit bc7c618
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
28 changes: 26 additions & 2 deletions shuffle/shuffle.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,39 @@ func (s Shuffler) Unshuffle(x [][]byte) {
}
}

// maxMultiple returns the highest multiple of n that fits in a uint32
func maxMultiple(n uint32) uint32 {
uint32Max := ^uint32(0)
return uint32Max - (uint32Max % n)
}

// intn returns a random number uniformly distributed between 0 and n (not
// including n).
//
// rand should be a source of random bytes
//
// buf should be a temporary buffer with length at least 4
func intn(rand *bufio.Reader, n uint32, buf []byte) int {
max := ^uint32(0)
m := max - (max % n)
// intn does not simply take a random uint32 mod n because this is biased.
// Consider n=3 and a random uint32 u. (2^32-2)%3 == 2, so for u from 0 to
// 2^32-2, u%3 evenly rotates among 0, 1, and 2. However, (2^32-1)%3 == 0,
// so there is a slight bias in favor of u%3 == 0 in the case where u ==
// 2^32-1.
//
// To solve this problem, intn rejection-samples a number x between 0 and a
// multiple of n (not including the upper bound), which is truly uniform,
// then takes x%n.

m := maxMultiple(n)
for {
if _, err := rand.Read(buf); err != nil {
panic(err)
}
// Get a uniform random number in [0, 2^32-1)
x := binary.BigEndian.Uint32(buf)
if x < m {
// Accept only random numbers in [0, m). Because m is a multiple of
// n, x % n is uniformly distributed in [0, n).
return int(x % n)
}
}
Expand Down
16 changes: 16 additions & 0 deletions shuffle/shuffle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@ func TestShuffle(t *testing.T) {
}
}

func TestMaxMultiple(t *testing.T) {
for _, n := range []uint32{2, 3, 5, 10, 15, 1<<10} {
m := maxMultiple(n)
if m%n != 0 {
t.Errorf("maxMultiple(%d) is not a multiple", n)
continue
}
// note that m + n will wrap around if m is maximal; this relies on
// uint32 modular arithmetic
if m + n > m {
t.Errorf("maxMultiple(%d) is not maximal", n)
continue
}
}
}

func BenchmarkNew(b *testing.B) {
for i := 0; i < b.N; i++ {
New(rand.Reader, 50000)
Expand Down

0 comments on commit bc7c618

Please sign in to comment.