Skip to content

Commit

Permalink
origin check (#18)
Browse files Browse the repository at this point in the history
* origin check

* move var in if

* move to the top

* minor refactor
  • Loading branch information
klaidliadon authored Nov 4, 2024
1 parent 029fab0 commit 1ce92bf
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 20 deletions.
30 changes: 23 additions & 7 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,33 @@ func keyFunc(r *http.Request) string {
return r.Header.Get(HeaderKey)
}

func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey string, jwt string) (bool, error) {
type requestOption func(r *http.Request)

func accessKey(v string) requestOption {
return func(r *http.Request) {
r.Header.Set(HeaderKey, v)
}
}

func jwt(v string) requestOption {
return func(r *http.Request) {
r.Header.Set("Authorization", "Bearer "+v)
}
}

func origin(v string) requestOption {
return func(r *http.Request) {
r.Header.Set("Origin", v)
}
}

func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path string, options ...requestOption) (bool, error) {
req, err := http.NewRequest("POST", path, nil)
require.NoError(t, err)

req.Header.Set("X-Real-IP", "127.0.0.1")
if accessKey != "" {
req.Header.Set(HeaderKey, accessKey)
}

if jwt != "" {
req.Header.Set("Authorization", "Bearer "+jwt)
for _, opt := range options {
opt(req)
}

rr := httptest.NewRecorder()
Expand Down
10 changes: 10 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package authcontrol
import (
"errors"
"net/http"
"strings"

"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwt"
Expand Down Expand Up @@ -87,6 +88,15 @@ func Session[T any](cfg Options[T]) func(next http.Handler) http.Handler {
return
}

if originClaim, _ := claims["ogn"].(string); originClaim != "" {
originClaim = strings.TrimSuffix(originClaim, "/")
originHeader := strings.TrimSuffix(r.Header.Get("Origin"), "/")
if originHeader != "" && originHeader != originClaim {
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("invalid origin claim"))
return
}
}

serviceClaim, _ := claims["service"].(string)
accountClaim, _ := claims["account"].(string)
adminClaim, _ := claims["admin"].(bool)
Expand Down
66 changes: 53 additions & 13 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,15 @@ func TestSession(t *testing.T) {
claims = map[string]any{"service": ServiceName}
}

ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, authcontrol.S2SToken(JWTSecret, claims))
var options []requestOption
if tc.AccessKey != "" {
options = append(options, accessKey(tc.AccessKey))
}
if claims != nil {
options = append(options, jwt(authcontrol.S2SToken(JWTSecret, claims)))
}

ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), options...)

session := tc.Session
switch {
Expand Down Expand Up @@ -202,39 +210,39 @@ func TestInvalid(t *testing.T) {
}))

// Without JWT
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, "")
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(""))
assert.True(t, ok)
assert.NoError(t, err)

// Wrong JWT
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, "wrong-secret")
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt("wrong-secret"))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

claims := map[string]any{"service": "client_service"}

// Valid Request
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid request path with wrong not enough parts in path for valid RPC request
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid request path with wrong "rpc"
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid Service
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid Method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

Expand All @@ -243,17 +251,17 @@ func TestInvalid(t *testing.T) {
expiredJWT := authcontrol.S2SToken(JWTSecret, claims)

// Expired JWT Token valid method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, expiredJWT)
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(expiredJWT))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)

// Expired JWT Token invalid service
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, expiredJWT)
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(expiredJWT))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)

// Expired JWT Token invalid method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, expiredJWT)
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(expiredJWT))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrSessionExpired)
}
Expand Down Expand Up @@ -315,12 +323,44 @@ func TestCustomErrHandler(t *testing.T) {
claims = map[string]any{"service": "client_service"}

// Valid Request
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid service which should return custom error from overrided ErrHandler
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
assert.False(t, ok)
assert.ErrorIs(t, err, customErr)
}

func TestOrigin(t *testing.T) {
ctx := context.Background()

opts := authcontrol.Options[any]{
JWTSecret: JWTSecret,
}

r := chi.NewRouter()
r.Use(authcontrol.Session(opts))
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

token := authcontrol.S2SToken(JWTSecret, map[string]any{
"user": "123",
"ogn": "http://localhost",
})

// No Origin header
ok, err := executeRequest(t, ctx, r, "", jwt(token))
assert.True(t, ok)
assert.NoError(t, err)

// Valid Origin header
ok, err = executeRequest(t, ctx, r, "", jwt(token), origin("http://localhost"))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid Origin header
ok, err = executeRequest(t, ctx, r, "", jwt(token), origin("http://evil.com"))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)
}

0 comments on commit 1ce92bf

Please sign in to comment.