Skip to content

Commit

Permalink
Merge pull request #261 from xmidt-org/feature/clientauth-configuration
Browse files Browse the repository at this point in the history
Feature/clientauth configuration
  • Loading branch information
johnabass authored Nov 21, 2024
2 parents 40ab59f + 8dd765d commit b7ebbc1
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 709 deletions.
50 changes: 0 additions & 50 deletions xhttp/xhttpserver/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ package xhttpserver
import (
"bufio"
"context"
"crypto/x509"
"net"
"net/http"
"testing"

"github.com/stretchr/testify/mock"
)
Expand Down Expand Up @@ -123,51 +121,3 @@ func (m *mockServer) Shutdown(ctx context.Context) error {
func (m *mockServer) ExpectShutdown(p ...interface{}) *mock.Call {
return m.On("Shutdown", p...)
}

func newCertificateMatcher(t *testing.T, commonName string, dnsNames ...string) func(*x509.Certificate) bool {
return func(actual *x509.Certificate) bool {
t.Logf("Testing cert: Subject.CommonName=%s, DNSNames=%s", actual.Subject.CommonName, actual.DNSNames)

switch {
case commonName != actual.Subject.CommonName:
return false

case len(dnsNames) != len(actual.DNSNames):
return false

default:
for i := 0; i < len(dnsNames); i++ {
if dnsNames[i] != actual.DNSNames[i] {
return false
}
}
}

return true
}
}

type mockPeerVerifier struct {
mock.Mock
}

func (m *mockPeerVerifier) Verify(peerCert *x509.Certificate, verifiedChains [][]*x509.Certificate) error {
return m.Called(peerCert, verifiedChains).Error(0)
}

// ExpectVerify sets up the a mocked call to Verify with a peer certificate with the given
// subject common name and dns names. Since this package doesn't use any other fields,
// this expectation suffices for tests.
func (m *mockPeerVerifier) ExpectVerify(certificateMatcher func(*x509.Certificate) bool) *mock.Call {
return m.On(
"Verify",
mock.MatchedBy(certificateMatcher),
[][]*x509.Certificate(nil), // we always pass nil in tests, since we don't use this parameter
)
}

func assertPeerVerifierExpectations(t *testing.T, pvs ...PeerVerifier) {
for _, pv := range pvs {
pv.(*mockPeerVerifier).AssertExpectations(t)
}
}
19 changes: 5 additions & 14 deletions xhttp/xhttpserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/xmidt-org/sallust"
"github.com/xmidt-org/sallust/sallusthttp"
"go.uber.org/zap"

"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -128,7 +127,7 @@ func testNewServerChainFull(t *testing.T) {
assert = assert.New(t)
require = require.New(t)

output, base = sallust.NewTestLogger(zap.DebugLevel)
base = sallust.Default()

next = http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
assert.Implements((*TrackingWriter)(nil), response)
Expand Down Expand Up @@ -157,10 +156,6 @@ func testNewServerChainFull(t *testing.T) {
require.NotNil(decorated)
decorated.ServeHTTP(response, request)
assert.Equal(299, response.Code)
assert.Contains(output.String(), "requestMethod")
assert.Contains(output.String(), "POST")
assert.Contains(output.String(), "requestURI")
assert.Contains(output.String(), "/foo")
}

func TestNewServerChain(t *testing.T) {
Expand All @@ -175,8 +170,8 @@ func testNewSimple(t *testing.T) {
assert = assert.New(t)
require = require.New(t)

output, base = sallust.NewTestLogger(zap.DebugLevel)
router = mux.NewRouter()
base = sallust.Default()
router = mux.NewRouter()

s = New(
Options{
Expand Down Expand Up @@ -209,7 +204,6 @@ func testNewSimple(t *testing.T) {

require.NotNil(s.(*http.Server).ErrorLog)
s.(*http.Server).ErrorLog.Print("foo", "bar")
assert.Greater(output.Len(), 0)

assert.Nil(s.(*http.Server).ConnState)
}
Expand All @@ -219,8 +213,8 @@ func testNewFull(t *testing.T) {
assert = assert.New(t)
require = require.New(t)

output, base = sallust.NewTestLogger(zap.DebugLevel)
router = mux.NewRouter()
base = sallust.Default()
router = mux.NewRouter()

s = New(
Options{
Expand Down Expand Up @@ -252,12 +246,9 @@ func testNewFull(t *testing.T) {

require.NotNil(s.(*http.Server).ErrorLog)
s.(*http.Server).ErrorLog.Print("foo", "bar")
assert.Greater(output.Len(), 0)

require.NotNil(s.(*http.Server).ConnState)
output.Reset()
s.(*http.Server).ConnState(new(net.IPConn), http.StateNew)
assert.Greater(output.Len(), 0)
}

func TestNew(t *testing.T) {
Expand Down
Loading

0 comments on commit b7ebbc1

Please sign in to comment.