diff --git a/context_errors.go b/context_errors.go new file mode 100644 index 0000000..dbfa329 --- /dev/null +++ b/context_errors.go @@ -0,0 +1,28 @@ +package errors + +import ( + "context" + "fmt" +) + +// ContextError represents a standard error +// that can also encapsulate a context. +type ContextError struct { + Err error + Ctx context.Context +} + +func WrapWithContext(err error, ctx context.Context) *ContextError { + return &ContextError{ + Err: err, + Ctx: ctx, + } +} + +func (ce *ContextError) Error() string { + return fmt.Sprintf("%s", ce.Err) +} + +func (ce *ContextError) Unwrap() error { + return ce.Err +} diff --git a/errors.go b/errors.go index 6fe33cb..7c08371 100644 --- a/errors.go +++ b/errors.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/aserto-dev/errors/werr" "github.com/pkg/errors" "github.com/rs/zerolog" "google.golang.org/genproto/googleapis/rpc/errdetails" @@ -29,7 +28,7 @@ var ( ) func NewAsertoError(code string, statusCode codes.Code, httpCode int, msg string) *AsertoError { - asertoError := &AsertoError{code, statusCode, msg, httpCode, map[string]string{}, nil, nil} + asertoError := &AsertoError{code, statusCode, msg, httpCode, map[string]string{}, nil} asertoErrors[code] = asertoError return asertoError } @@ -43,14 +42,6 @@ type AsertoError struct { HTTPCode int data map[string]string errs []error - Ctx context.Context -} - -// Associates a context with the AsertoError. -func (e *AsertoError) WithContext(ctx context.Context) *AsertoError { - c := e.Copy() - c.Ctx = ctx - return c } func (e *AsertoError) Data() map[string]string { @@ -82,7 +73,6 @@ func (e *AsertoError) Copy() *AsertoError { data: dataCopy, errs: e.errs, HTTPCode: e.HTTPCode, - Ctx: e.Ctx, } } @@ -283,6 +273,10 @@ func (e *AsertoError) WithHTTPStatus(httpStatus int) *AsertoError { return c } +func (e *AsertoError) Ctx(ctx context.Context) error { + return WrapWithContext(e, ctx) +} + // Returns an Aserto error based on a given grpcStatus. The details that are not of type errdetails.ErrorInfo are dropped. // and if there are details from multiple errors, the aserto error will be constructed based on the first one. func FromGRPCStatus(grpcStatus status.Status) *AsertoError { @@ -319,19 +313,11 @@ func Logger(err error) *zerolog.Logger { } for { - wErr, ok := err.(*werr.WrappedError) - if ok { - aErr, aOk := wErr.Err.(*AsertoError) - if aOk { - setLogger(aErr.Ctx, &logger) + if ce, ok := err.(*ContextError); ok { + newLogger := extractLogger(ce.Ctx) + if newLogger != nil { + logger = newLogger } - setLogger(wErr.Ctx, &logger) - } - - aErr, ok := err.(*AsertoError) - if ok { - setLogger(aErr.Ctx, &logger) - } err = errors.Unwrap(err) @@ -343,27 +329,6 @@ func Logger(err error) *zerolog.Logger { return logger } -/** - * setLogger sets the logger pointer to the logger stored in the provided context. - * If the context is nil or the logger in the context is nil, the logger pointer remains unchanged. - * If the logger in the context is the default context logger or has a disabled level, the logger pointer remains unchanged. - * - * @param ctx The context from which to retrieve the logger. - * @param logger The pointer to the logger to be set. - */ -func setLogger(ctx context.Context, logger **zerolog.Logger) { - if ctx == nil { - return - } - - newLogger := zerolog.Ctx(ctx) - if newLogger == nil || newLogger == zerolog.DefaultContextLogger || newLogger.GetLevel() == zerolog.Disabled { - return - } - - *logger = newLogger -} - func UnwrapAsertoError(err error) *AsertoError { if err == nil { return nil @@ -376,14 +341,6 @@ func UnwrapAsertoError(err error) *AsertoError { // try to process Aserto error. for { - wErr, ok := err.(*werr.WrappedError) - if ok { - aErr, aOk := wErr.Err.(*AsertoError) - if aOk { - return aErr - } - } - aErr, ok := err.(*AsertoError) if ok { return aErr @@ -426,3 +383,15 @@ func Equals(err1, err2 error) bool { func CodeToAsertoError(code string) *AsertoError { return asertoErrors[code] } + +func extractLogger(ctx context.Context) *zerolog.Logger { + if ctx == nil { + return nil + } + logger := zerolog.Ctx(ctx) + if logger == nil || logger == zerolog.DefaultContextLogger || logger.GetLevel() == zerolog.Disabled { + logger = nil + } + + return logger +} diff --git a/errors_test.go b/errors_test.go index 577e936..574ce79 100644 --- a/errors_test.go +++ b/errors_test.go @@ -10,7 +10,6 @@ import ( "github.com/rs/zerolog" cerr "github.com/aserto-dev/errors" - "github.com/aserto-dev/errors/werr" "github.com/stretchr/testify/require" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" @@ -218,30 +217,44 @@ func TestLoggerWithWrappedNilError(t *testing.T) { var err error ctx := context.Background() - logger := cerr.Logger(werr.Wrap(err, ctx)) + logger := cerr.Logger(cerr.WrapWithContext(err, ctx)) assert.Nil(logger) } -func TestLoggerWithWrappedErrorsWithContext(t *testing.T) { +func TestLoggerWithWrappedErrorsWithEmptyContext(t *testing.T) { assert := require.New(t) ctx := context.Background() - err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx) + err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx) wrappedErr := errors.Wrap(err, "wrapped error") logger := cerr.Logger(wrappedErr) assert.Nil(logger) } +func TestLoggerWithWrappedErrorsWithLoggerContext(t *testing.T) { + assert := require.New(t) + initialLogger := zerolog.New(os.Stderr) + + ctx := context.Background() + ctx = initialLogger.WithContext(ctx) + err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx) + wrappedErr := errors.Wrap(err, "wrapped error") + + logger := cerr.Logger(wrappedErr) + assert.NotNil(logger) + assert.Equal(logger, zerolog.Ctx(ctx)) +} + func TestLoggerWithWrappedMultipleWithoutErrorsWithContext(t *testing.T) { assert := require.New(t) initialLogger := zerolog.New(os.Stderr) - ctx1 := context.Background() - ctx := initialLogger.WithContext(ctx1) - err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx) + ctx := context.Background() + ctx = initialLogger.WithContext(ctx) + err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx) errWithoutCtx := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error") - wrappedErr := errors.Wrap(errWithoutCtx.Err(err), "wrapped error") + wrappedErr := errWithoutCtx.Err(errors.Wrap(err, "wrapped error")) logger := cerr.Logger(wrappedErr) assert.NotNil(logger) @@ -252,11 +265,11 @@ func TestLoggerWithWrappedMultipleErrorsWithContext(t *testing.T) { assert := require.New(t) initialLogger := zerolog.New(os.Stderr) - ctx1 := context.Background() - ctx := initialLogger.WithContext(ctx1) - err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx) + ctx := context.Background() + ctx = initialLogger.WithContext(ctx) + err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx) errWithoutCtx := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error") - wrappedErr := errors.Wrap(err.Err(errWithoutCtx), "wrapped error") + wrappedErr := errors.Wrap(errWithoutCtx.Err(err), "wrapped error") logger := cerr.Logger(wrappedErr) assert.NotNil(logger) @@ -268,9 +281,8 @@ func TestLoggerWithWrappedMultipleErrorsWithMultipleContexts(t *testing.T) { initialLogger := zerolog.New(os.Stderr) ctx1 := context.Background() ctx2 := initialLogger.WithContext(ctx1) - err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx1) - err2 := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx2) - wrappedErr := errors.Wrap(err.Err(err2), "wrapped error") + err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx1) + wrappedErr := cerr.WrapWithContext(cerr.WrapWithContext(err, ctx2), ctx1) logger := cerr.Logger(wrappedErr) ctx1Logger := zerolog.Ctx(ctx1) @@ -286,9 +298,27 @@ func TestLoggerWithWrappedMultipleErrorsWithMultipleContextsOuter(t *testing.T) initialLogger := zerolog.New(os.Stderr) ctx1 := context.Background() ctx2 := initialLogger.WithContext(ctx1) - err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx1) - err2 := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error").WithContext(ctx2) - wrappedErr := errors.Wrap(err2.Err(err), "wrapped error") + err := cerr.WrapWithContext(cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error"), ctx1) + err2 := cerr.WrapWithContext(cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error"), ctx2) + wrappedErr := errors.Wrap(errors.Wrap(err2, err.Error()), "wrapped error") + + logger := cerr.Logger(wrappedErr) + ctx1Logger := zerolog.Ctx(ctx1) + ctx2Logger := zerolog.Ctx(ctx2) + + assert.NotNil(logger) + assert.NotEqual(logger, ctx1Logger) + assert.Equal(logger, ctx2Logger) +} + +func TestLoggerWithWrappedMultipleAsertoErrorsWithMultipleContextsOuter(t *testing.T) { + assert := require.New(t) + initialLogger := zerolog.New(os.Stderr) + ctx1 := context.Background() + ctx2 := initialLogger.WithContext(ctx1) + err := cerr.NewAsertoError("E00001", codes.Internal, http.StatusInternalServerError, "internal error").Ctx(ctx1) + err2 := cerr.NewAsertoError("E00002", codes.Internal, http.StatusInternalServerError, "internal error").Ctx(ctx2) + wrappedErr := errors.Wrap(errors.Wrap(err2, err.Error()), "wrapped error") logger := cerr.Logger(wrappedErr) ctx1Logger := zerolog.Ctx(ctx1) diff --git a/werr/errors.go b/werr/errors.go deleted file mode 100644 index 1890733..0000000 --- a/werr/errors.go +++ /dev/null @@ -1,35 +0,0 @@ -package werr - -import ( - "context" - "fmt" -) - -// WrappedError represents a standard error -// that can also encapsulate a context. -type WrappedError struct { - Ctx context.Context - Err error -} - -func (w *WrappedError) Error() string { - return fmt.Sprintf("%s", w.Err) -} - -func Wrap(err error, ctx context.Context) *WrappedError { - return &WrappedError{ - Ctx: ctx, - Err: err, - } -} - -func (w *WrappedError) WithContext(ctx context.Context) *WrappedError { - return &WrappedError{ - Ctx: ctx, - Err: w.Err, - } -} - -func (w *WrappedError) Unwrap() error { - return w.Err -}