From bbf7c1a9bae31b4ec20b3e0853eaee8de3eaf43a Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Fri, 20 Dec 2024 22:24:11 -0500 Subject: [PATCH] Convert lib/auth/middleware to use slog (#50521) --- lib/auth/init.go | 8 ++++--- lib/auth/middleware.go | 50 ++++++++++++++++++++++++++---------------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/lib/auth/init.go b/lib/auth/init.go index c3707012cccf9..9bcd669a27313 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -71,11 +71,13 @@ import ( "github.com/gravitational/teleport/lib/tlsca" usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) -var log = logrus.WithFields(logrus.Fields{ - teleport.ComponentKey: teleport.ComponentAuth, -}) +var ( + log = logrus.WithField(teleport.ComponentKey, teleport.ComponentAuth) + logger = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentAuth) +) // VersionStorage local storage for saving the version. type VersionStorage interface { diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index 8e21c071eb4a5..4d5b913f9db56 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -24,6 +24,7 @@ import ( "crypto/x509" "encoding/json" "fmt" + "log/slog" "net" "net/http" "os" @@ -36,7 +37,6 @@ import ( "github.com/gravitational/trace" grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "golang.org/x/net/http2" "google.golang.org/grpc" @@ -57,6 +57,7 @@ import ( "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -139,7 +140,7 @@ type TLSServer struct { // cfg is TLS server configuration used for auth server cfg TLSServerConfig // log is TLS server logging entry - log *logrus.Entry + log *slog.Logger // mux is a listener that multiplexes HTTP/2 and HTTP/1.1 // on different listeners mux *multiplexer.TLSListener @@ -215,9 +216,7 @@ func NewTLSServer(ctx context.Context, cfg TLSServerConfig) (*TLSServer, error) return authz.ContextWithConn(ctx, c) }, }, - log: logrus.WithFields(logrus.Fields{ - teleport.ComponentKey: cfg.Component, - }), + log: slog.With(teleport.ComponentKey, cfg.Component), } tlsConfig := cfg.TLS.Clone() @@ -306,7 +305,7 @@ func (t *TLSServer) Serve() error { errC := make(chan error, 2) go func() { err := t.mux.Serve() - t.log.WithError(err).Warningf("Mux serve failed.") + t.log.WarnContext(context.Background(), "Mux serve failed", "error", err) }() go func() { errC <- t.httpServer.Serve(t.mux.HTTP()) @@ -372,7 +371,9 @@ func getCustomRate(endpoint string) *limiter.RateSet { rates := limiter.NewRateSet() // This limit means: 1 request per minute with bursts up to 10 requests. if err := rates.Add(time.Minute, 1, 10); err != nil { - log.WithError(err).Debugf("Failed to define a custom rate for rpc method %q, using default rate", endpoint) + logger.DebugContext(context.Background(), "Failed to define a custom rate for rpc method, using default rate", + "error", err, + "rpc_method", endpoint) return nil } return rates @@ -383,7 +384,10 @@ func getCustomRate(endpoint string) *limiter.RateSet { const burst = defaults.LimiterBurst rates := limiter.NewRateSet() if err := rates.Add(period, average, burst); err != nil { - log.WithError(err).Debugf("Failed to define a custom rate for rpc method %q, using default rate", endpoint) + logger.DebugContext(context.Background(), "Failed to define a custom rate for rpc method, using default rate", + "error", err, + "rpc_method", endpoint, + ) return nil } return rates @@ -406,24 +410,29 @@ func (a *Middleware) ValidateClientVersion(ctx context.Context, info IdentityInf ua := metadata.UserAgentFromContext(ctx) - logger := log.WithFields(logrus.Fields{"user_agent": ua, "identity": info.IdentityGetter.GetIdentity().Username, "version": clientVersionString, "addr": info.Conn.RemoteAddr().String()}) + logger := slog.With( + "user_agent", ua, + "identity", info.IdentityGetter.GetIdentity().Username, + "version", clientVersionString, + "addr", logutils.StringerAttr(info.Conn.RemoteAddr()), + ) clientVersion, err := semver.NewVersion(clientVersionString) if err != nil { - logger.WithError(err).Warn("Failed to determine client version") + logger.WarnContext(ctx, "Failed to determine client version", "error", err) a.displayRejectedClientAlert(ctx, clientVersionString, info.Conn.RemoteAddr(), ua, info.IdentityGetter) if err := info.Conn.Close(); err != nil { - logger.WithError(err).Warn("Failed to close client connection") + logger.WarnContext(ctx, "Failed to close client connection", "error", err) } return trace.AccessDenied("client version is unsupported") } if clientVersion.LessThan(*a.OldestSupportedVersion) { - logger.Info("Terminating connection of client using unsupported version") + logger.InfoContext(ctx, "Terminating connection of client using unsupported version") a.displayRejectedClientAlert(ctx, clientVersionString, info.Conn.RemoteAddr(), ua, info.IdentityGetter) if err := info.Conn.Close(); err != nil { - logger.WithError(err).Warn("Failed to close client connection") + logger.WarnContext(ctx, "Failed to close client connection", "error", err) } return trace.AccessDenied("client version is unsupported") @@ -486,12 +495,12 @@ func (a *Middleware) displayRejectedClientAlert(ctx context.Context, clientVersi types.WithAlertLabel(types.AlertVerbPermit, fmt.Sprintf("%s:%s", types.KindToken, types.VerbCreate)), ) if err != nil { - log.WithError(err).Warn("failed to create rejected-unsupported-connection alert") + logger.WarnContext(ctx, "failed to create rejected-unsupported-connection alert", "error", err) return } if err := a.AlertCreator(ctx, alert); err != nil { - log.WithError(err).Warn("failed to persist rejected-unsupported-connection alert") + logger.WarnContext(ctx, "failed to persist rejected-unsupported-connection alert", "error", err) return } } @@ -656,7 +665,7 @@ func (a *Middleware) GetUser(connState tls.ConnectionState) (authz.IdentityGette if certClusterName == "" { certClusterName, err = tlsca.ClusterName(clientCert.Issuer) if err != nil { - log.Warnf("Failed to parse client certificate %v.", err) + logger.WarnContext(context.Background(), "Failed to parse client certificate", "error", err) return nil, trace.AccessDenied("access denied: invalid client certificate") } identity.TeleportCluster = certClusterName @@ -667,8 +676,11 @@ func (a *Middleware) GetUser(connState tls.ConnectionState) (authz.IdentityGette // against auth server. Later on we can extend more // advanced cert usage, but for now this is the safest option. if len(identity.Usage) != 0 && !slices.Equal(a.AcceptedUsage, identity.Usage) { - log.Warningf("Restricted certificate of user %q with usage %v rejected while accessing the auth endpoint with acceptable usage %v.", - identity.Username, identity.Usage, a.AcceptedUsage) + logger.WarnContext(context.Background(), "Restricted certificate rejected while accessing the auth endpoint", + "user", identity.Username, + "cert_usage", identity.Usage, + "acceptable_usage", a.AcceptedUsage, + ) return nil, trace.AccessDenied("access denied: invalid client certificate") } @@ -734,7 +746,7 @@ func extractAdditionalSystemRoles(roles []string) types.SystemRoles { if err != nil { // ignore unknown system roles rather than rejecting them, since new unknown system // roles may be present on certs if we rolled back from a newer version. - log.Warnf("Ignoring unknown system role: %q", role) + logger.WarnContext(context.Background(), "Ignoring unknown system role", "unknown_role", role) continue } systemRoles = append(systemRoles, systemRole)