From 8c293ecaa5e45ed28f48fb23a355acafa67ec7ea Mon Sep 17 00:00:00 2001 From: johnabass Date: Thu, 23 Jun 2022 17:26:50 -0700 Subject: [PATCH 1/5] VerifyPeerCertificate returns success if any certificate passes validation; refactoring tests --- xhttp/xhttpserver/tls.go | 29 ++++++++---- xhttp/xhttpserver/tls_test.go | 87 +++++++++++++++++++---------------- 2 files changed, 69 insertions(+), 47 deletions(-) diff --git a/xhttp/xhttpserver/tls.go b/xhttp/xhttpserver/tls.go index 2feba4c..c5fe5c7 100644 --- a/xhttp/xhttpserver/tls.go +++ b/xhttp/xhttpserver/tls.go @@ -6,6 +6,8 @@ import ( "errors" "io/ioutil" "strings" + + "go.uber.org/multierr" ) var ( @@ -122,23 +124,34 @@ func (pvs PeerVerifiers) Verify(peerCert *x509.Certificate, verifiedChains [][]* } // VerifyPeerCertificate may be used as the closure for crypto/tls.Config.VerifyPeerCertificate -func (pvs PeerVerifiers) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { +// +// If any of the rawCerts passes verification, this method returns nil to indicate that the +// client has supplied a valid certificate. If all rawCerts fail verification or if any certificates +// fail to parse, this method returns an error. +func (pvs PeerVerifiers) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) (err error) { if len(pvs) == 0 { - return nil + return } for _, rawCert := range rawCerts { - peerCert, err := x509.ParseCertificate(rawCert) - if err == nil { - err = pvs.Verify(peerCert, verifiedChains) + peerCert, parseErr := x509.ParseCertificate(rawCert) + if parseErr != nil { + // any parse error invalidates the entire sequence of certificates + err = parseErr + break } - if err != nil { - return err + verifyErr := pvs.Verify(peerCert, verifiedChains) + err = multierr.Append(err, verifyErr) + + if verifyErr == nil { + // we found (1) cert that passes verification, so we're good + err = nil + break } } - return nil + return } // NewPeerVerifiers constructs a chain of verification strategies merged from a set of options with an extra diff --git a/xhttp/xhttpserver/tls_test.go b/xhttp/xhttpserver/tls_test.go index 00621d7..0478474 100644 --- a/xhttp/xhttpserver/tls_test.go +++ b/xhttp/xhttpserver/tls_test.go @@ -10,11 +10,11 @@ import ( "math/big" "math/rand" "os" - "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" ) func TestPeerVerifyError(t *testing.T) { @@ -26,12 +26,24 @@ func TestPeerVerifyError(t *testing.T) { assert.Equal("expected", err.Error()) } -func testConfiguredPeerVerifierSuccess(t *testing.T) { +type ConfiguredPeerVerifierSuite struct { + suite.Suite +} + +func (suite *ConfiguredPeerVerifierSuite) newConfiguredPeerVerifier(o PeerVerifyOptions) *ConfiguredPeerVerifier { + cpv := NewConfiguredPeerVerifier(o) + suite.Require().NotNil(cpv) + return cpv +} + +func (suite *ConfiguredPeerVerifierSuite) testVerifySuccess() { testData := []struct { - peerCert x509.Certificate - options PeerVerifyOptions + description string + peerCert x509.Certificate + options PeerVerifyOptions }{ { + description: "DNS name", peerCert: x509.Certificate{ DNSNames: []string{"test.foobar.com"}, }, @@ -40,6 +52,7 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, }, { + description: "multiple DNS names", peerCert: x509.Certificate{ DNSNames: []string{"first.foobar.com", "second.something.net"}, }, @@ -48,6 +61,7 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, }, { + description: "common name as host name", peerCert: x509.Certificate{ Subject: pkix.Name{ CommonName: "PCTEST-another.thing.org", @@ -58,6 +72,7 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, }, { + description: "common name", peerCert: x509.Certificate{ Subject: pkix.Name{ CommonName: "A Great Organization", @@ -68,6 +83,7 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, }, { + description: "multiple common names", peerCert: x509.Certificate{ Subject: pkix.Name{ CommonName: "A Great Organization", @@ -79,34 +95,30 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, } - for i, record := range testData { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - - verifier = NewConfiguredPeerVerifier(record.options) - ) - - require.NotNil(verifier) - assert.NoError(verifier.Verify(&record.peerCert, nil)) + for _, testCase := range testData { + suite.Run(testCase.description, func() { + verifier := suite.newConfiguredPeerVerifier(testCase.options) + suite.NoError(verifier.Verify(&testCase.peerCert, nil)) }) } } -func testConfiguredPeerVerifierFailure(t *testing.T) { +func (suite *ConfiguredPeerVerifierSuite) testVerifyFailure() { testData := []struct { - peerCert x509.Certificate - options PeerVerifyOptions + description string + peerCert x509.Certificate + options PeerVerifyOptions }{ { - peerCert: x509.Certificate{}, + description: "empty fields", + peerCert: x509.Certificate{}, options: PeerVerifyOptions{ DNSSuffixes: []string{"foobar.net"}, CommonNames: []string{"For Great Justice"}, }, }, { + description: "DNS mismatch", peerCert: x509.Certificate{ DNSNames: []string{"another.company.com"}, }, @@ -116,6 +128,7 @@ func testConfiguredPeerVerifierFailure(t *testing.T) { }, }, { + description: "CommonName mismatch", peerCert: x509.Certificate{ Subject: pkix.Name{ CommonName: "Villains For Hire", @@ -127,6 +140,7 @@ func testConfiguredPeerVerifierFailure(t *testing.T) { }, }, { + description: "DNS and CommonName mismatch", peerCert: x509.Certificate{ DNSNames: []string{"another.company.com"}, Subject: pkix.Name{ @@ -140,36 +154,31 @@ func testConfiguredPeerVerifierFailure(t *testing.T) { }, } - for i, record := range testData { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - - verifier = NewConfiguredPeerVerifier(record.options) - ) + for _, testCase := range testData { + suite.Run(testCase.description, func() { + verifier := suite.newConfiguredPeerVerifier(testCase.options) - require.NotNil(verifier) - err := verifier.Verify(&record.peerCert, nil) - assert.Error(err) - require.IsType(PeerVerifyError{}, err) - assert.Equal(&record.peerCert, err.(PeerVerifyError).Certificate) + err := verifier.Verify(&testCase.peerCert, nil) + suite.Error(err) }) } } -func TestConfiguredPeerVerifier(t *testing.T) { - t.Run("Success", testConfiguredPeerVerifierSuccess) - t.Run("Failure", testConfiguredPeerVerifierFailure) +func (suite *ConfiguredPeerVerifierSuite) TestVerify() { + suite.Run("Success", suite.testVerifySuccess) + suite.Run("Failure", suite.testVerifyFailure) } -func TestNewConfiguredPeerVerifier(t *testing.T) { - t.Run("Nil", func(t *testing.T) { - assert := assert.New(t) - assert.Nil(NewConfiguredPeerVerifier(PeerVerifyOptions{})) +func (suite *ConfiguredPeerVerifierSuite) TestNewConfiguredPeerVerifier() { + suite.Run("Nil", func() { + suite.Nil(NewConfiguredPeerVerifier(PeerVerifyOptions{})) }) } +func TestConfiguredPeerVerifier(t *testing.T) { + suite.Run(t, new(ConfiguredPeerVerifierSuite)) +} + func testPeerVerifiersVerifyPeerCertificate(t *testing.T) { t.Run("UnparseableCert", func(t *testing.T) { var ( From f18594021173bacfba5952c0687a9655f6d01833 Mon Sep 17 00:00:00 2001 From: johnabass Date: Fri, 24 Jun 2022 14:08:33 -0700 Subject: [PATCH 2/5] refactored tests to cover more cases --- xhttp/xhttpserver/tls_test.go | 288 ++++++++++++---------------------- 1 file changed, 98 insertions(+), 190 deletions(-) diff --git a/xhttp/xhttpserver/tls_test.go b/xhttp/xhttpserver/tls_test.go index 0478474..566c9a7 100644 --- a/xhttp/xhttpserver/tls_test.go +++ b/xhttp/xhttpserver/tls_test.go @@ -1,6 +1,7 @@ package xhttpserver import ( + "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" @@ -8,7 +9,6 @@ import ( "errors" "fmt" "math/big" - "math/rand" "os" "testing" @@ -179,226 +179,134 @@ func TestConfiguredPeerVerifier(t *testing.T) { suite.Run(t, new(ConfiguredPeerVerifierSuite)) } -func testPeerVerifiersVerifyPeerCertificate(t *testing.T) { - t.Run("UnparseableCert", func(t *testing.T) { - var ( - assert = assert.New(t) +type PeerVerifiersSuite struct { + suite.Suite - unparseable = []byte("unparseable") + key *rsa.PrivateKey +} - m = PeerVerifierFunc(func(*x509.Certificate, [][]*x509.Certificate) error { - assert.Fail("This verifier should not have been called due to an unparseable certificate") - return nil - }) +func (suite *PeerVerifiersSuite) SetupSuite() { + var err error + suite.key, err = rsa.GenerateKey(rand.Reader, 512) + suite.Require().NoError(err) +} - pv = PeerVerifiers{m} - ) +func (suite *PeerVerifiersSuite) createSelfSignedCertificate(template *x509.Certificate) []byte { + if template.SerialNumber == nil { + template.SerialNumber = big.NewInt(1) + } - assert.Error(pv.VerifyPeerCertificate([][]byte{unparseable}, nil)) - }) + raw, err := x509.CreateCertificate( + rand.Reader, + template, + template, + suite.key.PublicKey, + suite.key, + ) + suite.Require().NoError(err) + suite.Require().NotEmpty(raw) + return raw +} + +func (suite *PeerVerifiersSuite) TestUnparseableCertificate() { var ( - random = rand.New(rand.NewSource(1234)) - verifyErr = errors.New("expected Verify error") - - testData = []struct { - results []error - expectedErr error - }{ - { - results: []error{}, - expectedErr: nil, - }, - { - results: []error{nil}, - expectedErr: nil, - }, - { - results: []error{verifyErr}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil}, - expectedErr: nil, - }, - { - results: []error{nil, verifyErr}, - expectedErr: verifyErr, - }, - { - results: []error{verifyErr, nil}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil, nil, nil, nil}, - expectedErr: nil, - }, - { - results: []error{nil, nil, verifyErr, nil, nil}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil, nil, nil, verifyErr}, - expectedErr: verifyErr, - }, - } + unparseable = []byte("unparseable") + + m = PeerVerifierFunc(func(*x509.Certificate, [][]*x509.Certificate) error { + suite.Fail("This verifier should not have been called due to an unparseable certificate") + return nil + }) + + pv = PeerVerifiers{m} ) - for i, record := range testData { - t.Run(fmt.Sprintf("i=%d,len=%d", i, len(record.results)), func(t *testing.T) { + suite.Error(pv.VerifyPeerCertificate([][]byte{unparseable}, nil)) +} + +func (suite *PeerVerifiersSuite) testVerifySuccess() { + for l := 0; l < 3; l++ { + suite.Run(fmt.Sprintf("len=%d", l), func() { var ( - assert = assert.New(t) - require = require.New(t) - peerSerial = rand.Int63() - chainSerial = rand.Int63() - template = stubPeerCert(peerSerial) + callCount int + verifier = PeerVerifierFunc( + func(cert *x509.Certificate, _ [][]*x509.Certificate) error { + callCount++ + return nil + }, + ) pv PeerVerifiers ) - key, err := rsa.GenerateKey(random, 512) - require.NoError(err) - - peerCert, err := x509.CreateCertificate(random, template, template, &key.PublicKey, key) - require.NoError(err) - - errEncountered := false - for _, result := range record.results { - err := result - if errEncountered { - pv = append(pv, PeerVerifierFunc(func(*x509.Certificate, [][]*x509.Certificate) error { - assert.Fail("This verifier should not have been called due to an earlier error") - return err - })) - } else { - pv = append(pv, PeerVerifierFunc(func(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { - require.NotNil(peerCert) - require.NotNil(peerCert.SerialNumber) - assert.Equal(0, peerCert.SerialNumber.Cmp(big.NewInt(peerSerial))) - - require.Len(verifiedChains, 1) - require.Len(verifiedChains[0], 1) - assert.Equal(0, verifiedChains[0][0].SerialNumber.Cmp(big.NewInt(chainSerial))) - - return err - })) - } - - if result != nil { - errEncountered = true - } + for i := 0; i < l; i++ { + pv = append(pv, verifier) } - assert.Equal( - record.expectedErr, - pv.VerifyPeerCertificate( - [][]byte{peerCert}, - stubChain(chainSerial), - ), - ) + suite.NoError(pv.Verify( + &x509.Certificate{}, + nil, + )) + + suite.Equal(l, callCount) }) } } -func testPeerVerifiersVerify(t *testing.T) { - var ( - verifyErr = errors.New("expected Verify error") - - testData = []struct { - results []error - expectedErr error - }{ - { - results: []error{}, - expectedErr: nil, - }, - { - results: []error{nil}, - expectedErr: nil, - }, - { - results: []error{verifyErr}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil}, - expectedErr: nil, - }, - { - results: []error{nil, verifyErr}, - expectedErr: verifyErr, - }, - { - results: []error{verifyErr, nil}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil, nil, nil, nil}, - expectedErr: nil, - }, - { - results: []error{nil, nil, verifyErr, nil, nil}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil, nil, nil, verifyErr}, - expectedErr: verifyErr, - }, - } - ) - - for i, record := range testData { - t.Run(fmt.Sprintf("i=%d,len=%d", i, len(record.results)), func(t *testing.T) { +func (suite *PeerVerifiersSuite) testVerifyFailure() { + for l := 1; l < 4; l++ { + suite.Run(fmt.Sprintf("len=%d", l), func() { var ( - assert = assert.New(t) - require = require.New(t) - peerSerial = rand.Int63() - chainSerial = rand.Int63() + goodCount int + good = PeerVerifierFunc( + func(cert *x509.Certificate, _ [][]*x509.Certificate) error { + goodCount++ + return nil + }, + ) + + badCount int + bad = PeerVerifierFunc( + func(cert *x509.Certificate, _ [][]*x509.Certificate) error { + badCount++ + return errors.New("expected") + }, + ) + + shouldNotBeCalled = PeerVerifierFunc( + func(cert *x509.Certificate, _ [][]*x509.Certificate) error { + suite.Fail("This peer verifier should not have been called") + return errors.New("this should not have been called") + }, + ) pv PeerVerifiers ) - errEncountered := false - for _, result := range record.results { - err := result - if errEncountered { - pv = append(pv, PeerVerifierFunc(func(*x509.Certificate, [][]*x509.Certificate) error { - assert.Fail("This verifier should not have been called due to an earlier error") - return err - })) - } else { - pv = append(pv, PeerVerifierFunc(func(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { - require.NotNil(peerCert) - require.NotNil(peerCert.SerialNumber) - assert.Equal(0, peerCert.SerialNumber.Cmp(big.NewInt(peerSerial))) - - require.Len(verifiedChains, 1) - require.Len(verifiedChains[0], 1) - assert.Equal(0, verifiedChains[0][0].SerialNumber.Cmp(big.NewInt(chainSerial))) - - return err - })) - } - - if result != nil { - errEncountered = true - } + for i := 0; i < l-1; i++ { + pv = append(pv, good) } - assert.Equal( - record.expectedErr, - pv.Verify( - stubPeerCert(peerSerial), - stubChain(chainSerial), - ), - ) + pv = append(pv, bad, shouldNotBeCalled) + + suite.Error(pv.Verify( + &x509.Certificate{}, + nil, + )) + + suite.Equal(l-1, goodCount) + suite.Equal(1, badCount) }) } } +func (suite *PeerVerifiersSuite) TestVerify() { + suite.Run("Success", suite.testVerifySuccess) + suite.Run("Failure", suite.testVerifyFailure) +} + func TestPeerVerifiers(t *testing.T) { - t.Run("VerifyPeerCertificate", testPeerVerifiersVerifyPeerCertificate) - t.Run("Verify", testPeerVerifiersVerify) + suite.Run(t, new(PeerVerifiersSuite)) } func TestNewPeerVerifiers(t *testing.T) { From e46d8d19bf07e9f7705505da275736427b96b244 Mon Sep 17 00:00:00 2001 From: johnabass Date: Sat, 25 Jun 2022 15:41:26 -0700 Subject: [PATCH 3/5] refactor into mocks instead of closures --- xhttp/xhttpserver/mocks_test.go | 41 +++++--- xhttp/xhttpserver/tls_test.go | 169 ++++++++++++++++---------------- 2 files changed, 114 insertions(+), 96 deletions(-) diff --git a/xhttp/xhttpserver/mocks_test.go b/xhttp/xhttpserver/mocks_test.go index 231e6fb..3114eaa 100644 --- a/xhttp/xhttpserver/mocks_test.go +++ b/xhttp/xhttpserver/mocks_test.go @@ -4,10 +4,11 @@ import ( "bufio" "context" "crypto/x509" - "math/big" "net" "net/http" + "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -134,18 +135,36 @@ func (m *mockServer) ExpectShutdown(p ...interface{}) *mock.Call { return m.On("Shutdown", p...) } -func stubPeerCert(serialNumber int64) *x509.Certificate { - return &x509.Certificate{ - SerialNumber: big.NewInt(serialNumber), +func newCertificateMatcher(t *testing.T, commonName string, dnsNames ...string) func(*x509.Certificate) bool { + return func(actual *x509.Certificate) bool { + result := assert.Equal(t, commonName, actual.Subject.CommonName, "Subject common names do not match") + result = result && assert.Equal(t, dnsNames, actual.DNSNames, "DNS names do not match") + + return result } } -func stubChain(serialNumber int64) [][]*x509.Certificate { - return [][]*x509.Certificate{ - []*x509.Certificate{ - &x509.Certificate{ - SerialNumber: big.NewInt(serialNumber), - }, - }, +type mockPeerVerifier struct { + mock.Mock +} + +func (m *mockPeerVerifier) Verify(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { + return m.Called(peerCert, verifiedChains).Error(0) +} + +// ExpectVerify sets up the a mocked call to Verify with a peer certificate with the given +// subject common name and dns names. Since this package doesn't use any other fields, +// this expectation suffices for tests. +func (m *mockPeerVerifier) ExpectVerify(certificateMatcher func(*x509.Certificate) bool) *mock.Call { + return m.On( + "Verify", + mock.MatchedBy(certificateMatcher), + [][]*x509.Certificate(nil), // we always pass nil in tests, since we don't use this parameter + ) +} + +func assertPeerVerifierExpectations(t *testing.T, pvs ...PeerVerifier) { + for _, pv := range pvs { + pv.(*mockPeerVerifier).AssertExpectations(t) } } diff --git a/xhttp/xhttpserver/tls_test.go b/xhttp/xhttpserver/tls_test.go index 566c9a7..d09a1d5 100644 --- a/xhttp/xhttpserver/tls_test.go +++ b/xhttp/xhttpserver/tls_test.go @@ -182,127 +182,126 @@ func TestConfiguredPeerVerifier(t *testing.T) { type PeerVerifiersSuite struct { suite.Suite - key *rsa.PrivateKey + key *rsa.PrivateKey + testCerts []*x509.Certificate + rawCerts [][]byte } func (suite *PeerVerifiersSuite) SetupSuite() { var err error suite.key, err = rsa.GenerateKey(rand.Reader, 512) suite.Require().NoError(err) -} -func (suite *PeerVerifiersSuite) createSelfSignedCertificate(template *x509.Certificate) []byte { - if template.SerialNumber == nil { - template.SerialNumber = big.NewInt(1) + suite.testCerts = make([]*x509.Certificate, 3) + suite.rawCerts = make([][]byte, len(suite.testCerts)) + + for i := 0; i < len(suite.testCerts); i++ { + suite.testCerts[i] = &x509.Certificate{ + SerialNumber: big.NewInt(int64(i + 1)), + DNSNames: []string{ + fmt.Sprintf("host-%d.net", i), + }, + Subject: pkix.Name{ + CommonName: fmt.Sprintf("Organization #%d", i), + }, + } + + var err error + suite.rawCerts[i], err = x509.CreateCertificate( + rand.Reader, + suite.testCerts[i], + suite.testCerts[i], + &suite.key.PublicKey, + suite.key, + ) + + suite.Require().NoError(err) + suite.Require().NotEmpty(suite.rawCerts[i]) } +} - raw, err := x509.CreateCertificate( - rand.Reader, - template, - template, - suite.key.PublicKey, - suite.key, +// useCertificate gives some syntactic sugar for expecting a peer cert +func (suite *PeerVerifiersSuite) useCertificate(expected *x509.Certificate) func(*x509.Certificate) bool { + return newCertificateMatcher( + suite.T(), + expected.Subject.CommonName, + expected.DNSNames..., ) +} - suite.Require().NoError(err) - suite.Require().NotEmpty(raw) - return raw +func (suite *PeerVerifiersSuite) expectVerify(expected *x509.Certificate, result error) *mockPeerVerifier { + m := new(mockPeerVerifier) + m.ExpectVerify( + suite.useCertificate(expected), + ).Return(result) + + return m } func (suite *PeerVerifiersSuite) TestUnparseableCertificate() { var ( unparseable = []byte("unparseable") - m = PeerVerifierFunc(func(*x509.Certificate, [][]*x509.Certificate) error { - suite.Fail("This verifier should not have been called due to an unparseable certificate") - return nil - }) - + m = new(mockPeerVerifier) // no calls pv = PeerVerifiers{m} ) suite.Error(pv.VerifyPeerCertificate([][]byte{unparseable}, nil)) + m.AssertExpectations(suite.T()) } -func (suite *PeerVerifiersSuite) testVerifySuccess() { - for l := 0; l < 3; l++ { - suite.Run(fmt.Sprintf("len=%d", l), func() { - var ( - callCount int - verifier = PeerVerifierFunc( - func(cert *x509.Certificate, _ [][]*x509.Certificate) error { - callCount++ - return nil - }, - ) - - pv PeerVerifiers - ) - - for i := 0; i < l; i++ { - pv = append(pv, verifier) - } +func (suite *PeerVerifiersSuite) testVerifySuccess(expected *x509.Certificate) { + for count := 0; count < 3; count++ { + var pv PeerVerifiers + for i := 0; i < count; i++ { + pv = append(pv, suite.expectVerify(expected, nil)) + } - suite.NoError(pv.Verify( - &x509.Certificate{}, - nil, - )) + suite.NoError( + pv.Verify(expected, nil), + ) - suite.Equal(l, callCount) - }) + assertPeerVerifierExpectations(suite.T(), pv...) } } -func (suite *PeerVerifiersSuite) testVerifyFailure() { - for l := 1; l < 4; l++ { - suite.Run(fmt.Sprintf("len=%d", l), func() { - var ( - goodCount int - good = PeerVerifierFunc( - func(cert *x509.Certificate, _ [][]*x509.Certificate) error { - goodCount++ - return nil - }, - ) - - badCount int - bad = PeerVerifierFunc( - func(cert *x509.Certificate, _ [][]*x509.Certificate) error { - badCount++ - return errors.New("expected") - }, - ) - - shouldNotBeCalled = PeerVerifierFunc( - func(cert *x509.Certificate, _ [][]*x509.Certificate) error { - suite.Fail("This peer verifier should not have been called") - return errors.New("this should not have been called") - }, - ) - - pv PeerVerifiers - ) +func (suite *PeerVerifiersSuite) testVerifyFailure(expected *x509.Certificate) { + for count := 0; count < 3; count++ { + var pv PeerVerifiers - for i := 0; i < l-1; i++ { - pv = append(pv, good) - } + // setup our "good" calls + for i := 0; i < count; i++ { + pv = append(pv, suite.expectVerify(expected, nil)) + } - pv = append(pv, bad, shouldNotBeCalled) + // a failure, followed by a verifier that shouldn't be called + pv = append(pv, + suite.expectVerify(expected, errors.New("expected")), + new(mockPeerVerifier), + ) - suite.Error(pv.Verify( - &x509.Certificate{}, - nil, - )) + suite.Error( + pv.Verify(expected, nil), + ) - suite.Equal(l-1, goodCount) - suite.Equal(1, badCount) - }) + assertPeerVerifierExpectations(suite.T(), pv...) } } func (suite *PeerVerifiersSuite) TestVerify() { - suite.Run("Success", suite.testVerifySuccess) - suite.Run("Failure", suite.testVerifyFailure) + suite.Run("Success", func() { + suite.testVerifySuccess(suite.testCerts[0]) + }) + + suite.Run("Failure", func() { + suite.testVerifyFailure(suite.testCerts[0]) + }) +} + +func (suite *PeerVerifiersSuite) testVerifyPeerCertificateAllGood(testCerts []*x509.Certificate, rawCerts [][]byte) { +} + +func (suite *PeerVerifiersSuite) TestVerifyPeerCertificate() { } func TestPeerVerifiers(t *testing.T) { From da6bc47dc7a38100a7127f7b6070cfc25e38730c Mon Sep 17 00:00:00 2001 From: johnabass Date: Sun, 26 Jun 2022 19:05:42 -0700 Subject: [PATCH 4/5] fleshed out VerifyPeerCertificate cases --- xhttp/xhttpserver/mocks_test.go | 21 ++++- xhttp/xhttpserver/tls_test.go | 151 ++++++++++++++++++++++++++------ 2 files changed, 143 insertions(+), 29 deletions(-) diff --git a/xhttp/xhttpserver/mocks_test.go b/xhttp/xhttpserver/mocks_test.go index 3114eaa..54b5253 100644 --- a/xhttp/xhttpserver/mocks_test.go +++ b/xhttp/xhttpserver/mocks_test.go @@ -8,7 +8,6 @@ import ( "net/http" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -137,10 +136,24 @@ func (m *mockServer) ExpectShutdown(p ...interface{}) *mock.Call { func newCertificateMatcher(t *testing.T, commonName string, dnsNames ...string) func(*x509.Certificate) bool { return func(actual *x509.Certificate) bool { - result := assert.Equal(t, commonName, actual.Subject.CommonName, "Subject common names do not match") - result = result && assert.Equal(t, dnsNames, actual.DNSNames, "DNS names do not match") + t.Logf("Testing cert: Subject.CommonName=%s, DNSNames=%s", actual.Subject.CommonName, actual.DNSNames) - return result + switch { + case commonName != actual.Subject.CommonName: + return false + + case len(dnsNames) != len(actual.DNSNames): + return false + + default: + for i := 0; i < len(dnsNames); i++ { + if dnsNames[i] != actual.DNSNames[i] { + return false + } + } + } + + return true } } diff --git a/xhttp/xhttpserver/tls_test.go b/xhttp/xhttpserver/tls_test.go index d09a1d5..e0b7113 100644 --- a/xhttp/xhttpserver/tls_test.go +++ b/xhttp/xhttpserver/tls_test.go @@ -26,6 +26,16 @@ func TestPeerVerifyError(t *testing.T) { assert.Equal("expected", err.Error()) } +func TestPeerVerifierFunc(t *testing.T) { + testCert := new(x509.Certificate) + pvf := PeerVerifierFunc(func(actual *x509.Certificate, verifiedChains [][]*x509.Certificate) error { + assert.Equal(t, testCert, actual) + return nil + }) + + assert.NoError(t, pvf.Verify(testCert, nil)) +} + type ConfiguredPeerVerifierSuite struct { suite.Suite } @@ -252,39 +262,43 @@ func (suite *PeerVerifiersSuite) TestUnparseableCertificate() { func (suite *PeerVerifiersSuite) testVerifySuccess(expected *x509.Certificate) { for count := 0; count < 3; count++ { - var pv PeerVerifiers - for i := 0; i < count; i++ { - pv = append(pv, suite.expectVerify(expected, nil)) - } - - suite.NoError( - pv.Verify(expected, nil), - ) + suite.Run(fmt.Sprintf("verifiers=%d", count), func() { + var pv PeerVerifiers + for i := 0; i < count; i++ { + pv = append(pv, suite.expectVerify(expected, nil)) + } + + suite.NoError( + pv.Verify(expected, nil), + ) - assertPeerVerifierExpectations(suite.T(), pv...) + assertPeerVerifierExpectations(suite.T(), pv...) + }) } } func (suite *PeerVerifiersSuite) testVerifyFailure(expected *x509.Certificate) { for count := 0; count < 3; count++ { - var pv PeerVerifiers - - // setup our "good" calls - for i := 0; i < count; i++ { - pv = append(pv, suite.expectVerify(expected, nil)) - } - - // a failure, followed by a verifier that shouldn't be called - pv = append(pv, - suite.expectVerify(expected, errors.New("expected")), - new(mockPeerVerifier), - ) + suite.Run(fmt.Sprintf("goodVerifiers=%d", count), func() { + var pv PeerVerifiers + + // setup our "good" calls + for i := 0; i < count; i++ { + pv = append(pv, suite.expectVerify(expected, nil)) + } + + // a failure, followed by a verifier that shouldn't be called + pv = append(pv, + suite.expectVerify(expected, errors.New("expected")), + new(mockPeerVerifier), + ) - suite.Error( - pv.Verify(expected, nil), - ) + suite.Error( + pv.Verify(expected, nil), + ) - assertPeerVerifierExpectations(suite.T(), pv...) + assertPeerVerifierExpectations(suite.T(), pv...) + }) } } @@ -299,9 +313,96 @@ func (suite *PeerVerifiersSuite) TestVerify() { } func (suite *PeerVerifiersSuite) testVerifyPeerCertificateAllGood(testCerts []*x509.Certificate, rawCerts [][]byte) { + for count := 0; count < 3; count++ { + suite.Run(fmt.Sprintf("verifiers=%d", count), func() { + var pv PeerVerifiers + for i := 0; i < count; i++ { + m := new(mockPeerVerifier) + for j := 0; j < len(testCerts); j++ { + // maybe is used here because a success short-circuits subsequent calls + m.ExpectVerify(suite.useCertificate(testCerts[j])).Return(error(nil)).Maybe() + } + + pv = append(pv, m) + } + + suite.NoError( + pv.VerifyPeerCertificate(rawCerts, nil), + ) + + assertPeerVerifierExpectations(suite.T(), pv...) + }) + } +} + +func (suite *PeerVerifiersSuite) testVerifyPeerCertificateAllBad(testCerts []*x509.Certificate, rawCerts [][]byte) { + for count := 1; count < 3; count++ { + suite.Run(fmt.Sprintf("verifiers=%d", count), func() { + var pv PeerVerifiers + for i := 0; i < count; i++ { + m := new(mockPeerVerifier) + for j := 0; j < len(testCerts); j++ { + // maybe is used here because a failure short-circuits subsequent calls + m.ExpectVerify(suite.useCertificate(testCerts[j])).Return(errors.New("expected")).Maybe() + } + + pv = append(pv, m) + } + + suite.Error( + pv.VerifyPeerCertificate(rawCerts, nil), + ) + + assertPeerVerifierExpectations(suite.T(), pv...) + }) + } +} + +func (suite *PeerVerifiersSuite) testVerifyPeerCertificateOneGood(testCerts []*x509.Certificate, rawCerts [][]byte) { + // a verifier that passes any but the first cert + oneGood := new(mockPeerVerifier) + oneGood.ExpectVerify(func(actual *x509.Certificate) bool { + return actual == testCerts[0] + }).Return(errors.New("oneGood: first cert should fail")).Maybe() + oneGood.ExpectVerify(func(actual *x509.Certificate) bool { + return actual != testCerts[0] + }).Return(error(nil)) + + pv := PeerVerifiers{ + oneGood, + } + + suite.NoError( + pv.VerifyPeerCertificate(rawCerts, nil), + ) + + assertPeerVerifierExpectations(suite.T(), pv...) } func (suite *PeerVerifiersSuite) TestVerifyPeerCertificate() { + suite.Run("AllGood", func() { + for i := 1; i < len(suite.testCerts); i++ { + suite.Run(fmt.Sprintf("certs=%d", i), func() { + suite.testVerifyPeerCertificateAllGood(suite.testCerts[0:i], suite.rawCerts[0:i]) + }) + } + }) + + suite.Run("AllBad", func() { + for i := 1; i < len(suite.testCerts); i++ { + suite.Run(fmt.Sprintf("certs=%d", i), func() { + suite.testVerifyPeerCertificateAllBad(suite.testCerts[0:i], suite.rawCerts[0:i]) + }) + } + }) + + suite.Run("OneGood", func() { + for i := 2; i < len(suite.testCerts); i++ { + suite.Run(fmt.Sprintf("certs=%d", i), func() { + suite.testVerifyPeerCertificateOneGood(suite.testCerts[0:i], suite.rawCerts[0:i]) + }) + } + }) } func TestPeerVerifiers(t *testing.T) { From b52d75bcb2a8aadb060f9362c3af06eb1badb064 Mon Sep 17 00:00:00 2001 From: johnabass Date: Sun, 26 Jun 2022 19:09:27 -0700 Subject: [PATCH 5/5] updated --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1fdd06..b81f869 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Migrated to github.com/golang-jwt/jwt to address a security vulnerability. [#78](https://github.com/xmidt-org/themis/pull/78) - Updated spec file and rpkg version macro to be able to choose when the 'v' is included in the version. [#80](https://github.com/xmidt-org/themis/pull/80) - Updated transport.go to send a 400 error if there is an issue parsing the URL. [#47](https://github.com/xmidt-org/themis/issues/47) +- Allow any peer certificate to pass validation, instead of requiring all of them to pass. [#91](https://github.com/xmidt-org/themis/issues/91) ## [v0.4.7]