diff --git a/CHANGELOG.md b/CHANGELOG.md index d1fdd06..b81f869 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Migrated to github.com/golang-jwt/jwt to address a security vulnerability. [#78](https://github.com/xmidt-org/themis/pull/78) - Updated spec file and rpkg version macro to be able to choose when the 'v' is included in the version. [#80](https://github.com/xmidt-org/themis/pull/80) - Updated transport.go to send a 400 error if there is an issue parsing the URL. [#47](https://github.com/xmidt-org/themis/issues/47) +- Allow any peer certificate to pass validation, instead of requiring all of them to pass. [#91](https://github.com/xmidt-org/themis/issues/91) ## [v0.4.7] diff --git a/xhttp/xhttpserver/mocks_test.go b/xhttp/xhttpserver/mocks_test.go index 231e6fb..54b5253 100644 --- a/xhttp/xhttpserver/mocks_test.go +++ b/xhttp/xhttpserver/mocks_test.go @@ -4,9 +4,9 @@ import ( "bufio" "context" "crypto/x509" - "math/big" "net" "net/http" + "testing" "github.com/stretchr/testify/mock" ) @@ -134,18 +134,50 @@ func (m *mockServer) ExpectShutdown(p ...interface{}) *mock.Call { return m.On("Shutdown", p...) } -func stubPeerCert(serialNumber int64) *x509.Certificate { - return &x509.Certificate{ - SerialNumber: big.NewInt(serialNumber), +func newCertificateMatcher(t *testing.T, commonName string, dnsNames ...string) func(*x509.Certificate) bool { + return func(actual *x509.Certificate) bool { + t.Logf("Testing cert: Subject.CommonName=%s, DNSNames=%s", actual.Subject.CommonName, actual.DNSNames) + + switch { + case commonName != actual.Subject.CommonName: + return false + + case len(dnsNames) != len(actual.DNSNames): + return false + + default: + for i := 0; i < len(dnsNames); i++ { + if dnsNames[i] != actual.DNSNames[i] { + return false + } + } + } + + return true } } -func stubChain(serialNumber int64) [][]*x509.Certificate { - return [][]*x509.Certificate{ - []*x509.Certificate{ - &x509.Certificate{ - SerialNumber: big.NewInt(serialNumber), - }, - }, +type mockPeerVerifier struct { + mock.Mock +} + +func (m *mockPeerVerifier) Verify(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { + return m.Called(peerCert, verifiedChains).Error(0) +} + +// ExpectVerify sets up the a mocked call to Verify with a peer certificate with the given +// subject common name and dns names. Since this package doesn't use any other fields, +// this expectation suffices for tests. +func (m *mockPeerVerifier) ExpectVerify(certificateMatcher func(*x509.Certificate) bool) *mock.Call { + return m.On( + "Verify", + mock.MatchedBy(certificateMatcher), + [][]*x509.Certificate(nil), // we always pass nil in tests, since we don't use this parameter + ) +} + +func assertPeerVerifierExpectations(t *testing.T, pvs ...PeerVerifier) { + for _, pv := range pvs { + pv.(*mockPeerVerifier).AssertExpectations(t) } } diff --git a/xhttp/xhttpserver/tls.go b/xhttp/xhttpserver/tls.go index 2feba4c..c5fe5c7 100644 --- a/xhttp/xhttpserver/tls.go +++ b/xhttp/xhttpserver/tls.go @@ -6,6 +6,8 @@ import ( "errors" "io/ioutil" "strings" + + "go.uber.org/multierr" ) var ( @@ -122,23 +124,34 @@ func (pvs PeerVerifiers) Verify(peerCert *x509.Certificate, verifiedChains [][]* } // VerifyPeerCertificate may be used as the closure for crypto/tls.Config.VerifyPeerCertificate -func (pvs PeerVerifiers) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { +// +// If any of the rawCerts passes verification, this method returns nil to indicate that the +// client has supplied a valid certificate. If all rawCerts fail verification or if any certificates +// fail to parse, this method returns an error. +func (pvs PeerVerifiers) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) (err error) { if len(pvs) == 0 { - return nil + return } for _, rawCert := range rawCerts { - peerCert, err := x509.ParseCertificate(rawCert) - if err == nil { - err = pvs.Verify(peerCert, verifiedChains) + peerCert, parseErr := x509.ParseCertificate(rawCert) + if parseErr != nil { + // any parse error invalidates the entire sequence of certificates + err = parseErr + break } - if err != nil { - return err + verifyErr := pvs.Verify(peerCert, verifiedChains) + err = multierr.Append(err, verifyErr) + + if verifyErr == nil { + // we found (1) cert that passes verification, so we're good + err = nil + break } } - return nil + return } // NewPeerVerifiers constructs a chain of verification strategies merged from a set of options with an extra diff --git a/xhttp/xhttpserver/tls_test.go b/xhttp/xhttpserver/tls_test.go index 00621d7..e0b7113 100644 --- a/xhttp/xhttpserver/tls_test.go +++ b/xhttp/xhttpserver/tls_test.go @@ -1,6 +1,7 @@ package xhttpserver import ( + "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" @@ -8,13 +9,12 @@ import ( "errors" "fmt" "math/big" - "math/rand" "os" - "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" ) func TestPeerVerifyError(t *testing.T) { @@ -26,12 +26,34 @@ func TestPeerVerifyError(t *testing.T) { assert.Equal("expected", err.Error()) } -func testConfiguredPeerVerifierSuccess(t *testing.T) { +func TestPeerVerifierFunc(t *testing.T) { + testCert := new(x509.Certificate) + pvf := PeerVerifierFunc(func(actual *x509.Certificate, verifiedChains [][]*x509.Certificate) error { + assert.Equal(t, testCert, actual) + return nil + }) + + assert.NoError(t, pvf.Verify(testCert, nil)) +} + +type ConfiguredPeerVerifierSuite struct { + suite.Suite +} + +func (suite *ConfiguredPeerVerifierSuite) newConfiguredPeerVerifier(o PeerVerifyOptions) *ConfiguredPeerVerifier { + cpv := NewConfiguredPeerVerifier(o) + suite.Require().NotNil(cpv) + return cpv +} + +func (suite *ConfiguredPeerVerifierSuite) testVerifySuccess() { testData := []struct { - peerCert x509.Certificate - options PeerVerifyOptions + description string + peerCert x509.Certificate + options PeerVerifyOptions }{ { + description: "DNS name", peerCert: x509.Certificate{ DNSNames: []string{"test.foobar.com"}, }, @@ -40,6 +62,7 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, }, { + description: "multiple DNS names", peerCert: x509.Certificate{ DNSNames: []string{"first.foobar.com", "second.something.net"}, }, @@ -48,6 +71,7 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, }, { + description: "common name as host name", peerCert: x509.Certificate{ Subject: pkix.Name{ CommonName: "PCTEST-another.thing.org", @@ -58,6 +82,7 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, }, { + description: "common name", peerCert: x509.Certificate{ Subject: pkix.Name{ CommonName: "A Great Organization", @@ -68,6 +93,7 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, }, { + description: "multiple common names", peerCert: x509.Certificate{ Subject: pkix.Name{ CommonName: "A Great Organization", @@ -79,34 +105,30 @@ func testConfiguredPeerVerifierSuccess(t *testing.T) { }, } - for i, record := range testData { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - - verifier = NewConfiguredPeerVerifier(record.options) - ) - - require.NotNil(verifier) - assert.NoError(verifier.Verify(&record.peerCert, nil)) + for _, testCase := range testData { + suite.Run(testCase.description, func() { + verifier := suite.newConfiguredPeerVerifier(testCase.options) + suite.NoError(verifier.Verify(&testCase.peerCert, nil)) }) } } -func testConfiguredPeerVerifierFailure(t *testing.T) { +func (suite *ConfiguredPeerVerifierSuite) testVerifyFailure() { testData := []struct { - peerCert x509.Certificate - options PeerVerifyOptions + description string + peerCert x509.Certificate + options PeerVerifyOptions }{ { - peerCert: x509.Certificate{}, + description: "empty fields", + peerCert: x509.Certificate{}, options: PeerVerifyOptions{ DNSSuffixes: []string{"foobar.net"}, CommonNames: []string{"For Great Justice"}, }, }, { + description: "DNS mismatch", peerCert: x509.Certificate{ DNSNames: []string{"another.company.com"}, }, @@ -116,6 +138,7 @@ func testConfiguredPeerVerifierFailure(t *testing.T) { }, }, { + description: "CommonName mismatch", peerCert: x509.Certificate{ Subject: pkix.Name{ CommonName: "Villains For Hire", @@ -127,6 +150,7 @@ func testConfiguredPeerVerifierFailure(t *testing.T) { }, }, { + description: "DNS and CommonName mismatch", peerCert: x509.Certificate{ DNSNames: []string{"another.company.com"}, Subject: pkix.Name{ @@ -140,256 +164,249 @@ func testConfiguredPeerVerifierFailure(t *testing.T) { }, } - for i, record := range testData { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - - verifier = NewConfiguredPeerVerifier(record.options) - ) + for _, testCase := range testData { + suite.Run(testCase.description, func() { + verifier := suite.newConfiguredPeerVerifier(testCase.options) - require.NotNil(verifier) - err := verifier.Verify(&record.peerCert, nil) - assert.Error(err) - require.IsType(PeerVerifyError{}, err) - assert.Equal(&record.peerCert, err.(PeerVerifyError).Certificate) + err := verifier.Verify(&testCase.peerCert, nil) + suite.Error(err) }) } } -func TestConfiguredPeerVerifier(t *testing.T) { - t.Run("Success", testConfiguredPeerVerifierSuccess) - t.Run("Failure", testConfiguredPeerVerifierFailure) +func (suite *ConfiguredPeerVerifierSuite) TestVerify() { + suite.Run("Success", suite.testVerifySuccess) + suite.Run("Failure", suite.testVerifyFailure) } -func TestNewConfiguredPeerVerifier(t *testing.T) { - t.Run("Nil", func(t *testing.T) { - assert := assert.New(t) - assert.Nil(NewConfiguredPeerVerifier(PeerVerifyOptions{})) +func (suite *ConfiguredPeerVerifierSuite) TestNewConfiguredPeerVerifier() { + suite.Run("Nil", func() { + suite.Nil(NewConfiguredPeerVerifier(PeerVerifyOptions{})) }) } -func testPeerVerifiersVerifyPeerCertificate(t *testing.T) { - t.Run("UnparseableCert", func(t *testing.T) { - var ( - assert = assert.New(t) +func TestConfiguredPeerVerifier(t *testing.T) { + suite.Run(t, new(ConfiguredPeerVerifierSuite)) +} - unparseable = []byte("unparseable") +type PeerVerifiersSuite struct { + suite.Suite - m = PeerVerifierFunc(func(*x509.Certificate, [][]*x509.Certificate) error { - assert.Fail("This verifier should not have been called due to an unparseable certificate") - return nil - }) + key *rsa.PrivateKey + testCerts []*x509.Certificate + rawCerts [][]byte +} - pv = PeerVerifiers{m} - ) +func (suite *PeerVerifiersSuite) SetupSuite() { + var err error + suite.key, err = rsa.GenerateKey(rand.Reader, 512) + suite.Require().NoError(err) - assert.Error(pv.VerifyPeerCertificate([][]byte{unparseable}, nil)) - }) + suite.testCerts = make([]*x509.Certificate, 3) + suite.rawCerts = make([][]byte, len(suite.testCerts)) - var ( - random = rand.New(rand.NewSource(1234)) - verifyErr = errors.New("expected Verify error") - - testData = []struct { - results []error - expectedErr error - }{ - { - results: []error{}, - expectedErr: nil, - }, - { - results: []error{nil}, - expectedErr: nil, - }, - { - results: []error{verifyErr}, - expectedErr: verifyErr, + for i := 0; i < len(suite.testCerts); i++ { + suite.testCerts[i] = &x509.Certificate{ + SerialNumber: big.NewInt(int64(i + 1)), + DNSNames: []string{ + fmt.Sprintf("host-%d.net", i), }, - { - results: []error{nil, nil}, - expectedErr: nil, - }, - { - results: []error{nil, verifyErr}, - expectedErr: verifyErr, - }, - { - results: []error{verifyErr, nil}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil, nil, nil, nil}, - expectedErr: nil, - }, - { - results: []error{nil, nil, verifyErr, nil, nil}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil, nil, nil, verifyErr}, - expectedErr: verifyErr, + Subject: pkix.Name{ + CommonName: fmt.Sprintf("Organization #%d", i), }, } + + var err error + suite.rawCerts[i], err = x509.CreateCertificate( + rand.Reader, + suite.testCerts[i], + suite.testCerts[i], + &suite.key.PublicKey, + suite.key, + ) + + suite.Require().NoError(err) + suite.Require().NotEmpty(suite.rawCerts[i]) + } +} + +// useCertificate gives some syntactic sugar for expecting a peer cert +func (suite *PeerVerifiersSuite) useCertificate(expected *x509.Certificate) func(*x509.Certificate) bool { + return newCertificateMatcher( + suite.T(), + expected.Subject.CommonName, + expected.DNSNames..., ) +} - for i, record := range testData { - t.Run(fmt.Sprintf("i=%d,len=%d", i, len(record.results)), func(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - peerSerial = rand.Int63() - chainSerial = rand.Int63() - template = stubPeerCert(peerSerial) +func (suite *PeerVerifiersSuite) expectVerify(expected *x509.Certificate, result error) *mockPeerVerifier { + m := new(mockPeerVerifier) + m.ExpectVerify( + suite.useCertificate(expected), + ).Return(result) - pv PeerVerifiers - ) + return m +} - key, err := rsa.GenerateKey(random, 512) - require.NoError(err) - - peerCert, err := x509.CreateCertificate(random, template, template, &key.PublicKey, key) - require.NoError(err) - - errEncountered := false - for _, result := range record.results { - err := result - if errEncountered { - pv = append(pv, PeerVerifierFunc(func(*x509.Certificate, [][]*x509.Certificate) error { - assert.Fail("This verifier should not have been called due to an earlier error") - return err - })) - } else { - pv = append(pv, PeerVerifierFunc(func(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { - require.NotNil(peerCert) - require.NotNil(peerCert.SerialNumber) - assert.Equal(0, peerCert.SerialNumber.Cmp(big.NewInt(peerSerial))) - - require.Len(verifiedChains, 1) - require.Len(verifiedChains[0], 1) - assert.Equal(0, verifiedChains[0][0].SerialNumber.Cmp(big.NewInt(chainSerial))) - - return err - })) - } +func (suite *PeerVerifiersSuite) TestUnparseableCertificate() { + var ( + unparseable = []byte("unparseable") - if result != nil { - errEncountered = true - } + m = new(mockPeerVerifier) // no calls + pv = PeerVerifiers{m} + ) + + suite.Error(pv.VerifyPeerCertificate([][]byte{unparseable}, nil)) + m.AssertExpectations(suite.T()) +} + +func (suite *PeerVerifiersSuite) testVerifySuccess(expected *x509.Certificate) { + for count := 0; count < 3; count++ { + suite.Run(fmt.Sprintf("verifiers=%d", count), func() { + var pv PeerVerifiers + for i := 0; i < count; i++ { + pv = append(pv, suite.expectVerify(expected, nil)) } - assert.Equal( - record.expectedErr, - pv.VerifyPeerCertificate( - [][]byte{peerCert}, - stubChain(chainSerial), - ), + suite.NoError( + pv.Verify(expected, nil), ) + + assertPeerVerifierExpectations(suite.T(), pv...) }) } } -func testPeerVerifiersVerify(t *testing.T) { - var ( - verifyErr = errors.New("expected Verify error") - - testData = []struct { - results []error - expectedErr error - }{ - { - results: []error{}, - expectedErr: nil, - }, - { - results: []error{nil}, - expectedErr: nil, - }, - { - results: []error{verifyErr}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil}, - expectedErr: nil, - }, - { - results: []error{nil, verifyErr}, - expectedErr: verifyErr, - }, - { - results: []error{verifyErr, nil}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil, nil, nil, nil}, - expectedErr: nil, - }, - { - results: []error{nil, nil, verifyErr, nil, nil}, - expectedErr: verifyErr, - }, - { - results: []error{nil, nil, nil, nil, verifyErr}, - expectedErr: verifyErr, - }, - } - ) +func (suite *PeerVerifiersSuite) testVerifyFailure(expected *x509.Certificate) { + for count := 0; count < 3; count++ { + suite.Run(fmt.Sprintf("goodVerifiers=%d", count), func() { + var pv PeerVerifiers + + // setup our "good" calls + for i := 0; i < count; i++ { + pv = append(pv, suite.expectVerify(expected, nil)) + } - for i, record := range testData { - t.Run(fmt.Sprintf("i=%d,len=%d", i, len(record.results)), func(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - peerSerial = rand.Int63() - chainSerial = rand.Int63() + // a failure, followed by a verifier that shouldn't be called + pv = append(pv, + suite.expectVerify(expected, errors.New("expected")), + new(mockPeerVerifier), + ) - pv PeerVerifiers + suite.Error( + pv.Verify(expected, nil), ) - errEncountered := false - for _, result := range record.results { - err := result - if errEncountered { - pv = append(pv, PeerVerifierFunc(func(*x509.Certificate, [][]*x509.Certificate) error { - assert.Fail("This verifier should not have been called due to an earlier error") - return err - })) - } else { - pv = append(pv, PeerVerifierFunc(func(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { - require.NotNil(peerCert) - require.NotNil(peerCert.SerialNumber) - assert.Equal(0, peerCert.SerialNumber.Cmp(big.NewInt(peerSerial))) - - require.Len(verifiedChains, 1) - require.Len(verifiedChains[0], 1) - assert.Equal(0, verifiedChains[0][0].SerialNumber.Cmp(big.NewInt(chainSerial))) - - return err - })) + assertPeerVerifierExpectations(suite.T(), pv...) + }) + } +} + +func (suite *PeerVerifiersSuite) TestVerify() { + suite.Run("Success", func() { + suite.testVerifySuccess(suite.testCerts[0]) + }) + + suite.Run("Failure", func() { + suite.testVerifyFailure(suite.testCerts[0]) + }) +} + +func (suite *PeerVerifiersSuite) testVerifyPeerCertificateAllGood(testCerts []*x509.Certificate, rawCerts [][]byte) { + for count := 0; count < 3; count++ { + suite.Run(fmt.Sprintf("verifiers=%d", count), func() { + var pv PeerVerifiers + for i := 0; i < count; i++ { + m := new(mockPeerVerifier) + for j := 0; j < len(testCerts); j++ { + // maybe is used here because a success short-circuits subsequent calls + m.ExpectVerify(suite.useCertificate(testCerts[j])).Return(error(nil)).Maybe() } - if result != nil { - errEncountered = true + pv = append(pv, m) + } + + suite.NoError( + pv.VerifyPeerCertificate(rawCerts, nil), + ) + + assertPeerVerifierExpectations(suite.T(), pv...) + }) + } +} + +func (suite *PeerVerifiersSuite) testVerifyPeerCertificateAllBad(testCerts []*x509.Certificate, rawCerts [][]byte) { + for count := 1; count < 3; count++ { + suite.Run(fmt.Sprintf("verifiers=%d", count), func() { + var pv PeerVerifiers + for i := 0; i < count; i++ { + m := new(mockPeerVerifier) + for j := 0; j < len(testCerts); j++ { + // maybe is used here because a failure short-circuits subsequent calls + m.ExpectVerify(suite.useCertificate(testCerts[j])).Return(errors.New("expected")).Maybe() } + + pv = append(pv, m) } - assert.Equal( - record.expectedErr, - pv.Verify( - stubPeerCert(peerSerial), - stubChain(chainSerial), - ), + suite.Error( + pv.VerifyPeerCertificate(rawCerts, nil), ) + + assertPeerVerifierExpectations(suite.T(), pv...) }) } } +func (suite *PeerVerifiersSuite) testVerifyPeerCertificateOneGood(testCerts []*x509.Certificate, rawCerts [][]byte) { + // a verifier that passes any but the first cert + oneGood := new(mockPeerVerifier) + oneGood.ExpectVerify(func(actual *x509.Certificate) bool { + return actual == testCerts[0] + }).Return(errors.New("oneGood: first cert should fail")).Maybe() + oneGood.ExpectVerify(func(actual *x509.Certificate) bool { + return actual != testCerts[0] + }).Return(error(nil)) + + pv := PeerVerifiers{ + oneGood, + } + + suite.NoError( + pv.VerifyPeerCertificate(rawCerts, nil), + ) + + assertPeerVerifierExpectations(suite.T(), pv...) +} + +func (suite *PeerVerifiersSuite) TestVerifyPeerCertificate() { + suite.Run("AllGood", func() { + for i := 1; i < len(suite.testCerts); i++ { + suite.Run(fmt.Sprintf("certs=%d", i), func() { + suite.testVerifyPeerCertificateAllGood(suite.testCerts[0:i], suite.rawCerts[0:i]) + }) + } + }) + + suite.Run("AllBad", func() { + for i := 1; i < len(suite.testCerts); i++ { + suite.Run(fmt.Sprintf("certs=%d", i), func() { + suite.testVerifyPeerCertificateAllBad(suite.testCerts[0:i], suite.rawCerts[0:i]) + }) + } + }) + + suite.Run("OneGood", func() { + for i := 2; i < len(suite.testCerts); i++ { + suite.Run(fmt.Sprintf("certs=%d", i), func() { + suite.testVerifyPeerCertificateOneGood(suite.testCerts[0:i], suite.rawCerts[0:i]) + }) + } + }) +} + func TestPeerVerifiers(t *testing.T) { - t.Run("VerifyPeerCertificate", testPeerVerifiersVerifyPeerCertificate) - t.Run("Verify", testPeerVerifiersVerify) + suite.Run(t, new(PeerVerifiersSuite)) } func TestNewPeerVerifiers(t *testing.T) {