Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mockoidc: enable server impersonation #36

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,37 @@ import (
"github.com/golang-jwt/jwt"
)

const (
IssuerBase = "/oidc"
AuthorizationEndpoint = "/oidc/authorize"
TokenEndpoint = "/oidc/token"
UserinfoEndpoint = "/oidc/userinfo"
JWKSEndpoint = "/oidc/.well-known/jwks.json"
DiscoveryEndpoint = "/oidc/.well-known/openid-configuration"
type EndpointConfig struct {
IssuerBase string `json:"issuer"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
JWKSEndpoint string `json:"jwks_uri"`
DiscoveryEndpoint string
}

func (e *EndpointConfig) Defaults() {
if e.IssuerBase == "" {
e.IssuerBase = "/oidc"
}
if e.AuthorizationEndpoint == "" {
e.AuthorizationEndpoint = e.IssuerBase + "/authorize"
}
if e.TokenEndpoint == "" {
e.TokenEndpoint = e.IssuerBase + "/token"
}
if e.UserinfoEndpoint == "" {
e.UserinfoEndpoint = e.IssuerBase + "/userinfo"
}
if e.JWKSEndpoint == "" {
e.JWKSEndpoint = e.IssuerBase + "/.well-known/jwks.json"
}
if e.DiscoveryEndpoint == "" {
e.DiscoveryEndpoint = e.IssuerBase + "/.well-known/openid-configuration"
}
}

const (
InvalidRequest = "invalid_request"
InvalidClient = "invalid_client"
InvalidGrant = "invalid_grant"
Expand Down
48 changes: 24 additions & 24 deletions handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@ func TestMockOIDC_Authorize(t *testing.T) {
data.Set("client_id", m.ClientID)
data.Set("code_challenge", "somehash")
data.Set("code_challenge_method", "S256")
assert.HTTPError(t, m.Authorize, http.MethodGet, mockoidc.AuthorizationEndpoint, nil)
assert.HTTPError(t, m.Authorize, http.MethodGet, m.EndpointConfig.AuthorizationEndpoint, nil)

// valid request
assert.HTTPStatusCode(t, m.Authorize, http.MethodGet,
mockoidc.AuthorizationEndpoint, data, http.StatusFound)
m.EndpointConfig.AuthorizationEndpoint, data, http.StatusFound)

// Bad client ID
data.Set("client_id", "wrong_id")
assert.HTTPStatusCode(t, m.Authorize, http.MethodGet,
mockoidc.AuthorizationEndpoint, data, http.StatusUnauthorized)
m.EndpointConfig.AuthorizationEndpoint, data, http.StatusUnauthorized)
assert.HTTPBodyContains(t, m.Authorize, http.MethodGet,
mockoidc.AuthorizationEndpoint, data, mockoidc.InvalidClient)
m.EndpointConfig.AuthorizationEndpoint, data, mockoidc.InvalidClient)

// Bad code challenge method
data.Set("client_id", m.ClientID)
data.Set("code_challenge_method", "does not exist")
assert.HTTPStatusCode(t, m.Authorize, http.MethodGet,
mockoidc.AuthorizationEndpoint, data, http.StatusBadRequest)
m.EndpointConfig.AuthorizationEndpoint, data, http.StatusBadRequest)
assert.HTTPBodyContains(t, m.Authorize, http.MethodGet,
mockoidc.AuthorizationEndpoint, data, mockoidc.InvalidRequest)
m.EndpointConfig.AuthorizationEndpoint, data, mockoidc.InvalidRequest)

// Missing required form values
for key := range data {
Expand All @@ -60,9 +60,9 @@ func TestMockOIDC_Authorize(t *testing.T) {
badData.Del(key)

assert.HTTPStatusCode(t, m.Authorize, http.MethodGet,
mockoidc.AuthorizationEndpoint, badData, http.StatusBadRequest)
m.EndpointConfig.AuthorizationEndpoint, badData, http.StatusBadRequest)
assert.HTTPBodyContains(t, m.Authorize, http.MethodGet,
mockoidc.AuthorizationEndpoint, badData, mockoidc.InvalidRequest)
m.EndpointConfig.AuthorizationEndpoint, badData, mockoidc.InvalidRequest)
})
}
}
Expand All @@ -74,7 +74,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) {
session, _ := m.SessionStore.NewSession(
"openid email profile", "nonce", mockoidc.DefaultUser(), "", "")

assert.HTTPError(t, m.Token, http.MethodPost, mockoidc.TokenEndpoint, nil)
assert.HTTPError(t, m.Token, http.MethodPost, m.EndpointConfig.TokenEndpoint, nil)

data := url.Values{}
data.Set("client_id", m.ClientID)
Expand All @@ -88,7 +88,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) {
badData, _ := url.ParseQuery(data.Encode())
badData.Del(key)

rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData)
rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData)
assert.Equal(t, http.StatusBadRequest, rr.Code)

body, err := ioutil.ReadAll(rr.Body)
Expand All @@ -104,7 +104,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) {
assert.NoError(t, err)

badData.Set(key, "WRONG")
rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData)
rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData)
if key == "grant_type" {
assert.Equal(t, http.StatusBadRequest, rr.Code)
} else {
Expand All @@ -114,7 +114,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) {
}

// good request; check responses
rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data)
rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data)
assert.Equal(t, http.StatusOK, rr.Code)

tokenResp := make(map[string]interface{})
Expand All @@ -139,7 +139,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) {
}

// duplicate attempts are rejects
rrDup := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data)
rrDup := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data)
assert.Equal(t, http.StatusUnauthorized, rrDup.Code)
}

Expand All @@ -153,7 +153,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengePlain(t *testing.T) {
"openid email profile", "nonce", mockoidc.DefaultUser(),
codeChallenge, mockoidc.CodeChallengeMethodPlain)

assert.HTTPError(t, m.Token, http.MethodPost, mockoidc.TokenEndpoint, nil)
assert.HTTPError(t, m.Token, http.MethodPost, m.EndpointConfig.TokenEndpoint, nil)

data := url.Values{}
data.Set("client_id", m.ClientID)
Expand All @@ -163,7 +163,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengePlain(t *testing.T) {
data.Set("code_verifier", "sum")

// good request; good response
rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data)
rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data)
assert.Equal(t, http.StatusOK, rr.Code)

tokenResp := make(map[string]interface{})
Expand All @@ -174,7 +174,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengePlain(t *testing.T) {
badData, _ := url.ParseQuery(data.Encode())
badData.Del("code_verifier")

rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData)
rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData)
assert.Equal(t, http.StatusUnauthorized, rr.Code)

body, err := ioutil.ReadAll(rr.Body)
Expand All @@ -185,7 +185,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengePlain(t *testing.T) {
badData, _ = url.ParseQuery(data.Encode())
badData.Set("code_verifier", "WRONG")

rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData)
rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData)
assert.Equal(t, http.StatusUnauthorized, rr.Code)

body, err = ioutil.ReadAll(rr.Body)
Expand All @@ -203,7 +203,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengeHash(t *testing.T) {
"openid email profile", "nonce", mockoidc.DefaultUser(),
codeChallenge, mockoidc.CodeChallengeMethodS256)

assert.HTTPError(t, m.Token, http.MethodPost, mockoidc.TokenEndpoint, nil)
assert.HTTPError(t, m.Token, http.MethodPost, m.EndpointConfig.TokenEndpoint, nil)

data := url.Values{}
data.Set("client_id", m.ClientID)
Expand All @@ -213,7 +213,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengeHash(t *testing.T) {
data.Set("code_verifier", "sum")

// good request; good response
rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data)
rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data)
assert.Equal(t, http.StatusOK, rr.Code)

tokenResp := make(map[string]interface{})
Expand All @@ -224,7 +224,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengeHash(t *testing.T) {
badData, _ := url.ParseQuery(data.Encode())
badData.Del("code_verifier")

rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData)
rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData)
assert.Equal(t, http.StatusUnauthorized, rr.Code)

body, err := ioutil.ReadAll(rr.Body)
Expand All @@ -235,7 +235,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengeHash(t *testing.T) {
badData, _ = url.ParseQuery(data.Encode())
badData.Set("code_verifier", "WRONG")

rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData)
rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData)
assert.Equal(t, http.StatusUnauthorized, rr.Code)

body, err = ioutil.ReadAll(rr.Body)
Expand All @@ -251,7 +251,7 @@ func TestMockOIDC_Token_RefreshGrant(t *testing.T) {
"openid email profile", "sessionNonce", mockoidc.DefaultUser(), "", "")
refreshToken, _ := session.RefreshToken(m.Config(), m.Keypair, m.Now())

assert.HTTPError(t, m.Token, http.MethodPost, mockoidc.TokenEndpoint, nil)
assert.HTTPError(t, m.Token, http.MethodPost, m.EndpointConfig.TokenEndpoint, nil)

data := url.Values{}
data.Set("client_id", m.ClientID)
Expand All @@ -260,7 +260,7 @@ func TestMockOIDC_Token_RefreshGrant(t *testing.T) {
data.Set("grant_type", "refresh_token")

// good request; check responses
rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data)
rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data)
assert.Equal(t, http.StatusOK, rr.Code)

tokenResp := make(map[string]interface{})
Expand Down Expand Up @@ -291,7 +291,7 @@ func TestMockOIDC_Token_RefreshGrant(t *testing.T) {

data.Set("refresh_token", expiredToken)

rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data)
rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data)
assert.Equal(t, http.StatusUnauthorized, rr.Code)

body, err := ioutil.ReadAll(rr.Body)
Expand Down
47 changes: 36 additions & 11 deletions mockoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"

"github.com/golang-jwt/jwt"
Expand Down Expand Up @@ -39,6 +41,8 @@ type MockOIDC struct {
tlsConfig *tls.Config
middleware []func(http.Handler) http.Handler
fastForward time.Duration

EndpointConfig EndpointConfig
}

// Config gives the various settings MockOIDC starts with that a test
Expand Down Expand Up @@ -71,6 +75,9 @@ func NewServer(key *rsa.PrivateKey) (*MockOIDC, error) {
return nil, err
}

ecfg := EndpointConfig{}
ecfg.Defaults()

return &MockOIDC{
ClientID: clientID,
ClientSecret: clientSecret,
Expand All @@ -81,6 +88,7 @@ func NewServer(key *rsa.PrivateKey) (*MockOIDC, error) {
SessionStore: NewSessionStore(),
UserQueue: &UserQueue{},
ErrorQueue: &ErrorQueue{},
EndpointConfig: ecfg,
}, nil
}

Expand Down Expand Up @@ -110,12 +118,20 @@ func (m *MockOIDC) Start(ln net.Listener, cfg *tls.Config) error {
return errors.New("server already started")
}

var pathOf = func(s string) string {
u, err := url.Parse(s)
if err != nil {
return s
}
return u.Path
}

handler := http.NewServeMux()
handler.Handle(AuthorizationEndpoint, m.chainMiddleware(m.Authorize))
handler.Handle(TokenEndpoint, m.chainMiddleware(m.Token))
handler.Handle(UserinfoEndpoint, m.chainMiddleware(m.Userinfo))
handler.Handle(JWKSEndpoint, m.chainMiddleware(m.JWKS))
handler.Handle(DiscoveryEndpoint, m.chainMiddleware(m.Discovery))
handler.Handle(pathOf(m.EndpointConfig.AuthorizationEndpoint), m.chainMiddleware(m.Authorize))
handler.Handle(pathOf(m.EndpointConfig.TokenEndpoint), m.chainMiddleware(m.Token))
handler.Handle(pathOf(m.EndpointConfig.UserinfoEndpoint), m.chainMiddleware(m.Userinfo))
handler.Handle(pathOf(m.EndpointConfig.JWKSEndpoint), m.chainMiddleware(m.JWKS))
handler.Handle(pathOf(m.EndpointConfig.DiscoveryEndpoint), m.chainMiddleware(m.Discovery))

m.Server = &http.Server{
Addr: ln.Addr().String(),
Expand Down Expand Up @@ -220,52 +236,61 @@ func (m *MockOIDC) Addr() string {
return fmt.Sprintf("%s://%s", proto, m.Server.Addr)
}

// applyBase adds a the server scheme and host to the given url, unless it is already absolute.
func (m *MockOIDC) applyBase(u string) string {
if strings.Contains(u, "://") {
return u
}

return m.Addr() + u
}

// Issuer returns the OIDC Issuer that will be in `iss` token claims
func (m *MockOIDC) Issuer() string {
if m.Server == nil {
return ""
}
return m.Addr() + IssuerBase
return m.applyBase(m.EndpointConfig.IssuerBase)
}

// DiscoveryEndpoint returns the full `/.well-known/openid-configuration` URL
func (m *MockOIDC) DiscoveryEndpoint() string {
if m.Server == nil {
return ""
}
return m.Addr() + DiscoveryEndpoint
return m.applyBase(m.EndpointConfig.DiscoveryEndpoint)
}

// AuthorizationEndpoint returns the OIDC `authorization_endpoint`
func (m *MockOIDC) AuthorizationEndpoint() string {
if m.Server == nil {
return ""
}
return m.Addr() + AuthorizationEndpoint
return m.applyBase(m.EndpointConfig.AuthorizationEndpoint)
}

// TokenEndpoint returns the OIDC `token_endpoint`
func (m *MockOIDC) TokenEndpoint() string {
if m.Server == nil {
return ""
}
return m.Addr() + TokenEndpoint
return m.applyBase(m.EndpointConfig.TokenEndpoint)
}

// UserinfoEndpoint returns the OIDC `userinfo_endpoint`
func (m *MockOIDC) UserinfoEndpoint() string {
if m.Server == nil {
return ""
}
return m.Addr() + UserinfoEndpoint
return m.applyBase(m.EndpointConfig.UserinfoEndpoint)
}

// JWKSEndpoint returns the OIDC `jwks_uri`
func (m *MockOIDC) JWKSEndpoint() string {
if m.Server == nil {
return ""
}
return m.Addr() + JWKSEndpoint
return m.applyBase(m.EndpointConfig.JWKSEndpoint)
}

func (m *MockOIDC) chainMiddleware(endpoint func(http.ResponseWriter, *http.Request)) http.Handler {
Expand Down