From 5c03098438e49414eb4a83c9d371ade72d8784c8 Mon Sep 17 00:00:00 2001 From: lyric Date: Tue, 19 Jul 2016 09:51:58 +0800 Subject: [PATCH] Add server tests --- README.md | 90 ++++--------- example/server/main.go | 3 +- generates/access_test.go | 5 +- generates/authorize_test.go | 5 +- manage/manage_test.go | 14 +- manage/manager.go | 4 + manage/util_test.go | 6 +- server/server.go | 16 ++- server/server_test.go | 260 ++++++++++++++++++++++++++++++++++++ store/client/temp.go | 4 - store/token/redis.go | 25 +++- store/token/redis_test.go | 7 +- 12 files changed, 348 insertions(+), 91 deletions(-) create mode 100644 server/server_test.go diff --git a/README.md b/README.md index 2ea2876..177bf2b 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,29 @@ -基于Golang的OAuth2服务实现 -======================= - -> 完全模块化、支持http/fasthttp的服务端处理、令牌存储支持redis/mongodb +OAuth 2.0 +========= +> [OAuth 2.0](http://oauth.net/2/) is the next evolution of the OAuth protocol which was originally created in late 2006. [![GoDoc](https://godoc.org/gopkg.in/oauth2.v3?status.svg)](https://godoc.org/gopkg.in/oauth2.v3) [![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v3)](https://goreportcard.com/report/gopkg.in/oauth2.v3) -获取 ----- +Quick Start +----------- + +### Download and install ``` bash $ go get -u gopkg.in/oauth2.v3/... ``` -HTTP服务端 --------- +### Create file `server.go` ``` go package main import ( - "log" "net/http" "gopkg.in/oauth2.v3/manage" "gopkg.in/oauth2.v3/server" - "gopkg.in/oauth2.v3/store/client" "gopkg.in/oauth2.v3/store/token" ) @@ -33,78 +31,48 @@ func main() { manager := manage.NewRedisManager( &token.RedisConfig{Addr: "192.168.33.70:6379"}, ) - manager.MapClientStorage(client.NewTempStore()) srv := server.NewServer(server.NewConfig(), manager) - + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + // validation and to get the user id + userID = "000000" + return + }) http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { - authReq, err := srv.GetAuthorizeRequest(r) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - // TODO: 登录验证、授权处理 - authReq.UserID = "000000" - - err = srv.HandleAuthorizeRequest(w, authReq) + err := srv.HandleAuthorizeRequest(w, r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) } }) - http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { err := srv.HandleTokenRequest(w, r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) } }) - - log.Fatal(http.ListenAndServe(":9096", nil)) + http.ListenAndServe(":9096", nil) } - -``` - -FastHTTP服务端 -------------- - -``` go -srv := server.NewFastServer(server.NewConfig(), manager) - -fasthttp.ListenAndServe(":9096", func(ctx *fasthttp.RequestCtx) { - switch string(ctx.Request.URI().Path()) { - case "/authorize": - authReq, err := srv.GetAuthorizeRequest(ctx) - if err != nil { - ctx.Error(err.Error(), 400) - return - } - authReq.UserID = "000000" - // TODO: 登录验证、授权处理 - err = srv.HandleAuthorizeRequest(ctx, authReq) - if err != nil { - ctx.Error(err.Error(), 400) - } - case "/token": - err := srv.HandleTokenRequest(ctx) - if err != nil { - ctx.Error(err.Error(), 400) - } - } -}) ``` -测试 ----- -> [goconvey](https://github.com/smartystreets/goconvey) +### Build and run ``` bash -$ goconvey -port=9092 +$ go build server.go +$ ./server ``` -范例 ----- +Features +-------- + +* Based on the [RFC 6749](https://tools.ietf.org/html/rfc6749) implementation +* Easy to use +* Modularity +* Flexible +* Elegant -模拟授权码模式的测试范例,请查看[example](/example) +Example +------- +Simulation examples of authorization code model, please check [example](/example) License ------- diff --git a/example/server/main.go b/example/server/main.go index 21c19a1..fe43f61 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -1,11 +1,10 @@ package main import ( + "fmt" "log" "net/http" - "fmt" - "gopkg.in/oauth2.v3" "gopkg.in/oauth2.v3/manage" "gopkg.in/oauth2.v3/models" diff --git a/generates/access_test.go b/generates/access_test.go index 2b04a79..b431eb7 100644 --- a/generates/access_test.go +++ b/generates/access_test.go @@ -1,10 +1,11 @@ -package generates +package generates_test import ( "testing" "time" "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/generates" "gopkg.in/oauth2.v3/models" . "github.com/smartystreets/goconvey/convey" @@ -20,7 +21,7 @@ func TestAccess(t *testing.T) { UserID: "000000", CreateAt: time.Now(), } - gen := NewAccessGenerate() + gen := generates.NewAccessGenerate() access, refresh, err := gen.Token(data, true) So(err, ShouldBeNil) So(access, ShouldNotBeEmpty) diff --git a/generates/authorize_test.go b/generates/authorize_test.go index b50b230..c62dbd6 100644 --- a/generates/authorize_test.go +++ b/generates/authorize_test.go @@ -1,10 +1,11 @@ -package generates +package generates_test import ( "testing" "time" "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/generates" "gopkg.in/oauth2.v3/models" . "github.com/smartystreets/goconvey/convey" @@ -20,7 +21,7 @@ func TestAuthorize(t *testing.T) { UserID: "000000", CreateAt: time.Now(), } - gen := NewAuthorizeGenerate() + gen := generates.NewAuthorizeGenerate() code, err := gen.Token(data) So(err, ShouldBeNil) So(code, ShouldNotBeEmpty) diff --git a/manage/manage_test.go b/manage/manage_test.go index c7c1c9d..32ce33a 100644 --- a/manage/manage_test.go +++ b/manage/manage_test.go @@ -1,10 +1,11 @@ -package manage +package manage_test import ( "testing" "gopkg.in/oauth2.v3" "gopkg.in/oauth2.v3/generates" + "gopkg.in/oauth2.v3/manage" "gopkg.in/oauth2.v3/models" "gopkg.in/oauth2.v3/store/client" "gopkg.in/oauth2.v3/store/token" @@ -14,7 +15,7 @@ import ( func TestManager(t *testing.T) { Convey("Manager test", t, func() { - manager := NewManager() + manager := manage.NewManager() manager.MapClientModel(models.NewClient()) manager.MapTokenModel(models.NewToken()) @@ -51,11 +52,10 @@ func testManager(manager oauth2.Manager) { So(code, ShouldNotBeEmpty) atParams := &oauth2.TokenGenerateRequest{ - ClientID: reqParams.ClientID, - ClientSecret: "11", - RedirectURI: reqParams.RedirectURI, - Code: code, - IsGenerateRefresh: true, + ClientID: reqParams.ClientID, + ClientSecret: "11", + RedirectURI: reqParams.RedirectURI, + Code: code, } ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCode, atParams) So(err, ShouldBeNil) diff --git a/manage/manager.go b/manage/manager.go index 5211eaa..ba29c83 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -223,6 +223,10 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene if gt == oauth2.AuthorizationCode { ti, terr := m.LoadAccessToken(tgr.Code) if terr != nil { + if terr == errors.ErrInvalidAccessToken { + err = errors.ErrInvalidAuthorizeCode + return + } err = terr return } else if ti.GetRedirectURI() != tgr.RedirectURI || ti.GetClientID() != tgr.ClientID { diff --git a/manage/util_test.go b/manage/util_test.go index f936cbd..1faa89e 100644 --- a/manage/util_test.go +++ b/manage/util_test.go @@ -1,15 +1,17 @@ -package manage +package manage_test import ( "testing" + "gopkg.in/oauth2.v3/manage" + . "github.com/smartystreets/goconvey/convey" ) func TestUtil(t *testing.T) { Convey("Util Test", t, func() { Convey("ValidateURI Test", func() { - err := ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") + err := manage.ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") So(err, ShouldBeNil) }) }) diff --git a/server/server.go b/server/server.go index 8225da0..52c0ca1 100644 --- a/server/server.go +++ b/server/server.go @@ -252,6 +252,10 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) ierr error ) defer func() { + if verr := recover(); verr != nil { + err = fmt.Errorf("%v", verr) + return + } data := s.GetErrorData(rerr, ierr) if data != nil { if req == nil { @@ -303,8 +307,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (gt oauth2.GrantType, t ierr = err return } - tgr.ClientID = clientID - tgr.ClientSecret = clientSecret + tgr = &oauth2.TokenGenerateRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + } switch gt { case oauth2.AuthorizationCode: tgr.RedirectURI = r.Form.Get("redirect_uri") @@ -425,7 +431,7 @@ func (s *Server) GetTokenData(ti oauth2.TokenInfo) (data map[string]interface{}) data = map[string]interface{}{ "access_token": ti.GetAccess(), "token_type": s.Config.TokenType, - "expires_in": ti.GetAccessExpiresIn() / time.Second, + "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } if scope := ti.GetScope(); scope != "" { data["scope"] = scope @@ -444,6 +450,10 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err ierr error ) defer func() { + if verr := recover(); verr != nil { + err = fmt.Errorf("%v", verr) + return + } data := s.GetErrorData(rerr, ierr) if data == nil { data = s.GetTokenData(ti) diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..1867b9f --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,260 @@ +package server_test + +import ( + "errors" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gavv/httpexpect" + "gopkg.in/oauth2.v3/manage" + "gopkg.in/oauth2.v3/models" + "gopkg.in/oauth2.v3/server" + "gopkg.in/oauth2.v3/store/client" + "gopkg.in/oauth2.v3/store/token" +) + +var ( + srv *server.Server + tsrv *httptest.Server + manager *manage.Manager + csrv *httptest.Server +) + +func init() { + manager = manage.NewRedisManager( + &token.RedisConfig{Addr: "192.168.33.70:6379"}, + ) +} + +func testServer(t *testing.T, w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/authorize": + err := srv.HandleAuthorizeRequest(w, r) + if err != nil { + t.Error(err) + } + case "/token": + err := srv.HandleTokenRequest(w, r) + if err != nil { + t.Error(err) + } + } +} + +func TestAuthorizeCode(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2": + r.ParseForm() + code, state := r.Form.Get("code"), r.Form.Get("state") + if state != "123" { + t.Error("unrecognized state:", state) + return + } + val := e.POST("/token"). + WithFormField("redirect_uri", csrv.URL+"/oauth2"). + WithFormField("code", code). + WithFormField("grant_type", "authorization_code"). + WithFormField("client_id", "333333"). + WithFormField("client_secret", "33333333"). + Expect(). + Status(http.StatusOK). + JSON().Raw() + + t.Log(val) + } + })) + defer csrv.Close() + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "333333", + Secret: "33333333", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "111111" + return + }) + + e.GET("/authorize"). + WithQuery("response_type", "code"). + WithQuery("client_id", "333333"). + WithQuery("scope", "all"). + WithQuery("state", "123"). + WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + Expect().Status(http.StatusOK) +} + +func TestImplicit(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2": + t.Log(r.RequestURI) + } + })) + defer csrv.Close() + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "55555", + Secret: "5555555", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "222222" + return + }) + + e.GET("/authorize"). + WithQuery("response_type", "token"). + WithQuery("client_id", "55555"). + WithQuery("scope", "all"). + WithQuery("state", "123"). + WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + Expect().Status(http.StatusOK) +} + +func TestPasswordCredentials(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "666666", + Secret: "66666666", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + srv.SetPasswordAuthorizationHandler(func(username, password string) (userID string, err error) { + if username == "admin" && password == "123456" { + userID = "666666" + return + } + err = errors.New("user not found") + return + }) + + val := e.POST("/token"). + WithFormField("grant_type", "password"). + WithFormField("client_id", "666666"). + WithFormField("client_secret", "66666666"). + WithFormField("username", "admin"). + WithFormField("password", "123456"). + WithFormField("scope", "all"). + Expect(). + Status(http.StatusOK). + JSON().Raw() + + t.Log(val) +} + +func TestClientCredentials(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "777777", + Secret: "77777777", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + + val := e.POST("/token"). + WithFormField("grant_type", "clientcredentials"). + WithFormField("client_id", "777777"). + WithFormField("client_secret", "77777777"). + WithFormField("scope", "all"). + Expect(). + Status(http.StatusOK). + JSON().Raw() + + t.Log(val) +} + +func TestRefreshing(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2": + r.ParseForm() + code, state := r.Form.Get("code"), r.Form.Get("state") + if state != "123" { + t.Error("unrecognized state:", state) + return + } + jval := e.POST("/token"). + WithFormField("redirect_uri", csrv.URL+"/oauth2"). + WithFormField("code", code). + WithFormField("grant_type", "authorization_code"). + WithFormField("client_id", "888888"). + WithFormField("client_secret", "88888888"). + Expect(). + Status(http.StatusOK). + JSON() + + refresh := jval.Object().Value("refresh_token").String().Raw() + + rval := e.POST("/token"). + WithFormField("grant_type", "refreshtoken"). + WithFormField("client_id", "888888"). + WithFormField("client_secret", "88888888"). + WithFormField("scope", "one"). + WithFormField("refresh_token", refresh). + Expect(). + Status(http.StatusOK). + JSON().Raw() + + t.Log(rval) + } + })) + defer csrv.Close() + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "888888", + Secret: "88888888", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "888888" + return + }) + + e.GET("/authorize"). + WithQuery("response_type", "code"). + WithQuery("client_id", "888888"). + WithQuery("scope", "all"). + WithQuery("state", "123"). + WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + Expect().Status(http.StatusOK) +} diff --git a/store/client/temp.go b/store/client/temp.go index b9c671e..a59293f 100644 --- a/store/client/temp.go +++ b/store/client/temp.go @@ -1,8 +1,6 @@ package client import ( - "errors" - "gopkg.in/oauth2.v3" "gopkg.in/oauth2.v3/models" ) @@ -33,8 +31,6 @@ type TempStore struct { func (ts *TempStore) GetByID(id string) (cli oauth2.ClientInfo, err error) { if c, ok := ts.data[id]; ok { cli = c - return } - err = errors.New("not found") return } diff --git a/store/token/redis.go b/store/token/redis.go index 43e9dd4..575f26d 100644 --- a/store/token/redis.go +++ b/store/token/redis.go @@ -87,7 +87,10 @@ func (rs *RedisStore) Create(info oauth2.TokenInfo) (err error) { // remove func (rs *RedisStore) remove(key string) (err error) { - _, err = rs.cli.Del(key).Result() + _, verr := rs.cli.Del(key).Result() + if verr != redis.Nil { + err = verr + } return } @@ -105,16 +108,28 @@ func (rs *RedisStore) RemoveByRefresh(refresh string) (err error) { // get func (rs *RedisStore) get(token string) (ti oauth2.TokenInfo, err error) { - tv, err := rs.cli.Get(token).Result() - if err != nil { + tv, verr := rs.cli.Get(token).Result() + if verr != nil { + if verr == redis.Nil { + return + } + err = verr + return + } + result := rs.cli.Get(tv) + if verr := result.Err(); verr != nil { + if verr == redis.Nil { + return + } + err = verr return } - iv, err := rs.cli.Get(tv).Result() + iv, err := result.Bytes() if err != nil { return } var tm models.Token - if verr := json.Unmarshal([]byte(iv), &tm); verr != nil { + if verr := json.Unmarshal(iv, &tm); verr != nil { err = verr return } diff --git a/store/token/redis_test.go b/store/token/redis_test.go index 7a0e1a3..fefa63c 100644 --- a/store/token/redis_test.go +++ b/store/token/redis_test.go @@ -1,20 +1,21 @@ -package token +package token_test import ( "testing" "time" "gopkg.in/oauth2.v3/models" + "gopkg.in/oauth2.v3/store/token" . "github.com/smartystreets/goconvey/convey" ) func TestRedisStore(t *testing.T) { Convey("Test redis store", t, func() { - cfg := &RedisConfig{ + cfg := &token.RedisConfig{ Addr: "192.168.33.70:6379", } - store, err := NewRedisStore(cfg) + store, err := token.NewRedisStore(cfg) So(err, ShouldBeNil) Convey("Test access token store", func() {