Skip to content

Commit

Permalink
chore: fixup OIDC function signatures and improve tests
Browse files Browse the repository at this point in the history
Removes a lot of unused params etc.
  • Loading branch information
alnr committed Sep 4, 2024
1 parent 1a9ade0 commit 6eda038
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 136 deletions.
34 changes: 21 additions & 13 deletions internal/testhelpers/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,34 @@ func NewDebugClient(t *testing.T) *http.Client {
return &http.Client{Transport: NewTransportWithLogger(http.DefaultTransport, t)}
}

func NewClientWithCookieJar(t *testing.T, jar *cookiejar.Jar, debugRedirects bool) *http.Client {
func NewClientWithCookieJar(t *testing.T, jar *cookiejar.Jar, checkRedirect CheckRedirectFunc) *http.Client {
if jar == nil {
j, err := cookiejar.New(nil)
jar = j
require.NoError(t, err)
}
if checkRedirect == nil {
checkRedirect = DebugRedirects(t)
}
return &http.Client{
Jar: jar,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if debugRedirects {
t.Logf("Redirect: %s", req.URL.String())
}
if len(via) >= 20 {
for k, v := range via {
t.Logf("Failed with redirect (%d): %s", k, v.URL.String())
}
return errors.New("stopped after 20 redirects")
Jar: jar,
CheckRedirect: checkRedirect,
}
}

type CheckRedirectFunc func(req *http.Request, via []*http.Request) error

func DebugRedirects(t *testing.T) CheckRedirectFunc {
return func(req *http.Request, via []*http.Request) error {
t.Logf("Redirect: %s", req.URL.String())

if len(via) >= 20 {
for k, v := range via {
t.Logf("Failed with redirect (%d): %s", k, v.URL.String())
}
return nil
},
return errors.New("stopped after 20 redirects")
}
return nil
}
}

Expand Down
4 changes: 2 additions & 2 deletions selfservice/strategy/code/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}
return nil, nil
case flow.StateEmailSent:
i, err := s.loginVerifyCode(ctx, r, f, &p, sess)
i, err := s.loginVerifyCode(ctx, f, &p, sess)
if err != nil {
return nil, s.HandleLoginError(r, f, &p, err)
}
Expand Down Expand Up @@ -437,7 +437,7 @@ func maybeNormalizeEmail(input string) string {
return input
}

func (s *Strategy) loginVerifyCode(ctx context.Context, r *http.Request, f *login.Flow, p *updateLoginFlowWithCodeMethod, sess *session.Session) (_ *identity.Identity, err error) {
func (s *Strategy) loginVerifyCode(ctx context.Context, f *login.Flow, p *updateLoginFlowWithCodeMethod, sess *session.Session) (_ *identity.Identity, err error) {
ctx, span := s.deps.Tracer(ctx).Tracer().Start(ctx, "selfservice.strategy.code.strategy.loginVerifyCode")
defer otelx.End(span, &err)

Expand Down
10 changes: 5 additions & 5 deletions selfservice/strategy/oidc/provider_facebook.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ func NewProviderFacebook(
}
}

func (g *ProviderFacebook) generateAppSecretProof(ctx context.Context, exchange *oauth2.Token) string {
func (g *ProviderFacebook) generateAppSecretProof(token *oauth2.Token) string {
secret := g.config.ClientSecret
data := exchange.AccessToken
data := token.AccessToken

h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(data))
Expand All @@ -62,19 +62,19 @@ func (g *ProviderFacebook) OAuth2(ctx context.Context) (*oauth2.Config, error) {
return g.oauth2ConfigFromEndpoint(ctx, endpoint), nil
}

func (g *ProviderFacebook) Claims(ctx context.Context, exchange *oauth2.Token, query url.Values) (*Claims, error) {
func (g *ProviderFacebook) Claims(ctx context.Context, token *oauth2.Token, query url.Values) (*Claims, error) {
o, err := g.OAuth2(ctx)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err))
}

appSecretProof := g.generateAppSecretProof(ctx, exchange)
appSecretProof := g.generateAppSecretProof(token)
u, err := url.Parse(fmt.Sprintf("https://graph.facebook.com/me?fields=id,name,first_name,last_name,middle_name,email,picture,birthday,gender&appsecret_proof=%s", appSecretProof))
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err))
}

ctx, client := httpx.SetOAuth2(ctx, g.reg.HTTPClient(ctx), o, exchange)
ctx, client := httpx.SetOAuth2(ctx, g.reg.HTTPClient(ctx), o, token)
req, err := retryablehttp.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("%s", err))
Expand Down
6 changes: 3 additions & 3 deletions selfservice/strategy/oidc/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ func NewTestProvider(c *Configuration, reg Dependencies) Provider {
}
}

func RegisterTestProvider(id string) func() {
func RegisterTestProvider(t *testing.T, id string) {
supportedProviders[id] = func(c *Configuration, reg Dependencies) Provider {
return NewTestProvider(c, reg)
}
return func() {
t.Cleanup(func() {
delete(supportedProviders, id)
}
})
}

var _ IDTokenVerifier = new(TestProvider)
Expand Down
35 changes: 18 additions & 17 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ const (

RouteAuth = RouteBase + "/auth/:flow"
RouteCallback = RouteBase + "/callback/:provider"
RouteCallbackGeneric = RouteBase + "/callback"
RouteOrganizationCallback = RouteBase + "/organization/:organization/callback/:provider"
)

Expand Down Expand Up @@ -403,22 +404,22 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt
req, cntnr, err := s.ValidateCallback(w, r)
if err != nil {
if req != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
} else {
s.d.SelfServiceErrorManager().Forward(ctx, w, r, s.handleError(ctx, w, r, nil, pid, nil, err))
s.d.SelfServiceErrorManager().Forward(ctx, w, r, s.handleError(w, r, nil, pid, nil, err))
}
return
}

if authenticated, err := s.alreadyAuthenticated(w, r, req); err != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
} else if authenticated {
return
}

provider, err := s.provider(r.Context(), pid)
if err != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
}

Expand All @@ -428,37 +429,37 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt
case OAuth2Provider:
token, err := s.ExchangeCode(r.Context(), provider, code)
if err != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
}

et, err = s.encryptOAuth2Tokens(r.Context(), token)
if err != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
}

claims, err = p.Claims(r.Context(), token, r.URL.Query())
if err != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
}
case OAuth1Provider:
token, err := p.ExchangeToken(r.Context(), r)
if err != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
}

claims, err = p.Claims(r.Context(), token)
if err != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
}
}

if err = claims.Validate(); err != nil {
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, err))
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
}

Expand All @@ -482,7 +483,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt
case *registration.Flow:
a.Active = s.ID()
a.TransientPayload = cntnr.TransientPayload
if ff, err := s.processRegistration(ctx, w, r, a, et, claims, provider, cntnr, ""); err != nil {
if ff, err := s.processRegistration(ctx, w, r, a, et, claims, provider, cntnr); err != nil {
if ff != nil {
s.forwardError(w, r, ff, err)
return
Expand All @@ -495,16 +496,16 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt
a.TransientPayload = cntnr.TransientPayload
sess, err := s.d.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
s.forwardError(w, r, a, s.handleError(ctx, w, r, a, pid, nil, err))
s.forwardError(w, r, a, s.handleError(w, r, a, pid, nil, err))
return
}
if err := s.linkProvider(w, r, &settings.UpdateContext{Session: sess, Flow: a}, et, claims, provider); err != nil {
s.forwardError(w, r, a, s.handleError(ctx, w, r, a, pid, nil, err))
s.forwardError(w, r, a, s.handleError(w, r, a, pid, nil, err))
return
}
return
default:
s.forwardError(w, r, req, s.handleError(ctx, w, r, req, pid, nil, errors.WithStack(x.PseudoPanic.
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, errors.WithStack(x.PseudoPanic.
WithDetailf("cause", "Unexpected type in OpenID Connect flow: %T", a))))
return
}
Expand Down Expand Up @@ -588,7 +589,7 @@ func (s *Strategy) forwardError(w http.ResponseWriter, r *http.Request, f flow.F
}
}

func (s *Strategy) handleError(ctx context.Context, w http.ResponseWriter, r *http.Request, f flow.Flow, usedProviderID string, traits []byte, err error) error {
func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Flow, usedProviderID string, traits []byte, err error) error {
switch rf := f.(type) {
case *login.Flow:
return err
Expand All @@ -608,7 +609,7 @@ func (s *Strategy) handleError(ctx context.Context, w http.ResponseWriter, r *ht
rf.UI.Messages.Add(text.NewErrorValidationDuplicateCredentialsOnOIDCLink())
}

lf, err := s.registrationToLogin(w, r, rf, usedProviderID)
lf, err := s.registrationToLogin(w, r, rf)
if err != nil {
return err
}
Expand Down Expand Up @@ -741,7 +742,7 @@ func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.Au
}
}

func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provider Provider, idToken, idTokenNonce string) (*Claims, error) {
func (s *Strategy) processIDToken(r *http.Request, provider Provider, idToken, idTokenNonce string) (*Claims, error) {
verifier, ok := provider.(IDTokenVerifier)
if !ok {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The provider %s does not support id_token verification", provider.Config().Provider))
Expand Down
38 changes: 31 additions & 7 deletions selfservice/strategy/oidc/strategy_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,43 @@ func (token *idTokenClaims) MarshalJSON() ([]byte, error) {
})
}

func createClient(t *testing.T, remote string, redir string) (id, secret string) {
func createClient(t *testing.T, remote string, redir []string) (id, secret string) {
require.NoError(t, resilience.Retry(logrusx.New("", ""), time.Second*10, time.Minute*2, func() error {
var b bytes.Buffer
require.NoError(t, json.NewEncoder(&b).Encode(&struct {
Scope string `json:"scope"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
RedirectURIs []string `json:"redirect_uris"`
Scope string `json:"scope"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
RedirectURIs []string `json:"redirect_uris"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
}{
GrantTypes: []string{"authorization_code", "refresh_token"},
ResponseTypes: []string{"code"},
Scope: "offline offline_access openid",
RedirectURIs: []string{redir},
RedirectURIs: redir,

// This is a workaround to prevent golang.org/x/oauth2 from
// swallowing the actual error messages from failed token exchanges.
//
// The library first attempts to use the Authorization header to
// pass Client ID+secret during token exchange (client_secret_basic
// in Hydra terminology). If that fails (with any error), it tries
// again with the Client ID+secret passed in the HTTP POST body
// (client_secret_post in Hydra). If that also fails, this second
// error is returned.
//
// Now, if the the client was indeed configured to use
// client_secret_basic, but the token exchange fails for another
// reason, the error message will be swallowed and replaced with
// "invalid_client".
//
// Manually setting this to client_secret_post means that during
// tests, all token exchanges will first fail with `invalid_client`
// and then be retried with the correct method. This is the only way
// to get the actual error message from the server, however.
//
// https://github.com/golang/oauth2/blob/5fd42413edb3b1699004a31b72e485e0e4ba1b13/internal/token.go#L227-L242
TokenEndpointAuthMethod: "client_secret_post",
}))

res, err := http.Post(remote+"/admin/clients", "application/json", &b)
Expand Down Expand Up @@ -317,7 +341,7 @@ func newOIDCProvider(
id string,
opts ...func(*oidc.Configuration),
) oidc.Configuration {
clientID, secret := createClient(t, hydraAdmin, kratos.URL+oidc.RouteBase+"/callback/"+id)
clientID, secret := createClient(t, hydraAdmin, []string{kratos.URL + oidc.RouteBase + "/callback/" + id, kratos.URL + oidc.RouteCallbackGeneric})

cfg := oidc.Configuration{
Provider: "generic",
Expand Down
Loading

0 comments on commit 6eda038

Please sign in to comment.