Skip to content

Commit

Permalink
add context_errors
Browse files Browse the repository at this point in the history
  • Loading branch information
gimmyxd committed May 2, 2024
1 parent e34572c commit 2468a96
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 105 deletions.
28 changes: 28 additions & 0 deletions context_errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package errors

import (
"context"
"fmt"
)

// ContextError represents a standard error
// that can also encapsulate a context.
type ContextError struct {
Err error
Ctx context.Context
}

func WrapWithContext(err error, ctx context.Context) *ContextError {
return &ContextError{
Err: err,
Ctx: ctx,
}
}

func (ce *ContextError) Error() string {
return fmt.Sprintf("%s", ce.Err)
}

func (ce *ContextError) Unwrap() error {
return ce.Err
}
73 changes: 21 additions & 52 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strings"
"time"

"github.com/aserto-dev/errors/werr"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"google.golang.org/genproto/googleapis/rpc/errdetails"
Expand All @@ -29,7 +28,7 @@ var (
)

func NewAsertoError(code string, statusCode codes.Code, httpCode int, msg string) *AsertoError {
asertoError := &AsertoError{code, statusCode, msg, httpCode, map[string]string{}, nil, nil}
asertoError := &AsertoError{code, statusCode, msg, httpCode, map[string]string{}, nil}
asertoErrors[code] = asertoError
return asertoError
}
Expand All @@ -43,14 +42,6 @@ type AsertoError struct {
HTTPCode int
data map[string]string
errs []error
Ctx context.Context
}

// Associates a context with the AsertoError.
func (e *AsertoError) WithContext(ctx context.Context) *AsertoError {
c := e.Copy()
c.Ctx = ctx
return c
}

func (e *AsertoError) Data() map[string]string {
Expand Down Expand Up @@ -82,7 +73,6 @@ func (e *AsertoError) Copy() *AsertoError {
data: dataCopy,
errs: e.errs,
HTTPCode: e.HTTPCode,
Ctx: e.Ctx,
}
}

Expand Down Expand Up @@ -283,6 +273,10 @@ func (e *AsertoError) WithHTTPStatus(httpStatus int) *AsertoError {
return c
}

func (e *AsertoError) Ctx(ctx context.Context) error {
return WrapWithContext(e, ctx)
}

// Returns an Aserto error based on a given grpcStatus. The details that are not of type errdetails.ErrorInfo are dropped.
// and if there are details from multiple errors, the aserto error will be constructed based on the first one.
func FromGRPCStatus(grpcStatus status.Status) *AsertoError {
Expand Down Expand Up @@ -319,19 +313,11 @@ func Logger(err error) *zerolog.Logger {
}

for {
wErr, ok := err.(*werr.WrappedError)
if ok {
aErr, aOk := wErr.Err.(*AsertoError)
if aOk {
setLogger(aErr.Ctx, &logger)
if ce, ok := err.(*ContextError); ok {
newLogger := extractLogger(ce.Ctx)
if newLogger != nil {
logger = newLogger
}
setLogger(wErr.Ctx, &logger)
}

aErr, ok := err.(*AsertoError)
if ok {
setLogger(aErr.Ctx, &logger)

}

err = errors.Unwrap(err)
Expand All @@ -343,27 +329,6 @@ func Logger(err error) *zerolog.Logger {
return logger
}

/**
* setLogger sets the logger pointer to the logger stored in the provided context.
* If the context is nil or the logger in the context is nil, the logger pointer remains unchanged.
* If the logger in the context is the default context logger or has a disabled level, the logger pointer remains unchanged.
*
* @param ctx The context from which to retrieve the logger.
* @param logger The pointer to the logger to be set.
*/
func setLogger(ctx context.Context, logger **zerolog.Logger) {
if ctx == nil {
return
}

newLogger := zerolog.Ctx(ctx)
if newLogger == nil || newLogger == zerolog.DefaultContextLogger || newLogger.GetLevel() == zerolog.Disabled {
return
}

*logger = newLogger
}

func UnwrapAsertoError(err error) *AsertoError {
if err == nil {
return nil
Expand All @@ -376,14 +341,6 @@ func UnwrapAsertoError(err error) *AsertoError {

// try to process Aserto error.
for {
wErr, ok := err.(*werr.WrappedError)
if ok {
aErr, aOk := wErr.Err.(*AsertoError)
if aOk {
return aErr
}
}

aErr, ok := err.(*AsertoError)
if ok {
return aErr
Expand Down Expand Up @@ -426,3 +383,15 @@ func Equals(err1, err2 error) bool {
func CodeToAsertoError(code string) *AsertoError {
return asertoErrors[code]
}

func extractLogger(ctx context.Context) *zerolog.Logger {
if ctx == nil {
return nil
}
logger := zerolog.Ctx(ctx)
if logger == nil || logger == zerolog.DefaultContextLogger || logger.GetLevel() == zerolog.Disabled {
logger = nil
}

return logger
}
66 changes: 48 additions & 18 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/rs/zerolog"

cerr "github.com/aserto-dev/errors"
"github.com/aserto-dev/errors/werr"
"github.com/stretchr/testify/require"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -218,30 +217,44 @@ func TestLoggerWithWrappedNilError(t *testing.T) {
var err error
ctx := context.Background()

logger := cerr.Logger(werr.Wrap(err, ctx))
logger := cerr.Logger(cerr.WrapWithContext(err, ctx))
assert.Nil(logger)
}

func TestLoggerWithWrappedErrorsWithContext(t *testing.T) {
func TestLoggerWithWrappedErrorsWithEmptyContext(t *testing.T) {
assert := require.New(t)

ctx := context.Background()
err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx)
err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx)
wrappedErr := errors.Wrap(err, "wrapped error")

logger := cerr.Logger(wrappedErr)
assert.Nil(logger)
}

func TestLoggerWithWrappedErrorsWithLoggerContext(t *testing.T) {
assert := require.New(t)
initialLogger := zerolog.New(os.Stderr)

ctx := context.Background()
ctx = initialLogger.WithContext(ctx)
err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx)
wrappedErr := errors.Wrap(err, "wrapped error")

logger := cerr.Logger(wrappedErr)
assert.NotNil(logger)
assert.Equal(logger, zerolog.Ctx(ctx))
}

func TestLoggerWithWrappedMultipleWithoutErrorsWithContext(t *testing.T) {
assert := require.New(t)
initialLogger := zerolog.New(os.Stderr)

ctx1 := context.Background()
ctx := initialLogger.WithContext(ctx1)
err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx)
ctx := context.Background()
ctx = initialLogger.WithContext(ctx)
err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx)
errWithoutCtx := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error")
wrappedErr := errors.Wrap(errWithoutCtx.Err(err), "wrapped error")
wrappedErr := errWithoutCtx.Err(errors.Wrap(err, "wrapped error"))

logger := cerr.Logger(wrappedErr)
assert.NotNil(logger)
Expand All @@ -252,11 +265,11 @@ func TestLoggerWithWrappedMultipleErrorsWithContext(t *testing.T) {
assert := require.New(t)
initialLogger := zerolog.New(os.Stderr)

ctx1 := context.Background()
ctx := initialLogger.WithContext(ctx1)
err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx)
ctx := context.Background()
ctx = initialLogger.WithContext(ctx)
err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx)
errWithoutCtx := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error")
wrappedErr := errors.Wrap(err.Err(errWithoutCtx), "wrapped error")
wrappedErr := errors.Wrap(errWithoutCtx.Err(err), "wrapped error")

logger := cerr.Logger(wrappedErr)
assert.NotNil(logger)
Expand All @@ -268,9 +281,8 @@ func TestLoggerWithWrappedMultipleErrorsWithMultipleContexts(t *testing.T) {
initialLogger := zerolog.New(os.Stderr)
ctx1 := context.Background()
ctx2 := initialLogger.WithContext(ctx1)
err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx1)
err2 := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx2)
wrappedErr := errors.Wrap(err.Err(err2), "wrapped error")
err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx1)
wrappedErr := cerr.WrapWithContext(cerr.WrapWithContext(err, ctx2), ctx1)

logger := cerr.Logger(wrappedErr)
ctx1Logger := zerolog.Ctx(ctx1)
Expand All @@ -286,9 +298,27 @@ func TestLoggerWithWrappedMultipleErrorsWithMultipleContextsOuter(t *testing.T)
initialLogger := zerolog.New(os.Stderr)
ctx1 := context.Background()
ctx2 := initialLogger.WithContext(ctx1)
err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx1)
err2 := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx2)
wrappedErr := errors.Wrap(err2.Err(err), "wrapped error")
err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx1)
err2 := cerr.WrapWithContext(cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error"), ctx2)
wrappedErr := errors.Wrap(errors.Wrap(err2, err.Error()), "wrapped error")

logger := cerr.Logger(wrappedErr)
ctx1Logger := zerolog.Ctx(ctx1)
ctx2Logger := zerolog.Ctx(ctx2)

assert.NotNil(logger)
assert.NotEqual(logger, ctx1Logger)
assert.Equal(logger, ctx2Logger)
}

func TestLoggerWithWrappedMultipleAsertoErrorsWithMultipleContextsOuter(t *testing.T) {
assert := require.New(t)
initialLogger := zerolog.New(os.Stderr)
ctx1 := context.Background()
ctx2 := initialLogger.WithContext(ctx1)
err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").Ctx(ctx1)
err2 := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error").Ctx(ctx2)
wrappedErr := errors.Wrap(errors.Wrap(err2, err.Error()), "wrapped error")

logger := cerr.Logger(wrappedErr)
ctx1Logger := zerolog.Ctx(ctx1)
Expand Down
35 changes: 0 additions & 35 deletions werr/errors.go

This file was deleted.

0 comments on commit 2468a96

Please sign in to comment.