Skip to content

Commit

Permalink
refactor(ipld): use Set/GetCell API from rstm2d (#1173)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wondertan authored Nov 11, 2022
1 parent 8356c21 commit 19dd38f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 93 deletions.
98 changes: 43 additions & 55 deletions share/eds/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,18 @@ func (r *Retriever) Retrieve(ctx context.Context, dah *da.DataAvailabilityHeader
// quadrant request retries. Also, provides an API
// to reconstruct the block once enough shares are fetched.
type retrievalSession struct {
dah *da.DataAvailabilityHeader
bget blockservice.BlockGetter
adder *ipld.NmtNodeAdder

treeFn rsmt2d.TreeConstructorFn
codec rsmt2d.Codec

dah *da.DataAvailabilityHeader
squareImported *rsmt2d.ExtendedDataSquare

quadrants []*quadrant
sharesLks []sync.Mutex
sharesCount uint32

squareLk sync.RWMutex
square [][]byte
squareSig chan struct{}
squareDn chan struct{}
// TODO(@Wondertan): Extract into a separate data structure https://github.com/celestiaorg/rsmt2d/issues/135
squareQuadrants []*quadrant
squareCellsLks [][]sync.Mutex
squareCellsCount uint32
squareSig chan struct{}
squareDn chan struct{}
squareLk sync.RWMutex
square *rsmt2d.ExtendedDataSquare

span trace.Span
}
Expand All @@ -133,29 +128,31 @@ func (r *Retriever) newSession(ctx context.Context, dah *da.DataAvailabilityHead
r.bServ,
ipld.MaxSizeBatchOption(size),
)
ses := &retrievalSession{
bget: blockservice.NewSession(ctx, r.bServ),
adder: adder,
treeFn: func(_ rsmt2d.Axis, index uint) rsmt2d.Tree {
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(adder.Visit))
return &tree
},
codec: share.DefaultRSMT2DCodec(),
dah: dah,
quadrants: newQuadrants(dah),
sharesLks: make([]sync.Mutex, size*size),
square: make([][]byte, size*size),
squareSig: make(chan struct{}, 1),
squareDn: make(chan struct{}),
span: trace.SpanFromContext(ctx),

treeFn := func(_ rsmt2d.Axis, index uint) rsmt2d.Tree {
tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(size)/2, index, nmt.NodeVisitor(adder.Visit))
return &tree
}

square, err := rsmt2d.ImportExtendedDataSquare(ses.square, ses.codec, ses.treeFn)
square, err := rsmt2d.ImportExtendedDataSquare(make([][]byte, size*size), share.DefaultRSMT2DCodec(), treeFn)
if err != nil {
return nil, err
}

ses.squareImported = square
ses := &retrievalSession{
dah: dah,
bget: blockservice.NewSession(ctx, r.bServ),
adder: adder,
squareQuadrants: newQuadrants(dah),
squareCellsLks: make([][]sync.Mutex, size),
squareSig: make(chan struct{}, 1),
squareDn: make(chan struct{}),
square: square,
span: trace.SpanFromContext(ctx),
}
for i := range ses.squareCellsLks {
ses.squareCellsLks[i] = make([]sync.Mutex, size)
}
go ses.request(ctx)
return ses, nil
}
Expand All @@ -170,36 +167,24 @@ func (rs *retrievalSession) Done() <-chan struct{} {
// Reconstruct tries to reconstruct the data square and returns it on success.
func (rs *retrievalSession) Reconstruct(ctx context.Context) (*rsmt2d.ExtendedDataSquare, error) {
if rs.isReconstructed() {
return rs.squareImported, nil
return rs.square, nil
}
// prevent further writes to the square
rs.squareLk.Lock()
defer rs.squareLk.Unlock()

// TODO(@Wondertan): This is bad!
// * We should not reimport the square multiple times
// * We should set shares into imported square via
// SetShare(https://github.com/celestiaorg/rsmt2d/issues/83) to accomplish the above point.
{
squareImported, err := rsmt2d.ImportExtendedDataSquare(rs.square, rs.codec, rs.treeFn)
if err != nil {
return nil, err
}
rs.squareImported = squareImported
}

_, span := tracer.Start(ctx, "reconstruct-square")
defer span.End()

// and try to repair with what we have
err := rs.squareImported.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots)
err := rs.square.Repair(rs.dah.RowsRoots, rs.dah.ColumnRoots)
if err != nil {
span.RecordError(err)
return nil, err
}
log.Infow("data square reconstructed", "data_hash", hex.EncodeToString(rs.dah.Hash()), "size", len(rs.dah.RowsRoots))
close(rs.squareDn)
return rs.squareImported, nil
return rs.square, nil
}

// isReconstructed report true whether the square attached to the session
Expand Down Expand Up @@ -232,16 +217,16 @@ func (rs *retrievalSession) Close() error {
func (rs *retrievalSession) request(ctx context.Context) {
t := time.NewTicker(RetrieveQuadrantTimeout)
defer t.Stop()
for retry := 0; retry < len(rs.quadrants); retry++ {
q := rs.quadrants[retry]
for retry := 0; retry < len(rs.squareQuadrants); retry++ {
q := rs.squareQuadrants[retry]
log.Debugw("requesting quadrant",
"axis", q.source,
"x", q.x,
"y", q.y,
"size", len(q.roots),
)
rs.span.AddEvent("requesting quadrant", trace.WithAttributes(
attribute.Int("axis", q.source),
attribute.Int("axis", int(q.source)),
attribute.Int("x", q.x),
attribute.Int("y", q.y),
attribute.Int("size", len(q.roots)),
Expand All @@ -260,7 +245,7 @@ func (rs *retrievalSession) request(ctx context.Context) {
"size", len(q.roots),
)
rs.span.AddEvent("quadrant request timeout", trace.WithAttributes(
attribute.Int("axis", q.source),
attribute.Int("axis", int(q.source)),
attribute.Int("x", q.x),
attribute.Int("y", q.y),
attribute.Int("size", len(q.roots)),
Expand Down Expand Up @@ -292,10 +277,10 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) {
// in the square.
// NOTE-2: We never actually fetch shares from the network *twice*.
// Once a share is downloaded from the network it is cached on the IPLD(blockservice) level.
// calc index of the share
idx := q.index(i, j)
// calc position of the share
x, y := q.pos(i, j)
// try to lock the share
ok := rs.sharesLks[idx].TryLock()
ok := rs.squareCellsLks[x][y].TryLock()
if !ok {
// if already locked and written - do nothing
return
Expand All @@ -312,14 +297,17 @@ func (rs *retrievalSession) doRequest(ctx context.Context, q *quadrant) {
if rs.isReconstructed() {
return
}
rs.square[idx] = share
if rs.square.GetCell(uint(x), uint(y)) != nil {
return
}
rs.square.SetCell(uint(x), uint(y), share)
// if we have >= 1/4 of the square we can start trying to Reconstruct
// TODO(@Wondertan): This is not an ideal way to know when to start
// reconstruction and can cause idle reconstruction tries in some cases,
// but it is totally fine for the happy case and for now.
// The earlier we correctly know that we have the full square - the earlier
// we cancel ongoing requests - the less data is being wastedly transferred.
if atomic.AddUint32(&rs.sharesCount, 1) >= uint32(size*size) {
if atomic.AddUint32(&rs.squareCellsCount, 1) >= uint32(size*size) {
select {
case rs.squareSig <- struct{}{}:
default:
Expand Down
55 changes: 17 additions & 38 deletions share/eds/retriever_quadrant.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package eds

import (
"math"
"math/rand"
"time"

"github.com/ipfs/go-cid"

"github.com/celestiaorg/celestia-app/pkg/da"
"github.com/celestiaorg/rsmt2d"

"github.com/celestiaorg/celestia-node/share/ipld"
)
Expand Down Expand Up @@ -42,10 +42,8 @@ type quadrant struct {
// |(0;1)| |(1;1)|
// ------ -------
x, y int
// source defines the axis for quadrant
// it can be either 1 or 0 similar to x and y
// where 0 is Row source and 1 is Col respectively
source int
// source defines the axis(Row or Col) to fetch the quadrant from
source rsmt2d.Axis
}

// newQuadrants constructs a slice of quadrants from DAHeader.
Expand All @@ -70,17 +68,13 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant {
}

for i := range quadrants {
// convert quadrant index into coordinates
// convert quadrant 1D into into 2D coordinates
x, y := i%2, i/2
if source == 1 { // swap coordinates for column
x, y = y, x
}

quadrants[i] = &quadrant{
roots: roots[qsize*y : qsize*(y+1)],
x: x,
y: y,
source: source,
source: rsmt2d.Axis(source),
}
}
}
Expand All @@ -93,31 +87,16 @@ func newQuadrants(dah *da.DataAvailabilityHeader) []*quadrant {
return quadrants
}

// index calculates index for a share in a data square slice flattened by rows.
//
// NOTE: The complexity of the formula below comes from:
// - Goal to avoid share copying
// - Goal to make formula generic for both rows and cols
// - While data square is flattened by rows only
//
// TODO(@Wondertan): This can be simplified by making rsmt2d working over 3D byte slice(not
// flattened)
func (q *quadrant) index(rootIdx, cellIdx int) int {
size := len(q.roots)
// half square offsets, e.g. share is from Q3,
// so we add to index Q1+Q2
halfSquareOffsetCol := pow(size*2, q.source)
halfSquareOffsetRow := pow(size*2, q.source^1)
// offsets for the axis, e.g. share is from Q4.
// so we add to index Q3
offsetX := q.x * halfSquareOffsetCol * size
offsetY := q.y * halfSquareOffsetRow * size

rootIdx *= halfSquareOffsetRow
cellIdx *= halfSquareOffsetCol
return rootIdx + cellIdx + offsetX + offsetY
}

func pow(x, y int) int {
return int(math.Pow(float64(x), float64(y)))
// pos calculates position of a share in a data square.
func (q *quadrant) pos(rootIdx, cellIdx int) (int, int) {
cellIdx += len(q.roots) * q.x
rootIdx += len(q.roots) * q.y
switch q.source {
case rsmt2d.Row:
return rootIdx, cellIdx
case rsmt2d.Col:
return cellIdx, rootIdx
default:
panic("unknown axis")
}
}

0 comments on commit 19dd38f

Please sign in to comment.