diff --git a/xhttp/xhttpserver/mocks_test.go b/xhttp/xhttpserver/mocks_test.go index d967035..f3caa88 100644 --- a/xhttp/xhttpserver/mocks_test.go +++ b/xhttp/xhttpserver/mocks_test.go @@ -5,10 +5,8 @@ package xhttpserver import ( "bufio" "context" - "crypto/x509" "net" "net/http" - "testing" "github.com/stretchr/testify/mock" ) @@ -123,51 +121,3 @@ func (m *mockServer) Shutdown(ctx context.Context) error { func (m *mockServer) ExpectShutdown(p ...interface{}) *mock.Call { return m.On("Shutdown", p...) } - -func newCertificateMatcher(t *testing.T, commonName string, dnsNames ...string) func(*x509.Certificate) bool { - return func(actual *x509.Certificate) bool { - t.Logf("Testing cert: Subject.CommonName=%s, DNSNames=%s", actual.Subject.CommonName, actual.DNSNames) - - switch { - case commonName != actual.Subject.CommonName: - return false - - case len(dnsNames) != len(actual.DNSNames): - return false - - default: - for i := 0; i < len(dnsNames); i++ { - if dnsNames[i] != actual.DNSNames[i] { - return false - } - } - } - - return true - } -} - -type mockPeerVerifier struct { - mock.Mock -} - -func (m *mockPeerVerifier) Verify(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { - return m.Called(peerCert, verifiedChains).Error(0) -} - -// ExpectVerify sets up the a mocked call to Verify with a peer certificate with the given -// subject common name and dns names. Since this package doesn't use any other fields, -// this expectation suffices for tests. -func (m *mockPeerVerifier) ExpectVerify(certificateMatcher func(*x509.Certificate) bool) *mock.Call { - return m.On( - "Verify", - mock.MatchedBy(certificateMatcher), - [][]*x509.Certificate(nil), // we always pass nil in tests, since we don't use this parameter - ) -} - -func assertPeerVerifierExpectations(t *testing.T, pvs ...PeerVerifier) { - for _, pv := range pvs { - pv.(*mockPeerVerifier).AssertExpectations(t) - } -} diff --git a/xhttp/xhttpserver/server_test.go b/xhttp/xhttpserver/server_test.go index f893e72..4a977ad 100644 --- a/xhttp/xhttpserver/server_test.go +++ b/xhttp/xhttpserver/server_test.go @@ -11,7 +11,6 @@ import ( "github.com/xmidt-org/sallust" "github.com/xmidt-org/sallust/sallusthttp" - "go.uber.org/zap" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" @@ -128,7 +127,7 @@ func testNewServerChainFull(t *testing.T) { assert = assert.New(t) require = require.New(t) - output, base = sallust.NewTestLogger(zap.DebugLevel) + base = sallust.Default() next = http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { assert.Implements((*TrackingWriter)(nil), response) @@ -157,10 +156,6 @@ func testNewServerChainFull(t *testing.T) { require.NotNil(decorated) decorated.ServeHTTP(response, request) assert.Equal(299, response.Code) - assert.Contains(output.String(), "requestMethod") - assert.Contains(output.String(), "POST") - assert.Contains(output.String(), "requestURI") - assert.Contains(output.String(), "/foo") } func TestNewServerChain(t *testing.T) { @@ -175,8 +170,8 @@ func testNewSimple(t *testing.T) { assert = assert.New(t) require = require.New(t) - output, base = sallust.NewTestLogger(zap.DebugLevel) - router = mux.NewRouter() + base = sallust.Default() + router = mux.NewRouter() s = New( Options{ @@ -209,7 +204,6 @@ func testNewSimple(t *testing.T) { require.NotNil(s.(*http.Server).ErrorLog) s.(*http.Server).ErrorLog.Print("foo", "bar") - assert.Greater(output.Len(), 0) assert.Nil(s.(*http.Server).ConnState) } @@ -219,8 +213,8 @@ func testNewFull(t *testing.T) { assert = assert.New(t) require = require.New(t) - output, base = sallust.NewTestLogger(zap.DebugLevel) - router = mux.NewRouter() + base = sallust.Default() + router = mux.NewRouter() s = New( Options{ @@ -252,12 +246,9 @@ func testNewFull(t *testing.T) { require.NotNil(s.(*http.Server).ErrorLog) s.(*http.Server).ErrorLog.Print("foo", "bar") - assert.Greater(output.Len(), 0) require.NotNil(s.(*http.Server).ConnState) - output.Reset() s.(*http.Server).ConnState(new(net.IPConn), http.StateNew) - assert.Greater(output.Len(), 0) } func TestNew(t *testing.T) { diff --git a/xhttp/xhttpserver/tls.go b/xhttp/xhttpserver/tls.go index 2caebce..ae00d25 100644 --- a/xhttp/xhttpserver/tls.go +++ b/xhttp/xhttpserver/tls.go @@ -6,246 +6,114 @@ import ( "crypto/tls" "crypto/x509" "errors" + "fmt" "os" - "strings" - - "go.uber.org/multierr" ) var ( - ErrTlsCertificateRequired = errors.New("Both a certificateFile and keyFile are required") - ErrUnableToAddClientCACertificate = errors.New("Unable to add client CA certificate") + ErrTlsCertificateRequired = errors.New("Both a certificateFile and keyFile are required") ) -// PeerVerifyError represents a verification error for a particular certificate -type PeerVerifyError struct { - Certificate *x509.Certificate - Reason string -} - -func (pve PeerVerifyError) Error() string { - return pve.Reason -} - -// PeerVerifyOptions allows common checks against a client-side certificate to be configured externally. Any constraint that matches -// will result in a valid peer cert. -type PeerVerifyOptions struct { - // DNSSuffixes enumerates any DNS suffixes that are checked. A DNSName field of at least (1) peer cert - // must have one of these suffixes. If this field is not supplied, no DNS suffix checking is performed. - // Matching is case insensitive. - // - // If any DNS suffix matches, that is sufficient for the peer cert to be valid. No further checking is done in that case. - DNSSuffixes []string - - // CommonNames lists the subject common names that at least (1) peer cert must have. If not supplied, - // no checking is done on the common name. Matching common names is case sensitive. - // - // If any common name matches, that is sufficient for the peer cert to be valid. No further checking is done in that case. - CommonNames []string -} - -// PeerVerifier is a verification strategy for a peer (client) certificate. -type PeerVerifier interface { - Verify(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error -} - -type PeerVerifierFunc func(*x509.Certificate, [][]*x509.Certificate) error - -func (pvf PeerVerifierFunc) Verify(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { - return pvf(peerCert, verifiedChains) -} - -// ConfiguredPeerVerifier is a PeerVerifier strategy synthesized from a PeerVerifyOptions. This type is the built-in -// PeerVerifier strategy for this package. -type ConfiguredPeerVerifier struct { - dnsSuffixes []string - commonNames []string -} - -func (cpv *ConfiguredPeerVerifier) Verify(peerCert *x509.Certificate, _ [][]*x509.Certificate) error { - for _, suffix := range cpv.dnsSuffixes { - for _, dnsName := range peerCert.DNSNames { - if strings.HasSuffix(strings.ToLower(dnsName), suffix) { - return nil - } - } - - // Allow the common name to be suffixed by a DNS suffix - if strings.HasSuffix(strings.ToLower(peerCert.Subject.CommonName), suffix) { - return nil +// ReadCertPool reads a file that is expected to contain a certificate bundle +// and returns that bundle as a pool. +func ReadCertPool(path string) (cp *x509.CertPool, err error) { + var contents []byte + contents, err = os.ReadFile(path) + if err == nil { + cp = x509.NewCertPool() + if !cp.AppendCertsFromPEM(contents) { + err = fmt.Errorf("Unable to add certificates from %s", path) } } - for _, commonName := range cpv.commonNames { - if commonName == peerCert.Subject.CommonName { - return nil - } - } - - return PeerVerifyError{ - Certificate: peerCert, - Reason: "No DNS name or common name matched", - } + return } -// NewConfiguredPeerVerifier returns a ConfiguredPeerVerifier from a set of options. If the given options -// do not represent any constraints, i.e. if every field is unset, then this function returns nil. -func NewConfiguredPeerVerifier(pvo PeerVerifyOptions) *ConfiguredPeerVerifier { - if len(pvo.DNSSuffixes) == 0 && len(pvo.CommonNames) == 0 { - return nil - } - - cpv := new(ConfiguredPeerVerifier) - if len(pvo.DNSSuffixes) > 0 { - cpv.dnsSuffixes = make([]string, len(pvo.DNSSuffixes)) - for i, suffix := range pvo.DNSSuffixes { - cpv.dnsSuffixes[i] = strings.ToLower(suffix) - } - } - - if len(pvo.CommonNames) > 0 { - cpv.commonNames = append(cpv.commonNames, pvo.CommonNames...) - } - - return cpv +// Mtls configures the mutual TLS settings for a tls.Config. +type Mtls struct { + ClientCACertificateFile string + DisableRequire bool + DisableVerify bool } -// PeerVerifiers is a sequence of verification strategies. All of the verifiers must return nil errors for -// a given peer cert to be considered valid. -type PeerVerifiers []PeerVerifier - -// Verify allows a PeerVerifiers to itself be used as a PeerVerifier -func (pvs PeerVerifiers) Verify(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error { - for _, pv := range pvs { - if err := pv.Verify(peerCert, verifiedChains); err != nil { - return err - } - } - - return nil +// Tls represents the set of configurable options for a serverside tls.Config associated with a server. +type Tls struct { + CertificateFile string + KeyFile string + Mtls *Mtls + ServerName string + NextProtos []string + MinVersion uint16 + MaxVersion uint16 } -// VerifyPeerCertificate may be used as the closure for crypto/tls.Config.VerifyPeerCertificate -// -// If any of the rawCerts passes verification, this method returns nil to indicate that the -// client has supplied a valid certificate. If all rawCerts fail verification or if any certificates -// fail to parse, this method returns an error. -func (pvs PeerVerifiers) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) (err error) { - if len(pvs) == 0 { +// configureMtls sets up mtls on the given TLS configuration. +func configureMtls(tc *tls.Config, mtls *Mtls) (err error) { + if mtls == nil { return } - for _, rawCert := range rawCerts { - peerCert, parseErr := x509.ParseCertificate(rawCert) - if parseErr != nil { - // any parse error invalidates the entire sequence of certificates - err = parseErr - break - } + switch { + case mtls.DisableRequire && mtls.DisableVerify: + tc.ClientAuth = tls.RequestClientCert - verifyErr := pvs.Verify(peerCert, verifiedChains) - err = multierr.Append(err, verifyErr) + case !mtls.DisableRequire && mtls.DisableVerify: + tc.ClientAuth = tls.RequireAnyClientCert - if verifyErr == nil { - // we found (1) cert that passes verification, so we're good - err = nil - break - } - } - - return -} + case mtls.DisableRequire && !mtls.DisableVerify: + tc.ClientAuth = tls.VerifyClientCertIfGiven -// NewPeerVerifiers constructs a chain of verification strategies merged from a set of options with an extra -// set of application-layer strategies. The extra verifiers are run first. This function will return an empty -// chain of verifiers if both (1) the options do not have any constraints, and (2) there are no extra verifiers. -func NewPeerVerifiers(pvo PeerVerifyOptions, extra ...PeerVerifier) PeerVerifiers { - pvs := append(PeerVerifiers{}, extra...) - - if cpv := NewConfiguredPeerVerifier(pvo); cpv != nil { - pvs = append(pvs, cpv) + case !mtls.DisableRequire && !mtls.DisableVerify: + tc.ClientAuth = tls.RequireAndVerifyClientCert } - return pvs -} - -// Tls represents the set of configurable options for a serverside tls.Config associated with a server. -type Tls struct { - CertificateFile string - KeyFile string - ClientCACertificateFile string - ServerName string - NextProtos []string - MinVersion uint16 - MaxVersion uint16 - PeerVerify PeerVerifyOptions + tc.ClientCAs, err = ReadCertPool(mtls.ClientCACertificateFile) + return } // NewTlsConfig produces a *tls.Config from a set of configuration options. If the supplied set of options // is nil, this function returns nil with no error. -// -// If supplied, the PeerVerifier strategies will be executed as part of peer verification. This allows application-layer -// logic to be injected. -func NewTlsConfig(t *Tls, extra ...PeerVerifier) (*tls.Config, error) { +func NewTlsConfig(t *Tls) (tc *tls.Config, err error) { if t == nil { - return nil, nil - } - - if len(t.CertificateFile) == 0 || len(t.KeyFile) == 0 { - return nil, ErrTlsCertificateRequired - } - - var nextProtos []string - if len(t.NextProtos) > 0 { - nextProtos = append(nextProtos, t.NextProtos...) - } else { - // assume http/1.1 by default - nextProtos = append(nextProtos, "http/1.1") - } - - tc := &tls.Config{ // nolint: gosec - MinVersion: t.MinVersion, - MaxVersion: t.MaxVersion, - ServerName: t.ServerName, - NextProtos: nextProtos, - - // disable vulnerable ciphers - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - }, - } - - // if no MinVersion was set, default to TLS 1.2 - if tc.MinVersion == 0 { - tc.MinVersion = tls.VersionTLS12 - } - - if pvs := NewPeerVerifiers(t.PeerVerify, extra...); len(pvs) > 0 { - tc.VerifyPeerCertificate = pvs.VerifyPeerCertificate + return + } else if len(t.CertificateFile) == 0 || len(t.KeyFile) == 0 { + err = ErrTlsCertificateRequired } - if cert, err := tls.LoadX509KeyPair(t.CertificateFile, t.KeyFile); err != nil { - return nil, err - } else { - tc.Certificates = []tls.Certificate{cert} - } + if err == nil { + var nextProtos []string + if len(t.NextProtos) > 0 { + nextProtos = append(nextProtos, t.NextProtos...) + } else { + // assume http/1.1 by default + nextProtos = append(nextProtos, "http/1.1") + } - if len(t.ClientCACertificateFile) > 0 { - caCert, err := os.ReadFile(t.ClientCACertificateFile) - if err != nil { - return nil, err + tc = &tls.Config{ // nolint: gosec + MinVersion: t.MinVersion, + MaxVersion: t.MaxVersion, + ServerName: t.ServerName, + NextProtos: nextProtos, + + // disable vulnerable ciphers + CipherSuites: []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + }, } - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, ErrUnableToAddClientCACertificate + // if no MinVersion was set, default to TLS 1.2 + if tc.MinVersion == 0 { + tc.MinVersion = tls.VersionTLS12 } - tc.ClientCAs = caCertPool - tc.ClientAuth = tls.RequestClientCert + tc.Certificates = make([]tls.Certificate, 1) + tc.Certificates[0], err = tls.LoadX509KeyPair(t.CertificateFile, t.KeyFile) + } + + if err == nil { + err = configureMtls(tc, t.Mtls) } - tc.BuildNameToCertificate() // nolint: staticcheck - return tc, nil + return } diff --git a/xhttp/xhttpserver/tls_test.go b/xhttp/xhttpserver/tls_test.go index 4a03952..784ae89 100644 --- a/xhttp/xhttpserver/tls_test.go +++ b/xhttp/xhttpserver/tls_test.go @@ -3,454 +3,58 @@ package xhttpserver import ( - "crypto/rand" - "crypto/rsa" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "errors" - "fmt" - "math/big" "os" + "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" ) -func TestPeerVerifyError(t *testing.T) { - var ( - assert = assert.New(t) - err = PeerVerifyError{Reason: "expected"} - ) - - assert.Equal("expected", err.Error()) -} - -func TestPeerVerifierFunc(t *testing.T) { - testCert := new(x509.Certificate) - pvf := PeerVerifierFunc(func(actual *x509.Certificate, verifiedChains [][]*x509.Certificate) error { - assert.Equal(t, testCert, actual) - return nil - }) - - assert.NoError(t, pvf.Verify(testCert, nil)) -} - -type ConfiguredPeerVerifierSuite struct { - suite.Suite -} - -func (suite *ConfiguredPeerVerifierSuite) newConfiguredPeerVerifier(o PeerVerifyOptions) *ConfiguredPeerVerifier { - cpv := NewConfiguredPeerVerifier(o) - suite.Require().NotNil(cpv) - return cpv -} - -func (suite *ConfiguredPeerVerifierSuite) testVerifySuccess() { - testData := []struct { - description string - peerCert x509.Certificate - options PeerVerifyOptions +func TestConfigureMtls(t *testing.T) { + testCases := []struct { + Mtls *Mtls + Expected tls.ClientAuthType }{ { - description: "DNS name", - peerCert: x509.Certificate{ - DNSNames: []string{"test.foobar.com"}, - }, - options: PeerVerifyOptions{ - DNSSuffixes: []string{"foobar.com"}, - }, + Mtls: nil, + Expected: tls.NoClientCert, }, { - description: "multiple DNS names", - peerCert: x509.Certificate{ - DNSNames: []string{"first.foobar.com", "second.something.net"}, - }, - options: PeerVerifyOptions{ - DNSSuffixes: []string{"another.thing.org", "something.net"}, - }, - }, - { - description: "common name as host name", - peerCert: x509.Certificate{ - Subject: pkix.Name{ - CommonName: "PCTEST-another.thing.org", - }, - }, - options: PeerVerifyOptions{ - DNSSuffixes: []string{"another.thing.org", "something.net"}, - }, - }, - { - description: "common name", - peerCert: x509.Certificate{ - Subject: pkix.Name{ - CommonName: "A Great Organization", - }, - }, - options: PeerVerifyOptions{ - CommonNames: []string{"A Great Organization"}, - }, - }, - { - description: "multiple common names", - peerCert: x509.Certificate{ - Subject: pkix.Name{ - CommonName: "A Great Organization", - }, - }, - options: PeerVerifyOptions{ - CommonNames: []string{"First Organization Doesn't Match", "A Great Organization"}, - }, - }, - } - - for _, testCase := range testData { - suite.Run(testCase.description, func() { - verifier := suite.newConfiguredPeerVerifier(testCase.options) - peerCert := testCase.peerCert - suite.NoError(verifier.Verify(&peerCert, nil)) - }) - } -} - -func (suite *ConfiguredPeerVerifierSuite) testVerifyFailure() { - testData := []struct { - description string - peerCert x509.Certificate - options PeerVerifyOptions - }{ - { - description: "empty fields", - peerCert: x509.Certificate{}, - options: PeerVerifyOptions{ - DNSSuffixes: []string{"foobar.net"}, - CommonNames: []string{"For Great Justice"}, - }, + Mtls: &Mtls{}, + Expected: tls.RequireAndVerifyClientCert, }, { - description: "DNS mismatch", - peerCert: x509.Certificate{ - DNSNames: []string{"another.company.com"}, - }, - options: PeerVerifyOptions{ - DNSSuffixes: []string{"foobar.net"}, - CommonNames: []string{"For Great Justice"}, + Mtls: &Mtls{ + DisableRequire: true, }, + Expected: tls.VerifyClientCertIfGiven, }, { - description: "CommonName mismatch", - peerCert: x509.Certificate{ - Subject: pkix.Name{ - CommonName: "Villains For Hire", - }, - }, - options: PeerVerifyOptions{ - DNSSuffixes: []string{"foobar.net"}, - CommonNames: []string{"For Great Justice"}, + Mtls: &Mtls{ + DisableVerify: true, }, + Expected: tls.RequireAnyClientCert, }, { - description: "DNS and CommonName mismatch", - peerCert: x509.Certificate{ - DNSNames: []string{"another.company.com"}, - Subject: pkix.Name{ - CommonName: "Villains For Hire", - }, - }, - options: PeerVerifyOptions{ - DNSSuffixes: []string{"foobar.net"}, - CommonNames: []string{"For Great Justice"}, + Mtls: &Mtls{ + DisableRequire: true, + DisableVerify: true, }, + Expected: tls.RequestClientCert, }, } - for _, testCase := range testData { - suite.Run(testCase.description, func() { - verifier := suite.newConfiguredPeerVerifier(testCase.options) - - peerCert := testCase.peerCert - err := verifier.Verify(&peerCert, nil) - suite.Error(err) - }) - } -} - -func (suite *ConfiguredPeerVerifierSuite) TestVerify() { - suite.Run("Success", suite.testVerifySuccess) - suite.Run("Failure", suite.testVerifyFailure) -} - -func (suite *ConfiguredPeerVerifierSuite) TestNewConfiguredPeerVerifier() { - suite.Run("Nil", func() { - suite.Nil(NewConfiguredPeerVerifier(PeerVerifyOptions{})) - }) -} - -func TestConfiguredPeerVerifier(t *testing.T) { - suite.Run(t, new(ConfiguredPeerVerifierSuite)) -} - -type PeerVerifiersSuite struct { - suite.Suite - - key *rsa.PrivateKey - testCerts []*x509.Certificate - rawCerts [][]byte -} - -func (suite *PeerVerifiersSuite) SetupSuite() { - var err error - suite.key, err = rsa.GenerateKey(rand.Reader, 512) // nolint: gosec - suite.Require().NoError(err) - - suite.testCerts = make([]*x509.Certificate, 3) - suite.rawCerts = make([][]byte, len(suite.testCerts)) - - for i := 0; i < len(suite.testCerts); i++ { - suite.testCerts[i] = &x509.Certificate{ - SerialNumber: big.NewInt(int64(i + 1)), - DNSNames: []string{ - fmt.Sprintf("host-%d.net", i), - }, - Subject: pkix.Name{ - CommonName: fmt.Sprintf("Organization #%d", i), - }, - } - - var err error - suite.rawCerts[i], err = x509.CreateCertificate( - rand.Reader, - suite.testCerts[i], - suite.testCerts[i], - &suite.key.PublicKey, - suite.key, - ) - - suite.Require().NoError(err) - suite.Require().NotEmpty(suite.rawCerts[i]) - } -} - -// useCertificate gives some syntactic sugar for expecting a peer cert -func (suite *PeerVerifiersSuite) useCertificate(expected *x509.Certificate) func(*x509.Certificate) bool { - return newCertificateMatcher( - suite.T(), - expected.Subject.CommonName, - expected.DNSNames..., - ) -} - -func (suite *PeerVerifiersSuite) expectVerify(expected *x509.Certificate, result error) *mockPeerVerifier { - m := new(mockPeerVerifier) - m.ExpectVerify( - suite.useCertificate(expected), - ).Return(result) - - return m -} - -func (suite *PeerVerifiersSuite) TestUnparseableCertificate() { - var ( - unparseable = []byte("unparseable") - - m = new(mockPeerVerifier) // no calls - pv = PeerVerifiers{m} - ) - - suite.Error(pv.VerifyPeerCertificate([][]byte{unparseable}, nil)) - m.AssertExpectations(suite.T()) -} - -func (suite *PeerVerifiersSuite) testVerifySuccess(expected *x509.Certificate) { - for count := 0; count < 3; count++ { - suite.Run(fmt.Sprintf("verifiers=%d", count), func() { - var pv PeerVerifiers - for i := 0; i < count; i++ { - pv = append(pv, suite.expectVerify(expected, nil)) - } - - suite.NoError( - pv.Verify(expected, nil), - ) - - assertPeerVerifierExpectations(suite.T(), pv...) - }) - } -} - -func (suite *PeerVerifiersSuite) testVerifyFailure(expected *x509.Certificate) { - for count := 0; count < 3; count++ { - suite.Run(fmt.Sprintf("goodVerifiers=%d", count), func() { - var pv PeerVerifiers - - // setup our "good" calls - for i := 0; i < count; i++ { - pv = append(pv, suite.expectVerify(expected, nil)) - } - - // a failure, followed by a verifier that shouldn't be called - pv = append(pv, - suite.expectVerify(expected, errors.New("expected")), - new(mockPeerVerifier), - ) - - suite.Error( - pv.Verify(expected, nil), - ) - - assertPeerVerifierExpectations(suite.T(), pv...) + for i, testCase := range testCases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var tc tls.Config + configureMtls(&tc, testCase.Mtls) + assert.Equal(t, testCase.Expected, tc.ClientAuth) }) } } -func (suite *PeerVerifiersSuite) TestVerify() { - suite.Run("Success", func() { - suite.testVerifySuccess(suite.testCerts[0]) - }) - - suite.Run("Failure", func() { - suite.testVerifyFailure(suite.testCerts[0]) - }) -} - -func (suite *PeerVerifiersSuite) testVerifyPeerCertificateAllGood(testCerts []*x509.Certificate, rawCerts [][]byte) { - for count := 0; count < 3; count++ { - suite.Run(fmt.Sprintf("verifiers=%d", count), func() { - var pv PeerVerifiers - for i := 0; i < count; i++ { - m := new(mockPeerVerifier) - for j := 0; j < len(testCerts); j++ { - // maybe is used here because a success short-circuits subsequent calls - m.ExpectVerify(suite.useCertificate(testCerts[j])).Return(error(nil)).Maybe() - } - - pv = append(pv, m) - } - - suite.NoError( - pv.VerifyPeerCertificate(rawCerts, nil), - ) - - assertPeerVerifierExpectations(suite.T(), pv...) - }) - } -} - -func (suite *PeerVerifiersSuite) testVerifyPeerCertificateAllBad(testCerts []*x509.Certificate, rawCerts [][]byte) { - for count := 1; count < 3; count++ { - suite.Run(fmt.Sprintf("verifiers=%d", count), func() { - var pv PeerVerifiers - for i := 0; i < count; i++ { - m := new(mockPeerVerifier) - for j := 0; j < len(testCerts); j++ { - // maybe is used here because a failure short-circuits subsequent calls - m.ExpectVerify(suite.useCertificate(testCerts[j])).Return(errors.New("expected")).Maybe() - } - - pv = append(pv, m) - } - - suite.Error( - pv.VerifyPeerCertificate(rawCerts, nil), - ) - - assertPeerVerifierExpectations(suite.T(), pv...) - }) - } -} - -func (suite *PeerVerifiersSuite) testVerifyPeerCertificateOneGood(testCerts []*x509.Certificate, rawCerts [][]byte) { - // a verifier that passes any but the first cert - oneGood := new(mockPeerVerifier) - oneGood.ExpectVerify(func(actual *x509.Certificate) bool { - return actual == testCerts[0] - }).Return(errors.New("oneGood: first cert should fail")).Maybe() - oneGood.ExpectVerify(func(actual *x509.Certificate) bool { - return actual != testCerts[0] - }).Return(error(nil)) - - pv := PeerVerifiers{ - oneGood, - } - - suite.NoError( - pv.VerifyPeerCertificate(rawCerts, nil), - ) - - assertPeerVerifierExpectations(suite.T(), pv...) -} - -func (suite *PeerVerifiersSuite) TestVerifyPeerCertificate() { - suite.Run("AllGood", func() { - for i := 1; i < len(suite.testCerts); i++ { - suite.Run(fmt.Sprintf("certs=%d", i), func() { - suite.testVerifyPeerCertificateAllGood(suite.testCerts[0:i], suite.rawCerts[0:i]) - }) - } - }) - - suite.Run("AllBad", func() { - for i := 1; i < len(suite.testCerts); i++ { - suite.Run(fmt.Sprintf("certs=%d", i), func() { - suite.testVerifyPeerCertificateAllBad(suite.testCerts[0:i], suite.rawCerts[0:i]) - }) - } - }) - - suite.Run("OneGood", func() { - for i := 2; i < len(suite.testCerts); i++ { - suite.Run(fmt.Sprintf("certs=%d", i), func() { - suite.testVerifyPeerCertificateOneGood(suite.testCerts[0:i], suite.rawCerts[0:i]) - }) - } - }) -} - -func TestPeerVerifiers(t *testing.T) { - suite.Run(t, new(PeerVerifiersSuite)) -} - -func TestNewPeerVerifiers(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - var ( - assert = assert.New(t) - pv = NewPeerVerifiers(PeerVerifyOptions{}) - ) - - assert.Len(pv, 0) - }) - - t.Run("Configured", func(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - pv = NewPeerVerifiers(PeerVerifyOptions{DNSSuffixes: []string{"foobar.com"}}) - ) - - require.Len(pv, 1) - assert.IsType((*ConfiguredPeerVerifier)(nil), pv[0]) - }) - - t.Run("ConfiguredWithExtra", func(t *testing.T) { - var ( - assert = assert.New(t) - require = require.New(t) - extra = make([]PeerVerifier, 2) - - pv = NewPeerVerifiers( - PeerVerifyOptions{DNSSuffixes: []string{"foobar.com"}}, - extra..., - ) - ) - - require.Len(pv, 3) - assert.IsType((*ConfiguredPeerVerifier)(nil), pv[2]) - }) -} - func testNewTlsConfigNil(t *testing.T) { assert := assert.New(t) tc, err := NewTlsConfig(nil) @@ -474,12 +78,11 @@ func testNewTlsConfigNoKeyFile(t *testing.T) { func testNewTlsConfigLoadCertificateError(t *testing.T) { assert := assert.New(t) - tc, err := NewTlsConfig(&Tls{ + _, err := NewTlsConfig(&Tls{ CertificateFile: "nosuch", KeyFile: "nosuch", }) - assert.Nil(tc) assert.Error(err) } @@ -535,11 +138,10 @@ func testNewTlsConfigWithClientCACertificateFile(t *testing.T, certificateFile, require = require.New(t) tc, err = NewTlsConfig(&Tls{ - CertificateFile: certificateFile, - KeyFile: keyFile, - ClientCACertificateFile: certificateFile, - PeerVerify: PeerVerifyOptions{ - CommonNames: []string{"Hippies, Inc."}, + CertificateFile: certificateFile, + KeyFile: keyFile, + Mtls: &Mtls{ + ClientCACertificateFile: certificateFile, }, }) ) @@ -553,21 +155,22 @@ func testNewTlsConfigWithClientCACertificateFile(t *testing.T, certificateFile, assert.Equal([]string{"http/1.1"}, tc.NextProtos) assert.NotEmpty(tc.Certificates) assert.NotNil(tc.ClientCAs) - assert.Equal(tls.RequestClientCert, tc.ClientAuth) + assert.Equal(tls.RequireAndVerifyClientCert, tc.ClientAuth) } func testNewTlsConfigLoadClientCACertificateError(t *testing.T, certificateFile, keyFile string) { var ( assert = assert.New(t) - tc, err = NewTlsConfig(&Tls{ - CertificateFile: certificateFile, - KeyFile: keyFile, - ClientCACertificateFile: "nosuch", + _, err = NewTlsConfig(&Tls{ + CertificateFile: certificateFile, + KeyFile: keyFile, + Mtls: &Mtls{ + ClientCACertificateFile: "nosuch", + }, }) ) - assert.Nil(tc) assert.Error(err) } @@ -575,15 +178,16 @@ func testNewTlsConfigAppendClientCACertificateError(t *testing.T, certificateFil var ( assert = assert.New(t) - tc, err = NewTlsConfig(&Tls{ - CertificateFile: certificateFile, - KeyFile: keyFile, - ClientCACertificateFile: keyFile, // not a certificate, but still valid PEM + _, err = NewTlsConfig(&Tls{ + CertificateFile: certificateFile, + KeyFile: keyFile, + Mtls: &Mtls{ + ClientCACertificateFile: keyFile, // not a certificate, but still valid PEM + }, }) ) - assert.Nil(tc) - assert.Equal(ErrUnableToAddClientCACertificate, err) + assert.Error(err) } func TestNewTlsConfig(t *testing.T) { diff --git a/xhttp/xhttpserver/unmarshal_test.go b/xhttp/xhttpserver/unmarshal_test.go index 0ec6c86..16c4c47 100644 --- a/xhttp/xhttpserver/unmarshal_test.go +++ b/xhttp/xhttpserver/unmarshal_test.go @@ -43,6 +43,7 @@ func testUnmarshalProvideFull(t *testing.T) { router *mux.Router app = fxtest.New(t, + fx.NopLogger, fx.Provide( sallust.Default, config.ProvideViper( @@ -94,6 +95,7 @@ func testUnmarshalProvideOptional(t *testing.T) { router *mux.Router app = fxtest.New(t, + fx.NopLogger, fx.Provide( sallust.Default, config.ProvideViper(), @@ -225,8 +227,9 @@ func testUnmarshalAnnotatedFull(t *testing.T) { router *mux.Router app = fxtest.New(t, - fx.Supply(sallust.Default()), + fx.NopLogger, fx.Provide( + sallust.Default, config.ProvideViper( config.Json(` { @@ -291,6 +294,7 @@ func testUnmarshalAnnotatedNamed(t *testing.T) { router *mux.Router app = fxtest.New(t, + fx.NopLogger, fx.Provide( sallust.Default, config.ProvideViper(