diff --git a/datasquare.go b/datasquare.go index 4a5ca0c..45df481 100644 --- a/datasquare.go +++ b/datasquare.go @@ -26,20 +26,15 @@ type dataSquare struct { createTreeFn TreeConstructorFn } -func newDataSquare(data [][]byte, treeCreator TreeConstructorFn) (*dataSquare, error) { +func newDataSquare(data [][]byte, treeCreator TreeConstructorFn, chunkSize uint) (*dataSquare, error) { width := int(math.Ceil(math.Sqrt(float64(len(data))))) if width*width != len(data) { return nil, errors.New("number of chunks must be a square number") } - var chunkSize int for _, d := range data { - if d != nil { - if chunkSize == 0 { - chunkSize = len(d) - } else if chunkSize != len(d) { - return nil, ErrUnevenChunks - } + if d != nil && len(d) != int(chunkSize) { + return nil, ErrUnevenChunks } } @@ -48,7 +43,7 @@ func newDataSquare(data [][]byte, treeCreator TreeConstructorFn) (*dataSquare, e squareRow[i] = data[i*width : i*width+width] for j := 0; j < width; j++ { - if squareRow[i][j] != nil && len(squareRow[i][j]) != chunkSize { + if squareRow[i][j] != nil && len(squareRow[i][j]) != int(chunkSize) { return nil, ErrUnevenChunks } } diff --git a/datasquare_test.go b/datasquare_test.go index 4771b74..45a449a 100644 --- a/datasquare_test.go +++ b/datasquare_test.go @@ -12,17 +12,18 @@ import ( func TestNewDataSquare(t *testing.T) { tests := []struct { - name string - cells [][]byte - expected [][][]byte + name string + cells [][]byte + expected [][][]byte + chunkSize uint }{ - {"1x1", [][]byte{{1, 2}}, [][][]byte{{{1, 2}}}}, - {"2x2", [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, [][][]byte{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}}, + {"1x1", [][]byte{{1, 2}}, [][][]byte{{{1, 2}}}, 2}, + {"2x2", [][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, [][][]byte{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, 2}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - result, err := newDataSquare(test.cells, NewDefaultTree) + result, err := newDataSquare(test.cells, NewDefaultTree, test.chunkSize) if err != nil { panic(err) } @@ -35,15 +36,16 @@ func TestNewDataSquare(t *testing.T) { func TestInvalidDataSquareCreation(t *testing.T) { tests := []struct { - name string - cells [][]byte + name string + cells [][]byte + chunkSize uint }{ - {"InconsistentChunkNumber", [][]byte{{1, 2}, {3, 4}, {5, 6}}}, - {"UnequalChunkSize", [][]byte{{1, 2}, {3, 4}, {5, 6}, {7}}}, + {"InconsistentChunkNumber", [][]byte{{1, 2}, {3, 4}, {5, 6}}, 2}, + {"UnequalChunkSize", [][]byte{{1, 2}, {3, 4}, {5, 6}, {7}}, 2}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - _, err := newDataSquare(test.cells, NewDefaultTree) + _, err := newDataSquare(test.cells, NewDefaultTree, test.chunkSize) if err == nil { t.Errorf("newDataSquare failed; chunks accepted with %v", test.name) } @@ -81,7 +83,7 @@ func TestSetCell(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ds, err := newDataSquare([][]byte{tc.originalCell, {2}, {3}, {4}}, NewDefaultTree) + ds, err := newDataSquare([][]byte{tc.originalCell, {2}, {3}, {4}}, NewDefaultTree, 1) assert.NoError(t, err) err = ds.SetCell(0, 0, tc.newCell) @@ -124,7 +126,7 @@ func Test_setCell(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - ds, err := newDataSquare([][]byte{tc.original, {2}, {3}, {4}}, NewDefaultTree) + ds, err := newDataSquare([][]byte{tc.original, {2}, {3}, {4}}, NewDefaultTree, 1) assert.NoError(t, err) ds.setCell(0, 0, tc.new) @@ -134,7 +136,7 @@ func Test_setCell(t *testing.T) { } func TestGetCell(t *testing.T) { - ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree, 1) if err != nil { panic(err) } @@ -148,7 +150,7 @@ func TestGetCell(t *testing.T) { } func TestFlattened(t *testing.T) { - ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree, 1) if err != nil { panic(err) } @@ -162,7 +164,7 @@ func TestFlattened(t *testing.T) { } func TestExtendSquare(t *testing.T) { - ds, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree) + ds, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree, 2) if err != nil { panic(err) } @@ -176,7 +178,7 @@ func TestExtendSquare(t *testing.T) { } func TestInvalidSquareExtension(t *testing.T) { - ds, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree) + ds, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree, 2) if err != nil { panic(err) } @@ -189,7 +191,7 @@ func TestInvalidSquareExtension(t *testing.T) { // TestRoots verifies that the row roots and column roots are equal for a 1x1 // square. func TestRoots(t *testing.T) { - result, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree) + result, err := newDataSquare([][]byte{{1, 2}}, NewDefaultTree, 2) assert.NoError(t, err) rowRoots, err := result.getRowRoots() @@ -202,7 +204,7 @@ func TestRoots(t *testing.T) { } func TestLazyRootGeneration(t *testing.T) { - square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree, 1) if err != nil { panic(err) } @@ -228,13 +230,13 @@ func TestLazyRootGeneration(t *testing.T) { func TestComputeRoots(t *testing.T) { t.Run("default tree computeRoots() returns no error", func(t *testing.T) { - square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree, 1) assert.NoError(t, err) err = square.computeRoots() assert.NoError(t, err) }) t.Run("error tree computeRoots() returns an error", func(t *testing.T) { - square, err := newDataSquare([][]byte{{1}}, newErrorTree) + square, err := newDataSquare([][]byte{{1}}, newErrorTree, 1) assert.NoError(t, err) err = square.computeRoots() assert.Error(t, err) @@ -242,7 +244,7 @@ func TestComputeRoots(t *testing.T) { } func TestRootAPI(t *testing.T) { - square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree, 1) if err != nil { panic(err) } @@ -267,7 +269,7 @@ func TestRootAPI(t *testing.T) { } func TestDefaultTreeProofs(t *testing.T) { - result, err := newDataSquare([][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, NewDefaultTree) + result, err := newDataSquare([][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, NewDefaultTree, 2) if err != nil { panic(err) } @@ -330,7 +332,7 @@ func Test_setRowSlice(t *testing.T) { } for _, tc := range testCases { - ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree, 1) assert.NoError(t, err) err = ds.setRowSlice(tc.x, tc.y, tc.newRow) @@ -386,7 +388,7 @@ func Test_setColSlice(t *testing.T) { } for _, tc := range testCases { - ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree) + ds, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree, 1) assert.NoError(t, err) err = ds.setColSlice(tc.x, tc.y, tc.newCol) @@ -401,7 +403,8 @@ func Test_setColSlice(t *testing.T) { func BenchmarkEDSRoots(b *testing.B) { for i := 32; i < 513; i *= 2 { - square, err := newDataSquare(genRandDS(i*2), NewDefaultTree) + chunkSize := uint(256) + square, err := newDataSquare(genRandDS(i*2, int(chunkSize)), NewDefaultTree, chunkSize) if err != nil { b.Errorf("Failure to create square of size %d: %s", i, err) } diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index 9524a80..149f635 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -261,6 +261,7 @@ func TestCorruptedEdsReturnsErrByzantineData(t *testing.T) { } func BenchmarkRepair(b *testing.B) { + chunkSize := uint(256) // For different ODS sizes for originalDataWidth := 4; originalDataWidth <= 512; originalDataWidth *= 2 { for codecName, codec := range codecs { @@ -270,7 +271,7 @@ func BenchmarkRepair(b *testing.B) { } // Generate a new range original data square then extend it - square := genRandDS(originalDataWidth) + square := genRandDS(originalDataWidth, int(chunkSize)) eds, err := ComputeExtendedDataSquare(square, codec, NewDefaultTree) if err != nil { b.Error(err) diff --git a/extendeddatasquare.go b/extendeddatasquare.go index 1e71184..f46abdb 100644 --- a/extendeddatasquare.go +++ b/extendeddatasquare.go @@ -54,7 +54,8 @@ func ComputeExtendedDataSquare( return nil, errors.New("number of chunks exceeds the maximum") } - ds, err := newDataSquare(data, treeCreatorFn) + chunkSize := getChunkSize(data) + ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize)) if err != nil { return nil, err } @@ -78,14 +79,16 @@ func ImportExtendedDataSquare( return nil, errors.New("number of chunks exceeds the maximum") } - ds, err := newDataSquare(data, treeCreatorFn) + chunkSize := getChunkSize(data) + ds, err := newDataSquare(data, treeCreatorFn, uint(chunkSize)) if err != nil { return nil, err } eds := ExtendedDataSquare{dataSquare: ds, codec: codec} - if eds.width%2 != 0 { - return nil, errors.New("square width must be even") + err = validateEdsWidth(eds.width) + if err != nil { + return nil, err } eds.originalDataWidth = eds.width / 2 @@ -93,6 +96,30 @@ func ImportExtendedDataSquare( return &eds, nil } +// NewExtendedDataSquare returns a new extended data square with a width of +// edsWidth. All shares are initialized to nil so that the returned extended +// data square can be populated via subsequent SetCell invocations. +func NewExtendedDataSquare(codec Codec, treeCreatorFn TreeConstructorFn, edsWidth uint, chunkSize uint) (*ExtendedDataSquare, error) { + err := validateEdsWidth(edsWidth) + if err != nil { + return nil, err + } + + data := make([][]byte, edsWidth*edsWidth) + dataSquare, err := newDataSquare(data, treeCreatorFn, chunkSize) + if err != nil { + return nil, err + } + + originalDataWidth := edsWidth / 2 + eds := ExtendedDataSquare{ + dataSquare: dataSquare, + codec: codec, + originalDataWidth: originalDataWidth, + } + return &eds, nil +} + func (eds *ExtendedDataSquare) erasureExtendSquare(codec Codec) error { eds.originalDataWidth = eds.width @@ -232,3 +259,23 @@ func deepCopy(original [][]byte) [][]byte { func (eds *ExtendedDataSquare) Width() uint { return eds.width } + +// validateEdsWidth returns an error if edsWidth is not a valid width for an +// extended data square. +func validateEdsWidth(edsWidth uint) error { + if edsWidth%2 != 0 { + return errors.New("square width must be even") + } + + return nil +} + +// getChunkSize returns the size of the first non-nil chunk in data. +func getChunkSize(data [][]byte) (chunkSize int) { + for _, d := range data { + if d != nil { + return len(d) + } + } + return 0 +} diff --git a/extendeddatasquare_test.go b/extendeddatasquare_test.go index 683d9f4..ddf351a 100644 --- a/extendeddatasquare_test.go +++ b/extendeddatasquare_test.go @@ -96,6 +96,49 @@ func TestMarshalJSON(t *testing.T) { } } +func TestNewExtendedDataSquare(t *testing.T) { + t.Run("returns an error if edsWidth is not even", func(t *testing.T) { + edsWidth := uint(1) + chunkSize := uint(512) + + _, err := NewExtendedDataSquare(NewLeoRSCodec(), NewDefaultTree, edsWidth, chunkSize) + assert.Error(t, err) + }) + t.Run("returns a 4x4 EDS", func(t *testing.T) { + edsWidth := uint(4) + chunkSize := uint(512) + + got, err := NewExtendedDataSquare(NewLeoRSCodec(), NewDefaultTree, edsWidth, chunkSize) + assert.NoError(t, err) + assert.Equal(t, edsWidth, got.width) + assert.Equal(t, chunkSize, got.chunkSize) + }) + t.Run("returns a 4x4 EDS that can be populated via SetCell", func(t *testing.T) { + edsWidth := uint(4) + chunkSize := uint(512) + + got, err := NewExtendedDataSquare(NewLeoRSCodec(), NewDefaultTree, edsWidth, chunkSize) + assert.NoError(t, err) + + chunk := bytes.Repeat([]byte{1}, int(chunkSize)) + err = got.SetCell(0, 0, chunk) + assert.NoError(t, err) + assert.Equal(t, chunk, got.squareRow[0][0]) + }) + t.Run("returns an error when SetCell is invoked on an EDS with a chunk that is not the correct size", func(t *testing.T) { + edsWidth := uint(4) + chunkSize := uint(512) + incorrectChunkSize := uint(513) + + got, err := NewExtendedDataSquare(NewLeoRSCodec(), NewDefaultTree, edsWidth, chunkSize) + assert.NoError(t, err) + + chunk := bytes.Repeat([]byte{1}, int(incorrectChunkSize)) + err = got.SetCell(0, 0, chunk) + assert.Error(t, err) + }) +} + func TestImmutableRoots(t *testing.T) { codec := NewLeoRSCodec() result, err := ComputeExtendedDataSquare([][]byte{ @@ -161,6 +204,7 @@ var dump *ExtendedDataSquare // BenchmarkExtension benchmarks extending datasquares sizes 4-128 using all // supported codecs (encoding only) func BenchmarkExtensionEncoding(b *testing.B) { + chunkSize := 256 for i := 4; i < 513; i *= 2 { for codecName, codec := range codecs { if codec.MaxChunks() < i*i { @@ -168,7 +212,7 @@ func BenchmarkExtensionEncoding(b *testing.B) { continue } - square := genRandDS(i) + square := genRandDS(i, chunkSize) b.Run( fmt.Sprintf("%s %dx%dx%d ODS", codecName, i, i, len(square[0])), func(b *testing.B) { @@ -188,6 +232,7 @@ func BenchmarkExtensionEncoding(b *testing.B) { // BenchmarkExtension benchmarks extending datasquares sizes 4-128 using all // supported codecs (both encoding and root computation) func BenchmarkExtensionWithRoots(b *testing.B) { + chunkSize := 256 for i := 4; i < 513; i *= 2 { for codecName, codec := range codecs { if codec.MaxChunks() < i*i { @@ -195,7 +240,7 @@ func BenchmarkExtensionWithRoots(b *testing.B) { continue } - square := genRandDS(i) + square := genRandDS(i, chunkSize) b.Run( fmt.Sprintf("%s %dx%dx%d ODS", codecName, i, i, len(square[0])), func(b *testing.B) { @@ -216,11 +261,11 @@ func BenchmarkExtensionWithRoots(b *testing.B) { // genRandDS make a datasquare of random data, with width describing the number // of shares on a single side of the ds -func genRandDS(width int) [][]byte { +func genRandDS(width int, chunkSize int) [][]byte { var ds [][]byte count := width * width for i := 0; i < count; i++ { - share := make([]byte, 256) + share := make([]byte, chunkSize) _, err := rand.Read(share) if err != nil { panic(err)