Skip to content

Commit

Permalink
switching jwt library (#31)
Browse files Browse the repository at this point in the history
* switching jwt library

* copy the middleware from legacy goa
  • Loading branch information
glaslos authored Oct 4, 2023
1 parent b1e100f commit b3491e3
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 16 deletions.
43 changes: 33 additions & 10 deletions api/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ import (
"github.com/Vivino/rankdb"
"github.com/Vivino/rankdb/api/app"
"github.com/Vivino/rankdb/log"
"github.com/dgrijalva/jwt-go"
"github.com/goadesign/goa"
goajwt "github.com/goadesign/goa/middleware/security/jwt"
"github.com/golang-jwt/jwt/v4"
)

// Change to enable JWT for all api endpoints.
Expand All @@ -36,7 +35,7 @@ var (
// Note: the code below assumes the example is compiled against the master branch of goa.
// If compiling against goa v1 the call to jwt.New needs to be:
//
// middleware := jwt.New(keys, ForceFail(), app.NewJWTSecurity())
// middleware := jwt.New(keys, ForceFail(), app.NewJWTSecurity())
func NewJWTMiddleware() (goa.Middleware, error) {
// Skip all JWT if needed.
if !enableJWT {
Expand All @@ -53,7 +52,7 @@ func NewJWTMiddleware() (goa.Middleware, error) {
if len(keys) == 0 {
return nil, errors.New("no public keys found")
}
return goajwt.New(keys, nil, app.NewJWTSecurity()), nil
return NewJWT(keys, nil, app.NewJWTSecurity()), nil
}

// LoadJWTPublicKeys loads PEM encoded RSA public keys used to validate and decrypt the JWT.
Expand Down Expand Up @@ -82,13 +81,13 @@ func LoadJWTPublicKeys() ([]*rsa.PublicKey, error) {
}

func HasAccessToList(ctx context.Context, ids ...rankdb.ListID) error {
token := goajwt.ContextJWT(ctx)
token := ContextJWT(ctx)
if token == nil {
return nil
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return goajwt.ErrJWTError("unsupported claims shape")
return ErrJWTError("unsupported claims shape")
}
if val, ok := claims["only_lists"].(string); ok {
// Check all ids
Expand All @@ -102,23 +101,47 @@ func HasAccessToList(ctx context.Context, ids ...rankdb.ListID) error {
}
}
if !ok {
return goajwt.ErrJWTError("Access not granted to list " + string(id))
return ErrJWTError("Access not granted to list " + string(id))
}
}
}
return nil
}

type contextKey int

const (
jwtKey contextKey = iota + 1
)

// WithJWT creates a child context containing the given JWT.
func WithJWT(ctx context.Context, t *jwt.Token) context.Context {
return context.WithValue(ctx, jwtKey, t)
}

// ContextJWT retrieves the JWT token from a `context` that went through our security middleware.
func ContextJWT(ctx context.Context) *jwt.Token {
token, ok := ctx.Value(jwtKey).(*jwt.Token)
if !ok {
return nil
}
return token
}

// ErrJWTError is the error returned by this middleware when any sort of validation or assertion
// fails during processing.
var ErrJWTError = goa.NewErrorClass("jwt_security_error", 401)

// HasAccessToElement returns true
// Not optimized for big lists of IDs.
func HasAccessToElement(ctx context.Context, ids ...rankdb.ElementID) error {
token := goajwt.ContextJWT(ctx)
token := ContextJWT(ctx)
if token == nil {
return nil
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return goajwt.ErrJWTError("unsupported claims shape")
return ErrJWTError("unsupported claims shape")
}
if val, ok := claims["only_elements"].(string); ok {
lists := strings.Split(val, ",")
Expand All @@ -132,7 +155,7 @@ func HasAccessToElement(ctx context.Context, ids ...rankdb.ElementID) error {
}
}
if !ok {
return goajwt.ErrJWTError(fmt.Sprint("Access not granted to element ", id))
return ErrJWTError(fmt.Sprint("Access not granted to element ", id))
}
}
}
Expand Down
271 changes: 271 additions & 0 deletions api/jwt_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
package api

import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"fmt"
"net/http"
"sort"
"strings"

"github.com/goadesign/goa"
"github.com/golang-jwt/jwt/v4"
)

// partitionKeys sorts keys by their type.
func partitionKeys(k interface{}) ([]*rsa.PublicKey, []*ecdsa.PublicKey, [][]byte) {
var (
rsaKeys []*rsa.PublicKey
ecdsaKeys []*ecdsa.PublicKey
hmacKeys [][]byte
)

switch typed := k.(type) {
case []byte:
hmacKeys = append(hmacKeys, typed)
case [][]byte:
hmacKeys = typed
case string:
hmacKeys = append(hmacKeys, []byte(typed))
case []string:
for _, s := range typed {
hmacKeys = append(hmacKeys, []byte(s))
}
case *rsa.PublicKey:
rsaKeys = append(rsaKeys, typed)
case []*rsa.PublicKey:
rsaKeys = typed
case *ecdsa.PublicKey:
ecdsaKeys = append(ecdsaKeys, typed)
case []*ecdsa.PublicKey:
ecdsaKeys = typed
}

return rsaKeys, ecdsaKeys, hmacKeys
}

func extractTokenFromQueryParam(schemeName string, req *http.Request) (string, error) {
incomingToken := req.URL.Query().Get(schemeName)
if incomingToken == "" {
return "", ErrJWTError(fmt.Sprintf("missing parameter %q", schemeName))
}

return incomingToken, nil
}

func validateRSAKeys(rsaKeys []*rsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
for _, pubkey := range rsaKeys {
token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
if !strings.HasPrefix(token.Method.Alg(), algo) {
return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
}
return pubkey, nil
})
if err == nil {
return
}
}
return
}

func validateECDSAKeys(ecdsaKeys []*ecdsa.PublicKey, algo, incomingToken string) (token *jwt.Token, err error) {
for _, pubkey := range ecdsaKeys {
token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
if !strings.HasPrefix(token.Method.Alg(), algo) {
return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
}
return pubkey, nil
})
if err == nil {
return
}
}
return
}

func validateHMACKeys(hmacKeys [][]byte, algo, incomingToken string) (token *jwt.Token, err error) {
for _, key := range hmacKeys {
token, err = jwt.Parse(incomingToken, func(token *jwt.Token) (interface{}, error) {
if !strings.HasPrefix(token.Method.Alg(), algo) {
return nil, ErrJWTError(fmt.Sprintf("Unexpected signing method: %v", token.Header["alg"]))
}
return key, nil
})
if err == nil {
return
}
}
return
}

// validScopeClaimKeys are the claims under which scopes may be found in a token
var validScopeClaimKeys = []string{"scope", "scopes"}

// parseClaimScopes parses the "scope" or "scopes" parameter in the Claims. It
// supports two formats:
//
// * a list of strings
//
// * a single string with space-separated scopes (akin to OAuth2's "scope").
//
// An empty string is an explicit claim of no scopes.
func parseClaimScopes(token *jwt.Token) (map[string]bool, []string, error) {
scopesInClaim := make(map[string]bool)
var scopesInClaimList []string
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, nil, fmt.Errorf("unsupport claims shape")
}
for _, k := range validScopeClaimKeys {
if rawscopes, ok := claims[k]; ok && rawscopes != nil {
switch scopes := rawscopes.(type) {
case string:
for _, scope := range strings.Split(scopes, " ") {
scopesInClaim[scope] = true
scopesInClaimList = append(scopesInClaimList, scope)
}
case []interface{}:
for _, scope := range scopes {
if val, ok := scope.(string); ok {
scopesInClaim[val] = true
scopesInClaimList = append(scopesInClaimList, val)
}
}
default:
return nil, nil, fmt.Errorf("unsupported scope format in incoming JWT claim, was type %T", scopes)
}
break
}
}
sort.Strings(scopesInClaimList)
return scopesInClaim, scopesInClaimList, nil
}

// NewJWT returns a middleware to be used with the JWTSecurity DSL definitions of goa. It supports the
// scopes claim in the JWT and ensures goa-defined Security DSLs are properly validated.
//
// The steps taken by the middleware are:
//
// 1. Extract the "Bearer" token from the Authorization header or query parameter
// 2. Validate the "Bearer" token against the key(s)
// given to NewJWT
// 3. If scopes are defined in the design for the action, validate them
// against the scopes presented by the JWT in the claim "scope", or if
// that's not defined, "scopes".
//
// The `exp` (expiration) and `nbf` (not before) date checks are validated by the JWT library.
//
// validationKeys can be one of these:
//
// - a string (for HMAC)
// - a []byte (for HMAC)
// - an rsa.PublicKey
// - an ecdsa.PublicKey
// - a slice of any of the above
//
// The type of the keys determine the algorithm that will be used to do the check. The goal of
// having lists of keys is to allow for key rotation, still check the previous keys until rotation
// has been completed.
//
// You can define an optional function to do additional validations on the token once the signature
// and the claims requirements are proven to be valid. Example:
//
// validationHandler, _ := goa.NewMiddleware(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
// token := jwt.ContextJWT(ctx)
// if val, ok := token.Claims["is_uncle"].(string); !ok || val != "ben" {
// return jwt.ErrJWTError("you are not uncle ben's")
// }
// })
//
// Mount the middleware with the generated UseXX function where XX is the name of the scheme as
// defined in the design, e.g.:
//
// app.UseJWT(jwt.NewJWT("secret", validationHandler, app.NewJWTSecurity()))
func NewJWT(validationKeys interface{}, validationFunc goa.Middleware, scheme *goa.JWTSecurity) goa.Middleware {
var rsaKeys []*rsa.PublicKey
var hmacKeys [][]byte

rsaKeys, ecdsaKeys, hmacKeys := partitionKeys(validationKeys)

return func(nextHandler goa.Handler) goa.Handler {
return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
var (
incomingToken string
err error
)

if scheme.In == goa.LocHeader {
if incomingToken, err = extractTokenFromHeader(scheme.Name, req); err != nil {
return err
}
} else if scheme.In == goa.LocQuery {
if incomingToken, err = extractTokenFromQueryParam(scheme.Name, req); err != nil {
return err
}
} else {
return fmt.Errorf("whoops, security scheme with location (in) %q not supported", scheme.In)
}

var (
token *jwt.Token
validated = false
)

if len(rsaKeys) > 0 {
token, err = validateRSAKeys(rsaKeys, "RS", incomingToken)
validated = err == nil
}

if !validated && len(ecdsaKeys) > 0 {
token, err = validateECDSAKeys(ecdsaKeys, "ES", incomingToken)
validated = err == nil
}

if !validated && len(hmacKeys) > 0 {
token, err = validateHMACKeys(hmacKeys, "HS", incomingToken)
//validated = err == nil
}

if err != nil {
return ErrJWTError(fmt.Sprintf("JWT validation failed: %s", err))
}

scopesInClaim, scopesInClaimList, err := parseClaimScopes(token)
if err != nil {
goa.LogError(ctx, err.Error())
return ErrJWTError(err)
}

requiredScopes := goa.ContextRequiredScopes(ctx)

for _, scope := range requiredScopes {
if !scopesInClaim[scope] {
msg := "authorization failed: required 'scope' or 'scopes' not present in JWT claim"
return ErrJWTError(msg, "required", requiredScopes, "scopes", scopesInClaimList)
}
}

ctx = WithJWT(ctx, token)
if validationFunc != nil {
nextHandler = validationFunc(nextHandler)
}
return nextHandler(ctx, rw, req)
}
}
}

func extractTokenFromHeader(schemeName string, req *http.Request) (string, error) {
val := req.Header.Get(schemeName)
if val == "" {
return "", ErrJWTError(fmt.Sprintf("missing header %q", schemeName))
}

if !strings.HasPrefix(strings.ToLower(val), "bearer ") {
return "", ErrJWTError(fmt.Sprintf("invalid or malformed %q header, expected 'Bearer JWT-token...'", val))
}

incomingToken := strings.Split(val, " ")[1]

return incomingToken, nil
}
2 changes: 1 addition & 1 deletion api/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import (
"github.com/Vivino/rankdb/log"
"github.com/Vivino/rankdb/log/loggoa"
"github.com/Vivino/rankdb/log/testlogger"
jwtgo "github.com/dgrijalva/jwt-go"
"github.com/goadesign/goa"
goaclient "github.com/goadesign/goa/client"
goalogrus "github.com/goadesign/goa/logging/logrus"
jwtgo "github.com/golang-jwt/jwt/v4"
shutdown "github.com/klauspost/shutdown2"
"github.com/sirupsen/logrus"
)
Expand Down
Loading

0 comments on commit b3491e3

Please sign in to comment.