diff --git a/internal/freelist/array_test.go b/internal/freelist/array_test.go index 31b0702dc..4d1306102 100644 --- a/internal/freelist/array_test.go +++ b/internal/freelist/array_test.go @@ -4,6 +4,8 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/require" + "go.etcd.io/bbolt/internal/common" ) @@ -50,3 +52,30 @@ func TestFreelistArray_allocate(t *testing.T) { t.Fatalf("exp=%v; got=%v", exp, f.freePageIds()) } } + +func Test_Freelist_Array_Rollback(t *testing.T) { + f := newTestArrayFreelist() + + f.Init([]common.Pgid{3, 5, 6, 7, 12, 13}) + + f.Free(100, common.NewPage(20, 0, 0, 1)) + f.Allocate(100, 3) + f.Free(100, common.NewPage(25, 0, 0, 0)) + f.Allocate(100, 2) + + require.Equal(t, map[common.Pgid]common.Txid{5: 100, 12: 100}, f.allocs) + require.Equal(t, map[common.Txid]*txPending{100: { + ids: []common.Pgid{20, 21, 25}, + alloctx: []common.Txid{0, 0, 0}, + }}, f.pending) + + f.Rollback(100) + + require.Equal(t, map[common.Pgid]common.Txid{}, f.allocs) + require.Equal(t, map[common.Txid]*txPending{}, f.pending) +} + +func newTestArrayFreelist() *array { + f := NewArrayFreelist() + return f.(*array) +} diff --git a/internal/freelist/hashmap_test.go b/internal/freelist/hashmap_test.go index 32cc5dfa0..c77a05800 100644 --- a/internal/freelist/hashmap_test.go +++ b/internal/freelist/hashmap_test.go @@ -6,6 +6,8 @@ import ( "sort" "testing" + "github.com/stretchr/testify/require" + "go.etcd.io/bbolt/internal/common" ) @@ -128,6 +130,28 @@ func TestFreelistHashmap_GetFreePageIDs(t *testing.T) { } } +func Test_Freelist_Hashmap_Rollback(t *testing.T) { + f := newTestHashMapFreelist() + + f.Init([]common.Pgid{3, 5, 6, 7, 12, 13}) + + f.Free(100, common.NewPage(20, 0, 0, 1)) + f.Allocate(100, 3) + f.Free(100, common.NewPage(25, 0, 0, 0)) + f.Allocate(100, 2) + + require.Equal(t, map[common.Pgid]common.Txid{5: 100, 12: 100}, f.allocs) + require.Equal(t, map[common.Txid]*txPending{100: { + ids: []common.Pgid{20, 21, 25}, + alloctx: []common.Txid{0, 0, 0}, + }}, f.pending) + + f.Rollback(100) + + require.Equal(t, map[common.Pgid]common.Txid{}, f.allocs) + require.Equal(t, map[common.Txid]*txPending{}, f.pending) +} + func Benchmark_freelist_hashmapGetFreePageIDs(b *testing.B) { f := newTestHashMapFreelist() N := int32(100000) diff --git a/internal/freelist/shared.go b/internal/freelist/shared.go index 16a5b3286..f2d113008 100644 --- a/internal/freelist/shared.go +++ b/internal/freelist/shared.go @@ -108,6 +108,13 @@ func (t *shared) Rollback(txid common.Txid) { } // Remove pages from pending list and mark as free if allocated by txid. delete(t.pending, txid) + + // Remove pgids which are allocated by this txid + for pgid, tid := range t.allocs { + if tid == txid { + delete(t.allocs, pgid) + } + } } func (t *shared) AddReadonlyTXID(tid common.Txid) {