Skip to content

Commit

Permalink
serial witness retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
ToniRamirezM committed Jul 14, 2024
1 parent 93f964f commit 4078e4e
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 38 deletions.
30 changes: 15 additions & 15 deletions aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func (a *Aggregator) handleReceivedDataStream(entry *datastreamer.FileEntry, cli
a.currentStreamBatch.Timestamp = sequence.Timestamp

// Calculate Acc Input Hash
oldBatch, _, err := a.state.GetBatch(ctx, a.currentStreamBatch.BatchNumber-1, nil)
oldBatch, _, _, err := a.state.GetBatch(ctx, a.currentStreamBatch.BatchNumber-1, nil)
if err != nil {
log.Errorf("Error getting batch %d: %v", a.currentStreamBatch.BatchNumber-1, err)
return err
Expand All @@ -343,7 +343,14 @@ func (a *Aggregator) handleReceivedDataStream(entry *datastreamer.FileEntry, cli

a.currentStreamBatch.AccInputHash = accInputHash

err = a.state.AddBatch(ctx, &a.currentStreamBatch, a.currentBatchStreamData, nil)
// Get Witness
witness, err := getWitness(a.currentStreamBatch.BatchNumber, a.cfg.WitnessURL, a.cfg.UseFullWitness)
if err != nil {
log.Errorf("Failed to get witness for batch %d, err: %v", a.currentStreamBatch.BatchNumber, err)
return err
}

err = a.state.AddBatch(ctx, &a.currentStreamBatch, a.currentBatchStreamData, witness, nil)
if err != nil {
log.Errorf("Error adding batch: %v", err)
return err
Expand Down Expand Up @@ -467,7 +474,7 @@ func (a *Aggregator) Start(ctx context.Context) error {

// Store Acc Input Hash of the latest verified batch
dummyBatch := state.Batch{BatchNumber: lastVerifiedBatchNumber, AccInputHash: *accInputHash}
err = a.state.AddBatch(ctx, &dummyBatch, []byte{0}, nil)
err = a.state.AddBatch(ctx, &dummyBatch, []byte{0}, []byte{0}, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -624,7 +631,7 @@ func (a *Aggregator) sendFinalProof() {

a.startProofVerification()

finalBatch, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
finalBatch, _, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
if err != nil {
log.Errorf("Failed to retrieve batch with number [%d]: %v", proof.BatchNumberFinal, err)
a.endProofVerification()
Expand Down Expand Up @@ -770,7 +777,7 @@ func (a *Aggregator) buildFinalProof(ctx context.Context, prover proverInterface
if string(finalProof.Public.NewStateRoot) == mockedStateRoot && string(finalProof.Public.NewLocalExitRoot) == mockedLocalExitRoot {
// This local exit root and state root come from the mock
// prover, use the one captured by the executor instead
finalBatch, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
finalBatch, _, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
if err != nil {
return nil, fmt.Errorf("failed to retrieve batch with number [%d]", proof.BatchNumberFinal)
}
Expand Down Expand Up @@ -1226,7 +1233,7 @@ func (a *Aggregator) getAndLockBatchToProve(ctx context.Context, prover proverIn
return nil, nil, err
}

batch, _, err := a.state.GetBatch(ctx, batchNumberToVerify, nil)
batch, _, _, err := a.state.GetBatch(ctx, batchNumberToVerify, nil)
if err != nil {
return batch, nil, err
}
Expand Down Expand Up @@ -1491,15 +1498,8 @@ func (a *Aggregator) buildInputProver(ctx context.Context, batchToVerify *state.
}*/
}

// Get Witness
witness, err := getWitness(batchToVerify.BatchNumber, a.cfg.WitnessURL, a.cfg.UseFullWitness)
if err != nil {
log.Errorf("Failed to get witness, err: %v", err)
return nil, err
}

// Get Old Acc Input Hash
oldBatch, _, err := a.state.GetBatch(ctx, batchToVerify.BatchNumber-1, nil)
// Get Old Acc Input Hash and witness
oldBatch, _, witness, err := a.state.GetBatch(ctx, batchToVerify.BatchNumber-1, nil)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions aggregator/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ type stateInterface interface {
CleanupLockedProofs(ctx context.Context, duration string, dbTx pgx.Tx) (int64, error)
CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (bool, error)
AddSequence(ctx context.Context, sequence state.Sequence, dbTx pgx.Tx) error
AddBatch(ctx context.Context, batch *state.Batch, datastream []byte, dbTx pgx.Tx) error
GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*state.Batch, []byte, error)
AddBatch(ctx context.Context, batch *state.Batch, datastream []byte, witness []byte, dbTx pgx.Tx) error
GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*state.Batch, []byte, []byte, error)
DeleteBatchesOlderThanBatchNumber(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
DeleteBatchesNewerThanBatchNumber(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
}
15 changes: 4 additions & 11 deletions aggregator/prover/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ func (p *Prover) Addr() string {

// Status gets the prover status.
func (p *Prover) Status() (*GetStatusResponse, error) {
start := time.Now()
req := &AggregatorMessage{
Request: &AggregatorMessage_GetStatusRequest{
GetStatusRequest: &GetStatusRequest{},
Expand All @@ -79,11 +78,9 @@ func (p *Prover) Status() (*GetStatusResponse, error) {
if err != nil {
return nil, err
}
log.Infof("Prover status call")
if msg, ok := res.Response.(*ProverMessage_GetStatusResponse); ok {
return msg.GetStatusResponse, nil
}
log.Infof("Prover %s status call took %v", p.ID(), time.Since(start))
return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &ProverMessage_GetStatusResponse{}, res.Response)
}

Expand Down Expand Up @@ -119,12 +116,11 @@ func (p *Prover) BatchProof(input *StatelessInputProver) (*string, error) {
GenStatelessBatchProofRequest: &GenStatelessBatchProofRequest{Input: input},
},
}
start := time.Now()

res, err := p.call(req)
if err != nil {
return nil, err
}
log.Infof("Prover %s batch proof call took %v", p.ID(), time.Since(start))

if msg, ok := res.Response.(*ProverMessage_GenBatchProofResponse); ok {
switch msg.GenBatchProofResponse.Result {
Expand Down Expand Up @@ -157,12 +153,11 @@ func (p *Prover) AggregatedProof(inputProof1, inputProof2 string) (*string, erro
},
},
}
start := time.Now()

res, err := p.call(req)
if err != nil {
return nil, err
}
log.Infof("Prover %s aggregated proof call took %v", p.ID(), time.Since(start))

if msg, ok := res.Response.(*ProverMessage_GenAggregatedProofResponse); ok {
switch msg.GenAggregatedProofResponse.Result {
Expand Down Expand Up @@ -199,12 +194,11 @@ func (p *Prover) FinalProof(inputProof string, aggregatorAddr string) (*string,
},
},
}
start := time.Now()

res, err := p.call(req)
if err != nil {
return nil, err
}
log.Infof("Prover %s final proof call took %v", p.ID(), time.Since(start))

if msg, ok := res.Response.(*ProverMessage_GenFinalProofResponse); ok {
switch msg.GenFinalProofResponse.Result {
Expand Down Expand Up @@ -235,12 +229,11 @@ func (p *Prover) CancelProofRequest(proofID string) error {
CancelRequest: &CancelRequest{Id: proofID},
},
}
start := time.Now()

res, err := p.call(req)
if err != nil {
return err
}
log.Infof("Prover %s cancel proof call took %v", p.ID(), time.Since(start))

if msg, ok := res.Response.(*ProverMessage_CancelResponse); ok {
switch msg.CancelResponse.Result {
Expand Down
8 changes: 8 additions & 0 deletions db/migrations/aggregator/002.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- +migrate Up
DELETE FROM aggregator.batch;
ALTER TABLE aggregator.batch
ADD COLUMN IF NOT EXISTS witness varchar NOT NULL;

-- +migrate Down
ALTER TABLE aggregator.batch
DROP COLUMN IF NOT EXISTS witness;
4 changes: 2 additions & 2 deletions state/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ type storage interface {
CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
CleanupLockedProofs(ctx context.Context, duration string, dbTx pgx.Tx) (int64, error)
CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (bool, error)
AddBatch(ctx context.Context, batch *Batch, datastream []byte, dbTx pgx.Tx) error
GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*Batch, []byte, error)
AddBatch(ctx context.Context, batch *Batch, datastream []byte, witness []byte, dbTx pgx.Tx) error
GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*Batch, []byte, []byte, error)
DeleteBatchesOlderThanBatchNumber(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
DeleteBatchesNewerThanBatchNumber(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
}
17 changes: 9 additions & 8 deletions state/pgstatestorage/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,25 @@ import (
)

// AddBatch stores a batch
func (p *PostgresStorage) AddBatch(ctx context.Context, batch *state.Batch, datastream []byte, dbTx pgx.Tx) error {
const addInputHashSQL = "INSERT INTO aggregator.batch (batch_num, batch, datastream) VALUES ($1, $2, $3) ON CONFLICT (batch_num) DO UPDATE SET batch = $2, datastream = $3"
func (p *PostgresStorage) AddBatch(ctx context.Context, batch *state.Batch, datastream []byte, witness []byte, dbTx pgx.Tx) error {
const addInputHashSQL = "INSERT INTO aggregator.batch (batch_num, batch, datastream, witness) VALUES ($1, $2, $3, $4) ON CONFLICT (batch_num) DO UPDATE SET batch = $2, datastream = $3, witness = $4"
e := p.getExecQuerier(dbTx)
_, err := e.Exec(ctx, addInputHashSQL, batch.BatchNumber, &batch, common.Bytes2Hex(datastream))
_, err := e.Exec(ctx, addInputHashSQL, batch.BatchNumber, &batch, common.Bytes2Hex(datastream), common.Bytes2Hex(witness))
return err
}

// GetBatch gets a batch by a given batch number
func (p *PostgresStorage) GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*state.Batch, []byte, error) {
const getInputHashSQL = "SELECT batch, datastream FROM aggregator.batch WHERE batch_num = $1"
func (p *PostgresStorage) GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*state.Batch, []byte, []byte, error) {
const getInputHashSQL = "SELECT batch, datastream, witness FROM aggregator.batch WHERE batch_num = $1"
e := p.getExecQuerier(dbTx)
var batch *state.Batch
var streamStr string
err := e.QueryRow(ctx, getInputHashSQL, batchNumber).Scan(&batch, &streamStr)
var witnessStr string
err := e.QueryRow(ctx, getInputHashSQL, batchNumber).Scan(&batch, &streamStr, &witnessStr)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
return batch, common.Hex2Bytes(streamStr), nil
return batch, common.Hex2Bytes(streamStr), common.Hex2Bytes(witnessStr), nil
}

// DeleteBatchesOlderThanBatchNumber deletes batches previous to the given batch number
Expand Down

0 comments on commit 4078e4e

Please sign in to comment.