diff --git a/api/admin.go b/api/admin.go index afd299cfa..7a4ddb3e7 100644 --- a/api/admin.go +++ b/api/admin.go @@ -6,14 +6,82 @@ package api import ( + "fmt" "log/slog" + "net" "net/http" + "sync/atomic" + "time" + "github.com/gorilla/handlers" + "github.com/gorilla/mux" "github.com/pkg/errors" "github.com/vechain/thor/v2/api/utils" + "github.com/vechain/thor/v2/co" "github.com/vechain/thor/v2/log" ) +type Admin struct { + address string + logLevel *slog.LevelVar + logRequests *atomic.Bool +} + +func NewAdmin(addr string, logLevel *slog.LevelVar, logRequests *atomic.Bool) *Admin { + return &Admin{ + address: addr, + logLevel: logLevel, + logRequests: logRequests, + } +} + +// Start the admin server. +func (a *Admin) Start() (string, func(), error) { + listener, err := net.Listen("tcp", a.address) + if err != nil { + return "", nil, errors.Wrapf(err, "listen admin API addr [%v]", a.address) + } + + router := mux.NewRouter() + handler := handlers.CompressHandler(router) + sub := router.PathPrefix("/admin").Subrouter() + + // GET /admin/loglevel + sub.Path("/loglevel"). + Methods(http.MethodGet). + Name("get-log-level"). + HandlerFunc(utils.WrapHandlerFunc(a.getLogLevelHandler)) + // POST /admin/loglevel + sub.Path("/loglevel"). + Methods(http.MethodPost). + Name("post-log-level"). + HandlerFunc(utils.WrapHandlerFunc(a.postLogLevelHandler)) + + // GET /admin/apilogs + sub.Path("/apilogs"). + Methods(http.MethodGet). + Name("get-api-logs-enabled"). + Handler(utils.WrapHandlerFunc(a.getRequestLoggerEnabled)) + // POST /admin/apilogs + sub.Path("/apilogs"). + Methods(http.MethodPost). + Name("post-api-logs-enabled"). + Handler(utils.WrapHandlerFunc(a.postRequestLogger)) + + server := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second, ReadTimeout: 5 * time.Second} + var goes co.Goes + goes.Go(func() { + server.Serve(listener) + }) + + cancel := func() { + server.Close() + goes.Wait() + } + + return "http://" + listener.Addr().String() + "/admin", cancel, nil +} + type logLevelRequest struct { Level string `json:"level"` } @@ -22,41 +90,69 @@ type logLevelResponse struct { CurrentLevel string `json:"currentLevel"` } -func getLogLevelHandler(logLevel *slog.LevelVar) utils.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) error { - return utils.WriteJSON(w, logLevelResponse{ - CurrentLevel: logLevel.Level().String(), - }) +func (a *Admin) getLogLevelHandler(w http.ResponseWriter, r *http.Request) error { + return utils.WriteJSON(w, logLevelResponse{ + CurrentLevel: a.logLevel.Level().String(), + }) +} + +func (a *Admin) postLogLevelHandler(w http.ResponseWriter, r *http.Request) error { + var req logLevelRequest + + if err := utils.ParseJSON(r.Body, &req); err != nil { + return utils.BadRequest(errors.WithMessage(err, "invalid request body")) + } + + switch req.Level { + case "debug": + a.logLevel.Set(log.LevelDebug) + case "info": + a.logLevel.Set(log.LevelInfo) + case "warn": + a.logLevel.Set(log.LevelWarn) + case "error": + a.logLevel.Set(log.LevelError) + case "trace": + a.logLevel.Set(log.LevelTrace) + case "crit": + a.logLevel.Set(log.LevelCrit) + default: + return utils.BadRequest(fmt.Errorf("invalid verbosity level: %s", req.Level)) } + + log.Warn("admin changed the log level", "level", log.LevelString(a.logLevel.Level())) + + return utils.WriteJSON(w, logLevelResponse{ + CurrentLevel: a.logLevel.Level().String(), + }) } -func postLogLevelHandler(logLevel *slog.LevelVar) utils.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) error { - var req logLevelRequest +type apiLogRequests struct { + Enabled *bool `json:"enabled"` +} - if err := utils.ParseJSON(r.Body, &req); err != nil { - return utils.BadRequest(errors.WithMessage(err, "Invalid request body")) - } +func (a *Admin) getRequestLoggerEnabled(w http.ResponseWriter, r *http.Request) error { + enabled := a.logRequests.Load() + res := apiLogRequests{ + Enabled: &enabled, + } + return utils.WriteJSON(w, res) +} - switch req.Level { - case "debug": - logLevel.Set(log.LevelDebug) - case "info": - logLevel.Set(log.LevelInfo) - case "warn": - logLevel.Set(log.LevelWarn) - case "error": - logLevel.Set(log.LevelError) - case "trace": - logLevel.Set(log.LevelTrace) - case "crit": - logLevel.Set(log.LevelCrit) - default: - return utils.BadRequest(errors.New("Invalid verbosity level")) - } +func (a *Admin) postRequestLogger(w http.ResponseWriter, r *http.Request) error { + var req apiLogRequests + + if err := utils.ParseJSON(r.Body, &req); err != nil { + return utils.BadRequest(errors.WithMessage(err, "invalid request body")) + } - return utils.WriteJSON(w, logLevelResponse{ - CurrentLevel: logLevel.Level().String(), - }) + if req.Enabled == nil { + return utils.BadRequest(errors.New("missing 'enabled' field")) } + + log.Warn("admin changed the request logger", "enabled", *req.Enabled) + + a.logRequests.Store(*req.Enabled) + + return utils.WriteJSON(w, req) } diff --git a/api/admin_server.go b/api/admin_server.go deleted file mode 100644 index 26054e908..000000000 --- a/api/admin_server.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2024 The VeChainThor developers - -// Distributed under the GNU Lesser General Public License v3.0 software license, see the accompanying -// file LICENSE or - -package api - -import ( - "log/slog" - "net" - "net/http" - "time" - - "github.com/gorilla/handlers" - "github.com/gorilla/mux" - "github.com/pkg/errors" - "github.com/vechain/thor/v2/api/utils" - "github.com/vechain/thor/v2/co" -) - -func HTTPHandler(logLevel *slog.LevelVar) http.Handler { - router := mux.NewRouter() - sub := router.PathPrefix("/admin").Subrouter() - sub.Path("/loglevel"). - Methods(http.MethodGet). - Name("get-log-level"). - HandlerFunc(utils.WrapHandlerFunc(getLogLevelHandler(logLevel))) - - sub.Path("/loglevel"). - Methods(http.MethodPost). - Name("post-log-level"). - HandlerFunc(utils.WrapHandlerFunc(postLogLevelHandler(logLevel))) - - return handlers.CompressHandler(router) -} - -func StartAdminServer(addr string, logLevel *slog.LevelVar) (string, func(), error) { - listener, err := net.Listen("tcp", addr) - if err != nil { - return "", nil, errors.Wrapf(err, "listen admin API addr [%v]", addr) - } - - router := mux.NewRouter() - router.PathPrefix("/admin").Handler(HTTPHandler(logLevel)) - handler := handlers.CompressHandler(router) - - srv := &http.Server{Handler: handler, ReadHeaderTimeout: time.Second, ReadTimeout: 5 * time.Second} - var goes co.Goes - goes.Go(func() { - srv.Serve(listener) - }) - return "http://" + listener.Addr().String() + "/admin", func() { - srv.Close() - goes.Wait() - }, nil -} diff --git a/api/admin_test.go b/api/admin_test.go index be2847cbf..26b8d93ee 100644 --- a/api/admin_test.go +++ b/api/admin_test.go @@ -8,92 +8,128 @@ package api import ( "bytes" "encoding/json" + "fmt" "log/slog" "net/http" "net/http/httptest" - "strings" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" + "github.com/vechain/thor/v2/api/utils" + "github.com/vechain/thor/v2/log" ) -type TestCase struct { - name string - method string - body interface{} - expectedStatus int - expectedLevel string - expectedErrorMsg string -} +func TestAdmin_postLogLevel(t *testing.T) { + tests := []struct { + level string + httpCode int + }{ + {"debug", http.StatusOK}, + {"info", http.StatusOK}, + {"warn", http.StatusOK}, + {"error", http.StatusOK}, + {"crit", http.StatusOK}, + {"invalid", http.StatusBadRequest}, + } -func marshalBody(tt TestCase, t *testing.T) []byte { - var reqBody []byte - var err error - if tt.body != nil { - reqBody, err = json.Marshal(tt.body) - if err != nil { - t.Fatalf("could not marshal request body: %v", err) - } + for _, tt := range tests { + t.Run(tt.level, func(t *testing.T) { + admin := newAdmin() + req := newRequest(t, http.MethodPost, "/admin/loglevel", map[string]string{"level": tt.level}) + res := newHTTPTest(req, admin.postLogLevelHandler) + + assert.Equal(t, tt.httpCode, res.Code) + if tt.httpCode == http.StatusOK { + assert.Equal(t, tt.level, log.LevelString(admin.logLevel.Level())) + } + }) } - return reqBody } -func TestLogLevelHandler(t *testing.T) { - tests := []TestCase{ - { - name: "Valid POST input - set level to DEBUG", - method: "POST", - body: map[string]string{"level": "debug"}, - expectedStatus: http.StatusOK, - expectedLevel: "DEBUG", - }, - { - name: "Invalid POST input - invalid level", - method: "POST", - body: map[string]string{"level": "invalid_body"}, - expectedStatus: http.StatusBadRequest, - expectedErrorMsg: "Invalid verbosity level", - }, - { - name: "GET request - get current level INFO", - method: "GET", - body: nil, - expectedStatus: http.StatusOK, - expectedLevel: "INFO", - }, - } +func TestAdmin_getLogLevel(t *testing.T) { + admin := newAdmin() + initialLevel := log.LevelString(admin.logLevel.Level()) + req := newRequest(t, http.MethodGet, "/admin/loglevel", nil) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var logLevel slog.LevelVar - logLevel.Set(slog.LevelInfo) + res := newHTTPTest(req, admin.getRequestLoggerEnabled) - reqBodyBytes := marshalBody(tt, t) + assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, initialLevel, log.LevelString(admin.logLevel.Level())) +} - req, err := http.NewRequest(tt.method, "/admin/loglevel", bytes.NewBuffer(reqBodyBytes)) - if err != nil { - t.Fatal(err) - } +// Update TestAdmin_postRequestLogger +func TestAdmin_postRequestLogger(t *testing.T) { + testCases := []struct { + enabled interface{} + httpCode int + }{ + {true, http.StatusOK}, + {false, http.StatusOK}, + {"invalid", http.StatusBadRequest}, + {nil, http.StatusBadRequest}, + } - rr := httptest.NewRecorder() - handler := http.HandlerFunc(HTTPHandler(&logLevel).ServeHTTP) - handler.ServeHTTP(rr, req) + for _, tt := range testCases { + t.Run(fmt.Sprintf("enabled=%v", tt.enabled), func(t *testing.T) { + admin := newAdmin() + req := newRequest(t, http.MethodPost, "/admin/apilogs", map[string]interface{}{"enabled": tt.enabled}) - if status := rr.Code; status != tt.expectedStatus { - t.Errorf("handler returned wrong status code: got %v want %v", status, tt.expectedStatus) - } + res := newHTTPTest(req, admin.postRequestLogger) - if tt.expectedLevel != "" { - var response logLevelResponse - if err := json.NewDecoder(rr.Body).Decode(&response); err != nil { - t.Fatalf("could not decode response: %v", err) - } - if response.CurrentLevel != tt.expectedLevel { - t.Errorf("handler returned unexpected log level: got %v want %v", response.CurrentLevel, tt.expectedLevel) - } - } else { - assert.Equal(t, tt.expectedErrorMsg, strings.Trim(rr.Body.String(), "\n")) + assert.Equal(t, tt.httpCode, res.Code) + if res.Code == http.StatusOK { + assert.Equal(t, tt.enabled, admin.logRequests.Load()) } }) } } + +// Update TestAdmin_getRequestLoggerEnabled +func TestAdmin_getRequestLoggerEnabled(t *testing.T) { + admin := newAdmin() + req := newRequest(t, http.MethodGet, "/admin/apilogs", nil) + + res := newHTTPTest(req, admin.getRequestLoggerEnabled) + + assert.Equal(t, http.StatusOK, res.Code) + assert.True(t, admin.logRequests.Load()) +} + +func newHTTPTest(req *http.Request, handlerFunc utils.HandlerFunc) *httptest.ResponseRecorder { + rr := httptest.NewRecorder() + handler := utils.WrapHandlerFunc(handlerFunc) + handler.ServeHTTP(rr, req) + return rr +} + +func newAdmin() *Admin { + var lvl slog.LevelVar + lvl.Set(slog.LevelDebug) + + var enabled atomic.Bool + enabled.Store(true) + + return NewAdmin("localhost:0", &lvl, &enabled) +} + +func newRequest(t *testing.T, method, url string, body interface{}) *http.Request { + reqBody := marshalBody(t, body) + req, err := http.NewRequest(method, url, bytes.NewBuffer(reqBody)) + if err != nil { + t.Fatal(err) + } + return req +} + +func marshalBody(t *testing.T, body interface{}) []byte { + var reqBody []byte + var err error + if body != nil { + reqBody, err = json.Marshal(body) + if err != nil { + t.Fatalf("could not marshal request body: %v", err) + } + } + return reqBody +} diff --git a/api/api.go b/api/api.go index 38b412a97..b47f86d5f 100644 --- a/api/api.go +++ b/api/api.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/pprof" "strings" + "sync/atomic" "github.com/gorilla/handlers" "github.com/gorilla/mux" @@ -47,7 +48,7 @@ func New( pprofOn bool, skipLogs bool, allowCustomTracer bool, - enableReqLogger bool, + enableReqLogger *atomic.Bool, enableMetrics bool, logsLimit uint64, allowedTracers []string, @@ -110,9 +111,7 @@ func New( handlers.ExposedHeaders([]string{"x-genesis-id", "x-thorest-ver"}), )(handler) - if enableReqLogger { - handler = RequestLoggerHandler(handler, logger) - } + handler = RequestLoggerHandler(handler, logger, enableReqLogger) return handler.ServeHTTP, subs.Close // subscriptions handles hijacked conns, which need to be closed } diff --git a/api/request_logger.go b/api/request_logger.go index 3d48a2d36..d10b1d055 100644 --- a/api/request_logger.go +++ b/api/request_logger.go @@ -9,14 +9,21 @@ import ( "bytes" "io" "net/http" + "sync/atomic" "time" "github.com/vechain/thor/v2/log" ) // RequestLoggerHandler returns a http handler to ensure requests are syphoned into the writer -func RequestLoggerHandler(handler http.Handler, logger log.Logger) http.Handler { +func RequestLoggerHandler(handler http.Handler, logger log.Logger, enabled *atomic.Bool) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { + // If logging is disabled, just call the original handler + if !enabled.Load() { + handler.ServeHTTP(w, r) + return + } + // Read and log the body (note: this can only be done once) // Ensure you don't disrupt the request body for handlers that need to read it var bodyBytes []byte diff --git a/api/request_logger_test.go b/api/request_logger_test.go index 6b8ddcd91..76026e6b4 100644 --- a/api/request_logger_test.go +++ b/api/request_logger_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" @@ -67,7 +68,9 @@ func TestRequestLoggerHandler(t *testing.T) { }) // Create the RequestLoggerHandler - loggerHandler := RequestLoggerHandler(testHandler, mockLog) + enabled := &atomic.Bool{} + enabled.Store(true) + loggerHandler := RequestLoggerHandler(testHandler, mockLog, enabled) // Create a test HTTP request reqBody := "test body" diff --git a/cmd/thor/main.go b/cmd/thor/main.go index 4b934bc13..4b217551d 100644 --- a/cmd/thor/main.go +++ b/cmd/thor/main.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "strings" + "sync/atomic" "time" "github.com/ethereum/go-ethereum/accounts/keystore" @@ -179,13 +180,16 @@ func defaultAction(ctx *cli.Context) error { defer func() { log.Info("stopping metrics server..."); closeFunc() }() } + logAPIRequests := atomic.Bool{} + logAPIRequests.Store(ctx.Bool(enableAPILogsFlag.Name)) adminURL := "" if ctx.Bool(enableAdminFlag.Name) { - url, closeFunc, err := api.StartAdminServer(ctx.String(adminAddrFlag.Name), logLevel) + admin := api.NewAdmin(ctx.String(adminAddrFlag.Name), logLevel, &logAPIRequests) + var closeFunc func() + adminURL, closeFunc, err = admin.Start() if err != nil { - return fmt.Errorf("unable to start admin server - %w", err) + return err } - adminURL = url defer func() { log.Info("stopping admin server..."); closeFunc() }() } @@ -261,7 +265,7 @@ func defaultAction(ctx *cli.Context) error { ctx.Bool(pprofFlag.Name), skipLogs, ctx.Bool(apiAllowCustomTracerFlag.Name), - ctx.Bool(enableAPILogsFlag.Name), + &logAPIRequests, ctx.Bool(enableMetricsFlag.Name), ctx.Uint64(apiLogsLimitFlag.Name), parseTracerList(strings.TrimSpace(ctx.String(allowedTracersFlag.Name))), @@ -322,13 +326,16 @@ func soloAction(ctx *cli.Context) error { defer func() { log.Info("stopping metrics server..."); closeFunc() }() } + logAPIRequests := atomic.Bool{} + logAPIRequests.Store(ctx.Bool(enableAPILogsFlag.Name)) adminURL := "" if ctx.Bool(enableAdminFlag.Name) { - url, closeFunc, err := api.StartAdminServer(ctx.String(adminAddrFlag.Name), logLevel) + admin := api.NewAdmin(ctx.String(adminAddrFlag.Name), logLevel, &logAPIRequests) + var closeFunc func() + adminURL, closeFunc, err = admin.Start() if err != nil { - return fmt.Errorf("unable to start admin server - %w", err) + return err } - adminURL = url defer func() { log.Info("stopping admin server..."); closeFunc() }() } @@ -413,7 +420,7 @@ func soloAction(ctx *cli.Context) error { ctx.Bool(pprofFlag.Name), skipLogs, ctx.Bool(apiAllowCustomTracerFlag.Name), - ctx.Bool(enableAPILogsFlag.Name), + &logAPIRequests, ctx.Bool(enableMetricsFlag.Name), ctx.Uint64(apiLogsLimitFlag.Name), parseTracerList(strings.TrimSpace(ctx.String(allowedTracersFlag.Name))),