From 1ce92bf7a6df0c322250a9d2df455d53d41bf541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20-=20=E3=82=A2=E3=83=AC=E3=83=83=E3=82=AF=E3=82=B9?= Date: Mon, 4 Nov 2024 09:25:57 +0100 Subject: [PATCH] origin check (#18) * origin check * move var in if * move to the top * minor refactor --- common_test.go | 30 ++++++++++++++++----- middleware.go | 10 +++++++ middleware_test.go | 66 +++++++++++++++++++++++++++++++++++++--------- 3 files changed, 86 insertions(+), 20 deletions(-) diff --git a/common_test.go b/common_test.go index 3b3054a..9c09257 100644 --- a/common_test.go +++ b/common_test.go @@ -21,17 +21,33 @@ func keyFunc(r *http.Request) string { return r.Header.Get(HeaderKey) } -func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey string, jwt string) (bool, error) { +type requestOption func(r *http.Request) + +func accessKey(v string) requestOption { + return func(r *http.Request) { + r.Header.Set(HeaderKey, v) + } +} + +func jwt(v string) requestOption { + return func(r *http.Request) { + r.Header.Set("Authorization", "Bearer "+v) + } +} + +func origin(v string) requestOption { + return func(r *http.Request) { + r.Header.Set("Origin", v) + } +} + +func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path string, options ...requestOption) (bool, error) { req, err := http.NewRequest("POST", path, nil) require.NoError(t, err) req.Header.Set("X-Real-IP", "127.0.0.1") - if accessKey != "" { - req.Header.Set(HeaderKey, accessKey) - } - - if jwt != "" { - req.Header.Set("Authorization", "Bearer "+jwt) + for _, opt := range options { + opt(req) } rr := httptest.NewRecorder() diff --git a/middleware.go b/middleware.go index 1937e3a..a7486c6 100644 --- a/middleware.go +++ b/middleware.go @@ -3,6 +3,7 @@ package authcontrol import ( "errors" "net/http" + "strings" "github.com/go-chi/jwtauth/v5" "github.com/lestrrat-go/jwx/v2/jwt" @@ -87,6 +88,15 @@ func Session[T any](cfg Options[T]) func(next http.Handler) http.Handler { return } + if originClaim, _ := claims["ogn"].(string); originClaim != "" { + originClaim = strings.TrimSuffix(originClaim, "/") + originHeader := strings.TrimSuffix(r.Header.Get("Origin"), "/") + if originHeader != "" && originHeader != originClaim { + cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("invalid origin claim")) + return + } + } + serviceClaim, _ := claims["service"].(string) accountClaim, _ := claims["account"].(string) adminClaim, _ := claims["admin"].(bool) diff --git a/middleware_test.go b/middleware_test.go index 40e9cf4..5ef937d 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -130,7 +130,15 @@ func TestSession(t *testing.T) { claims = map[string]any{"service": ServiceName} } - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, authcontrol.S2SToken(JWTSecret, claims)) + var options []requestOption + if tc.AccessKey != "" { + options = append(options, accessKey(tc.AccessKey)) + } + if claims != nil { + options = append(options, jwt(authcontrol.S2SToken(JWTSecret, claims))) + } + + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), options...) session := tc.Session switch { @@ -202,39 +210,39 @@ func TestInvalid(t *testing.T) { })) // Without JWT - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, "") + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt("")) assert.True(t, ok) assert.NoError(t, err) // Wrong JWT - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, "wrong-secret") + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt("wrong-secret")) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) claims := map[string]any{"service": "client_service"} // Valid Request - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) assert.True(t, ok) assert.NoError(t, err) // Invalid request path with wrong not enough parts in path for valid RPC request - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid request path with wrong "rpc" - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid Service - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid Method - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) @@ -243,17 +251,17 @@ func TestInvalid(t *testing.T) { expiredJWT := authcontrol.S2SToken(JWTSecret, claims) // Expired JWT Token valid method - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, expiredJWT) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(expiredJWT)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrSessionExpired) // Expired JWT Token invalid service - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, expiredJWT) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(expiredJWT)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrSessionExpired) // Expired JWT Token invalid method - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, expiredJWT) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(expiredJWT)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrSessionExpired) } @@ -315,12 +323,44 @@ func TestCustomErrHandler(t *testing.T) { claims = map[string]any{"service": "client_service"} // Valid Request - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) assert.True(t, ok) assert.NoError(t, err) // Invalid service which should return custom error from overrided ErrHandler - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims))) assert.False(t, ok) assert.ErrorIs(t, err, customErr) } + +func TestOrigin(t *testing.T) { + ctx := context.Background() + + opts := authcontrol.Options[any]{ + JWTSecret: JWTSecret, + } + + r := chi.NewRouter() + r.Use(authcontrol.Session(opts)) + r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + token := authcontrol.S2SToken(JWTSecret, map[string]any{ + "user": "123", + "ogn": "http://localhost", + }) + + // No Origin header + ok, err := executeRequest(t, ctx, r, "", jwt(token)) + assert.True(t, ok) + assert.NoError(t, err) + + // Valid Origin header + ok, err = executeRequest(t, ctx, r, "", jwt(token), origin("http://localhost")) + assert.True(t, ok) + assert.NoError(t, err) + + // Invalid Origin header + ok, err = executeRequest(t, ctx, r, "", jwt(token), origin("http://evil.com")) + assert.False(t, ok) + assert.ErrorIs(t, err, proto.ErrUnauthorized) +}