Skip to content

Commit

Permalink
fix: deadlock when signing with the same validator in parallel (#108)
Browse files Browse the repository at this point in the history
* test: sign with the same validator in parallel on the same time.

* fix: return lock instead of locking inside `lock` function to prevent deadlock.

* make account context thread-safe

* use value mutex to not initialize

* fix golangci-lint issues

* fix test

* defer without function wrapping

* make test wait for both goroutines

* typo

* add many validator parallel test

---------

Co-authored-by: moshe-blox <[email protected]>
  • Loading branch information
y0sher and moshe-blox authored Oct 28, 2024
1 parent a34b29d commit e6bc06b
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 30 deletions.
6 changes: 3 additions & 3 deletions cli/cmd/wallet/cmd/account/handler/handler_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,13 @@ func ValidateHighestValues(accountFlagValues CreateAccountFlagValues) error {
privateKeysCount := len(accountFlagValues.privateKeys)

if len(accountFlagValues.highestSources) != privateKeysCount {
return errors.Errorf("highest sources " + errorExplain)
return errors.Errorf("highest sources %v", errorExplain)
}
if len(accountFlagValues.highestTargets) != privateKeysCount {
return errors.Errorf("highest targets " + errorExplain)
return errors.Errorf("highest targets %v", errorExplain)
}
if len(accountFlagValues.highestProposals) != privateKeysCount {
return errors.Errorf("highest proposals " + errorExplain)
return errors.Errorf("highest proposals %v", errorExplain)
}
} else if accountFlagValues.accumulate {
if len(accountFlagValues.highestSources) != (accountFlagValues.index + 1) {
Expand Down
7 changes: 3 additions & 4 deletions signer/sign_attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ func (signer *SimpleSigner) SignBeaconAttestation(attestation *phase0.Attestatio
}

// 2. lock for current account
signer.lock(account.ID(), "attestation")
defer func() {
signer.unlock(account.ID(), "attestation")
}()
val := signer.lock(account.ID(), "attestation")
val.Lock()
defer val.Unlock()

// 3. far future check
if !IsValidFarFutureEpoch(signer.network, attestation.Target.Epoch) {
Expand Down
173 changes: 173 additions & 0 deletions signer/sign_attestation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package signer
import (
"encoding/hex"
"fmt"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -73,6 +74,178 @@ func TestReferenceAttestation(t *testing.T) {
require.EqualValues(t, sig, actualSig)
}

// tested against a block and sig generated from https://github.com/prysmaticlabs/prysm/blob/master/shared/testutil/block.go#L357
func TestLockSameValidatorInParallel(t *testing.T) {
sk := _byteArray("2c083f2c8fc923fa2bd32a70ab72b4b46247e8c1f347adc30b2f8036a355086c")
pk := _byteArray("a9cf360aa15fb1d1d30ee2b578dc5884823c19661886ae8b892775ccb3bd96b7d7345569a2aa0b14e4d015c54a6a0c54")
domain := _byteArray32("0100000081509579e35e84020ad8751eca180b44df470332d3ad17fc6fd52459")

store := inmemStorage()
options := &eth2keymanager.KeyVaultOptions{}
options.SetStorage(store)
options.SetWalletType(core.NDWallet)
vault, err := eth2keymanager.NewKeyVault(options)
require.NoError(t, err)
wallet, err := vault.Wallet()
require.NoError(t, err)

k, err := core.NewHDKeyFromPrivateKey(sk, "")
require.NoError(t, err)
acc := wallets.NewValidatorAccount("1", k, nil, "", vault.Context)
require.NoError(t, err)
require.NoError(t, wallet.AddValidatorAccount(acc))

//// setup signer
signer := NewSimpleSigner(wallet, &prot.NoProtection{}, core.MainNetwork)

attestationDataByts := _byteArray("000000000000000000000000000000003a43a4bf26fb5947e809c1f24f7dc6857c8ac007e535d48e6e4eca2122fd776b0000000000000000000000000000000000000000000000000000000000000000000000000000000002000000000000003a43a4bf26fb5947e809c1f24f7dc6857c8ac007e535d48e6e4eca2122fd776b")

// decode attestation
attData := &phase0.AttestationData{}
require.NoError(t, attData.UnmarshalSSZ(attestationDataByts))

ch := make(chan struct{})

go func() {
_, _, err := signer.SignBeaconAttestation(attData, phase0.Domain{0}, pk)
require.NoError(t, err)
close(ch)
}()

ch2 := make(chan struct{})

go func() {
_, _, err := signer.SignBeaconAttestation(attData, domain, pk)
require.NoError(t, err)
close(ch2)

}()

select {
case <-ch2:
case <-time.After(200 * time.Millisecond):
t.Fatal("timeout")
}

select {
case <-ch:
case <-time.After(200 * time.Millisecond):
t.Fatal("timeout")
}

}

func TestManyValidatorsParallel(t *testing.T) {
type testValidator struct {
sk []byte
pk []byte
id string
}

testValidators := []testValidator{
{
sk: _byteArray("2c083f2c8fc923fa2bd32a70ab72b4b46247e8c1f347adc30b2f8036a355086c"),
pk: _byteArray("a9cf360aa15fb1d1d30ee2b578dc5884823c19661886ae8b892775ccb3bd96b7d7345569a2aa0b14e4d015c54a6a0c54"),
id: "1",
},
{
sk: _byteArray("6327b1e58c41d60dd7c3c8b9634204255707c2d12e2513c345001d8926745eea"),
pk: _byteArray("954eb88ed1207f891dc3c28fa6cfdf8f53bf0ed3d838f3476c0900a61314d22d4f0a300da3cd010444dd5183e35a593c"),
id: "2",
},
{
sk: _byteArray("5470813f7deef638dc531188ca89e36976d536f680e89849cd9077fd096e20bc"),
pk: _byteArray("a3862121db5914d7272b0b705e6e3c5336b79e316735661873566245207329c30f9a33d4fb5f5857fc6fd0a368186972"),
id: "3",
},
}

attestationDataByts := _byteArray("000000000000000000000000000000003a43a4bf26fb5947e809c1f24f7dc6857c8ac007e535d48e6e4eca2122fd776b0000000000000000000000000000000000000000000000000000000000000000000000000000000002000000000000003a43a4bf26fb5947e809c1f24f7dc6857c8ac007e535d48e6e4eca2122fd776b")
domain := _byteArray32("0100000081509579e35e84020ad8751eca180b44df470332d3ad17fc6fd52459")

// setup KeyVault
store := inmemStorage()
options := &eth2keymanager.KeyVaultOptions{}
options.SetStorage(store)
options.SetWalletType(core.NDWallet)
vault, err := eth2keymanager.NewKeyVault(options)
require.NoError(t, err)
wallet, err := vault.Wallet()
require.NoError(t, err)

// create accounts
protector := prot.NewNormalProtection(store)
for i := range testValidators {
k, err := core.NewHDKeyFromPrivateKey(testValidators[i].sk, "")
require.NoError(t, err)
require.EqualValues(t, testValidators[i].pk, k.PublicKey().Serialize())

acc := wallets.NewValidatorAccount(testValidators[i].id, k, nil, "", vault.Context)
require.NoError(t, err)
require.EqualValues(t, testValidators[i].pk, acc.ValidatorPublicKey())
require.NoError(t, wallet.AddValidatorAccount(acc))

// setup base attestation data
baseAttData := &phase0.AttestationData{}
require.NoError(t, baseAttData.UnmarshalSSZ(attestationDataByts))
err = protector.UpdateHighestAttestation(acc.ValidatorPublicKey(), baseAttData)
require.NoError(t, err)
}

// setup signer
signer := NewSimpleSigner(wallet, protector, core.PraterNetwork)

// Sign attestation in parallel.
type validatorResult struct {
signs int
errs int
}
var validatorResults = map[string]*validatorResult{}
var mu sync.Mutex
for _, v := range testValidators {
validatorResults[string(v.pk)] = &validatorResult{}
}

var wg sync.WaitGroup
const goroutinesPerValidator = 10
for _, v := range testValidators {
v := v
for i := 0; i < goroutinesPerValidator; i++ {
wg.Add(1)
go func() {
defer wg.Done()

// decode attestation to be signed
attData := &phase0.AttestationData{}
require.NoError(t, attData.UnmarshalSSZ(attestationDataByts))
attData.Slot += phase0.Slot(core.PraterNetwork.SlotsPerEpoch())
attData.Source.Epoch++
attData.Target.Epoch++

_, _, err := signer.SignBeaconAttestation(attData, domain, v.pk)
// require.EqualValues(t, sig, actualSig)

mu.Lock()
defer mu.Unlock()
if err != nil {
validatorResults[string(v.pk)].errs++
require.ErrorContains(t, err, "slashable attestation (HighestAttestationVote), not signing")
} else {
validatorResults[string(v.pk)].signs++
}
}()
}
}
wg.Wait()

for pk, v := range validatorResults {
t.Logf("pk: %x, signs: %d, errs: %d", []byte(pk), v.signs, v.errs)

require.Equal(t, 1, v.signs)
require.Equal(t, goroutinesPerValidator-1, v.errs)
}
}

func TestAttestationSlashingSignatures(t *testing.T) {
t.Run("valid attestation, sign using public key", func(t *testing.T) {
seed, _ := hex.DecodeString("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1fff")
Expand Down
5 changes: 3 additions & 2 deletions signer/sign_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ func (signer *SimpleSigner) SignBlock(block ssz.HashRoot, slot phase0.Slot, doma
}

// 2. lock for current account
signer.lock(account.ID(), "proposal")
defer signer.unlock(account.ID(), "proposal")
val := signer.lock(account.ID(), "proposal")
val.Lock()
defer val.Unlock()

// 3. far future check
if !IsValidFarFutureSlot(signer.network, slot) {
Expand Down
15 changes: 9 additions & 6 deletions signer/sign_sync_committee.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ func (signer *SimpleSigner) SignSyncCommittee(msgBlockRoot []byte, domain phase0
}

// 2. lock for current account
signer.lock(account.ID(), "sync_committee")
defer signer.unlock(account.ID(), "sync_committee")
val := signer.lock(account.ID(), "sync_committee")
val.Lock()
defer val.Unlock()

// 3. sign
sszRoot := SSZBytes(msgBlockRoot)
Expand Down Expand Up @@ -51,8 +52,9 @@ func (signer *SimpleSigner) SignSyncCommitteeSelectionData(data *altair.SyncAggr
}

// 2. lock for current account
signer.lock(account.ID(), "sync_committee_selection_data")
defer signer.unlock(account.ID(), "sync_committee_selection_data")
val := signer.lock(account.ID(), "sync_committee_selection_data")
val.Lock()
defer val.Unlock()

// 3. sign
if data == nil {
Expand Down Expand Up @@ -83,8 +85,9 @@ func (signer *SimpleSigner) SignSyncCommitteeContributionAndProof(contribAndProo
}

// 2. lock for current account
signer.lock(account.ID(), "sync_committee_selection_and_proof")
defer signer.unlock(account.ID(), "sync_committee_selection_and_proof")
val := signer.lock(account.ID(), "sync_committee_selection_and_proof")
val.Lock()
defer val.Unlock()

// 3. sign
if contribAndProof == nil {
Expand Down
16 changes: 3 additions & 13 deletions signer/validator_signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,16 @@ func NewSimpleSigner(wallet core.Wallet, slashingProtector core.SlashingProtecto
}

// lock locks signer
func (signer *SimpleSigner) lock(accountID uuid.UUID, operation string) {
func (signer *SimpleSigner) lock(accountID uuid.UUID, operation string) *sync.RWMutex {
signer.mapLock.Lock()
defer signer.mapLock.Unlock()

k := accountID.String() + "_" + operation
if val, ok := signer.signLocks[k]; ok {
val.Lock()
return val
} else {
signer.signLocks[k] = &sync.RWMutex{}
signer.signLocks[k].Lock()
}
}

func (signer *SimpleSigner) unlock(accountID uuid.UUID, operation string) {
signer.mapLock.RLock()
defer signer.mapLock.RUnlock()

k := accountID.String() + "_" + operation
if val, ok := signer.signLocks[k]; ok {
val.Unlock()
return signer.signLocks[k]
}
}

Expand Down
15 changes: 13 additions & 2 deletions wallets/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/hex"
"encoding/json"
"strings"
"sync"

"github.com/google/uuid"
"github.com/pkg/errors"
Expand All @@ -21,6 +22,7 @@ type HDAccount struct {
id uuid.UUID
validationKey *core.HDKey
withdrawalPubKey []byte
contextMtx sync.RWMutex
context *core.WalletContext
}

Expand Down Expand Up @@ -161,7 +163,7 @@ func (account *HDAccount) GetDepositData() (map[string]interface{}, error) {
depositData, root, err := eth1deposit.DepositData(
account.validationKey,
account.withdrawalPubKey,
account.context.Storage.Network(),
account.GetContext().Storage.Network(),
eth1deposit.MaxEffectiveBalanceInGwei,
)
if err != nil {
Expand All @@ -173,11 +175,20 @@ func (account *HDAccount) GetDepositData() (map[string]interface{}, error) {
"signature": strings.TrimPrefix(depositData.Signature.String(), "0x"),
"withdrawalCredentials": hex.EncodeToString(depositData.WithdrawalCredentials),
"depositDataRoot": hex.EncodeToString(root[:]),
"depositContractAddress": account.context.Storage.Network().DepositContractAddress(),
"depositContractAddress": account.GetContext().Storage.Network().DepositContractAddress(),
}, nil
}

// SetContext is the context setter
func (account *HDAccount) SetContext(ctx *core.WalletContext) {
account.contextMtx.Lock()
defer account.contextMtx.Unlock()
account.context = ctx
}

// GetContext is the context getter
func (account *HDAccount) GetContext() *core.WalletContext {
account.contextMtx.RLock()
defer account.contextMtx.RUnlock()
return account.context
}

0 comments on commit e6bc06b

Please sign in to comment.