From ad5fb09687f863e7c5d45868d0b8f5ec2d965372 Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Wed, 7 Aug 2024 15:39:20 +0200 Subject: [PATCH] fix: trigger oidc web hook on sign in after registration (#4027) --- selfservice/strategy/oidc/strategy.go | 5 +++ selfservice/strategy/oidc/strategy_login.go | 2 + .../strategy/oidc/strategy_registration.go | 1 + selfservice/strategy/oidc/strategy_test.go | 42 ++++++++++++++++--- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index 901c5636424..603ab5a5ea6 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -16,6 +16,8 @@ import ( "strings" "time" + "github.com/ory/x/sqlxx" + "golang.org/x/exp/maps" "github.com/ory/x/urlx" @@ -464,6 +466,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt switch a := req.(type) { case *login.Flow: + a.Active = s.ID() a.TransientPayload = cntnr.TransientPayload if ff, err := s.processLogin(w, r, a, et, claims, provider, cntnr); err != nil { if errors.Is(err, flow.ErrCompletedByStrategy) { @@ -477,6 +480,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt } return case *registration.Flow: + a.Active = s.ID() a.TransientPayload = cntnr.TransientPayload if ff, err := s.processRegistration(w, r, a, et, claims, provider, cntnr, ""); err != nil { if ff != nil { @@ -487,6 +491,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt } return case *settings.Flow: + a.Active = sqlxx.NullString(s.ID()) a.TransientPayload = cntnr.TransientPayload sess, err := s.d.SessionManager().FetchFromRequest(r.Context(), r) if err != nil { diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 3b5f7229170..6cb46e7a2bb 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -150,6 +150,8 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, loginFlo registrationFlow.RawIDTokenNonce = loginFlow.RawIDTokenNonce registrationFlow.RequestURL, err = x.TakeOverReturnToParameter(loginFlow.RequestURL, registrationFlow.RequestURL) registrationFlow.TransientPayload = loginFlow.TransientPayload + registrationFlow.Active = s.ID() + if err != nil { return nil, s.handleError(w, r, loginFlow, provider.Config().ID, nil, err) } diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index 124e8539f6e..999106b374b 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -272,6 +272,7 @@ func (s *Strategy) registrationToLogin(w http.ResponseWriter, r *http.Request, r return nil, err } lf.TransientPayload = rf.TransientPayload + lf.Active = s.ID() return lf, nil } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 0f41fd622d0..25d45486391 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -424,21 +424,34 @@ func TestStrategy(t *testing.T) { } t.Run("case=should pass registration", func(t *testing.T) { + postRegistrationWebhook := hooktest.NewServer() + t.Cleanup(postRegistrationWebhook.Close) + postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String())) + transientPayload := `{"data": "registration-one"}` + r := newBrowserRegistrationFlow(t, returnTS.URL, time.Minute) action := assertFormValues(t, r.ID, "valid") - res, body := makeRequest(t, "valid", action, url.Values{}) + res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}}) assertIdentity(t, res, body) expectTokens(t, "valid", body) + postRegistrationWebhook.AssertTransientPayload(t, transientPayload) }) t.Run("case=try another registration", func(t *testing.T) { + transientPayload := `{"data": "registration-two"}` + postLoginWebhook := hooktest.NewServer() + t.Cleanup(postLoginWebhook.Close) + postLoginWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, identity.CredentialsTypeOIDC.String())) + returnTo := fmt.Sprintf("%s/home?query=true", returnTS.URL) r := newBrowserRegistrationFlow(t, fmt.Sprintf("%s?return_to=%s", returnTS.URL, url.QueryEscape(returnTo)), time.Minute) action := assertFormValues(t, r.ID, "valid") - res, body := makeRequest(t, "valid", action, url.Values{}) + res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}}) assert.Equal(t, returnTo, res.Request.URL.String()) assertIdentity(t, res, body) expectTokens(t, "valid", body) + + postLoginWebhook.AssertTransientPayload(t, transientPayload) }) }) @@ -981,31 +994,48 @@ func TestStrategy(t *testing.T) { scope = []string{"openid"} t.Run("case=should pass registration", func(t *testing.T) { + postRegistrationWebhook := hooktest.NewServer() + t.Cleanup(postRegistrationWebhook.Close) + postRegistrationWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter, identity.CredentialsTypeOIDC.String())) + r := newBrowserRegistrationFlow(t, returnTS.URL, time.Minute) action := assertFormValues(t, r.ID, "valid") - res, body := makeRequest(t, "valid", action, url.Values{}) + transientPayload := `{"data": "registration-one"}` + res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}}) assertIdentity(t, res, body) + postRegistrationWebhook.AssertTransientPayload(t, transientPayload) }) t.Run("case=should pass second time registration", func(t *testing.T) { + postLoginWebhook := hooktest.NewServer() + t.Cleanup(postLoginWebhook.Close) + postLoginWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, identity.CredentialsTypeOIDC.String())) + r := newBrowserLoginFlow(t, returnTS.URL, time.Minute) action := assertFormValues(t, r.ID, "valid") - res, body := makeRequest(t, "valid", action, url.Values{}) + transientPayload := `{"data": "registration-two"}` + res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}}) assertIdentity(t, res, body) + postLoginWebhook.AssertTransientPayload(t, transientPayload) }) t.Run("case=should pass third time registration with return to", func(t *testing.T) { + postLoginWebhook := hooktest.NewServer() + t.Cleanup(postLoginWebhook.Close) + postLoginWebhook.SetConfig(t, conf.GetProvider(ctx), config.HookStrategyKey(config.ViperKeySelfServiceLoginAfter, identity.CredentialsTypeOIDC.String())) + returnTo := "/foo" r := newBrowserLoginFlow(t, fmt.Sprintf("%s?return_to=%s", returnTS.URL, returnTo), time.Minute) action := assertFormValues(t, r.ID, "valid") - res, body := makeRequest(t, "valid", action, url.Values{}) + transientPayload := `{"data": "registration-three"}` + res, body := makeRequest(t, "valid", action, url.Values{"transient_payload": {transientPayload}}) assert.True(t, strings.HasSuffix(res.Request.URL.String(), returnTo)) assertIdentity(t, res, body) + postLoginWebhook.AssertTransientPayload(t, transientPayload) }) }) t.Run("case=register, merge, and complete data", func(t *testing.T) { - for _, tc := range []struct{ name, provider string }{ {name: "idtoken", provider: "valid"}, {name: "userinfo", provider: "claimsViaUserInfo"},