Skip to content

Commit

Permalink
fix: sso provider id is not added to session data when credentials li…
Browse files Browse the repository at this point in the history
…nked (PS-362)
  • Loading branch information
splaunov committed Aug 2, 2024
1 parent 1a70648 commit 1b832af
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 37 deletions.
8 changes: 6 additions & 2 deletions selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,12 @@ continueLogin:
sess = session.NewInactiveSession()
}

method := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR)
sess.CompletedLoginForMethod(method)
method, err := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR, nil)
if err != nil {
h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, group, err)
return
}
sess.CompletedLoginForMethod(*method)
i = interim
break
}
Expand Down
7 changes: 5 additions & 2 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,11 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return err
}

method := strategy.CompletedAuthenticationMethod(ctx, sess.AMR)
sess.CompletedLoginForMethod(method)
method, err := strategy.CompletedAuthenticationMethod(ctx, sess.AMR, lc.CredentialsConfig)
if err != nil {
return err
}
sess.CompletedLoginForMethod(*method)

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Strategy interface {
RegisterLoginRoutes(*x.RouterPublic)
PopulateLoginMethod(r *http.Request, requestedAAL identity.AuthenticatorAssuranceLevel, sr *Flow) error
Login(w http.ResponseWriter, r *http.Request, f *Flow, sess *session.Session) (i *identity.Identity, err error)
CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods) session.AuthenticationMethod
CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods, credentialsConfig sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error)
}

type Strategies []Strategy
Expand Down
11 changes: 6 additions & 5 deletions selfservice/strategy/code/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"database/sql"
"encoding/json"
"github.com/ory/x/sqlxx"
"net/http"
"strings"

Expand Down Expand Up @@ -67,20 +68,20 @@ type updateLoginFlowWithCodeMethod struct {

func (s *Strategy) RegisterLoginRoutes(*x.RouterPublic) {}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, amr session.AuthenticationMethods) session.AuthenticationMethod {
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, amr session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
aal1Satisfied := lo.ContainsBy(amr, func(am session.AuthenticationMethod) bool {
return am.Method != identity.CredentialsTypeCodeAuth && am.AAL == identity.AuthenticatorAssuranceLevel1
})
if aal1Satisfied {
return session.AuthenticationMethod{
return &session.AuthenticationMethod{
Method: identity.CredentialsTypeCodeAuth,
AAL: identity.AuthenticatorAssuranceLevel2,
}
}, nil
}
return session.AuthenticationMethod{
return &session.AuthenticationMethod{
Method: identity.CredentialsTypeCodeAuth,
AAL: identity.AuthenticatorAssuranceLevel1,
}
}, nil
}

func (s *Strategy) HandleLoginError(r *http.Request, f *login.Flow, body *updateLoginFlowWithCodeMethod, err error) error {
Expand Down
7 changes: 4 additions & 3 deletions selfservice/strategy/lookup/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package lookup
import (
"context"
"encoding/json"
"github.com/ory/x/sqlxx"

"github.com/pkg/errors"

Expand Down Expand Up @@ -106,9 +107,9 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup {
return node.LookupGroup
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
return &session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel2,
}
}, nil
}
27 changes: 23 additions & 4 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/ory/x/sqlxx"
"net/http"
"net/url"
"path/filepath"
Expand Down Expand Up @@ -719,11 +720,17 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup {
return node.OpenIDConnectGroup
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel1,
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, credentialsConfig sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
credentialsOIDCProvider, err := s.getProvider(credentialsConfig)
if err != nil {
return nil, err
}

return &session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel1,
Provider: credentialsOIDCProvider.Provider,
}, nil
}

func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provider Provider, idToken, idTokenNonce string) (*Claims, error) {
Expand Down Expand Up @@ -840,3 +847,15 @@ func (s *Strategy) encryptOAuth2Tokens(ctx context.Context, token *oauth2.Token)

return et, nil
}

func (s *Strategy) getProvider(credentialsConfig sqlxx.JSONRawMessage) (identity.CredentialsOIDCProvider, error) {
var credentialsOIDCConfig identity.CredentialsOIDC
if err := json.Unmarshal(credentialsConfig, &credentialsOIDCConfig); err != nil {
return identity.CredentialsOIDCProvider{}, err
}
if len(credentialsOIDCConfig.Providers) != 1 {
return identity.CredentialsOIDCProvider{}, errors.New("No oidc provider was set")
}
credentialsOIDCProvider := credentialsOIDCConfig.Providers[0]
return credentialsOIDCProvider, nil
}
8 changes: 2 additions & 6 deletions selfservice/strategy/oidc/strategy_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,10 @@ func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, c
}

func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsConfig sqlxx.JSONRawMessage) error {
var credentialsOIDCConfig identity.CredentialsOIDC
if err := json.Unmarshal(credentialsConfig, &credentialsOIDCConfig); err != nil {
credentialsOIDCProvider, err := s.getProvider(credentialsConfig)
if err != nil {
return err
}
if len(credentialsOIDCConfig.Providers) != 1 {
return errors.New("No oidc provider was set")
}
credentialsOIDCProvider := credentialsOIDCConfig.Providers[0]

if err := s.linkCredentials(
ctx,
Expand Down
2 changes: 2 additions & 0 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ func TestStrategy(t *testing.T) {
}`,
expect: func(t *testing.T, res *http.Response, body []byte) {
require.NotEmpty(t, gjson.GetBytes(body, "session_token").String(), "%s", body)
require.Equal(t, "test-provider", gjson.GetBytes(body, "session.authentication_methods.0.provider").String(), "%s", body)
},
},
{
Expand Down Expand Up @@ -1273,6 +1274,7 @@ func TestStrategy(t *testing.T) {
assert.Equal(t, provider, gjson.GetBytes(i.Credentials["oidc"].Config, "providers.0.provider").String(),
"%s", string(i.Credentials["oidc"].Config[:]))
assert.Contains(t, gjson.GetBytes(body, "authentication_methods").String(), "oidc", "%s", body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.1.provider").String(), "%s", body)
}

t.Run("case=second login is password", func(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions selfservice/strategy/passkey/passkey_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package passkey
import (
"context"
"encoding/json"
"github.com/ory/x/sqlxx"

"github.com/pkg/errors"

Expand Down Expand Up @@ -88,11 +89,11 @@ func (*Strategy) NodeGroup() node.UiNodeGroup {
return node.PasskeyGroup
}

func (s *Strategy) CompletedAuthenticationMethod(context.Context, session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
func (s *Strategy) CompletedAuthenticationMethod(context.Context, session.AuthenticationMethods, sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
return &session.AuthenticationMethod{
Method: identity.CredentialsTypePasskey,
AAL: identity.AuthenticatorAssuranceLevel1,
}
}, nil
}

func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
Expand Down
7 changes: 4 additions & 3 deletions selfservice/strategy/password/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package password
import (
"context"
"encoding/json"
"github.com/ory/x/sqlxx"

"github.com/ory/kratos/ui/node"

Expand Down Expand Up @@ -109,11 +110,11 @@ func (s *Strategy) ID() identity.CredentialsType {
return identity.CredentialsTypePassword
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
return &session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel1,
}
}, nil
}

func (s *Strategy) NodeGroup() node.UiNodeGroup {
Expand Down
7 changes: 4 additions & 3 deletions selfservice/strategy/totp/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package totp
import (
"context"
"encoding/json"
"github.com/ory/x/sqlxx"

"github.com/pkg/errors"
"github.com/pquerna/otp"
Expand Down Expand Up @@ -109,9 +110,9 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup {
return node.TOTPGroup
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
return &session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel2,
}
}, nil
}
7 changes: 4 additions & 3 deletions selfservice/strategy/webauthn/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package webauthn
import (
"context"
"encoding/json"
"github.com/ory/x/sqlxx"

"github.com/pkg/errors"

Expand Down Expand Up @@ -114,13 +115,13 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup {
return node.WebAuthnGroup
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
aal := identity.AuthenticatorAssuranceLevel1
if !s.d.Config().WebAuthnForPasswordless(ctx) {
aal = identity.AuthenticatorAssuranceLevel2
}
return session.AuthenticationMethod{
return &session.AuthenticationMethod{
Method: s.ID(),
AAL: aal,
}
}, nil
}
8 changes: 6 additions & 2 deletions selfservice/strategy/webauthn/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@ func TestCompletedAuthenticationMethod(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
strategy := webauthn.NewStrategy(reg)

method, err := strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}, nil)
assert.NoError(t, err)
assert.Equal(t, session.AuthenticationMethod{
Method: strategy.ID(),
AAL: identity.AuthenticatorAssuranceLevel2,
}, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}))
}, *method)

conf.MustSet(ctx, config.ViperKeyWebAuthnPasswordless, true)
method, err = strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}, nil)
assert.NoError(t, err)
assert.Equal(t, session.AuthenticationMethod{
Method: strategy.ID(),
AAL: identity.AuthenticatorAssuranceLevel1,
}, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}))
}, *method)
}

func TestCountActiveFirstFactorCredentials(t *testing.T) {
Expand Down

0 comments on commit 1b832af

Please sign in to comment.