From 1505d282ac345ac771ed8e4315e1ee6d10024336 Mon Sep 17 00:00:00 2001 From: Prashant Varanasi Date: Mon, 6 Feb 2023 01:17:57 -0800 Subject: [PATCH] Fix Swap and CompareAndSwap for Value wrappers (#130) * Regenerate code to update copyright end year to 2023 * Test behaviour of default values initialized in different ways This adds repro tests for #126 and #129 * Fix Swap and CompareAndSwap for Value wrappers Fixes #126, #129 All atomic types can be used without initialization, e.g., `var v `. This works fine for integer types as the initialized value of 0 matches the default value for the user-facing type. However, for Value wrappers, they are initialized to `nil`, which is a value that can't be set (triggers a panic) so the default value for the user-facing type is forced to be stored as a different value. This leads to multiple possible values representing the default user-facing type. E.g., an `atomic.String` with value `""` may be represented by the underlying atomic as either `nil`, or `""`. This causes issues when we don't handle the `nil` value correctly, causing to panics in `Swap` and incorrectly not swapping values in `CompareAndSwap`. This change fixes the above issues by: * Requiring `pack` and `unpack` function in gen-atomicwrapper as the only place we weren't supplying them was for `String`, and the branching adds unnecessary complexity, especially with added `nil` handling. * Extending `CompareAndSwap` for `Value` wrappers to try an additional `CompareAndSwap(nil, )` only if the original `CompareAndSwap` fails and the old value is the zero value. --- bool.go | 2 +- bool_test.go | 64 +++++++++++++++++++++++++ duration.go | 2 +- error.go | 14 +++++- error_test.go | 53 ++++++++++++++++++++ float32.go | 2 +- float64.go | 2 +- int32.go | 2 +- int64.go | 2 +- internal/gen-atomicwrapper/main.go | 15 +++--- internal/gen-atomicwrapper/wrapper.tmpl | 28 ++++++----- string.go | 23 +++++---- string_ext.go | 15 +++++- string_test.go | 64 +++++++++++++++++++++++++ time.go | 2 +- uint32.go | 2 +- uint64.go | 2 +- uintptr.go | 2 +- 18 files changed, 252 insertions(+), 44 deletions(-) diff --git a/bool.go b/bool.go index dfa2085..f0a2ddd 100644 --- a/bool.go +++ b/bool.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicwrapper. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/bool_test.go b/bool_test.go index bcba01d..fa73325 100644 --- a/bool_test.go +++ b/bool_test.go @@ -84,3 +84,67 @@ func TestBool(t *testing.T) { }) }) } + +func TestBool_InitializeDefaults(t *testing.T) { + tests := []struct { + msg string + newBool func() *Bool + }{ + { + msg: "Uninitialized", + newBool: func() *Bool { + var b Bool + return &b + }, + }, + { + msg: "NewBool with default", + newBool: func() *Bool { + return NewBool(false) + }, + }, + { + msg: "Bool swapped with default", + newBool: func() *Bool { + b := NewBool(true) + b.Swap(false) + return b + }, + }, + { + msg: "Bool CAS'd with default", + newBool: func() *Bool { + b := NewBool(true) + b.CompareAndSwap(true, false) + return b + }, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + t.Run("MarshalJSON", func(t *testing.T) { + b := tt.newBool() + marshalled, err := b.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, "false", string(marshalled)) + }) + + t.Run("String", func(t *testing.T) { + b := tt.newBool() + assert.Equal(t, "false", b.String()) + }) + + t.Run("CompareAndSwap", func(t *testing.T) { + b := tt.newBool() + require.True(t, b.CompareAndSwap(false, true)) + assert.Equal(t, true, b.Load()) + }) + + t.Run("Swap", func(t *testing.T) { + b := tt.newBool() + assert.Equal(t, false, b.Swap(true)) + }) + }) + } +} diff --git a/duration.go b/duration.go index 6f41574..7c23868 100644 --- a/duration.go +++ b/duration.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicwrapper. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/error.go b/error.go index 27b23ea..b7e3f12 100644 --- a/error.go +++ b/error.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicwrapper. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -52,7 +52,17 @@ func (x *Error) Store(val error) { // CompareAndSwap is an atomic compare-and-swap for error values. func (x *Error) CompareAndSwap(old, new error) (swapped bool) { - return x.v.CompareAndSwap(packError(old), packError(new)) + if x.v.CompareAndSwap(packError(old), packError(new)) { + return true + } + + if old == _zeroError { + // If the old value is the empty value, then it's possible the + // underlying Value hasn't been set and is nil, so retry with nil. + return x.v.CompareAndSwap(nil, packError(new)) + } + + return false } // Swap atomically stores the given error and returns the old diff --git a/error_test.go b/error_test.go index c9ac2ef..1f02e6d 100644 --- a/error_test.go +++ b/error_test.go @@ -24,6 +24,7 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -81,3 +82,55 @@ func TestErrorCompareAndSwap(t *testing.T) { require.True(t, swapped, "Expected swapped to be true") require.Equal(t, err2, atom.Load(), "Expected Load to return overridden value") } + +func TestError_InitializeDefaults(t *testing.T) { + tests := []struct { + msg string + newError func() *Error + }{ + { + msg: "Uninitialized", + newError: func() *Error { + var e Error + return &e + }, + }, + { + msg: "NewError with default", + newError: func() *Error { + return NewError(nil) + }, + }, + { + msg: "Error swapped with default", + newError: func() *Error { + e := NewError(assert.AnError) + e.Swap(nil) + return e + }, + }, + { + msg: "Error CAS'd with default", + newError: func() *Error { + e := NewError(assert.AnError) + e.CompareAndSwap(assert.AnError, nil) + return e + }, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + t.Run("CompareAndSwap", func(t *testing.T) { + e := tt.newError() + require.True(t, e.CompareAndSwap(nil, assert.AnError)) + assert.Equal(t, assert.AnError, e.Load()) + }) + + t.Run("Swap", func(t *testing.T) { + e := tt.newError() + assert.Equal(t, nil, e.Swap(assert.AnError)) + }) + }) + } +} diff --git a/float32.go b/float32.go index 5d535a6..62c3633 100644 --- a/float32.go +++ b/float32.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicwrapper. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/float64.go b/float64.go index 11d5189..5bc11ca 100644 --- a/float64.go +++ b/float64.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicwrapper. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/int32.go b/int32.go index b9a68f4..5320eac 100644 --- a/int32.go +++ b/int32.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicint. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/int64.go b/int64.go index 78d2609..460821d 100644 --- a/int64.go +++ b/int64.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicint. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/internal/gen-atomicwrapper/main.go b/internal/gen-atomicwrapper/main.go index aec93b8..982254d 100644 --- a/internal/gen-atomicwrapper/main.go +++ b/internal/gen-atomicwrapper/main.go @@ -47,9 +47,6 @@ // // The packing/unpacking logic allows the stored value to be different from // the user-facing value. -// -// Without -pack and -unpack, the output will be cast to the target type, -// defaulting to the zero value. package main import ( @@ -143,12 +140,12 @@ func run(args []string) error { return err } - if len(opts.Name) == 0 || len(opts.Wrapped) == 0 || len(opts.Type) == 0 { - return errors.New("flags -name, -wrapped, and -type are required") - } - - if (len(opts.Pack) == 0) != (len(opts.Unpack) == 0) { - return errors.New("either both, or neither of -pack and -unpack must be specified") + if len(opts.Name) == 0 || + len(opts.Wrapped) == 0 || + len(opts.Type) == 0 || + len(opts.Pack) == 0 || + len(opts.Unpack) == 0 { + return errors.New("flags -name, -wrapped, -pack, -unpack and -type are required") } if opts.CAS { diff --git a/internal/gen-atomicwrapper/wrapper.tmpl b/internal/gen-atomicwrapper/wrapper.tmpl index 2e078e2..47e0253 100644 --- a/internal/gen-atomicwrapper/wrapper.tmpl +++ b/internal/gen-atomicwrapper/wrapper.tmpl @@ -61,11 +61,7 @@ func (x *{{ .Name }}) Load() {{ .Type }} { // Store atomically stores the passed {{ .Type }}. func (x *{{ .Name }}) Store(val {{ .Type }}) { - {{ if .Pack -}} - x.v.Store({{ .Pack }}(val)) - {{- else -}} - x.v.Store(val) - {{- end }} + x.v.Store({{ .Pack }}(val)) } {{ if .CAS -}} @@ -80,10 +76,20 @@ func (x *{{ .Name }}) Store(val {{ .Type }}) { {{ if .CompareAndSwap -}} // CompareAndSwap is an atomic compare-and-swap for {{ .Type }} values. func (x *{{ .Name }}) CompareAndSwap(old, new {{ .Type }}) (swapped bool) { - {{ if .Pack -}} + {{ if eq .Wrapped "Value" -}} + if x.v.CompareAndSwap({{ .Pack }}(old), {{ .Pack }}(new)) { + return true + } + + if old == _zero{{ .Name }} { + // If the old value is the empty value, then it's possible the + // underlying Value hasn't been set and is nil, so retry with nil. + return x.v.CompareAndSwap(nil, {{ .Pack }}(new)) + } + + return false + {{- else -}} return x.v.CompareAndSwap({{ .Pack }}(old), {{ .Pack }}(new)) - {{- else -}}{{- /* assume go.uber.org/atomic.Value */ -}} - return x.v.CompareAndSwap(old, new) {{- end }} } {{- end }} @@ -92,11 +98,7 @@ func (x *{{ .Name }}) Store(val {{ .Type }}) { // Swap atomically stores the given {{ .Type }} and returns the old // value. func (x *{{ .Name }}) Swap(val {{ .Type }}) (old {{ .Type }}) { - {{ if .Pack -}} - return {{ .Unpack }}(x.v.Swap({{ .Pack }}(val))) - {{- else -}}{{- /* assume go.uber.org/atomic.Value */ -}} - return x.v.Swap(val).({{ .Type }}) - {{- end }} + return {{ .Unpack }}(x.v.Swap({{ .Pack }}(val))) } {{- end }} diff --git a/string.go b/string.go index c4bea70..061466c 100644 --- a/string.go +++ b/string.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicwrapper. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -42,24 +42,31 @@ func NewString(val string) *String { // Load atomically loads the wrapped string. func (x *String) Load() string { - if v := x.v.Load(); v != nil { - return v.(string) - } - return _zeroString + return unpackString(x.v.Load()) } // Store atomically stores the passed string. func (x *String) Store(val string) { - x.v.Store(val) + x.v.Store(packString(val)) } // CompareAndSwap is an atomic compare-and-swap for string values. func (x *String) CompareAndSwap(old, new string) (swapped bool) { - return x.v.CompareAndSwap(old, new) + if x.v.CompareAndSwap(packString(old), packString(new)) { + return true + } + + if old == _zeroString { + // If the old value is the empty value, then it's possible the + // underlying Value hasn't been set and is nil, so retry with nil. + return x.v.CompareAndSwap(nil, packString(new)) + } + + return false } // Swap atomically stores the given string and returns the old // value. func (x *String) Swap(val string) (old string) { - return x.v.Swap(val).(string) + return unpackString(x.v.Swap(packString(val))) } diff --git a/string_ext.go b/string_ext.go index 1f63dfd..019109c 100644 --- a/string_ext.go +++ b/string_ext.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,7 +20,18 @@ package atomic -//go:generate bin/gen-atomicwrapper -name=String -type=string -wrapped=Value -compareandswap -swap -file=string.go +//go:generate bin/gen-atomicwrapper -name=String -type=string -wrapped Value -pack packString -unpack unpackString -compareandswap -swap -file=string.go + +func packString(s string) interface{} { + return s +} + +func unpackString(v interface{}) string { + if s, ok := v.(string); ok { + return s + } + return "" +} // String returns the wrapped value. func (s *String) String() string { diff --git a/string_test.go b/string_test.go index 54121d1..3fb5fc1 100644 --- a/string_test.go +++ b/string_test.go @@ -103,3 +103,67 @@ func TestString(t *testing.T) { require.Equal(t, atom.Load(), "bar", "Load returned wrong value") }) } + +func TestString_InitializeDefault(t *testing.T) { + tests := []struct { + msg string + newStr func() *String + }{ + { + msg: "Uninitialized", + newStr: func() *String { + var s String + return &s + }, + }, + { + msg: "NewString with default", + newStr: func() *String { + return NewString("") + }, + }, + { + msg: "String swapped with default", + newStr: func() *String { + s := NewString("initial") + s.Swap("") + return s + }, + }, + { + msg: "String CAS'd with default", + newStr: func() *String { + s := NewString("initial") + s.CompareAndSwap("initial", "") + return s + }, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + t.Run("MarshalText", func(t *testing.T) { + str := tt.newStr() + text, err := str.MarshalText() + require.NoError(t, err) + assert.Equal(t, "", string(text), "") + }) + + t.Run("String", func(t *testing.T) { + str := tt.newStr() + assert.Equal(t, "", str.String()) + }) + + t.Run("CompareAndSwap", func(t *testing.T) { + str := tt.newStr() + require.True(t, str.CompareAndSwap("", "new")) + assert.Equal(t, "new", str.Load()) + }) + + t.Run("Swap", func(t *testing.T) { + str := tt.newStr() + assert.Equal(t, "", str.Swap("new")) + }) + }) + } +} diff --git a/time.go b/time.go index 1660feb..cc2a230 100644 --- a/time.go +++ b/time.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicwrapper. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/uint32.go b/uint32.go index d6f04a9..4adc294 100644 --- a/uint32.go +++ b/uint32.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicint. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/uint64.go b/uint64.go index 2574bdd..0e2eddb 100644 --- a/uint64.go +++ b/uint64.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicint. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal diff --git a/uintptr.go b/uintptr.go index 81b275a..7d5b000 100644 --- a/uintptr.go +++ b/uintptr.go @@ -1,6 +1,6 @@ // @generated Code generated by gen-atomicint. -// Copyright (c) 2020-2022 Uber Technologies, Inc. +// Copyright (c) 2020-2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal