Skip to content

Commit

Permalink
add generics to UserStore (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
klaidliadon authored Oct 30, 2024
1 parent c5f92fb commit 3fd30b9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 50 deletions.
4 changes: 2 additions & 2 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func errHandler(r *http.Request, w http.ResponseWriter, err error) {
w.Write(respBody)
}

type UserStore interface {
GetUser(ctx context.Context, address string) (any, bool, error)
type UserStore[T any] interface {
GetUser(ctx context.Context, address string) (user *T, isAdmin bool, err error)
}

// Config is a generic map of services/methods to a config value.
Expand Down
10 changes: 5 additions & 5 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

// Options for the authcontrol middleware handlers Session and AccessControl.
type Options struct {
type Options[T any] struct {
// JWT secret used to verify the JWT token.
JWTSecret string

Expand All @@ -21,13 +21,13 @@ type Options struct {

// UserStore is a function that is used to get the user from the request
// with pluggable backends.
UserStore UserStore
UserStore UserStore[T]

// ErrHandler is a function that is used to handle and respond to errors.
ErrHandler ErrHandler
}

func (o *Options) ApplyDefaults() {
func (o *Options[T]) ApplyDefaults() {
// Set default access key functions if not provided.
// We intentionally check for nil instead of len == 0 because
// if you can pass an empty slice to have no access key defaults.
Expand All @@ -41,7 +41,7 @@ func (o *Options) ApplyDefaults() {
}
}

func Session(cfg *Options) func(next http.Handler) http.Handler {
func Session[T any](cfg *Options[T]) func(next http.Handler) http.Handler {
cfg.ApplyDefaults()
auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil)

Expand Down Expand Up @@ -144,7 +144,7 @@ func Session(cfg *Options) func(next http.Handler) http.Handler {

// AccessControl middleware that checks if the session type is allowed to access the endpoint.
// It also sets the compute units on the context if the endpoint requires it.
func AccessControl(acl Config[ACL], cfg *Options) func(next http.Handler) http.Handler {
func AccessControl[T any](acl Config[ACL], cfg *Options[T]) func(next http.Handler) http.Handler {
cfg.ApplyDefaults()

return func(next http.Handler) http.Handler {
Expand Down
84 changes: 41 additions & 43 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"testing"
"time"
Expand All @@ -17,34 +16,21 @@ import (
"github.com/0xsequence/authcontrol/proto"
)

type mockStore map[string]bool
// JWTSecret is the secret used to sign the JWT token in the tests.
const JWTSecret = "secret"

var secret = "secret"
type User struct{}

func (m mockStore) GetUser(ctx context.Context, address string) (any, bool, error) {
// MockStore is a simple in-memory User store for testing, it stores the address and admin status.
type MockStore map[string]bool

// GetUser returns the user and the admin status from the store.
func (m MockStore) GetUser(ctx context.Context, address string) (user *User, isAdmin bool, err error) {
v, ok := m[address]
if !ok {
return nil, false, nil
}
return struct{}{}, v, nil
}

type testCase struct {
AccessKey string
Session proto.SessionType
Admin bool
}

func (t testCase) String() string {
s := strings.Builder{}
s.WriteString(t.Session.String())
if t.AccessKey != "" {
s.WriteString("/WithKey")
}
if t.Admin {
s.WriteString("/Admin")
}
return s.String()
return &User{}, v, nil
}

func TestSession(t *testing.T) {
Expand Down Expand Up @@ -78,9 +64,9 @@ func TestSession(t *testing.T) {
ServiceName = "serviceName"
)

options := &authcontrol.Options{
JWTSecret: secret,
UserStore: mockStore{
options := &authcontrol.Options[User]{
JWTSecret: JWTSecret,
UserStore: MockStore{
UserAddress: false,
AdminAddress: true,
},
Expand All @@ -95,7 +81,11 @@ func TestSession(t *testing.T) {
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

ctx := context.Background()
testCases := []testCase{
testCases := []struct {
AccessKey string
Session proto.SessionType
Admin bool
}{
{Session: proto.SessionType_Public},
{Session: proto.SessionType_Public, AccessKey: AccessKey},
{Session: proto.SessionType_Wallet},
Expand All @@ -113,7 +103,15 @@ func TestSession(t *testing.T) {
for _, method := range Methods {
types := ACLConfig[service][method]
for _, tc := range testCases {
t.Run(path.Join(method, tc.String()), func(t *testing.T) {
s := strings.Builder{}
fmt.Fprintf(&s, "%s/%s", method, tc.Session)
if tc.AccessKey != "" {
s.WriteString("+AccessKey")
}
if tc.Admin {
s.WriteString("+Admin")
}
t.Run(s.String(), func(t *testing.T) {
var claims map[string]any
switch tc.Session {
case proto.SessionType_Wallet:
Expand All @@ -132,7 +130,7 @@ func TestSession(t *testing.T) {
claims = map[string]any{"service": ServiceName}
}

ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, authcontrol.S2SToken(secret, claims))
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", service, method), tc.AccessKey, authcontrol.S2SToken(JWTSecret, claims))

session := tc.Session
switch {
Expand Down Expand Up @@ -178,9 +176,9 @@ func TestInvalid(t *testing.T) {
AdminAddress = "adminAddress"
)

options := &authcontrol.Options{
JWTSecret: secret,
UserStore: mockStore{
options := &authcontrol.Options[User]{
JWTSecret: JWTSecret,
UserStore: MockStore{
UserAddress: false,
AdminAddress: true,
},
Expand Down Expand Up @@ -216,33 +214,33 @@ func TestInvalid(t *testing.T) {
claims := map[string]any{"service": "client_service"}

// Valid Request
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid request path with wrong not enough parts in path for valid RPC request
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid request path with wrong "rpc"
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid Service
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Invalid Method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, authcontrol.S2SToken(secret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, proto.ErrUnauthorized)

// Expired JWT Token
claims["exp"] = time.Now().Add(-time.Second).Unix()
expiredJWT := authcontrol.S2SToken(secret, claims)
expiredJWT := authcontrol.S2SToken(JWTSecret, claims)

// Expired JWT Token valid method
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, expiredJWT)
Expand Down Expand Up @@ -288,9 +286,9 @@ func TestCustomErrHandler(t *testing.T) {
HTTPStatus: 400,
}

opts := &authcontrol.Options{
JWTSecret: secret,
UserStore: mockStore{
opts := &authcontrol.Options[User]{
JWTSecret: JWTSecret,
UserStore: MockStore{
UserAddress: false,
AdminAddress: true,
},
Expand All @@ -317,12 +315,12 @@ func TestCustomErrHandler(t *testing.T) {
claims = map[string]any{"service": "client_service"}

// Valid Request
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
assert.True(t, ok)
assert.NoError(t, err)

// Invalid service which should return custom error from overrided ErrHandler
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(secret, claims))
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), AccessKey, authcontrol.S2SToken(JWTSecret, claims))
assert.False(t, ok)
assert.ErrorIs(t, err, customErr)
}

0 comments on commit 3fd30b9

Please sign in to comment.