diff --git a/handlers.go b/handlers.go index 1cd7b72..df5e861 100644 --- a/handlers.go +++ b/handlers.go @@ -12,14 +12,37 @@ import ( "github.com/golang-jwt/jwt" ) -const ( - IssuerBase = "/oidc" - AuthorizationEndpoint = "/oidc/authorize" - TokenEndpoint = "/oidc/token" - UserinfoEndpoint = "/oidc/userinfo" - JWKSEndpoint = "/oidc/.well-known/jwks.json" - DiscoveryEndpoint = "/oidc/.well-known/openid-configuration" +type EndpointConfig struct { + IssuerBase string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + JWKSEndpoint string `json:"jwks_uri"` + DiscoveryEndpoint string +} + +func (e *EndpointConfig) Defaults() { + if e.IssuerBase == "" { + e.IssuerBase = "/oidc" + } + if e.AuthorizationEndpoint == "" { + e.AuthorizationEndpoint = e.IssuerBase + "/authorize" + } + if e.TokenEndpoint == "" { + e.TokenEndpoint = e.IssuerBase + "/token" + } + if e.UserinfoEndpoint == "" { + e.UserinfoEndpoint = e.IssuerBase + "/userinfo" + } + if e.JWKSEndpoint == "" { + e.JWKSEndpoint = e.IssuerBase + "/.well-known/jwks.json" + } + if e.DiscoveryEndpoint == "" { + e.DiscoveryEndpoint = e.IssuerBase + "/.well-known/openid-configuration" + } +} +const ( InvalidRequest = "invalid_request" InvalidClient = "invalid_client" InvalidGrant = "invalid_grant" diff --git a/handlers_test.go b/handlers_test.go index 0e9e31c..0f20511 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -27,26 +27,26 @@ func TestMockOIDC_Authorize(t *testing.T) { data.Set("client_id", m.ClientID) data.Set("code_challenge", "somehash") data.Set("code_challenge_method", "S256") - assert.HTTPError(t, m.Authorize, http.MethodGet, mockoidc.AuthorizationEndpoint, nil) + assert.HTTPError(t, m.Authorize, http.MethodGet, m.EndpointConfig.AuthorizationEndpoint, nil) // valid request assert.HTTPStatusCode(t, m.Authorize, http.MethodGet, - mockoidc.AuthorizationEndpoint, data, http.StatusFound) + m.EndpointConfig.AuthorizationEndpoint, data, http.StatusFound) // Bad client ID data.Set("client_id", "wrong_id") assert.HTTPStatusCode(t, m.Authorize, http.MethodGet, - mockoidc.AuthorizationEndpoint, data, http.StatusUnauthorized) + m.EndpointConfig.AuthorizationEndpoint, data, http.StatusUnauthorized) assert.HTTPBodyContains(t, m.Authorize, http.MethodGet, - mockoidc.AuthorizationEndpoint, data, mockoidc.InvalidClient) + m.EndpointConfig.AuthorizationEndpoint, data, mockoidc.InvalidClient) // Bad code challenge method data.Set("client_id", m.ClientID) data.Set("code_challenge_method", "does not exist") assert.HTTPStatusCode(t, m.Authorize, http.MethodGet, - mockoidc.AuthorizationEndpoint, data, http.StatusBadRequest) + m.EndpointConfig.AuthorizationEndpoint, data, http.StatusBadRequest) assert.HTTPBodyContains(t, m.Authorize, http.MethodGet, - mockoidc.AuthorizationEndpoint, data, mockoidc.InvalidRequest) + m.EndpointConfig.AuthorizationEndpoint, data, mockoidc.InvalidRequest) // Missing required form values for key := range data { @@ -60,9 +60,9 @@ func TestMockOIDC_Authorize(t *testing.T) { badData.Del(key) assert.HTTPStatusCode(t, m.Authorize, http.MethodGet, - mockoidc.AuthorizationEndpoint, badData, http.StatusBadRequest) + m.EndpointConfig.AuthorizationEndpoint, badData, http.StatusBadRequest) assert.HTTPBodyContains(t, m.Authorize, http.MethodGet, - mockoidc.AuthorizationEndpoint, badData, mockoidc.InvalidRequest) + m.EndpointConfig.AuthorizationEndpoint, badData, mockoidc.InvalidRequest) }) } } @@ -74,7 +74,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) { session, _ := m.SessionStore.NewSession( "openid email profile", "nonce", mockoidc.DefaultUser(), "", "") - assert.HTTPError(t, m.Token, http.MethodPost, mockoidc.TokenEndpoint, nil) + assert.HTTPError(t, m.Token, http.MethodPost, m.EndpointConfig.TokenEndpoint, nil) data := url.Values{} data.Set("client_id", m.ClientID) @@ -88,7 +88,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) { badData, _ := url.ParseQuery(data.Encode()) badData.Del(key) - rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData) + rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData) assert.Equal(t, http.StatusBadRequest, rr.Code) body, err := ioutil.ReadAll(rr.Body) @@ -104,7 +104,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) { assert.NoError(t, err) badData.Set(key, "WRONG") - rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData) + rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData) if key == "grant_type" { assert.Equal(t, http.StatusBadRequest, rr.Code) } else { @@ -114,7 +114,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) { } // good request; check responses - rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data) + rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data) assert.Equal(t, http.StatusOK, rr.Code) tokenResp := make(map[string]interface{}) @@ -139,7 +139,7 @@ func TestMockOIDC_Token_CodeGrant(t *testing.T) { } // duplicate attempts are rejects - rrDup := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data) + rrDup := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data) assert.Equal(t, http.StatusUnauthorized, rrDup.Code) } @@ -153,7 +153,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengePlain(t *testing.T) { "openid email profile", "nonce", mockoidc.DefaultUser(), codeChallenge, mockoidc.CodeChallengeMethodPlain) - assert.HTTPError(t, m.Token, http.MethodPost, mockoidc.TokenEndpoint, nil) + assert.HTTPError(t, m.Token, http.MethodPost, m.EndpointConfig.TokenEndpoint, nil) data := url.Values{} data.Set("client_id", m.ClientID) @@ -163,7 +163,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengePlain(t *testing.T) { data.Set("code_verifier", "sum") // good request; good response - rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data) + rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data) assert.Equal(t, http.StatusOK, rr.Code) tokenResp := make(map[string]interface{}) @@ -174,7 +174,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengePlain(t *testing.T) { badData, _ := url.ParseQuery(data.Encode()) badData.Del("code_verifier") - rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData) + rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData) assert.Equal(t, http.StatusUnauthorized, rr.Code) body, err := ioutil.ReadAll(rr.Body) @@ -185,7 +185,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengePlain(t *testing.T) { badData, _ = url.ParseQuery(data.Encode()) badData.Set("code_verifier", "WRONG") - rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData) + rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData) assert.Equal(t, http.StatusUnauthorized, rr.Code) body, err = ioutil.ReadAll(rr.Body) @@ -203,7 +203,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengeHash(t *testing.T) { "openid email profile", "nonce", mockoidc.DefaultUser(), codeChallenge, mockoidc.CodeChallengeMethodS256) - assert.HTTPError(t, m.Token, http.MethodPost, mockoidc.TokenEndpoint, nil) + assert.HTTPError(t, m.Token, http.MethodPost, m.EndpointConfig.TokenEndpoint, nil) data := url.Values{} data.Set("client_id", m.ClientID) @@ -213,7 +213,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengeHash(t *testing.T) { data.Set("code_verifier", "sum") // good request; good response - rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data) + rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data) assert.Equal(t, http.StatusOK, rr.Code) tokenResp := make(map[string]interface{}) @@ -224,7 +224,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengeHash(t *testing.T) { badData, _ := url.ParseQuery(data.Encode()) badData.Del("code_verifier") - rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData) + rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData) assert.Equal(t, http.StatusUnauthorized, rr.Code) body, err := ioutil.ReadAll(rr.Body) @@ -235,7 +235,7 @@ func TestMockOIDC_Token_CodeGrant_CodeChallengeHash(t *testing.T) { badData, _ = url.ParseQuery(data.Encode()) badData.Set("code_verifier", "WRONG") - rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, badData) + rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, badData) assert.Equal(t, http.StatusUnauthorized, rr.Code) body, err = ioutil.ReadAll(rr.Body) @@ -251,7 +251,7 @@ func TestMockOIDC_Token_RefreshGrant(t *testing.T) { "openid email profile", "sessionNonce", mockoidc.DefaultUser(), "", "") refreshToken, _ := session.RefreshToken(m.Config(), m.Keypair, m.Now()) - assert.HTTPError(t, m.Token, http.MethodPost, mockoidc.TokenEndpoint, nil) + assert.HTTPError(t, m.Token, http.MethodPost, m.EndpointConfig.TokenEndpoint, nil) data := url.Values{} data.Set("client_id", m.ClientID) @@ -260,7 +260,7 @@ func TestMockOIDC_Token_RefreshGrant(t *testing.T) { data.Set("grant_type", "refresh_token") // good request; check responses - rr := testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data) + rr := testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data) assert.Equal(t, http.StatusOK, rr.Code) tokenResp := make(map[string]interface{}) @@ -291,7 +291,7 @@ func TestMockOIDC_Token_RefreshGrant(t *testing.T) { data.Set("refresh_token", expiredToken) - rr = testResponse(t, mockoidc.TokenEndpoint, m.Token, http.MethodPost, data) + rr = testResponse(t, m.EndpointConfig.TokenEndpoint, m.Token, http.MethodPost, data) assert.Equal(t, http.StatusUnauthorized, rr.Code) body, err := ioutil.ReadAll(rr.Body) diff --git a/mockoidc.go b/mockoidc.go index 990923f..20fe3fa 100644 --- a/mockoidc.go +++ b/mockoidc.go @@ -8,6 +8,8 @@ import ( "fmt" "net" "net/http" + "net/url" + "strings" "time" "github.com/golang-jwt/jwt" @@ -39,6 +41,8 @@ type MockOIDC struct { tlsConfig *tls.Config middleware []func(http.Handler) http.Handler fastForward time.Duration + + EndpointConfig EndpointConfig } // Config gives the various settings MockOIDC starts with that a test @@ -71,6 +75,9 @@ func NewServer(key *rsa.PrivateKey) (*MockOIDC, error) { return nil, err } + ecfg := EndpointConfig{} + ecfg.Defaults() + return &MockOIDC{ ClientID: clientID, ClientSecret: clientSecret, @@ -81,6 +88,7 @@ func NewServer(key *rsa.PrivateKey) (*MockOIDC, error) { SessionStore: NewSessionStore(), UserQueue: &UserQueue{}, ErrorQueue: &ErrorQueue{}, + EndpointConfig: ecfg, }, nil } @@ -110,12 +118,20 @@ func (m *MockOIDC) Start(ln net.Listener, cfg *tls.Config) error { return errors.New("server already started") } + var pathOf = func(s string) string { + u, err := url.Parse(s) + if err != nil { + return s + } + return u.Path + } + handler := http.NewServeMux() - handler.Handle(AuthorizationEndpoint, m.chainMiddleware(m.Authorize)) - handler.Handle(TokenEndpoint, m.chainMiddleware(m.Token)) - handler.Handle(UserinfoEndpoint, m.chainMiddleware(m.Userinfo)) - handler.Handle(JWKSEndpoint, m.chainMiddleware(m.JWKS)) - handler.Handle(DiscoveryEndpoint, m.chainMiddleware(m.Discovery)) + handler.Handle(pathOf(m.EndpointConfig.AuthorizationEndpoint), m.chainMiddleware(m.Authorize)) + handler.Handle(pathOf(m.EndpointConfig.TokenEndpoint), m.chainMiddleware(m.Token)) + handler.Handle(pathOf(m.EndpointConfig.UserinfoEndpoint), m.chainMiddleware(m.Userinfo)) + handler.Handle(pathOf(m.EndpointConfig.JWKSEndpoint), m.chainMiddleware(m.JWKS)) + handler.Handle(pathOf(m.EndpointConfig.DiscoveryEndpoint), m.chainMiddleware(m.Discovery)) m.Server = &http.Server{ Addr: ln.Addr().String(), @@ -220,12 +236,21 @@ func (m *MockOIDC) Addr() string { return fmt.Sprintf("%s://%s", proto, m.Server.Addr) } +// applyBase adds a the server scheme and host to the given url, unless it is already absolute. +func (m *MockOIDC) applyBase(u string) string { + if strings.Contains(u, "://") { + return u + } + + return m.Addr() + u +} + // Issuer returns the OIDC Issuer that will be in `iss` token claims func (m *MockOIDC) Issuer() string { if m.Server == nil { return "" } - return m.Addr() + IssuerBase + return m.applyBase(m.EndpointConfig.IssuerBase) } // DiscoveryEndpoint returns the full `/.well-known/openid-configuration` URL @@ -233,7 +258,7 @@ func (m *MockOIDC) DiscoveryEndpoint() string { if m.Server == nil { return "" } - return m.Addr() + DiscoveryEndpoint + return m.applyBase(m.EndpointConfig.DiscoveryEndpoint) } // AuthorizationEndpoint returns the OIDC `authorization_endpoint` @@ -241,7 +266,7 @@ func (m *MockOIDC) AuthorizationEndpoint() string { if m.Server == nil { return "" } - return m.Addr() + AuthorizationEndpoint + return m.applyBase(m.EndpointConfig.AuthorizationEndpoint) } // TokenEndpoint returns the OIDC `token_endpoint` @@ -249,7 +274,7 @@ func (m *MockOIDC) TokenEndpoint() string { if m.Server == nil { return "" } - return m.Addr() + TokenEndpoint + return m.applyBase(m.EndpointConfig.TokenEndpoint) } // UserinfoEndpoint returns the OIDC `userinfo_endpoint` @@ -257,7 +282,7 @@ func (m *MockOIDC) UserinfoEndpoint() string { if m.Server == nil { return "" } - return m.Addr() + UserinfoEndpoint + return m.applyBase(m.EndpointConfig.UserinfoEndpoint) } // JWKSEndpoint returns the OIDC `jwks_uri` @@ -265,7 +290,7 @@ func (m *MockOIDC) JWKSEndpoint() string { if m.Server == nil { return "" } - return m.Addr() + JWKSEndpoint + return m.applyBase(m.EndpointConfig.JWKSEndpoint) } func (m *MockOIDC) chainMiddleware(endpoint func(http.ResponseWriter, *http.Request)) http.Handler {