From 217243c4826b01307bc0bf4b919bcff2d3b5e0b5 Mon Sep 17 00:00:00 2001 From: lyric Date: Fri, 29 Jul 2016 09:36:47 +0800 Subject: [PATCH] Fix error handling --- server/handler.go | 9 ++++---- server/server.go | 57 ++++++++++++++++++++++++++++------------------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/server/handler.go b/server/handler.go index 7567405..f1879a6 100644 --- a/server/handler.go +++ b/server/handler.go @@ -2,7 +2,6 @@ package server import ( "net/http" - "time" "gopkg.in/oauth2.v3" @@ -25,7 +24,7 @@ type UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (user type PasswordAuthorizationHandler func(username, password string) (userID string, err error) // RefreshingScopeHandler Check the scope of the refreshing token -type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool) +type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error) // ResponseErrorHandler Response error handing type ResponseErrorHandler func(re *errors.Response) @@ -34,13 +33,13 @@ type ResponseErrorHandler func(re *errors.Response) type InternalErrorHandler func(r *http.Request, err error) // AuthorizeScopeHandler Set the authorized scope -type AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string) +type AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) // AccessTokenExpHandler Set expiration date for the access token -type AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration) +type AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) // ExtensionFieldsHandler In response to the access token with the extension of the field -type ExtensionFieldsHandler func(w http.ResponseWriter, r *http.Request) (fieldsValue map[string]interface{}) +type ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) // ClientFormHandler Get client data from form func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) { diff --git a/server/server.go b/server/server.go index 3429909..523d1cd 100644 --- a/server/server.go +++ b/server/server.go @@ -322,13 +322,21 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) req.UserID = userID // specify the scope of authorization if fn := s.AuthorizeScopeHandler; fn != nil { - if scope := fn(w, r); scope != "" { + scope, verr := fn(w, r) + if verr != nil { + err = verr + return + } else if scope != "" { req.Scope = scope } } // specify the expiration time of access token if fn := s.AccessTokenExpHandler; fn != nil { - if exp := fn(w, r); exp > 0 { + exp, verr := fn(w, r) + if verr != nil { + err = verr + return + } else if exp > 0 { req.AccessTokenExp = exp } } @@ -403,8 +411,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe if err != nil { ierr = err return - } - if !allowed { + } else if !allowed { rerr = errors.ErrUnauthorizedClient return } @@ -427,8 +434,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe if err != nil { ierr = err return - } - if !allowed { + } else if !allowed { rerr = errors.ErrInvalidScope return } @@ -441,17 +447,23 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe } } case oauth2.Refreshing: - if scope := tgr.Scope; scope != "" { + // check scope + if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil { rti, err := s.Manager.LoadRefreshToken(tgr.Refresh) if err != nil { - if err == errors.ErrInvalidRefreshToken { - rerr = err + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + rerr = errors.ErrInvalidGrant return } ierr = err return } - if fn := s.RefreshingScopeHandler; fn != nil && !fn(scope, rti.GetScope()) { + + allowed, err := scopeFn(scope, rti.GetScope()) + if err != nil { + ierr = err + return + } else if !allowed { rerr = errors.ErrInvalidScope return } @@ -461,8 +473,8 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe if ierr == errors.ErrInvalidClient { rerr = errors.ErrInvalidClient ierr = nil - } else if ierr == errors.ErrInvalidRefreshToken { - rerr = errors.ErrInvalidRefreshToken + } else if ierr == errors.ErrInvalidRefreshToken || ierr == errors.ErrExpiredRefreshToken { + rerr = errors.ErrInvalidGrant ierr = nil } } @@ -484,6 +496,15 @@ func (s *Server) GetTokenData(ti oauth2.TokenInfo) (data map[string]interface{}) if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } + if fn := s.ExtensionFieldsHandler; fn != nil { + ext := fn(ti) + for k, v := range ext { + if _, ok := data[k]; ok { + continue + } + data[k] = v + } + } return } @@ -504,17 +525,7 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err err = s.resTokenError(w, r, rerr, ierr) return } - tokenData := s.GetTokenData(ti) - if fn := s.ExtensionFieldsHandler; fn != nil { - ext := fn(w, r) - for k, v := range ext { - if _, ok := tokenData[k]; ok { - continue - } - tokenData[k] = v - } - } - err = s.resToken(w, tokenData) + err = s.resToken(w, s.GetTokenData(ti)) return }