diff --git a/README.md b/README.md index e6c9532..177bf2b 100644 --- a/README.md +++ b/README.md @@ -1,111 +1,78 @@ -基于Golang的OAuth2服务实现 -======================= +OAuth 2.0 +========= +> [OAuth 2.0](http://oauth.net/2/) is the next evolution of the OAuth protocol which was originally created in late 2006. -> 完全模块化、支持http/fasthttp的服务端处理、令牌存储支持redis/mongodb +[![GoDoc](https://godoc.org/gopkg.in/oauth2.v3?status.svg)](https://godoc.org/gopkg.in/oauth2.v3) +[![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v3)](https://goreportcard.com/report/gopkg.in/oauth2.v3) -[![GoDoc](https://godoc.org/gopkg.in/oauth2.v2?status.svg)](https://godoc.org/gopkg.in/oauth2.v2) -[![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v2)](https://goreportcard.com/report/gopkg.in/oauth2.v2) +Quick Start +----------- -获取 ----- +### Download and install ``` bash -$ go get -u gopkg.in/oauth2.v2/... +$ go get -u gopkg.in/oauth2.v3/... ``` -HTTP服务端 --------- +### Create file `server.go` ``` go package main import ( - "log" "net/http" - "gopkg.in/oauth2.v2/manage" - "gopkg.in/oauth2.v2/models" - "gopkg.in/oauth2.v2/server" - "gopkg.in/oauth2.v2/store/client" - "gopkg.in/oauth2.v2/store/token" + "gopkg.in/oauth2.v3/manage" + "gopkg.in/oauth2.v3/server" + "gopkg.in/oauth2.v3/store/token" ) func main() { manager := manage.NewRedisManager( &token.RedisConfig{Addr: "192.168.33.70:6379"}, ) - manager.MapClientStorage(client.NewTempStore()) srv := server.NewServer(server.NewConfig(), manager) - + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + // validation and to get the user id + userID = "000000" + return + }) http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { - authReq, err := srv.GetAuthorizeRequest(r) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - // TODO: 登录验证、授权处理 - authReq.UserID = "000000" - - err = srv.HandleAuthorizeRequest(w, authReq) + err := srv.HandleAuthorizeRequest(w, r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) } }) - http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { err := srv.HandleTokenRequest(w, r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) } }) - - log.Fatal(http.ListenAndServe(":9096", nil)) + http.ListenAndServe(":9096", nil) } - -``` - -FastHTTP服务端 -------------- - -``` go -srv := server.NewFastServer(server.NewConfig(), manager) - -fasthttp.ListenAndServe(":9096", func(ctx *fasthttp.RequestCtx) { - switch string(ctx.Request.URI().Path()) { - case "/authorize": - authReq, err := srv.GetAuthorizeRequest(ctx) - if err != nil { - ctx.Error(err.Error(), 400) - return - } - authReq.UserID = "000000" - // TODO: 登录验证、授权处理 - err = srv.HandleAuthorizeRequest(ctx, authReq) - if err != nil { - ctx.Error(err.Error(), 400) - } - case "/token": - err := srv.HandleTokenRequest(ctx) - if err != nil { - ctx.Error(err.Error(), 400) - } - } -}) ``` -测试 ----- -> [goconvey](https://github.com/smartystreets/goconvey) +### Build and run ``` bash -$ goconvey -port=9092 +$ go build server.go +$ ./server ``` -范例 ----- +Features +-------- + +* Based on the [RFC 6749](https://tools.ietf.org/html/rfc6749) implementation +* Easy to use +* Modularity +* Flexible +* Elegant -模拟授权码模式的测试范例,请查看[example](/example) +Example +------- +Simulation examples of authorization code model, please check [example](/example) License ------- diff --git a/const.go b/const.go index 29d9efd..49cfe2e 100644 --- a/const.go +++ b/const.go @@ -1,12 +1,12 @@ package oauth2 -// ResponseType 定义授权类型 +// ResponseType Response Type type ResponseType string const ( - // Code 授权码类型 + // Code Authorization code type Code ResponseType = "code" - // Token 令牌类型 + // Token Token type Token ResponseType = "token" ) @@ -14,18 +14,20 @@ func (rt ResponseType) String() string { return string(rt) } -// GrantType 定义授权模式 +// GrantType Authorization Grant type GrantType string const ( - // AuthorizationCodeCredentials 授权码模式 - AuthorizationCodeCredentials GrantType = "authorization_code" - // PasswordCredentials 密码模式 + // AuthorizationCode Authorization Code + AuthorizationCode GrantType = "authorization_code" + // PasswordCredentials Resource Owner Password Credentials PasswordCredentials GrantType = "password" - // ClientCredentials 客户端模式 + // ClientCredentials Client Credentials ClientCredentials GrantType = "clientcredentials" - // RefreshCredentials 更新令牌模式 - RefreshCredentials GrantType = "refreshtoken" + // Refreshing Refresh Token + Refreshing GrantType = "refreshtoken" + // Implicit Implicit Grant + Implicit GrantType = "__implicit" ) func (gt GrantType) String() string { diff --git a/errors/error.go b/errors/error.go new file mode 100644 index 0000000..2e36800 --- /dev/null +++ b/errors/error.go @@ -0,0 +1,55 @@ +package errors + +import "errors" + +var ( + // ErrUnauthorizedClient unauthorized client + ErrUnauthorizedClient = errors.New("unauthorized_client") + + // ErrAccessDenied access denied + ErrAccessDenied = errors.New("access_denied") + + // ErrUnsupportedResponseType unsupported response type + ErrUnsupportedResponseType = errors.New("unsupported_response_type") + + // ErrInvalidScope invalid scope + ErrInvalidScope = errors.New("invalid_scope") + + // ErrInvalidRequest invalid request + ErrInvalidRequest = errors.New("invalid_request") + + // ErrInvalidClient invalid client + ErrInvalidClient = errors.New("invalid_client") + + // ErrInvalidGrant invalid grant + ErrInvalidGrant = errors.New("invalid_grant") + + // ErrUnsupportedGrantType unsupported grant type + ErrUnsupportedGrantType = errors.New("unsupported_grant_type") + + // ErrServerError server error + ErrServerError = errors.New("server_error") +) + +var ( + // ErrNilValue Nil Value + ErrNilValue = errors.New("nil value") + + // ErrInvalidRedirectURI invalid redirect uri + ErrInvalidRedirectURI = errors.New("invalid redirect uri") + + // ErrInvalidAuthorizeCode invalid authorize code + ErrInvalidAuthorizeCode = errors.New("invalid authorize code") + + // ErrInvalidAccessToken invalid access token + ErrInvalidAccessToken = errors.New("invalid access token") + + // ErrInvalidRefreshToken invalid refresh token + ErrInvalidRefreshToken = errors.New("invalid refresh token") + + // ErrExpiredAccessToken expired access token + ErrExpiredAccessToken = errors.New("expired access token") + + // ErrExpiredRefreshToken expired refresh token + ErrExpiredRefreshToken = errors.New("expired refresh token") +) diff --git a/example/README.md b/example/README.md index ea29300..8b19fd2 100644 --- a/example/README.md +++ b/example/README.md @@ -1,25 +1,24 @@ -OAuth2授权码模式模拟 -================= +Authorization code simulation +============================= -运行服务端 --------- -> 运行fasthttp服务端,请使用`cd example/fastserver` +Run Server +--------- -``` +``` bash $ cd example/server $ go run main.go ``` -运行客户端 --------- +Run Client +---------- ``` $ cd example/client $ go run main.go ``` -打开浏览器 --------- +Open the browser +---------------- [http://localhost:9094](http://localhost:9094) diff --git a/example/client/main.go b/example/client/main.go index 031a194..8aaa1a5 100644 --- a/example/client/main.go +++ b/example/client/main.go @@ -54,6 +54,6 @@ func main() { io.Copy(w, resp.Body) }) - log.Println("OAuth2 client is running at 9094 port.") + log.Println("Client is running at 9094 port.") log.Fatal(http.ListenAndServe(":9094", nil)) } diff --git a/example/fastserver/main.go b/example/fastserver/main.go deleted file mode 100644 index fa07b5d..0000000 --- a/example/fastserver/main.go +++ /dev/null @@ -1,50 +0,0 @@ -package main - -import ( - "log" - - "github.com/valyala/fasthttp" - "gopkg.in/oauth2.v2/manage" - "gopkg.in/oauth2.v2/models" - "gopkg.in/oauth2.v2/server" - "gopkg.in/oauth2.v2/store/client" - "gopkg.in/oauth2.v2/store/token" -) - -func main() { - // 创建基于redis的oauth2管理实例 - manager := manage.NewRedisManager( - &token.RedisConfig{Addr: "192.168.33.70:6379"}, - ) - // 使用临时客户端存储 - manager.MapClientStorage(client.NewTempStore(&models.Client{ - ID: "222222", - Secret: "22222222", - Domain: "http://localhost:9094", - })) - - srv := server.NewFastServer(server.NewConfig(), manager) - - log.Println("OAuth2 server is running at 9096 port.") - fasthttp.ListenAndServe(":9096", func(ctx *fasthttp.RequestCtx) { - switch string(ctx.Request.URI().Path()) { - case "/authorize": - authReq, err := srv.GetAuthorizeRequest(ctx) - if err != nil { - ctx.Error(err.Error(), 400) - return - } - authReq.UserID = "000000" - // TODO: 登录验证、授权处理 - err = srv.HandleAuthorizeRequest(ctx, authReq) - if err != nil { - ctx.Error(err.Error(), 400) - } - case "/token": - err := srv.HandleTokenRequest(ctx) - if err != nil { - ctx.Error(err.Error(), 400) - } - } - }) -} diff --git a/example/server/main.go b/example/server/main.go index c82f4ff..fe43f61 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -1,22 +1,23 @@ package main import ( + "fmt" "log" "net/http" - "gopkg.in/oauth2.v2/manage" - "gopkg.in/oauth2.v2/models" - "gopkg.in/oauth2.v2/server" - "gopkg.in/oauth2.v2/store/client" - "gopkg.in/oauth2.v2/store/token" + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/manage" + "gopkg.in/oauth2.v3/models" + "gopkg.in/oauth2.v3/server" + "gopkg.in/oauth2.v3/store/client" + "gopkg.in/oauth2.v3/store/token" ) func main() { - // 创建基于redis的oauth2管理实例 manager := manage.NewRedisManager( &token.RedisConfig{Addr: "192.168.33.70:6379"}, ) - // 使用临时客户端存储 + // Create the client temporary storage manager.MapClientStorage(client.NewTempStore(&models.Client{ ID: "222222", Secret: "22222222", @@ -24,16 +25,18 @@ func main() { })) srv := server.NewServer(server.NewConfig(), manager) + srv.SetAllowedResponseType(oauth2.Code) + srv.SetAllowedGrantType(oauth2.AuthorizationCode) + srv.SetErrorHandler(func(err error) { + fmt.Println("OAuth2 Error:", err.Error()) + }) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "000000" + return + }) http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { - authReq, err := srv.GetAuthorizeRequest(r) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - authReq.UserID = "000000" - // TODO: 登录验证、授权处理 - err = srv.HandleAuthorizeRequest(w, authReq) + err := srv.HandleAuthorizeRequest(w, r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) } @@ -46,6 +49,6 @@ func main() { } }) - log.Println("OAuth2 server is running at 9096 port.") + log.Println("Server is running at 9096 port.") log.Fatal(http.ListenAndServe(":9096", nil)) } diff --git a/generate.go b/generate.go index 7eb0f06..49fb210 100644 --- a/generate.go +++ b/generate.go @@ -3,22 +3,20 @@ package oauth2 import "time" type ( - // GenerateBasic 提供生成令牌的基础数据 + // GenerateBasic Provide the basis of the generated token data GenerateBasic struct { - Client ClientInfo // 客户端信息 - UserID string // 用户标识 - CreateAt time.Time // 创建时间 + Client ClientInfo // The client information + UserID string // The user id + CreateAt time.Time // Creation time } - // AuthorizeGenerate 授权令牌生成接口 + // AuthorizeGenerate Generate the authorization code interface AuthorizeGenerate interface { - // 授权令牌 Token(data *GenerateBasic) (code string, err error) } - // AccessGenerate 访问令牌生成接口 + // AccessGenerate Generate the access and refresh tokens interface AccessGenerate interface { - // 访问令牌、更新令牌 Token(data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error) } ) diff --git a/generates/access.go b/generates/access.go index 2d7c666..c22f6cf 100644 --- a/generates/access.go +++ b/generates/access.go @@ -2,39 +2,34 @@ package generates import ( "bytes" + "encoding/base64" "strconv" "strings" "github.com/LyricTian/go.uuid" - "gopkg.in/LyricTian/lib.v2" - "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v3" ) -// NewAccessGenerate 创建访问令牌生成实例 +// NewAccessGenerate Create to generate the access token instance func NewAccessGenerate() *AccessGenerate { return &AccessGenerate{} } -// AccessGenerate 访问令牌生成 +// AccessGenerate Generate the access token type AccessGenerate struct { } -// Token 生成令牌 +// Token Based on the UUID generated token func (ag *AccessGenerate) Token(data *oauth2.GenerateBasic, isGenRefresh bool) (access, refresh string, err error) { buf := bytes.NewBufferString(data.Client.GetID()) buf.WriteString(data.UserID) buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10)) - access, err = lib.NewEncryption(uuid.NewV3(uuid.NewV4(), buf.String()).Bytes()).MD5() - if err != nil { - return - } - access = strings.ToUpper(access) + + access = base64.URLEncoding.EncodeToString(uuid.NewV3(uuid.NewV4(), buf.String()).Bytes()) + access = strings.ToUpper(strings.TrimRight(access, "=")) if isGenRefresh { - refresh, err = lib.NewEncryption(uuid.NewV5(uuid.NewV4(), buf.String()).Bytes()).Sha1() - if err != nil { - return - } - refresh = strings.ToUpper(refresh) + refresh = base64.URLEncoding.EncodeToString(uuid.NewV5(uuid.NewV4(), buf.String()).Bytes()) + refresh = strings.ToUpper(strings.TrimRight(refresh, "=")) } return diff --git a/generates/access_test.go b/generates/access_test.go index ff533ec..b431eb7 100644 --- a/generates/access_test.go +++ b/generates/access_test.go @@ -1,12 +1,14 @@ -package generates +package generates_test import ( "testing" "time" + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/generates" + "gopkg.in/oauth2.v3/models" + . "github.com/smartystreets/goconvey/convey" - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/models" ) func TestAccess(t *testing.T) { @@ -19,10 +21,12 @@ func TestAccess(t *testing.T) { UserID: "000000", CreateAt: time.Now(), } - gen := NewAccessGenerate() + gen := generates.NewAccessGenerate() access, refresh, err := gen.Token(data, true) So(err, ShouldBeNil) So(access, ShouldNotBeEmpty) So(refresh, ShouldNotBeEmpty) + Println("\nAccess Token:" + access) + Println("Refresh Token:" + refresh) }) } diff --git a/generates/authorize.go b/generates/authorize.go index 32fc50f..8eafdbf 100644 --- a/generates/authorize.go +++ b/generates/authorize.go @@ -2,30 +2,28 @@ package generates import ( "bytes" + "encoding/base64" "strings" "github.com/LyricTian/go.uuid" - "gopkg.in/LyricTian/lib.v2" - "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v3" ) -// NewAuthorizeGenerate 创建授权令牌生成实例 +// NewAuthorizeGenerate Create to generate the authorize code instance func NewAuthorizeGenerate() *AuthorizeGenerate { return &AuthorizeGenerate{} } -// AuthorizeGenerate 授权令牌生成 +// AuthorizeGenerate Generate the authorize code type AuthorizeGenerate struct{} -// Token 生成令牌 +// Token Based on the UUID generated token func (ag *AuthorizeGenerate) Token(data *oauth2.GenerateBasic) (code string, err error) { - buf := bytes.NewBuffer(uuid.NewV1().Bytes()) + buf := bytes.NewBufferString(data.Client.GetID()) buf.WriteString(data.UserID) - buf.WriteString(data.Client.GetID()) - code, err = lib.NewEncryption(buf.Bytes()).MD5() - if err != nil { - return - } - code = strings.ToUpper(code) + token := uuid.NewV3(uuid.NewV1(), buf.String()) + code = base64.URLEncoding.EncodeToString(token.Bytes()) + code = strings.ToUpper(strings.TrimRight(code, "=")) + return } diff --git a/generates/authorize_test.go b/generates/authorize_test.go index d94d67b..c62dbd6 100644 --- a/generates/authorize_test.go +++ b/generates/authorize_test.go @@ -1,12 +1,14 @@ -package generates +package generates_test import ( "testing" "time" + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/generates" + "gopkg.in/oauth2.v3/models" + . "github.com/smartystreets/goconvey/convey" - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/models" ) func TestAuthorize(t *testing.T) { @@ -19,9 +21,10 @@ func TestAuthorize(t *testing.T) { UserID: "000000", CreateAt: time.Now(), } - gen := NewAuthorizeGenerate() + gen := generates.NewAuthorizeGenerate() code, err := gen.Token(data) So(err, ShouldBeNil) So(code, ShouldNotBeEmpty) + Println("\nAuthorize Code:" + code) }) } diff --git a/manage.go b/manage.go index d4635ce..5c5e502 100644 --- a/manage.go +++ b/manage.go @@ -1,50 +1,39 @@ package oauth2 -// TokenGenerateRequest 提供生成令牌的请求参数 +// TokenGenerateRequest Provide to generate the token request parameters type TokenGenerateRequest struct { - ClientID string // 客户端标识 - ClientSecret string // 客户端密钥 - UserID string // 用户标识 - RedirectURI string // 重定向URI - Scope string // 授权范围 - Code string // 授权码(授权码模式使用) - Refresh string // 刷新令牌 - IsGenerateRefresh bool // 是否生成更新令牌 + ClientID string // The client information + ClientSecret string // The client secret + UserID string // The user id + RedirectURI string // Redirect URI + Scope string // Scope of authorization + Code string // Authorization code + Refresh string // Refreshing token } -// Manager OAuth2授权管理接口 +// Manager Authorization management interface type Manager interface { - // GetClient 获取客户端信息 - // clientID 客户端标识 + // GetClient Get the client information GetClient(clientID string) (cli ClientInfo, err error) - // GenerateAuthToken 生成授权令牌 - // rt 授权类型 - // tgr 生成令牌的请求参数 + // GenerateAuthToken Generate the authorization token(code) GenerateAuthToken(rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error) - // GenerateAccessToken 生成访问令牌、更新令牌 - // rt 授权模式 - // tgr 生成令牌的请求参数 + // GenerateAccessToken Generate the access token GenerateAccessToken(rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) - // RefreshAccessToken 更新访问令牌 - // tgr 生成令牌的请求参数 + // RefreshAccessToken Refreshing an access token RefreshAccessToken(tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) - // RemoveAccessToken 删除访问令牌 - // access 访问令牌 + // RemoveAccessToken Use the access token to delete the token information RemoveAccessToken(access string) (err error) - // RemoveRefreshToken 删除更新令牌 - // refresh 更新令牌 + // RemoveRefreshToken Use the refresh token to delete the token information RemoveRefreshToken(refresh string) (err error) - // LoadAccessToken 加载访问令牌信息 - // access 访问令牌 + // LoadAccessToken According to the access token for corresponding token information LoadAccessToken(access string) (ti TokenInfo, err error) - // LoadRefreshToken 加载更新令牌信息 - // refresh 更新令牌 + // LoadRefreshToken According to the refresh token for corresponding token information LoadRefreshToken(refresh string) (ti TokenInfo, err error) } diff --git a/manage/error.go b/manage/error.go deleted file mode 100644 index c5ec58d..0000000 --- a/manage/error.go +++ /dev/null @@ -1,29 +0,0 @@ -package manage - -import "errors" - -var ( - // ErrNilValue Nil Value - ErrNilValue = errors.New("nil value") - - // ErrClientNotFound Client not Found - ErrClientNotFound = errors.New("client not found") - - // ErrClientInvalid Client invalid - ErrClientInvalid = errors.New("client invalid") - - // ErrAuthCodeInvalid Authorize token invalid - ErrAuthCodeInvalid = errors.New("authorize code invalid") - - // ErrAccessInvalid Access token expired - ErrAccessInvalid = errors.New("access token invalid") - - // ErrAccessExpired Access token expired - ErrAccessExpired = errors.New("access token expired") - - // ErrRefreshInvalid Refresh token invalid - ErrRefreshInvalid = errors.New("refresh token invalid") - - // ErrRefreshExpired Refresh token expired - ErrRefreshExpired = errors.New("refresh token expired") -) diff --git a/manage/manage_test.go b/manage/manage_test.go index 8f495e5..32ce33a 100644 --- a/manage/manage_test.go +++ b/manage/manage_test.go @@ -1,19 +1,21 @@ -package manage +package manage_test import ( "testing" + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/generates" + "gopkg.in/oauth2.v3/manage" + "gopkg.in/oauth2.v3/models" + "gopkg.in/oauth2.v3/store/client" + "gopkg.in/oauth2.v3/store/token" + . "github.com/smartystreets/goconvey/convey" - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/generates" - "gopkg.in/oauth2.v2/models" - "gopkg.in/oauth2.v2/store/client" - "gopkg.in/oauth2.v2/store/token" ) func TestManager(t *testing.T) { Convey("Manager test", t, func() { - manager := NewManager() + manager := manage.NewManager() manager.MapClientModel(models.NewClient()) manager.MapTokenModel(models.NewToken()) @@ -33,14 +35,6 @@ func TestManager(t *testing.T) { )) testManager(manager) }) - - Convey("MongoDB store test", func() { - manager.MustTokenStorage(token.NewMongoStore( - &token.MongoConfig{URL: "mongodb://admin:123456@192.168.33.70:27017"}, - )) - testManager(manager) - }) - }) } @@ -58,13 +52,12 @@ func testManager(manager oauth2.Manager) { So(code, ShouldNotBeEmpty) atParams := &oauth2.TokenGenerateRequest{ - ClientID: reqParams.ClientID, - ClientSecret: "11", - RedirectURI: reqParams.RedirectURI, - Code: code, - IsGenerateRefresh: true, + ClientID: reqParams.ClientID, + ClientSecret: "11", + RedirectURI: reqParams.RedirectURI, + Code: code, } - ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, atParams) + ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCode, atParams) So(err, ShouldBeNil) accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh() @@ -97,12 +90,15 @@ func testManager(manager oauth2.Manager) { So(err, ShouldBeNil) So(refreshAInfo.GetScope(), ShouldEqual, "owner") - err = manager.RemoveRefreshToken(refreshToken) + err = manager.RemoveAccessToken(refreshAT) So(err, ShouldBeNil) _, err = manager.LoadAccessToken(refreshAT) So(err, ShouldNotBeNil) + err = manager.RemoveRefreshToken(refreshToken) + So(err, ShouldBeNil) + _, err = manager.LoadRefreshToken(refreshToken) So(err, ShouldNotBeNil) } diff --git a/manage/manager.go b/manage/manager.go index 202aa38..ba29c83 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -4,19 +4,22 @@ import ( "time" "github.com/LyricTian/inject" - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/generates" - "gopkg.in/oauth2.v2/models" - "gopkg.in/oauth2.v2/store/token" + + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/errors" + "gopkg.in/oauth2.v3/generates" + "gopkg.in/oauth2.v3/models" + "gopkg.in/oauth2.v3/store/token" ) -// Config 授权配置参数 +// Config Configuration parameters type Config struct { - TokenExp time.Duration // 令牌有效期 - RefreshExp time.Duration // 更新令牌有效期 + AccessTokenExp time.Duration // Access token expiration time (in seconds) + RefreshTokenExp time.Duration // Refresh token expiration time + IsGenerateRefresh bool // Whether to generate the refreshing token } -// NewRedisManager 创建基于redis存储的管理实例 +// NewRedisManager Create to based on redis store authorization management instance func NewRedisManager(redisCfg *token.RedisConfig) *Manager { m := NewManager() m.MapClientModel(models.NewClient()) @@ -28,141 +31,137 @@ func NewRedisManager(redisCfg *token.RedisConfig) *Manager { return m } -// NewMongoManager 创建基于mongodb存储的管理实例 -func NewMongoManager(mongoCfg *token.MongoConfig) *Manager { - m := NewManager() - m.MapClientModel(models.NewClient()) - m.MapTokenModel(models.NewToken()) - m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) - m.MapAccessGenerate(generates.NewAccessGenerate()) - m.MustTokenStorage(token.NewMongoStore(mongoCfg)) - - return m -} - -// NewManager 创建Manager的实例 +// NewManager Create to authorization management instance func NewManager() *Manager { m := &Manager{ injector: inject.New(), - rtcfg: make(map[oauth2.ResponseType]*Config), gtcfg: make(map[oauth2.GrantType]*Config), } - // 设定参数默认值 - // 设定授权码的有效期为10分钟 - m.SetRTConfig(oauth2.Code, &Config{TokenExp: time.Minute * 10}) - // 设定简化模式授权令牌的有效期为1小时 - m.SetRTConfig(oauth2.Token, &Config{TokenExp: time.Hour * 1}) - - // 设定授权码模式令牌的有效期为2小时,更新令牌的有效期为3天 - m.SetGTConfig(oauth2.AuthorizationCodeCredentials, &Config{TokenExp: time.Hour * 2, RefreshExp: time.Hour * 24 * 3}) - // 设定密码模式令牌的有效期为2小时,更新令牌的有效期为7天 - m.SetGTConfig(oauth2.PasswordCredentials, &Config{TokenExp: time.Hour * 2, RefreshExp: time.Hour * 24 * 7}) - // 设定客户端模式令牌的有效期为1小时 - m.SetGTConfig(oauth2.ClientCredentials, &Config{TokenExp: time.Hour * 2}) + m.SetAuthorizeCodeExp(time.Minute * 10) + m.SetAuthorizeCodeTokenExp(&Config{IsGenerateRefresh: true, AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 3}) + m.SetImplicitTokenExp(&Config{AccessTokenExp: time.Hour * 1}) + m.SetPasswordTokenExp(&Config{IsGenerateRefresh: true, AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 7}) + m.SetClientTokenExp(&Config{AccessTokenExp: time.Hour * 2}) return m } -// Manager OAuth2授权管理 +// Manager Provide authorization management type Manager struct { - injector inject.Injector // 注入器 - rtcfg map[oauth2.ResponseType]*Config // 授权类型配置参数 - gtcfg map[oauth2.GrantType]*Config // 授权模式配置参数 + injector inject.Injector // Dependency injection + codeExp time.Duration // Authorize code expiration time + gtcfg map[oauth2.GrantType]*Config // Authorization grant configuration +} + +// SetAuthorizeCodeExp Set the authorization code expiration time +func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { + m.codeExp = exp +} + +// SetAuthorizeCodeTokenExp Set the authorization code grant token expiration time +func (m *Manager) SetAuthorizeCodeTokenExp(cfg *Config) { + m.gtcfg[oauth2.AuthorizationCode] = cfg } -// SetRTConfig 设定授权类型配置参数 -// rt 授权类型 -// cfg 配置参数 -func (m *Manager) SetRTConfig(rt oauth2.ResponseType, cfg *Config) { - m.rtcfg[rt] = cfg +// SetImplicitTokenExp Set the implicit grant token expiration time +func (m *Manager) SetImplicitTokenExp(cfg *Config) { + m.gtcfg[oauth2.Implicit] = cfg } -// SetGTConfig 设定授权模式配置参数 -// gt 授权模式 -// cfg 配置参数 -func (m *Manager) SetGTConfig(gt oauth2.GrantType, cfg *Config) { - m.gtcfg[gt] = cfg +// SetPasswordTokenExp Set the password grant token expiration time +func (m *Manager) SetPasswordTokenExp(cfg *Config) { + m.gtcfg[oauth2.PasswordCredentials] = cfg } -// MapClientModel 注入客户端信息模型 -func (m *Manager) MapClientModel(cli oauth2.ClientInfo) { +// SetClientTokenExp Set the client grant token expiration time +func (m *Manager) SetClientTokenExp(cfg *Config) { + m.gtcfg[oauth2.ClientCredentials] = cfg +} + +// MapClientModel Mapping the client information model +func (m *Manager) MapClientModel(cli oauth2.ClientInfo) error { if cli == nil { - panic(ErrNilValue) + return errors.ErrNilValue } m.injector.Map(cli) + return nil } -// MapTokenModel 注入令牌信息模型 -func (m *Manager) MapTokenModel(token oauth2.TokenInfo) { +// MapTokenModel Mapping the token information model +func (m *Manager) MapTokenModel(token oauth2.TokenInfo) error { if token == nil { - panic(ErrNilValue) + return errors.ErrNilValue } m.injector.Map(token) + return nil } -// MapAuthorizeGenerate 注入授权令牌生成接口 -func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { +// MapAuthorizeGenerate Mapping the authorize code generate interface +func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) error { if gen == nil { - panic(ErrNilValue) + return errors.ErrNilValue } m.injector.Map(gen) + return nil } -// MapAccessGenerate 注入访问令牌生成接口 -func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { +// MapAccessGenerate Mapping the access token generate interface +func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) error { if gen == nil { - panic(ErrNilValue) + return errors.ErrNilValue } m.injector.Map(gen) + return nil } -// MapClientStorage 注入客户端信息存储接口 -func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { +// MapClientStorage Mapping the client store interface +func (m *Manager) MapClientStorage(stor oauth2.ClientStore) error { if stor == nil { - panic(ErrNilValue) + return errors.ErrNilValue } m.injector.Map(stor) + return nil } -// MustClientStorage 强制注入客户端信息存储接口 +// MustClientStorage Mandatory mapping the client store interface func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { if err != nil { - panic(err) + panic(err.Error()) } if stor == nil { - panic(ErrNilValue) + panic("client store can't be nil value") } m.injector.Map(stor) } -// MapTokenStorage 注入令牌信息存储接口 -func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { +// MapTokenStorage Mapping the token store interface +func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) error { if stor == nil { - panic(ErrNilValue) + return (errors.ErrNilValue) } m.injector.Map(stor) + return nil } -// MustTokenStorage 强制注入令牌信息存储接口 +// MustTokenStorage Mandatory mapping the token store interface func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { if err != nil { panic(err) } if stor == nil { - panic(ErrNilValue) + panic("token store can't be nil value") } m.injector.Map(stor) } -// GetClient 获取客户端信息 -// clientID 客户端标识 +// GetClient Get the client information func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) { _, ierr := m.injector.Invoke(func(stor oauth2.ClientStore) { cli, err = stor.GetByID(clientID) if err != nil { return } else if cli == nil { - err = ErrClientNotFound + err = errors.ErrInvalidClient } }) if err == nil && ierr != nil { @@ -171,9 +170,7 @@ func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) return } -// GenerateAuthToken 生成授权令牌 -// rt 授权类型 -// tgr 生成令牌的配置参数 +// GenerateAuthToken Generate the authorization token(code) func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (authToken oauth2.TokenInfo, err error) { cli, err := m.GetClient(tgr.ClientID) if err != nil { @@ -182,25 +179,33 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen err = verr return } - _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, stor oauth2.TokenStore) { + _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, tgen oauth2.AccessGenerate, stor oauth2.TokenStore) { + var ( + tv string + terr error + ) td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: time.Now(), } - tv, terr := gen.Token(td) + if rt == oauth2.Code { + ti.SetAccessExpiresIn(m.codeExp) + tv, terr = gen.Token(td) + } else { + ti.SetAccessExpiresIn(m.gtcfg[oauth2.Implicit].AccessTokenExp) + tv, _, terr = tgen.Token(td, false) + } if terr != nil { err = terr return } + ti.SetAccess(tv) + ti.SetAccessCreateAt(td.CreateAt) ti.SetClientID(tgr.ClientID) ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) - ti.SetAuthType(rt.String()) - ti.SetAccess(tv) - ti.SetAccessCreateAt(td.CreateAt) - ti.SetAccessExpiresIn(m.rtcfg[rt].TokenExp) err = stor.Create(ti) if err != nil { return @@ -213,19 +218,21 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen return } -// GenerateAccessToken 生成访问令牌、更新令牌 -// gt 授权模式 -// tgr 生成令牌的参数 +// GenerateAccessToken Generate the access token func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (accessToken oauth2.TokenInfo, err error) { - if gt == oauth2.AuthorizationCodeCredentials { // 授权码模式 + if gt == oauth2.AuthorizationCode { ti, terr := m.LoadAccessToken(tgr.Code) if terr != nil { + if terr == errors.ErrInvalidAccessToken { + err = errors.ErrInvalidAuthorizeCode + return + } err = terr return } else if ti.GetRedirectURI() != tgr.RedirectURI || ti.GetClientID() != tgr.ClientID { - err = ErrAuthCodeInvalid + err = errors.ErrInvalidAuthorizeCode return - } else if verr := m.RemoveAccessToken(tgr.Code); verr != nil { // 删除授权码 + } else if verr := m.RemoveAccessToken(tgr.Code); verr != nil { // remove authorize code err = verr return } @@ -235,8 +242,8 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene cli, err := m.GetClient(tgr.ClientID) if err != nil { return - } else if tgr.ClientSecret != "" && tgr.ClientSecret != cli.GetSecret() { - err = ErrClientInvalid + } else if tgr.ClientSecret != cli.GetSecret() { + err = errors.ErrInvalidClient return } _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AccessGenerate, stor oauth2.TokenStore) { @@ -245,7 +252,7 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene UserID: tgr.UserID, CreateAt: time.Now(), } - av, rv, terr := gen.Token(td, tgr.IsGenerateRefresh) + av, rv, terr := gen.Token(td, m.gtcfg[gt].IsGenerateRefresh) if terr != nil { err = terr return @@ -254,13 +261,12 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) - ti.SetAuthType(gt.String()) ti.SetAccessCreateAt(td.CreateAt) - ti.SetAccessExpiresIn(m.gtcfg[gt].TokenExp) + ti.SetAccessExpiresIn(m.gtcfg[gt].AccessTokenExp) ti.SetAccess(av) - if rv != "" { + if m.gtcfg[gt].IsGenerateRefresh && rv != "" { ti.SetRefreshCreateAt(td.CreateAt) - ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshExp) + ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshTokenExp) ti.SetRefresh(rv) } err = stor.Create(ti) @@ -275,28 +281,24 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene return } -// RefreshAccessToken 更新访问令牌 +// RefreshAccessToken Refreshing an access token func (m *Manager) RefreshAccessToken(tgr *oauth2.TokenGenerateRequest) (accessToken oauth2.TokenInfo, err error) { cli, err := m.GetClient(tgr.ClientID) if err != nil { return - } else if tgr.ClientSecret != "" && tgr.ClientSecret != cli.GetSecret() { - err = ErrClientInvalid + } else if tgr.ClientSecret != cli.GetSecret() { + err = errors.ErrInvalidClient return } ti, err := m.LoadRefreshToken(tgr.Refresh) if err != nil { return } else if ti.GetClientID() != tgr.ClientID { - err = ErrRefreshInvalid + err = errors.ErrInvalidRefreshToken return } + oldAccess := ti.GetAccess() _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore, gen oauth2.AccessGenerate) { - // 移除旧的访问令牌 - if verr := stor.RemoveByAccess(ti.GetAccess()); verr != nil { - err = verr - return - } td := &oauth2.GenerateBasic{ Client: cli, UserID: ti.GetUserID(), @@ -316,6 +318,11 @@ func (m *Manager) RefreshAccessToken(tgr *oauth2.TokenGenerateRequest) (accessTo err = verr return } + // remove the old access token + if verr := stor.RemoveByAccess(oldAccess); verr != nil { + err = verr + return + } accessToken = ti }) if ierr != nil && err == nil { @@ -324,10 +331,10 @@ func (m *Manager) RefreshAccessToken(tgr *oauth2.TokenGenerateRequest) (accessTo return } -// RemoveAccessToken 删除访问令牌 +// RemoveAccessToken Use the access token to delete the token information func (m *Manager) RemoveAccessToken(access string) (err error) { if access == "" { - err = ErrAccessInvalid + err = errors.ErrInvalidAccessToken return } _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) { @@ -339,10 +346,10 @@ func (m *Manager) RemoveAccessToken(access string) (err error) { return } -// RemoveRefreshToken 删除更新令牌 +// RemoveRefreshToken Use the refresh token to delete the token information func (m *Manager) RemoveRefreshToken(refresh string) (err error) { if refresh == "" { - err = ErrAccessInvalid + err = errors.ErrInvalidAccessToken return } _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) { @@ -354,10 +361,10 @@ func (m *Manager) RemoveRefreshToken(refresh string) (err error) { return } -// LoadAccessToken 加载访问令牌信息 +// LoadAccessToken According to the access token for corresponding token information func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err error) { if access == "" { - err = ErrAccessInvalid + err = errors.ErrInvalidAccessToken return } _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) { @@ -367,12 +374,12 @@ func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err err err = terr return } else if ti == nil { - err = ErrAccessInvalid + err = errors.ErrInvalidAccessToken return - } else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { // 检查更新令牌是否过期 - err = ErrRefreshExpired - } else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { // 检查访问令牌是否过期 - err = ErrAccessExpired + } else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { + err = errors.ErrExpiredRefreshToken + } else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { + err = errors.ErrExpiredAccessToken return } info = ti @@ -383,10 +390,10 @@ func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err err return } -// LoadRefreshToken 加载更新令牌信息 +// LoadRefreshToken According to the refresh token for corresponding token information func (m *Manager) LoadRefreshToken(refresh string) (info oauth2.TokenInfo, err error) { if refresh == "" { - err = ErrRefreshInvalid + err = errors.ErrInvalidRefreshToken return } _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) { @@ -395,10 +402,10 @@ func (m *Manager) LoadRefreshToken(refresh string) (info oauth2.TokenInfo, err e err = terr return } else if ti == nil { - err = ErrRefreshInvalid + err = errors.ErrInvalidRefreshToken return } else if ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { - err = ErrRefreshExpired + err = errors.ErrExpiredRefreshToken return } info = ti diff --git a/manage/util.go b/manage/util.go index 2677f80..beed15d 100644 --- a/manage/util.go +++ b/manage/util.go @@ -1,25 +1,24 @@ package manage import ( - "errors" "net/url" + "strings" + + "gopkg.in/oauth2.v3/errors" ) -// ValidateURI 校验重定向的URI与域名的一致性 -func ValidateURI(domain string, redirectURI string) error { +// ValidateURI Validates that RedirectURI is contained in domain +func ValidateURI(domain string, redirectURI string) (err error) { base, err := url.Parse(domain) if err != nil { - return err + return } redirect, err := url.Parse(redirectURI) if err != nil { - return err - } else if base.Fragment != "" || redirect.Fragment != "" { - return errors.New("Url must not include fragment.") - } else if base.Scheme != redirect.Scheme { - return errors.New("Scheme don't match.") - } else if base.Host != redirect.Host { - return errors.New("Host don't match.") + return + } + if !strings.HasSuffix(redirect.Host, base.Host) { + err = errors.ErrInvalidRedirectURI } - return nil + return } diff --git a/manage/util_test.go b/manage/util_test.go index f936cbd..1faa89e 100644 --- a/manage/util_test.go +++ b/manage/util_test.go @@ -1,15 +1,17 @@ -package manage +package manage_test import ( "testing" + "gopkg.in/oauth2.v3/manage" + . "github.com/smartystreets/goconvey/convey" ) func TestUtil(t *testing.T) { Convey("Util Test", t, func() { Convey("ValidateURI Test", func() { - err := ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") + err := manage.ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") So(err, ShouldBeNil) }) }) diff --git a/model.go b/model.go index 163a4b4..9db768b 100644 --- a/model.go +++ b/model.go @@ -2,67 +2,62 @@ package oauth2 import "time" -// 相关模型接口的定义 type ( - // ClientInfo 客户端信息模型接口 + // ClientInfo The client information model interface ClientInfo interface { - // 客户端ID + // The client id GetID() string - // 客户端秘钥 + // The client secret GetSecret() string - // 客户端域名URL + // The client domain GetDomain() string - // 扩展数据 + // The extension data related to the client GetExtraData() interface{} } - // TokenInfo 令牌信息模型接口 + // TokenInfo The token information model interface TokenInfo interface { - // 客户端ID + // Get client id GetClientID() string - // 设置客户端ID + // Set client id SetClientID(string) - // 用户ID + // Get user id GetUserID() string - // 设置用户ID + // Set user id SetUserID(string) - // 重定向URI + // Get Redirect URI GetRedirectURI() string - // 设置重定向URI + // Set Redirect URI SetRedirectURI(string) - // 权限范围 + // Get Scope of authorization GetScope() string - // 设置权限范围 + // Set Scope of authorization SetScope(string) - // 令牌授权类型 - GetAuthType() string - // 设置令牌授权类型 - SetAuthType(string) - // 访问令牌(或授权令牌) + // Get Access Token GetAccess() string - // 设置访问令牌(或授权令牌) + // Set Access Token SetAccess(string) - // 访问令牌(或授权令牌)创建时间 + // Get Create Time GetAccessCreateAt() time.Time - // 设置访问令牌(或授权令牌)创建时间 + // Set Create Time SetAccessCreateAt(time.Time) - // 访问令牌(或授权令牌)有效期 + // Get The lifetime in seconds of the access token GetAccessExpiresIn() time.Duration - // 设置访问令牌(或授权令牌)有效期 + // Set The lifetime in seconds of the access token SetAccessExpiresIn(time.Duration) - // 更新令牌 + // Get Refresh Token GetRefresh() string - // 设置更新令牌 + // Set Refresh Token SetRefresh(string) - // 更新令牌创建时间 + // Get Create Time GetRefreshCreateAt() time.Time - // 设置更新令牌创建时间 + // Set Create Time SetRefreshCreateAt(time.Time) - // 更新令牌有效期 + // Get The lifetime in seconds of the access token GetRefreshExpiresIn() time.Duration - // 设置更新令牌有效期 + // Set The lifetime in seconds of the access token SetRefreshExpiresIn(time.Duration) } ) diff --git a/models/client.go b/models/client.go index 9f6871b..35199e6 100644 --- a/models/client.go +++ b/models/client.go @@ -1,33 +1,33 @@ package models -// NewClient 创建客户端模型实例 +// NewClient Create to client model instance func NewClient() *Client { return &Client{} } -// Client 客户端信息 +// Client Client model type Client struct { - ID string // 客户端ID - Secret string // 密钥 - Domain string // 域名url + ID string // The client id + Secret string // The client secret + Domain string // The client domain } -// GetID 客户端ID +// GetID The client id func (c *Client) GetID() string { return c.ID } -// GetSecret 客户端秘钥 +// GetSecret The client domain func (c *Client) GetSecret() string { return c.Secret } -// GetDomain 域名URL +// GetDomain The client domain func (c *Client) GetDomain() string { return c.Domain } -// GetExtraData 扩展数据 +// GetExtraData The extension data related to the client func (c *Client) GetExtraData() interface{} { return nil } diff --git a/models/token.go b/models/token.go index 60b13c5..6cac968 100644 --- a/models/token.go +++ b/models/token.go @@ -2,132 +2,121 @@ package models import "time" -// NewToken 创建令牌模型实例 +// NewToken Create to token model instance func NewToken() *Token { return &Token{} } -// Token 令牌信息 +// Token Token model type Token struct { - ClientID string `bson:"ClientID"` // 客户端标识 - UserID string `bson:"UserID"` // 用户标识 - RedirectURI string `bson:"RedirectURI"` // 重定向URI - Scope string `bson:"Scope"` // 权限范围 - AuthType string `bson:"AuthType"` // 令牌授权类型 - Access string `bson:"Access"` // 访问令牌 - AccessCreateAt time.Time `bson:"AccessCreateAt"` // 访问令牌创建时间 - AccessExpiresIn time.Duration `bson:"AccessExpiresIn"` // 访问令牌有效期 - Refresh string `bson:"Refresh"` // 更新令牌 - RefreshCreateAt time.Time `bson:"RefreshCreateAt"` // 更新令牌创建时间 - RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` // 更新令牌有效期 -} - -// GetClientID 客户端ID + ClientID string `bson:"ClientID"` // The client id + UserID string `bson:"UserID"` // The user id + RedirectURI string `bson:"RedirectURI"` // Redirect URI + Scope string `bson:"Scope"` // Scope of authorization + Access string `bson:"Access"` // Access Token + AccessCreateAt time.Time `bson:"AccessCreateAt"` // Create Time + AccessExpiresIn time.Duration `bson:"AccessExpiresIn"` // The lifetime in seconds of the access token + Refresh string `bson:"Refresh"` // Refresh Token + RefreshCreateAt time.Time `bson:"RefreshCreateAt"` // Create Time + RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` // The lifetime in seconds of the access token +} + +// GetClientID The client id func (t *Token) GetClientID() string { return t.ClientID } -// SetClientID 设置客户端ID +// SetClientID The client id func (t *Token) SetClientID(clientID string) { t.ClientID = clientID } -// GetUserID 用户ID +// GetUserID The user id func (t *Token) GetUserID() string { return t.UserID } -// SetUserID 设置用户ID +// SetUserID The user id func (t *Token) SetUserID(userID string) { t.UserID = userID } -// GetRedirectURI 重定向URI +// GetRedirectURI Redirect URI func (t *Token) GetRedirectURI() string { return t.RedirectURI } -// SetRedirectURI 设置重定向URI +// SetRedirectURI Redirect URI func (t *Token) SetRedirectURI(redirectURI string) { t.RedirectURI = redirectURI } -// GetScope 权限范围 +// GetScope Get Scope of authorization func (t *Token) GetScope() string { return t.Scope } -// SetScope 设置权限范围 +// SetScope Get Scope of authorization func (t *Token) SetScope(scope string) { t.Scope = scope } -// GetAuthType 授权类型 -func (t *Token) GetAuthType() string { - return t.AuthType -} - -// SetAuthType 设置授权类型 -func (t *Token) SetAuthType(authType string) { - t.AuthType = authType -} - -// GetAccess 访问令牌 +// GetAccess Access Token func (t *Token) GetAccess() string { return t.Access } -// SetAccess 设置访问令牌 +// SetAccess Access Token func (t *Token) SetAccess(access string) { t.Access = access } -// GetAccessCreateAt 访问令牌创建时间 +// GetAccessCreateAt Create Time func (t *Token) GetAccessCreateAt() time.Time { return t.AccessCreateAt } -// SetAccessCreateAt 设置访问令牌创建时间 +// SetAccessCreateAt Create Time func (t *Token) SetAccessCreateAt(createAt time.Time) { t.AccessCreateAt = createAt } -// GetAccessExpiresIn 访问令牌有效期 +// GetAccessExpiresIn The lifetime in seconds of the access token func (t *Token) GetAccessExpiresIn() time.Duration { return t.AccessExpiresIn } -// SetAccessExpiresIn 设置访问令牌有效期 +// SetAccessExpiresIn The lifetime in seconds of the access token func (t *Token) SetAccessExpiresIn(exp time.Duration) { t.AccessExpiresIn = exp } -// GetRefresh 更新令牌 +// GetRefresh Refresh Token func (t *Token) GetRefresh() string { return t.Refresh } -// SetRefresh 设置更新令牌 +// SetRefresh Refresh Token func (t *Token) SetRefresh(refresh string) { t.Refresh = refresh } -// GetRefreshCreateAt 更新令牌创建时间 +// GetRefreshCreateAt Create Time func (t *Token) GetRefreshCreateAt() time.Time { return t.RefreshCreateAt } -// SetRefreshCreateAt 设置更新令牌创建时间 +// SetRefreshCreateAt Create Time func (t *Token) SetRefreshCreateAt(createAt time.Time) { t.RefreshCreateAt = createAt } -// GetRefreshExpiresIn 更新令牌有效期 +// GetRefreshExpiresIn The lifetime in seconds of the access token func (t *Token) GetRefreshExpiresIn() time.Duration { return t.RefreshExpiresIn } -// SetRefreshExpiresIn 设置更新令牌有效期 +// SetRefreshExpiresIn The lifetime in seconds of the access token func (t *Token) SetRefreshExpiresIn(exp time.Duration) { t.RefreshExpiresIn = exp } diff --git a/server/authorize.go b/server/authorize.go deleted file mode 100644 index a018e37..0000000 --- a/server/authorize.go +++ /dev/null @@ -1,99 +0,0 @@ -package server - -import ( - "encoding/base64" - "net/http" - "strings" - - "github.com/valyala/fasthttp" - "gopkg.in/oauth2.v2" -) - -// AuthorizeRequest 授权请求 -type AuthorizeRequest struct { - Type oauth2.ResponseType // 授权类型 - ClientID string // 客户端标识 - Scope string // 授权范围 - RedirectURI string // 重定向URI - State string // 状态 - UserID string // 用户标识 -} - -// TokenRequestHandler 令牌请求处理 -type TokenRequestHandler struct { - // 客户端信息处理 - ClientHandler ClientHandler - // 客户端信息处理(基于fasthttp) - ClientFastHandler ClientFastHandler - // 用户信息处理 - UserHandler UserHandler - // 授权范围处理 - ScopeHandler ScopeHandler -} - -// ClientHandler 获取请求的客户端认证信息 -type ClientHandler func(r *http.Request) (clientID, clientSecret string, err error) - -// ClientFormHandler 客户端表单信息 -func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) { - clientID = r.Form.Get("client_id") - clientSecret = r.Form.Get("client_secret") - if clientID == "" || clientSecret == "" { - err = ErrAuthorizationFormInvalid - } - return -} - -// ClientBasicHandler 客户端基础认证信息 -func ClientBasicHandler(r *http.Request) (clientID, clientSecret string, err error) { - username, password, ok := r.BasicAuth() - if !ok { - err = ErrAuthorizationHeaderInvalid - return - } - clientID = username - clientSecret = password - return -} - -// ClientFastHandler 基于fasthttp获取客户端认证信息 -type ClientFastHandler func(ctx *fasthttp.RequestCtx) (clientID, clientSecret string, err error) - -// ClientFormFastHandler 客户端表单信息(基于fasthttp) -func ClientFormFastHandler(ctx *fasthttp.RequestCtx) (clientID, clientSecret string, err error) { - clientID = string(ctx.FormValue("client_id")) - clientSecret = string(ctx.FormValue("client_secret")) - if clientID == "" || clientSecret == "" { - err = ErrAuthorizationFormInvalid - } - return -} - -// ClientBasicFastHandler 客户端基础认证信息(基于fasthttp) -func ClientBasicFastHandler(ctx *fasthttp.RequestCtx) (clientID, clientSecret string, err error) { - auth := string(ctx.Request.Header.Peek("Authorization")) - const prefix = "Basic " - if auth == "" || !strings.HasPrefix(auth, prefix) { - err = ErrAuthorizationHeaderInvalid - return - } - c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) - if err != nil { - return - } - cs := string(c) - s := strings.IndexByte(cs, ':') - if s < 0 { - err = ErrAuthorizationHeaderInvalid - return - } - clientID = cs[:s] - clientSecret = cs[s+1:] - return -} - -// UserHandler 密码模式下,根据用户名、密码获取用户标识 -type UserHandler func(username, password string) (userID string, err error) - -// ScopeHandler 更新令牌时的授权范围检查 -type ScopeHandler func(new, old string) (err error) diff --git a/server/config.go b/server/config.go index eb2d923..f76e68f 100644 --- a/server/config.go +++ b/server/config.go @@ -1,25 +1,24 @@ package server -import "gopkg.in/oauth2.v2" +import "gopkg.in/oauth2.v3" -// Config 配置参数 +// Config Configuration parameters type Config struct { - // TokenType 令牌类型(默认为Bearer) - TokenType string - // AllowedResponseType 允许的授权类型(默认code) - AllowedResponseType []oauth2.ResponseType - // AllowedGrantType 允许的授权模式(默认authorization_code) - AllowedGrantType []oauth2.GrantType - // Handler 令牌请求处理 - Handler *TokenRequestHandler + TokenType string // TokenType token type(Default is Bearer) + AllowedResponseTypes []oauth2.ResponseType // Allow the authorization type(Default is all) + AllowedGrantTypes []oauth2.GrantType // Allow the grant type(Default is all) } -// NewConfig 创建默认的配置参数 +// NewConfig Create to configuration instance func NewConfig() *Config { return &Config{ - TokenType: "Bearer", - AllowedResponseType: []oauth2.ResponseType{oauth2.Code}, - AllowedGrantType: []oauth2.GrantType{oauth2.AuthorizationCodeCredentials}, - Handler: &TokenRequestHandler{}, + TokenType: "Bearer", + AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code, oauth2.Token}, + AllowedGrantTypes: []oauth2.GrantType{ + oauth2.AuthorizationCode, + oauth2.PasswordCredentials, + oauth2.ClientCredentials, + oauth2.Refreshing, + }, } } diff --git a/server/error.go b/server/error.go deleted file mode 100644 index f138392..0000000 --- a/server/error.go +++ /dev/null @@ -1,29 +0,0 @@ -package server - -import "errors" - -var ( - // ErrRequestMethodInvalid Request method invalid - ErrRequestMethodInvalid = errors.New("request method invalid") - - // ErrResponseTypeInvalid Response type invalid - ErrResponseTypeInvalid = errors.New("response type invalid") - - // ErrGrantTypeInvalid Grant type invalid - ErrGrantTypeInvalid = errors.New("grant type invalid") - - // ErrClientInvalid Client invalid - ErrClientInvalid = errors.New("client invalid") - - // ErrUserInvalid User invalid - ErrUserInvalid = errors.New("user invalid") - - // ErrAuthorizationFormInvalid Authorization form invalid - ErrAuthorizationFormInvalid = errors.New("authorization form invalid") - - // ErrAuthorizationHeaderInvalid Authorization header invalid - ErrAuthorizationHeaderInvalid = errors.New("authorization header invalid") - - // ErrRefreshInvalid Refresh token invalid - ErrRefreshInvalid = errors.New("refresh token invalid") -) diff --git a/server/fastserver.go b/server/fastserver.go deleted file mode 100644 index 4203b7a..0000000 --- a/server/fastserver.go +++ /dev/null @@ -1,175 +0,0 @@ -package server - -import ( - "encoding/json" - "net/url" - "time" - - "github.com/valyala/fasthttp" - "gopkg.in/oauth2.v2" -) - -// NewFastServer 创建基于fasthttp的OAuth2服务实例 -func NewFastServer(cfg *Config, manager oauth2.Manager) *FastServer { - srv := &FastServer{} - srv.cfg = cfg - srv.manager = manager - srv.SetClientHandler(ClientFormFastHandler) - return srv -} - -// FastServer 基于fasthttp(https://github.com/valyala/fasthttp)的OAuth2服务处理 -type FastServer struct { - Server -} - -// SetClientHandler 设置客户端处理 -func (fs *FastServer) SetClientHandler(handler ClientFastHandler) { - fs.cfg.Handler.ClientFastHandler = handler -} - -// GetAuthorizeRequest 获取授权请求参数 -func (fs *FastServer) GetAuthorizeRequest(ctx *fasthttp.RequestCtx) (authReq *AuthorizeRequest, err error) { - if !ctx.IsGet() { - err = ErrRequestMethodInvalid - return - } - redirectURI, err := url.QueryUnescape(string(ctx.FormValue("redirect_uri"))) - if err != nil { - return - } - authReq = &AuthorizeRequest{ - Type: oauth2.ResponseType(string(ctx.FormValue("response_type"))), - RedirectURI: redirectURI, - State: string(ctx.FormValue("state")), - Scope: string(ctx.FormValue("scope")), - ClientID: string(ctx.FormValue("client_id")), - } - if authReq.Type == "" || !fs.checkResponseType(authReq.Type) { - err = ErrResponseTypeInvalid - } else if authReq.ClientID == "" { - err = ErrClientInvalid - } - return -} - -// HandleAuthorizeRequest 处理授权请求 -func (fs *FastServer) HandleAuthorizeRequest(ctx *fasthttp.RequestCtx, authReq *AuthorizeRequest) (err error) { - if authReq.UserID == "" { - err = ErrUserInvalid - return - } - tgr := &oauth2.TokenGenerateRequest{ - ClientID: authReq.ClientID, - UserID: authReq.UserID, - RedirectURI: authReq.RedirectURI, - Scope: authReq.Scope, - } - ti, terr := fs.manager.GenerateAuthToken(oauth2.Code, tgr) - if terr != nil { - err = terr - return - } - redirectURI, err := fs.GetRedirectURI(authReq, ti) - if err != nil { - return - } - ctx.Redirect(redirectURI, 302) - return -} - -// HandleTokenRequest 处理令牌请求 -func (fs *FastServer) HandleTokenRequest(ctx *fasthttp.RequestCtx) (err error) { - if !ctx.IsPost() { - err = ErrRequestMethodInvalid - return - } - gt := oauth2.GrantType(string(ctx.FormValue("grant_type"))) - if gt == "" || !fs.checkGrantType(gt) { - err = ErrGrantTypeInvalid - return - } - - var ti oauth2.TokenInfo - clientID, clientSecret, err := fs.cfg.Handler.ClientFastHandler(ctx) - if err != nil { - return - } - tgr := &oauth2.TokenGenerateRequest{ - ClientID: clientID, - ClientSecret: clientSecret, - } - - switch gt { - case oauth2.AuthorizationCodeCredentials: - tgr.RedirectURI = string(ctx.FormValue("redirect_uri")) - tgr.Code = string(ctx.FormValue("code")) - tgr.IsGenerateRefresh = true - ti, err = fs.manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, tgr) - case oauth2.PasswordCredentials: - userID, uerr := fs.cfg.Handler.UserHandler(string(ctx.FormValue("username")), string(ctx.FormValue("password"))) - if uerr != nil { - err = uerr - return - } - tgr.UserID = userID - tgr.Scope = string(ctx.FormValue("scope")) - tgr.IsGenerateRefresh = true - ti, err = fs.manager.GenerateAccessToken(oauth2.PasswordCredentials, tgr) - case oauth2.ClientCredentials: - tgr.Scope = string(ctx.FormValue("scope")) - ti, err = fs.manager.GenerateAccessToken(oauth2.ClientCredentials, tgr) - case oauth2.RefreshCredentials: - tgr.Refresh = string(ctx.FormValue("refresh_token")) - tgr.Scope = string(ctx.FormValue("scope")) - if tgr.Scope != "" { // 检查授权范围 - rti, rerr := fs.manager.LoadRefreshToken(tgr.Refresh) - if rerr != nil { - err = rerr - return - } else if rti.GetClientID() != tgr.ClientID { - err = ErrRefreshInvalid - return - } else if verr := fs.cfg.Handler.ScopeHandler(tgr.Scope, rti.GetScope()); verr != nil { - err = verr - return - } - } - ti, err = fs.manager.RefreshAccessToken(tgr) - if err == nil { - ti.SetRefresh("") - } - } - - if err != nil { - return - } - err = fs.ResJSON(ctx, ti) - return -} - -// ResJSON 响应Json数据 -func (fs *FastServer) ResJSON(ctx *fasthttp.RequestCtx, ti oauth2.TokenInfo) (err error) { - data := map[string]interface{}{ - "access_token": ti.GetAccess(), - "token_type": fs.cfg.TokenType, - "expires_in": ti.GetAccessExpiresIn() / time.Second, - } - if scope := ti.GetScope(); scope != "" { - data["scope"] = scope - } - if refresh := ti.GetRefresh(); refresh != "" { - data["refresh_token"] = refresh - } - buf, err := json.Marshal(data) - if err != nil { - return - } - ctx.Response.Header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") - ctx.Response.Header.Set("Pragma", "no-cache") - ctx.Response.Header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") - ctx.SetContentType("application/json;charset=UTF-8") - ctx.SetStatusCode(200) - _, err = ctx.Write(buf) - return nil -} diff --git a/server/handler.go b/server/handler.go new file mode 100644 index 0000000..42b721a --- /dev/null +++ b/server/handler.go @@ -0,0 +1,51 @@ +package server + +import ( + "net/http" + + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/errors" +) + +// ClientInfoHandler Get client info from request +type ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) + +// ClientAuthorizedHandler Check the client allows to use this authorization grant type +type ClientAuthorizedHandler func(clientID string, grantType oauth2.GrantType) (allowed bool, err error) + +// ClientScopeHandler Check the client allows to use scope +type ClientScopeHandler func(clientID, scope string) (allowed bool, err error) + +// UserAuthorizationHandler Get user id from request authorization +type UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) + +// PasswordAuthorizationHandler Get user id from username and password +type PasswordAuthorizationHandler func(username, password string) (userID string, err error) + +// RefreshingScopeHandler Check the scope of the refreshing token +type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool) + +// ErrorHandler Error handling +type ErrorHandler func(err error) + +// ClientFormHandler Get client data from form +func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) { + clientID = r.Form.Get("client_id") + clientSecret = r.Form.Get("client_secret") + if clientID == "" || clientSecret == "" { + err = errors.ErrInvalidRequest + } + return +} + +// ClientBasicHandler Get client data from basic authorization +func ClientBasicHandler(r *http.Request) (clientID, clientSecret string, err error) { + username, password, ok := r.BasicAuth() + if !ok { + err = errors.ErrInvalidClient + return + } + clientID = username + clientSecret = password + return +} diff --git a/server/request.go b/server/request.go new file mode 100644 index 0000000..1ded3a3 --- /dev/null +++ b/server/request.go @@ -0,0 +1,13 @@ +package server + +import "gopkg.in/oauth2.v3" + +// AuthorizeRequest The authorization request +type AuthorizeRequest struct { + ResponseType oauth2.ResponseType + ClientID string + Scope string + RedirectURI string + State string + UserID string +} diff --git a/server/server.go b/server/server.go index 15db1db..52c0ca1 100644 --- a/server/server.go +++ b/server/server.go @@ -2,63 +2,86 @@ package server import ( "encoding/json" + "fmt" "net/http" "net/url" - "strconv" "time" - "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/errors" ) -// NewServer 创建OAuth2服务实例 +// NewServer Create to authorization server instance func NewServer(cfg *Config, manager oauth2.Manager) *Server { srv := &Server{ - cfg: cfg, - manager: manager, + Config: cfg, + Manager: manager, + ClientInfoHandler: ClientFormHandler, } - srv.SetClientHandler(ClientFormHandler) return srv } -// Server OAuth2服务处理 +// Server Provide authorization server type Server struct { - cfg *Config - manager oauth2.Manager + Config *Config + Manager oauth2.Manager + ClientInfoHandler ClientInfoHandler + ClientAuthorizedHandler ClientAuthorizedHandler + ClientScopeHandler ClientScopeHandler + UserAuthorizationHandler UserAuthorizationHandler + PasswordAuthorizationHandler PasswordAuthorizationHandler + RefreshingScopeHandler RefreshingScopeHandler + ErrorHandler ErrorHandler } -// SetTokenType 设置令牌类型 -func (s *Server) SetTokenType(tokenType string) { - s.cfg.TokenType = tokenType +// SetAllowedResponseType Allow the authorization types +func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) { + s.Config.AllowedResponseTypes = types } -// SetAllowedResponseType 设置允许的授权类型 -func (s *Server) SetAllowedResponseType(allowedTypes ...oauth2.ResponseType) { - s.cfg.AllowedResponseType = allowedTypes +// SetAllowedGrantType Allow the grant types +func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) { + s.Config.AllowedGrantTypes = types } -// SetAllowedGrantType 允许的授权模式 -func (s *Server) SetAllowedGrantType(allowedTypes ...oauth2.GrantType) { - s.cfg.AllowedGrantType = allowedTypes +// SetClientInfoHandler Get client info from request +func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) { + s.ClientInfoHandler = handler } -// SetClientHandler 设置客户端处理 -func (s *Server) SetClientHandler(handler ClientHandler) { - s.cfg.Handler.ClientHandler = handler +// SetClientAuthorizedHandler Check the client allows to use this authorization grant type +func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) { + s.ClientAuthorizedHandler = handler } -// SetUserHandler 设置用户处理 -func (s *Server) SetUserHandler(handler UserHandler) { - s.cfg.Handler.UserHandler = handler +// SetClientScopeHandler Check the client allows to use scope +func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) { + s.ClientScopeHandler = handler } -// SetScopeHandler 设置授权范围处理 -func (s *Server) SetScopeHandler(handler ScopeHandler) { - s.cfg.Handler.ScopeHandler = handler +// SetUserAuthorizationHandler Get user id from request authorization +func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) { + s.UserAuthorizationHandler = handler } -// checkResponseType 检查允许的授权类型 -func (s *Server) checkResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.cfg.AllowedResponseType { +// SetPasswordAuthorizationHandler Get user id from username and password +func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) { + s.PasswordAuthorizationHandler = handler +} + +// SetRefreshingScopeHandler Check the scope of the refreshing token +func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) { + s.RefreshingScopeHandler = handler +} + +// SetErrorHandler Error handling +func (s *Server) SetErrorHandler(handler ErrorHandler) { + s.ErrorHandler = handler +} + +// CheckResponseType Check allows response type +func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { + for _, art := range s.Config.AllowedResponseTypes { if art == rt { return true } @@ -66,9 +89,9 @@ func (s *Server) checkResponseType(rt oauth2.ResponseType) bool { return false } -// checkGrantType 检查允许的授权模式 -func (s *Server) checkGrantType(gt oauth2.GrantType) bool { - for _, agt := range s.cfg.AllowedGrantType { +// CheckGrantType Check allows grant type +func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { + for _, agt := range s.Config.AllowedGrantTypes { if agt == gt { return true } @@ -76,75 +99,106 @@ func (s *Server) checkGrantType(gt oauth2.GrantType) bool { return false } -// GetAuthorizeRequest 获取授权请求参数 -func (s *Server) GetAuthorizeRequest(r *http.Request) (authReq *AuthorizeRequest, err error) { - if r.Method != "GET" { - err = ErrRequestMethodInvalid +// ValidationAuthorizeRequest The authorization request validation +func (s *Server) ValidationAuthorizeRequest(r *http.Request) (req *AuthorizeRequest, rerr, ierr error) { + if err := r.ParseForm(); err != nil { + ierr = err return } - r.ParseForm() redirectURI, err := url.QueryUnescape(r.Form.Get("redirect_uri")) if err != nil { + ierr = err return } - authReq = &AuthorizeRequest{ - Type: oauth2.ResponseType(r.Form.Get("response_type")), - RedirectURI: redirectURI, - State: r.Form.Get("state"), - Scope: r.Form.Get("scope"), - ClientID: r.Form.Get("client_id"), + req = &AuthorizeRequest{ + RedirectURI: redirectURI, + ResponseType: oauth2.ResponseType(r.Form.Get("response_type")), + ClientID: r.Form.Get("client_id"), + State: r.Form.Get("state"), + Scope: r.Form.Get("scope"), } - if authReq.Type == "" || !s.checkResponseType(authReq.Type) { - err = ErrResponseTypeInvalid - return - } else if authReq.ClientID == "" { - err = ErrClientInvalid + if r.Method != "GET" { + rerr = errors.ErrInvalidRequest } return } -// HandleAuthorizeRequest 处理授权请求 -func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, authReq *AuthorizeRequest) (err error) { - if authReq.UserID == "" { - err = ErrUserInvalid +// GetAuthorizeToken Get authorization token(code) +func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo, rerr, ierr error) { + if req.RedirectURI == "" || + req.ClientID == "" || + req.UserID == "" { + rerr = errors.ErrInvalidRequest + return + } else if req.ResponseType == "" { + rerr = errors.ErrUnsupportedResponseType return } - tgr := &oauth2.TokenGenerateRequest{ - ClientID: authReq.ClientID, - UserID: authReq.UserID, - RedirectURI: authReq.RedirectURI, - Scope: authReq.Scope, - } - ti, err := s.manager.GenerateAuthToken(oauth2.Code, tgr) - if err != nil { + if allowed := s.CheckResponseType(req.ResponseType); !allowed { + rerr = errors.ErrUnauthorizedClient return } - redirectURI, err := s.GetRedirectURI(authReq, ti) + if fn := s.ClientAuthorizedHandler; fn != nil { + gt := oauth2.AuthorizationCode + if req.ResponseType == oauth2.Token { + gt = oauth2.Implicit + } + allowed, err := fn(req.ClientID, gt) + if err != nil { + ierr = err + return + } + if !allowed { + rerr = errors.ErrUnauthorizedClient + return + } + } + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(req.ClientID, req.Scope) + if err != nil { + ierr = err + return + } + if !allowed { + rerr = errors.ErrInvalidScope + return + } + } + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + } + ti, err := s.Manager.GenerateAuthToken(req.ResponseType, tgr) if err != nil { - return + if err == errors.ErrInvalidClient { + rerr = err + } else { + ierr = err + } } - w.Header().Set("Location", redirectURI) - w.WriteHeader(302) return } -// GetRedirectURI 获取重定向URI -func (s *Server) GetRedirectURI(authReq *AuthorizeRequest, ti oauth2.TokenInfo) (uri string, err error) { - u, err := url.Parse(authReq.RedirectURI) +// GetRedirectURI Get redirect uri +func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (uri string, err error) { + if req == nil { + return + } + u, err := url.Parse(req.RedirectURI) if err != nil { return } q := u.Query() - q.Set("state", authReq.State) - switch authReq.Type { + q.Set("state", req.State) + for k, v := range data { + q.Set(k, fmt.Sprint(v)) + } + switch req.ResponseType { case oauth2.Code: - q.Set("code", ti.GetAccess()) u.RawQuery = q.Encode() case oauth2.Token: - q.Set("access_token", ti.GetAccess()) - q.Set("token_type", s.cfg.TokenType) - q.Set("expires_in", strconv.FormatInt(int64(ti.GetAccessExpiresIn()/time.Second), 10)) - q.Set("scope", ti.GetScope()) u.RawQuery = "" u.Fragment, err = url.QueryUnescape(q.Encode()) if err != nil { @@ -155,81 +209,229 @@ func (s *Server) GetRedirectURI(authReq *AuthorizeRequest, ti oauth2.TokenInfo) return } -// HandleTokenRequest 处理令牌请求 -func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err error) { +// GetAuthorizeData Get authorization response data +func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) (data map[string]interface{}) { + if rt == oauth2.Code { + data = map[string]interface{}{ + "code": ti.GetAccess(), + } + } else { + data = s.GetTokenData(ti) + } + return +} + +// GetErrorData Get error response data +func (s *Server) GetErrorData(rerr, ierr error) (data map[string]interface{}) { + var err error + if ierr != nil { + rerr = errors.ErrServerError + err = ierr + } else if rerr != nil { + err = rerr + ierr = rerr + } + if err == nil { + return + } + if fn := s.ErrorHandler; fn != nil { + s.ErrorHandler(err) + } + data = map[string]interface{}{ + "error": err.Error(), + } + return +} + +// HandleAuthorizeRequest The authorization request handling +func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) (err error) { + var ( + ti oauth2.TokenInfo + req *AuthorizeRequest + rerr error + ierr error + ) + defer func() { + if verr := recover(); verr != nil { + err = fmt.Errorf("%v", verr) + return + } + data := s.GetErrorData(rerr, ierr) + if data != nil { + if req == nil { + err = ierr + return + } + } else { + data = s.GetAuthorizeData(req.ResponseType, ti) + } + uri, verr := s.GetRedirectURI(req, data) + if verr != nil { + err = verr + return + } + w.Header().Set("Location", uri) + w.WriteHeader(302) + }() + req, rerr, ierr = s.ValidationAuthorizeRequest(r) + if rerr != nil || ierr != nil { + return + } + userID, err := s.UserAuthorizationHandler(w, r) + if err != nil { + ierr = err + return + } + req.UserID = userID + ti, rerr, ierr = s.GetAuthorizeToken(req) + return +} + +// ValidationTokenRequest The token request validation +func (s *Server) ValidationTokenRequest(r *http.Request) (gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest, rerr, ierr error) { if r.Method != "POST" { - err = ErrRequestMethodInvalid + rerr = errors.ErrInvalidRequest return } - r.ParseForm() - gt := oauth2.GrantType(r.Form.Get("grant_type")) - if gt == "" || !s.checkGrantType(gt) { - err = ErrGrantTypeInvalid + if err := r.ParseForm(); err != nil { + ierr = err return } - var ti oauth2.TokenInfo - clientID, clientSecret, err := s.cfg.Handler.ClientHandler(r) + gt = oauth2.GrantType(r.Form.Get("grant_type")) + if gt == "" { + rerr = errors.ErrUnsupportedGrantType + return + } + clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { + ierr = err return } - tgr := &oauth2.TokenGenerateRequest{ + tgr = &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, } switch gt { - case oauth2.AuthorizationCodeCredentials: + case oauth2.AuthorizationCode: tgr.RedirectURI = r.Form.Get("redirect_uri") tgr.Code = r.Form.Get("code") - tgr.IsGenerateRefresh = true - ti, err = s.manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, tgr) + if tgr.RedirectURI == "" || + tgr.Code == "" { + rerr = errors.ErrInvalidRequest + } case oauth2.PasswordCredentials: - userID, uerr := s.cfg.Handler.UserHandler(r.Form.Get("username"), r.Form.Get("password")) - if uerr != nil { - err = uerr + tgr.Scope = r.Form.Get("scope") + userID, verr := s.PasswordAuthorizationHandler(r.Form.Get("username"), r.Form.Get("password")) + if verr != nil { + ierr = verr + return + } + if userID == "" { + rerr = errors.ErrInvalidRequest return } tgr.UserID = userID - tgr.Scope = r.Form.Get("scope") - tgr.IsGenerateRefresh = true - ti, err = s.manager.GenerateAccessToken(oauth2.PasswordCredentials, tgr) case oauth2.ClientCredentials: tgr.Scope = r.Form.Get("scope") - ti, err = s.manager.GenerateAccessToken(oauth2.ClientCredentials, tgr) - case oauth2.RefreshCredentials: + case oauth2.Refreshing: tgr.Refresh = r.Form.Get("refresh_token") tgr.Scope = r.Form.Get("scope") - if tgr.Scope != "" { // 检查授权范围 - rti, rerr := s.manager.LoadRefreshToken(tgr.Refresh) - if rerr != nil { - err = rerr + if tgr.Refresh == "" { + rerr = errors.ErrInvalidRequest + } + } + return +} + +// GetAccessToken Get access token +func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (ti oauth2.TokenInfo, rerr, ierr error) { + if allowed := s.CheckGrantType(gt); !allowed { + rerr = errors.ErrUnauthorizedClient + return + } + if fn := s.ClientAuthorizedHandler; fn != nil { + allowed, err := fn(tgr.ClientID, gt) + if err != nil { + ierr = err + return + } + if !allowed { + rerr = errors.ErrUnauthorizedClient + return + } + } + switch gt { + case oauth2.AuthorizationCode: + ti, ierr = s.Manager.GenerateAccessToken(gt, tgr) + if ierr != nil { + if ierr == errors.ErrInvalidAuthorizeCode { + rerr = errors.ErrInvalidGrant + ierr = nil + } else if ierr == errors.ErrInvalidClient { + rerr = errors.ErrInvalidClient + ierr = nil + } + } + case oauth2.PasswordCredentials: + fallthrough + case oauth2.ClientCredentials: + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr.ClientID, tgr.Scope) + if err != nil { + ierr = err + return + } + if !allowed { + rerr = errors.ErrInvalidScope return - } else if rti.GetClientID() != tgr.ClientID { - err = ErrRefreshInvalid + } + } + ti, ierr = s.Manager.GenerateAccessToken(gt, tgr) + if ierr != nil { + if ierr == errors.ErrInvalidClient { + rerr = errors.ErrInvalidClient + ierr = nil + } + } + case oauth2.Refreshing: + if scope := tgr.Scope; scope != "" { + rti, err := s.Manager.LoadRefreshToken(tgr.Refresh) + if err != nil { + if err == errors.ErrInvalidRefreshToken { + rerr = err + return + } + ierr = err return - } else if verr := s.cfg.Handler.ScopeHandler(tgr.Scope, rti.GetScope()); verr != nil { - err = verr + } + if fn := s.RefreshingScopeHandler; fn != nil && !fn(scope, rti.GetScope()) { + rerr = errors.ErrInvalidScope return } } - ti, err = s.manager.RefreshAccessToken(tgr) - if err == nil { + ti, ierr = s.Manager.RefreshAccessToken(tgr) + if ierr != nil { + if ierr == errors.ErrInvalidClient { + rerr = errors.ErrInvalidClient + ierr = nil + } else if ierr == errors.ErrInvalidRefreshToken { + rerr = errors.ErrInvalidRefreshToken + ierr = nil + } + } else { ti.SetRefresh("") } } - if err != nil { - return - } - err = s.ResJSON(w, ti) return } -// ResJSON 响应Json数据 -func (s *Server) ResJSON(w http.ResponseWriter, ti oauth2.TokenInfo) (err error) { - data := map[string]interface{}{ +// GetTokenData Get token data +func (s *Server) GetTokenData(ti oauth2.TokenInfo) (data map[string]interface{}) { + data = map[string]interface{}{ "access_token": ti.GetAccess(), - "token_type": s.cfg.TokenType, - "expires_in": ti.GetAccessExpiresIn() / time.Second, + "token_type": s.Config.TokenType, + "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } if scope := ti.GetScope(); scope != "" { data["scope"] = scope @@ -237,10 +439,35 @@ func (s *Server) ResJSON(w http.ResponseWriter, ti oauth2.TokenInfo) (err error) if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } - w.Header().Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") - w.Header().Set("Pragma", "no-cache") - w.Header().Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") - w.Header().Set("Content-Type", "application/json;charset=UTF-8") - w.WriteHeader(http.StatusOK) - return json.NewEncoder(w).Encode(data) + return +} + +// HandleTokenRequest The token request handling +func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err error) { + var ( + ti oauth2.TokenInfo + rerr error + ierr error + ) + defer func() { + if verr := recover(); verr != nil { + err = fmt.Errorf("%v", verr) + return + } + data := s.GetErrorData(rerr, ierr) + if data == nil { + data = s.GetTokenData(ti) + } + w.Header().Set("Content-Type", "application/json;charset=UTF-8") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(http.StatusOK) + err = json.NewEncoder(w).Encode(data) + }() + gt, tgr, rerr, ierr := s.ValidationTokenRequest(r) + if rerr != nil || ierr != nil { + return + } + ti, rerr, ierr = s.GetAccessToken(gt, tgr) + return } diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..1867b9f --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,260 @@ +package server_test + +import ( + "errors" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gavv/httpexpect" + "gopkg.in/oauth2.v3/manage" + "gopkg.in/oauth2.v3/models" + "gopkg.in/oauth2.v3/server" + "gopkg.in/oauth2.v3/store/client" + "gopkg.in/oauth2.v3/store/token" +) + +var ( + srv *server.Server + tsrv *httptest.Server + manager *manage.Manager + csrv *httptest.Server +) + +func init() { + manager = manage.NewRedisManager( + &token.RedisConfig{Addr: "192.168.33.70:6379"}, + ) +} + +func testServer(t *testing.T, w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/authorize": + err := srv.HandleAuthorizeRequest(w, r) + if err != nil { + t.Error(err) + } + case "/token": + err := srv.HandleTokenRequest(w, r) + if err != nil { + t.Error(err) + } + } +} + +func TestAuthorizeCode(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2": + r.ParseForm() + code, state := r.Form.Get("code"), r.Form.Get("state") + if state != "123" { + t.Error("unrecognized state:", state) + return + } + val := e.POST("/token"). + WithFormField("redirect_uri", csrv.URL+"/oauth2"). + WithFormField("code", code). + WithFormField("grant_type", "authorization_code"). + WithFormField("client_id", "333333"). + WithFormField("client_secret", "33333333"). + Expect(). + Status(http.StatusOK). + JSON().Raw() + + t.Log(val) + } + })) + defer csrv.Close() + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "333333", + Secret: "33333333", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "111111" + return + }) + + e.GET("/authorize"). + WithQuery("response_type", "code"). + WithQuery("client_id", "333333"). + WithQuery("scope", "all"). + WithQuery("state", "123"). + WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + Expect().Status(http.StatusOK) +} + +func TestImplicit(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2": + t.Log(r.RequestURI) + } + })) + defer csrv.Close() + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "55555", + Secret: "5555555", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "222222" + return + }) + + e.GET("/authorize"). + WithQuery("response_type", "token"). + WithQuery("client_id", "55555"). + WithQuery("scope", "all"). + WithQuery("state", "123"). + WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + Expect().Status(http.StatusOK) +} + +func TestPasswordCredentials(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "666666", + Secret: "66666666", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + srv.SetPasswordAuthorizationHandler(func(username, password string) (userID string, err error) { + if username == "admin" && password == "123456" { + userID = "666666" + return + } + err = errors.New("user not found") + return + }) + + val := e.POST("/token"). + WithFormField("grant_type", "password"). + WithFormField("client_id", "666666"). + WithFormField("client_secret", "66666666"). + WithFormField("username", "admin"). + WithFormField("password", "123456"). + WithFormField("scope", "all"). + Expect(). + Status(http.StatusOK). + JSON().Raw() + + t.Log(val) +} + +func TestClientCredentials(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "777777", + Secret: "77777777", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + + val := e.POST("/token"). + WithFormField("grant_type", "clientcredentials"). + WithFormField("client_id", "777777"). + WithFormField("client_secret", "77777777"). + WithFormField("scope", "all"). + Expect(). + Status(http.StatusOK). + JSON().Raw() + + t.Log(val) +} + +func TestRefreshing(t *testing.T) { + tsrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer(t, w, r) + })) + defer tsrv.Close() + e := httpexpect.New(t, tsrv.URL) + + csrv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2": + r.ParseForm() + code, state := r.Form.Get("code"), r.Form.Get("state") + if state != "123" { + t.Error("unrecognized state:", state) + return + } + jval := e.POST("/token"). + WithFormField("redirect_uri", csrv.URL+"/oauth2"). + WithFormField("code", code). + WithFormField("grant_type", "authorization_code"). + WithFormField("client_id", "888888"). + WithFormField("client_secret", "88888888"). + Expect(). + Status(http.StatusOK). + JSON() + + refresh := jval.Object().Value("refresh_token").String().Raw() + + rval := e.POST("/token"). + WithFormField("grant_type", "refreshtoken"). + WithFormField("client_id", "888888"). + WithFormField("client_secret", "88888888"). + WithFormField("scope", "one"). + WithFormField("refresh_token", refresh). + Expect(). + Status(http.StatusOK). + JSON().Raw() + + t.Log(rval) + } + })) + defer csrv.Close() + + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "888888", + Secret: "88888888", + Domain: csrv.URL, + })) + + srv = server.NewServer(server.NewConfig(), manager) + srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + userID = "888888" + return + }) + + e.GET("/authorize"). + WithQuery("response_type", "code"). + WithQuery("client_id", "888888"). + WithQuery("scope", "all"). + WithQuery("state", "123"). + WithQuery("redirect_uri", url.QueryEscape(csrv.URL+"/oauth2")). + Expect().Status(http.StatusOK) +} diff --git a/store.go b/store.go index c2ac40d..60c37b2 100644 --- a/store.go +++ b/store.go @@ -1,28 +1,27 @@ package oauth2 -// 提供存储接口 type ( - // ClientStore 客户端信息存储接口 + // ClientStore The client information storage interface ClientStore interface { - // GetByID 根据ID获取客户端信息 + // GetByID According to the ID for the client information GetByID(id string) (ClientInfo, error) } - // TokenStore 令牌信息存储接口 + // TokenStore The token information storage interface TokenStore interface { - // Create 创建并存储新的令牌信息 + // Create Create and store the new token information Create(info TokenInfo) error - // RemoveByAccess 使用访问令牌删除令牌信息 + // RemoveByAccess Use the access token to delete the token information(Along with the refresh token) RemoveByAccess(access string) error - // RemoveByRefresh 使用更新令牌删除令牌信息 + // RemoveByRefresh Use the refresh token to delete the token information RemoveByRefresh(refresh string) error - // 使用访问令牌获取令牌信息数据 + // Use the access token for token information data GetByAccess(access string) (TokenInfo, error) - // 根据更新令牌获取令牌信息数据 + // Use the refresh token for token information data GetByRefresh(refresh string) (TokenInfo, error) } ) diff --git a/store/client/temp.go b/store/client/temp.go index c0e126c..a59293f 100644 --- a/store/client/temp.go +++ b/store/client/temp.go @@ -1,13 +1,11 @@ package client import ( - "errors" - - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/models" + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/models" ) -// NewTempStore 创建客户端临时存储实例 +// NewTempStore Create to client information temporary store instance func NewTempStore(clients ...*models.Client) oauth2.ClientStore { data := map[string]*models.Client{ "1": &models.Client{ @@ -24,17 +22,15 @@ func NewTempStore(clients ...*models.Client) oauth2.ClientStore { } } -// TempStore 客户端信息的临时存储 +// TempStore Client information store type TempStore struct { data map[string]*models.Client } -// GetByID 获取客户端信息 +// GetByID According to the ID for the client information func (ts *TempStore) GetByID(id string) (cli oauth2.ClientInfo, err error) { if c, ok := ts.data[id]; ok { cli = c - return } - err = errors.New("not found") return } diff --git a/store/token/mongo.go b/store/token/mongo.go deleted file mode 100644 index ee699f3..0000000 --- a/store/token/mongo.go +++ /dev/null @@ -1,148 +0,0 @@ -package token - -import ( - "time" - - "gopkg.in/LyricTian/lib.v2/mongo" - "gopkg.in/mgo.v2" - "gopkg.in/mgo.v2/bson" - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/models" -) - -// MongoConfig MongoDB Configuration -type MongoConfig struct { - // Connection String - URL string - // DB Name(default oauth2) - DB string - // Collection Name(default tokens) - C string -} - -// NewMongoStore 创建MongoDB的令牌存储 -func NewMongoStore(cfg *MongoConfig) (store oauth2.TokenStore, err error) { - if cfg.DB == "" { - cfg.DB = "oauth2" - } - if cfg.C == "" { - cfg.C = "tokens" - } - handler, err := mongo.InitHandlerWithDB(cfg.URL, cfg.DB) - if err != nil { - return - } - // 创建自动过期索引 - err = handler.C(cfg.C).EnsureIndex(mgo.Index{ - Key: []string{"ExpiredAt"}, - ExpireAfter: time.Second * 1, - }) - if err != nil { - return - } - err = handler.C(cfg.C).EnsureIndexKey("Access") - if err != nil { - return - } - err = handler.C(cfg.C).EnsureIndexKey("Refresh") - if err != nil { - return - } - store = &MongoStore{ - handler: handler, - cfg: cfg, - } - return -} - -// MongoStore MongoDB Store -type MongoStore struct { - cfg *MongoConfig - handler *mongo.Handler -} - -// Create 存储令牌信息 -func (ms *MongoStore) Create(info oauth2.TokenInfo) (err error) { - tm := info.(*models.Token) - var expiredAt time.Time - if refresh := tm.Refresh; refresh != "" { - expiredAt = tm.RefreshCreateAt.Add(tm.RefreshExpiresIn) - rinfo, rerr := ms.GetByRefresh(refresh) - if rerr != nil { - err = rerr - return - } - if rinfo != nil { - expiredAt = rinfo.GetRefreshCreateAt().Add(rinfo.GetRefreshExpiresIn()) - } - } - if expiredAt.IsZero() { - expiredAt = tm.AccessCreateAt.Add(tm.AccessExpiresIn) - } - doc := map[string]interface{}{ - "ExpiredAt": expiredAt, - "ClientID": tm.ClientID, - "UserID": tm.UserID, - "RedirectURI": tm.RedirectURI, - "Scope": tm.Scope, - "AuthType": tm.AuthType, - "Access": tm.Access, - "AccessCreateAt": tm.AccessCreateAt, - "AccessExpiresIn": tm.AccessExpiresIn, - "Refresh": tm.Refresh, - "RefreshCreateAt": tm.RefreshCreateAt, - "RefreshExpiresIn": tm.RefreshExpiresIn, - } - - ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) { - err = c.Insert(doc) - }) - return -} - -func (ms *MongoStore) remove(selector interface{}) (err error) { - ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) { - err = c.Remove(selector) - }) - return -} - -// RemoveByAccess 移除令牌 -func (ms *MongoStore) RemoveByAccess(access string) (err error) { - err = ms.remove(bson.M{"Access": access}) - return -} - -// RemoveByRefresh 移除令牌 -func (ms *MongoStore) RemoveByRefresh(refresh string) (err error) { - err = ms.remove(bson.M{"Refresh": refresh}) - return -} - -func (ms *MongoStore) get(find interface{}) (info oauth2.TokenInfo, err error) { - ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) { - var tm models.Token - aerr := c.Find(find).Select(bson.M{"_id": 0}).Sort("-_id").One(&tm) - if aerr != nil { - if aerr == mgo.ErrNotFound { - return - } - err = aerr - return - } - info = &tm - }) - return -} - -// GetByAccess 获取令牌数据 -func (ms *MongoStore) GetByAccess(access string) (info oauth2.TokenInfo, err error) { - info, err = ms.get(bson.M{"Access": access}) - return -} - -// GetByRefresh 获取令牌数据 -func (ms *MongoStore) GetByRefresh(refresh string) (info oauth2.TokenInfo, err error) { - info, err = ms.get(bson.M{"Refresh": refresh}) - return -} diff --git a/store/token/mongo_test.go b/store/token/mongo_test.go deleted file mode 100644 index 6e5a2fd..0000000 --- a/store/token/mongo_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package token - -import ( - "testing" - - . "github.com/smartystreets/goconvey/convey" -) - -const ( - mongoURL = "mongodb://admin:123456@192.168.33.70:27017" -) - -func TestMongoStore(t *testing.T) { - Convey("Test mongo store", t, func() { - cfg := &MongoConfig{ - URL: mongoURL, - } - store, err := NewMongoStore(cfg) - So(err, ShouldBeNil) - - Convey("Test access token store", func() { - testAccessStore(store) - }) - - Convey("Test refresh token store", func() { - testRefreshStore(store) - }) - }) -} diff --git a/store/token/redis.go b/store/token/redis.go index 54f84f7..575f26d 100644 --- a/store/token/redis.go +++ b/store/token/redis.go @@ -3,12 +3,18 @@ package token import ( "encoding/json" - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/models" "gopkg.in/redis.v4" + + "strconv" + + "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/models" ) -// NewRedisStore 创建redis存储的实例 +// DefaultIncrKey store incr id +const DefaultIncrKey = "oauth2_incr" + +// NewRedisStore Create a token store instance based on redis func NewRedisStore(cfg *RedisConfig) (store oauth2.TokenStore, err error) { opt := &redis.Options{ Network: cfg.Network, @@ -31,36 +37,47 @@ func NewRedisStore(cfg *RedisConfig) (store oauth2.TokenStore, err error) { return } -// RedisStore 令牌的redis存储 +// RedisStore Redis Store type RedisStore struct { cli *redis.Client } -// Create 存储令牌信息 +func (rs *RedisStore) getBasicID(id int64, info oauth2.TokenInfo) string { + return "oauth2_" + info.GetClientID() + "_" + strconv.FormatInt(id, 10) +} + +// Create Create and store the new token information func (rs *RedisStore) Create(info oauth2.TokenInfo) (err error) { jv, err := json.Marshal(info) if err != nil { return } + id, err := rs.cli.Incr(DefaultIncrKey).Result() + if err != nil { + return + } pipe := rs.cli.Pipeline() - + basicID := rs.getBasicID(id, info) aexp := info.GetAccessExpiresIn() + rexp := aexp + if refresh := info.GetRefresh(); refresh != "" { - exp := info.GetRefreshExpiresIn() + rexp = info.GetRefreshExpiresIn() ttl := rs.cli.TTL(refresh) if verr := ttl.Err(); verr != nil { err = verr return } if v := ttl.Val(); v.Seconds() > 0 { - exp = v + rexp = v } - if aexp.Seconds() > exp.Seconds() { - aexp = exp + if aexp.Seconds() > rexp.Seconds() { + aexp = rexp } - pipe.Set(refresh, jv, exp) + pipe.Set(refresh, basicID, rexp) } - pipe.Set(info.GetAccess(), jv, aexp) + pipe.Set(info.GetAccess(), basicID, aexp) + pipe.Set(basicID, jv, rexp) if _, verr := pipe.Exec(); verr != nil { err = verr @@ -70,44 +87,49 @@ func (rs *RedisStore) Create(info oauth2.TokenInfo) (err error) { // remove func (rs *RedisStore) remove(key string) (err error) { - info, err := rs.get(key) - if err != nil || info == nil { - return - } - pipe := rs.cli.Pipeline() - pipe.Del(info.GetAccess()) - if v := info.GetRefresh(); v != "" { - pipe.Del(v) - } - if _, verr := pipe.Exec(); verr != nil { + _, verr := rs.cli.Del(key).Result() + if verr != redis.Nil { err = verr } return } -// RemoveByAccess 移除令牌 +// RemoveByAccess Use the access token to delete the token information(Along with the refresh token) func (rs *RedisStore) RemoveByAccess(access string) (err error) { err = rs.remove(access) return } -// RemoveByRefresh 移除令牌 +// RemoveByRefresh Use the refresh token to delete the token information func (rs *RedisStore) RemoveByRefresh(refresh string) (err error) { err = rs.remove(refresh) return } -func (rs *RedisStore) get(key string) (ti oauth2.TokenInfo, err error) { - gv, gerr := rs.cli.Get(key).Result() - if gerr != nil { - if gerr == redis.Nil { +// get +func (rs *RedisStore) get(token string) (ti oauth2.TokenInfo, err error) { + tv, verr := rs.cli.Get(token).Result() + if verr != nil { + if verr == redis.Nil { + return + } + err = verr + return + } + result := rs.cli.Get(tv) + if verr := result.Err(); verr != nil { + if verr == redis.Nil { return } - err = gerr + err = verr + return + } + iv, err := result.Bytes() + if err != nil { return } var tm models.Token - if verr := json.Unmarshal([]byte(gv), &tm); verr != nil { + if verr := json.Unmarshal(iv, &tm); verr != nil { err = verr return } @@ -115,13 +137,13 @@ func (rs *RedisStore) get(key string) (ti oauth2.TokenInfo, err error) { return } -// GetByAccess 获取令牌数据 +// GetByAccess Use the access token for token information data func (rs *RedisStore) GetByAccess(access string) (ti oauth2.TokenInfo, err error) { ti, err = rs.get(access) return } -// GetByRefresh 获取令牌数据 +// GetByRefresh Use the refresh token for token information data func (rs *RedisStore) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) { ti, err = rs.get(refresh) return diff --git a/store/token/redis_test.go b/store/token/redis_test.go index d658656..fefa63c 100644 --- a/store/token/redis_test.go +++ b/store/token/redis_test.go @@ -1,25 +1,74 @@ -package token +package token_test import ( "testing" + "time" + + "gopkg.in/oauth2.v3/models" + "gopkg.in/oauth2.v3/store/token" . "github.com/smartystreets/goconvey/convey" ) func TestRedisStore(t *testing.T) { Convey("Test redis store", t, func() { - cfg := &RedisConfig{ + cfg := &token.RedisConfig{ Addr: "192.168.33.70:6379", } - store, err := NewRedisStore(cfg) + store, err := token.NewRedisStore(cfg) So(err, ShouldBeNil) Convey("Test access token store", func() { - testAccessStore(store) + info := &models.Token{ + ClientID: "1", + UserID: "1_1", + RedirectURI: "http://localhost/", + Scope: "all", + Access: "1_1_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 5, + } + err := store.Create(info) + So(err, ShouldBeNil) + + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo.GetUserID(), ShouldEqual, info.GetUserID()) + + err = store.RemoveByAccess(info.GetAccess()) + So(err, ShouldBeNil) + + ainfo, err = store.GetByAccess(info.GetAccess()) + So(err, ShouldNotBeNil) + So(ainfo, ShouldBeNil) }) Convey("Test refresh token store", func() { - testRefreshStore(store) + info := &models.Token{ + ClientID: "1", + UserID: "1_2", + RedirectURI: "http://localhost/", + Scope: "all", + Access: "1_2_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 5, + Refresh: "1_2_2", + RefreshCreateAt: time.Now(), + RefreshExpiresIn: time.Minute * 1, + } + err := store.Create(info) + So(err, ShouldBeNil) + + rinfo, err := store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo.GetUserID(), ShouldEqual, info.GetUserID()) + + err = store.RemoveByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + + rinfo, err = store.GetByRefresh(info.GetRefresh()) + So(err, ShouldNotBeNil) + So(rinfo, ShouldBeNil) }) }) } diff --git a/store/token/token_test.go b/store/token/token_test.go deleted file mode 100644 index 87742a5..0000000 --- a/store/token/token_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package token - -import ( - "time" - - . "github.com/smartystreets/goconvey/convey" - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/models" -) - -func testAccessStore(store oauth2.TokenStore) { - info := &models.Token{ - ClientID: "1", - UserID: "1_1", - RedirectURI: "http://localhost/", - Scope: "all", - AuthType: oauth2.Code.String(), - Access: "1_1_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 5, - } - err := store.Create(info) - So(err, ShouldBeNil) - - ainfo, err := store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo.GetUserID(), ShouldEqual, info.GetUserID()) - - err = store.RemoveByAccess(info.GetAccess()) - So(err, ShouldBeNil) - - ainfo, err = store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo, ShouldBeNil) -} - -func testRefreshStore(store oauth2.TokenStore) { - info := &models.Token{ - ClientID: "1", - UserID: "1_2", - RedirectURI: "http://localhost/", - Scope: "all", - AuthType: oauth2.Code.String(), - Access: "1_2_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 5, - Refresh: "1_2_2", - RefreshCreateAt: time.Now(), - RefreshExpiresIn: time.Minute * 1, - } - err := store.Create(info) - So(err, ShouldBeNil) - - rinfo, err := store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo.GetUserID(), ShouldEqual, info.GetUserID()) - - err = store.RemoveByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - - rinfo, err = store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo, ShouldBeNil) -}