Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for Proof Key For Code Exchange (PKCE) in OIDC social providers #4033

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .schemastore/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,12 @@
"enum": ["id_token", "userinfo"],
"default": "id_token",
"examples": ["id_token", "userinfo"]
},
"pkcs_method": {
"title": "PKCS Method",
"description": "PKCSMethod is a config to enable PKCS (Proof Key for Code Exchange) using the generic provider. Can be either `S256` (sends code_challenge and code_challenge_method=S256) to authorization endpoint) and `code_verifier` to token endpoint. Can be `plain` if its impossible to support S256. (sends code verifier == code_challenge and code_challenge_method=plain to authorization endpoint). Can be empty, in which case PKCS is disabled.",
"type": "string",
"enum": ["S256", "plain"]
}
},
"additionalProperties": false,
Expand Down
6 changes: 6 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,12 @@
"enum": ["id_token", "userinfo"],
"default": "id_token",
"examples": ["id_token", "userinfo"]
},
"pkcs_method": {
"title": "PKCS Method",
"description": "PKCSMethod is a config to enable PKCS (Proof Key for Code Exchange) using the generic provider. Can be either `S256` (sends code_challenge and code_challenge_method=S256) to authorization endpoint) and `code_verifier` to token endpoint. Can be `plain` if its impossible to support S256. (sends code verifier == code_challenge and code_challenge_method=plain to authorization endpoint). Can be empty, in which case PKCS is disabled.",
"type": "string",
"enum": ["S256", "plain"]
}
},
"additionalProperties": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,28 @@
}
}
},
{
"type": "input",
"group": "oidc",
"attributes": {
"name": "provider",
"type": "submit",
"value": "providerWithPKCS",
"disabled": false,
"node_type": "input"
},
"messages": [],
"meta": {
"label": {
"id": 1040002,
"text": "Sign up with providerWithPKCS",
"type": "info",
"context": {
"provider": "providerWithPKCS"
}
}
}
},
{
"type": "input",
"group": "oidc",
Expand Down
7 changes: 7 additions & 0 deletions selfservice/strategy/oidc/provider_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ type Configuration struct {
// endpoint to get the claims) or `id_token` (takes the claims from the id
// token). It defaults to `id_token`.
ClaimsSource string `json:"claims_source"`

// PKCSMethod is a config to enable PKCS (Proof Key for Code Exchange)
// using the generic provider. Can be either `S256` (sends code_challenge and code_challenge_method=S256)
// to authorization endpoint) and `code_verifier` to token endpoint.
// Can be `plain` if its impossible to support S256. (sends code verifier == code_challenge and code_challenge_method=plain to authorization endpoint)
// Can be empty, in which case PKCS is disabled.
PKCSMethod string `json:"pkcs_method"`
}

func (p Configuration) Redir(public *url.URL) string {
Expand Down
23 changes: 22 additions & 1 deletion selfservice/strategy/oidc/provider_generic_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ func (g *ProviderGenericOIDC) OAuth2(ctx context.Context) (*oauth2.Config, error

func (g *ProviderGenericOIDC) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption {
var options []oauth2.AuthCodeOption

if g.config.PKCSMethod != "" {
options = g.addPKCSURLOptions(r, options)
}
if isForced(r) {
options = append(options, oauth2.SetAuthURLParam("prompt", "login"))
}
Expand All @@ -96,6 +98,25 @@ func (g *ProviderGenericOIDC) AuthCodeURLOptions(r ider) []oauth2.AuthCodeOption
return options
}

func (g *ProviderGenericOIDC) addPKCSURLOptions(r ider, options []oauth2.AuthCodeOption) []oauth2.AuthCodeOption {
flow, err := g.reg.LoginFlowPersister().GetLoginFlow(context.Background(), r.GetID())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r could also be a registration flow, correct? In that case, we won't find a login flow. Also, the context needs to be bound to the user's request context, not context.Background().

I think you can just add flow.InternalContexter to the ider interface, and use r directly.

if err != nil {
return options
}
pkcsContext, err := GetPKCSContext(flow)
if err != nil {
return options
}
if pkcsContext.Verifier != "" && pkcsContext.Method == "S256" {
options = append(options, oauth2.S256ChallengeOption(pkcsContext.Verifier))
}
if pkcsContext.Verifier != "" && pkcsContext.Method == "plain" {
options = append(options, oauth2.SetAuthURLParam("code_challenge", string(pkcsContext.Verifier)))
options = append(options, oauth2.SetAuthURLParam("code_challenge_method", string(pkcsContext.Method)))
}
return options
}

func (g *ProviderGenericOIDC) verifyAndDecodeClaimsWithProvider(ctx context.Context, provider *gooidc.Provider, raw string) (*Claims, error) {
token, err := provider.VerifierContext(g.withHTTPClientContext(ctx), &gooidc.Config{ClientID: g.config.ClientID}).Verify(ctx, raw)
if err != nil {
Expand Down
63 changes: 62 additions & 1 deletion selfservice/strategy/oidc/provider_generic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/url"
"testing"

"golang.org/x/oauth2"

"github.com/ory/kratos/driver"
"github.com/ory/kratos/driver/config"
"github.com/ory/kratos/internal"
Expand All @@ -35,14 +37,21 @@ func makeOIDCClaims() json.RawMessage {
return claims
}

func makeAuthCodeURL(t *testing.T, r *login.Flow, reg *driver.RegistryDefault) string {
func makeAuthCodeURL(t *testing.T, r *login.Flow, reg *driver.RegistryDefault, pkcsMethods ...string) string {
var pkcsMethod string
if len(pkcsMethods) > 0 {
pkcsMethod = pkcsMethods[0]
} else {
pkcsMethod = ""
}
p := oidc.NewProviderGenericOIDC(&oidc.Configuration{
Provider: "generic",
ID: "valid",
ClientID: "client",
ClientSecret: "secret",
IssuerURL: "https://accounts.google.com",
Mapper: "file://./stub/hydra.schema.json",
PKCSMethod: pkcsMethod,
RequestedClaims: makeOIDCClaims(),
}, reg)
c, err := p.(oidc.OAuth2Provider).OAuth2(context.Background())
Expand Down Expand Up @@ -94,3 +103,55 @@ func TestProviderGenericOIDC_AddAuthCodeURLOptions(t *testing.T) {
assert.Contains(t, makeAuthCodeURL(t, r, reg), "claims="+url.QueryEscape(string(makeOIDCClaims())))
})
}

func TestProviderGenericOIDC_PKCS(t *testing.T) {
ctx := context.Background()
conf, reg := internal.NewFastRegistryWithMocks(t)
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, "https://ory.sh")

t.Run("case=PKCSMethod is set to S256", func(t *testing.T) {
r := &login.Flow{ID: x.NewUUID(), Refresh: true}
reg.LoginFlowPersister().CreateLoginFlow(ctx, r)
err := oidc.SetPKCSContext(r, oidc.PkcsContext{
Method: "S256",
Verifier: oauth2.GenerateVerifier(),
})
require.NoError(t, err)
err = reg.LoginFlowPersister().UpdateLoginFlow(ctx, r)
require.NoError(t, err)
actual, err := url.ParseRequestURI(makeAuthCodeURL(t, r, reg, "S256"))
require.NoError(t, err)
assert.Contains(t, actual.Query(), "code_challenge")
t.Logf("code_challenge: %s", actual.Query().Get("code_challenge"))
assert.Contains(t, actual.Query().Get("code_challenge_method"), "S256")
t.Logf("code_challenge_method: %s", actual.Query().Get("code_challenge_method"))
})
t.Run("case=PKCSMethod is set to plain", func(t *testing.T) {
r := &login.Flow{ID: x.NewUUID(), Refresh: true}
reg.LoginFlowPersister().CreateLoginFlow(ctx, r)
verifier := oauth2.GenerateVerifier()
err := oidc.SetPKCSContext(r, oidc.PkcsContext{
Method: "plain",
Verifier: verifier,
})
require.NoError(t, err)
err = reg.LoginFlowPersister().UpdateLoginFlow(ctx, r)
require.NoError(t, err)
actual, err := url.ParseRequestURI(makeAuthCodeURL(t, r, reg, "plain"))
require.NoError(t, err)
assert.Contains(t, actual.Query(), "code_challenge")
t.Logf("code_challenge: %s", actual.Query().Get("code_challenge"))
assert.Contains(t, actual.Query().Get("code_challenge_method"), "plain")
t.Logf("code_challenge_method: %s", actual.Query().Get("code_challenge_method"))
assert.Equal(t, actual.Query().Get("code_challenge"), verifier)
})
t.Run("case=PKCSMethod is empty", func(t *testing.T) {
r := &login.Flow{ID: x.NewUUID(), Refresh: true}
actual, err := url.ParseRequestURI(makeAuthCodeURL(t, r, reg))
require.NoError(t, err)
assert.NotContains(t, actual.Query(), "code_challenge")
t.Logf("code_challenge: %s", actual.Query().Get("code_challenge"))
assert.NotContains(t, actual.Query(), "code_challenge_method")
t.Logf("code_challenge_method: %s", actual.Query().Get("code_challenge_method"))
})
}
26 changes: 19 additions & 7 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ import (
"strings"
"time"

"github.com/ory/x/sqlxx"

"golang.org/x/exp/maps"

"github.com/ory/x/sqlxx"

"github.com/ory/x/urlx"

"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -426,7 +426,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt
var et *identity.CredentialsOIDCEncryptedTokens
switch p := provider.(type) {
case OAuth2Provider:
token, err := s.ExchangeCode(r.Context(), provider, code)
token, err := s.ExchangeCode(r.Context(), provider, code, req)
if err != nil {
s.forwardError(w, r, req, s.handleError(w, r, req, pid, nil, err))
return
Expand Down Expand Up @@ -510,7 +510,7 @@ func (s *Strategy) HandleCallback(w http.ResponseWriter, r *http.Request, ps htt
}
}

func (s *Strategy) ExchangeCode(ctx context.Context, provider Provider, code string) (token *oauth2.Token, err error) {
func (s *Strategy) ExchangeCode(ctx context.Context, provider Provider, code string, flow flow.Flow) (token *oauth2.Token, err error) {
ctx, span := s.d.Tracer(ctx).Tracer().Start(ctx, "strategy.oidc.ExchangeCode")
defer otelx.End(span, &err)
span.SetAttributes(attribute.String("provider_id", provider.Config().ID))
Expand All @@ -525,11 +525,23 @@ func (s *Strategy) ExchangeCode(ctx context.Context, provider Provider, code str
return nil, err
}
}

client := s.d.HTTPClient(ctx)
ctx = context.WithValue(ctx, oauth2.HTTPClient, client.HTTPClient)
token, err = te.Exchange(ctx, code)
return token, err
switch loginFlow := flow.(type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this not work for registration flows?

case *login.Flow:
if provider.Config().PKCSMethod != "" {
pkcsContext, err := GetPKCSContext(loginFlow)
if err != nil {
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Unable to decode PKCS context: %s", err))
}
if pkcsContext.Verifier != "" && (pkcsContext.Method == "S256" || pkcsContext.Method == "plain") {
return te.Exchange(ctx, code, oauth2.VerifierOption(pkcsContext.Verifier))
} else {
return nil, errors.Errorf("Invalid PKCS method: %s or empty verifier: %s", pkcsContext.Method, pkcsContext.Verifier)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a herodot error as well.

}
}
}
return te.Exchange(ctx, code)
default:
return nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("The chosen provider is not capable of exchanging an OAuth 2.0 code for an access token."))
}
Expand Down
63 changes: 56 additions & 7 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@ import (
"strings"
"time"

"github.com/ory/kratos/selfservice/strategy/idfirst"
"github.com/ory/x/stringsx"

"github.com/ory/kratos/selfservice/flowhelpers"

"github.com/julienschmidt/httprouter"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"

"github.com/ory/kratos/session"
"github.com/ory/kratos/text"

"github.com/ory/kratos/ui/node"
"github.com/ory/x/otelx"
"github.com/ory/x/sqlcon"
"github.com/ory/x/stringsx"

"github.com/ory/kratos/selfservice/flow/registration"

"github.com/ory/kratos/text"
"github.com/ory/kratos/selfservice/flowhelpers"
"github.com/ory/kratos/selfservice/strategy/idfirst"

"github.com/ory/kratos/continuity"

Expand All @@ -48,6 +48,13 @@ func (s *Strategy) RegisterLoginRoutes(r *x.RouterPublic) {
s.setRoutes(r)
}

const internalContextPKCSPath = "pkcs"

type PkcsContext struct {
Method string `json:"method"`
Verifier string `json:"verifier"`
}

// Update Login Flow with OpenID Connect Method
//
// swagger:model updateLoginFlowWithOidcMethod
Expand Down Expand Up @@ -255,6 +262,17 @@ func (s *Strategy) Login(w http.ResponseWriter, r *http.Request, f *login.Flow,
}

state := generateState(f.ID.String())

if provider.Config().PKCSMethod != "" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to be done for registration. And how about settings?

err := SetPKCSContext(f, PkcsContext{
Method: provider.Config().PKCSMethod,
Verifier: oauth2.GenerateVerifier(),
})
if err != nil {
return nil, s.handleError(w, r, f, pid, nil, err)
}
}

if code, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, f.ID); hasCode {
state.setCode(code.InitCode)
}
Expand Down Expand Up @@ -386,3 +404,34 @@ func (s *Strategy) PopulateLoginMethodIdentifierFirstCredentials(r *http.Request
func (s *Strategy) PopulateLoginMethodIdentifierFirstIdentification(r *http.Request, f *login.Flow) error {
return s.populateMethod(r, f, text.NewInfoLoginWith)
}

func SetPKCSContext(flow flow.InternalContexter, context PkcsContext) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move this and the next method into the selfservice/flow package. There is probably also some opportunity to make a generic getter/setter, as these methods are very similar to Get/SetDuplicateCredentials.

if flow.GetInternalContext() == nil {
flow.EnsureInternalContext()
}
bytes, err := sjson.SetBytes(
flow.GetInternalContext(),
internalContextPKCSPath,
context,
)
if err != nil {
return err
}
flow.SetInternalContext(bytes)

return nil
}

func GetPKCSContext(flow flow.InternalContexter) (*PkcsContext, error) {
if flow.GetInternalContext() == nil {
flow.EnsureInternalContext()
}
raw := gjson.GetBytes(flow.GetInternalContext(), internalContextPKCSPath)
if !raw.IsObject() {
return nil, nil
}
var context PkcsContext
err := json.Unmarshal([]byte(raw.Raw), &context)

return &context, err
}
4 changes: 3 additions & 1 deletion selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ func TestStrategy(t *testing.T) {
newOIDCProvider(t, ts, remotePublic, remoteAdmin, "claimsViaUserInfo", func(c *oidc.Configuration) {
c.ClaimsSource = oidc.ClaimsSourceUserInfo
}),
newOIDCProvider(t, ts, remotePublic, remoteAdmin, "providerWithPKCS", func(c *oidc.Configuration) {
c.PKCSMethod = "S256"
}),
oidc.Configuration{
Provider: "generic",
ID: "invalid-issuer",
Expand Down Expand Up @@ -1072,7 +1075,6 @@ func TestStrategy(t *testing.T) {
})
})
}

})

t.Run("case=should fail to register and return fresh login flow if email is already being used by password credentials", func(t *testing.T) {
Expand Down
Loading