Skip to content

Commit

Permalink
Feature/integrate client credentials nd pcke verifie (#5)
Browse files Browse the repository at this point in the history
* implement client_credentials grant type

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

* include pkce_verifier + upgrade oauth2

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

* append issuer prefix to device redirectURI

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

* fix lint?

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

* fix test

Signed-off-by: Houssem Ben Mabrouk <[email protected]>

---------

Signed-off-by: Houssem Ben Mabrouk <[email protected]>
  • Loading branch information
orange-hbenmabrouk authored Mar 26, 2024
1 parent 7267b11 commit 9f14114
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 10 deletions.
83 changes: 80 additions & 3 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -40,6 +41,11 @@ type Config struct {

Scopes []string `json:"scopes"` // defaults to "profile" and "email"

PKCE struct {
// Configurable key which controls if pkce challenge should be created or not
Enabled bool `json:"enabled"` // defaults to "false"
} `json:"pkce"`

// HostedDomains was an optional list of whitelisted domains when using the OIDC connector with Google.
// Only users from a whitelisted domain were allowed to log in.
// Support for this option was removed from the OIDC connector.
Expand Down Expand Up @@ -247,6 +253,12 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
promptType = *c.PromptType
}

// pkce
pkceVerifier := ""
if c.PKCE.Enabled {
pkceVerifier = oauth2.GenerateVerifier()
}

clientID := c.ClientID
return &oidcConnector{
provider: provider,
Expand All @@ -261,6 +273,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
verifier: provider.Verifier(
&oidc.Config{ClientID: clientID},
),
pkceVerifier: pkceVerifier,
logger: logger,
cancel: cancel,
httpClient: httpClient,
Expand Down Expand Up @@ -290,6 +303,7 @@ type oidcConnector struct {
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
pkceVerifier string
cancel context.CancelFunc
logger log.Logger
httpClient *http.Client
Expand Down Expand Up @@ -328,6 +342,10 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType))
}

if c.pkceVerifier != "" {
opts = append(opts, oauth2.S256ChallengeOption(c.pkceVerifier))
}
return c.oauth2Config.AuthCodeURL(state, opts...), nil
}

Expand All @@ -351,17 +369,76 @@ const (
exchangeCaller
)

func (c *oidcConnector) getTokenViaClientCredentials() (token *oauth2.Token, err error) {
data := url.Values{
"grant_type": {"client_credentials"},
"client_id": {c.oauth2Config.ClientID},
"client_secret": {c.oauth2Config.ClientSecret},
"scope": {strings.Join(c.oauth2Config.Scopes, " ")},
}

resp, err := c.httpClient.PostForm(c.oauth2Config.Endpoint.TokenURL, data)
if err != nil {
return nil, fmt.Errorf("oidc: failed to get token: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("oidc: issuer returned an error: %v", resp.Status)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("oidc: failed to get read token body: %v", err)
}

type AccessTokenType struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
response := AccessTokenType{}
if err = json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("oidc: unable to parse response: %v", err)
}

token = &oauth2.Token{
AccessToken: response.AccessToken,
Expiry: time.Now().Add(time.Second * time.Duration(response.ExpiresIn)),
}
raw := make(map[string]interface{})
json.Unmarshal(body, &raw) // no error checks for optional fields
token = token.WithExtra(raw)

return token, nil
}

func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
}

ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient)
var token *oauth2.Token
if q.Has("code") {
// exchange code to token
var opts []oauth2.AuthCodeOption

token, err := c.oauth2Config.Exchange(ctx, q.Get("code"))
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
if c.pkceVerifier != "" {
opts = append(opts, oauth2.VerifierOption(c.pkceVerifier))
}

token, err = c.oauth2Config.Exchange(ctx, q.Get("code"), opts...)
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}
} else {
// get token via client_credentials
token, err = c.getTokenViaClientCredentials()
if err != nil {
return identity, err
}
}
return c.createIdentity(ctx, identity, token, createCaller)
}
Expand Down
36 changes: 36 additions & 0 deletions connector/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func TestHandleCallback(t *testing.T) {
expectPreferredUsername string
expectedEmailField string
token map[string]interface{}
pkce bool
newGroupFromClaims []NewGroupFromClaims
}{
{
Expand Down Expand Up @@ -363,6 +364,40 @@ func TestHandleCallback(t *testing.T) {
"non-string-claim2": 666,
},
},
{
name: "withPKCE",
userIDKey: "", // not configured
userNameKey: "", // not configured
expectUserID: "subvalue",
expectUserName: "namevalue",
expectGroups: []string{"group1", "group2"},
expectedEmailField: "emailvalue",
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
"groups": []string{"group1", "group2"},
"email": "emailvalue",
"email_verified": true,
},
pkce: true,
},
{
name: "withoutPKCE",
userIDKey: "", // not configured
userNameKey: "", // not configured
expectUserID: "subvalue",
expectUserName: "namevalue",
expectGroups: []string{"group1", "group2"},
expectedEmailField: "emailvalue",
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
"groups": []string{"group1", "group2"},
"email": "emailvalue",
"email_verified": true,
},
pkce: false,
},
}

for _, tc := range tests {
Expand Down Expand Up @@ -399,6 +434,7 @@ func TestHandleCallback(t *testing.T) {
config.ClaimMapping.EmailKey = tc.emailKey
config.ClaimMapping.GroupsKey = tc.groupsKey
config.ClaimMutations.NewGroupFromClaims = tc.newGroupFromClaims
config.PKCE.Enabled = tc.pkce

conn, err := newConnector(config)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions server/deviceflowhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
// Make device code
deviceCode := storage.NewDeviceCode()

// make user code
// Make user code
userCode := storage.NewUserCode()

// Generate the expire time
Expand Down Expand Up @@ -434,7 +434,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
q.Set("client_secret", deviceRequest.ClientSecret)
q.Set("state", deviceRequest.UserCode)
q.Set("response_type", "code")
q.Set("redirect_uri", "/device/callback")
q.Set("redirect_uri", fmt.Sprintf("%s/device/callback", s.issuerURL.Path))
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
u.RawQuery = q.Encode()

Expand Down
2 changes: 2 additions & 0 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
s.withClientFromStorage(w, r, s.handleAuthCode)
case grantTypeRefreshToken:
s.withClientFromStorage(w, r, s.handleRefreshToken)
case grantTypeClientCredentials:
s.withClientFromStorage(w, r, s.handleClientCredentials)
case grantTypePassword:
s.withClientFromStorage(w, r, s.handlePasswordGrant)
case grantTypeTokenExchange:
Expand Down
1 change: 1 addition & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ const (
grantTypeImplicit = "implicit"
grantTypePassword = "password"
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
grantTypeClientCredentials = "client_credentials"
grantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange"
)

Expand Down
61 changes: 61 additions & 0 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,64 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
s.writeAccessToken(w, resp)
}

func (s *Server) handleClientCredentials(w http.ResponseWriter, r *http.Request, client storage.Client) {
// Parse the fields
if err := r.ParseForm(); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest)
return
}
q := r.Form

scopes := strings.Fields(q.Get("scope"))
nonce := ""
connID := q.Get("connector_id")

// Which connector
conn, err := s.getConnector(connID)
if err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}

callbackConnector, ok := conn.Connector.(connector.CallbackConnector)
if !ok {
s.tokenErrHelper(w, errInvalidRequest, "Requested callback connector does not correct type.", http.StatusBadRequest)
return
}

// Login
identity, err := callbackConnector.HandleCallback(parseScopes(scopes), r)
if err != nil {
s.logger.Errorf("Failed to login user: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest)
return
}

// Build the claims to send the id token
claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
PreferredUsername: identity.PreferredUsername,
Email: identity.Email,
EmailVerified: identity.EmailVerified,
Groups: identity.Groups,
}

accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
if err != nil {
s.logger.Errorf("client grant failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, "", connID)
if err != nil {
s.logger.Errorf("client grant failed to create new ID token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

resp := s.toAccessTokenResponse(idToken, accessToken, "", expiry)
s.writeAccessToken(w, resp)
}
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
grantTypeRefreshToken: true,
grantTypeDeviceCode: true,
grantTypeTokenExchange: true,
grantTypeClientCredentials: true,
}
supportedRes := make(map[string]bool)

Expand Down
11 changes: 6 additions & 5 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
grantTypeTokenExchange,
grantTypeImplicit,
grantTypePassword,
grantTypeClientCredentials,
},
}
if updateConfig != nil {
Expand Down Expand Up @@ -1631,7 +1632,7 @@ func TestOAuth2DeviceFlow(t *testing.T) {
// Add the Clients to the test server
client := storage.Client{
ID: clientID,
RedirectURIs: []string{deviceCallbackURI},
RedirectURIs: []string{"/non-root-path" + deviceCallbackURI},
Public: true,
}
if err := s.storage.CreateClient(ctx, client); err != nil {
Expand Down Expand Up @@ -1765,7 +1766,7 @@ func TestServerSupportedGrants(t *testing.T) {
{
name: "Simple",
config: func(c *Config) {},
resGrants: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
resGrants: []string{grantTypeAuthorizationCode, grantTypeClientCredentials, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
},
{
name: "Minimal",
Expand All @@ -1775,20 +1776,20 @@ func TestServerSupportedGrants(t *testing.T) {
{
name: "With password connector",
config: func(c *Config) { c.PasswordConnector = "local" },
resGrants: []string{grantTypeAuthorizationCode, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
resGrants: []string{grantTypeAuthorizationCode, grantTypeClientCredentials, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
},
{
name: "With token response",
config: func(c *Config) { c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken) },
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
resGrants: []string{grantTypeAuthorizationCode, grantTypeClientCredentials, grantTypeImplicit, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
},
{
name: "All",
config: func(c *Config) {
c.PasswordConnector = "local"
c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken)
},
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
resGrants: []string{grantTypeAuthorizationCode, grantTypeClientCredentials, grantTypeImplicit, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
},
}

Expand Down

0 comments on commit 9f14114

Please sign in to comment.