Skip to content

Commit

Permalink
Migrate lib/events to slog (#46875)
Browse files Browse the repository at this point in the history
Converts all events backend logging from logrus to slog.
  • Loading branch information
rosstimothy authored Sep 30, 2024
1 parent 56b1bab commit 95e3d60
Show file tree
Hide file tree
Showing 29 changed files with 478 additions and 576 deletions.
14 changes: 6 additions & 8 deletions lib/events/athena/athena.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package athena
import (
"context"
"io"
"log/slog"
"net/url"
"regexp"
"strconv"
Expand All @@ -32,7 +33,6 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws"
oteltrace "go.opentelemetry.io/otel/trace"

Expand Down Expand Up @@ -130,8 +130,8 @@ type Config struct {
Clock clockwork.Clock
// UIDGenerator is unique ID generator.
UIDGenerator utils.UID
// LogEntry is a log entry.
LogEntry *log.Entry
// Logger emits log messages.
Logger *slog.Logger

// PublisherConsumerAWSConfig is an AWS config which can be used to
// construct AWS Clients using aws-sdk-go-v2, used by the publisher and
Expand Down Expand Up @@ -276,10 +276,8 @@ func (cfg *Config) CheckAndSetDefaults(ctx context.Context) error {
cfg.UIDGenerator = utils.NewRealUID()
}

if cfg.LogEntry == nil {
cfg.LogEntry = log.WithFields(log.Fields{
teleport.ComponentKey: teleport.ComponentAthena,
})
if cfg.Logger == nil {
cfg.Logger = slog.With(teleport.ComponentKey, teleport.ComponentAthena)
}

if cfg.PublisherConsumerAWSConfig == nil {
Expand Down Expand Up @@ -476,7 +474,7 @@ func New(ctx context.Context, cfg Config) (*Log, error) {
getQueryResultsInterval: cfg.GetQueryResultsInterval,
disableQueryCostOptimization: cfg.DisableSearchCostOptimization,
awsCfg: cfg.StorerQuerierAWSConfig,
logger: cfg.LogEntry,
logger: cfg.Logger,
clock: cfg.Clock,
tracer: cfg.Tracer,
})
Expand Down
2 changes: 1 addition & 1 deletion lib/events/athena/athena_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) {
err := cfg.CheckAndSetDefaults(context.Background())
if tt.wantErr == "" {
require.NoError(t, err, "CheckAndSetDefaults return unexpected err")
require.Empty(t, cmp.Diff(tt.want, cfg, cmpopts.EquateApprox(0, 0.0001), cmpopts.IgnoreFields(Config{}, "Clock", "UIDGenerator", "LogEntry", "Tracer", "metrics", "ObserveWriteEventsError"), cmp.AllowUnexported(Config{})))
require.Empty(t, cmp.Diff(tt.want, cfg, cmpopts.EquateApprox(0, 0.0001), cmpopts.IgnoreFields(Config{}, "Clock", "UIDGenerator", "Logger", "Tracer", "metrics", "ObserveWriteEventsError"), cmp.AllowUnexported(Config{})))
} else {
require.ErrorContains(t, err, tt.wantErr)
}
Expand Down
36 changes: 17 additions & 19 deletions lib/events/athena/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"slices"
"strconv"
Expand All @@ -40,7 +41,6 @@ import (
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/parquet-go/parquet-go"
log "github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
Expand Down Expand Up @@ -73,7 +73,7 @@ const (
// consumer is responsible for receiving messages from SQS, batching them up to
// certain size or interval, and writes to s3 as parquet file.
type consumer struct {
logger log.FieldLogger
logger *slog.Logger
backend backend.Backend
storeLocationPrefix string
storeLocationBucket string
Expand Down Expand Up @@ -138,7 +138,7 @@ func newConsumer(cfg Config, cancelFn context.CancelFunc) (*consumer, error) {
visibilityTimeout: int32(cfg.BatchMaxInterval.Seconds()),
batchMaxItems: cfg.BatchMaxItems,
errHandlingFn: errHandlingFnFromSQS(&cfg),
logger: cfg.LogEntry,
logger: cfg.Logger,
metrics: cfg.metrics,
}
err := collectCfg.CheckAndSetDefaults()
Expand All @@ -151,7 +151,7 @@ func newConsumer(cfg Config, cancelFn context.CancelFunc) (*consumer, error) {
}

return &consumer{
logger: cfg.LogEntry,
logger: cfg.Logger,
backend: cfg.Backend,
storeLocationPrefix: cfg.locationS3Prefix,
storeLocationBucket: cfg.locationS3Bucket,
Expand Down Expand Up @@ -189,7 +189,7 @@ func newConsumer(cfg Config, cancelFn context.CancelFunc) (*consumer, error) {
func (c *consumer) run(ctx context.Context) {
defer func() {
close(c.finished)
c.logger.Debug("Consumer finished")
c.logger.DebugContext(ctx, "Consumer finished")
}()
c.runContinuouslyOnSingleAuth(ctx, c.processEventsContinuously)
}
Expand Down Expand Up @@ -219,14 +219,14 @@ func (c *consumer) processEventsContinuously(ctx context.Context) {
if ctx.Err() != nil {
return false
}
c.logger.Errorf("Batcher single run failed: %v", err)
c.logger.ErrorContext(ctx, "Batcher single run failed", "error", err)
return false
}
return reachedMaxBatch
}

c.logger.Debug("Processing of events started on this instance")
defer c.logger.Debug("Processing of events finished on this instance")
c.logger.DebugContext(ctx, "Processing of events started on this instance")
defer c.logger.DebugContext(ctx, "Processing of events finished on this instance")

// If batch took 90% of specified interval, we don't want to wait just little bit.
// It's mainly to avoid cases when we will wait like 10ms.
Expand Down Expand Up @@ -276,7 +276,7 @@ func (c *consumer) runContinuouslyOnSingleAuth(ctx context.Context, eventsProces
}
// Ending up here means something went wrong in the backend while locking/waiting
// for lock. What we can do is log and retry whole operation.
c.logger.WithError(err).Warn("Could not get consumer to run with lock")
c.logger.WarnContext(ctx, "Could not get consumer to run with lock", "error", err)
select {
// Use wait to make sure we won't spam CPU with a lot requests
// if something goes wrong during acquire lock.
Expand Down Expand Up @@ -372,7 +372,7 @@ type sqsCollectConfig struct {
// noOfWorkers defines how many workers are processing messages from queue.
noOfWorkers int

logger log.FieldLogger
logger *slog.Logger
errHandlingFn func(ctx context.Context, errC chan error)

metrics *athenaMetrics
Expand Down Expand Up @@ -413,9 +413,7 @@ func (cfg *sqsCollectConfig) CheckAndSetDefaults() error {
cfg.noOfWorkers = 5
}
if cfg.logger == nil {
cfg.logger = log.WithFields(log.Fields{
teleport.ComponentKey: teleport.ComponentAthena,
})
cfg.logger = slog.With(teleport.ComponentKey, teleport.ComponentAthena)
}
if cfg.errHandlingFn == nil {
return trace.BadParameter("errHandlingFn is not specified")
Expand Down Expand Up @@ -499,7 +497,7 @@ func (s *sqsMessagesCollector) fromSQS(ctx context.Context) {
if isOverBatch || isOverMaximumUniqueDays {
fullBatchMetadataMu.Unlock()
cancel()
s.cfg.logger.Debugf("Batcher aborting early because of maxSize: %v, or maxUniqueDays %v", isOverBatch, isOverMaximumUniqueDays)
s.cfg.logger.DebugContext(ctx, "Batcher aborting early", "max_size", isOverBatch, "max_unique_days", isOverMaximumUniqueDays)
return
}
fullBatchMetadataMu.Unlock()
Expand Down Expand Up @@ -621,7 +619,7 @@ func (s *sqsMessagesCollector) receiveMessagesAndSendOnChan(ctx context.Context,
}
messageSentTimestamp, err := getMessageSentTimestamp(msg)
if err != nil {
s.cfg.logger.Debugf("Failed to get sentTimestamp: %v", err)
s.cfg.logger.DebugContext(ctx, "Failed to get sentTimestamp", "error", err)
}
singleReceiveMetadata.MergeWithEvent(event, messageSentTimestamp)
}
Expand Down Expand Up @@ -707,7 +705,7 @@ func errHandlingFnFromSQS(cfg *Config) func(ctx context.Context, errC chan error

defer func() {
if errorsCount > maxErrorCountForLogsOnSQSReceive {
cfg.LogEntry.Errorf("Got %d errors from SQS collector, printed only first %d", errorsCount, maxErrorCountForLogsOnSQSReceive)
cfg.Logger.ErrorContext(ctx, "Got errors from SQS collector", "error_count", errorsCount)
}
cfg.metrics.consumerNumberOfErrorsFromSQSCollect.Add(float64(errorsCount))
}()
Expand All @@ -723,7 +721,7 @@ func errHandlingFnFromSQS(cfg *Config) func(ctx context.Context, errC chan error
}
errorsCount++
if errorsCount <= maxErrorCountForLogsOnSQSReceive {
cfg.LogEntry.WithError(err).Error("Failure processing SQS messages")
cfg.Logger.ErrorContext(ctx, "Failure processing SQS messages", "error", err)
}
}
}
Expand All @@ -744,7 +742,7 @@ func (s *sqsMessagesCollector) downloadEventFromS3(ctx context.Context, payload
path := s3Payload.GetPath()
versionID := s3Payload.GetVersionId()

s.cfg.logger.Debugf("Downloading %v %v [%v].", s.cfg.payloadBucket, path, versionID)
s.cfg.logger.DebugContext(ctx, "Downloading event from S3", "bucket", s.cfg.payloadBucket, "path", path, "version", versionID)

var versionIDPtr *string
if versionID != "" {
Expand Down Expand Up @@ -785,7 +783,7 @@ eventLoop:
}
pqtEvent, err := auditEventToParquet(eventAndAckID.event)
if err != nil {
c.logger.WithError(err).Error("Could not convert event to parquet format")
c.logger.ErrorContext(ctx, "Could not convert event to parquet format", "error", err)
continue
}
date := pqtEvent.EventTime.Format(time.DateOnly)
Expand Down
103 changes: 3 additions & 100 deletions lib/events/athena/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@
package athena

import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"math/big"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
Expand All @@ -44,7 +43,6 @@ import (
"github.com/parquet-go/parquet-go"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -251,7 +249,7 @@ func validCollectCfgForTests(t *testing.T) sqsCollectConfig {
queueURL: "test-queue",
payloadBucket: "bucket",
payloadDownloader: &fakeS3manager{},
logger: utils.NewLoggerForTests(),
logger: slog.Default(),
errHandlingFn: func(ctx context.Context, errC chan error) {
err, ok := <-errC
if ok && err != nil {
Expand Down Expand Up @@ -416,7 +414,7 @@ func (m *mockReceiver) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMe
}

func TestConsumerRunContinuouslyOnSingleAuth(t *testing.T) {
log := utils.NewLoggerForTests()
log := slog.Default()
backend, err := memory.New(memory.Config{})
require.NoError(t, err)
defer backend.Close()
Expand Down Expand Up @@ -561,101 +559,6 @@ func TestRunWithMinInterval(t *testing.T) {
})
}

func TestErrHandlingFnFromSQS(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()

log := utils.NewLoggerForTests()
// buf is used as output of logs, that we will use for assertions.
var buf bytes.Buffer
log.SetOutput(&buf)

metrics, err := newAthenaMetrics(athenaMetricsConfig{
batchInterval: defaultBatchInterval,
externalAuditStorage: false,
})
require.NoError(t, err)

cfg := &Config{
LogEntry: log.WithField(teleport.ComponentKey, "test"),
metrics: metrics,
}

t.Run("a lot of errors, make sure only up to maxErrorCountForLogsOnSQSReceive are printed and total count", func(t *testing.T) {
buf.Reset()
noOfErrors := maxErrorCountForLogsOnSQSReceive + 1
errorC := make(chan error, noOfErrors)
go func() {
for i := 0; i < noOfErrors; i++ {
errorC <- errors.New("some error")
}
close(errorC)
}()
errHandlingFnFromSQS(cfg)(ctx, errorC)
require.Equal(t, maxErrorCountForLogsOnSQSReceive, strings.Count(buf.String(), "some error"), "number of error log messages does not match")
require.Contains(t, buf.String(), fmt.Sprintf("Got %d errors from SQS collector, printed only first", noOfErrors))
})

t.Run("few errors, no total count should be printed", func(t *testing.T) {
buf.Reset()
noOfErrors := 5
errorC := make(chan error, noOfErrors)
go func() {
for i := 0; i < noOfErrors; i++ {
errorC <- errors.New("some error")
}
close(errorC)
}()
errHandlingFnFromSQS(cfg)(ctx, errorC)
require.Equal(t, noOfErrors, strings.Count(buf.String(), "some error"), "number of error log messages does not match")
require.NotContains(t, buf.String(), "printed only first")
})
t.Run("no errors at all", func(t *testing.T) {
buf.Reset()
errorC := make(chan error, 10)
go func() {
// close without any errors sent means receiving loop finished without any err
close(errorC)
}()
errHandlingFnFromSQS(cfg)(ctx, errorC)
require.Empty(t, buf.String())
})
t.Run("no errors at all - stopped via ctx cancel", func(t *testing.T) {
buf.Reset()
errorC := make(chan error, 10)
defer close(errorC)

ctx, inCancel := context.WithCancel(ctx)
inCancel()

errHandlingFnFromSQS(cfg)(ctx, errorC)
require.Empty(t, buf.String())
})

t.Run("there were a lot of errors, stopped via ctx cancel", func(t *testing.T) {
buf.Reset()
// unbuffered channel and a more messages,
// just make sure that errors are processed
// before cancel happen, used to avoid sleeping.
noOfErrors := maxErrorCountForLogsOnSQSReceive + 10

errorC := make(chan error)
defer close(errorC)

ctx, inCancel := context.WithCancel(ctx)
go func() {
for i := 0; i < noOfErrors; i++ {
errorC <- errors.New("some error")
}
inCancel()
}()

errHandlingFnFromSQS(cfg)(ctx, errorC)
require.Equal(t, maxErrorCountForLogsOnSQSReceive, strings.Count(buf.String(), "some error"), "number of error log messages does not match")
require.Contains(t, buf.String(), "printed only first")
})
}

// TestConsumerWriteToS3 checks if writing parquet files per date works.
// It receives events from different dates and make sure that multiple
// files are created and compare it against file in testdata.
Expand Down
Loading

0 comments on commit 95e3d60

Please sign in to comment.