Skip to content

Commit

Permalink
Merge pull request #19 from LyricTian/develop
Browse files Browse the repository at this point in the history
Fix error handling
  • Loading branch information
LyricTian authored Jul 29, 2016
2 parents 69079cf + 217243c commit becfcf0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 28 deletions.
9 changes: 4 additions & 5 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package server

import (
"net/http"

"time"

"gopkg.in/oauth2.v3"
Expand All @@ -25,7 +24,7 @@ type UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (user
type PasswordAuthorizationHandler func(username, password string) (userID string, err error)

// RefreshingScopeHandler Check the scope of the refreshing token
type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool)
type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error)

// ResponseErrorHandler Response error handing
type ResponseErrorHandler func(re *errors.Response)
Expand All @@ -34,13 +33,13 @@ type ResponseErrorHandler func(re *errors.Response)
type InternalErrorHandler func(r *http.Request, err error)

// AuthorizeScopeHandler Set the authorized scope
type AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string)
type AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error)

// AccessTokenExpHandler Set expiration date for the access token
type AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration)
type AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error)

// ExtensionFieldsHandler In response to the access token with the extension of the field
type ExtensionFieldsHandler func(w http.ResponseWriter, r *http.Request) (fieldsValue map[string]interface{})
type ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{})

// ClientFormHandler Get client data from form
func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) {
Expand Down
57 changes: 34 additions & 23 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,21 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
req.UserID = userID
// specify the scope of authorization
if fn := s.AuthorizeScopeHandler; fn != nil {
if scope := fn(w, r); scope != "" {
scope, verr := fn(w, r)
if verr != nil {
err = verr
return
} else if scope != "" {
req.Scope = scope
}
}
// specify the expiration time of access token
if fn := s.AccessTokenExpHandler; fn != nil {
if exp := fn(w, r); exp > 0 {
exp, verr := fn(w, r)
if verr != nil {
err = verr
return
} else if exp > 0 {
req.AccessTokenExp = exp
}
}
Expand Down Expand Up @@ -403,8 +411,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
if err != nil {
ierr = err
return
}
if !allowed {
} else if !allowed {
rerr = errors.ErrUnauthorizedClient
return
}
Expand All @@ -427,8 +434,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
if err != nil {
ierr = err
return
}
if !allowed {
} else if !allowed {
rerr = errors.ErrInvalidScope
return
}
Expand All @@ -441,17 +447,23 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
}
}
case oauth2.Refreshing:
if scope := tgr.Scope; scope != "" {
// check scope
if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil {
rti, err := s.Manager.LoadRefreshToken(tgr.Refresh)
if err != nil {
if err == errors.ErrInvalidRefreshToken {
rerr = err
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
rerr = errors.ErrInvalidGrant
return
}
ierr = err
return
}
if fn := s.RefreshingScopeHandler; fn != nil && !fn(scope, rti.GetScope()) {

allowed, err := scopeFn(scope, rti.GetScope())
if err != nil {
ierr = err
return
} else if !allowed {
rerr = errors.ErrInvalidScope
return
}
Expand All @@ -461,8 +473,8 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
if ierr == errors.ErrInvalidClient {
rerr = errors.ErrInvalidClient
ierr = nil
} else if ierr == errors.ErrInvalidRefreshToken {
rerr = errors.ErrInvalidRefreshToken
} else if ierr == errors.ErrInvalidRefreshToken || ierr == errors.ErrExpiredRefreshToken {
rerr = errors.ErrInvalidGrant
ierr = nil
}
}
Expand All @@ -484,6 +496,15 @@ func (s *Server) GetTokenData(ti oauth2.TokenInfo) (data map[string]interface{})
if refresh := ti.GetRefresh(); refresh != "" {
data["refresh_token"] = refresh
}
if fn := s.ExtensionFieldsHandler; fn != nil {
ext := fn(ti)
for k, v := range ext {
if _, ok := data[k]; ok {
continue
}
data[k] = v
}
}
return
}

Expand All @@ -504,17 +525,7 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err
err = s.resTokenError(w, r, rerr, ierr)
return
}
tokenData := s.GetTokenData(ti)
if fn := s.ExtensionFieldsHandler; fn != nil {
ext := fn(w, r)
for k, v := range ext {
if _, ok := tokenData[k]; ok {
continue
}
tokenData[k] = v
}
}
err = s.resToken(w, tokenData)
err = s.resToken(w, s.GetTokenData(ti))
return
}

Expand Down

0 comments on commit becfcf0

Please sign in to comment.