diff --git a/types/encoding.go b/types/encoding.go index 58586630..0c8677df 100644 --- a/types/encoding.go +++ b/types/encoding.go @@ -1113,10 +1113,7 @@ func (txn *Transaction) DecodeFrom(d *Decoder) { // DecodeFrom implements types.DecoderFrom. func (p *SpendPolicy) DecodeFrom(d *Decoder) { - const ( - version = 1 - maxPolicies = 1024 - ) + const version = 1 const ( opInvalid = iota opAbove @@ -1128,7 +1125,6 @@ func (p *SpendPolicy) DecodeFrom(d *Decoder) { opUnlockConditions ) - var totalPolicies int var readPolicy func() (SpendPolicy, error) readPolicy = func() (SpendPolicy, error) { switch op := d.ReadUint8(); op { @@ -1147,9 +1143,6 @@ func (p *SpendPolicy) DecodeFrom(d *Decoder) { case opThreshold: n := d.ReadUint8() of := make([]SpendPolicy, d.ReadUint8()) - if totalPolicies += len(of); totalPolicies > maxPolicies { - return SpendPolicy{}, errors.New("policy is too complex") - } var err error for i := range of { if of[i], err = readPolicy(); err != nil { diff --git a/types/policy.go b/types/policy.go index 794dc85e..ec6e7242 100644 --- a/types/policy.go +++ b/types/policy.go @@ -132,8 +132,9 @@ func (p SpendPolicy) Verify(height uint64, medianTimestamp time.Time, sigHash Ha } return } - errInvalidSignature := errors.New("invalid signature") - errInvalidPreimage := errors.New("invalid preimage") + const maxPolicies = 1024 + var totalPolicies int + errOpaque := errors.New("opaque policy") var verify func(SpendPolicy) error verify = func(p SpendPolicy) error { switch p := p.Type.(type) { @@ -151,28 +152,38 @@ func (p SpendPolicy) Verify(height uint64, medianTimestamp time.Time, sigHash Ha if sig, ok := nextSig(); ok && PublicKey(p).VerifyHash(sigHash, sig) { return nil } - return errInvalidSignature + return errors.New("invalid signature") case PolicyTypeHash: if preimage, ok := nextPreimage(); ok && p == sha256.Sum256(preimage[:]) { return nil } - return errInvalidPreimage + return errors.New("invalid preimage") case PolicyTypeThreshold: - for i := 0; i < len(p.Of) && p.N > 0 && len(p.Of[i:]) >= int(p.N); i++ { - if _, ok := p.Of[i].Type.(PolicyTypeUnlockConditions); ok { + if totalPolicies += len(p.Of); totalPolicies > maxPolicies || len(p.Of) > 255 { + return errors.New("policy is too complex") + } + var satisfied uint8 + for _, sp := range p.Of { + switch sp.Type.(type) { + case PolicyTypeUnlockConditions: return errors.New("unlock conditions cannot be sub-policies") - } else if err := verify(p.Of[i]); err == errInvalidSignature || err == errInvalidPreimage { - return err // fatal; should have been opaque - } else if err == nil { - p.N-- + case PolicyTypeOpaque: + continue + default: + if satisfied == p.N { + return errors.New("threshold exceeded") + } else if err := verify(sp); err != nil { + return err // fatal; should have been opaque + } + satisfied++ } } - if p.N == 0 { + if satisfied == p.N { return nil } return errors.New("threshold not reached") case PolicyTypeOpaque: - return errors.New("opaque policy") + return errOpaque case PolicyTypeUnlockConditions: if err := verify(PolicyAbove(p.Timelock)); err != nil { return err diff --git a/types/policy_test.go b/types/policy_test.go index 1ec8179a..8e3f700e 100644 --- a/types/policy_test.go +++ b/types/policy_test.go @@ -264,7 +264,7 @@ func TestPolicyVerify(t *testing.T) { if err := test.p.Verify(test.height, time.Time{}, sigHash, test.sigs, nil); err != nil && test.valid { t.Fatalf("%v: %v", test.desc, err) } else if err == nil && !test.valid { - t.Fatal("expected error") + t.Fatalf("%v: expected error", test.desc) } } }