From da5af9b9c855b0bea62b38ada6ba059e813fe566 Mon Sep 17 00:00:00 2001 From: lyric Date: Tue, 12 Jul 2016 17:25:07 +0800 Subject: [PATCH] Add fasthttp server handle --- README.md | 44 ++++++++-- example/README.md | 5 +- example/fastserver/main.go | 50 +++++++++++ server/authorize.go | 78 +++++++++++++---- server/fastserver.go | 175 +++++++++++++++++++++++++++++++++++++ server/server.go | 89 +++++++++---------- 6 files changed, 365 insertions(+), 76 deletions(-) create mode 100644 example/fastserver/main.go create mode 100644 server/fastserver.go diff --git a/README.md b/README.md index c807e07..e6c9532 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -OAuth2服务端 -=========== +基于Golang的OAuth2服务实现 +======================= -> 基于Golang实现的OAuth2协议,具有简单化、模块化的特点 +> 完全模块化、支持http/fasthttp的服务端处理、令牌存储支持redis/mongodb [![GoDoc](https://godoc.org/gopkg.in/oauth2.v2?status.svg)](https://godoc.org/gopkg.in/oauth2.v2) [![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v2)](https://goreportcard.com/report/gopkg.in/oauth2.v2) @@ -9,12 +9,12 @@ OAuth2服务端 获取 ---- -```bash +``` bash $ go get -u gopkg.in/oauth2.v2/... ``` -使用 ----- +HTTP服务端 +-------- ``` go package main @@ -64,15 +64,43 @@ func main() { ``` +FastHTTP服务端 +------------- + +``` go +srv := server.NewFastServer(server.NewConfig(), manager) + +fasthttp.ListenAndServe(":9096", func(ctx *fasthttp.RequestCtx) { + switch string(ctx.Request.URI().Path()) { + case "/authorize": + authReq, err := srv.GetAuthorizeRequest(ctx) + if err != nil { + ctx.Error(err.Error(), 400) + return + } + authReq.UserID = "000000" + // TODO: 登录验证、授权处理 + err = srv.HandleAuthorizeRequest(ctx, authReq) + if err != nil { + ctx.Error(err.Error(), 400) + } + case "/token": + err := srv.HandleTokenRequest(ctx) + if err != nil { + ctx.Error(err.Error(), 400) + } + } +}) +``` + 测试 ---- +> [goconvey](https://github.com/smartystreets/goconvey) ``` bash $ goconvey -port=9092 ``` -> goconvey使用明细[https://github.com/smartystreets/goconvey](https://github.com/smartystreets/goconvey) - 范例 ---- diff --git a/example/README.md b/example/README.md index 2e40d62..ea29300 100644 --- a/example/README.md +++ b/example/README.md @@ -1,8 +1,9 @@ -OAuth2 服务端/客户端模拟 -===================== +OAuth2授权码模式模拟 +================= 运行服务端 -------- +> 运行fasthttp服务端,请使用`cd example/fastserver` ``` $ cd example/server diff --git a/example/fastserver/main.go b/example/fastserver/main.go new file mode 100644 index 0000000..fa07b5d --- /dev/null +++ b/example/fastserver/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "log" + + "github.com/valyala/fasthttp" + "gopkg.in/oauth2.v2/manage" + "gopkg.in/oauth2.v2/models" + "gopkg.in/oauth2.v2/server" + "gopkg.in/oauth2.v2/store/client" + "gopkg.in/oauth2.v2/store/token" +) + +func main() { + // 创建基于redis的oauth2管理实例 + manager := manage.NewRedisManager( + &token.RedisConfig{Addr: "192.168.33.70:6379"}, + ) + // 使用临时客户端存储 + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "222222", + Secret: "22222222", + Domain: "http://localhost:9094", + })) + + srv := server.NewFastServer(server.NewConfig(), manager) + + log.Println("OAuth2 server is running at 9096 port.") + fasthttp.ListenAndServe(":9096", func(ctx *fasthttp.RequestCtx) { + switch string(ctx.Request.URI().Path()) { + case "/authorize": + authReq, err := srv.GetAuthorizeRequest(ctx) + if err != nil { + ctx.Error(err.Error(), 400) + return + } + authReq.UserID = "000000" + // TODO: 登录验证、授权处理 + err = srv.HandleAuthorizeRequest(ctx, authReq) + if err != nil { + ctx.Error(err.Error(), 400) + } + case "/token": + err := srv.HandleTokenRequest(ctx) + if err != nil { + ctx.Error(err.Error(), 400) + } + } + }) +} diff --git a/server/authorize.go b/server/authorize.go index 0f59523..a018e37 100644 --- a/server/authorize.go +++ b/server/authorize.go @@ -1,37 +1,39 @@ package server import ( + "encoding/base64" "net/http" + "strings" + "github.com/valyala/fasthttp" "gopkg.in/oauth2.v2" ) // AuthorizeRequest 授权请求 type AuthorizeRequest struct { - Type oauth2.ResponseType - ClientID string - Scope string - RedirectURI string - State string - UserID string + Type oauth2.ResponseType // 授权类型 + ClientID string // 客户端标识 + Scope string // 授权范围 + RedirectURI string // 重定向URI + State string // 状态 + UserID string // 用户标识 } -// ClientHandler 客户端处理(获取请求的客户端认证信息) -type ClientHandler func(r *http.Request) (clientID, clientSecret string, err error) - -// UserHandler 用户处理(密码模式,根据用户名、密码获取用户标识) -type UserHandler func(username, password string) (userID string, err error) - -// ScopeHandler 授权范围处理(更新令牌时的授权范围检查) -type ScopeHandler func(new, old string) (err error) - // TokenRequestHandler 令牌请求处理 type TokenRequestHandler struct { + // 客户端信息处理 ClientHandler ClientHandler - UserHandler UserHandler - ScopeHandler ScopeHandler + // 客户端信息处理(基于fasthttp) + ClientFastHandler ClientFastHandler + // 用户信息处理 + UserHandler UserHandler + // 授权范围处理 + ScopeHandler ScopeHandler } +// ClientHandler 获取请求的客户端认证信息 +type ClientHandler func(r *http.Request) (clientID, clientSecret string, err error) + // ClientFormHandler 客户端表单信息 func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) { clientID = r.Form.Get("client_id") @@ -53,3 +55,45 @@ func ClientBasicHandler(r *http.Request) (clientID, clientSecret string, err err clientSecret = password return } + +// ClientFastHandler 基于fasthttp获取客户端认证信息 +type ClientFastHandler func(ctx *fasthttp.RequestCtx) (clientID, clientSecret string, err error) + +// ClientFormFastHandler 客户端表单信息(基于fasthttp) +func ClientFormFastHandler(ctx *fasthttp.RequestCtx) (clientID, clientSecret string, err error) { + clientID = string(ctx.FormValue("client_id")) + clientSecret = string(ctx.FormValue("client_secret")) + if clientID == "" || clientSecret == "" { + err = ErrAuthorizationFormInvalid + } + return +} + +// ClientBasicFastHandler 客户端基础认证信息(基于fasthttp) +func ClientBasicFastHandler(ctx *fasthttp.RequestCtx) (clientID, clientSecret string, err error) { + auth := string(ctx.Request.Header.Peek("Authorization")) + const prefix = "Basic " + if auth == "" || !strings.HasPrefix(auth, prefix) { + err = ErrAuthorizationHeaderInvalid + return + } + c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + err = ErrAuthorizationHeaderInvalid + return + } + clientID = cs[:s] + clientSecret = cs[s+1:] + return +} + +// UserHandler 密码模式下,根据用户名、密码获取用户标识 +type UserHandler func(username, password string) (userID string, err error) + +// ScopeHandler 更新令牌时的授权范围检查 +type ScopeHandler func(new, old string) (err error) diff --git a/server/fastserver.go b/server/fastserver.go new file mode 100644 index 0000000..4203b7a --- /dev/null +++ b/server/fastserver.go @@ -0,0 +1,175 @@ +package server + +import ( + "encoding/json" + "net/url" + "time" + + "github.com/valyala/fasthttp" + "gopkg.in/oauth2.v2" +) + +// NewFastServer 创建基于fasthttp的OAuth2服务实例 +func NewFastServer(cfg *Config, manager oauth2.Manager) *FastServer { + srv := &FastServer{} + srv.cfg = cfg + srv.manager = manager + srv.SetClientHandler(ClientFormFastHandler) + return srv +} + +// FastServer 基于fasthttp(https://github.com/valyala/fasthttp)的OAuth2服务处理 +type FastServer struct { + Server +} + +// SetClientHandler 设置客户端处理 +func (fs *FastServer) SetClientHandler(handler ClientFastHandler) { + fs.cfg.Handler.ClientFastHandler = handler +} + +// GetAuthorizeRequest 获取授权请求参数 +func (fs *FastServer) GetAuthorizeRequest(ctx *fasthttp.RequestCtx) (authReq *AuthorizeRequest, err error) { + if !ctx.IsGet() { + err = ErrRequestMethodInvalid + return + } + redirectURI, err := url.QueryUnescape(string(ctx.FormValue("redirect_uri"))) + if err != nil { + return + } + authReq = &AuthorizeRequest{ + Type: oauth2.ResponseType(string(ctx.FormValue("response_type"))), + RedirectURI: redirectURI, + State: string(ctx.FormValue("state")), + Scope: string(ctx.FormValue("scope")), + ClientID: string(ctx.FormValue("client_id")), + } + if authReq.Type == "" || !fs.checkResponseType(authReq.Type) { + err = ErrResponseTypeInvalid + } else if authReq.ClientID == "" { + err = ErrClientInvalid + } + return +} + +// HandleAuthorizeRequest 处理授权请求 +func (fs *FastServer) HandleAuthorizeRequest(ctx *fasthttp.RequestCtx, authReq *AuthorizeRequest) (err error) { + if authReq.UserID == "" { + err = ErrUserInvalid + return + } + tgr := &oauth2.TokenGenerateRequest{ + ClientID: authReq.ClientID, + UserID: authReq.UserID, + RedirectURI: authReq.RedirectURI, + Scope: authReq.Scope, + } + ti, terr := fs.manager.GenerateAuthToken(oauth2.Code, tgr) + if terr != nil { + err = terr + return + } + redirectURI, err := fs.GetRedirectURI(authReq, ti) + if err != nil { + return + } + ctx.Redirect(redirectURI, 302) + return +} + +// HandleTokenRequest 处理令牌请求 +func (fs *FastServer) HandleTokenRequest(ctx *fasthttp.RequestCtx) (err error) { + if !ctx.IsPost() { + err = ErrRequestMethodInvalid + return + } + gt := oauth2.GrantType(string(ctx.FormValue("grant_type"))) + if gt == "" || !fs.checkGrantType(gt) { + err = ErrGrantTypeInvalid + return + } + + var ti oauth2.TokenInfo + clientID, clientSecret, err := fs.cfg.Handler.ClientFastHandler(ctx) + if err != nil { + return + } + tgr := &oauth2.TokenGenerateRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + } + + switch gt { + case oauth2.AuthorizationCodeCredentials: + tgr.RedirectURI = string(ctx.FormValue("redirect_uri")) + tgr.Code = string(ctx.FormValue("code")) + tgr.IsGenerateRefresh = true + ti, err = fs.manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, tgr) + case oauth2.PasswordCredentials: + userID, uerr := fs.cfg.Handler.UserHandler(string(ctx.FormValue("username")), string(ctx.FormValue("password"))) + if uerr != nil { + err = uerr + return + } + tgr.UserID = userID + tgr.Scope = string(ctx.FormValue("scope")) + tgr.IsGenerateRefresh = true + ti, err = fs.manager.GenerateAccessToken(oauth2.PasswordCredentials, tgr) + case oauth2.ClientCredentials: + tgr.Scope = string(ctx.FormValue("scope")) + ti, err = fs.manager.GenerateAccessToken(oauth2.ClientCredentials, tgr) + case oauth2.RefreshCredentials: + tgr.Refresh = string(ctx.FormValue("refresh_token")) + tgr.Scope = string(ctx.FormValue("scope")) + if tgr.Scope != "" { // 检查授权范围 + rti, rerr := fs.manager.LoadRefreshToken(tgr.Refresh) + if rerr != nil { + err = rerr + return + } else if rti.GetClientID() != tgr.ClientID { + err = ErrRefreshInvalid + return + } else if verr := fs.cfg.Handler.ScopeHandler(tgr.Scope, rti.GetScope()); verr != nil { + err = verr + return + } + } + ti, err = fs.manager.RefreshAccessToken(tgr) + if err == nil { + ti.SetRefresh("") + } + } + + if err != nil { + return + } + err = fs.ResJSON(ctx, ti) + return +} + +// ResJSON 响应Json数据 +func (fs *FastServer) ResJSON(ctx *fasthttp.RequestCtx, ti oauth2.TokenInfo) (err error) { + data := map[string]interface{}{ + "access_token": ti.GetAccess(), + "token_type": fs.cfg.TokenType, + "expires_in": ti.GetAccessExpiresIn() / time.Second, + } + if scope := ti.GetScope(); scope != "" { + data["scope"] = scope + } + if refresh := ti.GetRefresh(); refresh != "" { + data["refresh_token"] = refresh + } + buf, err := json.Marshal(data) + if err != nil { + return + } + ctx.Response.Header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + ctx.Response.Header.Set("Pragma", "no-cache") + ctx.Response.Header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") + ctx.SetContentType("application/json;charset=UTF-8") + ctx.SetStatusCode(200) + _, err = ctx.Write(buf) + return nil +} diff --git a/server/server.go b/server/server.go index 590841b..15db1db 100644 --- a/server/server.go +++ b/server/server.go @@ -115,33 +115,58 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, authReq *Authoriz RedirectURI: authReq.RedirectURI, Scope: authReq.Scope, } - ti, terr := s.manager.GenerateAuthToken(oauth2.Code, tgr) - if terr != nil { - err = terr + ti, err := s.manager.GenerateAuthToken(oauth2.Code, tgr) + if err != nil { + return + } + redirectURI, err := s.GetRedirectURI(authReq, ti) + if err != nil { return } - s.ResRedirectURI(w, authReq, ti) + w.Header().Set("Location", redirectURI) + w.WriteHeader(302) + return +} + +// GetRedirectURI 获取重定向URI +func (s *Server) GetRedirectURI(authReq *AuthorizeRequest, ti oauth2.TokenInfo) (uri string, err error) { + u, err := url.Parse(authReq.RedirectURI) + if err != nil { + return + } + q := u.Query() + q.Set("state", authReq.State) + switch authReq.Type { + case oauth2.Code: + q.Set("code", ti.GetAccess()) + u.RawQuery = q.Encode() + case oauth2.Token: + q.Set("access_token", ti.GetAccess()) + q.Set("token_type", s.cfg.TokenType) + q.Set("expires_in", strconv.FormatInt(int64(ti.GetAccessExpiresIn()/time.Second), 10)) + q.Set("scope", ti.GetScope()) + u.RawQuery = "" + u.Fragment, err = url.QueryUnescape(q.Encode()) + if err != nil { + return + } + } + uri = u.String() return } // HandleTokenRequest 处理令牌请求 -// cli 获取客户端信息 -// user 获取用户信息 func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err error) { if r.Method != "POST" { err = ErrRequestMethodInvalid return } - if verr := r.ParseForm(); verr != nil { - err = verr - return - } + r.ParseForm() gt := oauth2.GrantType(r.Form.Get("grant_type")) if gt == "" || !s.checkGrantType(gt) { err = ErrGrantTypeInvalid return } - var ti oauth2.TokenInfo clientID, clientSecret, err := s.cfg.Handler.ClientHandler(r) if err != nil { @@ -151,8 +176,7 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err ClientID: clientID, ClientSecret: clientSecret, } - - switch oauth2.GrantType(r.Form.Get("grant_type")) { + switch gt { case oauth2.AuthorizationCodeCredentials: tgr.RedirectURI = r.Form.Get("redirect_uri") tgr.Code = r.Form.Get("code") @@ -200,41 +224,6 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err return } -func (s *Server) handleReponse(w http.ResponseWriter) { - w.Header().Add("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") - w.Header().Add("Pragma", "no-cache") - w.Header().Add("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") -} - -// ResRedirectURI 响应数据到重定向URI -func (s *Server) ResRedirectURI(w http.ResponseWriter, authReq *AuthorizeRequest, ti oauth2.TokenInfo) (err error) { - u, err := url.Parse(authReq.RedirectURI) - if err != nil { - return - } - q := u.Query() - q.Set("state", authReq.State) - switch authReq.Type { - case oauth2.Code: - q.Set("code", ti.GetAccess()) - u.RawQuery = q.Encode() - case oauth2.Token: - q.Set("access_token", ti.GetAccess()) - q.Set("token_type", s.cfg.TokenType) - q.Set("expires_in", strconv.FormatInt(int64(ti.GetAccessExpiresIn()/time.Second), 10)) - q.Set("scope", ti.GetScope()) - u.RawQuery = "" - u.Fragment, err = url.QueryUnescape(q.Encode()) - if err != nil { - return - } - } - s.handleReponse(w) - w.Header().Add("Location", u.String()) - w.WriteHeader(302) - return -} - // ResJSON 响应Json数据 func (s *Server) ResJSON(w http.ResponseWriter, ti oauth2.TokenInfo) (err error) { data := map[string]interface{}{ @@ -248,7 +237,9 @@ func (s *Server) ResJSON(w http.ResponseWriter, ti oauth2.TokenInfo) (err error) if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } - s.handleReponse(w) + w.Header().Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.WriteHeader(http.StatusOK) return json.NewEncoder(w).Encode(data)