diff --git a/internal/common/convert.go b/internal/common/convert.go index 2f30788c7..ffeb7501a 100644 --- a/internal/common/convert.go +++ b/internal/common/convert.go @@ -95,3 +95,17 @@ func QueryResultTypePtr(t s.QueryResultType) *s.QueryResultType { func PtrOf[T any](v T) *T { return &v } + +// ValueFromPtr returns the value from a pointer. +func ValueFromPtr[T any](v *T) T { + if v == nil { + return Zero[T]() + } + return *v +} + +// Zero returns the zero value of a type by return type. +func Zero[T any]() T { + var zero T + return zero +} diff --git a/internal/common/convert_test.go b/internal/common/convert_test.go index f09c90191..a88b4d020 100644 --- a/internal/common/convert_test.go +++ b/internal/common/convert_test.go @@ -55,3 +55,19 @@ func TestCeilHelpers(t *testing.T) { assert.Equal(t, int32(2), Int32Ceil(1.1)) assert.Equal(t, int64(2), Int64Ceil(1.1)) } + +func TestValueFromPtr(t *testing.T) { + assert.Equal(t, "a", ValueFromPtr(PtrOf("a"))) + assert.Equal(t, 1, ValueFromPtr(PtrOf(1))) + assert.Equal(t, int32(1), ValueFromPtr(PtrOf(int32(1)))) + assert.Equal(t, int64(1), ValueFromPtr(PtrOf(int64(1)))) + assert.Equal(t, 1.1, ValueFromPtr(PtrOf(1.1))) + assert.Equal(t, true, ValueFromPtr(PtrOf(true))) + assert.Equal(t, []string{"a"}, ValueFromPtr(PtrOf([]string{"a"}))) +} + +func TestZero(t *testing.T) { + assert.Equal(t, "", Zero[string]()) + assert.Equal(t, 0, Zero[int]()) + assert.Equal(t, (*int)(nil), Zero[*int]()) +} diff --git a/internal/common/testlogger/testlogger.go b/internal/common/testlogger/testlogger.go index 85dddd74d..a48ddc656 100644 --- a/internal/common/testlogger/testlogger.go +++ b/internal/common/testlogger/testlogger.go @@ -22,12 +22,12 @@ package testlogger import ( "fmt" + "go.uber.org/cadence/internal/common" "slices" "strings" "sync" "github.com/stretchr/testify/require" - "go.uber.org/atomic" "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest" @@ -58,10 +58,11 @@ func NewZap(t TestingT) *zap.Logger { logAfterComplete, err := zap.NewDevelopment() require.NoError(t, err, "could not build a fallback zap logger") replaced := &fallbackTestCore{ + mu: &sync.RWMutex{}, t: t, fallback: logAfterComplete.Core(), testing: zaptest.NewLogger(t).Core(), - completed: &atomic.Bool{}, + completed: common.PtrOf(false), } t.Cleanup(replaced.UseFallback) // switch to fallback before ending the test @@ -81,30 +82,38 @@ func NewObserved(t TestingT) (*zap.Logger, *observer.ObservedLogs) { } type fallbackTestCore struct { - sync.Mutex + mu *sync.RWMutex t TestingT fallback zapcore.Core testing zapcore.Core - completed *atomic.Bool + completed *bool } var _ zapcore.Core = (*fallbackTestCore)(nil) func (f *fallbackTestCore) UseFallback() { - f.completed.Store(true) + f.mu.Lock() + defer f.mu.Unlock() + *f.completed = true } func (f *fallbackTestCore) Enabled(level zapcore.Level) bool { - if f.completed.Load() { + f.mu.RLock() + defer f.mu.RUnlock() + if f.completed != nil && *f.completed { return f.fallback.Enabled(level) } return f.testing.Enabled(level) } func (f *fallbackTestCore) With(fields []zapcore.Field) zapcore.Core { + f.mu.Lock() + defer f.mu.Unlock() + // need to copy and defer, else the returned core will be used at an // arbitrarily later point in time, possibly after the test has completed. return &fallbackTestCore{ + mu: f.mu, t: f.t, fallback: f.fallback.With(fields), testing: f.testing.With(fields), @@ -113,6 +122,8 @@ func (f *fallbackTestCore) With(fields []zapcore.Field) zapcore.Core { } func (f *fallbackTestCore) Check(entry zapcore.Entry, checked *zapcore.CheckedEntry) *zapcore.CheckedEntry { + f.mu.RLock() + defer f.mu.RUnlock() // see other Check impls, all look similar. // this defers the "where to log" decision to Write, as `f` is the core that will write. if f.fallback.Enabled(entry.Level) { @@ -122,7 +133,10 @@ func (f *fallbackTestCore) Check(entry zapcore.Entry, checked *zapcore.CheckedEn } func (f *fallbackTestCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { - if f.completed.Load() { + f.mu.RLock() + defer f.mu.RUnlock() + + if common.ValueFromPtr(f.completed) { entry.Message = fmt.Sprintf("COULD FAIL TEST %q, logged too late: %v", f.t.Name(), entry.Message) hasStack := slices.ContainsFunc(fields, func(field zapcore.Field) bool { @@ -134,14 +148,14 @@ func (f *fallbackTestCore) Write(entry zapcore.Entry, fields []zapcore.Field) er } return f.fallback.Write(entry, fields) } - // Ensure no concurrent writes to the test logger. - f.Lock() - defer f.Unlock() return f.testing.Write(entry, fields) } func (f *fallbackTestCore) Sync() error { - if f.completed.Load() { + f.mu.RLock() + defer f.mu.RUnlock() + + if common.ValueFromPtr(f.completed) { return f.fallback.Sync() } return f.testing.Sync() diff --git a/internal/common/testlogger/testlogger_test.go b/internal/common/testlogger/testlogger_test.go index e29db3139..5409ea4b8 100644 --- a/internal/common/testlogger/testlogger_test.go +++ b/internal/common/testlogger/testlogger_test.go @@ -22,7 +22,9 @@ package testlogger import ( "fmt" + "go.uber.org/cadence/internal/common" "os" + "sync" "testing" "time" @@ -30,7 +32,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" "go.uber.org/zap" ) @@ -47,7 +48,7 @@ func TestMain(m *testing.M) { select { case <-logged: os.Exit(code) - case <-time.After(time.Second): // should be MUCH faster + case <-time.After(time.Millisecond): // should be MUCH faster _, _ = fmt.Fprintln(os.Stderr, "timed out waiting for test to log") os.Exit(1) } @@ -131,10 +132,11 @@ func TestFallbackTestCore_Enabled(t *testing.T) { require.NoError(t, err) core := &fallbackTestCore{ + mu: &sync.RWMutex{}, t: t, fallback: fallbackLogger.Core(), testing: zaptest.NewLogger(t).Core(), - completed: &atomic.Bool{}, + completed: common.PtrOf(false), } // Debug is enabled in zaptest.Logger assert.True(t, core.Enabled(zap.DebugLevel)) @@ -144,16 +146,11 @@ func TestFallbackTestCore_Enabled(t *testing.T) { } func TestFallbackTestCore_Sync(t *testing.T) { - - core := &fallbackTestCore{ - t: t, - fallback: zaptest.NewLogger(t).Core(), - testing: zaptest.NewLogger(t).Core(), - completed: &atomic.Bool{}, - } + core := NewZap(t).Core().(*fallbackTestCore) + core.fallback = zap.NewNop().Core() // Sync for testing logger must not fail. assert.NoError(t, core.Sync(), "normal sync must not fail") core.UseFallback() // Sync for fallback logger must not fail. - assert.NoError(t, core.Sync(), "fallback sync must not fail") + assert.NoError(t, core.Sync(), "fallback sync must not fail") }