Skip to content

Commit

Permalink
Convert lib/client to use slog (#50498)
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy authored Dec 23, 2024
1 parent d070ce0 commit c394596
Show file tree
Hide file tree
Showing 27 changed files with 307 additions and 333 deletions.
2 changes: 1 addition & 1 deletion lib/client/alpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func RunALPNAuthTunnel(ctx context.Context, cfg ALPNAuthTunnelConfig) error {
go func() {
defer cfg.Listener.Close()
if err := lp.Start(ctx); err != nil {
log.WithError(err).Info("ALPN proxy stopped.")
log.InfoContext(ctx, "ALPN proxy stopped", "error", err)
}
}()

Expand Down
116 changes: 69 additions & 47 deletions lib/client/api.go

Large diffs are not rendered by default.

24 changes: 2 additions & 22 deletions lib/client/api_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
"github.com/google/uuid"
"github.com/jonboulle/clockwork"
"github.com/pquerna/otp/totp"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -64,8 +63,6 @@ import (
func TestTeleportClient_Login_local(t *testing.T) {
t.Parallel()

silenceLogger(t)

type webauthnFunc func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt) (*proto.MFAAuthenticateResponse, error)

waitForCancelFn := func(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -324,8 +321,6 @@ func TestTeleportClient_Login_local(t *testing.T) {
}

func TestTeleportClient_DeviceLogin(t *testing.T) {
silenceLogger(t)

clock := clockwork.NewFakeClockAt(time.Now())
sa := newStandaloneTeleport(t, clock)
username := sa.Username
Expand Down Expand Up @@ -521,10 +516,6 @@ type standaloneBundle struct {
func newStandaloneTeleport(t *testing.T, clock clockwork.Clock) *standaloneBundle {
randomAddr := utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}

// Silent logger and console.
logger := utils.NewLoggerForTests()
logger.SetLevel(log.PanicLevel)
logger.SetOutput(io.Discard)
console := io.Discard

staticToken := uuid.New().String()
Expand Down Expand Up @@ -559,7 +550,7 @@ func newStandaloneTeleport(t *testing.T, clock clockwork.Clock) *standaloneBundl
cfg.Hostname = "localhost"
cfg.Clock = clock
cfg.Console = console
cfg.Log = logger
cfg.Logger = utils.NewSlogLoggerForTests()
cfg.SetAuthServerAddress(randomAddr) // must be present
cfg.Auth.Preference, err = types.NewAuthPreferenceFromConfigFile(types.AuthPreferenceSpecV2{
Type: constants.Local,
Expand Down Expand Up @@ -643,7 +634,7 @@ func newStandaloneTeleport(t *testing.T, clock clockwork.Clock) *standaloneBundl
cfg.SetToken(staticToken)
cfg.Clock = clock
cfg.Console = console
cfg.Log = logger
cfg.Logger = utils.NewSlogLoggerForTests()
cfg.SetAuthServerAddress(*authAddr)
cfg.Auth.Enabled = false
cfg.Proxy.Enabled = true
Expand Down Expand Up @@ -683,17 +674,6 @@ func startAndWait(t *testing.T, cfg *servicecfg.Config, eventName string) *servi
return instance
}

// silenceLogger silences logger during testing.
func silenceLogger(t *testing.T) {
lvl := log.GetLevel()
t.Cleanup(func() {
log.SetOutput(os.Stderr)
log.SetLevel(lvl)
})
log.SetOutput(io.Discard)
log.SetLevel(log.PanicLevel)
}

func TestRetryWithRelogin(t *testing.T) {
clock := clockwork.NewFakeClockAt(time.Now())
sa := newStandaloneTeleport(t, clock)
Expand Down
76 changes: 47 additions & 29 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/utils"
logutils "github.com/gravitational/teleport/lib/utils/log"
"github.com/gravitational/teleport/lib/utils/socks"
)

Expand Down Expand Up @@ -217,7 +218,7 @@ func makeDatabaseClientPEM(proto string, cert []byte, pk *keys.PrivateKey) ([]by
} else if !trace.IsBadParameter(err) {
return nil, trace.Wrap(err)
}
log.WithError(err).Warn("MongoDB integration is not supported when logging in with a hardware private key.")
log.WarnContext(context.Background(), "MongoDB integration is not supported when logging in with a hardware private key", "error", err)
}
return cert, nil
}
Expand Down Expand Up @@ -333,7 +334,11 @@ func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Co
// TODO(codingllama): Improve error message below for device trust.
// An alternative we have here is querying the cluster to check if device
// trust is required, a check similar to `IsMFARequired`.
log.Infof("Access denied to %v connecting to %v: %v", sshConfig.User, nodeName, err)
log.InfoContext(ctx, "Access denied connecting to host",
"login", sshConfig.User,
"target_host", nodeName,
"error", err,
)
return nil, trace.AccessDenied(`access denied to %v connecting to %v`, sshConfig.User, nodeName)
}
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -640,7 +645,7 @@ func (c *NodeClient) handleGlobalRequests(ctx context.Context, requestCh <-chan
switch r.Type {
case teleport.MFAPresenceRequest:
if c.OnMFA == nil {
log.Warn("Received MFA presence request, but no callback was provided.")
log.WarnContext(ctx, "Received MFA presence request, but no callback was provided")
continue
}

Expand All @@ -651,21 +656,21 @@ func (c *NodeClient) handleGlobalRequests(ctx context.Context, requestCh <-chan
var e events.EventFields
err := json.Unmarshal(r.Payload, &e)
if err != nil {
log.Warnf("Unable to parse event: %v: %v.", string(r.Payload), err)
log.WarnContext(ctx, "Unable to parse event", "event", string(r.Payload), "error", err)
continue
}

// Send event to event channel.
err = c.TC.SendEvent(ctx, e)
if err != nil {
log.Warnf("Unable to send event %v: %v.", string(r.Payload), err)
log.WarnContext(ctx, "Unable to send event", "event", string(r.Payload), "error", err)
continue
}
default:
// This handles keep-alive messages and matches the behavior of OpenSSH.
err := r.Reply(false, nil)
if err != nil {
log.Warnf("Unable to reply to %v request.", r.Type)
log.WarnContext(ctx, "Unable to reply to request", "request_type", r.Type, "error", err)
continue
}
}
Expand Down Expand Up @@ -707,7 +712,7 @@ func newClientConn(
case <-ctx.Done():
errClose := conn.Close()
if errClose != nil {
log.Error(errClose)
log.ErrorContext(ctx, "Failed closing connection", "error", errClose)
}
// drain the channel
resp := <-respCh
Expand All @@ -732,11 +737,16 @@ type netDialer interface {
}

func proxyConnection(ctx context.Context, conn net.Conn, remoteAddr string, dialer netDialer) error {
logger := log.With(
"source_addr", logutils.StringerAttr(conn.RemoteAddr()),
"target_addr", remoteAddr,
)

defer conn.Close()
defer log.Debugf("Finished proxy from %v to %v.", conn.RemoteAddr(), remoteAddr)
defer logger.DebugContext(ctx, "Finished proxy connection")

var remoteConn net.Conn
log.Debugf("Attempting to connect proxy from %v to %v.", conn.RemoteAddr(), remoteAddr)
logger.DebugContext(ctx, "Attempting to proxy connection")

retry, err := retryutils.NewLinear(retryutils.LinearConfig{
First: 100 * time.Millisecond,
Expand All @@ -756,7 +766,7 @@ func proxyConnection(ctx context.Context, conn net.Conn, remoteAddr string, dial
break
}

log.Debugf("Proxy connection attempt %v: %v.", attempt, err)
logger.DebugContext(ctx, "Proxy connection attempt", "attempt", attempt, "error", err)
// Wait and attempt to connect again, if the context has closed, exit
// right away.
select {
Expand Down Expand Up @@ -806,16 +816,19 @@ func acceptWithContext(ctx context.Context, l net.Listener) (net.Conn, error) {
func (c *NodeClient) listenAndForward(ctx context.Context, ln net.Listener, localAddr string, remoteAddr string) {
defer ln.Close()

log := log.WithField("localAddr", localAddr).WithField("remoteAddr", remoteAddr)
log := log.With(
"local_addr", localAddr,
"remote_addr", remoteAddr,
)

log.Infof("Starting port forwarding")
log.InfoContext(ctx, "Starting port forwarding")

for ctx.Err() == nil {
// Accept connections from the client.
conn, err := acceptWithContext(ctx, ln)
if err != nil {
if ctx.Err() == nil {
log.WithError(err).Errorf("Port forwarding failed.")
log.ErrorContext(ctx, "Port forwarding failed", "error", err)
}
continue
}
Expand All @@ -824,30 +837,32 @@ func (c *NodeClient) listenAndForward(ctx context.Context, ln net.Listener, loca
go func() {
// `err` must be a fresh variable, hence `:=` instead of `=`.
if err := proxyConnection(ctx, conn, remoteAddr, c.Client); err != nil {
log.WithError(err).Warnf("Failed to proxy connection.")
log.WarnContext(ctx, "Failed to proxy connection", "error", err)
}
}()
}

log.WithError(ctx.Err()).Infof("Shutting down port forwarding.")
log.InfoContext(ctx, "Shutting down port forwarding", "error", ctx.Err())
}

// dynamicListenAndForward listens for connections, performs a SOCKS5
// handshake, and then proxies the connection to the requested address.
func (c *NodeClient) dynamicListenAndForward(ctx context.Context, ln net.Listener, localAddr string) {
defer ln.Close()

log := log.WithField("localAddr", localAddr)
log := log.With(
"local_addr", localAddr,
)

log.Infof("Starting dynamic port forwarding.")
log.InfoContext(ctx, "Starting dynamic port forwarding")

for ctx.Err() == nil {
// Accept connection from the client. Here the client is typically
// something like a web browser or other SOCKS5 aware application.
conn, err := acceptWithContext(ctx, ln)
if err != nil {
if ctx.Err() == nil {
log.WithError(err).Errorf("Dynamic port forwarding (SOCKS5) failed.")
log.ErrorContext(ctx, "Dynamic port forwarding (SOCKS5) failed", "error", err)
}
continue
}
Expand All @@ -856,52 +871,55 @@ func (c *NodeClient) dynamicListenAndForward(ctx context.Context, ln net.Listene
// address to proxy.
remoteAddr, err := socks.Handshake(conn)
if err != nil {
log.WithError(err).Errorf("SOCKS5 handshake failed.")
log.ErrorContext(ctx, "SOCKS5 handshake failed", "error", err)
if err = conn.Close(); err != nil {
log.WithError(err).Errorf("Error closing failed proxy connection.")
log.ErrorContext(ctx, "Error closing failed proxy connection", "error", err)
}
continue
}
log.Debugf("SOCKS5 proxy forwarding requests to %v.", remoteAddr)
log.DebugContext(ctx, "SOCKS5 proxy forwarding requests", "remote_addr", remoteAddr)

// Proxy the connection to the remote address.
go func() {
// `err` must be a fresh variable, hence `:=` instead of `=`.
if err := proxyConnection(ctx, conn, remoteAddr, c.Client); err != nil {
log.WithError(err).Warnf("Failed to proxy connection.")
log.WarnContext(ctx, "Failed to proxy connection", "error", err)
if err = conn.Close(); err != nil {
log.WithError(err).Errorf("Error closing failed proxy connection.")
log.ErrorContext(ctx, "Error closing failed proxy connection", "error", err)
}
}
}()
}

log.WithError(ctx.Err()).Infof("Shutting down dynamic port forwarding.")
log.InfoContext(ctx, "Shutting down dynamic port forwarding", "error", ctx.Err())
}

// remoteListenAndForward requests a listening socket and forwards all incoming
// commands to the local address through the SSH tunnel.
func (c *NodeClient) remoteListenAndForward(ctx context.Context, ln net.Listener, localAddr, remoteAddr string) {
defer ln.Close()
log := log.WithField("localAddr", localAddr).WithField("remoteAddr", remoteAddr)
log.Infof("Starting remote port forwarding")
log := log.With(
"local_addr", localAddr,
"remote_addr", remoteAddr,
)
log.InfoContext(ctx, "Starting remote port forwarding")

for ctx.Err() == nil {
conn, err := acceptWithContext(ctx, ln)
if err != nil {
if ctx.Err() == nil {
log.WithError(err).Errorf("Remote port forwarding failed.")
log.ErrorContext(ctx, "Remote port forwarding failed", "error", err)
}
continue
}

go func() {
if err := proxyConnection(ctx, conn, localAddr, &net.Dialer{}); err != nil {
log.WithError(err).Warnf("Failed to proxy connection")
log.WarnContext(ctx, "Failed to proxy connection", "error", err)
}
}()
}
log.WithError(ctx.Err()).Infof("Shutting down remote port forwarding.")
log.InfoContext(ctx, "Shutting down remote port forwarding", "error", ctx.Err())
}

// GetRemoteTerminalSize fetches the terminal size of a given SSH session.
Expand Down
14 changes: 9 additions & 5 deletions lib/client/client_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
package client

import (
"context"
"errors"
"fmt"
"log/slog"
"net/url"
"time"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
Expand All @@ -42,7 +43,7 @@ import (
// when using `tsh --add-keys-to-agent=only`, Store will be made up of an in-memory
// key store and an FS (~/.tsh) profile and trusted certs store.
type Store struct {
log *logrus.Entry
log *slog.Logger

KeyStore
TrustedCertsStore
Expand All @@ -53,7 +54,7 @@ type Store struct {
func NewFSClientStore(dirPath string) *Store {
dirPath = profile.FullProfilePath(dirPath)
return &Store{
log: logrus.WithField(teleport.ComponentKey, teleport.ComponentKeyStore),
log: slog.With(teleport.ComponentKey, teleport.ComponentKeyStore),
KeyStore: NewFSKeyStore(dirPath),
TrustedCertsStore: NewFSTrustedCertsStore(dirPath),
ProfileStore: NewFSProfileStore(dirPath),
Expand All @@ -63,7 +64,7 @@ func NewFSClientStore(dirPath string) *Store {
// NewMemClientStore initializes a new in-memory client store.
func NewMemClientStore() *Store {
return &Store{
log: logrus.WithField(teleport.ComponentKey, teleport.ComponentKeyStore),
log: slog.With(teleport.ComponentKey, teleport.ComponentKeyStore),
KeyStore: NewMemKeyStore(),
TrustedCertsStore: NewMemTrustedCertsStore(),
ProfileStore: NewMemProfileStore(),
Expand Down Expand Up @@ -261,7 +262,10 @@ func (s *Store) FullProfileStatus() (*ProfileStatus, []*ProfileStatus, error) {
}
status, err := s.ReadProfileStatus(profileName)
if err != nil {
s.log.WithError(err).Warnf("skipping profile %q due to error", profileName)
s.log.WarnContext(context.Background(), "skipping profile due to error",
"profile_name", profileName,
"error", err,
)
continue
}
profiles = append(profiles, status)
Expand Down
Loading

0 comments on commit c394596

Please sign in to comment.