Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types: Move complexity check to (SpendPolicy).Verify #236

Merged
merged 2 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions types/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -1113,10 +1113,7 @@ func (txn *Transaction) DecodeFrom(d *Decoder) {

// DecodeFrom implements types.DecoderFrom.
func (p *SpendPolicy) DecodeFrom(d *Decoder) {
const (
version = 1
maxPolicies = 1024
)
const version = 1
const (
opInvalid = iota
opAbove
Expand All @@ -1128,7 +1125,6 @@ func (p *SpendPolicy) DecodeFrom(d *Decoder) {
opUnlockConditions
)

var totalPolicies int
var readPolicy func() (SpendPolicy, error)
readPolicy = func() (SpendPolicy, error) {
switch op := d.ReadUint8(); op {
Expand All @@ -1147,9 +1143,6 @@ func (p *SpendPolicy) DecodeFrom(d *Decoder) {
case opThreshold:
n := d.ReadUint8()
of := make([]SpendPolicy, d.ReadUint8())
if totalPolicies += len(of); totalPolicies > maxPolicies {
return SpendPolicy{}, errors.New("policy is too complex")
}
var err error
for i := range of {
if of[i], err = readPolicy(); err != nil {
Expand Down
35 changes: 23 additions & 12 deletions types/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ func (p SpendPolicy) Verify(height uint64, medianTimestamp time.Time, sigHash Ha
}
return
}
errInvalidSignature := errors.New("invalid signature")
lukechampine marked this conversation as resolved.
Show resolved Hide resolved
errInvalidPreimage := errors.New("invalid preimage")
const maxPolicies = 1024
var totalPolicies int
errOpaque := errors.New("opaque policy")
var verify func(SpendPolicy) error
verify = func(p SpendPolicy) error {
switch p := p.Type.(type) {
Expand All @@ -151,28 +152,38 @@ func (p SpendPolicy) Verify(height uint64, medianTimestamp time.Time, sigHash Ha
if sig, ok := nextSig(); ok && PublicKey(p).VerifyHash(sigHash, sig) {
return nil
}
return errInvalidSignature
return errors.New("invalid signature")
case PolicyTypeHash:
if preimage, ok := nextPreimage(); ok && p == sha256.Sum256(preimage[:]) {
return nil
}
return errInvalidPreimage
return errors.New("invalid preimage")
case PolicyTypeThreshold:
for i := 0; i < len(p.Of) && p.N > 0 && len(p.Of[i:]) >= int(p.N); i++ {
if _, ok := p.Of[i].Type.(PolicyTypeUnlockConditions); ok {
if totalPolicies += len(p.Of); totalPolicies > maxPolicies || len(p.Of) > 255 {
return errors.New("policy is too complex")
}
var satisfied uint8
for _, sp := range p.Of {
switch sp.Type.(type) {
case PolicyTypeUnlockConditions:
return errors.New("unlock conditions cannot be sub-policies")
} else if err := verify(p.Of[i]); err == errInvalidSignature || err == errInvalidPreimage {
return err // fatal; should have been opaque
} else if err == nil {
p.N--
case PolicyTypeOpaque:
continue
default:
if satisfied == p.N {
return errors.New("threshold exceeded")
} else if err := verify(sp); err != nil {
return err // fatal; should have been opaque
}
satisfied++
}
}
if p.N == 0 {
if satisfied == p.N {
return nil
}
return errors.New("threshold not reached")
case PolicyTypeOpaque:
return errors.New("opaque policy")
return errOpaque
case PolicyTypeUnlockConditions:
if err := verify(PolicyAbove(p.Timelock)); err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion types/policy_test.go
lukechampine marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func TestPolicyVerify(t *testing.T) {
if err := test.p.Verify(test.height, time.Time{}, sigHash, test.sigs, nil); err != nil && test.valid {
t.Fatalf("%v: %v", test.desc, err)
} else if err == nil && !test.valid {
t.Fatal("expected error")
t.Fatalf("%v: expected error", test.desc)
}
}
}
Expand Down