diff --git a/broker/broker.go b/broker/broker.go index fd950778..692fd054 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -8,10 +8,6 @@ import ( "github.com/anycable/anycable-go/common" ) -const ( - SESSION_ID_HEADER = "X-ANYCABLE-RESTORE-SID" -) - // Broadcaster is responsible for fanning-out messages to the stream clients // and other nodes type Broadcaster interface { diff --git a/cli/options.go b/cli/options.go index d3786dc1..0859b6b8 100644 --- a/cli/options.go +++ b/cli/options.go @@ -5,7 +5,6 @@ import ( "regexp" "strings" - "github.com/anycable/anycable-go/broker" "github.com/anycable/anycable-go/config" "github.com/anycable/anycable-go/node" "github.com/anycable/anycable-go/version" @@ -125,11 +124,6 @@ func NewConfigFromCLI(args []string, opts ...cliOption) (*config.Config, error, c.Headers = strings.Split(strings.ToLower(headers), ",") - // Read session ID header if using a broker - if c.BrokerAdapter != "" { - c.Headers = append(c.Headers, broker.SESSION_ID_HEADER) - } - if len(cookieFilter) > 0 { c.Cookies = strings.Split(cookieFilter, ",") } diff --git a/cli/session_options.go b/cli/session_options.go index 7f51af27..f091ab4c 100644 --- a/cli/session_options.go +++ b/cli/session_options.go @@ -4,6 +4,7 @@ import ( "strconv" "time" + "github.com/anycable/anycable-go/common" "github.com/anycable/anycable-go/node" "github.com/anycable/anycable-go/server" "github.com/apex/log" @@ -12,11 +13,18 @@ import ( const ( pingIntervalParameter = "pi" pingPrecisionParameter = "ptp" + + prevSessionHeader = "X-ANYCABLE-RESTORE-SID" + prevSessionParam = "sid" ) func (r *Runner) sessionOptionsFromProtocol(protocol string) []node.SessionOption { opts := []node.SessionOption{} + if common.IsExtendedActionCableProtocol(protocol) { + opts = append(opts, node.WithResumable(true)) + } + return opts } @@ -36,5 +44,11 @@ func (r *Runner) sessionOptionsFromParams(info *server.RequestInfo) []node.Sessi opts = append(opts, node.WithPingPrecision(val)) } + if hval := info.AnyCableHeader(prevSessionHeader); hval != "" { + opts = append(opts, node.WithPrevSID(hval)) + } else if pval := info.Param(prevSessionParam); pval != "" { + opts = append(opts, node.WithPrevSID(pval)) + } + return opts } diff --git a/common/common.go b/common/common.go index eaba984a..b313a1f0 100644 --- a/common/common.go +++ b/common/common.go @@ -21,6 +21,20 @@ func ActionCableProtocols() []string { return []string{ActionCableV1JSON, ActionCableV1ExtJSON} } +func ActionCableExtendedProtocols() []string { + return []string{ActionCableV1ExtJSON} +} + +func IsExtendedActionCableProtocol(protocol string) bool { + for _, p := range ActionCableExtendedProtocols() { + if p == protocol { + return true + } + } + + return false +} + // Outgoing message types (according to Action Cable protocol) const ( WelcomeType = "welcome" diff --git a/etc/anyt/broker_tests/restore_test.rb b/etc/anyt/broker_tests/restore_test.rb index 3d80d96f..249dc35b 100644 --- a/etc/anyt/broker_tests/restore_test.rb +++ b/etc/anyt/broker_tests/restore_test.rb @@ -8,7 +8,7 @@ def subscribed end before do - client = build_client(ignore: ["ping"]) + client = build_client(ignore: ["ping"], protocol: "actioncable-v1-ext-json") welcome_msg = client.receive assert_message({ "type" => "welcome" }, welcome_msg) @@ -50,7 +50,8 @@ def subscribed ignore: ["ping"], headers: { "X-ANYCABLE-RESTORE-SID" => @sid - } + }, + protocol: "actioncable-v1-ext-json" ) assert_message({ "type" => "welcome", "restored" => true }, another_client.receive) diff --git a/node/broker_integration_test.go b/node/broker_integration_test.go index 557c1a81..52e3ae09 100644 --- a/node/broker_integration_test.go +++ b/node/broker_integration_test.go @@ -52,7 +52,7 @@ func sharedIntegrationRestore(t *testing.T, node *Node, controller *mocks.Contro sid := "s18" ids := "user:jack" - prev_session := NewMockSessionWithEnv(sid, node, "ws://test.anycable.io/cable", nil) + prev_session := NewMockSessionWithEnv(sid, node, "ws://test.anycable.io/cable", nil, WithResumable(true)) controller. On("Authenticate", sid, prev_session.env). @@ -115,7 +115,7 @@ func sharedIntegrationRestore(t *testing.T, node *Node, controller *mocks.Contro prev_session.Disconnect("normal", ws.CloseNormalClosure) - session := NewMockSessionWithEnv("s21", node, fmt.Sprintf("ws://test.anycable.io/cable?sid=%s", sid), nil) + session := NewMockSessionWithEnv("s21", node, fmt.Sprintf("ws://test.anycable.io/cable?sid=%s", sid), nil, WithResumable(true), WithPrevSID(sid)) _, err = node.Authenticate(session) require.NoError(t, err) @@ -173,7 +173,7 @@ func sharedIntegrationRestore(t *testing.T, node *Node, controller *mocks.Contro Transmissions: []string{`{"type":"welcome","restored":false}`}, }, nil) - new_session := NewMockSessionWithEnv("s42", node, fmt.Sprintf("ws://test.anycable.io/cable?sid=%s", sid), nil) + new_session := NewMockSessionWithEnv("s42", node, fmt.Sprintf("ws://test.anycable.io/cable?sid=%s", sid), nil, WithResumable(true), WithPrevSID(sid)) time.Sleep(4 * time.Second) diff --git a/node/node.go b/node/node.go index 6b415be4..d9345957 100644 --- a/node/node.go +++ b/node/node.go @@ -263,10 +263,12 @@ func WithDisconnectOnFailure(disconnect bool) AuthOption { func (n *Node) Authenticate(s *Session, options ...AuthOption) (res *common.ConnectResult, err error) { opts := newAuthOptions(options) - restored := n.TryRestoreSession(s) + if s.IsResumeable() { + restored := n.TryRestoreSession(s) - if restored { - return &common.ConnectResult{Status: common.SUCCESS}, nil + if restored { + return &common.ConnectResult{Status: common.SUCCESS}, nil + } } res, err = n.controller.Authenticate(s.GetID(), s.env) @@ -291,8 +293,10 @@ func (n *Node) Authenticate(s *Session, options ...AuthOption) (res *common.Conn n.handleCallReply(s, res.ToCallResult()) n.markDisconnectable(s, res.DisconnectInterest) - if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { - s.Log.Errorf("Failed to persist session in cache: %v", berr) + if s.IsResumeable() { + if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { + s.Log.Errorf("Failed to persist session in cache: %v", berr) + } } return @@ -354,8 +358,10 @@ func (n *Node) TryRestoreSession(s *Session) (restored bool) { RestoredIDs: utils.Keys(s.subscriptions.channels), }) - if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { - s.Log.Errorf("Failed to persist session in cache: %v", berr) + if s.IsResumeable() { + if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { + s.Log.Errorf("Failed to persist session in cache: %v", berr) + } } return true @@ -395,8 +401,10 @@ func (n *Node) Subscribe(s *Session, msg *common.Message) (res *common.CommandRe } if confirmed { - if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { - s.Log.Errorf("Failed to persist session in cache: %v", berr) + if s.IsResumeable() { + if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { + s.Log.Errorf("Failed to persist session in cache: %v", berr) + } } if msg.History.Since > 0 || msg.History.Streams != nil { @@ -438,8 +446,10 @@ func (n *Node) Unsubscribe(s *Session, msg *common.Message) (res *common.Command n.handleCommandReply(s, msg, res) } - if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { - s.Log.Errorf("Failed to persist session in cache: %v", berr) + if s.IsResumeable() { + if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { + s.Log.Errorf("Failed to persist session in cache: %v", berr) + } } return @@ -476,8 +486,10 @@ func (n *Node) Perform(s *Session, msg *common.Message) (res *common.CommandResu if res != nil { if n.handleCommandReply(s, msg, res) { - if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { - s.Log.Errorf("Failed to persist session in cache: %v", berr) + if s.IsResumeable() { + if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { + s.Log.Errorf("Failed to persist session in cache: %v", berr) + } } } } @@ -590,7 +602,9 @@ func (n *Node) ExecuteRemoteCommand(msg *common.RemoteCommandMessage) { // Disconnect adds session to disconnector queue and unregister session from hub func (n *Node) Disconnect(s *Session) error { - n.broker.FinishSession(s.GetID()) // nolint:errcheck + if s.IsResumeable() { + n.broker.FinishSession(s.GetID()) // nolint:errcheck + } if n.IsShuttingDown() { if s.IsDisconnectable() { diff --git a/node/node_mocks_test.go b/node/node_mocks_test.go index 38e9815c..1395dda2 100644 --- a/node/node_mocks_test.go +++ b/node/node_mocks_test.go @@ -26,7 +26,7 @@ func NewMockNode() *Node { } // NewMockSession returns a new session with a specified uid and identifiers equal to uid -func NewMockSession(uid string, node *Node) *Session { +func NewMockSession(uid string, node *Node, opts ...SessionOption) *Session { session := Session{ executor: node, closed: true, @@ -41,14 +41,19 @@ func NewMockSession(uid string, node *Node) *Session { session.SetIdentifiers(uid) session.conn = mocks.NewMockConnection() + + for _, opt := range opts { + opt(&session) + } + go session.SendMessages() return &session } // NewMockSession returns a new session with a specified uid, path and headers, and identifiers equal to uid -func NewMockSessionWithEnv(uid string, node *Node, url string, headers *map[string]string) *Session { - session := NewMockSession(uid, node) +func NewMockSessionWithEnv(uid string, node *Node, url string, headers *map[string]string, opts ...SessionOption) *Session { + session := NewMockSession(uid, node, opts...) session.env = common.NewSessionEnv(url, headers) return session } diff --git a/node/node_test.go b/node/node_test.go index ee30fc97..ee3a4336 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -564,7 +564,7 @@ func TestRestoreSession(t *testing.T) { go node.hub.Run() defer node.hub.Shutdown() - prev_session := NewMockSession("114", node) + prev_session := NewMockSession("114", node, WithResumable(true)) prev_session.subscriptions.AddChannel("fruits_channel") prev_session.subscriptions.AddChannelStream("fruits_channel", "arancia") prev_session.subscriptions.AddChannelStream("fruits_channel", "limoni") @@ -584,11 +584,9 @@ func TestRestoreSession(t *testing.T) { On("Subscribe", mock.Anything). Return(func(name string) string { return name }) - session := NewMockSession("214", node) + session := NewMockSession("214", node, WithResumable(true), WithPrevSID("114")) t.Run("Successful restore via header", func(t *testing.T) { - session.env.SetHeader("X-ANYCABLE-RESTORE-SID", "114") - res, err := node.Authenticate(session) require.NoError(t, err) assert.Equal(t, common.SUCCESS, res.Status) @@ -634,31 +632,11 @@ func TestRestoreSession(t *testing.T) { ) }) - t.Run("Successful restore via url", func(t *testing.T) { - session.env.URL = "/cable-test?sid=114" - - res, err := node.Authenticate(session) - require.NoError(t, err) - assert.Equal(t, common.SUCCESS, res.Status) - - welcome, err := session.conn.Read() - require.NoError(t, err) - - require.Equalf( - t, - `{"type":"welcome","sid":"214","restored":true,"restored_ids":["fruits_channel"]}`, - string(welcome), - "Sent message is invalid: %s", welcome, - ) - }) - t.Run("Failed to restore", func(t *testing.T) { broker. On("RestoreSession", "114"). Return(nil, nil) - session.env.SetHeader("-anycable-restore-sid", "114") - session = NewMockSession("154", node) res, err := node.Authenticate(session) diff --git a/node/session.go b/node/session.go index 754ae2f0..68225f7f 100644 --- a/node/session.go +++ b/node/session.go @@ -3,7 +3,6 @@ package node import ( "encoding/json" "errors" - "net/url" "sync" "time" @@ -16,9 +15,6 @@ import ( const ( writeWait = 10 * time.Second - - prevSessionHeader = "X-ANYCABLE-RESTORE-SID" - prevSessionParam = "sid" ) // Executor handles incoming commands (messages) @@ -55,6 +51,9 @@ type Session struct { handshakeDeadline time.Time + resumable bool + prevSid string + Connected bool // Could be used to store arbitrary data within a session InternalState map[string]interface{} @@ -107,6 +106,20 @@ func WithMetrics(m metrics.Instrumenter) SessionOption { } } +// WithResumable allows marking session as resumable (so we store its state in cache) +func WithResumable(val bool) SessionOption { + return func(s *Session) { + s.resumable = val + } +} + +// WithPrevSID allows providing the previous session ID to restore from +func WithPrevSID(sid string) SessionOption { + return func(s *Session) { + s.prevSid = sid + } +} + // NewSession build a new Session struct from ws connetion and http request func NewSession(node *Node, conn Connection, url string, headers *map[string]string, uid string, opts ...SessionOption) *Session { session := &Session{ @@ -174,6 +187,10 @@ func (s *Session) IsConnected() bool { return s.Connected } +func (s *Session) IsResumeable() bool { + return s.resumable +} + func (s *Session) maybeDisconnectIdle() { s.mu.Lock() @@ -441,34 +458,8 @@ func (s *Session) RestoreFromCache(cached []byte) error { return nil } -func (s *Session) PrevSid() (psid string) { - if s.env.Headers != nil { - if v, ok := (*s.env.Headers)[prevSessionHeader]; ok { - psid = v - // This header is of one-time use, - // no need to leak it to the RPC app - delete(*s.env.Headers, prevSessionHeader) - return - } - } - - u, err := url.Parse(s.env.URL) - - if err != nil { - return - } - - m, err := url.ParseQuery(u.RawQuery) - - if err != nil { - return - } - - if v, ok := m[prevSessionParam]; ok { - psid = v[0] - } - - return +func (s *Session) PrevSid() string { + return s.prevSid } func (s *Session) disconnectFromNode() { diff --git a/node/session_test.go b/node/session_test.go index ce8723c8..d1c6457e 100644 --- a/node/session_test.go +++ b/node/session_test.go @@ -204,24 +204,6 @@ func TestCacheEntryEmptySession(t *testing.T) { require.NoError(t, err) } -func TestPrevSid(t *testing.T) { - session := Session{} - headers := make(map[string]string) - session.env = common.NewSessionEnv("ws://example.dev/cable", nil) - assert.Equal(t, "", session.PrevSid()) - - session.env = common.NewSessionEnv("ws://example.dev/cable?sid=123", &headers) - assert.Equal(t, "123", session.PrevSid()) - - session.env = common.NewSessionEnv("http://example.dev/cable?jid=xxxx&sid=213", &headers) - assert.Equal(t, "213", session.PrevSid()) - - headers["X-ANYCABLE-RESTORE-SID"] = "456" - - session.env = common.NewSessionEnv("http://example.dev/cable?jid=xxxx&sid=213", &headers) - assert.Equal(t, "456", session.PrevSid()) -} - func TestMarkDisconnectable(t *testing.T) { session := Session{} diff --git a/server/request_info.go b/server/request_info.go index d87c378c..87a22f45 100644 --- a/server/request_info.go +++ b/server/request_info.go @@ -17,7 +17,9 @@ type RequestInfo struct { UID string URL string Headers *map[string]string - Params map[string]string + + anycableHeaders map[string]string + params map[string]string } func NewRequestInfo(r *http.Request, extractor HeadersExtractor) (*RequestInfo, error) { @@ -29,6 +31,15 @@ func NewRequestInfo(r *http.Request, extractor HeadersExtractor) (*RequestInfo, headers = extractor.FromRequest(r) } + anycableHeaders := make(map[string]string) + + // Extract headers prefixed with `X-AnyCable-` from request headers + for k, v := range r.Header { + if strings.HasPrefix(strings.ToLower(k), "x-anycable-") { + anycableHeaders[strings.ToLower(k)] = v[len(v)-1] + } + } + uid, err := FetchUID(r) if err != nil { @@ -53,15 +64,23 @@ func NewRequestInfo(r *http.Request, extractor HeadersExtractor) (*RequestInfo, params[k] = v[len(v)-1] } - return &RequestInfo{UID: uid, Headers: &headers, URL: url, Params: params}, nil + return &RequestInfo{UID: uid, Headers: &headers, URL: url, params: params, anycableHeaders: anycableHeaders}, nil } func (i *RequestInfo) Param(key string) string { - if i.Params == nil { + if i.params == nil { + return "" + } + + return i.params[key] +} + +func (i *RequestInfo) AnyCableHeader(key string) string { + if i.anycableHeaders == nil { return "" } - return i.Params[key] + return i.anycableHeaders[strings.ToLower(key)] } // FetchUID safely extracts uid from `X-Request-ID` header or generates a new one diff --git a/server/request_info_test.go b/server/request_info_test.go index 8d9ca677..bac65767 100644 --- a/server/request_info_test.go +++ b/server/request_info_test.go @@ -64,4 +64,18 @@ func TestRequestInfo(t *testing.T) { blank_info := RequestInfo{} assert.Equal(t, "", blank_info.Param("pi")) }) + + t.Run("With AnyCable Headers", func(t *testing.T) { + req := httptest.NewRequest("GET", "ws://anycable.io/cable?pi=3&pp=no&pi=5", nil) + req.Header.Set("x-anycable-sid", "432") + req.Header.Set("X-AnyCable-Ping-Interval", "10") + info, err := NewRequestInfo(req, nil) + + require.NoError(t, err) + assert.Equal(t, "432", info.AnyCableHeader("X-ANYCABLE-SID")) + assert.Equal(t, "10", info.AnyCableHeader("x-anycable-ping-interval")) + + blank_info := RequestInfo{} + assert.Equal(t, "", blank_info.AnyCableHeader("SID")) + }) }