Skip to content

Commit

Permalink
fix: trigger oidc web hook on sign in after registration (#4027)
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Aug 7, 2024
1 parent 4fb28b3 commit ad5fb09
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
5 changes: 5 additions & 0 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"strings"
"time"

"github.com/ory/x/sqlxx"

"golang.org/x/exp/maps"

"github.com/ory/x/urlx"
Expand Down Expand Up @@ -464,6 +466,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt

switch a := req.(type) {
case *login.Flow:
a.Active = s.ID()
a.TransientPayload = cntnr.TransientPayload
if ff, err := s.processLogin(w, r, a, et, claims, provider, cntnr); err != nil {
if errors.Is(err, flow.ErrCompletedByStrategy) {
Expand All @@ -477,6 +480,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt
}
return
case *registration.Flow:
a.Active = s.ID()
a.TransientPayload = cntnr.TransientPayload
if ff, err := s.processRegistration(w, r, a, et, claims, provider, cntnr, ""); err != nil {
if ff != nil {
Expand All @@ -487,6 +491,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt
}
return
case *settings.Flow:
a.Active = sqlxx.NullString(s.ID())
a.TransientPayload = cntnr.TransientPayload
sess, err := s.d.SessionManager().FetchFromRequest(r.Context(), r)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo
registrationFlow.RawIDTokenNonce = loginFlow.RawIDTokenNonce
registrationFlow.RequestURL, err = x.TakeOverReturnToParameter(loginFlow.RequestURL, registrationFlow.RequestURL)
registrationFlow.TransientPayload = loginFlow.TransientPayload
registrationFlow.Active = s.ID()

if err != nil {
return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err)
}
Expand Down
1 change: 1 addition & 0 deletions selfservice/strategy/oidc/strategy_registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ func (s *Strategy) registrationToLogin(w http.ResponseWriter, r *http.Request, r
return nil, err
}
lf.TransientPayload = rf.TransientPayload
lf.Active = s.ID()

return lf, nil
}
Expand Down
42 changes: 36 additions & 6 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,21 +424,34 @@ func TestStrategy(t *testing.T) {
}

t.Run("case=should pass registration", func(t *testing.T) {
postRegistrationWebhook := hooktest.NewServer()
t.Cleanup(postRegistrationWebhook.Close)
postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String()))
transientPayload := `{"data": "registration-one"}`

r := newBrowserRegistrationFlow(t, returnTS.URL, time.Minute)
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}})
assertIdentity(t, res, body)
expectTokens(t, "valid", body)
postRegistrationWebhook.AssertTransientPayload(t, transientPayload)
})

t.Run("case=try another registration", func(t *testing.T) {
transientPayload := `{"data": "registration-two"}`
postLoginWebhook := hooktest.NewServer()
t.Cleanup(postLoginWebhook.Close)
postLoginWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, identity.CredentialsTypeOIDC.String()))

returnTo := fmt.Sprintf("%s/home?query=true", returnTS.URL)
r := newBrowserRegistrationFlow(t, fmt.Sprintf("%s?return_to=%s", returnTS.URL, url.QueryEscape(returnTo)), time.Minute)
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}})
assert.Equal(t, returnTo, res.Request.URL.String())
assertIdentity(t, res, body)
expectTokens(t, "valid", body)

postLoginWebhook.AssertTransientPayload(t, transientPayload)
})
})

Expand Down Expand Up @@ -981,31 +994,48 @@ func TestStrategy(t *testing.T) {
scope = []string{"openid"}

t.Run("case=should pass registration", func(t *testing.T) {
postRegistrationWebhook := hooktest.NewServer()
t.Cleanup(postRegistrationWebhook.Close)
postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String()))

r := newBrowserRegistrationFlow(t, returnTS.URL, time.Minute)
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
transientPayload := `{"data": "registration-one"}`
res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}})
assertIdentity(t, res, body)
postRegistrationWebhook.AssertTransientPayload(t, transientPayload)
})

t.Run("case=should pass second time registration", func(t *testing.T) {
postLoginWebhook := hooktest.NewServer()
t.Cleanup(postLoginWebhook.Close)
postLoginWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, identity.CredentialsTypeOIDC.String()))

r := newBrowserLoginFlow(t, returnTS.URL, time.Minute)
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
transientPayload := `{"data": "registration-two"}`
res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}})
assertIdentity(t, res, body)
postLoginWebhook.AssertTransientPayload(t, transientPayload)
})

t.Run("case=should pass third time registration with return to", func(t *testing.T) {
postLoginWebhook := hooktest.NewServer()
t.Cleanup(postLoginWebhook.Close)
postLoginWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, identity.CredentialsTypeOIDC.String()))

returnTo := "/foo"
r := newBrowserLoginFlow(t, fmt.Sprintf("%s?return_to=%s", returnTS.URL, returnTo), time.Minute)
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
transientPayload := `{"data": "registration-three"}`
res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}})
assert.True(t, strings.HasSuffix(res.Request.URL.String(), returnTo))
assertIdentity(t, res, body)
postLoginWebhook.AssertTransientPayload(t, transientPayload)
})
})

t.Run("case=register, merge, and complete data", func(t *testing.T) {

for _, tc := range []struct{ name, provider string }{
{name: "idtoken", provider: "valid"},
{name: "userinfo", provider: "claimsViaUserInfo"},
Expand Down

0 comments on commit ad5fb09

Please sign in to comment.