diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..21747dd --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,106 @@ +package cache + +import ( + "time" + + "github.com/Code-Hex/synchro" + "github.com/Code-Hex/synchro/tz" +) + +type CacheStorage[T comparable] interface { + Set(key string, value T) + Get(key string) (T, bool) + Delete(key string) + BulkSet(data map[string]T) + BulkGet(keys []string) map[string]T + BulkDelete(keys []string) + GetAll(mustIncludeKeys ...string) map[string]T + DeleteAll() +} + +type cacheStorage[T comparable] struct { + data map[string]T + exipresAt synchro.Time[tz.AsiaTokyo] + exiresIn time.Duration +} + +func NewCacheStorage[T comparable](exiresIn time.Duration) CacheStorage[T] { + return &cacheStorage[T]{ + data: make(map[string]T), + exipresAt: synchro.Now[tz.AsiaTokyo](), + exiresIn: exiresIn, + } +} + +func (c *cacheStorage[T]) Set(key string, value T) { + c.data[key] = value + c.exipresAt = synchro.Now[tz.AsiaTokyo]().Add(c.exiresIn) +} + +func (c *cacheStorage[T]) Get(key string) (T, bool) { + if synchro.Now[tz.AsiaTokyo]().After(c.exipresAt) { + c.DeleteAll() + var v T + return v, false + } + + value, ok := c.data[key] + return value, ok +} + +func (c *cacheStorage[T]) Delete(key string) { + delete(c.data, key) + c.exipresAt = synchro.Now[tz.AsiaTokyo]().Add(c.exiresIn) +} + +func (c *cacheStorage[T]) BulkSet(data map[string]T) { + for k, v := range data { + c.data[k] = v + } + c.exipresAt = synchro.Now[tz.AsiaTokyo]().Add(c.exiresIn) +} + +func (c *cacheStorage[T]) BulkGet(keys []string) map[string]T { + if synchro.Now[tz.AsiaTokyo]().After(c.exipresAt) { + c.DeleteAll() + return make(map[string]T, len(keys)) + } + data := make(map[string]T) + for _, k := range keys { + if v, ok := c.data[k]; ok { + data[k] = v + } + } + return data +} + +func (c *cacheStorage[T]) BulkDelete(keys []string) { + for _, k := range keys { + delete(c.data, k) + } + c.exipresAt = synchro.Now[tz.AsiaTokyo]().Add(c.exiresIn) +} + +func (c *cacheStorage[T]) GetAll(mustIncludeKeys ...string) map[string]T { + include := true + if len(mustIncludeKeys) > 0 { + for _, k := range mustIncludeKeys { + if _, ok := c.data[k]; !ok { + include = false + break + } + } + } + + if synchro.Now[tz.AsiaTokyo]().After(c.exipresAt) || !include { + c.DeleteAll() + return make(map[string]T) + } + + return c.data +} + +func (c *cacheStorage[T]) DeleteAll() { + c.data = make(map[string]T) + c.exipresAt = synchro.Now[tz.AsiaTokyo]().Add(c.exiresIn) +} diff --git a/slack/slack.go b/slack/slack.go index f7aef0e..268134d 100644 --- a/slack/slack.go +++ b/slack/slack.go @@ -11,6 +11,7 @@ import ( "github.com/Code-Hex/synchro" "github.com/Code-Hex/synchro/tz" + "github.com/kmc-jp/inviteallmcg/cache" "github.com/kmc-jp/inviteallmcg/config" "github.com/slack-go/slack" "github.com/slack-go/slack/slackevents" @@ -25,16 +26,13 @@ type Client struct { cacheDuration time.Duration - prefixedChannelCache map[string]map[string]string - prefixedChannelCacheExpiresAt map[string]synchro.Time[tz.AsiaTokyo] - - mcgMemberCache map[string]struct{} - mcgMemberCacheExpiresAt synchro.Time[tz.AsiaTokyo] + prefixedChannelIDCache cache.CacheStorage[string] + everythingChannelIDCache cache.CacheStorage[string] + mcgMemberCache cache.CacheStorage[struct{}] + determineYearCache cache.CacheStorage[ObservTarget] determineYearRegex *regexp.Regexp determineYearCacheDuration time.Duration - determineYearCache ObservTarget - determineYearExpiresAt synchro.Time[tz.AsiaTokyo] } type SlackLogger struct { @@ -66,10 +64,10 @@ func NewSlackClient(cfg config.Config) Client { cacheDuration: cfg.SlackCacheDuration, - prefixedChannelCache: make(map[string]map[string]string, 2), - prefixedChannelCacheExpiresAt: make(map[string]synchro.Time[tz.AsiaTokyo], 2), - - mcgMemberCache: make(map[string]struct{}, 100), + prefixedChannelIDCache: cache.NewCacheStorage[string](cfg.SlackCacheDuration), + everythingChannelIDCache: cache.NewCacheStorage[string](cfg.SlackCacheDuration), + mcgMemberCache: cache.NewCacheStorage[struct{}](cfg.SlackCacheDuration), + determineYearCache: cache.NewCacheStorage[ObservTarget](cfg.SlackDetermineYearCacheDuration), determineYearRegex: regexp.MustCompile(cfg.MCGJoinChannelRegex), determineYearCacheDuration: cfg.SlackDetermineYearCacheDuration, @@ -98,22 +96,35 @@ func (c *Client) InviteUsersToChannels(ctx context.Context, channelIDs []string, return nil } -func (c *Client) GetPrefixedChannels(ctx context.Context, prefix string, mustIncludeChannelIDs ...string) (map[string]string, error) { - now := synchro.Now[tz.AsiaTokyo]() - if cache, ok := c.prefixedChannelCache[prefix]; ok && now.Before(c.prefixedChannelCacheExpiresAt[prefix]) { - useCache := true - for _, mustIncludeChannelID := range mustIncludeChannelIDs { - if _, ok := cache[mustIncludeChannelID]; !ok { - useCache = false - break - } - } +func (c *Client) GetPrefixedEverythingChannel(ctx context.Context, prefix string) (string, error) { + const cacheKey = "key" + cache, ok := c.everythingChannelIDCache.Get(cacheKey) + if ok { + return cache, nil + } - if useCache { - return cache, nil + channels, err := c.GetPublicChannels(ctx) + if err != nil { + return "", err + } + + for _, channel := range channels { + if channel.Name == fmt.Sprintf("%s-everything", prefix) { + c.everythingChannelIDCache.Set(prefix, channel.ID) + return channel.ID, nil } } + return "", fmt.Errorf("everything channel not found") +} + +// key: channelID, value: channelName +func (c *Client) GetPrefixedChannels(ctx context.Context, prefix string, mustIncludeChannelIDs ...string) (map[string]string, error) { + caches := c.prefixedChannelIDCache.GetAll(mustIncludeChannelIDs...) + if len(caches) > 0 { + return caches, nil + } + channels, err := c.GetPublicChannels(ctx) if err != nil { return nil, err @@ -126,9 +137,7 @@ func (c *Client) GetPrefixedChannels(ctx context.Context, prefix string, mustInc } } - c.prefixedChannelCache[prefix] = prefixedChannels - c.prefixedChannelCacheExpiresAt[prefix] = now.Add(c.cacheDuration) - + c.prefixedChannelIDCache.BulkSet(prefixedChannels) return prefixedChannels, nil } @@ -161,16 +170,9 @@ func (c *Client) GetPublicChannels(ctx context.Context) ([]slack.Channel, error) } func (c *Client) GetAllMCGMembers(ctx context.Context, mustIncludeUsers ...string) (map[string]struct{}, error) { - now := synchro.Now[tz.AsiaTokyo]() - - var include bool - for _, mustIncludeUser := range mustIncludeUsers { - _, ok := c.mcgMemberCache[mustIncludeUser] - include = include || ok - } - - if c.mcgMemberCache != nil && now.Before(c.mcgMemberCacheExpiresAt) && include { - return c.mcgMemberCache, nil + cache := c.mcgMemberCache.GetAll(mustIncludeUsers...) + if len(cache) > 0 { + return cache, nil } users, err := c.slackUserClient.GetUsersContext(ctx) @@ -186,9 +188,7 @@ func (c *Client) GetAllMCGMembers(ctx context.Context, mustIncludeUsers ...strin } } - c.mcgMemberCache = mcgMembers - c.mcgMemberCacheExpiresAt = now.Add(c.cacheDuration) - + c.mcgMemberCache.BulkSet(mcgMembers) return mcgMembers, nil } @@ -233,12 +233,14 @@ func (c *Client) ForwardMessage(ctx context.Context, everythingChannelID string, slog.Error("Error getting permalink", "error", err) } + mentionReplaced := strings.ReplaceAll(message.Text, "@", "@\u200B") + blocks := []slack.Block{ &slack.SectionBlock{ Type: slack.MBTSection, Text: &slack.TextBlockObject{ Type: slack.MarkdownType, - Text: fmt.Sprintf("<%s|`#%s`> %s", permalink, sourceChannelName, message.Text), + Text: fmt.Sprintf("<%s|`#%s`> %s", permalink, sourceChannelName, mentionReplaced), }, }, } @@ -312,15 +314,8 @@ func (c *Client) HandleSlackEvents(ctx context.Context) error { continue } - var everythingChannelID string - for id, name := range shinkanChannels { - if name == fmt.Sprintf("%s-everything", observTarget.year) { - everythingChannelID = id - break - } - } - - if everythingChannelID == "" { + everythingChannelID, err := c.GetPrefixedEverythingChannel(ctx, observTarget.year) + if err != nil { slog.Error("Everything channel not found", "year", observTarget.year) continue } @@ -330,13 +325,18 @@ func (c *Client) HandleSlackEvents(ctx context.Context) error { continue } - sourceChanName, ok := shinkanChannels[ev.Channel] + sourceChannelName, ok := shinkanChannels[ev.Channel] if !ok { slog.Error("Source channel not found", "channel", ev.Channel) continue } - err = c.ForwardMessage(ctx, everythingChannelID, sourceChanName, *ev) + if strings.Contains(sourceChannelName, "announce") { + slog.Info("Ignored message event from announce channel", "channel", sourceChannelName) + return nil + } + + err = c.ForwardMessage(ctx, everythingChannelID, sourceChannelName, *ev) if err != nil { slog.Error("Error forwarding message", "error", err) continue @@ -408,6 +408,7 @@ func (c *Client) HandleSlackEvents(ctx context.Context) error { mcgMembers["UQYG1JA95"] = struct{}{} } + c.prefixedChannelIDCache.Set(ev.Channel.ID, ev.Channel.Name) channels, err := c.GetPrefixedChannels(ctx, fmt.Sprintf("%s-", observTarget.year)) if err != nil { slog.Error("Error getting prefixed channels", "error", err) @@ -442,10 +443,12 @@ type ObservTarget struct { } func (c *Client) DetermineObservTarget(ctx context.Context) (ObservTarget, error) { + const cacheKey = "key" now := synchro.Now[tz.AsiaTokyo]() - if c.determineYearCache != (ObservTarget{}) && now.Before(c.determineYearExpiresAt) { - return c.determineYearCache, nil + cache, ok := c.determineYearCache.Get(cacheKey) + if ok { + return cache, nil } year := now.Year() @@ -499,8 +502,6 @@ func (c *Client) DetermineObservTarget(ctx context.Context) (ObservTarget, error slog.Debug("Determined general channel", "target", target) - c.determineYearCache = target - c.determineYearExpiresAt = now.Add(c.determineYearCacheDuration) - + c.determineYearCache.Set(cacheKey, target) return target, nil }