Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: oidc provider id is not added to session data when credentials linked #4022

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
8 changes: 6 additions & 2 deletions selfservice/flow/login/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,12 @@ continueLogin:
sess = session.NewInactiveSession()
}

method := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR)
sess.CompletedLoginForMethod(method)
method, err := ss.CompletedAuthenticationMethod(r.Context(), sess.AMR, nil)
if err != nil {
h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, group, err)
return
}
sess.CompletedLoginForMethod(*method)
i = interim
break
}
Expand Down
7 changes: 5 additions & 2 deletions selfservice/flow/login/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,11 @@ func (e *HookExecutor) maybeLinkCredentials(ctx context.Context, sess *session.S
return err
}

method := strategy.CompletedAuthenticationMethod(ctx, sess.AMR)
sess.CompletedLoginForMethod(method)
method, err := strategy.CompletedAuthenticationMethod(ctx, sess.AMR, lc.CredentialsConfig)
if err != nil {
return err
}
sess.CompletedLoginForMethod(*method)

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/login/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Strategy interface {
RegisterLoginRoutes(*x.RouterPublic)
PopulateLoginMethod(r *http.Request, requestedAAL identity.AuthenticatorAssuranceLevel, sr *Flow) error
Login(w http.ResponseWriter, r *http.Request, f *Flow, sess *session.Session) (i *identity.Identity, err error)
CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods) session.AuthenticationMethod
CompletedAuthenticationMethod(ctx context.Context, methods session.AuthenticationMethods, credentialsConfig sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error)
}

type Strategies []Strategy
Expand Down
12 changes: 7 additions & 5 deletions selfservice/strategy/code/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"net/http"
"strings"

"github.com/ory/x/sqlxx"

"github.com/ory/x/sqlcon"

"github.com/pkg/errors"
Expand Down Expand Up @@ -67,20 +69,20 @@ type updateLoginFlowWithCodeMethod struct {

func (s *Strategy) RegisterLoginRoutes(*x.RouterPublic) {}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, amr session.AuthenticationMethods) session.AuthenticationMethod {
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, amr session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
aal1Satisfied := lo.ContainsBy(amr, func(am session.AuthenticationMethod) bool {
return am.Method != identity.CredentialsTypeCodeAuth && am.AAL == identity.AuthenticatorAssuranceLevel1
})
if aal1Satisfied {
return session.AuthenticationMethod{
return &session.AuthenticationMethod{
Method: identity.CredentialsTypeCodeAuth,
AAL: identity.AuthenticatorAssuranceLevel2,
}
}, nil
}
return session.AuthenticationMethod{
return &session.AuthenticationMethod{
Method: identity.CredentialsTypeCodeAuth,
AAL: identity.AuthenticatorAssuranceLevel1,
}
}, nil
}

func (s *Strategy) HandleLoginError(r *http.Request, f *login.Flow, body *updateLoginFlowWithCodeMethod, err error) error {
Expand Down
8 changes: 5 additions & 3 deletions selfservice/strategy/lookup/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"context"
"encoding/json"

"github.com/ory/x/sqlxx"

"github.com/pkg/errors"

"github.com/ory/kratos/continuity"
Expand Down Expand Up @@ -106,9 +108,9 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup {
return node.LookupGroup
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
return &session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel2,
}
}, nil
}
28 changes: 24 additions & 4 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"strings"
"time"

"github.com/ory/x/sqlxx"

"golang.org/x/exp/maps"

"github.com/ory/x/urlx"
Expand Down Expand Up @@ -719,11 +721,17 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup {
return node.OpenIDConnectGroup
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel1,
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, credentialsConfig sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
credentialsOIDCProvider, err := s.getProvider(credentialsConfig)
if err != nil {
return nil, err
}

return &session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel1,
Provider: credentialsOIDCProvider.Provider,
}, nil
}

func (s *Strategy) processIDToken(w http.ResponseWriter, r *http.Request, provider Provider, idToken, idTokenNonce string) (*Claims, error) {
Expand Down Expand Up @@ -840,3 +848,15 @@ func (s *Strategy) encryptOAuth2Tokens(ctx context.Context, token *oauth2.Token)

return et, nil
}

func (s *Strategy) getProvider(credentialsConfig sqlxx.JSONRawMessage) (identity.CredentialsOIDCProvider, error) {
var credentialsOIDCConfig identity.CredentialsOIDC
if err := json.Unmarshal(credentialsConfig, &credentialsOIDCConfig); err != nil {
return identity.CredentialsOIDCProvider{}, err
}
if len(credentialsOIDCConfig.Providers) != 1 {
return identity.CredentialsOIDCProvider{}, errors.New("No oidc provider was set")
}
credentialsOIDCProvider := credentialsOIDCConfig.Providers[0]
return credentialsOIDCProvider, nil
}
8 changes: 2 additions & 6 deletions selfservice/strategy/oidc/strategy_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,10 @@ func (s *Strategy) handleSettingsError(w http.ResponseWriter, r *http.Request, c
}

func (s *Strategy) Link(ctx context.Context, i *identity.Identity, credentialsConfig sqlxx.JSONRawMessage) error {
var credentialsOIDCConfig identity.CredentialsOIDC
if err := json.Unmarshal(credentialsConfig, &credentialsOIDCConfig); err != nil {
credentialsOIDCProvider, err := s.getProvider(credentialsConfig)
if err != nil {
return err
}
if len(credentialsOIDCConfig.Providers) != 1 {
return errors.New("No oidc provider was set")
}
credentialsOIDCProvider := credentialsOIDCConfig.Providers[0]

if err := s.linkCredentials(
ctx,
Expand Down
2 changes: 2 additions & 0 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ func TestStrategy(t *testing.T) {
}`,
expect: func(t *testing.T, res *http.Response, body []byte) {
require.NotEmpty(t, gjson.GetBytes(body, "session_token").String(), "%s", body)
require.Equal(t, "test-provider", gjson.GetBytes(body, "session.authentication_methods.0.provider").String(), "%s", body)
},
},
{
Expand Down Expand Up @@ -1273,6 +1274,7 @@ func TestStrategy(t *testing.T) {
assert.Equal(t, provider, gjson.GetBytes(i.Credentials["oidc"].Config, "providers.0.provider").String(),
"%s", string(i.Credentials["oidc"].Config[:]))
assert.Contains(t, gjson.GetBytes(body, "authentication_methods").String(), "oidc", "%s", body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.1.provider").String(), "%s", body)
}

t.Run("case=second login is password", func(t *testing.T) {
Expand Down
8 changes: 5 additions & 3 deletions selfservice/strategy/passkey/passkey_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"context"
"encoding/json"

"github.com/ory/x/sqlxx"

"github.com/pkg/errors"

"github.com/ory/kratos/continuity"
Expand Down Expand Up @@ -88,11 +90,11 @@ func (*Strategy) NodeGroup() node.UiNodeGroup {
return node.PasskeyGroup
}

func (s *Strategy) CompletedAuthenticationMethod(context.Context, session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
func (s *Strategy) CompletedAuthenticationMethod(context.Context, session.AuthenticationMethods, sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
return &session.AuthenticationMethod{
Method: identity.CredentialsTypePasskey,
AAL: identity.AuthenticatorAssuranceLevel1,
}
}, nil
}

func (s *Strategy) CountActiveMultiFactorCredentials(cc map[identity.CredentialsType]identity.Credentials) (count int, err error) {
Expand Down
8 changes: 5 additions & 3 deletions selfservice/strategy/password/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"context"
"encoding/json"

"github.com/ory/x/sqlxx"

"github.com/ory/kratos/ui/node"

"github.com/go-playground/validator/v10"
Expand Down Expand Up @@ -109,11 +111,11 @@ func (s *Strategy) ID() identity.CredentialsType {
return identity.CredentialsTypePassword
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
return &session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel1,
}
}, nil
}

func (s *Strategy) NodeGroup() node.UiNodeGroup {
Expand Down
8 changes: 5 additions & 3 deletions selfservice/strategy/totp/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"context"
"encoding/json"

"github.com/ory/x/sqlxx"

"github.com/pkg/errors"
"github.com/pquerna/otp"

Expand Down Expand Up @@ -109,9 +111,9 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup {
return node.TOTPGroup
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
return session.AuthenticationMethod{
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
return &session.AuthenticationMethod{
Method: s.ID(),
AAL: identity.AuthenticatorAssuranceLevel2,
}
}, nil
}
8 changes: 5 additions & 3 deletions selfservice/strategy/webauthn/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"context"
"encoding/json"

"github.com/ory/x/sqlxx"

"github.com/pkg/errors"

"github.com/ory/kratos/continuity"
Expand Down Expand Up @@ -114,13 +116,13 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup {
return node.WebAuthnGroup
}

func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods) session.AuthenticationMethod {
func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context, _ session.AuthenticationMethods, _ sqlxx.JSONRawMessage) (*session.AuthenticationMethod, error) {
aal := identity.AuthenticatorAssuranceLevel1
if !s.d.Config().WebAuthnForPasswordless(ctx) {
aal = identity.AuthenticatorAssuranceLevel2
}
return session.AuthenticationMethod{
return &session.AuthenticationMethod{
Method: s.ID(),
AAL: aal,
}
}, nil
}
8 changes: 6 additions & 2 deletions selfservice/strategy/webauthn/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@ func TestCompletedAuthenticationMethod(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
strategy := webauthn.NewStrategy(reg)

method, err := strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}, nil)
assert.NoError(t, err)
assert.Equal(t, session.AuthenticationMethod{
Method: strategy.ID(),
AAL: identity.AuthenticatorAssuranceLevel2,
}, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}))
}, *method)

conf.MustSet(ctx, config.ViperKeyWebAuthnPasswordless, true)
method, err = strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}, nil)
assert.NoError(t, err)
assert.Equal(t, session.AuthenticationMethod{
Method: strategy.ID(),
AAL: identity.AuthenticatorAssuranceLevel1,
}, strategy.CompletedAuthenticationMethod(context.Background(), session.AuthenticationMethods{}))
}, *method)
}

func TestCountActiveFirstFactorCredentials(t *testing.T) {
Expand Down