diff --git a/common.go b/common.go index 71c8ca6..78587c1 100644 --- a/common.go +++ b/common.go @@ -36,8 +36,8 @@ func errHandler(r *http.Request, w http.ResponseWriter, err error) { w.Write(respBody) } -type UserStore interface { - GetUser(ctx context.Context, address string) (any, bool, error) +type UserStore[T any] interface { + GetUser(ctx context.Context, address string) (user *T, isAdmin bool, err error) } // Config is a generic map of services/methods to a config value. diff --git a/middleware.go b/middleware.go index cb25efb..6479f91 100644 --- a/middleware.go +++ b/middleware.go @@ -11,7 +11,7 @@ import ( ) // Options for the authcontrol middleware handlers Session and AccessControl. -type Options struct { +type Options[T any] struct { // JWT secret used to verify the JWT token. JWTSecret string @@ -21,13 +21,13 @@ type Options struct { // UserStore is a function that is used to get the user from the request // with pluggable backends. - UserStore UserStore + UserStore UserStore[T] // ErrHandler is a function that is used to handle and respond to errors. ErrHandler ErrHandler } -func (o *Options) ApplyDefaults() { +func (o *Options[T]) ApplyDefaults() { // Set default access key functions if not provided. // We intentionally check for nil instead of len == 0 because // if you can pass an empty slice to have no access key defaults. @@ -41,7 +41,7 @@ func (o *Options) ApplyDefaults() { } } -func Session(cfg *Options) func(next http.Handler) http.Handler { +func Session[T any](cfg *Options[T]) func(next http.Handler) http.Handler { cfg.ApplyDefaults() auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil) @@ -144,7 +144,7 @@ func Session(cfg *Options) func(next http.Handler) http.Handler { // AccessControl middleware that checks if the session type is allowed to access the endpoint. // It also sets the compute units on the context if the endpoint requires it. -func AccessControl(acl Config[ACL], cfg *Options) func(next http.Handler) http.Handler { +func AccessControl[T any](acl Config[ACL], cfg *Options[T]) func(next http.Handler) http.Handler { cfg.ApplyDefaults() return func(next http.Handler) http.Handler { diff --git a/middleware_test.go b/middleware_test.go index 2fc098e..209835e 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net/http" - "path" "strings" "testing" "time" @@ -17,34 +16,21 @@ import ( "github.com/0xsequence/authcontrol/proto" ) -type mockStore map[string]bool +// JWTSecret is the secret used to sign the JWT token in the tests. +const JWTSecret = "secret" -var secret = "secret" +type User struct{} -func (m mockStore) GetUser(ctx context.Context, address string) (any, bool, error) { +// MockStore is a simple in-memory User store for testing, it stores the address and admin status. +type MockStore map[string]bool + +// GetUser returns the user and the admin status from the store. +func (m MockStore) GetUser(ctx context.Context, address string) (user *User, isAdmin bool, err error) { v, ok := m[address] if !ok { return nil, false, nil } - return struct{}{}, v, nil -} - -type testCase struct { - AccessKey string - Session proto.SessionType - Admin bool -} - -func (t testCase) String() string { - s := strings.Builder{} - s.WriteString(t.Session.String()) - if t.AccessKey != "" { - s.WriteString("/WithKey") - } - if t.Admin { - s.WriteString("/Admin") - } - return s.String() + return &User{}, v, nil } func TestSession(t *testing.T) { @@ -78,9 +64,9 @@ func TestSession(t *testing.T) { ServiceName = "serviceName" ) - options := &authcontrol.Options{ - JWTSecret: secret, - UserStore: mockStore{ + options := &authcontrol.Options[User]{ + JWTSecret: JWTSecret, + UserStore: MockStore{ UserAddress: false, AdminAddress: true, }, @@ -95,7 +81,11 @@ func TestSession(t *testing.T) { r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) ctx := context.Background() - testCases := []testCase{ + testCases := []struct { + AccessKey string + Session proto.SessionType + Admin bool + }{ {Session: proto.SessionType_Public}, {Session: proto.SessionType_Public, AccessKey: AccessKey}, {Session: proto.SessionType_Wallet}, @@ -113,7 +103,15 @@ func TestSession(t *testing.T) { for _, method := range Methods { types := ACLConfig[service][method] for _, tc := range testCases { - t.Run(path.Join(method, tc.String()), func(t *testing.T) { + s := strings.Builder{} + fmt.Fprintf(&s, "%s/%s", method, tc.Session) + if tc.AccessKey != "" { + s.WriteString("+AccessKey") + } + if tc.Admin { + s.WriteString("+Admin") + } + t.Run(s.String(), func(t *testing.T) { var claims map[string]any switch tc.Session { case proto.SessionType_Wallet: @@ -132,7 +130,7 @@ func TestSession(t *testing.T) { claims = map[string]any{"service": ServiceName} } - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, authcontrol.S2SToken(secret, claims)) + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, authcontrol.S2SToken(JWTSecret, claims)) session := tc.Session switch { @@ -178,9 +176,9 @@ func TestInvalid(t *testing.T) { AdminAddress = "adminAddress" ) - options := &authcontrol.Options{ - JWTSecret: secret, - UserStore: mockStore{ + options := &authcontrol.Options[User]{ + JWTSecret: JWTSecret, + UserStore: MockStore{ UserAddress: false, AdminAddress: true, }, @@ -216,33 +214,33 @@ func TestInvalid(t *testing.T) { claims := map[string]any{"service": "client_service"} // Valid Request - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) assert.True(t, ok) assert.NoError(t, err) // Invalid request path with wrong not enough parts in path for valid RPC request - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid request path with wrong "rpc" - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid Service - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Invalid Method - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, authcontrol.S2SToken(secret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, proto.ErrUnauthorized) // Expired JWT Token claims["exp"] = time.Now().Add(-time.Second).Unix() - expiredJWT := authcontrol.S2SToken(secret, claims) + expiredJWT := authcontrol.S2SToken(JWTSecret, claims) // Expired JWT Token valid method ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, expiredJWT) @@ -288,9 +286,9 @@ func TestCustomErrHandler(t *testing.T) { HTTPStatus: 400, } - opts := &authcontrol.Options{ - JWTSecret: secret, - UserStore: mockStore{ + opts := &authcontrol.Options[User]{ + JWTSecret: JWTSecret, + UserStore: MockStore{ UserAddress: false, AdminAddress: true, }, @@ -317,12 +315,12 @@ func TestCustomErrHandler(t *testing.T) { claims = map[string]any{"service": "client_service"} // Valid Request - ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) + ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) assert.True(t, ok) assert.NoError(t, err) // Invalid service which should return custom error from overrided ErrHandler - ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(secret, claims)) + ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims)) assert.False(t, ok) assert.ErrorIs(t, err, customErr) }