Skip to content

Commit

Permalink
refactor: extract broadcast and pubsub configs
Browse files Browse the repository at this point in the history
  • Loading branch information
palkan committed Oct 9, 2024
1 parent 8ea4e9e commit 61f6b05
Show file tree
Hide file tree
Showing 17 changed files with 392 additions and 145 deletions.
32 changes: 27 additions & 5 deletions broadcast/legacy_nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,46 @@ package broadcast

import (
"context"
"fmt"
"log/slog"
"strings"

"github.com/nats-io/nats.go"

nconfig "github.com/anycable/anycable-go/nats"
)

type LegacyNATSConfig struct {
Channel string `toml:"channel"`
NATS *nconfig.NATSConfig `toml:"nats"`
}

func NewLegacyNATSConfig() LegacyNATSConfig {
return LegacyNATSConfig{
Channel: "__anycable__",
}
}

func (c LegacyNATSConfig) ToToml() string {
var result strings.Builder
result.WriteString(fmt.Sprintf("channel = \"%s\"\n", c.Channel))

result.WriteString("\n")

return result.String()
}

type LegacyNATSBroadcaster struct {
conn *nats.Conn
handler Handler
config *nconfig.NATSConfig
config *LegacyNATSConfig

log *slog.Logger
}

var _ Broadcaster = (*LegacyNATSBroadcaster)(nil)

func NewLegacyNATSBroadcaster(node Handler, c *nconfig.NATSConfig, l *slog.Logger) *LegacyNATSBroadcaster {
func NewLegacyNATSBroadcaster(node Handler, c *LegacyNATSConfig, l *slog.Logger) *LegacyNATSBroadcaster {
return &LegacyNATSBroadcaster{
config: c,
handler: node,
Expand All @@ -34,7 +56,7 @@ func (LegacyNATSBroadcaster) IsFanout() bool {
func (s *LegacyNATSBroadcaster) Start(done chan (error)) error {
connectOptions := []nats.Option{
nats.RetryOnFailedConnect(true),
nats.MaxReconnects(s.config.MaxReconnectAttempts),
nats.MaxReconnects(s.config.NATS.MaxReconnectAttempts),
nats.DisconnectErrHandler(func(nc *nats.Conn, err error) {
if err != nil {
s.log.Warn("connection failed", "error", err)
Expand All @@ -45,11 +67,11 @@ func (s *LegacyNATSBroadcaster) Start(done chan (error)) error {
}),
}

if s.config.DontRandomizeServers {
if s.config.NATS.DontRandomizeServers {
connectOptions = append(connectOptions, nats.DontRandomize())
}

nc, err := nats.Connect(s.config.Servers, connectOptions...)
nc, err := nats.Connect(s.config.NATS.Servers, connectOptions...)

if err != nil {
return err
Expand Down
26 changes: 26 additions & 0 deletions broadcast/legacy_nats_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package broadcast

import (
"testing"

"github.com/BurntSushi/toml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestLegacyNATSConfig__ToToml(t *testing.T) {
conf := NewLegacyNATSConfig()
conf.Channel = "_test_"

tomlStr := conf.ToToml()

assert.Contains(t, tomlStr, "channel = \"_test_\"")

// Round-trip test
conf2 := NewLegacyNATSConfig()

_, err := toml.Decode(tomlStr, &conf2)
require.NoError(t, err)

assert.Equal(t, conf, conf2)
}
34 changes: 27 additions & 7 deletions broadcast/legacy_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@ import (
"github.com/gomodule/redigo/redis"
)

type LegacyRedisConfig struct {
Channel string `toml:"channel"`
Redis *rconfig.RedisConfig `toml:"redis"`
}

func NewLegacyRedisConfig() LegacyRedisConfig {
return LegacyRedisConfig{
Channel: "__anycable__",
}
}

func (c LegacyRedisConfig) ToToml() string {
var result strings.Builder
result.WriteString(fmt.Sprintf("channel = \"%s\"\n", c.Channel))

result.WriteString("\n")

return result.String()
}

// LegacyRedisBroadcaster contains information about Redis pubsub connection
type LegacyRedisBroadcaster struct {
node Handler
Expand All @@ -33,18 +53,18 @@ type LegacyRedisBroadcaster struct {
}

// NewLegacyRedisBroadcaster returns new RedisSubscriber struct
func NewLegacyRedisBroadcaster(node Handler, config *rconfig.RedisConfig, l *slog.Logger) *LegacyRedisBroadcaster {
func NewLegacyRedisBroadcaster(node Handler, config *LegacyRedisConfig, l *slog.Logger) *LegacyRedisBroadcaster {
return &LegacyRedisBroadcaster{
node: node,
url: config.URL,
sentinels: config.Sentinels,
sentinelDiscoveryInterval: time.Duration(config.SentinelDiscoveryInterval),
url: config.Redis.URL,
sentinels: config.Redis.Sentinels,
sentinelDiscoveryInterval: time.Duration(config.Redis.SentinelDiscoveryInterval),
channel: config.Channel,
pingInterval: time.Duration(config.KeepalivePingInterval),
pingInterval: time.Duration(config.Redis.KeepalivePingInterval),
reconnectAttempt: 0,
maxReconnectAttempts: config.MaxReconnectAttempts,
maxReconnectAttempts: config.Redis.MaxReconnectAttempts,
log: l.With("context", "broadcast").With("provider", "redis"),
tlsVerify: config.TLSVerify,
tlsVerify: config.Redis.TLSVerify,
}
}

Expand Down
26 changes: 26 additions & 0 deletions broadcast/legacy_redis_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package broadcast

import (
"testing"

"github.com/BurntSushi/toml"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestLegacyRedisConfig__ToToml(t *testing.T) {
conf := NewLegacyRedisConfig()
conf.Channel = "_test_"

tomlStr := conf.ToToml()

assert.Contains(t, tomlStr, "channel = \"_test_\"")

// Round-trip test
conf2 := NewLegacyRedisConfig()

_, err := toml.Decode(tomlStr, &conf2)
require.NoError(t, err)

assert.Equal(t, conf, conf2)
}
66 changes: 50 additions & 16 deletions broadcast/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,44 @@ import (
"github.com/redis/rueidis"
)

type RedisConfig struct {
Stream string `toml:"stream"`
Group string `toml:"group"`
// Redis stream read wait time in milliseconds
StreamReadBlockMilliseconds int64 `toml:"stream_read_block_milliseconds"`

Redis *rconfig.RedisConfig `toml:"redis"`
}

func NewRedisConfig() RedisConfig {
return RedisConfig{
Stream: "__anycable__",
Group: "bx",
StreamReadBlockMilliseconds: 2000,
}
}

func (c RedisConfig) ToToml() string {
var result strings.Builder

result.WriteString("# Redis stream name for broadcasts\n")
result.WriteString(fmt.Sprintf("stream = \"%s\"\n", c.Stream))

result.WriteString("# Stream consumer group name\n")
result.WriteString(fmt.Sprintf("group = \"%s\"\n", c.Group))

result.WriteString("# Streams read wait time in milliseconds\n")
result.WriteString(fmt.Sprintf("stream_read_block_milliseconds = %d\n", c.StreamReadBlockMilliseconds))

result.WriteString("\n")

return result.String()
}

// RedisBroadcaster represents Redis broadcaster using Redis streams
type RedisBroadcaster struct {
node Handler
config *rconfig.RedisConfig
config *RedisConfig

// Unique consumer identifier
consumerName string
Expand All @@ -39,7 +73,7 @@ type RedisBroadcaster struct {
var _ Broadcaster = (*RedisBroadcaster)(nil)

// NewRedisBroadcaster builds a new RedisSubscriber struct
func NewRedisBroadcaster(node Handler, config *rconfig.RedisConfig, l *slog.Logger) *RedisBroadcaster {
func NewRedisBroadcaster(node Handler, config *RedisConfig, l *slog.Logger) *RedisBroadcaster {
name, _ := nanoid.Nanoid(6)

return &RedisBroadcaster{
Expand All @@ -57,18 +91,18 @@ func (s *RedisBroadcaster) IsFanout() bool {
}

func (s *RedisBroadcaster) Start(done chan error) error {
options, err := s.config.ToRueidisOptions()
options, err := s.config.Redis.ToRueidisOptions()

if err != nil {
return err
}

if s.config.IsSentinel() { //nolint:gocritic
s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (sentinels)", s.config.Hostnames()))
} else if s.config.IsCluster() {
s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (cluster)", s.config.Hostnames()))
if s.config.Redis.IsSentinel() { //nolint:gocritic
s.log.With("stream", s.config.Stream).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (sentinels)", s.config.Redis.Hostnames()))
} else if s.config.Redis.IsCluster() {
s.log.With("stream", s.config.Stream).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %v (cluster)", s.config.Redis.Hostnames()))
} else {
s.log.With("stream", s.config.Channel).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %s", s.config.Hostname()))
s.log.With("stream", s.config.Stream).With("consumer", s.consumerName).Info(fmt.Sprintf("Starting Redis broadcaster at %s", s.config.Redis.Hostname()))
}

s.clientOptions = options
Expand All @@ -94,7 +128,7 @@ func (s *RedisBroadcaster) Shutdown(ctx context.Context) error {

res := s.client.Do(
context.Background(),
s.client.B().XgroupDelconsumer().Key(s.config.Channel).Group(s.config.Group).Consumername(s.consumerName).Build(),
s.client.B().XgroupDelconsumer().Key(s.config.Stream).Group(s.config.Group).Consumername(s.consumerName).Build(),
)

err := res.Error()
Expand Down Expand Up @@ -144,7 +178,7 @@ func (s *RedisBroadcaster) runReader(done chan (error)) {

// First, create a consumer group for the stream
err = s.client.Do(context.Background(),
s.client.B().XgroupCreate().Key(s.config.Channel).Group(s.config.Group).Id("$").Mkstream().Build(),
s.client.B().XgroupCreate().Key(s.config.Stream).Group(s.config.Group).Id("$").Mkstream().Build(),
).Error()

if err != nil {
Expand Down Expand Up @@ -204,7 +238,7 @@ func (s *RedisBroadcaster) runReader(done chan (error)) {

func (s *RedisBroadcaster) readFromStream(blockTime int64) ([]rueidis.XRangeEntry, error) {
streamRes := s.client.Do(context.Background(),
s.client.B().Xreadgroup().Group(s.config.Group, s.consumerName).Block(blockTime).Streams().Key(s.config.Channel).Id(">").Build(),
s.client.B().Xreadgroup().Group(s.config.Group, s.consumerName).Block(blockTime).Streams().Key(s.config.Stream).Id(">").Build(),
)

res, _ := streamRes.AsXRead()
Expand All @@ -218,7 +252,7 @@ func (s *RedisBroadcaster) readFromStream(blockTime int64) ([]rueidis.XRangeEntr
return nil, nil
}

if messages, ok := res[s.config.Channel]; ok {
if messages, ok := res[s.config.Stream]; ok {
return messages, nil
}

Expand All @@ -227,7 +261,7 @@ func (s *RedisBroadcaster) readFromStream(blockTime int64) ([]rueidis.XRangeEntr

func (s *RedisBroadcaster) autoclaimMessages(blockTime int64) ([]rueidis.XRangeEntry, error) {
claimRes := s.client.Do(context.Background(),
s.client.B().Xautoclaim().Key(s.config.Channel).Group(s.config.Group).Consumer(s.consumerName).MinIdleTime(fmt.Sprintf("%d", blockTime)).Start("0-0").Build(),
s.client.B().Xautoclaim().Key(s.config.Stream).Group(s.config.Group).Consumer(s.consumerName).MinIdleTime(fmt.Sprintf("%d", blockTime)).Start("0-0").Build(),
)

arr, err := claimRes.ToArray()
Expand Down Expand Up @@ -260,8 +294,8 @@ func (s *RedisBroadcaster) broadcastXrange(messages []rueidis.XRangeEntry) {
s.node.HandleBroadcast([]byte(payload))

ackRes := s.client.DoMulti(context.Background(),
s.client.B().Xack().Key(s.config.Channel).Group(s.config.Group).Id(message.ID).Build(),
s.client.B().Xdel().Key(s.config.Channel).Id(message.ID).Build(),
s.client.B().Xack().Key(s.config.Stream).Group(s.config.Group).Id(message.ID).Build(),
s.client.B().Xdel().Key(s.config.Stream).Id(message.ID).Build(),
)

err := ackRes[0].Error()
Expand All @@ -274,7 +308,7 @@ func (s *RedisBroadcaster) broadcastXrange(messages []rueidis.XRangeEntry) {
}

func (s *RedisBroadcaster) maybeReconnect(done chan (error)) {
if s.reconnectAttempt >= s.config.MaxReconnectAttempts {
if s.reconnectAttempt >= s.config.Redis.MaxReconnectAttempts {
close(s.finishedCh)
done <- errors.New("failed to reconnect to Redis: attempts exceeded") //nolint:stylecheck
return
Expand Down
36 changes: 31 additions & 5 deletions broadcast/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"time"

"github.com/BurntSushi/toml"
"github.com/anycable/anycable-go/mocks"
rconfig "github.com/anycable/anycable-go/redis"
"github.com/anycable/anycable-go/utils"
Expand Down Expand Up @@ -64,10 +65,12 @@ func TestRedisBroadcaster(t *testing.T) {
return
}

config := rconfig.NewRedisConfig()
rconfig := rconfig.NewRedisConfig()
config := NewRedisConfig()
config.Redis = &rconfig

if redisURL != "" {
config.URL = redisURL
rconfig.URL = redisURL
}

config.StreamReadBlockMilliseconds = 500
Expand Down Expand Up @@ -154,12 +157,14 @@ func TestRedisBroadcasterAcksClaims(t *testing.T) {
return
}

config := rconfig.NewRedisConfig()
rconfig := rconfig.NewRedisConfig()
config := NewRedisConfig()
config.Redis = &rconfig
// Make it short to avoid sleeping for too long in tests
config.StreamReadBlockMilliseconds = 100

if redisURL != "" {
config.URL = redisURL
rconfig.URL = redisURL
}

handler := &mocks.Handler{}
Expand All @@ -181,7 +186,7 @@ func TestRedisBroadcasterAcksClaims(t *testing.T) {
closed = true
// Close the connection to prevent consumer from ack-ing the message
broadcaster.client.Close()
broadcaster.reconnectAttempt = config.MaxReconnectAttempts + 1
broadcaster.reconnectAttempt = config.Redis.MaxReconnectAttempts + 1
}
})

Expand Down Expand Up @@ -285,3 +290,24 @@ func waitRedisStreamConsumers(client rueidis.Client, count int) error {
attempts++
}
}

func TestRedisConfig__ToToml(t *testing.T) {
config := NewRedisConfig()
config.Stream = "test_stream"
config.Group = "test_group"
config.StreamReadBlockMilliseconds = 3000

tomlStr := config.ToToml()

assert.Contains(t, tomlStr, "stream = \"test_stream\"")
assert.Contains(t, tomlStr, "group = \"test_group\"")
assert.Contains(t, tomlStr, "stream_read_block_milliseconds = 3000")

// Round-trip test
config2 := NewRedisConfig()

_, err := toml.Decode(tomlStr, &config2)
require.NoError(t, err)

assert.Equal(t, config, config2)
}
Loading

0 comments on commit 61f6b05

Please sign in to comment.