Skip to content

Commit

Permalink
cache: Use generic and context (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunshineplan authored Jun 11, 2024
1 parent a60b344 commit 66f3f16
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 149 deletions.
182 changes: 82 additions & 100 deletions cache/cache.go
Original file line number Diff line number Diff line change
@@ -1,137 +1,119 @@
package cache

import (
"context"
"log"
"sync"
"time"
)

type item struct {
sync.Mutex
Value any
Duration time.Duration
Expiration int64
Regenerate func() (any, error)
}
var valueKey int

func (i *item) Expired() bool {
if i.Duration == 0 {
return false
func newContext(value any, lifecycle time.Duration) (ctx context.Context, cancel context.CancelFunc) {
ctx = context.WithValue(context.Background(), &valueKey, value)
if lifecycle > 0 {
ctx, cancel = context.WithTimeout(ctx, lifecycle)
}

return time.Now().UnixNano() > i.Expiration
return
}

// Cache is cache struct.
type Cache struct {
cache sync.Map
autoClean bool
type item[T any] struct {
sync.Mutex
context.Context
cancel context.CancelFunc
lifecycle time.Duration
fn func() (T, error)
}

// New creates a new cache with auto clean or not.
func New(autoClean bool) *Cache {
c := &Cache{autoClean: autoClean}
func (i *item[T]) value() T {
i.Lock()
defer i.Unlock()
return i.Value(&valueKey).(T)
}

if autoClean {
go c.check()
func (i *item[T]) renew() T {
v, err := i.fn()
if err != nil {
log.Print(err)
v = i.value()
}

return c
i.Lock()
defer i.Unlock()
i.Context, i.cancel = newContext(v, i.lifecycle)
return v
}

// Set sets cache value for a key, if f is presented, this value will regenerate when expired.
func (c *Cache) Set(key, value any, d time.Duration, f func() (any, error)) {
c.cache.Store(key, &item{
Value: value,
Duration: d,
Expiration: time.Now().Add(d).UnixNano(),
Regenerate: f,
})
// Cache is cache struct.
type Cache[Key, Value any] struct {
cache sync.Map
autoRenew bool
}

func (c *Cache) regenerate(i *item) {
i.Expiration = 0
i.Unlock()

go func() {
value, err := i.Regenerate()

i.Lock()
defer i.Unlock()
// New creates a new cache with auto clean or not.
func New[Key, Value any](autoRenew bool) *Cache[Key, Value] {
return &Cache[Key, Value]{autoRenew: autoRenew}
}

if err != nil {
log.Print(err)
} else {
i.Value = value
}
i.Expiration = time.Now().Add(i.Duration).UnixNano()
}()
// Set sets cache value for a key, if fn is presented, this value will regenerate when expired.
func (c *Cache[Key, Value]) Set(key Key, value Value, lifecycle time.Duration, fn func() (Value, error)) {
i := &item[Value]{lifecycle: lifecycle, fn: fn}
i.Context, i.cancel = newContext(value, lifecycle)
if c.autoRenew && lifecycle > 0 {
go func() {
for {
<-i.Done()
if err := i.Err(); err == context.DeadlineExceeded {
if i.fn != nil {
i.renew()
} else {
c.Delete(key)
}
} else {
return
}
}
}()
}
c.cache.Store(key, i)
}

// Get gets cache value by key and whether value was found.
func (c *Cache) Get(key any) (any, bool) {
value, ok := c.cache.Load(key)
func (c *Cache[Key, Value]) Get(key Key) (Value, bool) {
v, ok := c.cache.Load(key)
if !ok {
return nil, false
return *new(Value), false
}

i := value.(*item)
i.Lock()

if i.Expired() && !c.autoClean {
if i.Regenerate == nil {
c.cache.Delete(key)
i.Unlock()

return nil, false
if i := v.(*item[Value]); !c.autoRenew && i.Err() == context.DeadlineExceeded {
if i.fn == nil {
c.Delete(key)
return *new(Value), false
}

defer c.regenerate(i)

return i.Value, true
return i.renew(), true
} else {
return i.value(), true
}

i.Unlock()

return i.Value, true
}

// Delete deletes the value for a key.
func (c *Cache) Delete(key any) {
c.cache.Delete(key)
func (c *Cache[Key, Value]) Delete(key Key) {
if v, ok := c.cache.LoadAndDelete(key); ok {
if v, ok := v.(*item[Value]); ok {
if v.cancel != nil {
v.cancel()
}
}
}
}

// Empty deletes all values in cache.
func (c *Cache) Empty() {
c.cache.Range(func(key, _ any) bool {
c.cache.Delete(key)
func (c *Cache[Key, Value]) Empty() {
c.cache.Range(func(k, v any) bool {
c.cache.Delete(k)
if v, ok := v.(*item[Value]); ok {
if v.cancel != nil {
v.cancel()
}
}
return true
})
}

func (c *Cache) check() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()

for range ticker.C {
c.cache.Range(func(key, value any) bool {
i := value.(*item)
i.Lock()

if i.Expired() {
if i.Regenerate == nil {
c.cache.Delete(key)
i.Unlock()
} else {
defer c.regenerate(i)
}

return true
}

i.Unlock()

return true
})
}
}
62 changes: 20 additions & 42 deletions cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,89 +6,67 @@ import (
)

func TestSetGetDelete(t *testing.T) {
cache := New(false)

cache := New[string, string](false)
cache.Set("key", "value", 0, nil)

value, ok := cache.Get("key")
if !ok {
if value, ok := cache.Get("key"); !ok {
t.Fatal("expected ok; got not")
}
if value != "value" {
} else if value != "value" {
t.Errorf("expected value; got %q", value)
}

cache.Delete("key")
_, ok = cache.Get("key")
if ok {
if _, ok := cache.Get("key"); ok {
t.Error("expected not ok; got ok")
}
}

func TestEmpty(t *testing.T) {
cache := New(false)

cache := New[string, int](false)
cache.Set("a", 1, 0, nil)
cache.Set("b", 2, 0, nil)
cache.Set("c", 3, 0, nil)

for _, i := range []string{"a", "b", "c"} {
_, ok := cache.Get(i)
if !ok {
if _, ok := cache.Get(i); !ok {
t.Error("expected ok; got not")
}
}

cache.Empty()

for _, i := range []string{"a", "b", "c"} {
_, ok := cache.Get(i)
if ok {
if _, ok := cache.Get(i); ok {
t.Error("expected not ok; got ok")
}
}
}

func TestAutoCleanRegenerate(t *testing.T) {
cache := New(true)

done := make(chan bool)
cache.Set("regenerate", "old", 2*time.Second, func() (any, error) {
defer func() { done <- true }()
func TestRenew(t *testing.T) {
cache := New[string, string](true)
expire := make(chan struct{})
cache.Set("renew", "old", 2*time.Second, func() (string, error) {
defer func() { close(expire) }()
return "new", nil
})
cache.Set("expire", "value", 1*time.Second, nil)

value, ok := cache.Get("expire")
if !ok {
if value, ok := cache.Get("expire"); !ok {
t.Fatal("expected ok; got not")
}
if expect := "value"; value != expect {
} else if expect := "value"; value != expect {
t.Errorf("expected %q; got %q", expect, value)
}

value, ok = cache.Get("regenerate")
if !ok {
if value, ok := cache.Get("renew"); !ok {
t.Fatal("expected ok; got not")
}
if expect := "old"; value != expect {
} else if expect := "old"; value != expect {
t.Errorf("expected %q; got %q", expect, value)
}

ticker := time.NewTicker(4 * time.Second)
defer ticker.Stop()

select {
case <-done:
case <-expire:
time.Sleep(100 * time.Millisecond)
if _, ok := cache.Get("expire"); ok {
t.Error("expected not ok; got ok")
}

value, ok := cache.Get("regenerate")
value, ok := cache.Get("renew")
if !ok {
t.Fatal("expected ok; got not")
}
if expect := "new"; value != expect {
} else if expect := "new"; value != expect {
t.Errorf("expected %q; got %q", expect, value)
}
case <-ticker.C:
Expand Down
4 changes: 1 addition & 3 deletions counter/listener.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package counter

import (
"net"
)
import "net"

var (
_ net.Listener = &Listener{}
Expand Down
8 changes: 4 additions & 4 deletions httpsvr/httpsvr.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/sunshineplan/utils/log"
)

var certCache = cache.New(false)
var certCache = cache.New[string, *tls.Certificate](false)

var defaultReload = 24 * time.Hour

Expand Down Expand Up @@ -124,7 +124,7 @@ func (s *Server) run() error {
return nil
}

func (s *Server) loadCertificate() (any, error) {
func (s *Server) loadCertificate() (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(s.certFile, s.keyFile)
if err != nil {
return nil, err
Expand All @@ -135,7 +135,7 @@ func (s *Server) loadCertificate() (any, error) {
func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
v, ok := certCache.Get("cert")
if ok {
return v.(*tls.Certificate), nil
return v, nil
}

cert, err := s.loadCertificate()
Expand All @@ -148,7 +148,7 @@ func (s *Server) getCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error
}
certCache.Set("cert", cert, s.reload, s.loadCertificate)

return cert.(*tls.Certificate), nil
return cert, nil
}

// Run runs an HTTP server which can be gracefully shut down.
Expand Down

0 comments on commit 66f3f16

Please sign in to comment.