diff --git a/arbitrum/backend.go b/arbitrum/backend.go index 3e78d8d124..b1bb9c89de 100644 --- a/arbitrum/backend.go +++ b/arbitrum/backend.go @@ -51,6 +51,14 @@ func NewBackend(stack *node.Node, config *Config, chainDb ethdb.Database, publis chanNewBlock: make(chan struct{}, 1), } + if len(config.AllowMethod) > 0 { + rpcFilter := make(map[string]bool) + for _, method := range config.AllowMethod { + rpcFilter[method] = true + } + backend.stack.ApplyAPIFilter(rpcFilter) + } + backend.bloomIndexer.Start(backend.arb.BlockChain()) filterSystem, err := createRegisterAPIBackend(backend, filterConfig, config.ClassicRedirect, config.ClassicRedirectTimeout) if err != nil { diff --git a/arbitrum/config.go b/arbitrum/config.go index fbae58588e..308cf9aee1 100644 --- a/arbitrum/config.go +++ b/arbitrum/config.go @@ -37,6 +37,8 @@ type Config struct { ClassicRedirect string `koanf:"classic-redirect"` ClassicRedirectTimeout time.Duration `koanf:"classic-redirect-timeout"` MaxRecreateStateDepth int64 `koanf:"max-recreate-state-depth"` + + AllowMethod []string `koanf:"allow-method"` } type ArbDebugConfig struct { @@ -57,6 +59,7 @@ func ConfigAddOptions(prefix string, f *flag.FlagSet) { f.Int(prefix+".filter-log-cache-size", DefaultConfig.FilterLogCacheSize, "log filter system maximum number of cached blocks") f.Duration(prefix+".filter-timeout", DefaultConfig.FilterTimeout, "log filter system maximum time filters stay active") f.Int64(prefix+".max-recreate-state-depth", DefaultConfig.MaxRecreateStateDepth, "maximum depth for recreating state, measured in l2 gas (0=don't recreate state, -1=infinite, -2=use default value for archive or non-archive node (whichever is configured))") + f.StringSlice(prefix+".allow-method", DefaultConfig.AllowMethod, "list of whitelisted rpc methods") arbDebug := DefaultConfig.ArbDebug f.Uint64(prefix+".arbdebug.block-range-bound", arbDebug.BlockRangeBound, "bounds the number of blocks arbdebug calls may return") f.Uint64(prefix+".arbdebug.timeout-queue-bound", arbDebug.TimeoutQueueBound, "bounds the length of timeout queues arbdebug calls may return") @@ -81,6 +84,7 @@ var DefaultConfig = Config{ FeeHistoryMaxBlockCount: 1024, ClassicRedirect: "", MaxRecreateStateDepth: UninitializedMaxRecreateStateDepth, // default value should be set for depending on node type (archive / non-archive) + AllowMethod: []string{}, ArbDebug: ArbDebugConfig{ BlockRangeBound: 256, TimeoutQueueBound: 512, diff --git a/node/node.go b/node/node.go index c1530b6371..1cd0388ccd 100644 --- a/node/node.go +++ b/node/node.go @@ -69,6 +69,8 @@ type Node struct { inprocHandler *rpc.Server // In-process RPC request handler to process the API requests databases map[*closeTrackingDB]struct{} // All open databases + + apiFilter map[string]bool // Whitelisting API methods } const ( @@ -378,6 +380,11 @@ func (n *Node) obtainJWTSecret(cliParam string) ([]byte, error) { return jwtSecret, nil } +// ApplyAPIFilter is the first step in whitelisting given rpc methods inside apiFilter +func (n *Node) ApplyAPIFilter(apiFilter map[string]bool) { + n.apiFilter = apiFilter +} + // startRPC is a helper method to configure all the various RPC endpoints during node // startup. It's not meant to be called at any time afterwards as it makes certain // assumptions about the state of the node. @@ -418,6 +425,7 @@ func (n *Node) startRPC() error { Vhosts: n.config.HTTPVirtualHosts, Modules: n.config.HTTPModules, prefix: n.config.HTTPPathPrefix, + apiFilter: n.apiFilter, }); err != nil { return err } @@ -431,9 +439,10 @@ func (n *Node) startRPC() error { return err } if err := server.enableWS(openAPIs, wsConfig{ - Modules: n.config.WSModules, - Origins: n.config.WSOrigins, - prefix: n.config.WSPathPrefix, + Modules: n.config.WSModules, + Origins: n.config.WSOrigins, + prefix: n.config.WSPathPrefix, + apiFilter: n.apiFilter, }); err != nil { return err } @@ -453,6 +462,7 @@ func (n *Node) startRPC() error { Modules: n.config.AuthModules, prefix: DefaultAuthPrefix, jwtSecret: secret, + apiFilter: n.apiFilter, }); err != nil { return err } @@ -467,6 +477,7 @@ func (n *Node) startRPC() error { Origins: n.config.AuthOrigins, prefix: DefaultAuthPrefix, jwtSecret: secret, + apiFilter: n.apiFilter, }); err != nil { return err } diff --git a/node/rpcstack.go b/node/rpcstack.go index 16d51b3371..513179ea78 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -42,6 +42,7 @@ type httpConfig struct { Vhosts []string prefix string // path prefix on which to mount http handler jwtSecret []byte // optional JWT secret + apiFilter map[string]bool } // wsConfig is the JSON-RPC/Websocket configuration @@ -50,6 +51,7 @@ type wsConfig struct { Modules []string prefix string // path prefix on which to mount ws handler jwtSecret []byte // optional JWT secret + apiFilter map[string]bool } type rpcHandler struct { @@ -302,6 +304,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { // Create RPC server and handler. srv := rpc.NewServer() + srv.ApplyAPIFilter(config.apiFilter) if err := RegisterApis(apis, config.Modules, srv); err != nil { return err } @@ -339,6 +342,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } // Create RPC server and handler. srv := rpc.NewServer() + srv.ApplyAPIFilter(config.apiFilter) if err := RegisterApis(apis, config.Modules, srv); err != nil { return err } diff --git a/rpc/server.go b/rpc/server.go index 089bbb1fd5..ff6534d284 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -65,6 +65,10 @@ func NewServer() *Server { return server } +func (s *Server) ApplyAPIFilter(apiFilter map[string]bool) { + s.services.apiFilter = apiFilter +} + // RegisterName creates a service for the given receiver type under the given name. When no // methods on the given receiver match the criteria to be either a RPC method or a // subscription an error is returned. Otherwise a new service is created and added to the diff --git a/rpc/service.go b/rpc/service.go index 8485cab3aa..7821d5fc2a 100644 --- a/rpc/service.go +++ b/rpc/service.go @@ -38,6 +38,8 @@ var ( type serviceRegistry struct { mu sync.Mutex services map[string]service + + apiFilter map[string]bool } // service represents a registered object. @@ -81,11 +83,17 @@ func (r *serviceRegistry) registerName(name string, rcvr interface{}) error { } r.services[name] = svc } - for name, cb := range callbacks { + for methodName, cb := range callbacks { + if r.apiFilter != nil { + key := name + "_" + methodName + if _, ok := r.apiFilter[key]; !ok { + continue + } + } if cb.isSubscribe { - svc.subscriptions[name] = cb + svc.subscriptions[methodName] = cb } else { - svc.callbacks[name] = cb + svc.callbacks[methodName] = cb } } return nil