Skip to content

Commit

Permalink
enh: websocket reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
huskar-t committed Jun 12, 2024
1 parent befeed1 commit dd12077
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 39 deletions.
18 changes: 12 additions & 6 deletions ws/schemaless/schemaless.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/url"
"sync"
"time"
Expand Down Expand Up @@ -164,12 +165,17 @@ func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl in
if !s.autoReconnect {
return err
}
err = s.reconnect()
if err != nil {
return err
}
respBytes, err = s.sendText(uint64(reqID), envelope)
if err != nil {
var opError *net.OpError
if errors.Is(err, client.ClosedError) || errors.As(err, &opError) {
err = s.reconnect()
if err != nil {
return err
}
respBytes, err = s.sendText(uint64(reqID), envelope)
if err != nil {
return err
}
} else {
return err
}
}
Expand Down
21 changes: 14 additions & 7 deletions ws/schemaless/schemaless_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

jsoniter "github.com/json-iterator/go"
"github.com/stretchr/testify/assert"
taosErrors "github.com/taosdata/driver-go/v3/errors"
"github.com/taosdata/driver-go/v3/ws/client"
)
Expand Down Expand Up @@ -198,8 +199,8 @@ func TestSchemalessReconnect(t *testing.T) {
}
s, err := NewSchemaless(NewConfig(fmt.Sprintf("ws://localhost:%s", port), 1,
SetDb("test_schemaless_reconnect"),
SetReadTimeout(10*time.Second),
SetWriteTimeout(10*time.Second),
SetReadTimeout(3*time.Second),
SetWriteTimeout(3*time.Second),
SetUser("root"),
SetPassword("taosdata"),
//SetEnableCompression(true),
Expand All @@ -214,10 +215,12 @@ func TestSchemalessReconnect(t *testing.T) {
t.Fatal(err)
}
stopTaosadapter(cmd)
time.Sleep(time.Second * 5)
time.Sleep(time.Second * 3)
startChan := make(chan struct{})
go func() {
time.Sleep(time.Second * 3)
time.Sleep(time.Second * 10)
err = startTaosadapter(cmd, port)
startChan <- struct{}{}
if err != nil {
t.Error(err)
return
Expand All @@ -228,7 +231,11 @@ func TestSchemalessReconnect(t *testing.T) {
"measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" +
"measurement,host=host1 field1=2i,field2=2.0 1577837600000"
err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0)
if err != nil {
t.Fatal(err)
}
assert.Error(t, err)
<-startChan
time.Sleep(time.Second)
err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0)
assert.NoError(t, err)
err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0)
assert.NoError(t, err)
}
18 changes: 12 additions & 6 deletions ws/stmt/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/url"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -352,12 +353,17 @@ func (c *Connector) Init() (*Stmt, error) {
if !c.autoReconnect {
return nil, err
}
err = c.reconnect()
if err != nil {
return nil, err
}
respBytes, err = c.sendText(reqID, envelope)
if err != nil {
var opError *net.OpError
if errors.Is(err, client.ClosedError) || errors.As(err, &opError) {
err = c.reconnect()
if err != nil {
return nil, err
}
respBytes, err = c.sendText(reqID, envelope)
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
Expand Down
16 changes: 13 additions & 3 deletions ws/stmt/stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1074,8 +1074,8 @@ func TestSTMTReconnect(t *testing.T) {
config := NewConfig("ws://127.0.0.1:"+port, 0)
config.SetConnectUser("root")
config.SetConnectPass("taosdata")
config.SetMessageTimeout(10 * time.Second)
config.SetWriteWait(common.DefaultWriteWait)
config.SetMessageTimeout(3 * time.Second)
config.SetWriteWait(3 * time.Second)
config.SetEnableCompression(true)
config.SetErrorHandler(func(connector *Connector, err error) {
t.Log(err)
Expand All @@ -1095,11 +1095,21 @@ func TestSTMTReconnect(t *testing.T) {
assert.NoError(t, err)
stmt.Close()
stopTaosadapter(cmd)
startChan := make(chan struct{})
go func() {
time.Sleep(time.Second * 3)
startTaosadapter(cmd, port)
err = startTaosadapter(cmd, port)
startChan <- struct{}{}
if err != nil {
t.Error(err)
return
}
}()
stmt, err = connector.Init()
assert.Error(t, err)
<-startChan
time.Sleep(time.Second)
stmt, err = connector.Init()
assert.NoError(t, err)
stmt.Close()
}
38 changes: 23 additions & 15 deletions ws/tmq/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"sync"
Expand Down Expand Up @@ -365,9 +366,6 @@ const (
var ClosedErr = errors.New("connection closed")

func (c *Consumer) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) {
if !c.client.IsRunning() {
return nil, ClosedErr
}
channel := &IndexedChan{
index: reqID,
channel: make(chan []byte, 1),
Expand Down Expand Up @@ -449,12 +447,17 @@ func (c *Consumer) doSubscribe(topics []string, reconnect bool) error {
if !reconnect {
return err
}
err = c.reconnect()
if err != nil {
return err
}
respBytes, err = c.sendText(reqID, envelope)
if err != nil {
var opError *net.OpError
if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) {
err = c.reconnect()
if err != nil {
return err
}
respBytes, err = c.sendText(reqID, envelope)
if err != nil {
return err
}
} else {
return err
}
}
Expand Down Expand Up @@ -510,12 +513,17 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event {
if !c.autoReconnect {
return tmq.NewTMQErrorWithErr(err)
}
err = c.reconnect()
if err != nil {
return tmq.NewTMQErrorWithErr(err)
}
respBytes, err = c.sendText(reqID, envelope)
if err != nil {
var opError *net.OpError
if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) {
err = c.reconnect()
if err != nil {
return tmq.NewTMQErrorWithErr(err)
}
respBytes, err = c.sendText(reqID, envelope)
if err != nil {
return tmq.NewTMQErrorWithErr(err)
}
} else {
return tmq.NewTMQErrorWithErr(err)
}
}
Expand Down
24 changes: 22 additions & 2 deletions ws/tmq/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ func TestSubscribeReconnect(t *testing.T) {
consumer, err := NewConsumer(&tmq.ConfigMap{
"ws.url": "ws://127.0.0.1:" + port,
"ws.message.channelLen": uint(0),
"ws.message.timeout": time.Second * 10,
"ws.message.timeout": time.Second * 5,
"ws.message.writeWait": common.DefaultWriteWait,
"td.connect.user": "root",
"td.connect.pass": "taosdata",
Expand All @@ -925,11 +925,22 @@ func TestSubscribeReconnect(t *testing.T) {
})
assert.NoError(t, err)
stopTaosadapter(cmd)
time.Sleep(time.Second)
startChan := make(chan struct{})
go func() {
time.Sleep(time.Second * 3)
startTaosadapter(cmd, port)
err = startTaosadapter(cmd, port)
if err != nil {
t.Error(err)
return
}
startChan <- struct{}{}
}()
err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil)
assert.Error(t, err)
<-startChan
time.Sleep(time.Second)
err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil)
assert.NoError(t, err)
doRequest("create table test_ws_tmq_sub_reconnect.st(ts timestamp,v int) tags (cn binary(20))")
doRequest("create table test_ws_tmq_sub_reconnect.t1 using test_ws_tmq_sub_reconnect.st tags ('t1')")
Expand All @@ -938,7 +949,14 @@ func TestSubscribeReconnect(t *testing.T) {
go func() {
time.Sleep(time.Second * 3)
startTaosadapter(cmd, port)
startChan <- struct{}{}
}()
time.Sleep(time.Second)
event := consumer.Poll(500)
assert.NotNil(t, event)
_, ok := event.(tmq.Error)
assert.True(t, ok)
<-startChan
haveMessage := false
for i := 0; i < 10; i++ {
event := consumer.Poll(500)
Expand All @@ -951,6 +969,8 @@ func TestSubscribeReconnect(t *testing.T) {
assert.Equal(t, "test_ws_tmq_sub_reconnect", e.DBName())
haveMessage = true
break
default:
t.Log(e)
}
}
assert.True(t, haveMessage)
Expand Down

0 comments on commit dd12077

Please sign in to comment.