Skip to content

Commit

Permalink
add binary edge list shuffle script (again) (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
meyerzinn authored Apr 22, 2024
1 parent e4c0ac3 commit 1b62931
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 51 deletions.
4 changes: 3 additions & 1 deletion scripts/shufbel/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ module shufbel

go 1.18

require github.com/vmunoz82/shuffle v1.0.2 // indirect
require (
github.com/vmunoz82/shuffle v1.0.2
)
122 changes: 72 additions & 50 deletions scripts/shufbel/shufbel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,32 @@ package main
import (
"flag"
"fmt"
"io"
"math/rand"
"os"
"runtime"
"sync"
"syscall"
"unsafe"

"github.com/vmunoz82/shuffle"
)

// from https://github.com/bramp/dsector/blob/e96a7734bb3f/input/io.go#L41
func ReadFullAt(r io.ReaderAt, buf []byte, off int64) (n int, err error) {
min := len(buf)
for n < min && err == nil {
var nn int
nn, err = r.ReadAt(buf[n:], off)
n += nn
off += int64(nn)
}
if n >= min {
err = nil
} else if n > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
}
return
}

/* Function for each round, could be anything don't need to be reversible */
func roundFunction(v, key shuffle.FeistelWord) shuffle.FeistelWord {
return (v * 941083987) ^ (key >> (v & 7) * 104729)
Expand All @@ -22,6 +38,8 @@ func roundFunction(v, key shuffle.FeistelWord) shuffle.FeistelWord {
func main() {
// parse CLI options
threads := flag.Int("threads", runtime.NumCPU(), "number of threads")
routines := flag.Int("routines", 8*runtime.NumCPU(), "number of goroutines")
big := flag.Bool("big", false, "don't attempt to memory-map the input file")
rseed := flag.Int64("rseed", 0, "random seed")
flag.Parse()

Expand All @@ -38,20 +56,15 @@ func main() {
input_file := files[0]
output_file := files[1]

// mmap input file
// Open input file.
fin, err := os.Open(input_file)
if err != nil {
fmt.Printf("Error opening file %s: %v\n", input_file, err)
return
}
defer fin.Close()

fout, err := os.OpenFile(output_file, os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
fmt.Printf("Error opening file %s: %v\n", output_file, err)
return
}

// Compute file size.
fi, err := fin.Stat()
if err != nil {
fmt.Printf("Error getting file info: %v\n", err)
Expand All @@ -65,64 +78,73 @@ func main() {
fmt.Printf("Error: file size is not divisible by 16")
return
}
num_edges := sizeBytes / 16

fout.Truncate(int64(sizeBytes))

data, err := syscall.Mmap(int(fin.Fd()), 0, int(sizeBytes), syscall.PROT_READ, syscall.MAP_PRIVATE|syscall.MAP_NORESERVE)
if err != nil {
fmt.Printf("Error mmaping file %s: %v\n", input_file, err)
return
}
if uintptr(unsafe.Pointer(&data[0]))%uintptr(pageSize) != 0 {
fmt.Println("mmap returned a non-page-aligned address: %p", uintptr(unsafe.Pointer(&data[0])))
return
}
// madv willneed
err = syscall.Madvise(data, syscall.MADV_WILLNEED)
if err != nil {
fmt.Printf("Error madvise: %v\n", err)
return
var readFromInput func([]byte, int64)
if *big {
readFromInput = func(buf []byte, off int64) {
_, err := ReadFullAt(fin, buf, off)
if err != nil {
panic("Error reading input file")
}
}
} else {
// Memory-map the input file.
data, err := syscall.Mmap(int(fin.Fd()), 0, int(sizeBytes), syscall.PROT_READ, syscall.MAP_PRIVATE|syscall.MAP_POPULATE|syscall.MAP_NONBLOCK)
if err != nil {
fmt.Printf("Error memory-mapping file %s: %v\n", input_file, err)
return
}
defer syscall.Munmap(data)

readFromInput = func(buf []byte, off int64) {
copy(buf, data[off:off+int64(len(buf))])
}
}

// mmap output file
output, err := syscall.Mmap(int(fout.Fd()), 0, int(sizeBytes), syscall.PROT_WRITE, syscall.MAP_SHARED)
// Open output file.
fout, err := os.OpenFile(output_file, os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
fmt.Printf("Error mmaping file %s: %v\n", output_file, err)
fmt.Printf("Error opening file %s: %v\n", output_file, err)
return
}
fout.Truncate(int64(sizeBytes))

num_edges := sizeBytes / 16

rand.Seed(*rseed)
// keys should be an array of 4 random uint64s
keys := []shuffle.FeistelWord{shuffle.FeistelWord(rand.Uint64()), shuffle.FeistelWord(rand.Uint64()), shuffle.FeistelWord(rand.Uint64()), shuffle.FeistelWord(rand.Uint64())}
rand.Seed(*rseed)
var keys []shuffle.FeistelWord
for i := 0; i < 4; i++ {
keys = append(keys, shuffle.FeistelWord(rand.Uint64()))
}

var wg sync.WaitGroup
// spawn a goroutine for every page, since each page can fault
// todo (meyer): should we limit the number of concurrent goroutines?
// limit the in-flight goroutines with a ticketing system
limiter := make(chan struct{}, *routines)
for i := 0; i < *routines; i++ {
limiter <- struct{}{}
}

// Spawn one goroutine per page.
for i := uint64(0); i < num_edges; i += uint64(pageSize / 16) {
wg.Add(1)
// take a ticket
<-limiter
go func(i uint64) {
defer wg.Done()
output := make([]byte, pageSize)
fn := shuffle.NewFeistel(keys, roundFunction)
for j := i; j < i+uint64(pageSize/16) && j < num_edges; j++ {
src, _ := shuffle.GetIndex(shuffle.FeistelWord(j), shuffle.FeistelWord(num_edges), fn)
// move 16 bytes from data to output
copy(output[j*16:], data[src*16:src*16+16])
readFromInput(output[(j-i)*16:(j-i+1)*16], int64(src*16))
}
_, err := fout.WriteAt(output, int64(i*16))
if err != nil {
panic("Error writing output file")
}
// release the ticket
limiter <- struct{}{}
}(i)
}
wg.Wait()

err = syscall.Munmap(data)
if err != nil {
fmt.Printf("Error unmapping file %s: %v\n", input_file, err)
return
}

err = syscall.Munmap(output)
if err != nil {
fmt.Printf("Error unmapping file %s: %v\n", output_file, err)
return
// wait for all tickets to be returned (i.e. all routines to be finish)
for i := 0; i < *routines; i++ {
<-limiter
}
}

0 comments on commit 1b62931

Please sign in to comment.