Skip to content

Commit

Permalink
refactor: extract protocol-specific session configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
palkan committed Aug 9, 2023
1 parent 6859a86 commit ae86aed
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 110 deletions.
4 changes: 0 additions & 4 deletions broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ import (
"github.com/anycable/anycable-go/common"
)

const (
SESSION_ID_HEADER = "X-ANYCABLE-RESTORE-SID"
)

// Broadcaster is responsible for fanning-out messages to the stream clients
// and other nodes
type Broadcaster interface {
Expand Down
6 changes: 0 additions & 6 deletions cli/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"regexp"
"strings"

"github.com/anycable/anycable-go/broker"
"github.com/anycable/anycable-go/config"
"github.com/anycable/anycable-go/node"
"github.com/anycable/anycable-go/version"
Expand Down Expand Up @@ -125,11 +124,6 @@ func NewConfigFromCLI(args []string, opts ...cliOption) (*config.Config, error,

c.Headers = strings.Split(strings.ToLower(headers), ",")

// Read session ID header if using a broker
if c.BrokerAdapter != "" {
c.Headers = append(c.Headers, broker.SESSION_ID_HEADER)
}

if len(cookieFilter) > 0 {
c.Cookies = strings.Split(cookieFilter, ",")
}
Expand Down
14 changes: 14 additions & 0 deletions cli/session_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"strconv"
"time"

"github.com/anycable/anycable-go/common"
"github.com/anycable/anycable-go/node"
"github.com/anycable/anycable-go/server"
"github.com/apex/log"
Expand All @@ -12,11 +13,18 @@ import (
const (
pingIntervalParameter = "pi"
pingPrecisionParameter = "ptp"

prevSessionHeader = "X-ANYCABLE-RESTORE-SID"
prevSessionParam = "sid"
)

func (r *Runner) sessionOptionsFromProtocol(protocol string) []node.SessionOption {
opts := []node.SessionOption{}

if common.IsExtendedActionCableProtocol(protocol) {
opts = append(opts, node.WithResumable(true))
}

return opts
}

Expand All @@ -36,5 +44,11 @@ func (r *Runner) sessionOptionsFromParams(info *server.RequestInfo) []node.Sessi
opts = append(opts, node.WithPingPrecision(val))
}

if hval := info.AnyCableHeader(prevSessionHeader); hval != "" {
opts = append(opts, node.WithPrevSID(hval))
} else if pval := info.Param(prevSessionParam); pval != "" {
opts = append(opts, node.WithPrevSID(pval))
}

return opts
}
14 changes: 14 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ func ActionCableProtocols() []string {
return []string{ActionCableV1JSON, ActionCableV1ExtJSON}
}

func ActionCableExtendedProtocols() []string {
return []string{ActionCableV1ExtJSON}
}

func IsExtendedActionCableProtocol(protocol string) bool {
for _, p := range ActionCableExtendedProtocols() {
if p == protocol {
return true
}
}

return false
}

// Outgoing message types (according to Action Cable protocol)
const (
WelcomeType = "welcome"
Expand Down
5 changes: 3 additions & 2 deletions etc/anyt/broker_tests/restore_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def subscribed
end

before do
client = build_client(ignore: ["ping"])
client = build_client(ignore: ["ping"], protocol: "actioncable-v1-ext-json")

welcome_msg = client.receive
assert_message({ "type" => "welcome" }, welcome_msg)
Expand Down Expand Up @@ -50,7 +50,8 @@ def subscribed
ignore: ["ping"],
headers: {
"X-ANYCABLE-RESTORE-SID" => @sid
}
},
protocol: "actioncable-v1-ext-json"
)

assert_message({ "type" => "welcome", "restored" => true }, another_client.receive)
Expand Down
6 changes: 3 additions & 3 deletions node/broker_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func sharedIntegrationRestore(t *testing.T, node *Node, controller *mocks.Contro
sid := "s18"
ids := "user:jack"

prev_session := NewMockSessionWithEnv(sid, node, "ws://test.anycable.io/cable", nil)
prev_session := NewMockSessionWithEnv(sid, node, "ws://test.anycable.io/cable", nil, WithResumable(true))

controller.
On("Authenticate", sid, prev_session.env).
Expand Down Expand Up @@ -115,7 +115,7 @@ func sharedIntegrationRestore(t *testing.T, node *Node, controller *mocks.Contro

prev_session.Disconnect("normal", ws.CloseNormalClosure)

session := NewMockSessionWithEnv("s21", node, fmt.Sprintf("ws://test.anycable.io/cable?sid=%s", sid), nil)
session := NewMockSessionWithEnv("s21", node, fmt.Sprintf("ws://test.anycable.io/cable?sid=%s", sid), nil, WithResumable(true), WithPrevSID(sid))

_, err = node.Authenticate(session)
require.NoError(t, err)
Expand Down Expand Up @@ -173,7 +173,7 @@ func sharedIntegrationRestore(t *testing.T, node *Node, controller *mocks.Contro
Transmissions: []string{`{"type":"welcome","restored":false}`},
}, nil)

new_session := NewMockSessionWithEnv("s42", node, fmt.Sprintf("ws://test.anycable.io/cable?sid=%s", sid), nil)
new_session := NewMockSessionWithEnv("s42", node, fmt.Sprintf("ws://test.anycable.io/cable?sid=%s", sid), nil, WithResumable(true), WithPrevSID(sid))

time.Sleep(4 * time.Second)

Expand Down
42 changes: 28 additions & 14 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,12 @@ func WithDisconnectOnFailure(disconnect bool) AuthOption {
func (n *Node) Authenticate(s *Session, options ...AuthOption) (res *common.ConnectResult, err error) {
opts := newAuthOptions(options)

restored := n.TryRestoreSession(s)
if s.IsResumeable() {
restored := n.TryRestoreSession(s)

if restored {
return &common.ConnectResult{Status: common.SUCCESS}, nil
if restored {
return &common.ConnectResult{Status: common.SUCCESS}, nil
}
}

res, err = n.controller.Authenticate(s.GetID(), s.env)
Expand All @@ -291,8 +293,10 @@ func (n *Node) Authenticate(s *Session, options ...AuthOption) (res *common.Conn
n.handleCallReply(s, res.ToCallResult())
n.markDisconnectable(s, res.DisconnectInterest)

if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
if s.IsResumeable() {
if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
}
}

return
Expand Down Expand Up @@ -354,8 +358,10 @@ func (n *Node) TryRestoreSession(s *Session) (restored bool) {
RestoredIDs: utils.Keys(s.subscriptions.channels),
})

if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
if s.IsResumeable() {
if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
}
}

return true
Expand Down Expand Up @@ -395,8 +401,10 @@ func (n *Node) Subscribe(s *Session, msg *common.Message) (res *common.CommandRe
}

if confirmed {
if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
if s.IsResumeable() {
if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
}
}

if msg.History.Since > 0 || msg.History.Streams != nil {
Expand Down Expand Up @@ -438,8 +446,10 @@ func (n *Node) Unsubscribe(s *Session, msg *common.Message) (res *common.Command
n.handleCommandReply(s, msg, res)
}

if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
if s.IsResumeable() {
if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
}
}

return
Expand Down Expand Up @@ -476,8 +486,10 @@ func (n *Node) Perform(s *Session, msg *common.Message) (res *common.CommandResu

if res != nil {
if n.handleCommandReply(s, msg, res) {
if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
if s.IsResumeable() {
if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
s.Log.Errorf("Failed to persist session in cache: %v", berr)
}
}
}
}
Expand Down Expand Up @@ -590,7 +602,9 @@ func (n *Node) ExecuteRemoteCommand(msg *common.RemoteCommandMessage) {

// Disconnect adds session to disconnector queue and unregister session from hub
func (n *Node) Disconnect(s *Session) error {
n.broker.FinishSession(s.GetID()) // nolint:errcheck
if s.IsResumeable() {
n.broker.FinishSession(s.GetID()) // nolint:errcheck
}

if n.IsShuttingDown() {
if s.IsDisconnectable() {
Expand Down
11 changes: 8 additions & 3 deletions node/node_mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func NewMockNode() *Node {
}

// NewMockSession returns a new session with a specified uid and identifiers equal to uid
func NewMockSession(uid string, node *Node) *Session {
func NewMockSession(uid string, node *Node, opts ...SessionOption) *Session {
session := Session{
executor: node,
closed: true,
Expand All @@ -41,14 +41,19 @@ func NewMockSession(uid string, node *Node) *Session {

session.SetIdentifiers(uid)
session.conn = mocks.NewMockConnection()

for _, opt := range opts {
opt(&session)
}

go session.SendMessages()

return &session
}

// NewMockSession returns a new session with a specified uid, path and headers, and identifiers equal to uid
func NewMockSessionWithEnv(uid string, node *Node, url string, headers *map[string]string) *Session {
session := NewMockSession(uid, node)
func NewMockSessionWithEnv(uid string, node *Node, url string, headers *map[string]string, opts ...SessionOption) *Session {
session := NewMockSession(uid, node, opts...)
session.env = common.NewSessionEnv(url, headers)
return session
}
26 changes: 2 additions & 24 deletions node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ func TestRestoreSession(t *testing.T) {
go node.hub.Run()
defer node.hub.Shutdown()

prev_session := NewMockSession("114", node)
prev_session := NewMockSession("114", node, WithResumable(true))
prev_session.subscriptions.AddChannel("fruits_channel")
prev_session.subscriptions.AddChannelStream("fruits_channel", "arancia")
prev_session.subscriptions.AddChannelStream("fruits_channel", "limoni")
Expand All @@ -584,11 +584,9 @@ func TestRestoreSession(t *testing.T) {
On("Subscribe", mock.Anything).
Return(func(name string) string { return name })

session := NewMockSession("214", node)
session := NewMockSession("214", node, WithResumable(true), WithPrevSID("114"))

t.Run("Successful restore via header", func(t *testing.T) {
session.env.SetHeader("X-ANYCABLE-RESTORE-SID", "114")

res, err := node.Authenticate(session)
require.NoError(t, err)
assert.Equal(t, common.SUCCESS, res.Status)
Expand Down Expand Up @@ -634,31 +632,11 @@ func TestRestoreSession(t *testing.T) {
)
})

t.Run("Successful restore via url", func(t *testing.T) {
session.env.URL = "/cable-test?sid=114"

res, err := node.Authenticate(session)
require.NoError(t, err)
assert.Equal(t, common.SUCCESS, res.Status)

welcome, err := session.conn.Read()
require.NoError(t, err)

require.Equalf(
t,
`{"type":"welcome","sid":"214","restored":true,"restored_ids":["fruits_channel"]}`,
string(welcome),
"Sent message is invalid: %s", welcome,
)
})

t.Run("Failed to restore", func(t *testing.T) {
broker.
On("RestoreSession", "114").
Return(nil, nil)

session.env.SetHeader("-anycable-restore-sid", "114")

session = NewMockSession("154", node)

res, err := node.Authenticate(session)
Expand Down
Loading

0 comments on commit ae86aed

Please sign in to comment.