diff --git a/xhttp/xhttpserver/mocks_test.go b/xhttp/xhttpserver/mocks_test.go index 9a52768..f3caa88 100644 --- a/xhttp/xhttpserver/mocks_test.go +++ b/xhttp/xhttpserver/mocks_test.go @@ -5,10 +5,8 @@ package xhttpserver import ( "bufio" "context" - "crypto/x509" "net" "net/http" - "testing" "github.com/stretchr/testify/mock" ) @@ -123,26 +121,3 @@ func (m *mockServer) Shutdown(ctx context.Context) error { func (m *mockServer) ExpectShutdown(p ...interface{}) *mock.Call { return m.On("Shutdown", p...) } - -func newCertificateMatcher(t *testing.T, commonName string, dnsNames ...string) func(*x509.Certificate) bool { - return func(actual *x509.Certificate) bool { - t.Logf("Testing cert: Subject.CommonName=%s, DNSNames=%s", actual.Subject.CommonName, actual.DNSNames) - - switch { - case commonName != actual.Subject.CommonName: - return false - - case len(dnsNames) != len(actual.DNSNames): - return false - - default: - for i := 0; i < len(dnsNames); i++ { - if dnsNames[i] != actual.DNSNames[i] { - return false - } - } - } - - return true - } -} diff --git a/xhttp/xhttpserver/tls.go b/xhttp/xhttpserver/tls.go index f42ed78..ae00d25 100644 --- a/xhttp/xhttpserver/tls.go +++ b/xhttp/xhttpserver/tls.go @@ -32,7 +32,7 @@ func ReadCertPool(path string) (cp *x509.CertPool, err error) { // Mtls configures the mutual TLS settings for a tls.Config. type Mtls struct { ClientCACertificateFile string - DisableRequired bool + DisableRequire bool DisableVerify bool } @@ -54,16 +54,16 @@ func configureMtls(tc *tls.Config, mtls *Mtls) (err error) { } switch { - case mtls.DisableRequired && mtls.DisableVerify: + case mtls.DisableRequire && mtls.DisableVerify: tc.ClientAuth = tls.RequestClientCert - case !mtls.DisableRequired && mtls.DisableVerify: + case !mtls.DisableRequire && mtls.DisableVerify: tc.ClientAuth = tls.RequireAnyClientCert - case mtls.DisableRequired && !mtls.DisableVerify: + case mtls.DisableRequire && !mtls.DisableVerify: tc.ClientAuth = tls.VerifyClientCertIfGiven - case !mtls.DisableRequired && !mtls.DisableVerify: + case !mtls.DisableRequire && !mtls.DisableVerify: tc.ClientAuth = tls.RequireAndVerifyClientCert } diff --git a/xhttp/xhttpserver/tls_test.go b/xhttp/xhttpserver/tls_test.go index 1bc6e80..784ae89 100644 --- a/xhttp/xhttpserver/tls_test.go +++ b/xhttp/xhttpserver/tls_test.go @@ -5,12 +5,56 @@ package xhttpserver import ( "crypto/tls" "os" + "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestConfigureMtls(t *testing.T) { + testCases := []struct { + Mtls *Mtls + Expected tls.ClientAuthType + }{ + { + Mtls: nil, + Expected: tls.NoClientCert, + }, + { + Mtls: &Mtls{}, + Expected: tls.RequireAndVerifyClientCert, + }, + { + Mtls: &Mtls{ + DisableRequire: true, + }, + Expected: tls.VerifyClientCertIfGiven, + }, + { + Mtls: &Mtls{ + DisableVerify: true, + }, + Expected: tls.RequireAnyClientCert, + }, + { + Mtls: &Mtls{ + DisableRequire: true, + DisableVerify: true, + }, + Expected: tls.RequestClientCert, + }, + } + + for i, testCase := range testCases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var tc tls.Config + configureMtls(&tc, testCase.Mtls) + assert.Equal(t, testCase.Expected, tc.ClientAuth) + }) + } +} + func testNewTlsConfigNil(t *testing.T) { assert := assert.New(t) tc, err := NewTlsConfig(nil)