diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index 7bd3e1a..f22c7a6 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net" "net/url" "sync" "time" @@ -164,12 +165,17 @@ func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl in if !s.autoReconnect { return err } - err = s.reconnect() - if err != nil { - return err - } - respBytes, err = s.sendText(uint64(reqID), envelope) - if err != nil { + var opError *net.OpError + if errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = s.reconnect() + if err != nil { + return err + } + respBytes, err = s.sendText(uint64(reqID), envelope) + if err != nil { + return err + } + } else { return err } } diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index 1d7cc8f..dc9caa6 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -14,6 +14,7 @@ import ( "time" jsoniter "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" taosErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/ws/client" ) @@ -198,8 +199,8 @@ func TestSchemalessReconnect(t *testing.T) { } s, err := NewSchemaless(NewConfig(fmt.Sprintf("ws://localhost:%s", port), 1, SetDb("test_schemaless_reconnect"), - SetReadTimeout(10*time.Second), - SetWriteTimeout(10*time.Second), + SetReadTimeout(3*time.Second), + SetWriteTimeout(3*time.Second), SetUser("root"), SetPassword("taosdata"), //SetEnableCompression(true), @@ -214,10 +215,12 @@ func TestSchemalessReconnect(t *testing.T) { t.Fatal(err) } stopTaosadapter(cmd) - time.Sleep(time.Second * 5) + time.Sleep(time.Second * 3) + startChan := make(chan struct{}) go func() { - time.Sleep(time.Second * 3) + time.Sleep(time.Second * 10) err = startTaosadapter(cmd, port) + startChan <- struct{}{} if err != nil { t.Error(err) return @@ -228,7 +231,11 @@ func TestSchemalessReconnect(t *testing.T) { "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + "measurement,host=host1 field1=2i,field2=2.0 1577837600000" err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) - if err != nil { - t.Fatal(err) - } + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.NoError(t, err) + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.NoError(t, err) } diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index f92596b..a08cd2e 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "net" "net/url" "sync" "sync/atomic" @@ -352,12 +353,17 @@ func (c *Connector) Init() (*Stmt, error) { if !c.autoReconnect { return nil, err } - err = c.reconnect() - if err != nil { - return nil, err - } - respBytes, err = c.sendText(reqID, envelope) - if err != nil { + var opError *net.OpError + if errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return nil, err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + } else { return nil, err } } diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index 6baefc5..652766e 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -1074,8 +1074,8 @@ func TestSTMTReconnect(t *testing.T) { config := NewConfig("ws://127.0.0.1:"+port, 0) config.SetConnectUser("root") config.SetConnectPass("taosdata") - config.SetMessageTimeout(10 * time.Second) - config.SetWriteWait(common.DefaultWriteWait) + config.SetMessageTimeout(3 * time.Second) + config.SetWriteWait(3 * time.Second) config.SetEnableCompression(true) config.SetErrorHandler(func(connector *Connector, err error) { t.Log(err) @@ -1095,11 +1095,21 @@ func TestSTMTReconnect(t *testing.T) { assert.NoError(t, err) stmt.Close() stopTaosadapter(cmd) + startChan := make(chan struct{}) go func() { time.Sleep(time.Second * 3) - startTaosadapter(cmd, port) + err = startTaosadapter(cmd, port) + startChan <- struct{}{} + if err != nil { + t.Error(err) + return + } }() stmt, err = connector.Init() + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + stmt, err = connector.Init() assert.NoError(t, err) stmt.Close() } diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index 5352c2c..64bb1ed 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "net" "net/url" "strconv" "sync" @@ -365,9 +366,6 @@ const ( var ClosedErr = errors.New("connection closed") func (c *Consumer) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { - if !c.client.IsRunning() { - return nil, ClosedErr - } channel := &IndexedChan{ index: reqID, channel: make(chan []byte, 1), @@ -449,12 +447,17 @@ func (c *Consumer) doSubscribe(topics []string, reconnect bool) error { if !reconnect { return err } - err = c.reconnect() - if err != nil { - return err - } - respBytes, err = c.sendText(reqID, envelope) - if err != nil { + var opError *net.OpError + if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return err + } + } else { return err } } @@ -510,12 +513,17 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { if !c.autoReconnect { return tmq.NewTMQErrorWithErr(err) } - err = c.reconnect() - if err != nil { - return tmq.NewTMQErrorWithErr(err) - } - respBytes, err = c.sendText(reqID, envelope) - if err != nil { + var opError *net.OpError + if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + } else { return tmq.NewTMQErrorWithErr(err) } } diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index ef8e0ac..37dd34b 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -909,7 +909,7 @@ func TestSubscribeReconnect(t *testing.T) { consumer, err := NewConsumer(&tmq.ConfigMap{ "ws.url": "ws://127.0.0.1:" + port, "ws.message.channelLen": uint(0), - "ws.message.timeout": time.Second * 10, + "ws.message.timeout": time.Second * 5, "ws.message.writeWait": common.DefaultWriteWait, "td.connect.user": "root", "td.connect.pass": "taosdata", @@ -925,11 +925,22 @@ func TestSubscribeReconnect(t *testing.T) { }) assert.NoError(t, err) stopTaosadapter(cmd) + time.Sleep(time.Second) + startChan := make(chan struct{}) go func() { time.Sleep(time.Second * 3) - startTaosadapter(cmd, port) + err = startTaosadapter(cmd, port) + if err != nil { + t.Error(err) + return + } + startChan <- struct{}{} }() err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) assert.NoError(t, err) doRequest("create table test_ws_tmq_sub_reconnect.st(ts timestamp,v int) tags (cn binary(20))") doRequest("create table test_ws_tmq_sub_reconnect.t1 using test_ws_tmq_sub_reconnect.st tags ('t1')") @@ -938,7 +949,14 @@ func TestSubscribeReconnect(t *testing.T) { go func() { time.Sleep(time.Second * 3) startTaosadapter(cmd, port) + startChan <- struct{}{} }() + time.Sleep(time.Second) + event := consumer.Poll(500) + assert.NotNil(t, event) + _, ok := event.(tmq.Error) + assert.True(t, ok) + <-startChan haveMessage := false for i := 0; i < 10; i++ { event := consumer.Poll(500) @@ -951,6 +969,8 @@ func TestSubscribeReconnect(t *testing.T) { assert.Equal(t, "test_ws_tmq_sub_reconnect", e.DBName()) haveMessage = true break + default: + t.Log(e) } } assert.True(t, haveMessage)