From befeed16f703c8c2828a0d13370cc39dc4aa83d5 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 9 May 2024 16:23:01 +0800 Subject: [PATCH 1/2] enh: websocket reconnect --- .github/workflows/go.yml | 6 + .github/workflows/push.yml | 6 + .github/workflows/taos.cfg | 5 + .github/workflows/taosadapter.toml | 115 +++++++++++++ ws/client/conn.go | 68 +++++--- ws/schemaless/config.go | 41 +++-- ws/schemaless/schemaless.go | 169 +++++++++++++++----- ws/schemaless/schemaless_test.go | 99 ++++++++++++ ws/stmt/config.go | 41 +++-- ws/stmt/connector.go | 183 +++++++++++++++------ ws/stmt/rows.go | 13 +- ws/stmt/stmt.go | 32 ++-- ws/stmt/stmt_test.go | 86 ++++++++++ ws/tmq/config.go | 15 ++ ws/tmq/consumer.go | 248 ++++++++++++++++++++--------- ws/tmq/consumer_test.go | 136 ++++++++++++++++ 16 files changed, 1041 insertions(+), 222 deletions(-) create mode 100644 .github/workflows/taos.cfg create mode 100644 .github/workflows/taosadapter.toml diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 8e35cbb..03dc0dc 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -144,6 +144,12 @@ jobs: - name: checkout uses: actions/checkout@v3 + - name: copy taos cfg + run: | + sudo mkdir -p /etc/taos + sudo cp ./.github/workflows/taos.cfg /etc/taos/taos.cfg + sudo cp ./.github/workflows/taosadapter.toml /etc/taos/taosadapter.toml + - name: shell run: | cat >start.sh<start.sh< 64*1024 { + e.Msg = new(bytes.Buffer) + } else { + e.Msg.Reset() + } + if len(e.ErrorChan) > 0 { + e.ErrorChan = make(chan error, 1) + } } +var ClosedError = errors.New("websocket closed") + type Client struct { conn *websocket.Conn status uint32 @@ -63,9 +74,10 @@ type Client struct { TextMessageHandler func(message []byte) BinaryMessageHandler func(message []byte) ErrorHandler func(err error) - SendMessageHandler func(envelope *Envelope) - once sync.Once - errHandlerOnce sync.Once + //SendMessageHandler func(envelope *Envelope) + once sync.Once + errHandlerOnce sync.Once + err error } func NewClient(conn *websocket.Conn, sendChanLength uint) *Client { @@ -80,9 +92,9 @@ func NewClient(conn *websocket.Conn, sendChanLength uint) *Client { TextMessageHandler: func(message []byte) {}, BinaryMessageHandler: func(message []byte) {}, ErrorHandler: func(err error) {}, - SendMessageHandler: func(envelope *Envelope) { - GlobalEnvelopePool.Put(envelope) - }, + //SendMessageHandler: func(envelope *Envelope) { + // GlobalEnvelopePool.Put(envelope) + //}, } } @@ -117,41 +129,61 @@ func (c *Client) WritePump() { defer func() { ticker.Stop() }() + for { select { case message, ok := <-c.sendChan: if !ok { - return + if message == nil { + return + } + message.ErrorChan <- ClosedError + continue } c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait)) err := c.conn.WriteMessage(message.Type, message.Msg.Bytes()) if err != nil { + message.ErrorChan <- err c.handleError(err) - return + c.Close() + for message := range c.sendChan { + if message == nil { + return + } + message.ErrorChan <- ClosedError + } } - c.SendMessageHandler(message) + message.ErrorChan <- nil case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { c.handleError(err) - return + c.Close() + for message := range c.sendChan { + if message == nil { + return + } + message.ErrorChan <- ClosedError + } } } } } -func (c *Client) Send(envelope *Envelope) { +func (c *Client) Send(envelope *Envelope) error { if !c.IsRunning() { - return + return ClosedError } + var err error defer func() { // maybe closed if recover() != nil { - + err = ClosedError return } }() c.sendChan <- envelope + return err } func (c *Client) GetEnvelope() *Envelope { @@ -168,8 +200,8 @@ func (c *Client) IsRunning() bool { func (c *Client) Close() { c.once.Do(func() { - close(c.sendChan) atomic.StoreUint32(&c.status, StatusStop) + close(c.sendChan) if c.conn != nil { c.conn.Close() } diff --git a/ws/schemaless/config.go b/ws/schemaless/config.go index d62eb3b..7599984 100644 --- a/ws/schemaless/config.go +++ b/ws/schemaless/config.go @@ -10,19 +10,22 @@ const ( ) type Config struct { - url string - chanLength uint - user string - password string - db string - readTimeout time.Duration - writeTimeout time.Duration - errorHandler func(error) - enableCompression bool + url string + chanLength uint + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + errorHandler func(error) + enableCompression bool + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int } func NewConfig(url string, chanLength uint, opts ...func(*Config)) *Config { - c := Config{url: url, chanLength: chanLength} + c := Config{url: url, chanLength: chanLength, reconnectRetryCount: 3, reconnectIntervalMs: 2000} for _, opt := range opts { opt(&c) } @@ -71,3 +74,21 @@ func SetEnableCompression(enableCompression bool) func(*Config) { c.enableCompression = enableCompression } } + +func SetAutoReconnect(reconnect bool) func(*Config) { + return func(c *Config) { + c.autoReconnect = reconnect + } +} + +func SetReconnectIntervalMs(reconnectIntervalMs int) func(*Config) { + return func(c *Config) { + c.reconnectIntervalMs = reconnectIntervalMs + } +} + +func SetReconnectRetryCount(reconnectRetryCount int) func(*Config) { + return func(c *Config) { + c.reconnectRetryCount = reconnectRetryCount + } +} diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index 5494745..7bd3e1a 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -23,17 +23,23 @@ const ( ) type Schemaless struct { - client *client.Client - sendList *list.List - url string - user string - password string - db string - readTimeout time.Duration - lock sync.Mutex - once sync.Once - closeChan chan struct{} - errorHandler func(error) + client *client.Client + sendList *list.List + url string + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + lock sync.Mutex + once sync.Once + closeChan chan struct{} + errorHandler func(error) + dialer *websocket.Dialer + chanLength uint + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int } func NewSchemaless(config *Config) (*Schemaless, error) { @@ -47,21 +53,28 @@ func NewSchemaless(config *Config) (*Schemaless, error) { wsUrl.Path = "/ws" dialer := common.DefaultDialer dialer.EnableCompression = config.enableCompression - ws, _, err := dialer.Dial(wsUrl.String(), nil) - ws.EnableWriteCompression(config.enableCompression) + conn, _, err := dialer.Dial(wsUrl.String(), nil) if err != nil { return nil, fmt.Errorf("dial ws error: %s", err) } - + conn.EnableWriteCompression(config.enableCompression) s := Schemaless{ - client: client.NewClient(ws, config.chanLength), + client: client.NewClient(conn, config.chanLength), sendList: list.New(), - url: config.url, + url: wsUrl.String(), user: config.user, password: config.password, db: config.db, closeChan: make(chan struct{}), errorHandler: config.errorHandler, + dialer: &dialer, + chanLength: config.chanLength, + } + + if config.autoReconnect { + s.autoReconnect = true + s.reconnectIntervalMs = config.reconnectIntervalMs + s.reconnectRetryCount = config.reconnectRetryCount } if config.readTimeout > 0 { @@ -69,21 +82,59 @@ func NewSchemaless(config *Config) (*Schemaless, error) { } if config.writeTimeout > 0 { - s.client.WriteWait = config.writeTimeout + s.writeTimeout = config.writeTimeout } - s.client.ErrorHandler = s.handleError - s.client.TextMessageHandler = s.handleTextMessage - go s.client.ReadPump() - go s.client.WritePump() - - if err = s.connect(); err != nil { + if err = connect(conn, s.user, s.password, s.db, s.writeTimeout, s.readTimeout); err != nil { return nil, fmt.Errorf("connect ws error: %s", err) } + s.initClient(s.client) return &s, nil } +func (s *Schemaless) initClient(c *client.Client) { + if s.writeTimeout > 0 { + c.WriteWait = s.writeTimeout + } + c.ErrorHandler = s.handleError + c.TextMessageHandler = s.handleTextMessage + + go c.ReadPump() + go c.WritePump() +} + +func (s *Schemaless) reconnect() error { + reconnected := false + for i := 0; i < s.reconnectRetryCount; i++ { + time.Sleep(time.Duration(s.reconnectIntervalMs) * time.Millisecond) + conn, _, err := s.dialer.Dial(s.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(s.dialer.EnableCompression) + if err = connect(conn, s.user, s.password, s.db, s.writeTimeout, s.readTimeout); err != nil { + conn.Close() + continue + } + if s.client != nil { + s.client.Close() + } + c := client.NewClient(conn, s.chanLength) + s.initClient(c) + s.client = c + reconnected = true + break + } + if !reconnected { + if s.client != nil { + s.client.Close() + } + return errors.New("reconnect failed") + } + return nil +} + func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl int, reqID int64) error { if reqID == 0 { reqID = common.GetReqID() @@ -102,15 +153,25 @@ func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl in return err } action := &client.WSAction{Action: insertAction, Args: args} - envelope := s.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.client.PutEnvelope(envelope) return err } respBytes, err := s.sendText(uint64(reqID), envelope) if err != nil { - return err + if !s.autoReconnect { + return err + } + err = s.reconnect() + if err != nil { + return err + } + respBytes, err = s.sendText(uint64(reqID), envelope) + if err != nil { + return err + } } var resp schemalessResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -133,13 +194,16 @@ func (s *Schemaless) Close() { }) } -func (s *Schemaless) connect() error { - reqID := uint64(common.GetReqID()) +var ( + ConnectTimeoutErr = errors.New("schemaless connect timeout") +) + +func connect(ws *websocket.Conn, user string, password string, db string, writeTimeout time.Duration, readTimeout time.Duration) error { req := &wsConnectReq{ - ReqID: reqID, - User: s.user, - Password: s.password, - DB: s.db, + ReqID: 0, + User: user, + Password: password, + DB: db, } args, err := client.JsonI.Marshal(req) if err != nil { @@ -149,14 +213,29 @@ func (s *Schemaless) connect() error { Action: connAction, Args: args, } - envelope := s.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + connectAction, err := client.JsonI.Marshal(action) if err != nil { - s.client.PutEnvelope(envelope) return err } - - respBytes, err := s.sendText(reqID, envelope) + ws.SetWriteDeadline(time.Now().Add(writeTimeout)) + err = ws.WriteMessage(websocket.TextMessage, connectAction) + if err != nil { + return err + } + done := make(chan struct{}) + ctx, cancel := context.WithTimeout(context.Background(), readTimeout) + var respBytes []byte + go func() { + _, respBytes, err = ws.ReadMessage() + close(done) + }() + select { + case <-done: + cancel() + case <-ctx.Done(): + cancel() + return ConnectTimeoutErr + } if err != nil { return err } @@ -182,7 +261,20 @@ func (s *Schemaless) send(reqID uint64, envelope *client.Envelope) ([]byte, erro channel: make(chan []byte, 1), } element := s.addMessageOutChan(channel) - s.client.Send(envelope) + err := s.client.Send(envelope) + if err != nil { + s.lock.Lock() + s.sendList.Remove(element) + s.lock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + s.lock.Lock() + s.sendList.Remove(element) + s.lock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), s.readTimeout) defer cancel() select { @@ -259,5 +351,4 @@ func (s *Schemaless) handleError(err error) { if s.errorHandler != nil { s.errorHandler(err) } - s.Close() } diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index 83cf47d..1d7cc8f 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -1,10 +1,15 @@ package schemaless import ( + "errors" "fmt" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" @@ -133,3 +138,97 @@ func before() error { func after() error { return doRequest("drop database test_schemaless_ws") } + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port, "--logLevel", "debug") +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 30; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil + time.Sleep(time.Second) +} + +func TestSchemalessReconnect(t *testing.T) { + port := "36041" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + if err != nil { + t.Fatal(err) + } + defer func() { + stopTaosadapter(cmd) + }() + err = doRequest("drop database if exists test_schemaless_reconnect") + if err != nil { + t.Fatal(err) + } + err = doRequest("create database if not exists test_schemaless_reconnect") + if err != nil { + t.Fatal(err) + } + s, err := NewSchemaless(NewConfig(fmt.Sprintf("ws://localhost:%s", port), 1, + SetDb("test_schemaless_reconnect"), + SetReadTimeout(10*time.Second), + SetWriteTimeout(10*time.Second), + SetUser("root"), + SetPassword("taosdata"), + //SetEnableCompression(true), + SetErrorHandler(func(err error) { + t.Log(err) + }), + SetAutoReconnect(true), + SetReconnectIntervalMs(2000), + SetReconnectRetryCount(3), + )) + if err != nil { + t.Fatal(err) + } + stopTaosadapter(cmd) + time.Sleep(time.Second * 5) + go func() { + time.Sleep(time.Second * 3) + err = startTaosadapter(cmd, port) + if err != nil { + t.Error(err) + return + } + }() + data := "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837600000" + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + if err != nil { + t.Fatal(err) + } +} diff --git a/ws/stmt/config.go b/ws/stmt/config.go index 7eab614..3b533cc 100644 --- a/ws/stmt/config.go +++ b/ws/stmt/config.go @@ -6,22 +6,27 @@ import ( ) type Config struct { - Url string - ChanLength uint - MessageTimeout time.Duration - WriteWait time.Duration - ErrorHandler func(connector *Connector, err error) - CloseHandler func() - User string - Password string - DB string - EnableCompression bool + Url string + ChanLength uint + MessageTimeout time.Duration + WriteWait time.Duration + ErrorHandler func(connector *Connector, err error) + CloseHandler func() + User string + Password string + DB string + EnableCompression bool + AutoReconnect bool + ReconnectIntervalMs int + ReconnectRetryCount int } func NewConfig(url string, chanLength uint) *Config { return &Config{ - Url: url, - ChanLength: chanLength, + Url: url, + ChanLength: chanLength, + ReconnectRetryCount: 3, + ReconnectIntervalMs: 2000, } } func (c *Config) SetConnectUser(user string) error { @@ -65,3 +70,15 @@ func (c *Config) SetCloseHandler(f func()) { func (c *Config) SetEnableCompression(enableCompression bool) { c.EnableCompression = enableCompression } + +func (c *Config) SetAutoReconnect(reconnect bool) { + c.AutoReconnect = reconnect +} + +func (c *Config) SetReconnectIntervalMs(reconnectIntervalMs int) { + c.ReconnectIntervalMs = reconnectIntervalMs +} + +func (c *Config) SetReconnectRetryCount(reconnectRetryCount int) { + c.ReconnectRetryCount = reconnectRetryCount +} diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index 0cb2d39..f92596b 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -19,17 +19,26 @@ import ( ) type Connector struct { - client *client.Client - requestID uint64 - listLock sync.RWMutex - sendChanList *list.List - writeTimeout time.Duration - readTimeout time.Duration - config *Config - closeOnce sync.Once - closeChan chan struct{} - customErrorHandler func(*Connector, error) - customCloseHandler func() + client *client.Client + requestID uint64 + listLock sync.RWMutex + sendChanList *list.List + writeTimeout time.Duration + readTimeout time.Duration + config *Config + closeOnce sync.Once + closeChan chan struct{} + customErrorHandler func(*Connector, error) + customCloseHandler func() + url string + chanLength uint + dialer *websocket.Dialer + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int + user string + password string + db string } var ( @@ -66,15 +75,58 @@ func NewConnector(config *Config) (*Connector, error) { if config.MessageTimeout <= 0 { config.MessageTimeout = common.DefaultMessageTimeout } + err = connect(ws, config.User, config.Password, config.DB, writeTimeout, readTimeout) + if err != nil { + return nil, err + } + wsClient := client.NewClient(ws, config.ChanLength) + connector = &Connector{ + client: wsClient, + requestID: 0, + listLock: sync.RWMutex{}, + sendChanList: list.New(), + writeTimeout: writeTimeout, + readTimeout: readTimeout, + config: config, + closeOnce: sync.Once{}, + closeChan: make(chan struct{}), + customErrorHandler: config.ErrorHandler, + customCloseHandler: config.CloseHandler, + url: u.String(), + dialer: &dialer, + chanLength: config.ChanLength, + autoReconnect: config.AutoReconnect, + reconnectIntervalMs: config.ReconnectIntervalMs, + reconnectRetryCount: config.ReconnectRetryCount, + user: config.User, + password: config.Password, + db: config.DB, + } + connector.initClient(connector.client) + return connector, nil +} + +func (c *Connector) initClient(client *client.Client) { + if c.writeTimeout > 0 { + client.WriteWait = c.writeTimeout + } + client.TextMessageHandler = c.handleTextMessage + client.BinaryMessageHandler = c.handleBinaryMessage + client.ErrorHandler = c.handleError + go client.WritePump() + go client.ReadPump() +} + +func connect(ws *websocket.Conn, user string, password string, db string, writeTimeout time.Duration, readTimeout time.Duration) error { req := &ConnectReq{ ReqID: 0, - User: config.User, - Password: config.Password, - DB: config.DB, + User: user, + Password: password, + DB: db, } args, err := client.JsonI.Marshal(req) if err != nil { - return nil, err + return err } action := &client.WSAction{ Action: STMTConnect, @@ -82,12 +134,12 @@ func NewConnector(config *Config) (*Connector, error) { } connectAction, err := client.JsonI.Marshal(action) if err != nil { - return nil, err + return err } ws.SetWriteDeadline(time.Now().Add(writeTimeout)) err = ws.WriteMessage(websocket.TextMessage, connectAction) if err != nil { - return nil, err + return err } done := make(chan struct{}) ctx, cancel := context.WithTimeout(context.Background(), readTimeout) @@ -101,41 +153,20 @@ func NewConnector(config *Config) (*Connector, error) { cancel() case <-ctx.Done(): cancel() - return nil, ConnectTimeoutErr + return ConnectTimeoutErr } if err != nil { - return nil, err + return err } var resp ConnectResp err = client.JsonI.Unmarshal(respBytes, &resp) if err != nil { - return nil, err + return err } if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - wsClient := client.NewClient(ws, config.ChanLength) - wsClient.WriteWait = writeTimeout - connector = &Connector{ - client: wsClient, - requestID: 0, - listLock: sync.RWMutex{}, - sendChanList: list.New(), - writeTimeout: writeTimeout, - readTimeout: readTimeout, - config: config, - closeOnce: sync.Once{}, - closeChan: make(chan struct{}), - customErrorHandler: config.ErrorHandler, - customCloseHandler: config.CloseHandler, + return taosErrors.NewError(resp.Code, resp.Message) } - - wsClient.TextMessageHandler = connector.handleTextMessage - wsClient.BinaryMessageHandler = connector.handleBinaryMessage - wsClient.ErrorHandler = connector.handleError - go wsClient.WritePump() - go wsClient.ReadPump() - return connector, nil + return nil } func (c *Connector) handleTextMessage(message []byte) { @@ -191,7 +222,20 @@ func (c *Connector) send(reqID uint64, envelope *client.Envelope) ([]byte, error channel: make(chan []byte, 1), } element := c.addMessageOutChan(channel) - c.client.Send(envelope) + err := c.client.Send(envelope) + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), c.readTimeout) defer cancel() select { @@ -210,6 +254,7 @@ func (c *Connector) send(reqID uint64, envelope *client.Envelope) ([]byte, error func (c *Connector) sendTextWithoutResp(envelope *client.Envelope) { envelope.Type = websocket.TextMessage c.client.Send(envelope) + <-envelope.ErrorChan } func (c *Connector) findOutChanByID(index uint64) *list.Element { @@ -244,13 +289,45 @@ func (c *Connector) handleError(err error) { if c.customErrorHandler != nil { c.customErrorHandler(c, err) } - c.Close() + //c.Close() } func (c *Connector) generateReqID() uint64 { return atomic.AddUint64(&c.requestID, 1) } +func (c *Connector) reconnect() error { + reconnected := false + for i := 0; i < c.reconnectRetryCount; i++ { + time.Sleep(time.Duration(c.reconnectIntervalMs) * time.Millisecond) + conn, _, err := c.dialer.Dial(c.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(c.dialer.EnableCompression) + err = connect(conn, c.user, c.password, c.db, c.writeTimeout, c.readTimeout) + if err != nil { + conn.Close() + continue + } + if c.client != nil { + c.client.Close() + } + cl := client.NewClient(conn, c.chanLength) + c.initClient(cl) + c.client = cl + reconnected = true + break + } + if !reconnected { + if c.client != nil { + c.client.Close() + } + return errors.New("reconnect failed") + } + return nil +} + func (c *Connector) Init() (*Stmt, error) { reqID := c.generateReqID() req := &InitReq{ @@ -264,15 +341,25 @@ func (c *Connector) Init() (*Stmt, error) { Action: STMTInit, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return nil, err + if !c.autoReconnect { + return nil, err + } + err = c.reconnect() + if err != nil { + return nil, err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return nil, err + } } var resp InitResp err = client.JsonI.Unmarshal(respBytes, &resp) diff --git a/ws/stmt/rows.go b/ws/stmt/rows.go index 78f6c75..5247b55 100644 --- a/ws/stmt/rows.go +++ b/ws/stmt/rows.go @@ -23,6 +23,7 @@ type Rows struct { resultID uint64 block []byte conn *Connector + client *client.Client fieldsCount int fieldsNames []string fieldsTypes []uint8 @@ -88,10 +89,10 @@ func (rs *Rows) taosFetchBlock() error { Args: args, } rs.buf.Reset() - envelope := rs.conn.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - rs.conn.client.PutEnvelope(envelope) return err } respBytes, err := rs.conn.sendText(reqID, envelope) @@ -129,10 +130,10 @@ func (rs *Rows) fetchBlock() error { Args: args, } rs.buf.Reset() - envelope := rs.conn.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - rs.conn.client.PutEnvelope(envelope) return err } respBytes, err := rs.conn.sendText(rs.resultID, envelope) @@ -160,10 +161,10 @@ func (rs *Rows) freeResult() error { Args: args, } rs.buf.Reset() - envelope := rs.conn.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - rs.conn.client.PutEnvelope(envelope) return err } rs.conn.sendTextWithoutResp(envelope) diff --git a/ws/stmt/stmt.go b/ws/stmt/stmt.go index 9833647..373b763 100644 --- a/ws/stmt/stmt.go +++ b/ws/stmt/stmt.go @@ -31,10 +31,10 @@ func (s *Stmt) Prepare(sql string) error { Action: STMTPrepare, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -67,10 +67,10 @@ func (s *Stmt) SetTableName(name string) error { Action: STMTSetTableName, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -103,7 +103,8 @@ func (s *Stmt) SetTags(tags *param.Param, bindType *param.ColumnType) error { binary.LittleEndian.PutUint64(reqData, reqID) binary.LittleEndian.PutUint64(reqData[8:], s.id) binary.LittleEndian.PutUint64(reqData[16:], SetTagsMessage) - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) envelope.Msg.Grow(24 + len(block)) envelope.Msg.Write(reqData) envelope.Msg.Write(block) @@ -132,13 +133,13 @@ func (s *Stmt) BindParam(params []*param.Param, bindType *param.ColumnType) erro binary.LittleEndian.PutUint64(reqData, reqID) binary.LittleEndian.PutUint64(reqData[8:], s.id) binary.LittleEndian.PutUint64(reqData[16:], BindMessage) - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) envelope.Msg.Grow(24 + len(block)) envelope.Msg.Write(reqData) envelope.Msg.Write(block) err = client.JsonI.NewEncoder(envelope.Msg).Encode(reqData) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendBinary(reqID, envelope) @@ -170,10 +171,10 @@ func (s *Stmt) AddBatch() error { Action: STMTAddBatch, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -205,10 +206,10 @@ func (s *Stmt) Exec() error { Action: STMTExec, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -245,10 +246,10 @@ func (s *Stmt) UseResult() (*Rows, error) { Action: STMTUseResult, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return nil, err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -266,6 +267,7 @@ func (s *Stmt) UseResult() (*Rows, error) { return &Rows{ buf: &bytes.Buffer{}, conn: s.connector, + client: s.connector.client, resultID: resp.ResultID, fieldsCount: resp.FieldsCount, fieldsNames: resp.FieldsNames, @@ -289,10 +291,10 @@ func (s *Stmt) Close() error { Action: STMTClose, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } s.connector.sendTextWithoutResp(envelope) diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index 7cd9496..6baefc5 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -2,11 +2,16 @@ package stmt import ( "database/sql/driver" + "errors" "fmt" "io" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" @@ -1017,3 +1022,84 @@ func TestSTMTQuery(t *testing.T) { assert.Equal(t, "tb2", row3[27]) } } + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port) +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 10; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil +} + +func TestSTMTReconnect(t *testing.T) { + port := "36042" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + if err != nil { + t.Fatal(err) + } + defer func() { + stopTaosadapter(cmd) + }() + config := NewConfig("ws://127.0.0.1:"+port, 0) + config.SetConnectUser("root") + config.SetConnectPass("taosdata") + config.SetMessageTimeout(10 * time.Second) + config.SetWriteWait(common.DefaultWriteWait) + config.SetEnableCompression(true) + config.SetErrorHandler(func(connector *Connector, err error) { + t.Log(err) + }) + config.SetCloseHandler(func() { + t.Log("stmt websocket closed") + }) + config.SetAutoReconnect(true) + config.SetReconnectRetryCount(3) + config.SetReconnectIntervalMs(2000) + connector, err := NewConnector(config) + if err != nil { + t.Error(err) + return + } + stmt, err := connector.Init() + assert.NoError(t, err) + stmt.Close() + stopTaosadapter(cmd) + go func() { + time.Sleep(time.Second * 3) + startTaosadapter(cmd, port) + }() + stmt, err = connector.Init() + assert.NoError(t, err) + stmt.Close() +} diff --git a/ws/tmq/config.go b/ws/tmq/config.go index 99c96a3..e119dcf 100644 --- a/ws/tmq/config.go +++ b/ws/tmq/config.go @@ -20,6 +20,9 @@ type config struct { SnapshotEnable string WithTableName string EnableCompression bool + AutoReconnect bool + ReconnectIntervalMs int + ReconnectRetryCount int } func newConfig(url string, chanLength uint) *config { @@ -84,3 +87,15 @@ func (c *config) setWithTableName(withTableName string) { func (c *config) setEnableCompression(enableCompression bool) { c.EnableCompression = enableCompression } + +func (c *config) setAutoReconnect(autoReconnect bool) { + c.AutoReconnect = autoReconnect +} + +func (c *config) setReconnectIntervalMs(reconnectIntervalMs int) { + c.ReconnectIntervalMs = reconnectIntervalMs +} + +func (c *config) setReconnectRetryCount(reconnectRetryCount int) { + c.ReconnectRetryCount = reconnectRetryCount +} diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index b4f3e8b..5352c2c 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -23,27 +23,33 @@ import ( ) type Consumer struct { - client *client.Client - requestID uint64 - err error - dataParser *parser.TMQRawDataParser - listLock sync.RWMutex - sendChanList *list.List - messageTimeout time.Duration - autoCommit bool - autoCommitInterval time.Duration - nextAutoCommitTime time.Time - url string - user string - password string - groupID string - clientID string - offsetRest string - snapshotEnable string - withTableName string - closeOnce sync.Once - closeChan chan struct{} - topics []string + client *client.Client + requestID uint64 + err error + dataParser *parser.TMQRawDataParser + listLock sync.RWMutex + sendChanList *list.List + messageTimeout time.Duration + autoCommit bool + autoCommitInterval time.Duration + nextAutoCommitTime time.Time + url string + user string + password string + groupID string + clientID string + offsetRest string + snapshotEnable string + withTableName string + closeOnce sync.Once + closeChan chan struct{} + topics []string + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int + chanLength uint + writeWait time.Duration + dialer *websocket.Dialer } type IndexedChan struct { @@ -94,34 +100,75 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { wsClient := client.NewClient(ws, config.ChanLength) consumer := &Consumer{ - client: wsClient, - requestID: 0, - sendChanList: list.New(), - messageTimeout: config.MessageTimeout, - url: config.Url, - user: config.User, - password: config.Password, - groupID: config.GroupID, - clientID: config.ClientID, - offsetRest: config.OffsetRest, - autoCommit: autoCommit, - autoCommitInterval: autoCommitInterval, - snapshotEnable: config.SnapshotEnable, - withTableName: config.WithTableName, - closeChan: make(chan struct{}), - dataParser: parser.NewTMQRawDataParser(), - } - if config.WriteWait > 0 { - wsClient.WriteWait = config.WriteWait - } - wsClient.BinaryMessageHandler = consumer.handleBinaryMessage - wsClient.TextMessageHandler = consumer.handleTextMessage - wsClient.ErrorHandler = consumer.handleError - go wsClient.WritePump() - go wsClient.ReadPump() + client: wsClient, + requestID: 0, + sendChanList: list.New(), + messageTimeout: config.MessageTimeout, + url: u.String(), + user: config.User, + password: config.Password, + groupID: config.GroupID, + clientID: config.ClientID, + offsetRest: config.OffsetRest, + autoCommit: autoCommit, + autoCommitInterval: autoCommitInterval, + snapshotEnable: config.SnapshotEnable, + withTableName: config.WithTableName, + closeChan: make(chan struct{}), + dataParser: parser.NewTMQRawDataParser(), + autoReconnect: config.AutoReconnect, + reconnectIntervalMs: config.ReconnectIntervalMs, + reconnectRetryCount: config.ReconnectRetryCount, + chanLength: config.ChanLength, + writeWait: config.WriteWait, + dialer: &dialer, + } + consumer.initClient(consumer.client) return consumer, nil } +func (c *Consumer) initClient(client *client.Client) { + if c.writeWait > 0 { + client.WriteWait = c.writeWait + } + client.BinaryMessageHandler = c.handleBinaryMessage + client.TextMessageHandler = c.handleTextMessage + client.ErrorHandler = c.handleError + go client.WritePump() + go client.ReadPump() +} + +func (c *Consumer) reconnect() error { + reconnected := false + for i := 0; i < c.reconnectRetryCount; i++ { + time.Sleep(time.Duration(c.reconnectIntervalMs) * time.Millisecond) + conn, _, err := c.dialer.Dial(c.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(c.dialer.EnableCompression) + cl := client.NewClient(conn, c.chanLength) + c.initClient(cl) + if c.client != nil { + c.client.Close() + } + c.client = cl + if len(c.topics) > 0 { + err = c.doSubscribe(c.topics, false) + if err != nil { + c.client.Close() + continue + } + } + reconnected = true + break + } + if !reconnected { + return errors.New("reconnect failed") + } + return nil +} + func configMapToConfig(m *tmq.ConfigMap) (*config, error) { url, err := m.Get("ws.url", "") if err != nil { @@ -183,6 +230,18 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { if err != nil { return nil, err } + autoReconnect, err := m.Get("ws.autoReconnect", false) + if err != nil { + return nil, err + } + reconnectIntervalMs, err := m.Get("ws.reconnectIntervalMs", int(2000)) + if err != nil { + return nil, err + } + reconnectRetryCount, err := m.Get("ws.reconnectRetryCount", int(3)) + if err != nil { + return nil, err + } config := newConfig(url.(string), chanLen.(uint)) err = config.setMessageTimeout(messageTimeout.(time.Duration)) if err != nil { @@ -202,6 +261,9 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { config.setSnapshotEnable(enableSnapshot.(string)) config.setWithTableName(withTableName.(string)) config.setEnableCompression(enableCompression.(bool)) + config.setAutoReconnect(autoReconnect.(bool)) + config.setReconnectIntervalMs(reconnectIntervalMs.(int)) + config.setReconnectRetryCount(reconnectRetryCount.(int)) return config, nil } @@ -240,8 +302,9 @@ func (c *Consumer) handleBinaryMessage(message []byte) { } func (c *Consumer) handleError(err error) { - c.err = &WSError{err: err} - c.Close() + if !c.autoReconnect { + c.err = &WSError{err: err} + } } func (c *Consumer) generateReqID() uint64 { @@ -303,7 +366,6 @@ var ClosedErr = errors.New("connection closed") func (c *Consumer) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { if !c.client.IsRunning() { - c.client.PutEnvelope(envelope) return nil, ClosedErr } channel := &IndexedChan{ @@ -312,7 +374,20 @@ func (c *Consumer) sendText(reqID uint64, envelope *client.Envelope) ([]byte, er } element := c.addMessageOutChan(channel) envelope.Type = websocket.TextMessage - c.client.Send(envelope) + err := c.client.Send(envelope) + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), c.messageTimeout) defer cancel() select { @@ -335,6 +410,10 @@ func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error { } func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error { + return c.doSubscribe(topics, c.autoReconnect) +} + +func (c *Consumer) doSubscribe(topics []string, reconnect bool) error { if c.err != nil { return c.err } @@ -359,15 +438,25 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err Action: TMQSubscribe, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return err + if !reconnect { + return err + } + err = c.reconnect() + if err != nil { + return err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return err + } } var resp SubscribeResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -410,15 +499,25 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { Action: TMQPoll, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return tmq.NewTMQErrorWithErr(err) } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return tmq.NewTMQErrorWithErr(err) + if !c.autoReconnect { + return tmq.NewTMQErrorWithErr(err) + } + err = c.reconnect() + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } } var resp PollResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -510,10 +609,10 @@ func (c *Consumer) fetchJsonMeta(messageID uint64) (*tmq.Meta, error) { Action: TMQFetchJsonMeta, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -550,10 +649,10 @@ func (c *Consumer) fetch(messageID uint64) ([]*tmq.Data, error) { Action: TMQFetchRaw, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -603,10 +702,10 @@ func (c *Consumer) doCommit() error { Action: TMQCommit, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) @@ -640,10 +739,10 @@ func (c *Consumer) Unsubscribe() error { Action: TMQUnsubscribe, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) @@ -679,10 +778,10 @@ func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { Action: TMQGetTopicAssignment, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -729,10 +828,10 @@ func (c *Consumer) Seek(partition tmq.TopicPartition, ignoredTimeoutMs int) erro Action: TMQSeek, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) @@ -771,10 +870,10 @@ func (c *Consumer) Committed(partitions []tmq.TopicPartition, timeoutMs int) (of Action: TMQCommitted, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -803,6 +902,8 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti if c.err != nil { return nil, c.err } + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) for i := 0; i < len(offsets); i++ { reqID := c.generateReqID() req := &CommitOffsetReq{ @@ -819,10 +920,9 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti Action: TMQCommitOffset, Args: args, } - envelope := c.client.GetEnvelope() + envelope.Reset() err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -862,10 +962,10 @@ func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.Topi Action: TMQPosition, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index 6407984..ef8e0ac 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -1,10 +1,15 @@ package tmq import ( + "errors" "fmt" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" @@ -819,3 +824,134 @@ func TestMeta(t *testing.T) { } } } + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port) +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 10; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil +} + +func prepareSubReconnectEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_sub_reconnect_topic", + "drop database if exists test_ws_tmq_sub_reconnect", + "create database test_ws_tmq_sub_reconnect vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_sub_reconnect_topic as database test_ws_tmq_sub_reconnect", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanSubReconnectEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_sub_reconnect_topic", + "drop database if exists test_ws_tmq_sub_reconnect", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestSubscribeReconnect(t *testing.T) { + port := "36043" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + assert.NoError(t, err) + defer func() { + stopTaosadapter(cmd) + }() + prepareSubReconnectEnv() + defer cleanSubReconnectEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:" + port, + "ws.message.channelLen": uint(0), + "ws.message.timeout": time.Second * 10, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + "ws.autoReconnect": true, + "ws.reconnectIntervalMs": 3000, + "ws.reconnectRetryCount": 3, + }) + assert.NoError(t, err) + stopTaosadapter(cmd) + go func() { + time.Sleep(time.Second * 3) + startTaosadapter(cmd, port) + }() + err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) + assert.NoError(t, err) + doRequest("create table test_ws_tmq_sub_reconnect.st(ts timestamp,v int) tags (cn binary(20))") + doRequest("create table test_ws_tmq_sub_reconnect.t1 using test_ws_tmq_sub_reconnect.st tags ('t1')") + doRequest("insert into test_ws_tmq_sub_reconnect.t1 values (now,1)") + stopTaosadapter(cmd) + go func() { + time.Sleep(time.Second * 3) + startTaosadapter(cmd, port) + }() + haveMessage := false + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + t.Log(e) + assert.Equal(t, "test_ws_tmq_sub_reconnect", e.DBName()) + haveMessage = true + break + } + } + assert.True(t, haveMessage) +} From dd120779d332695b9804d2baa02f871967bd38f2 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 12 Jun 2024 17:53:44 +0800 Subject: [PATCH 2/2] enh: websocket reconnect --- ws/schemaless/schemaless.go | 18 ++++++++++----- ws/schemaless/schemaless_test.go | 21 ++++++++++++------ ws/stmt/connector.go | 18 ++++++++++----- ws/stmt/stmt_test.go | 16 +++++++++++--- ws/tmq/consumer.go | 38 +++++++++++++++++++------------- ws/tmq/consumer_test.go | 24 ++++++++++++++++++-- 6 files changed, 96 insertions(+), 39 deletions(-) diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index 7bd3e1a..f22c7a6 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net" "net/url" "sync" "time" @@ -164,12 +165,17 @@ func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl in if !s.autoReconnect { return err } - err = s.reconnect() - if err != nil { - return err - } - respBytes, err = s.sendText(uint64(reqID), envelope) - if err != nil { + var opError *net.OpError + if errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = s.reconnect() + if err != nil { + return err + } + respBytes, err = s.sendText(uint64(reqID), envelope) + if err != nil { + return err + } + } else { return err } } diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index 1d7cc8f..dc9caa6 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -14,6 +14,7 @@ import ( "time" jsoniter "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" taosErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/ws/client" ) @@ -198,8 +199,8 @@ func TestSchemalessReconnect(t *testing.T) { } s, err := NewSchemaless(NewConfig(fmt.Sprintf("ws://localhost:%s", port), 1, SetDb("test_schemaless_reconnect"), - SetReadTimeout(10*time.Second), - SetWriteTimeout(10*time.Second), + SetReadTimeout(3*time.Second), + SetWriteTimeout(3*time.Second), SetUser("root"), SetPassword("taosdata"), //SetEnableCompression(true), @@ -214,10 +215,12 @@ func TestSchemalessReconnect(t *testing.T) { t.Fatal(err) } stopTaosadapter(cmd) - time.Sleep(time.Second * 5) + time.Sleep(time.Second * 3) + startChan := make(chan struct{}) go func() { - time.Sleep(time.Second * 3) + time.Sleep(time.Second * 10) err = startTaosadapter(cmd, port) + startChan <- struct{}{} if err != nil { t.Error(err) return @@ -228,7 +231,11 @@ func TestSchemalessReconnect(t *testing.T) { "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + "measurement,host=host1 field1=2i,field2=2.0 1577837600000" err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) - if err != nil { - t.Fatal(err) - } + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.NoError(t, err) + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.NoError(t, err) } diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index f92596b..a08cd2e 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "net" "net/url" "sync" "sync/atomic" @@ -352,12 +353,17 @@ func (c *Connector) Init() (*Stmt, error) { if !c.autoReconnect { return nil, err } - err = c.reconnect() - if err != nil { - return nil, err - } - respBytes, err = c.sendText(reqID, envelope) - if err != nil { + var opError *net.OpError + if errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return nil, err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + } else { return nil, err } } diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index 6baefc5..652766e 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -1074,8 +1074,8 @@ func TestSTMTReconnect(t *testing.T) { config := NewConfig("ws://127.0.0.1:"+port, 0) config.SetConnectUser("root") config.SetConnectPass("taosdata") - config.SetMessageTimeout(10 * time.Second) - config.SetWriteWait(common.DefaultWriteWait) + config.SetMessageTimeout(3 * time.Second) + config.SetWriteWait(3 * time.Second) config.SetEnableCompression(true) config.SetErrorHandler(func(connector *Connector, err error) { t.Log(err) @@ -1095,11 +1095,21 @@ func TestSTMTReconnect(t *testing.T) { assert.NoError(t, err) stmt.Close() stopTaosadapter(cmd) + startChan := make(chan struct{}) go func() { time.Sleep(time.Second * 3) - startTaosadapter(cmd, port) + err = startTaosadapter(cmd, port) + startChan <- struct{}{} + if err != nil { + t.Error(err) + return + } }() stmt, err = connector.Init() + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + stmt, err = connector.Init() assert.NoError(t, err) stmt.Close() } diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index 5352c2c..64bb1ed 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "net" "net/url" "strconv" "sync" @@ -365,9 +366,6 @@ const ( var ClosedErr = errors.New("connection closed") func (c *Consumer) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { - if !c.client.IsRunning() { - return nil, ClosedErr - } channel := &IndexedChan{ index: reqID, channel: make(chan []byte, 1), @@ -449,12 +447,17 @@ func (c *Consumer) doSubscribe(topics []string, reconnect bool) error { if !reconnect { return err } - err = c.reconnect() - if err != nil { - return err - } - respBytes, err = c.sendText(reqID, envelope) - if err != nil { + var opError *net.OpError + if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return err + } + } else { return err } } @@ -510,12 +513,17 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { if !c.autoReconnect { return tmq.NewTMQErrorWithErr(err) } - err = c.reconnect() - if err != nil { - return tmq.NewTMQErrorWithErr(err) - } - respBytes, err = c.sendText(reqID, envelope) - if err != nil { + var opError *net.OpError + if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + } else { return tmq.NewTMQErrorWithErr(err) } } diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index ef8e0ac..37dd34b 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -909,7 +909,7 @@ func TestSubscribeReconnect(t *testing.T) { consumer, err := NewConsumer(&tmq.ConfigMap{ "ws.url": "ws://127.0.0.1:" + port, "ws.message.channelLen": uint(0), - "ws.message.timeout": time.Second * 10, + "ws.message.timeout": time.Second * 5, "ws.message.writeWait": common.DefaultWriteWait, "td.connect.user": "root", "td.connect.pass": "taosdata", @@ -925,11 +925,22 @@ func TestSubscribeReconnect(t *testing.T) { }) assert.NoError(t, err) stopTaosadapter(cmd) + time.Sleep(time.Second) + startChan := make(chan struct{}) go func() { time.Sleep(time.Second * 3) - startTaosadapter(cmd, port) + err = startTaosadapter(cmd, port) + if err != nil { + t.Error(err) + return + } + startChan <- struct{}{} }() err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) assert.NoError(t, err) doRequest("create table test_ws_tmq_sub_reconnect.st(ts timestamp,v int) tags (cn binary(20))") doRequest("create table test_ws_tmq_sub_reconnect.t1 using test_ws_tmq_sub_reconnect.st tags ('t1')") @@ -938,7 +949,14 @@ func TestSubscribeReconnect(t *testing.T) { go func() { time.Sleep(time.Second * 3) startTaosadapter(cmd, port) + startChan <- struct{}{} }() + time.Sleep(time.Second) + event := consumer.Poll(500) + assert.NotNil(t, event) + _, ok := event.(tmq.Error) + assert.True(t, ok) + <-startChan haveMessage := false for i := 0; i < 10; i++ { event := consumer.Poll(500) @@ -951,6 +969,8 @@ func TestSubscribeReconnect(t *testing.T) { assert.Equal(t, "test_ws_tmq_sub_reconnect", e.DBName()) haveMessage = true break + default: + t.Log(e) } } assert.True(t, haveMessage)