Skip to content

Commit

Permalink
allow passing custom logger in
Browse files Browse the repository at this point in the history
  • Loading branch information
chrispatrick committed Oct 14, 2024
1 parent a229916 commit 77b5c36
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 30 deletions.
8 changes: 6 additions & 2 deletions handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
// Start initiates startup of the skills given the provided Handlers
func Start(handlers Handlers) {
Log.Info("Starting skill...")
http.HandleFunc("/", CreateHttpHandler(handlers))
http.HandleFunc("/", CreateHttpHandlerWithLogger(handlers, nil))

port := os.Getenv("PORT")
if port == "" {
Expand All @@ -45,6 +45,10 @@ func Start(handlers Handlers) {
}

func CreateHttpHandler(handlers Handlers) func(http.ResponseWriter, *http.Request) {
return CreateHttpHandlerWithLogger(handlers, nil)
}

func CreateHttpHandlerWithLogger(handlers Handlers, loggerCreator CreateLogger) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
handleStart := time.Now()
buf := new(strings.Builder)
Expand All @@ -66,7 +70,7 @@ func CreateHttpHandler(handlers Handlers) func(http.ResponseWriter, *http.Reques

name := NameFromEvent(event)
ctx := context.Background()
logger := createLogger(ctx, event, r.Header)
logger := createLogger(ctx, event, r.Header, loggerCreator)
req := RequestContext{
Event: event,
Log: logger,
Expand Down
157 changes: 131 additions & 26 deletions log.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ var (
instanceID string
)

type CreateLogger func(ctx context.Context, labels map[string]string) *Logger

func init() {
Log = logrus.New()
Log.SetOutput(os.Stdout)
Expand Down Expand Up @@ -88,11 +90,7 @@ type Logger struct {
Close func()
}

func createLogger(ctx context.Context, event EventIncoming, headers http.Header) Logger {
logger := Logger{}

var gcpLogger *logging.Logger
var client *logging.Client
func createCommonLabels(event EventIncoming, headers http.Header) map[string]string {
labels := make(map[string]string)

labels["correlation_id"] = event.ExecutionId
Expand All @@ -111,11 +109,20 @@ func createLogger(ctx context.Context, event EventIncoming, headers http.Header)
labels["cloud_trace_context"] = headers.Get("X-Cloud-Trace-Context")
labels["trace_parent"] = headers.Get("traceparent")

if projectID != "" {
client, _ = logging.NewClient(ctx, projectID)
gcpLogger = client.Logger("skill_logging")
return labels
}

func createGcpLogger(ctx context.Context, labels map[string]string) *Logger {
if projectID == "" {
return nil
}

var gcpLogger *logging.Logger
var client *logging.Client

client, _ = logging.NewClient(ctx, projectID)
gcpLogger = client.Logger("skill_logging")

var doGcpLog = func(msg string, level edn.Keyword) {
if gcpLogger != nil {
var severity logging.Severity
Expand All @@ -137,47 +144,145 @@ func createLogger(ctx context.Context, event EventIncoming, headers http.Header)
}
}

logger := Logger{
Debug: func(msg string) {
doGcpLog(msg, internal.Debug)
},
Debugf: func(format string, a ...any) {
doGcpLog(fmt.Sprintf(format, a...), internal.Debug)
},
Info: func(msg string) {
doGcpLog(msg, internal.Info)
},
Infof: func(format string, a ...any) {
doGcpLog(fmt.Sprintf(format, a...), internal.Info)
},
Warn: func(msg string) {
doGcpLog(msg, internal.Warn)
},
Warnf: func(format string, a ...any) {
doGcpLog(fmt.Sprintf(format, a...), internal.Warn)
},
Error: func(msg string) {
doGcpLog(msg, internal.Error)
},
Errorf: func(format string, a ...any) {
doGcpLog(fmt.Sprintf(format, a...), internal.Error)
},
Close: func() {
if client != nil {
_ = client.Close()
}
},
}

return &logger
}

func createDefaultLogger(ctx context.Context, labels map[string]string) *Logger {
localLabels := make(map[string]interface{})
for k, v := range labels {
localLabels[k] = v
}

logger := Logger{
Debug: func(msg string) {
Log.WithFields(localLabels).Debug(msg)
},
Debugf: func(format string, a ...any) {
Log.WithFields(localLabels).Debugf(format, a...)
},
Info: func(msg string) {
Log.WithFields(localLabels).Info(msg)
},
Infof: func(format string, a ...any) {
Log.WithFields(localLabels).Infof(format, a...)
},
Warn: func(msg string) {
Log.WithFields(localLabels).Warn(msg)
},
Warnf: func(format string, a ...any) {
Log.WithFields(localLabels).Warnf(format, a...)
},
Error: func(msg string) {
Log.WithFields(localLabels).Error(msg)
},
Errorf: func(format string, a ...any) {
Log.WithFields(localLabels).Errorf(format, a...)
},
Close: func() {
},
}

return &logger
}

func createLogger(ctx context.Context, event EventIncoming, headers http.Header, loggerCreator CreateLogger) Logger {
labels := createCommonLabels(event, headers)

loggerCreators := []CreateLogger{
createGcpLogger,
}

if loggerCreator != nil {
loggerCreators = append(loggerCreators, loggerCreator)
} else {
loggerCreators = append(loggerCreators, createDefaultLogger)
}

loggers := []Logger{}
for _, creator := range loggerCreators {
l := creator(ctx, labels)
if l != nil {
loggers = append(loggers, *l)
}
}

logger := Logger{}
logger.Debug = func(msg string) {
Log.WithFields(localLabels).Debug(msg)
doGcpLog(msg, internal.Debug)
for _, l := range loggers {
l.Debug(msg)
}
}
logger.Debugf = func(format string, a ...any) {
a = expandFuncs(a, logrus.DebugLevel)
Log.WithFields(localLabels).Debugf(format, a...)
doGcpLog(fmt.Sprintf(format, a...), internal.Debug)
for _, l := range loggers {
l.Debugf(format, a...)
}
}
logger.Info = func(msg string) {
Log.WithFields(localLabels).Info(msg)
doGcpLog(msg, internal.Info)
for _, l := range loggers {
l.Info(msg)
}
}
logger.Infof = func(format string, a ...any) {
Log.WithFields(localLabels).Infof(format, a...)
doGcpLog(fmt.Sprintf(format, a...), internal.Info)
for _, l := range loggers {
l.Infof(format, a...)
}
}
logger.Warn = func(msg string) {
Log.WithFields(localLabels).Warn(msg)
doGcpLog(msg, internal.Warn)
for _, l := range loggers {
l.Warn(msg)
}
}
logger.Warnf = func(format string, a ...any) {
Log.WithFields(localLabels).Warnf(format, a...)
doGcpLog(fmt.Sprintf(format, a...), internal.Warn)
for _, l := range loggers {
l.Warnf(format, a...)
}
}
logger.Error = func(msg string) {
Log.WithFields(localLabels).Error(msg)
doGcpLog(msg, internal.Error)
for _, l := range loggers {
l.Error(msg)
}
}
logger.Errorf = func(format string, a ...any) {
Log.WithFields(localLabels).Errorf(format, a...)
doGcpLog(fmt.Sprintf(format, a...), internal.Error)
for _, l := range loggers {
l.Errorf(format, a...)
}
}
logger.Close = func() {
if client != nil {
_ = client.Close()
for _, l := range loggers {
l.Close()
}
}

Expand Down
4 changes: 2 additions & 2 deletions log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestSuccessfulLogging(t *testing.T) {
Logs: server.URL,
},
Token: "token",
}, http.Header{})
}, http.Header{}, nil)
logger.Infof("This is a %s message", "test")
}

Expand Down Expand Up @@ -98,7 +98,7 @@ func TestLoggingWithFunc(t *testing.T) {
var buf bytes.Buffer
Log.SetOutput(&buf)
Log.SetLevel(logrus.DebugLevel)
logger := createLogger(context.Background(), EventIncoming{}, http.Header{})
logger := createLogger(context.Background(), EventIncoming{}, http.Header{}, nil)
logger.Debugf("This is a %s message", func() interface{} { return "test" })

if !strings.Contains(buf.String(), "This is a test message") {
Expand Down

0 comments on commit 77b5c36

Please sign in to comment.