From dc2a944a5da7ad703633e6a3c7ad04b5bc81fb60 Mon Sep 17 00:00:00 2001 From: RelicOfTesla Date: Mon, 15 Apr 2024 20:37:51 +0800 Subject: [PATCH] WithKeepTTL --- cache.go | 33 ++++++++++++++++++++++++++------- cache_internal_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 7 deletions(-) diff --git a/cache.go b/cache.go index e9f6dea..837a176 100644 --- a/cache.go +++ b/cache.go @@ -72,6 +72,7 @@ type ItemOption func(*itemOptions) type itemOptions struct { expiration time.Time // default none referenceCount int + keepTTL bool } // WithExpiration is an option to set expiration time for any items. @@ -82,6 +83,17 @@ func WithExpiration(exp time.Duration) ItemOption { } } +// WithKeepTTL is the expiration time to keep existing keys when calling Set. By default, it is replaced by the new time or never expires. +func WithKeepTTL(_keep ...bool) ItemOption { + return func(o *itemOptions) { + if len(_keep) > 0 { + o.keepTTL = _keep[0] + } else { + o.keepTTL = true + } + } +} + // WithReferenceCount is an option to set reference count for any items. // This option is only applicable to cache policies that have a reference count (e.g., Clock, LFU). // referenceCount specifies the reference count value to set for the cache item. @@ -94,7 +106,7 @@ func WithReferenceCount(referenceCount int) ItemOption { } // newItem creates a new item with specified any options. -func newItem[K comparable, V any](key K, val V, opts ...ItemOption) *Item[K, V] { +func newItem[K comparable, V any](key K, val V, opts ...ItemOption) (*Item[K, V], *itemOptions) { o := new(itemOptions) for _, optFunc := range opts { optFunc(o) @@ -104,7 +116,7 @@ func newItem[K comparable, V any](key K, val V, opts ...ItemOption) *Item[K, V] Value: val, Expiration: o.expiration, InitialReferenceCount: o.referenceCount, - } + }, o } // Cache is a thread safe cache. @@ -231,8 +243,11 @@ func (c *Cache[K, V]) GetOrSet(key K, val V, opts ...ItemOption) (actual V, load item, ok := c.cache.Get(key) if !ok || item.Expired() { - item := newItem(key, val, opts...) - c.cache.Set(key, item) + replaceItem, o := newItem(key, val, opts...) + if o.keepTTL && ok && !replaceItem.hasExpiration() { + replaceItem.Expiration = item.Expiration + } + c.cache.Set(key, replaceItem) return val, false } @@ -273,9 +288,13 @@ func (c *Cache[K, V]) DeleteExpired() { func (c *Cache[K, V]) Set(key K, val V, opts ...ItemOption) { c.mu.Lock() defer c.mu.Unlock() - item := newItem(key, val, opts...) + item, o := newItem(key, val, opts...) if item.hasExpiration() { c.expManager.update(key, item.Expiration) + } else if o.keepTTL { + if old, has := c.cache.Get(key); has { + item.Expiration = old.Expiration + } } c.cache.Set(key, item) } @@ -334,7 +353,7 @@ func (nc *NumberCache[K, V]) Increment(key K, n V) V { defer nc.nmu.Unlock() got, _ := nc.Cache.Get(key) nv := got + n - nc.Cache.Set(key, nv) + nc.Cache.Set(key, nv, WithKeepTTL()) return nv } @@ -345,6 +364,6 @@ func (nc *NumberCache[K, V]) Decrement(key K, n V) V { defer nc.nmu.Unlock() got, _ := nc.Cache.Get(key) nv := got - n - nc.Cache.Set(key, nv) + nc.Cache.Set(key, nv, WithKeepTTL()) return nv } diff --git a/cache_internal_test.go b/cache_internal_test.go index 580b726..17116bd 100644 --- a/cache_internal_test.go +++ b/cache_internal_test.go @@ -143,6 +143,48 @@ func TestDeleteExpired(t *testing.T) { t.Errorf("want %d items but got %d", want, got) } }) + + t.Run("keepTTL", func(t *testing.T) { + + t.Run("incr must keep ttl", func(t *testing.T) { + defer restore() + c := NewNumber[string, int]() + + c.Set("1", 10, WithExpiration(10*time.Millisecond)) + c.Increment("1", 20) + nowFunc = func() time.Time { + return now.Add(30 * time.Millisecond).Add(time.Millisecond) + } + c.DeleteExpired() + if c.Len() != 0 { + t.Fail() + } + }) + + testCase := func(t *testing.T, wantN int, opts ...ItemOption) { + defer restore() + c := NewNumber[string, int]() + + c.Set("1", 10, WithExpiration(10*time.Millisecond)) + c.Set("1", 20, opts...) + nowFunc = func() time.Time { + return now.Add(300 * time.Millisecond).Add(time.Millisecond) + } + c.DeleteExpired() + if c.Len() != wantN { + t.Fail() + } + } + t.Run("must forever when default set", func(t *testing.T) { + testCase(t, 1) + }) + t.Run("must expired when KeepTTL=true", func(t *testing.T) { + testCase(t, 0, WithKeepTTL()) + }) + t.Run("must forever when KeepTTL=false", func(t *testing.T) { + testCase(t, 1, WithKeepTTL(false)) + }) + }) } func max(x, y int) int {